package org.apache.seatunnel.translation.spark.sink.writer;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.seatunnel.api.sink.MultiTableResourceManager;
import org.apache.seatunnel.api.sink.SeaTunnelSink;
import org.apache.seatunnel.api.sink.SinkAggregatedCommitter;
import org.apache.seatunnel.api.sink.SupportResourceShare;
import org.apache.seatunnel.api.table.catalog.CatalogTable;
import org.apache.seatunnel.api.table.type.SeaTunnelRow;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.sources.v2.writer.DataSourceWriter;
import org.apache.spark.sql.sources.v2.writer.DataWriterFactory;
import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage;

/* loaded from: input_file:org/apache/seatunnel/translation/spark/sink/writer/SparkDataSourceWriter.class */
public class SparkDataSourceWriter<StateT, CommitInfoT, AggregatedCommitInfoT> implements DataSourceWriter {
    protected final SeaTunnelSink<SeaTunnelRow, StateT, CommitInfoT, AggregatedCommitInfoT> sink;

    @Nullable
    protected final SinkAggregatedCommitter<CommitInfoT, AggregatedCommitInfoT> sinkAggregatedCommitter;
    protected final CatalogTable catalogTable;
    private MultiTableResourceManager resourceManager;

    public SparkDataSourceWriter(SeaTunnelSink<SeaTunnelRow, StateT, CommitInfoT, AggregatedCommitInfoT> seaTunnelSink, CatalogTable catalogTable) throws IOException {
        this.sink = seaTunnelSink;
        this.catalogTable = catalogTable;
        this.sinkAggregatedCommitter = seaTunnelSink.createAggregatedCommitter().orElse(null);
        if (this.sinkAggregatedCommitter != null) {
            if (this.sinkAggregatedCommitter instanceof SupportResourceShare) {
                this.resourceManager = ((SupportResourceShare) this.sinkAggregatedCommitter).initMultiTableResourceManager(1, 1);
            }
            this.sinkAggregatedCommitter.init();
            if (this.resourceManager != null) {
                ((SupportResourceShare) this.sinkAggregatedCommitter).setMultiTableResourceManager(this.resourceManager, 0);
            }
        }
    }

    public DataWriterFactory<InternalRow> createWriterFactory() {
        return new SparkDataWriterFactory(this.sink, this.catalogTable);
    }

    public void commit(WriterCommitMessage[] writerCommitMessageArr) {
        if (this.sinkAggregatedCommitter != null) {
            try {
                this.sinkAggregatedCommitter.commit(combineCommitMessage(writerCommitMessageArr));
            } catch (IOException e) {
                throw new RuntimeException("SinkAggregatedCommitter commit failed in driver", e);
            }
        }
    }

    public void abort(WriterCommitMessage[] writerCommitMessageArr) {
        if (this.sinkAggregatedCommitter != null) {
            try {
                this.sinkAggregatedCommitter.abort(combineCommitMessage(writerCommitMessageArr));
            } catch (Exception e) {
                throw new RuntimeException("SinkAggregatedCommitter abort failed in driver", e);
            }
        }
    }

    @Nonnull
    private List<AggregatedCommitInfoT> combineCommitMessage(WriterCommitMessage[] writerCommitMessageArr) {
        if (this.sinkAggregatedCommitter == null || writerCommitMessageArr.length == 0) {
            return Collections.emptyList();
        }
        return Collections.singletonList(this.sinkAggregatedCommitter.combine((List) Arrays.stream(writerCommitMessageArr).map(writerCommitMessage -> {
            return ((SparkWriterCommitMessage) writerCommitMessage).getMessage();
        }).filter(Objects::nonNull).collect(Collectors.toList())));
    }
}
