package org.apache.shardingsphere.driver.executor.batch;

import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import lombok.Generated;
import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.database.type.DatabaseType;
import org.apache.shardingsphere.infra.executor.kernel.model.ExecutionGroup;
import org.apache.shardingsphere.infra.executor.kernel.model.ExecutionGroupContext;
import org.apache.shardingsphere.infra.executor.sql.context.ExecutionUnit;
import org.apache.shardingsphere.infra.executor.sql.execute.engine.ConnectionMode;
import org.apache.shardingsphere.infra.executor.sql.execute.engine.SQLExecutorExceptionHandler;
import org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutionUnit;
import org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutor;
import org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutorCallback;
import org.apache.shardingsphere.infra.rule.identifier.type.DataNodeContainedRule;
import org.apache.shardingsphere.infra.util.eventbus.EventBusContext;
import org.apache.shardingsphere.mode.metadata.MetaDataContexts;
import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;

/* loaded from: input_file:org/apache/shardingsphere/driver/executor/batch/BatchPreparedStatementExecutor.class */
public final class BatchPreparedStatementExecutor {
    private final MetaDataContexts metaDataContexts;
    private final JDBCExecutor jdbcExecutor;
    private ExecutionGroupContext<JDBCExecutionUnit> executionGroupContext = new ExecutionGroupContext<>(new LinkedList());
    private final Collection<BatchExecutionUnit> batchExecutionUnits = new LinkedList();
    private int batchCount;
    private final String databaseName;
    private final EventBusContext eventBusContext;

    public BatchPreparedStatementExecutor(MetaDataContexts metaDataContexts, JDBCExecutor jDBCExecutor, String str, EventBusContext eventBusContext) {
        this.databaseName = str;
        this.metaDataContexts = metaDataContexts;
        this.jdbcExecutor = jDBCExecutor;
        this.eventBusContext = eventBusContext;
    }

    public void init(ExecutionGroupContext<JDBCExecutionUnit> executionGroupContext) {
        this.executionGroupContext = executionGroupContext;
    }

    public void addBatchForExecutionUnits(Collection<ExecutionUnit> collection) {
        Collection<BatchExecutionUnit> createBatchExecutionUnits = createBatchExecutionUnits(collection);
        handleOldBatchExecutionUnits(createBatchExecutionUnits);
        handleNewBatchExecutionUnits(createBatchExecutionUnits);
        this.batchCount++;
    }

    private Collection<BatchExecutionUnit> createBatchExecutionUnits(Collection<ExecutionUnit> collection) {
        ArrayList arrayList = new ArrayList(collection.size());
        Iterator<ExecutionUnit> it = collection.iterator();
        while (it.hasNext()) {
            arrayList.add(new BatchExecutionUnit(it.next()));
        }
        return arrayList;
    }

    private void handleOldBatchExecutionUnits(Collection<BatchExecutionUnit> collection) {
        collection.forEach(this::reviseBatchExecutionUnits);
    }

    private void reviseBatchExecutionUnits(BatchExecutionUnit batchExecutionUnit) {
        for (BatchExecutionUnit batchExecutionUnit2 : this.batchExecutionUnits) {
            if (batchExecutionUnit2.equals(batchExecutionUnit)) {
                reviseBatchExecutionUnit(batchExecutionUnit2, batchExecutionUnit);
            }
        }
    }

    private void reviseBatchExecutionUnit(BatchExecutionUnit batchExecutionUnit, BatchExecutionUnit batchExecutionUnit2) {
        batchExecutionUnit.getExecutionUnit().getSqlUnit().getParameters().addAll(batchExecutionUnit2.getExecutionUnit().getSqlUnit().getParameters());
        batchExecutionUnit.mapAddBatchCount(this.batchCount);
    }

    private void handleNewBatchExecutionUnits(Collection<BatchExecutionUnit> collection) {
        collection.removeAll(this.batchExecutionUnits);
        Iterator<BatchExecutionUnit> it = collection.iterator();
        while (it.hasNext()) {
            it.next().mapAddBatchCount(this.batchCount);
        }
        this.batchExecutionUnits.addAll(collection);
    }

    public int[] executeBatch(SQLStatementContext<?> sQLStatementContext) throws SQLException {
        List<int[]> execute = this.jdbcExecutor.execute(this.executionGroupContext, new JDBCExecutorCallback<int[]>(this.metaDataContexts.getMetaData().getDatabase(this.databaseName).getProtocolType(), this.metaDataContexts.getMetaData().getDatabase(this.databaseName).getResourceMetaData().getStorageTypes(), sQLStatementContext.getSqlStatement(), SQLExecutorExceptionHandler.isExceptionThrown(), this.eventBusContext) { // from class: org.apache.shardingsphere.driver.executor.batch.BatchPreparedStatementExecutor.1
            /* JADX INFO: Access modifiers changed from: protected */
            /* renamed from: executeSQL, reason: merged with bridge method [inline-methods] */
            public int[] m2executeSQL(String str, Statement statement, ConnectionMode connectionMode, DatabaseType databaseType) throws SQLException {
                return statement.executeBatch();
            }

            protected Optional<int[]> getSaneResult(SQLStatement sQLStatement, SQLException sQLException) {
                return Optional.empty();
            }
        });
        return execute.isEmpty() ? new int[0] : isNeedAccumulate(sQLStatementContext) ? accumulate(execute) : execute.get(0);
    }

    private boolean isNeedAccumulate(SQLStatementContext<?> sQLStatementContext) {
        for (DataNodeContainedRule dataNodeContainedRule : this.metaDataContexts.getMetaData().getDatabase(this.databaseName).getRuleMetaData().getRules()) {
            if ((dataNodeContainedRule instanceof DataNodeContainedRule) && dataNodeContainedRule.isNeedAccumulate(sQLStatementContext.getTablesContext().getTableNames())) {
                return true;
            }
        }
        return false;
    }

    private int[] accumulate(List<int[]> list) {
        int[] iArr = new int[this.batchCount];
        int i = 0;
        Iterator it = this.executionGroupContext.getInputGroups().iterator();
        while (it.hasNext()) {
            for (JDBCExecutionUnit jDBCExecutionUnit : ((ExecutionGroup) it.next()).getInputs()) {
                Map<Integer, Integer> emptyMap = Collections.emptyMap();
                Iterator<BatchExecutionUnit> it2 = this.batchExecutionUnits.iterator();
                while (true) {
                    if (!it2.hasNext()) {
                        break;
                    }
                    BatchExecutionUnit next = it2.next();
                    if (isSameDataSourceAndSQL(next, jDBCExecutionUnit)) {
                        emptyMap = next.getJdbcAndActualAddBatchCallTimesMap();
                        break;
                    }
                }
                for (Map.Entry<Integer, Integer> entry : emptyMap.entrySet()) {
                    int i2 = null == list.get(i) ? 0 : list.get(i)[entry.getValue().intValue()];
                    int intValue = entry.getKey().intValue();
                    iArr[intValue] = iArr[intValue] + i2;
                }
                i++;
            }
        }
        return iArr;
    }

    private boolean isSameDataSourceAndSQL(BatchExecutionUnit batchExecutionUnit, JDBCExecutionUnit jDBCExecutionUnit) {
        return batchExecutionUnit.getExecutionUnit().getDataSourceName().equals(jDBCExecutionUnit.getExecutionUnit().getDataSourceName()) && batchExecutionUnit.getExecutionUnit().getSqlUnit().getSql().equals(jDBCExecutionUnit.getExecutionUnit().getSqlUnit().getSql());
    }

    public List<Statement> getStatements() {
        LinkedList linkedList = new LinkedList();
        Iterator it = this.executionGroupContext.getInputGroups().iterator();
        while (it.hasNext()) {
            Iterator it2 = ((ExecutionGroup) it.next()).getInputs().iterator();
            while (it2.hasNext()) {
                linkedList.add(((JDBCExecutionUnit) it2.next()).getStorageResource());
            }
        }
        return linkedList;
    }

    public List<List<Object>> getParameterSet(Statement statement) {
        Iterator it = this.executionGroupContext.getInputGroups().iterator();
        while (it.hasNext()) {
            Optional<JDBCExecutionUnit> findJDBCExecutionUnit = findJDBCExecutionUnit(statement, (ExecutionGroup) it.next());
            if (findJDBCExecutionUnit.isPresent()) {
                return getParameterSets(findJDBCExecutionUnit.get());
            }
        }
        return Collections.emptyList();
    }

    private Optional<JDBCExecutionUnit> findJDBCExecutionUnit(Statement statement, ExecutionGroup<JDBCExecutionUnit> executionGroup) {
        for (JDBCExecutionUnit jDBCExecutionUnit : executionGroup.getInputs()) {
            if (jDBCExecutionUnit.getStorageResource().equals(statement)) {
                return Optional.of(jDBCExecutionUnit);
            }
        }
        return Optional.empty();
    }

    private List<List<Object>> getParameterSets(JDBCExecutionUnit jDBCExecutionUnit) {
        for (BatchExecutionUnit batchExecutionUnit : this.batchExecutionUnits) {
            if (isSameDataSourceAndSQL(batchExecutionUnit, jDBCExecutionUnit)) {
                return batchExecutionUnit.getParameterSets();
            }
        }
        throw new IllegalStateException();
    }

    public void clear() {
        getStatements().clear();
        this.executionGroupContext.getInputGroups().clear();
        this.batchCount = 0;
        this.batchExecutionUnits.clear();
    }

    @Generated
    public Collection<BatchExecutionUnit> getBatchExecutionUnits() {
        return this.batchExecutionUnits;
    }
}
