package org.apache.samoa.learners.classifiers.ensemble;

import java.util.Random;
import org.apache.samoa.core.ContentEvent;
import org.apache.samoa.core.Processor;
import org.apache.samoa.instances.Instance;
import org.apache.samoa.instances.Instances;
import org.apache.samoa.learners.InstanceContentEvent;
import org.apache.samoa.learners.ResultContentEvent;
import org.apache.samoa.learners.classifiers.ensemble.BoostMAProcessor;
import org.apache.samoa.learners.classifiers.trees.BoostVHTActiveLearningNode;
import org.apache.samoa.learners.classifiers.trees.LocalResultContentEvent;
import org.apache.samoa.moa.classifiers.core.splitcriteria.InfoGainSplitCriterion;
import org.apache.samoa.moa.classifiers.core.splitcriteria.SplitCriterion;
import org.apache.samoa.moa.core.DoubleVector;
import org.apache.samoa.moa.core.MiscUtils;
import org.apache.samoa.topology.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/samoa/learners/classifiers/ensemble/BoostVHTProcessor.class */
public class BoostVHTProcessor implements Processor {
    private static final long serialVersionUID = -1550901409625192730L;
    private static final Logger logger = LoggerFactory.getLogger(BoostVHTProcessor.class);
    private SplitCriterion splitCriterion;
    private Double splitConfidence;
    private Double tieThreshold;
    private int gracePeriod;
    private int parallelismHint;
    private int timeOut;
    private BoostVHTActiveLearningNode.SplittingOption splittingOption;
    private Instances dataset;
    private int ensembleSize;
    private Stream resultStream;
    private Stream controlStream;
    private Stream attributeStream;
    protected BoostMAProcessor[] mAPEnsemble;
    protected Random random;
    private int seed;
    protected double[] scms;
    protected double[] swms;
    private double[] e_m;
    private double trainingWeightSeenByModel;
    private int numberOfClasses;
    private int maxBufferSize;

    /* loaded from: input_file:org/apache/samoa/learners/classifiers/ensemble/BoostVHTProcessor$Builder.class */
    public static class Builder {
        private final Instances dataset;
        private int ensembleSize;
        private int numberOfClasses;
        private SplitCriterion splitCriterion;
        private double splitConfidence;
        private double tieThreshold;
        private int gracePeriod;
        private int parallelismHint;
        private int timeOut;
        private BoostVHTActiveLearningNode.SplittingOption splittingOption;
        private int maxBufferSize;
        private int seed;

        public Builder(Instances instances) {
            this.splitCriterion = new InfoGainSplitCriterion();
            this.timeOut = Integer.MAX_VALUE;
            this.dataset = instances;
        }

        public Builder(BoostVHTProcessor boostVHTProcessor) {
            this.splitCriterion = new InfoGainSplitCriterion();
            this.timeOut = Integer.MAX_VALUE;
            this.dataset = boostVHTProcessor.getDataset();
            this.ensembleSize = boostVHTProcessor.getEnsembleSize();
            this.numberOfClasses = boostVHTProcessor.getNumberOfClasses();
            this.splitCriterion = boostVHTProcessor.getSplitCriterion();
            this.splitConfidence = boostVHTProcessor.getSplitConfidence().doubleValue();
            this.tieThreshold = boostVHTProcessor.getTieThreshold().doubleValue();
            this.gracePeriod = boostVHTProcessor.getGracePeriod();
            this.parallelismHint = boostVHTProcessor.getParallelismHint();
            this.timeOut = boostVHTProcessor.getTimeOut();
            this.splittingOption = boostVHTProcessor.splittingOption;
            this.seed = boostVHTProcessor.getSeed();
        }

        public Builder ensembleSize(int i) {
            this.ensembleSize = i;
            return this;
        }

        public Builder numberOfClasses(int i) {
            this.numberOfClasses = i;
            return this;
        }

        public Builder splitCriterion(SplitCriterion splitCriterion) {
            this.splitCriterion = splitCriterion;
            return this;
        }

        public Builder splitConfidence(double d) {
            this.splitConfidence = d;
            return this;
        }

        public Builder tieThreshold(double d) {
            this.tieThreshold = d;
            return this;
        }

        public Builder gracePeriod(int i) {
            this.gracePeriod = i;
            return this;
        }

        public Builder parallelismHint(int i) {
            this.parallelismHint = i;
            return this;
        }

        public Builder timeOut(int i) {
            this.timeOut = i;
            return this;
        }

        public Builder splittingOption(BoostVHTActiveLearningNode.SplittingOption splittingOption) {
            this.splittingOption = splittingOption;
            return this;
        }

        public Builder maxBufferSize(int i) {
            this.maxBufferSize = i;
            return this;
        }

        public Builder seed(int i) {
            this.seed = i;
            return this;
        }

        public BoostVHTProcessor build() {
            return new BoostVHTProcessor(this);
        }
    }

    private BoostVHTProcessor(Builder builder) {
        this.dataset = builder.dataset;
        this.ensembleSize = builder.ensembleSize;
        this.seed = builder.seed;
        this.numberOfClasses = builder.numberOfClasses;
        this.splitCriterion = builder.splitCriterion;
        this.splitConfidence = Double.valueOf(builder.splitConfidence);
        this.tieThreshold = Double.valueOf(builder.tieThreshold);
        this.gracePeriod = builder.gracePeriod;
        this.parallelismHint = builder.parallelismHint;
        this.timeOut = builder.timeOut;
        this.splittingOption = builder.splittingOption;
        this.maxBufferSize = builder.maxBufferSize;
    }

    @Override // org.apache.samoa.core.Processor
    public boolean process(ContentEvent contentEvent) {
        if (!(contentEvent instanceof InstanceContentEvent)) {
            if (!(contentEvent instanceof LocalResultContentEvent)) {
                return true;
            }
            LocalResultContentEvent localResultContentEvent = (LocalResultContentEvent) contentEvent;
            this.mAPEnsemble[localResultContentEvent.getEnsembleId()].updateModel(localResultContentEvent);
            return true;
        }
        InstanceContentEvent instanceContentEvent = (InstanceContentEvent) contentEvent;
        if (instanceContentEvent.isTesting()) {
            this.resultStream.put(newResultContentEvent(computeBoosting(instanceContentEvent), instanceContentEvent));
        }
        if (!instanceContentEvent.isTraining()) {
            return true;
        }
        train(instanceContentEvent);
        return true;
    }

    @Override // org.apache.samoa.core.Processor
    public void onCreate(int i) {
        this.mAPEnsemble = new BoostMAProcessor[this.ensembleSize];
        this.random = new Random(this.seed);
        this.scms = new double[this.ensembleSize];
        this.swms = new double[this.ensembleSize];
        this.e_m = new double[this.ensembleSize];
        for (int i2 = 0; i2 < this.ensembleSize; i2++) {
            BoostMAProcessor build = new BoostMAProcessor.BoostMABuilder(this.dataset).splitCriterion(this.splitCriterion).splitConfidence(this.splitConfidence.doubleValue()).tieThreshold(this.tieThreshold.doubleValue()).gracePeriod(this.gracePeriod).parallelismHint(this.parallelismHint).timeOut(this.timeOut).processorID(i2).maxBufferSize(this.maxBufferSize).splittingOption(this.splittingOption).build();
            build.setAttributeStream(this.attributeStream);
            build.setControlStream(this.controlStream);
            this.mAPEnsemble[i2] = build;
        }
    }

    private double[] computeBoosting(InstanceContentEvent instanceContentEvent) {
        Instance instanceContentEvent2 = instanceContentEvent.getInstance();
        DoubleVector doubleVector = new DoubleVector();
        for (int i = 0; i < this.ensembleSize; i++) {
            double ensembleMemberWeight = getEnsembleMemberWeight(i);
            if (ensembleMemberWeight <= 0.0d) {
                break;
            }
            DoubleVector doubleVector2 = new DoubleVector(this.mAPEnsemble[i].getVotesForInstance(instanceContentEvent2));
            if (doubleVector2.sumOfValues() > 0.0d) {
                doubleVector2.normalize();
                doubleVector2.scaleValues(ensembleMemberWeight);
                doubleVector.addValues(doubleVector2);
            }
        }
        return doubleVector.getArrayRef();
    }

    protected void train(InstanceContentEvent instanceContentEvent) {
        double d;
        double d2;
        double d3;
        double d4;
        Instance instanceContentEvent2 = instanceContentEvent.getInstance();
        this.trainingWeightSeenByModel += instanceContentEvent2.weight();
        double d5 = 1.0d;
        for (int i = 0; i < this.ensembleSize; i++) {
            int poisson = MiscUtils.poisson(d5, this.random);
            if (poisson > 0) {
                Instance copy = instanceContentEvent2.copy();
                copy.setWeight(instanceContentEvent2.weight() * poisson);
                this.mAPEnsemble[i].trainOnInstanceImpl(copy);
            }
            if (this.mAPEnsemble[i].correctlyClassifies(instanceContentEvent2, this.mAPEnsemble[i].getVotesForInstance(instanceContentEvent2))) {
                double[] dArr = this.scms;
                int i2 = i;
                dArr[i2] = dArr[i2] + d5;
                d = d5;
                d2 = this.trainingWeightSeenByModel;
                d3 = 2.0d;
                d4 = this.scms[i];
            } else {
                double[] dArr2 = this.swms;
                int i3 = i;
                dArr2[i3] = dArr2[i3] + d5;
                d = d5;
                d2 = this.trainingWeightSeenByModel;
                d3 = 2.0d;
                d4 = this.swms[i];
            }
            d5 = d * (d2 / (d3 * d4));
        }
    }

    private double getEnsembleMemberWeight(int i) {
        double d = this.swms[i] / (this.scms[i] + this.swms[i]);
        if (d == 0.0d || d > 1.0d - (1.0d / this.numberOfClasses)) {
            return 0.0d;
        }
        return Math.log(1.0d / (d / (1.0d - d))) + Math.log(this.numberOfClasses - 1);
    }

    private ResultContentEvent newResultContentEvent(double[] dArr, InstanceContentEvent instanceContentEvent) {
        ResultContentEvent resultContentEvent = new ResultContentEvent(instanceContentEvent.getInstanceIndex(), instanceContentEvent.getInstance(), instanceContentEvent.getClassId(), dArr, instanceContentEvent.isLastEvent(), instanceContentEvent.getArrivalTimestamp());
        resultContentEvent.setEvaluationIndex(instanceContentEvent.getEvaluationIndex());
        return resultContentEvent;
    }

    public Instances getInputInstances() {
        return this.dataset;
    }

    public void setInputInstances(Instances instances) {
        this.dataset = instances;
    }

    public Stream getResultStream() {
        return this.resultStream;
    }

    public void setResultStream(Stream stream) {
        this.resultStream = stream;
    }

    public int getEnsembleSize() {
        return this.ensembleSize;
    }

    public Stream getControlStream() {
        return this.controlStream;
    }

    public void setControlStream(Stream stream) {
        this.controlStream = stream;
    }

    public Stream getAttributeStream() {
        return this.attributeStream;
    }

    public void setAttributeStream(Stream stream) {
        this.attributeStream = stream;
    }

    public SplitCriterion getSplitCriterion() {
        return this.splitCriterion;
    }

    public Double getSplitConfidence() {
        return this.splitConfidence;
    }

    public Double getTieThreshold() {
        return this.tieThreshold;
    }

    public int getSeed() {
        return this.seed;
    }

    public int getGracePeriod() {
        return this.gracePeriod;
    }

    public int getParallelismHint() {
        return this.parallelismHint;
    }

    public int getTimeOut() {
        return this.timeOut;
    }

    public void setTimeOut(int i) {
        this.timeOut = i;
    }

    public int getNumberOfClasses() {
        return this.numberOfClasses;
    }

    public void setNumberOfClasses(int i) {
        this.numberOfClasses = i;
    }

    public Instances getDataset() {
        return this.dataset;
    }

    @Override // org.apache.samoa.core.Processor
    public Processor newProcessor(Processor processor) {
        BoostVHTProcessor boostVHTProcessor = (BoostVHTProcessor) processor;
        BoostVHTProcessor build = new Builder(boostVHTProcessor).build();
        if (boostVHTProcessor.getResultStream() != null) {
            build.setResultStream(boostVHTProcessor.getResultStream());
            build.setControlStream(boostVHTProcessor.getControlStream());
            build.setAttributeStream(boostVHTProcessor.getAttributeStream());
        }
        return build;
    }
}
