package org.apache.beam.sdk.io.snowflake.services;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.List;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import javax.sql.DataSource;
import org.apache.beam.sdk.io.snowflake.enums.CloudProvider;
import org.apache.beam.sdk.io.snowflake.enums.WriteDisposition;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/beam/sdk/io/snowflake/services/SnowflakeServiceImpl.class */
public class SnowflakeServiceImpl implements SnowflakeService<SnowflakeServiceConfig> {
    private static final Logger LOG;
    private static final String SNOWFLAKE_GCS_PREFIX = "gcs://";
    static final /* synthetic */ boolean $assertionsDisabled;

    @Override // org.apache.beam.sdk.io.snowflake.services.SnowflakeService
    public void write(SnowflakeServiceConfig snowflakeServiceConfig) throws Exception {
        copyToTable(snowflakeServiceConfig);
    }

    @Override // org.apache.beam.sdk.io.snowflake.services.SnowflakeService
    public String read(SnowflakeServiceConfig snowflakeServiceConfig) throws Exception {
        return copyIntoStage(snowflakeServiceConfig);
    }

    public String copyIntoStage(SnowflakeServiceConfig snowflakeServiceConfig) throws SQLException {
        SerializableFunction<Void, DataSource> dataSourceProviderFn = snowflakeServiceConfig.getDataSourceProviderFn();
        String table = snowflakeServiceConfig.getTable();
        String query = snowflakeServiceConfig.getQuery();
        String str = snowflakeServiceConfig.getstorageIntegrationName();
        String stagingBucketDir = snowflakeServiceConfig.getStagingBucketDir();
        runStatement(String.format("COPY INTO '%s' FROM %s STORAGE_INTEGRATION=%s FILE_FORMAT=(TYPE=CSV COMPRESSION=GZIP FIELD_OPTIONALLY_ENCLOSED_BY='%s');", getProperBucketDir(stagingBucketDir), query != null ? String.format("(%s)", query) : table, str, SnowflakeService.CSV_QUOTE_CHAR_FOR_COPY), getConnection(dataSourceProviderFn), null);
        return stagingBucketDir.concat("*");
    }

    public void copyToTable(SnowflakeServiceConfig snowflakeServiceConfig) throws SQLException {
        SerializableFunction<Void, DataSource> dataSourceProviderFn = snowflakeServiceConfig.getDataSourceProviderFn();
        List<String> filesList = snowflakeServiceConfig.getFilesList();
        String table = snowflakeServiceConfig.getTable();
        String query = snowflakeServiceConfig.getQuery();
        WriteDisposition writeDisposition = snowflakeServiceConfig.getWriteDisposition();
        String str = snowflakeServiceConfig.getstorageIntegrationName();
        String stagingBucketDir = snowflakeServiceConfig.getStagingBucketDir();
        String format = query != null ? String.format("(%s)", query) : String.format("'%s'", stagingBucketDir);
        String replaceAll = String.join(", ", (List) filesList.stream().map(str2 -> {
            return String.format("'%s'", str2);
        }).collect(Collectors.toList())).replaceAll(stagingBucketDir, "");
        DataSource dataSource = (DataSource) dataSourceProviderFn.apply((Object) null);
        prepareTableAccordingWriteDisposition(dataSource, table, writeDisposition);
        runStatement(!str.isEmpty() ? String.format("COPY INTO %s FROM %s FILES=(%s) FILE_FORMAT=(TYPE=CSV FIELD_OPTIONALLY_ENCLOSED_BY='%s' COMPRESSION=GZIP) STORAGE_INTEGRATION=%s;", table, getProperBucketDir(format), replaceAll, SnowflakeService.CSV_QUOTE_CHAR_FOR_COPY, str) : String.format("COPY INTO %s FROM %s FILES=(%s) FILE_FORMAT=(TYPE=CSV FIELD_OPTIONALLY_ENCLOSED_BY='%s' COMPRESSION=GZIP);", table, format, replaceAll, SnowflakeService.CSV_QUOTE_CHAR_FOR_COPY), dataSource.getConnection(), null);
    }

    private void truncateTable(DataSource dataSource, String str) throws SQLException {
        runConnectionWithStatement(dataSource, String.format("TRUNCATE %s;", str), null);
    }

    private static void checkIfTableIsEmpty(DataSource dataSource, String str) throws SQLException {
        runConnectionWithStatement(dataSource, String.format("SELECT count(*) FROM %s LIMIT 1;", str), obj -> {
            if (!$assertionsDisabled && obj == null) {
                throw new AssertionError();
            }
            checkIfTableIsEmpty((ResultSet) obj);
        });
    }

    private static void checkIfTableIsEmpty(ResultSet resultSet) {
        try {
            if (resultSet.next() && checkIfTableIsEmpty(resultSet, 1)) {
            } else {
                throw new RuntimeException("Table is not empty. Aborting COPY with disposition EMPTY");
            }
        } catch (SQLException e) {
            throw new RuntimeException("Unable run pipeline with EMPTY disposition.", e);
        }
    }

    private static boolean checkIfTableIsEmpty(ResultSet resultSet, int i) throws SQLException {
        return resultSet.getInt(i) < 1;
    }

    private void prepareTableAccordingWriteDisposition(DataSource dataSource, String str, WriteDisposition writeDisposition) throws SQLException {
        switch (writeDisposition) {
            case TRUNCATE:
                truncateTable(dataSource, str);
                return;
            case EMPTY:
                checkIfTableIsEmpty(dataSource, str);
                return;
            case APPEND:
            default:
                return;
        }
    }

    private static void runConnectionWithStatement(DataSource dataSource, String str, Consumer consumer) throws SQLException {
        Connection connection = dataSource.getConnection();
        runStatement(str, connection, consumer);
        connection.close();
    }

    private static void runStatement(String str, Connection connection, Consumer consumer) throws SQLException {
        PreparedStatement prepareStatement = connection.prepareStatement(str);
        try {
            if (consumer != null) {
                consumer.accept(prepareStatement.executeQuery());
            } else {
                prepareStatement.execute();
            }
            prepareStatement.close();
            connection.close();
        } catch (Throwable th) {
            prepareStatement.close();
            connection.close();
            throw th;
        }
    }

    private Connection getConnection(SerializableFunction<Void, DataSource> serializableFunction) throws SQLException {
        return ((DataSource) serializableFunction.apply((Object) null)).getConnection();
    }

    private String getProperBucketDir(String str) {
        return str.contains(CloudProvider.GCS.getPrefix()) ? str.replace(CloudProvider.GCS.getPrefix(), SNOWFLAKE_GCS_PREFIX) : str;
    }

    static {
        $assertionsDisabled = !SnowflakeServiceImpl.class.desiredAssertionStatus();
        LOG = LoggerFactory.getLogger(SnowflakeServiceImpl.class);
    }
}
