package water.rapids.ast.prims.advmath;

import java.util.Arrays;
import water.H2O;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.rapids.Env;
import water.rapids.Val;
import water.rapids.ast.AstBuiltin;
import water.rapids.ast.AstRoot;
import water.rapids.vals.ValFrame;
import water.util.ArrayUtils;
import water.util.Log;

/* loaded from: input_file:water/rapids/ast/prims/advmath/AstDistance.class */
public class AstDistance extends AstBuiltin<AstDistance> {

    /* loaded from: input_file:water/rapids/ast/prims/advmath/AstDistance$DistanceComputer.class */
    public static class DistanceComputer extends MRTask<DistanceComputer> {
        Frame _queries;
        String _measure;

        DistanceComputer(Frame frame, String str) {
            this._queries = frame;
            this._measure = str;
        }

        @Override // water.MRTask
        public void map(Chunk[] chunkArr, NewChunk[] newChunkArr) {
            int length = chunkArr.length;
            int numRows = (int) this._queries.numRows();
            int i = chunkArr[0]._len;
            Vec.Reader[] readerArr = new Vec.Reader[length];
            for (int i2 = 0; i2 < length; i2++) {
                Vec vec = this._queries.vec(i2);
                vec.getClass();
                readerArr[i2] = new Vec.Reader();
            }
            double[] dArr = null;
            double[] dArr2 = null;
            boolean equals = this._measure.toLowerCase().equals("cosine");
            boolean equals2 = this._measure.toLowerCase().equals("cosine_sq");
            boolean equals3 = this._measure.toLowerCase().equals("l1");
            boolean equals4 = this._measure.toLowerCase().equals("l2");
            if (equals || equals2) {
                dArr = new double[i];
                dArr2 = new double[numRows];
                for (int i3 = 0; i3 < i; i3++) {
                    for (Chunk chunk : chunkArr) {
                        int i4 = i3;
                        dArr[i4] = dArr[i4] + Math.pow(chunk.atd(i3), 2.0d);
                    }
                }
                for (int i5 = 0; i5 < numRows; i5++) {
                    for (int i6 = 0; i6 < length; i6++) {
                        int i7 = i5;
                        dArr2[i7] = dArr2[i7] + Math.pow(readerArr[i6].at(i5), 2.0d);
                    }
                }
            }
            for (int i8 = 0; i8 < chunkArr[0]._len; i8++) {
                for (int i9 = 0; i9 < numRows; i9++) {
                    double d = 0.0d;
                    if (equals3) {
                        for (int i10 = 0; i10 < length; i10++) {
                            d += Math.abs(chunkArr[i10].atd(i8) - readerArr[i10].at(i9));
                        }
                    } else if (equals4) {
                        for (int i11 = 0; i11 < length; i11++) {
                            d += Math.pow(chunkArr[i11].atd(i8) - readerArr[i11].at(i9), 2.0d);
                        }
                        d = Math.sqrt(d);
                    } else if (equals || equals2) {
                        for (int i12 = 0; i12 < length; i12++) {
                            d += chunkArr[i12].atd(i8) * readerArr[i12].at(i9);
                        }
                        d = equals2 ? (d * d) / (dArr[i8] * dArr2[i9]) : d / Math.sqrt(dArr[i8] * dArr2[i9]);
                    }
                    newChunkArr[i9].addNum(d);
                }
            }
        }
    }

    @Override // water.rapids.ast.AstPrimitive
    public String[] args() {
        return new String[]{"ary", "x", "y", "measure"};
    }

    @Override // water.rapids.ast.AstPrimitive
    public int nargs() {
        return 4;
    }

    @Override // water.rapids.ast.AstRoot
    public String str() {
        return "distance";
    }

    @Override // water.rapids.ast.AstPrimitive, water.rapids.ast.AstRoot
    public String description() {
        return "Compute a pairwise distance measure between all rows of two numeric H2OFrames.\nFor a given (usually larger) reference frame (N rows x p cols),\nand a (usually smaller) query frame (M rows x p cols), we return a numeric Frame of size (N rows x M cols),\nwhere the ij-th element is the distance measure between the i-th reference row and the j-th query row.\nNote1: The output frame is symmetric.\nNote2: Since N x M can be very large, it may be more efficient (memory-wise) to make multiple calls with smaller query Frames.";
    }

    @Override // water.rapids.ast.AstBuiltin, water.rapids.ast.AstPrimitive
    public Val apply(Env env, Env.StackHelp stackHelp, AstRoot[] astRootArr) {
        return computeCosineDistances(stackHelp.track(astRootArr[1].exec(env)).getFrame(), stackHelp.track(astRootArr[2].exec(env)).getFrame(), stackHelp.track(astRootArr[3].exec(env)).getStr());
    }

    public Val computeCosineDistances(Frame frame, Frame frame2, String str) {
        Log.info("Number of references: " + frame.numRows());
        Log.info("Number of queries   : " + frame2.numRows());
        String[] strArr = {"cosine", "cosine_sq", "l1", "l2"};
        if (!ArrayUtils.contains(strArr, str.toLowerCase())) {
            throw new IllegalArgumentException("Invalid distance measure provided: " + str + ". Mustbe one of " + Arrays.toString(strArr));
        }
        if (frame.numRows() * frame2.numRows() * 8 > H2O.CLOUD.free_mem()) {
            throw new IllegalArgumentException("Not enough free memory to allocate the distance matrix (" + frame.numRows() + " rows and " + frame2.numRows() + " cols. Try specifying a smaller query frame.");
        }
        if (frame.numCols() != frame2.numCols()) {
            throw new IllegalArgumentException("Frames must have the same number of cols, found " + frame.numCols() + " and " + frame2.numCols());
        }
        if (frame2.numRows() > 2147483647L) {
            throw new IllegalArgumentException("Queries can't be larger than 2 billion rows.");
        }
        if (frame2.numCols() != frame.numCols()) {
            throw new IllegalArgumentException("Queries and References must have the same dimensionality");
        }
        for (int i = 0; i < frame2.numCols(); i++) {
            if (!frame.vec(i).isNumeric()) {
                throw new IllegalArgumentException("References column " + frame.name(i) + " is not numeric.");
            }
            if (!frame2.vec(i).isNumeric()) {
                throw new IllegalArgumentException("Queries column " + frame.name(i) + " is not numeric.");
            }
            if (frame.vec(i).naCnt() > 0) {
                throw new IllegalArgumentException("References column " + frame.name(i) + " contains missing values.");
            }
            if (frame2.vec(i).naCnt() > 0) {
                throw new IllegalArgumentException("Queries column " + frame.name(i) + " contains missing values.");
            }
        }
        return new ValFrame(new DistanceComputer(frame2, str).doAll((int) frame2.numRows(), (byte) 3, frame).outputFrame());
    }
}
