package hivemall.smile.tools;

import hivemall.smile.classification.DecisionTree;
import hivemall.smile.regression.RegressionTree;
import hivemall.smile.vm.StackMachine;
import hivemall.smile.vm.VMRuntimeException;
import hivemall.utils.codec.Base91;
import hivemall.utils.codec.DeflateCodec;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.io.IOUtils;
import hivemall.utils.lang.ExceptionUtils;
import hivemall.utils.lang.SizeOf;
import java.io.Closeable;
import java.io.IOException;
import java.util.Arrays;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import javax.script.Bindings;
import javax.script.Compilable;
import javax.script.CompiledScript;
import javax.script.ScriptEngine;
import javax.script.ScriptEngineManager;
import javax.script.ScriptException;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.MapredContext;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.Writable;

@UDFType(deterministic = true, stateful = false)
@Description(name = "tree_predict_v1", value = "_FUNC_(string modelId, int modelType, string script, array<double> features [, const boolean classification]) - Returns a prediction result of a random forest")
@Deprecated
/* loaded from: input_file:hivemall/smile/tools/TreePredictUDFv1.class */
public final class TreePredictUDFv1 extends GenericUDF {
    private boolean classification;
    private PrimitiveObjectInspector modelTypeOI;
    private StringObjectInspector stringOI;
    private ListObjectInspector featureListOI;
    private PrimitiveObjectInspector featureElemOI;

    @Nullable
    private transient Evaluator evaluator;
    private boolean support_javascript_eval = true;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: hivemall.smile.tools.TreePredictUDFv1$1, reason: invalid class name */
    /* loaded from: input_file:hivemall/smile/tools/TreePredictUDFv1$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$hivemall$smile$tools$TreePredictUDFv1$ModelType = new int[ModelType.values().length];

        static {
            try {
                $SwitchMap$hivemall$smile$tools$TreePredictUDFv1$ModelType[ModelType.serialization.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$hivemall$smile$tools$TreePredictUDFv1$ModelType[ModelType.serialization_compressed.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$hivemall$smile$tools$TreePredictUDFv1$ModelType[ModelType.opscode.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$hivemall$smile$tools$TreePredictUDFv1$ModelType[ModelType.opscode_compressed.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$hivemall$smile$tools$TreePredictUDFv1$ModelType[ModelType.javascript.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$hivemall$smile$tools$TreePredictUDFv1$ModelType[ModelType.javascript_compressed.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
        }
    }

    /* loaded from: input_file:hivemall/smile/tools/TreePredictUDFv1$Evaluator.class */
    public interface Evaluator extends Closeable {
        @Nullable
        Writable evaluate(@Nonnull String str, boolean z, @Nonnull Text text, @Nonnull double[] dArr, boolean z2) throws HiveException;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:hivemall/smile/tools/TreePredictUDFv1$JavaSerializationEvaluator.class */
    public static final class JavaSerializationEvaluator implements Evaluator {

        @Nullable
        private String prevModelId = null;
        private DecisionTree.Node cNode = null;
        private RegressionTree.Node rNode = null;
        static final /* synthetic */ boolean $assertionsDisabled;

        JavaSerializationEvaluator() {
        }

        @Override // hivemall.smile.tools.TreePredictUDFv1.Evaluator
        public Writable evaluate(@Nonnull String str, boolean z, @Nonnull Text text, double[] dArr, boolean z2) throws HiveException {
            return z2 ? evaluateClassification(str, z, text, dArr) : evaluteRegression(str, z, text, dArr);
        }

        private IntWritable evaluateClassification(@Nonnull String str, boolean z, @Nonnull Text text, double[] dArr) throws HiveException {
            if (!str.equals(this.prevModelId)) {
                this.prevModelId = str;
                byte[] decode = Base91.decode(text.getBytes(), 0, text.getLength());
                this.cNode = DecisionTree.deserialize(decode, decode.length, z);
            }
            if ($assertionsDisabled || this.cNode != null) {
                return new IntWritable(this.cNode.predict(dArr));
            }
            throw new AssertionError();
        }

        private DoubleWritable evaluteRegression(@Nonnull String str, boolean z, @Nonnull Text text, double[] dArr) throws HiveException {
            if (!str.equals(this.prevModelId)) {
                this.prevModelId = str;
                byte[] decode = Base91.decode(text.getBytes(), 0, text.getLength());
                this.rNode = RegressionTree.deserialize(decode, decode.length, z);
            }
            if ($assertionsDisabled || this.rNode != null) {
                return new DoubleWritable(this.rNode.predict(dArr));
            }
            throw new AssertionError();
        }

        @Override // java.io.Closeable, java.lang.AutoCloseable
        public void close() throws IOException {
        }

        static {
            $assertionsDisabled = !TreePredictUDFv1.class.desiredAssertionStatus();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:hivemall/smile/tools/TreePredictUDFv1$JavascriptEvaluator.class */
    public static final class JavascriptEvaluator implements Evaluator {
        private final ScriptEngine scriptEngine;
        private final Compilable compilableEngine;
        private CompiledScript prevCompiled;
        private String prevModelId = null;
        private DeflateCodec codec = null;

        JavascriptEvaluator() throws UDFArgumentException {
            Compilable engineByExtension = new ScriptEngineManager().getEngineByExtension("js");
            if (!(engineByExtension instanceof Compilable)) {
                throw new UDFArgumentException("ScriptEngine was not compilable: " + engineByExtension.getFactory().getEngineName() + " version " + engineByExtension.getFactory().getEngineVersion());
            }
            this.scriptEngine = engineByExtension;
            this.compilableEngine = engineByExtension;
        }

        @Override // hivemall.smile.tools.TreePredictUDFv1.Evaluator
        public Writable evaluate(@Nonnull String str, boolean z, @Nonnull Text text, double[] dArr, boolean z2) throws HiveException {
            String text2;
            CompiledScript compile;
            if (z) {
                if (this.codec == null) {
                    this.codec = new DeflateCodec(false, true);
                }
                try {
                    text2 = new String(this.codec.decompress(Base91.decode(text.getBytes(), 0, text.getLength())));
                } catch (IOException e) {
                    throw new HiveException("decompression failed", e);
                }
            } else {
                text2 = text.toString();
            }
            if (str.equals(this.prevModelId)) {
                compile = this.prevCompiled;
            } else {
                try {
                    compile = this.compilableEngine.compile(text2);
                    this.prevCompiled = compile;
                } catch (ScriptException e2) {
                    throw new HiveException("failed to compile: \n" + text, e2);
                }
            }
            Bindings createBindings = this.scriptEngine.createBindings();
            try {
                try {
                    createBindings.put("x", dArr);
                    Object eval = compile.eval(createBindings);
                    createBindings.clear();
                    if (eval == null) {
                        return null;
                    }
                    if (eval instanceof Number) {
                        return z2 ? new IntWritable(((Number) eval).intValue()) : new DoubleWritable(((Number) eval).doubleValue());
                    }
                    throw new HiveException("Got an unexpected non-number result: " + eval);
                } catch (Throwable th) {
                    createBindings.clear();
                    throw th;
                }
            } catch (ScriptException e3) {
                throw new HiveException("failed to evaluate: \n" + text, e3);
            } catch (Throwable th2) {
                throw new HiveException("failed to evaluate: \n" + text, th2);
            }
        }

        @Override // java.io.Closeable, java.lang.AutoCloseable
        public void close() throws IOException {
            IOUtils.closeQuietly(this.codec);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:hivemall/smile/tools/TreePredictUDFv1$ModelType.class */
    public enum ModelType {
        opscode(1, false),
        javascript(2, false),
        serialization(3, false),
        opscode_compressed(-1, true),
        javascript_compressed(-2, true),
        serialization_compressed(-3, true);

        private final int id;
        private final boolean compressed;

        ModelType(int i, boolean z) {
            this.id = i;
            this.compressed = z;
        }

        int getId() {
            return this.id;
        }

        boolean isCompressed() {
            return this.compressed;
        }

        @Nonnull
        static ModelType resolve(int i) {
            ModelType modelType;
            switch (i) {
                case -3:
                    modelType = serialization_compressed;
                    break;
                case -2:
                    modelType = javascript_compressed;
                    break;
                case -1:
                    modelType = opscode_compressed;
                    break;
                case 0:
                default:
                    throw new IllegalStateException("Unexpected ID for ModelType: " + i);
                case SizeOf.BYTE /* 1 */:
                    modelType = opscode;
                    break;
                case 2:
                    modelType = javascript;
                    break;
                case 3:
                    modelType = serialization;
                    break;
            }
            return modelType;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:hivemall/smile/tools/TreePredictUDFv1$StackmachineEvaluator.class */
    public static final class StackmachineEvaluator implements Evaluator {
        private String prevModelId = null;
        private StackMachine prevVM = null;
        private DeflateCodec codec = null;

        StackmachineEvaluator() {
        }

        @Override // hivemall.smile.tools.TreePredictUDFv1.Evaluator
        public Writable evaluate(@Nonnull String str, boolean z, @Nonnull Text text, double[] dArr, boolean z2) throws HiveException {
            String text2;
            StackMachine stackMachine;
            if (z) {
                if (this.codec == null) {
                    this.codec = new DeflateCodec(false, true);
                }
                try {
                    text2 = new String(this.codec.decompress(Base91.decode(text.getBytes(), 0, text.getLength())));
                } catch (IOException e) {
                    throw new HiveException("decompression failed", e);
                }
            } else {
                text2 = text.toString();
            }
            if (str.equals(this.prevModelId)) {
                stackMachine = this.prevVM;
            } else {
                stackMachine = new StackMachine();
                try {
                    stackMachine.compile(text2);
                    this.prevModelId = str;
                    this.prevVM = stackMachine;
                } catch (VMRuntimeException e2) {
                    throw new HiveException("failed to compile StackMachine", e2);
                }
            }
            try {
                stackMachine.eval(dArr);
                Double result = stackMachine.getResult();
                if (result == null) {
                    return null;
                }
                return z2 ? new IntWritable(result.intValue()) : new DoubleWritable(result.doubleValue());
            } catch (VMRuntimeException e3) {
                throw new HiveException("failed to eval StackMachine", e3);
            } catch (Throwable th) {
                throw new HiveException("failed to eval StackMachine", th);
            }
        }

        @Override // java.io.Closeable, java.lang.AutoCloseable
        public void close() throws IOException {
            IOUtils.closeQuietly(this.codec);
        }
    }

    public void configure(MapredContext mapredContext) {
        super.configure(mapredContext);
        if (mapredContext == null || mapredContext.getJobConf().get("td.jar.version") == null) {
            return;
        }
        this.support_javascript_eval = false;
    }

    public ObjectInspector initialize(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        if (objectInspectorArr.length != 4 && objectInspectorArr.length != 5) {
            throw new UDFArgumentException("_FUNC_ takes 4 or 5 arguments");
        }
        this.modelTypeOI = HiveUtils.asIntegerOI(objectInspectorArr[1]);
        this.stringOI = HiveUtils.asStringOI(objectInspectorArr[2]);
        ListObjectInspector asListOI = HiveUtils.asListOI(objectInspectorArr[3]);
        this.featureListOI = asListOI;
        this.featureElemOI = HiveUtils.asDoubleCompatibleOI(asListOI.getListElementObjectInspector());
        boolean z = false;
        if (objectInspectorArr.length == 5) {
            z = HiveUtils.getConstBoolean(objectInspectorArr[4]);
        }
        this.classification = z;
        return z ? PrimitiveObjectInspectorFactory.writableIntObjectInspector : PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
    }

    /* renamed from: evaluate, reason: merged with bridge method [inline-methods] */
    public Writable m154evaluate(@Nonnull GenericUDF.DeferredObject[] deferredObjectArr) throws HiveException {
        Object obj = deferredObjectArr[0].get();
        if (obj == null) {
            throw new HiveException("ModelId was null");
        }
        String obj2 = obj.toString();
        ModelType resolve = ModelType.resolve(PrimitiveObjectInspectorUtils.getInt(deferredObjectArr[1].get(), this.modelTypeOI));
        Object obj3 = deferredObjectArr[2].get();
        if (obj3 == null) {
            return null;
        }
        Text primitiveWritableObject = this.stringOI.getPrimitiveWritableObject(obj3);
        Object obj4 = deferredObjectArr[3].get();
        if (obj4 == null) {
            throw new HiveException("array<double> features was null");
        }
        double[] asDoubleArray = HiveUtils.asDoubleArray(obj4, this.featureListOI, this.featureElemOI);
        if (this.evaluator == null) {
            this.evaluator = getEvaluator(resolve, this.support_javascript_eval);
        }
        return this.evaluator.evaluate(obj2, resolve.isCompressed(), primitiveWritableObject, asDoubleArray, this.classification);
    }

    @Nonnull
    private static Evaluator getEvaluator(@Nonnull ModelType modelType, boolean z) throws UDFArgumentException {
        Evaluator javascriptEvaluator;
        switch (AnonymousClass1.$SwitchMap$hivemall$smile$tools$TreePredictUDFv1$ModelType[modelType.ordinal()]) {
            case SizeOf.BYTE /* 1 */:
            case 2:
                javascriptEvaluator = new JavaSerializationEvaluator();
                break;
            case 3:
            case 4:
                javascriptEvaluator = new StackmachineEvaluator();
                break;
            case ExceptionUtils.TRACE_CAUSE_DEPTH /* 5 */:
            case 6:
                if (!z) {
                    throw new UDFArgumentException("Javascript evaluation is not allowed in Treasure Data env");
                }
                javascriptEvaluator = new JavascriptEvaluator();
                break;
            default:
                throw new UDFArgumentException("Unexpected model type was detected: " + modelType);
        }
        return javascriptEvaluator;
    }

    public void close() throws IOException {
        this.modelTypeOI = null;
        this.stringOI = null;
        this.featureElemOI = null;
        this.featureListOI = null;
        IOUtils.closeQuietly(this.evaluator);
        this.evaluator = null;
    }

    public String getDisplayString(String[] strArr) {
        return "tree_predict(" + Arrays.toString(strArr) + ")";
    }
}
