package hivemall.ftvec.hashing;

import hivemall.HivemallConstants;
import hivemall.UDFWithOptions;
import hivemall.annotations.VisibleForTesting;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.hashing.MurmurHash3;
import hivemall.utils.lang.Primitives;
import hivemall.utils.lang.StringUtils;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;

@UDFType(deterministic = true, stateful = false)
@Description(name = "feature_hashing", value = "_FUNC_(array<string> features [, const string options]) - returns a hashed feature vector in array<string>", extended = "select feature_hashing(array('aaa:1.0','aaa','bbb:2.0'), '-libsvm');\n [\"4063537:1.0\",\"4063537:1\",\"8459207:2.0\"]\n\nselect feature_hashing(array('aaa:1.0','aaa','bbb:2.0'), '-features 10');\n [\"7:1.0\",\"7\",\"1:2.0\"]\n\nselect feature_hashing(array('aaa:1.0','aaa','bbb:2.0'), '-features 10 -libsvm');\n [\"1:2.0\",\"7:1.0\",\"7:1\"]\n")
/* loaded from: input_file:hivemall/ftvec/hashing/FeatureHashingUDF.class */
public final class FeatureHashingUDF extends UDFWithOptions {
    private static final IndexComparator indexCmp = new IndexComparator();

    @Nullable
    private ListObjectInspector _listOI;
    private boolean _libsvmFormat = false;
    private int _numFeatures = 16777216;

    @Nullable
    private transient List<String> _returnObj;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hivemall/ftvec/hashing/FeatureHashingUDF$IndexComparator.class */
    public static final class IndexComparator implements Comparator<String>, Serializable {
        private static final long serialVersionUID = -260142385860586255L;

        private IndexComparator() {
        }

        @Override // java.util.Comparator
        public int compare(@Nonnull String str, @Nonnull String str2) {
            return Integer.compare(getIndex(str), getIndex(str2));
        }

        private static int getIndex(@Nonnull String str) {
            int indexOf = str.indexOf(58);
            int lastIndexOf = str.lastIndexOf(58);
            return Integer.parseInt(indexOf == lastIndexOf ? str.substring(0, indexOf) : str.substring(indexOf + 1, lastIndexOf));
        }
    }

    @Override // hivemall.UDFWithOptions
    protected Options getOptions() {
        Options options = new Options();
        options.addOption("libsvm", false, "Returns in libsvm format (<index>:<value>)* sorted by index ascending order");
        options.addOption("features", "num_features", true, "The number of features [default: 16777217 (2^24)]");
        return options;
    }

    @Override // hivemall.UDFWithOptions
    protected CommandLine processOptions(@Nonnull String str) throws UDFArgumentException {
        CommandLine parseOptions = parseOptions(str);
        this._libsvmFormat = parseOptions.hasOption("libsvm");
        this._numFeatures = Primitives.parseInt(parseOptions.getOptionValue("num_features"), this._numFeatures);
        return parseOptions;
    }

    public ObjectInspector initialize(@Nonnull ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        if (objectInspectorArr.length != 1 && objectInspectorArr.length != 2) {
            showHelp("The feature_hashing function takes 1 or 2 arguments: " + objectInspectorArr.length);
        }
        ObjectInspector objectInspector = objectInspectorArr[0];
        this._listOI = HiveUtils.isListOI(objectInspector) ? (ListObjectInspector) objectInspector : null;
        if (objectInspectorArr.length == 2) {
            processOptions(HiveUtils.getConstString(objectInspectorArr[1]));
        }
        return this._listOI == null ? PrimitiveObjectInspectorFactory.javaStringObjectInspector : ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector);
    }

    public Object evaluate(@Nonnull GenericUDF.DeferredObject[] deferredObjectArr) throws HiveException {
        Object obj = deferredObjectArr[0].get();
        if (obj == null) {
            return null;
        }
        return this._listOI == null ? evaluateScalar(obj) : evaluateList(obj);
    }

    @Nonnull
    private String evaluateScalar(@Nonnull Object obj) {
        return featureHashing(obj.toString(), this._numFeatures, this._libsvmFormat);
    }

    @Nonnull
    private List<String> evaluateList(@Nonnull Object obj) throws HiveException {
        int listLength = this._listOI.getListLength(obj);
        List<String> list = this._returnObj;
        if (list == null) {
            list = new ArrayList(listLength);
            this._returnObj = list;
        } else {
            list.clear();
        }
        int i = this._numFeatures;
        for (int i2 = 0; i2 < listLength; i2++) {
            Object listElement = this._listOI.getListElement(obj, i2);
            if (listElement != null) {
                list.add(featureHashing(listElement.toString(), i, this._libsvmFormat));
            }
        }
        if (this._libsvmFormat) {
            try {
                Collections.sort(list, indexCmp);
            } catch (NumberFormatException e) {
                throw new HiveException(e);
            }
        }
        return list;
    }

    @VisibleForTesting
    @Nonnull
    static String featureHashing(@Nonnull String str, int i) {
        return featureHashing(str, i, false);
    }

    @Nonnull
    static String featureHashing(@Nonnull String str, int i, boolean z) {
        int indexOf = str.indexOf(58);
        if (indexOf == -1) {
            if (str.equals(HivemallConstants.BIAS_CLAUSE)) {
                return str;
            }
            int mhash = mhash(str, i);
            return z ? mhash + ":1" : String.valueOf(mhash);
        }
        int lastIndexOf = str.lastIndexOf(58);
        if (indexOf != lastIndexOf) {
            return str.substring(0, indexOf + 1) + mhash(str.substring(indexOf + 1, lastIndexOf), i) + str.substring(lastIndexOf);
        }
        String substring = str.substring(0, indexOf);
        String substring2 = str.substring(indexOf);
        if (substring.equals(HivemallConstants.BIAS_CLAUSE) && Double.parseDouble(str.substring(indexOf + 1)) == 1.0d) {
            return str;
        }
        return mhash(substring, i) + substring2;
    }

    static int mhash(@Nonnull String str, int i) {
        int murmurhash3_x86_32 = MurmurHash3.murmurhash3_x86_32(str, 0, str.length(), -1756908916) % i;
        if (murmurhash3_x86_32 < 0) {
            murmurhash3_x86_32 += i;
        }
        return murmurhash3_x86_32 + 1;
    }

    public String getDisplayString(String[] strArr) {
        return "feature_hashing(" + StringUtils.join((Object[]) strArr, ',') + ')';
    }
}
