package hex.genmodel.algos.gam;

import hex.genmodel.utils.DistributionFamily;

/* loaded from: input_file:hex/genmodel/algos/gam/GamMojoMultinomialModel.class */
public class GamMojoMultinomialModel extends GamMojoModelBase {
    private boolean _trueMultinomial;

    /* JADX INFO: Access modifiers changed from: package-private */
    public GamMojoMultinomialModel(String[] strArr, String[][] strArr2, String str) {
        super(strArr, strArr2, str);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @Override // hex.genmodel.algos.gam.GamMojoModelBase
    public void init() {
        super.init();
        this._trueMultinomial = this._family.equals(DistributionFamily.multinomial);
    }

    @Override // hex.genmodel.algos.gam.GamMojoModelBase
    double[] gamScore0(double[] dArr, double[] dArr2) {
        if (dArr.length == nfeatures()) {
            this._beta_multinomial = this._beta_multinomial_center;
        } else {
            this._beta_multinomial = this._beta_multinomial_no_center;
        }
        for (int i = 0; i < this._nclasses; i++) {
            dArr2[i + 1] = generateEta(this._beta_multinomial[i], dArr);
        }
        return this._trueMultinomial ? postPredMultinomial(dArr2) : postPredOrdinal(dArr2);
    }

    double[] postPredMultinomial(double[] dArr) {
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i = 1; i < dArr.length; i++) {
            if (dArr[i] > d) {
                d = dArr[i];
            }
        }
        for (int i2 = 1; i2 < dArr.length; i2++) {
            double exp = Math.exp(dArr[i2] - d);
            dArr[i2] = exp;
            d2 += exp;
        }
        double d3 = 1.0d / d2;
        double d4 = 0.0d;
        for (int i3 = 1; i3 < dArr.length; i3++) {
            int i4 = i3;
            double d5 = dArr[i4] * d3;
            dArr[i4] = d5;
            if (d5 > d4) {
                d4 = dArr[i3];
                dArr[0] = i3 - 1;
            }
        }
        return dArr;
    }

    double[] postPredOrdinal(double[] dArr) {
        double d = 0.0d;
        dArr[0] = this._lastClass;
        int i = 0;
        while (true) {
            if (i >= this._lastClass) {
                break;
            }
            double d2 = dArr[i + 1];
            double exp = 1.0d / (1.0d + Math.exp(-d2));
            dArr[i + 1] = exp - d;
            d = exp;
            if (d2 > 0.0d) {
                dArr[0] = i;
                break;
            }
            i++;
        }
        for (int i2 = ((int) dArr[0]) + 1; i2 < this._lastClass; i2++) {
            double exp2 = 1.0d / (1.0d + Math.exp(-dArr[i2 + 1]));
            dArr[i2 + 1] = exp2 - d;
            d = exp2;
        }
        dArr[this._nclasses] = 1.0d - d;
        return dArr;
    }
}
