package org.apache.samoa.evaluation;

import java.util.Collections;
import java.util.List;
import java.util.Vector;
import org.apache.samoa.instances.Attribute;
import org.apache.samoa.instances.Instance;
import org.apache.samoa.instances.Utils;
import org.apache.samoa.moa.AbstractMOAObject;
import org.apache.samoa.moa.core.Measurement;
import org.apache.samoa.moa.core.Vote;

/* loaded from: input_file:org/apache/samoa/evaluation/F1ClassificationPerformanceEvaluator.class */
public class F1ClassificationPerformanceEvaluator extends AbstractMOAObject implements ClassificationPerformanceEvaluator {
    private static final long serialVersionUID = 1;
    protected int numClasses = -1;
    protected long[] support;
    protected long[] truePos;
    protected long[] falsePos;
    protected long[] trueNeg;
    protected long[] falseNeg;
    private String instanceIdentifier;
    private Instance lastSeenInstance;
    protected double[] classVotes;

    @Override // org.apache.samoa.evaluation.PerformanceEvaluator
    public void reset() {
        reset(this.numClasses);
    }

    public void reset(int i) {
        this.numClasses = i;
        this.support = new long[i];
        this.truePos = new long[i];
        this.falsePos = new long[i];
        this.trueNeg = new long[i];
        this.falseNeg = new long[i];
        for (int i2 = 0; i2 < this.numClasses; i2++) {
            this.support[i2] = 0;
            this.truePos[i2] = 0;
            this.falsePos[i2] = 0;
            this.trueNeg[i2] = 0;
            this.falseNeg[i2] = 0;
        }
    }

    @Override // org.apache.samoa.evaluation.PerformanceEvaluator
    public void addResult(Instance instance, double[] dArr, String str, long j) {
        if (this.numClasses == -1) {
            reset(instance.numClasses());
        }
        int classValue = (int) instance.classValue();
        long[] jArr = this.support;
        jArr[classValue] = jArr[classValue] + 1;
        int maxIndex = Utils.maxIndex(dArr);
        if (maxIndex == classValue) {
            long[] jArr2 = this.truePos;
            jArr2[classValue] = jArr2[classValue] + 1;
            for (int i = 0; i < this.numClasses; i++) {
                if (i != maxIndex) {
                    long[] jArr3 = this.trueNeg;
                    int i2 = i;
                    jArr3[i2] = jArr3[i2] + 1;
                }
            }
            return;
        }
        long[] jArr4 = this.falsePos;
        jArr4[maxIndex] = jArr4[maxIndex] + 1;
        long[] jArr5 = this.falseNeg;
        jArr5[classValue] = jArr5[classValue] + 1;
        for (int i3 = 0; i3 < this.numClasses; i3++) {
            if (i3 != maxIndex && i3 != classValue) {
                long[] jArr6 = this.trueNeg;
                int i4 = i3;
                jArr6[i4] = jArr6[i4] + 1;
            }
        }
    }

    @Override // org.apache.samoa.evaluation.PerformanceEvaluator
    public Measurement[] getPerformanceMeasurements() {
        Vector vector = new Vector();
        Collections.addAll(vector, getSupportMeasurements());
        Collections.addAll(vector, getPrecisionMeasurements());
        Collections.addAll(vector, getRecallMeasurements());
        Collections.addAll(vector, getF1Measurements());
        return (Measurement[]) vector.toArray(new Measurement[vector.size()]);
    }

    @Override // org.apache.samoa.evaluation.PerformanceEvaluator
    public Vote[] getPredictionVotes() {
        Attribute classAttribute = this.lastSeenInstance.dataset().classAttribute();
        double classValue = this.lastSeenInstance.classValue();
        List attributeValues = classAttribute.getAttributeValues();
        String str = (String) attributeValues.get((int) classValue);
        Vote[] voteArr = new Vote[this.classVotes.length + 3];
        voteArr[0] = new Vote("instance number", this.instanceIdentifier);
        voteArr[1] = new Vote("true class value", str);
        voteArr[2] = new Vote("predicted class value", (String) attributeValues.get(Utils.maxIndex(this.classVotes)));
        for (int i = 0; i < attributeValues.size(); i++) {
            if (i < this.classVotes.length) {
                voteArr[2 + i] = new Vote("votes_" + ((String) attributeValues.get(i)), this.classVotes[i]);
            } else {
                voteArr[2 + i] = new Vote("votes_" + ((String) attributeValues.get(i)), 0.0d);
            }
        }
        return voteArr;
    }

    private Measurement[] getSupportMeasurements() {
        Measurement[] measurementArr = new Measurement[this.numClasses];
        for (int i = 0; i < this.numClasses; i++) {
            measurementArr[i] = new Measurement(String.format("class %s support", Integer.valueOf(i)), this.support[i]);
        }
        return measurementArr;
    }

    private Measurement[] getPrecisionMeasurements() {
        Measurement[] measurementArr = new Measurement[this.numClasses];
        for (int i = 0; i < this.numClasses; i++) {
            measurementArr[i] = new Measurement(String.format("class %s precision", Integer.valueOf(i)), getPrecision(i), 10);
        }
        return measurementArr;
    }

    private Measurement[] getRecallMeasurements() {
        Measurement[] measurementArr = new Measurement[this.numClasses];
        for (int i = 0; i < this.numClasses; i++) {
            measurementArr[i] = new Measurement(String.format("class %s recall", Integer.valueOf(i)), getRecall(i), 10);
        }
        return measurementArr;
    }

    private Measurement[] getF1Measurements() {
        Measurement[] measurementArr = new Measurement[this.numClasses];
        for (int i = 0; i < this.numClasses; i++) {
            measurementArr[i] = new Measurement(String.format("class %s f1-score", Integer.valueOf(i)), getF1Score(i), 10);
        }
        return measurementArr;
    }

    @Override // org.apache.samoa.moa.MOAObject
    public void getDescription(StringBuilder sb, int i) {
        Measurement.getMeasurementsDescription(getSupportMeasurements(), sb, i);
        Measurement.getMeasurementsDescription(getPrecisionMeasurements(), sb, i);
        Measurement.getMeasurementsDescription(getRecallMeasurements(), sb, i);
        Measurement.getMeasurementsDescription(getF1Measurements(), sb, i);
    }

    private double getPrecision(int i) {
        return this.truePos[i] / (this.truePos[i] + this.falsePos[i]);
    }

    private double getRecall(int i) {
        return this.truePos[i] / (this.truePos[i] + this.falseNeg[i]);
    }

    private double getF1Score(int i) {
        double precision = getPrecision(i);
        double recall = getRecall(i);
        return (2.0d * (precision * recall)) / (precision + recall);
    }
}
