001 /*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements. See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License. You may obtain a copy of the License at
008 *
009 * http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017 package org.apache.commons.math3.distribution;
018
019 import java.util.List;
020 import java.util.ArrayList;
021 import org.apache.commons.math3.exception.DimensionMismatchException;
022 import org.apache.commons.math3.exception.NotPositiveException;
023 import org.apache.commons.math3.exception.MathArithmeticException;
024 import org.apache.commons.math3.exception.util.LocalizedFormats;
025 import org.apache.commons.math3.random.RandomGenerator;
026 import org.apache.commons.math3.random.Well19937c;
027 import org.apache.commons.math3.util.Pair;
028
029 /**
030 * Class for representing <a href="http://en.wikipedia.org/wiki/Mixture_model">
031 * mixture model</a> distributions.
032 *
033 * @param <T> Type of the mixture components.
034 *
035 * @version $Id: MixtureMultivariateRealDistribution.java 1416643 2012-12-03 19:37:14Z tn $
036 * @since 3.1
037 */
038 public class MixtureMultivariateRealDistribution<T extends MultivariateRealDistribution>
039 extends AbstractMultivariateRealDistribution {
040 /** Normalized weight of each mixture component. */
041 private final double[] weight;
042 /** Mixture components. */
043 private final List<T> distribution;
044
045 /**
046 * Creates a mixture model from a list of distributions and their
047 * associated weights.
048 *
049 * @param components List of (weight, distribution) pairs from which to sample.
050 */
051 public MixtureMultivariateRealDistribution(List<Pair<Double, T>> components) {
052 this(new Well19937c(), components);
053 }
054
055 /**
056 * Creates a mixture model from a list of distributions and their
057 * associated weights.
058 *
059 * @param rng Random number generator.
060 * @param components Distributions from which to sample.
061 * @throws NotPositiveException if any of the weights is negative.
062 * @throws DimensionMismatchException if not all components have the same
063 * number of variables.
064 */
065 public MixtureMultivariateRealDistribution(RandomGenerator rng,
066 List<Pair<Double, T>> components) {
067 super(rng, components.get(0).getSecond().getDimension());
068
069 final int numComp = components.size();
070 final int dim = getDimension();
071 double weightSum = 0;
072 for (int i = 0; i < numComp; i++) {
073 final Pair<Double, T> comp = components.get(i);
074 if (comp.getSecond().getDimension() != dim) {
075 throw new DimensionMismatchException(comp.getSecond().getDimension(), dim);
076 }
077 if (comp.getFirst() < 0) {
078 throw new NotPositiveException(comp.getFirst());
079 }
080 weightSum += comp.getFirst();
081 }
082
083 // Check for overflow.
084 if (Double.isInfinite(weightSum)) {
085 throw new MathArithmeticException(LocalizedFormats.OVERFLOW);
086 }
087
088 // Store each distribution and its normalized weight.
089 distribution = new ArrayList<T>();
090 weight = new double[numComp];
091 for (int i = 0; i < numComp; i++) {
092 final Pair<Double, T> comp = components.get(i);
093 weight[i] = comp.getFirst() / weightSum;
094 distribution.add(comp.getSecond());
095 }
096 }
097
098 /** {@inheritDoc} */
099 public double density(final double[] values) {
100 double p = 0;
101 for (int i = 0; i < weight.length; i++) {
102 p += weight[i] * distribution.get(i).density(values);
103 }
104 return p;
105 }
106
107 /** {@inheritDoc} */
108 public double[] sample() {
109 // Sampled values.
110 double[] vals = null;
111
112 // Determine which component to sample from.
113 final double randomValue = random.nextDouble();
114 double sum = 0;
115
116 for (int i = 0; i < weight.length; i++) {
117 sum += weight[i];
118 if (randomValue <= sum) {
119 // pick model i
120 vals = distribution.get(i).sample();
121 break;
122 }
123 }
124
125 if (vals == null) {
126 // This should never happen, but it ensures we won't return a null in
127 // case the loop above has some floating point inequality problem on
128 // the final iteration.
129 vals = distribution.get(weight.length - 1).sample();
130 }
131
132 return vals;
133 }
134
135 /** {@inheritDoc} */
136 public void reseedRandomGenerator(long seed) {
137 // Seed needs to be propagated to underlying components
138 // in order to maintain consistency between runs.
139 super.reseedRandomGenerator(seed);
140
141 for (int i = 0; i < distribution.size(); i++) {
142 // Make each component's seed different in order to avoid
143 // using the same sequence of random numbers.
144 distribution.get(i).reseedRandomGenerator(i + 1 + seed);
145 }
146 }
147
148 /**
149 * Gets the distributions that make up the mixture model.
150 *
151 * @return the component distributions and associated weights.
152 */
153 public List<Pair<Double, T>> getComponents() {
154 final List<Pair<Double, T>> list = new ArrayList<Pair<Double, T>>();
155
156 for (int i = 0; i < weight.length; i++) {
157 list.add(new Pair<Double, T>(weight[i], distribution.get(i)));
158 }
159
160 return list;
161 }
162 }