001 /*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements. See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License. You may obtain a copy of the License at
008 *
009 * http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017 package org.apache.commons.math3.stat.regression;
018
019 import org.apache.commons.math3.exception.DimensionMismatchException;
020 import org.apache.commons.math3.exception.MathIllegalArgumentException;
021 import org.apache.commons.math3.exception.NoDataException;
022 import org.apache.commons.math3.exception.NullArgumentException;
023 import org.apache.commons.math3.exception.NumberIsTooSmallException;
024 import org.apache.commons.math3.exception.util.LocalizedFormats;
025 import org.apache.commons.math3.linear.NonSquareMatrixException;
026 import org.apache.commons.math3.linear.RealMatrix;
027 import org.apache.commons.math3.linear.Array2DRowRealMatrix;
028 import org.apache.commons.math3.linear.RealVector;
029 import org.apache.commons.math3.linear.ArrayRealVector;
030 import org.apache.commons.math3.stat.descriptive.moment.Variance;
031 import org.apache.commons.math3.util.FastMath;
032
033 /**
034 * Abstract base class for implementations of MultipleLinearRegression.
035 * @version $Id: AbstractMultipleLinearRegression.java 1416643 2012-12-03 19:37:14Z tn $
036 * @since 2.0
037 */
038 public abstract class AbstractMultipleLinearRegression implements
039 MultipleLinearRegression {
040
041 /** X sample data. */
042 private RealMatrix xMatrix;
043
044 /** Y sample data. */
045 private RealVector yVector;
046
047 /** Whether or not the regression model includes an intercept. True means no intercept. */
048 private boolean noIntercept = false;
049
050 /**
051 * @return the X sample data.
052 */
053 protected RealMatrix getX() {
054 return xMatrix;
055 }
056
057 /**
058 * @return the Y sample data.
059 */
060 protected RealVector getY() {
061 return yVector;
062 }
063
064 /**
065 * @return true if the model has no intercept term; false otherwise
066 * @since 2.2
067 */
068 public boolean isNoIntercept() {
069 return noIntercept;
070 }
071
072 /**
073 * @param noIntercept true means the model is to be estimated without an intercept term
074 * @since 2.2
075 */
076 public void setNoIntercept(boolean noIntercept) {
077 this.noIntercept = noIntercept;
078 }
079
080 /**
081 * <p>Loads model x and y sample data from a flat input array, overriding any previous sample.
082 * </p>
083 * <p>Assumes that rows are concatenated with y values first in each row. For example, an input
084 * <code>data</code> array containing the sequence of values (1, 2, 3, 4, 5, 6, 7, 8, 9) with
085 * <code>nobs = 3</code> and <code>nvars = 2</code> creates a regression dataset with two
086 * independent variables, as below:
087 * <pre>
088 * y x[0] x[1]
089 * --------------
090 * 1 2 3
091 * 4 5 6
092 * 7 8 9
093 * </pre>
094 * </p>
095 * <p>Note that there is no need to add an initial unitary column (column of 1's) when
096 * specifying a model including an intercept term. If {@link #isNoIntercept()} is <code>true</code>,
097 * the X matrix will be created without an initial column of "1"s; otherwise this column will
098 * be added.
099 * </p>
100 * <p>Throws IllegalArgumentException if any of the following preconditions fail:
101 * <ul><li><code>data</code> cannot be null</li>
102 * <li><code>data.length = nobs * (nvars + 1)</li>
103 * <li><code>nobs > nvars</code></li></ul>
104 * </p>
105 *
106 * @param data input data array
107 * @param nobs number of observations (rows)
108 * @param nvars number of independent variables (columns, not counting y)
109 * @throws NullArgumentException if the data array is null
110 * @throws DimensionMismatchException if the length of the data array is not equal
111 * to <code>nobs * (nvars + 1)</code>
112 * @throws NumberIsTooSmallException if <code>nobs</code> is smaller than
113 * <code>nvars</code>
114 */
115 public void newSampleData(double[] data, int nobs, int nvars) {
116 if (data == null) {
117 throw new NullArgumentException();
118 }
119 if (data.length != nobs * (nvars + 1)) {
120 throw new DimensionMismatchException(data.length, nobs * (nvars + 1));
121 }
122 if (nobs <= nvars) {
123 throw new NumberIsTooSmallException(nobs, nvars, false);
124 }
125 double[] y = new double[nobs];
126 final int cols = noIntercept ? nvars: nvars + 1;
127 double[][] x = new double[nobs][cols];
128 int pointer = 0;
129 for (int i = 0; i < nobs; i++) {
130 y[i] = data[pointer++];
131 if (!noIntercept) {
132 x[i][0] = 1.0d;
133 }
134 for (int j = noIntercept ? 0 : 1; j < cols; j++) {
135 x[i][j] = data[pointer++];
136 }
137 }
138 this.xMatrix = new Array2DRowRealMatrix(x);
139 this.yVector = new ArrayRealVector(y);
140 }
141
142 /**
143 * Loads new y sample data, overriding any previous data.
144 *
145 * @param y the array representing the y sample
146 * @throws NullArgumentException if y is null
147 * @throws NoDataException if y is empty
148 */
149 protected void newYSampleData(double[] y) {
150 if (y == null) {
151 throw new NullArgumentException();
152 }
153 if (y.length == 0) {
154 throw new NoDataException();
155 }
156 this.yVector = new ArrayRealVector(y);
157 }
158
159 /**
160 * <p>Loads new x sample data, overriding any previous data.
161 * </p>
162 * The input <code>x</code> array should have one row for each sample
163 * observation, with columns corresponding to independent variables.
164 * For example, if <pre>
165 * <code> x = new double[][] {{1, 2}, {3, 4}, {5, 6}} </code></pre>
166 * then <code>setXSampleData(x) </code> results in a model with two independent
167 * variables and 3 observations:
168 * <pre>
169 * x[0] x[1]
170 * ----------
171 * 1 2
172 * 3 4
173 * 5 6
174 * </pre>
175 * </p>
176 * <p>Note that there is no need to add an initial unitary column (column of 1's) when
177 * specifying a model including an intercept term.
178 * </p>
179 * @param x the rectangular array representing the x sample
180 * @throws NullArgumentException if x is null
181 * @throws NoDataException if x is empty
182 * @throws DimensionMismatchException if x is not rectangular
183 */
184 protected void newXSampleData(double[][] x) {
185 if (x == null) {
186 throw new NullArgumentException();
187 }
188 if (x.length == 0) {
189 throw new NoDataException();
190 }
191 if (noIntercept) {
192 this.xMatrix = new Array2DRowRealMatrix(x, true);
193 } else { // Augment design matrix with initial unitary column
194 final int nVars = x[0].length;
195 final double[][] xAug = new double[x.length][nVars + 1];
196 for (int i = 0; i < x.length; i++) {
197 if (x[i].length != nVars) {
198 throw new DimensionMismatchException(x[i].length, nVars);
199 }
200 xAug[i][0] = 1.0d;
201 System.arraycopy(x[i], 0, xAug[i], 1, nVars);
202 }
203 this.xMatrix = new Array2DRowRealMatrix(xAug, false);
204 }
205 }
206
207 /**
208 * Validates sample data. Checks that
209 * <ul><li>Neither x nor y is null or empty;</li>
210 * <li>The length (i.e. number of rows) of x equals the length of y</li>
211 * <li>x has at least one more row than it has columns (i.e. there is
212 * sufficient data to estimate regression coefficients for each of the
213 * columns in x plus an intercept.</li>
214 * </ul>
215 *
216 * @param x the [n,k] array representing the x data
217 * @param y the [n,1] array representing the y data
218 * @throws NullArgumentException if {@code x} or {@code y} is null
219 * @throws DimensionMismatchException if {@code x} and {@code y} do not
220 * have the same length
221 * @throws NoDataException if {@code x} or {@code y} are zero-length
222 * @throws MathIllegalArgumentException if the number of rows of {@code x}
223 * is not larger than the number of columns + 1
224 */
225 protected void validateSampleData(double[][] x, double[] y) throws MathIllegalArgumentException {
226 if ((x == null) || (y == null)) {
227 throw new NullArgumentException();
228 }
229 if (x.length != y.length) {
230 throw new DimensionMismatchException(y.length, x.length);
231 }
232 if (x.length == 0) { // Must be no y data either
233 throw new NoDataException();
234 }
235 if (x[0].length + 1 > x.length) {
236 throw new MathIllegalArgumentException(
237 LocalizedFormats.NOT_ENOUGH_DATA_FOR_NUMBER_OF_PREDICTORS,
238 x.length, x[0].length);
239 }
240 }
241
242 /**
243 * Validates that the x data and covariance matrix have the same
244 * number of rows and that the covariance matrix is square.
245 *
246 * @param x the [n,k] array representing the x sample
247 * @param covariance the [n,n] array representing the covariance matrix
248 * @throws DimensionMismatchException if the number of rows in x is not equal
249 * to the number of rows in covariance
250 * @throws NonSquareMatrixException if the covariance matrix is not square
251 */
252 protected void validateCovarianceData(double[][] x, double[][] covariance) {
253 if (x.length != covariance.length) {
254 throw new DimensionMismatchException(x.length, covariance.length);
255 }
256 if (covariance.length > 0 && covariance.length != covariance[0].length) {
257 throw new NonSquareMatrixException(covariance.length, covariance[0].length);
258 }
259 }
260
261 /**
262 * {@inheritDoc}
263 */
264 public double[] estimateRegressionParameters() {
265 RealVector b = calculateBeta();
266 return b.toArray();
267 }
268
269 /**
270 * {@inheritDoc}
271 */
272 public double[] estimateResiduals() {
273 RealVector b = calculateBeta();
274 RealVector e = yVector.subtract(xMatrix.operate(b));
275 return e.toArray();
276 }
277
278 /**
279 * {@inheritDoc}
280 */
281 public double[][] estimateRegressionParametersVariance() {
282 return calculateBetaVariance().getData();
283 }
284
285 /**
286 * {@inheritDoc}
287 */
288 public double[] estimateRegressionParametersStandardErrors() {
289 double[][] betaVariance = estimateRegressionParametersVariance();
290 double sigma = calculateErrorVariance();
291 int length = betaVariance[0].length;
292 double[] result = new double[length];
293 for (int i = 0; i < length; i++) {
294 result[i] = FastMath.sqrt(sigma * betaVariance[i][i]);
295 }
296 return result;
297 }
298
299 /**
300 * {@inheritDoc}
301 */
302 public double estimateRegressandVariance() {
303 return calculateYVariance();
304 }
305
306 /**
307 * Estimates the variance of the error.
308 *
309 * @return estimate of the error variance
310 * @since 2.2
311 */
312 public double estimateErrorVariance() {
313 return calculateErrorVariance();
314
315 }
316
317 /**
318 * Estimates the standard error of the regression.
319 *
320 * @return regression standard error
321 * @since 2.2
322 */
323 public double estimateRegressionStandardError() {
324 return Math.sqrt(estimateErrorVariance());
325 }
326
327 /**
328 * Calculates the beta of multiple linear regression in matrix notation.
329 *
330 * @return beta
331 */
332 protected abstract RealVector calculateBeta();
333
334 /**
335 * Calculates the beta variance of multiple linear regression in matrix
336 * notation.
337 *
338 * @return beta variance
339 */
340 protected abstract RealMatrix calculateBetaVariance();
341
342
343 /**
344 * Calculates the variance of the y values.
345 *
346 * @return Y variance
347 */
348 protected double calculateYVariance() {
349 return new Variance().evaluate(yVector.toArray());
350 }
351
352 /**
353 * <p>Calculates the variance of the error term.</p>
354 * Uses the formula <pre>
355 * var(u) = u · u / (n - k)
356 * </pre>
357 * where n and k are the row and column dimensions of the design
358 * matrix X.
359 *
360 * @return error variance estimate
361 * @since 2.2
362 */
363 protected double calculateErrorVariance() {
364 RealVector residuals = calculateResiduals();
365 return residuals.dotProduct(residuals) /
366 (xMatrix.getRowDimension() - xMatrix.getColumnDimension());
367 }
368
369 /**
370 * Calculates the residuals of multiple linear regression in matrix
371 * notation.
372 *
373 * <pre>
374 * u = y - X * b
375 * </pre>
376 *
377 * @return The residuals [n,1] matrix
378 */
379 protected RealVector calculateResiduals() {
380 RealVector b = calculateBeta();
381 return yVector.subtract(xMatrix.operate(b));
382 }
383
384 }