package org.apache.spark.ml.feature;

import java.util.Arrays;
import org.apache.spark.SharedSparkSession;
import org.apache.spark.ml.feature.VectorIndexerSuite;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/apache/spark/ml/feature/JavaVectorIndexerSuite.class */
public class JavaVectorIndexerSuite extends SharedSparkSession {
    @Test
    public void vectorIndexerAPI() {
        Dataset createDataFrame = this.spark.createDataFrame(this.jsc.parallelize(Arrays.asList(new VectorIndexerSuite.FeatureData(Vectors.dense(0.0d, new double[]{-2.0d})), new VectorIndexerSuite.FeatureData(Vectors.dense(1.0d, new double[]{3.0d})), new VectorIndexerSuite.FeatureData(Vectors.dense(1.0d, new double[]{4.0d}))), 2), VectorIndexerSuite.FeatureData.class);
        VectorIndexerModel fit = new VectorIndexer().setInputCol("features").setOutputCol("indexed").setMaxCategories(2).fit(createDataFrame);
        Assert.assertEquals(fit.numFeatures(), 2L);
        Assert.assertEquals(fit.javaCategoryMaps().size(), 1L);
        fit.transform(createDataFrame);
    }
}
