package org.apache.samoa.learners.classifiers.ensemble;

import java.util.HashMap;
import java.util.Map;
import org.apache.samoa.core.ContentEvent;
import org.apache.samoa.core.Processor;
import org.apache.samoa.learners.ResultContentEvent;
import org.apache.samoa.moa.core.DoubleVector;
import org.apache.samoa.topology.Stream;

/* loaded from: input_file:org/apache/samoa/learners/classifiers/ensemble/PredictionCombinerProcessor.class */
public class PredictionCombinerProcessor implements Processor {
    private static final long serialVersionUID = -1606045723451191132L;
    protected int ensembleSize;
    protected Stream outputStream;
    protected Map<Integer, Integer> mapCountsforInstanceReceived;
    protected Map<Integer, DoubleVector> mapVotesforInstanceReceived;

    public void setOutputStream(Stream stream) {
        this.outputStream = stream;
    }

    public Stream getOutputStream() {
        return this.outputStream;
    }

    public int getEnsembleSize() {
        return this.ensembleSize;
    }

    public void setEnsembleSize(int i) {
        this.ensembleSize = i;
    }

    @Override // org.apache.samoa.core.Processor
    public boolean process(ContentEvent contentEvent) {
        ResultContentEvent resultContentEvent = (ResultContentEvent) contentEvent;
        double[] classVotes = resultContentEvent.getClassVotes();
        int instanceIndex = (int) resultContentEvent.getInstanceIndex();
        addStatisticsForInstanceReceived(instanceIndex, resultContentEvent.getClassifierIndex(), classVotes, 1);
        if (!resultContentEvent.isLastEvent() && !hasAllVotesArrivedInstance(instanceIndex)) {
            return false;
        }
        DoubleVector doubleVector = this.mapVotesforInstanceReceived.get(Integer.valueOf(instanceIndex));
        if (doubleVector == null) {
            doubleVector = new DoubleVector(new double[resultContentEvent.getInstance().numClasses()]);
        }
        ResultContentEvent resultContentEvent2 = new ResultContentEvent(resultContentEvent.getInstanceIndex(), resultContentEvent.getInstance(), resultContentEvent.getClassId(), doubleVector.getArrayCopy(), resultContentEvent.isLastEvent());
        resultContentEvent2.setEvaluationIndex(resultContentEvent.getEvaluationIndex());
        this.outputStream.put(resultContentEvent2);
        clearStatisticsInstance(instanceIndex);
        return true;
    }

    @Override // org.apache.samoa.core.Processor
    public void onCreate(int i) {
        reset();
    }

    public void reset() {
    }

    @Override // org.apache.samoa.core.Processor
    public Processor newProcessor(Processor processor) {
        PredictionCombinerProcessor predictionCombinerProcessor = new PredictionCombinerProcessor();
        PredictionCombinerProcessor predictionCombinerProcessor2 = (PredictionCombinerProcessor) processor;
        if (predictionCombinerProcessor2.getOutputStream() != null) {
            predictionCombinerProcessor.setOutputStream(predictionCombinerProcessor2.getOutputStream());
        }
        predictionCombinerProcessor.setEnsembleSize(predictionCombinerProcessor2.getEnsembleSize());
        return predictionCombinerProcessor;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void addStatisticsForInstanceReceived(int i, int i2, double[] dArr, int i3) {
        if (this.mapCountsforInstanceReceived == null) {
            this.mapCountsforInstanceReceived = new HashMap();
            this.mapVotesforInstanceReceived = new HashMap();
        }
        DoubleVector doubleVector = new DoubleVector(dArr);
        if (doubleVector.sumOfValues() > 0.0d) {
            doubleVector.normalize();
            DoubleVector doubleVector2 = this.mapVotesforInstanceReceived.get(Integer.valueOf(i));
            if (doubleVector2 == null) {
                doubleVector2 = new DoubleVector();
            }
            doubleVector.scaleValues(getEnsembleMemberWeight(i2));
            doubleVector2.addValues(doubleVector);
            this.mapVotesforInstanceReceived.put(Integer.valueOf(i), doubleVector2);
        }
        Integer num = this.mapCountsforInstanceReceived.get(Integer.valueOf(i));
        if (num == null) {
            num = 0;
        }
        this.mapCountsforInstanceReceived.put(Integer.valueOf(i), Integer.valueOf(num.intValue() + i3));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean hasAllVotesArrivedInstance(int i) {
        return this.mapCountsforInstanceReceived.get(Integer.valueOf(i)).intValue() == this.ensembleSize;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void clearStatisticsInstance(int i) {
        this.mapCountsforInstanceReceived.remove(Integer.valueOf(i));
        this.mapVotesforInstanceReceived.remove(Integer.valueOf(i));
    }

    protected double getEnsembleMemberWeight(int i) {
        return 1.0d;
    }
}
