package org.apache.flink.ml.clustering.kmeans;

import java.io.IOException;
import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.iteration.DataStreamList;
import org.apache.flink.iteration.IterationBody;
import org.apache.flink.iteration.IterationBodyResult;
import org.apache.flink.iteration.IterationConfig;
import org.apache.flink.iteration.IterationListener;
import org.apache.flink.iteration.Iterations;
import org.apache.flink.iteration.ReplayableDataStreamList;
import org.apache.flink.iteration.datacache.nonkeyed.ListStateWithCache;
import org.apache.flink.iteration.operator.OperatorStateUtils;
import org.apache.flink.ml.api.Estimator;
import org.apache.flink.ml.common.datastream.DataStreamUtils;
import org.apache.flink.ml.common.distance.DistanceMeasure;
import org.apache.flink.ml.common.iteration.ForwardInputsOfLastRound;
import org.apache.flink.ml.common.iteration.TerminateOnMaxIter;
import org.apache.flink.ml.linalg.BLAS;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.VectorWithNorm;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
import org.apache.flink.ml.linalg.typeinfo.VectorWithNormSerializer;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.runtime.state.StateInitializationContext;
import org.apache.flink.runtime.state.StateSnapshotContext;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:org/apache/flink/ml/clustering/kmeans/KMeans.class */
public class KMeans implements Estimator<KMeans, KMeansModel>, KMeansParams<KMeans> {
    private final Map<Param<?>, Object> paramMap = new HashMap();

    /* loaded from: input_file:org/apache/flink/ml/clustering/kmeans/KMeans$CentroidsUpdateAccumulator.class */
    private static class CentroidsUpdateAccumulator extends AbstractStreamOperator<Tuple2<Integer[], DenseVector[]>> implements TwoInputStreamOperator<DenseVector, DenseVector[], Tuple2<Integer[], DenseVector[]>>, IterationListener<Tuple2<Integer[], DenseVector[]>> {
        private final DistanceMeasure distanceMeasure;
        private ListState<DenseVector[]> centroids;
        private ListStateWithCache<VectorWithNorm> points;

        public CentroidsUpdateAccumulator(DistanceMeasure distanceMeasure) {
            this.distanceMeasure = distanceMeasure;
        }

        public void initializeState(StateInitializationContext stateInitializationContext) throws Exception {
            super.initializeState(stateInitializationContext);
            this.centroids = stateInitializationContext.getOperatorStateStore().getListState(new ListStateDescriptor("centroids", ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE)));
            this.points = new ListStateWithCache<>(new VectorWithNormSerializer(), getContainingTask(), getRuntimeContext(), stateInitializationContext, this.config.getOperatorID());
        }

        public void snapshotState(StateSnapshotContext stateSnapshotContext) throws Exception {
            super.snapshotState(stateSnapshotContext);
            this.points.snapshotState(stateSnapshotContext);
        }

        public void processElement1(StreamRecord<DenseVector> streamRecord) throws Exception {
            this.points.add(new VectorWithNorm((Vector) streamRecord.getValue()));
        }

        public void processElement2(StreamRecord<DenseVector[]> streamRecord) throws Exception {
            Preconditions.checkState(!((Iterable) this.centroids.get()).iterator().hasNext());
            this.centroids.add(streamRecord.getValue());
        }

        public void onEpochWatermarkIncremented(int i, IterationListener.Context context, Collector<Tuple2<Integer[], DenseVector[]>> collector) throws Exception {
            Vector[] vectorArr = (DenseVector[]) Objects.requireNonNull(OperatorStateUtils.getUniqueElement(this.centroids, "centroids").orElse(null));
            VectorWithNorm[] vectorWithNormArr = new VectorWithNorm[vectorArr.length];
            for (int i2 = 0; i2 < vectorWithNormArr.length; i2++) {
                vectorWithNormArr[i2] = new VectorWithNorm(vectorArr[i2]);
            }
            DenseVector[] denseVectorArr = new DenseVector[vectorArr.length];
            Integer[] numArr = new Integer[vectorArr.length];
            Arrays.fill((Object[]) numArr, (Object) 0);
            for (int i3 = 0; i3 < vectorArr.length; i3++) {
                denseVectorArr[i3] = new DenseVector(vectorArr[i3].size());
            }
            for (VectorWithNorm vectorWithNorm : this.points.get()) {
                int findClosest = this.distanceMeasure.findClosest(vectorWithNormArr, vectorWithNorm);
                BLAS.axpy(1.0d, vectorWithNorm.vector, denseVectorArr[findClosest]);
                Integer num = numArr[findClosest];
                numArr[findClosest] = Integer.valueOf(numArr[findClosest].intValue() + 1);
            }
            this.output.collect(new StreamRecord(Tuple2.of(numArr, denseVectorArr)));
            this.centroids.clear();
        }

        public void onIterationTerminated(IterationListener.Context context, Collector<Tuple2<Integer[], DenseVector[]>> collector) {
            this.centroids.clear();
            this.points.clear();
        }
    }

    /* loaded from: input_file:org/apache/flink/ml/clustering/kmeans/KMeans$CentroidsUpdateReducer.class */
    private static class CentroidsUpdateReducer implements ReduceFunction<Tuple2<Integer[], DenseVector[]>> {
        private CentroidsUpdateReducer() {
        }

        public Tuple2<Integer[], DenseVector[]> reduce(Tuple2<Integer[], DenseVector[]> tuple2, Tuple2<Integer[], DenseVector[]> tuple22) throws Exception {
            for (int i = 0; i < ((Integer[]) tuple2.f0).length; i++) {
                Integer[] numArr = (Integer[]) tuple2.f0;
                int i2 = i;
                numArr[i2] = Integer.valueOf(numArr[i2].intValue() + ((Integer[]) tuple22.f0)[i].intValue());
                BLAS.axpy(1.0d, ((DenseVector[]) tuple22.f1)[i], ((DenseVector[]) tuple2.f1)[i]);
            }
            return tuple2;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/ml/clustering/kmeans/KMeans$KMeansIterationBody.class */
    public static class KMeansIterationBody implements IterationBody {
        private final int maxIterationNum;
        private final DistanceMeasure distanceMeasure;

        public KMeansIterationBody(int i, DistanceMeasure distanceMeasure) {
            this.maxIterationNum = i;
            this.distanceMeasure = distanceMeasure;
        }

        public IterationBodyResult process(DataStreamList dataStreamList, DataStreamList dataStreamList2) {
            DataStream dataStream = dataStreamList.get(0);
            DataStream dataStream2 = dataStreamList2.get(0);
            SingleOutputStreamOperator flatMap = dataStream.flatMap(new TerminateOnMaxIter(this.maxIterationNum));
            SingleOutputStreamOperator transform = dataStream2.connect(dataStream.broadcast()).transform("CentroidsUpdateAccumulator", new TupleTypeInfo(new TypeInformation[]{BasicArrayTypeInfo.INT_ARRAY_TYPE_INFO, ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE)}), new CentroidsUpdateAccumulator(this.distanceMeasure));
            DataStreamUtils.setManagedMemoryWeight(transform.getTransformation(), 100L);
            SingleOutputStreamOperator map = transform.countWindowAll(transform.getParallelism()).reduce(new CentroidsUpdateReducer()).map(new ModelDataGenerator());
            return new IterationBodyResult(DataStreamList.of(new DataStream[]{map.map(kMeansModelData -> {
                return kMeansModelData.centroids;
            }).setParallelism(1)}), DataStreamList.of(new DataStream[]{map.flatMap(new ForwardInputsOfLastRound())}), flatMap);
        }

        private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
            String implMethodName = serializedLambda.getImplMethodName();
            boolean z = -1;
            switch (implMethodName.hashCode()) {
                case 1788433273:
                    if (implMethodName.equals("lambda$process$aa128551$1")) {
                        z = false;
                        break;
                    }
                    break;
            }
            switch (z) {
                case false:
                    if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("map") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/flink/ml/clustering/kmeans/KMeans$KMeansIterationBody") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/flink/ml/clustering/kmeans/KMeansModelData;)[Lorg/apache/flink/ml/linalg/DenseVector;")) {
                        return kMeansModelData -> {
                            return kMeansModelData.centroids;
                        };
                    }
                    break;
            }
            throw new IllegalArgumentException("Invalid lambda deserialization");
        }
    }

    /* loaded from: input_file:org/apache/flink/ml/clustering/kmeans/KMeans$ModelDataGenerator.class */
    private static class ModelDataGenerator implements MapFunction<Tuple2<Integer[], DenseVector[]>, KMeansModelData> {
        private ModelDataGenerator() {
        }

        public KMeansModelData map(Tuple2<Integer[], DenseVector[]> tuple2) throws Exception {
            double[] dArr = new double[((Integer[]) tuple2.f0).length];
            for (int i = 0; i < ((Integer[]) tuple2.f0).length; i++) {
                BLAS.scal(1.0d / ((Integer[]) tuple2.f0)[i].intValue(), ((DenseVector[]) tuple2.f1)[i]);
                dArr[i] = ((Integer[]) tuple2.f0)[i].intValue();
            }
            return new KMeansModelData((DenseVector[]) tuple2.f1, new DenseVector(dArr));
        }
    }

    public KMeans() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
    }

    /* renamed from: fit, reason: merged with bridge method [inline-methods] */
    public KMeansModel m20fit(Table... tableArr) {
        Preconditions.checkArgument(tableArr.length == 1);
        StreamTableEnvironment tableEnvironment = ((TableImpl) tableArr[0]).getTableEnvironment();
        DataStream map = tableEnvironment.toDataStream(tableArr[0]).map(row -> {
            return ((Vector) row.getField(getFeaturesCol())).toDense();
        });
        KMeansModel m21setModelData = new KMeansModel().m21setModelData(tableEnvironment.fromDataStream(Iterations.iterateBoundedStreamsUntilTermination(DataStreamList.of(new DataStream[]{selectRandomCentroids(map, getK(), getSeed())}), ReplayableDataStreamList.notReplay(new DataStream[]{map}), IterationConfig.newBuilder().setOperatorLifeCycle(IterationConfig.OperatorLifeCycle.ALL_ROUND).build(), new KMeansIterationBody(getMaxIter(), DistanceMeasure.getInstance(getDistanceMeasure()))).get(0)));
        ReadWriteUtils.updateExistingParams(m21setModelData, this.paramMap);
        return m21setModelData;
    }

    public void save(String str) throws IOException {
        ReadWriteUtils.saveMetadata(this, str);
    }

    public static KMeans load(StreamTableEnvironment streamTableEnvironment, String str) throws IOException {
        return ReadWriteUtils.loadStageParam(str);
    }

    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }

    public static DataStream<DenseVector[]> selectRandomCentroids(DataStream<DenseVector> dataStream, int i, long j) {
        DataStream<DenseVector[]> mapPartition = DataStreamUtils.mapPartition(DataStreamUtils.sample(dataStream, i, j), new MapPartitionFunction<DenseVector, DenseVector[]>() { // from class: org.apache.flink.ml.clustering.kmeans.KMeans.1
            public void mapPartition(Iterable<DenseVector> iterable, Collector<DenseVector[]> collector) {
                ArrayList arrayList = new ArrayList();
                Iterator<DenseVector> it = iterable.iterator();
                arrayList.getClass();
                it.forEachRemaining((v1) -> {
                    r1.add(v1);
                });
                collector.collect(arrayList.toArray(new DenseVector[0]));
            }
        });
        mapPartition.getTransformation().setParallelism(1);
        return mapPartition;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -550846040:
                if (implMethodName.equals("lambda$fit$33d572ed$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/common/functions/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("map") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/flink/ml/clustering/kmeans/KMeans") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/flink/types/Row;)Lorg/apache/flink/ml/linalg/DenseVector;")) {
                    KMeans kMeans = (KMeans) serializedLambda.getCapturedArg(0);
                    return row -> {
                        return ((Vector) row.getField(getFeaturesCol())).toDense();
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
