package hivemall.classifier;

import hivemall.annotations.Experimental;
import hivemall.annotations.VisibleForTesting;
import hivemall.model.FeatureValue;
import hivemall.model.PredictionModel;
import hivemall.model.PredictionResult;
import hivemall.optimizer.LossFunctions;
import hivemall.utils.collections.Fastutil;
import hivemall.utils.hashing.HashFunction;
import hivemall.utils.lang.Preconditions;
import it.unimi.dsi.fastutil.ints.Int2FloatMap;
import it.unimi.dsi.fastutil.ints.Int2FloatOpenHashMap;
import it.unimi.dsi.fastutil.objects.ObjectIterator;
import java.util.ArrayList;
import java.util.List;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.IntWritable;

@Experimental
@Description(name = "train_kpa", value = "_FUNC_(array<string|int|bigint> features, int label [, const string options]) - returns a relation <h int, hk int, float w0, float w1, float w2, float w3>")
/* loaded from: input_file:hivemall/classifier/KernelExpansionPassiveAggressiveUDTF.class */
public final class KernelExpansionPassiveAggressiveUDTF extends BinaryOnlineClassifierUDTF {
    private float _pkc;
    private Algorithm _algo;
    private float _w0;
    private Int2FloatMap _w1;
    private Int2FloatMap _w2;
    private Int2FloatMap _w3;
    private float _loss;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:hivemall/classifier/KernelExpansionPassiveAggressiveUDTF$Algorithm.class */
    public interface Algorithm {
        float eta(float f, @Nonnull PredictionResult predictionResult);
    }

    /* loaded from: input_file:hivemall/classifier/KernelExpansionPassiveAggressiveUDTF$PA.class */
    static class PA implements Algorithm {
        PA() {
        }

        @Override // hivemall.classifier.KernelExpansionPassiveAggressiveUDTF.Algorithm
        public float eta(float f, PredictionResult predictionResult) {
            return f / predictionResult.getSquaredNorm();
        }
    }

    /* loaded from: input_file:hivemall/classifier/KernelExpansionPassiveAggressiveUDTF$PA1.class */
    static class PA1 implements Algorithm {
        private final float c;

        PA1(float f) {
            this.c = f;
        }

        @Override // hivemall.classifier.KernelExpansionPassiveAggressiveUDTF.Algorithm
        public float eta(float f, PredictionResult predictionResult) {
            return Math.min(this.c, f / predictionResult.getSquaredNorm());
        }
    }

    /* loaded from: input_file:hivemall/classifier/KernelExpansionPassiveAggressiveUDTF$PA2.class */
    static class PA2 implements Algorithm {
        private final float c;

        PA2(float f) {
            this.c = f;
        }

        @Override // hivemall.classifier.KernelExpansionPassiveAggressiveUDTF.Algorithm
        public float eta(float f, PredictionResult predictionResult) {
            return f / (predictionResult.getSquaredNorm() + (0.5f / this.c));
        }
    }

    @VisibleForTesting
    float getLoss() {
        return this._loss;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hivemall.LearnerBaseUDTF, hivemall.UDTFWithOptions
    public Options getOptions() {
        Options options = new Options();
        options.addOption("pkc", true, "Constant c inside polynomial kernel K = (dot(xi,xj) + c)^2 [default 1.0]");
        options.addOption("algo", "algorithm", true, "Algorithm for calculating loss [pa, pa1 (default), pa2]");
        options.addOption("c", "aggressiveness", true, "Aggressiveness parameter C for PA-1 and PA-2 [default 1.0]");
        return options;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hivemall.LearnerBaseUDTF, hivemall.UDTFWithOptions
    public CommandLine processOptions(ObjectInspector[] objectInspectorArr) throws UDFArgumentException {
        float f = 1.0f;
        float f2 = 1.0f;
        String str = "pa1";
        CommandLine processOptions = super.processOptions(objectInspectorArr);
        if (processOptions != null) {
            String optionValue = processOptions.getOptionValue("pkc");
            if (optionValue != null) {
                f = Float.parseFloat(optionValue);
            }
            String optionValue2 = processOptions.getOptionValue("c");
            if (optionValue2 != null) {
                f2 = Float.parseFloat(optionValue2);
                if (f2 <= 0.0f) {
                    throw new UDFArgumentException("Aggressiveness parameter C must be C > 0: " + f2);
                }
            }
            str = processOptions.getOptionValue("algo", str);
        }
        if ("pa1".equalsIgnoreCase(str)) {
            this._algo = new PA1(f2);
        } else if ("pa2".equalsIgnoreCase(str)) {
            this._algo = new PA2(f2);
        } else {
            if (!"pa".equalsIgnoreCase(str)) {
                throw new UDFArgumentException("Unsupported algorithm: " + str);
            }
            this._algo = new PA();
        }
        this._pkc = f;
        return processOptions;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hivemall.LearnerBaseUDTF
    public PredictionModel createModel() {
        this._w0 = 0.0f;
        this._w1 = new Int2FloatOpenHashMap(16384);
        this._w1.defaultReturnValue(0.0f);
        this._w2 = new Int2FloatOpenHashMap(16384);
        this._w2.defaultReturnValue(0.0f);
        this._w3 = new Int2FloatOpenHashMap(16384);
        this._w3.defaultReturnValue(0.0f);
        return null;
    }

    @Override // hivemall.classifier.BinaryOnlineClassifierUDTF
    protected StructObjectInspector getReturnOI(ObjectInspector objectInspector) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        arrayList.add("h");
        arrayList2.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
        arrayList.add("w0");
        arrayList2.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
        arrayList.add("w1");
        arrayList2.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
        arrayList.add("w2");
        arrayList2.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
        arrayList.add("hk");
        arrayList2.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
        arrayList.add("w3");
        arrayList2.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
        return ObjectInspectorFactory.getStandardStructObjectInspector(arrayList, arrayList2);
    }

    @Override // hivemall.classifier.BinaryOnlineClassifierUDTF
    @Nullable
    FeatureValue[] parseFeatures(@Nonnull List<?> list) {
        int size = list.size();
        if (size == 0) {
            return null;
        }
        FeatureValue[] featureValueArr = new FeatureValue[size];
        for (int i = 0; i < size; i++) {
            Object obj = list.get(i);
            if (obj != null) {
                featureValueArr[i] = FeatureValue.parse(obj, true);
            }
        }
        return featureValueArr;
    }

    @Override // hivemall.classifier.BinaryOnlineClassifierUDTF
    protected void train(@Nonnull FeatureValue[] featureValueArr, int i) {
        float f = i > 0 ? 1.0f : -1.0f;
        PredictionResult calcScoreWithKernelAndNorm = calcScoreWithKernelAndNorm(featureValueArr);
        float hingeLoss = LossFunctions.hingeLoss(calcScoreWithKernelAndNorm.getScore(), f);
        this._loss = hingeLoss;
        if (hingeLoss > 0.0f) {
            updateKernel(f, hingeLoss, calcScoreWithKernelAndNorm, featureValueArr);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @Override // hivemall.classifier.BinaryOnlineClassifierUDTF
    public float predict(@Nonnull FeatureValue[] featureValueArr) {
        float f = 0.0f;
        for (int i = 0; i < featureValueArr.length; i++) {
            if (featureValueArr[i] != null) {
                int featureAsInt = featureValueArr[i].getFeatureAsInt();
                float f2 = this._w1.get(featureAsInt);
                float f3 = this._w2.get(featureAsInt);
                double value = featureValueArr[i].getValue();
                f = (float) (((float) (f + (f2 * value))) + (f3 * value * value));
                for (int i2 = i + 1; i2 < featureValueArr.length; i2++) {
                    f = (float) (f + (value * featureValueArr[i2].getValue() * this._w3.get(HashFunction.hash(featureAsInt, featureValueArr[i2].getFeatureAsInt(), true))));
                }
            }
        }
        return f;
    }

    @Nonnull
    final PredictionResult calcScoreWithKernelAndNorm(@Nonnull FeatureValue[] featureValueArr) {
        float f = this._w0;
        float f2 = 0.0f;
        for (int i = 0; i < featureValueArr.length; i++) {
            if (featureValueArr[i] != null) {
                int featureAsInt = featureValueArr[i].getFeatureAsInt();
                float f3 = this._w1.get(featureAsInt);
                float f4 = this._w2.get(featureAsInt);
                double value = featureValueArr[i].getValue();
                double d = value * value;
                f = (float) (((float) (f + (f3 * value))) + (f4 * d));
                f2 = (float) (f2 + d);
                for (int i2 = i + 1; i2 < featureValueArr.length; i2++) {
                    f = (float) (f + (value * featureValueArr[i2].getValue() * this._w3.get(HashFunction.hash(featureAsInt, featureValueArr[i2].getFeatureAsInt(), true))));
                }
            }
        }
        return new PredictionResult(f).squaredNorm(f2);
    }

    protected void updateKernel(float f, float f2, @Nonnull PredictionResult predictionResult, @Nonnull FeatureValue[] featureValueArr) {
        expandKernel(featureValueArr, this._algo.eta(f2, predictionResult) * f);
    }

    private void expandKernel(@Nonnull FeatureValue[] featureValueArr, float f) {
        float f2 = this._pkc;
        this._w0 += f * f2 * f2;
        for (int i = 0; i < featureValueArr.length; i++) {
            FeatureValue featureValue = featureValueArr[i];
            int featureAsInt = featureValue.getFeatureAsInt();
            float valueAsFloat = featureValue.getValueAsFloat();
            float f3 = f * valueAsFloat;
            float f4 = f3 * 2.0f;
            this._w1.put(featureAsInt, this._w1.get(featureAsInt) + (f2 * f4));
            this._w2.put(featureAsInt, this._w2.get(featureAsInt) + (f3 * valueAsFloat));
            for (int i2 = i + 1; i2 < featureValueArr.length; i2++) {
                FeatureValue featureValue2 = featureValueArr[i2];
                int hash = HashFunction.hash(featureAsInt, featureValue2.getFeatureAsInt(), true);
                this._w3.put(hash, this._w3.get(hash) + (f4 * featureValue2.getValueAsFloat()));
            }
        }
    }

    @Override // hivemall.classifier.BinaryOnlineClassifierUDTF, hivemall.LearnerBaseUDTF
    public void close() throws HiveException {
        IntWritable intWritable = new IntWritable(0);
        FloatWritable floatWritable = new FloatWritable(this._w0);
        FloatWritable floatWritable2 = new FloatWritable();
        FloatWritable floatWritable3 = new FloatWritable();
        IntWritable intWritable2 = new IntWritable(0);
        FloatWritable floatWritable4 = new FloatWritable();
        Object[] objArr = {intWritable, floatWritable, null, null, null, null};
        forward(objArr);
        objArr[1] = null;
        objArr[2] = floatWritable2;
        objArr[3] = floatWritable3;
        Int2FloatMap int2FloatMap = this._w2;
        ObjectIterator it = Fastutil.fastIterable(this._w1).iterator();
        while (it.hasNext()) {
            Int2FloatMap.Entry entry = (Int2FloatMap.Entry) it.next();
            int intKey = entry.getIntKey();
            Preconditions.checkArgument(intKey > 0, HiveException.class);
            intWritable.set(intKey);
            floatWritable2.set(entry.getFloatValue());
            floatWritable3.set(int2FloatMap.get(intKey));
            forward(objArr);
        }
        this._w1 = null;
        this._w2 = null;
        objArr[0] = null;
        objArr[2] = null;
        objArr[3] = null;
        objArr[4] = intWritable2;
        objArr[5] = floatWritable4;
        this._w3.int2FloatEntrySet();
        ObjectIterator it2 = Fastutil.fastIterable(this._w3).iterator();
        while (it2.hasNext()) {
            Int2FloatMap.Entry entry2 = (Int2FloatMap.Entry) it2.next();
            int intKey2 = entry2.getIntKey();
            Preconditions.checkArgument(intKey2 > 0, HiveException.class);
            intWritable2.set(intKey2);
            floatWritable4.set(entry2.getFloatValue());
            forward(objArr);
        }
        this._w3 = null;
    }
}
