package org.apache.samoa.learners.classifiers.trees;

import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import org.apache.samoa.core.ContentEvent;
import org.apache.samoa.core.Processor;
import org.apache.samoa.instances.Instance;
import org.apache.samoa.instances.Instances;
import org.apache.samoa.instances.InstancesHeader;
import org.apache.samoa.learners.InstanceContentEvent;
import org.apache.samoa.learners.InstancesContentEvent;
import org.apache.samoa.learners.ResultContentEvent;
import org.apache.samoa.moa.classifiers.core.AttributeSplitSuggestion;
import org.apache.samoa.moa.classifiers.core.driftdetection.ChangeDetector;
import org.apache.samoa.moa.classifiers.core.splitcriteria.InfoGainSplitCriterion;
import org.apache.samoa.moa.classifiers.core.splitcriteria.SplitCriterion;
import org.apache.samoa.moa.core.Utils;
import org.apache.samoa.topology.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:org/apache/samoa/learners/classifiers/trees/ModelAggregatorProcessor.class */
public final class ModelAggregatorProcessor implements Processor {
    private static final long serialVersionUID = -1685875718300564886L;
    private static final Logger logger = LoggerFactory.getLogger(ModelAggregatorProcessor.class);
    private int processorId;
    private Node treeRoot;
    private int activeLeafNodeCount;
    private int inactiveLeafNodeCount;
    private int decisionNodeCount;
    private boolean growthAllowed;
    private final Instances dataset;
    private long splitId;
    private ConcurrentMap<Long, SplittingNodeInfo> splittingNodes;
    private BlockingQueue<Long> timedOutSplittingNodes;
    private Stream resultStream;
    private Stream attributeStream;
    private Stream controlStream;
    private transient ScheduledExecutorService executor;
    private final SplitCriterion splitCriterion;
    private final double splitConfidence;
    private final double tieThreshold;
    private final int gracePeriod;
    private final int parallelismHint;
    private final long timeOut;
    protected Set<FoundNode> foundNodeSet;
    private List<InstancesContentEvent> contentEventList;
    private int numBatches;
    protected ChangeDetector changeDetector;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/samoa/learners/classifiers/trees/ModelAggregatorProcessor$AggregationTimeOutHandler.class */
    public static class AggregationTimeOutHandler implements Runnable {
        private static final Logger logger = LoggerFactory.getLogger(AggregationTimeOutHandler.class);
        private final Long splitId;
        private final BlockingQueue<Long> toBeSplittedNodes;

        AggregationTimeOutHandler(Long l, BlockingQueue<Long> blockingQueue) {
            this.splitId = l;
            this.toBeSplittedNodes = blockingQueue;
        }

        @Override // java.lang.Runnable
        public void run() {
            logger.debug("Time out is reached. AggregationTimeOutHandler is started.");
            try {
                this.toBeSplittedNodes.put(this.splitId);
            } catch (InterruptedException e) {
                logger.warn("Interrupted while trying to put the ID into the queue");
            }
            logger.debug("AggregationTimeOutHandler is finished.");
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/samoa/learners/classifiers/trees/ModelAggregatorProcessor$Builder.class */
    public static class Builder {
        private final Instances dataset;
        private SplitCriterion splitCriterion;
        private double splitConfidence;
        private double tieThreshold;
        private int gracePeriod;
        private int parallelismHint;
        private long timeOut;
        private ChangeDetector changeDetector;

        /* JADX INFO: Access modifiers changed from: package-private */
        public Builder(Instances instances) {
            this.splitCriterion = new InfoGainSplitCriterion();
            this.splitConfidence = 1.0E-7d;
            this.tieThreshold = 0.05d;
            this.gracePeriod = 200;
            this.parallelismHint = 1;
            this.timeOut = 30L;
            this.changeDetector = null;
            this.dataset = instances;
        }

        Builder(ModelAggregatorProcessor modelAggregatorProcessor) {
            this.splitCriterion = new InfoGainSplitCriterion();
            this.splitConfidence = 1.0E-7d;
            this.tieThreshold = 0.05d;
            this.gracePeriod = 200;
            this.parallelismHint = 1;
            this.timeOut = 30L;
            this.changeDetector = null;
            this.dataset = modelAggregatorProcessor.dataset;
            this.splitCriterion = modelAggregatorProcessor.splitCriterion;
            this.splitConfidence = modelAggregatorProcessor.splitConfidence;
            this.tieThreshold = modelAggregatorProcessor.tieThreshold;
            this.gracePeriod = modelAggregatorProcessor.gracePeriod;
            this.parallelismHint = modelAggregatorProcessor.parallelismHint;
            this.timeOut = modelAggregatorProcessor.timeOut;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public Builder splitCriterion(SplitCriterion splitCriterion) {
            this.splitCriterion = splitCriterion;
            return this;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public Builder splitConfidence(double d) {
            this.splitConfidence = d;
            return this;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public Builder tieThreshold(double d) {
            this.tieThreshold = d;
            return this;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public Builder gracePeriod(int i) {
            this.gracePeriod = i;
            return this;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public Builder parallelismHint(int i) {
            this.parallelismHint = i;
            return this;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public Builder timeOut(long j) {
            this.timeOut = j;
            return this;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public Builder changeDetector(ChangeDetector changeDetector) {
            this.changeDetector = changeDetector;
            return this;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public ModelAggregatorProcessor build() {
            return new ModelAggregatorProcessor(this);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/samoa/learners/classifiers/trees/ModelAggregatorProcessor$SplittingNodeInfo.class */
    public static class SplittingNodeInfo {
        private final ActiveLearningNode activeLearningNode;
        private final FoundNode foundNode;
        private final ScheduledFuture<?> scheduledFuture;

        SplittingNodeInfo(ActiveLearningNode activeLearningNode, FoundNode foundNode, ScheduledFuture<?> scheduledFuture) {
            this.activeLearningNode = activeLearningNode;
            this.foundNode = foundNode;
            this.scheduledFuture = scheduledFuture;
        }
    }

    private ModelAggregatorProcessor(Builder builder) {
        this.contentEventList = new LinkedList();
        this.numBatches = 0;
        this.dataset = builder.dataset;
        this.splitCriterion = builder.splitCriterion;
        this.splitConfidence = builder.splitConfidence;
        this.tieThreshold = builder.tieThreshold;
        this.gracePeriod = builder.gracePeriod;
        this.parallelismHint = builder.parallelismHint;
        this.timeOut = builder.timeOut;
        this.changeDetector = builder.changeDetector;
        setModelContext(new InstancesHeader(this.dataset));
    }

    @Override // org.apache.samoa.core.Processor
    public boolean process(ContentEvent contentEvent) {
        SplittingNodeInfo splittingNodeInfo;
        Long poll = this.timedOutSplittingNodes.poll();
        if (poll != null && (splittingNodeInfo = this.splittingNodes.get(poll)) != null) {
            this.splittingNodes.remove(poll);
            continueAttemptToSplit(splittingNodeInfo.activeLearningNode, splittingNodeInfo.foundNode);
        }
        if (!(contentEvent instanceof InstancesContentEvent)) {
            if (!(contentEvent instanceof LocalResultContentEvent)) {
                return false;
            }
            LocalResultContentEvent localResultContentEvent = (LocalResultContentEvent) contentEvent;
            Long valueOf = Long.valueOf(localResultContentEvent.getSplitId());
            SplittingNodeInfo splittingNodeInfo2 = this.splittingNodes.get(valueOf);
            if (splittingNodeInfo2 == null) {
                return false;
            }
            ActiveLearningNode activeLearningNode = splittingNodeInfo2.activeLearningNode;
            activeLearningNode.addDistributedSuggestions(localResultContentEvent.getBestSuggestion(), localResultContentEvent.getSecondBestSuggestion());
            if (!activeLearningNode.isAllSuggestionsCollected()) {
                return false;
            }
            splittingNodeInfo2.scheduledFuture.cancel(false);
            this.splittingNodes.remove(valueOf);
            continueAttemptToSplit(activeLearningNode, splittingNodeInfo2.foundNode);
            return false;
        }
        processInstanceContentEvent((InstancesContentEvent) contentEvent);
        if (this.foundNodeSet != null) {
            for (FoundNode foundNode : this.foundNodeSet) {
                ActiveLearningNode activeLearningNode2 = (ActiveLearningNode) foundNode.getNode();
                AttributeBatchContentEvent[] attributeBatchContentEvent = activeLearningNode2.getAttributeBatchContentEvent();
                if (attributeBatchContentEvent != null) {
                    for (int i = 0; i < this.dataset.numAttributes() - 1; i++) {
                        sendToAttributeStream(attributeBatchContentEvent[i]);
                    }
                }
                activeLearningNode2.setAttributeBatchContentEvent(null);
                if (!activeLearningNode2.isSplitting() && activeLearningNode2.getWeightSeen() - activeLearningNode2.getWeightSeenAtLastSplitEvaluation() >= this.gracePeriod) {
                    attemptToSplit(activeLearningNode2, foundNode);
                }
            }
        }
        this.foundNodeSet = null;
        return false;
    }

    @Override // org.apache.samoa.core.Processor
    public void onCreate(int i) {
        this.processorId = i;
        this.activeLeafNodeCount = 0;
        this.inactiveLeafNodeCount = 0;
        this.decisionNodeCount = 0;
        this.growthAllowed = true;
        this.splittingNodes = new ConcurrentHashMap();
        this.timedOutSplittingNodes = new LinkedBlockingQueue();
        this.splitId = 0L;
        this.executor = Executors.newScheduledThreadPool(8);
    }

    @Override // org.apache.samoa.core.Processor
    public Processor newProcessor(Processor processor) {
        ModelAggregatorProcessor modelAggregatorProcessor = (ModelAggregatorProcessor) processor;
        ModelAggregatorProcessor build = new Builder(modelAggregatorProcessor).build();
        build.setResultStream(modelAggregatorProcessor.resultStream);
        build.setAttributeStream(modelAggregatorProcessor.attributeStream);
        build.setControlStream(modelAggregatorProcessor.controlStream);
        return build;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(super.toString());
        sb.append("ActiveLeafNodeCount: ").append(this.activeLeafNodeCount);
        sb.append("InactiveLeafNodeCount: ").append(this.inactiveLeafNodeCount);
        sb.append("DecisionNodeCount: ").append(this.decisionNodeCount);
        sb.append("Growth allowed: ").append(this.growthAllowed);
        return sb.toString();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void setResultStream(Stream stream) {
        this.resultStream = stream;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void setAttributeStream(Stream stream) {
        this.attributeStream = stream;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void setControlStream(Stream stream) {
        this.controlStream = stream;
    }

    void sendToAttributeStream(ContentEvent contentEvent) {
        this.attributeStream.put(contentEvent);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void sendToControlStream(ContentEvent contentEvent) {
        this.controlStream.put(contentEvent);
    }

    private ResultContentEvent newResultContentEvent(double[] dArr, InstanceContentEvent instanceContentEvent) {
        ResultContentEvent resultContentEvent = new ResultContentEvent(instanceContentEvent.getInstanceIndex(), instanceContentEvent.getInstance(), instanceContentEvent.getClassId(), dArr, instanceContentEvent.isLastEvent());
        resultContentEvent.setClassifierIndex(this.processorId);
        resultContentEvent.setEvaluationIndex(instanceContentEvent.getEvaluationIndex());
        return resultContentEvent;
    }

    private ResultContentEvent newResultContentEvent(double[] dArr, Instance instance, InstancesContentEvent instancesContentEvent) {
        ResultContentEvent resultContentEvent = new ResultContentEvent(instancesContentEvent.getInstanceIndex(), instance, (int) instance.classValue(), dArr, instancesContentEvent.isLastEvent());
        resultContentEvent.setClassifierIndex(this.processorId);
        resultContentEvent.setEvaluationIndex(instancesContentEvent.getEvaluationIndex());
        return resultContentEvent;
    }

    private void processInstanceContentEvent(InstancesContentEvent instancesContentEvent) {
        this.numBatches++;
        this.contentEventList.add(instancesContentEvent);
        if (this.numBatches == 1 || this.numBatches > 4) {
            processInstances(this.contentEventList.remove(0));
        }
        if (instancesContentEvent.isLastEvent()) {
            while (!this.contentEventList.isEmpty()) {
                processInstances(this.contentEventList.remove(0));
            }
        }
    }

    private void processInstances(InstancesContentEvent instancesContentEvent) {
        Instance[] instances = instancesContentEvent.getInstances();
        boolean isTesting = instancesContentEvent.isTesting();
        boolean isTraining = instancesContentEvent.isTraining();
        for (Instance instance : instances) {
            processInstance(instance, instancesContentEvent, isTesting, isTraining);
        }
    }

    private void processInstance(Instance instance, InstancesContentEvent instancesContentEvent, boolean z, boolean z2) {
        instance.setDataset(this.dataset);
        double[] dArr = null;
        if (z) {
            dArr = getVotesForInstance(instance, false);
            this.resultStream.put(newResultContentEvent(dArr, instance, instancesContentEvent));
        }
        if (z2) {
            trainOnInstanceImpl(instance);
            if (this.changeDetector != null) {
                if (dArr == null) {
                    dArr = getVotesForInstance(instance);
                }
                boolean correctlyClassifies = correctlyClassifies(instance, dArr);
                double estimation = this.changeDetector.getEstimation();
                this.changeDetector.input(correctlyClassifies ? 0.0d : 1.0d);
                if (this.changeDetector.getEstimation() > estimation) {
                    logger.info("Change detected, resetting the classifier");
                    resetLearning();
                    this.changeDetector.resetLearning();
                }
            }
        }
    }

    private boolean correctlyClassifies(Instance instance, double[] dArr) {
        return Utils.maxIndex(dArr) == ((int) instance.classValue());
    }

    private void resetLearning() {
        this.treeRoot = null;
        for (FoundNode foundNode : findNodes()) {
            Node node = foundNode.getNode();
            if (node instanceof SplitNode) {
                SplitNode splitNode = (SplitNode) node;
                for (int i = 0; i < splitNode.numChildren(); i++) {
                    splitNode.setChild(i, null);
                }
            }
        }
    }

    protected FoundNode[] findNodes() {
        LinkedList linkedList = new LinkedList();
        findNodes(this.treeRoot, null, -1, linkedList);
        return (FoundNode[]) linkedList.toArray(new FoundNode[linkedList.size()]);
    }

    protected void findNodes(Node node, SplitNode splitNode, int i, List<FoundNode> list) {
        if (node != null) {
            list.add(new FoundNode(node, splitNode, i));
            if (node instanceof SplitNode) {
                SplitNode splitNode2 = (SplitNode) node;
                for (int i2 = 0; i2 < splitNode2.numChildren(); i2++) {
                    findNodes(splitNode2.getChild(i2), splitNode2, i2, list);
                }
            }
        }
    }

    private double[] getVotesForInstance(Instance instance) {
        return getVotesForInstance(instance, false);
    }

    private double[] getVotesForInstance(Instance instance, boolean z) {
        double[] dArr;
        FoundNode foundNode = null;
        if (this.treeRoot != null) {
            foundNode = this.treeRoot.filterInstanceToLeaf(instance, null, -1);
            Node node = foundNode.getNode();
            if (node == null) {
                node = foundNode.getParent();
            }
            dArr = node.getClassVotes(instance, this);
        } else {
            dArr = new double[this.dataset.numClasses()];
        }
        if (z) {
            if (this.treeRoot == null) {
                this.treeRoot = newLearningNode(this.parallelismHint);
                this.activeLeafNodeCount = 1;
                foundNode = this.treeRoot.filterInstanceToLeaf(instance, null, -1);
            }
            trainOnInstanceImpl(foundNode, instance);
        }
        return dArr;
    }

    private void trainOnInstanceImpl(Instance instance) {
        if (this.treeRoot == null) {
            this.treeRoot = newLearningNode(this.parallelismHint);
            this.activeLeafNodeCount = 1;
        }
        trainOnInstanceImpl(this.treeRoot.filterInstanceToLeaf(instance, null, -1), instance);
    }

    private void trainOnInstanceImpl(FoundNode foundNode, Instance instance) {
        Node node = foundNode.getNode();
        if (node == null) {
            node = newLearningNode(this.parallelismHint);
            foundNode.getParent().setChild(foundNode.getParentBranch(), node);
            this.activeLeafNodeCount++;
        }
        if (node instanceof LearningNode) {
            ((LearningNode) node).learnFromInstance(instance, this);
        }
        if (this.foundNodeSet == null) {
            this.foundNodeSet = new HashSet();
        }
        this.foundNodeSet.add(foundNode);
    }

    private void attemptToSplit(ActiveLearningNode activeLearningNode, FoundNode foundNode) {
        this.splitId++;
        this.splittingNodes.put(Long.valueOf(this.splitId), new SplittingNodeInfo(activeLearningNode, foundNode, this.executor.schedule(new AggregationTimeOutHandler(Long.valueOf(this.splitId), this.timedOutSplittingNodes), this.timeOut, TimeUnit.SECONDS)));
        activeLearningNode.requestDistributedSuggestions(this.splitId, this);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r3v1, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r6v1, types: [double[], double[][]] */
    private void continueAttemptToSplit(ActiveLearningNode activeLearningNode, FoundNode foundNode) {
        AttributeSplitSuggestion distributedBestSuggestion = activeLearningNode.getDistributedBestSuggestion();
        AttributeSplitSuggestion distributedSecondBestSuggestion = activeLearningNode.getDistributedSecondBestSuggestion();
        double[] observedClassDistribution = activeLearningNode.getObservedClassDistribution();
        AttributeSplitSuggestion attributeSplitSuggestion = new AttributeSplitSuggestion(null, new double[0], this.splitCriterion.getMeritOfSplit(observedClassDistribution, new double[]{observedClassDistribution}));
        if (distributedBestSuggestion == null || attributeSplitSuggestion.compareTo(distributedBestSuggestion) > 0) {
            distributedSecondBestSuggestion = distributedBestSuggestion;
            distributedBestSuggestion = attributeSplitSuggestion;
        } else if (distributedSecondBestSuggestion == null || attributeSplitSuggestion.compareTo(distributedSecondBestSuggestion) > 0) {
            distributedSecondBestSuggestion = attributeSplitSuggestion;
        }
        boolean z = false;
        if (distributedSecondBestSuggestion == null) {
            z = distributedBestSuggestion != null;
        } else {
            double computeHoeffdingBound = computeHoeffdingBound(this.splitCriterion.getRangeOfMerit(activeLearningNode.getObservedClassDistribution()), this.splitConfidence, activeLearningNode.getWeightSeen());
            if (distributedBestSuggestion.merit - distributedSecondBestSuggestion.merit > computeHoeffdingBound || computeHoeffdingBound < this.tieThreshold) {
                z = true;
            }
        }
        SplitNode parent = foundNode.getParent();
        int parentBranch = foundNode.getParentBranch();
        if (z && distributedBestSuggestion.splitTest != null) {
            SplitNode splitNode = new SplitNode(distributedBestSuggestion.splitTest, activeLearningNode.getObservedClassDistribution());
            for (int i = 0; i < distributedBestSuggestion.numSplits(); i++) {
                splitNode.setChild(i, newLearningNode(distributedBestSuggestion.resultingClassDistributionFromSplit(i), this.parallelismHint));
            }
            this.activeLeafNodeCount--;
            this.decisionNodeCount++;
            this.activeLeafNodeCount += distributedBestSuggestion.numSplits();
            if (parent == null) {
                this.treeRoot = splitNode;
            } else {
                parent.setChild(parentBranch, splitNode);
            }
        }
        activeLearningNode.endSplitting();
        activeLearningNode.setWeightSeenAtLastSplitEvaluation(activeLearningNode.getWeightSeen());
    }

    private void deactivateLearningNode(ActiveLearningNode activeLearningNode, SplitNode splitNode, int i) {
        InactiveLearningNode inactiveLearningNode = new InactiveLearningNode(activeLearningNode.getObservedClassDistribution());
        if (splitNode == null) {
            this.treeRoot = inactiveLearningNode;
        } else {
            splitNode.setChild(i, inactiveLearningNode);
        }
        this.activeLeafNodeCount--;
        this.inactiveLeafNodeCount++;
    }

    private LearningNode newLearningNode(int i) {
        return newLearningNode(new double[0], i);
    }

    private LearningNode newLearningNode(double[] dArr, int i) {
        return new ActiveLearningNode(dArr, i);
    }

    private void setModelContext(InstancesHeader instancesHeader) {
        if (instancesHeader != null && instancesHeader.classIndex() < 0) {
            throw new IllegalArgumentException("Context for a classifier must include a class to learn");
        }
        logger.trace("Model context: {}", instancesHeader.toString());
    }

    private static double computeHoeffdingBound(double d, double d2, double d3) {
        return Math.sqrt((Math.pow(d, 2.0d) * Math.log(1.0d / d2)) / (2.0d * d3));
    }

    public ChangeDetector getChangeDetector() {
        return this.changeDetector;
    }

    public void setChangeDetector(ChangeDetector changeDetector) {
        this.changeDetector = changeDetector;
    }
}
