package water.rapids;

import java.util.Arrays;
import water.Key;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.rapids.Env;
import water.util.ArrayUtils;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:water/rapids/ASTVariance.class */
public class ASTVariance extends ASTPrim {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:water/rapids/ASTVariance$CoVarTaskComplete.class */
    public static class CoVarTaskComplete extends MRTask<CoVarTaskComplete> {
        double[][] _ss;
        double[] _xsum;
        double[] _ysum;
        long _NACount;
        int _ncolx;
        int _ncoly;

        CoVarTaskComplete(int i, int i2) {
            this._ncolx = i;
            this._ncoly = i2;
        }

        @Override // water.MRTask
        public void map(Chunk[] chunkArr) {
            this._ss = new double[this._ncoly][this._ncolx];
            this._xsum = new double[this._ncolx];
            this._ysum = new double[this._ncoly];
            double[] dArr = new double[this._ncolx];
            double[] dArr2 = new double[this._ncoly];
            int i = chunkArr[0]._len;
            for (int i2 = 0; i2 < i; i2++) {
                boolean z = true;
                Arrays.fill(dArr, 0.0d);
                Arrays.fill(dArr2, 0.0d);
                int i3 = 0;
                while (true) {
                    if (i3 >= this._ncoly) {
                        break;
                    }
                    double atd = chunkArr[i3].atd(i2);
                    if (Double.isNaN(atd)) {
                        this._NACount++;
                        z = false;
                        break;
                    } else {
                        dArr2[i3] = atd;
                        i3++;
                    }
                }
                if (z) {
                    int i4 = 0;
                    while (true) {
                        if (i4 >= this._ncolx) {
                            break;
                        }
                        double atd2 = chunkArr[i4 + this._ncoly].atd(i2);
                        if (Double.isNaN(atd2)) {
                            this._NACount++;
                            z = false;
                            break;
                        } else {
                            dArr[i4] = atd2;
                            i4++;
                        }
                    }
                }
                if (z) {
                    for (int i5 = 0; i5 < this._ncoly; i5++) {
                        double[] dArr3 = this._ss[i5];
                        for (int i6 = 0; i6 < this._ncolx; i6++) {
                            int i7 = i6;
                            dArr3[i7] = dArr3[i7] + (dArr[i6] * dArr2[i5]);
                        }
                    }
                    ArrayUtils.add(this._xsum, dArr);
                    ArrayUtils.add(this._ysum, dArr2);
                }
            }
        }

        @Override // water.MRTask
        public void reduce(CoVarTaskComplete coVarTaskComplete) {
            ArrayUtils.add(this._ss, coVarTaskComplete._ss);
            ArrayUtils.add(this._xsum, coVarTaskComplete._xsum);
            ArrayUtils.add(this._ysum, coVarTaskComplete._ysum);
            this._NACount += coVarTaskComplete._NACount;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:water/rapids/ASTVariance$CoVarTaskEverything.class */
    public static class CoVarTaskEverything extends MRTask<CoVarTaskEverything> {
        double[] _ss;
        double[] _xsum;
        double[] _ysum;

        CoVarTaskEverything() {
        }

        @Override // water.MRTask
        public void map(Chunk[] chunkArr) {
            int length = chunkArr.length - 1;
            Chunk chunk = chunkArr[0];
            int i = chunk._len;
            this._ss = new double[length];
            this._xsum = new double[length];
            this._ysum = new double[length];
            for (int i2 = 0; i2 < length; i2++) {
                double d = 0.0d;
                double d2 = 0.0d;
                double d3 = 0.0d;
                Chunk chunk2 = chunkArr[i2 + 1];
                for (int i3 = 0; i3 < i; i3++) {
                    double atd = chunk2.atd(i3);
                    double atd2 = chunk.atd(i3);
                    d2 += atd;
                    d3 += atd2;
                    d += atd * atd2;
                }
                this._ss[i2] = d;
                this._xsum[i2] = d2;
                this._ysum[i2] = d3;
            }
        }

        @Override // water.MRTask
        public void reduce(CoVarTaskEverything coVarTaskEverything) {
            ArrayUtils.add(this._ss, coVarTaskEverything._ss);
            ArrayUtils.add(this._xsum, coVarTaskEverything._xsum);
            ArrayUtils.add(this._ysum, coVarTaskEverything._ysum);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:water/rapids/ASTVariance$CoVarTaskPairwise.class */
    public static class CoVarTaskPairwise extends MRTask<CoVarTaskPairwise> {
        double[] _ss;
        double[] _xsum;
        double[] _ysum;
        long[] _NACount;

        CoVarTaskPairwise() {
        }

        @Override // water.MRTask
        public void map(Chunk[] chunkArr) {
            int length = chunkArr.length - 1;
            Chunk chunk = chunkArr[0];
            int i = chunk._len;
            this._ss = new double[length];
            this._xsum = new double[length];
            this._ysum = new double[length];
            this._NACount = new long[length];
            for (int i2 = 0; i2 < length; i2++) {
                double d = 0.0d;
                double d2 = 0.0d;
                double d3 = 0.0d;
                long j = 0;
                Chunk chunk2 = chunkArr[i2 + 1];
                for (int i3 = 0; i3 < i; i3++) {
                    double atd = chunk2.atd(i3);
                    double atd2 = chunk.atd(i3);
                    if (Double.isNaN(atd) || Double.isNaN(atd2)) {
                        j++;
                    } else {
                        d2 += atd;
                        d3 += atd2;
                        d += atd * atd2;
                    }
                }
                this._ss[i2] = d;
                this._xsum[i2] = d2;
                this._ysum[i2] = d3;
                this._NACount[i2] = j;
            }
        }

        @Override // water.MRTask
        public void reduce(CoVarTaskPairwise coVarTaskPairwise) {
            ArrayUtils.add(this._ss, coVarTaskPairwise._ss);
            ArrayUtils.add(this._xsum, coVarTaskPairwise._xsum);
            ArrayUtils.add(this._ysum, coVarTaskPairwise._ysum);
            ArrayUtils.add(this._NACount, coVarTaskPairwise._NACount);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:water/rapids/ASTVariance$Mode.class */
    public enum Mode {
        Everything,
        AllObs,
        CompleteObs,
        PairwiseCompleteObs
    }

    @Override // water.rapids.ASTPrim
    public String[] args() {
        return new String[]{"ary", "x", "y", "use"};
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @Override // water.rapids.AST
    public int nargs() {
        return 4;
    }

    @Override // water.rapids.AST
    public String str() {
        return "var";
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @Override // water.rapids.AST
    public Val apply(Env env, Env.StackHelp stackHelp, AST[] astArr) {
        Mode mode;
        Frame frame = stackHelp.track(astArr[1].exec(env)).getFrame();
        Frame frame2 = stackHelp.track(astArr[2].exec(env)).getFrame();
        if (frame.numRows() != frame2.numRows()) {
            throw new IllegalArgumentException("Frames must have the same number of rows, found " + frame.numRows() + " and " + frame2.numRows());
        }
        String str = stackHelp.track(astArr[3].exec(env)).getStr();
        boolean z = -1;
        switch (str.hashCode()) {
            case -913287373:
                if (str.equals("all.obs")) {
                    z = true;
                    break;
                }
                break;
            case -411139381:
                if (str.equals("complete.obs")) {
                    z = 2;
                    break;
                }
                break;
            case 401590963:
                if (str.equals("everything")) {
                    z = false;
                    break;
                }
                break;
            case 1811649595:
                if (str.equals("pairwise.complete.obs")) {
                    z = 3;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                mode = Mode.Everything;
                break;
            case true:
                mode = Mode.AllObs;
                break;
            case true:
                mode = Mode.CompleteObs;
                break;
            case true:
                mode = Mode.PairwiseCompleteObs;
                break;
            default:
                throw new IllegalArgumentException("unknown use mode, found: " + str);
        }
        return frame.numRows() == 1 ? scalar(frame, frame2, mode) : array(frame, frame2, mode);
    }

    private ValNum scalar(Frame frame, Frame frame2, Mode mode) {
        if (frame.numCols() != frame2.numCols()) {
            throw new IllegalArgumentException("Single rows have the same number of columns, found " + frame.numCols() + " and " + frame2.numCols());
        }
        Vec[] vecs = frame.vecs();
        Vec[] vecs2 = frame2.vecs();
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        double numCols = frame.numCols();
        double d4 = 0.0d;
        for (int i = 0; i < vecs.length; i++) {
            double at = vecs[i].at(0L);
            double at2 = vecs2[i].at(0L);
            if (Double.isNaN(at) || Double.isNaN(at2)) {
                d3 += 1.0d;
            } else {
                d += at;
                d2 += at2;
                d4 += at * at2;
            }
        }
        if (d3 > 0.0d) {
            if (mode.equals(Mode.AllObs)) {
                throw new IllegalArgumentException("Mode is 'all.obs' but NAs are present");
            }
            if (mode.equals(Mode.Everything)) {
                return new ValNum(Double.NaN);
            }
        }
        return new ValNum((d4 - ((d * d2) / (numCols - d3))) / ((numCols - 1.0d) - d3));
    }

    private Val array(Frame frame, Frame frame2, Mode mode) {
        int length = frame.vecs().length;
        Vec[] vecs = frame2.vecs();
        int length2 = vecs.length;
        if (mode.equals(Mode.Everything) || mode.equals(Mode.AllObs)) {
            CoVarTaskEverything[] coVarTaskEverythingArr = new CoVarTaskEverything[length2];
            for (int i = 0; i < length2; i++) {
                coVarTaskEverythingArr[i] = new CoVarTaskEverything().dfork(new Frame(vecs[i]).add(frame));
            }
            if (length == 1 && length2 == 1) {
                CoVarTaskEverything result = coVarTaskEverythingArr[0].getResult();
                if (mode.equals(Mode.AllObs) && Double.isNaN(result._ss[0])) {
                    throw new IllegalArgumentException("Mode is 'all.obs' but NAs are present");
                }
                return new ValNum((result._ss[0] - ((result._xsum[0] * result._ysum[0]) / frame.numRows())) / (frame.numRows() - 1));
            }
            Vec[] vecArr = new Vec[length2];
            Key<Vec>[] addVecs = Vec.VectorGroup.VG_LEN1.addVecs(length2);
            for (int i2 = 0; i2 < length2; i2++) {
                CoVarTaskEverything result2 = coVarTaskEverythingArr[i2].getResult();
                if (mode.equals(Mode.AllObs)) {
                    for (double d : result2._ss) {
                        if (Double.isNaN(d)) {
                            throw new IllegalArgumentException("Mode is 'all.obs' but NAs are present");
                        }
                    }
                }
                vecArr[i2] = Vec.makeVec(ArrayUtils.div(ArrayUtils.subtract(result2._ss, ArrayUtils.mult(result2._xsum, ArrayUtils.div(result2._ysum, frame.numRows()))), frame.numRows() - 1), addVecs[i2]);
            }
            return new ValFrame(new Frame(frame2._names, vecArr));
        }
        if (mode.equals(Mode.CompleteObs)) {
            CoVarTaskComplete doAll = new CoVarTaskComplete(length, length2).doAll(new Frame(frame2).add(frame));
            if (length == 1 && length2 == 1) {
                return new ValNum((doAll._ss[0][0] - ((doAll._xsum[0] * doAll._ysum[0]) / (frame.numRows() - doAll._NACount))) / ((frame.numRows() - doAll._NACount) - 1));
            }
            Vec[] vecArr2 = new Vec[length2];
            Key<Vec>[] addVecs2 = Vec.VectorGroup.VG_LEN1.addVecs(length2);
            for (int i3 = 0; i3 < length2; i3++) {
                vecArr2[i3] = Vec.makeVec(ArrayUtils.div(ArrayUtils.subtract(doAll._ss[i3], ArrayUtils.mult((double[]) doAll._xsum.clone(), doAll._ysum[i3] / (frame.numRows() - doAll._NACount))), (frame.numRows() - 1) - doAll._NACount), addVecs2[i3]);
            }
            return new ValFrame(new Frame(frame2._names, vecArr2));
        }
        if (!mode.equals(Mode.PairwiseCompleteObs)) {
            throw new IllegalArgumentException("unknown use mode, found: " + mode);
        }
        CoVarTaskPairwise[] coVarTaskPairwiseArr = new CoVarTaskPairwise[length2];
        for (int i4 = 0; i4 < length2; i4++) {
            coVarTaskPairwiseArr[i4] = new CoVarTaskPairwise().dfork(new Frame(vecs[i4]).add(frame));
        }
        if (length == 1 && length2 == 1) {
            CoVarTaskPairwise result3 = coVarTaskPairwiseArr[0].getResult();
            return new ValNum((result3._ss[0] - ((result3._xsum[0] * result3._ysum[0]) / (frame.numRows() - result3._NACount[0]))) / ((frame.numRows() - 1) - result3._NACount[0]));
        }
        Vec[] vecArr3 = new Vec[length2];
        Key<Vec>[] addVecs3 = Vec.VectorGroup.VG_LEN1.addVecs(length2);
        for (int i5 = 0; i5 < length2; i5++) {
            CoVarTaskPairwise result4 = coVarTaskPairwiseArr[i5].getResult();
            vecArr3[i5] = Vec.makeVec(ArrayUtils.div(ArrayUtils.subtract(result4._ss, ArrayUtils.mult(result4._xsum, ArrayUtils.div(result4._ysum, ArrayUtils.subtract(frame.numRows(), (long[]) result4._NACount.clone())))), ArrayUtils.subtract(frame.numRows() - 1, (long[]) result4._NACount.clone())), addVecs3[i5]);
        }
        return new ValFrame(new Frame(frame2._names, vecArr3));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static double getVar(Vec vec) {
        CoVarTaskEverything doAll = new CoVarTaskEverything().doAll(new Frame(vec, vec));
        return (doAll._ss[0] - ((doAll._xsum[0] * doAll._ysum[0]) / vec.length())) / (vec.length() - 1);
    }
}
