package org.apache.beam.sdk.io.jdbc;

import com.google.cloud.Timestamp;
import java.sql.SQLException;
import java.util.HashSet;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.function.Function;
import org.apache.beam.sdk.PipelineResult;
import org.apache.beam.sdk.coders.SerializableCoder;
import org.apache.beam.sdk.io.GenerateSequence;
import org.apache.beam.sdk.io.common.DatabaseTestHelper;
import org.apache.beam.sdk.io.common.HashingFn;
import org.apache.beam.sdk.io.common.IOITHelper;
import org.apache.beam.sdk.io.common.PostgresIOTestPipelineOptions;
import org.apache.beam.sdk.io.common.TestRow;
import org.apache.beam.sdk.io.jdbc.JdbcIO;
import org.apache.beam.sdk.io.jdbc.JdbcTestHelper;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.testutils.NamedTestResult;
import org.apache.beam.sdk.testutils.metrics.IOITMetrics;
import org.apache.beam.sdk.testutils.metrics.MetricsReader;
import org.apache.beam.sdk.testutils.metrics.TimeMonitor;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.Count;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.Top;
import org.apache.beam.sdk.values.PCollection;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.postgresql.ds.PGSimpleDataSource;

@RunWith(JUnit4.class)
/* loaded from: input_file:org/apache/beam/sdk/io/jdbc/JdbcIOIT.class */
public class JdbcIOIT {
    private static final String NAMESPACE = JdbcIOIT.class.getName();
    private static int numberOfRows;
    private static PGSimpleDataSource dataSource;
    private static String tableName;
    private static String bigQueryDataset;
    private static String bigQueryTable;
    private static Long tableSize;

    @Rule
    public TestPipeline pipelineWrite = TestPipeline.create();

    @Rule
    public TestPipeline pipelineRead = TestPipeline.create();

    @BeforeClass
    public static void setup() throws Exception {
        PostgresIOTestPipelineOptions readIOTestPipelineOptions = IOITHelper.readIOTestPipelineOptions(PostgresIOTestPipelineOptions.class);
        bigQueryDataset = readIOTestPipelineOptions.getBigQueryDataset();
        bigQueryTable = readIOTestPipelineOptions.getBigQueryTable();
        numberOfRows = readIOTestPipelineOptions.getNumberOfRecords().intValue();
        dataSource = DatabaseTestHelper.getPostgresDataSource(readIOTestPipelineOptions);
        tableName = DatabaseTestHelper.getTestTableName("IT");
        IOITHelper.executeWithRetry(JdbcIOIT::createTable);
        tableSize = (Long) DatabaseTestHelper.getPostgresTableSize(dataSource, tableName).orElse(0L);
    }

    private static void createTable() throws SQLException {
        DatabaseTestHelper.createTable(dataSource, tableName);
    }

    @AfterClass
    public static void tearDown() throws Exception {
        IOITHelper.executeWithRetry(JdbcIOIT::deleteTable);
    }

    private static void deleteTable() throws SQLException {
        DatabaseTestHelper.deleteTable(dataSource, tableName);
    }

    @Test
    public void testWriteThenRead() {
        PipelineResult runWrite = runWrite();
        runWrite.waitUntilFinish();
        PipelineResult runRead = runRead();
        runRead.waitUntilFinish();
        gatherAndPublishMetrics(runWrite, runRead);
    }

    private void gatherAndPublishMetrics(PipelineResult pipelineResult, PipelineResult pipelineResult2) {
        String uuid = UUID.randomUUID().toString();
        String timestamp = Timestamp.now().toString();
        new IOITMetrics(getWriteMetricSuppliers(uuid, timestamp), pipelineResult, NAMESPACE, uuid, timestamp).publish(bigQueryDataset, bigQueryTable);
        new IOITMetrics(getReadMetricSuppliers(uuid, timestamp), pipelineResult2, NAMESPACE, uuid, timestamp).publish(bigQueryDataset, bigQueryTable);
    }

    private Set<Function<MetricsReader, NamedTestResult>> getWriteMetricSuppliers(String str, String str2) {
        HashSet hashSet = new HashSet();
        Optional postgresTableSize = DatabaseTestHelper.getPostgresTableSize(dataSource, tableName);
        hashSet.add(metricsReader -> {
            return NamedTestResult.create(str, str2, "write_time", (metricsReader.getEndTimeMetric("write_time") - metricsReader.getStartTimeMetric("write_time")) / 1000.0d);
        });
        postgresTableSize.ifPresent(l -> {
            hashSet.add(metricsReader2 -> {
                return NamedTestResult.create(str, str2, "total_size", l.longValue() - tableSize.longValue());
            });
        });
        return hashSet;
    }

    private Set<Function<MetricsReader, NamedTestResult>> getReadMetricSuppliers(String str, String str2) {
        HashSet hashSet = new HashSet();
        hashSet.add(metricsReader -> {
            return NamedTestResult.create(str, str2, "read_time", (metricsReader.getEndTimeMetric("read_time") - metricsReader.getStartTimeMetric("read_time")) / 1000.0d);
        });
        return hashSet;
    }

    private PipelineResult runWrite() {
        this.pipelineWrite.apply(GenerateSequence.from(0L).to(numberOfRows)).apply(ParDo.of(new TestRow.DeterministicallyConstructTestRowFn())).apply(ParDo.of(new TimeMonitor(NAMESPACE, "write_time"))).apply(JdbcIO.write().withDataSourceConfiguration(JdbcIO.DataSourceConfiguration.create(dataSource)).withStatement(String.format("insert into %s values(?, ?)", tableName)).withPreparedStatementSetter(new JdbcTestHelper.PrepareStatementFromTestRow()));
        return this.pipelineWrite.run();
    }

    private PipelineResult runRead() {
        PCollection apply = this.pipelineRead.apply(JdbcIO.read().withDataSourceConfiguration(JdbcIO.DataSourceConfiguration.create(dataSource)).withQuery(String.format("select name,id from %s;", tableName)).withRowMapper(new JdbcTestHelper.CreateTestRowOfNameAndId()).withCoder(SerializableCoder.of(TestRow.class))).apply(ParDo.of(new TimeMonitor(NAMESPACE, "read_time")));
        PAssert.thatSingleton(apply.apply("Count All", Count.globally())).isEqualTo(Long.valueOf(numberOfRows));
        PAssert.that(apply.apply(ParDo.of(new TestRow.SelectNameFn())).apply("Hash row contents", Combine.globally(new HashingFn()).withoutDefaults())).containsInAnyOrder(new String[]{TestRow.getExpectedHashForRowCount(numberOfRows)});
        PAssert.thatSingletonIterable(apply.apply(Top.smallest(500))).containsInAnyOrder(TestRow.getExpectedValues(0, 500));
        PAssert.thatSingletonIterable(apply.apply(Top.largest(500))).containsInAnyOrder(TestRow.getExpectedValues(numberOfRows - 500, numberOfRows));
        return this.pipelineRead.run();
    }
}
