package org.apache.flink.ml.optimization;

import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.scala.DataSet;
import org.apache.flink.ml.common.LabeledVector;
import org.apache.flink.ml.common.WeightVector;
import org.apache.flink.ml.math.Vector;
import org.apache.flink.ml.package$;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Some;
import scala.Tuple2;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;

/* compiled from: GradientDescent.scala */
@ScalaSignature(bytes = "\u0006\u0001]4Q!\u0001\u0002\u0002\u00025\u0011qb\u0012:bI&,g\u000e\u001e#fg\u000e,g\u000e\u001e\u0006\u0003\u0007\u0011\tAb\u001c9uS6L'0\u0019;j_:T!!\u0002\u0004\u0002\u00055d'BA\u0004\t\u0003\u00151G.\u001b8l\u0015\tI!\"\u0001\u0004ba\u0006\u001c\u0007.\u001a\u0006\u0002\u0017\u0005\u0019qN]4\u0004\u0001M\u0011\u0001A\u0004\t\u0003\u001fAi\u0011AA\u0005\u0003#\t\u0011q\"\u0013;fe\u0006$\u0018N^3T_24XM\u001d\u0005\u0006'\u0001!\t\u0001F\u0001\u0007y%t\u0017\u000e\u001e \u0015\u0003U\u0001\"a\u0004\u0001\t\u000b]\u0001A\u0011\t\r\u0002\u0011=\u0004H/[7ju\u0016$2!G\u0014.!\rQr$I\u0007\u00027)\u0011A$H\u0001\u0006g\u000e\fG.\u0019\u0006\u0003=\u0019\t1!\u00199j\u0013\t\u00013DA\u0004ECR\f7+\u001a;\u0011\u0005\t*S\"A\u0012\u000b\u0005\u0011\"\u0011AB2p[6|g.\u0003\u0002'G\taq+Z5hQR4Vm\u0019;pe\")\u0001F\u0006a\u0001S\u0005!A-\u0019;b!\rQrD\u000b\t\u0003E-J!\u0001L\u0012\u0003\u001b1\u000b'-\u001a7fIZ+7\r^8s\u0011\u0015qc\u00031\u00010\u00039Ig.\u001b;jC2<V-[4iiN\u00042\u0001\r\u001a\u001a\u001b\u0005\t$\"\u0001\u000f\n\u0005M\n$AB(qi&|g\u000eC\u00036\u0001\u0011\u0005a'\u0001\u0011paRLW.\u001b>f/&$\bnQ8om\u0016\u0014x-\u001a8dK\u000e\u0013\u0018\u000e^3sS>tG\u0003C\r8sm\u0002UiR%\t\u000ba\"\u0004\u0019A\u0015\u0002\u0015\u0011\fG/\u0019)pS:$8\u000fC\u0003;i\u0001\u0007\u0011$\u0001\tj]&$\u0018.\u00197XK&<\u0007\u000e^:E'\")A\b\u000ea\u0001{\u0005\u0011b.^7cKJ|e-\u0013;fe\u0006$\u0018n\u001c8t!\t\u0001d(\u0003\u0002@c\t\u0019\u0011J\u001c;\t\u000b\u0005#\u0004\u0019\u0001\"\u0002-I,w-\u001e7be&T\u0018\r^5p]\u000e{gn\u001d;b]R\u0004\"\u0001M\"\n\u0005\u0011\u000b$A\u0002#pk\ndW\rC\u0003Gi\u0001\u0007!)\u0001\u0007mK\u0006\u0014h.\u001b8h%\u0006$X\rC\u0003Ii\u0001\u0007!)\u0001\u000bd_:4XM]4f]\u000e,G\u000b\u001b:fg\"|G\u000e\u001a\u0005\u0006\u0015R\u0002\raS\u0001\rY>\u001c8OR;oGRLwN\u001c\t\u0003\u001f1K!!\u0014\u0002\u0003\u00191{7o\u001d$v]\u000e$\u0018n\u001c8\t\u000b=\u0003A\u0011\u0001)\u0002G=\u0004H/[7ju\u0016<\u0016\u000e\u001e5pkR\u001cuN\u001c<fe\u001e,gnY3De&$XM]5p]R9\u0011$\u0015*T)V3\u0006\"\u0002\u0015O\u0001\u0004I\u0003\"\u0002\u001eO\u0001\u0004I\u0002\"\u0002\u001fO\u0001\u0004i\u0004\"B!O\u0001\u0004\u0011\u0005\"\u0002$O\u0001\u0004\u0011\u0005\"\u0002&O\u0001\u0004Y\u0005\"\u0002-\u0001\t\u0013I\u0016aB*H\tN#X\r\u001d\u000b\u00073i[VLX0\t\u000b!:\u0006\u0019A\u0015\t\u000bq;\u0006\u0019A\r\u0002\u001d\r,(O]3oi^+\u0017n\u001a5ug\")!j\u0016a\u0001\u0017\")\u0011i\u0016a\u0001\u0005\")ai\u0016a\u0001\u0005\")\u0011\r\u0001D\u0001E\u0006AA/Y6f'R,\u0007\u000fF\u0003dS.lg\u000e\u0005\u0002eO6\tQM\u0003\u0002g\t\u0005!Q.\u0019;i\u0013\tAWM\u0001\u0004WK\u000e$xN\u001d\u0005\u0006U\u0002\u0004\raY\u0001\ro\u0016Lw\r\u001b;WK\u000e$xN\u001d\u0005\u0006Y\u0002\u0004\raY\u0001\tOJ\fG-[3oi\")\u0011\t\u0019a\u0001\u0005\")a\t\u0019a\u0001\u0005\")\u0001\u000f\u0001C\u0005c\u0006i1-\u00197dk2\fG/\u001a'pgN$BA]:umB\u0019!d\b\"\t\u000b!z\u0007\u0019A\u0015\t\u000bU|\u0007\u0019A\r\u0002\u0011],\u0017n\u001a5u\tNCQAS8A\u0002-\u0003")
/* loaded from: input_file:org/apache/flink/ml/optimization/GradientDescent.class */
public abstract class GradientDescent extends IterativeSolver {
    @Override // org.apache.flink.ml.optimization.Solver
    public DataSet<WeightVector> optimize(DataSet<LabeledVector> dataSet, Option<DataSet<WeightVector>> option) {
        DataSet<WeightVector> optimizeWithConvergenceCriterion;
        int unboxToInt = BoxesRunTime.unboxToInt(parameters().apply(IterativeSolver$Iterations$.MODULE$));
        Some some = parameters().get(IterativeSolver$ConvergenceThreshold$.MODULE$);
        LossFunction lossFunction = (LossFunction) parameters().apply(Solver$LossFunction$.MODULE$);
        double unboxToDouble = BoxesRunTime.unboxToDouble(parameters().apply(IterativeSolver$LearningRate$.MODULE$));
        double unboxToDouble2 = BoxesRunTime.unboxToDouble(parameters().apply(Solver$RegularizationConstant$.MODULE$));
        DataSet<WeightVector> createInitialWeightsDS = createInitialWeightsDS(option, dataSet);
        None$ none$ = None$.MODULE$;
        if (none$ != null ? none$.equals(some) : some == null) {
            optimizeWithConvergenceCriterion = optimizeWithoutConvergenceCriterion(dataSet, createInitialWeightsDS, unboxToInt, unboxToDouble2, unboxToDouble, lossFunction);
        } else {
            if (!(some instanceof Some)) {
                throw new MatchError(some);
            }
            optimizeWithConvergenceCriterion = optimizeWithConvergenceCriterion(dataSet, createInitialWeightsDS, unboxToInt, unboxToDouble2, unboxToDouble, BoxesRunTime.unboxToDouble(some.x()), lossFunction);
        }
        return optimizeWithConvergenceCriterion;
    }

    public DataSet<WeightVector> optimizeWithConvergenceCriterion(DataSet<LabeledVector> dataSet, DataSet<WeightVector> dataSet2, int i, double d, double d2, double d3, LossFunction lossFunction) {
        return package$.MODULE$.RichDataSet(dataSet2).mapWithBcVariable(org$apache$flink$ml$optimization$GradientDescent$$calculateLoss(dataSet, dataSet2, lossFunction), new GradientDescent$$anonfun$2(this), new GradientDescent$$anon$14(this), ClassTag$.MODULE$.apply(Tuple2.class)).iterateWithTermination(i, new GradientDescent$$anonfun$3(this, dataSet, d, d2, d3, lossFunction)).map(new GradientDescent$$anonfun$optimizeWithConvergenceCriterion$1(this), new GradientDescent$$anon$21(this), ClassTag$.MODULE$.apply(WeightVector.class));
    }

    public DataSet<WeightVector> optimizeWithoutConvergenceCriterion(DataSet<LabeledVector> dataSet, DataSet<WeightVector> dataSet2, int i, double d, double d2, LossFunction lossFunction) {
        return dataSet2.iterate(i, new GradientDescent$$anonfun$optimizeWithoutConvergenceCriterion$1(this, dataSet, d, d2, lossFunction));
    }

    public DataSet<WeightVector> org$apache$flink$ml$optimization$GradientDescent$$SGDStep(DataSet<LabeledVector> dataSet, DataSet<WeightVector> dataSet2, LossFunction lossFunction, double d, double d2) {
        return package$.MODULE$.RichDataSet(package$.MODULE$.RichDataSet(dataSet).mapWithBcVariable(dataSet2, new GradientDescent$$anonfun$org$apache$flink$ml$optimization$GradientDescent$$SGDStep$1(this, lossFunction), new GradientDescent$$anon$22(this), ClassTag$.MODULE$.apply(Tuple2.class)).reduce(new GradientDescent$$anonfun$org$apache$flink$ml$optimization$GradientDescent$$SGDStep$2(this))).mapWithBcVariableIteration(dataSet2, new GradientDescent$$anonfun$org$apache$flink$ml$optimization$GradientDescent$$SGDStep$3(this, d, d2), new GradientDescent$$anon$25(this), ClassTag$.MODULE$.apply(WeightVector.class));
    }

    public abstract Vector takeStep(Vector vector, Vector vector2, double d, double d2);

    public DataSet<Object> org$apache$flink$ml$optimization$GradientDescent$$calculateLoss(DataSet<LabeledVector> dataSet, DataSet<WeightVector> dataSet2, LossFunction lossFunction) {
        return package$.MODULE$.RichDataSet(dataSet).mapWithBcVariable(dataSet2, new GradientDescent$$anonfun$org$apache$flink$ml$optimization$GradientDescent$$calculateLoss$1(this, lossFunction), new GradientDescent$$anon$26(this), ClassTag$.MODULE$.apply(Tuple2.class)).reduce(new GradientDescent$$anonfun$org$apache$flink$ml$optimization$GradientDescent$$calculateLoss$2(this)).map(new GradientDescent$$anonfun$org$apache$flink$ml$optimization$GradientDescent$$calculateLoss$3(this), BasicTypeInfo.getInfoFor(Double.TYPE), ClassTag$.MODULE$.Double());
    }
}
