package org.apache.mahout.clustering.streaming.cluster;

import com.google.common.base.Function;
import com.google.common.collect.Iterables;
import com.google.common.collect.Iterators;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.Random;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.math.Centroid;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.MatrixSlice;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.neighborhood.UpdatableSearcher;
import org.apache.mahout.math.random.WeightedThing;

/* loaded from: input_file:org/apache/mahout/clustering/streaming/cluster/StreamingKMeans.class */
public class StreamingKMeans implements Iterable<Centroid> {
    private final UpdatableSearcher centroids;
    private int numClusters;
    private int numProcessedDatapoints;
    private double distanceCutoff;
    private final double beta;
    private final double clusterLogFactor;
    private final double clusterOvershoot;
    private final Random random;

    public StreamingKMeans(UpdatableSearcher updatableSearcher, int i) {
        this(updatableSearcher, i, 1.0d / i, 1.3d, 20.0d, 2.0d);
    }

    public StreamingKMeans(UpdatableSearcher updatableSearcher, int i, double d) {
        this(updatableSearcher, i, d, 1.3d, 20.0d, 2.0d);
    }

    public StreamingKMeans(UpdatableSearcher updatableSearcher, int i, double d, double d2, double d3, double d4) {
        this.numProcessedDatapoints = 0;
        this.random = RandomUtils.getRandom();
        this.centroids = updatableSearcher;
        this.numClusters = i;
        this.distanceCutoff = d;
        this.beta = d2;
        this.clusterLogFactor = d3;
        this.clusterOvershoot = d4;
    }

    @Override // java.lang.Iterable
    public Iterator<Centroid> iterator() {
        return Iterators.transform(this.centroids.iterator(), new Function<Vector, Centroid>() { // from class: org.apache.mahout.clustering.streaming.cluster.StreamingKMeans.1
            @Override // com.google.common.base.Function
            public Centroid apply(Vector vector) {
                return (Centroid) vector;
            }
        });
    }

    public UpdatableSearcher cluster(Matrix matrix) {
        return cluster(Iterables.transform(matrix, new Function<MatrixSlice, Centroid>() { // from class: org.apache.mahout.clustering.streaming.cluster.StreamingKMeans.2
            @Override // com.google.common.base.Function
            public Centroid apply(MatrixSlice matrixSlice) {
                return Centroid.create(matrixSlice.index(), matrixSlice.vector());
            }
        }));
    }

    public UpdatableSearcher cluster(Iterable<Centroid> iterable) {
        return clusterInternal(iterable, false);
    }

    public UpdatableSearcher cluster(final Centroid centroid) {
        return cluster(new Iterable<Centroid>() { // from class: org.apache.mahout.clustering.streaming.cluster.StreamingKMeans.3
            @Override // java.lang.Iterable
            public Iterator<Centroid> iterator() {
                return new Iterator<Centroid>() { // from class: org.apache.mahout.clustering.streaming.cluster.StreamingKMeans.3.1
                    private boolean accessed = false;

                    @Override // java.util.Iterator
                    public boolean hasNext() {
                        return !this.accessed;
                    }

                    /* JADX WARN: Can't rename method to resolve collision */
                    @Override // java.util.Iterator
                    public Centroid next() {
                        this.accessed = true;
                        return centroid;
                    }

                    @Override // java.util.Iterator
                    public void remove() {
                        throw new UnsupportedOperationException();
                    }
                };
            }
        });
    }

    public int getNumClusters() {
        return this.centroids.size();
    }

    private UpdatableSearcher clusterInternal(Iterable<Centroid> iterable, boolean z) {
        Iterator<Centroid> it = iterable.iterator();
        if (!it.hasNext()) {
            return this.centroids;
        }
        int i = this.numProcessedDatapoints;
        if (z) {
            this.centroids.clear();
            this.numProcessedDatapoints = 0;
        }
        if (this.centroids.size() == 0) {
            this.centroids.add(it.next().mo1229clone());
            this.numProcessedDatapoints++;
        }
        while (it.hasNext()) {
            Centroid next = it.next();
            WeightedThing<Vector> searchFirst = this.centroids.searchFirst((Vector) next, false);
            if (this.random.nextDouble() < (next.getWeight() * searchFirst.getWeight()) / this.distanceCutoff) {
                this.centroids.add(next.mo1229clone());
            } else {
                Centroid centroid = (Centroid) searchFirst.getValue();
                if (!this.centroids.remove(centroid, 1.0E-6d)) {
                    throw new RuntimeException("Unable to remove centroid");
                }
                centroid.update(next);
                this.centroids.add(centroid);
            }
            this.numProcessedDatapoints++;
            if (!z && this.centroids.size() > this.clusterOvershoot * this.numClusters) {
                this.numClusters = (int) Math.max(this.numClusters, this.clusterLogFactor * Math.log(this.numProcessedDatapoints));
                ArrayList newArrayList = Lists.newArrayList();
                Iterator it2 = this.centroids.iterator();
                while (it2.hasNext()) {
                    newArrayList.add((Centroid) ((Vector) it2.next()));
                }
                Collections.shuffle(newArrayList);
                clusterInternal(newArrayList, true);
                if (this.centroids.size() > this.numClusters) {
                    this.distanceCutoff *= this.beta;
                }
            }
        }
        if (z) {
            this.numProcessedDatapoints = i;
        }
        return this.centroids;
    }

    public void reindexCentroids() {
        int i = 0;
        Iterator<Centroid> it = iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            it.next().setIndex(i2);
        }
    }

    public double getDistanceCutoff() {
        return this.distanceCutoff;
    }

    public void setDistanceCutoff(double d) {
        this.distanceCutoff = d;
    }

    public DistanceMeasure getDistanceMeasure() {
        return this.centroids.getDistanceMeasure();
    }
}
