package hex.aggregator;

import hex.DataInfo;
import hex.Model;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.ToEigenVec;
import hex.aggregator.AggregatorModel;
import hex.util.LinearAlgebraUtils;
import java.util.Arrays;
import java.util.Collections;
import water.DKV;
import water.Iced;
import water.IcedUtils;
import water.Job;
import water.Key;
import water.MRTask;
import water.Scope;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.IcedInt;
import water.util.Log;

/* loaded from: input_file:hex/aggregator/Aggregator.class */
public class Aggregator extends ModelBuilder<AggregatorModel, AggregatorModel.AggregatorParameters, AggregatorModel.AggregatorOutput> {

    /* loaded from: input_file:hex/aggregator/Aggregator$AggregateTask.class */
    private static class AggregateTask extends MRTask<AggregateTask> {
        final double _delta;
        final Key _dataInfoKey;
        final Key _jobKey;
        final int _maxExemplars;
        Exemplar[] _exemplars;
        Key _terminateKey;
        GIDMapping _mapping;
        static final /* synthetic */ boolean $assertionsDisabled;

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:hex/aggregator/Aggregator$AggregateTask$GIDMapping.class */
        public static class GIDMapping extends Iced<GIDMapping> {
            int capacity = 32;
            int len = 0;
            MyPair[] pairSet = new MyPair[this.capacity];

            void set(long j, long j2) {
                for (int i = 0; i < this.len; i++) {
                    MyPair myPair = this.pairSet[i];
                    if (myPair.second == j) {
                        myPair.second = j2;
                    }
                }
                MyPair myPair2 = new MyPair(j, j2);
                if (this.len == this.capacity) {
                    this.capacity *= 2;
                    this.pairSet = (MyPair[]) Arrays.copyOf(this.pairSet, this.capacity);
                }
                MyPair[] myPairArr = this.pairSet;
                int i2 = this.len;
                this.len = i2 + 1;
                myPairArr[i2] = myPair2;
            }

            long[][] unsortedList() {
                long[][] jArr = new long[2][this.len];
                MyPair[] myPairArr = this.pairSet;
                for (int i = 0; i < this.len; i++) {
                    jArr[0][i] = myPairArr[i].first;
                    jArr[1][i] = myPairArr[i].second;
                }
                return jArr;
            }
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        /* loaded from: input_file:hex/aggregator/Aggregator$AggregateTask$MyPair.class */
        public static class MyPair extends Iced<MyPair> implements Comparable<MyPair> {
            long first;
            long second;

            public MyPair(long j, long j2) {
                this.first = j;
                this.second = j2;
            }

            public MyPair() {
            }

            @Override // java.lang.Comparable
            public int compareTo(MyPair myPair) {
                if (this.first < myPair.first) {
                    return -1;
                }
                return this.first == myPair.first ? 0 : 1;
            }
        }

        public AggregateTask(Key<DataInfo> key, double d, Key<Job> key2, int i, Key key3) {
            this._delta = d * d;
            this._dataInfoKey = key;
            this._jobKey = key2;
            this._maxExemplars = i;
            this._terminateKey = key3;
            if (this._terminateKey != null) {
                DKV.put(this._terminateKey, new IcedInt(0));
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        public boolean isTerminated() {
            return this._terminateKey != null && DKV.getGet(this._terminateKey)._val == 1;
        }

        private void terminate() {
            if (this._terminateKey != null) {
                DKV.put(this._terminateKey, new IcedInt(1));
            }
        }

        public void map(Chunk[] chunkArr) {
            Exemplar exemplar;
            this._mapping = new GIDMapping();
            Exemplar[] exemplarArr = new Exemplar[4];
            Chunk[] chunkArr2 = (Chunk[]) Arrays.copyOf(chunkArr, chunkArr.length - 1);
            Chunk chunk = chunkArr[chunkArr.length - 1];
            DataInfo dataInfo = (DataInfo) this._dataInfoKey.get();
            if (!$assertionsDisabled && dataInfo == null) {
                throw new AssertionError();
            }
            DataInfo.Row newDenseRow = dataInfo.newDenseRow();
            int i = newDenseRow.nNums;
            for (int i2 = 0; i2 < chunkArr[0]._len; i2++) {
                if (i2 % 100 == 0 && isTerminated()) {
                    return;
                }
                long start = chunkArr[0].start() + i2;
                newDenseRow = dataInfo.extractDenseRow(chunkArr2, i2, newDenseRow);
                double[] copyOf = Arrays.copyOf(newDenseRow.numVals, i);
                int[] copyOf2 = Arrays.copyOf(newDenseRow.binIds, newDenseRow.binIds.length);
                if (i2 == 0) {
                    Exemplar exemplar2 = new Exemplar(copyOf, copyOf2, start);
                    exemplarArr = Exemplar.addExemplar(exemplarArr, exemplar2);
                    chunk.set(i2, exemplar2.gid);
                } else {
                    double d = Double.MAX_VALUE;
                    int i3 = 0;
                    int i4 = 0;
                    long j = -1;
                    Exemplar[] exemplarArr2 = exemplarArr;
                    int length = exemplarArr2.length;
                    for (int i5 = 0; i5 < length && null != (exemplar = exemplarArr2[i5]); i5++) {
                        if (Arrays.equals(copyOf2, exemplar.cats)) {
                            double squaredEuclideanDistance = exemplar.squaredEuclideanDistance(copyOf, d);
                            if (squaredEuclideanDistance < d) {
                                d = squaredEuclideanDistance;
                                i3 = i4;
                                j = exemplar.gid;
                            }
                            if (d < this._delta) {
                                break;
                            }
                        }
                        i4++;
                    }
                    if (d < this._delta) {
                        exemplarArr[i3]._cnt++;
                        chunk.set(i2, j);
                    } else {
                        Exemplar exemplar3 = new Exemplar(copyOf, copyOf2, start);
                        if (!$assertionsDisabled && !Arrays.equals(copyOf2, exemplar3.cats)) {
                            throw new AssertionError();
                        }
                        exemplarArr = Exemplar.addExemplar(exemplarArr, exemplar3);
                        if (exemplarArr.length > 2 * this._maxExemplars) {
                            terminate();
                        }
                        chunk.set(i2, start);
                    }
                }
            }
            this._exemplars = Exemplar.trim(exemplarArr);
            if (this._exemplars.length > this._maxExemplars) {
                terminate();
            }
            if (isTerminated()) {
                return;
            }
            if (!$assertionsDisabled && this._exemplars.length > chunkArr[0].len()) {
                throw new AssertionError();
            }
            long j2 = 0;
            for (Exemplar exemplar4 : this._exemplars) {
                j2 += exemplar4._cnt;
            }
            if (!$assertionsDisabled && j2 > chunkArr[0].len()) {
                throw new AssertionError();
            }
            this._jobKey.get().update(1L, "Aggregating.");
        }

        public void reduce(AggregateTask aggregateTask) {
            Exemplar exemplar;
            if (isTerminated() || this._exemplars == null || aggregateTask._exemplars == null || this._exemplars.length > this._maxExemplars || aggregateTask._exemplars.length > this._maxExemplars) {
                terminate();
                this._mapping = null;
                this._exemplars = null;
                aggregateTask._exemplars = null;
            }
            if (isTerminated()) {
                return;
            }
            for (int i = 0; i < aggregateTask._mapping.len; i++) {
                this._mapping.set(aggregateTask._mapping.pairSet[i].first, aggregateTask._mapping.pairSet[i].second);
            }
            Exemplar[] exemplarArr = aggregateTask._exemplars;
            long j = 0;
            for (Exemplar exemplar2 : this._exemplars) {
                j += exemplar2._cnt;
            }
            long j2 = 0;
            for (Exemplar exemplar3 : aggregateTask._exemplars) {
                j2 += exemplar3._cnt;
            }
            for (int i2 = 0; i2 < aggregateTask._exemplars.length; i2++) {
                double d = Double.MAX_VALUE;
                int i3 = 0;
                int i4 = 0;
                Exemplar[] exemplarArr2 = this._exemplars;
                int length = exemplarArr2.length;
                for (int i5 = 0; i5 < length && null != (exemplar = exemplarArr2[i5]); i5++) {
                    double squaredEuclideanDistance = exemplar.squaredEuclideanDistance(aggregateTask._exemplars[i2].data, d);
                    if (squaredEuclideanDistance < d) {
                        d = squaredEuclideanDistance;
                        i3 = i4;
                    }
                    if (d < this._delta) {
                        break;
                    }
                    i4++;
                }
                if (d < this._delta) {
                    this._exemplars[i3]._cnt += aggregateTask._exemplars[i2]._cnt;
                    this._mapping.set(exemplarArr[i2].gid, this._exemplars[i3].gid);
                } else {
                    this._exemplars = Exemplar.addExemplar(this._exemplars, (Exemplar) IcedUtils.deepCopy(aggregateTask._exemplars[i2]));
                }
            }
            aggregateTask._exemplars = null;
            this._exemplars = Exemplar.trim(this._exemplars);
            if (!$assertionsDisabled && this._exemplars.length > j + j2) {
                throw new AssertionError();
            }
            long j3 = 0;
            for (Exemplar exemplar4 : this._exemplars) {
                j3 += exemplar4._cnt;
            }
            if (!$assertionsDisabled && j3 != j + j2) {
                throw new AssertionError();
            }
            this._jobKey.get().update(1L, "Aggregating.");
        }

        static {
            $assertionsDisabled = !Aggregator.class.desiredAssertionStatus();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:hex/aggregator/Aggregator$AggregatorDriver.class */
    public class AggregatorDriver extends ModelBuilder<AggregatorModel, AggregatorModel.AggregatorParameters, AggregatorModel.AggregatorOutput>.Driver {
        static final /* synthetic */ boolean $assertionsDisabled;

        AggregatorDriver() {
            super(Aggregator.this);
        }

        public void computeImpl() {
            Vec makeZero;
            AggregateTask aggregateTask;
            int length;
            AggregatorModel aggregatorModel = null;
            DataInfo dataInfo = null;
            try {
                try {
                    Aggregator.this.init(true);
                    if (Aggregator.this.error_count() > 0) {
                        throw new IllegalArgumentException("Found validation errors: " + Aggregator.this.validationErrors());
                    }
                    AggregatorModel aggregatorModel2 = new AggregatorModel(Aggregator.this.dest(), (AggregatorModel.AggregatorParameters) Aggregator.this._parms, new AggregatorModel.AggregatorOutput(Aggregator.this));
                    aggregatorModel2.delete_and_lock(Aggregator.this._job);
                    Frame train = Aggregator.this.train();
                    Aggregator.this._job.update(1L, "Preprocessing data.");
                    DataInfo dataInfo2 = new DataInfo(train, null, true, ((AggregatorModel.AggregatorParameters) Aggregator.this._parms)._transform, false, false, false);
                    DKV.put(dataInfo2);
                    double pow = 0.1d / Math.pow(Math.log(train.numRows()), 1.0d / train.numCols());
                    int min = (int) Math.min(((AggregatorModel.AggregatorParameters) Aggregator.this._parms)._target_num_exemplars, train.numRows());
                    Aggregator.this._job.update(0L, "Aggregating.");
                    double d = 0.0d;
                    double d2 = 256.0d;
                    double d3 = 8.0d;
                    double d4 = ((AggregatorModel.AggregatorParameters) Aggregator.this._parms)._rel_tol_num_exemplars;
                    int i = (int) ((1.0d + d4) * min);
                    int i2 = (int) ((1.0d - d4) * min);
                    Key make = Key.make();
                    while (true) {
                        Log.info(new Object[]{"radius_scale lo/mid/hi: " + d + "/" + d3 + "/" + d2});
                        double d5 = d3 * pow;
                        if (min == train.numRows()) {
                            d5 = 0.0d;
                        }
                        Vec[] vecArr = (Vec[]) Arrays.copyOf(train.vecs(), train.vecs().length + 1);
                        int length2 = vecArr.length - 1;
                        makeZero = train.anyVec().makeZero();
                        vecArr[length2] = makeZero;
                        Log.info(new Object[]{"Aggregating with radius " + String.format("%5f", Double.valueOf(d5)) + ":"});
                        aggregateTask = (AggregateTask) new AggregateTask(dataInfo2._key, d5, Aggregator.this._job._key, i, d5 == 0.0d ? null : make).doAll(vecArr);
                        if (d5 != 0.0d) {
                            if (aggregateTask.isTerminated() && Math.abs(d2 - d) < 0.001d * Math.abs(d + d2)) {
                                aggregateTask = (AggregateTask) new AggregateTask(dataInfo2._key, d5, Aggregator.this._job._key, (int) train.numRows(), make).doAll(vecArr);
                                Log.info(new Object[]{" Running again without early cutout."});
                                length = aggregateTask._exemplars.length;
                                break;
                            }
                            if (!aggregateTask.isTerminated() && aggregateTask._exemplars.length <= i) {
                                length = aggregateTask._exemplars.length;
                                Log.info(new Object[]{" " + length + " exemplars."});
                                if (length >= i2 && length <= i) {
                                    Log.info(new Object[]{" Within " + (100.0d * d4) + "% of target number of exemplars. Done."});
                                    break;
                                } else {
                                    Log.info(new Object[]{" Too few exemplars."});
                                    d2 = d3;
                                }
                            } else {
                                Log.info(new Object[]{" Too many exemplars."});
                                d = d3;
                            }
                            d3 = d + ((d2 - d) / 2.0d);
                        } else {
                            Log.info(new Object[]{" Returning original dataset."});
                            length = aggregateTask._exemplars.length;
                            if (!$assertionsDisabled && length != train.numRows()) {
                                throw new AssertionError();
                            }
                        }
                    }
                    Aggregator.this._job.update(1L, "Aggregation finished. Got " + length + " examplars");
                    if (!$assertionsDisabled && aggregateTask.isTerminated()) {
                        throw new AssertionError();
                    }
                    DKV.remove(make);
                    Log.info(new Object[]{"Creating exemplar assignments."});
                    Aggregator.this._job.update(1L, "Creating exemplar assignments.");
                    new RenumberTask(aggregateTask._mapping).doAll(new Vec[]{makeZero});
                    aggregatorModel2._exemplars = aggregateTask._exemplars;
                    aggregatorModel2._counts = new long[aggregateTask._exemplars.length];
                    for (int i3 = 0; i3 < aggregateTask._exemplars.length; i3++) {
                        aggregatorModel2._counts[i3] = aggregateTask._exemplars[i3]._cnt;
                    }
                    aggregatorModel2._exemplar_assignment_vec_key = makeZero._key;
                    ((AggregatorModel.AggregatorOutput) aggregatorModel2._output)._output_frame = Key.make("aggregated_" + ((AggregatorModel.AggregatorParameters) Aggregator.this._parms)._train.toString() + "_by_" + aggregatorModel2._key);
                    Log.info(new Object[]{"Creating output frame."});
                    Aggregator.this._job.update(1L, "Creating output frame.");
                    aggregatorModel2.createFrameOfExemplars((Frame) ((AggregatorModel.AggregatorParameters) Aggregator.this._parms)._train.get(), ((AggregatorModel.AggregatorOutput) aggregatorModel2._output)._output_frame);
                    Aggregator.this._job.update(1L, "Done.");
                    aggregatorModel2.update(Aggregator.this._job);
                    if (aggregatorModel2 != null) {
                        aggregatorModel2.unlock(Aggregator.this._job);
                        Scope.untrack(Collections.singletonList(aggregatorModel2._exemplar_assignment_vec_key));
                        Frame frame = ((AggregatorModel.AggregatorOutput) aggregatorModel2._output)._output_frame != null ? (Frame) ((AggregatorModel.AggregatorOutput) aggregatorModel2._output)._output_frame.get() : null;
                        if (frame != null) {
                            Scope.untrack(frame.keysList());
                        }
                    }
                    if (dataInfo2 != null) {
                        dataInfo2.remove();
                    }
                } catch (Throwable th) {
                    th.printStackTrace();
                    throw th;
                }
            } catch (Throwable th2) {
                if (0 != 0) {
                    aggregatorModel.unlock(Aggregator.this._job);
                    Scope.untrack(Collections.singletonList(aggregatorModel._exemplar_assignment_vec_key));
                    Frame frame2 = ((AggregatorModel.AggregatorOutput) aggregatorModel._output)._output_frame != null ? (Frame) ((AggregatorModel.AggregatorOutput) aggregatorModel._output)._output_frame.get() : null;
                    if (frame2 != null) {
                        Scope.untrack(frame2.keysList());
                    }
                }
                if (0 != 0) {
                    dataInfo.remove();
                }
                throw th2;
            }
        }

        static {
            $assertionsDisabled = !Aggregator.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:hex/aggregator/Aggregator$Exemplar.class */
    public static class Exemplar extends Iced<Exemplar> {
        final double[] data;
        final int[] cats;
        final long gid;
        long _cnt = 1;

        Exemplar(double[] dArr, int[] iArr, long j) {
            this.data = dArr;
            this.cats = iArr;
            this.gid = j;
        }

        public static Exemplar[] addExemplar(Exemplar[] exemplarArr, Exemplar exemplar) {
            if (exemplarArr.length == 0) {
                return new Exemplar[]{exemplar};
            }
            int length = exemplarArr.length - 1;
            while (length >= 0 && exemplarArr[length] == null) {
                length--;
            }
            if (length != exemplarArr.length - 1) {
                exemplarArr[length + 1] = exemplar;
                return exemplarArr;
            }
            Exemplar[] exemplarArr2 = (Exemplar[]) Arrays.copyOf(exemplarArr, exemplarArr.length << 1);
            exemplarArr2[exemplarArr.length] = exemplar;
            return exemplarArr2;
        }

        public static Exemplar[] trim(Exemplar[] exemplarArr) {
            int length = exemplarArr.length - 1;
            while (length >= 0 && null == exemplarArr[length]) {
                length--;
            }
            return (Exemplar[]) Arrays.copyOf(exemplarArr, length + 1);
        }

        /* JADX INFO: Access modifiers changed from: private */
        public double squaredEuclideanDistance(double[] dArr, double d) {
            double d2 = 0.0d;
            int i = 0;
            boolean z = false;
            double[] dArr2 = this.data;
            double length = dArr2.length;
            for (int i2 = 0; i2 < length; i2++) {
                double d3 = dArr2[i2];
                double d4 = dArr[i2];
                if (isMissing(d3) || isMissing(d4)) {
                    z = true;
                } else {
                    double d5 = d3 - d4;
                    d2 += d5 * d5;
                    i++;
                }
                if (!z && d2 > d) {
                    break;
                }
            }
            return d2 * (length / i);
        }

        private static boolean isMissing(double d) {
            return Double.isNaN(d);
        }
    }

    /* loaded from: input_file:hex/aggregator/Aggregator$RenumberTask.class */
    private static class RenumberTask extends MRTask<RenumberTask> {
        final long[][] _map;

        public RenumberTask(AggregateTask.GIDMapping gIDMapping) {
            this._map = gIDMapping.unsortedList();
        }

        public void map(Chunk chunk) {
            for (int i = 0; i < chunk._len; i++) {
                int find = ArrayUtils.find(this._map[0], chunk.at8(i));
                if (find >= 0) {
                    chunk.set(i, this._map[1][find]);
                }
            }
        }
    }

    public ToEigenVec getToEigenVec() {
        return LinearAlgebraUtils.toEigen;
    }

    public ModelBuilder.BuilderVisibility builderVisibility() {
        return ModelBuilder.BuilderVisibility.Stable;
    }

    public boolean isSupervised() {
        return false;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: trainModelImpl, reason: merged with bridge method [inline-methods] */
    public AggregatorDriver m6trainModelImpl() {
        return new AggregatorDriver();
    }

    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.Clustering};
    }

    public Aggregator(AggregatorModel.AggregatorParameters aggregatorParameters) {
        super(aggregatorParameters);
        init(false);
    }

    public Aggregator(boolean z) {
        super(new AggregatorModel.AggregatorParameters(), z);
    }

    public void init(boolean z) {
        if (z && ((AggregatorModel.AggregatorParameters) this._parms)._categorical_encoding == Model.Parameters.CategoricalEncodingScheme.AUTO) {
            ((AggregatorModel.AggregatorParameters) this._parms)._categorical_encoding = Model.Parameters.CategoricalEncodingScheme.Eigen;
        }
        if (((AggregatorModel.AggregatorParameters) this._parms)._target_num_exemplars <= 0) {
            error("_target_num_exemplars", "target_num_exemplars must be > 0.");
        }
        if (((AggregatorModel.AggregatorParameters) this._parms)._rel_tol_num_exemplars <= 0.0d || ((AggregatorModel.AggregatorParameters) this._parms)._rel_tol_num_exemplars >= 1.0d) {
            error("_rel_tol_num_exemplars", "rel_tol_num_exemplars must be inside 0...1.");
        }
        super.init(z);
        if (error_count() > 0) {
            throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(this);
        }
    }
}
