package org.apache.beam.sdk.io.snowflake.services;

import java.math.BigInteger;
import java.nio.charset.StandardCharsets;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.List;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import javax.sql.DataSource;
import org.apache.beam.sdk.io.snowflake.data.SnowflakeTableSchema;
import org.apache.beam.sdk.io.snowflake.enums.CreateDisposition;
import org.apache.beam.sdk.io.snowflake.enums.WriteDisposition;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;

/* loaded from: input_file:org/apache/beam/sdk/io/snowflake/services/SnowflakeBatchServiceImpl.class */
public class SnowflakeBatchServiceImpl implements SnowflakeService<SnowflakeBatchServiceConfig> {
    private static final String SNOWFLAKE_GCS_PREFIX = "gcs://";
    private static final String GCS_PREFIX = "gs://";
    static final /* synthetic */ boolean $assertionsDisabled;

    @Override // org.apache.beam.sdk.io.snowflake.services.SnowflakeService
    public void write(SnowflakeBatchServiceConfig snowflakeBatchServiceConfig) throws Exception {
        copyToTable(snowflakeBatchServiceConfig);
    }

    @Override // org.apache.beam.sdk.io.snowflake.services.SnowflakeService
    public String read(SnowflakeBatchServiceConfig snowflakeBatchServiceConfig) throws Exception {
        return copyIntoStage(snowflakeBatchServiceConfig);
    }

    private String copyIntoStage(SnowflakeBatchServiceConfig snowflakeBatchServiceConfig) throws SQLException {
        SerializableFunction<Void, DataSource> dataSourceProviderFn = snowflakeBatchServiceConfig.getDataSourceProviderFn();
        String database = snowflakeBatchServiceConfig.getDatabase();
        String schema = snowflakeBatchServiceConfig.getSchema();
        String table = snowflakeBatchServiceConfig.getTable();
        String query = snowflakeBatchServiceConfig.getQuery();
        String storageIntegrationName = snowflakeBatchServiceConfig.getStorageIntegrationName();
        String stagingBucketDir = snowflakeBatchServiceConfig.getStagingBucketDir();
        runStatement(String.format("COPY INTO '%s' FROM %s STORAGE_INTEGRATION=%s FILE_FORMAT=(TYPE=CSV COMPRESSION=GZIP FIELD_OPTIONALLY_ENCLOSED_BY='%s');", getProperBucketDir(stagingBucketDir), query != null ? String.format("(%s)", query) : getTablePath(database, schema, table), storageIntegrationName, getASCIICharRepresentation(snowflakeBatchServiceConfig.getQuotationMark())), getConnection(dataSourceProviderFn), null);
        return stagingBucketDir.concat("*");
    }

    private String getASCIICharRepresentation(String str) {
        return String.format("0x%x", new BigInteger(1, str.getBytes(StandardCharsets.UTF_8)));
    }

    private void copyToTable(SnowflakeBatchServiceConfig snowflakeBatchServiceConfig) throws SQLException {
        SerializableFunction<Void, DataSource> dataSourceProviderFn = snowflakeBatchServiceConfig.getDataSourceProviderFn();
        List<String> filesList = snowflakeBatchServiceConfig.getFilesList();
        String database = snowflakeBatchServiceConfig.getDatabase();
        String schema = snowflakeBatchServiceConfig.getSchema();
        String table = snowflakeBatchServiceConfig.getTable();
        String query = snowflakeBatchServiceConfig.getQuery();
        SnowflakeTableSchema tableSchema = snowflakeBatchServiceConfig.getTableSchema();
        CreateDisposition createDisposition = snowflakeBatchServiceConfig.getCreateDisposition();
        WriteDisposition writeDisposition = snowflakeBatchServiceConfig.getWriteDisposition();
        String storageIntegrationName = snowflakeBatchServiceConfig.getStorageIntegrationName();
        String stagingBucketDir = snowflakeBatchServiceConfig.getStagingBucketDir();
        String format = query != null ? String.format("(%s)", query) : String.format("'%s'", stagingBucketDir);
        String replaceAll = String.join(", ", (List) filesList.stream().map(str -> {
            return String.format("'%s'", str);
        }).collect(Collectors.toList())).replaceAll(stagingBucketDir, "");
        DataSource dataSource = (DataSource) dataSourceProviderFn.apply((Object) null);
        prepareTableAccordingCreateDisposition(dataSource, table, tableSchema, createDisposition);
        prepareTableAccordingWriteDisposition(dataSource, table, writeDisposition);
        runStatement(!storageIntegrationName.isEmpty() ? String.format("COPY INTO %s FROM %s FILES=(%s) FILE_FORMAT=(TYPE=CSV FIELD_OPTIONALLY_ENCLOSED_BY='%s' COMPRESSION=GZIP) STORAGE_INTEGRATION=%s;", getTablePath(database, schema, table), getProperBucketDir(format), replaceAll, getASCIICharRepresentation(snowflakeBatchServiceConfig.getQuotationMark()), storageIntegrationName) : String.format("COPY INTO %s FROM %s FILES=(%s) FILE_FORMAT=(TYPE=CSV FIELD_OPTIONALLY_ENCLOSED_BY='%s' COMPRESSION=GZIP);", table, format, replaceAll, getASCIICharRepresentation(snowflakeBatchServiceConfig.getQuotationMark())), dataSource.getConnection(), null);
    }

    private void truncateTable(DataSource dataSource, String str) throws SQLException {
        runConnectionWithStatement(dataSource, String.format("TRUNCATE %s;", str), null);
    }

    private static void checkIfTableIsEmpty(DataSource dataSource, String str) throws SQLException {
        runConnectionWithStatement(dataSource, String.format("SELECT count(*) FROM %s LIMIT 1;", str), obj -> {
            if (!$assertionsDisabled && obj == null) {
                throw new AssertionError();
            }
            checkIfTableIsEmpty((ResultSet) obj);
        });
    }

    private static void checkIfTableIsEmpty(ResultSet resultSet) {
        try {
            if (resultSet.next() && checkIfTableIsEmpty(resultSet, 1)) {
            } else {
                throw new RuntimeException("Table is not empty. Aborting COPY with disposition EMPTY");
            }
        } catch (SQLException e) {
            throw new RuntimeException("Unable run pipeline with EMPTY disposition.", e);
        }
    }

    private static boolean checkIfTableIsEmpty(ResultSet resultSet, int i) throws SQLException {
        return resultSet.getInt(i) < 1;
    }

    private void prepareTableAccordingCreateDisposition(DataSource dataSource, String str, SnowflakeTableSchema snowflakeTableSchema, CreateDisposition createDisposition) throws SQLException {
        switch (createDisposition) {
            case CREATE_NEVER:
            default:
                return;
            case CREATE_IF_NEEDED:
                createTableIfNotExists(dataSource, str, snowflakeTableSchema);
                return;
        }
    }

    private void prepareTableAccordingWriteDisposition(DataSource dataSource, String str, WriteDisposition writeDisposition) throws SQLException {
        switch (writeDisposition) {
            case TRUNCATE:
                truncateTable(dataSource, str);
                return;
            case EMPTY:
                checkIfTableIsEmpty(dataSource, str);
                return;
            case APPEND:
            default:
                return;
        }
    }

    private void createTableIfNotExists(DataSource dataSource, String str, SnowflakeTableSchema snowflakeTableSchema) throws SQLException {
        runConnectionWithStatement(dataSource, String.format("SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = '%s');", str.toUpperCase()), obj -> {
            if (!$assertionsDisabled && obj == null) {
                throw new AssertionError();
            }
            if (checkResultIfTableExists((ResultSet) obj)) {
                return;
            }
            try {
                createTable(dataSource, str, snowflakeTableSchema);
            } catch (SQLException e) {
                throw new RuntimeException("Unable to create table.", e);
            }
        });
    }

    private static boolean checkResultIfTableExists(ResultSet resultSet) {
        try {
            if (resultSet.next()) {
                return checkIfResultIsTrue(resultSet);
            }
            throw new RuntimeException("Unable run pipeline with CREATE IF NEEDED - no response.");
        } catch (SQLException e) {
            throw new RuntimeException("Unable run pipeline with CREATE IF NEEDED disposition.", e);
        }
    }

    private void createTable(DataSource dataSource, String str, SnowflakeTableSchema snowflakeTableSchema) throws SQLException {
        Preconditions.checkArgument(snowflakeTableSchema != null, "The CREATE_IF_NEEDED disposition requires schema if table doesn't exists");
        runConnectionWithStatement(dataSource, String.format("CREATE TABLE %s (%s);", str, snowflakeTableSchema.sql()), null);
    }

    private static boolean checkIfResultIsTrue(ResultSet resultSet) throws SQLException {
        return resultSet.getBoolean(1);
    }

    private static void runConnectionWithStatement(DataSource dataSource, String str, Consumer consumer) throws SQLException {
        Connection connection = dataSource.getConnection();
        runStatement(str, connection, consumer);
        connection.close();
    }

    private static void runStatement(String str, Connection connection, Consumer consumer) throws SQLException {
        PreparedStatement prepareStatement = connection.prepareStatement(str);
        try {
            if (consumer != null) {
                consumer.accept(prepareStatement.executeQuery());
            } else {
                prepareStatement.execute();
            }
            prepareStatement.close();
            connection.close();
        } catch (Throwable th) {
            prepareStatement.close();
            connection.close();
            throw th;
        }
    }

    private Connection getConnection(SerializableFunction<Void, DataSource> serializableFunction) throws SQLException {
        return ((DataSource) serializableFunction.apply((Object) null)).getConnection();
    }

    private String getProperBucketDir(String str) {
        return str.contains(GCS_PREFIX) ? str.replace(GCS_PREFIX, SNOWFLAKE_GCS_PREFIX) : str;
    }

    private String getTablePath(String str, String str2, String str3) {
        return String.format("%s.%s.%s", str, str2, str3);
    }

    static {
        $assertionsDisabled = !SnowflakeBatchServiceImpl.class.desiredAssertionStatus();
    }
}
