package org.jpmml.evaluator;

import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import com.google.common.collect.Lists;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import org.dmg.pmml.EmbeddedModel;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.MissingValueStrategyType;
import org.dmg.pmml.NoTrueChildStrategyType;
import org.dmg.pmml.Node;
import org.dmg.pmml.PMML;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.ScoreDistribution;
import org.dmg.pmml.TreeModel;
import org.jpmml.manager.InvalidFeatureException;
import org.jpmml.manager.UnsupportedFeatureException;

/* loaded from: input_file:org/jpmml/evaluator/TreeModelEvaluator.class */
public class TreeModelEvaluator extends ModelEvaluator<TreeModel> implements HasEntityRegistry<Node> {
    private static final LoadingCache<TreeModel, BiMap<String, Node>> entityCache = CacheBuilder.newBuilder().weakKeys().build(new CacheLoader<TreeModel, BiMap<String, Node>>() { // from class: org.jpmml.evaluator.TreeModelEvaluator.1
        @Override // com.google.common.cache.CacheLoader
        public BiMap<String, Node> load(TreeModel treeModel) {
            HashBiMap create = HashBiMap.create();
            collectNodes(treeModel.getNode(), create);
            return create;
        }

        private void collectNodes(Node node, BiMap<String, Node> biMap) {
            EntityUtil.put(node, biMap);
            Iterator<Node> it = node.getNodes().iterator();
            while (it.hasNext()) {
                collectNodes(it.next(), biMap);
            }
        }
    });

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/jpmml/evaluator/TreeModelEvaluator$FinalNodeResult.class */
    public static class FinalNodeResult extends NodeResult {
        public FinalNodeResult(Node node) {
            super(node);
        }

        @Override // org.jpmml.evaluator.TreeModelEvaluator.NodeResult
        public boolean isFinal() {
            return true;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/jpmml/evaluator/TreeModelEvaluator$NodeResult.class */
    public static class NodeResult {
        private Node node = null;

        public NodeResult(Node node) {
            setNode(node);
        }

        public boolean isFinal() {
            return false;
        }

        public Node getNode() {
            return this.node;
        }

        private void setNode(Node node) {
            this.node = node;
        }
    }

    public TreeModelEvaluator(PMML pmml) {
        this(pmml, (TreeModel) find(pmml.getModels(), TreeModel.class));
    }

    public TreeModelEvaluator(PMML pmml, TreeModel treeModel) {
        super(pmml, treeModel);
    }

    @Override // org.jpmml.manager.ModelManager, org.jpmml.manager.Consumer
    public String getSummary() {
        return "Tree model";
    }

    @Override // org.jpmml.evaluator.HasEntityRegistry
    public BiMap<String, Node> getEntityRegistry() {
        return (BiMap) getValue(entityCache);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.jpmml.evaluator.ModelEvaluator
    public Map<FieldName, ?> evaluate(ModelEvaluationContext modelEvaluationContext) {
        TreeModel treeModel = (TreeModel) getModel();
        if (!treeModel.isScorable()) {
            throw new InvalidResultException(treeModel);
        }
        MiningFunctionType functionName = treeModel.getFunctionName();
        switch (functionName) {
            case REGRESSION:
            case CLASSIFICATION:
                Node evaluateTree = evaluateTree(modelEvaluationContext);
                NodeClassificationMap nodeClassificationMap = null;
                if (evaluateTree != null) {
                    nodeClassificationMap = createNodeClassificationMap(evaluateTree);
                }
                return OutputUtil.evaluate(TargetUtil.evaluateClassification((ClassificationMap<?>) nodeClassificationMap, modelEvaluationContext), modelEvaluationContext);
            default:
                throw new UnsupportedFeatureException(treeModel, functionName);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Node evaluateTree(ModelEvaluationContext modelEvaluationContext) {
        TreeModel treeModel = (TreeModel) getModel();
        Node node = treeModel.getNode();
        if (node == null) {
            throw new InvalidFeatureException(treeModel);
        }
        LinkedList<Node> newLinkedList = Lists.newLinkedList();
        NodeResult nodeResult = new NodeResult(null);
        Boolean evaluateNode = evaluateNode(node, modelEvaluationContext);
        if (evaluateNode == null) {
            nodeResult = handleMissingValue(node, newLinkedList, modelEvaluationContext);
        } else if (evaluateNode.booleanValue()) {
            nodeResult = handleTrue(node, newLinkedList, modelEvaluationContext);
        }
        if (nodeResult == null) {
            throw new MissingResultException(node);
        }
        Node node2 = nodeResult.getNode();
        if (node2 != null || nodeResult.isFinal()) {
            return node2;
        }
        NoTrueChildStrategyType noTrueChildStrategy = treeModel.getNoTrueChildStrategy();
        switch (noTrueChildStrategy) {
            case RETURN_NULL_PREDICTION:
                return null;
            case RETURN_LAST_PREDICTION:
                return lastPrediction(node, newLinkedList);
            default:
                throw new UnsupportedFeatureException(treeModel, noTrueChildStrategy);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private NodeResult handleMissingValue(Node node, LinkedList<Node> linkedList, EvaluationContext evaluationContext) {
        TreeModel treeModel = (TreeModel) getModel();
        MissingValueStrategyType missingValueStrategy = treeModel.getMissingValueStrategy();
        switch (missingValueStrategy) {
            case NULL_PREDICTION:
                return new FinalNodeResult(null);
            case LAST_PREDICTION:
                return new FinalNodeResult(lastPrediction(node, linkedList));
            case NONE:
                return null;
            default:
                throw new UnsupportedFeatureException(treeModel, missingValueStrategy);
        }
    }

    private NodeResult handleTrue(Node node, LinkedList<Node> linkedList, EvaluationContext evaluationContext) {
        List<Node> nodes = node.getNodes();
        if (nodes.isEmpty()) {
            return new NodeResult(node);
        }
        linkedList.add(node);
        for (Node node2 : nodes) {
            Boolean evaluateNode = evaluateNode(node2, evaluationContext);
            if (evaluateNode == null) {
                NodeResult handleMissingValue = handleMissingValue(node2, linkedList, evaluationContext);
                if (handleMissingValue != null) {
                    return handleMissingValue;
                }
            } else if (evaluateNode.booleanValue()) {
                return handleTrue(node2, linkedList, evaluationContext);
            }
        }
        return new NodeResult(null);
    }

    private Node lastPrediction(Node node, LinkedList<Node> linkedList) {
        try {
            return linkedList.getLast();
        } catch (NoSuchElementException e) {
            throw new MissingResultException(node);
        }
    }

    private Boolean evaluateNode(Node node, EvaluationContext evaluationContext) {
        Predicate predicate = node.getPredicate();
        if (predicate == null) {
            throw new InvalidFeatureException(node);
        }
        EmbeddedModel embeddedModel = node.getEmbeddedModel();
        if (embeddedModel != null) {
            throw new UnsupportedFeatureException(embeddedModel);
        }
        return PredicateUtil.evaluate(predicate, evaluationContext);
    }

    private static NodeClassificationMap createNodeClassificationMap(Node node) {
        NodeClassificationMap nodeClassificationMap = new NodeClassificationMap(node);
        List<ScoreDistribution> scoreDistributions = node.getScoreDistributions();
        double d = 0.0d;
        Iterator<ScoreDistribution> it = scoreDistributions.iterator();
        while (it.hasNext()) {
            d += it.next().getRecordCount();
        }
        for (ScoreDistribution scoreDistribution : scoreDistributions) {
            Double probability = scoreDistribution.getProbability();
            if (probability == null) {
                probability = Double.valueOf(scoreDistribution.getRecordCount() / d);
            }
            nodeClassificationMap.put(scoreDistribution.getValue(), probability);
        }
        return nodeClassificationMap;
    }
}
