package org.apache.hadoop.hive.ql.udf.generic;

import java.util.ArrayList;
import java.util.List;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.Text;
import org.apache.hudi.org.apache.hadoop.hive.ql.metadata.HiveException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Description(name = "context_ngrams", value = "_FUNC_(expr, array<string1, string2, ...>, k, pf) estimates the top-k most frequent n-grams that fit into the specified context. The second parameter specifies a string of words that specify the positions of the n-gram elements, with a null value standing in for a 'blank' that must be filled by an n-gram element.", extended = "The primary expression must be an array of strings, or an array of arrays of strings, such as the return type of the sentences() UDF. The second parameter specifies the context -- for example, array(\"i\", \"love\", null) -- which would estimate the top 'k' words that follow the phrase \"i love\" in the primary expression. The optional fourth parameter 'pf' controls the memory used by the heuristic. Larger values will yield better accuracy, but use more memory. Example usage:\n  SELECT context_ngrams(sentences(lower(review)), array(\"i\", \"love\", null, null), 10) FROM movies\nwould attempt to determine the 10 most common two-word phrases that follow \"i love\" in a database of free-form natural language movie reviews.")
/* loaded from: input_file:org/apache/hadoop/hive/ql/udf/generic/GenericUDAFContextNGrams.class */
public class GenericUDAFContextNGrams implements GenericUDAFResolver {
    static final Logger LOG = LoggerFactory.getLogger(GenericUDAFContextNGrams.class.getName());

    /* loaded from: input_file:org/apache/hadoop/hive/ql/udf/generic/GenericUDAFContextNGrams$GenericUDAFContextNGramEvaluator.class */
    public static class GenericUDAFContextNGramEvaluator extends GenericUDAFEvaluator {
        private transient ListObjectInspector outerInputOI;
        private transient StandardListObjectInspector innerInputOI;
        private transient ListObjectInspector contextListOI;
        private PrimitiveObjectInspector contextOI;
        private PrimitiveObjectInspector inputOI;
        private transient PrimitiveObjectInspector kOI;
        private transient PrimitiveObjectInspector pOI;
        private transient ListObjectInspector loi;
        static final /* synthetic */ boolean $assertionsDisabled;

        /* JADX INFO: Access modifiers changed from: package-private */
        /* loaded from: input_file:org/apache/hadoop/hive/ql/udf/generic/GenericUDAFContextNGrams$GenericUDAFContextNGramEvaluator$NGramAggBuf.class */
        public static class NGramAggBuf extends GenericUDAFEvaluator.AbstractAggregationBuffer {
            ArrayList<String> context;
            NGramEstimator nge;

            NGramAggBuf() {
            }
        }

        @Override // org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator
        public ObjectInspector init(GenericUDAFEvaluator.Mode mode, ObjectInspector[] objectInspectorArr) throws HiveException {
            super.init(mode, objectInspectorArr);
            if (mode == GenericUDAFEvaluator.Mode.PARTIAL1 || mode == GenericUDAFEvaluator.Mode.COMPLETE) {
                this.outerInputOI = (ListObjectInspector) objectInspectorArr[0];
                if (this.outerInputOI.getListElementObjectInspector().getCategory() == ObjectInspector.Category.LIST) {
                    this.innerInputOI = (StandardListObjectInspector) this.outerInputOI.getListElementObjectInspector();
                    this.inputOI = (PrimitiveObjectInspector) this.innerInputOI.getListElementObjectInspector();
                } else {
                    this.inputOI = (PrimitiveObjectInspector) this.outerInputOI.getListElementObjectInspector();
                    this.innerInputOI = null;
                }
                this.contextListOI = (ListObjectInspector) objectInspectorArr[1];
                this.contextOI = (PrimitiveObjectInspector) this.contextListOI.getListElementObjectInspector();
                this.kOI = (PrimitiveObjectInspector) objectInspectorArr[2];
                if (objectInspectorArr.length == 4) {
                    this.pOI = (PrimitiveObjectInspector) objectInspectorArr[3];
                } else {
                    this.pOI = null;
                }
            } else {
                this.loi = (ListObjectInspector) objectInspectorArr[0];
            }
            if (mode == GenericUDAFEvaluator.Mode.PARTIAL1 || mode == GenericUDAFEvaluator.Mode.PARTIAL2) {
                return ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
            }
            ArrayList arrayList = new ArrayList();
            arrayList.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableStringObjectInspector));
            arrayList.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
            ArrayList arrayList2 = new ArrayList();
            arrayList2.add("ngram");
            arrayList2.add("estfrequency");
            return ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorFactory.getStandardStructObjectInspector(arrayList2, arrayList));
        }

        @Override // org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator
        public void merge(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer, Object obj) throws HiveException {
            if (obj == null) {
                return;
            }
            NGramAggBuf nGramAggBuf = (NGramAggBuf) aggregationBuffer;
            List<?> list = this.loi.getList(obj);
            int parseInt = Integer.parseInt(list.get(list.size() - 1).toString());
            list.remove(list.size() - 1);
            if (nGramAggBuf.context.size() > 0) {
                if (parseInt != nGramAggBuf.context.size()) {
                    throw new HiveException(getClass().getSimpleName() + ": found a mismatch in the context string lengths. This is usually caused by passing a non-constant expression for the context.");
                }
                return;
            }
            for (int size = list.size() - parseInt; size < list.size(); size++) {
                String obj2 = list.get(size).toString();
                if (obj2.equals("")) {
                    nGramAggBuf.context.add(null);
                } else {
                    nGramAggBuf.context.add(obj2);
                }
            }
            list.subList(list.size() - parseInt, list.size()).clear();
            nGramAggBuf.nge.merge(list);
        }

        @Override // org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator
        public Object terminatePartial(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer) throws HiveException {
            NGramAggBuf nGramAggBuf = (NGramAggBuf) aggregationBuffer;
            ArrayList<Text> serialize = nGramAggBuf.nge.serialize();
            for (int i = 0; i < nGramAggBuf.context.size(); i++) {
                if (nGramAggBuf.context.get(i) == null) {
                    serialize.add(new Text(""));
                } else {
                    serialize.add(new Text(nGramAggBuf.context.get(i)));
                }
            }
            serialize.add(new Text(Integer.toString(nGramAggBuf.context.size())));
            return serialize;
        }

        private void processNgrams(NGramAggBuf nGramAggBuf, ArrayList<String> arrayList) throws HiveException {
            if (!$assertionsDisabled && nGramAggBuf.context.size() <= 0) {
                throw new AssertionError();
            }
            ArrayList<String> arrayList2 = new ArrayList<>();
            for (int size = arrayList.size() - nGramAggBuf.context.size(); size >= 0; size--) {
                boolean z = true;
                arrayList2.clear();
                int i = 0;
                while (true) {
                    if (i >= nGramAggBuf.context.size()) {
                        break;
                    }
                    String str = nGramAggBuf.context.get(i);
                    if (str != null) {
                        if (!str.equals(arrayList.get(size + i))) {
                            z = false;
                            break;
                        }
                    } else {
                        arrayList2.add(arrayList.get(size + i));
                    }
                    i++;
                }
                if (z) {
                    nGramAggBuf.nge.add(arrayList2);
                    arrayList2 = new ArrayList<>();
                }
            }
        }

        @Override // org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator
        public void iterate(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer, Object[] objArr) throws HiveException {
            int i;
            if (!$assertionsDisabled && objArr.length != 3 && objArr.length != 4) {
                throw new AssertionError();
            }
            if (objArr[0] == null || objArr[1] == null || objArr[2] == null) {
                return;
            }
            NGramAggBuf nGramAggBuf = (NGramAggBuf) aggregationBuffer;
            if (!nGramAggBuf.nge.isInitialized()) {
                int i2 = PrimitiveObjectInspectorUtils.getInt(objArr[2], this.kOI);
                if (i2 < 1) {
                    throw new HiveException(getClass().getSimpleName() + " needs 'k' to be at least 1, but you supplied " + i2);
                }
                if (objArr.length == 4) {
                    i = PrimitiveObjectInspectorUtils.getInt(objArr[3], this.pOI);
                    if (i < 1) {
                        throw new HiveException(getClass().getSimpleName() + " needs 'pf' to be at least 1, but you supplied " + i);
                    }
                } else {
                    i = 1;
                }
                nGramAggBuf.context.clear();
                List<?> list = this.contextListOI.getList(objArr[1]);
                int i3 = 0;
                for (int i4 = 0; i4 < list.size(); i4++) {
                    String string = PrimitiveObjectInspectorUtils.getString(list.get(i4), this.contextOI);
                    if (string == null) {
                        i3++;
                    }
                    nGramAggBuf.context.add(string);
                }
                if (list.size() == 0) {
                    throw new HiveException(getClass().getSimpleName() + " needs a context array with at least one element.");
                }
                if (i3 == 0) {
                    throw new HiveException(getClass().getSimpleName() + " the context array needs to contain at least one 'null' value to indicate what should be counted.");
                }
                nGramAggBuf.nge.initialize(i2, i, i3);
            }
            List<?> list2 = this.outerInputOI.getList(objArr[0]);
            if (this.innerInputOI == null) {
                ArrayList<String> arrayList = new ArrayList<>();
                for (int i5 = 0; i5 < list2.size(); i5++) {
                    arrayList.add(PrimitiveObjectInspectorUtils.getString(list2.get(i5), this.inputOI));
                }
                processNgrams(nGramAggBuf, arrayList);
                return;
            }
            for (int i6 = 0; i6 < list2.size(); i6++) {
                List<?> list3 = this.innerInputOI.getList(list2.get(i6));
                ArrayList<String> arrayList2 = new ArrayList<>();
                for (int i7 = 0; i7 < list3.size(); i7++) {
                    arrayList2.add(PrimitiveObjectInspectorUtils.getString(list3.get(i7), this.inputOI));
                }
                processNgrams(nGramAggBuf, arrayList2);
            }
        }

        @Override // org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator
        public Object terminate(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer) throws HiveException {
            return ((NGramAggBuf) aggregationBuffer).nge.getNGrams();
        }

        @Override // org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator
        public GenericUDAFEvaluator.AggregationBuffer getNewAggregationBuffer() throws HiveException {
            NGramAggBuf nGramAggBuf = new NGramAggBuf();
            nGramAggBuf.nge = new NGramEstimator();
            nGramAggBuf.context = new ArrayList<>();
            reset(nGramAggBuf);
            return nGramAggBuf;
        }

        @Override // org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator
        public void reset(GenericUDAFEvaluator.AggregationBuffer aggregationBuffer) throws HiveException {
            NGramAggBuf nGramAggBuf = (NGramAggBuf) aggregationBuffer;
            nGramAggBuf.context.clear();
            nGramAggBuf.nge.reset();
        }

        static {
            $assertionsDisabled = !GenericUDAFContextNGrams.class.desiredAssertionStatus();
        }
    }

    @Override // org.apache.hadoop.hive.ql.udf.generic.GenericUDAFResolver
    public GenericUDAFEvaluator getEvaluator(TypeInfo[] typeInfoArr) throws SemanticException {
        PrimitiveTypeInfo primitiveTypeInfo;
        if (typeInfoArr.length != 3 && typeInfoArr.length != 4) {
            throw new UDFArgumentTypeException(typeInfoArr.length - 1, "Please specify either three or four arguments.");
        }
        if (typeInfoArr[0].getCategory() != ObjectInspector.Category.LIST) {
            throw new UDFArgumentTypeException(0, "Only list type arguments are accepted but " + typeInfoArr[0].getTypeName() + " was passed as parameter 1.");
        }
        switch (((ListTypeInfo) typeInfoArr[0]).getListElementTypeInfo().getCategory()) {
            case PRIMITIVE:
                primitiveTypeInfo = (PrimitiveTypeInfo) ((ListTypeInfo) typeInfoArr[0]).getListElementTypeInfo();
                break;
            case LIST:
                primitiveTypeInfo = (PrimitiveTypeInfo) ((ListTypeInfo) ((ListTypeInfo) typeInfoArr[0]).getListElementTypeInfo()).getListElementTypeInfo();
                break;
            default:
                throw new UDFArgumentTypeException(0, "Only arrays of strings or arrays of arrays of strings are accepted but " + typeInfoArr[0].getTypeName() + " was passed as parameter 1.");
        }
        if (primitiveTypeInfo.getPrimitiveCategory() != PrimitiveObjectInspector.PrimitiveCategory.STRING) {
            throw new UDFArgumentTypeException(0, "Only array<string> or array<array<string>> is allowed, but " + typeInfoArr[0].getTypeName() + " was passed as parameter 1.");
        }
        if (typeInfoArr[1].getCategory() != ObjectInspector.Category.LIST || ((ListTypeInfo) typeInfoArr[1]).getListElementTypeInfo().getCategory() != ObjectInspector.Category.PRIMITIVE) {
            throw new UDFArgumentTypeException(1, "Only arrays of strings are accepted but " + typeInfoArr[1].getTypeName() + " was passed as parameter 2.");
        }
        if (((PrimitiveTypeInfo) ((ListTypeInfo) typeInfoArr[1]).getListElementTypeInfo()).getPrimitiveCategory() != PrimitiveObjectInspector.PrimitiveCategory.STRING) {
            throw new UDFArgumentTypeException(1, "Only arrays of strings are accepted but " + typeInfoArr[1].getTypeName() + " was passed as parameter 2.");
        }
        if (typeInfoArr[2].getCategory() != ObjectInspector.Category.PRIMITIVE) {
            throw new UDFArgumentTypeException(2, "Only integers are accepted but " + typeInfoArr[2].getTypeName() + " was passed as parameter 3.");
        }
        switch (((PrimitiveTypeInfo) typeInfoArr[2]).getPrimitiveCategory()) {
            case BYTE:
            case SHORT:
            case INT:
            case LONG:
            case TIMESTAMP:
                if (typeInfoArr.length == 4) {
                    if (typeInfoArr[3].getCategory() != ObjectInspector.Category.PRIMITIVE) {
                        throw new UDFArgumentTypeException(3, "Only integers are accepted but " + typeInfoArr[3].getTypeName() + " was passed as parameter 4.");
                    }
                    switch (((PrimitiveTypeInfo) typeInfoArr[3]).getPrimitiveCategory()) {
                        case BYTE:
                        case SHORT:
                        case INT:
                        case LONG:
                        case TIMESTAMP:
                            break;
                        default:
                            throw new UDFArgumentTypeException(3, "Only integers are accepted but " + typeInfoArr[3].getTypeName() + " was passed as parameter 4.");
                    }
                }
                return new GenericUDAFContextNGramEvaluator();
            default:
                throw new UDFArgumentTypeException(2, "Only integers are accepted but " + typeInfoArr[2].getTypeName() + " was passed as parameter 3.");
        }
    }
}
