/*
 * Decompiled with CFR 0.152.
 */
package de.jungblut.classification.meta;

import com.google.common.base.Preconditions;
import de.jungblut.classification.AbstractClassifier;
import de.jungblut.classification.Classifier;
import de.jungblut.classification.ClassifierFactory;
import de.jungblut.classification.meta.TrainingSplit;
import de.jungblut.datastructure.ArrayUtils;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.dense.DenseDoubleVector;
import de.jungblut.partition.BlockPartitioner;
import de.jungblut.partition.Boundaries;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public final class Voter<A extends Classifier>
extends AbstractClassifier {
    private static final Logger LOG = LogManager.getLogger(Voter.class);
    private final Classifier[] classifier;
    private CombiningType type;
    private SelectionType selection = SelectionType.NONE;
    private int threads = 1;
    private boolean verbose;

    private Voter(CombiningType type, int numClassifiers, ClassifierFactory<A> classifierFactory) {
        this.type = type;
        this.classifier = new Classifier[numClassifiers];
        for (int i = 0; i < numClassifiers; ++i) {
            this.classifier[i] = classifierFactory.newInstance();
        }
    }

    private Voter(List<A> classifierCollection) {
        this.classifier = new Classifier[classifierCollection.size()];
        for (int i = 0; i < this.classifier.length; ++i) {
            this.classifier[i] = (Classifier)Preconditions.checkNotNull((Object)((Classifier)classifierCollection.get(i)));
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void train(DoubleVector[] features, DoubleVector[] outcome) {
        ExecutorService pool = Executors.newFixedThreadPool(this.threads);
        try {
            int i;
            ExecutorCompletionService<Boolean> completionService = new ExecutorCompletionService<Boolean>(pool);
            List<TrainingSplit> splits = null;
            switch (this.selection) {
                case BAGGING: {
                    splits = this.bag(features, outcome);
                    break;
                }
                case SHUFFLE: {
                    splits = this.partition(features, outcome, true);
                    break;
                }
                default: {
                    splits = this.partition(features, outcome, false);
                }
            }
            for (i = 0; i < this.classifier.length; ++i) {
                completionService.submit(new TrainingWorker(this.classifier[i], splits.get(i)));
            }
            for (i = 0; i < this.classifier.length; ++i) {
                completionService.take();
                if (!this.verbose) continue;
                LOG.info("Finished with training classifier " + (i + 1) + " of " + this.classifier.length);
            }
        }
        catch (InterruptedException e) {
            e.printStackTrace();
        }
        finally {
            pool.shutdownNow();
        }
        if (this.verbose) {
            LOG.info("Successfully finished training!");
        }
    }

    @Override
    public DoubleVector predict(DoubleVector features) {
        DoubleVector[] result = new DoubleVector[this.classifier.length];
        for (int i = 0; i < this.classifier.length; ++i) {
            result[i] = this.classifier[i].predict(features);
        }
        int numPossibleOutcomes = result[0].getDimension() == 1 ? 2 : result[0].getDimension();
        DenseDoubleVector toReturn = new DenseDoubleVector(result[0].getDimension() == 1 ? 1 : numPossibleOutcomes);
        switch (this.type) {
            case MAJORITY: {
                double[] histogram = this.createPredictionHistogram(result, numPossibleOutcomes);
                if (numPossibleOutcomes == 2) {
                    toReturn.set(0, (double)ArrayUtils.maxIndex(histogram));
                    break;
                }
                toReturn.set(ArrayUtils.maxIndex(histogram), 1.0);
                break;
            }
            case PROBABILITY: {
                DoubleVector v = result[0];
                for (int i = 1; i < result.length; ++i) {
                    v = v.add(result[i]);
                }
                toReturn = v.divide(v.sum());
                break;
            }
            case AVERAGE: {
                for (int i = 0; i < result.length; ++i) {
                    toReturn = toReturn.add(result[i]);
                }
                toReturn = toReturn.divide((double)this.classifier.length);
                break;
            }
            default: {
                throw new UnsupportedOperationException("Type " + this.type + " isn't supported yet!");
            }
        }
        return toReturn;
    }

    public Classifier[] getClassifier() {
        return this.classifier;
    }

    public Voter<A> verbose() {
        return this.verbose(true);
    }

    public Voter<A> verbose(boolean verb) {
        this.verbose = verb;
        return this;
    }

    public Voter<A> selectionType(SelectionType type) {
        this.selection = type;
        return this;
    }

    public Voter<A> numThreads(int threads) {
        this.threads = threads;
        return this;
    }

    public Voter<A> setCombiningType(CombiningType type) {
        this.type = type;
        return this;
    }

    private double[] createPredictionHistogram(DoubleVector[] result, int possibleOutcomes) {
        double[] histogram = new double[possibleOutcomes];
        for (int i = 0; i < this.classifier.length; ++i) {
            int clz;
            int n = clz = this.classifier[i].extractPredictedClass(result[i]);
            histogram[n] = histogram[n] + 1.0;
        }
        return histogram;
    }

    private List<TrainingSplit> bag(DoubleVector[] features, DoubleVector[] outcome) {
        ArrayList<TrainingSplit> splits = new ArrayList<TrainingSplit>(this.classifier.length);
        Random rand = new Random();
        for (int i = 0; i < this.classifier.length; ++i) {
            DoubleVector[] featureBag = new DoubleVector[features.length];
            DoubleVector[] outcomeBag = new DoubleVector[features.length];
            for (int n = 0; n < features.length; ++n) {
                int nextInt = rand.nextInt(features.length);
                featureBag[n] = features[nextInt];
                outcomeBag[n] = outcome[nextInt];
            }
            splits.add(new TrainingSplit(featureBag, outcomeBag));
        }
        return splits;
    }

    private List<TrainingSplit> partition(DoubleVector[] features, DoubleVector[] outcome, boolean shuffle) {
        int i;
        ArrayList<TrainingSplit> splits = new ArrayList<TrainingSplit>(this.classifier.length);
        if (shuffle) {
            ArrayUtils.multiShuffle(features, new DoubleVector[][]{outcome});
        }
        ArrayList<Boundaries.Range> partitions = new ArrayList<Boundaries.Range>(new BlockPartitioner().partition(this.classifier.length, features.length).getBoundaries());
        int[] splitRanges = new int[this.classifier.length + 1];
        for (i = 1; i < this.classifier.length; ++i) {
            splitRanges[i] = ((Boundaries.Range)partitions.get(i)).getStart();
        }
        splitRanges[this.classifier.length] = features.length - 1;
        if (this.verbose) {
            LOG.info("Computed split ranges for 0-" + features.length + ": " + Arrays.toString(splitRanges) + "\n");
        }
        for (i = 0; i < this.classifier.length; ++i) {
            DoubleVector[] featureSplit = ArrayUtils.subArray(features, splitRanges[i], splitRanges[i + 1]);
            DoubleVector[] outcomeSplit = ArrayUtils.subArray(outcome, splitRanges[i], splitRanges[i + 1]);
            splits.add(new TrainingSplit(featureSplit, outcomeSplit));
        }
        return splits;
    }

    public static <K extends Classifier> Voter<K> create(int numClassifiers, CombiningType type, ClassifierFactory<K> classifierFactory) {
        return new Voter<K>(type, numClassifiers, classifierFactory);
    }

    public static <K extends Classifier> Voter<K> fromTrainedModels(List<K> classifier) {
        return new Voter<K>(classifier);
    }

    final class TrainingWorker
    implements Callable<Boolean> {
        private final Classifier cls;
        private final TrainingSplit split;

        TrainingWorker(Classifier classifier, TrainingSplit split) {
            this.cls = classifier;
            this.split = split;
        }

        @Override
        public Boolean call() throws Exception {
            this.cls.train(this.split.getTrainFeatures(), this.split.getTrainOutcome());
            return true;
        }
    }

    public static enum SelectionType {
        NONE,
        SHUFFLE,
        BAGGING;

    }

    public static enum CombiningType {
        MAJORITY,
        AVERAGE,
        PROBABILITY;

    }
}

