/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.tree.impurity.gini;

import java.util.Arrays;
import java.util.Map;
import org.apache.ignite.ml.tree.TreeFilter;
import org.apache.ignite.ml.tree.data.DecisionTreeData;
import org.apache.ignite.ml.tree.data.TreeDataIndex;
import org.apache.ignite.ml.tree.impurity.ImpurityMeasure;
import org.apache.ignite.ml.tree.impurity.ImpurityMeasureCalculator;
import org.apache.ignite.ml.tree.impurity.gini.GiniImpurityMeasure;
import org.apache.ignite.ml.tree.impurity.util.StepFunction;

public class GiniImpurityMeasureCalculator
extends ImpurityMeasureCalculator<GiniImpurityMeasure> {
    private static final long serialVersionUID = -522995134128519679L;
    private final Map<Double, Integer> lbEncoder;

    public GiniImpurityMeasureCalculator(Map<Double, Integer> lbEncoder, boolean useIdx) {
        super(useIdx);
        this.lbEncoder = lbEncoder;
    }

    @Override
    public StepFunction<GiniImpurityMeasure>[] calculate(DecisionTreeData data, TreeFilter filter, int depth) {
        TreeDataIndex idx = null;
        boolean canCalculate = false;
        if (this.useIdx) {
            idx = data.createIndexByFilter(depth, filter);
            canCalculate = idx.rowsCount() > 0;
        } else {
            boolean bl = canCalculate = (data = data.filter(filter)).getFeatures().length > 0;
        }
        if (canCalculate) {
            int rowsCnt = this.rowsCount(data, idx);
            int colsCnt = this.columnsCount(data, idx);
            StepFunction[] res = new StepFunction[colsCnt];
            long[] right = new long[this.lbEncoder.size()];
            for (int i = 0; i < rowsCnt; ++i) {
                double lb = this.getLabelValue(data, idx, 0, i);
                int n = this.getLabelCode(lb);
                right[n] = right[n] + 1L;
            }
            for (int col = 0; col < res.length; ++col) {
                if (!this.useIdx) {
                    data.sort(col);
                }
                double[] x = new double[rowsCnt + 1];
                GiniImpurityMeasure[] y = new GiniImpurityMeasure[rowsCnt + 1];
                long[] left = new long[this.lbEncoder.size()];
                long[] rightCp = Arrays.copyOf(right, right.length);
                int xPtr = 0;
                int yPtr = 0;
                x[xPtr++] = Double.NEGATIVE_INFINITY;
                y[yPtr++] = new GiniImpurityMeasure(Arrays.copyOf(left, left.length), Arrays.copyOf(rightCp, rightCp.length));
                for (int i = 0; i < rowsCnt; ++i) {
                    double lb = this.getLabelValue(data, idx, col, i);
                    int n = this.getLabelCode(lb);
                    left[n] = left[n] + 1L;
                    int n2 = this.getLabelCode(lb);
                    rightCp[n2] = rightCp[n2] - 1L;
                    double featureVal = this.getFeatureValue(data, idx, col, i);
                    if (i < rowsCnt - 1 && this.getFeatureValue(data, idx, col, i + 1) == featureVal) continue;
                    x[xPtr++] = featureVal;
                    y[yPtr++] = new GiniImpurityMeasure(Arrays.copyOf(left, left.length), Arrays.copyOf(rightCp, rightCp.length));
                }
                res[col] = new StepFunction(Arrays.copyOf(x, xPtr), (ImpurityMeasure[])Arrays.copyOf(y, yPtr));
            }
            return res;
        }
        return null;
    }

    int getLabelCode(double lb) {
        Integer code = this.lbEncoder.get(lb);
        assert (code != null) : "Can't find code for label " + lb;
        return code;
    }
}

