package org.apache.flink.test.broadcastvars;

import java.io.BufferedReader;
import java.util.Collection;
import java.util.Random;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache.flink.api.common.operators.util.UserCodeClassWrapper;
import org.apache.flink.api.common.operators.util.UserCodeObjectWrapper;
import org.apache.flink.api.common.typeutils.TypeSerializerFactory;
import org.apache.flink.api.java.record.functions.MapFunction;
import org.apache.flink.api.java.record.io.CsvInputFormat;
import org.apache.flink.api.java.record.io.CsvOutputFormat;
import org.apache.flink.api.java.typeutils.runtime.record.RecordSerializerFactory;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.fs.Path;
import org.apache.flink.runtime.io.network.channels.ChannelType;
import org.apache.flink.runtime.jobgraph.DistributionPattern;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.jobgraph.JobGraphDefinitionException;
import org.apache.flink.runtime.jobgraph.JobInputVertex;
import org.apache.flink.runtime.jobgraph.JobOutputVertex;
import org.apache.flink.runtime.jobgraph.JobTaskVertex;
import org.apache.flink.runtime.operators.CollectorMapDriver;
import org.apache.flink.runtime.operators.DriverStrategy;
import org.apache.flink.runtime.operators.RegularPactTask;
import org.apache.flink.runtime.operators.shipping.ShipStrategyType;
import org.apache.flink.runtime.operators.util.LocalStrategy;
import org.apache.flink.runtime.operators.util.TaskConfig;
import org.apache.flink.test.iterative.nephele.JobGraphUtils;
import org.apache.flink.test.util.RecordAPITestBase;
import org.apache.flink.types.LongValue;
import org.apache.flink.types.Record;
import org.apache.flink.util.Collector;
import org.junit.Assert;

/* loaded from: input_file:org/apache/flink/test/broadcastvars/BroadcastVarsNepheleITCase.class */
public class BroadcastVarsNepheleITCase extends RecordAPITestBase {
    private static final long SEED_POINTS = 3287269182979823L;
    private static final long SEED_MODELS = 1004078042382130L;
    private static final int NUM_POINTS = 10000;
    private static final int NUM_MODELS = 42;
    private static final int NUM_FEATURES = 3;
    private static final int DOP = 4;
    protected String pointsPath;
    protected String modelsPath;
    protected String resultPath;

    /* loaded from: input_file:org/apache/flink/test/broadcastvars/BroadcastVarsNepheleITCase$DotProducts.class */
    public static final class DotProducts extends MapFunction {
        private static final long serialVersionUID = 1;
        private final Record result = new Record(BroadcastVarsNepheleITCase.NUM_FEATURES);
        private final LongValue lft = new LongValue();
        private final LongValue rgt = new LongValue();
        private final LongValue prd = new LongValue();
        private Collection<Record> models;

        public void open(Configuration configuration) throws Exception {
            this.models = getRuntimeContext().getBroadcastVariable("models");
        }

        public void map(Record record, Collector<Record> collector) throws Exception {
            for (Record record2 : this.models) {
                long j = 0;
                for (int i = 1; i <= BroadcastVarsNepheleITCase.NUM_FEATURES; i++) {
                    j += record2.getField(i, this.lft).getValue() * record.getField(i, this.rgt).getValue();
                }
                this.prd.setValue(j);
                this.result.copyFrom(record2, new int[]{0}, new int[]{0});
                this.result.copyFrom(record, new int[]{0}, new int[]{1});
                this.result.setField(2, this.prd);
                collector.collect(this.result);
            }
        }

        public /* bridge */ /* synthetic */ void map(Object obj, Collector collector) throws Exception {
            map((Record) obj, (Collector<Record>) collector);
        }
    }

    public BroadcastVarsNepheleITCase() {
        setTaskManagerNumSlots(DOP);
    }

    public static final String getInputPoints(int i, int i2, long j) {
        if (i < 1 || i > 1000000) {
            throw new IllegalArgumentException();
        }
        Random random = new Random();
        StringBuilder sb = new StringBuilder(NUM_FEATURES * (1 + i2) * i);
        for (int i3 = 1; i3 <= i; i3++) {
            sb.append(i3);
            sb.append(' ');
            random.setSeed(j + (1000 * i3));
            for (int i4 = 1; i4 <= i2; i4++) {
                sb.append(random.nextInt(1000));
                sb.append(' ');
            }
            sb.append('\n');
        }
        return sb.toString();
    }

    public static final String getInputModels(int i, int i2, long j) {
        if (i < 1 || i > 100) {
            throw new IllegalArgumentException();
        }
        Random random = new Random();
        StringBuilder sb = new StringBuilder(NUM_FEATURES * (1 + i2) * i);
        for (int i3 = 1; i3 <= i; i3++) {
            sb.append(i3);
            sb.append(' ');
            random.setSeed(j + (1000 * i3));
            for (int i4 = 1; i4 <= i2; i4++) {
                sb.append(random.nextInt(100));
                sb.append(' ');
            }
            sb.append('\n');
        }
        return sb.toString();
    }

    protected void preSubmit() throws Exception {
        this.pointsPath = createTempFile("points.txt", getInputPoints(NUM_POINTS, NUM_FEATURES, SEED_POINTS));
        this.modelsPath = createTempFile("models.txt", getInputModels(NUM_MODELS, NUM_FEATURES, SEED_MODELS));
        this.resultPath = getTempFilePath("results");
    }

    protected JobGraph getJobGraph() throws Exception {
        return createJobGraphV1(this.pointsPath, this.modelsPath, this.resultPath, DOP);
    }

    protected void postSubmit() throws Exception {
        Random random = new Random();
        Random random2 = new Random();
        Pattern compile = Pattern.compile("(\\d+) (\\d+) (\\d+)");
        long[][] jArr = new long[NUM_POINTS][NUM_MODELS];
        boolean[][] zArr = new boolean[NUM_POINTS][NUM_MODELS];
        for (int i = 0; i < NUM_POINTS; i++) {
            for (int i2 = 0; i2 < NUM_MODELS; i2++) {
                long j = 0;
                random.setSeed(SEED_POINTS + (1000 * (i + 1)));
                random2.setSeed(SEED_MODELS + (1000 * (i2 + 1)));
                for (int i3 = 1; i3 <= NUM_FEATURES; i3++) {
                    j += random.nextInt(1000) * random2.nextInt(100);
                }
                jArr[i][i2] = j;
                zArr[i][i2] = false;
            }
        }
        for (BufferedReader bufferedReader : getResultReader(this.resultPath)) {
            while (true) {
                String readLine = bufferedReader.readLine();
                if (null != readLine) {
                    Matcher matcher = compile.matcher(readLine);
                    Assert.assertTrue(matcher.matches());
                    int parseInt = Integer.parseInt(matcher.group(1));
                    int parseInt2 = Integer.parseInt(matcher.group(2));
                    long parseLong = Long.parseLong(matcher.group(NUM_FEATURES));
                    Assert.assertFalse("Dot product for record (" + parseInt2 + ", " + parseInt + ") occurs more than once", zArr[parseInt2 - 1][parseInt - 1]);
                    Assert.assertEquals(String.format("Bad product for (%04d, %04d)", Integer.valueOf(parseInt2), Integer.valueOf(parseInt)), parseLong, jArr[parseInt2 - 1][parseInt - 1]);
                    zArr[parseInt2 - 1][parseInt - 1] = true;
                }
            }
        }
        for (int i4 = 0; i4 < NUM_POINTS; i4++) {
            for (int i5 = 0; i5 < NUM_MODELS; i5++) {
                Assert.assertTrue("Dot product for record (" + (i4 + 1) + ", " + (i5 + 1) + ") does not occur", zArr[i4][i5]);
            }
        }
    }

    private static JobInputVertex createPointsInput(JobGraph jobGraph, String str, int i, TypeSerializerFactory<?> typeSerializerFactory) {
        JobInputVertex createInput = JobGraphUtils.createInput(new CsvInputFormat(' ', new Class[]{LongValue.class, LongValue.class, LongValue.class, LongValue.class}), str, "Input[Points]", jobGraph, i);
        TaskConfig taskConfig = new TaskConfig(createInput.getConfiguration());
        taskConfig.addOutputShipStrategy(ShipStrategyType.FORWARD);
        taskConfig.setOutputSerializer(typeSerializerFactory);
        return createInput;
    }

    private static JobInputVertex createModelsInput(JobGraph jobGraph, String str, int i, TypeSerializerFactory<?> typeSerializerFactory) {
        JobInputVertex createInput = JobGraphUtils.createInput(new CsvInputFormat(' ', new Class[]{LongValue.class, LongValue.class, LongValue.class, LongValue.class}), str, "Input[Models]", jobGraph, i);
        TaskConfig taskConfig = new TaskConfig(createInput.getConfiguration());
        taskConfig.addOutputShipStrategy(ShipStrategyType.BROADCAST);
        taskConfig.setOutputSerializer(typeSerializerFactory);
        return createInput;
    }

    private static JobTaskVertex createMapper(JobGraph jobGraph, int i, TypeSerializerFactory<?> typeSerializerFactory) {
        JobTaskVertex createTask = JobGraphUtils.createTask(RegularPactTask.class, "Map[DotProducts]", jobGraph, i);
        TaskConfig taskConfig = new TaskConfig(createTask.getConfiguration());
        taskConfig.setStubWrapper(new UserCodeClassWrapper(DotProducts.class));
        taskConfig.addOutputShipStrategy(ShipStrategyType.FORWARD);
        taskConfig.setOutputSerializer(typeSerializerFactory);
        taskConfig.setDriver(CollectorMapDriver.class);
        taskConfig.setDriverStrategy(DriverStrategy.COLLECTOR_MAP);
        taskConfig.addInputToGroup(0);
        taskConfig.setInputLocalStrategy(0, LocalStrategy.NONE);
        taskConfig.setInputSerializer(typeSerializerFactory, 0);
        taskConfig.setBroadcastInputName("models", 0);
        taskConfig.addBroadcastInputToGroup(0);
        taskConfig.setBroadcastInputSerializer(typeSerializerFactory, 0);
        return createTask;
    }

    private static JobOutputVertex createOutput(JobGraph jobGraph, String str, int i, TypeSerializerFactory<?> typeSerializerFactory) {
        JobOutputVertex createFileOutput = JobGraphUtils.createFileOutput(jobGraph, "Output", i);
        TaskConfig taskConfig = new TaskConfig(createFileOutput.getConfiguration());
        taskConfig.addInputToGroup(0);
        taskConfig.setInputSerializer(typeSerializerFactory, 0);
        CsvOutputFormat csvOutputFormat = new CsvOutputFormat("\n", " ", new Class[]{LongValue.class, LongValue.class, LongValue.class});
        csvOutputFormat.setOutputFilePath(new Path(str));
        taskConfig.setStubWrapper(new UserCodeObjectWrapper(csvOutputFormat));
        return createFileOutput;
    }

    private JobGraph createJobGraphV1(String str, String str2, String str3, int i) throws JobGraphDefinitionException {
        RecordSerializerFactory recordSerializerFactory = RecordSerializerFactory.get();
        JobGraph jobGraph = new JobGraph("Distance Builder");
        JobInputVertex createPointsInput = createPointsInput(jobGraph, str, i, recordSerializerFactory);
        JobInputVertex createModelsInput = createModelsInput(jobGraph, str2, i, recordSerializerFactory);
        JobTaskVertex createMapper = createMapper(jobGraph, i, recordSerializerFactory);
        JobOutputVertex createOutput = createOutput(jobGraph, str3, i, recordSerializerFactory);
        JobGraphUtils.connect(createPointsInput, createMapper, ChannelType.NETWORK, DistributionPattern.POINTWISE);
        JobGraphUtils.connect(createModelsInput, createMapper, ChannelType.NETWORK, DistributionPattern.BIPARTITE);
        JobGraphUtils.connect(createMapper, createOutput, ChannelType.NETWORK, DistributionPattern.POINTWISE);
        createPointsInput.setVertexToShareInstancesWith(createOutput);
        createModelsInput.setVertexToShareInstancesWith(createOutput);
        createMapper.setVertexToShareInstancesWith(createOutput);
        return jobGraph;
    }
}
