package io.kgraph.library.clustering;

import io.kgraph.EdgeWithValue;
import io.kgraph.VertexWithValue;
import io.kgraph.pregel.ComputeFunction;
import io.kgraph.pregel.aggregators.LongSumAggregator;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:io/kgraph/library/clustering/KMeansClustering.class */
public class KMeansClustering<EV, Message> implements ComputeFunction<Long, KMeansVertexValue, EV, Message> {
    public static final String CENTER_AGGR_PREFIX = "center.aggr.prefix";
    public static final String ASSIGNED_POINTS_PREFIX = "assigned.points.prefix";
    public static final String INITIAL_CENTERS = "kmeans.initial.centers";
    public static final String MAX_ITERATIONS = "kmeans.iterations";
    public static final int ITERATIONS_DEFAULT = 100;
    public static final String CLUSTER_CENTERS_COUNT = "kmeans.cluster.centers.count";
    public static final int CLUSTER_CENTERS_COUNT_DEFAULT = 3;
    public static final String DIMENSIONS = "kmeans.points.dimensions";
    public static final String POINTS_COUNT = "kmeans.points.count";
    public static final String PRINT_FINAL_CENTERS = "kmeans.print.final.centers";
    public static final boolean PRINT_FINAL_CENTERS_DEFAULT = false;
    public static final String TEST_INITIAL_CENTERS = "test.initial.centers";
    private Map<String, Object> configs;
    private int maxIterations;
    private List<Double>[] currentClusterCenters;
    private int clustersCount;
    private int dimensions;

    /* loaded from: input_file:io/kgraph/library/clustering/KMeansClustering$RandomCentersInitialization.class */
    public class RandomCentersInitialization implements ComputeFunction<Long, KMeansVertexValue, EV, Message> {
        public RandomCentersInitialization() {
        }

        @Override // io.kgraph.pregel.ComputeFunction
        public void compute(int i, VertexWithValue<Long, KMeansVertexValue> vertexWithValue, Iterable<Message> iterable, Iterable<EdgeWithValue<Long, EV>> iterable2, ComputeFunction.Callback<Long, KMeansVertexValue, EV, Message> callback) {
            if (KMeansClustering.this.configs.get(KMeansClustering.TEST_INITIAL_CENTERS) != null) {
                return;
            }
            ArrayList arrayList = new ArrayList();
            arrayList.add(vertexWithValue.value().getPointCoordinates());
            callback.aggregate(KMeansClustering.INITIAL_CENTERS, arrayList);
        }
    }

    public void superstepCompute(int i, VertexWithValue<Long, KMeansVertexValue> vertexWithValue, Iterable<Message> iterable, Iterable<EdgeWithValue<Long, EV>> iterable2, ComputeFunction.Callback<Long, KMeansVertexValue, EV, Message> callback) {
        KMeansVertexValue value = vertexWithValue.value();
        List<Double> pointCoordinates = value.getPointCoordinates();
        int findClosestCenter = findClosestCenter(readClusterCenters(callback, CENTER_AGGR_PREFIX), value.getPointCoordinates());
        callback.aggregate("center.aggr.prefixC_" + findClosestCenter, pointCoordinates);
        callback.aggregate("assigned.points.prefixC_" + findClosestCenter, 1L);
        callback.setNewVertexValue(new KMeansVertexValue(vertexWithValue.value().getPointCoordinates(), findClosestCenter));
    }

    private List<Double>[] readClusterCenters(ComputeFunction.Callback<Long, KMeansVertexValue, EV, Message> callback, String str) {
        List<Double>[] listArr = new List[this.clustersCount];
        for (int i = 0; i < this.clustersCount; i++) {
            listArr[i] = (List) callback.getAggregatedValue(str + "C_" + i);
        }
        return listArr;
    }

    private int findClosestCenter(List<Double>[] listArr, List<Double> list) {
        double d = Double.MAX_VALUE;
        int i = 0;
        for (int i2 = 0; i2 < listArr.length; i2++) {
            double euclideanDistance = euclideanDistance(listArr[i2], list, listArr[i2].size());
            if (euclideanDistance < d) {
                d = euclideanDistance;
                i = i2;
            }
        }
        return i;
    }

    private double euclideanDistance(List<Double> list, List<Double> list2, int i) {
        double d = 0.0d;
        for (int i2 = 0; i2 < i; i2++) {
            d += Math.pow(list.get(i2).doubleValue() - list2.get(i2).doubleValue(), 2.0d);
        }
        return Math.sqrt(d);
    }

    @Override // io.kgraph.pregel.ComputeFunction
    public final void init(Map<String, ?> map, ComputeFunction.InitCallback initCallback) {
        this.configs = map;
        this.maxIterations = ((Integer) this.configs.getOrDefault(MAX_ITERATIONS, 100)).intValue();
        this.clustersCount = ((Integer) this.configs.getOrDefault("kmeans.cluster.centers.count", 3)).intValue();
        this.dimensions = ((Integer) this.configs.getOrDefault(DIMENSIONS, 0)).intValue();
        this.currentClusterCenters = new List[this.clustersCount];
        initCallback.registerAggregator(INITIAL_CENTERS, ListOfDoubleListAggregator.class);
        for (int i = 0; i < this.clustersCount; i++) {
            initCallback.registerAggregator("center.aggr.prefixC_" + i, DoubleListAggregator.class);
            initCallback.registerAggregator("assigned.points.prefixC_" + i, LongSumAggregator.class);
        }
    }

    @Override // io.kgraph.pregel.ComputeFunction
    public final void masterCompute(int i, ComputeFunction.MasterCallback masterCallback) {
        if (i == 1) {
            List list = (List) this.configs.get(TEST_INITIAL_CENTERS);
            if (list == null) {
                list = (List) masterCallback.getAggregatedValue(INITIAL_CENTERS);
            }
            for (int i2 = 0; i2 < this.clustersCount; i2++) {
                masterCallback.setAggregatedValue("center.aggr.prefixC_" + i2, list.get(i2));
                this.currentClusterCenters[i2] = (List) list.get(i2);
            }
            return;
        }
        if (i > 1) {
            List<Double>[] computeClusterCenters = computeClusterCenters(masterCallback);
            if (i > this.maxIterations || clusterPositionsDiff(this.currentClusterCenters, computeClusterCenters)) {
                if (((Boolean) this.configs.getOrDefault(PRINT_FINAL_CENTERS, false)).booleanValue()) {
                    printFinalCentersCoordinates();
                }
                masterCallback.haltComputation();
            } else {
                for (int i3 = 0; i3 < this.clustersCount; i3++) {
                    masterCallback.setAggregatedValue("center.aggr.prefixC_" + i3, computeClusterCenters[i3]);
                }
                this.currentClusterCenters = computeClusterCenters;
            }
        }
    }

    @Override // io.kgraph.pregel.ComputeFunction
    public void compute(int i, VertexWithValue<Long, KMeansVertexValue> vertexWithValue, Iterable<Message> iterable, Iterable<EdgeWithValue<Long, EV>> iterable2, ComputeFunction.Callback<Long, KMeansVertexValue, EV, Message> callback) {
        if (i == 0) {
            new RandomCentersInitialization().compute(i, vertexWithValue, iterable, iterable2, callback);
        } else {
            superstepCompute(i, vertexWithValue, iterable, iterable2, callback);
        }
    }

    private List<Double>[] computeClusterCenters(ComputeFunction.MasterCallback masterCallback) {
        List<Double>[] listArr = new List[this.clustersCount];
        for (int i = 0; i < this.clustersCount; i++) {
            List<Double> list = (List) masterCallback.getAggregatedValue("center.aggr.prefixC_" + i);
            long longValue = ((Long) masterCallback.getAggregatedValue("assigned.points.prefixC_" + i)).longValue();
            for (int i2 = 0; i2 < list.size(); i2++) {
                list.set(i2, Double.valueOf(list.get(i2).doubleValue() / longValue));
            }
            listArr[i] = list;
        }
        return listArr;
    }

    private boolean clusterPositionsDiff(List<Double>[] listArr, List<Double>[] listArr2) {
        double d = 0.0d;
        for (int i = 0; i < this.clustersCount; i++) {
            for (int i2 = 0; i2 < this.dimensions; i2++) {
                d += Math.abs(listArr[i].get(i2).doubleValue() - listArr2[i].get(i2).doubleValue());
            }
        }
        return d <= 0.0010000000474974513d;
    }

    private void printFinalCentersCoordinates() {
        System.out.println("Centers Coordinates: ");
        for (int i = 0; i < this.clustersCount; i++) {
            System.out.print("cluster id " + i + ": ");
            for (int i2 = 0; i2 < this.currentClusterCenters[i].size(); i2++) {
                System.out.print(this.currentClusterCenters[i].get(i2) + " ");
            }
            System.out.println();
        }
    }
}
