package hex.genmodel.algos.xgboost;

import biz.k11i.xgboost.gbm.GBTree;
import biz.k11i.xgboost.gbm.GradBooster;
import biz.k11i.xgboost.tree.RegTree;
import biz.k11i.xgboost.tree.RegTreeNode;
import hex.genmodel.GenModel;
import hex.genmodel.MojoModel;
import hex.genmodel.algos.tree.ConvertTreeOptions;
import hex.genmodel.algos.tree.PlattScalingMojoHelper;
import hex.genmodel.algos.tree.SharedTreeGraph;
import hex.genmodel.algos.tree.SharedTreeGraphConverter;
import hex.genmodel.algos.tree.SharedTreeNode;
import hex.genmodel.algos.tree.SharedTreeSubgraph;
import hex.genmodel.algos.tree.TreeBackedMojoModel;
import java.io.Closeable;
import java.util.Arrays;

/* loaded from: input_file:www/3/h2o-genmodel.jar:hex/genmodel/algos/xgboost/XGBoostMojoModel.class */
public abstract class XGBoostMojoModel extends MojoModel implements TreeBackedMojoModel, SharedTreeGraphConverter, PlattScalingMojoHelper.MojoModelWithCalibration, Closeable {
    private static final String SPACE = " ";
    public String _boosterType;
    public int _ntrees;
    public int _nums;
    public int _cats;
    public int[] _catOffsets;
    public boolean _useAllFactorLevels;
    public boolean _sparse;
    public String _featureMap;
    public boolean _hasOffset;
    protected double[] _calib_glm_beta;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:www/3/h2o-genmodel.jar:hex/genmodel/algos/xgboost/XGBoostMojoModel$ObjectiveType.class */
    public enum ObjectiveType {
        BINARY_LOGISTIC("binary:logistic"),
        REG_GAMMA("reg:gamma"),
        REG_TWEEDIE("reg:tweedie"),
        COUNT_POISSON("count:poisson"),
        REG_SQUAREDERROR("reg:squarederror"),
        REG_LINEAR("reg:linear"),
        MULTI_SOFTPROB("multi:softprob"),
        RANK_PAIRWISE("rank:pairwise");

        private String _id;

        ObjectiveType(String str) {
            this._id = str;
        }

        public String getId() {
            return this._id;
        }

        public static ObjectiveType fromXGBoost(String str) {
            for (ObjectiveType objectiveType : values()) {
                if (objectiveType.getId().equals(str)) {
                    return objectiveType;
                }
            }
            return null;
        }
    }

    public XGBoostMojoModel(String[] strArr, String[][] strArr2, String str) {
        super(strArr, strArr2, str);
    }

    public void postReadInit() {
    }

    @Override // hex.genmodel.GenModel
    public boolean requiresOffset() {
        return this._hasOffset;
    }

    @Override // hex.genmodel.GenModel
    public final double[] score0(double[] dArr, double[] dArr2) {
        if (this._hasOffset) {
            throw new IllegalStateException("Model was trained with offset, use score0 with offset");
        }
        return score0(dArr, 0.0d, dArr2);
    }

    public static double[] toPreds(double[] dArr, float[] fArr, double[] dArr2, int i, double[] dArr3, double d) {
        if (i > 2) {
            for (int i2 = 0; i2 < fArr.length; i2++) {
                dArr2[1 + i2] = fArr[i2];
            }
            dArr2[0] = GenModel.getPrediction(dArr2, dArr3, dArr, d);
        } else if (i == 2) {
            dArr2[1] = 1.0f - fArr[0];
            dArr2[2] = fArr[0];
            dArr2[0] = GenModel.getPrediction(dArr2, dArr3, dArr, d);
        } else {
            dArr2[0] = fArr[0];
        }
        return dArr2;
    }

    @Override // hex.genmodel.algos.tree.TreeBackedMojoModel
    public int getNTreeGroups() {
        return this._ntrees;
    }

    @Override // hex.genmodel.algos.tree.TreeBackedMojoModel
    public int getNTreesPerGroup() {
        if (this._nclasses > 2) {
            return this._nclasses;
        }
        return 1;
    }

    @Override // hex.genmodel.algos.tree.PlattScalingMojoHelper.MojoModelWithCalibration
    public double[] getCalibGlmBeta() {
        return this._calib_glm_beta;
    }

    @Override // hex.genmodel.GenModel
    public boolean calibrateClassProbabilities(double[] dArr) {
        return PlattScalingMojoHelper.calibrateClassProbabilities(this, dArr);
    }

    protected void constructSubgraph(RegTreeNode[] regTreeNodeArr, SharedTreeNode sharedTreeNode, int i, SharedTreeSubgraph sharedTreeSubgraph, boolean[] zArr, boolean z, String[] strArr) {
        RegTreeNode regTreeNode = regTreeNodeArr[i];
        if (zArr[regTreeNode.getSplitIndex()]) {
            sharedTreeNode.setSplitValue(1.0f);
        } else {
            sharedTreeNode.setSplitValue(regTreeNode.getSplitCondition());
        }
        sharedTreeNode.setPredValue(regTreeNode.getLeafValue());
        sharedTreeNode.setCol(regTreeNode.getSplitIndex(), strArr[regTreeNode.getSplitIndex()].split(" ")[1]);
        sharedTreeNode.setInclusiveNa(z);
        sharedTreeNode.setNodeNumber(i);
        if (regTreeNode.getLeftChildIndex() != -1) {
            constructSubgraph(regTreeNodeArr, sharedTreeSubgraph.makeLeftChildNode(sharedTreeNode), regTreeNode.getLeftChildIndex(), sharedTreeSubgraph, zArr, regTreeNode.default_left(), strArr);
        }
        if (regTreeNode.getRightChildIndex() != -1) {
            constructSubgraph(regTreeNodeArr, sharedTreeSubgraph.makeRightChildNode(sharedTreeNode), regTreeNode.getRightChildIndex(), sharedTreeSubgraph, zArr, !regTreeNode.default_left(), strArr);
        }
    }

    private String[] constructFeatureMap() {
        String[] split = this._featureMap.split("\n");
        int length = split.length;
        int i = 0;
        while (true) {
            if (i >= split.length) {
                break;
            }
            if (split[i].trim().isEmpty()) {
                length = i + 1;
                break;
            }
            i++;
        }
        return (String[]) Arrays.copyOfRange(split, 0, length);
    }

    protected boolean[] markOneHotEncodedCategoricals(String[] strArr) {
        int length = strArr.length;
        int i = -1;
        int i2 = 0;
        while (true) {
            if (i2 >= strArr.length) {
                break;
            }
            String[] split = strArr[i2].split(" ");
            if (!$assertionsDisabled && split.length <= 3) {
                throw new AssertionError();
            }
            if (!split[2].equals("i")) {
                i = i2;
                break;
            }
            i2++;
        }
        if (i == -1) {
            i = strArr.length;
        }
        boolean[] zArr = new boolean[length];
        for (int i3 = 0; i3 < length; i3++) {
            if (i3 < i) {
                zArr[i3] = true;
            }
        }
        return zArr;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public SharedTreeGraph computeGraph(GradBooster gradBooster, int i) {
        if (!(gradBooster instanceof GBTree)) {
            throw new IllegalArgumentException(String.format("Given XGBoost model is not backed by a tree-based booster. Booster class is %d", gradBooster.getClass().getCanonicalName()));
        }
        RegTree[][] groupedTrees = ((GBTree) gradBooster).getGroupedTrees();
        SharedTreeGraph sharedTreeGraph = new SharedTreeGraph();
        for (int i2 = 0; i2 < groupedTrees.length; i2++) {
            RegTree[] regTreeArr = groupedTrees[i2];
            if (i >= regTreeArr.length || i < 0) {
                throw new IllegalArgumentException(String.format("There is no such tree number for given class. Total number of trees is %d.", Integer.valueOf(regTreeArr.length)));
            }
            RegTreeNode[] nodes = regTreeArr[i].getNodes();
            if (!$assertionsDisabled && nodes.length < 1) {
                throw new AssertionError();
            }
            SharedTreeSubgraph makeSubgraph = sharedTreeGraph.makeSubgraph(String.format("Class %d", Integer.valueOf(i2)));
            String[] constructFeatureMap = constructFeatureMap();
            constructSubgraph(nodes, makeSubgraph.makeRootNode(), 0, makeSubgraph, markOneHotEncodedCategoricals(constructFeatureMap), true, constructFeatureMap);
        }
        return sharedTreeGraph;
    }

    public SharedTreeGraph convert(int i, String str, ConvertTreeOptions convertTreeOptions) {
        return convert(i, str);
    }

    static {
        $assertionsDisabled = !XGBoostMojoModel.class.desiredAssertionStatus();
    }
}
