001 /*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements. See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License. You may obtain a copy of the License at
008 *
009 * http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018 package org.apache.commons.math3.analysis.function;
019
020 import org.apache.commons.math3.analysis.FunctionUtils;
021 import org.apache.commons.math3.analysis.UnivariateFunction;
022 import org.apache.commons.math3.analysis.DifferentiableUnivariateFunction;
023 import org.apache.commons.math3.analysis.ParametricUnivariateFunction;
024 import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
025 import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiableFunction;
026 import org.apache.commons.math3.exception.NotStrictlyPositiveException;
027 import org.apache.commons.math3.exception.NullArgumentException;
028 import org.apache.commons.math3.exception.DimensionMismatchException;
029 import org.apache.commons.math3.util.FastMath;
030
031 /**
032 * <a href="http://en.wikipedia.org/wiki/Generalised_logistic_function">
033 * Generalised logistic</a> function.
034 *
035 * @since 3.0
036 * @version $Id: Logistic.java 1391927 2012-09-30 00:03:30Z erans $
037 */
038 public class Logistic implements UnivariateDifferentiableFunction, DifferentiableUnivariateFunction {
039 /** Lower asymptote. */
040 private final double a;
041 /** Upper asymptote. */
042 private final double k;
043 /** Growth rate. */
044 private final double b;
045 /** Parameter that affects near which asymptote maximum growth occurs. */
046 private final double oneOverN;
047 /** Parameter that affects the position of the curve along the ordinate axis. */
048 private final double q;
049 /** Abscissa of maximum growth. */
050 private final double m;
051
052 /**
053 * @param k If {@code b > 0}, value of the function for x going towards +∞.
054 * If {@code b < 0}, value of the function for x going towards -∞.
055 * @param m Abscissa of maximum growth.
056 * @param b Growth rate.
057 * @param q Parameter that affects the position of the curve along the
058 * ordinate axis.
059 * @param a If {@code b > 0}, value of the function for x going towards -∞.
060 * If {@code b < 0}, value of the function for x going towards +∞.
061 * @param n Parameter that affects near which asymptote the maximum
062 * growth occurs.
063 * @throws NotStrictlyPositiveException if {@code n <= 0}.
064 */
065 public Logistic(double k,
066 double m,
067 double b,
068 double q,
069 double a,
070 double n)
071 throws NotStrictlyPositiveException {
072 if (n <= 0) {
073 throw new NotStrictlyPositiveException(n);
074 }
075
076 this.k = k;
077 this.m = m;
078 this.b = b;
079 this.q = q;
080 this.a = a;
081 oneOverN = 1 / n;
082 }
083
084 /** {@inheritDoc} */
085 public double value(double x) {
086 return value(m - x, k, b, q, a, oneOverN);
087 }
088
089 /** {@inheritDoc}
090 * @deprecated as of 3.1, replaced by {@link #value(DerivativeStructure)}
091 */
092 @Deprecated
093 public UnivariateFunction derivative() {
094 return FunctionUtils.toDifferentiableUnivariateFunction(this).derivative();
095 }
096
097 /**
098 * Parametric function where the input array contains the parameters of
099 * the logit function, ordered as follows:
100 * <ul>
101 * <li>Lower asymptote</li>
102 * <li>Higher asymptote</li>
103 * </ul>
104 */
105 public static class Parametric implements ParametricUnivariateFunction {
106 /**
107 * Computes the value of the sigmoid at {@code x}.
108 *
109 * @param x Value for which the function must be computed.
110 * @param param Values for {@code k}, {@code m}, {@code b}, {@code q},
111 * {@code a} and {@code n}.
112 * @return the value of the function.
113 * @throws NullArgumentException if {@code param} is {@code null}.
114 * @throws DimensionMismatchException if the size of {@code param} is
115 * not 6.
116 * @throws NotStrictlyPositiveException if {@code param[5] <= 0}.
117 */
118 public double value(double x, double ... param)
119 throws NullArgumentException,
120 DimensionMismatchException,
121 NotStrictlyPositiveException {
122 validateParameters(param);
123 return Logistic.value(param[1] - x, param[0],
124 param[2], param[3],
125 param[4], 1 / param[5]);
126 }
127
128 /**
129 * Computes the value of the gradient at {@code x}.
130 * The components of the gradient vector are the partial
131 * derivatives of the function with respect to each of the
132 * <em>parameters</em>.
133 *
134 * @param x Value at which the gradient must be computed.
135 * @param param Values for {@code k}, {@code m}, {@code b}, {@code q},
136 * {@code a} and {@code n}.
137 * @return the gradient vector at {@code x}.
138 * @throws NullArgumentException if {@code param} is {@code null}.
139 * @throws DimensionMismatchException if the size of {@code param} is
140 * not 6.
141 * @throws NotStrictlyPositiveException if {@code param[5] <= 0}.
142 */
143 public double[] gradient(double x, double ... param)
144 throws NullArgumentException,
145 DimensionMismatchException,
146 NotStrictlyPositiveException {
147 validateParameters(param);
148
149 final double b = param[2];
150 final double q = param[3];
151
152 final double mMinusX = param[1] - x;
153 final double oneOverN = 1 / param[5];
154 final double exp = FastMath.exp(b * mMinusX);
155 final double qExp = q * exp;
156 final double qExp1 = qExp + 1;
157 final double factor1 = (param[0] - param[4]) * oneOverN / FastMath.pow(qExp1, oneOverN);
158 final double factor2 = -factor1 / qExp1;
159
160 // Components of the gradient.
161 final double gk = Logistic.value(mMinusX, 1, b, q, 0, oneOverN);
162 final double gm = factor2 * b * qExp;
163 final double gb = factor2 * mMinusX * qExp;
164 final double gq = factor2 * exp;
165 final double ga = Logistic.value(mMinusX, 0, b, q, 1, oneOverN);
166 final double gn = factor1 * Math.log(qExp1) * oneOverN;
167
168 return new double[] { gk, gm, gb, gq, ga, gn };
169 }
170
171 /**
172 * Validates parameters to ensure they are appropriate for the evaluation of
173 * the {@link #value(double,double[])} and {@link #gradient(double,double[])}
174 * methods.
175 *
176 * @param param Values for {@code k}, {@code m}, {@code b}, {@code q},
177 * {@code a} and {@code n}.
178 * @throws NullArgumentException if {@code param} is {@code null}.
179 * @throws DimensionMismatchException if the size of {@code param} is
180 * not 6.
181 * @throws NotStrictlyPositiveException if {@code param[5] <= 0}.
182 */
183 private void validateParameters(double[] param)
184 throws NullArgumentException,
185 DimensionMismatchException,
186 NotStrictlyPositiveException {
187 if (param == null) {
188 throw new NullArgumentException();
189 }
190 if (param.length != 6) {
191 throw new DimensionMismatchException(param.length, 6);
192 }
193 if (param[5] <= 0) {
194 throw new NotStrictlyPositiveException(param[5]);
195 }
196 }
197 }
198
199 /**
200 * @param mMinusX {@code m - x}.
201 * @param k {@code k}.
202 * @param b {@code b}.
203 * @param q {@code q}.
204 * @param a {@code a}.
205 * @param oneOverN {@code 1 / n}.
206 * @return the value of the function.
207 */
208 private static double value(double mMinusX,
209 double k,
210 double b,
211 double q,
212 double a,
213 double oneOverN) {
214 return a + (k - a) / FastMath.pow(1 + q * FastMath.exp(b * mMinusX), oneOverN);
215 }
216
217 /** {@inheritDoc}
218 * @since 3.1
219 */
220 public DerivativeStructure value(final DerivativeStructure t) {
221 return t.negate().add(m).multiply(b).exp().multiply(q).add(1).pow(oneOverN).reciprocal().multiply(k - a).add(a);
222 }
223
224 }