package org.apache.beam.sdk.io.singlestore;

import com.google.cloud.Timestamp;
import com.singlestore.jdbc.SingleStoreDataSource;
import java.util.HashSet;
import java.util.Set;
import java.util.UUID;
import java.util.function.Function;
import org.apache.beam.sdk.PipelineResult;
import org.apache.beam.sdk.io.GenerateSequence;
import org.apache.beam.sdk.io.common.DatabaseTestHelper;
import org.apache.beam.sdk.io.common.HashingFn;
import org.apache.beam.sdk.io.common.IOITHelper;
import org.apache.beam.sdk.io.common.TestRow;
import org.apache.beam.sdk.io.singlestore.SingleStoreIO;
import org.apache.beam.sdk.io.singlestore.TestHelper;
import org.apache.beam.sdk.testing.NeedsRunner;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.testutils.NamedTestResult;
import org.apache.beam.sdk.testutils.metrics.IOITMetrics;
import org.apache.beam.sdk.testutils.metrics.MetricsReader;
import org.apache.beam.sdk.testutils.metrics.TimeMonitor;
import org.apache.beam.sdk.testutils.publishing.InfluxDBSettings;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.Count;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.Sum;
import org.apache.beam.sdk.transforms.Top;
import org.apache.beam.sdk.values.PCollection;
import org.junit.Assert;
import org.junit.Assume;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(JUnit4.class)
/* loaded from: input_file:org/apache/beam/sdk/io/singlestore/SingleStoreIOPerformanceIT.class */
public class SingleStoreIOPerformanceIT {
    private static final String NAMESPACE = SingleStoreIOPerformanceIT.class.getName();
    private static final String DATABASE_NAME = "SingleStoreIOIT";
    private static int numberOfRows;
    private static String tableName;
    private static String serverName;
    private static String username;
    private static String password;
    private static Integer port;
    private static SingleStoreIO.DataSourceConfiguration dataSourceConfiguration;
    private static InfluxDBSettings settings;

    @Rule
    public TestPipeline pipelineWrite = TestPipeline.create();

    @Rule
    public TestPipeline pipelineRead = TestPipeline.create();

    @Rule
    public TestPipeline pipelineReadWithPartitions = TestPipeline.create();

    @BeforeClass
    public static void setup() {
        SingleStoreIOTestPipelineOptions singleStoreIOTestPipelineOptions;
        try {
            singleStoreIOTestPipelineOptions = (SingleStoreIOTestPipelineOptions) IOITHelper.readIOTestPipelineOptions(SingleStoreIOTestPipelineOptions.class);
        } catch (IllegalArgumentException e) {
            singleStoreIOTestPipelineOptions = null;
        }
        Assume.assumeNotNull(new Object[]{singleStoreIOTestPipelineOptions});
        numberOfRows = singleStoreIOTestPipelineOptions.getNumberOfRecords().intValue();
        serverName = singleStoreIOTestPipelineOptions.getSingleStoreServerName();
        username = singleStoreIOTestPipelineOptions.getSingleStoreUsername();
        password = singleStoreIOTestPipelineOptions.getSingleStorePassword();
        port = singleStoreIOTestPipelineOptions.getSingleStorePort();
        tableName = DatabaseTestHelper.getTestTableName("IT");
        dataSourceConfiguration = SingleStoreIO.DataSourceConfiguration.create(serverName + ":" + port).withDatabase(DATABASE_NAME).withPassword(password).withUsername(username);
        settings = InfluxDBSettings.builder().withHost(singleStoreIOTestPipelineOptions.getInfluxHost()).withDatabase(singleStoreIOTestPipelineOptions.getInfluxDatabase()).withMeasurement(singleStoreIOTestPipelineOptions.getInfluxMeasurement()).get();
    }

    @Test
    @Category({NeedsRunner.class})
    public void testWriteThenRead() throws Exception {
        TestHelper.createDatabaseIfNotExists(serverName, port, username, password, DATABASE_NAME);
        SingleStoreDataSource singleStoreDataSource = new SingleStoreDataSource(String.format("jdbc:singlestore://%s:%d/%s?user=%s&password=%s&allowLocalInfile=TRUE", serverName, port, DATABASE_NAME, username, password));
        DatabaseTestHelper.createTable(singleStoreDataSource, tableName);
        try {
            PipelineResult runWrite = runWrite();
            Assert.assertEquals(PipelineResult.State.DONE, runWrite.waitUntilFinish());
            PipelineResult runRead = runRead();
            Assert.assertEquals(PipelineResult.State.DONE, runRead.waitUntilFinish());
            PipelineResult runReadWithPartitions = runReadWithPartitions();
            Assert.assertEquals(PipelineResult.State.DONE, runReadWithPartitions.waitUntilFinish());
            gatherAndPublishMetrics(runWrite, runRead, runReadWithPartitions);
            DatabaseTestHelper.deleteTable(singleStoreDataSource, tableName);
        } catch (Throwable th) {
            DatabaseTestHelper.deleteTable(singleStoreDataSource, tableName);
            throw th;
        }
    }

    private void gatherAndPublishMetrics(PipelineResult pipelineResult, PipelineResult pipelineResult2, PipelineResult pipelineResult3) {
        String uuid = UUID.randomUUID().toString();
        String timestamp = Timestamp.now().toString();
        new IOITMetrics(getMetricSuppliers(uuid, timestamp, "write_time"), pipelineResult, NAMESPACE, uuid, timestamp).publishToInflux(settings);
        new IOITMetrics(getMetricSuppliers(uuid, timestamp, "read_time"), pipelineResult2, NAMESPACE, uuid, timestamp).publishToInflux(settings);
        new IOITMetrics(getMetricSuppliers(uuid, timestamp, "read_with_partitions_time"), pipelineResult3, NAMESPACE, uuid, timestamp).publishToInflux(settings);
    }

    private Set<Function<MetricsReader, NamedTestResult>> getMetricSuppliers(String str, String str2, String str3) {
        HashSet hashSet = new HashSet();
        hashSet.add(metricsReader -> {
            return NamedTestResult.create(str, str2, str3, (metricsReader.getEndTimeMetric(str3) - metricsReader.getStartTimeMetric(str3)) / 1000.0d);
        });
        return hashSet;
    }

    private PipelineResult runWrite() {
        PAssert.thatSingleton(this.pipelineWrite.apply(GenerateSequence.from(0L).to(numberOfRows)).apply(ParDo.of(new TestRow.DeterministicallyConstructTestRowFn())).apply(ParDo.of(new TimeMonitor(NAMESPACE, "write_time"))).apply(SingleStoreIO.write().withDataSourceConfiguration(dataSourceConfiguration).withTable(tableName).withUserDataMapper(new TestHelper.TestUserDataMapper())).apply("Sum All", Sum.integersGlobally())).isEqualTo(Integer.valueOf(numberOfRows));
        return this.pipelineWrite.run();
    }

    private PipelineResult runRead() {
        testReadResult((PCollection) this.pipelineRead.apply(SingleStoreIO.read().withDataSourceConfiguration(dataSourceConfiguration).withTable(tableName).withRowMapper(new TestHelper.TestRowMapper())).apply(ParDo.of(new TimeMonitor(NAMESPACE, "read_time"))));
        return this.pipelineRead.run();
    }

    private PipelineResult runReadWithPartitions() {
        testReadResult((PCollection) this.pipelineReadWithPartitions.apply(SingleStoreIO.readWithPartitions().withDataSourceConfiguration(dataSourceConfiguration).withTable(tableName).withRowMapper(new TestHelper.TestRowMapper())).apply(ParDo.of(new TimeMonitor(NAMESPACE, "read_with_partitions_time"))));
        return this.pipelineReadWithPartitions.run();
    }

    private void testReadResult(PCollection<TestRow> pCollection) {
        PAssert.thatSingleton(pCollection.apply("Count All", Count.globally())).isEqualTo(Long.valueOf(numberOfRows));
        PAssert.that(pCollection.apply(ParDo.of(new TestRow.SelectNameFn())).apply("Hash row contents", Combine.globally(new HashingFn()).withoutDefaults())).containsInAnyOrder(new String[]{TestRow.getExpectedHashForRowCount(numberOfRows)});
        PAssert.thatSingletonIterable(pCollection.apply(Top.smallest(500))).containsInAnyOrder(TestRow.getExpectedValues(0, 500));
        PAssert.thatSingletonIterable(pCollection.apply(Top.largest(500))).containsInAnyOrder(TestRow.getExpectedValues(numberOfRows - 500, numberOfRows));
    }
}
