package biz.k11i.xgboost.learner;

import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
import net.jafama.FastMath;

/* loaded from: input_file:www/3/h2o-genmodel.jar:biz/k11i/xgboost/learner/ObjFunction.class */
public class ObjFunction implements Serializable {
    private static final Map<String, ObjFunction> FUNCTIONS = new HashMap();

    /* loaded from: input_file:www/3/h2o-genmodel.jar:biz/k11i/xgboost/learner/ObjFunction$RegLossObjLogistic.class */
    static class RegLossObjLogistic extends ObjFunction {
        RegLossObjLogistic() {
        }

        @Override // biz.k11i.xgboost.learner.ObjFunction
        public float[] predTransform(float[] fArr) {
            for (int i = 0; i < fArr.length; i++) {
                fArr[i] = sigmoid(fArr[i]);
            }
            return fArr;
        }

        @Override // biz.k11i.xgboost.learner.ObjFunction
        public float predTransform(float f) {
            return sigmoid(f);
        }

        float sigmoid(float f) {
            return 1.0f / (1.0f + ((float) Math.exp(-f)));
        }

        @Override // biz.k11i.xgboost.learner.ObjFunction
        public float probToMargin(float f) {
            return (float) (-Math.log((1.0f / f) - 1.0f));
        }
    }

    /* loaded from: input_file:www/3/h2o-genmodel.jar:biz/k11i/xgboost/learner/ObjFunction$RegLossObjLogistic_Jafama.class */
    static class RegLossObjLogistic_Jafama extends RegLossObjLogistic {
        RegLossObjLogistic_Jafama() {
        }

        @Override // biz.k11i.xgboost.learner.ObjFunction.RegLossObjLogistic
        float sigmoid(float f) {
            return (float) (1.0d / (1.0d + FastMath.exp(-f)));
        }
    }

    /* loaded from: input_file:www/3/h2o-genmodel.jar:biz/k11i/xgboost/learner/ObjFunction$RegObjFunction.class */
    static class RegObjFunction extends ObjFunction {
        RegObjFunction() {
        }

        @Override // biz.k11i.xgboost.learner.ObjFunction
        public float[] predTransform(float[] fArr) {
            if (fArr.length != 1) {
                throw new IllegalStateException("Regression problem is supposed to have just a single predicted value, got " + fArr.length + " instead.");
            }
            fArr[0] = (float) Math.exp(fArr[0]);
            return fArr;
        }

        @Override // biz.k11i.xgboost.learner.ObjFunction
        public float predTransform(float f) {
            return (float) Math.exp(f);
        }

        @Override // biz.k11i.xgboost.learner.ObjFunction
        public float probToMargin(float f) {
            return (float) Math.log(f);
        }
    }

    /* loaded from: input_file:www/3/h2o-genmodel.jar:biz/k11i/xgboost/learner/ObjFunction$SoftmaxMultiClassObjClassify.class */
    static class SoftmaxMultiClassObjClassify extends ObjFunction {
        SoftmaxMultiClassObjClassify() {
        }

        @Override // biz.k11i.xgboost.learner.ObjFunction
        public float[] predTransform(float[] fArr) {
            int i = 0;
            float f = fArr[0];
            for (int i2 = 1; i2 < fArr.length; i2++) {
                if (f < fArr[i2]) {
                    i = i2;
                    f = fArr[i2];
                }
            }
            return new float[]{i};
        }

        @Override // biz.k11i.xgboost.learner.ObjFunction
        public float predTransform(float f) {
            throw new UnsupportedOperationException();
        }
    }

    /* loaded from: input_file:www/3/h2o-genmodel.jar:biz/k11i/xgboost/learner/ObjFunction$SoftmaxMultiClassObjProb.class */
    static class SoftmaxMultiClassObjProb extends ObjFunction {
        SoftmaxMultiClassObjProb() {
        }

        @Override // biz.k11i.xgboost.learner.ObjFunction
        public float[] predTransform(float[] fArr) {
            float f = fArr[0];
            for (int i = 1; i < fArr.length; i++) {
                f = Math.max(fArr[i], f);
            }
            double d = 0.0d;
            for (int i2 = 0; i2 < fArr.length; i2++) {
                fArr[i2] = exp(fArr[i2] - f);
                d += fArr[i2];
            }
            for (int i3 = 0; i3 < fArr.length; i3++) {
                int i4 = i3;
                fArr[i4] = fArr[i4] / ((float) d);
            }
            return fArr;
        }

        @Override // biz.k11i.xgboost.learner.ObjFunction
        public float predTransform(float f) {
            throw new UnsupportedOperationException();
        }

        float exp(float f) {
            return (float) Math.exp(f);
        }
    }

    /* loaded from: input_file:www/3/h2o-genmodel.jar:biz/k11i/xgboost/learner/ObjFunction$SoftmaxMultiClassObjProb_Jafama.class */
    static class SoftmaxMultiClassObjProb_Jafama extends SoftmaxMultiClassObjProb {
        SoftmaxMultiClassObjProb_Jafama() {
        }

        @Override // biz.k11i.xgboost.learner.ObjFunction.SoftmaxMultiClassObjProb
        float exp(float f) {
            return (float) FastMath.exp(f);
        }
    }

    public static ObjFunction fromName(String str) {
        ObjFunction objFunction = FUNCTIONS.get(str);
        if (objFunction == null) {
            throw new IllegalArgumentException(str + " is not supported objective function.");
        }
        return objFunction;
    }

    public static void register(String str, ObjFunction objFunction) {
        FUNCTIONS.put(str, objFunction);
    }

    public static void useFastMathExp(boolean z) {
        if (z) {
            register("binary:logistic", new RegLossObjLogistic_Jafama());
            register("multi:softprob", new SoftmaxMultiClassObjProb_Jafama());
        } else {
            register("binary:logistic", new RegLossObjLogistic());
            register("multi:softprob", new SoftmaxMultiClassObjProb());
        }
    }

    public float[] predTransform(float[] fArr) {
        return fArr;
    }

    public float predTransform(float f) {
        return f;
    }

    public float probToMargin(float f) {
        return f;
    }

    static {
        register("rank:pairwise", new ObjFunction());
        register("binary:logistic", new RegLossObjLogistic());
        register("binary:logitraw", new ObjFunction());
        register("multi:softmax", new SoftmaxMultiClassObjClassify());
        register("multi:softprob", new SoftmaxMultiClassObjProb());
        register("reg:linear", new ObjFunction());
        register("reg:squarederror", new ObjFunction());
        register("reg:gamma", new RegObjFunction());
        register("reg:tweedie", new RegObjFunction());
        register("count:poisson", new RegObjFunction());
    }
}
