package com.clust4j.algo;

import com.clust4j.NamedEntity;
import com.clust4j.kernel.CircularKernel;
import com.clust4j.kernel.LogKernel;
import com.clust4j.log.Log;
import com.clust4j.log.LogTimer;
import com.clust4j.metrics.pairwise.Distance;
import com.clust4j.metrics.pairwise.GeometricallySeparable;
import com.clust4j.metrics.scoring.SupervisedMetric;
import com.clust4j.metrics.scoring.UnsupervisedMetric;
import com.clust4j.utils.MatUtils;
import com.clust4j.utils.SimpleHeap;
import com.clust4j.utils.VecUtils;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Objects;
import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.util.FastMath;

/* loaded from: input_file:com/clust4j/algo/HierarchicalAgglomerative.class */
public final class HierarchicalAgglomerative extends AbstractPartitionalClusterer implements UnsupervisedClassifier {
    private static final long serialVersionUID = 7563413590708853735L;
    public static final Linkage DEF_LINKAGE = Linkage.WARD;
    static final HashSet<Class<? extends GeometricallySeparable>> comp_avg_unsupported = new HashSet<>();
    final Linkage linkage;
    private final int m;
    private volatile int[] labels;
    private volatile EfficientDistanceMatrix dist_vec;
    volatile HierarchicalDendrogram tree;
    private volatile int num_clusters;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/clust4j/algo/HierarchicalAgglomerative$AverageLinkageTree.class */
    public class AverageLinkageTree extends LinkageTree {
        private static final long serialVersionUID = 5891407873391751152L;

        public AverageLinkageTree() {
            super();
        }

        @Override // com.clust4j.algo.HierarchicalAgglomerative.HierarchicalDendrogram
        protected double getDist(double d, double d2, double d3, int i, int i2, int i3) {
            return ((i * d) + (i2 * d2)) / (i + i2);
        }

        @Override // com.clust4j.NamedEntity
        public String getName() {
            return "Avg Linkage Tree";
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/clust4j/algo/HierarchicalAgglomerative$CompleteLinkageTree.class */
    public class CompleteLinkageTree extends LinkageTree {
        private static final long serialVersionUID = 7407993870975009576L;

        public CompleteLinkageTree() {
            super();
        }

        @Override // com.clust4j.algo.HierarchicalAgglomerative.HierarchicalDendrogram
        protected double getDist(double d, double d2, double d3, int i, int i2, int i3) {
            return FastMath.max(d, d2);
        }

        @Override // com.clust4j.NamedEntity
        public String getName() {
            return "Complete Linkage Tree";
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:com/clust4j/algo/HierarchicalAgglomerative$EfficientDistanceMatrix.class */
    public static class EfficientDistanceMatrix implements Serializable {
        private static final long serialVersionUID = -7329893729526766664L;
        protected final double[] dists;

        EfficientDistanceMatrix(RealMatrix realMatrix, GeometricallySeparable geometricallySeparable, boolean z) {
            this.dists = build(realMatrix.getData(), geometricallySeparable, z);
        }

        static double[] build(double[][] dArr, GeometricallySeparable geometricallySeparable, boolean z) {
            int length = dArr.length;
            double[] dArr2 = new double[(length * (length - 1)) / 2];
            int i = 0;
            for (int i2 = 0; i2 < length - 1; i2++) {
                int i3 = i2 + 1;
                while (i3 < length) {
                    dArr2[i] = z ? geometricallySeparable.getPartialDistance(dArr[i2], dArr[i3]) : geometricallySeparable.getDistance(dArr[i2], dArr[i3]);
                    i3++;
                    i++;
                }
            }
            return dArr2;
        }

        static int getIndexFromFlattenedVec(int i, int i2, int i3) {
            if (i2 < i3) {
                return ((i * i2) - ((i2 * (i2 + 1)) / 2)) + ((i3 - i2) - 1);
            }
            if (i2 > i3) {
                return ((i * i3) - ((i3 * (i3 + 1)) / 2)) + ((i2 - i3) - 1);
            }
            throw new IllegalArgumentException(i2 + ", " + i3 + "; i should not equal j");
        }

        double navigate(int i, int i2, int i3) {
            return this.dists[getIndexFromFlattenedVec(i, i2, i3)];
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/clust4j/algo/HierarchicalAgglomerative$HierarchicalDendrogram.class */
    public abstract class HierarchicalDendrogram implements Serializable, NamedEntity {
        private static final long serialVersionUID = 5295537901834851676L;
        public final HierarchicalAgglomerative ref;
        public final GeometricallySeparable dist;

        HierarchicalDendrogram() {
            this.ref = HierarchicalAgglomerative.this;
            this.dist = this.ref.getSeparabilityMetric();
            if (null == HierarchicalAgglomerative.this.dist_vec) {
                HierarchicalAgglomerative.this.dist_vec = new EfficientDistanceMatrix(HierarchicalAgglomerative.this.data, this.dist, true);
            }
        }

        double[][] linkage() {
            double[][] dArr = new double[HierarchicalAgglomerative.this.m - 1][4];
            link(HierarchicalAgglomerative.this.dist_vec, dArr, HierarchicalAgglomerative.this.m);
            return MatUtils.getColumns(dArr, new int[]{0, 1});
        }

        private void link(EfficientDistanceMatrix efficientDistanceMatrix, double[][] dArr, int i) {
            int i2 = -1;
            int i3 = -1;
            this.ref.info("initializing node mappings (" + getClass().getName().split("\\$")[1] + ")");
            int[] iArr = new int[i];
            for (int i4 = 0; i4 < i; i4++) {
                iArr[i4] = i4;
            }
            LogTimer logTimer = new LogTimer();
            int i5 = i / 10;
            int i6 = 1;
            for (int i7 = 0; i7 < i - 1; i7++) {
                if (i5 > 0 && i7 % i5 == 0) {
                    int i8 = i6;
                    i6++;
                    this.ref.info("node mapping progress - " + (10 * i8) + "%. Total link time: " + logTimer.toString() + "");
                }
                double d = Double.POSITIVE_INFINITY;
                LogTimer logTimer2 = new LogTimer();
                for (int i9 = 0; i9 < i - 1; i9++) {
                    if (iArr[i9] != -1) {
                        int indexFromFlattenedVec = EfficientDistanceMatrix.getIndexFromFlattenedVec(i, i9, i9 + 1);
                        for (int i10 = 0; i10 < (i - i9) - 1; i10++) {
                            if (efficientDistanceMatrix.dists[indexFromFlattenedVec + i10] < d) {
                                d = efficientDistanceMatrix.dists[indexFromFlattenedVec + i10];
                                i2 = i9;
                                i3 = i9 + i10 + 1;
                            }
                        }
                    }
                }
                int i11 = iArr[i2];
                int i12 = iArr[i3];
                int i13 = i11 < i ? 1 : (int) dArr[i11 - i][3];
                int i14 = i12 < i ? 1 : (int) dArr[i12 - i][3];
                dArr[i7][0] = FastMath.min(i11, i12);
                dArr[i7][1] = FastMath.max(i12, i11);
                dArr[i7][2] = d;
                dArr[i7][3] = i13 + i14;
                iArr[i2] = -1;
                iArr[i3] = i + i7;
                int i15 = 0;
                for (int i16 = 0; i16 < i; i16++) {
                    int i17 = iArr[i16];
                    if (i17 == -1 || i17 == i + i7) {
                        i15++;
                    } else {
                        int i18 = i17 < i ? 1 : (int) dArr[i17 - i][3];
                        int indexFromFlattenedVec2 = EfficientDistanceMatrix.getIndexFromFlattenedVec(i, i16, i3);
                        efficientDistanceMatrix.dists[indexFromFlattenedVec2] = getDist(efficientDistanceMatrix.navigate(i, i16, i2), efficientDistanceMatrix.dists[indexFromFlattenedVec2], d, i13, i14, i18);
                        if (i16 < i2) {
                            efficientDistanceMatrix.dists[EfficientDistanceMatrix.getIndexFromFlattenedVec(i, i16, i2)] = Double.POSITIVE_INFINITY;
                        }
                    }
                }
                HierarchicalAgglomerative.this.fitSummary.add(new Object[]{Integer.valueOf(i7), Double.valueOf(d), Integer.valueOf(i15), logTimer2.formatTime(), logTimer.formatTime(), logTimer.wallMsg()});
            }
        }

        protected abstract double getDist(double d, double d2, double d3, int i, int i2, int i3);
    }

    /* loaded from: input_file:com/clust4j/algo/HierarchicalAgglomerative$Linkage.class */
    public enum Linkage implements Serializable, LinkageTreeBuilder {
        AVERAGE { // from class: com.clust4j.algo.HierarchicalAgglomerative.Linkage.1
            @Override // com.clust4j.algo.HierarchicalAgglomerative.LinkageTreeBuilder
            public AverageLinkageTree buildTree(HierarchicalAgglomerative hierarchicalAgglomerative) {
                Objects.requireNonNull(hierarchicalAgglomerative);
                return new AverageLinkageTree();
            }

            @Override // com.clust4j.algo.MetricValidator
            public boolean isValidMetric(GeometricallySeparable geometricallySeparable) {
                return !HierarchicalAgglomerative.comp_avg_unsupported.contains(geometricallySeparable.getClass());
            }
        },
        COMPLETE { // from class: com.clust4j.algo.HierarchicalAgglomerative.Linkage.2
            @Override // com.clust4j.algo.HierarchicalAgglomerative.LinkageTreeBuilder
            public CompleteLinkageTree buildTree(HierarchicalAgglomerative hierarchicalAgglomerative) {
                Objects.requireNonNull(hierarchicalAgglomerative);
                return new CompleteLinkageTree();
            }

            @Override // com.clust4j.algo.MetricValidator
            public boolean isValidMetric(GeometricallySeparable geometricallySeparable) {
                return !HierarchicalAgglomerative.comp_avg_unsupported.contains(geometricallySeparable.getClass());
            }
        },
        WARD { // from class: com.clust4j.algo.HierarchicalAgglomerative.Linkage.3
            @Override // com.clust4j.algo.HierarchicalAgglomerative.LinkageTreeBuilder
            public WardTree buildTree(HierarchicalAgglomerative hierarchicalAgglomerative) {
                Objects.requireNonNull(hierarchicalAgglomerative);
                return new WardTree();
            }

            @Override // com.clust4j.algo.MetricValidator
            public boolean isValidMetric(GeometricallySeparable geometricallySeparable) {
                return geometricallySeparable.equals(Distance.EUCLIDEAN);
            }
        }
    }

    /* loaded from: input_file:com/clust4j/algo/HierarchicalAgglomerative$LinkageTree.class */
    abstract class LinkageTree extends HierarchicalDendrogram {
        private static final long serialVersionUID = -252115690411913842L;

        public LinkageTree() {
            super();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/clust4j/algo/HierarchicalAgglomerative$LinkageTreeBuilder.class */
    public interface LinkageTreeBuilder extends MetricValidator {
        HierarchicalDendrogram buildTree(HierarchicalAgglomerative hierarchicalAgglomerative);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/clust4j/algo/HierarchicalAgglomerative$WardTree.class */
    public class WardTree extends HierarchicalDendrogram {
        private static final long serialVersionUID = -2336170779406847047L;

        public WardTree() {
            super();
        }

        @Override // com.clust4j.algo.HierarchicalAgglomerative.HierarchicalDendrogram
        protected double getDist(double d, double d2, double d3, int i, int i2, int i3) {
            double d4 = 1.0d / ((i + i2) + i3);
            return FastMath.sqrt((((((i3 + i) * d4) * d) * d) + ((((i3 + i2) * d4) * d2) * d2)) - (((i3 * d4) * d3) * d3));
        }

        @Override // com.clust4j.NamedEntity
        public String getName() {
            return "Ward Tree";
        }
    }

    @Override // com.clust4j.algo.MetricValidator
    public final boolean isValidMetric(GeometricallySeparable geometricallySeparable) {
        return this.linkage.isValidMetric(geometricallySeparable);
    }

    protected HierarchicalAgglomerative(RealMatrix realMatrix) {
        this(realMatrix, new HierarchicalAgglomerativeParameters());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public HierarchicalAgglomerative(RealMatrix realMatrix, HierarchicalAgglomerativeParameters hierarchicalAgglomerativeParameters) {
        super(realMatrix, hierarchicalAgglomerativeParameters, hierarchicalAgglomerativeParameters.getNumClusters());
        this.labels = null;
        this.dist_vec = null;
        this.tree = null;
        this.linkage = hierarchicalAgglomerativeParameters.getLinkage();
        if (!isValidMetric(this.dist_metric)) {
            warn(this.dist_metric.getName() + " is invalid for " + this.linkage + ". Falling back to default Euclidean dist");
            setSeparabilityMetric(DEF_DIST);
        }
        this.m = realMatrix.getRowDimension();
        this.num_clusters = this.k;
        logModelSummary();
    }

    /* JADX WARN: Type inference failed for: r2v1, types: [java.lang.Object[], java.lang.Object[][]] */
    @Override // com.clust4j.algo.AbstractClusterer
    protected final ModelSummary modelSummary() {
        return new ModelSummary(new Object[]{new Object[]{"Num Rows", "Num Cols", "Metric", "Linkage", "Allow Par.", "Num. Clusters"}, new Object[]{Integer.valueOf(this.data.getRowDimension()), Integer.valueOf(this.data.getColumnDimension()), getSeparabilityMetric(), this.linkage, Boolean.valueOf(this.parallel), Integer.valueOf(this.num_clusters)}});
    }

    @Override // com.clust4j.NamedEntity
    public String getName() {
        return "Agglomerative";
    }

    public Linkage getLinkage() {
        return this.linkage;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.clust4j.algo.AbstractClusterer, com.clust4j.algo.BaseModel
    public HierarchicalAgglomerative fit() {
        synchronized (this.fitLock) {
            if (null != this.labels) {
                return this;
            }
            LogTimer logTimer = new LogTimer();
            this.labels = new int[this.m];
            if (1 == this.k) {
                this.fitSummary.add(new Object[]{0, 0, Double.valueOf(Double.NaN), logTimer.formatTime(), logTimer.formatTime(), logTimer.wallMsg()});
                warn("converged immediately due to " + (this.singular_value ? "singular nature of input matrix" : "k = 1"));
                sayBye(logTimer);
                return this;
            }
            this.dist_vec = new EfficientDistanceMatrix(this.data, getSeparabilityMetric(), true);
            info("computed distance matrix in " + logTimer.toString());
            LogTimer logTimer2 = new LogTimer();
            this.tree = this.linkage.buildTree(this);
            info("constructed " + this.tree.getName() + " HierarchicalDendrogram in " + logTimer2.toString());
            this.labels = hcCut(this.num_clusters, this.tree.linkage(), this.m);
            this.labels = new SafeLabelEncoder(this.labels).fit().getEncodedLabels();
            sayBye(logTimer);
            this.dist_vec = null;
            return this;
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    static int[] hcCut(int i, double[][] dArr, int i2) {
        if (i > i2) {
            throw new InternalError(i + " > " + i2);
        }
        SimpleHeap simpleHeap = new SimpleHeap(Integer.valueOf(-(((int) VecUtils.max(dArr[dArr.length - 1])) + 1)));
        for (int i3 = 0; i3 < i - 1; i3++) {
            int i4 = (-((Integer) simpleHeap.get(0)).intValue()) - i2;
            if (i4 < 0) {
                i4 = dArr.length + i4;
            }
            double[] dArr2 = dArr[i4];
            simpleHeap.push(Integer.valueOf(-((int) dArr2[0])));
            simpleHeap.pushPop(Integer.valueOf(-((int) dArr2[1])));
        }
        int i5 = 0;
        int[] iArr = new int[i2];
        Iterator<T> it2 = simpleHeap.iterator();
        while (it2.hasNext()) {
            for (Integer num : hcGetDescendents(-((Integer) it2.next()).intValue(), dArr, i2)) {
                iArr[num.intValue()] = i5;
            }
            i5++;
        }
        return iArr;
    }

    static Integer[] hcGetDescendents(int i, double[][] dArr, int i2) {
        if (i < i2) {
            return new Integer[]{Integer.valueOf(i)};
        }
        SimpleHeap simpleHeap = new SimpleHeap(Integer.valueOf(i));
        ArrayList arrayList = new ArrayList();
        int i3 = 1;
        while (i3 > 0) {
            int intValue = ((Integer) simpleHeap.popInPlace()).intValue();
            if (intValue < i2) {
                arrayList.add(Integer.valueOf(intValue));
                i3--;
            } else {
                for (double d : dArr[intValue - i2]) {
                    simpleHeap.add(Integer.valueOf((int) d));
                }
                i3++;
            }
        }
        return (Integer[]) arrayList.toArray(new Integer[arrayList.size()]);
    }

    @Override // com.clust4j.algo.BaseClassifier
    public int[] getLabels() {
        return super.handleLabelCopy(this.labels);
    }

    @Override // com.clust4j.log.Loggable
    public Log.Tag.Algo getLoggerTag() {
        return Log.Tag.Algo.AGGLOMERATIVE;
    }

    @Override // com.clust4j.algo.AbstractClusterer
    protected final Object[] getModelFitSummaryHeaders() {
        return new Object[]{"Link Iter. #", "Iter. Min", "Continues", "Iter. Time", "Total Time", "Wall"};
    }

    @Override // com.clust4j.algo.UnsupervisedClassifier
    public double indexAffinityScore(int[] iArr) {
        return SupervisedMetric.INDEX_AFFINITY.evaluate(iArr, getLabels());
    }

    @Override // com.clust4j.algo.UnsupervisedClassifier
    public double silhouetteScore() {
        return UnsupervisedMetric.SILHOUETTE.evaluate(this, getLabels());
    }

    @Override // com.clust4j.algo.BaseClassifier
    public int[] predict(RealMatrix realMatrix) {
        int[] labels = getLabels();
        int rowDimension = realMatrix.getRowDimension();
        int columnDimension = realMatrix.getColumnDimension();
        if (columnDimension != this.data.getColumnDimension()) {
            throw new DimensionMismatchException(columnDimension, this.data.getColumnDimension());
        }
        return 1 == this.num_clusters ? VecUtils.repInt(labels[0], rowDimension) : new NearestCentroidParameters().setMetric(this.dist_metric).setVerbose(false).fitNewModel(getData(), labels).predict(realMatrix);
    }

    static {
        comp_avg_unsupported.add(CircularKernel.class);
        comp_avg_unsupported.add(LogKernel.class);
    }
}
