package hex.genmodel.algos.gbm;

import hex.genmodel.GenModel;
import hex.genmodel.PredictContributions;
import hex.genmodel.algos.tree.SharedTreeGraphConverter;
import hex.genmodel.algos.tree.SharedTreeMojoModelWithContributions;
import hex.genmodel.algos.tree.TreeSHAPPredictor;
import hex.genmodel.utils.DistributionFamily;
import hex.genmodel.utils.LinkFunctionType;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;

/* loaded from: input_file:hex/genmodel/algos/gbm/GbmMojoModel.class */
public final class GbmMojoModel extends SharedTreeMojoModelWithContributions implements SharedTreeGraphConverter {
    public DistributionFamily _family;
    public LinkFunctionType _link_function;
    public double _init_f;

    public GbmMojoModel(String[] strArr, String[][] strArr2, String str) {
        super(strArr, strArr2, str);
    }

    @Override // hex.genmodel.algos.tree.SharedTreeMojoModelWithContributions
    protected PredictContributions getContributionsPredictor(TreeSHAPPredictor<double[]> treeSHAPPredictor) {
        return new SharedTreeMojoModelWithContributions.SharedTreeContributionsPredictor(this, treeSHAPPredictor);
    }

    @Override // hex.genmodel.algos.tree.SharedTreeMojoModelWithContributions, hex.genmodel.algos.tree.TreeBackedMojoModel
    public double getInitF() {
        return this._init_f;
    }

    @Override // hex.genmodel.GenModel
    public final double[] score0(double[] dArr, double d, double[] dArr2) {
        super.scoreAllTrees(dArr, dArr2);
        return unifyPreds(dArr, d, dArr2);
    }

    @Override // hex.genmodel.algos.tree.SharedTreeMojoModel
    public final double[] unifyPreds(double[] dArr, double d, double[] dArr2) {
        if (this._family == DistributionFamily.bernoulli || this._family == DistributionFamily.quasibinomial || this._family == DistributionFamily.modified_huber) {
            dArr2[2] = linkInv(this._link_function, dArr2[1] + this._init_f + d);
            dArr2[1] = 1.0d - dArr2[2];
        } else {
            if (this._family != DistributionFamily.multinomial) {
                dArr2[0] = linkInv(this._link_function, dArr2[0] + this._init_f + d);
                return dArr2;
            }
            if (this._nclasses == 2) {
                dArr2[1] = dArr2[1] + this._init_f + d;
                dArr2[2] = -dArr2[1];
            }
            GenModel.GBM_rescale(dArr2);
        }
        if (this._balanceClasses) {
            GenModel.correctProbabilities(dArr2, this._priorClassDistrib, this._modelClassDistrib);
        }
        dArr2[0] = GenModel.getPrediction(dArr2, this._priorClassDistrib, dArr, this._defaultThreshold);
        return dArr2;
    }

    private double linkInv(LinkFunctionType linkFunctionType, double d) {
        switch (linkFunctionType) {
            case log:
                return exp(d);
            case logit:
            case ologit:
                return 1.0d / (1.0d + exp(-d));
            case ologlog:
                return 1.0d - exp((-1.0d) * exp(d));
            case oprobit:
                return CMAESOptimizer.DEFAULT_STOPFITNESS;
            case inverse:
                return 1.0d / (d < CMAESOptimizer.DEFAULT_STOPFITNESS ? Math.min(-1.0E-5d, d) : Math.max(-1.0E-5d, d));
            case identity:
            default:
                return d;
        }
    }

    public static double exp(double d) {
        return Math.min(1.0E19d, Math.exp(d));
    }

    public static double log(double d) {
        double max = Math.max(CMAESOptimizer.DEFAULT_STOPFITNESS, d);
        if (max == CMAESOptimizer.DEFAULT_STOPFITNESS) {
            return -19.0d;
        }
        return Math.max(-19.0d, Math.log(max));
    }

    @Override // hex.genmodel.GenModel
    public double[] score0(double[] dArr, double[] dArr2) {
        return score0(dArr, CMAESOptimizer.DEFAULT_STOPFITNESS, dArr2);
    }

    public String[] leaf_node_assignment(double[] dArr) {
        return getDecisionPath(dArr);
    }

    @Override // hex.genmodel.GenModel
    public String[] getOutputNames() {
        return (this._family == DistributionFamily.quasibinomial && getDomainValues(getResponseIdx()) == null) ? new String[]{"predict", "pVal0", "pVal1"} : super.getOutputNames();
    }
}
