001 /*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements. See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License. You may obtain a copy of the License at
008 *
009 * http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017 package org.apache.commons.math3.fitting;
018
019 import java.util.ArrayList;
020 import java.util.List;
021 import org.apache.commons.math3.analysis.MultivariateVectorFunction;
022 import org.apache.commons.math3.analysis.MultivariateMatrixFunction;
023 import org.apache.commons.math3.analysis.ParametricUnivariateFunction;
024 import org.apache.commons.math3.optim.MaxEval;
025 import org.apache.commons.math3.optim.InitialGuess;
026 import org.apache.commons.math3.optim.PointVectorValuePair;
027 import org.apache.commons.math3.optim.nonlinear.vector.MultivariateVectorOptimizer;
028 import org.apache.commons.math3.optim.nonlinear.vector.ModelFunction;
029 import org.apache.commons.math3.optim.nonlinear.vector.ModelFunctionJacobian;
030 import org.apache.commons.math3.optim.nonlinear.vector.Target;
031 import org.apache.commons.math3.optim.nonlinear.vector.Weight;
032
033 /**
034 * Fitter for parametric univariate real functions y = f(x).
035 * <br/>
036 * When a univariate real function y = f(x) does depend on some
037 * unknown parameters p<sub>0</sub>, p<sub>1</sub> ... p<sub>n-1</sub>,
038 * this class can be used to find these parameters. It does this
039 * by <em>fitting</em> the curve so it remains very close to a set of
040 * observed points (x<sub>0</sub>, y<sub>0</sub>), (x<sub>1</sub>,
041 * y<sub>1</sub>) ... (x<sub>k-1</sub>, y<sub>k-1</sub>). This fitting
042 * is done by finding the parameters values that minimizes the objective
043 * function ∑(y<sub>i</sub>-f(x<sub>i</sub>))<sup>2</sup>. This is
044 * really a least squares problem.
045 *
046 * @param <T> Function to use for the fit.
047 *
048 * @version $Id: CurveFitter.java 1416643 2012-12-03 19:37:14Z tn $
049 * @since 2.0
050 */
051 public class CurveFitter<T extends ParametricUnivariateFunction> {
052 /** Optimizer to use for the fitting. */
053 private final MultivariateVectorOptimizer optimizer;
054 /** Observed points. */
055 private final List<WeightedObservedPoint> observations;
056
057 /**
058 * Simple constructor.
059 *
060 * @param optimizer Optimizer to use for the fitting.
061 * @since 3.1
062 */
063 public CurveFitter(final MultivariateVectorOptimizer optimizer) {
064 this.optimizer = optimizer;
065 observations = new ArrayList<WeightedObservedPoint>();
066 }
067
068 /** Add an observed (x,y) point to the sample with unit weight.
069 * <p>Calling this method is equivalent to call
070 * {@code addObservedPoint(1.0, x, y)}.</p>
071 * @param x abscissa of the point
072 * @param y observed value of the point at x, after fitting we should
073 * have f(x) as close as possible to this value
074 * @see #addObservedPoint(double, double, double)
075 * @see #addObservedPoint(WeightedObservedPoint)
076 * @see #getObservations()
077 */
078 public void addObservedPoint(double x, double y) {
079 addObservedPoint(1.0, x, y);
080 }
081
082 /** Add an observed weighted (x,y) point to the sample.
083 * @param weight weight of the observed point in the fit
084 * @param x abscissa of the point
085 * @param y observed value of the point at x, after fitting we should
086 * have f(x) as close as possible to this value
087 * @see #addObservedPoint(double, double)
088 * @see #addObservedPoint(WeightedObservedPoint)
089 * @see #getObservations()
090 */
091 public void addObservedPoint(double weight, double x, double y) {
092 observations.add(new WeightedObservedPoint(weight, x, y));
093 }
094
095 /** Add an observed weighted (x,y) point to the sample.
096 * @param observed observed point to add
097 * @see #addObservedPoint(double, double)
098 * @see #addObservedPoint(double, double, double)
099 * @see #getObservations()
100 */
101 public void addObservedPoint(WeightedObservedPoint observed) {
102 observations.add(observed);
103 }
104
105 /** Get the observed points.
106 * @return observed points
107 * @see #addObservedPoint(double, double)
108 * @see #addObservedPoint(double, double, double)
109 * @see #addObservedPoint(WeightedObservedPoint)
110 */
111 public WeightedObservedPoint[] getObservations() {
112 return observations.toArray(new WeightedObservedPoint[observations.size()]);
113 }
114
115 /**
116 * Remove all observations.
117 */
118 public void clearObservations() {
119 observations.clear();
120 }
121
122 /**
123 * Fit a curve.
124 * This method compute the coefficients of the curve that best
125 * fit the sample of observed points previously given through calls
126 * to the {@link #addObservedPoint(WeightedObservedPoint)
127 * addObservedPoint} method.
128 *
129 * @param f parametric function to fit.
130 * @param initialGuess first guess of the function parameters.
131 * @return the fitted parameters.
132 * @throws org.apache.commons.math3.exception.DimensionMismatchException
133 * if the start point dimension is wrong.
134 */
135 public double[] fit(T f, final double[] initialGuess) {
136 return fit(Integer.MAX_VALUE, f, initialGuess);
137 }
138
139 /**
140 * Fit a curve.
141 * This method compute the coefficients of the curve that best
142 * fit the sample of observed points previously given through calls
143 * to the {@link #addObservedPoint(WeightedObservedPoint)
144 * addObservedPoint} method.
145 *
146 * @param f parametric function to fit.
147 * @param initialGuess first guess of the function parameters.
148 * @param maxEval Maximum number of function evaluations.
149 * @return the fitted parameters.
150 * @throws org.apache.commons.math3.exception.TooManyEvaluationsException
151 * if the number of allowed evaluations is exceeded.
152 * @throws org.apache.commons.math3.exception.DimensionMismatchException
153 * if the start point dimension is wrong.
154 * @since 3.0
155 */
156 public double[] fit(int maxEval, T f,
157 final double[] initialGuess) {
158 // Prepare least squares problem.
159 double[] target = new double[observations.size()];
160 double[] weights = new double[observations.size()];
161 int i = 0;
162 for (WeightedObservedPoint point : observations) {
163 target[i] = point.getY();
164 weights[i] = point.getWeight();
165 ++i;
166 }
167
168 // Input to the optimizer: the model and its Jacobian.
169 final TheoreticalValuesFunction model = new TheoreticalValuesFunction(f);
170
171 // Perform the fit.
172 final PointVectorValuePair optimum
173 = optimizer.optimize(new MaxEval(maxEval),
174 model.getModelFunction(),
175 model.getModelFunctionJacobian(),
176 new Target(target),
177 new Weight(weights),
178 new InitialGuess(initialGuess));
179 // Extract the coefficients.
180 return optimum.getPointRef();
181 }
182
183 /** Vectorial function computing function theoretical values. */
184 private class TheoreticalValuesFunction {
185 /** Function to fit. */
186 private final ParametricUnivariateFunction f;
187
188 /**
189 * @param f function to fit.
190 */
191 public TheoreticalValuesFunction(final ParametricUnivariateFunction f) {
192 this.f = f;
193 }
194
195 /**
196 * @return the model function values.
197 */
198 public ModelFunction getModelFunction() {
199 return new ModelFunction(new MultivariateVectorFunction() {
200 /** {@inheritDoc} */
201 public double[] value(double[] point) {
202 // compute the residuals
203 final double[] values = new double[observations.size()];
204 int i = 0;
205 for (WeightedObservedPoint observed : observations) {
206 values[i++] = f.value(observed.getX(), point);
207 }
208
209 return values;
210 }
211 });
212 }
213
214 /**
215 * @return the model function Jacobian.
216 */
217 public ModelFunctionJacobian getModelFunctionJacobian() {
218 return new ModelFunctionJacobian(new MultivariateMatrixFunction() {
219 public double[][] value(double[] point) {
220 final double[][] jacobian = new double[observations.size()][];
221 int i = 0;
222 for (WeightedObservedPoint observed : observations) {
223 jacobian[i++] = f.gradient(observed.getX(), point);
224 }
225 return jacobian;
226 }
227 });
228 }
229 }
230 }