package hivemall.regression;

import hivemall.GeneralLearnerBaseUDTF;
import hivemall.model.FeatureValue;
import hivemall.optimizer.LossFunctions;
import javax.annotation.Nonnull;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;

@Description(name = "train_regressor", value = "_FUNC_(list<string|int|bigint> features, double label [, const string options]) - Returns a relation consists of <string|int|bigint feature, float weight>", extended = "Build a prediction model by a generic regressor")
/* loaded from: input_file:hivemall/regression/GeneralRegressorUDTF.class */
public final class GeneralRegressorUDTF extends GeneralLearnerBaseUDTF {
    @Override // hivemall.GeneralLearnerBaseUDTF
    protected String getLossOptionDescription() {
        return "Loss function [SquaredLoss (default), QuantileLoss, EpsilonInsensitiveLoss, SquaredEpsilonInsensitiveLoss, HuberLoss]";
    }

    @Override // hivemall.GeneralLearnerBaseUDTF
    protected LossFunctions.LossType getDefaultLossType() {
        return LossFunctions.LossType.SquaredLoss;
    }

    @Override // hivemall.GeneralLearnerBaseUDTF
    protected void checkLossFunction(@Nonnull LossFunctions.LossFunction lossFunction) throws UDFArgumentException {
        if (!lossFunction.forRegression()) {
            throw new UDFArgumentException("The loss function `" + lossFunction.getType() + "` is not designed for regression");
        }
    }

    @Override // hivemall.GeneralLearnerBaseUDTF
    protected void checkTargetValue(float f) throws UDFArgumentException {
    }

    @Override // hivemall.GeneralLearnerBaseUDTF
    protected void train(@Nonnull FeatureValue[] featureValueArr, float f) {
        update(featureValueArr, f, predict(featureValueArr));
    }
}
