package org.apache.beam.sdk.io.singlestore;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import javax.sql.DataSource;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.io.common.TestRow;
import org.apache.beam.sdk.io.singlestore.SingleStoreIO;
import org.apache.beam.sdk.io.singlestore.TestHelper;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Count;
import org.apache.beam.sdk.values.PCollection;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.Mockito;
import org.mockito.stubbing.OngoingStubbing;

@RunWith(JUnit4.class)
/* loaded from: input_file:org/apache/beam/sdk/io/singlestore/ReadWithPartitionsTest.class */
public class ReadWithPartitionsTest {

    @Rule
    public final transient TestPipeline pipeline = TestPipeline.create();
    public final transient Pipeline pipelineForErrorChecks = Pipeline.create();
    private static SingleStoreIO.DataSourceConfiguration dataSourceConfiguration;
    private static final int EXPECTED_ROW_COUNT = 10;

    ResultSet getMockResultSet(int i, int i2) throws SQLException {
        ResultSet resultSet = (ResultSet) Mockito.mock(ResultSet.class, Mockito.withSettings().serializable());
        OngoingStubbing when = Mockito.when(Boolean.valueOf(resultSet.next()));
        for (int i3 = i; i3 < i2; i3++) {
            when = when.thenReturn(true);
        }
        when.thenReturn(false);
        OngoingStubbing when2 = Mockito.when(Integer.valueOf(resultSet.getInt(1)));
        for (int i4 = i; i4 < i2; i4++) {
            when2 = when2.thenReturn(Integer.valueOf(i4));
        }
        OngoingStubbing when3 = Mockito.when(resultSet.getString(2));
        for (int i5 = i; i5 < i2; i5++) {
            when3 = when3.thenReturn(TestRow.getNameForSeed(Integer.valueOf(i5)));
        }
        return resultSet;
    }

    @Before
    public void init() throws SQLException {
        ResultSet mockResultSet = getMockResultSet(0, 5);
        ResultSet mockResultSet2 = getMockResultSet(5, EXPECTED_ROW_COUNT);
        ResultSet resultSet = (ResultSet) Mockito.mock(ResultSet.class, Mockito.withSettings().serializable());
        Mockito.when(Boolean.valueOf(resultSet.next())).thenReturn(true).thenReturn(false);
        Mockito.when(Integer.valueOf(resultSet.getInt(1))).thenReturn(2);
        Statement statement = (Statement) Mockito.mock(Statement.class, Mockito.withSettings().serializable());
        Mockito.when(statement.executeQuery("SELECT num_partitions FROM information_schema.DISTRIBUTED_DATABASES WHERE database_name = 'db'")).thenReturn(resultSet);
        PreparedStatement preparedStatement = (PreparedStatement) Mockito.mock(PreparedStatement.class, Mockito.withSettings().serializable());
        Mockito.when(preparedStatement.executeQuery()).thenReturn(mockResultSet);
        PreparedStatement preparedStatement2 = (PreparedStatement) Mockito.mock(PreparedStatement.class, Mockito.withSettings().serializable());
        Mockito.when(preparedStatement2.executeQuery()).thenReturn(mockResultSet2);
        Connection connection = (Connection) Mockito.mock(Connection.class, Mockito.withSettings().serializable());
        Mockito.when(connection.createStatement()).thenReturn(statement);
        Mockito.when(connection.prepareStatement("SELECT * FROM (SELECT * FROM `t`) WHERE partition_id()=0")).thenReturn(preparedStatement);
        Mockito.when(connection.prepareStatement("SELECT * FROM (SELECT * FROM `t`) WHERE partition_id()=1")).thenReturn(preparedStatement2);
        DataSource dataSource = (DataSource) Mockito.mock(DataSource.class, Mockito.withSettings().serializable());
        Mockito.when(dataSource.getConnection()).thenReturn(connection);
        dataSourceConfiguration = (SingleStoreIO.DataSourceConfiguration) Mockito.mock(SingleStoreIO.DataSourceConfiguration.class, Mockito.withSettings().serializable());
        Mockito.when(dataSourceConfiguration.getDataSource()).thenReturn(dataSource);
        Mockito.when(dataSourceConfiguration.getDatabase()).thenReturn("db");
    }

    @Test
    public void testReadWithPartitions() {
        PCollection apply = this.pipeline.apply(SingleStoreIO.readWithPartitions().withDataSourceConfiguration(dataSourceConfiguration).withQuery("SELECT * FROM `t`").withRowMapper(new TestHelper.TestRowMapper()));
        PAssert.thatSingleton(apply.apply("Count All", Count.globally())).isEqualTo(10L);
        PAssert.that(apply).containsInAnyOrder(TestRow.getExpectedValues(0, EXPECTED_ROW_COUNT));
        this.pipeline.run();
    }

    @Test
    public void testReadWithPartitionsWithTable() {
        PCollection apply = this.pipeline.apply(SingleStoreIO.readWithPartitions().withDataSourceConfiguration(dataSourceConfiguration).withTable("t").withRowMapper(new TestHelper.TestRowMapper()));
        PAssert.thatSingleton(apply.apply("Count All", Count.globally())).isEqualTo(10L);
        PAssert.that(apply).containsInAnyOrder(TestRow.getExpectedValues(0, EXPECTED_ROW_COUNT));
        this.pipeline.run();
    }

    @Test
    public void testReadWithPartitionsNoTableAndQuery() {
        Assert.assertThrows("One of withTable() or withQuery() is required", IllegalArgumentException.class, () -> {
            this.pipelineForErrorChecks.apply(SingleStoreIO.readWithPartitions().withDataSourceConfiguration(dataSourceConfiguration).withRowMapper(new TestHelper.TestRowMapper()));
        });
    }

    @Test
    public void testReadWithPartitionsBothTableAndQuery() {
        Assert.assertThrows("withTable() can not be used together with withQuery()", IllegalArgumentException.class, () -> {
            this.pipelineForErrorChecks.apply(SingleStoreIO.readWithPartitions().withDataSourceConfiguration(dataSourceConfiguration).withTable("t").withQuery("SELECT * FROM `t`").withRowMapper(new TestHelper.TestRowMapper()));
        });
    }
}
