package hex.tree;

import hex.Model;
import hex.ModelCategory;
import hex.genmodel.algos.tree.SharedTreeGraph;
import hex.genmodel.algos.tree.SharedTreeGraphConverter;
import hex.genmodel.algos.tree.SharedTreeNode;
import hex.genmodel.algos.tree.SharedTreeSubgraph;
import hex.schemas.TreeV3;
import hex.tree.SharedTreeModel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.List;
import java.util.Objects;
import java.util.stream.IntStream;
import water.MemoryManager;
import water.api.Handler;

/* loaded from: input_file:hex/tree/TreeHandler.class */
public class TreeHandler extends Handler {
    private static final int NO_CHILD = -1;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:hex/tree/TreeHandler$TreeProperties.class */
    public static class TreeProperties {
        public int[] _leftChildren;
        public int[] _rightChildren;
        public String[] _descriptions;
        public float[] _thresholds;
        public String[] _features;
        public int[][] levels;
        public String[] _nas;
        public float[] _predictions;
        public String _treeDecisionPath;
        public String[] _leafNodeAssignments;
        public String[] _decisionPaths;
        private int[] _leftChildrenNormalized;
        private int[] _rightChildrenNormalized;
        private String[][] _domainValues;
    }

    public TreeV3 getTree(int i, TreeV3 treeV3) {
        SharedTreeSubgraph sharedTreeSubgraph;
        if (treeV3.tree_number < 0) {
            throw new IllegalArgumentException("Invalid tree number: " + treeV3.tree_number + ". Tree number must be >= 0.");
        }
        Model model = treeV3.model.key().get();
        if (model == null) {
            throw new IllegalArgumentException("Given model does not exist: " + treeV3.model.key().toString());
        }
        if (!(model instanceof SharedTreeModel) && !(model instanceof SharedTreeGraphConverter)) {
            throw new IllegalArgumentException("Given model is not tree-based.");
        }
        if (model instanceof SharedTreeGraphConverter) {
            SharedTreeGraph convert = ((SharedTreeGraphConverter) model).convert(treeV3.tree_number, treeV3.tree_class);
            if (!$assertionsDisabled && convert.subgraphArray.size() != 1) {
                throw new AssertionError();
            }
            sharedTreeSubgraph = (SharedTreeSubgraph) convert.subgraphArray.get(0);
            if (!model._output.isClassifier()) {
                treeV3.tree_class = null;
            }
        } else {
            SharedTreeModel sharedTreeModel = (SharedTreeModel) model;
            SharedTreeModel.SharedTreeOutput sharedTreeOutput = (SharedTreeModel.SharedTreeOutput) sharedTreeModel._output;
            int responseLevelIndex = getResponseLevelIndex(treeV3.tree_class, sharedTreeOutput);
            sharedTreeSubgraph = sharedTreeModel.getSharedTreeSubgraph(treeV3.tree_number, responseLevelIndex);
            treeV3.tree_class = sharedTreeOutput.isClassifier() ? sharedTreeOutput.classNames()[responseLevelIndex] : null;
        }
        TreeProperties convertSharedTreeSubgraph = convertSharedTreeSubgraph(sharedTreeSubgraph);
        treeV3.left_children = convertSharedTreeSubgraph._leftChildren;
        treeV3.right_children = convertSharedTreeSubgraph._rightChildren;
        treeV3.descriptions = convertSharedTreeSubgraph._descriptions;
        treeV3.root_node_id = sharedTreeSubgraph.rootNode.getNodeNumber();
        treeV3.thresholds = convertSharedTreeSubgraph._thresholds;
        treeV3.features = convertSharedTreeSubgraph._features;
        treeV3.nas = convertSharedTreeSubgraph._nas;
        treeV3.levels = convertSharedTreeSubgraph.levels;
        treeV3.predictions = convertSharedTreeSubgraph._predictions;
        treeV3.tree_decision_path = convertSharedTreeSubgraph._treeDecisionPath;
        treeV3.decision_paths = convertSharedTreeSubgraph._decisionPaths;
        return treeV3;
    }

    private static String getLanguageRepresentation(SharedTreeSubgraph sharedTreeSubgraph) {
        return getNodeRepresentation(sharedTreeSubgraph.rootNode, new StringBuilder(), 0).toString();
    }

    private static StringBuilder getNodeRepresentation(SharedTreeNode sharedTreeNode, StringBuilder sb, int i) {
        if (sharedTreeNode.getRightChild() != null) {
            sb.append((CharSequence) getConditionLine(sharedTreeNode, i));
            sb.append((CharSequence) getNewPaddedLine(i));
            StringBuilder nodeRepresentation = getNodeRepresentation(sharedTreeNode.getRightChild(), sb, i + 1);
            nodeRepresentation.append((CharSequence) getNewPaddedLine(i));
            nodeRepresentation.append((CharSequence) getElseLine(sharedTreeNode));
            nodeRepresentation.append((CharSequence) getNewPaddedLine(i));
            sb = getNodeRepresentation(sharedTreeNode.getLeftChild(), nodeRepresentation, i + 1);
            sb.append((CharSequence) getNewPaddedLine(i));
            sb.append("}");
        } else {
            sb.append((CharSequence) getNewPaddedLine(i));
            if (Float.compare(sharedTreeNode.getPredValue(), Float.NaN) != 0) {
                sb.append("Predicted value: " + sharedTreeNode.getPredValue());
            } else {
                sb.append("Predicted value: NaN");
            }
            sb.append((CharSequence) getNewPaddedLine(i));
        }
        return sb;
    }

    private static StringBuilder getNewPaddedLine(int i) {
        StringBuilder sb = new StringBuilder("\n");
        for (int i2 = 0; i2 < i; i2++) {
            sb.append("\t");
        }
        return sb;
    }

    private static StringBuilder getElseLine(SharedTreeNode sharedTreeNode) {
        StringBuilder sb = new StringBuilder();
        if (sharedTreeNode.getDomainValues() == null) {
            sb.append("} else {");
        } else {
            SharedTreeNode leftChild = sharedTreeNode.getLeftChild();
            sb.append("} else if ( ").append(sharedTreeNode.getColName()).append(" is in [ ");
            BitSet inclusiveLevels = leftChild.getInclusiveLevels();
            if (inclusiveLevels != null) {
                String bitSet = inclusiveLevels.toString();
                int length = inclusiveLevels.toString().length();
                if (length > 2) {
                    for (String str : bitSet.substring(1, length - 1).split(",")) {
                        sb.append(sharedTreeNode.getDomainValues()[Integer.parseInt(str.trim())] + " ");
                    }
                } else {
                    sb.append("Missing set of levels for underlying node");
                }
            }
            sb.append("]) {");
        }
        return sb;
    }

    private static StringBuilder getConditionLine(SharedTreeNode sharedTreeNode, int i) {
        StringBuilder sb = new StringBuilder();
        if (i != 0) {
            sb.append((CharSequence) getNewPaddedLine(i));
        }
        if (sharedTreeNode.getDomainValues() != null) {
            sb.append("If ( " + sharedTreeNode.getColName() + " is in [ ");
            SharedTreeNode rightChild = sharedTreeNode.getRightChild();
            String bitSet = rightChild.getInclusiveLevels().toString();
            int length = rightChild.getInclusiveLevels().toString().length();
            if (length > 2) {
                Arrays.stream(bitSet.substring(1, length - 1).split(",")).map((v0) -> {
                    return v0.trim();
                }).map(Integer::parseInt).forEach(num -> {
                    sb.append(sharedTreeNode.getDomainValues()[num.intValue()] + " ");
                });
            } else {
                sb.append("Missing set of levels for underlying node");
            }
            sb.append("]) {");
        } else if (Float.compare(sharedTreeNode.getSplitValue(), Float.NaN) == 0) {
            sb.append("If ( " + sharedTreeNode.getColName() + " is NaN ) {");
        } else {
            sb.append("If ( " + sharedTreeNode.getColName() + " >= " + sharedTreeNode.getSplitValue());
            if ("RIGHT".equals(getNaDirection(sharedTreeNode))) {
                sb.append(" or ").append(sharedTreeNode.getColName()).append(" is NaN ) {");
            } else {
                sb.append(" ) {");
            }
        }
        return sb;
    }

    private static int getResponseLevelIndex(String str, SharedTreeModel.SharedTreeOutput sharedTreeOutput) {
        String trim = str != null ? str.trim() : "";
        if (!sharedTreeOutput.isClassifier()) {
            if (trim.isEmpty()) {
                return 0;
            }
            throw new IllegalArgumentException("There are no tree classes for " + sharedTreeOutput.getModelCategory() + ".");
        }
        String[] strArr = sharedTreeOutput._domains[sharedTreeOutput.responseIdx()];
        if (sharedTreeOutput.getModelCategory() == ModelCategory.Binomial) {
            if (trim.isEmpty() || trim.equals(strArr[0])) {
                return 0;
            }
            throw new IllegalArgumentException("For binomial, only one tree class has been built per each iteration: " + strArr[0]);
        }
        for (int i = 0; i < strArr.length; i++) {
            if (trim.equals(strArr[i])) {
                return i;
            }
        }
        throw new IllegalArgumentException("There is no such tree class. Given categorical level does not exist in response column: " + trim);
    }

    /* JADX WARN: Type inference failed for: r1v53, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r1v62, types: [java.lang.String[], java.lang.String[][]] */
    static TreeProperties convertSharedTreeSubgraph(SharedTreeSubgraph sharedTreeSubgraph) {
        Objects.requireNonNull(sharedTreeSubgraph);
        TreeProperties treeProperties = new TreeProperties();
        treeProperties._leftChildren = MemoryManager.malloc4(sharedTreeSubgraph.nodesArray.size());
        treeProperties._rightChildren = MemoryManager.malloc4(sharedTreeSubgraph.nodesArray.size());
        treeProperties._descriptions = new String[sharedTreeSubgraph.nodesArray.size()];
        treeProperties._thresholds = MemoryManager.malloc4f(sharedTreeSubgraph.nodesArray.size());
        treeProperties._features = new String[sharedTreeSubgraph.nodesArray.size()];
        treeProperties._nas = new String[sharedTreeSubgraph.nodesArray.size()];
        treeProperties._predictions = MemoryManager.malloc4f(sharedTreeSubgraph.nodesArray.size());
        treeProperties._leafNodeAssignments = new String[sharedTreeSubgraph.nodesArray.size()];
        treeProperties._decisionPaths = new String[sharedTreeSubgraph.nodesArray.size()];
        treeProperties._leftChildrenNormalized = MemoryManager.malloc4(sharedTreeSubgraph.nodesArray.size());
        treeProperties._rightChildrenNormalized = MemoryManager.malloc4(sharedTreeSubgraph.nodesArray.size());
        treeProperties._rightChildren[0] = sharedTreeSubgraph.rootNode.getRightChild() != null ? sharedTreeSubgraph.rootNode.getRightChild().getNodeNumber() : -1;
        treeProperties._leftChildren[0] = sharedTreeSubgraph.rootNode.getLeftChild() != null ? sharedTreeSubgraph.rootNode.getLeftChild().getNodeNumber() : -1;
        treeProperties._thresholds[0] = sharedTreeSubgraph.rootNode.getSplitValue();
        treeProperties._features[0] = sharedTreeSubgraph.rootNode.getColName();
        treeProperties._nas[0] = getNaDirection(sharedTreeSubgraph.rootNode);
        treeProperties.levels = new int[sharedTreeSubgraph.nodesArray.size()];
        treeProperties._treeDecisionPath = getLanguageRepresentation(sharedTreeSubgraph);
        treeProperties._decisionPaths[0] = "Predicted value: " + sharedTreeSubgraph.rootNode.getPredValue();
        treeProperties._leftChildrenNormalized[0] = sharedTreeSubgraph.rootNode.getLeftChild() != null ? sharedTreeSubgraph.rootNode.getLeftChild().getNodeNumber() : -1;
        treeProperties._rightChildrenNormalized[0] = sharedTreeSubgraph.rootNode.getRightChild() != null ? sharedTreeSubgraph.rootNode.getRightChild().getNodeNumber() : -1;
        treeProperties._domainValues = new String[sharedTreeSubgraph.nodesArray.size()];
        treeProperties._domainValues[0] = sharedTreeSubgraph.rootNode.getDomainValues();
        ArrayList arrayList = new ArrayList();
        arrayList.add(sharedTreeSubgraph.rootNode);
        append(treeProperties._rightChildren, treeProperties._leftChildren, treeProperties._descriptions, treeProperties._thresholds, treeProperties._features, treeProperties._nas, treeProperties.levels, treeProperties._predictions, arrayList, -1, false, treeProperties._domainValues);
        fillLanguagePathRepresentation(treeProperties);
        return treeProperties;
    }

    private static void append(int[] iArr, int[] iArr2, String[] strArr, float[] fArr, String[] strArr2, String[] strArr3, int[][] iArr3, float[] fArr2, List<SharedTreeNode> list, int i, boolean z, String[][] strArr4) {
        if (list.isEmpty()) {
            return;
        }
        ArrayList arrayList = new ArrayList();
        for (SharedTreeNode sharedTreeNode : list) {
            i++;
            SharedTreeNode leftChild = sharedTreeNode.getLeftChild();
            SharedTreeNode rightChild = sharedTreeNode.getRightChild();
            if (z) {
                fillnodeDescriptions(sharedTreeNode, strArr, fArr, strArr2, iArr3, fArr2, strArr3, i, strArr4);
            } else {
                StringBuilder sb = new StringBuilder();
                sb.append("*** WARNING: This property is deprecated! *** ");
                sb.append("Root node has id ");
                sb.append(sharedTreeNode.getNodeNumber());
                sb.append(" and splits on column '");
                sb.append(sharedTreeNode.getColName());
                sb.append("'. ");
                fillNodeSplitTowardsChildren(sb, sharedTreeNode);
                strArr[i] = sb.toString();
                z = true;
            }
            if (leftChild != null) {
                arrayList.add(leftChild);
                iArr2[i] = leftChild.getNodeNumber();
            } else {
                iArr2[i] = -1;
            }
            if (rightChild != null) {
                arrayList.add(rightChild);
                iArr[i] = rightChild.getNodeNumber();
            } else {
                iArr[i] = -1;
            }
        }
        append(iArr, iArr2, strArr, fArr, strArr2, strArr3, iArr3, fArr2, arrayList, i, true, strArr4);
    }

    private static List<Integer> extractInternalIds(TreeProperties treeProperties) {
        int i = 0;
        ArrayList arrayList = new ArrayList();
        arrayList.add(0);
        for (int i2 = 0; i2 < treeProperties._leftChildren.length; i2++) {
            if (treeProperties._leftChildren[i2] != -1) {
                i++;
                arrayList.add(Integer.valueOf(treeProperties._leftChildren[i2]));
                treeProperties._leftChildrenNormalized[i2] = i;
            } else {
                treeProperties._leftChildrenNormalized[i2] = -1;
            }
            if (treeProperties._rightChildren[i2] != -1) {
                i++;
                arrayList.add(Integer.valueOf(treeProperties._rightChildren[i2]));
                treeProperties._rightChildrenNormalized[i2] = i;
            } else {
                treeProperties._rightChildrenNormalized[i2] = -1;
            }
        }
        return arrayList;
    }

    private static void fillLanguagePathRepresentation(TreeProperties treeProperties) {
        List<Integer> extractInternalIds = extractInternalIds(treeProperties);
        extractInternalIds.forEach(num -> {
            treeProperties._decisionPaths[extractInternalIds.indexOf(num)] = fillNodePath(num.intValue(), extractInternalIds, false, treeProperties);
        });
    }

    private static String fillNodePath(int i, List<Integer> list, boolean z, TreeProperties treeProperties) {
        int i2 = -1;
        int i3 = -1;
        String str = "";
        String str2 = "";
        int indexOf = list.indexOf(Integer.valueOf(i));
        if (z) {
            int[] iArr = treeProperties._leftChildrenNormalized;
            int[] iArr2 = treeProperties._rightChildrenNormalized;
            if (IntStream.of(iArr).anyMatch(i4 -> {
                return i4 == indexOf;
            })) {
                i2 = IntStream.range(0, iArr.length).filter(i5 -> {
                    return iArr[i5] == indexOf;
                }).findAny().getAsInt();
                i3 = list.get(i2).intValue();
                str = getConditionByIndex(i2, "R", treeProperties);
            }
            if (IntStream.of(iArr2).anyMatch(i6 -> {
                return i6 == indexOf;
            })) {
                i2 = IntStream.range(0, iArr2.length).filter(i7 -> {
                    return iArr2[i7] == indexOf;
                }).findAny().getAsInt();
                i3 = list.get(i2).intValue();
                str = getConditionByIndex(i2, "L", treeProperties);
            }
            if (i2 != -1) {
                str2 = (((((str2 + "^\n") + "|\n") + "|\n") + "|\n") + str) + fillNodePath(i3, list, z, treeProperties);
            }
        } else {
            str2 = (str2 + "Predicted value: " + treeProperties._predictions[indexOf] + "\n") + fillNodePath(i, list, true, treeProperties);
        }
        return str2;
    }

    private static String getConditionByIndex(int i, String str, TreeProperties treeProperties) {
        String str2;
        String str3;
        String str4 = " or " + treeProperties._features[i] + " is NaN";
        boolean z = false;
        if (treeProperties._domainValues[i] != null) {
            String str5 = "If ( " + treeProperties._features[i] + " is in [";
            int[] iArr = treeProperties.levels["R".equals(str) ? treeProperties._leftChildrenNormalized[i] : treeProperties._rightChildrenNormalized[i]];
            if (iArr != null) {
                for (int i2 : iArr) {
                    str5 = str5 + treeProperties._domainValues[i][i2] + " ";
                }
            } else {
                str5 = str5 + " ";
            }
            str3 = str5 + " ])\n";
        } else if (Float.compare(treeProperties._thresholds[i], Float.NaN) == 0) {
            str3 = "If ( " + treeProperties._features[i] + ("R".equals(str) ? " is not " : " is ") + "NaN )\n";
        } else {
            if ("R".equals(str)) {
                str2 = " < ";
                if ("LEFT".equals(treeProperties._nas[i])) {
                    z = true;
                }
            } else {
                str2 = " >= ";
                if ("RIGHT".equals(treeProperties._nas[i])) {
                    z = true;
                }
            }
            String str6 = "If ( " + treeProperties._features[i] + str2 + treeProperties._thresholds[i];
            if (z) {
                str6 = str6 + str4;
            }
            str3 = str6 + " )\n";
        }
        return str3;
    }

    private static void fillnodeDescriptions(SharedTreeNode sharedTreeNode, String[] strArr, float[] fArr, String[] strArr2, int[][] iArr, float[] fArr2, String[] strArr3, int i, String[][] strArr4) {
        StringBuilder sb = new StringBuilder();
        int[] extractNodeLevels = sharedTreeNode.getParent().isBitset() ? extractNodeLevels(sharedTreeNode) : null;
        sb.append("*** WARNING: This property is deprecated! *** ");
        sb.append("Node has id ");
        sb.append(sharedTreeNode.getNodeNumber());
        if (sharedTreeNode.getColName() == null || !sharedTreeNode.isLeaf()) {
            sb.append(" and is a terminal node. ");
        } else {
            sb.append(" and splits on column '");
            sb.append(sharedTreeNode.getColName());
            sb.append("'. ");
        }
        fillNodeSplitTowardsChildren(sb, sharedTreeNode);
        if (!Float.isNaN(sharedTreeNode.getParent().getSplitValue())) {
            sb.append(" Parent node split threshold is ");
            sb.append(sharedTreeNode.getParent().getSplitValue());
            sb.append(". Prediction: ");
            sb.append(sharedTreeNode.getPredValue());
            sb.append(".");
        } else if (sharedTreeNode.getParent().isBitset()) {
            extractNodeLevels = extractNodeLevels(sharedTreeNode);
            sb.append(" Parent node split on column [");
            sb.append(sharedTreeNode.getParent().getColName());
            if (extractNodeLevels != null) {
                sb.append("]. Inherited categorical levels from parent split: ");
                for (int i2 = 0; i2 < extractNodeLevels.length; i2++) {
                    sb.append(sharedTreeNode.getParent().getDomainValues()[extractNodeLevels[i2]]);
                    if (i2 != extractNodeLevels.length - 1) {
                        sb.append(",");
                    }
                }
            } else {
                sb.append("]. No categoricals levels inherited from parent.");
            }
        } else {
            sb.append("Split value is NA.");
        }
        strArr[i] = sb.toString();
        strArr2[i] = sharedTreeNode.getColName();
        strArr3[i] = getNaDirection(sharedTreeNode);
        iArr[i] = extractNodeLevels;
        fArr2[i] = sharedTreeNode.getPredValue();
        fArr[i] = sharedTreeNode.getSplitValue();
        strArr4[i] = sharedTreeNode.getDomainValues();
    }

    private static void fillNodeSplitTowardsChildren(StringBuilder sb, SharedTreeNode sharedTreeNode) {
        if (Float.isNaN(sharedTreeNode.getSplitValue())) {
            if (sharedTreeNode.isBitset()) {
                fillNodeCategoricalSplitDescription(sb, sharedTreeNode);
                return;
            }
            return;
        }
        sb.append("Split threshold is ");
        if (sharedTreeNode.getLeftChild() != null) {
            sb.append(" < ");
            sb.append(sharedTreeNode.getSplitValue());
            sb.append(" to the left node (");
            sb.append(sharedTreeNode.getLeftChild().getNodeNumber());
            sb.append(")");
        }
        if (sharedTreeNode.getLeftChild() != null) {
            if (sharedTreeNode.getLeftChild() != null) {
                sb.append(", ");
            }
            sb.append(" >= ");
            sb.append(sharedTreeNode.getSplitValue());
            sb.append(" to the right node (");
            sb.append(sharedTreeNode.getRightChild().getNodeNumber());
            sb.append(")");
        }
        sb.append(".");
    }

    private static int[] extractNodeLevels(SharedTreeNode sharedTreeNode) {
        BitSet inclusiveLevels = sharedTreeNode.getInclusiveLevels();
        int cardinality = inclusiveLevels.cardinality();
        if (cardinality <= 0) {
            return null;
        }
        int[] malloc4 = MemoryManager.malloc4(cardinality);
        int i = 0;
        int nextSetBit = inclusiveLevels.nextSetBit(0);
        while (true) {
            int i2 = nextSetBit;
            if (i2 < 0) {
                return malloc4;
            }
            malloc4[i] = i2;
            i++;
            nextSetBit = inclusiveLevels.nextSetBit(i2 + 1);
        }
    }

    private static void fillNodeCategoricalSplitDescription(StringBuilder sb, SharedTreeNode sharedTreeNode) {
        SharedTreeNode leftChild = sharedTreeNode.getLeftChild();
        SharedTreeNode rightChild = sharedTreeNode.getRightChild();
        int[] extractNodeLevels = extractNodeLevels(leftChild);
        int[] extractNodeLevels2 = extractNodeLevels(rightChild);
        if (leftChild != null) {
            sb.append(" Left child node (");
            sb.append(leftChild.getNodeNumber());
            sb.append(") inherits categorical levels: ");
            if (extractNodeLevels != null) {
                for (int i = 0; i < extractNodeLevels.length; i++) {
                    sb.append(sharedTreeNode.getDomainValues()[extractNodeLevels[i]]);
                    if (i != extractNodeLevels.length - 1) {
                        sb.append(",");
                    }
                }
            }
        }
        if (rightChild != null) {
            sb.append(". Right child node (");
            sb.append(rightChild.getNodeNumber());
            sb.append(") inherits categorical levels: ");
            if (extractNodeLevels2 != null) {
                for (int i2 = 0; i2 < extractNodeLevels2.length; i2++) {
                    sb.append(sharedTreeNode.getDomainValues()[extractNodeLevels2[i2]]);
                    if (i2 != extractNodeLevels2.length - 1) {
                        sb.append(",");
                    }
                }
            }
        }
        sb.append(". ");
    }

    private static String getNaDirection(SharedTreeNode sharedTreeNode) {
        boolean z = sharedTreeNode.getLeftChild() != null && sharedTreeNode.getLeftChild().isInclusiveNa();
        boolean z2 = sharedTreeNode.getRightChild() != null && sharedTreeNode.getRightChild().isInclusiveNa();
        if (!$assertionsDisabled && !(z2 ^ z) && (z2 || z)) {
            throw new AssertionError();
        }
        if (z) {
            return "LEFT";
        }
        if (z2) {
            return "RIGHT";
        }
        return null;
    }

    static {
        $assertionsDisabled = !TreeHandler.class.desiredAssertionStatus();
    }
}
