001 /*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements. See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License. You may obtain a copy of the License at
008 *
009 * http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017 package org.apache.commons.math3.stat.regression;
018
019 import org.apache.commons.math3.exception.MathIllegalArgumentException;
020 import org.apache.commons.math3.linear.Array2DRowRealMatrix;
021 import org.apache.commons.math3.linear.LUDecomposition;
022 import org.apache.commons.math3.linear.QRDecomposition;
023 import org.apache.commons.math3.linear.RealMatrix;
024 import org.apache.commons.math3.linear.RealVector;
025 import org.apache.commons.math3.stat.StatUtils;
026 import org.apache.commons.math3.stat.descriptive.moment.SecondMoment;
027
028 /**
029 * <p>Implements ordinary least squares (OLS) to estimate the parameters of a
030 * multiple linear regression model.</p>
031 *
032 * <p>The regression coefficients, <code>b</code>, satisfy the normal equations:
033 * <pre><code> X<sup>T</sup> X b = X<sup>T</sup> y </code></pre></p>
034 *
035 * <p>To solve the normal equations, this implementation uses QR decomposition
036 * of the <code>X</code> matrix. (See {@link QRDecomposition} for details on the
037 * decomposition algorithm.) The <code>X</code> matrix, also known as the <i>design matrix,</i>
038 * has rows corresponding to sample observations and columns corresponding to independent
039 * variables. When the model is estimated using an intercept term (i.e. when
040 * {@link #isNoIntercept() isNoIntercept} is false as it is by default), the <code>X</code>
041 * matrix includes an initial column identically equal to 1. We solve the normal equations
042 * as follows:
043 * <pre><code> X<sup>T</sup>X b = X<sup>T</sup> y
044 * (QR)<sup>T</sup> (QR) b = (QR)<sup>T</sup>y
045 * R<sup>T</sup> (Q<sup>T</sup>Q) R b = R<sup>T</sup> Q<sup>T</sup> y
046 * R<sup>T</sup> R b = R<sup>T</sup> Q<sup>T</sup> y
047 * (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> R b = (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> Q<sup>T</sup> y
048 * R b = Q<sup>T</sup> y </code></pre></p>
049 *
050 * <p>Given <code>Q</code> and <code>R</code>, the last equation is solved by back-substitution.</p>
051 *
052 * @version $Id: OLSMultipleLinearRegression.java 1416643 2012-12-03 19:37:14Z tn $
053 * @since 2.0
054 */
055 public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegression {
056
057 /** Cached QR decomposition of X matrix */
058 private QRDecomposition qr = null;
059
060 /**
061 * Loads model x and y sample data, overriding any previous sample.
062 *
063 * Computes and caches QR decomposition of the X matrix.
064 * @param y the [n,1] array representing the y sample
065 * @param x the [n,k] array representing the x sample
066 * @throws MathIllegalArgumentException if the x and y array data are not
067 * compatible for the regression
068 */
069 public void newSampleData(double[] y, double[][] x) throws MathIllegalArgumentException {
070 validateSampleData(x, y);
071 newYSampleData(y);
072 newXSampleData(x);
073 }
074
075 /**
076 * {@inheritDoc}
077 * <p>This implementation computes and caches the QR decomposition of the X matrix.</p>
078 */
079 @Override
080 public void newSampleData(double[] data, int nobs, int nvars) {
081 super.newSampleData(data, nobs, nvars);
082 qr = new QRDecomposition(getX());
083 }
084
085 /**
086 * <p>Compute the "hat" matrix.
087 * </p>
088 * <p>The hat matrix is defined in terms of the design matrix X
089 * by X(X<sup>T</sup>X)<sup>-1</sup>X<sup>T</sup>
090 * </p>
091 * <p>The implementation here uses the QR decomposition to compute the
092 * hat matrix as Q I<sub>p</sub>Q<sup>T</sup> where I<sub>p</sub> is the
093 * p-dimensional identity matrix augmented by 0's. This computational
094 * formula is from "The Hat Matrix in Regression and ANOVA",
095 * David C. Hoaglin and Roy E. Welsch,
096 * <i>The American Statistician</i>, Vol. 32, No. 1 (Feb., 1978), pp. 17-22.
097 * </p>
098 * <p>Data for the model must have been successfully loaded using one of
099 * the {@code newSampleData} methods before invoking this method; otherwise
100 * a {@code NullPointerException} will be thrown.</p>
101 *
102 * @return the hat matrix
103 */
104 public RealMatrix calculateHat() {
105 // Create augmented identity matrix
106 RealMatrix Q = qr.getQ();
107 final int p = qr.getR().getColumnDimension();
108 final int n = Q.getColumnDimension();
109 // No try-catch or advertised NotStrictlyPositiveException - NPE above if n < 3
110 Array2DRowRealMatrix augI = new Array2DRowRealMatrix(n, n);
111 double[][] augIData = augI.getDataRef();
112 for (int i = 0; i < n; i++) {
113 for (int j =0; j < n; j++) {
114 if (i == j && i < p) {
115 augIData[i][j] = 1d;
116 } else {
117 augIData[i][j] = 0d;
118 }
119 }
120 }
121
122 // Compute and return Hat matrix
123 // No DME advertised - args valid if we get here
124 return Q.multiply(augI).multiply(Q.transpose());
125 }
126
127 /**
128 * <p>Returns the sum of squared deviations of Y from its mean.</p>
129 *
130 * <p>If the model has no intercept term, <code>0</code> is used for the
131 * mean of Y - i.e., what is returned is the sum of the squared Y values.</p>
132 *
133 * <p>The value returned by this method is the SSTO value used in
134 * the {@link #calculateRSquared() R-squared} computation.</p>
135 *
136 * @return SSTO - the total sum of squares
137 * @throws MathIllegalArgumentException if the sample has not been set or does
138 * not contain at least 3 observations
139 * @see #isNoIntercept()
140 * @since 2.2
141 */
142 public double calculateTotalSumOfSquares() throws MathIllegalArgumentException {
143 if (isNoIntercept()) {
144 return StatUtils.sumSq(getY().toArray());
145 } else {
146 return new SecondMoment().evaluate(getY().toArray());
147 }
148 }
149
150 /**
151 * Returns the sum of squared residuals.
152 *
153 * @return residual sum of squares
154 * @since 2.2
155 */
156 public double calculateResidualSumOfSquares() {
157 final RealVector residuals = calculateResiduals();
158 // No advertised DME, args are valid
159 return residuals.dotProduct(residuals);
160 }
161
162 /**
163 * Returns the R-Squared statistic, defined by the formula <pre>
164 * R<sup>2</sup> = 1 - SSR / SSTO
165 * </pre>
166 * where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals}
167 * and SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares}
168 *
169 * @return R-square statistic
170 * @throws MathIllegalArgumentException if the sample has not been set or does
171 * not contain at least 3 observations
172 * @since 2.2
173 */
174 public double calculateRSquared() throws MathIllegalArgumentException {
175 return 1 - calculateResidualSumOfSquares() / calculateTotalSumOfSquares();
176 }
177
178 /**
179 * <p>Returns the adjusted R-squared statistic, defined by the formula <pre>
180 * R<sup>2</sup><sub>adj</sub> = 1 - [SSR (n - 1)] / [SSTO (n - p)]
181 * </pre>
182 * where SSR is the {@link #calculateResidualSumOfSquares() sum of squared residuals},
183 * SSTO is the {@link #calculateTotalSumOfSquares() total sum of squares}, n is the number
184 * of observations and p is the number of parameters estimated (including the intercept).</p>
185 *
186 * <p>If the regression is estimated without an intercept term, what is returned is <pre>
187 * <code> 1 - (1 - {@link #calculateRSquared()}) * (n / (n - p)) </code>
188 * </pre></p>
189 *
190 * @return adjusted R-Squared statistic
191 * @throws MathIllegalArgumentException if the sample has not been set or does
192 * not contain at least 3 observations
193 * @see #isNoIntercept()
194 * @since 2.2
195 */
196 public double calculateAdjustedRSquared() throws MathIllegalArgumentException {
197 final double n = getX().getRowDimension();
198 if (isNoIntercept()) {
199 return 1 - (1 - calculateRSquared()) * (n / (n - getX().getColumnDimension()));
200 } else {
201 return 1 - (calculateResidualSumOfSquares() * (n - 1)) /
202 (calculateTotalSumOfSquares() * (n - getX().getColumnDimension()));
203 }
204 }
205
206 /**
207 * {@inheritDoc}
208 * <p>This implementation computes and caches the QR decomposition of the X matrix
209 * once it is successfully loaded.</p>
210 */
211 @Override
212 protected void newXSampleData(double[][] x) {
213 super.newXSampleData(x);
214 qr = new QRDecomposition(getX());
215 }
216
217 /**
218 * Calculates the regression coefficients using OLS.
219 *
220 * <p>Data for the model must have been successfully loaded using one of
221 * the {@code newSampleData} methods before invoking this method; otherwise
222 * a {@code NullPointerException} will be thrown.</p>
223 *
224 * @return beta
225 */
226 @Override
227 protected RealVector calculateBeta() {
228 return qr.getSolver().solve(getY());
229 }
230
231 /**
232 * <p>Calculates the variance-covariance matrix of the regression parameters.
233 * </p>
234 * <p>Var(b) = (X<sup>T</sup>X)<sup>-1</sup>
235 * </p>
236 * <p>Uses QR decomposition to reduce (X<sup>T</sup>X)<sup>-1</sup>
237 * to (R<sup>T</sup>R)<sup>-1</sup>, with only the top p rows of
238 * R included, where p = the length of the beta vector.</p>
239 *
240 * <p>Data for the model must have been successfully loaded using one of
241 * the {@code newSampleData} methods before invoking this method; otherwise
242 * a {@code NullPointerException} will be thrown.</p>
243 *
244 * @return The beta variance-covariance matrix
245 */
246 @Override
247 protected RealMatrix calculateBetaVariance() {
248 int p = getX().getColumnDimension();
249 RealMatrix Raug = qr.getR().getSubMatrix(0, p - 1 , 0, p - 1);
250 RealMatrix Rinv = new LUDecomposition(Raug).getSolver().getInverse();
251 return Rinv.multiply(Rinv.transpose());
252 }
253
254 }