package org.apache.flink.ml.clustering.agglomerativeclustering;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.commons.collections.IteratorUtils;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple4;
import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.ml.api.AlgoOperator;
import org.apache.flink.ml.common.datastream.DataStreamUtils;
import org.apache.flink.ml.common.datastream.TableUtils;
import org.apache.flink.ml.common.distance.DistanceMeasure;
import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.VectorWithNorm;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.functions.windowing.ProcessAllWindowFunction;
import org.apache.flink.streaming.api.windowing.windows.Window;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.Schema;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
import org.apache.flink.util.OutputTag;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:org/apache/flink/ml/clustering/agglomerativeclustering/AgglomerativeClustering.class */
public class AgglomerativeClustering implements AlgoOperator<AgglomerativeClustering>, AgglomerativeClusteringParams<AgglomerativeClustering> {
    private final Map<Param<?>, Object> paramMap = new HashMap();

    /* loaded from: input_file:org/apache/flink/ml/clustering/agglomerativeclustering/AgglomerativeClustering$LocalAgglomerativeClusteringFunction.class */
    private static class LocalAgglomerativeClusteringFunction<W extends Window> extends ProcessAllWindowFunction<Row, Row, W> implements ResultTypeQueryable<Row> {
        private final String featuresCol;
        private final String linkage;
        private final DistanceMeasure distanceMeasure;
        private final Integer numCluster;
        private final Double distanceThreshold;
        private final boolean computeFullTree;
        private final OutputTag<Tuple4<Integer, Integer, Double, Integer>> mergeInfoOutputTag;
        private final RowTypeInfo outputTypeInfo;

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:org/apache/flink/ml/clustering/agglomerativeclustering/AgglomerativeClustering$LocalAgglomerativeClusteringFunction$DistanceMatrix.class */
        public static class DistanceMatrix {
            private final double[] distances;
            private final int n;

            public DistanceMatrix(int i) {
                this.distances = new double[(i * (i - 1)) / 2];
                this.n = i;
            }

            public void set(int i, int i2, double d) {
                int min = Math.min(i, i2);
                this.distances[(((((this.n * 2) - 1) - min) * min) / 2) + ((Math.max(i, i2) - min) - 1)] = d;
            }

            public double get(int i, int i2) {
                int min = Math.min(i, i2);
                return this.distances[(((((this.n * 2) - 1) - min) * min) / 2) + ((Math.max(i, i2) - min) - 1)];
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:org/apache/flink/ml/clustering/agglomerativeclustering/AgglomerativeClustering$LocalAgglomerativeClusteringFunction$UnionFind.class */
        public static class UnionFind {
            private final int[] parent;
            private int nextLabel;

            public UnionFind(int i) {
                this.parent = new int[(2 * i) - 1];
                Arrays.fill(this.parent, -1);
                this.nextLabel = i;
            }

            public void union(int i, int i2) {
                this.parent[i] = this.nextLabel;
                this.parent[i2] = this.nextLabel;
                this.nextLabel++;
            }

            public int find(int i) {
                int i2 = i;
                while (this.parent[i] != -1) {
                    i = this.parent[i];
                }
                while (this.parent[i2] != i && this.parent[i2] != -1) {
                    i2 = this.parent[i2];
                    this.parent[i2] = i;
                }
                return i;
            }
        }

        public LocalAgglomerativeClusteringFunction(String str, String str2, String str3, Integer num, Double d, boolean z, OutputTag<Tuple4<Integer, Integer, Double, Integer>> outputTag, RowTypeInfo rowTypeInfo) {
            this.featuresCol = str;
            this.linkage = str2;
            this.numCluster = num;
            this.distanceThreshold = d;
            this.computeFullTree = z;
            this.mergeInfoOutputTag = outputTag;
            this.distanceMeasure = DistanceMeasure.getInstance(str3);
            this.outputTypeInfo = rowTypeInfo;
        }

        public void process(ProcessAllWindowFunction<Row, Row, W>.Context context, Iterable<Row> iterable, Collector<Row> collector) {
            List list = IteratorUtils.toList(iterable.iterator());
            int size = list.size();
            if (size == 0) {
                return;
            }
            DistanceMatrix distanceMatrix = new DistanceMatrix((size * 2) - 1);
            for (int i = 0; i < size; i++) {
                VectorWithNorm vectorWithNorm = new VectorWithNorm((Vector) ((Row) list.get(i)).getFieldAs(this.featuresCol));
                for (int i2 = i + 1; i2 < size; i2++) {
                    distanceMatrix.set(i, i2, this.distanceMeasure.distance(vectorWithNorm, new VectorWithNorm((Vector) ((Row) list.get(i2)).getFieldAs(this.featuresCol))));
                }
            }
            HashSet<Integer> hashSet = new HashSet<>(size);
            for (int i3 = 0; i3 < size; i3++) {
                hashSet.add(Integer.valueOf(i3));
            }
            Tuple2<List<Tuple4<Integer, Integer, Integer, Double>>, int[]> nnChainCore = nnChainCore(hashSet, distanceMatrix, this.linkage);
            List<Tuple4<Integer, Integer, Integer, Double>> list2 = (List) nnChainCore.f0;
            list2.sort(Comparator.comparingDouble(tuple4 -> {
                return ((Double) tuple4.f3).doubleValue();
            }));
            reOrderNnChain(list2);
            int i4 = 0;
            if (this.distanceThreshold != null) {
                Iterator<Tuple4<Integer, Integer, Integer, Double>> it = list2.iterator();
                while (it.hasNext()) {
                    if (((Double) it.next().f3).doubleValue() <= this.distanceThreshold.doubleValue()) {
                        i4++;
                    }
                }
            } else {
                i4 = size - this.numCluster.intValue();
            }
            int[] label = label(list2.subList(0, i4), list2.size() + 1);
            HashMap hashMap = new HashMap();
            int i5 = 0;
            for (int i6 = 0; i6 < label.length; i6++) {
                int i7 = label[i6];
                if (hashMap.containsKey(Integer.valueOf(i7))) {
                    label[i6] = ((Integer) hashMap.get(Integer.valueOf(i7))).intValue();
                } else {
                    label[i6] = i5;
                    int i8 = i5;
                    i5++;
                    hashMap.put(Integer.valueOf(i7), Integer.valueOf(i8));
                }
            }
            for (int i9 = 0; i9 < size; i9++) {
                collector.collect(Row.join((Row) list.get(i9), new Row[]{Row.of(new Object[]{Integer.valueOf(label[i9])})}));
            }
            if (this.computeFullTree) {
                i4 = list2.size();
            }
            for (int i10 = 0; i10 < i4; i10++) {
                Tuple4<Integer, Integer, Integer, Double> tuple42 = list2.get(i10);
                int min = Math.min(((Integer) tuple42.f0).intValue(), ((Integer) tuple42.f1).intValue());
                int max = Math.max(((Integer) tuple42.f0).intValue(), ((Integer) tuple42.f1).intValue());
                context.output(this.mergeInfoOutputTag, Tuple4.of(Integer.valueOf(min), Integer.valueOf(max), (Double) tuple42.f3, Integer.valueOf(((int[]) nnChainCore.f1)[min] + ((int[]) nnChainCore.f1)[max])));
            }
        }

        private void reOrderNnChain(List<Tuple4<Integer, Integer, Integer, Double>> list) {
            int size = list.size() + 1;
            HashMap hashMap = new HashMap();
            for (Tuple4<Integer, Integer, Integer, Double> tuple4 : list) {
                if (hashMap.containsKey(tuple4.f0)) {
                    tuple4.f0 = hashMap.get(tuple4.f0);
                }
                if (hashMap.containsKey(tuple4.f1)) {
                    tuple4.f1 = hashMap.get(tuple4.f1);
                }
                hashMap.put((Integer) tuple4.f2, Integer.valueOf(size));
                size++;
            }
        }

        private int[] label(List<Tuple4<Integer, Integer, Integer, Double>> list, int i) {
            UnionFind unionFind = new UnionFind(i);
            for (Tuple4<Integer, Integer, Integer, Double> tuple4 : list) {
                unionFind.union(unionFind.find(((Integer) tuple4.f0).intValue()), unionFind.find(((Integer) tuple4.f1).intValue()));
            }
            int[] iArr = new int[i];
            for (int i2 = 0; i2 < iArr.length; i2++) {
                iArr[i2] = unionFind.find(i2);
            }
            return iArr;
        }

        private Tuple2<List<Tuple4<Integer, Integer, Integer, Double>>, int[]> nnChainCore(HashSet<Integer> hashSet, DistanceMatrix distanceMatrix, String str) {
            int intValue;
            int intValue2;
            int size = hashSet.size();
            int i = size;
            ArrayList arrayList = new ArrayList(size);
            ArrayList arrayList2 = new ArrayList();
            int[] iArr = new int[(size * 2) - 1];
            for (int i2 = 0; i2 < size; i2++) {
                iArr[i2] = 1;
            }
            while (hashSet.size() > 1) {
                if (arrayList2.size() <= 3) {
                    Iterator<Integer> it = hashSet.iterator();
                    intValue = it.next().intValue();
                    arrayList2.clear();
                    arrayList2.add(Integer.valueOf(intValue));
                    intValue2 = it.next().intValue();
                } else {
                    int size2 = arrayList2.size();
                    intValue = ((Integer) arrayList2.get(size2 - 4)).intValue();
                    intValue2 = ((Integer) arrayList2.get(size2 - 3)).intValue();
                    arrayList2.remove(size2 - 1);
                    arrayList2.remove(size2 - 2);
                    arrayList2.remove(size2 - 3);
                }
                while (true) {
                    if (arrayList2.size() >= 3 && ((Integer) arrayList2.get(arrayList2.size() - 3)).intValue() == intValue) {
                        break;
                    }
                    double d = Double.MAX_VALUE;
                    int i3 = -1;
                    Iterator<Integer> it2 = hashSet.iterator();
                    while (it2.hasNext()) {
                        int intValue3 = it2.next().intValue();
                        if (intValue3 != intValue) {
                            double d2 = distanceMatrix.get(intValue, intValue3);
                            if (d2 < d) {
                                i3 = intValue3;
                                d = d2;
                            }
                        }
                    }
                    if (d == distanceMatrix.get(intValue, intValue2) && hashSet.contains(Integer.valueOf(intValue2))) {
                        i3 = intValue2;
                    }
                    intValue2 = intValue;
                    intValue = i3;
                    arrayList2.add(Integer.valueOf(intValue));
                }
                int i4 = i;
                arrayList.add(Tuple4.of(Integer.valueOf(intValue), Integer.valueOf(intValue2), Integer.valueOf(i4), Double.valueOf(distanceMatrix.get(intValue, intValue2))));
                hashSet.remove(Integer.valueOf(intValue));
                hashSet.remove(Integer.valueOf(intValue2));
                i++;
                iArr[i4] = iArr[intValue] + iArr[intValue2];
                Iterator<Integer> it3 = hashSet.iterator();
                while (it3.hasNext()) {
                    int intValue4 = it3.next().intValue();
                    distanceMatrix.set(intValue4, i4, computeClusterDistances(distanceMatrix.get(intValue, intValue4), distanceMatrix.get(intValue2, intValue4), distanceMatrix.get(intValue, intValue2), iArr[intValue], iArr[intValue2], iArr[intValue4], str));
                }
                hashSet.add(Integer.valueOf(i4));
            }
            return Tuple2.of(arrayList, iArr);
        }

        private double computeClusterDistances(double d, double d2, double d3, int i, int i2, int i3, String str) {
            boolean z = -1;
            switch (str.hashCode()) {
                case -902265784:
                    if (str.equals(AgglomerativeClusteringParams.LINKAGE_SINGLE)) {
                        z = false;
                        break;
                    }
                    break;
                case -631448035:
                    if (str.equals(AgglomerativeClusteringParams.LINKAGE_AVERAGE)) {
                        z = 2;
                        break;
                    }
                    break;
                case -599445191:
                    if (str.equals(AgglomerativeClusteringParams.LINKAGE_COMPLETE)) {
                        z = true;
                        break;
                    }
                    break;
                case 3641980:
                    if (str.equals(AgglomerativeClusteringParams.LINKAGE_WARD)) {
                        z = 3;
                        break;
                    }
                    break;
            }
            switch (z) {
                case false:
                    return Math.min(d, d2);
                case true:
                    return Math.max(d, d2);
                case true:
                    return ((i * d) + (i2 * d2)) / (i + i2);
                case true:
                    return Math.sqrt((((((i + i3) * d) * d) + (((i2 + i3) * d2) * d2)) - ((i3 * d3) * d3)) / ((i + i2) + i3));
                default:
                    throw new UnsupportedOperationException("Unsupported " + AgglomerativeClusteringParams.LINKAGE + " type: " + str + ".");
            }
        }

        public TypeInformation<Row> getProducedType() {
            return this.outputTypeInfo;
        }
    }

    public AgglomerativeClustering() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
    }

    @Override // org.apache.flink.ml.api.AlgoOperator
    public Table[] transform(Table... tableArr) {
        Preconditions.checkArgument(tableArr.length == 1);
        Integer numClusters = getNumClusters();
        Double distanceThreshold = getDistanceThreshold();
        Preconditions.checkArgument((numClusters == null && distanceThreshold != null) || (numClusters != null && distanceThreshold == null), "One of param numCluster and distanceThreshold should be null.");
        if (getLinkage().equals(AgglomerativeClusteringParams.LINKAGE_WARD)) {
            String distanceMeasure = getDistanceMeasure();
            Preconditions.checkArgument(distanceMeasure.equals(EuclideanDistanceMeasure.NAME), distanceMeasure + " was provided as distance measure while linkage was ward. Ward only works with euclidean.");
        }
        StreamTableEnvironment tableEnvironment = ((TableImpl) tableArr[0]).getTableEnvironment();
        DataStream dataStream = tableEnvironment.toDataStream(tableArr[0]);
        RowTypeInfo rowTypeInfo = TableUtils.getRowTypeInfo(tableArr[0].getResolvedSchema());
        RowTypeInfo rowTypeInfo2 = new RowTypeInfo((TypeInformation[]) ArrayUtils.addAll(rowTypeInfo.getFieldTypes(), new TypeInformation[]{Types.INT}), (String[]) ArrayUtils.addAll(rowTypeInfo.getFieldNames(), new String[]{getPredictionCol()}));
        OutputTag<Tuple4<Integer, Integer, Double, Integer>> outputTag = new OutputTag<Tuple4<Integer, Integer, Double, Integer>>("MERGE_INFO") { // from class: org.apache.flink.ml.clustering.agglomerativeclustering.AgglomerativeClustering.1
        };
        SingleOutputStreamOperator windowAllAndProcess = DataStreamUtils.windowAllAndProcess(dataStream, getWindows(), new LocalAgglomerativeClusteringFunction(getFeaturesCol(), getLinkage(), getDistanceMeasure(), getNumClusters(), getDistanceThreshold(), getComputeFullTree().booleanValue(), outputTag, rowTypeInfo2));
        Table fromDataStream = tableEnvironment.fromDataStream(windowAllAndProcess, Schema.newBuilder().fromResolvedSchema(tableArr[0].getResolvedSchema()).column(getPredictionCol(), DataTypes.INT()).build());
        DataStream sideOutput = windowAllAndProcess.getSideOutput(outputTag);
        sideOutput.getTransformation().setParallelism(1);
        return new Table[]{fromDataStream, tableEnvironment.fromDataStream(sideOutput).as("clusterId1", new String[]{"clusterId2", "distance", "sizeOfMergedCluster"})};
    }

    @Override // org.apache.flink.ml.api.Stage
    public void save(String str) throws IOException {
        ReadWriteUtils.saveMetadata(this, str);
    }

    public static AgglomerativeClustering load(StreamTableEnvironment streamTableEnvironment, String str) throws IOException {
        return (AgglomerativeClustering) ReadWriteUtils.loadStageParam(str);
    }

    @Override // org.apache.flink.ml.param.WithParams
    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }
}
