package hivemall.ftvec.text;

import hivemall.UDFWithOptions;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.Primitives;
import hivemall.utils.lang.StringUtils;
import javax.annotation.Nonnull;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;

@UDFType(deterministic = true, stateful = false)
@Description(name = "bm25", value = "_FUNC_(double termFrequency, int docLength, double avgDocLength, int numDocs, int numDocsWithTerm [, const string options]) - Return an Okapi BM25 score in double. Refer http://hivemall.incubator.apache.org/userguide/ft_engineering/bm25.html for usage")
/* loaded from: input_file:hivemall/ftvec/text/OkapiBM25UDF.class */
public final class OkapiBM25UDF extends UDFWithOptions {
    private PrimitiveObjectInspector frequencyOI;
    private PrimitiveObjectInspector docLengthOI;
    private PrimitiveObjectInspector averageDocLengthOI;
    private PrimitiveObjectInspector numDocsOI;
    private PrimitiveObjectInspector numDocsWithTermOI;
    private double k1 = 1.2d;
    private double b = 0.75d;
    private double delta = 0.0d;
    private double minIDF = 1.0E-8d;

    @Nonnull
    private final DoubleWritable result = new DoubleWritable();

    @Override // hivemall.UDFWithOptions
    protected Options getOptions() {
        Options options = new Options();
        options.addOption("k1", true, "Hyperparameter with type double, usually in range 1.2 and 2.0 [default: 1.2]");
        options.addOption("b", true, "Hyperparameter with type double in range 0.0 and 1.0 [default: 0.75]");
        options.addOption("d", "delta", true, "Hyperparameter delta of BM25+ [default: 0.0]");
        options.addOption("min_idf", "epsilon", true, "Hyperparameter delta of BM25+ [default: 1e-8]");
        return options;
    }

    @Override // hivemall.UDFWithOptions
    protected CommandLine processOptions(@Nonnull String str) throws UDFArgumentException {
        CommandLine parseOptions = parseOptions(str);
        this.k1 = Primitives.parseDouble(parseOptions.getOptionValue("k1"), this.k1);
        if (!Primitives.isFinite(this.k1) || this.k1 < 0.0d) {
            throw new UDFArgumentException("k1 must be a non-negative finite value: " + this.k1);
        }
        this.b = Primitives.parseDouble(parseOptions.getOptionValue("b"), this.b);
        if (Double.isNaN(this.b) || this.b < 0.0d || this.b > 1.0d) {
            throw new UDFArgumentException("b1 hyperparameter must be in the range [0.0, 1.0]: " + this.b);
        }
        this.delta = Primitives.parseDouble(parseOptions.getOptionValue("delta"), this.delta);
        if (!Primitives.isFinite(this.delta)) {
            throw new UDFArgumentException("Delta must be a finite value: " + this.delta);
        }
        this.minIDF = Primitives.parseDouble(parseOptions.getOptionValue("min_idf"), this.minIDF);
        if (this.minIDF < 0.0d) {
            throw new UDFArgumentException("min_idf must not be negative value: " + this.minIDF);
        }
        return parseOptions;
    }

    public ObjectInspector initialize(@Nonnull ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        int length = objectInspectorArr.length;
        if (length < 5) {
            showHelp("#arguments must be greater than or equal to 5: " + length);
        } else if (length == 6) {
            processOptions(HiveUtils.getConstString(objectInspectorArr[5]));
        }
        this.frequencyOI = HiveUtils.asDoubleCompatibleOI(objectInspectorArr[0]);
        this.docLengthOI = HiveUtils.asIntegerOI(objectInspectorArr[1]);
        this.averageDocLengthOI = HiveUtils.asDoubleCompatibleOI(objectInspectorArr[2]);
        this.numDocsOI = HiveUtils.asIntegerOI(objectInspectorArr[3]);
        this.numDocsWithTermOI = HiveUtils.asIntegerOI(objectInspectorArr[4]);
        return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
    }

    /* renamed from: evaluate, reason: merged with bridge method [inline-methods] */
    public DoubleWritable m92evaluate(@Nonnull GenericUDF.DeferredObject[] deferredObjectArr) throws HiveException {
        Object obj = deferredObjectArr[0].get();
        Object obj2 = deferredObjectArr[1].get();
        Object obj3 = deferredObjectArr[2].get();
        Object obj4 = deferredObjectArr[3].get();
        Object obj5 = deferredObjectArr[4].get();
        if (obj == null || obj2 == null || obj3 == null || obj4 == null || obj5 == null) {
            throw new UDFArgumentException("Required arguments cannot be null");
        }
        double d = PrimitiveObjectInspectorUtils.getDouble(obj, this.frequencyOI);
        int i = PrimitiveObjectInspectorUtils.getInt(obj2, this.docLengthOI);
        double d2 = PrimitiveObjectInspectorUtils.getDouble(obj3, this.averageDocLengthOI);
        int i2 = PrimitiveObjectInspectorUtils.getInt(obj4, this.numDocsOI);
        int i3 = PrimitiveObjectInspectorUtils.getInt(obj5, this.numDocsWithTermOI);
        assumeFalse(d < 0.0d, "#frequency must be positive");
        assumeFalse(i < 1, "#docLength must be greater than or equal to 1");
        assumeFalse(d2 <= 0.0d, "#averageDocLength must be positive");
        assumeFalse(i2 < 1, "#numDocs must be greater than or equal to 1");
        assumeFalse(i3 < 1, "#numDocsWithTerm must be greater than or equal to 1");
        this.result.set(bm25(d, i, d2, i2, i3));
        return this.result;
    }

    private double bm25(double d, int i, double d2, int i2, int i3) {
        return Math.max(this.minIDF, idf(i2, i3)) * (((d * (this.k1 + 1.0d)) / (d + (this.k1 * ((1.0d - this.b) + ((this.b * i) / d2))))) + this.delta);
    }

    private static double idf(int i, int i2) {
        return Math.log10(1.0d + (((i - i2) + 0.5d) / (i2 + 0.5d)));
    }

    public String getDisplayString(String[] strArr) {
        return "bm25(" + StringUtils.join((Object[]) strArr, ',') + ")";
    }
}
