package org.apache.spark.sql;

import java.io.Serializable;
import java.math.BigDecimal;
import java.util.Arrays;
import org.apache.spark.api.java.function.MapFunction;
import org.apache.spark.sql.test.SparkConnectServerUtils;
import org.apache.spark.sql.types.StructType;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;

/* loaded from: input_file:org/apache/spark/sql/JavaEncoderSuite.class */
public class JavaEncoderSuite implements Serializable {
    private static SparkSession spark;

    @BeforeClass
    public static void setup() {
        spark = SparkConnectServerUtils.createSparkSession();
    }

    @AfterClass
    public static void tearDown() {
        spark.stop();
        spark = null;
        SparkConnectServerUtils.stop();
    }

    private static BigDecimal bigDec(long j, int i) {
        return BigDecimal.valueOf(j, i);
    }

    private <T> Dataset<T> dataset(Encoder<T> encoder, T... tArr) {
        return spark.createDataset(Arrays.asList(tArr), encoder);
    }

    @Test
    public void testSimpleEncoders() {
        Column col = functions.col("value");
        Assert.assertFalse(((Boolean) dataset(Encoders.BOOLEAN(), false, true, false).select(new Column[]{functions.every(col)}).as(Encoders.BOOLEAN()).head()).booleanValue());
        Assert.assertEquals(7L, ((Long) dataset(Encoders.BYTE(), (byte) -120, Byte.MAX_VALUE).select(new Column[]{functions.sum(col)}).as(Encoders.LONG()).head()).longValue());
        Assert.assertEquals(16L, ((Short) dataset(Encoders.SHORT(), (short) 16, (short) 2334).select(new Column[]{functions.min(col)}).as(Encoders.SHORT()).head()).shortValue());
        Assert.assertEquals(10L, ((Long) dataset(Encoders.INT(), 1, 2, 3, 4).select(new Column[]{functions.sum(col)}).as(Encoders.LONG()).head()).longValue());
        Assert.assertEquals(96L, ((Long) dataset(Encoders.LONG(), 77L, 19L).select(new Column[]{functions.sum(col)}).as(Encoders.LONG()).head()).longValue());
        Assert.assertEquals(0.12f, ((Float) dataset(Encoders.FLOAT(), Float.valueOf(0.12f), Float.valueOf(0.3f), Float.valueOf(44.0f)).select(new Column[]{functions.min(col)}).as(Encoders.FLOAT()).head()).floatValue(), 1.0E-4f);
        Assert.assertEquals(789.0d, ((Double) dataset(Encoders.DOUBLE(), Double.valueOf(789.0d), Double.valueOf(12.213d), Double.valueOf(10.01d)).select(new Column[]{functions.max(col)}).as(Encoders.DOUBLE()).head()).doubleValue(), 9.999999747378752E-5d);
        Assert.assertEquals(bigDec(1002L, 2), ((BigDecimal) dataset(Encoders.DECIMAL(), bigDec(1000L, 2), bigDec(2L, 2)).select(new Column[]{functions.sum(col)}).as(Encoders.DECIMAL()).head()).setScale(2));
    }

    @Test
    public void testRowEncoder() {
        Assert.assertEquals(Arrays.asList(RowFactory.create(new Object[]{1, "s1"}), RowFactory.create(new Object[]{2, "s2"})), spark.range(3L).map(new MapFunction<Long, Row>() { // from class: org.apache.spark.sql.JavaEncoderSuite.1
            public Row call(Long l) {
                return RowFactory.create(new Object[]{Integer.valueOf(l.intValue()), "s" + l});
            }
        }, Encoders.row(new StructType().add("a", "int").add("b", "string"))).filter(functions.col("a").geq(1)).collectAsList());
    }
}
