package org.apache.mahout.clustering.kmeans;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.List;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.OutputCollector;
import org.apache.mahout.clustering.ClusterBase;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.matrix.AbstractVector;
import org.apache.mahout.matrix.SquareRootFunction;
import org.apache.mahout.matrix.Vector;

/* loaded from: input_file:WEB-INF/lib/mahout-core-0.2.jar:org/apache/mahout/clustering/kmeans/Cluster.class */
public class Cluster extends ClusterBase {
    private static final String ERROR_UNKNOWN_CLUSTER_FORMAT = "Unknown cluster format:\n";
    public static final String DISTANCE_MEASURE_KEY = "org.apache.mahout.clustering.kmeans.measure";
    public static final String CLUSTER_PATH_KEY = "org.apache.mahout.clustering.kmeans.path";
    public static final String CLUSTER_CONVERGENCE_KEY = "org.apache.mahout.clustering.kmeans.convergence";
    public static final String ITERATION_NUMBER = "org.apache.mahout.clustering.kmeans.iteration";
    public static final String CANOPY_INPUT = "org.apache.mahout.clustering.kmeans.canopyInput";
    private Vector centroid;
    private Vector pointSquaredTotal;
    private boolean converged;
    private static DistanceMeasure measure;
    private static int nextClusterId = 0;
    private static double convergenceDelta = 0.0d;

    public static String formatCluster(Cluster cluster) {
        return cluster.getIdentifier() + ": " + cluster.computeCentroid().asFormatString();
    }

    @Override // org.apache.mahout.clustering.ClusterBase
    public String asFormatString() {
        return formatCluster(this);
    }

    public static Cluster decodeCluster(String str) {
        int indexOf = str.indexOf(123);
        if (indexOf <= 0) {
            throw new IllegalArgumentException(ERROR_UNKNOWN_CLUSTER_FORMAT + str);
        }
        String substring = str.substring(0, indexOf);
        String substring2 = str.substring(indexOf);
        char charAt = substring.charAt(0);
        boolean z = charAt == 'V';
        if (charAt != 'C' && !z) {
            throw new IllegalArgumentException(ERROR_UNKNOWN_CLUSTER_FORMAT + str);
        }
        Cluster cluster = new Cluster(AbstractVector.decodeVector(substring2), Integer.parseInt(str.substring(1, indexOf - 2)));
        cluster.setConverged(z);
        return cluster;
    }

    @Override // org.apache.mahout.clustering.ClusterBase
    public void write(DataOutput dataOutput) throws IOException {
        super.write(dataOutput);
        dataOutput.writeBoolean(this.converged);
        AbstractVector.writeVector(dataOutput, computeCentroid());
    }

    @Override // org.apache.mahout.clustering.ClusterBase
    public void readFields(DataInput dataInput) throws IOException {
        super.readFields(dataInput);
        this.converged = dataInput.readBoolean();
        setCenter(AbstractVector.readVector(dataInput));
        setNumPoints(0);
        setPointTotal(getCenter().like());
        this.pointSquaredTotal = getCenter().like();
    }

    public static void configure(JobConf jobConf) {
        try {
            measure = (DistanceMeasure) Thread.currentThread().getContextClassLoader().loadClass(jobConf.get("org.apache.mahout.clustering.kmeans.measure")).newInstance();
            measure.configure(jobConf);
            convergenceDelta = Double.parseDouble(jobConf.get("org.apache.mahout.clustering.kmeans.convergence"));
            nextClusterId = 0;
        } catch (ClassNotFoundException e) {
            throw new IllegalStateException(e);
        } catch (IllegalAccessException e2) {
            throw new IllegalStateException(e2);
        } catch (InstantiationException e3) {
            throw new IllegalStateException(e3);
        }
    }

    public static void config(DistanceMeasure distanceMeasure, double d) {
        measure = distanceMeasure;
        convergenceDelta = d;
        nextClusterId = 0;
    }

    public static void emitPointToNearestCluster(Vector vector, List<Cluster> list, OutputCollector<Text, KMeansInfo> outputCollector) throws IOException {
        Cluster cluster = null;
        double d = Double.MAX_VALUE;
        for (Cluster cluster2 : list) {
            Vector center = cluster2.getCenter();
            double distance = measure.distance(center.getLengthSquared(), center, vector);
            if (distance < d || cluster == null) {
                cluster = cluster2;
                d = distance;
            }
        }
        outputCollector.collect(new Text(cluster.getIdentifier()), new KMeansInfo(1, vector));
    }

    public static void outputPointWithClusterInfo(Vector vector, List<Cluster> list, OutputCollector<Text, Text> outputCollector) throws IOException {
        Cluster cluster = null;
        double d = Double.MAX_VALUE;
        for (Cluster cluster2 : list) {
            Vector center = cluster2.getCenter();
            double distance = measure.distance(center.getLengthSquared(), center, vector);
            if (distance < d || cluster == null) {
                cluster = cluster2;
                d = distance;
            }
        }
        String name = vector.getName();
        outputCollector.collect(new Text((name == null || name.length() == 0) ? vector.asFormatString() : name), new Text(String.valueOf(cluster.getId())));
    }

    private Vector computeCentroid() {
        if (getNumPoints() == 0) {
            return getCenter();
        }
        if (this.centroid == null) {
            this.centroid = getPointTotal().divide(getNumPoints());
        }
        return this.centroid;
    }

    public Cluster(Vector vector) {
        this.centroid = null;
        this.pointSquaredTotal = null;
        this.converged = false;
        int i = nextClusterId;
        nextClusterId = i + 1;
        setId(i);
        setCenter(vector);
        setNumPoints(0);
        setPointTotal(vector.like());
        this.pointSquaredTotal = vector.like();
    }

    public Cluster() {
        this.centroid = null;
        this.pointSquaredTotal = null;
        this.converged = false;
    }

    public Cluster(Vector vector, int i) {
        this.centroid = null;
        this.pointSquaredTotal = null;
        this.converged = false;
        setId(i);
        setCenter(vector);
        setNumPoints(0);
        setPointTotal(vector.like());
        this.pointSquaredTotal = vector.like();
    }

    public Cluster(String str) {
        this.centroid = null;
        this.pointSquaredTotal = null;
        this.converged = false;
        setId(Integer.parseInt(str.substring(1)));
        setNumPoints(0);
        this.converged = str.startsWith("V");
    }

    public String toString() {
        return getIdentifier() + " - " + getCenter().asFormatString();
    }

    public String getIdentifier() {
        return this.converged ? "V" + getId() : "C" + getId();
    }

    public void addPoint(Vector vector) {
        addPoints(1, vector);
    }

    public void addPoints(int i, Vector vector) {
        this.centroid = null;
        setNumPoints(getNumPoints() + i);
        if (getPointTotal() == null) {
            setPointTotal(vector.mo560clone());
            this.pointSquaredTotal = vector.times(vector);
        } else {
            vector.addTo(getPointTotal());
            vector.times(vector).addTo(this.pointSquaredTotal);
        }
    }

    public void recomputeCenter() {
        setCenter(computeCentroid());
        setNumPoints(0);
        setPointTotal(getCenter().like());
    }

    public boolean computeConvergence() {
        Vector computeCentroid = computeCentroid();
        this.converged = measure.distance(computeCentroid.getLengthSquared(), computeCentroid, getCenter()) <= convergenceDelta;
        return this.converged;
    }

    public boolean isConverged() {
        return this.converged;
    }

    private void setConverged(boolean z) {
        this.converged = z;
    }

    public double getStd() {
        return this.pointSquaredTotal.times(getNumPoints()).minus(getPointTotal().times(getPointTotal())).assign(new SquareRootFunction()).divide(getNumPoints()).zSum() / 2.0d;
    }
}
