package org.apache.samoa.learners.classifiers;

import org.apache.samoa.core.ContentEvent;
import org.apache.samoa.core.Processor;
import org.apache.samoa.instances.Instance;
import org.apache.samoa.learners.InstanceContentEvent;
import org.apache.samoa.learners.ResultContentEvent;
import org.apache.samoa.moa.classifiers.core.driftdetection.ChangeDetector;
import org.apache.samoa.moa.core.Utils;
import org.apache.samoa.topology.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/samoa/learners/classifiers/LocalLearnerProcessor.class */
public final class LocalLearnerProcessor implements Processor {
    private static final long serialVersionUID = -1577910988699148691L;
    private static final Logger logger = LoggerFactory.getLogger(LocalLearnerProcessor.class);
    private LocalLearner model;
    private Stream outputStream;
    private int modelId;
    private long instancesCount = 0;
    protected int test;
    protected ChangeDetector changeDetector;

    public void setLearner(LocalLearner localLearner) {
        this.model = localLearner;
    }

    public LocalLearner getLearner() {
        return this.model;
    }

    public void setOutputStream(Stream stream) {
        this.outputStream = stream;
    }

    public Stream getOutputStream() {
        return this.outputStream;
    }

    public long getInstancesCount() {
        return this.instancesCount;
    }

    private void updateStats(InstanceContentEvent instanceContentEvent) {
        Instance instanceContentEvent2 = instanceContentEvent.getInstance();
        this.model.trainOnInstance(instanceContentEvent2);
        this.instancesCount++;
        if (this.changeDetector != null) {
            boolean correctlyClassifies = correctlyClassifies(instanceContentEvent2);
            double estimation = this.changeDetector.getEstimation();
            this.changeDetector.input(correctlyClassifies ? 0.0d : 1.0d);
            if (!this.changeDetector.getChange() || this.changeDetector.getEstimation() <= estimation) {
                return;
            }
            this.model.resetLearning();
            this.changeDetector.resetLearning();
        }
    }

    private boolean correctlyClassifies(Instance instance) {
        return Utils.maxIndex(this.model.getVotesForInstance(instance)) == ((int) instance.classValue());
    }

    @Override // org.apache.samoa.core.Processor
    public boolean process(ContentEvent contentEvent) {
        InstanceContentEvent instanceContentEvent = (InstanceContentEvent) contentEvent;
        Instance instanceContentEvent2 = instanceContentEvent.getInstance();
        if (instanceContentEvent.getInstanceIndex() < 0) {
            ResultContentEvent resultContentEvent = new ResultContentEvent(-1L, instanceContentEvent2, 0, new double[0], instanceContentEvent.isLastEvent(), instanceContentEvent.getArrivalTimestamp());
            resultContentEvent.setClassifierIndex(this.modelId);
            resultContentEvent.setEvaluationIndex(instanceContentEvent.getEvaluationIndex());
            this.outputStream.put(resultContentEvent);
            return false;
        }
        if (instanceContentEvent.isTesting()) {
            double[] votesForInstance = this.model.getVotesForInstance(instanceContentEvent2);
            ResultContentEvent resultContentEvent2 = new ResultContentEvent(instanceContentEvent.getInstanceIndex(), instanceContentEvent2, instanceContentEvent.getClassId(), votesForInstance, instanceContentEvent.isLastEvent(), instanceContentEvent.getArrivalTimestamp());
            resultContentEvent2.setClassifierIndex(this.modelId);
            resultContentEvent2.setEvaluationIndex(instanceContentEvent.getEvaluationIndex());
            logger.trace(instanceContentEvent.getInstanceIndex() + " {} {}", Integer.valueOf(this.modelId), votesForInstance);
            this.outputStream.put(resultContentEvent2);
        }
        if (!instanceContentEvent.isTraining()) {
            return false;
        }
        updateStats(instanceContentEvent);
        return false;
    }

    @Override // org.apache.samoa.core.Processor
    public void onCreate(int i) {
        this.modelId = i;
        this.model = this.model.create();
    }

    @Override // org.apache.samoa.core.Processor
    public Processor newProcessor(Processor processor) {
        LocalLearnerProcessor localLearnerProcessor = new LocalLearnerProcessor();
        LocalLearnerProcessor localLearnerProcessor2 = (LocalLearnerProcessor) processor;
        if (localLearnerProcessor2.getLearner() != null) {
            localLearnerProcessor.setLearner(localLearnerProcessor2.getLearner().create());
        }
        if (localLearnerProcessor2.getChangeDetector() != null) {
            localLearnerProcessor.setChangeDetector(localLearnerProcessor2.getChangeDetector());
        }
        localLearnerProcessor.setOutputStream(localLearnerProcessor2.getOutputStream());
        return localLearnerProcessor;
    }

    public ChangeDetector getChangeDetector() {
        return this.changeDetector;
    }

    public void setChangeDetector(ChangeDetector changeDetector) {
        this.changeDetector = changeDetector;
    }
}
