package hivemall.ftvec.trans;

import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hadoop.WritableUtils;
import hivemall.utils.lang.Identifier;
import hivemall.utils.lang.Preconditions;
import hivemall.utils.lang.SizeOf;
import java.util.ArrayList;
import java.util.Map;
import java.util.Set;
import javax.annotation.CheckForNull;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.Writable;

@UDFType(deterministic = true, stateful = true)
@Description(name = "onehot_encoding", value = "_FUNC_(PRIMITIVE feature, ...) - Compute onehot encoded label for each feature")
/* loaded from: input_file:hivemall/ftvec/trans/OnehotEncodingUDAF.class */
public final class OnehotEncodingUDAF extends AbstractGenericUDAFResolver {

    /* renamed from: hivemall.ftvec.trans.OnehotEncodingUDAF$1, reason: invalid class name */
    /* loaded from: input_file:hivemall/ftvec/trans/OnehotEncodingUDAF$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$apache$hadoop$hive$ql$udf$generic$GenericUDAFEvaluator$Mode = new int[GenericUDAFEvaluator.Mode.values().length];

        static {
            try {
                $SwitchMap$org$apache$hadoop$hive$ql$udf$generic$GenericUDAFEvaluator$Mode[GenericUDAFEvaluator.Mode.PARTIAL1.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$apache$hadoop$hive$ql$udf$generic$GenericUDAFEvaluator$Mode[GenericUDAFEvaluator.Mode.PARTIAL2.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$apache$hadoop$hive$ql$udf$generic$GenericUDAFEvaluator$Mode[GenericUDAFEvaluator.Mode.COMPLETE.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$apache$hadoop$hive$ql$udf$generic$GenericUDAFEvaluator$Mode[GenericUDAFEvaluator.Mode.FINAL.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
        }
    }

    /* loaded from: input_file:hivemall/ftvec/trans/OnehotEncodingUDAF$EncodingBuffer.class */
    public static final class EncodingBuffer extends GenericUDAFEvaluator.AbstractAggregationBuffer {

        @Nullable
        private Identifier<Writable>[] identifiers;

        void reset() {
            this.identifiers = null;
        }

        void iterate(@Nonnull Object[] objArr, @Nonnull PrimitiveObjectInspector[] primitiveObjectInspectorArr) throws HiveException {
            Preconditions.checkArgument(objArr.length == primitiveObjectInspectorArr.length);
            int length = objArr.length;
            if (this.identifiers == null) {
                this.identifiers = new Identifier[length];
                for (int i = 0; i < length; i++) {
                    this.identifiers[i] = new Identifier<>(1);
                }
            }
            for (int i2 = 0; i2 < length; i2++) {
                Object obj = objArr[i2];
                if (obj != null) {
                    this.identifiers[i2].put(WritableUtils.copyToWritable(obj, primitiveObjectInspectorArr[i2]));
                }
            }
        }

        @Nullable
        Object[] partial() throws HiveException {
            if (this.identifiers == null) {
                return null;
            }
            int length = this.identifiers.length;
            Object[] objArr = new Object[length];
            for (int i = 0; i < length; i++) {
                Set<Writable> keySet = this.identifiers[i].getMap().keySet();
                ArrayList arrayList = new ArrayList(keySet.size());
                for (Writable writable : keySet) {
                    Preconditions.checkNotNull(writable);
                    arrayList.add(writable);
                }
                objArr[i] = arrayList;
            }
            return objArr;
        }

        void merge(@Nonnull Object obj, @Nonnull StructObjectInspector structObjectInspector, @Nonnull StructField[] structFieldArr, @Nonnull ListObjectInspector[] listObjectInspectorArr) {
            Preconditions.checkArgument(structFieldArr.length == listObjectInspectorArr.length);
            int length = listObjectInspectorArr.length;
            if (this.identifiers == null) {
                this.identifiers = new Identifier[length];
            }
            Preconditions.checkArgument(structFieldArr.length == this.identifiers.length);
            for (int i = 0; i < length; i++) {
                Identifier<Writable> identifier = this.identifiers[i];
                if (identifier == null) {
                    identifier = new Identifier<>(1);
                    this.identifiers[i] = identifier;
                }
                Object structFieldData = structObjectInspector.getStructFieldData(obj, structFieldArr[i]);
                ListObjectInspector listObjectInspector = listObjectInspectorArr[i];
                int listLength = listObjectInspector.getListLength(structFieldData);
                for (int i2 = 0; i2 < listLength; i2++) {
                    Object listElement = listObjectInspector.getListElement(structFieldData, i2);
                    Preconditions.checkNotNull(listElement);
                    identifier.valueOf((Writable) listElement);
                }
            }
        }

        @Nullable
        Object[] terminate() {
            if (this.identifiers == null) {
                return null;
            }
            Object[] objArr = new Object[this.identifiers.length];
            int i = 0;
            for (int i2 = 0; i2 < this.identifiers.length; i2++) {
                Map<Writable, Integer> map = this.identifiers[i2].getMap();
                if (i != 0) {
                    for (Map.Entry<Writable, Integer> entry : map.entrySet()) {
                        entry.setValue(Integer.valueOf(i + entry.getValue().intValue()));
                    }
                }
                objArr[i2] = map;
                i += map.size();
            }
            return objArr;
        }
    }

    /* loaded from: input_file:hivemall/ftvec/trans/OnehotEncodingUDAF$GenericUDAFOnehotEncodingEvaluator.class */
    public static final class GenericUDAFOnehotEncodingEvaluator extends GenericUDAFEvaluator {
        private PrimitiveObjectInspector[] inputElemOIs;
        private StructObjectInspector mergeOI;
        private StructField[] fields;
        private ListObjectInspector[] fieldOIs;

        public ObjectInspector init(GenericUDAFEvaluator.Mode mode, ObjectInspector[] objectInspectorArr) throws HiveException {
            StructObjectInspector terminalOutputOI;
            super.init(mode, objectInspectorArr);
            if (mode == GenericUDAFEvaluator.Mode.PARTIAL1 || mode == GenericUDAFEvaluator.Mode.COMPLETE) {
                this.inputElemOIs = new PrimitiveObjectInspector[objectInspectorArr.length];
                for (int i = 0; i < objectInspectorArr.length; i++) {
                    this.inputElemOIs[i] = HiveUtils.asPrimitiveObjectInspector(objectInspectorArr[i]);
                }
            } else {
                Preconditions.checkArgument(objectInspectorArr.length == 1);
                this.mergeOI = HiveUtils.asStructOI(objectInspectorArr[0]);
                int size = this.mergeOI.getAllStructFieldRefs().size();
                this.fields = new StructField[size];
                this.fieldOIs = new ListObjectInspector[size];
                this.inputElemOIs = new PrimitiveObjectInspector[size];
                for (int i2 = 0; i2 < size; i2++) {
                    StructField structFieldRef = this.mergeOI.getStructFieldRef("f" + String.valueOf(i2));
                    this.fields[i2] = structFieldRef;
                    ListObjectInspector asListOI = HiveUtils.asListOI(structFieldRef.getFieldObjectInspector());
                    this.fieldOIs[i2] = asListOI;
                    this.inputElemOIs[i2] = HiveUtils.asPrimitiveObjectInspector(asListOI.getListElementObjectInspector());
                }
            }
            switch (AnonymousClass1.$SwitchMap$org$apache$hadoop$hive$ql$udf$generic$GenericUDAFEvaluator$Mode[mode.ordinal()]) {
                case SizeOf.BYTE /* 1 */:
                    terminalOutputOI = internalMergeOutputOI(this.inputElemOIs);
                    break;
                case 2:
                    terminalOutputOI = internalMergeOutputOI(this.inputElemOIs);
                    break;
                case 3:
                    terminalOutputOI = terminalOutputOI(this.inputElemOIs);
                    break;
                case 4:
                    terminalOutputOI = terminalOutputOI(this.inputElemOIs);
                    break;
                default:
                    throw new IllegalStateException("Illegal mode: " + mode);
            }
            return terminalOutputOI;
        }

        @Nonnull
        private static StructObjectInspector internalMergeOutputOI(@CheckForNull PrimitiveObjectInspector[] primitiveObjectInspectorArr) throws UDFArgumentException {
            Preconditions.checkNotNull(primitiveObjectInspectorArr);
            int length = primitiveObjectInspectorArr.length;
            ArrayList arrayList = new ArrayList(length);
            ArrayList arrayList2 = new ArrayList(length);
            for (int i = 0; i < length; i++) {
                arrayList.add("f" + String.valueOf(i));
                arrayList2.add(ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorUtils.getStandardObjectInspector(primitiveObjectInspectorArr[i], ObjectInspectorUtils.ObjectInspectorCopyOption.WRITABLE)));
            }
            return ObjectInspectorFactory.getStandardStructObjectInspector(arrayList, arrayList2);
        }

        @Nonnull
        private static StructObjectInspector terminalOutputOI(@CheckForNull PrimitiveObjectInspector[] primitiveObjectInspectorArr) {
            Preconditions.checkNotNull(primitiveObjectInspectorArr);
            Preconditions.checkArgument(primitiveObjectInspectorArr.length >= 1, Integer.valueOf(primitiveObjectInspectorArr.length));
            ArrayList arrayList = new ArrayList(primitiveObjectInspectorArr.length);
            ArrayList arrayList2 = new ArrayList(primitiveObjectInspectorArr.length);
            for (int i = 0; i < primitiveObjectInspectorArr.length; i++) {
                arrayList.add("f" + String.valueOf(i + 1));
                arrayList2.add(ObjectInspectorFactory.getStandardMapObjectInspector(ObjectInspectorUtils.getStandardObjectInspector(primitiveObjectInspectorArr[i], ObjectInspectorUtils.ObjectInspectorCopyOption.WRITABLE), PrimitiveObjectInspectorFactory.javaIntObjectInspector));
            }
            return ObjectInspectorFactory.getStandardStructObjectInspector(arrayList, arrayList2);
        }

        public GenericUDAFEvaluator.AggregationBuffer getNewAggregationBuffer() throws HiveException {
            EncodingBuffer encodingBuffer = new EncodingBuffer();
            reset(encodingBuffer);
            return encodingBuffer;
        }

        public void reset(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer) throws HiveException {
            ((EncodingBuffer) aggregationBuffer).reset();
        }

        public void iterate(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer, Object[] objArr) throws HiveException {
            Preconditions.checkNotNull(this.inputElemOIs);
            ((EncodingBuffer) aggregationBuffer).iterate(objArr, this.inputElemOIs);
        }

        /* renamed from: terminatePartial, reason: merged with bridge method [inline-methods] */
        public Object[] m89terminatePartial(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer) throws HiveException {
            return ((EncodingBuffer) aggregationBuffer).partial();
        }

        public void merge(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer, Object obj) throws HiveException {
            if (obj == null) {
                return;
            }
            ((EncodingBuffer) aggregationBuffer).merge(obj, this.mergeOI, this.fields, this.fieldOIs);
        }

        /* renamed from: terminate, reason: merged with bridge method [inline-methods] */
        public Object[] m88terminate(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer) throws HiveException {
            return ((EncodingBuffer) aggregationBuffer).terminate();
        }
    }

    public GenericUDAFEvaluator getEvaluator(@Nonnull TypeInfo[] typeInfoArr) throws SemanticException {
        int length = typeInfoArr.length;
        if (length == 0) {
            throw new UDFArgumentException("_FUNC_ requires at least 1 argument");
        }
        for (int i = 0; i < length; i++) {
            if (typeInfoArr[i] == null) {
                throw new UDFArgumentTypeException(i, "Null type is found. Only primitive type arguments are accepted.");
            }
            if (typeInfoArr[i].getCategory() != ObjectInspector.Category.PRIMITIVE) {
                throw new UDFArgumentTypeException(i, "Only primitive type arguments are accepted but " + typeInfoArr[i].getTypeName() + " was passed as parameter 1.");
            }
        }
        return new GenericUDAFOnehotEncodingEvaluator();
    }
}
