package org.apache.flink.ml.benchmark.datagenerator.common;

import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.ml.benchmark.datagenerator.param.HasVectorDim;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
import org.apache.flink.ml.param.IntParam;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.param.ParamValidators;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.types.Row;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:org/apache/flink/ml/benchmark/datagenerator/common/LabeledPointWithWeightGenerator.class */
public class LabeledPointWithWeightGenerator extends InputTableGenerator<LabeledPointWithWeightGenerator> implements HasVectorDim<LabeledPointWithWeightGenerator> {
    public static final Param<Integer> FEATURE_ARITY = new IntParam("featureArity", "Arity of each feature. If set to positive value, each feature would be an integer in range [0, arity - 1]. If set to zero, each feature would be a continuous double in range [0, 1).", 2, ParamValidators.gtEq(0.0d));
    public static final Param<Integer> LABEL_ARITY = new IntParam("labelArity", "Arity of label. If set to positive value, the label would be an integer in range [0, arity - 1]. If set to zero, the label would be a continuous double in range [0, 1).", 2, ParamValidators.gtEq(0.0d));

    public LabeledPointWithWeightGenerator() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
    }

    public int getFeatureArity() {
        return ((Integer) get(FEATURE_ARITY)).intValue();
    }

    public LabeledPointWithWeightGenerator setFeatureArity(int i) {
        return (LabeledPointWithWeightGenerator) set(FEATURE_ARITY, Integer.valueOf(i));
    }

    public int getLabelArity() {
        return ((Integer) get(LABEL_ARITY)).intValue();
    }

    public LabeledPointWithWeightGenerator setLabelArity(int i) {
        return (LabeledPointWithWeightGenerator) set(LABEL_ARITY, Integer.valueOf(i));
    }

    @Override // org.apache.flink.ml.benchmark.datagenerator.common.InputTableGenerator
    protected RowGenerator[] getRowGenerators() {
        final String[][] colNames = getColNames();
        Preconditions.checkState(colNames.length == 1);
        Preconditions.checkState(colNames[0].length == 3);
        final int vectorDim = getVectorDim();
        final int featureArity = getFeatureArity();
        final int labelArity = getLabelArity();
        return new RowGenerator[]{new RowGenerator(getNumValues(), getSeed()) { // from class: org.apache.flink.ml.benchmark.datagenerator.common.LabeledPointWithWeightGenerator.1
            @Override // org.apache.flink.ml.benchmark.datagenerator.common.RowGenerator
            protected Row getRow() {
                double[] dArr = new double[vectorDim];
                for (int i = 0; i < vectorDim; i++) {
                    dArr[i] = getValue(featureArity);
                }
                return Row.of(new Object[]{Vectors.dense(dArr), Double.valueOf(getValue(labelArity)), Double.valueOf(this.random.nextDouble())});
            }

            /* JADX INFO: Access modifiers changed from: protected */
            @Override // org.apache.flink.ml.benchmark.datagenerator.common.RowGenerator
            public RowTypeInfo getRowTypeInfo() {
                return new RowTypeInfo(new TypeInformation[]{DenseVectorTypeInfo.INSTANCE, Types.DOUBLE, Types.DOUBLE}, colNames[0]);
            }

            private double getValue(int i) {
                return i > 0 ? this.random.nextInt(i) : this.random.nextDouble();
            }
        }};
    }
}
