package hivemall.evaluation;

import hivemall.utils.hadoop.WritableUtils;
import java.util.List;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDAF;
import org.apache.hadoop.hive.ql.exec.UDAFEvaluator;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.io.IntWritable;

@Description(name = "f1score", value = "_FUNC_(array[int], array[int]) - Return a F1 score")
/* loaded from: input_file:hivemall/evaluation/F1ScoreUDAF.class */
public final class F1ScoreUDAF extends UDAF {

    /* loaded from: input_file:hivemall/evaluation/F1ScoreUDAF$Evaluator.class */
    public static class Evaluator implements UDAFEvaluator {
        private PartialResult partial;

        /* loaded from: input_file:hivemall/evaluation/F1ScoreUDAF$Evaluator$PartialResult.class */
        public static class PartialResult {
            long tp = 0;
            long totalPredicted = 0;
            long totalActual = 0;

            PartialResult() {
            }

            void updateScore(List<IntWritable> list, List<IntWritable> list2) {
                int size = list.size();
                int size2 = list2.size();
                int i = 0;
                for (int i2 = 0; i2 < size2; i2++) {
                    if (list.contains(list2.get(i2))) {
                        i++;
                    }
                }
                this.tp += i;
                this.totalActual += size;
                this.totalPredicted += size2;
            }

            void merge(PartialResult partialResult) {
                this.tp += partialResult.tp;
                this.totalActual += partialResult.totalActual;
                this.totalPredicted += partialResult.totalPredicted;
            }
        }

        public void init() {
            this.partial = null;
        }

        public boolean iterate(List<IntWritable> list, List<IntWritable> list2) {
            if (this.partial == null) {
                this.partial = new PartialResult();
            }
            this.partial.updateScore(list, list2);
            return true;
        }

        public PartialResult terminatePartial() {
            return this.partial;
        }

        public boolean merge(PartialResult partialResult) {
            if (partialResult == null) {
                return true;
            }
            if (this.partial == null) {
                this.partial = new PartialResult();
            }
            this.partial.merge(partialResult);
            return true;
        }

        public DoubleWritable terminate() {
            if (this.partial == null) {
                return null;
            }
            return WritableUtils.val(f1Score(this.partial));
        }

        private static double f1Score(PartialResult partialResult) {
            double precision = precision(partialResult);
            double recall = recall(partialResult);
            double d = precision + recall;
            if (d > 0.0d) {
                return ((2.0d * precision) * recall) / d;
            }
            return -1.0d;
        }

        private static double precision(PartialResult partialResult) {
            if (partialResult.totalPredicted == 0) {
                return 0.0d;
            }
            return partialResult.tp / partialResult.totalPredicted;
        }

        private static double recall(PartialResult partialResult) {
            if (partialResult.totalActual == 0) {
                return 0.0d;
            }
            return partialResult.tp / partialResult.totalActual;
        }
    }
}
