package org.apache.iceberg.mr.hive;

import java.io.IOException;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Properties;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.ql.plan.TableDesc;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.JobContextImpl;
import org.apache.hadoop.mapred.JobID;
import org.apache.hadoop.mapred.OutputCommitter;
import org.apache.hadoop.mapred.TaskAttemptContext;
import org.apache.hadoop.mapred.TaskAttemptContextImpl;
import org.apache.hadoop.mapred.TaskAttemptID;
import org.apache.hadoop.mapreduce.JobStatus;
import org.apache.hadoop.mapreduce.TaskType;
import org.apache.iceberg.FileFormat;
import org.apache.iceberg.PartitionSpec;
import org.apache.iceberg.Schema;
import org.apache.iceberg.Table;
import org.apache.iceberg.data.GenericAppenderFactory;
import org.apache.iceberg.data.Record;
import org.apache.iceberg.hadoop.HadoopTables;
import org.apache.iceberg.io.FileIO;
import org.apache.iceberg.io.OutputFileFactory;
import org.apache.iceberg.mr.TestHelper;
import org.apache.iceberg.mr.mapred.Container;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.relocated.com.google.common.collect.Maps;
import org.apache.iceberg.types.Types;
import org.apache.iceberg.util.SerializationUtil;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;

/* loaded from: input_file:org/apache/iceberg/mr/hive/TestHiveIcebergOutputCommitter.class */
public class TestHiveIcebergOutputCommitter {
    private static final long TARGET_FILE_SIZE = 134217728;
    private static final int RECORD_NUM = 5;
    private static final String QUERY_ID = "query_id";
    private static final JobID JOB_ID = new JobID("test", 0);
    private static final TaskAttemptID MAP_TASK_ID = new TaskAttemptID(JOB_ID.getJtIdentifier(), JOB_ID.getId(), TaskType.MAP, 0, 0);
    private static final TaskAttemptID REDUCE_TASK_ID = new TaskAttemptID(JOB_ID.getJtIdentifier(), JOB_ID.getId(), TaskType.REDUCE, 0, 0);
    private static final Schema CUSTOMER_SCHEMA = new Schema(new Types.NestedField[]{Types.NestedField.required(1, "customer_id", Types.LongType.get()), Types.NestedField.required(2, "first_name", Types.StringType.get())});
    private static final PartitionSpec PARTITIONED_SPEC = PartitionSpec.builderFor(CUSTOMER_SCHEMA).bucket("customer_id", 3).build();

    @TempDir
    private Path temp;

    @Test
    public void testNeedsTaskCommit() {
        HiveIcebergOutputCommitter hiveIcebergOutputCommitter = new HiveIcebergOutputCommitter();
        JobConf jobConf = new JobConf();
        jobConf.setNumMapTasks(10);
        jobConf.setNumReduceTasks(0);
        Assertions.assertThat(hiveIcebergOutputCommitter.needsTaskCommit(new TaskAttemptContextImpl(jobConf, MAP_TASK_ID))).isTrue();
        JobConf jobConf2 = new JobConf();
        jobConf2.setNumMapTasks(10);
        jobConf2.setNumReduceTasks(10);
        Assertions.assertThat(hiveIcebergOutputCommitter.needsTaskCommit(new TaskAttemptContextImpl(jobConf2, MAP_TASK_ID))).isFalse();
        Assertions.assertThat(hiveIcebergOutputCommitter.needsTaskCommit(new TaskAttemptContextImpl(jobConf2, REDUCE_TASK_ID))).isTrue();
    }

    @Test
    public void testSuccessfulUnpartitionedWrite() throws IOException {
        HiveIcebergOutputCommitter hiveIcebergOutputCommitter = new HiveIcebergOutputCommitter();
        Table table = table(this.temp.toFile().getPath(), false);
        JobConf jobConf = jobConf(table, 1);
        List<Record> writeRecords = writeRecords(table.name(), 1, 0, true, false, jobConf);
        hiveIcebergOutputCommitter.commitJob(new JobContextImpl(jobConf, JOB_ID));
        HiveIcebergTestUtils.validateFiles(table, jobConf, JOB_ID, 1);
        HiveIcebergTestUtils.validateData(table, writeRecords, 0);
    }

    @Test
    public void testSuccessfulPartitionedWrite() throws IOException {
        HiveIcebergOutputCommitter hiveIcebergOutputCommitter = new HiveIcebergOutputCommitter();
        Table table = table(this.temp.toFile().getPath(), true);
        JobConf jobConf = jobConf(table, 1);
        List<Record> writeRecords = writeRecords(table.name(), 1, 0, true, false, jobConf);
        hiveIcebergOutputCommitter.commitJob(new JobContextImpl(jobConf, JOB_ID));
        HiveIcebergTestUtils.validateFiles(table, jobConf, JOB_ID, 3);
        HiveIcebergTestUtils.validateData(table, writeRecords, 0);
    }

    @Test
    public void testSuccessfulMultipleTasksUnpartitionedWrite() throws IOException {
        HiveIcebergOutputCommitter hiveIcebergOutputCommitter = new HiveIcebergOutputCommitter();
        Table table = table(this.temp.toFile().getPath(), false);
        JobConf jobConf = jobConf(table, 2);
        List<Record> writeRecords = writeRecords(table.name(), 2, 0, true, false, jobConf);
        hiveIcebergOutputCommitter.commitJob(new JobContextImpl(jobConf, JOB_ID));
        HiveIcebergTestUtils.validateFiles(table, jobConf, JOB_ID, 2);
        HiveIcebergTestUtils.validateData(table, writeRecords, 0);
    }

    @Test
    public void testSuccessfulMultipleTasksPartitionedWrite() throws IOException {
        HiveIcebergOutputCommitter hiveIcebergOutputCommitter = new HiveIcebergOutputCommitter();
        Table table = table(this.temp.toFile().getPath(), true);
        JobConf jobConf = jobConf(table, 2);
        List<Record> writeRecords = writeRecords(table.name(), 2, 0, true, false, jobConf);
        hiveIcebergOutputCommitter.commitJob(new JobContextImpl(jobConf, JOB_ID));
        HiveIcebergTestUtils.validateFiles(table, jobConf, JOB_ID, 6);
        HiveIcebergTestUtils.validateData(table, writeRecords, 0);
    }

    @Test
    public void testRetryTask() throws IOException {
        HiveIcebergOutputCommitter hiveIcebergOutputCommitter = new HiveIcebergOutputCommitter();
        Table table = table(this.temp.toFile().getPath(), false);
        JobConf jobConf = jobConf(table, 2);
        writeRecords(table.name(), 2, 0, false, true, jobConf);
        HiveIcebergTestUtils.validateFiles(table, jobConf, JOB_ID, 0);
        HiveIcebergTestUtils.validateData(table, (List<Record>) Collections.emptyList(), 0);
        writeRecords(table.name(), 2, 1, false, false, jobConf);
        HiveIcebergTestUtils.validateFiles(table, jobConf, JOB_ID, 2);
        HiveIcebergTestUtils.validateData(table, (List<Record>) Collections.emptyList(), 0);
        List<Record> writeRecords = writeRecords(table.name(), 2, 2, true, false, jobConf);
        hiveIcebergOutputCommitter.commitJob(new JobContextImpl(jobConf, JOB_ID));
        HiveIcebergTestUtils.validateFiles(table, jobConf, JOB_ID, 4);
        HiveIcebergTestUtils.validateData(table, writeRecords, 0);
    }

    @Test
    public void testAbortJob() throws IOException {
        HiveIcebergOutputCommitter hiveIcebergOutputCommitter = new HiveIcebergOutputCommitter();
        Table table = table(this.temp.toFile().getPath(), false);
        JobConf jobConf = jobConf(table, 1);
        writeRecords(table.name(), 1, 0, true, false, jobConf);
        hiveIcebergOutputCommitter.abortJob(new JobContextImpl(jobConf, JOB_ID), JobStatus.State.FAILED);
        HiveIcebergTestUtils.validateFiles(table, jobConf, JOB_ID, 0);
        HiveIcebergTestUtils.validateData(table, (List<Record>) Collections.emptyList(), 0);
    }

    @Test
    public void writerIsClosedAfterTaskCommitFailure() throws IOException {
        HiveIcebergOutputCommitter hiveIcebergOutputCommitter = (HiveIcebergOutputCommitter) Mockito.spy(new HiveIcebergOutputCommitter());
        ArgumentCaptor forClass = ArgumentCaptor.forClass(TaskAttemptContextImpl.class);
        ((HiveIcebergOutputCommitter) Mockito.doThrow(new Throwable[]{new RuntimeException("Commit task failed!")}).when(hiveIcebergOutputCommitter)).commitTask((TaskAttemptContext) forClass.capture());
        Table table = table(this.temp.toFile().getPath(), false);
        JobConf jobConf = jobConf(table, 1);
        Assertions.assertThatThrownBy(() -> {
            writeRecords(table.name(), 1, 0, true, false, jobConf, hiveIcebergOutputCommitter);
        }).isInstanceOf(RuntimeException.class).hasMessage("Commit task failed!");
        Assertions.assertThat(forClass.getAllValues()).hasSize(1);
        TaskAttemptID taskAttemptWrapper = TezUtil.taskAttemptWrapper(((TaskAttemptContextImpl) forClass.getValue()).getTaskAttemptID());
        Assertions.assertThat(HiveIcebergRecordWriter.getWriters(taskAttemptWrapper)).isNotNull();
        hiveIcebergOutputCommitter.abortTask(new TaskAttemptContextImpl(jobConf, taskAttemptWrapper));
        Assertions.assertThat(HiveIcebergRecordWriter.getWriters(taskAttemptWrapper)).isNull();
    }

    private Table table(String str, boolean z) {
        return new HadoopTables().create(CUSTOMER_SCHEMA, z ? PARTITIONED_SPEC : PartitionSpec.unpartitioned(), ImmutableMap.of("iceberg.catalog", "location_based_table"), str);
    }

    private JobConf jobConf(Table table, int i) {
        JobConf jobConf = new JobConf();
        jobConf.setNumMapTasks(i);
        jobConf.setNumReduceTasks(0);
        jobConf.set(HiveConf.ConfVars.HIVEQUERYID.varname, QUERY_ID);
        jobConf.set("iceberg.mr.output.tables", table.name());
        jobConf.set("iceberg.mr.table.catalog." + table.name(), (String) table.properties().get("iceberg.catalog"));
        jobConf.set("iceberg.mr.serialized.table." + table.name(), SerializationUtil.serializeToBase64(table));
        HashMap newHashMap = Maps.newHashMap();
        TableDesc tableDesc = new TableDesc();
        tableDesc.setProperties(new Properties());
        tableDesc.getProperties().setProperty("name", table.name());
        tableDesc.getProperties().setProperty("location", table.location());
        tableDesc.getProperties().setProperty("iceberg.catalog", (String) table.properties().get("iceberg.catalog"));
        HiveIcebergStorageHandler.overlayTableProperties(jobConf, tableDesc, newHashMap);
        newHashMap.forEach((str, str2) -> {
            jobConf.set(str, str2);
        });
        return jobConf;
    }

    private List<Record> writeRecords(String str, int i, int i2, boolean z, boolean z2, JobConf jobConf, OutputCommitter outputCommitter) throws IOException {
        ArrayList newArrayListWithExpectedSize = Lists.newArrayListWithExpectedSize(RECORD_NUM * i);
        Table table = HiveIcebergStorageHandler.table(jobConf, str);
        FileIO io = table.io();
        Schema schema = HiveIcebergStorageHandler.schema(jobConf);
        PartitionSpec spec = table.spec();
        for (int i3 = 0; i3 < i; i3++) {
            List<Record> generateRandomRecords = TestHelper.generateRandomRecords(schema, RECORD_NUM, i3 + i2);
            TaskAttemptID taskAttemptID = new TaskAttemptID(JOB_ID.getJtIdentifier(), JOB_ID.getId(), TaskType.MAP, i3, i2);
            int id = taskAttemptID.getTaskID().getId();
            String str2 = "query_id-" + JOB_ID;
            FileFormat fileFormat = FileFormat.PARQUET;
            HiveIcebergRecordWriter hiveIcebergRecordWriter = new HiveIcebergRecordWriter(schema, spec, fileFormat, new GenericAppenderFactory(schema), OutputFileFactory.builderFor(table, id, i2).format(fileFormat).operationId(str2).build(), io, TARGET_FILE_SIZE, TezUtil.taskAttemptWrapper(taskAttemptID), jobConf.get("name"));
            Container container = new Container();
            Iterator<Record> it = generateRandomRecords.iterator();
            while (it.hasNext()) {
                container.set(it.next());
                hiveIcebergRecordWriter.write(container);
            }
            hiveIcebergRecordWriter.close(false);
            if (z) {
                outputCommitter.commitTask(new TaskAttemptContextImpl(jobConf, taskAttemptID));
                newArrayListWithExpectedSize.addAll(generateRandomRecords);
            } else if (z2) {
                outputCommitter.abortTask(new TaskAttemptContextImpl(jobConf, taskAttemptID));
            }
        }
        return newArrayListWithExpectedSize;
    }

    private List<Record> writeRecords(String str, int i, int i2, boolean z, boolean z2, JobConf jobConf) throws IOException {
        return writeRecords(str, i, i2, z, z2, jobConf, new HiveIcebergOutputCommitter());
    }
}
