package com.linkedin.dagli.xgboost;

import com.linkedin.dagli.annotation.equality.ValueEquality;
import com.linkedin.dagli.function.FunctionResult1;
import com.linkedin.dagli.generator.Constant;
import com.linkedin.dagli.input.DenseFeatureVectorInput;
import com.linkedin.dagli.math.vector.DenseFloatArrayVector;
import com.linkedin.dagli.math.vector.DenseVector;
import com.linkedin.dagli.math.vector.Vector;
import com.linkedin.dagli.preparer.AbstractStreamPreparer3;
import com.linkedin.dagli.preparer.PreparerContext;
import com.linkedin.dagli.preparer.PreparerResult;
import com.linkedin.dagli.producer.MissingInput;
import com.linkedin.dagli.producer.Producer;
import com.linkedin.dagli.transformer.AbstractPreparableTransformer3;
import com.linkedin.dagli.transformer.AbstractPreparedTransformer2;
import com.linkedin.dagli.transformer.AbstractPreparedTransformer3;
import com.linkedin.dagli.transformer.PreparedTransformer;
import com.linkedin.dagli.tuple.Tuple3;
import com.linkedin.dagli.util.array.ArraysEx;
import com.linkedin.dagli.util.invariant.Arguments;
import com.linkedin.dagli.vector.CategoricalFeatureVector;
import com.linkedin.dagli.vector.DensifiedVector;
import com.linkedin.dagli.view.PreparedTransformerView;
import com.linkedin.dagli.xgboost.AbstractXGBoostModel;
import com.linkedin.dagli.xgboost.AbstractXGBoostModel.Prepared;
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.HashMap;
import ml.dmlc.xgboost4j.LabeledPoint;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.IEvaluation;
import ml.dmlc.xgboost4j.java.IObjective;
import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:com/linkedin/dagli/xgboost/AbstractXGBoostModel.class */
public abstract class AbstractXGBoostModel<L, R, P extends Prepared<L, R, ?>, S extends AbstractXGBoostModel<L, R, P, S>> extends AbstractPreparableTransformer3<Number, L, DenseVector, R, P, S> {
    private static final long serialVersionUID = 1;
    private static final int MISSING_ID_MARKER = -1;
    protected double _learningRateMultiplier;
    protected int _maxDepth;
    protected boolean _silent;
    protected int _rounds;
    protected int _threadCount;
    protected int _earlyStoppingRounds;

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:com/linkedin/dagli/xgboost/AbstractXGBoostModel$Prepared.class */
    public static abstract class Prepared<L, R, S extends Prepared<L, R, S>> extends AbstractPreparedTransformer3<Number, L, DenseVector, R, S> {
        private static final long serialVersionUID = 1;
        protected Booster _booster;

        /* JADX INFO: Access modifiers changed from: package-private */
        public Prepared(Booster booster) {
            this._booster = booster;
        }

        public Booster getBooster() {
            return this._booster;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/linkedin/dagli/xgboost/AbstractXGBoostModel$Preparer.class */
    public static class Preparer<L, R, P extends Prepared<L, R, ?>> extends AbstractStreamPreparer3<Number, L, DenseVector, R, P> {
        private final ArrayList<Tuple3<Number, L, DenseVector>> _labeledVectorList = new ArrayList<>();
        private final AbstractXGBoostModel<L, R, P, ?> _owner;

        public Preparer(AbstractXGBoostModel<L, R, P, ?> abstractXGBoostModel) {
            this._owner = abstractXGBoostModel;
        }

        public void process(Number number, L l, DenseVector denseVector) {
            this._labeledVectorList.add(Tuple3.of(number, l, denseVector));
        }

        /* renamed from: finish, reason: merged with bridge method [inline-methods] */
        public PreparerResult<P> m1finish() {
            boolean z = this._owner.getObjectiveType() == XGBoostObjectiveType.REGRESSION;
            Object2IntOpenHashMap<L> object2IntOpenHashMap = new Object2IntOpenHashMap<>();
            object2IntOpenHashMap.defaultReturnValue(AbstractXGBoostModel.MISSING_ID_MARKER);
            try {
                DMatrix dMatrix = new DMatrix(this._labeledVectorList.stream().map(tuple3 -> {
                    float floatValue;
                    if (z) {
                        floatValue = ((Number) tuple3.get1()).floatValue();
                    } else {
                        int i = object2IntOpenHashMap.getInt(tuple3.get1());
                        if (i == AbstractXGBoostModel.MISSING_ID_MARKER) {
                            i = object2IntOpenHashMap.size();
                            object2IntOpenHashMap.put(tuple3.get1(), i);
                        }
                        floatValue = i;
                    }
                    return AbstractXGBoostModel.makeDenseLabeledPoint((Number) tuple3.get0(), floatValue, (DenseVector) tuple3.get2());
                }).iterator(), (String) null);
                float[] fArr = new float[this._labeledVectorList.size()];
                boolean z2 = false;
                for (int i = 0; i < this._labeledVectorList.size(); i++) {
                    Number number = (Number) this._labeledVectorList.get(i).get0();
                    if (number != null) {
                        z2 = true;
                        fArr[i] = number.floatValue();
                    } else {
                        fArr[i] = 1.0f;
                    }
                }
                if (z2) {
                    dMatrix.setWeight(fArr);
                }
                XGBoostObjective objective = this._owner.getObjective(object2IntOpenHashMap.size());
                HashMap hashMap = new HashMap();
                hashMap.put("eta", Double.valueOf(this._owner.getLearningRateMultiplier()));
                hashMap.put("max_depth", Integer.valueOf(this._owner.getMaxDepth()));
                hashMap.put("silent", Integer.valueOf(this._owner.isSilent() ? 1 : 0));
                hashMap.put("objective", objective.getObjectiveName());
                hashMap.put("nthread", Integer.valueOf(this._owner._threadCount <= 0 ? Runtime.getRuntime().availableProcessors() : this._owner._threadCount));
                if (objective.shouldSpecifyNumberOfClasses()) {
                    hashMap.put("num_class", Integer.valueOf(object2IntOpenHashMap.size()));
                }
                HashMap hashMap2 = new HashMap();
                hashMap2.put("train", dMatrix);
                try {
                    try {
                        Booster train = XGBoost.train(dMatrix, hashMap, this._owner.getRounds(), hashMap2, (float[][]) null, (IObjective) null, (IEvaluation) null, this._owner._earlyStoppingRounds);
                        XGBoostModel.IS_THREAD_CONFIGURED_FOR_SINGLE_THREADED_PREDICTION.set(false);
                        PreparerResult<P> preparerResult = new PreparerResult<>(this._owner.createPrepared(object2IntOpenHashMap, train));
                        dMatrix.dispose();
                        return preparerResult;
                    } catch (XGBoostError e) {
                        throw new RuntimeException("Encountered an XGBoostException while training model", e);
                    }
                } catch (Throwable th) {
                    dMatrix.dispose();
                    throw th;
                }
            } catch (XGBoostError e2) {
                throw new RuntimeException((Throwable) e2);
            }
        }

        /* JADX WARN: Multi-variable type inference failed */
        public /* bridge */ /* synthetic */ void process(Object obj, Object obj2, Object obj3) {
            process((Number) obj, (Number) obj2, (DenseVector) obj3);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    @ValueEquality
    /* loaded from: input_file:com/linkedin/dagli/xgboost/AbstractXGBoostModel$XGBoostLeaves.class */
    public static class XGBoostLeaves extends AbstractPreparedTransformer2<Prepared<?, ?, ?>, DenseVector, int[], XGBoostLeaves> {
        private static final long serialVersionUID = 1;

        XGBoostLeaves(Producer<? extends Prepared<?, ?, ?>> producer, Producer<? extends DenseVector> producer2) {
            super(producer, producer2);
        }

        public int[] apply(Prepared<?, ?, ?> prepared, DenseVector denseVector) {
            return ArraysEx.toIntegersLossy(XGBoostModel.predictAsFloats(prepared.getBooster(), denseVector, (booster, dMatrix) -> {
                return booster.predictLeaf(dMatrix, 0)[0];
            }));
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:com/linkedin/dagli/xgboost/AbstractXGBoostModel$XGBoostObjective.class */
    public enum XGBoostObjective {
        REGRESSION_SQUARED_ERROR("reg:squarederror", XGBoostObjectiveType.REGRESSION, false),
        CLASSIFICATION_SOFTMAX("multi:softprob", XGBoostObjectiveType.CLASSIFICATON, true),
        CLASSIFICATION_LOGISTIC_REGRESSION("binary:logistic", XGBoostObjectiveType.CLASSIFICATON, false);

        private String _objectiveName;
        private XGBoostObjectiveType _type;
        private boolean _shouldSpecifyNumberOfClasses;

        public String getObjectiveName() {
            return this._objectiveName;
        }

        public XGBoostObjectiveType getType() {
            return this._type;
        }

        public boolean shouldSpecifyNumberOfClasses() {
            return this._shouldSpecifyNumberOfClasses;
        }

        XGBoostObjective(String str, XGBoostObjectiveType xGBoostObjectiveType, boolean z) {
            this._objectiveName = str;
            this._type = xGBoostObjectiveType;
            this._shouldSpecifyNumberOfClasses = z;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:com/linkedin/dagli/xgboost/AbstractXGBoostModel$XGBoostObjectiveType.class */
    public enum XGBoostObjectiveType {
        REGRESSION,
        CLASSIFICATON
    }

    public AbstractXGBoostModel() {
        super(Constant.nullValue(), MissingInput.get(), MissingInput.get());
        this._learningRateMultiplier = 0.3d;
        this._maxDepth = 3;
        this._silent = true;
        this._rounds = 4;
        this._threadCount = MISSING_ID_MARKER;
        this._earlyStoppingRounds = MISSING_ID_MARKER;
    }

    public S withWeightInput(Producer<? extends Number> producer) {
        return clone(abstractXGBoostModel -> {
            abstractXGBoostModel._input1 = producer;
        });
    }

    public S withLabelInput(Producer<? extends L> producer) {
        return clone(abstractXGBoostModel -> {
            abstractXGBoostModel._input2 = producer;
        });
    }

    public S withFeaturesInput(Producer<? extends Vector> producer) {
        return withInput3(DensifiedVector.densifyIfSparse(producer));
    }

    public DenseFeatureVectorInput<S> withFeaturesInput() {
        return new DenseFeatureVectorInput<>(producer -> {
            return this.withInput3(producer);
        });
    }

    public double getLearningRateMultiplier() {
        return this._learningRateMultiplier;
    }

    public S withLearningRateMultiplier(double d) {
        Arguments.check(d > 0.0d, "Learning rate multiplier must be > 0");
        return clone(abstractXGBoostModel -> {
            abstractXGBoostModel._learningRateMultiplier = d;
        });
    }

    public int getMaxDepth() {
        return this._maxDepth;
    }

    public S withMaxDepth(int i) {
        Arguments.check(i >= 1, "Maximum tree depth must be at least 1");
        return clone(abstractXGBoostModel -> {
            abstractXGBoostModel._maxDepth = i;
        });
    }

    public boolean isSilent() {
        return this._silent;
    }

    public S withSilent(boolean z) {
        return clone(abstractXGBoostModel -> {
            abstractXGBoostModel._silent = z;
        });
    }

    public int getRounds() {
        return this._rounds;
    }

    public S withRounds(int i) {
        return clone(abstractXGBoostModel -> {
            abstractXGBoostModel._rounds = i;
        });
    }

    public int getThreadCount() {
        return this._threadCount;
    }

    public S withThreadCount(int i) {
        return clone(abstractXGBoostModel -> {
            abstractXGBoostModel._threadCount = i;
        });
    }

    public S withEarlyStopping(boolean z) {
        return clone(abstractXGBoostModel -> {
            abstractXGBoostModel._earlyStoppingRounds = z ? 2 : MISSING_ID_MARKER;
        });
    }

    public boolean isEarlyStopping() {
        return this._earlyStoppingRounds >= 1;
    }

    protected abstract XGBoostObjective getObjective(int i);

    protected abstract XGBoostObjectiveType getObjectiveType();

    /* JADX INFO: Access modifiers changed from: protected */
    public static LabeledPoint makeDenseLabeledPoint(Number number, float f, DenseVector denseVector) {
        float[] array = denseVector instanceof DenseFloatArrayVector ? ((DenseFloatArrayVector) denseVector).getArray() : denseVector.toFloatArray();
        if (array.length == 0) {
            array = new float[]{0.0f};
        }
        return new LabeledPoint(f, (int[]) null, array, number != null ? number.floatValue() : 1.0f, MISSING_ID_MARKER, Float.NaN);
    }

    protected abstract P createPrepared(Object2IntOpenHashMap<L> object2IntOpenHashMap, Booster booster);

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: getPreparer, reason: merged with bridge method [inline-methods] */
    public Preparer<L, R, P> m0getPreparer(PreparerContext preparerContext) {
        return new Preparer<>(this);
    }

    public PreparedTransformer<int[]> asLeafIDArray() {
        return new XGBoostLeaves(new PreparedTransformerView(this), getInput3());
    }

    public PreparedTransformer<Vector> asLeafFeatures() {
        return new CategoricalFeatureVector().withInputList(new FunctionResult1(ArraysEx::asList).withInput(asLeafIDArray()));
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = MISSING_ID_MARKER;
        switch (implMethodName.hashCode()) {
            case -1409366032:
                if (implMethodName.equals("asList")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("com/linkedin/dagli/util/function/Function1$Serializable") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/linkedin/dagli/util/array/ArraysEx") && serializedLambda.getImplMethodSignature().equals("([I)Lit/unimi/dsi/fastutil/ints/IntArrayList;")) {
                    return ArraysEx::asList;
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
