package water.rapids;

import java.util.Arrays;
import water.Key;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.rapids.Env;
import water.util.ArrayUtils;

/* loaded from: input_file:water/rapids/ASTVariance.class */
class ASTVariance extends ASTPrim {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:water/rapids/ASTVariance$CoVarTaskCompleteObs.class */
    public static class CoVarTaskCompleteObs extends MRTask<CoVarTaskCompleteObs> {
        double[][] _covs;
        final double[] _xmeans;
        final double[] _ymeans;

        CoVarTaskCompleteObs(double[] dArr, double[] dArr2) {
            this._ymeans = dArr;
            this._xmeans = dArr2;
        }

        @Override // water.MRTask
        public void map(Chunk[] chunkArr) {
            int length = this._xmeans.length;
            int length2 = this._ymeans.length;
            double[] dArr = new double[length];
            double[] dArr2 = new double[length2];
            this._covs = new double[length2][length];
            int i = chunkArr[0]._len;
            for (int i2 = 0; i2 < i; i2++) {
                boolean z = true;
                Arrays.fill(dArr, 0.0d);
                Arrays.fill(dArr2, 0.0d);
                int i3 = 0;
                while (true) {
                    if (i3 >= length2) {
                        break;
                    }
                    double atd = chunkArr[i3].atd(i2);
                    if (Double.isNaN(atd)) {
                        z = false;
                        break;
                    } else {
                        dArr2[i3] = atd;
                        i3++;
                    }
                }
                if (z) {
                    int i4 = 0;
                    while (true) {
                        if (i4 >= length) {
                            break;
                        }
                        double atd2 = chunkArr[i4 + length2].atd(i2);
                        if (Double.isNaN(atd2)) {
                            z = false;
                            break;
                        } else {
                            dArr[i4] = atd2;
                            i4++;
                        }
                    }
                }
                if (z) {
                    for (int i5 = 0; i5 < length2; i5++) {
                        double[] dArr3 = this._covs[i5];
                        double d = dArr2[i5];
                        double d2 = this._ymeans[i5];
                        for (int i6 = 0; i6 < length; i6++) {
                            int i7 = i6;
                            dArr3[i7] = dArr3[i7] + ((dArr[i6] - this._xmeans[i6]) * (d - d2));
                        }
                    }
                }
            }
        }

        @Override // water.MRTask
        public void reduce(CoVarTaskCompleteObs coVarTaskCompleteObs) {
            ArrayUtils.add(this._covs, coVarTaskCompleteObs._covs);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:water/rapids/ASTVariance$CoVarTaskCompleteObsMean.class */
    public static class CoVarTaskCompleteObsMean extends MRTask<CoVarTaskCompleteObsMean> {
        double[] _xsum;
        double[] _ysum;
        long _NACount;
        int _ncolx;
        int _ncoly;

        CoVarTaskCompleteObsMean(int i, int i2) {
            this._ncolx = i2;
            this._ncoly = i;
        }

        @Override // water.MRTask
        public void map(Chunk[] chunkArr) {
            this._xsum = new double[this._ncolx];
            this._ysum = new double[this._ncoly];
            double[] dArr = new double[this._ncolx];
            double[] dArr2 = new double[this._ncoly];
            int i = chunkArr[0]._len;
            for (int i2 = 0; i2 < i; i2++) {
                boolean z = true;
                Arrays.fill(dArr, 0.0d);
                Arrays.fill(dArr2, 0.0d);
                int i3 = 0;
                while (true) {
                    if (i3 >= this._ncoly) {
                        break;
                    }
                    double atd = chunkArr[i3].atd(i2);
                    if (Double.isNaN(atd)) {
                        this._NACount++;
                        z = false;
                        break;
                    } else {
                        dArr2[i3] = atd;
                        i3++;
                    }
                }
                if (z) {
                    int i4 = 0;
                    while (true) {
                        if (i4 >= this._ncolx) {
                            break;
                        }
                        double atd2 = chunkArr[i4 + this._ncoly].atd(i2);
                        if (Double.isNaN(atd2)) {
                            this._NACount++;
                            z = false;
                            break;
                        } else {
                            dArr[i4] = atd2;
                            i4++;
                        }
                    }
                }
                if (z) {
                    ArrayUtils.add(this._xsum, dArr);
                    ArrayUtils.add(this._ysum, dArr2);
                }
            }
        }

        @Override // water.MRTask
        public void reduce(CoVarTaskCompleteObsMean coVarTaskCompleteObsMean) {
            ArrayUtils.add(this._xsum, coVarTaskCompleteObsMean._xsum);
            ArrayUtils.add(this._ysum, coVarTaskCompleteObsMean._ysum);
            this._NACount += coVarTaskCompleteObsMean._NACount;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:water/rapids/ASTVariance$CoVarTaskCompleteObsMeanSym.class */
    public static class CoVarTaskCompleteObsMeanSym extends MRTask<CoVarTaskCompleteObsMeanSym> {
        double[] _ysum;
        long _NACount;

        private CoVarTaskCompleteObsMeanSym() {
        }

        @Override // water.MRTask
        public void map(Chunk[] chunkArr) {
            int length = chunkArr.length;
            this._ysum = new double[length];
            double[] dArr = new double[length];
            int i = chunkArr[0]._len;
            for (int i2 = 0; i2 < i; i2++) {
                boolean z = true;
                Arrays.fill(dArr, 0.0d);
                int i3 = 0;
                while (true) {
                    if (i3 >= length) {
                        break;
                    }
                    double atd = chunkArr[i3].atd(i2);
                    if (Double.isNaN(atd)) {
                        this._NACount++;
                        z = false;
                        break;
                    } else {
                        dArr[i3] = atd;
                        i3++;
                    }
                }
                if (z) {
                    ArrayUtils.add(this._ysum, dArr);
                }
            }
        }

        @Override // water.MRTask
        public void reduce(CoVarTaskCompleteObsMeanSym coVarTaskCompleteObsMeanSym) {
            ArrayUtils.add(this._ysum, coVarTaskCompleteObsMeanSym._ysum);
            this._NACount += coVarTaskCompleteObsMeanSym._NACount;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:water/rapids/ASTVariance$CoVarTaskCompleteObsSym.class */
    public static class CoVarTaskCompleteObsSym extends MRTask<CoVarTaskCompleteObsSym> {
        double[][] _covs;
        final double[] _ymeans;

        CoVarTaskCompleteObsSym(double[] dArr) {
            this._ymeans = dArr;
        }

        @Override // water.MRTask
        public void map(Chunk[] chunkArr) {
            int length = this._ymeans.length;
            double[] dArr = new double[length];
            this._covs = new double[length][length];
            int i = chunkArr[0]._len;
            for (int i2 = 0; i2 < i; i2++) {
                boolean z = true;
                Arrays.fill(dArr, 0.0d);
                int i3 = 0;
                while (true) {
                    if (i3 >= length) {
                        break;
                    }
                    double atd = chunkArr[i3].atd(i2);
                    if (Double.isNaN(atd)) {
                        z = false;
                        break;
                    } else {
                        dArr[i3] = atd;
                        i3++;
                    }
                }
                if (z) {
                    for (int i4 = 0; i4 < length; i4++) {
                        double[] dArr2 = this._covs[i4];
                        double d = dArr[i4];
                        double d2 = this._ymeans[i4];
                        for (int i5 = i4; i5 < length; i5++) {
                            int i6 = i5;
                            dArr2[i6] = dArr2[i6] + ((dArr[i5] - this._ymeans[i5]) * (d - d2));
                        }
                    }
                }
            }
        }

        @Override // water.MRTask
        public void reduce(CoVarTaskCompleteObsSym coVarTaskCompleteObsSym) {
            ArrayUtils.add(this._covs, coVarTaskCompleteObsSym._covs);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:water/rapids/ASTVariance$CoVarTaskEverything.class */
    public static class CoVarTaskEverything extends MRTask<CoVarTaskEverything> {
        double[] _covs;
        final double[] _xmeans;
        final double _ymean;

        CoVarTaskEverything(double d, double[] dArr) {
            this._ymean = d;
            this._xmeans = dArr;
        }

        @Override // water.MRTask
        public void map(Chunk[] chunkArr) {
            int length = chunkArr.length - 1;
            Chunk chunk = chunkArr[0];
            int i = chunk._len;
            this._covs = new double[length];
            for (int i2 = 0; i2 < length; i2++) {
                double d = 0.0d;
                Chunk chunk2 = chunkArr[i2 + 1];
                double d2 = this._xmeans[i2];
                for (int i3 = 0; i3 < i; i3++) {
                    d += (chunk2.atd(i3) - d2) * (chunk.atd(i3) - this._ymean);
                }
                this._covs[i2] = d;
            }
        }

        @Override // water.MRTask
        public void reduce(CoVarTaskEverything coVarTaskEverything) {
            ArrayUtils.add(this._covs, coVarTaskEverything._covs);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:water/rapids/ASTVariance$Mode.class */
    public enum Mode {
        Everything,
        AllObs,
        CompleteObs
    }

    @Override // water.rapids.ASTPrim
    public String[] args() {
        return new String[]{"ary", "x", "y", "use", "symmetric"};
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @Override // water.rapids.AST
    public int nargs() {
        return 5;
    }

    @Override // water.rapids.AST
    public String str() {
        return "var";
    }

    @Override // water.rapids.AST
    public Val apply(Env env, Env.StackHelp stackHelp, AST[] astArr) {
        Mode mode;
        Frame frame = stackHelp.track(astArr[1].exec(env)).getFrame();
        Frame frame2 = stackHelp.track(astArr[2].exec(env)).getFrame();
        if (frame.numRows() != frame2.numRows()) {
            throw new IllegalArgumentException("Frames must have the same number of rows, found " + frame.numRows() + " and " + frame2.numRows());
        }
        String str = stackHelp.track(astArr[3].exec(env)).getStr();
        boolean z = astArr[4].exec(env).getNum() == 1.0d;
        boolean z2 = -1;
        switch (str.hashCode()) {
            case -913287373:
                if (str.equals("all.obs")) {
                    z2 = true;
                    break;
                }
                break;
            case -411139381:
                if (str.equals("complete.obs")) {
                    z2 = 2;
                    break;
                }
                break;
            case 401590963:
                if (str.equals("everything")) {
                    z2 = false;
                    break;
                }
                break;
        }
        switch (z2) {
            case false:
                mode = Mode.Everything;
                break;
            case true:
                mode = Mode.AllObs;
                break;
            case true:
                mode = Mode.CompleteObs;
                break;
            default:
                throw new IllegalArgumentException("unknown use mode: " + str);
        }
        return frame2.numRows() == 1 ? scalar(frame, frame2, mode) : array(frame, frame2, mode, z);
    }

    private ValNum scalar(Frame frame, Frame frame2, Mode mode) {
        if (frame.numCols() != frame2.numCols()) {
            throw new IllegalArgumentException("Single rows must have the same number of columns, found " + frame.numCols() + " and " + frame2.numCols());
        }
        Vec[] vecs = frame.vecs();
        Vec[] vecs2 = frame2.vecs();
        double d = 0.0d;
        double d2 = 0.0d;
        double numCols = frame.numCols();
        double d3 = 0.0d;
        double d4 = 0.0d;
        for (int i = 0; i < numCols; i++) {
            double at = vecs[i].at(0L);
            double at2 = vecs2[i].at(0L);
            if (Double.isNaN(at) || Double.isNaN(at2)) {
                d3 += 1.0d;
            } else {
                d += at;
                d2 += at2;
            }
        }
        double d5 = d / (numCols - d3);
        double d6 = d2 / (numCols - d3);
        if (d3 != 0.0d) {
            if (mode.equals(Mode.AllObs)) {
                throw new IllegalArgumentException("Mode is 'all.obs' but NAs are present");
            }
            if (mode.equals(Mode.Everything)) {
                return new ValNum(Double.NaN);
            }
        }
        for (int i2 = 0; i2 < numCols; i2++) {
            double at3 = vecs[i2].at(0L);
            double at4 = vecs2[i2].at(0L);
            if (!Double.isNaN(at3) && !Double.isNaN(at4)) {
                d4 += (vecs[i2].at(0L) - d5) * (vecs2[i2].at(0L) - d6);
            }
        }
        return new ValNum(d4 / ((numCols - d3) - 1.0d));
    }

    private Val array(Frame frame, Frame frame2, Mode mode, boolean z) {
        Vec[] vecs = frame.vecs();
        int length = vecs.length;
        Vec[] vecs2 = frame2.vecs();
        int length2 = vecs2.length;
        if (!mode.equals(Mode.Everything) && !mode.equals(Mode.AllObs)) {
            if (!z) {
                CoVarTaskCompleteObsMean doAll = new CoVarTaskCompleteObsMean(length2, length).doAll(new Frame(frame2).add(frame));
                long j = doAll._NACount;
                CoVarTaskCompleteObs doAll2 = new CoVarTaskCompleteObs(ArrayUtils.div(doAll._ysum, frame2.numRows() - j), ArrayUtils.div(doAll._xsum, frame2.numRows() - j)).doAll(new Frame(frame2).add(frame));
                if (length == 1 && length2 == 1) {
                    return new ValNum(doAll2._covs[0][0] / ((frame2.numRows() - 1) - j));
                }
                Vec[] vecArr = new Vec[length2];
                Key<Vec>[] addVecs = Vec.VectorGroup.VG_LEN1.addVecs(length2);
                for (int i = 0; i < length2; i++) {
                    vecArr[i] = Vec.makeVec(ArrayUtils.div(doAll2._covs[i], (frame2.numRows() - 1) - j), addVecs[i]);
                }
                return new ValFrame(new Frame(frame2._names, vecArr));
            }
            if (length2 == 1) {
                return new ValNum(vecs2[0].sigma() * vecs2[0].sigma());
            }
            CoVarTaskCompleteObsMeanSym doAll3 = new CoVarTaskCompleteObsMeanSym().doAll(frame2);
            long j2 = doAll3._NACount;
            CoVarTaskCompleteObsSym doAll4 = new CoVarTaskCompleteObsSym(ArrayUtils.div(doAll3._ysum, frame2.numRows() - j2)).doAll(new Frame(frame2));
            double[][] dArr = new double[length2][length2];
            for (int i2 = 0; i2 < length2; i2++) {
                System.arraycopy(ArrayUtils.div(doAll4._covs[i2], (frame2.numRows() - 1) - j2), i2, dArr[i2], i2, length2 - i2);
            }
            for (int i3 = 0; i3 < length2 - 1; i3++) {
                for (int i4 = i3 + 1; i4 < length2; i4++) {
                    dArr[i4][i3] = dArr[i3][i4];
                }
            }
            Vec[] vecArr2 = new Vec[length2];
            Key<Vec>[] addVecs2 = Vec.VectorGroup.VG_LEN1.addVecs(length2);
            for (int i5 = 0; i5 < length2; i5++) {
                vecArr2[i5] = Vec.makeVec(dArr[i5], addVecs2[i5]);
            }
            return new ValFrame(new Frame(frame2._names, vecArr2));
        }
        if (mode.equals(Mode.AllObs)) {
            for (Vec vec : vecs) {
                if (vec.naCnt() != 0) {
                    throw new IllegalArgumentException("Mode is 'all.obs' but NAs are present");
                }
            }
            if (!z) {
                for (Vec vec2 : vecs2) {
                    if (vec2.naCnt() != 0) {
                        throw new IllegalArgumentException("Mode is 'all.obs' but NAs are present");
                    }
                }
            }
        }
        CoVarTaskEverything[] coVarTaskEverythingArr = new CoVarTaskEverything[length2];
        double[] dArr2 = new double[length];
        for (int i6 = 0; i6 < length2; i6++) {
            dArr2[i6] = vecs[i6].mean();
        }
        if (!z) {
            for (int i7 = 0; i7 < length2; i7++) {
                coVarTaskEverythingArr[i7] = new CoVarTaskEverything(vecs2[i7].mean(), dArr2).dfork(new Frame(vecs2[i7]).add(frame));
            }
            if (length == 1 && length2 == 1) {
                return new ValNum(coVarTaskEverythingArr[0].getResult()._covs[0] / (frame2.numRows() - 1));
            }
            Vec[] vecArr3 = new Vec[length2];
            Key<Vec>[] addVecs3 = Vec.VectorGroup.VG_LEN1.addVecs(length2);
            for (int i8 = 0; i8 < length2; i8++) {
                vecArr3[i8] = Vec.makeVec(ArrayUtils.div(coVarTaskEverythingArr[i8].getResult()._covs, frame2.numRows() - 1), addVecs3[i8]);
            }
            return new ValFrame(new Frame(frame2._names, vecArr3));
        }
        if (length2 == 1) {
            return new ValNum(vecs2[0].naCnt() == 0 ? vecs2[0].sigma() * vecs2[0].sigma() : Double.NaN);
        }
        int[] iArr = new int[length2];
        for (int i9 = 1; i9 < length2; i9++) {
            iArr[i9] = i9;
        }
        int[] iArr2 = {0};
        for (int i10 = 0; i10 < length2 - 1; i10++) {
            iArr = ArrayUtils.removeIds(iArr, iArr2);
            coVarTaskEverythingArr[i10] = new CoVarTaskEverything(vecs2[i10].mean(), dArr2).dfork(new Frame(vecs2[i10]).add(new Frame(frame.vecs(iArr))));
        }
        double[][] dArr3 = new double[length2][length2];
        for (int i11 = 0; i11 < length2; i11++) {
            dArr3[i11][i11] = vecs2[i11].naCnt() == 0 ? vecs2[i11].sigma() * vecs2[i11].sigma() : Double.NaN;
        }
        for (int i12 = 0; i12 < length2 - 1; i12++) {
            System.arraycopy(ArrayUtils.div(coVarTaskEverythingArr[i12].getResult()._covs, frame2.numRows() - 1), 0, dArr3[i12], i12 + 1, (length2 - i12) - 1);
        }
        for (int i13 = 0; i13 < length2 - 1; i13++) {
            for (int i14 = i13 + 1; i14 < length2; i14++) {
                dArr3[i14][i13] = dArr3[i13][i14];
            }
        }
        Vec[] vecArr4 = new Vec[length2];
        Key<Vec>[] addVecs4 = Vec.VectorGroup.VG_LEN1.addVecs(length2);
        for (int i15 = 0; i15 < length2; i15++) {
            vecArr4[i15] = Vec.makeVec(dArr3[i15], addVecs4[i15]);
        }
        return new ValFrame(new Frame(frame2._names, vecArr4));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static double getVar(Vec vec) {
        if (vec.naCnt() == 0) {
            return vec.sigma() * vec.sigma();
        }
        return Double.NaN;
    }
}
