/*
 * Decompiled with CFR 0.152.
 */
package org.apache.solr.client.solrj.io.eval;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.TreeSet;
import org.apache.commons.math3.ml.distance.DistanceMeasure;
import org.apache.commons.math3.ml.distance.EuclideanDistance;
import org.apache.solr.client.solrj.io.eval.ManyValueWorker;
import org.apache.solr.client.solrj.io.eval.Matrix;
import org.apache.solr.client.solrj.io.eval.RecursiveObjectEvaluator;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;

public class KnnEvaluator
extends RecursiveObjectEvaluator
implements ManyValueWorker {
    protected static final long serialVersionUID = 1L;

    public KnnEvaluator(StreamExpression expression, StreamFactory factory) throws IOException {
        super(expression, factory);
    }

    @Override
    public Object doWork(Object ... values2) throws IOException {
        if (values2.length < 3) {
            throw new IOException("knn expects three parameters a Matrix, numeric array and k");
        }
        Matrix matrix = null;
        double[] vec = null;
        int k = 0;
        if (!(values2[0] instanceof Matrix)) {
            throw new IOException("The first parameter for knn should be a matrix.");
        }
        matrix = (Matrix)values2[0];
        if (values2[1] instanceof List) {
            List nums = (List)values2[1];
            vec = new double[nums.size()];
            for (int i = 0; i < nums.size(); ++i) {
                vec[i] = ((Number)nums.get(i)).doubleValue();
            }
        } else {
            throw new IOException("The second parameter for knn should be a numeric array.");
        }
        if (!(values2[2] instanceof Number)) {
            throw new IOException("The third parameter for knn should be k.");
        }
        k = ((Number)values2[2]).intValue();
        Object distanceMeasure = null;
        distanceMeasure = values2.length == 4 ? (DistanceMeasure)values2[3] : new EuclideanDistance();
        return KnnEvaluator.search(matrix, vec, k, (DistanceMeasure)distanceMeasure);
    }

    public static Matrix search(Matrix observations, double[] vec, int k, DistanceMeasure distanceMeasure) {
        double[][] data2 = observations.getData();
        TreeSet<Neighbor> neighbors = new TreeSet<Neighbor>();
        for (int i = 0; i < data2.length; ++i) {
            double distance = distanceMeasure.compute(vec, data2[i]);
            neighbors.add(new Neighbor(i, distance));
            if (neighbors.size() <= k) continue;
            neighbors.pollLast();
        }
        double[][] out = new double[neighbors.size()][];
        List<String> rowLabels = observations.getRowLabels();
        ArrayList<String> newRowLabels = new ArrayList<String>();
        ArrayList<Integer> indexes = new ArrayList<Integer>();
        ArrayList<Double> distances = new ArrayList<Double>();
        int i = -1;
        while (neighbors.size() > 0) {
            Neighbor neighbor = (Neighbor)neighbors.pollFirst();
            int rowIndex = neighbor.getRow();
            if (rowLabels != null) {
                newRowLabels.add(rowLabels.get(rowIndex));
            }
            out[++i] = data2[rowIndex];
            distances.add(neighbor.getDistance());
            indexes.add(rowIndex);
        }
        Matrix knn = new Matrix(out);
        if (rowLabels != null) {
            knn.setRowLabels(newRowLabels);
        }
        knn.setColumnLabels(observations.getColumnLabels());
        knn.setAttribute("distances", distances);
        knn.setAttribute("indexes", indexes);
        return knn;
    }

    public static class Neighbor
    implements Comparable<Neighbor> {
        private Double distance;
        private int row;

        public Neighbor(int row, double distance) {
            this.distance = distance;
            this.row = row;
        }

        public int getRow() {
            return this.row;
        }

        public Double getDistance() {
            return this.distance;
        }

        @Override
        public int compareTo(Neighbor neighbor) {
            if (this.distance.compareTo(neighbor.getDistance()) == 0) {
                return this.row - neighbor.getRow();
            }
            return this.distance.compareTo(neighbor.getDistance());
        }
    }
}

