package org.apache.seatunnel.connectors.seatunnel.neo4j.sink;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import org.apache.seatunnel.api.common.SeaTunnelAPIErrorCode;
import org.apache.seatunnel.api.sink.SinkWriter;
import org.apache.seatunnel.api.table.type.SeaTunnelRow;
import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
import org.apache.seatunnel.common.constants.PluginType;
import org.apache.seatunnel.common.exception.SeaTunnelErrorCode;
import org.apache.seatunnel.connectors.seatunnel.neo4j.config.Neo4jSinkQueryInfo;
import org.apache.seatunnel.connectors.seatunnel.neo4j.constants.CypherEnum;
import org.apache.seatunnel.connectors.seatunnel.neo4j.exception.Neo4jConnectorErrorCode;
import org.apache.seatunnel.connectors.seatunnel.neo4j.exception.Neo4jConnectorException;
import org.apache.seatunnel.connectors.seatunnel.neo4j.internal.SeaTunnelRowNeo4jValue;
import org.neo4j.driver.Driver;
import org.neo4j.driver.Query;
import org.neo4j.driver.Session;
import org.neo4j.driver.SessionConfig;
import org.neo4j.driver.Values;
import org.neo4j.driver.exceptions.ClientException;
import org.neo4j.driver.exceptions.Neo4jException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/seatunnel/connectors/seatunnel/neo4j/sink/Neo4jSinkWriter.class */
public class Neo4jSinkWriter implements SinkWriter<SeaTunnelRow, Void, Void> {
    private static final Logger log = LoggerFactory.getLogger(Neo4jSinkWriter.class);
    private final Neo4jSinkQueryInfo neo4jSinkQueryInfo;
    private final transient Driver driver;
    private final transient Session session;
    private final SeaTunnelRowType seaTunnelRowType;
    private final List<SeaTunnelRowNeo4jValue> writeBuffer;
    private final Integer maxBatchSize;

    public Neo4jSinkWriter(Neo4jSinkQueryInfo neo4jSinkQueryInfo, SeaTunnelRowType seaTunnelRowType) {
        this.neo4jSinkQueryInfo = neo4jSinkQueryInfo;
        this.driver = this.neo4jSinkQueryInfo.getDriverBuilder().build();
        this.session = this.driver.session(SessionConfig.forDatabase(neo4jSinkQueryInfo.getDriverBuilder().getDatabase()));
        this.seaTunnelRowType = seaTunnelRowType;
        this.maxBatchSize = (Integer) Optional.ofNullable(neo4jSinkQueryInfo.getMaxBatchSize()).orElse(0);
        this.writeBuffer = new ArrayList(this.maxBatchSize.intValue());
    }

    public void write(SeaTunnelRow seaTunnelRow) throws IOException {
        if (this.neo4jSinkQueryInfo.batchMode()) {
            writeByBatchSize(seaTunnelRow);
        } else {
            writeOneByOne(seaTunnelRow);
        }
    }

    private void writeOneByOne(SeaTunnelRow seaTunnelRow) {
        writeByQuery(new Query(this.neo4jSinkQueryInfo.getQuery(), (Map<String, Object>) this.neo4jSinkQueryInfo.getQueryParamPosition().entrySet().stream().collect(Collectors.toMap((v0) -> {
            return v0.getKey();
        }, entry -> {
            return seaTunnelRow.getField(((Integer) entry.getValue()).intValue());
        }))));
    }

    private void writeByBatchSize(SeaTunnelRow seaTunnelRow) {
        this.writeBuffer.add(new SeaTunnelRowNeo4jValue(this.seaTunnelRowType, seaTunnelRow));
        tryWriteByBatchSize();
    }

    private void tryWriteByBatchSize() {
        if (this.writeBuffer.isEmpty() || this.writeBuffer.size() < this.maxBatchSize.intValue()) {
            return;
        }
        writeByQuery(batchQuery());
        this.writeBuffer.clear();
    }

    private Query batchQuery() {
        try {
            return new Query(this.neo4jSinkQueryInfo.getQuery(), Values.parameters(CypherEnum.BATCH.getValue(), this.writeBuffer));
        } catch (ClientException e) {
            log.error("Failed to build cypher statement", e);
            throw new Neo4jConnectorException((SeaTunnelErrorCode) SeaTunnelAPIErrorCode.CONFIG_VALIDATION_FAILED, String.format("PluginName: %s, PluginType: %s, Message: %s", "Neo4j", PluginType.SINK, e.getMessage()));
        }
    }

    private void writeByQuery(Query query) {
        try {
            this.session.writeTransaction(transaction -> {
                transaction.run(query);
                return null;
            });
        } catch (Neo4jException e) {
            throw new Neo4jConnectorException(Neo4jConnectorErrorCode.DATE_BASE_ERROR, e.getMessage());
        }
    }

    public Optional<Void> prepareCommit() throws IOException {
        return Optional.empty();
    }

    public void abortPrepare() {
    }

    public void close() throws IOException {
        flushWriteBuffer();
        this.session.close();
        this.driver.close();
    }

    private void flushWriteBuffer() {
        if (this.writeBuffer.isEmpty()) {
            return;
        }
        writeByQuery(batchQuery());
        this.writeBuffer.clear();
    }
}
