001 /*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements. See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License. You may obtain a copy of the License at
008 *
009 * http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018 package org.apache.commons.math3.analysis.function;
019
020 import org.apache.commons.math3.analysis.DifferentiableUnivariateFunction;
021 import org.apache.commons.math3.analysis.FunctionUtils;
022 import org.apache.commons.math3.analysis.ParametricUnivariateFunction;
023 import org.apache.commons.math3.analysis.UnivariateFunction;
024 import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
025 import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiableFunction;
026 import org.apache.commons.math3.exception.DimensionMismatchException;
027 import org.apache.commons.math3.exception.NullArgumentException;
028 import org.apache.commons.math3.exception.OutOfRangeException;
029 import org.apache.commons.math3.util.FastMath;
030
031 /**
032 * <a href="http://en.wikipedia.org/wiki/Logit">
033 * Logit</a> function.
034 * It is the inverse of the {@link Sigmoid sigmoid} function.
035 *
036 * @since 3.0
037 * @version $Id: Logit.java 1391927 2012-09-30 00:03:30Z erans $
038 */
039 public class Logit implements UnivariateDifferentiableFunction, DifferentiableUnivariateFunction {
040 /** Lower bound. */
041 private final double lo;
042 /** Higher bound. */
043 private final double hi;
044
045 /**
046 * Usual logit function, where the lower bound is 0 and the higher
047 * bound is 1.
048 */
049 public Logit() {
050 this(0, 1);
051 }
052
053 /**
054 * Logit function.
055 *
056 * @param lo Lower bound of the function domain.
057 * @param hi Higher bound of the function domain.
058 */
059 public Logit(double lo,
060 double hi) {
061 this.lo = lo;
062 this.hi = hi;
063 }
064
065 /** {@inheritDoc} */
066 public double value(double x)
067 throws OutOfRangeException {
068 return value(x, lo, hi);
069 }
070
071 /** {@inheritDoc}
072 * @deprecated as of 3.1, replaced by {@link #value(DerivativeStructure)}
073 */
074 @Deprecated
075 public UnivariateFunction derivative() {
076 return FunctionUtils.toDifferentiableUnivariateFunction(this).derivative();
077 }
078
079 /**
080 * Parametric function where the input array contains the parameters of
081 * the logit function, ordered as follows:
082 * <ul>
083 * <li>Lower bound</li>
084 * <li>Higher bound</li>
085 * </ul>
086 */
087 public static class Parametric implements ParametricUnivariateFunction {
088 /**
089 * Computes the value of the logit at {@code x}.
090 *
091 * @param x Value for which the function must be computed.
092 * @param param Values of lower bound and higher bounds.
093 * @return the value of the function.
094 * @throws NullArgumentException if {@code param} is {@code null}.
095 * @throws DimensionMismatchException if the size of {@code param} is
096 * not 2.
097 */
098 public double value(double x, double ... param)
099 throws NullArgumentException,
100 DimensionMismatchException {
101 validateParameters(param);
102 return Logit.value(x, param[0], param[1]);
103 }
104
105 /**
106 * Computes the value of the gradient at {@code x}.
107 * The components of the gradient vector are the partial
108 * derivatives of the function with respect to each of the
109 * <em>parameters</em> (lower bound and higher bound).
110 *
111 * @param x Value at which the gradient must be computed.
112 * @param param Values for lower and higher bounds.
113 * @return the gradient vector at {@code x}.
114 * @throws NullArgumentException if {@code param} is {@code null}.
115 * @throws DimensionMismatchException if the size of {@code param} is
116 * not 2.
117 */
118 public double[] gradient(double x, double ... param)
119 throws NullArgumentException,
120 DimensionMismatchException {
121 validateParameters(param);
122
123 final double lo = param[0];
124 final double hi = param[1];
125
126 return new double[] { 1 / (lo - x), 1 / (hi - x) };
127 }
128
129 /**
130 * Validates parameters to ensure they are appropriate for the evaluation of
131 * the {@link #value(double,double[])} and {@link #gradient(double,double[])}
132 * methods.
133 *
134 * @param param Values for lower and higher bounds.
135 * @throws NullArgumentException if {@code param} is {@code null}.
136 * @throws DimensionMismatchException if the size of {@code param} is
137 * not 2.
138 */
139 private void validateParameters(double[] param)
140 throws NullArgumentException,
141 DimensionMismatchException {
142 if (param == null) {
143 throw new NullArgumentException();
144 }
145 if (param.length != 2) {
146 throw new DimensionMismatchException(param.length, 2);
147 }
148 }
149 }
150
151 /**
152 * @param x Value at which to compute the logit.
153 * @param lo Lower bound.
154 * @param hi Higher bound.
155 * @return the value of the logit function at {@code x}.
156 * @throws OutOfRangeException if {@code x < lo} or {@code x > hi}.
157 */
158 private static double value(double x,
159 double lo,
160 double hi)
161 throws OutOfRangeException {
162 if (x < lo || x > hi) {
163 throw new OutOfRangeException(x, lo, hi);
164 }
165 return FastMath.log((x - lo) / (hi - x));
166 }
167
168 /** {@inheritDoc}
169 * @since 3.1
170 * @exception OutOfRangeException if parameter is outside of function domain
171 */
172 public DerivativeStructure value(final DerivativeStructure t)
173 throws OutOfRangeException {
174 final double x = t.getValue();
175 if (x < lo || x > hi) {
176 throw new OutOfRangeException(x, lo, hi);
177 }
178 double[] f = new double[t.getOrder() + 1];
179
180 // function value
181 f[0] = FastMath.log((x - lo) / (hi - x));
182
183 if (Double.isInfinite(f[0])) {
184
185 if (f.length > 1) {
186 f[1] = Double.POSITIVE_INFINITY;
187 }
188 // fill the array with infinities
189 // (for x close to lo the signs will flip between -inf and +inf,
190 // for x close to hi the signs will always be +inf)
191 // this is probably overkill, since the call to compose at the end
192 // of the method will transform most infinities into NaN ...
193 for (int i = 2; i < f.length; ++i) {
194 f[i] = f[i - 2];
195 }
196
197 } else {
198
199 // function derivatives
200 final double invL = 1.0 / (x - lo);
201 double xL = invL;
202 final double invH = 1.0 / (hi - x);
203 double xH = invH;
204 for (int i = 1; i < f.length; ++i) {
205 f[i] = xL + xH;
206 xL *= -i * invL;
207 xH *= i * invH;
208 }
209 }
210
211 return t.compose(f);
212 }
213 }