package hivemall.xgboost;

import hivemall.UDTFWithOptions;
import hivemall.utils.hadoop.HadoopUtils;
import hivemall.utils.hadoop.HiveUtils;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.annotation.Nonnull;
import ml.dmlc.xgboost4j.LabeledPoint;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoostError;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;

/* loaded from: input_file:hivemall/xgboost/XGBoostUDTF.class */
public abstract class XGBoostUDTF extends UDTFWithOptions {
    private static final Log logger = LogFactory.getLog(XGBoostUDTF.class);
    private final List<LabeledPoint> featuresList;
    private ListObjectInspector featureListOI;
    private PrimitiveObjectInspector featureElemOI;
    private PrimitiveObjectInspector targetOI;

    @Nonnull
    protected final Map<String, Object> params = new HashMap();

    public XGBoostUDTF() {
        this.params.put("booster", "gbtree");
        this.params.put("num_round", 8);
        this.params.put("silent", 1);
        this.params.put("nthread", 1);
        this.params.put("alpha", Double.valueOf(0.0d));
        this.params.put("eta", Double.valueOf(0.3d));
        this.params.put("gamma", Double.valueOf(0.0d));
        this.params.put("max_depth", 6);
        this.params.put("min_child_weight", 1);
        this.params.put("max_delta_step", 0);
        this.params.put("subsample", Double.valueOf(1.0d));
        this.params.put("colsample_bytree", Double.valueOf(1.0d));
        this.params.put("colsample_bylevel", Double.valueOf(1.0d));
        this.params.put("tree_method", "exact");
        this.params.put("base_score", Double.valueOf(0.5d));
        this.featuresList = new ArrayList(1024);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Options getOptions() {
        Options options = new Options();
        options.addOption("booster", true, "Set a booster to use, gbtree or gblinear. [default: gbree]");
        options.addOption("num_round", true, "Number of boosting iterations [default: 8]");
        options.addOption("silent", true, "0 means printing running messages, 1 means silent mode [default: 1]");
        options.addOption("nthread", true, "Number of parallel threads used to run xgboost [default: 1]");
        options.addOption("num_pbuffer", true, "Size of prediction buffer [set automatically by xgboost]");
        options.addOption("num_feature", true, "Feature dimension used in boosting [default: set automatically by xgboost]");
        options.addOption("alpha", true, "L1 regularization term on weights [default: 0.0]");
        options.addOption("lambda", true, "L2 regularization term on weights [default: 1.0 for gbtree, 0.0 for gblinear]");
        options.addOption("eta", true, "Step size shrinkage used in update to prevents overfitting [default: 0.3]");
        options.addOption("gamma", true, "Minimum loss reduction required to make a further partition on a leaf node of the tree [default: 0.0]");
        options.addOption("max_depth", true, "Max depth of decision tree [default: 6]");
        options.addOption("min_child_weight", true, "Minimum sum of instance weight(hessian) needed in a child [default: 1]");
        options.addOption("max_delta_step", true, "Maximum delta step we allow each tree's weight estimation to be [default: 0]");
        options.addOption("subsample", true, "Subsample ratio of the training instance [default: 1.0]");
        options.addOption("colsample_bytree", true, "Subsample ratio of columns when constructing each tree [default: 1.0]");
        options.addOption("colsample_bylevel", true, "Subsample ratio of columns for each split, in each level [default: 1.0]");
        options.addOption("lambda_bias", true, "L2 regularization term on bias [default: 0.0]");
        options.addOption("base_score", true, "Initial prediction score of all instances, global bias [default: 0.5]");
        options.addOption("eval_metric", true, "Evaluation metrics for validation data [default according to objective]");
        return options;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public CommandLine processOptions(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        CommandLine commandLine = null;
        if (objectInspectorArr.length >= 3) {
            commandLine = parseOptions(HiveUtils.getConstString(objectInspectorArr[2]));
            if (commandLine.hasOption("booster")) {
                this.params.put("booster", commandLine.getOptionValue("booster"));
            }
            if (commandLine.hasOption("num_round")) {
                this.params.put("num_round", Integer.valueOf(commandLine.getOptionValue("num_round")));
            }
            if (commandLine.hasOption("silent")) {
                this.params.put("silent", Integer.valueOf(commandLine.getOptionValue("silent")));
            }
            if (commandLine.hasOption("nthread")) {
                this.params.put("nthread", Integer.valueOf(commandLine.getOptionValue("nthread")));
            }
            if (commandLine.hasOption("num_pbuffer")) {
                this.params.put("num_pbuffer", Integer.valueOf(commandLine.getOptionValue("num_pbuffer")));
            }
            if (commandLine.hasOption("num_feature")) {
                this.params.put("num_feature", Integer.valueOf(commandLine.getOptionValue("num_feature")));
            }
            if (commandLine.hasOption("alpha")) {
                this.params.put("alpha", Double.valueOf(commandLine.getOptionValue("alpha")));
            }
            if (commandLine.hasOption("lambda")) {
                this.params.put("lambda", Double.valueOf(commandLine.getOptionValue("lambda")));
            }
            if (commandLine.hasOption("eta")) {
                this.params.put("eta", Double.valueOf(commandLine.getOptionValue("eta")));
            }
            if (commandLine.hasOption("gamma")) {
                this.params.put("gamma", Double.valueOf(commandLine.getOptionValue("gamma")));
            }
            if (commandLine.hasOption("max_depth")) {
                this.params.put("max_depth", Integer.valueOf(commandLine.getOptionValue("max_depth")));
            }
            if (commandLine.hasOption("min_child_weight")) {
                this.params.put("min_child_weight", Integer.valueOf(commandLine.getOptionValue("min_child_weight")));
            }
            if (commandLine.hasOption("max_delta_step")) {
                this.params.put("max_delta_step", Integer.valueOf(commandLine.getOptionValue("max_delta_step")));
            }
            if (commandLine.hasOption("subsample")) {
                this.params.put("subsample", Double.valueOf(commandLine.getOptionValue("subsample")));
            }
            if (commandLine.hasOption("colsample_bytree")) {
                this.params.put("colsamle_bytree", Double.valueOf(commandLine.getOptionValue("colsample_bytree")));
            }
            if (commandLine.hasOption("colsample_bylevel")) {
                this.params.put("colsamle_bylevel", Double.valueOf(commandLine.getOptionValue("colsample_bylevel")));
            }
            if (commandLine.hasOption("lambda_bias")) {
                this.params.put("lambda_bias", Double.valueOf(commandLine.getOptionValue("lambda_bias")));
            }
            if (commandLine.hasOption("base_score")) {
                this.params.put("base_score", Double.valueOf(commandLine.getOptionValue("base_score")));
            }
            if (commandLine.hasOption("eval_metric")) {
                this.params.put("eval_metric", commandLine.getOptionValue("eval_metric"));
            }
        }
        try {
            createXGBooster(this.params, this.featuresList);
            return commandLine;
        } catch (Exception e) {
            throw new UDFArgumentException(e);
        }
    }

    @Nonnull
    private static StructObjectInspector getReturnOIs() {
        ArrayList arrayList = new ArrayList(2);
        ArrayList arrayList2 = new ArrayList(2);
        arrayList.add("model_id");
        arrayList2.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector);
        arrayList.add("pred_model");
        arrayList2.add(PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector);
        return ObjectInspectorFactory.getStandardStructObjectInspector(arrayList, arrayList2);
    }

    public StructObjectInspector initialize(@Nonnull ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        processOptions(objectInspectorArr);
        ListObjectInspector asListOI = HiveUtils.asListOI(objectInspectorArr[0]);
        ObjectInspector listElementObjectInspector = asListOI.getListElementObjectInspector();
        this.featureListOI = asListOI;
        this.featureElemOI = HiveUtils.asStringOI(listElementObjectInspector);
        this.targetOI = HiveUtils.asDoubleCompatibleOI(objectInspectorArr[1]);
        return getReturnOIs();
    }

    protected void checkTargetValue(double d) throws HiveException {
    }

    public void process(@Nonnull Object[] objArr) throws HiveException {
        if (objArr[0] == null) {
            return;
        }
        List list = this.featureListOI.getList(objArr[0]);
        String[] strArr = new String[list.size()];
        for (int i = 0; i < list.size(); i++) {
            strArr[i] = (String) this.featureElemOI.getPrimitiveJavaObject(list.get(i));
        }
        double d = PrimitiveObjectInspectorUtils.getDouble(objArr[1], this.targetOI);
        checkTargetValue(d);
        LabeledPoint parseFeatures = XGBoostUtils.parseFeatures(d, strArr);
        if (parseFeatures != null) {
            this.featuresList.add(parseFeatures);
        }
    }

    @Nonnull
    private static String generateUniqueModelId() {
        return "xgbmodel-" + HadoopUtils.getUniqueTaskIdString();
    }

    @Nonnull
    private static Booster createXGBooster(Map<String, Object> map, List<LabeledPoint> list) throws NoSuchMethodException, XGBoostError, IllegalAccessException, InvocationTargetException, InstantiationException {
        Constructor declaredConstructor = Booster.class.getDeclaredConstructor(Map.class, DMatrix[].class);
        declaredConstructor.setAccessible(true);
        return (Booster) declaredConstructor.newInstance(map, new DMatrix[]{new DMatrix(list.iterator(), "")});
    }

    public void close() throws HiveException {
        try {
            DMatrix dMatrix = new DMatrix(this.featuresList.iterator(), "");
            Booster createXGBooster = createXGBooster(this.params, this.featuresList);
            int intValue = ((Integer) this.params.get("num_round")).intValue();
            for (int i = 0; i < intValue; i++) {
                createXGBooster.update(dMatrix, i);
            }
            String generateUniqueModelId = generateUniqueModelId();
            byte[] byteArray = createXGBooster.toByteArray();
            logger.info("model_id:" + generateUniqueModelId.toString() + " size:" + byteArray.length);
            forward(new Object[]{generateUniqueModelId, byteArray});
        } catch (Exception e) {
            throw new HiveException(e);
        }
    }

    static {
        NativeLibLoader.initXGBoost();
    }
}
