package org.apache.flink.ml.clustering.kmeans;

import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.collections.IteratorUtils;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
import org.apache.flink.iteration.DataStreamList;
import org.apache.flink.iteration.IterationBody;
import org.apache.flink.iteration.IterationBodyResult;
import org.apache.flink.iteration.Iterations;
import org.apache.flink.iteration.operator.OperatorStateUtils;
import org.apache.flink.ml.api.Estimator;
import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
import org.apache.flink.ml.common.datastream.DataStreamUtils;
import org.apache.flink.ml.common.distance.DistanceMeasure;
import org.apache.flink.ml.linalg.BLAS;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.VectorWithNorm;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.runtime.state.StateInitializationContext;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.types.Row;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:org/apache/flink/ml/clustering/kmeans/OnlineKMeans.class */
public class OnlineKMeans implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
    private final Map<Param<?>, Object> paramMap = new HashMap();
    private Table initModelDataTable;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/ml/clustering/kmeans/OnlineKMeans$FeaturesExtractor.class */
    public static class FeaturesExtractor implements MapFunction<Row, DenseVector> {
        private final String featuresCol;

        private FeaturesExtractor(String str) {
            this.featuresCol = str;
        }

        public DenseVector map(Row row) {
            return ((Vector) row.getField(this.featuresCol)).toDense();
        }
    }

    /* loaded from: input_file:org/apache/flink/ml/clustering/kmeans/OnlineKMeans$ModelDataGlobalReducer.class */
    private static class ModelDataGlobalReducer implements ReduceFunction<KMeansModelData> {
        private ModelDataGlobalReducer() {
        }

        public KMeansModelData reduce(KMeansModelData kMeansModelData, KMeansModelData kMeansModelData2) {
            DenseVector denseVector = kMeansModelData.weights;
            DenseVector[] denseVectorArr = kMeansModelData.centroids;
            DenseVector denseVector2 = kMeansModelData2.weights;
            DenseVector[] denseVectorArr2 = kMeansModelData2.centroids;
            int length = denseVectorArr2.length;
            int size = denseVectorArr2[0].size();
            for (int i = 0; i < length; i++) {
                for (int i2 = 0; i2 < size; i2++) {
                    denseVectorArr[i].values[i2] = ((denseVectorArr[i].values[i2] * denseVector.values[i]) + (denseVectorArr2[i].values[i2] * denseVector2.values[i])) / Math.max(denseVector.values[i] + denseVector2.values[i], 1.0E-16d);
                }
                double[] dArr = denseVector.values;
                int i3 = i;
                dArr[i3] = dArr[i3] + denseVector2.values[i];
            }
            return new KMeansModelData(denseVectorArr, denseVector);
        }
    }

    /* loaded from: input_file:org/apache/flink/ml/clustering/kmeans/OnlineKMeans$ModelDataLocalUpdater.class */
    private static class ModelDataLocalUpdater extends AbstractStreamOperator<KMeansModelData> implements TwoInputStreamOperator<DenseVector[], KMeansModelData, KMeansModelData> {
        private final DistanceMeasure distanceMeasure;
        private final int k;
        private final double decayFactor;
        private ListState<DenseVector[]> localBatchDataState;
        private ListState<KMeansModelData> modelDataState;

        private ModelDataLocalUpdater(DistanceMeasure distanceMeasure, int i, double d) {
            this.distanceMeasure = distanceMeasure;
            this.k = i;
            this.decayFactor = d;
        }

        public void initializeState(StateInitializationContext stateInitializationContext) throws Exception {
            super.initializeState(stateInitializationContext);
            this.localBatchDataState = stateInitializationContext.getOperatorStateStore().getListState(new ListStateDescriptor("localBatch", ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE)));
            this.modelDataState = stateInitializationContext.getOperatorStateStore().getListState(new ListStateDescriptor("modelData", KMeansModelData.class));
        }

        public void processElement1(StreamRecord<DenseVector[]> streamRecord) throws Exception {
            this.localBatchDataState.add((DenseVector[]) streamRecord.getValue());
            alignAndComputeModelData();
        }

        public void processElement2(StreamRecord<KMeansModelData> streamRecord) throws Exception {
            Preconditions.checkArgument(((KMeansModelData) streamRecord.getValue()).centroids.length == this.k);
            this.modelDataState.add((KMeansModelData) streamRecord.getValue());
            alignAndComputeModelData();
        }

        private void alignAndComputeModelData() throws Exception {
            if (((Iterable) this.modelDataState.get()).iterator().hasNext() && ((Iterable) this.localBatchDataState.get()).iterator().hasNext()) {
                KMeansModelData kMeansModelData = (KMeansModelData) OperatorStateUtils.getUniqueElement(this.modelDataState, "modelData").get();
                DenseVector[] denseVectorArr = kMeansModelData.centroids;
                VectorWithNorm[] vectorWithNormArr = new VectorWithNorm[kMeansModelData.centroids.length];
                for (int i = 0; i < vectorWithNormArr.length; i++) {
                    vectorWithNormArr[i] = new VectorWithNorm(kMeansModelData.centroids[i]);
                }
                DenseVector denseVector = kMeansModelData.weights;
                this.modelDataState.clear();
                List list = IteratorUtils.toList(((Iterable) this.localBatchDataState.get()).iterator());
                DenseVector[] denseVectorArr2 = (DenseVector[]) list.remove(0);
                this.localBatchDataState.update(list);
                int size = denseVectorArr[0].size();
                int numberOfParallelSubtasks = getRuntimeContext().getNumberOfParallelSubtasks();
                DenseVector[] denseVectorArr3 = new DenseVector[this.k];
                int[] iArr = new int[this.k];
                for (int i2 = 0; i2 < this.k; i2++) {
                    denseVectorArr3[i2] = new DenseVector(size);
                }
                for (DenseVector denseVector2 : denseVectorArr2) {
                    int findClosest = this.distanceMeasure.findClosest(vectorWithNormArr, new VectorWithNorm(denseVector2));
                    iArr[findClosest] = iArr[findClosest] + 1;
                    BLAS.axpy(1.0d, denseVector2, denseVectorArr3[findClosest]);
                }
                BLAS.scal(this.decayFactor / numberOfParallelSubtasks, denseVector);
                for (int i3 = 0; i3 < this.k; i3++) {
                    if (iArr[i3] != 0) {
                        DenseVector denseVector3 = denseVectorArr[i3];
                        denseVector.values[i3] = denseVector.values[i3] + iArr[i3];
                        double d = iArr[i3] / denseVector.values[i3];
                        BLAS.scal(1.0d - d, denseVector3);
                        BLAS.axpy(d / iArr[i3], denseVectorArr3[i3], denseVector3);
                    }
                }
                this.output.collect(new StreamRecord(new KMeansModelData(denseVectorArr, denseVector)));
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/ml/clustering/kmeans/OnlineKMeans$OnlineKMeansIterationBody.class */
    public static class OnlineKMeansIterationBody implements IterationBody {
        private final DistanceMeasure distanceMeasure;
        private final int k;
        private final double decayFactor;
        private final int batchSize;

        public OnlineKMeansIterationBody(DistanceMeasure distanceMeasure, int i, double d, int i2) {
            this.distanceMeasure = distanceMeasure;
            this.k = i;
            this.decayFactor = d;
            this.batchSize = i2;
        }

        @Override // org.apache.flink.iteration.IterationBody
        public IterationBodyResult process(DataStreamList dataStreamList, DataStreamList dataStreamList2) {
            DataStream dataStream = dataStreamList.get(0);
            DataStream dataStream2 = dataStreamList2.get(0);
            int parallelism = dataStream2.getParallelism();
            Preconditions.checkState(parallelism <= this.batchSize, "There are more subtasks in the training process than the number of elements in each batch. Some subtasks might be idling forever.");
            return new IterationBodyResult(DataStreamList.of(DataStreamUtils.generateBatchData(dataStream2, parallelism, this.batchSize).connect(dataStream.broadcast()).transform("ModelDataLocalUpdater", TypeInformation.of(KMeansModelData.class), new ModelDataLocalUpdater(this.distanceMeasure, this.k, this.decayFactor)).setParallelism(parallelism).countWindowAll(parallelism).reduce(new ModelDataGlobalReducer())), DataStreamList.of(dataStream));
        }
    }

    public OnlineKMeans() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.apache.flink.ml.api.Estimator
    public OnlineKMeansModel fit(Table... tableArr) {
        Preconditions.checkArgument(tableArr.length == 1);
        StreamTableEnvironment tableEnvironment = ((TableImpl) tableArr[0]).getTableEnvironment();
        DataStream map = tableEnvironment.toDataStream(tableArr[0]).map(new FeaturesExtractor(getFeaturesCol()));
        DataStream<KMeansModelData> modelDataStream = KMeansModelData.getModelDataStream(this.initModelDataTable);
        modelDataStream.getTransformation().setParallelism(1);
        OnlineKMeansModel modelData = new OnlineKMeansModel().setModelData(tableEnvironment.fromDataStream(Iterations.iterateUnboundedStreams(DataStreamList.of(modelDataStream), DataStreamList.of(map), new OnlineKMeansIterationBody(DistanceMeasure.getInstance(getDistanceMeasure()), getK(), getDecayFactor(), getGlobalBatchSize())).get(0)));
        ParamUtils.updateExistingParams(modelData, this.paramMap);
        return modelData;
    }

    @Override // org.apache.flink.ml.api.Stage
    public void save(String str) throws IOException {
        Preconditions.checkNotNull(this.initModelDataTable, "Initial Model Data Table should have been set.");
        ReadWriteUtils.saveMetadata(this, str);
        ReadWriteUtils.saveModelData(KMeansModelData.getModelDataStream(this.initModelDataTable), str, new KMeansModelData.ModelDataEncoder());
    }

    public static OnlineKMeans load(StreamTableEnvironment streamTableEnvironment, String str) throws IOException {
        OnlineKMeans onlineKMeans = (OnlineKMeans) ReadWriteUtils.loadStageParam(str);
        onlineKMeans.initModelDataTable = ReadWriteUtils.loadModelData(streamTableEnvironment, str, new KMeansModelData.ModelDataDecoder());
        return onlineKMeans;
    }

    @Override // org.apache.flink.ml.param.WithParams
    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }

    public OnlineKMeans setInitialModelData(Table table) {
        this.initModelDataTable = table;
        return this;
    }
}
