package hivemall.evaluation;

import hivemall.HivemallConstants;
import hivemall.UDAFEvaluatorWithOptions;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.Primitives;
import hivemall.utils.lang.SizeOf;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import javax.annotation.Nonnull;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.ql.util.JavaDataModel;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.BooleanObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.IntObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.LongWritable;

@Description(name = "fmeasure", value = "_FUNC_(array|int|boolean actual, array|int| boolean predicted [, const string options]) - Return a F-measure (f1score is the special with beta=1.0)")
/* loaded from: input_file:hivemall/evaluation/FMeasureUDAF.class */
public final class FMeasureUDAF extends AbstractGenericUDAFResolver {

    /* loaded from: input_file:hivemall/evaluation/FMeasureUDAF$Evaluator.class */
    public static class Evaluator extends UDAFEvaluatorWithOptions {
        private ObjectInspector actualOI;
        private ObjectInspector predictedOI;
        private StructObjectInspector internalMergeOI;
        private StructField tpField;
        private StructField totalActualField;
        private StructField totalPredictedField;
        private StructField betaOptionField;
        private StructField averageOptionFiled;
        private double beta;
        private String average;
        static final /* synthetic */ boolean $assertionsDisabled;

        @Override // hivemall.UDAFEvaluatorWithOptions
        protected Options getOptions() {
            Options options = new Options();
            options.addOption("beta", true, "The weight of precision [default: 1.]");
            options.addOption("average", true, "The way of average calculation [default: micro]");
            return options;
        }

        @Override // hivemall.UDAFEvaluatorWithOptions
        protected CommandLine processOptions(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
            CommandLine commandLine = null;
            double d = 1.0d;
            String str = "micro";
            if (objectInspectorArr.length >= 3) {
                commandLine = parseOptions(HiveUtils.getConstString(objectInspectorArr[2]));
                d = Primitives.parseDouble(commandLine.getOptionValue("beta"), 1.0d);
                if (d <= 0.0d) {
                    throw new UDFArgumentException("The third argument `double beta` must be greater than 0.0: " + d);
                }
                str = commandLine.getOptionValue("average", str);
                if (str.equals("macro")) {
                    throw new UDFArgumentException("\"-average macro\" is not supported");
                }
                if (!str.equals(HivemallConstants.BINARY_TYPE_NAME) && !str.equals("micro")) {
                    throw new UDFArgumentException("The third argument `String average` must be one of the {binary, micro, macro}: " + str);
                }
            }
            this.beta = d;
            this.average = str;
            return commandLine;
        }

        public ObjectInspector init(GenericUDAFEvaluator.Mode mode, ObjectInspector[] objectInspectorArr) throws HiveException {
            if (!$assertionsDisabled && objectInspectorArr.length != 2 && objectInspectorArr.length != 3) {
                throw new AssertionError(objectInspectorArr.length);
            }
            super.init(mode, objectInspectorArr);
            if (mode == GenericUDAFEvaluator.Mode.PARTIAL1 || mode == GenericUDAFEvaluator.Mode.COMPLETE) {
                processOptions(objectInspectorArr);
                this.actualOI = objectInspectorArr[0];
                this.predictedOI = objectInspectorArr[1];
            } else {
                StructObjectInspector structObjectInspector = (StructObjectInspector) objectInspectorArr[0];
                this.internalMergeOI = structObjectInspector;
                this.tpField = structObjectInspector.getStructFieldRef("tp");
                this.totalActualField = structObjectInspector.getStructFieldRef("totalActual");
                this.totalPredictedField = structObjectInspector.getStructFieldRef("totalPredicted");
                this.betaOptionField = structObjectInspector.getStructFieldRef("beta");
                this.averageOptionFiled = structObjectInspector.getStructFieldRef("average");
            }
            return (mode == GenericUDAFEvaluator.Mode.PARTIAL1 || mode == GenericUDAFEvaluator.Mode.PARTIAL2) ? internalMergeOI() : PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
        }

        @Nonnull
        private static StructObjectInspector internalMergeOI() {
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            arrayList.add("tp");
            arrayList2.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
            arrayList.add("totalActual");
            arrayList2.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
            arrayList.add("totalPredicted");
            arrayList2.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
            arrayList.add("beta");
            arrayList2.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
            arrayList.add("average");
            arrayList2.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector);
            return ObjectInspectorFactory.getStandardStructObjectInspector(arrayList, arrayList2);
        }

        /* renamed from: getNewAggregationBuffer, reason: merged with bridge method [inline-methods] */
        public FMeasureAggregationBuffer m33getNewAggregationBuffer() throws HiveException {
            FMeasureAggregationBuffer fMeasureAggregationBuffer = new FMeasureAggregationBuffer();
            reset(fMeasureAggregationBuffer);
            return fMeasureAggregationBuffer;
        }

        public void reset(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer) throws HiveException {
            FMeasureAggregationBuffer fMeasureAggregationBuffer = (FMeasureAggregationBuffer) aggregationBuffer;
            fMeasureAggregationBuffer.reset();
            fMeasureAggregationBuffer.setOptions(this.beta, this.average);
        }

        public void iterate(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer, Object[] objArr) throws HiveException {
            List<?> emptyList;
            List<?> emptyList2;
            FMeasureAggregationBuffer fMeasureAggregationBuffer = (FMeasureAggregationBuffer) aggregationBuffer;
            if (HiveUtils.isListOI(this.actualOI) && HiveUtils.isListOI(this.predictedOI)) {
                if (HivemallConstants.BINARY_TYPE_NAME.equals(this.average)) {
                    throw new UDFArgumentException("\"-average binary\" is not supported when `predict` is array");
                }
                emptyList = this.actualOI.getList(objArr[0]);
                emptyList2 = this.predictedOI.getList(objArr[1]);
            } else if (HiveUtils.isBooleanOI(this.actualOI)) {
                emptyList = Arrays.asList(Integer.valueOf(asIntLabel(objArr[0], this.actualOI)));
                emptyList2 = Arrays.asList(Integer.valueOf(asIntLabel(objArr[1], this.predictedOI)));
            } else {
                int asIntLabel = asIntLabel(objArr[0], this.actualOI);
                emptyList = (asIntLabel == 0 && HivemallConstants.BINARY_TYPE_NAME.equals(this.average)) ? Collections.emptyList() : Arrays.asList(Integer.valueOf(asIntLabel));
                int asIntLabel2 = asIntLabel(objArr[1], this.predictedOI);
                emptyList2 = (asIntLabel2 == 0 && HivemallConstants.BINARY_TYPE_NAME.equals(this.average)) ? Collections.emptyList() : Arrays.asList(Integer.valueOf(asIntLabel2));
            }
            fMeasureAggregationBuffer.iterate(emptyList, emptyList2);
        }

        private static int asIntLabel(@Nonnull Object obj, @Nonnull BooleanObjectInspector booleanObjectInspector) {
            return booleanObjectInspector.get(obj) ? 1 : 0;
        }

        private static int asIntLabel(@Nonnull Object obj, @Nonnull IntObjectInspector intObjectInspector) throws UDFArgumentException {
            int i = intObjectInspector.get(obj);
            switch (i) {
                case -1:
                case 0:
                    return 0;
                case SizeOf.BYTE /* 1 */:
                    return 1;
                default:
                    throw new UDFArgumentException("Int label must be 1, 0 or -1: " + i);
            }
        }

        public Object terminatePartial(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer) throws HiveException {
            FMeasureAggregationBuffer fMeasureAggregationBuffer = (FMeasureAggregationBuffer) aggregationBuffer;
            return new Object[]{new LongWritable(fMeasureAggregationBuffer.tp), new LongWritable(fMeasureAggregationBuffer.totalActual), new LongWritable(fMeasureAggregationBuffer.totalPredicted), new DoubleWritable(fMeasureAggregationBuffer.beta), fMeasureAggregationBuffer.average};
        }

        public void merge(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer, Object obj) throws HiveException {
            if (obj == null) {
                return;
            }
            Object structFieldData = this.internalMergeOI.getStructFieldData(obj, this.tpField);
            Object structFieldData2 = this.internalMergeOI.getStructFieldData(obj, this.totalActualField);
            Object structFieldData3 = this.internalMergeOI.getStructFieldData(obj, this.totalPredictedField);
            Object structFieldData4 = this.internalMergeOI.getStructFieldData(obj, this.betaOptionField);
            Object structFieldData5 = this.internalMergeOI.getStructFieldData(obj, this.averageOptionFiled);
            ((FMeasureAggregationBuffer) aggregationBuffer).merge(PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(structFieldData), PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(structFieldData2), PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(structFieldData3), PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(structFieldData4), PrimitiveObjectInspectorFactory.writableStringObjectInspector.getPrimitiveJavaObject(structFieldData5));
        }

        /* renamed from: terminate, reason: merged with bridge method [inline-methods] */
        public DoubleWritable m32terminate(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer) throws HiveException {
            return new DoubleWritable(((FMeasureAggregationBuffer) aggregationBuffer).get());
        }

        static {
            $assertionsDisabled = !FMeasureUDAF.class.desiredAssertionStatus();
        }
    }

    @GenericUDAFEvaluator.AggregationType(estimable = true)
    /* loaded from: input_file:hivemall/evaluation/FMeasureUDAF$FMeasureAggregationBuffer.class */
    public static class FMeasureAggregationBuffer extends GenericUDAFEvaluator.AbstractAggregationBuffer {
        long tp;
        long totalActual;
        long totalPredicted;
        double beta;
        String average;

        public int estimate() {
            JavaDataModel javaDataModel = JavaDataModel.get();
            return (javaDataModel.primitive2() * 4) + javaDataModel.lengthFor(this.average);
        }

        void setOptions(double d, String str) {
            this.beta = d;
            this.average = str;
        }

        void reset() {
            this.tp = 0L;
            this.totalActual = 0L;
            this.totalPredicted = 0L;
        }

        void merge(long j, long j2, long j3, double d, String str) {
            this.tp += j;
            this.totalActual += j2;
            this.totalPredicted += j3;
            this.beta = d;
            this.average = str;
        }

        double get() {
            double d;
            double d2;
            double d3 = this.beta * this.beta;
            if ("micro".equals(this.average)) {
                d = denom(this.tp, this.totalActual, this.totalPredicted, d3);
                d2 = (1.0d + d3) * this.tp;
            } else {
                double precision = precision(this.tp, this.totalPredicted);
                double recall = recall(this.tp, this.totalActual);
                d = (d3 * precision) + recall;
                d2 = (1.0d + d3) * precision * recall;
            }
            if (d > 0.0d) {
                return d2 / d;
            }
            return 0.0d;
        }

        private static double denom(long j, long j2, long j3, double d) {
            return (d * (j + (j2 - j))) + j + (j3 - j);
        }

        private static double precision(long j, long j2) {
            if (j2 == 0) {
                return 0.0d;
            }
            return j / j2;
        }

        private static double recall(long j, long j2) {
            if (j2 == 0) {
                return 0.0d;
            }
            return j / j2;
        }

        void iterate(@Nonnull List<?> list, @Nonnull List<?> list2) {
            int size = list.size();
            int size2 = list2.size();
            int i = 0;
            Iterator<?> it = list2.iterator();
            while (it.hasNext()) {
                if (list.contains(it.next())) {
                    i++;
                }
            }
            this.tp += i;
            this.totalActual += size;
            this.totalPredicted += size2;
        }
    }

    public GenericUDAFEvaluator getEvaluator(@Nonnull TypeInfo[] typeInfoArr) throws SemanticException {
        if (typeInfoArr.length != 2 && typeInfoArr.length != 3) {
            throw new UDFArgumentTypeException(typeInfoArr.length - 1, "_FUNC_ takes two or three arguments");
        }
        if (!(HiveUtils.isListTypeInfo(typeInfoArr[0]) || HiveUtils.isIntegerTypeInfo(typeInfoArr[0]) || HiveUtils.isBooleanTypeInfo(typeInfoArr[0]))) {
            throw new UDFArgumentTypeException(0, "The first argument `array/int/boolean actual` is invalid form: " + typeInfoArr[0]);
        }
        if (!(HiveUtils.isListTypeInfo(typeInfoArr[1]) || HiveUtils.isIntegerTypeInfo(typeInfoArr[1]) || HiveUtils.isBooleanTypeInfo(typeInfoArr[1]))) {
            throw new UDFArgumentTypeException(1, "The second argument `array/int/boolean predicted` is invalid form: " + typeInfoArr[1]);
        }
        if (typeInfoArr[0].equals(typeInfoArr[1])) {
            return new Evaluator();
        }
        throw new UDFArgumentTypeException(1, "The first argument `actual`'s type is " + typeInfoArr[0] + ", but the second argument `predicted`'s type is not match: " + typeInfoArr[1]);
    }
}
