/*
 * Decompiled with CFR 0.152.
 */
package test.org.apache.spark.sql;

import java.io.Serializable;
import java.util.List;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.api.java.UDF1;
import org.apache.spark.sql.api.java.UDF2;
import org.apache.spark.sql.types.DataTypes;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

public class JavaUDFSuite
implements Serializable {
    private transient SparkSession spark;

    @Before
    public void setUp() {
        this.spark = SparkSession.builder().master("local[*]").appName("testing").getOrCreate();
    }

    @After
    public void tearDown() {
        this.spark.stop();
        this.spark = null;
    }

    @Test
    public void udf1Test() {
        this.spark.udf().register("stringLengthTest", (UDF1)new UDF1<String, Integer>(){

            public Integer call(String str) {
                return str.length();
            }
        }, DataTypes.IntegerType);
        Row result = (Row)this.spark.sql("SELECT stringLengthTest('test')").head();
        Assert.assertEquals((long)4L, (long)result.getInt(0));
    }

    @Test
    public void udf2Test() {
        this.spark.udf().register("stringLengthTest", (UDF2)new UDF2<String, String, Integer>(){

            public Integer call(String str1, String str2) {
                return str1.length() + str2.length();
            }
        }, DataTypes.IntegerType);
        Row result = (Row)this.spark.sql("SELECT stringLengthTest('test', 'test2')").head();
        Assert.assertEquals((long)9L, (long)result.getInt(0));
    }

    @Test
    public void udf3Test() {
        this.spark.udf().registerJava("stringLengthTest", StringLengthTest.class.getName(), DataTypes.IntegerType);
        Row result = (Row)this.spark.sql("SELECT stringLengthTest('test', 'test2')").head();
        Assert.assertEquals((long)9L, (long)result.getInt(0));
        this.spark.udf().registerJava("stringLengthTest2", StringLengthTest.class.getName(), null);
        result = (Row)this.spark.sql("SELECT stringLengthTest('test', 'test2')").head();
        Assert.assertEquals((long)9L, (long)result.getInt(0));
    }

    @Test
    public void udf4Test() {
        this.spark.udf().register("inc", (UDF1)new UDF1<Long, Long>(){

            public Long call(Long i) {
                return i + 1L;
            }
        }, DataTypes.LongType);
        this.spark.range(10L).toDF(new String[]{"x"}).createOrReplaceTempView("tmp");
        List results = this.spark.sql("SELECT inc(x) FROM tmp GROUP BY inc(x)").collectAsList();
        Assert.assertEquals((long)10L, (long)results.size());
        long sum = 0L;
        for (Row result : results) {
            sum += result.getLong(0);
        }
        Assert.assertEquals((long)55L, (long)sum);
    }

    public static class StringLengthTest
    implements UDF2<String, String, Integer> {
        public Integer call(String str1, String str2) throws Exception {
            return new Integer(str1.length() + str2.length());
        }
    }
}

