package org.apache.beam.sdk.extensions.sql.error;

import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import org.apache.beam.sdk.extensions.sql.SqlTransform;
import org.apache.beam.sdk.extensions.sql.TestUtils;
import org.apache.beam.sdk.extensions.sql.meta.provider.test.TestTableUtils;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.SchemaCoder;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PInput;
import org.apache.beam.sdk.values.Row;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Matchers;
import org.mockito.Mockito;

/* loaded from: input_file:org/apache/beam/sdk/extensions/sql/error/BeamSqlErrorTest.class */
public class BeamSqlErrorTest {
    private static final String ID = "id";
    private static final String AMOUNT = "amount";
    private static final String COUNTRY_CODE = "country_code";
    private static final String CURRENCY = "currency";
    private static final String invalidAmount = "100$";
    private static final String invalidCurrency = "*";
    private static final String SUM_AMOUNT = "sum_amount";
    private static final String F_3 = "f3";
    private static final String ROW = "row";
    private static final String ERROR = "error";

    @Rule
    public final TestPipeline pipeline = TestPipeline.create();
    protected PCollection<Row> boundedInputBytes;
    static List<Row> inputRows;
    static Schema inputType;

    @BeforeClass
    public static void prepareClass() {
        inputType = Schema.builder().addStringField(ID).addStringField(AMOUNT).addStringField(COUNTRY_CODE).addStringField(CURRENCY).build();
        inputRows = TestUtils.RowsBuilder.of(inputType).addRows("1", "100", "US", "$", "2", invalidAmount, "US", "$", "3", "100", "US", invalidCurrency).getRows();
    }

    @Before
    public void preparePCollections() {
        this.boundedInputBytes = this.pipeline.apply("boundedInput", Create.of(inputRows).withRowSchema(inputType));
    }

    @Test
    public void testFailedExpression() {
        Schema build = Schema.builder().addStringField(ID).addStringField(COUNTRY_CODE).addDoubleField(SUM_AMOUNT).build();
        Schema build2 = Schema.builder().addStringField(ID).addStringField(COUNTRY_CODE).addStringField(CURRENCY).addInt64Field(F_3).build();
        PTransform pTransform = (PTransform) Mockito.spy(PTransform.class);
        Mockito.when(pTransform.expand((PInput) Matchers.any())).thenAnswer(invocationOnMock -> {
            return (PCollection) invocationOnMock.getArgument(0, PCollection.class);
        });
        ArgumentCaptor forClass = ArgumentCaptor.forClass(PCollection.class);
        PAssert.that(this.boundedInputBytes.apply("calculate", SqlTransform.query("SELECT id,country_code,CalculatePrice(sum(CastUdf(amount)),currency) as sum_amount FROM PCOLLECTION group by id,country_code,currency").withAutoLoading(false).withErrorsTransformer(pTransform).registerUdf("CastUdf", CastUdf.class).registerUdf("CalculatePrice", CalculatePrice.class)).setCoder(SchemaCoder.of(build))).containsInAnyOrder(TestUtils.RowsBuilder.of(build).addRows("1", "US", Double.valueOf(100.0d)).getRows());
        Row row = (Row) TestTableUtils.buildRows(Schema.builder().addRowField(ROW, inputType).addStringField(ERROR).build(), Arrays.asList((Serializable) TestTableUtils.buildRows(inputType, Arrays.asList("2", invalidAmount, "US", "$")).get(0), "Found invalid value 100$")).get(0);
        Row row2 = (Row) TestTableUtils.buildRows(Schema.builder().addRowField(ROW, build2).addStringField(ERROR).build(), Arrays.asList((Serializable) TestTableUtils.buildRows(build2, Arrays.asList("3", "US", invalidCurrency, 100L)).get(0), "Currency isn't supported *")).get(0);
        ((PTransform) Mockito.verify(pTransform, Mockito.times(2))).expand((PInput) forClass.capture());
        PAssert.that((PCollection) forClass.getAllValues().get(0)).containsInAnyOrder(new Object[]{row});
        PAssert.that((PCollection) forClass.getAllValues().get(1)).containsInAnyOrder(new Object[]{row2});
        this.pipeline.run().waitUntilFinish();
    }
}
