package com.clust4j.algo;

import com.clust4j.algo.BaseNeighborsModel;
import com.clust4j.except.ModelNotFitException;
import com.clust4j.log.Log;
import com.clust4j.log.LogTimer;
import com.clust4j.utils.MatUtils;
import com.clust4j.utils.VecUtils;
import java.util.concurrent.RejectedExecutionException;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.util.FastMath;
import org.apache.pdfbox.contentstream.operator.OperatorName;

/* loaded from: input_file:com/clust4j/algo/NearestNeighbors.class */
public final class NearestNeighbors extends BaseNeighborsModel {
    private static final long serialVersionUID = 8306843374522289973L;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/clust4j/algo/NearestNeighbors$ParallelNNSearch.class */
    public static class ParallelNNSearch extends BaseNeighborsModel.ParallelNeighborhoodSearch {
        private static final long serialVersionUID = -1600812794470325448L;
        final int k;

        public ParallelNNSearch(double[][] dArr, NearestNeighbors nearestNeighbors, int i) {
            super(dArr, nearestNeighbors);
            this.k = i;
        }

        public ParallelNNSearch(ParallelNNSearch parallelNNSearch, int i, int i2) {
            super(parallelNNSearch, i, i2);
            this.k = parallelNNSearch.k;
        }

        static Neighborhood doAll(double[][] dArr, NearestNeighbors nearestNeighbors, int i) {
            return (Neighborhood) getThreadPool().invoke(new ParallelNNSearch(dArr, nearestNeighbors, i));
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        @Override // com.clust4j.algo.BaseNeighborsModel.ParallelNeighborhoodSearch
        public ParallelNNSearch newInstance(BaseNeighborsModel.ParallelNeighborhoodSearch parallelNeighborhoodSearch, int i, int i2) {
            return new ParallelNNSearch((ParallelNNSearch) parallelNeighborhoodSearch, i, i2);
        }

        @Override // com.clust4j.algo.BaseNeighborsModel.ParallelNeighborhoodSearch
        Neighborhood query(NearestNeighborHeapSearch nearestNeighborHeapSearch, double[][] dArr) {
            return nearestNeighborHeapSearch.query(dArr, this.k, false, true);
        }
    }

    protected NearestNeighbors(RealMatrix realMatrix) {
        this(realMatrix, 5);
    }

    protected NearestNeighbors(AbstractClusterer abstractClusterer) {
        this(abstractClusterer, 5);
    }

    protected NearestNeighbors(RealMatrix realMatrix, int i) {
        this(realMatrix, new NearestNeighborsParameters(i));
    }

    protected NearestNeighbors(AbstractClusterer abstractClusterer, int i) {
        this(abstractClusterer, new NearestNeighborsParameters(i));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public NearestNeighbors(RealMatrix realMatrix, NearestNeighborsParameters nearestNeighborsParameters) {
        this(realMatrix, nearestNeighborsParameters, false);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public NearestNeighbors(AbstractClusterer abstractClusterer, NearestNeighborsParameters nearestNeighborsParameters) {
        super(abstractClusterer, nearestNeighborsParameters);
        validateK(this.kNeighbors.intValue(), this.m);
        logModelSummary();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public NearestNeighbors(RealMatrix realMatrix, NearestNeighborsParameters nearestNeighborsParameters, boolean z) {
        super(realMatrix, nearestNeighborsParameters, z);
        validateK(this.kNeighbors.intValue(), this.m);
        logModelSummary();
    }

    private static void validateK(int i, int i2) {
        if (i < 1) {
            throw new IllegalArgumentException("k must be positive");
        }
        if (i > i2) {
            throw new IllegalArgumentException("k must be <= number of samples");
        }
    }

    /* JADX WARN: Type inference failed for: r2v1, types: [java.lang.Object[], java.lang.Object[][]] */
    @Override // com.clust4j.algo.AbstractClusterer
    protected final ModelSummary modelSummary() {
        return new ModelSummary(new Object[]{new Object[]{"Num Rows", "Num Cols", "Metric", "Algo", OperatorName.STROKING_COLOR_CMYK, "Leaf Size", "Allow Par."}, new Object[]{Integer.valueOf(this.m), Integer.valueOf(this.data.getColumnDimension()), getSeparabilityMetric(), this.alg, this.kNeighbors, Integer.valueOf(this.leafSize), Boolean.valueOf(this.parallel)}});
    }

    @Override // com.clust4j.algo.AbstractClusterer
    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (!(obj instanceof NearestNeighbors)) {
            return false;
        }
        NearestNeighbors nearestNeighbors = (NearestNeighbors) obj;
        return super.equals(obj) && (null == nearestNeighbors.kNeighbors || null == this.kNeighbors ? nearestNeighbors.kNeighbors == this.kNeighbors : nearestNeighbors.kNeighbors.intValue() == this.kNeighbors.intValue()) && nearestNeighbors.leafSize == this.leafSize && MatUtils.equalsExactly(nearestNeighbors.fit_X, this.fit_X);
    }

    @Override // com.clust4j.NamedEntity
    public String getName() {
        return "NearestNeighbors";
    }

    public int getK() {
        return this.kNeighbors.intValue();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.clust4j.algo.BaseNeighborsModel, com.clust4j.algo.AbstractClusterer, com.clust4j.algo.BaseModel
    public NearestNeighbors fit() {
        synchronized (this.fitLock) {
            if (null != this.res) {
                return this;
            }
            int min = FastMath.min(this.kNeighbors.intValue() + 1, this.m);
            LogTimer logTimer = new LogTimer();
            Neighborhood neighborhood = null;
            if (this.parallel) {
                try {
                    neighborhood = ParallelNNSearch.doAll(this.fit_X, this, min);
                } catch (RejectedExecutionException e) {
                    warn("parallel neighborhood search failed; falling back to serial query");
                }
            }
            if (null == neighborhood) {
                neighborhood = new Neighborhood(this.tree.query(this.fit_X, min, false, true));
            }
            info("queried " + this.alg + " for nearest neighbors in " + logTimer.toString());
            double[][] distances = neighborhood.getDistances();
            int[][] indices = neighborhood.getIndices();
            int length = indices[0].length;
            int[] arange = VecUtils.arange(this.m);
            boolean[] zArr = new boolean[this.m];
            boolean[][] zArr2 = new boolean[this.m][length];
            for (int i = 0; i < this.m; i++) {
                boolean z = true;
                for (int i2 = 0; i2 < length; i2++) {
                    boolean z2 = indices[i][i2] != arange[i];
                    zArr2[i][i2] = z2;
                    z &= z2;
                }
                zArr[i] = z;
            }
            for (int i3 = 0; i3 < this.m; i3++) {
                if (zArr[i3]) {
                    zArr2[i3][0] = false;
                }
            }
            int i4 = 0;
            int[] iArr = new int[this.m * (min - 1)];
            double[] dArr = new double[this.m * (min - 1)];
            for (int i5 = 0; i5 < this.m; i5++) {
                double d = Double.POSITIVE_INFINITY;
                double d2 = Double.NEGATIVE_INFINITY;
                for (int i6 = 0; i6 < length; i6++) {
                    if (zArr2[i5][i6]) {
                        iArr[i4] = indices[i5][i6];
                        dArr[i4] = distances[i5][i6];
                        d = FastMath.min(distances[i5][i6], d);
                        d2 = FastMath.max(distances[i5][i6], d2);
                        i4++;
                    }
                }
                this.fitSummary.add(new Object[]{Integer.valueOf(i5), Double.valueOf(d), Double.valueOf(d2), logTimer.wallTime()});
            }
            this.res = new Neighborhood(MatUtils.reshape(dArr, this.m, min - 1), MatUtils.reshape(iArr, this.m, min - 1));
            sayBye(logTimer);
            return this;
        }
    }

    @Override // com.clust4j.algo.AbstractClusterer
    protected final Object[] getModelFitSummaryHeaders() {
        return new Object[]{"Instance", "Nrst-Nbr. Dist", "Max-Nbr. Dist", "Wall"};
    }

    @Override // com.clust4j.algo.BaseNeighborsModel
    public Neighborhood getNeighbors(RealMatrix realMatrix) {
        return getNeighbors(realMatrix, this.kNeighbors.intValue());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Neighborhood getNeighbors(double[][] dArr, boolean z) {
        return getNeighbors(dArr, this.kNeighbors.intValue(), z);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Neighborhood getNeighbors(double[][] dArr) {
        return getNeighbors(dArr, this.kNeighbors.intValue(), false);
    }

    public Neighborhood getNeighbors(RealMatrix realMatrix, int i) {
        return getNeighbors(realMatrix.getData(), i, this.parallel);
    }

    protected Neighborhood getNeighbors(double[][] dArr, int i, boolean z) {
        if (null == this.res) {
            throw new ModelNotFitException("model not yet fit");
        }
        validateK(i, this.m);
        if (z) {
            try {
                return ParallelNNSearch.doAll(dArr, this, i);
            } catch (RejectedExecutionException e) {
                warn("parallel neighborhood search failed; falling back to serial search");
            }
        }
        return this.tree.query(dArr, i, false, true);
    }

    @Override // com.clust4j.log.Loggable
    public Log.Tag.Algo getLoggerTag() {
        return Log.Tag.Algo.NEAREST;
    }
}
