package org.apache.lens.ml;

import java.io.IOException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.MapredContext;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.lazy.LazyDouble;
import org.apache.hadoop.hive.serde2.lazy.objectinspector.primitive.LazyDoubleObjectInspector;
import org.apache.hadoop.hive.serde2.lazy.objectinspector.primitive.LazyPrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
import org.apache.hadoop.mapred.JobConf;

@Description(name = HiveMLUDF.UDF_NAME, value = "_FUNC_(algorithm, modelID, features...) - Run prediction algorithm with given algorithm name, model ID and input feature columns")
/* loaded from: input_file:org/apache/lens/ml/HiveMLUDF.class */
public final class HiveMLUDF extends GenericUDF {
    public static final String UDF_NAME = "predict";
    public static final Log LOG = LogFactory.getLog(HiveMLUDF.class);
    private JobConf conf;
    private StringObjectInspector soi;
    private LazyDoubleObjectInspector doi;
    private MLModel model;

    private HiveMLUDF() {
    }

    public ObjectInspector initialize(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        if (objectInspectorArr.length < 3) {
            throw new UDFArgumentLengthException("Algo name, model ID and at least one feature should be passed to predict");
        }
        LOG.info("predict initialized");
        return PrimitiveObjectInspectorFactory.javaDoubleObjectInspector;
    }

    public Object evaluate(GenericUDF.DeferredObject[] deferredObjectArr) throws HiveException {
        String primitiveJavaObject = this.soi.getPrimitiveJavaObject(deferredObjectArr[0].get());
        String primitiveJavaObject2 = this.soi.getPrimitiveJavaObject(deferredObjectArr[1].get());
        Double[] dArr = new Double[deferredObjectArr.length - 2];
        for (int i = 2; i < deferredObjectArr.length; i++) {
            LazyDouble lazyDouble = (LazyDouble) deferredObjectArr[i].get();
            dArr[i - 2] = Double.valueOf(lazyDouble == null ? 0.0d : this.doi.get(lazyDouble));
        }
        try {
            if (this.model == null) {
                this.model = ModelLoader.loadModel(this.conf, primitiveJavaObject, primitiveJavaObject2);
            }
            return this.model.predict(dArr);
        } catch (IOException e) {
            throw new HiveException(e);
        }
    }

    public String getDisplayString(String[] strArr) {
        return UDF_NAME;
    }

    public void configure(MapredContext mapredContext) {
        super.configure(mapredContext);
        this.conf = mapredContext.getJobConf();
        this.soi = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
        this.doi = LazyPrimitiveObjectInspectorFactory.LAZY_DOUBLE_OBJECT_INSPECTOR;
        LOG.info("predict configured. Model base dir path: " + this.conf.get(ModelLoader.MODEL_PATH_BASE_DIR));
    }
}
