package hivemall.classifier.multiclass;

import hivemall.HivemallConstants;
import hivemall.LearnerBaseUDTF;
import hivemall.fm.Feature;
import hivemall.model.FeatureValue;
import hivemall.model.IWeightValue;
import hivemall.model.Margin;
import hivemall.model.PredictionModel;
import hivemall.model.PredictionResult;
import hivemall.model.WeightValue;
import hivemall.utils.collections.IMapIterator;
import hivemall.utils.datetime.StopWatch;
import hivemall.utils.hadoop.HadoopUtils;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.io.IOUtils;
import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.SerDeException;
import org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.FloatObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableFloatObjectInspector;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.Text;

/* loaded from: input_file:hivemall/classifier/multiclass/MulticlassOnlineClassifierUDTF.class */
public abstract class MulticlassOnlineClassifierUDTF extends LearnerBaseUDTF {
    private static final Log logger;
    private ListObjectInspector featureListOI;
    private boolean parseFeature;
    private PrimitiveObjectInspector labelInputOI;
    protected Map<Object, PredictionModel> label2model;
    protected int count;
    static final /* synthetic */ boolean $assertionsDisabled;

    public MulticlassOnlineClassifierUDTF() {
        this(false);
    }

    public MulticlassOnlineClassifierUDTF(boolean z) {
        super(z);
    }

    public StructObjectInspector initialize(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        if (objectInspectorArr.length < 2) {
            throw new UDFArgumentException(getClass().getSimpleName() + " takes 2 arguments: List<Int|BigInt|Text> features, {Int|BitInt|Text} label [, constant text options]");
        }
        PrimitiveObjectInspector processFeaturesOI = processFeaturesOI(objectInspectorArr[0]);
        this.labelInputOI = HiveUtils.asPrimitiveObjectInspector(objectInspectorArr[1]);
        String typeName = this.labelInputOI.getTypeName();
        if (!HivemallConstants.STRING_TYPE_NAME.equals(typeName) && !HivemallConstants.INT_TYPE_NAME.equals(typeName) && !HivemallConstants.BIGINT_TYPE_NAME.equals(typeName)) {
            throw new UDFArgumentTypeException(0, "label must be a type [Int|BigInt|Text]: " + typeName);
        }
        processOptions(objectInspectorArr);
        this.label2model = new HashMap(64);
        this.count = 0;
        return getReturnOI(this.labelInputOI, getFeatureOutputOI(processFeaturesOI));
    }

    @Override // hivemall.LearnerBaseUDTF
    protected int getInitialModelSize() {
        return 8192;
    }

    @Nonnull
    protected PrimitiveObjectInspector processFeaturesOI(@Nonnull ObjectInspector objectInspector) throws UDFArgumentException {
        this.featureListOI = (ListObjectInspector) objectInspector;
        ObjectInspector listElementObjectInspector = this.featureListOI.getListElementObjectInspector();
        String typeName = listElementObjectInspector.getTypeName();
        if (!HivemallConstants.STRING_TYPE_NAME.equals(typeName) && !HivemallConstants.INT_TYPE_NAME.equals(typeName) && !HivemallConstants.BIGINT_TYPE_NAME.equals(typeName)) {
            throw new UDFArgumentTypeException(0, "1st argument must be Map of key type [Int|BitInt|Text]: " + typeName);
        }
        this.parseFeature = HivemallConstants.STRING_TYPE_NAME.equals(typeName);
        return HiveUtils.asPrimitiveObjectInspector(listElementObjectInspector);
    }

    @Nonnull
    protected StructObjectInspector getReturnOI(@Nonnull ObjectInspector objectInspector, @Nonnull ObjectInspector objectInspector2) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        arrayList.add("label");
        arrayList2.add(ObjectInspectorUtils.getStandardObjectInspector(objectInspector));
        arrayList.add("feature");
        arrayList2.add(objectInspector2);
        arrayList.add("weight");
        arrayList2.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
        if (useCovariance()) {
            arrayList.add("covar");
            arrayList2.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
        }
        return ObjectInspectorFactory.getStandardStructObjectInspector(arrayList, arrayList2);
    }

    public void process(Object[] objArr) throws HiveException {
        FeatureValue[] parseFeatures = parseFeatures(this.featureListOI.getList(objArr[0]));
        if (parseFeatures == null) {
            return;
        }
        Object copyToStandardObject = ObjectInspectorUtils.copyToStandardObject(objArr[1], this.labelInputOI);
        if (copyToStandardObject == null) {
            throw new UDFArgumentException("label value must not be NULL");
        }
        this.count++;
        train(parseFeatures, copyToStandardObject);
    }

    @Nullable
    protected final FeatureValue[] parseFeatures(@Nonnull List<?> list) {
        int size = list.size();
        if (size == 0) {
            return null;
        }
        ObjectInspector listElementObjectInspector = this.featureListOI.getListElementObjectInspector();
        FeatureValue[] featureValueArr = new FeatureValue[size];
        for (int i = 0; i < size; i++) {
            Object obj = list.get(i);
            if (obj != null) {
                featureValueArr[i] = this.parseFeature ? FeatureValue.parse(obj) : new FeatureValue(ObjectInspectorUtils.copyToStandardObject(obj, listElementObjectInspector), 1.0f);
            }
        }
        return featureValueArr;
    }

    protected abstract void train(@Nonnull FeatureValue[] featureValueArr, @Nonnull Object obj);

    /* JADX INFO: Access modifiers changed from: protected */
    public final PredictionResult classify(@Nonnull FeatureValue[] featureValueArr) {
        float f = Float.MIN_VALUE;
        Object obj = null;
        for (Map.Entry<Object, PredictionModel> entry : this.label2model.entrySet()) {
            Object key = entry.getKey();
            float calcScore = calcScore(entry.getValue(), featureValueArr);
            if (obj == null || calcScore > f) {
                f = calcScore;
                obj = key;
            }
        }
        return new PredictionResult(obj, f);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Margin getMargin(@Nonnull FeatureValue[] featureValueArr, Object obj) {
        float f = 0.0f;
        Object obj2 = null;
        float f2 = 0.0f;
        for (Map.Entry<Object, PredictionModel> entry : this.label2model.entrySet()) {
            Object key = entry.getKey();
            float calcScore = calcScore(entry.getValue(), featureValueArr);
            if (key.equals(obj)) {
                f = calcScore;
            } else if (obj2 == null || calcScore > f2) {
                obj2 = key;
                f2 = calcScore;
            }
        }
        return new Margin(f, obj2, f2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Margin getMarginAndVariance(@Nonnull FeatureValue[] featureValueArr, Object obj) {
        return getMarginAndVariance(featureValueArr, obj, false);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Margin getMarginAndVariance(@Nonnull FeatureValue[] featureValueArr, Object obj, boolean z) {
        float f = 0.0f;
        float f2 = 0.0f;
        Object obj2 = null;
        float f3 = 0.0f;
        float f4 = 0.0f;
        if (z && this.label2model.isEmpty()) {
            return new Margin(0.0f, null, 0.0f).variance(2.0f * calcVariance(featureValueArr));
        }
        for (Map.Entry<Object, PredictionModel> entry : this.label2model.entrySet()) {
            Object key = entry.getKey();
            PredictionResult calcScoreAndVariance = calcScoreAndVariance(entry.getValue(), featureValueArr);
            float score = calcScoreAndVariance.getScore();
            if (key.equals(obj)) {
                f = score;
                f2 = calcScoreAndVariance.getVariance();
            } else if (obj2 == null || score > f3) {
                obj2 = key;
                f3 = score;
                f4 = calcScoreAndVariance.getVariance();
            }
        }
        return new Margin(f, obj2, f3).variance(f2 + f4);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public final float squaredNorm(@Nonnull FeatureValue[] featureValueArr) {
        float f = 0.0f;
        for (FeatureValue featureValue : featureValueArr) {
            if (featureValue != null) {
                float valueAsFloat = featureValue.getValueAsFloat();
                f += valueAsFloat * valueAsFloat;
            }
        }
        return f;
    }

    protected final float calcScore(@Nonnull PredictionModel predictionModel, @Nonnull FeatureValue[] featureValueArr) {
        float f = 0.0f;
        for (FeatureValue featureValue : featureValueArr) {
            if (featureValue != null) {
                Object feature = featureValue.getFeature();
                float valueAsFloat = featureValue.getValueAsFloat();
                float weight = predictionModel.getWeight(feature);
                if (weight != 0.0f) {
                    f += weight * valueAsFloat;
                }
            }
        }
        return f;
    }

    protected final float calcVariance(@Nonnull FeatureValue[] featureValueArr) {
        float f = 0.0f;
        for (FeatureValue featureValue : featureValueArr) {
            if (featureValue != null) {
                float valueAsFloat = featureValue.getValueAsFloat();
                f += valueAsFloat * valueAsFloat;
            }
        }
        return f;
    }

    protected final PredictionResult calcScoreAndVariance(@Nonnull PredictionModel predictionModel, @Nonnull FeatureValue[] featureValueArr) {
        float f = 0.0f;
        float f2 = 0.0f;
        for (FeatureValue featureValue : featureValueArr) {
            if (featureValue != null) {
                Object feature = featureValue.getFeature();
                float valueAsFloat = featureValue.getValueAsFloat();
                IWeightValue iWeightValue = predictionModel.get(feature);
                if (iWeightValue == null) {
                    f2 += 1.0f * valueAsFloat * valueAsFloat;
                } else {
                    f += iWeightValue.get() * valueAsFloat;
                    f2 += iWeightValue.getCovariance() * valueAsFloat * valueAsFloat;
                }
            }
        }
        return new PredictionResult(f).variance(f2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void update(@Nonnull FeatureValue[] featureValueArr, float f, Object obj, Object obj2) {
        if (!$assertionsDisabled && obj == null) {
            throw new AssertionError();
        }
        if (obj.equals(obj2)) {
            throw new IllegalArgumentException("Actual label equals to missed label: " + obj);
        }
        PredictionModel predictionModel = this.label2model.get(obj);
        if (predictionModel == null) {
            predictionModel = createModel();
            this.label2model.put(obj, predictionModel);
        }
        PredictionModel predictionModel2 = null;
        if (obj2 != null) {
            predictionModel2 = this.label2model.get(obj2);
            if (predictionModel2 == null) {
                predictionModel2 = createModel();
                this.label2model.put(obj2, predictionModel2);
            }
        }
        for (FeatureValue featureValue : featureValueArr) {
            if (featureValue != null) {
                Object feature = featureValue.getFeature();
                float valueAsFloat = featureValue.getValueAsFloat();
                predictionModel.set(feature, new WeightValue(predictionModel.getWeight(feature) + (f * valueAsFloat)));
                if (predictionModel2 != null) {
                    predictionModel2.set(feature, new WeightValue(predictionModel2.getWeight(feature) - (f * valueAsFloat)));
                }
            }
        }
    }

    @Override // hivemall.LearnerBaseUDTF
    public final void close() throws HiveException {
        super.close();
        if (this.label2model != null) {
            long j = 0;
            long j2 = 0;
            if (useCovariance()) {
                WeightValue.WeightValueWithCovar weightValueWithCovar = new WeightValue.WeightValueWithCovar();
                Object[] objArr = new Object[4];
                FloatWritable floatWritable = new FloatWritable();
                FloatWritable floatWritable2 = new FloatWritable();
                for (Map.Entry<Object, PredictionModel> entry : this.label2model.entrySet()) {
                    objArr[0] = entry.getKey();
                    PredictionModel value = entry.getValue();
                    j2 += value.getNumMixed();
                    IMapIterator entries = value.entries();
                    while (entries.next() != -1) {
                        entries.getValue(weightValueWithCovar);
                        if (weightValueWithCovar.isTouched()) {
                            Object key2 = entries.getKey2();
                            floatWritable.set(weightValueWithCovar.get());
                            floatWritable2.set(weightValueWithCovar.getCovariance());
                            objArr[1] = key2;
                            objArr[2] = floatWritable;
                            objArr[3] = floatWritable2;
                            forward(objArr);
                            j++;
                        }
                    }
                }
            } else {
                WeightValue weightValue = new WeightValue();
                Object[] objArr2 = new Object[3];
                FloatWritable floatWritable3 = new FloatWritable();
                for (Map.Entry<Object, PredictionModel> entry2 : this.label2model.entrySet()) {
                    objArr2[0] = entry2.getKey();
                    PredictionModel value2 = entry2.getValue();
                    j2 += value2.getNumMixed();
                    IMapIterator entries2 = value2.entries();
                    while (entries2.next() != -1) {
                        entries2.getValue(weightValue);
                        if (weightValue.isTouched()) {
                            Object key22 = entries2.getKey2();
                            floatWritable3.set(weightValue.get());
                            objArr2[1] = key22;
                            objArr2[2] = floatWritable3;
                            forward(objArr2);
                            j++;
                        }
                    }
                }
            }
            this.label2model = null;
            logger.info("Trained a prediction model using " + this.count + " training examples" + (j2 > 0 ? "( numMixed: " + j2 + " )" : ""));
            logger.info("Forwarded the prediction model of " + j + " rows");
        }
    }

    protected void loadPredictionModel(Map<Object, PredictionModel> map, String str, PrimitiveObjectInspector primitiveObjectInspector, PrimitiveObjectInspector primitiveObjectInspector2) {
        StopWatch stopWatch = new StopWatch();
        try {
            long loadPredictionModel = useCovariance() ? loadPredictionModel(map, new File(str), primitiveObjectInspector, primitiveObjectInspector2, PrimitiveObjectInspectorFactory.writableFloatObjectInspector, PrimitiveObjectInspectorFactory.writableFloatObjectInspector) : loadPredictionModel(map, new File(str), primitiveObjectInspector, primitiveObjectInspector2, PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
            if (map.isEmpty()) {
                return;
            }
            long j = 0;
            StringBuilder sb = new StringBuilder(Feature.DEFAULT_NUM_FIELDS);
            for (Map.Entry<Object, PredictionModel> entry : map.entrySet()) {
                Object key = entry.getKey();
                int size = entry.getValue().size();
                sb.append('\n').append("Label: ").append(key).append(", Number of Features: ").append(size);
                j += size;
            }
            logger.info("Loaded total " + j + " features from distributed cache '" + str + "' (" + loadPredictionModel + " lines) in " + stopWatch + ((Object) sb));
        } catch (SerDeException e) {
            throw new RuntimeException("Failed to load a model: " + str, e);
        } catch (IOException e2) {
            throw new RuntimeException("Failed to load a model: " + str, e2);
        }
    }

    /* JADX WARN: Finally extract failed */
    /* JADX WARN: Multi-variable type inference failed */
    private long loadPredictionModel(Map<Object, PredictionModel> map, File file, PrimitiveObjectInspector primitiveObjectInspector, PrimitiveObjectInspector primitiveObjectInspector2, WritableFloatObjectInspector writableFloatObjectInspector) throws IOException, SerDeException {
        long j = 0;
        if (!file.exists()) {
            return 0L;
        }
        if (!file.getName().endsWith(".crc")) {
            if (file.isDirectory()) {
                for (File file2 : file.listFiles()) {
                    j += loadPredictionModel(map, file2, primitiveObjectInspector, primitiveObjectInspector2, writableFloatObjectInspector);
                }
            } else {
                LazySimpleSerDe lineSerde = HiveUtils.getLineSerde(primitiveObjectInspector, primitiveObjectInspector2, writableFloatObjectInspector);
                StructObjectInspector objectInspector = lineSerde.getObjectInspector();
                StructField structFieldRef = objectInspector.getStructFieldRef("c1");
                StructField structFieldRef2 = objectInspector.getStructFieldRef("c2");
                StructField structFieldRef3 = objectInspector.getStructFieldRef("c3");
                PrimitiveObjectInspector fieldObjectInspector = structFieldRef.getFieldObjectInspector();
                PrimitiveObjectInspector fieldObjectInspector2 = structFieldRef2.getFieldObjectInspector();
                FloatObjectInspector fieldObjectInspector3 = structFieldRef3.getFieldObjectInspector();
                BufferedReader bufferedReader = null;
                try {
                    bufferedReader = HadoopUtils.getBufferedReader(file);
                    while (true) {
                        String readLine = bufferedReader.readLine();
                        if (readLine == null) {
                            break;
                        }
                        j++;
                        List structFieldsDataAsList = objectInspector.getStructFieldsDataAsList(lineSerde.deserialize(new Text(readLine)));
                        Object obj = structFieldsDataAsList.get(0);
                        Object obj2 = structFieldsDataAsList.get(1);
                        Object obj3 = structFieldsDataAsList.get(2);
                        if (obj != null && obj2 != null && obj3 != null) {
                            Object primitiveWritableObject = fieldObjectInspector.getPrimitiveWritableObject(fieldObjectInspector.copyObject(obj));
                            PredictionModel predictionModel = map.get(primitiveWritableObject);
                            if (predictionModel == null) {
                                predictionModel = createModel();
                                map.put(primitiveWritableObject, predictionModel);
                            }
                            predictionModel.set(fieldObjectInspector2.getPrimitiveWritableObject(fieldObjectInspector2.copyObject(obj2)), new WeightValue(fieldObjectInspector3.get(obj3), false));
                        }
                    }
                    IOUtils.closeQuietly(bufferedReader);
                } catch (Throwable th) {
                    IOUtils.closeQuietly(bufferedReader);
                    throw th;
                }
            }
        }
        return j;
    }

    /* JADX WARN: Finally extract failed */
    /* JADX WARN: Multi-variable type inference failed */
    private long loadPredictionModel(Map<Object, PredictionModel> map, File file, PrimitiveObjectInspector primitiveObjectInspector, PrimitiveObjectInspector primitiveObjectInspector2, WritableFloatObjectInspector writableFloatObjectInspector, WritableFloatObjectInspector writableFloatObjectInspector2) throws IOException, SerDeException {
        long j = 0;
        if (!file.exists()) {
            return 0L;
        }
        if (!file.getName().endsWith(".crc")) {
            if (file.isDirectory()) {
                for (File file2 : file.listFiles()) {
                    j += loadPredictionModel(map, file2, primitiveObjectInspector, primitiveObjectInspector2, writableFloatObjectInspector, writableFloatObjectInspector2);
                }
            } else {
                LazySimpleSerDe lineSerde = HiveUtils.getLineSerde(primitiveObjectInspector, primitiveObjectInspector2, writableFloatObjectInspector, writableFloatObjectInspector2);
                StructObjectInspector objectInspector = lineSerde.getObjectInspector();
                StructField structFieldRef = objectInspector.getStructFieldRef("c1");
                StructField structFieldRef2 = objectInspector.getStructFieldRef("c2");
                StructField structFieldRef3 = objectInspector.getStructFieldRef("c3");
                StructField structFieldRef4 = objectInspector.getStructFieldRef("c4");
                PrimitiveObjectInspector fieldObjectInspector = structFieldRef.getFieldObjectInspector();
                PrimitiveObjectInspector fieldObjectInspector2 = structFieldRef2.getFieldObjectInspector();
                FloatObjectInspector fieldObjectInspector3 = structFieldRef3.getFieldObjectInspector();
                FloatObjectInspector fieldObjectInspector4 = structFieldRef4.getFieldObjectInspector();
                BufferedReader bufferedReader = null;
                try {
                    bufferedReader = HadoopUtils.getBufferedReader(file);
                    while (true) {
                        String readLine = bufferedReader.readLine();
                        if (readLine == null) {
                            break;
                        }
                        j++;
                        List structFieldsDataAsList = objectInspector.getStructFieldsDataAsList(lineSerde.deserialize(new Text(readLine)));
                        Object obj = structFieldsDataAsList.get(0);
                        Object obj2 = structFieldsDataAsList.get(1);
                        Object obj3 = structFieldsDataAsList.get(2);
                        Object obj4 = structFieldsDataAsList.get(3);
                        if (obj != null && obj2 != null && obj3 != null) {
                            Object primitiveWritableObject = fieldObjectInspector.getPrimitiveWritableObject(fieldObjectInspector.copyObject(obj));
                            PredictionModel predictionModel = map.get(primitiveWritableObject);
                            if (predictionModel == null) {
                                predictionModel = createModel();
                                map.put(primitiveWritableObject, predictionModel);
                            }
                            predictionModel.set(fieldObjectInspector2.getPrimitiveWritableObject(fieldObjectInspector2.copyObject(obj2)), new WeightValue.WeightValueWithCovar(fieldObjectInspector3.get(obj3), obj4 == null ? 1.0f : fieldObjectInspector4.get(obj4), false));
                        }
                    }
                    IOUtils.closeQuietly(bufferedReader);
                } catch (Throwable th) {
                    IOUtils.closeQuietly(bufferedReader);
                    throw th;
                }
            }
        }
        return j;
    }

    static {
        $assertionsDisabled = !MulticlassOnlineClassifierUDTF.class.desiredAssertionStatus();
        logger = LogFactory.getLog(MulticlassOnlineClassifierUDTF.class);
    }
}
