package hivemall.ftvec.selection;

import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hadoop.WritableUtils;
import hivemall.utils.lang.Preconditions;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import javax.annotation.Nonnull;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFParameterInfo;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;

@Description(name = "snr", value = "_FUNC_(array<number> features, array<int> one-hot class label) - Returns Signal Noise Ratio for each feature as array<double>")
/* loaded from: input_file:hivemall/ftvec/selection/SignalNoiseRatioUDAF.class */
public class SignalNoiseRatioUDAF extends AbstractGenericUDAFResolver {

    /* loaded from: input_file:hivemall/ftvec/selection/SignalNoiseRatioUDAF$SignalNoiseRatioUDAFEvaluator.class */
    static class SignalNoiseRatioUDAFEvaluator extends GenericUDAFEvaluator {
        private ListObjectInspector featuresOI;
        private PrimitiveObjectInspector featureOI;
        private ListObjectInspector labelsOI;
        private PrimitiveObjectInspector labelOI;
        private StructObjectInspector structOI;
        private StructField countsField;
        private StructField meansField;
        private StructField variancesField;
        private ListObjectInspector countsOI;
        private LongObjectInspector countOI;
        private ListObjectInspector meansOI;
        private ListObjectInspector meanListOI;
        private DoubleObjectInspector meanElemOI;
        private ListObjectInspector variancesOI;
        private ListObjectInspector varianceListOI;
        private DoubleObjectInspector varianceElemOI;

        /* JADX INFO: Access modifiers changed from: package-private */
        @GenericUDAFEvaluator.AggregationType(estimable = true)
        /* loaded from: input_file:hivemall/ftvec/selection/SignalNoiseRatioUDAF$SignalNoiseRatioUDAFEvaluator$SignalNoiseRatioAggregationBuffer.class */
        public static class SignalNoiseRatioAggregationBuffer extends GenericUDAFEvaluator.AbstractAggregationBuffer {
            long[] counts;
            double[][] means;
            double[][] variances;

            SignalNoiseRatioAggregationBuffer() {
            }

            public int estimate() {
                if (this.counts == null) {
                    return 0;
                }
                return (8 * this.counts.length) + (8 * this.means.length * this.means[0].length) + (8 * this.variances.length * this.variances[0].length);
            }

            public void init(int i, int i2) {
                this.counts = new long[i];
                this.means = new double[i][i2];
                this.variances = new double[i][i2];
            }

            public void reset() {
                if (this.counts != null) {
                    Arrays.fill(this.counts, 0L);
                    for (double[] dArr : this.means) {
                        Arrays.fill(dArr, 0.0d);
                    }
                    for (double[] dArr2 : this.variances) {
                        Arrays.fill(dArr2, 0.0d);
                    }
                }
            }
        }

        SignalNoiseRatioUDAFEvaluator() {
        }

        public ObjectInspector init(GenericUDAFEvaluator.Mode mode, ObjectInspector[] objectInspectorArr) throws HiveException {
            super.init(mode, objectInspectorArr);
            if (mode == GenericUDAFEvaluator.Mode.PARTIAL1 || mode == GenericUDAFEvaluator.Mode.COMPLETE) {
                this.featuresOI = HiveUtils.asListOI(objectInspectorArr[0]);
                this.featureOI = HiveUtils.asDoubleCompatibleOI(this.featuresOI.getListElementObjectInspector());
                this.labelsOI = HiveUtils.asListOI(objectInspectorArr[1]);
                this.labelOI = HiveUtils.asIntegerOI(this.labelsOI.getListElementObjectInspector());
            } else {
                this.structOI = (StructObjectInspector) objectInspectorArr[0];
                this.countsField = this.structOI.getStructFieldRef("counts");
                this.countsOI = HiveUtils.asListOI(this.countsField.getFieldObjectInspector());
                this.countOI = HiveUtils.asLongOI(this.countsOI.getListElementObjectInspector());
                this.meansField = this.structOI.getStructFieldRef("means");
                this.meansOI = HiveUtils.asListOI(this.meansField.getFieldObjectInspector());
                this.meanListOI = HiveUtils.asListOI(this.meansOI.getListElementObjectInspector());
                this.meanElemOI = HiveUtils.asDoubleOI(this.meanListOI.getListElementObjectInspector());
                this.variancesField = this.structOI.getStructFieldRef("variances");
                this.variancesOI = HiveUtils.asListOI(this.variancesField.getFieldObjectInspector());
                this.varianceListOI = HiveUtils.asListOI(this.variancesOI.getListElementObjectInspector());
                this.varianceElemOI = HiveUtils.asDoubleOI(this.varianceListOI.getListElementObjectInspector());
            }
            if (mode != GenericUDAFEvaluator.Mode.PARTIAL1 && mode != GenericUDAFEvaluator.Mode.PARTIAL2) {
                return ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
            }
            ArrayList arrayList = new ArrayList();
            arrayList.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableLongObjectInspector));
            arrayList.add(ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)));
            arrayList.add(ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)));
            return ObjectInspectorFactory.getStandardStructObjectInspector(Arrays.asList("counts", "means", "variances"), arrayList);
        }

        /* renamed from: getNewAggregationBuffer, reason: merged with bridge method [inline-methods] */
        public GenericUDAFEvaluator.AbstractAggregationBuffer m81getNewAggregationBuffer() throws HiveException {
            SignalNoiseRatioAggregationBuffer signalNoiseRatioAggregationBuffer = new SignalNoiseRatioAggregationBuffer();
            reset(signalNoiseRatioAggregationBuffer);
            return signalNoiseRatioAggregationBuffer;
        }

        public void reset(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer) throws HiveException {
            ((SignalNoiseRatioAggregationBuffer) aggregationBuffer).reset();
        }

        public void iterate(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer, Object[] objArr) throws HiveException {
            Object obj = objArr[0];
            Object obj2 = objArr[1];
            Preconditions.checkNotNull(obj);
            Preconditions.checkNotNull(obj2);
            SignalNoiseRatioAggregationBuffer signalNoiseRatioAggregationBuffer = (SignalNoiseRatioAggregationBuffer) aggregationBuffer;
            List list = this.labelsOI.getList(obj2);
            int size = list.size();
            Preconditions.checkArgument(size >= 2, UDFArgumentException.class);
            List list2 = this.featuresOI.getList(obj);
            int size2 = list2.size();
            Preconditions.checkArgument(size2 >= 1, UDFArgumentException.class);
            if (signalNoiseRatioAggregationBuffer.counts == null) {
                signalNoiseRatioAggregationBuffer.init(size, size2);
            } else {
                Preconditions.checkArgument(size == signalNoiseRatioAggregationBuffer.counts.length, UDFArgumentException.class);
                Preconditions.checkArgument(size2 == signalNoiseRatioAggregationBuffer.means[0].length, UDFArgumentException.class);
            }
            int hotIndex = hotIndex(list, this.labelOI);
            long j = signalNoiseRatioAggregationBuffer.counts[hotIndex];
            long[] jArr = signalNoiseRatioAggregationBuffer.counts;
            jArr[hotIndex] = jArr[hotIndex] + 1;
            for (int i = 0; i < size2; i++) {
                double d = PrimitiveObjectInspectorUtils.getDouble(list2.get(i), this.featureOI);
                double d2 = signalNoiseRatioAggregationBuffer.means[hotIndex][i];
                double d3 = signalNoiseRatioAggregationBuffer.variances[hotIndex][i];
                signalNoiseRatioAggregationBuffer.means[hotIndex][i] = ((j * d2) + d) / (j + 1.0d);
                signalNoiseRatioAggregationBuffer.variances[hotIndex][i] = ((j * d3) + ((d - d2) * (d - signalNoiseRatioAggregationBuffer.means[hotIndex][i]))) / (j + 1.0d);
            }
        }

        private static int hotIndex(@Nonnull List<?> list, PrimitiveObjectInspector primitiveObjectInspector) throws UDFArgumentException {
            int size = list.size();
            int i = -1;
            for (int i2 = 0; i2 < size; i2++) {
                int i3 = PrimitiveObjectInspectorUtils.getInt(list.get(i2), primitiveObjectInspector);
                if (i3 == 1) {
                    if (i != -1) {
                        throw new UDFArgumentException("Specify one-hot vectorized array. Multiple hot elements found.");
                    }
                    i = i2;
                } else if (i3 != 0) {
                    throw new UDFArgumentException("Assumed one-hot encoding (0/1) but found an invalid label: " + i3);
                }
            }
            if (i == -1) {
                throw new UDFArgumentException("Specify one-hot vectorized array for label. Hot element not found.");
            }
            return i;
        }

        public void merge(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer, Object obj) throws HiveException {
            if (obj == null) {
                return;
            }
            SignalNoiseRatioAggregationBuffer signalNoiseRatioAggregationBuffer = (SignalNoiseRatioAggregationBuffer) aggregationBuffer;
            List list = this.countsOI.getList(this.structOI.getStructFieldData(obj, this.countsField));
            List list2 = this.meansOI.getList(this.structOI.getStructFieldData(obj, this.meansField));
            List list3 = this.variancesOI.getList(this.structOI.getStructFieldData(obj, this.variancesField));
            int size = list.size();
            int listLength = this.meanListOI.getListLength(list2.get(0));
            if (signalNoiseRatioAggregationBuffer.counts == null) {
                signalNoiseRatioAggregationBuffer.init(size, listLength);
            }
            for (int i = 0; i < size; i++) {
                long j = signalNoiseRatioAggregationBuffer.counts[i];
                long j2 = PrimitiveObjectInspectorUtils.getLong(list.get(i), this.countOI);
                if (j2 != 0) {
                    List list4 = this.meanListOI.getList(list2.get(i));
                    List list5 = this.varianceListOI.getList(list3.get(i));
                    long[] jArr = signalNoiseRatioAggregationBuffer.counts;
                    int i2 = i;
                    jArr[i2] = jArr[i2] + j2;
                    for (int i3 = 0; i3 < listLength; i3++) {
                        double d = signalNoiseRatioAggregationBuffer.means[i][i3];
                        double d2 = PrimitiveObjectInspectorUtils.getDouble(list4.get(i3), this.meanElemOI);
                        double d3 = signalNoiseRatioAggregationBuffer.variances[i][i3];
                        double d4 = PrimitiveObjectInspectorUtils.getDouble(list5.get(i3), this.varianceElemOI);
                        if (j == 0) {
                            signalNoiseRatioAggregationBuffer.means[i][i3] = d2;
                            signalNoiseRatioAggregationBuffer.variances[i][i3] = d4;
                        } else {
                            signalNoiseRatioAggregationBuffer.means[i][i3] = ((j * d) + (j2 * d2)) / (j + j2);
                            signalNoiseRatioAggregationBuffer.variances[i][i3] = (((d3 * (j - 1)) + (d4 * (j2 - 1))) + (((Math.pow(d - d2, 2.0d) * j) * j2) / (j + j2))) / ((j + j2) - 1);
                        }
                    }
                }
            }
        }

        public Object terminatePartial(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer) throws HiveException {
            SignalNoiseRatioAggregationBuffer signalNoiseRatioAggregationBuffer = (SignalNoiseRatioAggregationBuffer) aggregationBuffer;
            Object[] objArr = new Object[3];
            objArr[0] = WritableUtils.toWritableList(signalNoiseRatioAggregationBuffer.counts);
            ArrayList arrayList = new ArrayList();
            for (double[] dArr : signalNoiseRatioAggregationBuffer.means) {
                arrayList.add(WritableUtils.toWritableList(dArr));
            }
            objArr[1] = arrayList;
            ArrayList arrayList2 = new ArrayList();
            for (double[] dArr2 : signalNoiseRatioAggregationBuffer.variances) {
                arrayList2.add(WritableUtils.toWritableList(dArr2));
            }
            objArr[2] = arrayList2;
            return objArr;
        }

        public Object terminate(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer) throws HiveException {
            SignalNoiseRatioAggregationBuffer signalNoiseRatioAggregationBuffer = (SignalNoiseRatioAggregationBuffer) aggregationBuffer;
            int length = signalNoiseRatioAggregationBuffer.counts.length;
            int length2 = signalNoiseRatioAggregationBuffer.means[0].length;
            double[] dArr = new double[length2];
            double[] dArr2 = new double[length];
            for (int i = 0; i < length2; i++) {
                dArr2[0] = Math.sqrt(signalNoiseRatioAggregationBuffer.variances[0][i]);
                for (int i2 = 1; i2 < length; i2++) {
                    dArr2[i2] = Math.sqrt(signalNoiseRatioAggregationBuffer.variances[i2][i]);
                    if (signalNoiseRatioAggregationBuffer.counts[i2] != 0) {
                        for (int i3 = 0; i3 < i2; i3++) {
                            if (signalNoiseRatioAggregationBuffer.counts[i3] != 0 && (signalNoiseRatioAggregationBuffer.counts[i2] != 1 || signalNoiseRatioAggregationBuffer.counts[i3] != 1)) {
                                double abs = Math.abs(signalNoiseRatioAggregationBuffer.means[i2][i] - signalNoiseRatioAggregationBuffer.means[i3][i]) / (dArr2[i2] + dArr2[i3]);
                                if (!Double.isNaN(abs)) {
                                    int i4 = i;
                                    dArr[i4] = dArr[i4] + abs;
                                }
                            }
                        }
                    }
                }
            }
            return WritableUtils.toWritableList(dArr);
        }
    }

    public GenericUDAFEvaluator getEvaluator(GenericUDAFParameterInfo genericUDAFParameterInfo) throws SemanticException {
        ObjectInspector[] parameterObjectInspectors = genericUDAFParameterInfo.getParameterObjectInspectors();
        if (parameterObjectInspectors.length != 2) {
            throw new UDFArgumentLengthException("Specify two arguments: " + parameterObjectInspectors.length);
        }
        if (!HiveUtils.isNumberListOI(parameterObjectInspectors[0])) {
            throw new UDFArgumentTypeException(0, "Only array<number> type argument is acceptable but " + parameterObjectInspectors[0].getTypeName() + " was passed as `features`");
        }
        if (HiveUtils.isListOI(parameterObjectInspectors[1]) && HiveUtils.isIntegerOI(((ListObjectInspector) parameterObjectInspectors[1]).getListElementObjectInspector())) {
            return new SignalNoiseRatioUDAFEvaluator();
        }
        throw new UDFArgumentTypeException(1, "Only array<int> type argument is acceptable but " + parameterObjectInspectors[1].getTypeName() + " was passed as `labels`");
    }
}
