/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.test.recordJobs.kmeans;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.flink.api.common.Plan;
import org.apache.flink.api.common.Program;
import org.apache.flink.api.common.ProgramDescription;
import org.apache.flink.api.common.io.FileInputFormat;
import org.apache.flink.api.common.operators.GenericDataSinkBase;
import org.apache.flink.api.common.operators.Operator;
import org.apache.flink.api.java.record.functions.MapFunction;
import org.apache.flink.api.java.record.functions.ReduceFunction;
import org.apache.flink.api.java.record.io.CsvInputFormat;
import org.apache.flink.api.java.record.io.FileOutputFormat;
import org.apache.flink.api.java.record.operators.BulkIteration;
import org.apache.flink.api.java.record.operators.FileDataSink;
import org.apache.flink.api.java.record.operators.FileDataSource;
import org.apache.flink.api.java.record.operators.MapOperator;
import org.apache.flink.api.java.record.operators.ReduceOperator;
import org.apache.flink.client.LocalExecutor;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.memory.DataInputView;
import org.apache.flink.core.memory.DataOutputView;
import org.apache.flink.types.DoubleValue;
import org.apache.flink.types.IntValue;
import org.apache.flink.types.Record;
import org.apache.flink.types.Value;
import org.apache.flink.util.Collector;

public class KMeansBroadcast
implements Program,
ProgramDescription {
    private static final long serialVersionUID = 1L;

    public Plan getPlan(String ... args) {
        int parallelism = args.length > 0 ? Integer.parseInt(args[0]) : 1;
        String dataPointInput = args.length > 1 ? args[1] : "";
        String clusterInput = args.length > 2 ? args[2] : "";
        String output = args.length > 3 ? args[3] : "";
        int numIterations = args.length > 4 ? Integer.parseInt(args[4]) : 2;
        FileDataSource pointsSource = new FileDataSource((FileInputFormat)new CsvInputFormat('|', new Class[]{IntValue.class, DoubleValue.class, DoubleValue.class, DoubleValue.class}), dataPointInput, "Data Points");
        FileDataSource clustersSource = new FileDataSource((FileInputFormat)new CsvInputFormat('|', new Class[]{IntValue.class, DoubleValue.class, DoubleValue.class, DoubleValue.class}), clusterInput, "Centers");
        MapOperator dataPoints = MapOperator.builder((MapFunction)new PointBuilder()).name("Build data points").input((Operator)pointsSource).build();
        MapOperator clusterPoints = MapOperator.builder((MapFunction)new PointBuilder()).name("Build cluster points").input((Operator)clustersSource).build();
        BulkIteration iter = new BulkIteration("k-means loop");
        iter.setInput((Operator)clusterPoints);
        iter.setMaximumNumberOfIterations(numIterations);
        MapOperator findNearestClusterCenters = MapOperator.builder((MapFunction)new SelectNearestCenter()).setBroadcastVariable("centers", iter.getPartialSolution()).input((Operator)dataPoints).name("Find Nearest Centers").build();
        ReduceOperator recomputeClusterCenter = ReduceOperator.builder((ReduceFunction)new RecomputeClusterCenter(), IntValue.class, (int)0).input((Operator)findNearestClusterCenters).name("Recompute Center Positions").build();
        iter.setNextPartialSolution((Operator)recomputeClusterCenter);
        FileDataSink newClusterPoints = new FileDataSink((org.apache.flink.api.common.io.FileOutputFormat)new PointOutFormat(), output, (Operator)iter, "New Center Positions");
        Plan plan = new Plan((GenericDataSinkBase)newClusterPoints, "K-Means");
        plan.setDefaultParallelism(parallelism);
        return plan;
    }

    public String getDescription() {
        return "Parameters: <numSubStasks> <dataPoints> <clusterCenters> <output> <numIterations>";
    }

    public static void main(String[] args) throws Exception {
        System.out.println(LocalExecutor.optimizerPlanAsJSON((Plan)new KMeansBroadcast().getPlan("4", "/dev/random", "/dev/random", "/tmp", "20")));
    }

    public static final class PointOutFormat
    extends FileOutputFormat {
        private static final long serialVersionUID = 1L;
        private static final String format = "%d|%.1f|%.1f|%.1f|\n";

        public void writeRecord(Record record) throws IOException {
            int id = ((IntValue)record.getField(0, IntValue.class)).getValue();
            Point p = (Point)record.getField(1, Point.class);
            byte[] bytes = String.format(format, id, p.x, p.y, p.z).getBytes();
            this.stream.write(bytes);
        }
    }

    public static final class PointBuilder
    extends MapFunction {
        private static final long serialVersionUID = 1L;

        public void map(Record record, Collector<Record> out) throws Exception {
            double x = ((DoubleValue)record.getField(1, DoubleValue.class)).getValue();
            double y = ((DoubleValue)record.getField(2, DoubleValue.class)).getValue();
            double z = ((DoubleValue)record.getField(3, DoubleValue.class)).getValue();
            record.setField(1, (Value)new Point(x, y, z));
            out.collect((Object)record);
        }
    }

    @ReduceOperator.Combinable
    public static final class RecomputeClusterCenter
    extends ReduceFunction {
        private static final long serialVersionUID = 1L;
        private final Point p = new Point();

        public void reduce(Iterator<Record> points, Collector<Record> out) {
            Record sum = this.sumPointsAndCount(points);
            sum.setField(1, (Value)((Point)sum.getField(1, Point.class)).div(((IntValue)sum.getField(2, IntValue.class)).getValue()));
            out.collect((Object)sum);
        }

        public void combine(Iterator<Record> points, Collector<Record> out) {
            out.collect((Object)this.sumPointsAndCount(points));
        }

        private final Record sumPointsAndCount(Iterator<Record> dataPoints) {
            Record next = null;
            this.p.clear();
            int count = 0;
            while (dataPoints.hasNext()) {
                next = dataPoints.next();
                this.p.add((Point)next.getField(1, Point.class));
                count += ((IntValue)next.getField(2, IntValue.class)).getValue();
            }
            next.setField(1, (Value)this.p);
            next.setField(2, (Value)new IntValue(count));
            return next;
        }
    }

    public static final class SelectNearestCenter
    extends MapFunction {
        private static final long serialVersionUID = 1L;
        private final IntValue one = new IntValue(1);
        private final Record result = new Record(3);
        private List<PointWithId> centers = new ArrayList<PointWithId>();

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public void open(Configuration parameters) throws Exception {
            List clusterCenters = this.getRuntimeContext().getBroadcastVariable("centers");
            this.centers.clear();
            List list = clusterCenters;
            synchronized (list) {
                for (Record r : clusterCenters) {
                    this.centers.add(new PointWithId(((IntValue)r.getField(0, IntValue.class)).getValue(), (Point)r.getField(1, Point.class)));
                }
            }
        }

        public void map(Record dataPointRecord, Collector<Record> out) {
            Point p = (Point)dataPointRecord.getField(1, Point.class);
            double nearestDistance = Double.MAX_VALUE;
            int centerId = -1;
            for (PointWithId center : this.centers) {
                double distance = p.euclideanDistance(center.point);
                if (!(distance < nearestDistance)) continue;
                nearestDistance = distance;
                centerId = center.id;
            }
            this.result.setField(0, (Value)new IntValue(centerId));
            this.result.setField(1, (Value)p);
            this.result.setField(2, (Value)this.one);
            out.collect((Object)this.result);
        }
    }

    public static final class PointWithId {
        public int id;
        public Point point;

        public PointWithId(int id, Point p) {
            this.id = id;
            this.point = p;
        }
    }

    public static final class Point
    implements Value {
        private static final long serialVersionUID = 1L;
        public double x;
        public double y;
        public double z;

        public Point() {
        }

        public Point(double x, double y, double z) {
            this.x = x;
            this.y = y;
            this.z = z;
        }

        public void add(Point other) {
            this.x += other.x;
            this.y += other.y;
            this.z += other.z;
        }

        public Point div(long val) {
            this.x /= (double)val;
            this.y /= (double)val;
            this.z /= (double)val;
            return this;
        }

        public double euclideanDistance(Point other) {
            return Math.sqrt((this.x - other.x) * (this.x - other.x) + (this.y - other.y) * (this.y - other.y) + (this.z - other.z) * (this.z - other.z));
        }

        public void clear() {
            this.z = 0.0;
            this.y = 0.0;
            this.x = 0.0;
        }

        public void write(DataOutputView out) throws IOException {
            out.writeDouble(this.x);
            out.writeDouble(this.y);
            out.writeDouble(this.z);
        }

        public void read(DataInputView in) throws IOException {
            this.x = in.readDouble();
            this.y = in.readDouble();
            this.z = in.readDouble();
        }

        public String toString() {
            return "(" + this.x + "|" + this.y + "|" + this.z + ")";
        }
    }
}

