package hivemall.ftvec.ranking;

import hivemall.UDTFWithOptions;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.BitUtils;
import hivemall.utils.lang.Primitives;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Random;
import javax.annotation.Nonnull;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.io.IntWritable;

@Description(name = "item_pairs_sampling", value = "_FUNC_(array<int|long> pos_items, const int max_item_id [, const string options])- Returns a relation consists of <int pos_item_id, int neg_item_id>")
/* loaded from: input_file:hivemall/ftvec/ranking/ItemPairsSamplingUDTF.class */
public final class ItemPairsSamplingUDTF extends UDTFWithOptions {
    private ListObjectInspector listOI;
    private PrimitiveObjectInspector listElemOI;
    private int maxItemId;
    private boolean bitsetInput;
    private float samplingRate;
    private boolean withReplacement;
    private Object[] forwardObjs;
    private IntWritable posItemId;
    private IntWritable negItemId;
    private BitSet _bitset;
    private Random _rand;

    @Override // hivemall.UDTFWithOptions
    protected Options getOptions() {
        Options options = new Options();
        options.addOption("bitset", "bitset_input", false, "Use Bitset for the input of pos_items [default:false]");
        options.addOption("sampling", "sampling_rate", true, "Sampling rates of positive items [default: 1.0]");
        options.addOption("with_replacement", false, "Do sampling with-replacement [default: false]");
        return options;
    }

    @Override // hivemall.UDTFWithOptions
    protected CommandLine processOptions(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        CommandLine commandLine = null;
        boolean z = false;
        float f = 1.0f;
        boolean z2 = false;
        if (objectInspectorArr.length == 3) {
            commandLine = parseOptions(HiveUtils.getConstString(objectInspectorArr[2]));
            z = commandLine.hasOption("bitset_input");
            z2 = commandLine.hasOption("with_replacement");
            f = Primitives.parseFloat(commandLine.getOptionValue("sampling_rate"), 1.0f);
            if (!z2 && f > 1.0f) {
                throw new UDFArgumentException("sampling_rate MUST be in less than or equals to 1 where withReplacement is false: " + f);
            }
        }
        this.bitsetInput = z;
        this.samplingRate = f;
        this.withReplacement = z2;
        return commandLine;
    }

    public StructObjectInspector initialize(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        if (objectInspectorArr.length != 2 && objectInspectorArr.length != 3) {
            throw new UDFArgumentException("_FUNC_(array<long>, const long max_item_id [, const string options]) takes at least two arguments");
        }
        this.listOI = HiveUtils.asListOI(objectInspectorArr[0]);
        this.listElemOI = HiveUtils.asPrimitiveObjectInspector(this.listOI.getListElementObjectInspector());
        processOptions(objectInspectorArr);
        this.maxItemId = HiveUtils.getAsConstInt(objectInspectorArr[1]);
        if (this.maxItemId <= 0) {
            throw new UDFArgumentException("maxItemId MUST be greater than 0: " + this.maxItemId);
        }
        this.posItemId = new IntWritable();
        this.negItemId = new IntWritable();
        this.forwardObjs = new Object[]{this.posItemId, this.negItemId};
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        arrayList.add("pos_item");
        arrayList2.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
        arrayList.add("neg_item");
        arrayList2.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
        return ObjectInspectorFactory.getStandardStructObjectInspector(arrayList, arrayList2);
    }

    public void process(Object[] objArr) throws HiveException {
        BitSet bitSet;
        int bits;
        int i;
        if (this.bitsetInput) {
            if (this._rand == null) {
                this._rand = new Random(43L);
            }
            bitSet = BitSet.valueOf(HiveUtils.asLongArray(objArr[0], this.listOI, this.listElemOI));
            bits = bitSet.cardinality();
        } else {
            if (this._bitset == null) {
                bitSet = new BitSet();
                this._bitset = bitSet;
                this._rand = new Random(43L);
            } else {
                bitSet = this._bitset;
                bitSet.clear();
            }
            bits = HiveUtils.setBits(objArr[0], this.listOI, this.listElemOI, bitSet);
        }
        if (bits == 0 || (i = (this.maxItemId + 1) - bits) == 0) {
            return;
        }
        if (i < 0) {
            throw new UDFArgumentException("maxItemId + 1 - numPosItems = " + this.maxItemId + " + 1 - " + bits + " = " + i);
        }
        if (this.withReplacement) {
            sampleWithReplacement(bits, i, bitSet);
        } else {
            sampleWithoutReplacement(bits, i, bitSet);
        }
    }

    private void sampleWithReplacement(int i, int i2, @Nonnull BitSet bitSet) throws HiveException {
        int max = Math.max(1, Math.round(i * this.samplingRate));
        for (int i3 = 0; i3 < max; i3++) {
            int nextInt = this._rand.nextInt(i);
            int indexOfSetBit = BitUtils.indexOfSetBit(bitSet, nextInt);
            if (indexOfSetBit == -1) {
                throw new UDFArgumentException("Cannot find a value for " + nextInt + "-th element in bitset " + bitSet.toString() + " where numPosItems = " + i);
            }
            int indexOfClearBit = BitUtils.indexOfClearBit(bitSet, this._rand.nextInt(i2), this.maxItemId);
            if (indexOfClearBit < 0 || indexOfClearBit > this.maxItemId) {
                throw new UDFArgumentException("j MUST be in [0," + this.maxItemId + "] but j was " + indexOfClearBit);
            }
            this.posItemId.set(indexOfSetBit);
            this.negItemId.set(indexOfClearBit);
            forward(this.forwardObjs);
        }
    }

    private void sampleWithoutReplacement(int i, int i2, @Nonnull BitSet bitSet) throws HiveException {
        BitSet valueOf = BitSet.valueOf(bitSet.toLongArray());
        int max = Math.max(1, Math.round(i * this.samplingRate));
        for (int i3 = 0; i3 < max; i3++) {
            int nextInt = this._rand.nextInt(i);
            int indexOfSetBit = BitUtils.indexOfSetBit(bitSet, nextInt);
            if (indexOfSetBit == -1) {
                throw new UDFArgumentException("Cannot find a value for " + nextInt + "-th element in bitset " + bitSet.toString() + " where numPosItems = " + i);
            }
            bitSet.set(indexOfSetBit, false);
            i--;
            int indexOfClearBit = BitUtils.indexOfClearBit(valueOf, this._rand.nextInt(i2), this.maxItemId);
            if (indexOfClearBit < 0 || indexOfClearBit > this.maxItemId) {
                throw new UDFArgumentException("j MUST be in [0," + this.maxItemId + "] but j was " + indexOfClearBit);
            }
            valueOf.set(indexOfClearBit, true);
            i2--;
            this.posItemId.set(indexOfSetBit);
            this.negItemId.set(indexOfClearBit);
            forward(this.forwardObjs);
            if (i <= 0 || i2 <= 0) {
                return;
            }
        }
    }

    public void close() throws HiveException {
        this.listOI = null;
        this.listElemOI = null;
        this.forwardObjs = null;
        this.posItemId = null;
        this.negItemId = null;
        this._bitset = null;
        this._rand = null;
    }
}
