/*
 * Decompiled with CFR 0.152.
 */
package de.jungblut.classification.eval;

import com.google.common.base.Preconditions;
import de.jungblut.datastructure.ArrayUtils;
import de.jungblut.math.DoubleVector;
import java.util.Arrays;
import java.util.Deque;
import java.util.LinkedList;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class EvaluationSplit {
    private static final Logger LOG = LogManager.getLogger(EvaluationSplit.class);
    private final DoubleVector[] trainFeatures;
    private final DoubleVector[] trainOutcome;
    private final DoubleVector[] testFeatures;
    private final DoubleVector[] testOutcome;

    public EvaluationSplit(DoubleVector[] trainFeatures, DoubleVector[] trainOutcome, DoubleVector[] testFeatures, DoubleVector[] testOutcome) {
        this.trainFeatures = trainFeatures;
        this.trainOutcome = trainOutcome;
        this.testFeatures = testFeatures;
        this.testOutcome = testOutcome;
    }

    public DoubleVector[] getTrainFeatures() {
        return this.trainFeatures;
    }

    public DoubleVector[] getTrainOutcome() {
        return this.trainOutcome;
    }

    public DoubleVector[] getTestFeatures() {
        return this.testFeatures;
    }

    public DoubleVector[] getTestOutcome() {
        return this.testOutcome;
    }

    public static EvaluationSplit create(DoubleVector[] features, DoubleVector[] outcome, float splitFraction, boolean random) {
        Preconditions.checkArgument((features.length == outcome.length ? 1 : 0) != 0, (Object)"Feature vector and outcome vector must match in length!");
        Preconditions.checkArgument((splitFraction >= 0.0f && splitFraction <= 1.0f ? 1 : 0) != 0, (Object)("splitFraction must be between 0 and 1! Given: " + splitFraction));
        if (random) {
            ArrayUtils.multiShuffle(features, new DoubleVector[][]{outcome});
        }
        int splitIndex = (int)((float)features.length * splitFraction);
        DoubleVector[] trainFeatures = ArrayUtils.subArray(features, splitIndex - 1);
        DoubleVector[] trainOutcome = ArrayUtils.subArray(outcome, splitIndex - 1);
        DoubleVector[] testFeatures = ArrayUtils.subArray(features, splitIndex, features.length - 1);
        DoubleVector[] testOutcome = ArrayUtils.subArray(outcome, splitIndex, outcome.length - 1);
        return new EvaluationSplit(trainFeatures, trainOutcome, testFeatures, testOutcome);
    }

    public static EvaluationSplit createStratified(DoubleVector[] features, DoubleVector[] outcome, float splitFraction, boolean random) {
        int i;
        Preconditions.checkArgument((features.length == outcome.length ? 1 : 0) != 0, (Object)"Feature vector and outcome vector must match in length!");
        Preconditions.checkArgument((splitFraction >= 0.0f && splitFraction <= 1.0f ? 1 : 0) != 0, (Object)("splitFraction must be between 0 and 1! Given: " + splitFraction));
        Deque[] multiQueues = new Deque[Math.max(2, outcome[0].getDimension())];
        DoubleVector[] sampleOutcome = new DoubleVector[multiQueues.length];
        for (i = 0; i < features.length; ++i) {
            int key = multiQueues.length == 2 ? (int)outcome[i].get(0) : outcome[i].maxIndex();
            LinkedList<DoubleVector> deque = multiQueues[key];
            if (deque == null) {
                multiQueues[key] = deque = new LinkedList<DoubleVector>();
            }
            deque.addLast(features[i]);
            sampleOutcome[key] = outcome[i];
        }
        for (i = 0; i < multiQueues.length; ++i) {
            Preconditions.checkNotNull((Object)multiQueues[i], (Object)("Queue for class " + i + " couldn't be found. This happens when the mentioned class label doesn't exists in the given set of vectors."));
        }
        int splitSize = (int)((float)features.length * splitFraction);
        double[] samplingProbabilities = new double[multiQueues.length];
        int[] splitIndices = new int[multiQueues.length];
        int sum = 0;
        for (int i2 = 0; i2 < multiQueues.length; ++i2) {
            samplingProbabilities[i2] = (double)multiQueues[i2].size() / (double)features.length;
            splitIndices[i2] = (int)((double)splitSize * samplingProbabilities[i2]);
            Preconditions.checkArgument((splitIndices[i2] > 0 ? 1 : 0) != 0, (Object)("Can't stratify the class " + i2 + " because the split size was too small to satisfy the sampling requirement."));
            sum += splitIndices[i2];
        }
        if (sum != splitSize) {
            LOG.warn("Correcting the split size from " + splitSize + " to " + sum + ", to satisfy the sampling target.");
            splitSize = sum;
        }
        DoubleVector[] trainFeatures = new DoubleVector[splitSize];
        DoubleVector[] trainOutcomes = new DoubleVector[splitSize];
        LOG.info("Sampling probabilities by class: " + Arrays.toString(samplingProbabilities));
        int offset = 0;
        for (int s = 0; s < splitIndices.length; ++s) {
            for (int i3 = 0; i3 < splitIndices[s]; ++i3) {
                trainFeatures[offset] = (DoubleVector)multiQueues[s].poll();
                trainOutcomes[offset] = sampleOutcome[s];
                ++offset;
            }
        }
        Preconditions.checkArgument((offset == trainFeatures.length ? 1 : 0) != 0, (Object)("Didn't fill up the targeted split size of " + splitSize + " vectors in the training set!"));
        DoubleVector[] testFeatures = new DoubleVector[features.length - splitSize];
        DoubleVector[] testOutcomes = new DoubleVector[features.length - splitSize];
        offset = 0;
        for (int i4 = 0; i4 < multiQueues.length; ++i4) {
            while (!multiQueues[i4].isEmpty()) {
                Preconditions.checkArgument((offset < testFeatures.length ? 1 : 0) != 0, (Object)"Features are overflowing the calculated testset size, stratifying failed.");
                testFeatures[offset] = (DoubleVector)multiQueues[i4].poll();
                testOutcomes[offset] = sampleOutcome[i4];
                ++offset;
            }
        }
        if (random) {
            ArrayUtils.multiShuffle(trainFeatures, new DoubleVector[][]{trainOutcomes});
            ArrayUtils.multiShuffle(testFeatures, new DoubleVector[][]{testOutcomes});
        }
        return new EvaluationSplit(trainFeatures, trainOutcomes, testFeatures, testOutcomes);
    }
}

