package hivemall.tools.array;

import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.Preconditions;
import java.io.IOException;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;

@UDFType(deterministic = true, stateful = false)
@Description(name = "select_k_best", value = "_FUNC_(array<number> array, const array<number> importance, const int k) - Returns selected top-k elements as array<double>")
/* loaded from: input_file:hivemall/tools/array/SelectKBestUDF.class */
public final class SelectKBestUDF extends GenericUDF {
    private ListObjectInspector featuresOI;
    private PrimitiveObjectInspector featureOI;
    private ListObjectInspector importanceListOI;
    private PrimitiveObjectInspector importanceElemOI;
    private int _k;
    private List<DoubleWritable> _result;
    private int[] _topKIndices;

    public ObjectInspector initialize(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        if (objectInspectorArr.length != 3) {
            throw new UDFArgumentLengthException("Specify three arguments: " + objectInspectorArr.length);
        }
        if (!HiveUtils.isNumberListOI(objectInspectorArr[0])) {
            throw new UDFArgumentTypeException(0, "Only array<number> type argument is acceptable but " + objectInspectorArr[0].getTypeName() + " was passed as `features`");
        }
        if (!HiveUtils.isNumberListOI(objectInspectorArr[1])) {
            throw new UDFArgumentTypeException(1, "Only array<number> type argument is acceptable but " + objectInspectorArr[1].getTypeName() + " was passed as `importance_list`");
        }
        if (!HiveUtils.isIntegerOI(objectInspectorArr[2])) {
            throw new UDFArgumentTypeException(2, "Only int type argument is acceptable but " + objectInspectorArr[2].getTypeName() + " was passed as `k`");
        }
        this.featuresOI = HiveUtils.asListOI(objectInspectorArr[0]);
        this.featureOI = HiveUtils.asDoubleCompatibleOI(this.featuresOI.getListElementObjectInspector());
        this.importanceListOI = HiveUtils.asListOI(objectInspectorArr[1]);
        this.importanceElemOI = HiveUtils.asDoubleCompatibleOI(this.importanceListOI.getListElementObjectInspector());
        this._k = HiveUtils.getConstInt(objectInspectorArr[2]);
        Preconditions.checkArgument(this._k >= 1, UDFArgumentException.class);
        ArrayList arrayList = new ArrayList(this._k);
        for (int i = 0; i < this._k; i++) {
            arrayList.add(new DoubleWritable());
        }
        this._result = arrayList;
        return ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
    }

    /* renamed from: evaluate, reason: merged with bridge method [inline-methods] */
    public List<DoubleWritable> m177evaluate(GenericUDF.DeferredObject[] deferredObjectArr) throws HiveException {
        double[] asDoubleArray = HiveUtils.asDoubleArray(deferredObjectArr[0].get(), this.featuresOI, this.featureOI);
        double[] asDoubleArray2 = HiveUtils.asDoubleArray(deferredObjectArr[1].get(), this.importanceListOI, this.importanceElemOI);
        Preconditions.checkNotNull(asDoubleArray, UDFArgumentException.class);
        Preconditions.checkNotNull(asDoubleArray2, UDFArgumentException.class);
        Preconditions.checkArgument(asDoubleArray.length == asDoubleArray2.length, UDFArgumentException.class);
        Preconditions.checkArgument(asDoubleArray.length >= this._k, UDFArgumentException.class);
        int[] iArr = this._topKIndices;
        if (iArr == null) {
            ArrayList arrayList = new ArrayList();
            for (int i = 0; i < asDoubleArray2.length; i++) {
                arrayList.add(new AbstractMap.SimpleEntry(Integer.valueOf(i), Double.valueOf(asDoubleArray2[i])));
            }
            Collections.sort(arrayList, new Comparator<Map.Entry<Integer, Double>>() { // from class: hivemall.tools.array.SelectKBestUDF.1
                @Override // java.util.Comparator
                public int compare(Map.Entry<Integer, Double> entry, Map.Entry<Integer, Double> entry2) {
                    return entry.getValue().doubleValue() > entry2.getValue().doubleValue() ? -1 : 1;
                }
            });
            iArr = new int[this._k];
            for (int i2 = 0; i2 < iArr.length; i2++) {
                iArr[i2] = ((Integer) ((Map.Entry) arrayList.get(i2)).getKey()).intValue();
            }
            this._topKIndices = iArr;
        }
        List<DoubleWritable> list = this._result;
        for (int i3 = 0; i3 < iArr.length; i3++) {
            list.get(i3).set(asDoubleArray[iArr[i3]]);
        }
        return list;
    }

    public void close() throws IOException {
        this._result = null;
        this._topKIndices = null;
    }

    public String getDisplayString(String[] strArr) {
        StringBuilder sb = new StringBuilder();
        sb.append("select_k_best");
        sb.append("(");
        if (strArr.length > 0) {
            sb.append(strArr[0]);
            for (int i = 1; i < strArr.length; i++) {
                sb.append(", ");
                sb.append(strArr[i]);
            }
        }
        sb.append(")");
        return sb.toString();
    }
}
