package hivemall.factorization.fm;

import hivemall.factorization.fm.FMHyperParameters;
import hivemall.utils.collections.arrays.DoubleArray3D;
import hivemall.utils.collections.lists.IntArrayList;
import hivemall.utils.lang.NumberUtils;
import hivemall.utils.math.MathUtils;
import java.util.Arrays;
import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.hadoop.hive.ql.metadata.HiveException;

/* loaded from: input_file:hivemall/factorization/fm/FieldAwareFactorizationMachineModel.class */
public abstract class FieldAwareFactorizationMachineModel extends FactorizationMachineModel {

    @Nonnull
    protected final FMHyperParameters.FFMHyperParameters _params;
    protected final float _eps;
    protected final boolean _useAdaGrad;
    protected final boolean _useFTRL;
    private final float _alpha;
    private final float _beta;
    private final float _lambda1;
    private final float _lambda2;
    static final /* synthetic */ boolean $assertionsDisabled;

    public FieldAwareFactorizationMachineModel(@Nonnull FMHyperParameters.FFMHyperParameters fFMHyperParameters) {
        super(fFMHyperParameters);
        this._params = fFMHyperParameters;
        this._eps = fFMHyperParameters.eps;
        this._useAdaGrad = fFMHyperParameters.useAdaGrad;
        this._useFTRL = fFMHyperParameters.useFTRL;
        this._alpha = fFMHyperParameters.alphaFTRL;
        this._beta = fFMHyperParameters.betaFTRL;
        this._lambda1 = fFMHyperParameters.lambda1;
        this._lambda2 = fFMHyperParameters.lambda2;
    }

    public abstract float getV(@Nonnull Feature feature, @Nonnull int i, int i2);

    @Deprecated
    protected abstract void setV(@Nonnull Feature feature, @Nonnull int i, int i2, float f);

    @Override // hivemall.factorization.fm.FactorizationMachineModel
    public float getV(Feature feature, int i) {
        throw new UnsupportedOperationException();
    }

    @Override // hivemall.factorization.fm.FactorizationMachineModel
    protected void setV(Feature feature, int i, float f) {
        throw new UnsupportedOperationException();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // hivemall.factorization.fm.FactorizationMachineModel
    public final double predict(@Nonnull Feature[] featureArr) throws HiveException {
        double w0 = getW0();
        for (Feature feature : featureArr) {
            w0 += getW(r0) * feature.getValue();
        }
        for (int i = 0; i < featureArr.length; i++) {
            Feature feature2 = featureArr[i];
            double value = feature2.getValue();
            short field = feature2.getField();
            for (int i2 = i + 1; i2 < featureArr.length; i2++) {
                Feature feature3 = featureArr[i2];
                double value2 = feature3.getValue();
                short field2 = feature3.getField();
                int i3 = this._factor;
                for (int i4 = 0; i4 < i3; i4++) {
                    w0 += getV(feature2, field2, i4) * getV(feature3, field, i4) * value * value2;
                    if (!$assertionsDisabled && Double.isNaN(w0)) {
                        throw new AssertionError();
                    }
                }
            }
        }
        if (NumberUtils.isFinite(w0)) {
            return w0;
        }
        throw new HiveException("Detected " + w0 + " in predict. We recommend to normalize training examples.\nDumping variables ...\n" + varDump(featureArr));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void updateWi(double d, @Nonnull Feature feature, long j) {
        if (this._useFTRL) {
            updateWi_FTRL(d, feature);
            return;
        }
        double value = feature.getValue();
        float f = (float) (d * value);
        Entry entryW = getEntryW(feature);
        float w = entryW.getW();
        float eta = eta(entryW, j, f);
        float f2 = w - (eta * (f + ((2.0f * this._lambdaW) * w)));
        if (!NumberUtils.isFinite(f2)) {
            throw new IllegalStateException("Got " + f2 + " for next W[" + feature.getFeature() + "]\nXi=" + value + ", gradWi=" + f + ", wi=" + w + ", dloss=" + d + ", eta=" + eta + ", t=" + j);
        }
        if (MathUtils.closeToZero(f2, 1.0E-9f)) {
            removeEntry(entryW);
        } else {
            entryW.setW(f2);
        }
    }

    private void updateWi_FTRL(double d, @Nonnull Feature feature) {
        double value = feature.getValue();
        float f = (float) (d * value);
        Entry entryW = getEntryW(feature);
        float updateZ = entryW.updateZ(f, this._alpha);
        double updateN = entryW.updateN(f);
        if (Math.abs(updateZ) <= this._lambda1) {
            removeEntry(entryW);
            return;
        }
        float sign = (float) (((MathUtils.sign(updateZ) * this._lambda1) - updateZ) / (((this._beta + Math.sqrt(updateN)) / this._alpha) + this._lambda2));
        if (!NumberUtils.isFinite(sign)) {
            throw new IllegalStateException("Got " + sign + " for next W[" + feature.getFeature() + "]\nXi=" + value + ", gradWi=" + f + ", wi=" + entryW.getW() + ", dloss=" + d + ", n=" + updateN + ", z=" + updateZ);
        }
        if (MathUtils.closeToZero(sign, 1.0E-9f)) {
            removeEntry(entryW);
        } else {
            entryW.setW(sign);
        }
    }

    protected abstract void removeEntry(@Nonnull Entry entry);

    /* JADX INFO: Access modifiers changed from: package-private */
    public void updateV(double d, @Nonnull Feature feature, @Nonnull int i, int i2, double d2, long j) {
        if (this._useFTRL) {
            updateV_FTRL(d, feature, i, i2, d2);
            return;
        }
        Entry entryV = getEntryV(feature, i);
        if (entryV == null) {
            return;
        }
        double value = feature.getValue();
        double d3 = value * d2;
        float f = (float) (d * d3);
        float lambdaV = getLambdaV(i2);
        float v = entryV.getV(i2);
        float eta = v - (eta(entryV, i2, j, f) * (f + ((2.0f * lambdaV) * v)));
        if (!NumberUtils.isFinite(eta)) {
            throw new IllegalStateException("Got " + eta + " for next V" + i2 + '[' + feature.getFeatureIndex() + "]\nXi=" + value + ", Vif=" + v + ", h=" + d3 + ", gradV=" + f + ", lambdaVf=" + lambdaV + ", dloss=" + d + ", sumViX=" + d2 + ", t=" + j);
        }
        if (!MathUtils.closeToZero(eta, 1.0E-9f)) {
            entryV.setV(i2, eta);
            return;
        }
        entryV.setV(i2, 0.0f);
        if (entryV.removable()) {
            removeEntry(entryV);
        }
    }

    private void updateV_FTRL(double d, @Nonnull Feature feature, @Nonnull int i, int i2, double d2) {
        Entry entryV = getEntryV(feature, i);
        if (entryV == null) {
            return;
        }
        double value = feature.getValue();
        double d3 = value * d2;
        float f = (float) (d * d3);
        float updateZ = entryV.updateZ(i2, entryV.getV(i2), f, this._alpha);
        double updateN = entryV.updateN(i2, f);
        if (Math.abs(updateZ) <= this._lambda1) {
            entryV.setV(i2, 0.0f);
            if (entryV.removable()) {
                removeEntry(entryV);
                return;
            }
            return;
        }
        float sign = (float) (((MathUtils.sign(updateZ) * this._lambda1) - updateZ) / (((this._beta + Math.sqrt(updateN)) / this._alpha) + this._lambda2));
        if (!NumberUtils.isFinite(sign)) {
            throw new IllegalStateException("Got " + sign + " for next V" + i2 + '[' + feature.getFeatureIndex() + "]\nXi=" + value + ", Vif=" + entryV.getV(i2) + ", h=" + d3 + ", gradV=" + f + ", dloss=" + d + ", sumViX=" + d2 + ", n=" + updateN + ", z=" + updateZ);
        }
        if (!MathUtils.closeToZero(sign, 1.0E-9f)) {
            entryV.setV(i2, sign);
            return;
        }
        entryV.setV(i2, 0.0f);
        if (entryV.removable()) {
            removeEntry(entryV);
        }
    }

    protected final float eta(@Nonnull Entry entry, long j, float f) {
        return eta(entry, 0, j, f);
    }

    protected final float eta(@Nonnull Entry entry, @Nonnegative int i, long j, float f) {
        if (!this._useAdaGrad) {
            return this._eta.eta(j);
        }
        double sumOfSquaredGradients = entry.getSumOfSquaredGradients(i);
        entry.addGradient(i, f);
        return (float) (this._eta.eta(j) / Math.sqrt(this._eps + sumOfSquaredGradients));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @Nonnull
    public final DoubleArray3D sumVfX(@Nonnull Feature[] featureArr, @Nonnull IntArrayList intArrayList, @Nullable DoubleArray3D doubleArray3D) {
        DoubleArray3D doubleArray3D2;
        int length = featureArr.length;
        int size = intArrayList.size();
        int i = this._factor;
        if (doubleArray3D == null) {
            doubleArray3D2 = new DoubleArray3D();
            doubleArray3D2.setSanityCheck(false);
        } else {
            doubleArray3D2 = doubleArray3D;
        }
        doubleArray3D2.configure(length, size, i);
        for (int i2 = 0; i2 < length; i2++) {
            for (int i3 = 0; i3 < size; i3++) {
                int i4 = intArrayList.get(i3);
                for (int i5 = 0; i5 < i; i5++) {
                    doubleArray3D2.set(i2, i3, i5, sumVfX(featureArr, i2, i4, i5));
                }
            }
        }
        return doubleArray3D2;
    }

    private double sumVfX(@Nonnull Feature[] featureArr, int i, @Nonnull int i2, int i3) {
        Feature feature = featureArr[i];
        int featureIndex = feature.getFeatureIndex();
        double value = feature.getValue();
        short field = feature.getField();
        double d = 0.0d;
        for (Feature feature2 : featureArr) {
            if (feature2.getFeatureIndex() != featureIndex && feature2.getField() == i2) {
                d += getV(feature2, field, i3) * value;
            }
        }
        if (NumberUtils.isFinite(d)) {
            return d;
        }
        throw new IllegalStateException("Got " + d + " for sumV[ " + i + "][ " + i3 + "]X.\nx = " + Arrays.toString(featureArr));
    }

    @Nonnull
    protected abstract Entry getEntryW(@Nonnull Feature feature);

    @Nullable
    protected abstract Entry getEntryV(@Nonnull Feature feature, @Nonnull int i);

    @Override // hivemall.factorization.fm.FactorizationMachineModel
    protected final String varDump(@Nonnull Feature[] featureArr) {
        StringBuilder sb = new StringBuilder(1024);
        StringBuilder sb2 = new StringBuilder(1024);
        for (int i = 0; i < featureArr.length; i++) {
            Feature feature = featureArr[i];
            String feature2 = feature.getFeature();
            double value = feature.getValue();
            if (i != 0) {
                sb.append(", ");
            }
            sb.append("x[").append(feature2).append("] = ").append(value);
        }
        sb.append("\n");
        double w0 = getW0();
        sb.append("predict(x) = w0");
        sb2.append("predict(x) = ").append(w0);
        for (Feature feature3 : featureArr) {
            String feature4 = feature3.getFeature();
            double value2 = feature3.getValue();
            float w = getW(feature3);
            sb.append(" + (w[").append(feature4).append("] * x[").append(feature4).append("])");
            sb2.append(" + (").append(w).append(" * ").append(value2).append(')');
            w0 += w * value2;
            if (!NumberUtils.isFinite(w0)) {
                return sb.append(" + ... = ").append(w0).append('\n').append((CharSequence) sb2).append(" + ... = ").append(w0).toString();
            }
        }
        for (int i2 = 0; i2 < featureArr.length; i2++) {
            Feature feature5 = featureArr[i2];
            String feature6 = feature5.getFeature();
            double value3 = feature5.getValue();
            short field = feature5.getField();
            for (int i3 = i2 + 1; i3 < featureArr.length; i3++) {
                Feature feature7 = featureArr[i3];
                String feature8 = feature7.getFeature();
                double value4 = feature7.getValue();
                short field2 = feature7.getField();
                int i4 = this._factor;
                for (int i5 = 0; i5 < i4; i5++) {
                    float v = getV(feature5, field2, i5);
                    float v2 = getV(feature7, field, i5);
                    sb.append(" + (v[i").append(feature6).append("-j").append((int) field2).append("-f").append(i5).append("] * v[j").append(feature8).append("-i").append((int) field).append("-f").append(i5).append("] * x[").append(feature6).append("] * x[").append(feature8).append("])");
                    sb2.append(" + (").append(v).append(" * ").append(v2).append(" * ").append(value3).append(" * ").append(value4).append(')');
                    w0 += v * v2 * value3 * value4;
                    if (!NumberUtils.isFinite(w0)) {
                        return sb.append(" + ... = ").append(w0).append('\n').append((CharSequence) sb2).append(" + ... = ").append(w0).toString();
                    }
                }
            }
        }
        return sb.append(" = ").append(w0).append('\n').append((CharSequence) sb2).append(" = ").append(w0).toString();
    }

    static {
        $assertionsDisabled = !FieldAwareFactorizationMachineModel.class.desiredAssertionStatus();
    }
}
