package org.apache.reef.examples.group.bgd.parameters;

import java.util.HashMap;
import java.util.Map;
import javax.inject.Inject;
import org.apache.reef.examples.group.bgd.loss.LogisticLossFunction;
import org.apache.reef.examples.group.bgd.loss.LossFunction;
import org.apache.reef.examples.group.bgd.loss.SquaredErrorLossFunction;
import org.apache.reef.examples.group.bgd.loss.WeightedLogisticLossFunction;
import org.apache.reef.tang.annotations.Parameter;

/* loaded from: input_file:org/apache/reef/examples/group/bgd/parameters/BGDLossType.class */
public class BGDLossType {
    private static final Map<String, Class<? extends LossFunction>> LOSS_FUNCTIONS = new HashMap<String, Class<? extends LossFunction>>() { // from class: org.apache.reef.examples.group.bgd.parameters.BGDLossType.1
        {
            put("logLoss", LogisticLossFunction.class);
            put("weightedLogLoss", WeightedLogisticLossFunction.class);
            put("squaredError", SquaredErrorLossFunction.class);
        }
    };
    private final Class<? extends LossFunction> lossFunction;
    private final String lossFunctionStr;

    @Inject
    public BGDLossType(@Parameter(LossFunctionType.class) String str) {
        this.lossFunctionStr = str;
        this.lossFunction = LOSS_FUNCTIONS.get(str);
        if (this.lossFunction == null) {
            throw new RuntimeException("Specified loss function type: " + str + " is not implemented. Supported types are logLoss|weightedLogLoss|squaredError");
        }
    }

    public Class<? extends LossFunction> getLossFunction() {
        return this.lossFunction;
    }

    public String lossFunctionString() {
        return this.lossFunctionStr;
    }
}
