package org.apache.beam.sdk.io.singlestore;

import com.singlestore.jdbc.Statement;
import java.io.IOException;
import java.io.InputStream;
import java.io.Serializable;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import javax.sql.DataSource;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.io.GenerateSequence;
import org.apache.beam.sdk.io.common.TestRow;
import org.apache.beam.sdk.io.singlestore.SingleStoreIO;
import org.apache.beam.sdk.io.singlestore.TestHelper;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.Sum;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Splitter;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
import org.apache.commons.dbcp2.DelegatingStatement;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;

@RunWith(JUnit4.class)
/* loaded from: input_file:org/apache/beam/sdk/io/singlestore/WriteTest.class */
public class WriteTest {

    @Rule
    public final transient TestPipeline pipeline = TestPipeline.create();
    public final transient Pipeline pipelineForErrorChecks = Pipeline.create();
    private static SingleStoreIO.DataSourceConfiguration dataSourceConfiguration;
    private static final int EXPECTED_ROW_COUNT = 1000;
    private static final List<TestRow> writtenRows = Collections.synchronizedList(new ArrayList());

    /* loaded from: input_file:org/apache/beam/sdk/io/singlestore/WriteTest$BatchSizeChecker.class */
    private static class BatchSizeChecker implements SerializableFunction<Iterable<Integer>, Void> {
        Integer maxBatchSize;

        BatchSizeChecker(Integer num) {
            this.maxBatchSize = num;
        }

        public Void apply(Iterable<Integer> iterable) {
            Iterator<Integer> it = iterable.iterator();
            while (it.hasNext()) {
                Assert.assertTrue(it.next().intValue() <= this.maxBatchSize.intValue());
            }
            return null;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/beam/sdk/io/singlestore/WriteTest$ExecuteUpdate.class */
    public static class ExecuteUpdate implements Serializable, Answer<Integer> {
        SetInputStream inputStreamSetter;

        ExecuteUpdate(SetInputStream setInputStream) {
            this.inputStreamSetter = setInputStream;
        }

        /* renamed from: answer, reason: merged with bridge method [inline-methods] */
        public Integer m7answer(InvocationOnMock invocationOnMock) {
            InputStream inputStream = this.inputStreamSetter.getInputStream();
            StringBuilder sb = new StringBuilder();
            byte[] bArr = new byte[100];
            while (true) {
                try {
                    int read = inputStream.read(bArr);
                    if (read == -1) {
                        break;
                    }
                    for (int i = 0; i < read; i++) {
                        sb.append((char) bArr[i]);
                    }
                } catch (IOException e) {
                    throw new RuntimeException(e);
                }
            }
            List splitToList = Splitter.on('\n').omitEmptyStrings().splitToList(sb.toString());
            Iterator it = splitToList.iterator();
            while (it.hasNext()) {
                List splitToList2 = Splitter.on('\t').splitToList((String) it.next());
                WriteTest.writtenRows.add(TestRow.create(Integer.valueOf((String) splitToList2.get(0)), (String) splitToList2.get(1)));
            }
            return Integer.valueOf(splitToList.size());
        }
    }

    /* loaded from: input_file:org/apache/beam/sdk/io/singlestore/WriteTest$GetDataSource.class */
    private static class GetDataSource implements Serializable, Answer<DataSource> {
        private GetDataSource() {
        }

        /* renamed from: answer, reason: merged with bridge method [inline-methods] */
        public DataSource m8answer(InvocationOnMock invocationOnMock) throws SQLException {
            Statement statement = (Statement) Mockito.mock(Statement.class);
            SetInputStream setInputStream = new SetInputStream();
            ((Statement) Mockito.doAnswer(setInputStream).when(statement)).setNextLocalInfileInputStream((InputStream) Mockito.any());
            DelegatingStatement delegatingStatement = (DelegatingStatement) Mockito.mock(DelegatingStatement.class);
            Mockito.when(delegatingStatement.getInnermostDelegate()).thenReturn(statement);
            ((DelegatingStatement) Mockito.doAnswer(new ExecuteUpdate(setInputStream)).when(delegatingStatement)).executeUpdate("LOAD DATA LOCAL INFILE '###.tsv' INTO TABLE `t`");
            Connection connection = (Connection) Mockito.mock(Connection.class);
            Mockito.when(connection.createStatement()).thenReturn(delegatingStatement);
            DataSource dataSource = (DataSource) Mockito.mock(DataSource.class);
            Mockito.when(dataSource.getConnection()).thenReturn(connection);
            return dataSource;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/beam/sdk/io/singlestore/WriteTest$SetInputStream.class */
    public static class SetInputStream implements Serializable, Answer<Void> {
        InputStream inputStream;

        private SetInputStream() {
        }

        /* renamed from: answer, reason: merged with bridge method [inline-methods] */
        public Void m9answer(InvocationOnMock invocationOnMock) {
            this.inputStream = (InputStream) invocationOnMock.getArgument(0, InputStream.class);
            return null;
        }

        public InputStream getInputStream() {
            return this.inputStream;
        }
    }

    void checkRows() {
        Assert.assertEquals(new HashSet(Lists.newArrayList(TestRow.getExpectedValues(0, EXPECTED_ROW_COUNT))), new HashSet(writtenRows));
    }

    @After
    public void cleanup() {
        writtenRows.clear();
    }

    @Before
    public void init() {
        dataSourceConfiguration = (SingleStoreIO.DataSourceConfiguration) Mockito.mock(TestHelper.MockDataSourceConfiguration.class, Mockito.withSettings().serializable());
        ((SingleStoreIO.DataSourceConfiguration) Mockito.doAnswer(new GetDataSource()).when(dataSourceConfiguration)).getDataSource();
    }

    @Test
    public void testWrite() {
        PCollection apply = this.pipeline.apply(GenerateSequence.from(0L).to(1000L)).apply(ParDo.of(new TestRow.DeterministicallyConstructTestRowFn())).apply(SingleStoreIO.write().withDataSourceConfiguration(dataSourceConfiguration).withTable("t").withUserDataMapper(new TestHelper.TestUserDataMapper()).withBatchSize(334));
        PAssert.thatSingleton(apply.apply("Sum All", Sum.integersGlobally())).isEqualTo(Integer.valueOf(EXPECTED_ROW_COUNT));
        PAssert.that(apply).satisfies(new BatchSizeChecker(334));
        this.pipeline.run();
        checkRows();
    }

    @Test
    public void testWriteSmallBatchSize() {
        PCollection apply = this.pipeline.apply(GenerateSequence.from(0L).to(1000L)).apply(ParDo.of(new TestRow.DeterministicallyConstructTestRowFn())).apply(SingleStoreIO.write().withDataSourceConfiguration(dataSourceConfiguration).withTable("t").withUserDataMapper(new TestHelper.TestUserDataMapper()).withBatchSize(1));
        PAssert.thatSingleton(apply.apply("Sum All", Sum.integersGlobally())).isEqualTo(Integer.valueOf(EXPECTED_ROW_COUNT));
        PAssert.that(apply).satisfies(new BatchSizeChecker(1));
        this.pipeline.run();
        checkRows();
    }

    @Test
    public void testWriteBigBatchSize() {
        PCollection apply = this.pipeline.apply(GenerateSequence.from(0L).to(1000L)).apply(ParDo.of(new TestRow.DeterministicallyConstructTestRowFn())).apply(SingleStoreIO.write().withDataSourceConfiguration(dataSourceConfiguration).withTable("t").withUserDataMapper(new TestHelper.TestUserDataMapper()).withBatchSize(Integer.valueOf(EXPECTED_ROW_COUNT)));
        PAssert.thatSingleton(apply.apply("Sum All", Sum.integersGlobally())).isEqualTo(Integer.valueOf(EXPECTED_ROW_COUNT));
        PAssert.that(apply).satisfies(new BatchSizeChecker(Integer.valueOf(EXPECTED_ROW_COUNT)));
        this.pipeline.run();
        checkRows();
    }

    @Test
    public void testWriteInvalidBatchSize() {
        Assert.assertThrows("batchSize should be greater then 0", IllegalArgumentException.class, () -> {
            this.pipelineForErrorChecks.apply(GenerateSequence.from(0L).to(1000L)).apply(ParDo.of(new TestRow.DeterministicallyConstructTestRowFn())).apply(SingleStoreIO.write().withDataSourceConfiguration(dataSourceConfiguration).withTable("t").withUserDataMapper(new TestHelper.TestUserDataMapper()).withBatchSize(0));
        });
    }
}
