/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.sdk.extensions.sql.error;

import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.extensions.sql.SqlTransform;
import org.apache.beam.sdk.extensions.sql.TestUtils;
import org.apache.beam.sdk.extensions.sql.error.CalculatePrice;
import org.apache.beam.sdk.extensions.sql.error.CastUdf;
import org.apache.beam.sdk.extensions.sql.meta.provider.test.TestTableUtils;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.SchemaCoder;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PInput;
import org.apache.beam.sdk.values.Row;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Matchers;
import org.mockito.Mockito;
import org.mockito.verification.VerificationMode;

public class BeamSqlErrorTest {
    private static final String ID = "id";
    private static final String AMOUNT = "amount";
    private static final String COUNTRY_CODE = "country_code";
    private static final String CURRENCY = "currency";
    private static final String invalidAmount = "100$";
    private static final String invalidCurrency = "*";
    private static final String SUM_AMOUNT = "sum_amount";
    private static final String F_3 = "f3";
    private static final String ROW = "row";
    private static final String ERROR = "error";
    @Rule
    public final TestPipeline pipeline = TestPipeline.create();
    protected PCollection<Row> boundedInputBytes;
    static List<Row> inputRows;
    static Schema inputType;

    @BeforeClass
    public static void prepareClass() {
        inputType = Schema.builder().addStringField(ID).addStringField(AMOUNT).addStringField(COUNTRY_CODE).addStringField(CURRENCY).build();
        inputRows = TestUtils.RowsBuilder.of(inputType).addRows("1", "100", "US", "$", "2", invalidAmount, "US", "$", "3", "100", "US", invalidCurrency).getRows();
    }

    @Before
    public void preparePCollections() {
        this.boundedInputBytes = (PCollection)this.pipeline.apply("boundedInput", (PTransform)Create.of(inputRows).withRowSchema(inputType));
    }

    @Test
    public void testFailedExpression() {
        Schema resultType = Schema.builder().addStringField(ID).addStringField(COUNTRY_CODE).addDoubleField(SUM_AMOUNT).build();
        Schema midResultType = Schema.builder().addStringField(ID).addStringField(COUNTRY_CODE).addStringField(CURRENCY).addInt64Field(F_3).build();
        String sql = "SELECT id,country_code,CalculatePrice(sum(CastUdf(amount)),currency) as sum_amount FROM PCOLLECTION group by id,country_code,currency";
        PTransform mock = (PTransform)Mockito.spy(PTransform.class);
        Mockito.when((Object)mock.expand((PInput)Matchers.any())).thenAnswer(invocationOnMock -> (PCollection)invocationOnMock.getArgument(0, PCollection.class));
        ArgumentCaptor captor = ArgumentCaptor.forClass(PCollection.class);
        PCollection validRowsResult = ((PCollection)this.boundedInputBytes.apply("calculate", (PTransform)SqlTransform.query((String)sql).withAutoLoading(false).withErrorsTransformer(mock).registerUdf("CastUdf", CastUdf.class).registerUdf("CalculatePrice", CalculatePrice.class))).setCoder((Coder)SchemaCoder.of((Schema)resultType));
        PAssert.that((PCollection)validRowsResult).containsInAnyOrder(TestUtils.RowsBuilder.of(resultType).addRows("1", "US", 100.0).getRows());
        Schema firstErrorSchema = Schema.builder().addRowField(ROW, inputType).addStringField(ERROR).build();
        Row failedOnFirstUdfElement = (Row)TestTableUtils.buildRows((Schema)firstErrorSchema, Arrays.asList((Serializable)TestTableUtils.buildRows((Schema)inputType, Arrays.asList("2", invalidAmount, "US", "$")).get(0), "Found invalid value 100$")).get(0);
        Schema secondErrorSchema = Schema.builder().addRowField(ROW, midResultType).addStringField(ERROR).build();
        Row failedOnSecondUdfElement = (Row)TestTableUtils.buildRows((Schema)secondErrorSchema, Arrays.asList((Serializable)TestTableUtils.buildRows((Schema)midResultType, Arrays.asList("3", "US", invalidCurrency, 100L)).get(0), "Currency isn't supported *")).get(0);
        ((PTransform)Mockito.verify((Object)mock, (VerificationMode)Mockito.times((int)2))).expand((PInput)captor.capture());
        PAssert.that((PCollection)((PCollection)captor.getAllValues().get(0))).containsInAnyOrder(new Object[]{failedOnFirstUdfElement});
        PAssert.that((PCollection)((PCollection)captor.getAllValues().get(1))).containsInAnyOrder(new Object[]{failedOnSecondUdfElement});
        this.pipeline.run().waitUntilFinish();
    }
}

