/*
 * Decompiled with CFR 0.152.
 */
package de.jungblut.clustering;

import de.jungblut.clustering.Cluster;
import de.jungblut.distance.DistanceMeasurer;
import de.jungblut.math.DoubleVector;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Deque;
import java.util.List;
import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.stream.IntStream;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public final class KMeansClustering {
    private static final Logger LOG = LogManager.getLogger(KMeansClustering.class);
    private final DoubleVector[] centers;
    private final List<DoubleVector> vectors;
    private final int k;
    private double clusteringCost;

    public KMeansClustering(int k, DoubleVector[] vectors, boolean random) {
        this(k, Arrays.asList(vectors), random);
    }

    public KMeansClustering(int k, List<DoubleVector> vectors, boolean random) {
        this.k = k;
        this.vectors = vectors;
        this.centers = new DoubleVector[k];
        if (random) {
            Collections.shuffle(vectors);
        }
        for (int i = 0; i < k; ++i) {
            this.centers[i] = vectors.get(i);
        }
    }

    public KMeansClustering(List<DoubleVector> centers, List<DoubleVector> vectors) {
        this.k = centers.size();
        this.vectors = vectors;
        this.centers = new DoubleVector[this.k];
        for (int i = 0; i < this.k; ++i) {
            this.centers[i] = centers.get(i);
        }
    }

    public List<Cluster> cluster(int iterations, DistanceMeasurer distanceMeasurer, double delta, boolean verbose) {
        Deque[] assignments = this.setupAssignments();
        double lastCost = Double.MAX_VALUE;
        for (int iteration = 0; iteration < iterations; ++iteration) {
            double diff;
            Arrays.stream(assignments).forEach(Collection::clear);
            double cost = IntStream.range(0, this.vectors.size()).parallel().mapToDouble(x -> this.assign(distanceMeasurer, assignments, x)).sum();
            this.computeCenters(assignments);
            if (verbose) {
                LOG.info("Iteration " + iteration + " | Cost: " + cost);
            }
            if ((diff = Math.abs(lastCost - cost)) < delta) break;
            lastCost = cost;
        }
        this.clusteringCost = lastCost;
        Arrays.stream(assignments).forEach(Collection::clear);
        IntStream.range(0, this.vectors.size()).parallel().forEach(x -> this.assign(distanceMeasurer, assignments, x));
        ArrayList<Cluster> lst = new ArrayList<Cluster>();
        for (int i = 0; i < this.centers.length; ++i) {
            lst.add(new Cluster(this.centers[i], new ArrayList<DoubleVector>(assignments[i])));
        }
        return lst;
    }

    public double getClusteringCost() {
        return this.clusteringCost;
    }

    private void computeCenters(Deque<DoubleVector>[] assignments) {
        IntStream.range(0, assignments.length).parallel().forEach(i -> {
            int len = assignments[i].size();
            if (len > 0) {
                DoubleVector sumVector = (DoubleVector)assignments[i].pop();
                while (!assignments[i].isEmpty()) {
                    sumVector = sumVector.add((DoubleVector)assignments[i].pop());
                }
                this.centers[i] = sumVector.divide((double)len);
            }
        });
    }

    private Deque<DoubleVector>[] setupAssignments() {
        Deque[] assignments = new Deque[this.k];
        for (int i = 0; i < assignments.length; ++i) {
            assignments[i] = new ConcurrentLinkedDeque();
        }
        return assignments;
    }

    private double assign(DistanceMeasurer distanceMeasurer, Deque<DoubleVector>[] assignments, int vectorIndex) {
        DoubleVector v = this.vectors.get(vectorIndex);
        int lowestDistantCenter = 0;
        double lowestDistance = Double.MAX_VALUE;
        for (int i = 0; i < this.centers.length; ++i) {
            double estimatedDistance = distanceMeasurer.measureDistance(this.centers[i], v);
            if (!(estimatedDistance < lowestDistance)) continue;
            lowestDistance = estimatedDistance;
            lowestDistantCenter = i;
        }
        assignments[lowestDistantCenter].add(v);
        return lowestDistance;
    }

    public DoubleVector[] getCenters() {
        return this.centers;
    }
}

