package water.rapids.ast.prims.advmath;

import java.math.BigDecimal;
import java.math.MathContext;
import water.Key;
import water.MRTask;
import water.Scope;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.rapids.Env;
import water.rapids.Merge;
import water.rapids.Val;
import water.rapids.ast.AstPrimitive;
import water.rapids.ast.AstRoot;
import water.rapids.vals.ValNum;
import water.util.FrameUtils;

/* loaded from: input_file:water/rapids/ast/prims/advmath/AstSpearman.class */
public class AstSpearman extends AstPrimitive<AstSpearman> {
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:water/rapids/ast/prims/advmath/AstSpearman$MeanTask.class */
    public static class MeanTask extends MRTask<MeanTask> {
        private double[] _means;
        private long _linesVisited;

        private MeanTask() {
            this._linesVisited = 0L;
        }

        @Override // water.MRTask
        public void map(Chunk[] chunkArr) {
            BigDecimal[] bigDecimalArr = new BigDecimal[chunkArr.length];
            for (int i = 0; i < bigDecimalArr.length; i++) {
                bigDecimalArr[i] = new BigDecimal(0, MathContext.DECIMAL128);
            }
            loop1: for (int i2 = 0; i2 < chunkArr[0].len(); i2++) {
                double[] dArr = new double[chunkArr.length];
                for (int i3 = 0; i3 < chunkArr.length; i3++) {
                    dArr[i3] = chunkArr[i3].atd(i2);
                    if (Double.isNaN(dArr[i3])) {
                        break loop1;
                    }
                }
                this._linesVisited++;
                for (int i4 = 0; i4 < dArr.length; i4++) {
                    bigDecimalArr[i4] = bigDecimalArr[i4].add(new BigDecimal(dArr[i4], MathContext.DECIMAL128), MathContext.DECIMAL128);
                }
            }
            this._means = new double[chunkArr.length];
            for (int i5 = 0; i5 < bigDecimalArr.length; i5++) {
                this._means[i5] = bigDecimalArr[i5].divide(new BigDecimal(this._linesVisited), MathContext.DECIMAL64).doubleValue();
            }
        }

        @Override // water.MRTask
        public void reduce(MeanTask meanTask) {
            int length = this._means.length;
            for (int i = 0; i < length; i++) {
                this._means[i] = ((this._means[i] * this._linesVisited) + (meanTask._means[i] * meanTask._linesVisited)) / (this._linesVisited + meanTask._linesVisited);
            }
            this._linesVisited += meanTask._linesVisited;
        }
    }

    /* loaded from: input_file:water/rapids/ast/prims/advmath/AstSpearman$SpearmanCorrelationCoefficientTask.class */
    private static class SpearmanCorrelationCoefficientTask extends MRTask<SpearmanCorrelationCoefficientTask> {
        private final double _xMean;
        private final double _yMean;
        private double spearmanCorrelationCoefficient;
        private double _xDiffSquared;
        private double _yDiffSquared;
        private double _xyMul;
        private long _linesVisited;
        static final /* synthetic */ boolean $assertionsDisabled;

        private SpearmanCorrelationCoefficientTask(double d, double d2) {
            this._xDiffSquared = 0.0d;
            this._yDiffSquared = 0.0d;
            this._xyMul = 0.0d;
            this._xMean = d;
            this._yMean = d2;
        }

        @Override // water.MRTask
        public void map(Chunk[] chunkArr) {
            if (!$assertionsDisabled && chunkArr.length != 2) {
                throw new AssertionError();
            }
            Chunk chunk = chunkArr[0];
            Chunk chunk2 = chunkArr[1];
            for (int i = 0; i < chunkArr[0].len(); i++) {
                double atd = chunk.atd(i);
                double atd2 = chunk2.atd(i);
                this._linesVisited++;
                this._xyMul += atd * atd2;
                double d = atd - this._xMean;
                double d2 = atd2 - this._yMean;
                this._xDiffSquared += Math.pow(d, 2.0d);
                this._yDiffSquared += Math.pow(d2, 2.0d);
            }
        }

        @Override // water.MRTask
        public void reduce(SpearmanCorrelationCoefficientTask spearmanCorrelationCoefficientTask) {
            this._xDiffSquared += spearmanCorrelationCoefficientTask._xDiffSquared;
            this._yDiffSquared += spearmanCorrelationCoefficientTask._yDiffSquared;
            this._linesVisited += spearmanCorrelationCoefficientTask._linesVisited;
            this._xyMul += spearmanCorrelationCoefficientTask._xyMul;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // water.MRTask
        public void postGlobal() {
            this.spearmanCorrelationCoefficient = (this._xyMul - ((this._linesVisited * this._xMean) * this._yMean)) / ((this._linesVisited * Math.sqrt(this._xDiffSquared / this._linesVisited)) * Math.sqrt(this._yDiffSquared / this._linesVisited));
        }

        public double getSpearmanCorrelationCoefficient() {
            return this.spearmanCorrelationCoefficient;
        }

        static {
            $assertionsDisabled = !AstSpearman.class.desiredAssertionStatus();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:water/rapids/ast/prims/advmath/AstSpearman$SpearmanRankedVectors.class */
    public static class SpearmanRankedVectors {
        private final Vec _x;
        private final Vec _y;

        public SpearmanRankedVectors(Vec vec, Vec vec2) {
            this._x = vec;
            this._y = vec2;
        }
    }

    @Override // water.rapids.ast.AstPrimitive
    public int nargs() {
        return 4;
    }

    @Override // water.rapids.ast.AstPrimitive
    public String[] args() {
        return new String[]{"frame", "first_column", "second_column"};
    }

    @Override // water.rapids.ast.AstPrimitive
    public Val apply(Env env, Env.StackHelp stackHelp, AstRoot[] astRootArr) {
        Frame track = stackHelp.track(astRootArr[1].exec(env).getFrame());
        int find = track.find(astRootArr[2].exec(env).getStr());
        int find2 = track.find(astRootArr[3].exec(env).getStr());
        try {
            Scope.enter();
            SpearmanRankedVectors rankedVectors = rankedVectors(track, find, find2);
            double[] calculateMeans = calculateMeans(rankedVectors._x, rankedVectors._y);
            ValNum valNum = new ValNum(new SpearmanCorrelationCoefficientTask(calculateMeans[0], calculateMeans[1]).doAll(rankedVectors._x, rankedVectors._y).getSpearmanCorrelationCoefficient());
            Scope.exit(new Key[0]);
            return valNum;
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    private SpearmanRankedVectors rankedVectors(Frame frame, int i, int i2) {
        Frame frame2 = new Frame(frame.vec(i).makeCopy(), frame.vec(i2).makeCopy());
        Frame outputFrame = new Merge.RemoveNAsTask(0, 1).doAll(frame2.types(), frame2).outputFrame(frame2.names(), frame2.domains());
        Frame frame3 = new Frame(outputFrame.vec(0).makeCopy());
        Scope.track(frame3);
        Frame frame4 = new Frame(outputFrame.vec(1).makeCopy());
        Scope.track(frame4);
        boolean needsOrdering = needsOrdering(frame3.vec(0));
        boolean needsOrdering2 = needsOrdering(frame4.vec(0));
        if (needsOrdering) {
            FrameUtils.labelRows(frame3, "label");
            frame3 = frame3.sort(new int[]{0});
            Scope.track(frame3);
        }
        if (needsOrdering2) {
            FrameUtils.labelRows(frame4, "label");
            frame4 = frame4.sort(new int[]{0});
            Scope.track(frame4);
        }
        if (!$assertionsDisabled && frame3.numRows() != frame4.numRows()) {
            throw new AssertionError();
        }
        Vec makeZero = needsOrdering(frame3.vec(0)) ? Vec.makeZero(frame3.numRows()) : frame.vec(i);
        Vec makeZero2 = needsOrdering(frame4.vec(0)) ? Vec.makeZero(frame4.numRows()) : frame.vec(i2);
        Vec vec = frame3.vec("label") == null ? frame3.vec(0) : frame3.vec("label");
        Vec vec2 = frame3.vec(0);
        Vec vec3 = frame4.vec("label") == null ? frame4.vec(0) : frame4.vec("label");
        Vec vec4 = frame4.vec(0);
        Scope.track(vec);
        Scope.track(vec3);
        Vec.Writer open = makeZero.open();
        Vec.Writer open2 = makeZero2.open();
        vec2.getClass();
        Vec.Reader reader = new Vec.Reader();
        vec4.getClass();
        Vec.Reader reader2 = new Vec.Reader();
        vec.getClass();
        Vec.Reader reader3 = new Vec.Reader();
        vec3.getClass();
        Vec.Reader reader4 = new Vec.Reader();
        double d = Double.NaN;
        double d2 = Double.NaN;
        long j = 0;
        long j2 = 0;
        for (int i3 = 0; i3 < makeZero.length(); i3++) {
            if (needsOrdering) {
                j = d == reader.at((long) i3) ? j + 1 : 0L;
                d = reader.at(i3);
                open.set(reader3.at8(i3) - 1, i3 - j);
            }
            if (needsOrdering2) {
                j2 = d2 == reader2.at((long) i3) ? j2 + 1 : 0L;
                d2 = reader2.at(i3);
                open2.set(reader4.at8(i3) - 1, i3 - j2);
            }
        }
        open.close();
        open2.close();
        return new SpearmanRankedVectors(makeZero, makeZero2);
    }

    private boolean needsOrdering(Vec vec) {
        return !vec.isCategorical();
    }

    @Override // water.rapids.ast.AstRoot
    public String str() {
        return "spearman";
    }

    private static double[] calculateMeans(Vec... vecArr) throws IllegalArgumentException {
        if (vecArr.length < 1) {
            throw new IllegalArgumentException("There are no vectors to calculate means from.");
        }
        long length = vecArr[0].length();
        for (int i = 0; i < vecArr.length; i++) {
            if (!vecArr[i].isCategorical() && !vecArr[i].isNumeric()) {
                throw new IllegalArgumentException(String.format("Given vector '%s' is not numerical or categorical.", vecArr[i]._key.toString()));
            }
            if (length != vecArr[i].length()) {
                throw new IllegalArgumentException("Vectors to calculate means from do not have the same length." + String.format(" Vector '%s' is of length '%d'", vecArr[i]._key.toString(), Long.valueOf(vecArr[i].length())));
            }
        }
        return new MeanTask().doAll(vecArr)._means;
    }

    static {
        $assertionsDisabled = !AstSpearman.class.desiredAssertionStatus();
    }
}
