package hivemall.classifier;

import hivemall.LearnerBaseUDTF;
import hivemall.annotations.VisibleForTesting;
import hivemall.model.FeatureValue;
import hivemall.model.IWeightValue;
import hivemall.model.PredictionModel;
import hivemall.model.PredictionResult;
import hivemall.model.WeightValue;
import hivemall.utils.collections.IMapIterator;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.lang.FloatAccumulator;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.io.FloatWritable;

/* loaded from: input_file:hivemall/classifier/BinaryOnlineClassifierUDTF.class */
public abstract class BinaryOnlineClassifierUDTF extends LearnerBaseUDTF {
    private static final Log logger = LogFactory.getLog(BinaryOnlineClassifierUDTF.class);
    protected ListObjectInspector featureListOI;
    protected PrimitiveObjectInspector labelOI;
    private boolean parseFeature;
    protected PredictionModel model;
    protected int count;
    protected transient Map<Object, FloatAccumulator> accumulated;
    protected int sampled;

    public BinaryOnlineClassifierUDTF() {
        this(false);
    }

    public BinaryOnlineClassifierUDTF(boolean z) {
        super(z);
    }

    public StructObjectInspector initialize(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        if (objectInspectorArr.length < 2) {
            throw new UDFArgumentException(getClass().getSimpleName() + " takes 2 arguments: List<Int|BigInt|Text> features, int label [, constant string options]");
        }
        PrimitiveObjectInspector processFeaturesOI = processFeaturesOI(objectInspectorArr[0]);
        this.labelOI = HiveUtils.asIntCompatibleOI(objectInspectorArr[1]);
        processOptions(objectInspectorArr);
        this.model = createModel();
        this.count = 0;
        this.sampled = 0;
        return getReturnOI(getFeatureOutputOI(processFeaturesOI));
    }

    @Nonnull
    protected PrimitiveObjectInspector processFeaturesOI(@Nonnull ObjectInspector objectInspector) throws UDFArgumentException {
        this.featureListOI = (ListObjectInspector) objectInspector;
        ObjectInspector listElementObjectInspector = this.featureListOI.getListElementObjectInspector();
        HiveUtils.validateFeatureOI(listElementObjectInspector);
        this.parseFeature = HiveUtils.isStringOI(listElementObjectInspector);
        return HiveUtils.asPrimitiveObjectInspector(listElementObjectInspector);
    }

    @Nonnull
    protected StructObjectInspector getReturnOI(@Nonnull ObjectInspector objectInspector) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        arrayList.add("feature");
        arrayList2.add(objectInspector);
        arrayList.add("weight");
        arrayList2.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
        if (useCovariance()) {
            arrayList.add("covar");
            arrayList2.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
        }
        return ObjectInspectorFactory.getStandardStructObjectInspector(arrayList, arrayList2);
    }

    public void process(Object[] objArr) throws HiveException {
        if (this.is_mini_batch && this.accumulated == null) {
            this.accumulated = new HashMap(1024);
        }
        FeatureValue[] parseFeatures = parseFeatures(this.featureListOI.getList(objArr[0]));
        if (parseFeatures == null) {
            return;
        }
        int i = PrimitiveObjectInspectorUtils.getInt(objArr[1], this.labelOI);
        checkLabelValue(i);
        this.count++;
        train(parseFeatures, i);
    }

    @Nullable
    FeatureValue[] parseFeatures(@Nonnull List<?> list) {
        int size = list.size();
        if (size == 0) {
            return null;
        }
        ObjectInspector listElementObjectInspector = this.featureListOI.getListElementObjectInspector();
        FeatureValue[] featureValueArr = new FeatureValue[size];
        for (int i = 0; i < size; i++) {
            Object obj = list.get(i);
            if (obj != null) {
                featureValueArr[i] = this.parseFeature ? FeatureValue.parse(obj) : new FeatureValue(ObjectInspectorUtils.copyToStandardObject(obj, listElementObjectInspector), 1.0f);
            }
        }
        return featureValueArr;
    }

    protected void checkLabelValue(int i) throws UDFArgumentException {
        if (i != -1 && i != 0 && i != 1) {
            throw new UDFArgumentException("Invalid label value for classification:  + label");
        }
    }

    @VisibleForTesting
    void train(List<?> list, int i) {
        train(parseFeatures(list), i);
    }

    protected void train(@Nonnull FeatureValue[] featureValueArr, int i) {
        float f = i > 0 ? 1.0f : -1.0f;
        float predict = predict(featureValueArr);
        if (predict * f <= 0.0f) {
            update(featureValueArr, f, predict);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public float predict(@Nonnull FeatureValue[] featureValueArr) {
        float f = 0.0f;
        for (FeatureValue featureValue : featureValueArr) {
            if (featureValue != null) {
                float weight = this.model.getWeight(featureValue.getFeature());
                if (weight != 0.0f) {
                    f += weight * featureValue.getValueAsFloat();
                }
            }
        }
        return f;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Nonnull
    public PredictionResult calcScoreAndNorm(@Nonnull FeatureValue[] featureValueArr) {
        float f = 0.0f;
        float f2 = 0.0f;
        for (FeatureValue featureValue : featureValueArr) {
            if (featureValue != null) {
                Object feature = featureValue.getFeature();
                float valueAsFloat = featureValue.getValueAsFloat();
                float weight = this.model.getWeight(feature);
                if (weight != 0.0f) {
                    f += weight * valueAsFloat;
                }
                f2 += valueAsFloat * valueAsFloat;
            }
        }
        return new PredictionResult(f).squaredNorm(f2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Nonnull
    public PredictionResult calcScoreAndVariance(@Nonnull FeatureValue[] featureValueArr) {
        float f = 0.0f;
        float f2 = 0.0f;
        for (FeatureValue featureValue : featureValueArr) {
            if (featureValue != null) {
                Object feature = featureValue.getFeature();
                float valueAsFloat = featureValue.getValueAsFloat();
                IWeightValue iWeightValue = this.model.get(feature);
                if (iWeightValue == null) {
                    f2 += 1.0f * valueAsFloat * valueAsFloat;
                } else {
                    f += iWeightValue.get() * valueAsFloat;
                    f2 += iWeightValue.getCovariance() * valueAsFloat * valueAsFloat;
                }
            }
        }
        return new PredictionResult(f).variance(f2);
    }

    protected void update(@Nonnull FeatureValue[] featureValueArr, float f, float f2) {
        throw new IllegalStateException("update() should not be called");
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void update(@Nonnull FeatureValue[] featureValueArr, float f) {
        for (FeatureValue featureValue : featureValueArr) {
            if (featureValue != null) {
                Object feature = featureValue.getFeature();
                this.model.set(feature, new WeightValue(this.model.getWeight(feature) + (f * featureValue.getValueAsFloat())));
            }
        }
    }

    protected void accumulateUpdate(@Nonnull FeatureValue[] featureValueArr, float f) {
        throw new UnsupportedOperationException();
    }

    protected void batchUpdate() {
        throw new UnsupportedOperationException();
    }

    protected void onlineUpdate(@Nonnull FeatureValue[] featureValueArr, float f) {
        throw new UnsupportedOperationException();
    }

    @Override // hivemall.LearnerBaseUDTF
    public void close() throws HiveException {
        super.close();
        if (this.model != null) {
            if (this.accumulated != null) {
                batchUpdate();
                this.accumulated = null;
            }
            int i = 0;
            if (useCovariance()) {
                WeightValue.WeightValueWithCovar weightValueWithCovar = new WeightValue.WeightValueWithCovar();
                Object[] objArr = new Object[3];
                FloatWritable floatWritable = new FloatWritable();
                FloatWritable floatWritable2 = new FloatWritable();
                IMapIterator entries = this.model.entries();
                while (entries.next() != -1) {
                    entries.getValue(weightValueWithCovar);
                    if (weightValueWithCovar.isTouched()) {
                        Object key2 = entries.getKey2();
                        floatWritable.set(weightValueWithCovar.get());
                        floatWritable2.set(weightValueWithCovar.getCovariance());
                        objArr[0] = key2;
                        objArr[1] = floatWritable;
                        objArr[2] = floatWritable2;
                        forward(objArr);
                        i++;
                    }
                }
            } else {
                WeightValue weightValue = new WeightValue();
                Object[] objArr2 = new Object[2];
                FloatWritable floatWritable3 = new FloatWritable();
                IMapIterator entries2 = this.model.entries();
                while (entries2.next() != -1) {
                    entries2.getValue(weightValue);
                    if (weightValue.isTouched()) {
                        Object key22 = entries2.getKey2();
                        floatWritable3.set(weightValue.get());
                        objArr2[0] = key22;
                        objArr2[1] = floatWritable3;
                        forward(objArr2);
                        i++;
                    }
                }
            }
            long numMixed = this.model.getNumMixed();
            this.model = null;
            logger.info("Trained a prediction model using " + this.count + " training examples" + (numMixed > 0 ? "( numMixed: " + numMixed + " )" : ""));
            logger.info("Forwarded the prediction model of " + i + " rows");
        }
    }
}
