/*
 * Decompiled with CFR 0.152.
 */
package de.jungblut.classification.knn;

import de.jungblut.classification.AbstractClassifier;
import de.jungblut.jrpt.VectorDistanceTuple;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.dense.DenseDoubleVector;
import de.jungblut.math.dense.SingleEntryDoubleVector;
import java.util.List;

public abstract class AbstractKNearestNeighbours
extends AbstractClassifier {
    protected final int numOutcomes;
    protected final int k;

    public AbstractKNearestNeighbours(int numOutcomes, int k) {
        this.numOutcomes = numOutcomes;
        this.k = k;
    }

    @Override
    public DoubleVector predict(DoubleVector features) {
        List<VectorDistanceTuple<DoubleVector>> nearestNeighbours = this.getNearestNeighbours(features, this.k);
        DenseDoubleVector outcomeHistogram = new DenseDoubleVector(this.numOutcomes);
        for (VectorDistanceTuple<DoubleVector> tuple : nearestNeighbours) {
            int classIndex = 0;
            classIndex = this.numOutcomes == 2 ? (int)((DoubleVector)tuple.getValue()).get(0) : ((DoubleVector)tuple.getValue()).maxIndex();
            outcomeHistogram.set(classIndex, outcomeHistogram.get(classIndex) + 1.0);
        }
        if (this.numOutcomes == 2) {
            return new SingleEntryDoubleVector((double)outcomeHistogram.maxIndex());
        }
        return outcomeHistogram;
    }

    @Override
    public DoubleVector predictProbability(DoubleVector features) {
        DoubleVector prediction = this.predict(features);
        if (this.numOutcomes != 2) {
            prediction = prediction.divide(prediction.sum());
        }
        return prediction;
    }

    protected abstract List<VectorDistanceTuple<DoubleVector>> getNearestNeighbours(DoubleVector var1, int var2);
}

