package hivemall.optimizer;

import hivemall.utils.lang.ExceptionUtils;
import hivemall.utils.lang.SizeOf;
import hivemall.utils.math.MathUtils;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;

/* loaded from: input_file:hivemall/optimizer/LossFunctions.class */
public final class LossFunctions {

    /* renamed from: hivemall.optimizer.LossFunctions$1, reason: invalid class name */
    /* loaded from: input_file:hivemall/optimizer/LossFunctions$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$hivemall$optimizer$LossFunctions$LossType = new int[LossType.values().length];

        static {
            try {
                $SwitchMap$hivemall$optimizer$LossFunctions$LossType[LossType.SquaredLoss.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$hivemall$optimizer$LossFunctions$LossType[LossType.QuantileLoss.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$hivemall$optimizer$LossFunctions$LossType[LossType.EpsilonInsensitiveLoss.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$hivemall$optimizer$LossFunctions$LossType[LossType.SquaredEpsilonInsensitiveLoss.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$hivemall$optimizer$LossFunctions$LossType[LossType.HuberLoss.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$hivemall$optimizer$LossFunctions$LossType[LossType.HingeLoss.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$hivemall$optimizer$LossFunctions$LossType[LossType.LogLoss.ordinal()] = 7;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$hivemall$optimizer$LossFunctions$LossType[LossType.SquaredHingeLoss.ordinal()] = 8;
            } catch (NoSuchFieldError e8) {
            }
            try {
                $SwitchMap$hivemall$optimizer$LossFunctions$LossType[LossType.ModifiedHuberLoss.ordinal()] = 9;
            } catch (NoSuchFieldError e9) {
            }
        }
    }

    /* loaded from: input_file:hivemall/optimizer/LossFunctions$BinaryLoss.class */
    public static abstract class BinaryLoss implements LossFunction {
        protected static void checkTarget(float f) {
            if (f != 1.0f && f != -1.0f) {
                throw new IllegalArgumentException("target must be [+1,-1]: " + f);
            }
        }

        protected static void checkTarget(double d) {
            if (d != 1.0d && d != -1.0d) {
                throw new IllegalArgumentException("target must be [+1,-1]: " + d);
            }
        }

        @Override // hivemall.optimizer.LossFunctions.LossFunction
        public boolean forBinaryClassification() {
            return true;
        }

        @Override // hivemall.optimizer.LossFunctions.LossFunction
        public boolean forRegression() {
            return false;
        }
    }

    /* loaded from: input_file:hivemall/optimizer/LossFunctions$EpsilonInsensitiveLoss.class */
    public static final class EpsilonInsensitiveLoss extends RegressionLoss {
        private float epsilon;

        public EpsilonInsensitiveLoss() {
            this(0.1f);
        }

        public EpsilonInsensitiveLoss(float f) {
            this.epsilon = f;
        }

        public void setEpsilon(float f) {
            this.epsilon = f;
        }

        @Override // hivemall.optimizer.LossFunctions.LossFunction
        public float loss(float f, float f2) {
            float abs = Math.abs(f2 - f) - this.epsilon;
            if (abs > 0.0f) {
                return abs;
            }
            return 0.0f;
        }

        @Override // hivemall.optimizer.LossFunctions.LossFunction
        public double loss(double d, double d2) {
            double abs = Math.abs(d2 - d) - this.epsilon;
            if (abs > 0.0d) {
                return abs;
            }
            return 0.0d;
        }

        @Override // hivemall.optimizer.LossFunctions.LossFunction
        public float dloss(float f, float f2) {
            if (f2 - f > this.epsilon) {
                return -1.0f;
            }
            return f - f2 > this.epsilon ? 1.0f : 0.0f;
        }

        @Override // hivemall.optimizer.LossFunctions.LossFunction
        public LossType getType() {
            return LossType.EpsilonInsensitiveLoss;
        }
    }

    /* loaded from: input_file:hivemall/optimizer/LossFunctions$HingeLoss.class */
    public static final class HingeLoss extends BinaryLoss {
        private float threshold;

        public HingeLoss() {
            this(1.0f);
        }

        public HingeLoss(float f) {
            this.threshold = f;
        }

        public void setThreshold(float f) {
            this.threshold = f;
        }

        @Override // hivemall.optimizer.LossFunctions.LossFunction
        public float loss(float f, float f2) {
            float hingeLoss = LossFunctions.hingeLoss(f, f2, this.threshold);
            if (hingeLoss > 0.0f) {
                return hingeLoss;
            }
            return 0.0f;
        }

        @Override // hivemall.optimizer.LossFunctions.LossFunction
        public double loss(double d, double d2) {
            double hingeLoss = LossFunctions.hingeLoss(d, d2, this.threshold);
            if (hingeLoss > 0.0d) {
                return hingeLoss;
            }
            return 0.0d;
        }

        @Override // hivemall.optimizer.LossFunctions.LossFunction
        public float dloss(float f, float f2) {
            if (LossFunctions.hingeLoss(f, f2, this.threshold) > 0.0f) {
                return -f2;
            }
            return 0.0f;
        }

        @Override // hivemall.optimizer.LossFunctions.LossFunction
        public LossType getType() {
            return LossType.HingeLoss;
        }
    }

    /* loaded from: input_file:hivemall/optimizer/LossFunctions$HuberLoss.class */
    public static final class HuberLoss extends RegressionLoss {
        private float c;

        public HuberLoss() {
            this(1.0f);
        }

        public HuberLoss(float f) {
            this.c = f;
        }

        public void setC(float f) {
            this.c = f;
        }

        @Override // hivemall.optimizer.LossFunctions.LossFunction
        public float loss(float f, float f2) {
            float f3 = f - f2;
            float abs = Math.abs(f3);
            return abs <= this.c ? 0.5f * f3 * f3 : (this.c * abs) - ((0.5f * this.c) * this.c);
        }

        @Override // hivemall.optimizer.LossFunctions.LossFunction
        public double loss(double d, double d2) {
            double d3 = d - d2;
            double abs = Math.abs(d3);
            return abs <= ((double) this.c) ? 0.5d * d3 * d3 : (this.c * abs) - ((0.5d * this.c) * this.c);
        }

        @Override // hivemall.optimizer.LossFunctions.LossFunction
        public float dloss(float f, float f2) {
            float f3 = f - f2;
            return Math.abs(f3) <= this.c ? f3 : f3 > 0.0f ? this.c : -this.c;
        }

        @Override // hivemall.optimizer.LossFunctions.LossFunction
        public LossType getType() {
            return LossType.HuberLoss;
        }
    }

    /* loaded from: input_file:hivemall/optimizer/LossFunctions$LogLoss.class */
    public static final class LogLoss extends BinaryLoss {
        @Override // hivemall.optimizer.LossFunctions.LossFunction
        public float loss(float f, float f2) {
            checkTarget(f2);
            float f3 = f2 * f;
            return f3 > 18.0f ? (float) Math.exp(-f3) : f3 < -18.0f ? -f3 : (float) Math.log(1.0d + Math.exp(-f3));
        }

        @Override // hivemall.optimizer.LossFunctions.LossFunction
        public double loss(double d, double d2) {
            checkTarget(d2);
            double d3 = d2 * d;
            return d3 > 18.0d ? Math.exp(-d3) : d3 < -18.0d ? -d3 : Math.log(1.0d + Math.exp(-d3));
        }

        @Override // hivemall.optimizer.LossFunctions.LossFunction
        public float dloss(float f, float f2) {
            checkTarget(f2);
            float f3 = f2 * f;
            return f3 > 18.0f ? ((float) Math.exp(-f3)) * (-f2) : f3 < -18.0f ? -f2 : (-f2) / (((float) Math.exp(f3)) + 1.0f);
        }

        @Override // hivemall.optimizer.LossFunctions.LossFunction
        public LossType getType() {
            return LossType.LogLoss;
        }
    }

    /* loaded from: input_file:hivemall/optimizer/LossFunctions$LossFunction.class */
    public interface LossFunction {
        float loss(float f, float f2);

        double loss(double d, double d2);

        float dloss(float f, float f2);

        boolean forBinaryClassification();

        boolean forRegression();

        @Nonnull
        LossType getType();
    }

    /* loaded from: input_file:hivemall/optimizer/LossFunctions$LossType.class */
    public enum LossType {
        SquaredLoss,
        QuantileLoss,
        EpsilonInsensitiveLoss,
        SquaredEpsilonInsensitiveLoss,
        HuberLoss,
        HingeLoss,
        LogLoss,
        SquaredHingeLoss,
        ModifiedHuberLoss
    }

    /* loaded from: input_file:hivemall/optimizer/LossFunctions$ModifiedHuberLoss.class */
    public static final class ModifiedHuberLoss extends BinaryLoss {
        @Override // hivemall.optimizer.LossFunctions.LossFunction
        public float loss(float f, float f2) {
            float f3 = f * f2;
            if (f3 >= 1.0f) {
                return 0.0f;
            }
            return f3 >= -1.0f ? (1.0f - f3) * (1.0f - f3) : (-4.0f) * f3;
        }

        @Override // hivemall.optimizer.LossFunctions.LossFunction
        public double loss(double d, double d2) {
            double d3 = d * d2;
            if (d3 >= 1.0d) {
                return 0.0d;
            }
            return d3 >= -1.0d ? (1.0d - d3) * (1.0d - d3) : (-4.0d) * d3;
        }

        @Override // hivemall.optimizer.LossFunctions.LossFunction
        public float dloss(float f, float f2) {
            float f3 = f * f2;
            if (f3 >= 1.0f) {
                return 0.0f;
            }
            return f3 >= -1.0f ? 2.0f * (1.0f - f3) * (-f2) : (-4.0f) * f2;
        }

        @Override // hivemall.optimizer.LossFunctions.LossFunction
        public LossType getType() {
            return LossType.ModifiedHuberLoss;
        }
    }

    /* loaded from: input_file:hivemall/optimizer/LossFunctions$QuantileLoss.class */
    public static final class QuantileLoss extends RegressionLoss {
        private float tau;

        public QuantileLoss() {
            this.tau = 0.5f;
        }

        public QuantileLoss(float f) {
            setTau(f);
        }

        public void setTau(float f) {
            if (f <= 0.0f || f >= 1.0d) {
                throw new IllegalArgumentException("tau must be in range (0, 1): " + f);
            }
            this.tau = f;
        }

        @Override // hivemall.optimizer.LossFunctions.LossFunction
        public float loss(float f, float f2) {
            float f3 = f2 - f;
            return f3 > 0.0f ? this.tau * f3 : (-(1.0f - this.tau)) * f3;
        }

        @Override // hivemall.optimizer.LossFunctions.LossFunction
        public double loss(double d, double d2) {
            double d3 = d2 - d;
            return d3 > 0.0d ? this.tau * d3 : (-(1.0d - this.tau)) * d3;
        }

        @Override // hivemall.optimizer.LossFunctions.LossFunction
        public float dloss(float f, float f2) {
            float f3 = f2 - f;
            if (f3 == 0.0f) {
                return 0.0f;
            }
            return f3 > 0.0f ? -this.tau : 1.0f - this.tau;
        }

        @Override // hivemall.optimizer.LossFunctions.LossFunction
        public LossType getType() {
            return LossType.QuantileLoss;
        }
    }

    /* loaded from: input_file:hivemall/optimizer/LossFunctions$RegressionLoss.class */
    public static abstract class RegressionLoss implements LossFunction {
        @Override // hivemall.optimizer.LossFunctions.LossFunction
        public boolean forBinaryClassification() {
            return false;
        }

        @Override // hivemall.optimizer.LossFunctions.LossFunction
        public boolean forRegression() {
            return true;
        }
    }

    /* loaded from: input_file:hivemall/optimizer/LossFunctions$SquaredEpsilonInsensitiveLoss.class */
    public static final class SquaredEpsilonInsensitiveLoss extends RegressionLoss {
        private float epsilon;

        public SquaredEpsilonInsensitiveLoss() {
            this(0.1f);
        }

        public SquaredEpsilonInsensitiveLoss(float f) {
            this.epsilon = f;
        }

        public void setEpsilon(float f) {
            this.epsilon = f;
        }

        @Override // hivemall.optimizer.LossFunctions.LossFunction
        public float loss(float f, float f2) {
            float abs = Math.abs(f2 - f) - this.epsilon;
            if (abs > 0.0f) {
                return abs * abs;
            }
            return 0.0f;
        }

        @Override // hivemall.optimizer.LossFunctions.LossFunction
        public double loss(double d, double d2) {
            double abs = Math.abs(d2 - d) - this.epsilon;
            if (abs > 0.0d) {
                return abs * abs;
            }
            return 0.0d;
        }

        @Override // hivemall.optimizer.LossFunctions.LossFunction
        public float dloss(float f, float f2) {
            float f3 = f2 - f;
            if (f3 > this.epsilon) {
                return (-2.0f) * (f3 - this.epsilon);
            }
            if ((-f3) > this.epsilon) {
                return 2.0f * ((-f3) - this.epsilon);
            }
            return 0.0f;
        }

        @Override // hivemall.optimizer.LossFunctions.LossFunction
        public LossType getType() {
            return LossType.SquaredEpsilonInsensitiveLoss;
        }
    }

    /* loaded from: input_file:hivemall/optimizer/LossFunctions$SquaredHingeLoss.class */
    public static final class SquaredHingeLoss extends BinaryLoss {
        @Override // hivemall.optimizer.LossFunctions.LossFunction
        public float loss(float f, float f2) {
            return LossFunctions.squaredHingeLoss(f, f2);
        }

        @Override // hivemall.optimizer.LossFunctions.LossFunction
        public double loss(double d, double d2) {
            return LossFunctions.squaredHingeLoss(d, d2);
        }

        @Override // hivemall.optimizer.LossFunctions.LossFunction
        public float dloss(float f, float f2) {
            checkTarget(f2);
            float f3 = 1.0f - (f2 * f);
            if (f3 > 0.0f) {
                return (-2.0f) * f3 * f2;
            }
            return 0.0f;
        }

        @Override // hivemall.optimizer.LossFunctions.LossFunction
        public LossType getType() {
            return LossType.SquaredHingeLoss;
        }
    }

    /* loaded from: input_file:hivemall/optimizer/LossFunctions$SquaredLoss.class */
    public static final class SquaredLoss extends RegressionLoss {
        @Override // hivemall.optimizer.LossFunctions.LossFunction
        public float loss(float f, float f2) {
            float f3 = f - f2;
            return f3 * f3 * 0.5f;
        }

        @Override // hivemall.optimizer.LossFunctions.LossFunction
        public double loss(double d, double d2) {
            double d3 = d - d2;
            return d3 * d3 * 0.5d;
        }

        @Override // hivemall.optimizer.LossFunctions.LossFunction
        public float dloss(float f, float f2) {
            return f - f2;
        }

        @Override // hivemall.optimizer.LossFunctions.LossFunction
        public LossType getType() {
            return LossType.SquaredLoss;
        }
    }

    @Nonnull
    public static LossFunction getLossFunction(@Nullable String str) {
        if ("SquaredLoss".equalsIgnoreCase(str) || "squared".equalsIgnoreCase(str)) {
            return new SquaredLoss();
        }
        if ("QuantileLoss".equalsIgnoreCase(str) || "quantile".equalsIgnoreCase(str)) {
            return new QuantileLoss();
        }
        if ("EpsilonInsensitiveLoss".equalsIgnoreCase(str) || "epsilon_insensitive".equalsIgnoreCase(str)) {
            return new EpsilonInsensitiveLoss();
        }
        if ("SquaredEpsilonInsensitiveLoss".equalsIgnoreCase(str) || "squared_epsilon_insensitive".equalsIgnoreCase(str)) {
            return new SquaredEpsilonInsensitiveLoss();
        }
        if ("HuberLoss".equalsIgnoreCase(str) || "huber".equalsIgnoreCase(str)) {
            return new HuberLoss();
        }
        if ("HingeLoss".equalsIgnoreCase(str) || "hinge".equalsIgnoreCase(str)) {
            return new HingeLoss();
        }
        if ("LogLoss".equalsIgnoreCase(str) || "log".equalsIgnoreCase(str) || "LogisticLoss".equalsIgnoreCase(str) || "logistic".equalsIgnoreCase(str)) {
            return new LogLoss();
        }
        if ("SquaredHingeLoss".equalsIgnoreCase(str) || "squared_hinge".equalsIgnoreCase(str)) {
            return new SquaredHingeLoss();
        }
        if ("ModifiedHuberLoss".equalsIgnoreCase(str) || "modified_huber".equalsIgnoreCase(str)) {
            return new ModifiedHuberLoss();
        }
        throw new IllegalArgumentException("Unsupported loss function name: " + str);
    }

    @Nonnull
    public static LossFunction getLossFunction(@Nonnull LossType lossType) {
        switch (AnonymousClass1.$SwitchMap$hivemall$optimizer$LossFunctions$LossType[lossType.ordinal()]) {
            case SizeOf.BYTE /* 1 */:
                return new SquaredLoss();
            case 2:
                return new QuantileLoss();
            case 3:
                return new EpsilonInsensitiveLoss();
            case 4:
                return new SquaredEpsilonInsensitiveLoss();
            case ExceptionUtils.TRACE_CAUSE_DEPTH /* 5 */:
                return new HuberLoss();
            case 6:
                return new HingeLoss();
            case 7:
                return new LogLoss();
            case 8:
                return new SquaredHingeLoss();
            case 9:
                return new ModifiedHuberLoss();
            default:
                throw new IllegalArgumentException("Unsupported loss function name: " + lossType);
        }
    }

    public static float logisticLoss(float f, float f2) {
        return ((double) f2) > -100.0d ? f - ((float) MathUtils.sigmoid(f2)) : f;
    }

    public static float logLoss(float f, float f2) {
        BinaryLoss.checkTarget(f2);
        float f3 = f2 * f;
        return f3 > 18.0f ? (float) Math.exp(-f3) : f3 < -18.0f ? -f3 : (float) Math.log(1.0d + Math.exp(-f3));
    }

    public static double logLoss(double d, double d2) {
        BinaryLoss.checkTarget(d2);
        double d3 = d2 * d;
        return d3 > 18.0d ? Math.exp(-d3) : d3 < -18.0d ? -d3 : Math.log(1.0d + Math.exp(-d3));
    }

    public static float squaredLoss(float f, float f2) {
        float f3 = f - f2;
        return f3 * f3 * 0.5f;
    }

    public static double squaredLoss(double d, double d2) {
        double d3 = d - d2;
        return d3 * d3 * 0.5d;
    }

    public static float hingeLoss(float f, float f2, float f3) {
        BinaryLoss.checkTarget(f2);
        return f3 - (f2 * f);
    }

    public static double hingeLoss(double d, double d2, double d3) {
        BinaryLoss.checkTarget(d2);
        return d3 - (d2 * d);
    }

    public static float hingeLoss(float f, float f2) {
        return hingeLoss(f, f2, 1.0f);
    }

    public static double hingeLoss(double d, double d2) {
        return hingeLoss(d, d2, 1.0d);
    }

    public static float squaredHingeLoss(float f, float f2) {
        BinaryLoss.checkTarget(f2);
        float f3 = 1.0f - (f2 * f);
        if (f3 > 0.0f) {
            return f3 * f3;
        }
        return 0.0f;
    }

    public static double squaredHingeLoss(double d, double d2) {
        BinaryLoss.checkTarget(d2);
        double d3 = 1.0d - (d2 * d);
        if (d3 > 0.0d) {
            return d3 * d3;
        }
        return 0.0d;
    }

    public static float epsilonInsensitiveLoss(float f, float f2, float f3) {
        return Math.abs(f2 - f) - f3;
    }
}
