/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.structures;

import java.io.Serializable;
import java.util.Map;
import java.util.Random;
import java.util.TreeMap;
import java.util.TreeSet;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.structures.LabeledVectorSet;
import org.jetbrains.annotations.NotNull;

public class LabeledVectorSetTestTrainPair
implements Serializable {
    private LabeledVectorSet train;
    private LabeledVectorSet test;

    public LabeledVectorSetTestTrainPair(LabeledVectorSet dataset, double testPercentage) {
        assert (testPercentage > 0.0);
        assert (testPercentage < 1.0);
        int datasetSize = dataset.rowSize();
        assert (datasetSize > 2);
        int testSize = (int)Math.floor((double)datasetSize * testPercentage);
        int trainSize = datasetSize - testSize;
        TreeSet<Integer> sortedTestIndices = this.getSortedIndices(datasetSize, testSize);
        LabeledVector[] testVectors = new LabeledVector[testSize];
        LabeledVector[] trainVectors = new LabeledVector[trainSize];
        int datasetCntr = 0;
        int trainCntr = 0;
        int testCntr = 0;
        for (Integer idx : sortedTestIndices) {
            testVectors[testCntr] = (LabeledVector)dataset.getRow(idx);
            ++testCntr;
            for (int i = datasetCntr; i < idx; ++i) {
                trainVectors[trainCntr] = (LabeledVector)dataset.getRow(i);
                ++trainCntr;
            }
            datasetCntr = idx + 1;
        }
        if (datasetCntr < datasetSize) {
            for (int i = datasetCntr; i < datasetSize; ++i) {
                trainVectors[trainCntr] = (LabeledVector)dataset.getRow(i);
                ++trainCntr;
            }
        }
        this.test = new LabeledVectorSet(testVectors, dataset.colSize());
        this.train = new LabeledVectorSet(trainVectors, dataset.colSize());
    }

    @NotNull
    private TreeSet<Integer> getSortedIndices(int datasetSize, int testSize) {
        Random rnd = new Random();
        TreeMap<Double, Integer> randomIdxPairs = new TreeMap<Double, Integer>();
        for (int i = 0; i < datasetSize; ++i) {
            randomIdxPairs.put(rnd.nextDouble(), i);
        }
        TreeMap testIdxPairs = randomIdxPairs.entrySet().stream().limit(testSize).collect(TreeMap::new, (m, e) -> {
            Integer cfr_ignored_0 = (Integer)m.put(e.getKey(), e.getValue());
        }, Map::putAll);
        return new TreeSet<Integer>(testIdxPairs.values());
    }

    public LabeledVectorSet train() {
        return this.train;
    }

    public LabeledVectorSet test() {
        return this.test;
    }
}

