package org.apache.spark.ml.feature;

import java.util.Arrays;
import java.util.Iterator;
import org.apache.spark.ml.attribute.Attribute;
import org.apache.spark.ml.attribute.AttributeGroup;
import org.apache.spark.ml.attribute.NumericAttribute;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.StructType;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/apache/spark/ml/feature/JavaVectorSlicerSuite.class */
public class JavaVectorSlicerSuite {
    private transient SparkSession spark;

    @Before
    public void setUp() {
        this.spark = SparkSession.builder().master("local").appName("JavaVectorSlicerSuite").getOrCreate();
    }

    @After
    public void tearDown() {
        this.spark.stop();
        this.spark = null;
    }

    @Test
    public void vectorSlice() {
        Dataset createDataFrame = this.spark.createDataFrame(Arrays.asList(RowFactory.create(new Object[]{Vectors.sparse(3, new int[]{0, 1}, new double[]{-2.0d, 2.3d})}), RowFactory.create(new Object[]{Vectors.dense(-2.0d, new double[]{2.3d, 0.0d})})), new StructType().add(new AttributeGroup("userFeatures", new Attribute[]{NumericAttribute.defaultAttr().withName("f1"), NumericAttribute.defaultAttr().withName("f2"), NumericAttribute.defaultAttr().withName("f3")}).toStructField()));
        VectorSlicer outputCol = new VectorSlicer().setInputCol("userFeatures").setOutputCol("features");
        outputCol.setIndices(new int[]{1}).setNames(new String[]{"f3"});
        Iterator it = outputCol.transform(createDataFrame).select("userFeatures", new String[]{"features"}).takeAsList(2).iterator();
        while (it.hasNext()) {
            Assert.assertEquals(((Vector) ((Row) it.next()).getAs(1)).size(), 2L);
        }
    }
}
