package hivemall;

import hivemall.annotations.VisibleForTesting;
import hivemall.common.ConversionState;
import hivemall.model.FeatureValue;
import hivemall.model.PredictionModel;
import hivemall.model.WeightValue;
import hivemall.optimizer.LossFunctions;
import hivemall.optimizer.Optimizer;
import hivemall.optimizer.OptimizerOptions;
import hivemall.sketch.bloom.BloomFilterUtils;
import hivemall.utils.collections.IMapIterator;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.io.NIOUtils;
import hivemall.utils.io.NioStatefulSegment;
import hivemall.utils.lang.FloatAccumulator;
import hivemall.utils.lang.Primitives;
import hivemall.utils.lang.SizeOf;
import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.IntObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaIntObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
import org.apache.hadoop.io.FloatWritable;

/* loaded from: input_file:hivemall/GeneralLearnerBaseUDTF.class */
public abstract class GeneralLearnerBaseUDTF extends LearnerBaseUDTF {
    private static final Log logger;
    private static final float MAX_DLOSS = 1.0E12f;
    private static final float MIN_DLOSS = -1.0E12f;
    private ListObjectInspector featureListOI;
    private PrimitiveObjectInspector targetOI;
    private FeatureType featureType;

    @Nonnull
    private final Map<String, String> optimizerOptions;
    private Optimizer optimizer;
    private LossFunctions.LossFunction lossFunction;
    private PredictionModel model;
    private long count;

    @Nullable
    private transient Map<Object, FloatAccumulator> accumulated;
    private int sampled;

    @Nullable
    protected transient NioStatefulSegment fileIO;

    @Nullable
    protected transient ByteBuffer inputBuf;
    private int iterations;
    protected ConversionState cvState;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: hivemall.GeneralLearnerBaseUDTF$1, reason: invalid class name */
    /* loaded from: input_file:hivemall/GeneralLearnerBaseUDTF$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$hivemall$GeneralLearnerBaseUDTF$FeatureType = new int[FeatureType.values().length];

        static {
            try {
                $SwitchMap$hivemall$GeneralLearnerBaseUDTF$FeatureType[FeatureType.STRING.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$hivemall$GeneralLearnerBaseUDTF$FeatureType[FeatureType.INT.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$hivemall$GeneralLearnerBaseUDTF$FeatureType[FeatureType.LONG.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    /* loaded from: input_file:hivemall/GeneralLearnerBaseUDTF$FeatureType.class */
    public enum FeatureType {
        STRING,
        INT,
        LONG
    }

    public GeneralLearnerBaseUDTF() {
        this(true);
    }

    public GeneralLearnerBaseUDTF(boolean z) {
        super(z);
        this.optimizerOptions = OptimizerOptions.create();
    }

    @Nonnull
    protected abstract String getLossOptionDescription();

    @Nonnull
    protected abstract LossFunctions.LossType getDefaultLossType();

    protected abstract void checkLossFunction(@Nonnull LossFunctions.LossFunction lossFunction) throws UDFArgumentException;

    protected abstract void checkTargetValue(float f) throws UDFArgumentException;

    protected abstract void train(@Nonnull FeatureValue[] featureValueArr, float f);

    public StructObjectInspector initialize(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        if (objectInspectorArr.length < 2) {
            throw new UDFArgumentException("_FUNC_ takes 2 arguments: List<Int|BigInt|Text> features, float target [, constant string options]");
        }
        this.featureListOI = HiveUtils.asListOI(objectInspectorArr[0]);
        this.featureType = getFeatureType(this.featureListOI);
        this.targetOI = HiveUtils.asDoubleCompatibleOI(objectInspectorArr[1]);
        processOptions(objectInspectorArr);
        this.model = createModel();
        try {
            this.optimizer = createOptimizer(this.optimizerOptions);
            this.count = 0L;
            this.sampled = 0;
            return getReturnOI(getFeatureOutputOI(this.featureType));
        } catch (Throwable th) {
            throw new UDFArgumentException(th);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hivemall.LearnerBaseUDTF, hivemall.UDTFWithOptions
    public Options getOptions() {
        Options options = super.getOptions();
        options.addOption("loss", "loss_function", true, getLossOptionDescription());
        options.addOption("iter", "iterations", true, "The maximum number of iterations [default: 10]");
        options.addOption("iters", "iterations", true, "The maximum number of iterations [default: 10]");
        options.addOption("disable_cv", "disable_cvtest", false, "Whether to disable convergence check [default: OFF]");
        options.addOption("cv_rate", "convergence_rate", true, "Threshold to determine convergence [default: 0.005]");
        OptimizerOptions.setup(options);
        return options;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hivemall.LearnerBaseUDTF, hivemall.UDTFWithOptions
    public CommandLine processOptions(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        CommandLine processOptions = super.processOptions(objectInspectorArr);
        LossFunctions.LossFunction lossFunction = LossFunctions.getLossFunction(getDefaultLossType());
        int i = 10;
        boolean z = true;
        double d = 0.005d;
        if (processOptions != null) {
            if (processOptions.hasOption("loss_function")) {
                try {
                    lossFunction = LossFunctions.getLossFunction(processOptions.getOptionValue("loss_function"));
                } catch (Throwable th) {
                    throw new UDFArgumentException(th.getMessage());
                }
            }
            checkLossFunction(lossFunction);
            i = Primitives.parseInt(processOptions.getOptionValue("iterations"), 10);
            if (i < 1) {
                throw new UDFArgumentException("'-iterations' must be greater than or equals to 1: " + i);
            }
            z = !processOptions.hasOption("disable_cvtest");
            d = Primitives.parseDouble(processOptions.getOptionValue("cv_rate"), 0.005d);
        }
        this.lossFunction = lossFunction;
        this.iterations = i;
        this.cvState = new ConversionState(z, d);
        OptimizerOptions.processOptions(processOptions, this.optimizerOptions);
        return processOptions;
    }

    @Nonnull
    private static FeatureType getFeatureType(@Nonnull ListObjectInspector listObjectInspector) throws UDFArgumentException {
        ObjectInspector listElementObjectInspector = listObjectInspector.getListElementObjectInspector();
        if (listElementObjectInspector instanceof StringObjectInspector) {
            return FeatureType.STRING;
        }
        if (listElementObjectInspector instanceof IntObjectInspector) {
            return FeatureType.INT;
        }
        if (listElementObjectInspector instanceof LongObjectInspector) {
            return FeatureType.LONG;
        }
        throw new UDFArgumentException("Feature object inspector must be one of [StringObjectInspector, IntObjectInspector, LongObjectInspector]: " + listElementObjectInspector.toString());
    }

    @Nonnull
    protected final ObjectInspector getFeatureOutputOI(@Nonnull FeatureType featureType) throws UDFArgumentException {
        JavaIntObjectInspector javaIntObjectInspector;
        if (this.dense_model) {
            javaIntObjectInspector = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
        } else {
            switch (AnonymousClass1.$SwitchMap$hivemall$GeneralLearnerBaseUDTF$FeatureType[featureType.ordinal()]) {
                case SizeOf.BYTE /* 1 */:
                    javaIntObjectInspector = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
                    break;
                case 2:
                    javaIntObjectInspector = PrimitiveObjectInspectorFactory.javaIntObjectInspector;
                    break;
                case 3:
                    javaIntObjectInspector = PrimitiveObjectInspectorFactory.javaLongObjectInspector;
                    break;
                default:
                    throw new IllegalStateException("Unexpected feature type: " + featureType);
            }
        }
        return javaIntObjectInspector;
    }

    @Nonnull
    protected StructObjectInspector getReturnOI(@Nonnull ObjectInspector objectInspector) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        arrayList.add("feature");
        arrayList2.add(objectInspector);
        arrayList.add("weight");
        arrayList2.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
        if (useCovariance()) {
            arrayList.add("covar");
            arrayList2.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
        }
        return ObjectInspectorFactory.getStandardStructObjectInspector(arrayList, arrayList2);
    }

    public void process(Object[] objArr) throws HiveException {
        if (this.is_mini_batch && this.accumulated == null) {
            this.accumulated = new HashMap(1024);
        }
        FeatureValue[] parseFeatures = parseFeatures(this.featureListOI.getList(objArr[0]));
        if (parseFeatures == null) {
            return;
        }
        float f = PrimitiveObjectInspectorUtils.getFloat(objArr[1], this.targetOI);
        checkTargetValue(f);
        this.count++;
        train(parseFeatures, f);
        recordTrainSampleToTempFile(parseFeatures, f);
    }

    protected void recordTrainSampleToTempFile(@Nonnull FeatureValue[] featureValueArr, float f) throws HiveException {
        if (this.iterations == 1) {
            return;
        }
        ByteBuffer byteBuffer = this.inputBuf;
        NioStatefulSegment nioStatefulSegment = this.fileIO;
        if (byteBuffer == null) {
            try {
                File createTempFile = File.createTempFile("hivemall_general_learner", ".sgmt");
                createTempFile.deleteOnExit();
                if (!createTempFile.canWrite()) {
                    throw new UDFArgumentException("Cannot write a temporary file: " + createTempFile.getAbsolutePath());
                }
                logger.info("Record training samples to a file: " + createTempFile.getAbsolutePath());
                ByteBuffer allocateDirect = ByteBuffer.allocateDirect(BloomFilterUtils.DEFAULT_BLOOM_FILTER_SIZE);
                byteBuffer = allocateDirect;
                this.inputBuf = allocateDirect;
                NioStatefulSegment nioStatefulSegment2 = new NioStatefulSegment(createTempFile, false);
                nioStatefulSegment = nioStatefulSegment2;
                this.fileIO = nioStatefulSegment2;
            } catch (IOException e) {
                throw new UDFArgumentException(e);
            } catch (Throwable th) {
                throw new UDFArgumentException(th);
            }
        }
        int i = 0;
        for (FeatureValue featureValue : featureValueArr) {
            if (featureValue != null) {
                i = i + (2 * featureValue.getFeatureAsString().length()) + 4 + 8;
            }
        }
        int i2 = 4 + i + 4;
        if (byteBuffer.remaining() < 4 + i2) {
            writeBuffer(byteBuffer, nioStatefulSegment);
        }
        byteBuffer.putInt(i2);
        byteBuffer.putInt(featureValueArr.length);
        for (FeatureValue featureValue2 : featureValueArr) {
            writeFeatureValue(byteBuffer, featureValue2);
        }
        byteBuffer.putFloat(f);
    }

    private static void writeFeatureValue(@Nonnull ByteBuffer byteBuffer, @Nonnull FeatureValue featureValue) {
        NIOUtils.putString(featureValue.getFeatureAsString(), byteBuffer);
        byteBuffer.putDouble(featureValue.getValue());
    }

    @Nonnull
    private static FeatureValue readFeatureValue(@Nonnull ByteBuffer byteBuffer, @Nonnull FeatureType featureType) {
        Object valueOf;
        String string = NIOUtils.getString(byteBuffer);
        switch (AnonymousClass1.$SwitchMap$hivemall$GeneralLearnerBaseUDTF$FeatureType[featureType.ordinal()]) {
            case SizeOf.BYTE /* 1 */:
                valueOf = string;
                break;
            case 2:
                valueOf = Integer.valueOf(string);
                break;
            case 3:
                valueOf = Long.valueOf(string);
                break;
            default:
                throw new IllegalStateException("Unexpected feature type " + featureType + " for feature: " + string);
        }
        return new FeatureValue(valueOf, byteBuffer.getDouble());
    }

    @Nullable
    public final FeatureValue[] parseFeatures(@Nonnull List<?> list) {
        int size = list.size();
        if (size == 0) {
            return null;
        }
        ObjectInspector listElementObjectInspector = this.featureListOI.getListElementObjectInspector();
        FeatureValue[] featureValueArr = new FeatureValue[size];
        for (int i = 0; i < size; i++) {
            Object obj = list.get(i);
            if (obj != null) {
                featureValueArr[i] = this.featureType == FeatureType.STRING ? FeatureValue.parseFeatureAsString(obj.toString()) : new FeatureValue(ObjectInspectorUtils.copyToStandardObject(obj, listElementObjectInspector, ObjectInspectorUtils.ObjectInspectorCopyOption.JAVA), 1.0f);
            }
        }
        return featureValueArr;
    }

    private static void writeBuffer(@Nonnull ByteBuffer byteBuffer, @Nonnull NioStatefulSegment nioStatefulSegment) throws HiveException {
        byteBuffer.flip();
        try {
            nioStatefulSegment.write(byteBuffer);
            byteBuffer.clear();
        } catch (IOException e) {
            throw new HiveException("Exception causes while writing a buffer to file", e);
        }
    }

    public float predict(@Nonnull FeatureValue[] featureValueArr) {
        float f = 0.0f;
        for (FeatureValue featureValue : featureValueArr) {
            if (featureValue != null) {
                Object feature = featureValue.getFeature();
                float valueAsFloat = featureValue.getValueAsFloat();
                float weight = this.model.getWeight(feature);
                if (weight != 0.0f) {
                    f += weight * valueAsFloat;
                }
            }
        }
        return f;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void update(@Nonnull FeatureValue[] featureValueArr, float f, float f2) {
        this.cvState.incrLoss(this.lossFunction.loss(f2, f));
        float dloss = this.lossFunction.dloss(f2, f);
        if (dloss == 0.0f) {
            this.optimizer.proceedStep();
            return;
        }
        if (dloss < MIN_DLOSS) {
            dloss = -1.0E12f;
        } else if (dloss > MAX_DLOSS) {
            dloss = 1.0E12f;
        }
        if (this.is_mini_batch) {
            accumulateUpdate(featureValueArr, dloss);
            if (this.sampled >= this.mini_batch_size) {
                batchUpdate();
            }
        } else {
            onlineUpdate(featureValueArr, dloss);
        }
        this.optimizer.proceedStep();
    }

    protected void accumulateUpdate(@Nonnull FeatureValue[] featureValueArr, float f) {
        for (FeatureValue featureValue : featureValueArr) {
            Object feature = featureValue.getFeature();
            float update = this.optimizer.update(feature, this.model.getWeight(feature), f * featureValue.getValueAsFloat());
            FloatAccumulator floatAccumulator = this.accumulated.get(feature);
            if (floatAccumulator == null) {
                this.accumulated.put(feature, new FloatAccumulator(update));
            } else {
                floatAccumulator.add(update);
            }
        }
        this.sampled++;
    }

    protected void batchUpdate() {
        if (this.accumulated.isEmpty()) {
            this.sampled = 0;
            return;
        }
        for (Map.Entry<Object, FloatAccumulator> entry : this.accumulated.entrySet()) {
            Object key = entry.getKey();
            float f = entry.getValue().get();
            if (f == 0.0f) {
                this.model.delete(key);
            } else {
                this.model.setWeight(key, f);
            }
        }
        this.accumulated.clear();
        this.sampled = 0;
    }

    protected void onlineUpdate(@Nonnull FeatureValue[] featureValueArr, float f) {
        for (FeatureValue featureValue : featureValueArr) {
            Object feature = featureValue.getFeature();
            float update = this.optimizer.update(feature, this.model.getWeight(feature), f * featureValue.getValueAsFloat());
            if (update == 0.0f) {
                this.model.delete(feature);
            } else {
                this.model.setWeight(feature, update);
            }
        }
    }

    @Override // hivemall.LearnerBaseUDTF
    public final void close() throws HiveException {
        super.close();
        finalizeTraining();
        forwardModel();
        this.accumulated = null;
        this.model = null;
    }

    @VisibleForTesting
    public void finalizeTraining() throws HiveException {
        if (this.count == 0) {
            this.model = null;
            return;
        }
        if (this.is_mini_batch) {
            batchUpdate();
        }
        if (this.iterations > 1) {
            runIterativeTraining(this.iterations);
        }
    }

    /* JADX WARN: Code restructure failed: missing block: B:119:0x030d, code lost:
    
        if (r7.is_mini_batch == false) goto L101;
     */
    /* JADX WARN: Code restructure failed: missing block: B:120:0x0310, code lost:
    
        batchUpdate();
     */
    /* JADX WARN: Code restructure failed: missing block: B:122:0x031d, code lost:
    
        if (r7.cvState.isConverged(r0) == false) goto L104;
     */
    /* JADX WARN: Code restructure failed: missing block: B:123:0x0323, code lost:
    
        r15 = r15 + 1;
     */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    protected final void runIterativeTraining(@javax.annotation.Nonnegative int r8) throws org.apache.hadoop.hive.ql.metadata.HiveException {
        /*
            Method dump skipped, instructions count: 1015
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: hivemall.GeneralLearnerBaseUDTF.runIterativeTraining(int):void");
    }

    protected void forwardModel() throws HiveException {
        int i = 0;
        if (useCovariance()) {
            WeightValue.WeightValueWithCovar weightValueWithCovar = new WeightValue.WeightValueWithCovar();
            Object[] objArr = new Object[3];
            FloatWritable floatWritable = new FloatWritable();
            FloatWritable floatWritable2 = new FloatWritable();
            IMapIterator entries = this.model.entries();
            while (entries.next() != -1) {
                entries.getValue(weightValueWithCovar);
                if (weightValueWithCovar.isTouched()) {
                    float f = weightValueWithCovar.get();
                    float covariance = weightValueWithCovar.getCovariance();
                    if (f != 0.0f || covariance != 0.0f) {
                        floatWritable.set(f);
                        floatWritable2.set(covariance);
                        objArr[0] = entries.getKey2();
                        objArr[1] = floatWritable;
                        objArr[2] = floatWritable2;
                        forward(objArr);
                        i++;
                    }
                }
            }
        } else {
            WeightValue weightValue = new WeightValue();
            Object[] objArr2 = new Object[2];
            FloatWritable floatWritable3 = new FloatWritable();
            IMapIterator entries2 = this.model.entries();
            while (entries2.next() != -1) {
                entries2.getValue(weightValue);
                if (weightValue.isTouched()) {
                    float f2 = weightValue.get();
                    if (f2 != 0.0f) {
                        floatWritable3.set(f2);
                        objArr2[0] = entries2.getKey2();
                        objArr2[1] = floatWritable3;
                        forward(objArr2);
                        i++;
                    }
                }
            }
        }
        long numMixed = this.model.getNumMixed();
        logger.info("Trained a prediction model using " + this.count + " training examples" + (numMixed > 0 ? "( numMixed: " + numMixed + " )" : ""));
        logger.info("Forwarded the prediction model of " + i + " rows");
    }

    @VisibleForTesting
    public double getCumulativeLoss() {
        if (this.cvState == null) {
            return Double.NaN;
        }
        return this.cvState.getCumulativeLoss();
    }

    static {
        $assertionsDisabled = !GeneralLearnerBaseUDTF.class.desiredAssertionStatus();
        logger = LogFactory.getLog(GeneralLearnerBaseUDTF.class);
    }
}
