/*
 * Decompiled with CFR 0.152.
 */
package mulan.classifier.lazy;

import java.util.ArrayList;
import java.util.Random;
import mulan.classifier.MultiLabelOutput;
import mulan.classifier.lazy.MultiLabelKNN;
import mulan.core.Util;
import mulan.data.MultiLabelInstances;
import weka.classifiers.lazy.IBk;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.TechnicalInformation;
import weka.core.Utils;

public class BRkNN
extends MultiLabelKNN {
    private Random random;
    private int avgPredictedLabels;
    private int cvMaxK;
    private boolean cvkSelection = false;
    private ExtensionType extension = ExtensionType.NONE;

    public BRkNN() {
        this(10, ExtensionType.NONE);
    }

    public BRkNN(int numOfNeighbors) {
        this(numOfNeighbors, ExtensionType.NONE);
    }

    public BRkNN(int numOfNeighbors, ExtensionType ext) {
        super(numOfNeighbors);
        this.random = new Random(1L);
        this.extension = ext;
    }

    @Override
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS);
        result.setValue(TechnicalInformation.Field.AUTHOR, "Eleftherios Spyromitros, Grigorios Tsoumakas, Ioannis Vlahavas");
        result.setValue(TechnicalInformation.Field.TITLE, "An Empirical Study of Lazy Multilabel Classification Algorithms");
        result.setValue(TechnicalInformation.Field.BOOKTITLE, "Proc. 5th Hellenic Conference on Artificial Intelligence (SETN 2008)");
        result.setValue(TechnicalInformation.Field.LOCATION, "Syros, Greece");
        result.setValue(TechnicalInformation.Field.YEAR, "2008");
        return result;
    }

    @Override
    protected void buildInternal(MultiLabelInstances aTrain) throws Exception {
        super.buildInternal(aTrain);
        if (this.cvkSelection) {
            this.crossValidate();
        }
    }

    public void setkSelectionViaCV(boolean flag) {
        this.cvkSelection = flag;
    }

    private void crossValidate() throws Exception {
        try {
            int i;
            double[] hammingLoss = new double[this.cvMaxK];
            for (int i2 = 0; i2 < this.cvMaxK; ++i2) {
                hammingLoss[i2] = 0.0;
            }
            Instances dataSet = this.train;
            for (i = 0; i < dataSet.numInstances(); ++i) {
                if (this.getDebug() && i % 50 == 0) {
                    this.debug("Cross validating " + i + "/" + dataSet.numInstances() + "\r");
                }
                Instance instance = dataSet.instance(i);
                Instances neighbours = this.lnn.kNearestNeighbours(instance, this.cvMaxK);
                double[] origDistances = this.lnn.getDistances();
                boolean[] trueLabels = new boolean[this.numLabels];
                for (int counter = 0; counter < this.numLabels; ++counter) {
                    int classIdx = this.labelIndices[counter];
                    String classValue = instance.attribute(classIdx).value((int)instance.value(classIdx));
                    trueLabels[counter] = classValue.equals("1");
                }
                for (int j = this.cvMaxK; j > 0; --j) {
                    double[] convertedDistances = new double[origDistances.length];
                    System.arraycopy(origDistances, 0, convertedDistances, 0, origDistances.length);
                    double[] confidences = this.getConfidences(neighbours, convertedDistances);
                    boolean[] bipartition = null;
                    switch (this.extension) {
                        case NONE: {
                            MultiLabelOutput results = new MultiLabelOutput(confidences, 0.5);
                            bipartition = results.getBipartition();
                            break;
                        }
                        case EXTA: {
                            bipartition = this.labelsFromConfidences2(confidences);
                            break;
                        }
                        case EXTB: {
                            bipartition = this.labelsFromConfidences3(confidences);
                        }
                    }
                    double symmetricDifference = 0.0;
                    for (int labelIndex = 0; labelIndex < this.numLabels; ++labelIndex) {
                        boolean predicted = bipartition[labelIndex];
                        boolean actual = trueLabels[labelIndex];
                        if (predicted == actual) continue;
                        symmetricDifference += 1.0;
                    }
                    int n = j - 1;
                    hammingLoss[n] = hammingLoss[n] + symmetricDifference / (double)this.numLabels;
                    neighbours = new IBk().pruneToK(neighbours, convertedDistances, j - 1);
                }
            }
            if (this.getDebug()) {
                for (i = this.cvMaxK; i > 0; --i) {
                    this.debug("Hold-one-out performance of " + i + " neighbors ");
                    this.debug("(Hamming Loss) = " + hammingLoss[i - 1] / (double)dataSet.numInstances());
                }
            }
            double[] searchStats = hammingLoss;
            double bestPerformance = Double.NaN;
            int bestK = 1;
            for (int i3 = 0; i3 < this.cvMaxK; ++i3) {
                if (!Double.isNaN(bestPerformance) && !(bestPerformance > searchStats[i3])) continue;
                bestPerformance = searchStats[i3];
                bestK = i3 + 1;
            }
            this.numOfNeighbors = bestK;
            if (this.getDebug()) {
                System.err.println("Selected k = " + bestK);
            }
        }
        catch (Exception ex) {
            throw new Error("Couldn't optimize by cross-validation: " + ex.getMessage());
        }
    }

    @Override
    protected MultiLabelOutput makePredictionInternal(Instance instance) throws Exception {
        Instances knn = this.lnn.kNearestNeighbours(instance, this.numOfNeighbors);
        double[] distances = this.lnn.getDistances();
        double[] confidences = this.getConfidences(knn, distances);
        MultiLabelOutput results = null;
        switch (this.extension) {
            case NONE: {
                results = new MultiLabelOutput(confidences, 0.5);
                break;
            }
            case EXTA: {
                boolean[] bipartition = this.labelsFromConfidences2(confidences);
                results = new MultiLabelOutput(bipartition, confidences);
                break;
            }
            case EXTB: {
                boolean[] bipartition = this.labelsFromConfidences3(confidences);
                results = new MultiLabelOutput(bipartition, confidences);
            }
        }
        return results;
    }

    private double[] getConfidences(Instances neighbours, double[] distances) {
        int i;
        double neighborLabels = 0.0;
        double[] confidences = new double[this.numLabels];
        for (i = 0; i < this.numLabels; ++i) {
            confidences[i] = 1.0 / (double)Math.max(1, this.train.numInstances());
        }
        double total = (double)this.numLabels / (double)Math.max(1, this.train.numInstances());
        for (i = 0; i < neighbours.numInstances(); ++i) {
            Instance current = neighbours.instance(i);
            distances[i] = distances[i] * distances[i];
            distances[i] = Math.sqrt(distances[i] / (double)(this.train.numAttributes() - this.numLabels));
            double weight = 1.0;
            weight *= current.weight();
            for (int j = 0; j < this.numLabels; ++j) {
                double value = Double.parseDouble(current.attribute(this.labelIndices[j]).value((int)current.value(this.labelIndices[j])));
                if (!Utils.eq((double)value, (double)1.0)) continue;
                int n = j;
                confidences[n] = confidences[n] + weight;
                neighborLabels += weight;
            }
            total += weight;
        }
        this.avgPredictedLabels = (int)Math.round(neighborLabels / total);
        if (total > 0.0) {
            Utils.normalize((double[])confidences, (double)total);
        }
        return confidences;
    }

    protected boolean[] labelsFromConfidences2(double[] confidences) {
        boolean[] bipartition = new boolean[this.numLabels];
        boolean flag = false;
        for (int i = 0; i < this.numLabels; ++i) {
            if (!(confidences[i] >= 0.5)) continue;
            bipartition[i] = true;
            flag = true;
        }
        if (!flag) {
            int index = Util.RandomIndexOfMax(confidences, this.random);
            bipartition[index] = true;
        }
        return bipartition;
    }

    protected boolean[] labelsFromConfidences3(double[] confidences) {
        boolean[] bipartition = new boolean[this.numLabels];
        int[] indices = Utils.stableSort((double[])confidences);
        ArrayList<Integer> lastindices = new ArrayList<Integer>();
        int counter = 0;
        for (int i = this.numLabels - 1; i > 0; --i) {
            if (confidences[indices[i]] > confidences[indices[this.numLabels - this.avgPredictedLabels]]) {
                bipartition[indices[i]] = true;
                ++counter;
                continue;
            }
            if (confidences[indices[i]] != confidences[indices[this.numLabels - this.avgPredictedLabels]]) break;
            lastindices.add(indices[i]);
        }
        int size = lastindices.size();
        int j = this.avgPredictedLabels - counter;
        while (j > 0) {
            int next = this.random.nextInt(size);
            if (bipartition[(Integer)lastindices.get(next)]) continue;
            bipartition[((Integer)lastindices.get((int)next)).intValue()] = true;
            --j;
        }
        return bipartition;
    }

    public void setCvMaxK(int cvMaxK) {
        this.cvMaxK = cvMaxK;
    }

    @Override
    public String globalInfo() {
        return "Simple BR implementation of the KNN algorithm.For more information, see\n\n" + this.getTechnicalInformation().toString();
    }

    public static enum ExtensionType {
        NONE,
        EXTA,
        EXTB;

    }
}

