/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.cluster;

import cc.mallet.cluster.Clusterer;
import cc.mallet.cluster.Clustering;
import cc.mallet.pipe.Pipe;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Metric;
import cc.mallet.types.SparseVector;
import cc.mallet.util.VectorStats;
import java.util.ArrayList;
import java.util.Random;
import java.util.logging.Logger;

public class KMeans
extends Clusterer {
    private static final long serialVersionUID = 1L;
    static double MEANS_TOLERANCE = 0.01;
    static int MAX_ITER = 100;
    static double POINTS_TOLERANCE = 0.005;
    public static final int EMPTY_ERROR = 0;
    public static final int EMPTY_DROP = 1;
    public static final int EMPTY_SINGLE = 2;
    Random randinator;
    Metric metric;
    int numClusters;
    int emptyAction;
    ArrayList<SparseVector> clusterMeans;
    private static Logger logger = Logger.getLogger("edu.umass.cs.mallet.base.cluster.KMeans");

    public KMeans(Pipe instancePipe, int numClusters, Metric metric, int emptyAction) {
        super(instancePipe);
        this.emptyAction = emptyAction;
        this.metric = metric;
        this.numClusters = numClusters;
        this.clusterMeans = new ArrayList(numClusters);
        this.randinator = new Random();
    }

    public KMeans(Pipe instancePipe, int numClusters, Metric metric) {
        this(instancePipe, numClusters, metric, 0);
    }

    @Override
    public Clustering cluster(InstanceList instances) {
        assert (instances.getPipe() == this.instancePipe);
        this.initializeMeansSample(instances, this.metric);
        int[] clusterLabels = new int[instances.size()];
        ArrayList<InstanceList> instanceClusters = new ArrayList<InstanceList>(this.numClusters);
        double deltaMeans = Double.MAX_VALUE;
        double deltaPoints = instances.size();
        int iterations = 0;
        int c = 0;
        while (c < this.numClusters) {
            instanceClusters.add(c, new InstanceList(this.instancePipe));
            ++c;
        }
        logger.info("Entering KMeans iteration");
        while (deltaMeans > MEANS_TOLERANCE && iterations < MAX_ITER && deltaPoints > (double)instances.size() * POINTS_TOLERANCE) {
            ++iterations;
            deltaPoints = 0.0;
            int n = 0;
            while (n < instances.size()) {
                int instClust = 0;
                double instClustDist = Double.MAX_VALUE;
                int c2 = 0;
                while (c2 < this.numClusters) {
                    double instDist = this.metric.distance(this.clusterMeans.get(c2), (SparseVector)((Instance)instances.get(n)).getData());
                    if (instDist < instClustDist) {
                        instClust = c2;
                        instClustDist = instDist;
                    }
                    ++c2;
                }
                ((InstanceList)instanceClusters.get(instClust)).add((Instance)instances.get(n));
                if (clusterLabels[n] != instClust) {
                    clusterLabels[n] = instClust;
                    deltaPoints += 1.0;
                }
                ++n;
            }
            deltaMeans = 0.0;
            c = 0;
            while (c < this.numClusters) {
                if (((InstanceList)instanceClusters.get(c)).size() > 0) {
                    SparseVector clusterMean = VectorStats.mean((InstanceList)instanceClusters.get(c));
                    deltaMeans += this.metric.distance(this.clusterMeans.get(c), clusterMean);
                    this.clusterMeans.set(c, clusterMean);
                    instanceClusters.set(c, new InstanceList(this.instancePipe));
                } else {
                    logger.info("Empty cluster found.");
                    switch (this.emptyAction) {
                        case 0: {
                            return null;
                        }
                        case 1: {
                            logger.fine("Removing cluster " + c);
                            this.clusterMeans.remove(c);
                            instanceClusters.remove(c);
                            int n2 = 0;
                            while (n2 < instances.size()) {
                                assert (clusterLabels[n2] != c) : "Cluster size is " + ((InstanceList)instanceClusters.get(c)).size() + "+ yet clusterLabels[n] is " + clusterLabels[n2];
                                if (clusterLabels[n2] > c) {
                                    int n3 = n2;
                                    clusterLabels[n3] = clusterLabels[n3] - 1;
                                }
                                ++n2;
                            }
                            --this.numClusters;
                            --c;
                            break;
                        }
                        case 2: {
                            double newCentroidDist = 0.0;
                            int newCentroid = 0;
                            ArrayList cacheList = null;
                            int clusters = 0;
                            while (clusters < this.clusterMeans.size()) {
                                SparseVector centroid = this.clusterMeans.get(clusters);
                                InstanceList centInstances = (InstanceList)instanceClusters.get(clusters);
                                if (centInstances.size() > 1) {
                                    int n4 = 0;
                                    while (n4 < centInstances.size()) {
                                        double currentDist = this.metric.distance(centroid, (SparseVector)((Instance)centInstances.get(n4)).getData());
                                        if (currentDist > newCentroidDist) {
                                            newCentroid = n4;
                                            newCentroidDist = currentDist;
                                            cacheList = centInstances;
                                        }
                                        ++n4;
                                    }
                                }
                                ++clusters;
                            }
                            if (cacheList == null) {
                                logger.info("Can't find an instance to move.  Exiting.");
                                return null;
                            }
                            this.clusterMeans.set(c, (SparseVector)((Instance)cacheList.get(newCentroid)).getData());
                        }
                        default: {
                            return null;
                        }
                    }
                }
                ++c;
            }
            logger.info("Iter " + iterations + " deltaMeans = " + deltaMeans);
        }
        if (deltaMeans <= MEANS_TOLERANCE) {
            logger.info("KMeans converged with deltaMeans = " + deltaMeans);
        } else if (iterations >= MAX_ITER) {
            logger.info("Maximum number of iterations (" + MAX_ITER + ") reached.");
        } else if (deltaPoints <= (double)instances.size() * POINTS_TOLERANCE) {
            logger.info("Minimum number of points (np*" + POINTS_TOLERANCE + "=" + (int)((double)instances.size() * POINTS_TOLERANCE) + ") moved in last iteration. Saying converged.");
        }
        return new Clustering(instances, this.numClusters, clusterLabels);
    }

    private void initializeMeansSample(InstanceList instList, Metric metric) {
        ArrayList<Instance> instances = new ArrayList<Instance>(instList.size());
        int i = 0;
        while (i < instList.size()) {
            Instance ins = (Instance)instList.get(i);
            SparseVector sparse = (SparseVector)ins.getData();
            if (sparse.numLocations() != 0) {
                instances.add(ins);
            }
            ++i;
        }
        i = 0;
        while (i < this.numClusters) {
            double max = 0.0;
            int selected = 0;
            int k = 0;
            while (k < instances.size()) {
                double min = Double.MAX_VALUE;
                Instance ins = (Instance)instances.get(k);
                SparseVector inst = (SparseVector)ins.getData();
                int j = 0;
                while (j < this.clusterMeans.size()) {
                    SparseVector centerInst = this.clusterMeans.get(j);
                    double dist = metric.distance(centerInst, inst);
                    if (dist < min) {
                        min = dist;
                    }
                    ++j;
                }
                if (min > max) {
                    selected = k;
                    max = min;
                }
                ++k;
            }
            Instance newCenter = (Instance)instances.remove(selected);
            this.clusterMeans.add((SparseVector)newCenter.getData());
            ++i;
        }
    }

    public ArrayList<SparseVector> getClusterMeans() {
        return this.clusterMeans;
    }
}

