package hivemall.fm;

import hivemall.optimizer.EtaEstimator;
import hivemall.utils.lang.NumberUtils;
import hivemall.utils.lang.SizeOf;
import hivemall.utils.math.MathUtils;
import java.util.Arrays;
import java.util.Random;
import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.hadoop.hive.ql.metadata.HiveException;

/* loaded from: input_file:hivemall/fm/FactorizationMachineModel.class */
public abstract class FactorizationMachineModel {
    protected final boolean _classification;
    protected final int _factor;
    protected final double _sigma;
    protected final EtaEstimator _eta;
    protected final VInitScheme _initScheme;
    protected final Random _rnd;
    protected final double _min_target;
    protected final double _max_target;
    protected float _lambdaW0;
    protected float _lambdaW;
    protected final float[] _lambdaV;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* renamed from: hivemall.fm.FactorizationMachineModel$1, reason: invalid class name */
    /* loaded from: input_file:hivemall/fm/FactorizationMachineModel$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$hivemall$fm$FactorizationMachineModel$VInitScheme = new int[VInitScheme.values().length];

        static {
            try {
                $SwitchMap$hivemall$fm$FactorizationMachineModel$VInitScheme[VInitScheme.adjustedRandom.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$hivemall$fm$FactorizationMachineModel$VInitScheme[VInitScheme.libffmRandom.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$hivemall$fm$FactorizationMachineModel$VInitScheme[VInitScheme.random.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$hivemall$fm$FactorizationMachineModel$VInitScheme[VInitScheme.gaussian.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
        }
    }

    /* loaded from: input_file:hivemall/fm/FactorizationMachineModel$VInitScheme.class */
    public enum VInitScheme {
        adjustedRandom,
        libffmRandom,
        random,
        gaussian;


        @Nonnegative
        float maxInitValue;

        @Nonnegative
        double initStdDev;
        Random[] rand;

        @Nonnull
        public static VInitScheme resolve(@Nullable String str) {
            return resolve(str, adjustedRandom);
        }

        @Nonnull
        public static VInitScheme resolve(@Nullable String str, @Nonnull VInitScheme vInitScheme) {
            return str == null ? vInitScheme : ("adjusted_random".equalsIgnoreCase(str) || "adjustedRandom".equalsIgnoreCase(str)) ? adjustedRandom : ("libffm_random".equalsIgnoreCase(str) || "libffmRandom".equalsIgnoreCase(str) || "libffm".equalsIgnoreCase(str)) ? libffmRandom : "random".equalsIgnoreCase(str) ? random : "gaussian".equalsIgnoreCase(str) ? gaussian : vInitScheme;
        }

        public void setMaxInitValue(float f) {
            this.maxInitValue = f;
        }

        public void setInitStdDev(double d) {
            this.initStdDev = d;
        }

        public void initRandom(int i, long j) {
            int i2 = this != gaussian ? 1 : i;
            this.rand = new Random[i2];
            for (int i3 = 0; i3 < i2; i3++) {
                this.rand[i3] = new Random(j + i3);
            }
        }
    }

    public FactorizationMachineModel(@Nonnull FMHyperParameters fMHyperParameters) {
        this._classification = fMHyperParameters.classification;
        this._factor = fMHyperParameters.factors;
        this._sigma = fMHyperParameters.sigma;
        this._eta = fMHyperParameters.eta;
        this._initScheme = fMHyperParameters.vInit;
        this._rnd = new Random(fMHyperParameters.seed);
        this._min_target = fMHyperParameters.minTarget;
        this._max_target = fMHyperParameters.maxTarget;
        this._lambdaW0 = fMHyperParameters.lambdaW0;
        this._lambdaW = fMHyperParameters.lambdaW;
        this._lambdaV = new float[fMHyperParameters.factors];
        Arrays.fill(this._lambdaV, fMHyperParameters.lambdaV);
    }

    public abstract int getSize();

    /* JADX INFO: Access modifiers changed from: protected */
    public int getMinIndex() {
        throw new UnsupportedOperationException();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public int getMaxIndex() {
        throw new UnsupportedOperationException();
    }

    public abstract float getW0();

    protected abstract void setW0(float f);

    /* JADX INFO: Access modifiers changed from: protected */
    public float getW(int i) {
        throw new UnsupportedOperationException();
    }

    public abstract float getW(@Nonnull Feature feature);

    protected abstract void setW(@Nonnull Feature feature, float f);

    /* JADX INFO: Access modifiers changed from: protected */
    @Nullable
    public float[] getV(int i, boolean z) {
        throw new UnsupportedOperationException();
    }

    public abstract float getV(@Nonnull Feature feature, int i);

    protected abstract void setV(@Nonnull Feature feature, int i, float f);

    /* JADX INFO: Access modifiers changed from: package-private */
    public float getLambdaV(int i) {
        return this._lambdaV[i];
    }

    final double dloss(@Nonnull Feature[] featureArr, double d) throws HiveException {
        return dloss(predict(featureArr), d);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public final double dloss(double d, double d2) {
        return this._classification ? (MathUtils.sigmoid(d * d2) - 1.0d) * d2 : Math.max(Math.min(d, this._max_target), this._min_target) - d2;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double predict(@Nonnull Feature[] featureArr) throws HiveException {
        double w0 = getW0();
        for (Feature feature : featureArr) {
            w0 += getW(r0) * feature.getValue();
        }
        int i = this._factor;
        for (int i2 = 0; i2 < i; i2++) {
            double d = 0.0d;
            double d2 = 0.0d;
            for (Feature feature2 : featureArr) {
                double v = getV(r0, i2) * feature2.getValue();
                d += v;
                d2 += v * v;
            }
            w0 += 0.5d * ((d * d) - d2);
            if (!$assertionsDisabled && Double.isNaN(w0)) {
                throw new AssertionError();
            }
        }
        if (NumberUtils.isFinite(w0)) {
            return w0;
        }
        throw new HiveException("Detected " + w0 + " in predict. We recommend to normalize training examples.\nDumping variables ...\n" + varDump(featureArr));
    }

    protected String varDump(@Nonnull Feature[] featureArr) {
        StringBuilder sb = new StringBuilder(1024);
        for (int i = 0; i < featureArr.length; i++) {
            Feature feature = featureArr[i];
            String feature2 = feature.getFeature();
            double value = feature.getValue();
            if (i != 0) {
                sb.append(", ");
            }
            sb.append("x[").append(feature2).append("] = ").append(value);
        }
        sb.append("\n");
        sb.append("W0 = ").append(getW0()).append('\n');
        for (int i2 = 0; i2 < featureArr.length; i2++) {
            Feature feature3 = featureArr[i2];
            String feature4 = feature3.getFeature();
            float w = getW(feature3);
            if (i2 != 0) {
                sb.append(", ");
            }
            sb.append("W[").append(feature4).append("] = ").append(w);
        }
        sb.append("\n");
        int i3 = this._factor;
        for (int i4 = 0; i4 < i3; i4++) {
            for (int i5 = 0; i5 < featureArr.length; i5++) {
                Feature feature5 = featureArr[i5];
                String feature6 = feature5.getFeature();
                float v = getV(feature5, i4);
                if (i5 != 0) {
                    sb.append(", ");
                }
                sb.append('V').append(i4).append('[').append(feature6).append("] = ").append(v);
            }
            sb.append('\n');
        }
        return sb.toString();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public final void updateW0(double d, float f) {
        float f2 = (float) d;
        float w0 = getW0();
        float f3 = w0 - (f * (f2 + ((2.0f * this._lambdaW0) * w0)));
        if (!NumberUtils.isFinite(f3)) {
            throw new IllegalStateException("Got " + f3 + " for next W0\ngradW0=" + f2 + ", prevW0=" + w0 + ", dloss=" + d + ", eta=" + f);
        }
        setW0(f3);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void updateWi(double d, @Nonnull Feature feature, float f) {
        double value = feature.getValue();
        float f2 = (float) (d * value);
        float w = getW(feature);
        float f3 = w - (f * (f2 + ((2.0f * this._lambdaW) * w)));
        if (!NumberUtils.isFinite(f3)) {
            throw new IllegalStateException("Got " + f3 + " for next W[" + feature.getFeature() + "]\nXi=" + value + ", gradWi=" + f2 + ", wi=" + w + ", dloss=" + d + ", eta=" + f);
        }
        setW(feature, f3);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public final void updateV(double d, @Nonnull Feature feature, int i, double d2, float f) {
        double value = feature.getValue();
        float v = getV(feature, i);
        double gradV = gradV(value, v, d2);
        float f2 = (float) (d * gradV);
        float lambdaV = getLambdaV(i);
        float f3 = v - (f * (f2 + ((2.0f * lambdaV) * v)));
        if (!NumberUtils.isFinite(f3)) {
            throw new IllegalStateException("Got " + f3 + " for next V" + i + '[' + feature.getFeature() + "]\nXi=" + value + ", Vif=" + v + ", h=" + gradV + ", gradV=" + f2 + ", lambdaVf=" + lambdaV + ", dloss=" + d + ", sumViX=" + d2 + ", eta=" + f);
        }
        setV(feature, i, f3);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public final void updateLambdaW0(double d, float f) {
        this._lambdaW0 = Math.max(0.0f, this._lambdaW0 - ((float) ((f * d) * (((-2.0f) * f) * getW0()))));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public final void updateLambdaW(@Nonnull Feature[] featureArr, double d, float f) {
        double d2 = 0.0d;
        for (Feature feature : featureArr) {
            if (!$assertionsDisabled && feature == null) {
                throw new AssertionError(Arrays.toString(featureArr));
            }
            d2 += getW(feature) * feature.getValue();
        }
        this._lambdaW = Math.max(0.0f, this._lambdaW - ((float) ((f * d) * (((-2.0f) * f) * d2))));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public final void updateLambdaV(@Nonnull Feature[] featureArr, double d, float f) {
        int i = this._factor;
        for (int i2 = 0; i2 < i; i2++) {
            double d2 = 0.0d;
            double d3 = 0.0d;
            double d4 = 0.0d;
            float lambdaV = getLambdaV(i2);
            double sumVfX = sumVfX(featureArr, i2);
            for (Feature feature : featureArr) {
                if (!$assertionsDisabled && feature == null) {
                    throw new AssertionError(Arrays.toString(featureArr));
                }
                double value = feature.getValue();
                float v = getV(feature, i2);
                double gradV = v - (f * (gradV(value, v, sumVfX) + ((2.0d * lambdaV) * v)));
                d2 += value * gradV;
                d3 += value * v;
                d4 += value * gradV * value * v;
            }
            this._lambdaV[i2] = Math.max(0.0f, (float) (lambdaV - ((f * d) * (((-2.0f) * f) * ((d2 * d3) - d4)))));
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public double[] sumVfX(@Nonnull Feature[] featureArr) {
        int i = this._factor;
        double[] dArr = new double[i];
        for (int i2 = 0; i2 < i; i2++) {
            dArr[i2] = sumVfX(featureArr, i2);
        }
        return dArr;
    }

    private double sumVfX(@Nonnull Feature[] featureArr, int i) {
        double d = 0.0d;
        for (Feature feature : featureArr) {
            d += getV(r0, i) * feature.getValue();
        }
        if (NumberUtils.isFinite(d)) {
            return d;
        }
        throw new IllegalStateException("Got " + d + " for sumV[ " + i + "]X.\nx = " + Arrays.toString(featureArr));
    }

    private double gradV(@Nonnull double d, float f, double d2) {
        return d * (d2 - (f * d));
    }

    public void check(@Nonnull Feature[] featureArr) throws HiveException {
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Nonnull
    public final float[] initV() {
        float[] fArr = new float[this._factor];
        switch (AnonymousClass1.$SwitchMap$hivemall$fm$FactorizationMachineModel$VInitScheme[this._initScheme.ordinal()]) {
            case SizeOf.BYTE /* 1 */:
                adjustedRandomFill(fArr, this._initScheme.rand[0], this._initScheme.maxInitValue);
                break;
            case 2:
                libffmRandomFill(fArr, this._initScheme.rand[0], this._initScheme.maxInitValue);
                break;
            case 3:
                randomFill(fArr, this._initScheme.rand[0], this._initScheme.maxInitValue);
                break;
            case 4:
                gaussianFill(fArr, this._initScheme.rand, this._initScheme.initStdDev);
                break;
            default:
                throw new IllegalStateException("Unsupported V initialization scheme: " + this._initScheme);
        }
        return fArr;
    }

    protected static final void adjustedRandomFill(@Nonnull float[] fArr, @Nonnull Random random, float f) {
        int length = fArr.length;
        float f2 = f / length;
        for (int i = 0; i < length; i++) {
            fArr[i] = random.nextFloat() * f2;
        }
    }

    protected static final void libffmRandomFill(@Nonnull float[] fArr, @Nonnull Random random, float f) {
        int length = fArr.length;
        float sqrt = f / ((float) Math.sqrt(length));
        for (int i = 0; i < length; i++) {
            fArr[i] = random.nextFloat() * sqrt;
        }
    }

    protected static final void randomFill(@Nonnull float[] fArr, @Nonnull Random random, float f) {
        int length = fArr.length;
        for (int i = 0; i < length; i++) {
            fArr[i] = random.nextFloat() * f;
        }
    }

    protected static final void gaussianFill(@Nonnull float[] fArr, @Nonnull Random[] randomArr, double d) {
        int length = fArr.length;
        for (int i = 0; i < length; i++) {
            fArr[i] = (float) MathUtils.gaussian(0.0d, d, randomArr[i]);
        }
    }

    static {
        $assertionsDisabled = !FactorizationMachineModel.class.desiredAssertionStatus();
    }
}
