package au.csiro.pathling.sql;

import au.csiro.pathling.test.SpringBootUnitTest;
import au.csiro.pathling.test.assertions.DatasetAssert;
import au.csiro.pathling.test.builders.DatasetBuilder;
import java.util.Objects;
import java.util.stream.Stream;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.functions;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;

@SpringBootUnitTest
/* loaded from: input_file:au/csiro/pathling/sql/PruneSyntheticFieldsTest.class */
public class PruneSyntheticFieldsTest {

    @Autowired
    private SparkSession spark;
    private final Metadata metadata = Metadata.empty();
    private final StructType testStructType = DataTypes.createStructType(new StructField[]{new StructField("id", DataTypes.IntegerType, true, this.metadata), new StructField("name", DataTypes.StringType, true, this.metadata), new StructField("_fid", DataTypes.StringType, true, this.metadata)});

    @Test
    public void testPruneSyntheticFields() {
        Dataset repartition = new DatasetBuilder(this.spark).withIdColumn().withColumn("active", DataTypes.BooleanType).withColumn("gender", DataTypes.createArrayType(DataTypes.StringType)).withStructTypeColumns(this.testStructType).withRow("patient-1", true, new String[]{"array_value-00-00"}, RowFactory.create(new Object[]{1, "Test-1", "fid_value-00"})).withRow("patient-2", false, new String[]{"array_value-01-00", "array_value-01-01"}, RowFactory.create(new Object[]{2, "Test-2", "fid_value_01"})).withRow("patient-3", null, null, null).buildWithStructValue().repartition(1);
        Stream of = Stream.of((Object[]) repartition.columns());
        Objects.requireNonNull(repartition);
        DatasetAssert.of(repartition.select((Column[]) of.map(repartition::col).map(SqlExpressions::pruneSyntheticFields).toArray(i -> {
            return new Column[i];
        }))).hasRows(new DatasetBuilder(this.spark).withIdColumn().withColumn("active", DataTypes.BooleanType).withColumn("gender", DataTypes.createArrayType(DataTypes.StringType)).withStructColumn("id", DataTypes.IntegerType).withStructColumn("name", DataTypes.StringType).withRow("patient-1", true, new String[]{"array_value-00-00"}, RowFactory.create(new Object[]{1, "Test-1"})).withRow("patient-2", false, new String[]{"array_value-01-00", "array_value-01-01"}, RowFactory.create(new Object[]{2, "Test-2"})).withRow("patient-3", null, null, null).buildWithStructValue().repartition(1));
    }

    @Test
    public void testPruneInGroupBy() {
        Dataset repartition = new DatasetBuilder(this.spark).withIdColumn().withColumn("gender", DataTypes.StringType).withColumn("active", DataTypes.BooleanType).withStructTypeColumns(this.testStructType).withRow("patient-1", "male", true, RowFactory.create(new Object[]{1, "Test-1", "fid-00"})).withRow("patient-2", "female", false, RowFactory.create(new Object[]{2, "Test-2", "fid-01"})).withRow("patient-3", "male", true, null).withRow("patient-4", null, true, null).withRow("patient-5", "female", false, RowFactory.create(new Object[]{2, "Test-2", "fid-02"})).buildWithStructValue().repartition(1);
        DatasetAssert.of(repartition.groupBy(new Column[]{SqlExpressions.pruneSyntheticFields(repartition.col(repartition.columns()[repartition.columns().length - 1]))}).agg(functions.count(repartition.col("gender")), new Column[0])).hasRows(RowFactory.create(new Object[]{RowFactory.create(new Object[]{1, "Test-1"}), 1}), RowFactory.create(new Object[]{RowFactory.create(new Object[]{2, "Test-2"}), 2}), RowFactory.create(new Object[]{null, 1}));
    }
}
