package org.apache.beam.sdk.io.gcp.spanner;

import com.google.cloud.Timestamp;
import com.google.cloud.spanner.DatabaseClient;
import com.google.cloud.spanner.KeySet;
import com.google.cloud.spanner.Options;
import com.google.cloud.spanner.ReadOnlyTransaction;
import com.google.cloud.spanner.ResultSets;
import com.google.cloud.spanner.Statement;
import com.google.cloud.spanner.Struct;
import com.google.cloud.spanner.TimestampBound;
import com.google.cloud.spanner.Type;
import com.google.cloud.spanner.Value;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.apache.beam.sdk.io.gcp.spanner.SpannerIO;
import org.apache.beam.sdk.testing.NeedsRunner;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFnTester;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionView;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.Matchers;
import org.mockito.Mockito;

@RunWith(JUnit4.class)
/* loaded from: input_file:org/apache/beam/sdk/io/gcp/spanner/SpannerIOReadTest.class */
public class SpannerIOReadTest implements Serializable {

    @Rule
    public final transient TestPipeline pipeline = TestPipeline.create();

    @Rule
    public final transient ExpectedException thrown = ExpectedException.none();
    private FakeServiceFactory serviceFactory;
    private ReadOnlyTransaction mockTx;
    private static final Type FAKE_TYPE = Type.struct(new Type.StructField[]{Type.StructField.of("id", Type.int64()), Type.StructField.of("name", Type.string())});
    private static final List<Struct> FAKE_ROWS = Arrays.asList(Struct.newBuilder().add("id", Value.int64(1)).add("name", Value.string("Alice")).build(), Struct.newBuilder().add("id", Value.int64(2)).add("name", Value.string("Bob")).build(), Struct.newBuilder().add("id", Value.int64(3)).add("name", Value.string("Carl")).build(), Struct.newBuilder().add("id", Value.int64(4)).add("name", Value.string("Dan")).build());

    @Before
    public void setUp() throws Exception {
        this.serviceFactory = new FakeServiceFactory();
        this.mockTx = (ReadOnlyTransaction) Mockito.mock(ReadOnlyTransaction.class);
    }

    @Test
    public void runQuery() throws Exception {
        SpannerIO.Read withServiceFactory = SpannerIO.read().withProjectId("test").withInstanceId("123").withDatabaseId("aaa").withQuery("SELECT * FROM users").withServiceFactory(this.serviceFactory);
        DoFnTester of = DoFnTester.of(new NaiveSpannerReadFn(withServiceFactory.getSpannerConfig()));
        Mockito.when(this.serviceFactory.mockDatabaseClient().readOnlyTransaction((TimestampBound) Matchers.any(TimestampBound.class))).thenReturn(this.mockTx);
        Mockito.when(this.mockTx.executeQuery((Statement) Matchers.any(Statement.class), new Options.QueryOption[0])).thenReturn(ResultSets.forRows(FAKE_TYPE, FAKE_ROWS));
        Assert.assertThat(of.processBundle(new ReadOperation[]{withServiceFactory.getReadOperation()}), org.hamcrest.Matchers.containsInAnyOrder(FAKE_ROWS.toArray()));
        ((DatabaseClient) Mockito.verify(this.serviceFactory.mockDatabaseClient())).readOnlyTransaction(TimestampBound.strong());
        ((ReadOnlyTransaction) Mockito.verify(this.mockTx)).executeQuery(Statement.of("SELECT * FROM users"), new Options.QueryOption[0]);
    }

    @Test
    public void runRead() throws Exception {
        SpannerIO.Read withServiceFactory = SpannerIO.read().withProjectId("test").withInstanceId("123").withDatabaseId("aaa").withTable("users").withColumns(new String[]{"id", "name"}).withServiceFactory(this.serviceFactory);
        DoFnTester of = DoFnTester.of(new NaiveSpannerReadFn(withServiceFactory.getSpannerConfig()));
        Mockito.when(this.serviceFactory.mockDatabaseClient().readOnlyTransaction((TimestampBound) Matchers.any(TimestampBound.class))).thenReturn(this.mockTx);
        Mockito.when(this.mockTx.read("users", KeySet.all(), Arrays.asList("id", "name"), new Options.ReadOption[0])).thenReturn(ResultSets.forRows(FAKE_TYPE, FAKE_ROWS));
        Assert.assertThat(of.processBundle(new ReadOperation[]{withServiceFactory.getReadOperation()}), org.hamcrest.Matchers.containsInAnyOrder(FAKE_ROWS.toArray()));
        ((DatabaseClient) Mockito.verify(this.serviceFactory.mockDatabaseClient())).readOnlyTransaction(TimestampBound.strong());
        ((ReadOnlyTransaction) Mockito.verify(this.mockTx)).read("users", KeySet.all(), Arrays.asList("id", "name"), new Options.ReadOption[0]);
    }

    @Test
    public void runReadUsingIndex() throws Exception {
        SpannerIO.Read withServiceFactory = SpannerIO.read().withProjectId("test").withInstanceId("123").withDatabaseId("aaa").withTimestamp(Timestamp.now()).withTable("users").withColumns(new String[]{"id", "name"}).withIndex("theindex").withServiceFactory(this.serviceFactory);
        DoFnTester of = DoFnTester.of(new NaiveSpannerReadFn(withServiceFactory.getSpannerConfig()));
        Mockito.when(this.serviceFactory.mockDatabaseClient().readOnlyTransaction((TimestampBound) Matchers.any(TimestampBound.class))).thenReturn(this.mockTx);
        Mockito.when(this.mockTx.readUsingIndex("users", "theindex", KeySet.all(), Arrays.asList("id", "name"), new Options.ReadOption[0])).thenReturn(ResultSets.forRows(FAKE_TYPE, FAKE_ROWS));
        Assert.assertThat(of.processBundle(new ReadOperation[]{withServiceFactory.getReadOperation()}), org.hamcrest.Matchers.containsInAnyOrder(FAKE_ROWS.toArray()));
        ((DatabaseClient) Mockito.verify(this.serviceFactory.mockDatabaseClient())).readOnlyTransaction(TimestampBound.strong());
        ((ReadOnlyTransaction) Mockito.verify(this.mockTx)).readUsingIndex("users", "theindex", KeySet.all(), Arrays.asList("id", "name"), new Options.ReadOption[0]);
    }

    @Test
    @Category({NeedsRunner.class})
    public void readPipeline() throws Exception {
        Timestamp ofTimeMicroseconds = Timestamp.ofTimeMicroseconds(12345L);
        SpannerConfig withServiceFactory = SpannerConfig.create().withProjectId("test").withInstanceId("123").withDatabaseId("aaa").withServiceFactory(this.serviceFactory);
        PCollectionView apply = this.pipeline.apply("tx", SpannerIO.createTransaction().withSpannerConfig(withServiceFactory));
        PCollection apply2 = this.pipeline.apply("read q", SpannerIO.read().withSpannerConfig(withServiceFactory).withQuery("SELECT * FROM users").withTransaction(apply));
        PCollection apply3 = this.pipeline.apply("read r", SpannerIO.read().withSpannerConfig(withServiceFactory).withTimestamp(Timestamp.now()).withTable("users").withColumns(new String[]{"id", "name"}).withTransaction(apply));
        Mockito.when(this.serviceFactory.mockDatabaseClient().readOnlyTransaction((TimestampBound) Matchers.any(TimestampBound.class))).thenReturn(this.mockTx);
        Mockito.when(this.mockTx.executeQuery(Statement.of("SELECT 1"), new Options.QueryOption[0])).thenReturn(ResultSets.forRows(Type.struct(new Type.StructField[0]), Collections.emptyList()));
        Mockito.when(this.mockTx.executeQuery(Statement.of("SELECT * FROM users"), new Options.QueryOption[0])).thenReturn(ResultSets.forRows(FAKE_TYPE, FAKE_ROWS));
        Mockito.when(this.mockTx.read("users", KeySet.all(), Arrays.asList("id", "name"), new Options.ReadOption[0])).thenReturn(ResultSets.forRows(FAKE_TYPE, FAKE_ROWS));
        Mockito.when(this.mockTx.getReadTimestamp()).thenReturn(ofTimeMicroseconds);
        PAssert.that(apply2).containsInAnyOrder(FAKE_ROWS);
        PAssert.that(apply3).containsInAnyOrder(FAKE_ROWS);
        this.pipeline.run();
        ((DatabaseClient) Mockito.verify(this.serviceFactory.mockDatabaseClient(), Mockito.times(2))).readOnlyTransaction(TimestampBound.ofReadTimestamp(ofTimeMicroseconds));
    }

    @Test
    @Category({NeedsRunner.class})
    public void readAllPipeline() throws Exception {
        Timestamp ofTimeMicroseconds = Timestamp.ofTimeMicroseconds(12345L);
        SpannerConfig withServiceFactory = SpannerConfig.create().withProjectId("test").withInstanceId("123").withDatabaseId("aaa").withServiceFactory(this.serviceFactory);
        PCollection apply = this.pipeline.apply(Create.of(ReadOperation.create().withQuery("SELECT * FROM users"), new ReadOperation[]{ReadOperation.create().withTable("users").withColumns(new String[]{"id", "name"})})).apply("read all", SpannerIO.readAll().withSpannerConfig(withServiceFactory).withTransaction(this.pipeline.apply("tx", SpannerIO.createTransaction().withSpannerConfig(withServiceFactory))));
        Mockito.when(this.serviceFactory.mockDatabaseClient().readOnlyTransaction((TimestampBound) Matchers.any(TimestampBound.class))).thenReturn(this.mockTx);
        Mockito.when(this.mockTx.executeQuery(Statement.of("SELECT 1"), new Options.QueryOption[0])).thenReturn(ResultSets.forRows(Type.struct(new Type.StructField[0]), Collections.emptyList()));
        Mockito.when(this.mockTx.executeQuery(Statement.of("SELECT * FROM users"), new Options.QueryOption[0])).thenReturn(ResultSets.forRows(FAKE_TYPE, FAKE_ROWS.subList(0, 2)));
        Mockito.when(this.mockTx.read("users", KeySet.all(), Arrays.asList("id", "name"), new Options.ReadOption[0])).thenReturn(ResultSets.forRows(FAKE_TYPE, FAKE_ROWS.subList(2, 4)));
        Mockito.when(this.mockTx.getReadTimestamp()).thenReturn(ofTimeMicroseconds);
        PAssert.that(apply).containsInAnyOrder(FAKE_ROWS);
        this.pipeline.run();
        ((DatabaseClient) Mockito.verify(this.serviceFactory.mockDatabaseClient(), Mockito.times(2))).readOnlyTransaction(TimestampBound.ofReadTimestamp(ofTimeMicroseconds));
    }
}
