/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.selection.scoring.evaluator.aggregator;

import java.io.Serializable;
import org.apache.ignite.internal.util.typedef.internal.A;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.selection.scoring.evaluator.aggregator.MetricStatsAggregator;
import org.apache.ignite.ml.selection.scoring.evaluator.context.BinaryClassificationEvaluationContext;
import org.apache.ignite.ml.structures.LabeledVector;

public class BinaryClassificationPointwiseMetricStatsAggregator<L extends Serializable>
implements MetricStatsAggregator<L, BinaryClassificationEvaluationContext<L>, BinaryClassificationPointwiseMetricStatsAggregator<L>> {
    private static final long serialVersionUID = -7677193556950322385L;
    private L falseLabel;
    private L truthLabel;
    private int truePositive;
    int falsePositive;
    int trueNegative;
    int falseNegative;

    public BinaryClassificationPointwiseMetricStatsAggregator() {
    }

    public BinaryClassificationPointwiseMetricStatsAggregator(L falseLabel, L truthLabel, int truePositive, int falsePositive, int trueNegative, int falseNegative) {
        this.falseLabel = falseLabel;
        this.truthLabel = truthLabel;
        this.truePositive = truePositive;
        this.falsePositive = falsePositive;
        this.trueNegative = trueNegative;
        this.falseNegative = falseNegative;
    }

    @Override
    public void aggregate(IgniteModel<Vector, L> mdl, LabeledVector<L> vector) {
        Serializable modelAns = (Serializable)mdl.predict(vector.features());
        Serializable realAns = (Serializable)vector.label();
        if (modelAns.equals(this.falseLabel) && realAns.equals(this.falseLabel)) {
            ++this.trueNegative;
        } else if (modelAns.equals(this.falseLabel) && realAns.equals(this.truthLabel)) {
            ++this.falseNegative;
        } else if (modelAns.equals(this.truthLabel) && realAns.equals(this.truthLabel)) {
            ++this.truePositive;
        } else if (modelAns.equals(this.truthLabel) && realAns.equals(this.falseLabel)) {
            ++this.falsePositive;
        }
    }

    @Override
    public BinaryClassificationPointwiseMetricStatsAggregator<L> mergeWith(BinaryClassificationPointwiseMetricStatsAggregator other) {
        A.ensure((boolean)this.falseLabel.equals(other.falseLabel), (String)"this.falseLabel == other.falseLabel");
        A.ensure((boolean)this.truthLabel.equals(other.truthLabel), (String)"this.truthLabel == other.truthLabel");
        return new BinaryClassificationPointwiseMetricStatsAggregator<L>(this.falseLabel, this.truthLabel, this.truePositive + other.truePositive, this.falsePositive + other.falsePositive, this.trueNegative + other.trueNegative, this.falseNegative + other.falseNegative);
    }

    @Override
    public BinaryClassificationEvaluationContext<L> createInitializedContext() {
        return new BinaryClassificationEvaluationContext();
    }

    @Override
    public void initByContext(BinaryClassificationEvaluationContext<L> ctx) {
        this.falseLabel = ctx.getFirstClsLbl();
        this.truthLabel = ctx.getSecondClsLbl();
    }

    public L getFalseLabel() {
        return this.falseLabel;
    }

    public L getTruthLabel() {
        return this.truthLabel;
    }

    public int getTruePositive() {
        return this.truePositive;
    }

    public int getFalsePositive() {
        return this.falsePositive;
    }

    public int getTrueNegative() {
        return this.trueNegative;
    }

    public int getFalseNegative() {
        return this.falseNegative;
    }

    public int getN() {
        return this.truePositive + this.falsePositive + this.trueNegative + this.falseNegative;
    }

    public static class WithCustomLabelsAggregator<L extends Serializable>
    extends BinaryClassificationPointwiseMetricStatsAggregator<L> {
        private final L truthLabel;
        private final L falseLabel;

        public WithCustomLabelsAggregator(L truthLabel, L falseLabel) {
            this.truthLabel = truthLabel;
            this.falseLabel = falseLabel;
        }

        @Override
        public BinaryClassificationEvaluationContext<L> createInitializedContext() {
            return new BinaryClassificationEvaluationContext<L>((Serializable)this.falseLabel, (Serializable)this.truthLabel){
                private static final long serialVersionUID = 4739649114414953828L;

                @Override
                public boolean needToCompute() {
                    return false;
                }
            };
        }
    }
}

