package org.apache.wayang.apps.sgd;

import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.apache.wayang.api.DataQuantaBuilder;
import org.apache.wayang.api.JavaPlanBuilder;
import org.apache.wayang.commons.util.profiledb.model.Experiment;
import org.apache.wayang.core.api.Configuration;
import org.apache.wayang.core.api.WayangContext;
import org.apache.wayang.core.plugin.Plugin;
import org.apache.wayang.core.util.Tuple;
import org.apache.wayang.core.util.WayangCollections;

/* loaded from: input_file:org/apache/wayang/apps/sgd/SGDImpl.class */
public class SGDImpl {
    private final Configuration configuration;
    private final List<Plugin> plugins;

    public SGDImpl(Configuration configuration, Plugin[] pluginArr) {
        this.configuration = configuration;
        this.plugins = Arrays.asList(pluginArr);
    }

    public double[] apply(String str, int i, int i2, int i3, double d, int i4) {
        return apply(str, i, i2, i3, d, i4, null);
    }

    /* JADX WARN: Type inference failed for: r0v11, types: [java.lang.Object[], double[]] */
    public double[] apply(String str, int i, int i2, int i3, double d, int i4, Experiment experiment) {
        WayangContext wayangContext = new WayangContext(this.configuration);
        Iterator<Plugin> it = this.plugins.iterator();
        while (it.hasNext()) {
            wayangContext.withPlugin(it.next());
        }
        JavaPlanBuilder javaPlanBuilder = new JavaPlanBuilder(wayangContext);
        if (experiment != null) {
            javaPlanBuilder.withExperiment(experiment);
        }
        javaPlanBuilder.withUdfJarOf(getClass());
        DataQuantaBuilder withName = javaPlanBuilder.loadCollection(Arrays.asList(new double[]{new double[i2]})).withName("init weights");
        DataQuantaBuilder withName2 = javaPlanBuilder.readTextFile(str).withName("source").map(new Transform(i2)).withName("transform");
        return (double[]) WayangCollections.getSingleOrNull(withName.doWhile(new LoopCondition(d, i3), dataQuantaBuilder -> {
            DataQuantaBuilder withName3 = withName2.sample(i4).withDatasetSize(i).withBroadcast(dataQuantaBuilder, "weights").map(new ComputeLogisticGradient()).withBroadcast(dataQuantaBuilder, "weights").withName("compute").reduce(new Sum()).withName("reduce").map(new WeightsUpdate()).withBroadcast(dataQuantaBuilder, "weights").withName("update");
            return new Tuple(withName3, withName3.map(new ComputeNorm()).withBroadcast(dataQuantaBuilder, "weights"));
        }).withExpectedNumberOfIterations(i3).collect());
    }
}
