package org.apache.flink.test.optimizer.examples;

import java.util.Arrays;
import java.util.Collection;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.Plan;
import org.apache.flink.api.common.functions.GroupCombineFunction;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.OpenContext;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.operators.GenericDataSourceBase;
import org.apache.flink.api.common.operators.util.FieldList;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.optimizer.plan.NamedChannel;
import org.apache.flink.optimizer.plan.OptimizedPlan;
import org.apache.flink.optimizer.plan.SingleInputPlanNode;
import org.apache.flink.optimizer.plan.SinkPlanNode;
import org.apache.flink.optimizer.util.CompilerTestBase;
import org.apache.flink.optimizer.util.OperatorResolver;
import org.apache.flink.runtime.operators.DriverStrategy;
import org.apache.flink.runtime.operators.shipping.ShipStrategyType;
import org.apache.flink.runtime.operators.util.LocalStrategy;
import org.apache.flink.util.Collector;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/apache/flink/test/optimizer/examples/KMeansSingleStepTest.class */
public class KMeansSingleStepTest extends CompilerTestBase {
    private static final String DATAPOINTS = "Data Points";
    private static final String CENTERS = "Centers";
    private static final String MAPPER_NAME = "Find Nearest Centers";
    private static final String REDUCER_NAME = "Recompute Center Positions";
    private static final String SINK = "New Center Positions";
    private final FieldList set0 = new FieldList(0);

    /* loaded from: input_file:org/apache/flink/test/optimizer/examples/KMeansSingleStepTest$Centroid.class */
    public static class Centroid extends Tuple2<Integer, Point> {
        public Centroid(int i, double d, double d2) {
            this.f0 = Integer.valueOf(i);
            this.f1 = new Point(d, d2);
        }

        public Centroid(int i, Point point) {
            this.f0 = Integer.valueOf(i);
            this.f1 = point;
        }
    }

    /* loaded from: input_file:org/apache/flink/test/optimizer/examples/KMeansSingleStepTest$Point.class */
    public static class Point extends Tuple2<Double, Double> {
        public Point(double d, double d2) {
            this.f0 = Double.valueOf(d);
            this.f1 = Double.valueOf(d2);
        }

        public Point add(Point point) {
            this.f0 = Double.valueOf(((Double) this.f0).doubleValue() + ((Double) point.f0).doubleValue());
            this.f1 = Double.valueOf(((Double) this.f1).doubleValue() + ((Double) point.f1).doubleValue());
            return this;
        }

        public Point div(long j) {
            this.f0 = Double.valueOf(((Double) this.f0).doubleValue() / j);
            this.f1 = Double.valueOf(((Double) this.f1).doubleValue() / j);
            return this;
        }

        public double euclideanDistance(Point point) {
            return Math.sqrt(((((Double) this.f0).doubleValue() - ((Double) point.f0).doubleValue()) * (((Double) this.f0).doubleValue() - ((Double) point.f0).doubleValue())) + ((((Double) this.f1).doubleValue() - ((Double) point.f1).doubleValue()) * (((Double) this.f1).doubleValue() - ((Double) point.f1).doubleValue())));
        }

        public double euclideanDistance(Centroid centroid) {
            return Math.sqrt(((((Double) this.f0).doubleValue() - ((Double) ((Point) centroid.f1).f0).doubleValue()) * (((Double) this.f0).doubleValue() - ((Double) ((Point) centroid.f1).f0).doubleValue())) + ((((Double) this.f1).doubleValue() - ((Double) ((Point) centroid.f1).f1).doubleValue()) * (((Double) this.f1).doubleValue() - ((Double) ((Point) centroid.f1).f1).doubleValue())));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/test/optimizer/examples/KMeansSingleStepTest$RecomputeClusterCenter.class */
    public static final class RecomputeClusterCenter implements GroupReduceFunction<Tuple3<Integer, Point, Integer>, Tuple3<Integer, Point, Integer>>, GroupCombineFunction<Tuple3<Integer, Point, Integer>, Tuple3<Integer, Point, Integer>> {
        private RecomputeClusterCenter() {
        }

        public void reduce(Iterable<Tuple3<Integer, Point, Integer>> iterable, Collector<Tuple3<Integer, Point, Integer>> collector) throws Exception {
            int i = -1;
            double d = 0.0d;
            double d2 = 0.0d;
            int i2 = 0;
            for (Tuple3<Integer, Point, Integer> tuple3 : iterable) {
                i = ((Integer) tuple3.f0).intValue();
                d += ((Double) ((Point) tuple3.f1).f0).doubleValue();
                d2 += ((Double) ((Point) tuple3.f1).f1).doubleValue();
                i2 += ((Integer) tuple3.f2).intValue();
            }
            collector.collect(new Tuple3(Integer.valueOf(i), new Point(d, d2), Integer.valueOf(i2)));
        }

        public void combine(Iterable<Tuple3<Integer, Point, Integer>> iterable, Collector<Tuple3<Integer, Point, Integer>> collector) throws Exception {
            reduce(iterable, collector);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/test/optimizer/examples/KMeansSingleStepTest$SelectNearestCenter.class */
    public static final class SelectNearestCenter extends RichMapFunction<Point, Tuple3<Integer, Point, Integer>> {
        private Collection<Centroid> centroids;

        private SelectNearestCenter() {
        }

        public void open(OpenContext openContext) throws Exception {
            this.centroids = getRuntimeContext().getBroadcastVariable("centroids");
        }

        public Tuple3<Integer, Point, Integer> map(Point point) throws Exception {
            double d = Double.MAX_VALUE;
            int i = -1;
            for (Centroid centroid : this.centroids) {
                double euclideanDistance = point.euclideanDistance(centroid);
                if (euclideanDistance < d) {
                    d = euclideanDistance;
                    i = ((Integer) centroid.f0).intValue();
                }
            }
            return new Tuple3<>(Integer.valueOf(i), point, 1);
        }
    }

    @Test
    public void testCompileKMeansSingleStepWithStats() throws Exception {
        Plan kMeansPlan = getKMeansPlan();
        kMeansPlan.setExecutionConfig(new ExecutionConfig());
        OperatorResolver contractResolver = getContractResolver(kMeansPlan);
        GenericDataSourceBase genericDataSourceBase = (GenericDataSourceBase) contractResolver.getNode(DATAPOINTS);
        GenericDataSourceBase genericDataSourceBase2 = (GenericDataSourceBase) contractResolver.getNode(CENTERS);
        setSourceStatistics(genericDataSourceBase, 107374182400L, 32.0f);
        setSourceStatistics(genericDataSourceBase2, 1048576L, 32.0f);
        checkPlan(compileWithStats(kMeansPlan));
    }

    @Test
    public void testCompileKMeansSingleStepWithOutStats() throws Exception {
        Plan kMeansPlan = getKMeansPlan();
        kMeansPlan.setExecutionConfig(new ExecutionConfig());
        checkPlan(compileNoStats(kMeansPlan));
    }

    private void checkPlan(OptimizedPlan optimizedPlan) {
        CompilerTestBase.OptimizerPlanNodeResolver optimizerPlanNodeResolver = getOptimizerPlanNodeResolver(optimizedPlan);
        SinkPlanNode node = optimizerPlanNodeResolver.getNode(SINK);
        SingleInputPlanNode node2 = optimizerPlanNodeResolver.getNode(REDUCER_NAME);
        SingleInputPlanNode predecessor = node2.getPredecessor();
        SingleInputPlanNode node3 = optimizerPlanNodeResolver.getNode(MAPPER_NAME);
        Assert.assertEquals(1L, node3.getBroadcastInputs().size());
        Assert.assertEquals(ShipStrategyType.FORWARD, node3.getInput().getShipStrategy());
        Assert.assertEquals(ShipStrategyType.BROADCAST, ((NamedChannel) node3.getBroadcastInputs().get(0)).getShipStrategy());
        Assert.assertEquals(LocalStrategy.NONE, node3.getInput().getLocalStrategy());
        Assert.assertEquals(LocalStrategy.NONE, ((NamedChannel) node3.getBroadcastInputs().get(0)).getLocalStrategy());
        Assert.assertEquals(DriverStrategy.MAP, node3.getDriverStrategy());
        Assert.assertNull(node3.getInput().getLocalStrategyKeys());
        Assert.assertNull(node3.getInput().getLocalStrategySortOrder());
        Assert.assertNull(((NamedChannel) node3.getBroadcastInputs().get(0)).getLocalStrategyKeys());
        Assert.assertNull(((NamedChannel) node3.getBroadcastInputs().get(0)).getLocalStrategySortOrder());
        Assert.assertNotNull(predecessor);
        Assert.assertEquals(ShipStrategyType.FORWARD, predecessor.getInput().getShipStrategy());
        Assert.assertEquals(LocalStrategy.NONE, predecessor.getInput().getLocalStrategy());
        Assert.assertEquals(DriverStrategy.SORTED_GROUP_COMBINE, predecessor.getDriverStrategy());
        Assert.assertNull(predecessor.getInput().getLocalStrategyKeys());
        Assert.assertNull(predecessor.getInput().getLocalStrategySortOrder());
        Assert.assertEquals(this.set0, predecessor.getKeys(0));
        Assert.assertEquals(this.set0, predecessor.getKeys(1));
        Assert.assertEquals(ShipStrategyType.PARTITION_HASH, node2.getInput().getShipStrategy());
        Assert.assertEquals(LocalStrategy.COMBININGSORT, node2.getInput().getLocalStrategy());
        Assert.assertEquals(DriverStrategy.SORTED_GROUP_REDUCE, node2.getDriverStrategy());
        Assert.assertEquals(this.set0, node2.getKeys(0));
        Assert.assertEquals(this.set0, node2.getInput().getLocalStrategyKeys());
        Assert.assertTrue(Arrays.equals(node2.getInput().getLocalStrategySortOrder(), node2.getSortOrders(0)));
        Assert.assertEquals(ShipStrategyType.FORWARD, node.getInput().getShipStrategy());
        Assert.assertEquals(LocalStrategy.NONE, node.getInput().getLocalStrategy());
    }

    public static Plan getKMeansPlan() throws Exception {
        return kmeans(new String[]{IN_FILE, IN_FILE, OUT_FILE, "20"});
    }

    public static Plan kmeans(String[] strArr) throws Exception {
        ExecutionEnvironment executionEnvironment = ExecutionEnvironment.getExecutionEnvironment();
        executionEnvironment.readCsvFile(strArr[0]).fieldDelimiter(" ").includeFields(new boolean[]{true, true}).types(Double.class, Double.class).name(DATAPOINTS).map(new MapFunction<Tuple2<Double, Double>, Point>() { // from class: org.apache.flink.test.optimizer.examples.KMeansSingleStepTest.1
            public Point map(Tuple2<Double, Double> tuple2) throws Exception {
                return new Point(((Double) tuple2.f0).doubleValue(), ((Double) tuple2.f1).doubleValue());
            }
        }).map(new SelectNearestCenter()).name(MAPPER_NAME).withBroadcastSet(executionEnvironment.readCsvFile(strArr[1]).fieldDelimiter(" ").includeFields(new boolean[]{true, true, true}).types(Integer.class, Double.class, Double.class).name(CENTERS).map(new MapFunction<Tuple3<Integer, Double, Double>, Centroid>() { // from class: org.apache.flink.test.optimizer.examples.KMeansSingleStepTest.2
            public Centroid map(Tuple3<Integer, Double, Double> tuple3) throws Exception {
                return new Centroid(((Integer) tuple3.f0).intValue(), ((Double) tuple3.f1).doubleValue(), ((Double) tuple3.f2).doubleValue());
            }
        }), "centroids").groupBy(new int[]{0}).reduceGroup(new RecomputeClusterCenter()).name(REDUCER_NAME).project(new int[]{0, 1}).writeAsCsv(strArr[2], "\n", " ").name(SINK);
        return executionEnvironment.createProgramPlan("KMeans Example");
    }
}
