package org.apache.flink.test.compiler.iterations;

import java.util.Arrays;
import org.apache.flink.api.common.Plan;
import org.apache.flink.api.common.operators.util.FieldList;
import org.apache.flink.api.java.record.operators.FileDataSource;
import org.apache.flink.compiler.plan.BulkIterationPlanNode;
import org.apache.flink.compiler.plan.NamedChannel;
import org.apache.flink.compiler.plan.OptimizedPlan;
import org.apache.flink.compiler.plan.SingleInputPlanNode;
import org.apache.flink.compiler.plan.SinkPlanNode;
import org.apache.flink.compiler.plantranslate.NepheleJobGraphGenerator;
import org.apache.flink.runtime.operators.DriverStrategy;
import org.apache.flink.runtime.operators.shipping.ShipStrategyType;
import org.apache.flink.runtime.operators.util.LocalStrategy;
import org.apache.flink.test.compiler.util.CompilerTestBase;
import org.apache.flink.test.compiler.util.OperatorResolver;
import org.apache.flink.test.iterative.nephele.JobGraphUtils;
import org.apache.flink.test.recordJobs.kmeans.KMeansBroadcast;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/apache/flink/test/compiler/iterations/IterativeKMeansTest.class */
public class IterativeKMeansTest extends CompilerTestBase {
    private static final String DATAPOINTS = "Data Points";
    private static final String CENTERS = "Centers";
    private static final String MAPPER_NAME = "Find Nearest Centers";
    private static final String REDUCER_NAME = "Recompute Center Positions";
    private static final String ITERATION_NAME = "k-means loop";
    private static final String SINK = "New Center Positions";
    private final FieldList set0 = new FieldList(0);

    @Test
    public void testCompileKMeansSingleStepWithStats() {
        Plan plan = new KMeansBroadcast().getPlan(String.valueOf(8), IN_FILE, IN_FILE, OUT_FILE, String.valueOf(20));
        OperatorResolver contractResolver = getContractResolver(plan);
        FileDataSource node = contractResolver.getNode(DATAPOINTS);
        FileDataSource node2 = contractResolver.getNode(CENTERS);
        setSourceStatistics(node, 107374182400L, 32.0f);
        setSourceStatistics(node2, JobGraphUtils.MEGABYTE, 32.0f);
        OptimizedPlan compileWithStats = compileWithStats(plan);
        checkPlan(compileWithStats);
        new NepheleJobGraphGenerator().compileJobGraph(compileWithStats);
    }

    @Test
    public void testCompileKMeansSingleStepWithOutStats() {
        OptimizedPlan compileNoStats = compileNoStats(new KMeansBroadcast().getPlan(String.valueOf(8), IN_FILE, IN_FILE, OUT_FILE, String.valueOf(20)));
        checkPlan(compileNoStats);
        new NepheleJobGraphGenerator().compileJobGraph(compileNoStats);
    }

    private void checkPlan(OptimizedPlan optimizedPlan) {
        CompilerTestBase.OptimizerPlanNodeResolver optimizerPlanNodeResolver = getOptimizerPlanNodeResolver(optimizedPlan);
        SinkPlanNode node = optimizerPlanNodeResolver.getNode(SINK);
        SingleInputPlanNode node2 = optimizerPlanNodeResolver.getNode(REDUCER_NAME);
        SingleInputPlanNode predecessor = node2.getPredecessor();
        SingleInputPlanNode node3 = optimizerPlanNodeResolver.getNode(MAPPER_NAME);
        BulkIterationPlanNode node4 = optimizerPlanNodeResolver.getNode(ITERATION_NAME);
        Assert.assertEquals(ShipStrategyType.FORWARD, node.getInput().getShipStrategy());
        Assert.assertEquals(LocalStrategy.NONE, node.getInput().getLocalStrategy());
        Assert.assertEquals(ShipStrategyType.FORWARD, node4.getInput().getShipStrategy());
        Assert.assertEquals(LocalStrategy.NONE, node4.getInput().getLocalStrategy());
        Assert.assertEquals(1L, node3.getBroadcastInputs().size());
        Assert.assertEquals(ShipStrategyType.FORWARD, node3.getInput().getShipStrategy());
        Assert.assertEquals(ShipStrategyType.BROADCAST, ((NamedChannel) node3.getBroadcastInputs().get(0)).getShipStrategy());
        Assert.assertFalse(node3.getInput().isOnDynamicPath());
        Assert.assertTrue(((NamedChannel) node3.getBroadcastInputs().get(0)).isOnDynamicPath());
        Assert.assertTrue(node3.getInput().getTempMode().isCached());
        Assert.assertEquals(LocalStrategy.NONE, node3.getInput().getLocalStrategy());
        Assert.assertEquals(LocalStrategy.NONE, ((NamedChannel) node3.getBroadcastInputs().get(0)).getLocalStrategy());
        Assert.assertEquals(DriverStrategy.COLLECTOR_MAP, node3.getDriverStrategy());
        Assert.assertNull(node3.getInput().getLocalStrategyKeys());
        Assert.assertNull(node3.getInput().getLocalStrategySortOrder());
        Assert.assertNull(((NamedChannel) node3.getBroadcastInputs().get(0)).getLocalStrategyKeys());
        Assert.assertNull(((NamedChannel) node3.getBroadcastInputs().get(0)).getLocalStrategySortOrder());
        Assert.assertNotNull(predecessor);
        Assert.assertEquals(ShipStrategyType.FORWARD, predecessor.getInput().getShipStrategy());
        Assert.assertTrue(predecessor.getInput().isOnDynamicPath());
        Assert.assertEquals(LocalStrategy.NONE, predecessor.getInput().getLocalStrategy());
        Assert.assertEquals(DriverStrategy.SORTED_GROUP_COMBINE, predecessor.getDriverStrategy());
        Assert.assertNull(predecessor.getInput().getLocalStrategyKeys());
        Assert.assertNull(predecessor.getInput().getLocalStrategySortOrder());
        Assert.assertEquals(this.set0, predecessor.getKeys(0));
        Assert.assertEquals(this.set0, predecessor.getKeys(1));
        Assert.assertEquals(ShipStrategyType.PARTITION_HASH, node2.getInput().getShipStrategy());
        Assert.assertTrue(node2.getInput().isOnDynamicPath());
        Assert.assertEquals(LocalStrategy.COMBININGSORT, node2.getInput().getLocalStrategy());
        Assert.assertEquals(DriverStrategy.SORTED_GROUP_REDUCE, node2.getDriverStrategy());
        Assert.assertEquals(this.set0, node2.getKeys(0));
        Assert.assertEquals(this.set0, node2.getInput().getLocalStrategyKeys());
        Assert.assertTrue(Arrays.equals(node2.getInput().getLocalStrategySortOrder(), node2.getSortOrders(0)));
    }
}
