package hivemall.topicmodel;

import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.CommandLineUtils;
import hivemall.utils.lang.Primitives;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StandardMapObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.IntWritable;

@Description(name = "plsa_predict", value = "_FUNC_(string word, float value, int label, float prob[, const string options]) - Returns a list which consists of <int label, float prob>")
/* loaded from: input_file:hivemall/topicmodel/PLSAPredictUDAF.class */
public final class PLSAPredictUDAF extends AbstractGenericUDAFResolver {

    /* loaded from: input_file:hivemall/topicmodel/PLSAPredictUDAF$Evaluator.class */
    public static class Evaluator extends GenericUDAFEvaluator {
        private PrimitiveObjectInspector wordOI;
        private PrimitiveObjectInspector valueOI;
        private PrimitiveObjectInspector labelOI;
        private PrimitiveObjectInspector probOI;
        private int topics;
        private float alpha;
        private double delta;
        private StructObjectInspector internalMergeOI;
        private StructField wcListField;
        private StructField probMapField;
        private StructField topicsOptionField;
        private StructField alphaOptionField;
        private StructField deltaOptionField;
        private PrimitiveObjectInspector wcListElemOI;
        private StandardListObjectInspector wcListOI;
        private StandardMapObjectInspector probMapOI;
        private PrimitiveObjectInspector probMapKeyOI;
        private StandardListObjectInspector probMapValueOI;
        private PrimitiveObjectInspector probMapValueElemOI;
        static final /* synthetic */ boolean $assertionsDisabled;

        protected Options getOptions() {
            Options options = new Options();
            options.addOption("k", "topics", true, "The number of topics [default: 10]");
            options.addOption("alpha", true, "The hyperparameter for P(w|z) update [default: 0.5]");
            options.addOption("delta", true, "Check convergence in the expectation step [default: 1E-5]");
            return options;
        }

        @Nonnull
        protected final CommandLine parseOptions(String str) throws UDFArgumentException {
            String simpleName;
            String[] split = str.split("\\s+");
            Options options = getOptions();
            options.addOption("help", false, "Show function help");
            CommandLine parseOptions = CommandLineUtils.parseOptions(split, options);
            if (!parseOptions.hasOption("help")) {
                return parseOptions;
            }
            Description annotation = getClass().getAnnotation(Description.class);
            if (annotation == null) {
                simpleName = getClass().getSimpleName();
            } else {
                simpleName = annotation.name() == null ? getClass().getSimpleName() : annotation.value().replace("_FUNC_", annotation.name());
            }
            StringWriter stringWriter = new StringWriter();
            stringWriter.write(10);
            PrintWriter printWriter = new PrintWriter(stringWriter);
            new HelpFormatter().printHelp(printWriter, 74, simpleName, (String) null, options, 1, 3, (String) null, true);
            printWriter.flush();
            throw new UDFArgumentException(stringWriter.toString());
        }

        @Nullable
        protected CommandLine processOptions(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
            CommandLine commandLine = null;
            if (objectInspectorArr.length >= 5) {
                commandLine = parseOptions(HiveUtils.getConstString(objectInspectorArr[4]));
                this.topics = Primitives.parseInt(commandLine.getOptionValue("topics"), 10);
                if (this.topics < 1) {
                    throw new UDFArgumentException("A positive integer MUST be set to an option `-topics`: " + this.topics);
                }
                this.alpha = Primitives.parseFloat(commandLine.getOptionValue("alpha"), 0.5f);
                this.delta = Primitives.parseDouble(commandLine.getOptionValue("delta"), 0.001d);
            } else {
                this.topics = 10;
                this.alpha = 0.5f;
                this.delta = 0.001d;
            }
            return commandLine;
        }

        public ObjectInspector init(GenericUDAFEvaluator.Mode mode, ObjectInspector[] objectInspectorArr) throws HiveException {
            StandardListObjectInspector internalMergeOI;
            if (!$assertionsDisabled && objectInspectorArr.length != 1 && objectInspectorArr.length != 4 && objectInspectorArr.length != 5) {
                throw new AssertionError();
            }
            super.init(mode, objectInspectorArr);
            if (mode == GenericUDAFEvaluator.Mode.PARTIAL1 || mode == GenericUDAFEvaluator.Mode.COMPLETE) {
                processOptions(objectInspectorArr);
                this.wordOI = HiveUtils.asStringOI(objectInspectorArr[0]);
                this.valueOI = HiveUtils.asDoubleCompatibleOI(objectInspectorArr[1]);
                this.labelOI = HiveUtils.asIntegerOI(objectInspectorArr[2]);
                this.probOI = HiveUtils.asDoubleCompatibleOI(objectInspectorArr[3]);
            } else {
                StructObjectInspector structObjectInspector = (StructObjectInspector) objectInspectorArr[0];
                this.internalMergeOI = structObjectInspector;
                this.wcListField = structObjectInspector.getStructFieldRef("wcList");
                this.probMapField = structObjectInspector.getStructFieldRef("probMap");
                this.topicsOptionField = structObjectInspector.getStructFieldRef("topics");
                this.alphaOptionField = structObjectInspector.getStructFieldRef("alpha");
                this.deltaOptionField = structObjectInspector.getStructFieldRef("delta");
                this.wcListElemOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
                this.wcListOI = ObjectInspectorFactory.getStandardListObjectInspector(this.wcListElemOI);
                this.probMapKeyOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
                this.probMapValueElemOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
                this.probMapValueOI = ObjectInspectorFactory.getStandardListObjectInspector(this.probMapValueElemOI);
                this.probMapOI = ObjectInspectorFactory.getStandardMapObjectInspector(this.probMapKeyOI, this.probMapValueOI);
            }
            if (mode == GenericUDAFEvaluator.Mode.PARTIAL1 || mode == GenericUDAFEvaluator.Mode.PARTIAL2) {
                internalMergeOI = internalMergeOI();
            } else {
                ArrayList arrayList = new ArrayList();
                ArrayList arrayList2 = new ArrayList();
                arrayList.add("label");
                arrayList2.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
                arrayList.add("probability");
                arrayList2.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
                internalMergeOI = ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorFactory.getStandardStructObjectInspector(arrayList, arrayList2));
            }
            return internalMergeOI;
        }

        private static StructObjectInspector internalMergeOI() {
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            arrayList.add("wcList");
            arrayList2.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector));
            arrayList.add("probMap");
            arrayList2.add(ObjectInspectorFactory.getStandardMapObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector, ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaFloatObjectInspector)));
            arrayList.add("topics");
            arrayList2.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
            arrayList.add("alpha");
            arrayList2.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
            arrayList.add("delta");
            arrayList2.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
            return ObjectInspectorFactory.getStandardStructObjectInspector(arrayList, arrayList2);
        }

        public GenericUDAFEvaluator.AggregationBuffer getNewAggregationBuffer() throws HiveException {
            PLSAPredictAggregationBuffer pLSAPredictAggregationBuffer = new PLSAPredictAggregationBuffer();
            reset(pLSAPredictAggregationBuffer);
            return pLSAPredictAggregationBuffer;
        }

        public void reset(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer) throws HiveException {
            PLSAPredictAggregationBuffer pLSAPredictAggregationBuffer = (PLSAPredictAggregationBuffer) aggregationBuffer;
            pLSAPredictAggregationBuffer.reset();
            pLSAPredictAggregationBuffer.setOptions(this.topics, this.alpha, this.delta);
        }

        public void iterate(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer, Object[] objArr) throws HiveException {
            PLSAPredictAggregationBuffer pLSAPredictAggregationBuffer = (PLSAPredictAggregationBuffer) aggregationBuffer;
            if (objArr[0] == null || objArr[1] == null || objArr[2] == null || objArr[3] == null) {
                return;
            }
            pLSAPredictAggregationBuffer.iterate(PrimitiveObjectInspectorUtils.getString(objArr[0], this.wordOI), PrimitiveObjectInspectorUtils.getFloat(objArr[1], this.valueOI), PrimitiveObjectInspectorUtils.getInt(objArr[2], this.labelOI), PrimitiveObjectInspectorUtils.getFloat(objArr[3], this.probOI));
        }

        public Object terminatePartial(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer) throws HiveException {
            PLSAPredictAggregationBuffer pLSAPredictAggregationBuffer = (PLSAPredictAggregationBuffer) aggregationBuffer;
            if (pLSAPredictAggregationBuffer.wcList.size() == 0) {
                return null;
            }
            return new Object[]{pLSAPredictAggregationBuffer.wcList, pLSAPredictAggregationBuffer.probMap, new IntWritable(pLSAPredictAggregationBuffer.topics), new FloatWritable(pLSAPredictAggregationBuffer.alpha), new DoubleWritable(pLSAPredictAggregationBuffer.delta)};
        }

        public void merge(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer, Object obj) throws HiveException {
            if (obj == null) {
                return;
            }
            List list = this.wcListOI.getList(HiveUtils.castLazyBinaryObject(this.internalMergeOI.getStructFieldData(obj, this.wcListField)));
            int size = list.size();
            ArrayList arrayList = new ArrayList();
            for (int i = 0; i < size; i++) {
                arrayList.add(PrimitiveObjectInspectorUtils.getString(list.get(i), this.wcListElemOI));
            }
            Map map = this.probMapOI.getMap(HiveUtils.castLazyBinaryObject(this.internalMergeOI.getStructFieldData(obj, this.probMapField)));
            HashMap hashMap = new HashMap();
            for (Map.Entry entry : map.entrySet()) {
                String string = PrimitiveObjectInspectorUtils.getString(entry.getKey(), this.probMapKeyOI);
                List list2 = this.probMapValueOI.getList(HiveUtils.castLazyBinaryObject(entry.getValue()));
                int size2 = list2.size();
                ArrayList arrayList2 = new ArrayList();
                for (int i2 = 0; i2 < size2; i2++) {
                    arrayList2.add(Float.valueOf(HiveUtils.getFloat(list2.get(i2), this.probMapValueElemOI)));
                }
                hashMap.put(string, arrayList2);
            }
            this.topics = PrimitiveObjectInspectorFactory.writableIntObjectInspector.get(this.internalMergeOI.getStructFieldData(obj, this.topicsOptionField));
            this.alpha = PrimitiveObjectInspectorFactory.writableFloatObjectInspector.get(this.internalMergeOI.getStructFieldData(obj, this.alphaOptionField));
            this.delta = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(this.internalMergeOI.getStructFieldData(obj, this.deltaOptionField));
            PLSAPredictAggregationBuffer pLSAPredictAggregationBuffer = (PLSAPredictAggregationBuffer) aggregationBuffer;
            pLSAPredictAggregationBuffer.setOptions(this.topics, this.alpha, this.delta);
            pLSAPredictAggregationBuffer.merge(arrayList, hashMap);
        }

        public Object terminate(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer) throws HiveException {
            float[] fArr = ((PLSAPredictAggregationBuffer) aggregationBuffer).get();
            TreeMap treeMap = new TreeMap(Collections.reverseOrder());
            for (int i = 0; i < fArr.length; i++) {
                treeMap.put(Float.valueOf(fArr[i]), Integer.valueOf(i));
            }
            ArrayList arrayList = new ArrayList();
            for (Map.Entry entry : treeMap.entrySet()) {
                arrayList.add(new Object[]{new IntWritable(((Integer) entry.getValue()).intValue()), new FloatWritable(((Float) entry.getKey()).floatValue())});
            }
            return arrayList;
        }

        static {
            $assertionsDisabled = !PLSAPredictUDAF.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:hivemall/topicmodel/PLSAPredictUDAF$PLSAPredictAggregationBuffer.class */
    public static class PLSAPredictAggregationBuffer extends GenericUDAFEvaluator.AbstractAggregationBuffer {
        private List<String> wcList;
        private Map<String, List<Float>> probMap;
        private int topics;
        private float alpha;
        private double delta;

        PLSAPredictAggregationBuffer() {
        }

        void setOptions(int i, float f, double d) {
            this.topics = i;
            this.alpha = f;
            this.delta = d;
        }

        void reset() {
            this.wcList = new ArrayList();
            this.probMap = new HashMap();
        }

        void iterate(@Nonnull String str, float f, int i, float f2) {
            this.wcList.add(str + ":" + f);
            List<Float> list = this.probMap.get(str);
            if (list == null) {
                list = new ArrayList(Collections.nCopies(this.topics, Float.valueOf(-1.0f)));
                this.probMap.put(str, list);
            }
            list.set(i, Float.valueOf(f2));
        }

        void merge(@Nonnull List<String> list, @Nonnull Map<String, List<Float>> map) {
            this.wcList.addAll(list);
            for (Map.Entry<String, List<Float>> entry : map.entrySet()) {
                String key = entry.getKey();
                List<Float> value = entry.getValue();
                List<Float> list2 = this.probMap.get(key);
                if (list2 == null) {
                    this.probMap.put(key, value);
                } else {
                    for (int i = 0; i < this.topics; i++) {
                        float floatValue = value.get(i).floatValue();
                        if (floatValue != -1.0f) {
                            list2.set(i, Float.valueOf(floatValue));
                        }
                    }
                    this.probMap.put(key, list2);
                }
            }
        }

        float[] get() {
            IncrementalPLSAModel incrementalPLSAModel = new IncrementalPLSAModel(this.topics, this.alpha, this.delta);
            for (Map.Entry<String, List<Float>> entry : this.probMap.entrySet()) {
                String key = entry.getKey();
                List<Float> value = entry.getValue();
                for (int i = 0; i < this.topics; i++) {
                    float floatValue = value.get(i).floatValue();
                    if (floatValue != -1.0f) {
                        incrementalPLSAModel.setWordScore(key, i, floatValue);
                    }
                }
            }
            return incrementalPLSAModel.getTopicDistribution((String[]) this.wcList.toArray(new String[this.wcList.size()]));
        }
    }

    /* renamed from: getEvaluator, reason: merged with bridge method [inline-methods] */
    public Evaluator m200getEvaluator(TypeInfo[] typeInfoArr) throws SemanticException {
        if (typeInfoArr.length != 4 && typeInfoArr.length != 5) {
            throw new UDFArgumentLengthException("Expected argument length is 4 or 5 but given argument length was " + typeInfoArr.length);
        }
        if (!HiveUtils.isStringTypeInfo(typeInfoArr[0])) {
            throw new UDFArgumentTypeException(0, "String type is expected for the first argument word: " + typeInfoArr[0].getTypeName());
        }
        if (!HiveUtils.isNumberTypeInfo(typeInfoArr[1])) {
            throw new UDFArgumentTypeException(1, "Number type is expected for the second argument value: " + typeInfoArr[1].getTypeName());
        }
        if (!HiveUtils.isIntegerTypeInfo(typeInfoArr[2])) {
            throw new UDFArgumentTypeException(2, "Integer type is expected for the third argument label: " + typeInfoArr[2].getTypeName());
        }
        if (!HiveUtils.isNumberTypeInfo(typeInfoArr[3])) {
            throw new UDFArgumentTypeException(3, "Number type is expected for the forth argument prob: " + typeInfoArr[3].getTypeName());
        }
        if (typeInfoArr.length != 5 || HiveUtils.isStringTypeInfo(typeInfoArr[4])) {
            return new Evaluator();
        }
        throw new UDFArgumentTypeException(4, "String type is expected for the fifth argument prob: " + typeInfoArr[4].getTypeName());
    }
}
