package hivemall.factorization.fm;

import hivemall.HivemallConstants;
import hivemall.UDTFWithOptions;
import hivemall.annotations.VisibleForTesting;
import hivemall.common.ConversionState;
import hivemall.factorization.fm.FMStringFeatureMapModel;
import hivemall.optimizer.EtaEstimator;
import hivemall.optimizer.LossFunctions;
import hivemall.sketch.bloom.BloomFilterUtils;
import hivemall.utils.collections.Fastutil;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.io.NioStatefulSegment;
import hivemall.utils.lang.Primitives;
import hivemall.utils.math.MathUtils;
import it.unimi.dsi.fastutil.objects.ObjectIterator;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Map;
import java.util.Random;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;

@Description(name = "train_fm", value = "_FUNC_(array<string> x, double y [, const string options]) - Returns a prediction model")
/* loaded from: input_file:hivemall/factorization/fm/FactorizationMachineUDTF.class */
public class FactorizationMachineUDTF extends UDTFWithOptions {
    private static final Log LOG;
    protected ListObjectInspector _xOI;
    protected PrimitiveObjectInspector _yOI;

    @Nullable
    protected Feature[] _probes;
    protected FMHyperParameters _params;
    protected boolean _classification;
    protected int _iterations;
    protected int _factors;
    protected boolean _parseFeatureAsInt;
    protected boolean _earlyStopping;
    protected ConversionState _validationState;
    protected boolean _adaptiveRegularization;

    @Nullable
    protected Random _va_rand;
    protected float _validationRatio;
    protected int _validationThreshold;
    protected LossFunctions.LossFunction _lossFunction;
    protected EtaEstimator _etaEstimator;
    protected ConversionState _cvState;
    protected transient FactorizationMachineModel _model;
    protected long _t;
    protected long _numValidations;
    private transient ByteBuffer _inputBuf;
    private transient NioStatefulSegment _fileIO;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hivemall.UDTFWithOptions
    public Options getOptions() {
        Options options = new Options();
        options.addOption("c", "classification", false, "Act as classification");
        options.addOption("seed", true, "Seed value [default: -1 (random)]");
        options.addOption("iters", "iterations", true, "The number of iterations [default: 10]");
        options.addOption("iter", true, "The number of iterations [default: 10]. Note this is alias of `iters` for backward compatibility");
        options.addOption("p", "num_features", true, "The size of feature dimensions [default: -1]");
        options.addOption("f", "factors", true, "The number of the latent variables [default: 5]");
        options.addOption("k", "factor", true, "The number of the latent variables [default: 5] Alias of `-factors` option");
        options.addOption("sigma", true, "The standard deviation for initializing V [default: 0.1]");
        options.addOption("lambda0", "lambda", true, "The initial lambda value for regularization [default: 1.0E-4]");
        options.addOption("lambdaW0", "lambda_w0", true, "The initial lambda value for W0 regularization [default: 1.0E-4]");
        options.addOption("lambdaWi", "lambda_wi", true, "The initial lambda value for Wi regularization [default: 1.0E-4]");
        options.addOption("lambdaV", "lambda_v", true, "The initial lambda value for V regularization [default: 1.0E-4]");
        options.addOption("min", "min_target", true, "The minimum value of target variable");
        options.addOption("max", "max_target", true, "The maximum value of target variable");
        options.addOption("eta", true, "The initial learning rate [default: 0.3]");
        options.addOption("eta0", true, "The initial learning rate [default: 0.1]");
        options.addOption("t", "total_steps", true, "The total number of training examples");
        options.addOption("power_t", true, "The exponent for inverse scaling learning rate [default: 0.1]");
        options.addOption("disable_cv", "disable_cvtest", false, "Whether to disable convergence check [default: OFF]");
        options.addOption("cv_rate", "convergence_rate", true, "Threshold to determine convergence [default: 0.005]");
        options.addOption("early_stopping", false, "Stop at the iteration that achieves the best validation on partial samples [default: OFF]");
        options.addOption("va_ratio", "validation_ratio", true, "Ratio of training data used for validation [default: 0.05f]");
        options.addOption("va_threshold", "validation_threshold", true, "Threshold to start validation. At least N training examples are used before validation [default: 1000]");
        if (isAdaptiveRegularizationSupported()) {
            options.addOption("adareg", "adaptive_regularization", false, "Whether to enable adaptive regularization [default: OFF]");
        }
        options.addOption("init_v", true, "Initialization strategy of matrix V [adjusted_random, libffm, random, gaussian](FM default: 'adjusted_random' for regression, 'gaussian' for classification, FFM default: random)");
        options.addOption("maxval", "max_init_value", true, "The maximum initial value in the matrix V [default: 0.5]");
        options.addOption("min_init_stddev", true, "The minimum standard deviation of initial matrix V [default: 0.1]");
        options.addOption("int_feature", "feature_as_integer", false, "Parse a feature as integer [default: OFF]");
        options.addOption("enable_norm", "l2norm", false, "Enable instance-wise L2 normalization");
        return options;
    }

    protected boolean isAdaptiveRegularizationSupported() {
        return true;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hivemall.UDTFWithOptions
    public CommandLine processOptions(@Nonnull ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        FMHyperParameters fMHyperParameters = this._params;
        CommandLine commandLine = null;
        if (objectInspectorArr.length >= 3) {
            commandLine = parseOptions(HiveUtils.getConstString(objectInspectorArr, 2));
            fMHyperParameters.processOptions(commandLine);
        }
        this._classification = fMHyperParameters.classification;
        this._iterations = fMHyperParameters.iters;
        this._factors = fMHyperParameters.factors;
        this._parseFeatureAsInt = fMHyperParameters.parseFeatureAsInt;
        this._earlyStopping = fMHyperParameters.earlyStopping;
        this._adaptiveRegularization = fMHyperParameters.adaptiveRegularization;
        if (this._earlyStopping || this._adaptiveRegularization) {
            this._va_rand = new Random(fMHyperParameters.seed + 31);
        }
        this._validationState = new ConversionState();
        this._validationRatio = fMHyperParameters.validationRatio;
        this._validationThreshold = fMHyperParameters.validationThreshold;
        this._lossFunction = fMHyperParameters.classification ? LossFunctions.getLossFunction(LossFunctions.LossType.LogLoss) : LossFunctions.getLossFunction(LossFunctions.LossType.SquaredLoss);
        this._etaEstimator = fMHyperParameters.eta;
        this._cvState = new ConversionState(fMHyperParameters.conversionCheck, fMHyperParameters.convergenceRate);
        return commandLine;
    }

    public StructObjectInspector initialize(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        if (objectInspectorArr.length != 2 && objectInspectorArr.length != 3) {
            showHelp(String.format("%s takes 2 or 3 arguments: array<string> x, double y [, CONSTANT string options]: %s", getClass().getSimpleName(), Arrays.toString(objectInspectorArr)));
        }
        this._xOI = HiveUtils.asListOI(objectInspectorArr, 0);
        HiveUtils.validateFeatureOI(this._xOI.getListElementObjectInspector());
        this._yOI = HiveUtils.asDoubleCompatibleOI(objectInspectorArr, 1);
        this._params = newHyperParameters();
        processOptions(objectInspectorArr);
        this._model = null;
        this._t = 0L;
        this._numValidations = 0L;
        if (LOG.isInfoEnabled()) {
            LOG.info(this._params);
        }
        return getOutputOI(this._params);
    }

    @Nonnull
    protected FMHyperParameters newHyperParameters() {
        return new FMHyperParameters();
    }

    @Nonnull
    protected StructObjectInspector getOutputOI(@Nonnull FMHyperParameters fMHyperParameters) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        arrayList.add("feature");
        if (fMHyperParameters.parseFeatureAsInt) {
            arrayList2.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
        } else {
            arrayList2.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
        }
        arrayList.add("W_i");
        arrayList2.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
        arrayList.add("V_if");
        arrayList2.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableFloatObjectInspector));
        return ObjectInspectorFactory.getStandardStructObjectInspector(arrayList, arrayList2);
    }

    @Nonnull
    protected FactorizationMachineModel initModel(@Nonnull FMHyperParameters fMHyperParameters) throws UDFArgumentException {
        FactorizationMachineModel fMIntFeatureMapModel = fMHyperParameters.parseFeatureAsInt ? fMHyperParameters.numFeatures == -1 ? new FMIntFeatureMapModel(fMHyperParameters) : new FMArrayModel(fMHyperParameters) : new FMStringFeatureMapModel(fMHyperParameters);
        this._model = fMIntFeatureMapModel;
        return fMIntFeatureMapModel;
    }

    public void process(Object[] objArr) throws HiveException {
        if (this._model == null) {
            this._model = initModel(this._params);
        }
        Feature[] parseFeatures = parseFeatures(objArr[0]);
        if (parseFeatures == null) {
            return;
        }
        this._probes = parseFeatures;
        this._model.check(parseFeatures);
        double d = PrimitiveObjectInspectorUtils.getDouble(objArr[1], this._yOI);
        if (this._classification) {
            d = d > 0.0d ? 1.0d : -1.0d;
        }
        boolean isValidationExample = isValidationExample();
        recordTrain(parseFeatures, d, isValidationExample);
        train(parseFeatures, d, isValidationExample);
    }

    private boolean isValidationExample() {
        return this._va_rand != null && this._t >= ((long) this._validationThreshold) && this._va_rand.nextFloat() < this._validationRatio;
    }

    @Nullable
    protected Feature[] parseFeatures(@Nonnull Object obj) throws HiveException {
        Feature[] parseFeatures = Feature.parseFeatures(obj, this._xOI, this._probes, this._parseFeatureAsInt);
        if (this._params.l2norm) {
            Feature.l2normalize(parseFeatures);
        }
        return parseFeatures;
    }

    private void recordTrain(@Nonnull Feature[] featureArr, double d, boolean z) throws HiveException {
        if (this._iterations <= 1) {
            return;
        }
        ByteBuffer byteBuffer = this._inputBuf;
        NioStatefulSegment nioStatefulSegment = this._fileIO;
        if (byteBuffer == null) {
            try {
                File createTempFile = File.createTempFile("hivemall_fm", ".sgmt");
                createTempFile.deleteOnExit();
                if (!createTempFile.canWrite()) {
                    throw new UDFArgumentException("Cannot write a temporary file: " + createTempFile.getAbsolutePath());
                }
                LOG.info("Record training examples to a file: " + createTempFile.getAbsolutePath());
                ByteBuffer allocateDirect = ByteBuffer.allocateDirect(BloomFilterUtils.DEFAULT_BLOOM_FILTER_SIZE);
                byteBuffer = allocateDirect;
                this._inputBuf = allocateDirect;
                NioStatefulSegment nioStatefulSegment2 = new NioStatefulSegment(createTempFile, false);
                nioStatefulSegment = nioStatefulSegment2;
                this._fileIO = nioStatefulSegment2;
            } catch (IOException e) {
                throw new UDFArgumentException(e);
            } catch (Throwable th) {
                throw new UDFArgumentException(th);
            }
        }
        int requiredBytes = 12 + Feature.requiredBytes(featureArr) + 1;
        if (byteBuffer.remaining() < 4 + requiredBytes) {
            writeBuffer(byteBuffer, nioStatefulSegment);
        }
        byteBuffer.putInt(requiredBytes);
        byteBuffer.putInt(featureArr.length);
        for (Feature feature : featureArr) {
            feature.writeTo(byteBuffer);
        }
        byteBuffer.putDouble(d);
        if (!z) {
            byteBuffer.put(Primitives.FALSE_BYTE.byteValue());
        } else {
            this._numValidations++;
            byteBuffer.put(Primitives.TRUE_BYTE.byteValue());
        }
    }

    private static void writeBuffer(@Nonnull ByteBuffer byteBuffer, @Nonnull NioStatefulSegment nioStatefulSegment) throws HiveException {
        byteBuffer.flip();
        try {
            nioStatefulSegment.write(byteBuffer);
            byteBuffer.clear();
        } catch (IOException e) {
            throw new HiveException("Exception causes while writing a buffer to file", e);
        }
    }

    private void train(@Nonnull Feature[] featureArr, double d, boolean z) throws HiveException {
        try {
            if (z) {
                processValidationSample(featureArr, d);
            } else {
                this._t++;
                trainTheta(featureArr, d);
            }
        } catch (Exception e) {
            throw new HiveException("Exception caused in the " + this._t + "-th call of train()", e);
        }
    }

    protected void processValidationSample(@Nonnull Feature[] featureArr, double d) throws HiveException {
        if (this._earlyStopping) {
            this._validationState.incrLoss(this._lossFunction.loss(this._model.predict(featureArr), d));
        }
        if (this._adaptiveRegularization) {
            trainLambda(featureArr, d);
        }
    }

    protected void trainTheta(Feature[] featureArr, double d) throws HiveException {
        float eta = this._etaEstimator.eta(this._t);
        double predict = this._model.predict(featureArr);
        double dloss = this._model.dloss(predict, d);
        this._cvState.incrLoss(this._lossFunction.loss(predict, d));
        if (MathUtils.closeToZero(dloss, 1.0E-9d)) {
            return;
        }
        this._model.updateW0(dloss, eta);
        double[] sumVfX = this._model.sumVfX(featureArr);
        for (Feature feature : featureArr) {
            this._model.updateWi(dloss, feature, eta);
            int i = this._factors;
            for (int i2 = 0; i2 < i; i2++) {
                this._model.updateV(dloss, feature, i2, sumVfX[i2], eta);
            }
        }
    }

    private void trainLambda(Feature[] featureArr, double d) throws HiveException {
        float eta = this._etaEstimator.eta(this._t);
        double dloss = this._model.dloss(this._model.predict(featureArr), d);
        this._model.updateLambdaW0(dloss, eta);
        this._model.updateLambdaW(featureArr, dloss, eta);
        this._model.updateLambdaV(featureArr, dloss, eta);
    }

    public void close() throws HiveException {
        this._probes = null;
        if (this._t == 0) {
            this._model = null;
            return;
        }
        if (this._iterations > 1) {
            runTrainingIteration(this._iterations);
        }
        int size = this._model.getSize();
        if (size <= 0) {
            LOG.warn("Model size P was less than zero: " + size);
            this._model = null;
        } else {
            forwardModel();
            this._model = null;
        }
    }

    @VisibleForTesting
    void finalizeTraining() throws HiveException {
        if (this._iterations > 1) {
            runTrainingIteration(this._iterations);
        }
    }

    protected void forwardModel() throws HiveException {
        if (this._parseFeatureAsInt) {
            forwardAsIntFeature(this._model, this._factors);
        } else {
            forwardAsStringFeature((FMStringFeatureMapModel) this._model, this._factors);
        }
    }

    private void forwardAsIntFeature(@Nonnull FactorizationMachineModel factorizationMachineModel, int i) throws HiveException {
        IntWritable intWritable = new IntWritable(0);
        FloatWritable floatWritable = new FloatWritable(0.0f);
        FloatWritable[] newFloatArray = HiveUtils.newFloatArray(i, 0.0f);
        Object[] objArr = {intWritable, floatWritable, null};
        intWritable.set(0);
        floatWritable.set(factorizationMachineModel.getW0());
        forward(objArr);
        objArr[2] = Arrays.asList(newFloatArray);
        int maxIndex = factorizationMachineModel.getMaxIndex();
        for (int minIndex = factorizationMachineModel.getMinIndex(); minIndex <= maxIndex; minIndex++) {
            float[] v = factorizationMachineModel.getV(minIndex, false);
            if (v != null) {
                intWritable.set(minIndex);
                floatWritable.set(factorizationMachineModel.getW(minIndex));
                for (int i2 = 0; i2 < i; i2++) {
                    newFloatArray[i2].set(v[i2]);
                }
                forward(objArr);
            }
        }
    }

    private void forwardAsStringFeature(@Nonnull FMStringFeatureMapModel fMStringFeatureMapModel, int i) throws HiveException {
        Text text = new Text();
        FloatWritable floatWritable = new FloatWritable(0.0f);
        FloatWritable[] newFloatArray = HiveUtils.newFloatArray(i, 0.0f);
        Object[] objArr = {text, floatWritable, null};
        text.set(HivemallConstants.BIAS_CLAUSE);
        floatWritable.set(fMStringFeatureMapModel.getW0());
        forward(objArr);
        objArr[2] = Arrays.asList(newFloatArray);
        ObjectIterator it = Fastutil.fastIterable(fMStringFeatureMapModel.getMap()).iterator();
        while (it.hasNext()) {
            Map.Entry entry = (Map.Entry) it.next();
            String str = (String) entry.getKey();
            if (!$assertionsDisabled && str == null) {
                throw new AssertionError();
            }
            text.set(str);
            FMStringFeatureMapModel.Entry entry2 = (FMStringFeatureMapModel.Entry) entry.getValue();
            floatWritable.set(entry2.W);
            float[] fArr = entry2.Vf;
            for (int i2 = 0; i2 < i; i2++) {
                newFloatArray[i2].set(fArr[i2]);
            }
            forward(objArr);
        }
    }

    /* JADX WARN: Code restructure failed: missing block: B:119:0x0348, code lost:
    
        r0 = r6._validationState.isLossIncreased();
     */
    /* JADX WARN: Code restructure failed: missing block: B:120:0x0353, code lost:
    
        if (r12 == false) goto L99;
     */
    /* JADX WARN: Code restructure failed: missing block: B:122:0x0358, code lost:
    
        if (r0 != false) goto L138;
     */
    /* JADX WARN: Code restructure failed: missing block: B:126:0x0364, code lost:
    
        if (r6._cvState.isConverged(r0) == false) goto L102;
     */
    /* JADX WARN: Code restructure failed: missing block: B:127:0x036a, code lost:
    
        r12 = r0;
        r15 = r15 + 1;
     */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    protected void runTrainingIteration(int r7) throws org.apache.hadoop.hive.ql.metadata.HiveException {
        /*
            Method dump skipped, instructions count: 1088
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: hivemall.factorization.fm.FactorizationMachineUDTF.runTrainingIteration(int):void");
    }

    @Nonnull
    protected Feature instantiateFeature(@Nonnull ByteBuffer byteBuffer) {
        return this._parseFeatureAsInt ? new IntFeature(byteBuffer) : new StringFeature(byteBuffer);
    }

    static {
        $assertionsDisabled = !FactorizationMachineUDTF.class.desiredAssertionStatus();
        LOG = LogFactory.getLog(FactorizationMachineUDTF.class);
    }
}
