package hex.glm;

import hex.DataInfo;
import hex.glm.GLMModel;
import hex.util.LinearAlgebraUtils;
import java.util.Arrays;
import water.Job;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.NewChunk;
import water.util.ArrayUtils;

/* loaded from: input_file:hex/glm/RegressionInfluenceDiagnosticsTasks.class */
public class RegressionInfluenceDiagnosticsTasks {

    /* loaded from: input_file:hex/glm/RegressionInfluenceDiagnosticsTasks$ComputeNewBetaVarEstimatedGaussian.class */
    public static class ComputeNewBetaVarEstimatedGaussian extends MRTask<ComputeNewBetaVarEstimatedGaussian> {
        final double[][] _cholInv;
        final double[] _xTransY;
        final double[] _xTransYReduced;
        final int _betaSize;
        final int _reducedBetaSize;
        final int _newChunkWidth;
        final Job _j;
        final DataInfo _dinfo;
        final double[][] _xTx;
        final double _weightedNobs;
        final double _sumRespSq;
        final boolean _foundRedCols;
        final double[] _stdErr;

        public ComputeNewBetaVarEstimatedGaussian(double[][] dArr, double[] dArr2, Job job, DataInfo dataInfo, double[][] dArr3, double d, double d2, double[] dArr4) {
            this._cholInv = dArr;
            this._xTransYReduced = dArr2;
            this._betaSize = dArr4.length;
            this._reducedBetaSize = dArr.length;
            this._foundRedCols = this._betaSize != this._reducedBetaSize;
            this._newChunkWidth = this._betaSize + 1;
            this._j = job;
            this._dinfo = dataInfo;
            this._xTx = dArr3;
            this._weightedNobs = d - this._reducedBetaSize;
            this._sumRespSq = d2;
            this._stdErr = dArr4;
            this._xTransY = new double[this._betaSize];
            if (!this._foundRedCols) {
                System.arraycopy(this._xTransYReduced, 0, this._xTransY, 0, this._reducedBetaSize);
                return;
            }
            int i = 0;
            for (int i2 = 0; i2 < this._betaSize; i2++) {
                if (!Double.isNaN(dArr4[i2])) {
                    int i3 = i;
                    i++;
                    this._xTransY[i2] = this._xTransYReduced[i3];
                }
            }
        }

        public void map(Chunk[] chunkArr, NewChunk[] newChunkArr) {
            if (isCancelled()) {
                return;
            }
            if (this._j == null || !this._j.stop_requested()) {
                double[] dArr = new double[this._betaSize];
                double[] dArr2 = new double[this._reducedBetaSize];
                double[] dArr3 = new double[this._betaSize];
                double[] dArr4 = new double[this._reducedBetaSize];
                double[][] dArr5 = new double[this._reducedBetaSize][this._reducedBetaSize];
                double[] dArr6 = new double[this._betaSize];
                double[] dArr7 = new double[this._reducedBetaSize];
                int i = chunkArr[0]._len;
                DataInfo.Row newDenseRow = this._dinfo.newDenseRow();
                for (int i2 = 0; i2 < i; i2++) {
                    this._dinfo.extractDenseRow(chunkArr, i2, newDenseRow);
                    getNewBetaVarEstimate(newDenseRow, newChunkArr, dArr3, dArr4, dArr, dArr2, dArr6, dArr7, dArr5);
                }
                if (this._j != null) {
                    this._j.update(1L);
                }
            }
        }

        private void getNewBetaVarEstimate(DataInfo.Row row, NewChunk[] newChunkArr, double[] dArr, double[] dArr2, double[] dArr3, double[] dArr4, double[] dArr5, double[] dArr6, double[][] dArr7) {
            if (row.response_bad) {
                if (this._foundRedCols) {
                    Arrays.fill(dArr4, Double.NaN);
                    writeNewChunk(dArr4, newChunkArr, Double.NaN);
                    return;
                } else {
                    Arrays.fill(dArr3, Double.NaN);
                    writeNewChunk(dArr3, newChunkArr, Double.NaN);
                    return;
                }
            }
            if (row.weight == 0.0d) {
                if (this._foundRedCols) {
                    Arrays.fill(dArr4, 0.0d);
                    writeNewChunk(dArr4, newChunkArr, 0.0d);
                    return;
                } else {
                    Arrays.fill(dArr3, 0.0d);
                    writeNewChunk(dArr3, newChunkArr, 0.0d);
                    return;
                }
            }
            row.expandCatsPredsOnly(dArr);
            if (this._foundRedCols) {
                GLMUtils.removeRedCols(dArr, dArr2, this._stdErr);
                ArrayUtils.outerProduct(dArr7, dArr2, dArr2);
            } else {
                ArrayUtils.outerProduct(dArr7, dArr, dArr);
            }
            double[][] matrixMultiply = LinearAlgebraUtils.matrixMultiply(LinearAlgebraUtils.matrixMultiply(this._cholInv, dArr7), this._cholInv);
            if (!this._foundRedCols) {
                genNewBetas(dArr, dArr5, dArr3, row, matrixMultiply);
                writeNewChunk(dArr3, newChunkArr, genVarEstimate(row, dArr5, dArr3, dArr3));
            } else {
                genNewBetas(dArr2, dArr6, dArr4, row, matrixMultiply);
                fillBetaRed2Full(dArr4, dArr3);
                writeNewChunk(dArr4, newChunkArr, genVarEstimate(row, dArr6, dArr4, dArr3));
            }
        }

        private void fillBetaRed2Full(double[] dArr, double[] dArr2) {
            int i = 0;
            for (int i2 = 0; i2 < this._betaSize; i2++) {
                if (Double.isNaN(this._stdErr[i2])) {
                    dArr2[i2] = 0.0d;
                } else {
                    int i3 = i;
                    i++;
                    dArr2[i2] = dArr[i3];
                }
            }
        }

        private void genNewBetas(double[] dArr, double[] dArr2, double[] dArr3, DataInfo.Row row, double[][] dArr4) {
            ArrayUtils.multArrVec(this._cholInv, dArr, dArr2);
            ArrayUtils.mult(dArr4, 1.0d / (1.0d - ArrayUtils.innerProduct(dArr, dArr2)));
            ArrayUtils.add(dArr4, this._cholInv);
            double[] mult = ArrayUtils.mult(dArr, -row.response(0));
            ArrayUtils.add(mult, this._xTransYReduced);
            ArrayUtils.multArrVec(dArr4, mult, dArr3);
        }

        private void writeNewChunk(double[] dArr, NewChunk[] newChunkArr, double d) {
            for (int i = 0; i < this._reducedBetaSize; i++) {
                newChunkArr[i].addNum(dArr[i]);
            }
            newChunkArr[this._reducedBetaSize].addNum(d);
        }

        private double genVarEstimate(DataInfo.Row row, double[] dArr, double[] dArr2, double[] dArr3) {
            double response = row.response(0) - row.innerProduct(dArr3);
            double d = row.weight * response * response;
            ArrayUtils.multArrVec(this._xTx, dArr2, dArr);
            return (((this._sumRespSq - (2.0d * ArrayUtils.innerProduct(dArr2, this._xTransYReduced))) + ArrayUtils.innerProduct(dArr2, dArr)) - d) / (this._weightedNobs - row.weight);
        }
    }

    /* loaded from: input_file:hex/glm/RegressionInfluenceDiagnosticsTasks$RegressionInfluenceDiagBinomial.class */
    public static class RegressionInfluenceDiagBinomial extends MRTask<RegressionInfluenceDiagBinomial> {
        final double[] _beta;
        final double[][] _gramInv;
        final Job _j;
        final int _betaSize;
        final int _reducedBetaSize;
        final GLMModel.GLMParameters _parms;
        final DataInfo _dinfo;
        final double[] _stdErr;
        final boolean _foundRedCols;
        final double[] _oneOverStdErr;

        public RegressionInfluenceDiagBinomial(Job job, double[] dArr, double[][] dArr2, GLMModel.GLMParameters gLMParameters, DataInfo dataInfo, double[] dArr3) {
            this._j = job;
            this._beta = dArr;
            this._betaSize = dArr.length;
            this._reducedBetaSize = dArr2.length;
            this._foundRedCols = this._betaSize != this._reducedBetaSize;
            this._gramInv = dArr2;
            this._parms = gLMParameters;
            this._dinfo = dataInfo;
            this._stdErr = dArr3;
            this._oneOverStdErr = Arrays.stream(this._stdErr).map(d -> {
                return 1.0d / d;
            }).toArray();
        }

        public void map(Chunk[] chunkArr, NewChunk[] newChunkArr) {
            if (isCancelled()) {
                return;
            }
            if (this._j == null || !this._j.stop_requested()) {
                double[] dArr = new double[this._betaSize];
                double[] dArr2 = new double[this._reducedBetaSize];
                double[] dArr3 = new double[this._betaSize];
                double[] dArr4 = new double[this._reducedBetaSize];
                double[] dArr5 = new double[this._reducedBetaSize];
                DataInfo.Row newDenseRow = this._dinfo.newDenseRow();
                for (int i = 0; i < chunkArr[0]._len; i++) {
                    this._dinfo.extractDenseRow(chunkArr, i, newDenseRow);
                    genDfBetasRow(newDenseRow, newChunkArr, dArr3, dArr4, dArr, dArr2, dArr5);
                }
                if (this._j != null) {
                    this._j.update(1L);
                }
            }
        }

        private void genDfBetasRow(DataInfo.Row row, NewChunk[] newChunkArr, double[] dArr, double[] dArr2, double[] dArr3, double[] dArr4, double[] dArr5) {
            if (row.response_bad) {
                Arrays.fill(dArr3, Double.NaN);
                return;
            }
            if (row.weight == 0.0d) {
                Arrays.fill(dArr3, 0.0d);
                return;
            }
            row.expandCatsPredsOnly(dArr);
            if (!this._foundRedCols) {
                genDfBeta(row, dArr, dArr5, dArr3, newChunkArr);
            } else {
                GLMUtils.removeRedCols(dArr, dArr2, this._stdErr);
                genDfBeta(row, dArr2, dArr5, dArr4, newChunkArr);
            }
        }

        private void genDfBeta(DataInfo.Row row, double[] dArr, double[] dArr2, double[] dArr3, NewChunk[] newChunkArr) {
            double linkInv = this._parms.linkInv(row.innerProduct(this._beta) + row.offset);
            genDfBetas(gen1OverMLL(dArr, dArr2, linkInv, row.weight), row.response(0) - linkInv, dArr, dArr3, row.weight);
            for (int i = 0; i < this._reducedBetaSize; i++) {
                newChunkArr[i].addNum(dArr3[i]);
            }
        }

        public void genDfBetas(double d, double d2, double[] dArr, double[] dArr2, double d3) {
            double d4 = d * d2 * d3;
            int i = 0;
            for (int i2 = 0; i2 < this._betaSize; i2++) {
                if (!Double.isNaN(this._stdErr[i2])) {
                    dArr2[i] = d4 * this._oneOverStdErr[i2] * ArrayUtils.innerProduct(dArr, this._gramInv[i]);
                    i++;
                }
            }
        }

        public double gen1OverMLL(double[] dArr, double[] dArr2, double d, double d2) {
            for (int i = 0; i < this._reducedBetaSize; i++) {
                dArr2[i] = ArrayUtils.innerProduct(dArr, this._gramInv[i]);
            }
            return 1.0d / (1.0d - (((d2 * d) * (1.0d - d)) * ArrayUtils.innerProduct(dArr2, dArr)));
        }
    }

    /* loaded from: input_file:hex/glm/RegressionInfluenceDiagnosticsTasks$RegressionInfluenceDiagGaussian.class */
    public static class RegressionInfluenceDiagGaussian extends MRTask<RegressionInfluenceDiagGaussian> {
        final double[] _oneOverSqrtXTXDiag;
        final double[] _betas;
        final int _betaSize;
        final Job _j;

        public RegressionInfluenceDiagGaussian(double[][] dArr, double[] dArr2, Job job) {
            this._betas = dArr2;
            this._betaSize = dArr2.length;
            this._j = job;
            this._oneOverSqrtXTXDiag = new double[this._betaSize];
            for (int i = 0; i < this._betaSize; i++) {
                this._oneOverSqrtXTXDiag[i] = 1.0d / Math.sqrt(dArr[i][i]);
            }
        }

        public void map(Chunk[] chunkArr, NewChunk[] newChunkArr) {
            if (isCancelled()) {
                return;
            }
            if (this._j == null || !this._j.stop_requested()) {
                double[] dArr = new double[this._betaSize];
                int length = chunkArr.length;
                double[] dArr2 = new double[length];
                int i = chunkArr[0]._len;
                for (int i2 = 0; i2 < i; i2++) {
                    readRow2Array(dArr2, chunkArr, i2, length);
                    setBetaDiff(dArr, dArr2, newChunkArr);
                }
            }
        }

        private void setBetaDiff(double[] dArr, double[] dArr2, NewChunk[] newChunkArr) {
            if (Double.isFinite(dArr2[0])) {
                double sqrt = 1.0d / Math.sqrt(dArr2[this._betaSize]);
                for (int i = 0; i < this._betaSize; i++) {
                    dArr[i] = (this._betas[i] - dArr2[i]) * sqrt * this._oneOverSqrtXTXDiag[i];
                }
            } else {
                Arrays.fill(dArr, Double.NaN);
            }
            for (int i2 = 0; i2 < this._betaSize; i2++) {
                newChunkArr[i2].addNum(dArr[i2]);
            }
        }

        private void readRow2Array(double[] dArr, Chunk[] chunkArr, int i, int i2) {
            for (int i3 = 0; i3 < i2; i3++) {
                dArr[i3] = chunkArr[i3].atd(i);
            }
        }
    }
}
