package hivemall.smile.tools;

import hivemall.UDFWithOptions;
import hivemall.smile.classification.DecisionTree;
import hivemall.smile.regression.RegressionTree;
import hivemall.smile.utils.SmileExtUtils;
import hivemall.utils.codec.Base91;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.SizeOf;
import hivemall.utils.lang.mutable.MutableInt;
import java.util.Arrays;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
import org.apache.hadoop.io.Text;

@UDFType(deterministic = true, stateful = false)
@Description(name = "tree_export", value = "_FUNC_(string model, const string options, optional array<string> featureNames=null, optional array<string> classNames=null) - exports a Decision Tree model as javascript/dot]")
/* loaded from: input_file:hivemall/smile/tools/TreeExportUDF.class */
public final class TreeExportUDF extends UDFWithOptions {
    private transient Evaluator evaluator;
    private transient StringObjectInspector modelOI;

    @Nullable
    private transient ListObjectInspector featureNamesOI;

    @Nullable
    private transient ListObjectInspector classNamesOI;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: hivemall.smile.tools.TreeExportUDF$1, reason: invalid class name */
    /* loaded from: input_file:hivemall/smile/tools/TreeExportUDF$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$hivemall$smile$tools$TreeExportUDF$OutputType = new int[OutputType.values().length];

        static {
            try {
                $SwitchMap$hivemall$smile$tools$TreeExportUDF$OutputType[OutputType.javascript.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$hivemall$smile$tools$TreeExportUDF$OutputType[OutputType.graphvis.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
        }
    }

    /* loaded from: input_file:hivemall/smile/tools/TreeExportUDF$Evaluator.class */
    public static class Evaluator {

        @Nonnull
        private final OutputType outputType;

        @Nonnull
        private final String outputName;
        private final boolean regression;

        public Evaluator(@Nonnull OutputType outputType, @Nonnull String str, boolean z) {
            this.outputType = outputType;
            this.outputName = str;
            this.regression = z;
        }

        @Nonnull
        public Text export(@Nonnull Text text, @Nullable String[] strArr, @Nullable String[] strArr2) throws HiveException {
            byte[] decode = Base91.decode(text.getBytes(), 0, text.getLength());
            return new Text(this.regression ? exportRegressor(decode, strArr) : exportClassifier(decode, strArr, strArr2));
        }

        @Nonnull
        private String exportClassifier(@Nonnull byte[] bArr, @Nullable String[] strArr, @Nullable String[] strArr2) throws HiveException {
            DecisionTree.Node deserialize = DecisionTree.deserialize(bArr, bArr.length, true);
            StringBuilder sb = new StringBuilder(8192);
            switch (AnonymousClass1.$SwitchMap$hivemall$smile$tools$TreeExportUDF$OutputType[this.outputType.ordinal()]) {
                case SizeOf.BYTE /* 1 */:
                    deserialize.exportJavascript(sb, strArr, strArr2, 0);
                    break;
                case 2:
                    sb.append("digraph Tree {\n node [shape=box, style=\"filled, rounded\", color=\"black\", fontname=helvetica];\n edge [fontname=helvetica];\n");
                    deserialize.exportGraphviz(sb, strArr, strArr2, this.outputName, strArr2 == null ? null : SmileExtUtils.getColorBrew(strArr2.length), new MutableInt(0), 0);
                    sb.append("}");
                    break;
                default:
                    throw new HiveException("Unsupported outputType: " + this.outputType);
            }
            return sb.toString();
        }

        @Nonnull
        private String exportRegressor(@Nonnull byte[] bArr, @Nullable String[] strArr) throws HiveException {
            RegressionTree.Node deserialize = RegressionTree.deserialize(bArr, bArr.length, true);
            StringBuilder sb = new StringBuilder(8192);
            switch (AnonymousClass1.$SwitchMap$hivemall$smile$tools$TreeExportUDF$OutputType[this.outputType.ordinal()]) {
                case SizeOf.BYTE /* 1 */:
                    deserialize.exportJavascript(sb, strArr, 0);
                    break;
                case 2:
                    sb.append("digraph Tree {\n node [shape=box, style=\"filled, rounded\", color=\"black\", fontname=helvetica];\n edge [fontname=helvetica];\n");
                    deserialize.exportGraphviz(sb, strArr, this.outputName, new MutableInt(0), 0);
                    sb.append("}");
                    break;
                default:
                    throw new HiveException("Unsupported outputType: " + this.outputType);
            }
            return sb.toString();
        }
    }

    /* loaded from: input_file:hivemall/smile/tools/TreeExportUDF$OutputType.class */
    public enum OutputType {
        javascript,
        graphvis;

        @Nonnull
        public static OutputType resolve(@Nonnull String str) throws UDFArgumentException {
            if ("js".equalsIgnoreCase(str) || "javascript".equalsIgnoreCase(str)) {
                return javascript;
            }
            if ("dot".equalsIgnoreCase(str) || "graphvis".equalsIgnoreCase(str)) {
                return graphvis;
            }
            throw new UDFArgumentException("Please provide a valid `-type` option from [javascript, graphvis]: " + str);
        }
    }

    @Override // hivemall.UDFWithOptions
    protected Options getOptions() {
        Options options = new Options();
        options.addOption("t", "type", true, "Type of output [default: js, javascript/js, graphvis/dot");
        options.addOption("r", "regression", false, "Is regression tree or not");
        options.addOption("output_name", "outputName", true, "output name [default: predicted]");
        return options;
    }

    @Override // hivemall.UDFWithOptions
    protected CommandLine processOptions(@Nonnull String str) throws UDFArgumentException {
        CommandLine parseOptions = parseOptions(str);
        this.evaluator = new Evaluator(OutputType.resolve(parseOptions.getOptionValue("type")), parseOptions.getOptionValue("output_name", "predicted"), parseOptions.hasOption("regression"));
        return parseOptions;
    }

    public ObjectInspector initialize(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        int length = objectInspectorArr.length;
        if (length < 2 || length > 4) {
            throw new UDFArgumentException("_FUNC_ takes 2~4 arguments: " + length);
        }
        this.modelOI = HiveUtils.asStringOI(objectInspectorArr[0]);
        processOptions(HiveUtils.getConstString(objectInspectorArr[1]));
        if (length >= 3) {
            this.featureNamesOI = HiveUtils.asListOI(objectInspectorArr[2]);
            if (!HiveUtils.isStringOI(this.featureNamesOI.getListElementObjectInspector())) {
                throw new UDFArgumentException("_FUNC_ expected array<string> for featureNames: " + this.featureNamesOI.getTypeName());
            }
            if (length == 4) {
                this.classNamesOI = HiveUtils.asListOI(objectInspectorArr[3]);
                if (!HiveUtils.isStringOI(this.classNamesOI.getListElementObjectInspector())) {
                    throw new UDFArgumentException("_FUNC_ expected array<string> for classNames: " + this.classNamesOI.getTypeName());
                }
            }
        }
        return PrimitiveObjectInspectorFactory.writableStringObjectInspector;
    }

    public Object evaluate(GenericUDF.DeferredObject[] deferredObjectArr) throws HiveException {
        Object obj = deferredObjectArr[0].get();
        if (obj == null) {
            return null;
        }
        Text primitiveWritableObject = this.modelOI.getPrimitiveWritableObject(obj);
        String[] strArr = null;
        String[] strArr2 = null;
        if (deferredObjectArr.length >= 3) {
            strArr = HiveUtils.asStringArray(deferredObjectArr[2], this.featureNamesOI);
            if (deferredObjectArr.length >= 4) {
                strArr2 = HiveUtils.asStringArray(deferredObjectArr[3], this.classNamesOI);
            }
        }
        try {
            return this.evaluator.export(primitiveWritableObject, strArr, strArr2);
        } catch (HiveException e) {
            throw e;
        } catch (Throwable th) {
            throw new HiveException(th);
        }
    }

    public String getDisplayString(String[] strArr) {
        return "tree_export(" + Arrays.toString(strArr) + ")";
    }
}
