001 /*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements. See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License. You may obtain a copy of the License at
008 *
009 * http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018 package org.apache.commons.math3.optimization.fitting;
019
020 import java.util.ArrayList;
021 import java.util.List;
022
023 import org.apache.commons.math3.analysis.DifferentiableMultivariateVectorFunction;
024 import org.apache.commons.math3.analysis.MultivariateMatrixFunction;
025 import org.apache.commons.math3.analysis.ParametricUnivariateFunction;
026 import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
027 import org.apache.commons.math3.analysis.differentiation.MultivariateDifferentiableVectorFunction;
028 import org.apache.commons.math3.optimization.DifferentiableMultivariateVectorOptimizer;
029 import org.apache.commons.math3.optimization.MultivariateDifferentiableVectorOptimizer;
030 import org.apache.commons.math3.optimization.PointVectorValuePair;
031
032 /** Fitter for parametric univariate real functions y = f(x).
033 * <br/>
034 * When a univariate real function y = f(x) does depend on some
035 * unknown parameters p<sub>0</sub>, p<sub>1</sub> ... p<sub>n-1</sub>,
036 * this class can be used to find these parameters. It does this
037 * by <em>fitting</em> the curve so it remains very close to a set of
038 * observed points (x<sub>0</sub>, y<sub>0</sub>), (x<sub>1</sub>,
039 * y<sub>1</sub>) ... (x<sub>k-1</sub>, y<sub>k-1</sub>). This fitting
040 * is done by finding the parameters values that minimizes the objective
041 * function ∑(y<sub>i</sub>-f(x<sub>i</sub>))<sup>2</sup>. This is
042 * really a least squares problem.
043 *
044 * @param <T> Function to use for the fit.
045 *
046 * @version $Id: CurveFitter.java 1422230 2012-12-15 12:11:13Z erans $
047 * @deprecated As of 3.1 (to be removed in 4.0).
048 * @since 2.0
049 */
050 @Deprecated
051 public class CurveFitter<T extends ParametricUnivariateFunction> {
052
053 /** Optimizer to use for the fitting.
054 * @deprecated as of 3.1 replaced by {@link #optimizer}
055 */
056 @Deprecated
057 private final DifferentiableMultivariateVectorOptimizer oldOptimizer;
058
059 /** Optimizer to use for the fitting. */
060 private final MultivariateDifferentiableVectorOptimizer optimizer;
061
062 /** Observed points. */
063 private final List<WeightedObservedPoint> observations;
064
065 /** Simple constructor.
066 * @param optimizer optimizer to use for the fitting
067 * @deprecated as of 3.1 replaced by {@link #CurveFitter(MultivariateDifferentiableVectorOptimizer)}
068 */
069 public CurveFitter(final DifferentiableMultivariateVectorOptimizer optimizer) {
070 this.oldOptimizer = optimizer;
071 this.optimizer = null;
072 observations = new ArrayList<WeightedObservedPoint>();
073 }
074
075 /** Simple constructor.
076 * @param optimizer optimizer to use for the fitting
077 * @since 3.1
078 */
079 public CurveFitter(final MultivariateDifferentiableVectorOptimizer optimizer) {
080 this.oldOptimizer = null;
081 this.optimizer = optimizer;
082 observations = new ArrayList<WeightedObservedPoint>();
083 }
084
085 /** Add an observed (x,y) point to the sample with unit weight.
086 * <p>Calling this method is equivalent to call
087 * {@code addObservedPoint(1.0, x, y)}.</p>
088 * @param x abscissa of the point
089 * @param y observed value of the point at x, after fitting we should
090 * have f(x) as close as possible to this value
091 * @see #addObservedPoint(double, double, double)
092 * @see #addObservedPoint(WeightedObservedPoint)
093 * @see #getObservations()
094 */
095 public void addObservedPoint(double x, double y) {
096 addObservedPoint(1.0, x, y);
097 }
098
099 /** Add an observed weighted (x,y) point to the sample.
100 * @param weight weight of the observed point in the fit
101 * @param x abscissa of the point
102 * @param y observed value of the point at x, after fitting we should
103 * have f(x) as close as possible to this value
104 * @see #addObservedPoint(double, double)
105 * @see #addObservedPoint(WeightedObservedPoint)
106 * @see #getObservations()
107 */
108 public void addObservedPoint(double weight, double x, double y) {
109 observations.add(new WeightedObservedPoint(weight, x, y));
110 }
111
112 /** Add an observed weighted (x,y) point to the sample.
113 * @param observed observed point to add
114 * @see #addObservedPoint(double, double)
115 * @see #addObservedPoint(double, double, double)
116 * @see #getObservations()
117 */
118 public void addObservedPoint(WeightedObservedPoint observed) {
119 observations.add(observed);
120 }
121
122 /** Get the observed points.
123 * @return observed points
124 * @see #addObservedPoint(double, double)
125 * @see #addObservedPoint(double, double, double)
126 * @see #addObservedPoint(WeightedObservedPoint)
127 */
128 public WeightedObservedPoint[] getObservations() {
129 return observations.toArray(new WeightedObservedPoint[observations.size()]);
130 }
131
132 /**
133 * Remove all observations.
134 */
135 public void clearObservations() {
136 observations.clear();
137 }
138
139 /**
140 * Fit a curve.
141 * This method compute the coefficients of the curve that best
142 * fit the sample of observed points previously given through calls
143 * to the {@link #addObservedPoint(WeightedObservedPoint)
144 * addObservedPoint} method.
145 *
146 * @param f parametric function to fit.
147 * @param initialGuess first guess of the function parameters.
148 * @return the fitted parameters.
149 * @throws org.apache.commons.math3.exception.DimensionMismatchException
150 * if the start point dimension is wrong.
151 */
152 public double[] fit(T f, final double[] initialGuess) {
153 return fit(Integer.MAX_VALUE, f, initialGuess);
154 }
155
156 /**
157 * Fit a curve.
158 * This method compute the coefficients of the curve that best
159 * fit the sample of observed points previously given through calls
160 * to the {@link #addObservedPoint(WeightedObservedPoint)
161 * addObservedPoint} method.
162 *
163 * @param f parametric function to fit.
164 * @param initialGuess first guess of the function parameters.
165 * @param maxEval Maximum number of function evaluations.
166 * @return the fitted parameters.
167 * @throws org.apache.commons.math3.exception.TooManyEvaluationsException
168 * if the number of allowed evaluations is exceeded.
169 * @throws org.apache.commons.math3.exception.DimensionMismatchException
170 * if the start point dimension is wrong.
171 * @since 3.0
172 */
173 public double[] fit(int maxEval, T f,
174 final double[] initialGuess) {
175 // prepare least squares problem
176 double[] target = new double[observations.size()];
177 double[] weights = new double[observations.size()];
178 int i = 0;
179 for (WeightedObservedPoint point : observations) {
180 target[i] = point.getY();
181 weights[i] = point.getWeight();
182 ++i;
183 }
184
185 // perform the fit
186 final PointVectorValuePair optimum;
187 if (optimizer == null) {
188 // to be removed in 4.0
189 optimum = oldOptimizer.optimize(maxEval, new OldTheoreticalValuesFunction(f),
190 target, weights, initialGuess);
191 } else {
192 optimum = optimizer.optimize(maxEval, new TheoreticalValuesFunction(f),
193 target, weights, initialGuess);
194 }
195
196 // extract the coefficients
197 return optimum.getPointRef();
198 }
199
200 /** Vectorial function computing function theoretical values. */
201 @Deprecated
202 private class OldTheoreticalValuesFunction
203 implements DifferentiableMultivariateVectorFunction {
204 /** Function to fit. */
205 private final ParametricUnivariateFunction f;
206
207 /** Simple constructor.
208 * @param f function to fit.
209 */
210 public OldTheoreticalValuesFunction(final ParametricUnivariateFunction f) {
211 this.f = f;
212 }
213
214 /** {@inheritDoc} */
215 public MultivariateMatrixFunction jacobian() {
216 return new MultivariateMatrixFunction() {
217 public double[][] value(double[] point) {
218 final double[][] jacobian = new double[observations.size()][];
219
220 int i = 0;
221 for (WeightedObservedPoint observed : observations) {
222 jacobian[i++] = f.gradient(observed.getX(), point);
223 }
224
225 return jacobian;
226 }
227 };
228 }
229
230 /** {@inheritDoc} */
231 public double[] value(double[] point) {
232 // compute the residuals
233 final double[] values = new double[observations.size()];
234 int i = 0;
235 for (WeightedObservedPoint observed : observations) {
236 values[i++] = f.value(observed.getX(), point);
237 }
238
239 return values;
240 }
241 }
242
243 /** Vectorial function computing function theoretical values. */
244 private class TheoreticalValuesFunction implements MultivariateDifferentiableVectorFunction {
245
246 /** Function to fit. */
247 private final ParametricUnivariateFunction f;
248
249 /** Simple constructor.
250 * @param f function to fit.
251 */
252 public TheoreticalValuesFunction(final ParametricUnivariateFunction f) {
253 this.f = f;
254 }
255
256 /** {@inheritDoc} */
257 public double[] value(double[] point) {
258 // compute the residuals
259 final double[] values = new double[observations.size()];
260 int i = 0;
261 for (WeightedObservedPoint observed : observations) {
262 values[i++] = f.value(observed.getX(), point);
263 }
264
265 return values;
266 }
267
268 /** {@inheritDoc} */
269 public DerivativeStructure[] value(DerivativeStructure[] point) {
270
271 // extract parameters
272 final double[] parameters = new double[point.length];
273 for (int k = 0; k < point.length; ++k) {
274 parameters[k] = point[k].getValue();
275 }
276
277 // compute the residuals
278 final DerivativeStructure[] values = new DerivativeStructure[observations.size()];
279 int i = 0;
280 for (WeightedObservedPoint observed : observations) {
281
282 // build the DerivativeStructure by adding first the value as a constant
283 // and then adding derivatives
284 DerivativeStructure vi = new DerivativeStructure(point.length, 1, f.value(observed.getX(), parameters));
285 for (int k = 0; k < point.length; ++k) {
286 vi = vi.add(new DerivativeStructure(point.length, 1, k, 0.0));
287 }
288
289 values[i++] = vi;
290
291 }
292
293 return values;
294 }
295
296 }
297
298 }