package hivemall.fm;

import hivemall.fm.FactorizationMachineModel;
import hivemall.optimizer.EtaEstimator;
import hivemall.utils.lang.Primitives;
import hivemall.utils.lang.SizeOf;
import javax.annotation.Nonnull;
import org.apache.commons.cli.CommandLine;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;

/* loaded from: input_file:hivemall/fm/FMHyperParameters.class */
class FMHyperParameters {
    private static final float DEFAULT_ETA0 = 0.05f;
    FactorizationMachineModel.VInitScheme vInit;
    EtaEstimator eta;
    boolean l2norm;
    boolean classification = false;
    int factors = 5;
    float lambda = 0.01f;
    float lambdaW0 = 0.01f;
    float lambdaW = 0.01f;
    float lambdaV = 0.01f;
    double sigma = 0.1d;
    long seed = -1;
    double minTarget = Double.MIN_VALUE;
    double maxTarget = Double.MAX_VALUE;
    int numFeatures = -1;
    int iters = 1;
    boolean conversionCheck = true;
    double convergenceRate = 0.005d;
    boolean adaptiveReglarization = false;
    float validationRatio = DEFAULT_ETA0;
    int validationThreshold = 1000;
    boolean parseFeatureAsInt = false;

    /* loaded from: input_file:hivemall/fm/FMHyperParameters$FFMHyperParameters.class */
    public static final class FFMHyperParameters extends FMHyperParameters {
        boolean globalBias = false;
        boolean linearCoeff = true;
        int numFields = Feature.DEFAULT_NUM_FIELDS;
        boolean useAdaGrad = false;
        float eps = 1.0f;
        boolean useFTRL = false;
        float alphaFTRL = 0.2f;
        float betaFTRL = 1.0f;
        float lambda1 = 0.001f;
        float lamdda2 = 1.0E-4f;

        @Override // hivemall.fm.FMHyperParameters
        void processOptions(@Nonnull CommandLine commandLine) throws UDFArgumentException {
            int parseInt;
            super.processOptions(commandLine);
            if (commandLine.hasOption("int_feature")) {
                throw new UDFArgumentException("int_feature option is not supported yet for FFM");
            }
            this.globalBias = commandLine.hasOption("global_bias");
            this.linearCoeff = !commandLine.hasOption("no_coeff");
            if (this.numFeatures == -1 && (parseInt = Primitives.parseInt(commandLine.getOptionValue("feature_hashing"), -1)) != -1) {
                if (parseInt < 18 || parseInt > 31) {
                    throw new UDFArgumentException("-feature_hashing MUST be in range [18,31]: " + parseInt);
                }
                this.numFeatures = 1 << parseInt;
            }
            this.numFields = Primitives.parseInt(commandLine.getOptionValue("num_fields"), this.numFields);
            if (this.numFields <= 1) {
                throw new UDFArgumentException("-num_fields MUST be greater than 1: " + this.numFields);
            }
            String lowerCase = commandLine.getOptionValue("optimizer", "ftrl").toLowerCase();
            boolean z = -1;
            switch (lowerCase.hashCode()) {
                case -1150778388:
                    if (lowerCase.equals("adagrad")) {
                        z = true;
                        break;
                    }
                    break;
                case 113808:
                    if (lowerCase.equals("sgd")) {
                        z = 2;
                        break;
                    }
                    break;
                case 3153800:
                    if (lowerCase.equals("ftrl")) {
                        z = false;
                        break;
                    }
                    break;
            }
            switch (z) {
                case false:
                    this.useFTRL = true;
                    this.useAdaGrad = false;
                    this.alphaFTRL = Primitives.parseFloat(commandLine.getOptionValue("alphaFTRL"), this.alphaFTRL);
                    if (this.alphaFTRL == 0.0f) {
                        throw new UDFArgumentException("-alphaFTRL SHOULD NOT be 0");
                    }
                    this.betaFTRL = Primitives.parseFloat(commandLine.getOptionValue("betaFTRL"), this.betaFTRL);
                    this.lambda1 = Primitives.parseFloat(commandLine.getOptionValue("lambda1"), this.lambda1);
                    this.lamdda2 = Primitives.parseFloat(commandLine.getOptionValue("lamdda2"), this.lamdda2);
                    return;
                case SizeOf.BYTE /* 1 */:
                    this.useAdaGrad = true;
                    this.useFTRL = false;
                    this.eps = Primitives.parseFloat(commandLine.getOptionValue("eps"), this.eps);
                    return;
                case true:
                default:
                    this.useFTRL = false;
                    this.useAdaGrad = false;
                    return;
            }
        }

        @Override // hivemall.fm.FMHyperParameters
        public String toString() {
            return "FFMHyperParameters [globalBias=" + this.globalBias + ", linearCoeff=" + this.linearCoeff + ", numFields=" + this.numFields + ", useAdaGrad=" + this.useAdaGrad + ", eps=" + this.eps + ", useFTRL=" + this.useFTRL + ", alphaFTRL=" + this.alphaFTRL + ", betaFTRL=" + this.betaFTRL + ", lambda1=" + this.lambda1 + ", lamdda2=" + this.lamdda2 + "], " + super.toString();
        }
    }

    public String toString() {
        return "FMHyperParameters [classification=" + this.classification + ", factors=" + this.factors + ", lambda=" + this.lambda + ", lambdaW0=" + this.lambdaW0 + ", lambdaW=" + this.lambdaW + ", lambdaV=" + this.lambdaV + ", sigma=" + this.sigma + ", seed=" + this.seed + ", vInit=" + this.vInit + ", minTarget=" + this.minTarget + ", maxTarget=" + this.maxTarget + ", eta=" + this.eta + ", numFeatures=" + this.numFeatures + ", l2norm=" + this.l2norm + ", iters=" + this.iters + ", conversionCheck=" + this.conversionCheck + ", convergenceRate=" + this.convergenceRate + ", adaptiveReglarization=" + this.adaptiveReglarization + ", validationRatio=" + this.validationRatio + ", validationThreshold=" + this.validationThreshold + ", parseFeatureAsInt=" + this.parseFeatureAsInt + "]";
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void processOptions(@Nonnull CommandLine commandLine) throws UDFArgumentException {
        this.classification = commandLine.hasOption("classification");
        this.factors = Primitives.parseInt(commandLine.getOptionValue("factors"), this.factors);
        this.lambda = Primitives.parseFloat(commandLine.getOptionValue("lambda"), this.lambda);
        this.lambdaW0 = Primitives.parseFloat(commandLine.getOptionValue("lambda_w0"), this.lambda);
        this.lambdaW = Primitives.parseFloat(commandLine.getOptionValue("lambda_w"), this.lambda);
        this.lambdaV = Primitives.parseFloat(commandLine.getOptionValue("lambda_v"), this.lambda);
        this.sigma = Primitives.parseDouble(commandLine.getOptionValue("sigma"), this.sigma);
        this.seed = Primitives.parseLong(commandLine.getOptionValue("seed"), this.seed);
        if (this.seed == -1) {
            this.seed = System.nanoTime();
        }
        this.vInit = instantiateVInit(commandLine, this.factors, this.seed, this.classification);
        this.minTarget = Primitives.parseDouble(commandLine.getOptionValue("min_target"), this.minTarget);
        this.maxTarget = Primitives.parseDouble(commandLine.getOptionValue("max_target"), this.maxTarget);
        this.eta = EtaEstimator.get(commandLine, DEFAULT_ETA0);
        this.numFeatures = Primitives.parseInt(commandLine.getOptionValue("num_features"), this.numFeatures);
        this.l2norm = commandLine.hasOption("enable_norm");
        this.iters = Primitives.parseInt(commandLine.getOptionValue("iterations"), this.iters);
        this.conversionCheck = !commandLine.hasOption("disable_cvtest");
        this.convergenceRate = Primitives.parseDouble(commandLine.getOptionValue("cv_rate"), this.convergenceRate);
        this.adaptiveReglarization = commandLine.hasOption("adaptive_regularizaion");
        this.validationRatio = Primitives.parseFloat(commandLine.getOptionValue("validation_ratio"), this.validationRatio);
        if (this.validationRatio < 0.0f || this.validationRatio >= 1.0f) {
            throw new UDFArgumentException("validation_ratio should be in range [0, 1): " + this.validationRatio);
        }
        this.validationThreshold = Primitives.parseInt(commandLine.getOptionValue("validation_threshold"), this.validationThreshold);
        this.parseFeatureAsInt = commandLine.hasOption("int_feature");
    }

    @Nonnull
    private static FactorizationMachineModel.VInitScheme instantiateVInit(@Nonnull CommandLine commandLine, int i, long j, boolean z) {
        String optionValue = commandLine.getOptionValue("init_v");
        float parseFloat = Primitives.parseFloat(commandLine.getOptionValue("max_init_value"), 0.5f);
        double parseDouble = Primitives.parseDouble(commandLine.getOptionValue("min_init_stddev"), 0.1d);
        FactorizationMachineModel.VInitScheme resolve = FactorizationMachineModel.VInitScheme.resolve(optionValue, z ? FactorizationMachineModel.VInitScheme.gaussian : FactorizationMachineModel.VInitScheme.random);
        resolve.setMaxInitValue(parseFloat);
        resolve.setInitStdDev(Math.max(parseDouble, 1.0d / i));
        resolve.initRandom(i, j);
        return resolve;
    }
}
