package hivemall.optimizer;

import hivemall.utils.lang.NumberUtils;
import hivemall.utils.lang.Primitives;
import hivemall.utils.lang.StringUtils;
import java.util.Map;
import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.cli.CommandLine;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;

/* loaded from: input_file:hivemall/optimizer/EtaEstimator.class */
public abstract class EtaEstimator {
    public static final float DEFAULT_ETA0 = 0.1f;
    public static final float DEFAULT_ETA = 0.3f;
    public static final double DEFAULT_POWER_T = 0.1d;
    protected final float eta0;

    /* loaded from: input_file:hivemall/optimizer/EtaEstimator$AdjustingEtaEstimator.class */
    public static final class AdjustingEtaEstimator extends EtaEstimator {
        private float eta;

        public AdjustingEtaEstimator(float f) {
            super(f);
            this.eta = f;
        }

        @Override // hivemall.optimizer.EtaEstimator
        @Nonnull
        public String typeName() {
            return "boldDriver";
        }

        @Override // hivemall.optimizer.EtaEstimator
        public float eta(long j) {
            return this.eta;
        }

        @Override // hivemall.optimizer.EtaEstimator
        public void update(@Nonnegative float f) {
            float f2 = this.eta * f;
            if (NumberUtils.isFinite(f2)) {
                this.eta = Math.min(this.eta0, f2);
            }
        }

        public String toString() {
            return "AdjustingEtaEstimator [ eta0 = " + this.eta0 + ", eta = " + this.eta + " ]";
        }
    }

    /* loaded from: input_file:hivemall/optimizer/EtaEstimator$FixedEtaEstimator.class */
    public static final class FixedEtaEstimator extends EtaEstimator {
        public FixedEtaEstimator(float f) {
            super(f);
        }

        @Override // hivemall.optimizer.EtaEstimator
        @Nonnull
        public String typeName() {
            return "Fixed";
        }

        @Override // hivemall.optimizer.EtaEstimator
        public float eta(long j) {
            return this.eta0;
        }

        public String toString() {
            return "FixedEtaEstimator [ eta0 = " + this.eta0 + " ]";
        }
    }

    /* loaded from: input_file:hivemall/optimizer/EtaEstimator$InvscalingEtaEstimator.class */
    public static final class InvscalingEtaEstimator extends EtaEstimator {
        private final double power_t;

        public InvscalingEtaEstimator(float f, double d) {
            super(f);
            this.power_t = d;
        }

        @Override // hivemall.optimizer.EtaEstimator
        @Nonnull
        public String typeName() {
            return "Invscaling";
        }

        @Override // hivemall.optimizer.EtaEstimator
        public float eta(long j) {
            return (float) (this.eta0 / Math.pow(j, this.power_t));
        }

        public String toString() {
            return "InvscalingEtaEstimator [ eta0 = " + this.eta0 + ", power_t = " + this.power_t + " ]";
        }

        @Override // hivemall.optimizer.EtaEstimator
        public void getHyperParameters(@Nonnull Map<String, Object> map) {
            super.getHyperParameters(map);
            map.put("power_t", Double.valueOf(this.power_t));
        }
    }

    /* loaded from: input_file:hivemall/optimizer/EtaEstimator$SimpleEtaEstimator.class */
    public static final class SimpleEtaEstimator extends EtaEstimator {
        private final float finalEta;
        private final double total_steps;

        public SimpleEtaEstimator(float f, long j) {
            super(f);
            this.finalEta = (float) (f / 2.0d);
            this.total_steps = j;
        }

        @Override // hivemall.optimizer.EtaEstimator
        @Nonnull
        public String typeName() {
            return "Simple";
        }

        @Override // hivemall.optimizer.EtaEstimator
        public float eta(long j) {
            return ((double) j) > this.total_steps ? this.finalEta : (float) (this.eta0 / (1.0d + (j / this.total_steps)));
        }

        public String toString() {
            return "SimpleEtaEstimator [ eta0 = " + this.eta0 + ", totalSteps = " + this.total_steps + ", finalEta = " + this.finalEta + " ]";
        }

        @Override // hivemall.optimizer.EtaEstimator
        public void getHyperParameters(@Nonnull Map<String, Object> map) {
            super.getHyperParameters(map);
            map.put("total_steps", Double.valueOf(this.total_steps));
        }
    }

    public EtaEstimator(float f) {
        this.eta0 = f;
    }

    @Nonnull
    public abstract String typeName();

    public float eta0() {
        return this.eta0;
    }

    public abstract float eta(long j);

    public void update(@Nonnegative float f) {
    }

    public void getHyperParameters(@Nonnull Map<String, Object> map) {
        map.put("eta", typeName());
        map.put("eta0", Float.valueOf(eta0()));
    }

    @Nonnull
    public static EtaEstimator get(@Nullable CommandLine commandLine) throws UDFArgumentException {
        return get(commandLine, 0.1f);
    }

    @Nonnull
    public static EtaEstimator get(@Nullable CommandLine commandLine, float f) throws UDFArgumentException {
        if (commandLine == null) {
            return new InvscalingEtaEstimator(f, 0.1d);
        }
        if (commandLine.hasOption("boldDriver")) {
            return new AdjustingEtaEstimator(Primitives.parseFloat(commandLine.getOptionValue("eta"), 0.3f));
        }
        String optionValue = commandLine.getOptionValue("eta");
        if (optionValue != null) {
            return new FixedEtaEstimator(Float.parseFloat(optionValue));
        }
        float parseFloat = Primitives.parseFloat(commandLine.getOptionValue("eta0"), f);
        return commandLine.hasOption("t") ? new SimpleEtaEstimator(parseFloat, Long.parseLong(commandLine.getOptionValue("t"))) : new InvscalingEtaEstimator(parseFloat, Primitives.parseDouble(commandLine.getOptionValue("power_t"), 0.1d));
    }

    @Nonnull
    public static EtaEstimator get(@Nonnull Map<String, String> map) throws IllegalArgumentException {
        float parseFloat = Primitives.parseFloat(map.get("eta0"), 0.1f);
        double parseDouble = Primitives.parseDouble(map.get("power_t"), 0.1d);
        String str = map.get("eta");
        if (str == null) {
            return new InvscalingEtaEstimator(parseFloat, parseDouble);
        }
        if ("fixed".equalsIgnoreCase(str)) {
            return new FixedEtaEstimator(parseFloat);
        }
        if ("simple".equalsIgnoreCase(str)) {
            if (map.containsKey("total_steps")) {
                return new SimpleEtaEstimator(parseFloat, Long.parseLong(map.get("total_steps")));
            }
            throw new IllegalArgumentException("-total_steps MUST be provided when `-eta simple` is specified");
        }
        if ("inv".equalsIgnoreCase(str) || "inverse".equalsIgnoreCase(str) || "invscaling".equalsIgnoreCase(str)) {
            return new InvscalingEtaEstimator(parseFloat, parseDouble);
        }
        if (StringUtils.isNumber(str)) {
            return new FixedEtaEstimator(Float.parseFloat(str));
        }
        throw new IllegalArgumentException("Unsupported ETA name: " + str);
    }
}
