/*
 * Decompiled with CFR 0.152.
 */
package de.jungblut.clustering;

import de.jungblut.distance.EuclidianDistance;
import de.jungblut.jrpt.KDTree;
import de.jungblut.jrpt.VectorDistanceTuple;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.dense.DenseDoubleVector;
import de.jungblut.math.tuple.Tuple;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.List;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.commons.math3.util.FastMath;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public final class MeanShiftClustering {
    private static final Logger LOG = LogManager.getLogger(MeanShiftClustering.class);
    private static final double SQRT_2_PI = FastMath.sqrt((double)(Math.PI * 2));

    public static List<DoubleVector> cluster(List<DoubleVector> points, double windowSize, double mergeWindow, int maxIterations, boolean verbose) {
        KDTree kdTree = new KDTree();
        Stream<Tuple> payloadStream = IntStream.range(0, points.size()).mapToObj(i -> new Tuple((Object)((DoubleVector)points.get(i)), (Object)i));
        kdTree.constructWithPayload(payloadStream);
        kdTree.balanceBySort();
        List<DoubleVector> centers = MeanShiftClustering.observeCenters((KDTree<Integer>)kdTree, points, windowSize, verbose);
        for (int i2 = 0; i2 < maxIterations; ++i2) {
            int converged = MeanShiftClustering.meanShift((KDTree<Integer>)kdTree, centers, windowSize);
            MeanShiftClustering.merge(centers, mergeWindow);
            if (verbose) {
                LOG.info("Iteration: " + i2 + " | Remaining centers converging: " + converged + "/" + centers.size());
            }
            if (converged == 0) break;
        }
        return centers;
    }

    private static void merge(List<DoubleVector> centers, double mergeWindow) {
        for (int i = 0; i < centers.size(); ++i) {
            DoubleVector referenceVector = centers.get(i);
            for (int j = i + 1; j < centers.size(); ++j) {
                DoubleVector center = centers.get(j);
                double dist = EuclidianDistance.get().measureDistance(referenceVector, center);
                if (!(dist < mergeWindow)) continue;
                centers.remove(j);
                centers.set(i, referenceVector.add(center).divide(2.0));
                --j;
            }
        }
    }

    private static int meanShift(KDTree<Integer> kdTree, List<DoubleVector> centers, double h) {
        int remainingConvergence = 0;
        for (int i = 0; i < centers.size(); ++i) {
            DoubleVector shift;
            DoubleVector newCenter;
            DoubleVector v = centers.get(i);
            List neighbours = kdTree.getNearestNeighbours(v, h);
            double weightSum = 0.0;
            DenseDoubleVector numerator = new DenseDoubleVector(v.getLength());
            for (VectorDistanceTuple neighbour : neighbours) {
                if (!(neighbour.getDistance() < h)) continue;
                double normDistance = neighbour.getDistance() / h;
                numerator = numerator.add(neighbour.getVector().multiply(weightSum -= MeanShiftClustering.gaussianGradient(normDistance)));
            }
            if (!(weightSum > 0.0) || !(v.subtract(newCenter = v.add(shift = v.divide((DoubleVector)numerator))).abs().sum() > 1.0E-5)) continue;
            ++remainingConvergence;
            centers.set(i, newCenter);
        }
        return remainingConvergence;
    }

    private static List<DoubleVector> observeCenters(KDTree<Integer> kdTree, List<DoubleVector> points, double h, boolean verbose) {
        ArrayList<DoubleVector> centers = new ArrayList<DoubleVector>();
        BitSet assignedIndices = new BitSet(kdTree.size());
        for (int i = 0; i < points.size(); ++i) {
            if (!assignedIndices.get(i)) {
                DoubleVector v = points.get(i);
                List neighbours = kdTree.getNearestNeighbours(v, h);
                DenseDoubleVector center = new DenseDoubleVector(v.getLength());
                int added = 0;
                for (VectorDistanceTuple neighbour : neighbours) {
                    if (assignedIndices.get((Integer)neighbour.getValue()) || !(neighbour.getDistance() < h)) continue;
                    center = center.add(neighbour.getVector());
                    assignedIndices.set((Integer)neighbour.getValue());
                    ++added;
                }
                if (added > 1) {
                    DoubleVector newCenter = center.divide((double)added);
                    centers.add(newCenter);
                    if (verbose && centers.size() % 1000 == 0) {
                        LOG.info("#Centers found: " + centers.size());
                    }
                }
            }
            assignedIndices.set(i);
        }
        return centers;
    }

    private static double gaussianGradient(double stddev) {
        return -FastMath.exp((double)(-(stddev * stddev) / 2.0)) / (SQRT_2_PI * stddev);
    }
}

