/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.tree.randomforest;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Queue;
import java.util.Random;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.feature.BucketMeta;
import org.apache.ignite.ml.dataset.feature.FeatureMeta;
import org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedDatasetBuilder;
import org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedDatasetPartition;
import org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedVector;
import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.preprocessing.Preprocessor;
import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
import org.apache.ignite.ml.tree.randomforest.RandomForestModel;
import org.apache.ignite.ml.tree.randomforest.data.FeaturesCountSelectionStrategies;
import org.apache.ignite.ml.tree.randomforest.data.NodeId;
import org.apache.ignite.ml.tree.randomforest.data.NodeSplit;
import org.apache.ignite.ml.tree.randomforest.data.RandomForestTreeModel;
import org.apache.ignite.ml.tree.randomforest.data.TreeNode;
import org.apache.ignite.ml.tree.randomforest.data.impurity.ImpurityComputer;
import org.apache.ignite.ml.tree.randomforest.data.impurity.ImpurityHistogramsComputer;
import org.apache.ignite.ml.tree.randomforest.data.statistics.LeafValuesComputer;
import org.apache.ignite.ml.tree.randomforest.data.statistics.NormalDistributionStatistics;
import org.apache.ignite.ml.tree.randomforest.data.statistics.NormalDistributionStatisticsComputer;

public abstract class RandomForestTrainer<L, S extends ImpurityComputer<BootstrappedVector, S>, T extends RandomForestTrainer<L, S, T>>
extends SingleLabelDatasetTrainer<RandomForestModel> {
    private static final double BUCKET_SIZE_FACTOR = 0.1;
    private int amountOfTrees = 1;
    private double subSampleSize = 1.0;
    private int maxDepth = 5;
    private double minImpurityDelta;
    private List<FeatureMeta> meta;
    private int featuresPerTree = 5;
    private long seed = 1234L;
    private Random random = new Random(this.seed);
    private Function<Queue<TreeNode>, List<TreeNode>> nodesToLearnSelectionStrgy = this::defaultNodesToLearnSelectionStrgy;

    public RandomForestTrainer(List<FeatureMeta> meta) {
        this.meta = meta;
        this.featuresPerTree = (Integer)FeaturesCountSelectionStrategies.ALL.apply(meta);
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    @Override
    public <K, V> RandomForestModel fitWithInitializedDeployingContext(DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
        List<RandomForestTreeModel> models = null;
        try (Dataset<EmptyContext, BootstrappedDatasetPartition> dataset = datasetBuilder.build(this.envBuilder, new EmptyContextBuilder(), new BootstrappedDatasetBuilder<K, V>(preprocessor, this.amountOfTrees, this.subSampleSize), this.learningEnvironment());){
            if (!this.init(dataset)) {
                RandomForestModel randomForestModel = this.buildComposition(Collections.emptyList());
                return randomForestModel;
            }
            models = this.fit(dataset);
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        if ($assertionsDisabled) return this.buildComposition(models);
        if (models != null) return this.buildComposition(models);
        throw new AssertionError();
    }

    protected abstract T instance();

    public T withAmountOfTrees(int amountOfTrees) {
        this.amountOfTrees = amountOfTrees;
        return this.instance();
    }

    public T withSubSampleSize(double subSampleSize) {
        this.subSampleSize = subSampleSize;
        return this.instance();
    }

    public T withMaxDepth(int maxDepth) {
        this.maxDepth = maxDepth;
        return this.instance();
    }

    public T withMinImpurityDelta(double minImpurityDelta) {
        this.minImpurityDelta = minImpurityDelta;
        return this.instance();
    }

    public T withFeaturesCountSelectionStrgy(Function<List<FeatureMeta>, Integer> strgy) {
        this.featuresPerTree = strgy.apply(this.meta);
        return this.instance();
    }

    public T withNodesToLearnSelectionStrgy(Function<Queue<TreeNode>, List<TreeNode>> strgy) {
        this.nodesToLearnSelectionStrgy = strgy;
        return this.instance();
    }

    public T withSeed(long seed) {
        this.seed = seed;
        this.random = new Random(seed);
        return this.instance();
    }

    protected boolean init(Dataset<EmptyContext, BootstrappedDatasetPartition> dataset) {
        return true;
    }

    private List<RandomForestTreeModel> fit(Dataset<EmptyContext, BootstrappedDatasetPartition> dataset) {
        Queue<TreeNode> treesQueue = this.createRootsQueue();
        ArrayList<RandomForestTreeModel> roots = this.initTrees(treesQueue);
        Map<Integer, BucketMeta> histMeta = this.computeHistogramMeta(this.meta, dataset);
        if (histMeta.isEmpty()) {
            return Collections.emptyList();
        }
        ImpurityHistogramsComputer<S> histogramsComputer = this.createImpurityHistogramsComputer();
        while (!treesQueue.isEmpty()) {
            Map<NodeId, TreeNode> nodesToLearn = this.getNodesToLearn(treesQueue);
            Map<NodeId, ImpurityHistogramsComputer.NodeImpurityHistograms<S>> nodesImpHists = histogramsComputer.aggregateImpurityStatistics(roots, histMeta, nodesToLearn, dataset);
            if (nodesToLearn.size() != nodesImpHists.size()) {
                throw new IllegalStateException();
            }
            for (NodeId nodeId : nodesImpHists.keySet()) {
                this.split(treesQueue, nodesToLearn, nodesImpHists.get(nodeId));
            }
        }
        this.createLeafStatisticsAggregator().setValuesForLeaves(roots, dataset);
        return roots;
    }

    @Override
    public boolean isUpdateable(RandomForestModel mdl) {
        RandomForestModel fakeComposition = this.buildComposition(Collections.emptyList());
        return mdl.getPredictionsAggregator().getClass() == fakeComposition.getPredictionsAggregator().getClass();
    }

    @Override
    protected <K, V> RandomForestModel updateModel(RandomForestModel mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
        ArrayList<RandomForestTreeModel> oldModels = new ArrayList<RandomForestTreeModel>(mdl.getModels());
        RandomForestModel newModels = (RandomForestModel)this.fit(datasetBuilder, preprocessor);
        oldModels.addAll(newModels.getModels());
        return new RandomForestModel((List<RandomForestTreeModel>)oldModels, mdl.getPredictionsAggregator());
    }

    private void split(Queue<TreeNode> learningQueue, Map<NodeId, TreeNode> nodesToLearn, ImpurityHistogramsComputer.NodeImpurityHistograms<S> nodeImpurityHistograms) {
        Optional<NodeSplit> bestSplit;
        TreeNode cornerNode = nodesToLearn.get(nodeImpurityHistograms.getNodeId());
        if (this.needSplit(cornerNode, bestSplit = nodeImpurityHistograms.findBestSplit())) {
            List<TreeNode> children = bestSplit.get().split(cornerNode);
            learningQueue.addAll(children);
        } else if (bestSplit.isPresent()) {
            bestSplit.get().createLeaf(cornerNode);
        } else {
            cornerNode.setImpurity(Double.NEGATIVE_INFINITY);
            cornerNode.toLeaf(0.0);
        }
    }

    protected abstract ImpurityHistogramsComputer<S> createImpurityHistogramsComputer();

    protected abstract LeafValuesComputer<L> createLeafStatisticsAggregator();

    protected ArrayList<RandomForestTreeModel> initTrees(Queue<TreeNode> treesQueue) {
        assert (this.featuresPerTree > 0);
        ArrayList<RandomForestTreeModel> roots = new ArrayList<RandomForestTreeModel>();
        List allFeatureIds = IntStream.range(0, this.meta.size()).boxed().collect(Collectors.toList());
        for (TreeNode node : treesQueue) {
            Collections.shuffle(allFeatureIds, this.random);
            Set<Integer> featuresSubspace = allFeatureIds.stream().limit(this.featuresPerTree).collect(Collectors.toSet());
            roots.add(new RandomForestTreeModel(node, featuresSubspace));
        }
        return roots;
    }

    private Map<Integer, BucketMeta> computeHistogramMeta(List<FeatureMeta> meta, Dataset<EmptyContext, BootstrappedDatasetPartition> dataset) {
        List<NormalDistributionStatistics> stats = new NormalDistributionStatisticsComputer().computeStatistics(meta, dataset);
        if (stats == null) {
            return Collections.emptyMap();
        }
        HashMap<Integer, BucketMeta> bucketsMeta = new HashMap<Integer, BucketMeta>();
        for (int i = 0; i < stats.size(); ++i) {
            BucketMeta bucketMeta = new BucketMeta(meta.get(i));
            if (!bucketMeta.getFeatureMeta().isCategoricalFeature()) {
                NormalDistributionStatistics stat = stats.get(i);
                bucketMeta.setMinVal(stat.min());
                bucketMeta.setBucketSize(stat.std() * 0.1);
            }
            bucketsMeta.put(i, bucketMeta);
        }
        return bucketsMeta;
    }

    private Queue<TreeNode> createRootsQueue() {
        LinkedList<TreeNode> roots = new LinkedList<TreeNode>();
        for (int i = 0; i < this.amountOfTrees; ++i) {
            roots.add(new TreeNode(1L, i));
        }
        return roots;
    }

    private Map<NodeId, TreeNode> getNodesToLearn(Queue<TreeNode> queue) {
        return this.nodesToLearnSelectionStrgy.apply(queue).stream().collect(Collectors.toMap(TreeNode::getId, node -> node));
    }

    private List<TreeNode> defaultNodesToLearnSelectionStrgy(Queue<TreeNode> queue) {
        ArrayList<TreeNode> res = new ArrayList<TreeNode>(queue);
        queue.clear();
        return res;
    }

    boolean needSplit(TreeNode parentNode, Optional<NodeSplit> split) {
        return split.isPresent() && parentNode.getImpurity() - split.get().getImpurity() > this.minImpurityDelta && parentNode.getDepth() < this.maxDepth + 1;
    }

    protected abstract RandomForestModel buildComposition(List<RandomForestTreeModel> var1);
}

