package org.apache.flink.ml.regression;

import breeze.linalg.Vector$;
import org.apache.flink.api.scala.DataSet;
import org.apache.flink.ml.common.LabeledVector;
import org.apache.flink.ml.common.ParameterMap;
import org.apache.flink.ml.common.WeightVector;
import org.apache.flink.ml.math.Breeze$;
import org.apache.flink.ml.math.Vector;
import org.apache.flink.ml.optimization.GenericLossFunction;
import org.apache.flink.ml.optimization.GradientDescent;
import org.apache.flink.ml.optimization.GradientDescent$;
import org.apache.flink.ml.optimization.IterativeSolver;
import org.apache.flink.ml.optimization.LearningRateMethod;
import org.apache.flink.ml.optimization.LinearPrediction$;
import org.apache.flink.ml.optimization.SquaredLoss$;
import org.apache.flink.ml.pipeline.FitOperation;
import org.apache.flink.ml.pipeline.PredictOperation;
import scala.MatchError;
import scala.None$;
import scala.Some;
import scala.Tuple2;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: MultipleLinearRegression.scala */
/* loaded from: input_file:org/apache/flink/ml/regression/MultipleLinearRegression$.class */
public final class MultipleLinearRegression$ {
    public static final MultipleLinearRegression$ MODULE$ = null;
    private final String WEIGHTVECTOR_BROADCAST;
    private final GenericLossFunction lossFunction;
    private final Object fitMLR;

    static {
        new MultipleLinearRegression$();
    }

    public String WEIGHTVECTOR_BROADCAST() {
        return this.WEIGHTVECTOR_BROADCAST;
    }

    public GenericLossFunction lossFunction() {
        return this.lossFunction;
    }

    public MultipleLinearRegression apply() {
        return new MultipleLinearRegression();
    }

    public Object fitMLR() {
        return this.fitMLR;
    }

    public <T extends Vector> Object predictVectors() {
        return new PredictOperation<MultipleLinearRegression, WeightVector, T, Object>() { // from class: org.apache.flink.ml.regression.MultipleLinearRegression$$anon$1
            @Override // org.apache.flink.ml.pipeline.PredictOperation
            public DataSet<WeightVector> getModel(MultipleLinearRegression multipleLinearRegression, ParameterMap parameterMap) {
                Some weightsOption = multipleLinearRegression.weightsOption();
                if (weightsOption instanceof Some) {
                    return (DataSet) weightsOption.x();
                }
                if (None$.MODULE$.equals(weightsOption)) {
                    throw new RuntimeException("The MultipleLinearRegression has not been fitted to the data. This is necessary to learn the weight vector of the linear function.");
                }
                throw new MatchError(weightsOption);
            }

            /* JADX WARN: Incorrect types in method signature: (TT;Lorg/apache/flink/ml/common/WeightVector;)D */
            public double predict(Vector vector, WeightVector weightVector) {
                if (weightVector == null) {
                    throw new MatchError(weightVector);
                }
                Tuple2 tuple2 = new Tuple2(weightVector.weights(), BoxesRunTime.boxToDouble(weightVector.intercept()));
                Vector vector2 = (Vector) tuple2._1();
                return BoxesRunTime.unboxToDouble(Breeze$.MODULE$.Vector2BreezeConverter(vector).asBreeze().dot(Breeze$.MODULE$.Vector2BreezeConverter(vector2).asBreeze(), Vector$.MODULE$.canDot_V_V_Double())) + tuple2._2$mcD$sp();
            }

            @Override // org.apache.flink.ml.pipeline.PredictOperation
            public /* bridge */ /* synthetic */ Object predict(Object obj, WeightVector weightVector) {
                return BoxesRunTime.boxToDouble(predict((Vector) obj, weightVector));
            }
        };
    }

    private MultipleLinearRegression$() {
        MODULE$ = this;
        this.WEIGHTVECTOR_BROADCAST = "weights_broadcast";
        this.lossFunction = new GenericLossFunction(SquaredLoss$.MODULE$, LinearPrediction$.MODULE$);
        this.fitMLR = new FitOperation<MultipleLinearRegression, LabeledVector>() { // from class: org.apache.flink.ml.regression.MultipleLinearRegression$$anon$2
            @Override // org.apache.flink.ml.pipeline.FitOperation
            public void fit(MultipleLinearRegression multipleLinearRegression, ParameterMap parameterMap, DataSet<LabeledVector> dataSet) {
                IterativeSolver iterativeSolver;
                IterativeSolver iterativeSolver2;
                ParameterMap $plus$plus = multipleLinearRegression.parameters().$plus$plus(parameterMap);
                int unboxToInt = BoxesRunTime.unboxToInt($plus$plus.apply(MultipleLinearRegression$Iterations$.MODULE$));
                double unboxToDouble = BoxesRunTime.unboxToDouble($plus$plus.apply(MultipleLinearRegression$Stepsize$.MODULE$));
                Some some = $plus$plus.get(MultipleLinearRegression$ConvergenceThreshold$.MODULE$);
                Some some2 = $plus$plus.get(MultipleLinearRegression$LearningRateMethodValue$.MODULE$);
                GradientDescent gradientDescent = (GradientDescent) GradientDescent$.MODULE$.apply().setIterations(unboxToInt).setStepsize(unboxToDouble).setLossFunction(new GenericLossFunction(SquaredLoss$.MODULE$, LinearPrediction$.MODULE$));
                if (some instanceof Some) {
                    iterativeSolver = gradientDescent.setConvergenceThreshold(BoxesRunTime.unboxToDouble(some.x()));
                } else {
                    if (!None$.MODULE$.equals(some)) {
                        throw new MatchError(some);
                    }
                    iterativeSolver = BoxedUnit.UNIT;
                }
                if (some2 instanceof Some) {
                    iterativeSolver2 = gradientDescent.setLearningRateMethod((LearningRateMethod.LearningRateMethodTrait) some2.x());
                } else {
                    if (!None$.MODULE$.equals(some2)) {
                        throw new MatchError(some2);
                    }
                    iterativeSolver2 = BoxedUnit.UNIT;
                }
                multipleLinearRegression.weightsOption_$eq(new Some(gradientDescent.optimize(dataSet, None$.MODULE$)));
            }
        };
    }
}
