package hex.genmodel.algos.tree;

import ai.h2o.algos.tree.INode;
import ai.h2o.algos.tree.INodeStat;
import hex.genmodel.tools.PrintMojo;
import hex.genmodel.utils.GenmodelBitSet;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Iterator;
import java.util.Objects;

/* loaded from: input_file:www/3/h2o-genmodel.jar:hex/genmodel/algos/tree/SharedTreeNode.class */
public class SharedTreeNode implements INode<double[]>, INodeStat {
    final int internalId;
    final SharedTreeNode parent;
    final int subgraphNumber;
    int nodeNumber;
    float weight;
    final int depth;
    int colId;
    String colName;
    boolean leftward;
    boolean naVsRest;
    String[] domainValues;
    GenmodelBitSet bs;
    SharedTreeNode leftChild;
    public SharedTreeNode rightChild;
    private boolean inclusiveNa;
    private BitSet inclusiveLevels;
    static final /* synthetic */ boolean $assertionsDisabled;
    float splitValue = Float.NaN;
    float predValue = Float.NaN;
    float squaredError = Float.NaN;

    /* JADX INFO: Access modifiers changed from: package-private */
    public SharedTreeNode(int i, SharedTreeNode sharedTreeNode, int i2, int i3) {
        this.internalId = i;
        this.parent = sharedTreeNode;
        this.subgraphNumber = i2;
        this.depth = i3;
    }

    public int getDepth() {
        return this.depth;
    }

    public int getNodeNumber() {
        return this.nodeNumber;
    }

    @Override // ai.h2o.algos.tree.INodeStat
    public float getWeight() {
        return this.weight;
    }

    public void setNodeNumber(int i) {
        this.nodeNumber = i;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void setWeight(float f) {
        this.weight = f;
    }

    public void setCol(int i, String str) {
        this.colId = i;
        this.colName = str;
    }

    public int getColId() {
        return this.colId;
    }

    public void setLeftward(boolean z) {
        this.leftward = z;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void setNaVsRest(boolean z) {
        this.naVsRest = z;
    }

    public void setSplitValue(float f) {
        this.splitValue = f;
    }

    public void setColName(String str) {
        this.colName = str;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void setBitset(String[] strArr, GenmodelBitSet genmodelBitSet) {
        if (!$assertionsDisabled && strArr == null) {
            throw new AssertionError();
        }
        this.domainValues = strArr;
        this.bs = genmodelBitSet;
    }

    public void setPredValue(float f) {
        this.predValue = f;
    }

    public void setSquaredError(float f) {
        this.squaredError = f;
    }

    private boolean findInclusiveNa(int i) {
        if (this.parent == null) {
            return true;
        }
        return this.parent.getColId() == i ? this.inclusiveNa : this.parent.findInclusiveNa(i);
    }

    private boolean calculateChildInclusiveNa(boolean z) {
        return findInclusiveNa(this.colId) && z;
    }

    private BitSet findInclusiveLevels(int i) {
        if (this.parent == null) {
            return null;
        }
        return this.parent.getColId() == i ? this.inclusiveLevels : this.parent.findInclusiveLevels(i);
    }

    private boolean calculateIncludeThisLevel(BitSet bitSet, int i) {
        return bitSet == null || bitSet.get(i);
    }

    private BitSet calculateChildInclusiveLevels(boolean z, boolean z2, boolean z3) {
        BitSet findInclusiveLevels = findInclusiveLevels(this.colId);
        BitSet bitSet = new BitSet();
        for (int i = 0; i < this.domainValues.length; i++) {
            boolean z4 = false;
            if (z2) {
                z4 = false;
            } else if (z) {
                z4 = calculateIncludeThisLevel(findInclusiveLevels, i);
            } else if (!Float.isNaN(this.splitValue)) {
                z4 = (this.splitValue < ((float) i)) ^ (!z3);
            } else if (this.bs.isInRange(i) && this.bs.contains(i) == z3) {
                z4 = calculateIncludeThisLevel(findInclusiveLevels, i);
            }
            if (z4) {
                bitSet.set(i);
            }
        }
        return bitSet;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void setLeftChild(SharedTreeNode sharedTreeNode) {
        this.leftChild = sharedTreeNode;
        sharedTreeNode.setInclusiveNa(calculateChildInclusiveNa(this.leftward));
        if (isBitset()) {
            sharedTreeNode.setInclusiveLevels(calculateChildInclusiveLevels(this.naVsRest, false, false));
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void setRightChild(SharedTreeNode sharedTreeNode) {
        this.rightChild = sharedTreeNode;
        sharedTreeNode.setInclusiveNa(calculateChildInclusiveNa(!this.leftward));
        if (isBitset()) {
            sharedTreeNode.setInclusiveLevels(calculateChildInclusiveLevels(false, this.naVsRest, true));
        }
    }

    public void setInclusiveNa(boolean z) {
        this.inclusiveNa = z;
    }

    public boolean getInclusiveNa() {
        return this.inclusiveNa;
    }

    private void setInclusiveLevels(BitSet bitSet) {
        this.inclusiveLevels = bitSet;
    }

    public BitSet getInclusiveLevels() {
        return this.inclusiveLevels;
    }

    public String getName() {
        return "Node " + this.nodeNumber;
    }

    public void print() {
        print(System.out, null);
    }

    public void print(PrintStream printStream, String str) {
        printStream.println("        Node " + this.nodeNumber + (str != null ? " (" + str + ")" : ""));
        printStream.println("            weight:      " + this.weight);
        printStream.println("            depth:       " + this.depth);
        printStream.println("            colId:       " + this.colId);
        printStream.println("            colName:     " + (this.colName != null ? this.colName : ""));
        printStream.println("            leftward:    " + this.leftward);
        printStream.println("            naVsRest:    " + this.naVsRest);
        printStream.println("            splitVal:    " + this.splitValue);
        printStream.println("            isBitset:    " + isBitset());
        printStream.println("            predValue:   " + this.predValue);
        printStream.println("            squaredErr:  " + this.squaredError);
        printStream.println("            leftChild:   " + (this.leftChild != null ? this.leftChild.getName() : ""));
        printStream.println("            rightChild:  " + (this.rightChild != null ? this.rightChild.getName() : ""));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void printEdges() {
        if (this.leftChild != null) {
            System.out.println("        " + getName() + " ---left---> " + this.leftChild.getName());
            this.leftChild.printEdges();
        }
        if (this.rightChild != null) {
            System.out.println("        " + getName() + " ---right--> " + this.rightChild.getName());
            this.rightChild.printEdges();
        }
    }

    private String getDotName() {
        return "SG_" + this.subgraphNumber + "_Node_" + this.nodeNumber;
    }

    public boolean isBitset() {
        return this.domainValues != null;
    }

    public static String escapeQuotes(String str) {
        return str.replace("\"", "\\\"");
    }

    private void printDotNode(PrintStream printStream, boolean z, PrintMojo.PrintTreeOptions printTreeOptions) {
        printStream.print("\"" + getDotName() + "\"");
        printStream.print(" [");
        if (this.leftChild == null && this.rightChild == null) {
            printStream.print("fontsize=" + printTreeOptions._fontSize + ", label=\"");
            printStream.print(printTreeOptions._setDecimalPlace ? printTreeOptions.roundNPlace(this.predValue) : this.predValue);
        } else if (isBitset() && (Float.isNaN(this.splitValue) || !printTreeOptions._internal)) {
            printStream.print("shape=box, fontsize=" + printTreeOptions._fontSize + ", label=\"");
            printStream.print(escapeQuotes(this.colName));
        } else {
            if (!$assertionsDisabled && Float.isNaN(this.splitValue)) {
                throw new AssertionError();
            }
            float roundNPlace = printTreeOptions._setDecimalPlace ? printTreeOptions.roundNPlace(this.splitValue) : this.splitValue;
            printStream.print("shape=box, fontsize=" + printTreeOptions._fontSize + ", label=\"");
            printStream.print(escapeQuotes(this.colName) + " < " + roundNPlace);
        }
        if (z) {
            printStream.print("\\n\\nN" + getNodeNumber() + "\\n");
            if ((this.leftChild != null || this.rightChild != null) && !Float.isNaN(this.predValue)) {
                printStream.print("\\nPred: " + (printTreeOptions._setDecimalPlace ? printTreeOptions.roundNPlace(this.predValue) : this.predValue));
            }
            if (!Float.isNaN(this.squaredError)) {
                printStream.print("\\nSE: " + this.squaredError);
            }
            printStream.print("\\nW: " + getWeight());
            if (this.naVsRest) {
                printStream.print("\\nnasVsRest");
            }
            if (this.leftChild != null) {
                printStream.print("\\nL: N" + this.leftChild.getNodeNumber());
            }
            if (this.rightChild != null) {
                printStream.print("\\nR: N" + this.rightChild.getNodeNumber());
            }
        }
        printStream.print("\"]");
        printStream.println("");
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void printDotNodesAtLevel(PrintStream printStream, int i, boolean z, PrintMojo.PrintTreeOptions printTreeOptions) {
        if (getDepth() == i) {
            printDotNode(printStream, z, printTreeOptions);
            return;
        }
        if (!$assertionsDisabled && getDepth() >= i) {
            throw new AssertionError();
        }
        if (this.leftChild != null) {
            this.leftChild.printDotNodesAtLevel(printStream, i, z, printTreeOptions);
        }
        if (this.rightChild != null) {
            this.rightChild.printDotNodesAtLevel(printStream, i, z, printTreeOptions);
        }
    }

    private void printDotEdgesCommon(PrintStream printStream, int i, ArrayList<String> arrayList, SharedTreeNode sharedTreeNode, float f, boolean z, PrintMojo.PrintTreeOptions printTreeOptions) {
        if (isBitset() || (!Float.isNaN(this.splitValue) && printTreeOptions._internal)) {
            BitSet inclusiveLevels = sharedTreeNode.getInclusiveLevels();
            int cardinality = inclusiveLevels.cardinality();
            if (cardinality > 0 && cardinality <= i) {
                int nextSetBit = inclusiveLevels.nextSetBit(0);
                while (true) {
                    int i2 = nextSetBit;
                    if (i2 < 0) {
                        break;
                    }
                    arrayList.add(this.domainValues[i2]);
                    nextSetBit = inclusiveLevels.nextSetBit(i2 + 1);
                }
            } else {
                arrayList.add(cardinality + " levels");
            }
        }
        if (z) {
            try {
                int round = Math.round((sharedTreeNode.getWeight() / f) * 14.0f) + 1;
                printStream.print("penwidth=");
                printStream.print(round);
                printStream.print(",");
            } catch (Exception e) {
            }
        }
        printStream.print("fontsize=" + printTreeOptions._fontSize + ", label=\"");
        Iterator<String> it = arrayList.iterator();
        while (it.hasNext()) {
            printStream.print(escapeQuotes(it.next()) + "\\n");
        }
        printStream.print("\"");
        printStream.println("]");
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void printDotEdges(PrintStream printStream, int i, float f, boolean z, PrintMojo.PrintTreeOptions printTreeOptions) {
        if (!$assertionsDisabled) {
            if ((this.leftChild == null) != (this.rightChild == null)) {
                throw new AssertionError();
            }
        }
        if (this.leftChild != null) {
            printStream.print("\"" + getDotName() + "\" -> \"" + this.leftChild.getDotName() + "\" [");
            ArrayList<String> arrayList = new ArrayList<>();
            if (this.leftChild.getInclusiveNa()) {
                arrayList.add("[NA]");
            }
            if (this.naVsRest) {
                arrayList.add("[Not NA]");
            } else if (!isBitset() || (!Float.isNaN(this.splitValue) && printTreeOptions._internal)) {
                arrayList.add("<");
            }
            printDotEdgesCommon(printStream, i, arrayList, this.leftChild, f, z, printTreeOptions);
        }
        if (this.rightChild != null) {
            printStream.print("\"" + getDotName() + "\" -> \"" + this.rightChild.getDotName() + "\" [");
            ArrayList<String> arrayList2 = new ArrayList<>();
            if (this.rightChild.getInclusiveNa()) {
                arrayList2.add("[NA]");
            }
            if (!this.naVsRest && (!isBitset() || (!Float.isNaN(this.splitValue) && printTreeOptions._internal))) {
                arrayList2.add(">=");
            }
            printDotEdgesCommon(printStream, i, arrayList2, this.rightChild, f, z, printTreeOptions);
        }
    }

    public SharedTreeNode getParent() {
        return this.parent;
    }

    public int getSubgraphNumber() {
        return this.subgraphNumber;
    }

    public String getColName() {
        return this.colName;
    }

    public boolean isLeftward() {
        return this.leftward;
    }

    public boolean isNaVsRest() {
        return this.naVsRest;
    }

    public float getSplitValue() {
        return this.splitValue;
    }

    public String[] getDomainValues() {
        return this.domainValues;
    }

    public void setDomainValues(String[] strArr) {
        this.domainValues = strArr;
    }

    public GenmodelBitSet getBs() {
        return this.bs;
    }

    public float getPredValue() {
        return this.predValue;
    }

    public float getSquaredError() {
        return this.squaredError;
    }

    public SharedTreeNode getLeftChild() {
        return this.leftChild;
    }

    public SharedTreeNode getRightChild() {
        return this.rightChild;
    }

    public boolean isInclusiveNa() {
        return this.inclusiveNa;
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        SharedTreeNode sharedTreeNode = (SharedTreeNode) obj;
        return this.subgraphNumber == sharedTreeNode.subgraphNumber && this.nodeNumber == sharedTreeNode.nodeNumber && Float.compare(sharedTreeNode.weight, this.weight) == 0 && this.depth == sharedTreeNode.depth && this.colId == sharedTreeNode.colId && this.leftward == sharedTreeNode.leftward && this.naVsRest == sharedTreeNode.naVsRest && Float.compare(sharedTreeNode.splitValue, this.splitValue) == 0 && Float.compare(sharedTreeNode.predValue, this.predValue) == 0 && Float.compare(sharedTreeNode.squaredError, this.squaredError) == 0 && this.inclusiveNa == sharedTreeNode.inclusiveNa && Objects.equals(this.colName, sharedTreeNode.colName) && Arrays.equals(this.domainValues, sharedTreeNode.domainValues) && Objects.equals(this.leftChild, sharedTreeNode.leftChild) && Objects.equals(this.rightChild, sharedTreeNode.rightChild) && Objects.equals(this.inclusiveLevels, sharedTreeNode.inclusiveLevels);
    }

    public int hashCode() {
        return Objects.hash(Integer.valueOf(this.subgraphNumber), Integer.valueOf(this.nodeNumber));
    }

    @Override // ai.h2o.algos.tree.INode
    public final boolean isLeaf() {
        return this.leftChild == null && this.rightChild == null;
    }

    @Override // ai.h2o.algos.tree.INode
    public final float getLeafValue() {
        return this.predValue;
    }

    @Override // ai.h2o.algos.tree.INode
    public final int getSplitIndex() {
        return this.colId;
    }

    @Override // ai.h2o.algos.tree.INode
    public final int next(double[] dArr) {
        double d = dArr[this.colId];
        return (Double.isNaN(d) || (!(this.bs == null || this.bs.isInRange((int) d)) || (this.domainValues != null && this.domainValues.length <= ((int) d))) ? this.leftward : this.naVsRest || (this.bs != null ? !this.bs.contains((int) d) : d < ((double) this.splitValue))) ? getLeftChildIndex() : getRightChildIndex();
    }

    @Override // ai.h2o.algos.tree.INode
    public final int getLeftChildIndex() {
        if (this.leftChild != null) {
            return this.leftChild.internalId;
        }
        return -1;
    }

    @Override // ai.h2o.algos.tree.INode
    public final int getRightChildIndex() {
        if (this.rightChild != null) {
            return this.rightChild.internalId;
        }
        return -1;
    }

    static {
        $assertionsDisabled = !SharedTreeNode.class.desiredAssertionStatus();
    }
}
