package org.apache.samoa.learners.classifiers.ensemble;

import com.github.javacliparser.ClassOption;
import com.github.javacliparser.Configurable;
import com.github.javacliparser.FlagOption;
import com.github.javacliparser.FloatOption;
import com.github.javacliparser.IntOption;
import com.google.common.collect.ImmutableSet;
import java.util.Set;
import org.apache.samoa.core.Processor;
import org.apache.samoa.instances.Instances;
import org.apache.samoa.learners.ClassificationLearner;
import org.apache.samoa.learners.classifiers.ensemble.BoostVHTProcessor;
import org.apache.samoa.learners.classifiers.trees.BoostVHTActiveLearningNode;
import org.apache.samoa.learners.classifiers.trees.LocalStatisticsProcessor;
import org.apache.samoa.moa.classifiers.core.attributeclassobservers.AttributeClassObserver;
import org.apache.samoa.moa.classifiers.core.attributeclassobservers.DiscreteAttributeClassObserver;
import org.apache.samoa.moa.classifiers.core.attributeclassobservers.NumericAttributeClassObserver;
import org.apache.samoa.moa.classifiers.core.splitcriteria.SplitCriterion;
import org.apache.samoa.topology.Stream;
import org.apache.samoa.topology.TopologyBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/samoa/learners/classifiers/ensemble/BoostVHT.class */
public class BoostVHT implements ClassificationLearner, Configurable {
    private static final long serialVersionUID = -7523211543185584536L;
    private static final Logger logger = LoggerFactory.getLogger(BoostVHT.class);
    private BoostVHTProcessor boostVHTProcessor;
    protected Stream resultStream;
    protected Stream attributeStream;
    protected Stream controlStream;
    protected Stream computeStream;
    private Instances dataset;
    protected int parallelism;
    private TopologyBuilder topologyBuilder;
    public ClassOption numericEstimatorOption = new ClassOption("numericEstimator", 'n', "Numeric estimator to use.", NumericAttributeClassObserver.class, "GaussianNumericAttributeClassObserver");
    public ClassOption nominalEstimatorOption = new ClassOption("nominalEstimator", 'd', "Nominal estimator to use.", DiscreteAttributeClassObserver.class, "NominalAttributeClassObserver");
    public ClassOption splitCriterionOption = new ClassOption("splitCriterion", 'r', "Split criterion to use.", SplitCriterion.class, "InfoGainSplitCriterion");
    public FloatOption splitConfidenceOption = new FloatOption("splitConfidence", 'c', "The allowable error in split decision, values closer to 0 will take longer to decide.", 1.0E-7d, 0.0d, 1.0d);
    public FloatOption tieThresholdOption = new FloatOption("tieThreshold", 't', "Threshold below which a split will be forced to break ties.", 0.05d, 0.0d, 1.0d);
    public IntOption gracePeriodOption = new IntOption("gracePeriod", 'g', "The number of instances a leaf should observe between split attempts.", 200, 0, Integer.MAX_VALUE);
    public IntOption timeOutOption = new IntOption("timeOut", 'o', "The duration to wait all distributed computation results from local statistics PI, in miliseconds", Integer.MAX_VALUE, 1, Integer.MAX_VALUE);
    public FlagOption binarySplitsOption = new FlagOption("binarySplits", 'b', "Only allow binary splits.");
    public FlagOption splittingOption = new FlagOption("keepInstanceWhileSplitting", 'q', "Keep instances in a buffer while splitting");
    public IntOption maxBufferSizeOption = new IntOption("maxBufferSizeWhileSplitting", 'z', "Maximum buffer size while splitting, use in conjunction with 'q' option. Size 0 means we don't use buffer while splitting", 0, 0, Integer.MAX_VALUE);
    public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's', "The number of models in the bag.", 10, 1, Integer.MAX_VALUE);
    public IntOption seedOption = new IntOption("seed", 'u', "the seed for the rng.", (int) System.currentTimeMillis());
    public IntOption numberOfClassesOption = new IntOption("numberOfClasses", 'k', "The number of classes.", 2, 2, Integer.MAX_VALUE);

    protected void setLayout() {
        int value = this.ensembleSizeOption.getValue();
        try {
            this.boostVHTProcessor = new BoostVHTProcessor.Builder(this.dataset).ensembleSize(this.ensembleSizeOption.getValue()).numberOfClasses(this.numberOfClassesOption.getValue()).splitCriterion((SplitCriterion) ClassOption.createObject(this.splitCriterionOption.getValueAsCLIString(), this.splitCriterionOption.getRequiredType())).splitConfidence(this.splitConfidenceOption.getValue()).tieThreshold(this.tieThresholdOption.getValue()).gracePeriod(this.gracePeriodOption.getValue()).parallelismHint(this.ensembleSizeOption.getValue()).timeOut(this.timeOutOption.getValue()).splittingOption(this.splittingOption.isSet() ? BoostVHTActiveLearningNode.SplittingOption.KEEP : BoostVHTActiveLearningNode.SplittingOption.THROW_AWAY).maxBufferSize(this.maxBufferSizeOption.getValue()).seed(this.seedOption.getValue()).build();
        } catch (Exception e) {
            e.printStackTrace();
        }
        this.topologyBuilder.addProcessor(this.boostVHTProcessor, 1);
        this.attributeStream = this.topologyBuilder.createStream(this.boostVHTProcessor);
        this.controlStream = this.topologyBuilder.createStream(this.boostVHTProcessor);
        LocalStatisticsProcessor build = new LocalStatisticsProcessor.Builder().splitCriterion((SplitCriterion) this.splitCriterionOption.getValue()).binarySplit(this.binarySplitsOption.isSet()).nominalClassObserver((AttributeClassObserver) this.nominalEstimatorOption.getValue()).numericClassObserver((AttributeClassObserver) this.numericEstimatorOption.getValue()).build();
        this.topologyBuilder.addProcessor(build, value);
        this.topologyBuilder.connectInputKeyStream(this.attributeStream, build);
        this.topologyBuilder.connectInputAllStream(this.controlStream, build);
        this.computeStream = this.topologyBuilder.createStream(build);
        build.setComputationResultStream(this.computeStream);
        this.topologyBuilder.connectInputAllStream(this.computeStream, this.boostVHTProcessor);
        this.resultStream = this.topologyBuilder.createStream(this.boostVHTProcessor);
        this.boostVHTProcessor.setResultStream(this.resultStream);
        this.boostVHTProcessor.setAttributeStream(this.attributeStream);
        this.boostVHTProcessor.setControlStream(this.controlStream);
    }

    @Override // org.apache.samoa.learners.Learner
    public void init(TopologyBuilder topologyBuilder, Instances instances, int i) {
        this.topologyBuilder = topologyBuilder;
        this.dataset = instances;
        this.parallelism = i;
        setLayout();
    }

    @Override // org.apache.samoa.learners.Learner
    public Processor getInputProcessor() {
        return this.boostVHTProcessor;
    }

    @Override // org.apache.samoa.learners.Learner
    public Set<Stream> getResultStreams() {
        return ImmutableSet.of(this.resultStream);
    }
}
