package hivemall.smile.tools;

import hivemall.UDFWithOptions;
import hivemall.math.vector.DenseVector;
import hivemall.math.vector.SparseVector;
import hivemall.math.vector.Vector;
import hivemall.smile.classification.DecisionTree;
import hivemall.smile.classification.PredictionHandler;
import hivemall.smile.regression.RegressionTree;
import hivemall.utils.codec.Base91;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hadoop.WritableUtils;
import hivemall.utils.lang.Preconditions;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;

@UDFType(deterministic = true, stateful = false)
@Description(name = "tree_predict", value = "_FUNC_(string modelId, string model, array<double|string> features [, const string options | const boolean classification=false]) - Returns a prediction result of a random forest in <int value, array<double> a posteriori> for classification and <double> for regression")
/* loaded from: input_file:hivemall/smile/tools/TreePredictUDF.class */
public final class TreePredictUDF extends UDFWithOptions {
    private boolean classification;
    private StringObjectInspector modelOI;
    private ListObjectInspector featureListOI;
    private PrimitiveObjectInspector featureElemOI;
    private boolean denseInput;

    @Nullable
    private Vector featuresProbe;

    @Nullable
    private transient Evaluator evaluator;

    /* loaded from: input_file:hivemall/smile/tools/TreePredictUDF$ClassificationEvaluator.class */
    static final class ClassificationEvaluator implements Evaluator {

        @Nullable
        private String prevModelId = null;
        private DecisionTree.Node cNode = null;

        @Nonnull
        private final Object[] result = new Object[2];

        ClassificationEvaluator() {
        }

        @Override // hivemall.smile.tools.TreePredictUDF.Evaluator
        @Nonnull
        public Object[] evaluate(@Nonnull String str, @Nonnull Text text, @Nonnull Vector vector) throws HiveException {
            if (!str.equals(this.prevModelId)) {
                this.prevModelId = str;
                byte[] decode = Base91.decode(text.getBytes(), 0, text.getLength());
                this.cNode = DecisionTree.deserialize(decode, decode.length, true);
            }
            Arrays.fill(this.result, (Object) null);
            Preconditions.checkNotNull(this.cNode);
            this.cNode.predict(vector, new PredictionHandler() { // from class: hivemall.smile.tools.TreePredictUDF.ClassificationEvaluator.1
                @Override // hivemall.smile.classification.PredictionHandler
                public void handle(int i, double[] dArr) {
                    ClassificationEvaluator.this.result[0] = new IntWritable(i);
                    ClassificationEvaluator.this.result[1] = WritableUtils.toWritableList(dArr);
                }
            });
            return this.result;
        }
    }

    /* loaded from: input_file:hivemall/smile/tools/TreePredictUDF$Evaluator.class */
    interface Evaluator {
        @Nonnull
        Object evaluate(@Nonnull String str, @Nonnull Text text, @Nonnull Vector vector) throws HiveException;
    }

    /* loaded from: input_file:hivemall/smile/tools/TreePredictUDF$RegressionEvaluator.class */
    static final class RegressionEvaluator implements Evaluator {

        @Nullable
        private String prevModelId = null;
        private RegressionTree.Node rNode = null;

        @Nonnull
        private final DoubleWritable result = new DoubleWritable();

        RegressionEvaluator() {
        }

        @Override // hivemall.smile.tools.TreePredictUDF.Evaluator
        @Nonnull
        public DoubleWritable evaluate(@Nonnull String str, @Nonnull Text text, @Nonnull Vector vector) throws HiveException {
            if (!str.equals(this.prevModelId)) {
                this.prevModelId = str;
                byte[] decode = Base91.decode(text.getBytes(), 0, text.getLength());
                this.rNode = RegressionTree.deserialize(decode, decode.length, true);
            }
            Preconditions.checkNotNull(this.rNode);
            this.result.set(this.rNode.predict(vector));
            return this.result;
        }
    }

    @Override // hivemall.UDFWithOptions
    protected Options getOptions() {
        Options options = new Options();
        options.addOption("c", "classification", false, "Predict as classification [default: not enabled]");
        return options;
    }

    @Override // hivemall.UDFWithOptions
    protected CommandLine processOptions(@Nonnull String str) throws UDFArgumentException {
        CommandLine parseOptions = parseOptions(str);
        this.classification = parseOptions.hasOption("classification");
        return parseOptions;
    }

    public ObjectInspector initialize(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        if (objectInspectorArr.length != 3 && objectInspectorArr.length != 4) {
            throw new UDFArgumentException("tree_predict takes 3 or 4 arguments");
        }
        this.modelOI = HiveUtils.asStringOI(objectInspectorArr[1]);
        ListObjectInspector asListOI = HiveUtils.asListOI(objectInspectorArr[2]);
        this.featureListOI = asListOI;
        ObjectInspector listElementObjectInspector = asListOI.getListElementObjectInspector();
        if (HiveUtils.isNumberOI(listElementObjectInspector)) {
            this.featureElemOI = HiveUtils.asDoubleCompatibleOI(listElementObjectInspector);
            this.denseInput = true;
        } else {
            if (!HiveUtils.isStringOI(listElementObjectInspector)) {
                throw new UDFArgumentException("tree_predict takes array<double> or array<string> for the second argument: " + asListOI.getTypeName());
            }
            this.featureElemOI = HiveUtils.asStringOI(listElementObjectInspector);
            this.denseInput = false;
        }
        if (objectInspectorArr.length == 4) {
            ObjectInspector objectInspector = objectInspectorArr[3];
            if (HiveUtils.isConstBoolean(objectInspector)) {
                this.classification = HiveUtils.getConstBoolean(objectInspector);
            } else {
                if (!HiveUtils.isConstString(objectInspector)) {
                    throw new UDFArgumentException("tree_predict expects <const boolean> or <const string> for the fourth argument: " + objectInspector.getTypeName());
                }
                processOptions(HiveUtils.getConstString(objectInspector));
            }
        } else {
            this.classification = false;
        }
        if (!this.classification) {
            return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
        }
        ArrayList arrayList = new ArrayList(2);
        ArrayList arrayList2 = new ArrayList(2);
        arrayList.add("value");
        arrayList2.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
        arrayList.add("posteriori");
        arrayList2.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
        return ObjectInspectorFactory.getStandardStructObjectInspector(arrayList, arrayList2);
    }

    public Object evaluate(@Nonnull GenericUDF.DeferredObject[] deferredObjectArr) throws HiveException {
        Object obj = deferredObjectArr[0].get();
        if (obj == null) {
            throw new HiveException("modelId should not be null");
        }
        String obj2 = obj.toString();
        Object obj3 = deferredObjectArr[1].get();
        if (obj3 == null) {
            return null;
        }
        Text primitiveWritableObject = this.modelOI.getPrimitiveWritableObject(obj3);
        Object obj4 = deferredObjectArr[2].get();
        if (obj4 == null) {
            throw new HiveException("features was null");
        }
        this.featuresProbe = parseFeatures(obj4, this.featuresProbe);
        if (this.evaluator == null) {
            this.evaluator = this.classification ? new ClassificationEvaluator() : new RegressionEvaluator();
        }
        return this.evaluator.evaluate(obj2, primitiveWritableObject, this.featuresProbe);
    }

    @Nonnull
    private Vector parseFeatures(@Nonnull Object obj, @Nullable Vector vector) throws UDFArgumentException {
        String str;
        double d;
        if (this.denseInput) {
            int listLength = this.featureListOI.getListLength(obj);
            if (vector == null) {
                vector = new DenseVector(listLength);
            } else if (listLength != vector.size()) {
                vector = new DenseVector(listLength);
            }
            for (int i = 0; i < listLength; i++) {
                Object listElement = this.featureListOI.getListElement(obj, i);
                if (listElement == null) {
                    vector.set(i, 0.0d);
                } else {
                    vector.set(i, PrimitiveObjectInspectorUtils.getDouble(listElement, this.featureElemOI));
                }
            }
        } else {
            if (vector == null) {
                vector = new SparseVector();
            } else {
                vector.clear();
            }
            int listLength2 = this.featureListOI.getListLength(obj);
            for (int i2 = 0; i2 < listLength2; i2++) {
                Object listElement2 = this.featureListOI.getListElement(obj, i2);
                if (listElement2 != null) {
                    String obj2 = listElement2.toString();
                    int indexOf = obj2.indexOf(58);
                    if (indexOf == 0) {
                        throw new UDFArgumentException("Invalid feature value representation: " + obj2);
                    }
                    if (indexOf > 0) {
                        str = obj2.substring(0, indexOf);
                        d = Double.parseDouble(obj2.substring(indexOf + 1));
                    } else {
                        str = obj2;
                        d = 1.0d;
                    }
                    if (str.indexOf(58) != -1) {
                        throw new UDFArgumentException("Invalid feature format `<index>:<value>`: " + obj2);
                    }
                    int parseInt = Integer.parseInt(str);
                    if (parseInt < 0) {
                        throw new UDFArgumentException("Col index MUST be greater than or equals to 0: " + parseInt);
                    }
                    vector.set(parseInt, d);
                }
            }
        }
        return vector;
    }

    public void close() throws IOException {
        this.modelOI = null;
        this.featureElemOI = null;
        this.featureListOI = null;
        this.evaluator = null;
    }

    public String getDisplayString(String[] strArr) {
        return "tree_predict(" + Arrays.toString(strArr) + ")";
    }
}
