package org.apache.flink.ml.benchmark.datagenerator.clustering;

import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import org.apache.flink.ml.benchmark.datagenerator.DataGenerator;
import org.apache.flink.ml.benchmark.datagenerator.common.DenseVectorArrayGenerator;
import org.apache.flink.ml.benchmark.datagenerator.param.HasArraySize;
import org.apache.flink.ml.benchmark.datagenerator.param.HasVectorDim;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.Expressions;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.catalog.DataTypeFactory;
import org.apache.flink.table.expressions.Expression;
import org.apache.flink.table.functions.ScalarFunction;
import org.apache.flink.table.types.inference.TypeInference;

/* loaded from: input_file:org/apache/flink/ml/benchmark/datagenerator/clustering/KMeansModelDataGenerator.class */
public class KMeansModelDataGenerator implements DataGenerator<KMeansModelDataGenerator>, HasVectorDim<KMeansModelDataGenerator>, HasArraySize<KMeansModelDataGenerator> {
    private final Map<Param<?>, Object> paramMap = new HashMap();

    /* loaded from: input_file:org/apache/flink/ml/benchmark/datagenerator/clustering/KMeansModelDataGenerator$GenerateWeightsFunction.class */
    public static class GenerateWeightsFunction extends ScalarFunction {
        public DenseVector eval(DenseVector[] denseVectorArr) {
            return new DenseVector(denseVectorArr.length);
        }

        public TypeInference getTypeInference(DataTypeFactory dataTypeFactory) {
            return TypeInference.newBuilder().outputTypeStrategy(callContext -> {
                return Optional.of(DataTypes.of(DenseVectorTypeInfo.INSTANCE).toDataType(dataTypeFactory));
            }).build();
        }
    }

    public KMeansModelDataGenerator() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r1v5, types: [java.lang.String[], java.lang.String[][]] */
    @Override // org.apache.flink.ml.benchmark.datagenerator.DataGenerator
    public Table[] getData(StreamTableEnvironment streamTableEnvironment) {
        DenseVectorArrayGenerator denseVectorArrayGenerator = new DenseVectorArrayGenerator();
        ParamUtils.updateExistingParams(denseVectorArrayGenerator, this.paramMap);
        denseVectorArrayGenerator.setNumValues(1L);
        denseVectorArrayGenerator.setColNames(new String[]{new String[]{"centroids"}});
        return new Table[]{denseVectorArrayGenerator.getData(streamTableEnvironment)[0].select(new Expression[]{Expressions.$("centroids"), (Expression) Expressions.call(GenerateWeightsFunction.class, new Object[]{Expressions.$("centroids")}).as("weights", new String[0])})};
    }

    @Override // org.apache.flink.ml.param.WithParams
    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }
}
