package org.apache.flink.test.recordJobs.kmeans;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.flink.api.common.Plan;
import org.apache.flink.api.common.Program;
import org.apache.flink.api.common.ProgramDescription;
import org.apache.flink.api.java.record.functions.MapFunction;
import org.apache.flink.api.java.record.functions.ReduceFunction;
import org.apache.flink.api.java.record.io.CsvInputFormat;
import org.apache.flink.api.java.record.io.FileOutputFormat;
import org.apache.flink.api.java.record.operators.BulkIteration;
import org.apache.flink.api.java.record.operators.FileDataSink;
import org.apache.flink.api.java.record.operators.FileDataSource;
import org.apache.flink.api.java.record.operators.MapOperator;
import org.apache.flink.api.java.record.operators.ReduceOperator;
import org.apache.flink.client.LocalExecutor;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.memory.DataInputView;
import org.apache.flink.core.memory.DataOutputView;
import org.apache.flink.types.DoubleValue;
import org.apache.flink.types.IntValue;
import org.apache.flink.types.Record;
import org.apache.flink.types.Value;
import org.apache.flink.util.Collector;

/* loaded from: input_file:org/apache/flink/test/recordJobs/kmeans/KMeansBroadcast.class */
public class KMeansBroadcast implements Program, ProgramDescription {
    private static final long serialVersionUID = 1;

    /* loaded from: input_file:org/apache/flink/test/recordJobs/kmeans/KMeansBroadcast$Point.class */
    public static final class Point implements Value {
        private static final long serialVersionUID = 1;
        public double x;
        public double y;
        public double z;

        public Point() {
        }

        public Point(double d, double d2, double d3) {
            this.x = d;
            this.y = d2;
            this.z = d3;
        }

        public void add(Point point) {
            this.x += point.x;
            this.y += point.y;
            this.z += point.z;
        }

        public Point div(long j) {
            this.x /= j;
            this.y /= j;
            this.z /= j;
            return this;
        }

        public double euclideanDistance(Point point) {
            return Math.sqrt(((this.x - point.x) * (this.x - point.x)) + ((this.y - point.y) * (this.y - point.y)) + ((this.z - point.z) * (this.z - point.z)));
        }

        /* JADX WARN: Multi-variable type inference failed */
        /* JADX WARN: Type inference failed for: r3v0, types: [org.apache.flink.test.recordJobs.kmeans.KMeansBroadcast$Point] */
        public void clear() {
            ?? r3 = 0;
            this.z = 0.0d;
            this.y = 0.0d;
            r3.x = this;
        }

        public void write(DataOutputView dataOutputView) throws IOException {
            dataOutputView.writeDouble(this.x);
            dataOutputView.writeDouble(this.y);
            dataOutputView.writeDouble(this.z);
        }

        public void read(DataInputView dataInputView) throws IOException {
            this.x = dataInputView.readDouble();
            this.y = dataInputView.readDouble();
            this.z = dataInputView.readDouble();
        }

        public String toString() {
            return "(" + this.x + "|" + this.y + "|" + this.z + ")";
        }
    }

    /* loaded from: input_file:org/apache/flink/test/recordJobs/kmeans/KMeansBroadcast$PointBuilder.class */
    public static final class PointBuilder extends MapFunction {
        private static final long serialVersionUID = 1;

        public void map(Record record, Collector<Record> collector) throws Exception {
            record.setField(1, new Point(record.getField(1, DoubleValue.class).getValue(), record.getField(2, DoubleValue.class).getValue(), record.getField(3, DoubleValue.class).getValue()));
            collector.collect(record);
        }

        public /* bridge */ /* synthetic */ void map(Object obj, Collector collector) throws Exception {
            map((Record) obj, (Collector<Record>) collector);
        }
    }

    /* loaded from: input_file:org/apache/flink/test/recordJobs/kmeans/KMeansBroadcast$PointOutFormat.class */
    public static final class PointOutFormat extends FileOutputFormat {
        private static final long serialVersionUID = 1;
        private static final String format = "%d|%.1f|%.1f|%.1f|\n";

        public void writeRecord(Record record) throws IOException {
            int value = record.getField(0, IntValue.class).getValue();
            Point point = (Point) record.getField(1, Point.class);
            this.stream.write(String.format(format, Integer.valueOf(value), Double.valueOf(point.x), Double.valueOf(point.y), Double.valueOf(point.z)).getBytes());
        }
    }

    /* loaded from: input_file:org/apache/flink/test/recordJobs/kmeans/KMeansBroadcast$PointWithId.class */
    public static final class PointWithId {
        public int id;
        public Point point;

        public PointWithId(int i, Point point) {
            this.id = i;
            this.point = point;
        }
    }

    @ReduceOperator.Combinable
    /* loaded from: input_file:org/apache/flink/test/recordJobs/kmeans/KMeansBroadcast$RecomputeClusterCenter.class */
    public static final class RecomputeClusterCenter extends ReduceFunction {
        private static final long serialVersionUID = 1;
        private final Point p = new Point();

        public void reduce(Iterator<Record> it, Collector<Record> collector) {
            Record sumPointsAndCount = sumPointsAndCount(it);
            sumPointsAndCount.setField(1, ((Point) sumPointsAndCount.getField(1, Point.class)).div(sumPointsAndCount.getField(2, IntValue.class).getValue()));
            collector.collect(sumPointsAndCount);
        }

        public void combine(Iterator<Record> it, Collector<Record> collector) {
            collector.collect(sumPointsAndCount(it));
        }

        private final Record sumPointsAndCount(Iterator<Record> it) {
            Record record = null;
            this.p.clear();
            int i = 0;
            while (true) {
                int i2 = i;
                if (!it.hasNext()) {
                    record.setField(1, this.p);
                    record.setField(2, new IntValue(i2));
                    return record;
                }
                record = it.next();
                this.p.add((Point) record.getField(1, Point.class));
                i = i2 + record.getField(2, IntValue.class).getValue();
            }
        }
    }

    /* loaded from: input_file:org/apache/flink/test/recordJobs/kmeans/KMeansBroadcast$SelectNearestCenter.class */
    public static final class SelectNearestCenter extends MapFunction {
        private static final long serialVersionUID = 1;
        private final IntValue one = new IntValue(1);
        private final Record result = new Record(3);
        private List<PointWithId> centers = new ArrayList();

        public void open(Configuration configuration) throws Exception {
            List<Record> broadcastVariable = getRuntimeContext().getBroadcastVariable("centers");
            this.centers.clear();
            synchronized (broadcastVariable) {
                for (Record record : broadcastVariable) {
                    this.centers.add(new PointWithId(record.getField(0, IntValue.class).getValue(), (Point) record.getField(1, Point.class)));
                }
            }
        }

        public void map(Record record, Collector<Record> collector) {
            Point point = (Point) record.getField(1, Point.class);
            double d = Double.MAX_VALUE;
            int i = -1;
            for (PointWithId pointWithId : this.centers) {
                double euclideanDistance = point.euclideanDistance(pointWithId.point);
                if (euclideanDistance < d) {
                    d = euclideanDistance;
                    i = pointWithId.id;
                }
            }
            this.result.setField(0, new IntValue(i));
            this.result.setField(1, point);
            this.result.setField(2, this.one);
            collector.collect(this.result);
        }

        public /* bridge */ /* synthetic */ void map(Object obj, Collector collector) throws Exception {
            map((Record) obj, (Collector<Record>) collector);
        }
    }

    public Plan getPlan(String... strArr) {
        int parseInt = strArr.length > 0 ? Integer.parseInt(strArr[0]) : 1;
        String str = strArr.length > 1 ? strArr[1] : "";
        String str2 = strArr.length > 2 ? strArr[2] : "";
        String str3 = strArr.length > 3 ? strArr[3] : "";
        int parseInt2 = strArr.length > 4 ? Integer.parseInt(strArr[4]) : 2;
        FileDataSource fileDataSource = new FileDataSource(new CsvInputFormat('|', new Class[]{IntValue.class, DoubleValue.class, DoubleValue.class, DoubleValue.class}), str, "Data Points");
        FileDataSource fileDataSource2 = new FileDataSource(new CsvInputFormat('|', new Class[]{IntValue.class, DoubleValue.class, DoubleValue.class, DoubleValue.class}), str2, "Centers");
        MapOperator build = MapOperator.builder(new PointBuilder()).name("Build data points").input(fileDataSource).build();
        MapOperator build2 = MapOperator.builder(new PointBuilder()).name("Build cluster points").input(fileDataSource2).build();
        BulkIteration bulkIteration = new BulkIteration("k-means loop");
        bulkIteration.setInput(build2);
        bulkIteration.setMaximumNumberOfIterations(parseInt2);
        bulkIteration.setNextPartialSolution(ReduceOperator.builder(new RecomputeClusterCenter(), IntValue.class, 0).input(MapOperator.builder(new SelectNearestCenter()).setBroadcastVariable("centers", bulkIteration.getPartialSolution()).input(build).name("Find Nearest Centers").build()).name("Recompute Center Positions").build());
        Plan plan = new Plan(new FileDataSink(new PointOutFormat(), str3, bulkIteration, "New Center Positions"), "K-Means");
        plan.setDefaultParallelism(parseInt);
        return plan;
    }

    public String getDescription() {
        return "Parameters: <numSubStasks> <dataPoints> <clusterCenters> <output> <numIterations>";
    }

    public static void main(String[] strArr) throws Exception {
        System.out.println(LocalExecutor.optimizerPlanAsJSON(new KMeansBroadcast().getPlan("4", "/dev/random", "/dev/random", "/tmp", "20")));
    }
}
