package hex.genmodel.algos.xgboost;

import biz.k11i.xgboost.Predictor;
import biz.k11i.xgboost.learner.ObjFunction;
import biz.k11i.xgboost.util.FVec;
import hex.genmodel.GenModel;
import hex.genmodel.algos.tree.SharedTreeGraph;
import java.io.ByteArrayInputStream;
import java.io.IOException;

/* loaded from: input_file:www/3/h2o-genmodel.jar:hex/genmodel/algos/xgboost/XGBoostJavaMojoModel.class */
public final class XGBoostJavaMojoModel extends XGBoostMojoModel {
    private Predictor _predictor;
    private OneHotEncoderFactory _1hotFactory;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:www/3/h2o-genmodel.jar:hex/genmodel/algos/xgboost/XGBoostJavaMojoModel$OneHotEncoderFVec.class */
    public class OneHotEncoderFVec implements FVec {
        private final int[] _catMap;
        private final int[] _catValues;
        private final float[] _numValues;
        private final float _notHot;

        private OneHotEncoderFVec(int[] iArr, int[] iArr2, float[] fArr, float f) {
            this._catMap = iArr;
            this._catValues = iArr2;
            this._numValues = fArr;
            this._notHot = f;
        }

        @Override // biz.k11i.xgboost.util.FVec
        public final float fvalue(int i) {
            if (i >= this._catMap.length) {
                return this._numValues[i - this._catMap.length];
            }
            if (this._catValues[this._catMap[i]] == i) {
                return 1.0f;
            }
            return this._notHot;
        }
    }

    /* loaded from: input_file:www/3/h2o-genmodel.jar:hex/genmodel/algos/xgboost/XGBoostJavaMojoModel$OneHotEncoderFactory.class */
    private class OneHotEncoderFactory {
        private final int[] _catMap;
        private final float _notHot;

        OneHotEncoderFactory() {
            this._notHot = XGBoostJavaMojoModel.this._sparse ? Float.NaN : 0.0f;
            if (XGBoostJavaMojoModel.this._catOffsets == null) {
                this._catMap = new int[0];
                return;
            }
            this._catMap = new int[XGBoostJavaMojoModel.this._catOffsets[XGBoostJavaMojoModel.this._cats]];
            for (int i = 0; i < XGBoostJavaMojoModel.this._cats; i++) {
                for (int i2 = XGBoostJavaMojoModel.this._catOffsets[i]; i2 < XGBoostJavaMojoModel.this._catOffsets[i + 1]; i2++) {
                    this._catMap[i2] = i;
                }
            }
        }

        OneHotEncoderFVec fromArray(double[] dArr) {
            float[] fArr = new float[XGBoostJavaMojoModel.this._nums];
            int[] iArr = new int[XGBoostJavaMojoModel.this._cats];
            GenModel.setCats(dArr, iArr, XGBoostJavaMojoModel.this._cats, XGBoostJavaMojoModel.this._catOffsets, XGBoostJavaMojoModel.this._useAllFactorLevels);
            for (int i = 0; i < fArr.length; i++) {
                float f = (float) dArr[XGBoostJavaMojoModel.this._cats + i];
                fArr[i] = (XGBoostJavaMojoModel.this._sparse && f == 0.0f) ? Float.NaN : f;
            }
            return new OneHotEncoderFVec(this._catMap, iArr, fArr, this._notHot);
        }
    }

    /* loaded from: input_file:www/3/h2o-genmodel.jar:hex/genmodel/algos/xgboost/XGBoostJavaMojoModel$RegObjFunction.class */
    private static class RegObjFunction extends ObjFunction {
        private RegObjFunction() {
        }

        @Override // biz.k11i.xgboost.learner.ObjFunction
        public float[] predTransform(float[] fArr) {
            if (fArr.length != 1) {
                throw new IllegalStateException("Regression problem is supposed to have just a single predicted value, got " + fArr.length + " instead.");
            }
            fArr[0] = (float) Math.exp(fArr[0]);
            return fArr;
        }

        @Override // biz.k11i.xgboost.learner.ObjFunction
        public float predTransform(float f) {
            return (float) Math.exp(f);
        }
    }

    public XGBoostJavaMojoModel(byte[] bArr, String[] strArr, String[][] strArr2, String str) {
        super(strArr, strArr2, str);
        this._predictor = makePredictor(bArr);
    }

    @Override // hex.genmodel.algos.xgboost.XGBoostMojoModel
    public void postReadInit() {
        this._1hotFactory = new OneHotEncoderFactory();
    }

    private static Predictor makePredictor(byte[] bArr) {
        try {
            ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(bArr);
            Throwable th = null;
            try {
                Predictor predictor = new Predictor(byteArrayInputStream);
                if (byteArrayInputStream != null) {
                    if (0 != 0) {
                        try {
                            byteArrayInputStream.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        byteArrayInputStream.close();
                    }
                }
                return predictor;
            } finally {
            }
        } catch (IOException e) {
            throw new IllegalStateException(e);
        }
    }

    @Override // hex.genmodel.GenModel
    public final double[] score0(double[] dArr, double d, double[] dArr2) {
        if (d != 0.0d) {
            throw new UnsupportedOperationException("Unsupported: offset != 0");
        }
        return toPreds(dArr, this._predictor.predict(this._1hotFactory.fromArray(dArr)), dArr2, this._nclasses, this._priorClassDistrib, this._defaultThreshold);
    }

    static ObjFunction getObjFunction(String str) {
        return ObjFunction.fromName(str);
    }

    @Override // java.io.Closeable, java.lang.AutoCloseable
    public void close() {
        this._predictor = null;
    }

    @Override // hex.genmodel.algos.tree.SharedTreeGraphConverter
    public SharedTreeGraph convert(int i, String str) {
        return _computeGraph(this._predictor.getBooster(), i);
    }

    static {
        ObjFunction.register("reg:gamma", new RegObjFunction());
        ObjFunction.register("reg:tweedie", new RegObjFunction());
        ObjFunction.register("count:poisson", new RegObjFunction());
    }
}
