package hex.aggregator;

import hex.DataInfo;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.aggregator.AggregatorModel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import water.AutoBuffer;
import water.DKV;
import water.Iced;
import water.Job;
import water.Key;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;

/* loaded from: input_file:hex/aggregator/Aggregator.class */
public class Aggregator extends ModelBuilder<AggregatorModel, AggregatorModel.AggregatorParameters, AggregatorModel.AggregatorOutput> {

    /* loaded from: input_file:hex/aggregator/Aggregator$AggregateTask.class */
    private static class AggregateTask extends MRTask<AggregateTask> {
        final double _delta;
        final Key _dataInfoKey;
        final Key _jobKey;
        Exemplar[] _exemplars;
        long[] _counts;
        GIDMapping _mapping;
        static final /* synthetic */ boolean $assertionsDisabled;

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:hex/aggregator/Aggregator$AggregateTask$GIDMapping.class */
        public static class GIDMapping extends Iced<GIDMapping> {
            int capacity = 32;
            int len = 0;
            MyPair[] pairSet = new MyPair[this.capacity];

            void set(long j, long j2) {
                for (int i = 0; i < this.len; i++) {
                    MyPair myPair = this.pairSet[i];
                    if (myPair.second == j) {
                        myPair.second = j2;
                    }
                }
                MyPair myPair2 = new MyPair(j, j2);
                if (this.len == this.capacity) {
                    this.capacity *= 2;
                    this.pairSet = (MyPair[]) Arrays.copyOf(this.pairSet, this.capacity);
                }
                MyPair[] myPairArr = this.pairSet;
                int i2 = this.len;
                this.len = i2 + 1;
                myPairArr[i2] = myPair2;
            }

            long[][] unsortedList() {
                long[][] jArr = new long[2][this.len];
                MyPair[] myPairArr = this.pairSet;
                for (int i = 0; i < this.len; i++) {
                    jArr[0][i] = myPairArr[i].first;
                    jArr[1][i] = myPairArr[i].second;
                }
                return jArr;
            }
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        /* loaded from: input_file:hex/aggregator/Aggregator$AggregateTask$MyPair.class */
        public static class MyPair extends Iced<MyPair> implements Comparable<MyPair> {
            long first;
            long second;

            public MyPair(long j, long j2) {
                this.first = j;
                this.second = j2;
            }

            public MyPair() {
            }

            @Override // java.lang.Comparable
            public int compareTo(MyPair myPair) {
                if (this.first < myPair.first) {
                    return -1;
                }
                return this.first == myPair.first ? 0 : 1;
            }
        }

        public AggregateTask(Key<DataInfo> key, double d, Key<Job> key2) {
            this._delta = d * d;
            this._dataInfoKey = key;
            this._jobKey = key2;
        }

        public void map(Chunk[] chunkArr) {
            this._mapping = new GIDMapping();
            ArrayList<Exemplar> arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            Chunk[] chunkArr2 = (Chunk[]) Arrays.copyOf(chunkArr, chunkArr.length - 1);
            Chunk chunk = chunkArr[chunkArr.length - 1];
            DataInfo dataInfo = (DataInfo) this._dataInfoKey.get();
            if (!$assertionsDisabled && dataInfo == null) {
                throw new AssertionError();
            }
            DataInfo.Row newDenseRow = dataInfo.newDenseRow();
            int i = newDenseRow.nNums;
            for (int i2 = 0; i2 < chunkArr[0]._len; i2++) {
                long start = chunkArr[0].start() + i2;
                newDenseRow = dataInfo.extractDenseRow(chunkArr2, i2, newDenseRow);
                double[] copyOf = Arrays.copyOf(newDenseRow.numVals, i);
                if (i2 == 0) {
                    Exemplar exemplar = new Exemplar(copyOf, start);
                    arrayList.add(exemplar);
                    arrayList2.add(1L);
                    chunk.set(i2, exemplar.gid);
                } else {
                    double d = Double.MAX_VALUE;
                    int i3 = 0;
                    int i4 = 0;
                    long j = -1;
                    for (Exemplar exemplar2 : arrayList) {
                        double squaredEuclideanDistance = squaredEuclideanDistance(exemplar2.data, copyOf, i, d);
                        if (squaredEuclideanDistance < d) {
                            d = squaredEuclideanDistance;
                            i3 = i4;
                            j = exemplar2.gid;
                        }
                        if (d < this._delta) {
                            break;
                        } else {
                            i4++;
                        }
                    }
                    if (d < this._delta) {
                        arrayList2.set(i3, Long.valueOf(((Long) arrayList2.get(i3)).longValue() + 1));
                        chunk.set(i2, j);
                    } else {
                        arrayList.add(new Exemplar(copyOf, start));
                        arrayList2.add(1L);
                        chunk.set(i2, start);
                    }
                }
            }
            this._exemplars = (Exemplar[]) arrayList.toArray(new Exemplar[0]);
            Object[] array = arrayList2.toArray();
            this._counts = new long[arrayList2.size()];
            for (int i5 = 0; i5 < arrayList2.size(); i5++) {
                this._counts[i5] = ((Long) array[i5]).longValue();
            }
            if (!$assertionsDisabled && this._exemplars.length > chunkArr[0].len()) {
                throw new AssertionError();
            }
            if (!$assertionsDisabled && this._counts.length != this._exemplars.length) {
                throw new AssertionError();
            }
            long j2 = 0;
            for (long j3 : this._counts) {
                j2 += j3;
            }
            if (!$assertionsDisabled && j2 > chunkArr[0].len()) {
                throw new AssertionError();
            }
            this._jobKey.get().update(1L, "Aggregating.");
        }

        public void reduce(AggregateTask aggregateTask) {
            for (int i = 0; i < aggregateTask._mapping.len; i++) {
                this._mapping.set(aggregateTask._mapping.pairSet[i].first, aggregateTask._mapping.pairSet[i].second);
            }
            Exemplar[] exemplarArr = aggregateTask._exemplars;
            long[] jArr = aggregateTask._counts;
            long j = 0;
            for (long j2 : this._counts) {
                j += j2;
            }
            long j3 = 0;
            for (long j4 : jArr) {
                j3 += j4;
            }
            ArrayList arrayList = new ArrayList();
            arrayList.addAll(Arrays.asList(this._exemplars));
            for (int i2 = 0; i2 < exemplarArr.length; i2++) {
                double[] dArr = exemplarArr[i2].data;
                double d = Double.MAX_VALUE;
                int i3 = 0;
                Iterator it = arrayList.iterator();
                int i4 = 0;
                while (it.hasNext()) {
                    double squaredEuclideanDistance = squaredEuclideanDistance(((Exemplar) it.next()).data, dArr, dArr.length, d);
                    if (squaredEuclideanDistance < d) {
                        d = squaredEuclideanDistance;
                        i3 = i4;
                    }
                    if (d < this._delta) {
                        break;
                    } else {
                        i4++;
                    }
                }
                if (d < this._delta) {
                    long[] jArr2 = this._counts;
                    int i5 = i3;
                    jArr2[i5] = jArr2[i5] + jArr[i2];
                    this._mapping.set(exemplarArr[i2].gid, this._exemplars[i3].gid);
                } else {
                    arrayList.add(exemplarArr[i2].deepClone());
                    long[] copyOf = Arrays.copyOf(this._counts, this._counts.length + 1);
                    copyOf[this._counts.length] = jArr[i2];
                    this._counts = copyOf;
                }
            }
            this._exemplars = (Exemplar[]) arrayList.toArray(new Exemplar[0]);
            aggregateTask._exemplars = null;
            aggregateTask._counts = null;
            if (!$assertionsDisabled && this._exemplars.length > j + j3) {
                throw new AssertionError();
            }
            if (!$assertionsDisabled && this._counts.length != this._exemplars.length) {
                throw new AssertionError();
            }
            long j5 = 0;
            for (long j6 : this._counts) {
                j5 += j6;
            }
            if (!$assertionsDisabled && j5 != j + j3) {
                throw new AssertionError();
            }
            this._jobKey.get().update(1L, "Aggregating.");
        }

        private static double squaredEuclideanDistance(double[] dArr, double[] dArr2, int i, double d) {
            double d2 = 0.0d;
            int i2 = 0;
            boolean z = false;
            for (int i3 = 0; i3 < i; i3++) {
                double d3 = dArr[i3];
                double d4 = dArr2[i3];
                if (isMissing(d3) || isMissing(d4)) {
                    z = true;
                } else {
                    double d5 = d3 - d4;
                    d2 += d5 * d5;
                    i2++;
                }
                if (!z && d2 > d) {
                    break;
                }
            }
            return d2 * (i / i2);
        }

        private static boolean isMissing(double d) {
            return Double.isNaN(d);
        }

        static {
            $assertionsDisabled = !Aggregator.class.desiredAssertionStatus();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:hex/aggregator/Aggregator$AggregatorDriver.class */
    public class AggregatorDriver extends ModelBuilder<AggregatorModel, AggregatorModel.AggregatorParameters, AggregatorModel.AggregatorOutput>.Driver {
        AggregatorDriver() {
            super(Aggregator.this);
        }

        public void compute2() {
            AggregatorModel aggregatorModel = null;
            DataInfo dataInfo = null;
            try {
                Aggregator.this.init(true);
                ((AggregatorModel.AggregatorParameters) Aggregator.this._parms).read_lock_frames(Aggregator.this._job);
                if (Aggregator.this.error_count() > 0) {
                    throw new IllegalArgumentException("Found validation errors: " + Aggregator.this.validationErrors());
                }
                AggregatorModel aggregatorModel2 = new AggregatorModel(Aggregator.this.dest(), (AggregatorModel.AggregatorParameters) Aggregator.this._parms, new AggregatorModel.AggregatorOutput(Aggregator.this));
                aggregatorModel2.delete_and_lock(Aggregator.this._job);
                Frame train = Aggregator.this.train();
                Aggregator.this._job.update(1L, "Preprocessing data.");
                DataInfo dataInfo2 = new DataInfo(train, null, true, ((AggregatorModel.AggregatorParameters) Aggregator.this._parms)._transform, false, false, false);
                DKV.put(dataInfo2);
                double pow = (((AggregatorModel.AggregatorParameters) Aggregator.this._parms)._radius_scale * 0.1d) / Math.pow(Math.log(train.numRows()), 1.0d / train.numCols());
                Vec[] vecArr = (Vec[]) Arrays.copyOf(train.vecs(), train.vecs().length + 1);
                int length = vecArr.length - 1;
                Vec makeZero = train.anyVec().makeZero();
                vecArr[length] = makeZero;
                Aggregator.this._job.update(1L, "Aggregating.");
                AggregateTask aggregateTask = (AggregateTask) new AggregateTask(dataInfo2._key, pow, Aggregator.this._job._key).doAll(vecArr);
                Aggregator.this._job.update(1L, "Aggregating exemplar assignments.");
                new RenumberTask(aggregateTask._mapping).doAll(new Vec[]{makeZero});
                aggregatorModel2._exemplars = aggregateTask._exemplars;
                aggregatorModel2._counts = aggregateTask._counts;
                aggregatorModel2._exemplar_assignment_vec_key = makeZero._key;
                ((AggregatorModel.AggregatorOutput) aggregatorModel2._output)._output_frame = Key.make("aggregated_" + ((AggregatorModel.AggregatorParameters) Aggregator.this._parms)._train.toString() + "_by_" + aggregatorModel2._key);
                Aggregator.this._job.update(1L, "Creating output frame.");
                aggregatorModel2.createFrameOfExemplars(((AggregatorModel.AggregatorOutput) aggregatorModel2._output)._output_frame);
                Aggregator.this._job.update(1L, "Done.");
                aggregatorModel2.update(Aggregator.this._job);
                ((AggregatorModel.AggregatorParameters) Aggregator.this._parms).read_unlock_frames(Aggregator.this._job);
                if (aggregatorModel2 != null) {
                    aggregatorModel2.unlock(Aggregator.this._job);
                }
                if (dataInfo2 != null) {
                    dataInfo2.remove();
                }
                tryComplete();
            } catch (Throwable th) {
                ((AggregatorModel.AggregatorParameters) Aggregator.this._parms).read_unlock_frames(Aggregator.this._job);
                if (0 != 0) {
                    aggregatorModel.unlock(Aggregator.this._job);
                }
                if (0 != 0) {
                    dataInfo.remove();
                }
                throw th;
            }
        }
    }

    /* loaded from: input_file:hex/aggregator/Aggregator$Exemplar.class */
    public static class Exemplar extends Iced<Exemplar> {
        final double[] data;
        final long gid;

        Exemplar(double[] dArr, long j) {
            this.data = dArr;
            this.gid = j;
        }

        Exemplar deepClone() {
            return new AutoBuffer().put(this).flipForReading().get();
        }
    }

    /* loaded from: input_file:hex/aggregator/Aggregator$RenumberTask.class */
    private static class RenumberTask extends MRTask<RenumberTask> {
        final long[][] _map;

        public RenumberTask(AggregateTask.GIDMapping gIDMapping) {
            this._map = gIDMapping.unsortedList();
        }

        public void map(Chunk chunk) {
            for (int i = 0; i < chunk._len; i++) {
                int find = ArrayUtils.find(this._map[0], chunk.at8(i));
                if (find >= 0) {
                    chunk.set(i, this._map[1][find]);
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: trainModelImpl, reason: merged with bridge method [inline-methods] */
    public AggregatorDriver m5trainModelImpl() {
        return new AggregatorDriver();
    }

    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.Clustering};
    }

    public Aggregator(AggregatorModel.AggregatorParameters aggregatorParameters) {
        super(aggregatorParameters);
        init(false);
    }

    public Aggregator(boolean z) {
        super(new AggregatorModel.AggregatorParameters(), z);
    }

    public void init(boolean z) {
        super.init(z);
    }
}
