package org.apache.flink.ml.regression;

import org.apache.flink.api.scala.DataSet;
import org.apache.flink.ml.common.LabeledVector;
import org.apache.flink.ml.common.ParameterMap;
import org.apache.flink.ml.math.Vector;
import org.apache.flink.ml.optimization.GenericLossFunction;
import org.apache.flink.ml.optimization.IterativeSolver;
import org.apache.flink.ml.optimization.LinearPrediction$;
import org.apache.flink.ml.optimization.SimpleGradientDescent;
import org.apache.flink.ml.optimization.SimpleGradientDescent$;
import org.apache.flink.ml.optimization.SquaredLoss$;
import org.apache.flink.ml.package$;
import org.apache.flink.ml.pipeline.FitOperation;
import org.apache.flink.ml.pipeline.PredictOperation;
import scala.MatchError;
import scala.None$;
import scala.Some;
import scala.Tuple2;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: MultipleLinearRegression.scala */
/* loaded from: input_file:org/apache/flink/ml/regression/MultipleLinearRegression$.class */
public final class MultipleLinearRegression$ {
    public static final MultipleLinearRegression$ MODULE$ = null;
    private final String WEIGHTVECTOR_BROADCAST;
    private final GenericLossFunction lossFunction;
    private final Object fitMLR;

    static {
        new MultipleLinearRegression$();
    }

    public String WEIGHTVECTOR_BROADCAST() {
        return this.WEIGHTVECTOR_BROADCAST;
    }

    public GenericLossFunction lossFunction() {
        return this.lossFunction;
    }

    public MultipleLinearRegression apply() {
        return new MultipleLinearRegression();
    }

    public Object fitMLR() {
        return this.fitMLR;
    }

    public <T extends Vector> Object predictVectors() {
        return new PredictOperation<MultipleLinearRegression, T, LabeledVector>() { // from class: org.apache.flink.ml.regression.MultipleLinearRegression$$anon$6
            @Override // org.apache.flink.ml.pipeline.PredictOperation
            public DataSet<LabeledVector> predict(MultipleLinearRegression multipleLinearRegression, ParameterMap parameterMap, DataSet<T> dataSet) {
                Some weightsOption = multipleLinearRegression.weightsOption();
                if (weightsOption instanceof Some) {
                    return package$.MODULE$.RichDataSet(dataSet).mapWithBcVariable((DataSet) weightsOption.x(), new MultipleLinearRegression$$anon$6$$anonfun$predict$1(this), new MultipleLinearRegression$$anon$6$$anon$3(this), ClassTag$.MODULE$.apply(LabeledVector.class));
                }
                None$ none$ = None$.MODULE$;
                if (none$ != null ? !none$.equals(weightsOption) : weightsOption != null) {
                    throw new MatchError(weightsOption);
                }
                throw new RuntimeException("The MultipleLinearRegression has not been fitted to the data. This is necessary to learn the weight vector of the linear function.");
            }
        };
    }

    public Object predictLabeledVectors() {
        return new PredictOperation<MultipleLinearRegression, LabeledVector, Tuple2<Object, Object>>() { // from class: org.apache.flink.ml.regression.MultipleLinearRegression$$anon$7
            @Override // org.apache.flink.ml.pipeline.PredictOperation
            public DataSet<Tuple2<Object, Object>> predict(MultipleLinearRegression multipleLinearRegression, ParameterMap parameterMap, DataSet<LabeledVector> dataSet) {
                Some weightsOption = multipleLinearRegression.weightsOption();
                if (weightsOption instanceof Some) {
                    return package$.MODULE$.RichDataSet(dataSet).mapWithBcVariable((DataSet) weightsOption.x(), new MultipleLinearRegression$$anon$7$$anonfun$predict$2(this), new MultipleLinearRegression$$anon$7$$anon$4(this), ClassTag$.MODULE$.apply(Tuple2.class));
                }
                None$ none$ = None$.MODULE$;
                if (none$ != null ? !none$.equals(weightsOption) : weightsOption != null) {
                    throw new MatchError(weightsOption);
                }
                throw new RuntimeException("The MultipleLinearRegression has not been fitted to the data. This is necessary to learn the weight vector of the linear function.");
            }
        };
    }

    private MultipleLinearRegression$() {
        MODULE$ = this;
        this.WEIGHTVECTOR_BROADCAST = "weights_broadcast";
        this.lossFunction = new GenericLossFunction(SquaredLoss$.MODULE$, LinearPrediction$.MODULE$);
        this.fitMLR = new FitOperation<MultipleLinearRegression, LabeledVector>() { // from class: org.apache.flink.ml.regression.MultipleLinearRegression$$anon$5
            @Override // org.apache.flink.ml.pipeline.FitOperation
            public void fit(MultipleLinearRegression multipleLinearRegression, ParameterMap parameterMap, DataSet<LabeledVector> dataSet) {
                IterativeSolver iterativeSolver;
                ParameterMap $plus$plus = multipleLinearRegression.parameters().$plus$plus(parameterMap);
                int unboxToInt = BoxesRunTime.unboxToInt($plus$plus.apply(MultipleLinearRegression$Iterations$.MODULE$));
                double unboxToDouble = BoxesRunTime.unboxToDouble($plus$plus.apply(MultipleLinearRegression$Stepsize$.MODULE$));
                Some some = $plus$plus.get(MultipleLinearRegression$ConvergenceThreshold$.MODULE$);
                SimpleGradientDescent simpleGradientDescent = (SimpleGradientDescent) SimpleGradientDescent$.MODULE$.apply().setIterations(unboxToInt).setStepsize(unboxToDouble).setLossFunction(new GenericLossFunction(SquaredLoss$.MODULE$, LinearPrediction$.MODULE$));
                if (some instanceof Some) {
                    iterativeSolver = simpleGradientDescent.setConvergenceThreshold(BoxesRunTime.unboxToDouble(some.x()));
                } else {
                    None$ none$ = None$.MODULE$;
                    if (none$ != null ? !none$.equals(some) : some != null) {
                        throw new MatchError(some);
                    }
                    iterativeSolver = BoxedUnit.UNIT;
                }
                multipleLinearRegression.weightsOption_$eq(new Some(simpleGradientDescent.optimize(dataSet, None$.MODULE$)));
            }
        };
    }
}
