package org.apache.crunch.impl.spark;

import com.google.common.base.Joiner;
import com.google.common.base.Objects;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.io.ByteStreams;
import java.io.File;
import java.io.IOException;
import java.io.Serializable;
import java.net.URI;
import java.util.ArrayList;
import java.util.Map;
import org.apache.crunch.CrunchRuntimeException;
import org.apache.crunch.DoFn;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.filecache.DistributedCache;
import org.apache.hadoop.mapred.SparkCounter;
import org.apache.hadoop.mapreduce.Counter;
import org.apache.hadoop.mapreduce.InputSplit;
import org.apache.hadoop.mapreduce.JobID;
import org.apache.hadoop.mapreduce.OutputCommitter;
import org.apache.hadoop.mapreduce.RecordReader;
import org.apache.hadoop.mapreduce.RecordWriter;
import org.apache.hadoop.mapreduce.StatusReporter;
import org.apache.hadoop.mapreduce.TaskAttemptID;
import org.apache.hadoop.mapreduce.TaskID;
import org.apache.hadoop.mapreduce.TaskInputOutputContext;
import org.apache.hadoop.mapreduce.task.MapContextImpl;
import org.apache.spark.Accumulator;
import org.apache.spark.SparkFiles;
import org.apache.spark.broadcast.Broadcast;

/* loaded from: input_file:org/apache/crunch/impl/spark/SparkRuntimeContext.class */
public class SparkRuntimeContext implements Serializable {
    private String jobName;
    private Broadcast<byte[]> broadConf;
    private final Accumulator<Map<String, Map<String, Long>>> counters;
    private transient Configuration conf;
    private transient TaskInputOutputContext context;
    private transient Integer lastTID;

    /* loaded from: input_file:org/apache/crunch/impl/spark/SparkRuntimeContext$SparkReporter.class */
    private static class SparkReporter extends StatusReporter implements Serializable {
        Accumulator<Map<String, Map<String, Long>>> accum;
        private transient Map<String, Map<String, Counter>> counters = Maps.newHashMap();

        public SparkReporter(Accumulator<Map<String, Map<String, Long>>> accumulator) {
            this.accum = accumulator;
        }

        public Counter getCounter(Enum<?> r5) {
            return getCounter(r5.getDeclaringClass().toString(), r5.name());
        }

        public Counter getCounter(String str, String str2) {
            Map<String, Counter> map = this.counters.get(str);
            if (map == null) {
                map = Maps.newTreeMap();
                this.counters.put(str, map);
            }
            if (!map.containsKey(str2)) {
                map.put(str2, new SparkCounter(str, str2, this.accum));
            }
            return map.get(str2);
        }

        public void progress() {
        }

        public float getProgress() {
            return 0.0f;
        }

        public void setStatus(String str) {
        }
    }

    public SparkRuntimeContext(String str, Accumulator<Map<String, Map<String, Long>>> accumulator, Broadcast<byte[]> broadcast) {
        this.jobName = str;
        this.counters = accumulator;
        this.broadConf = broadcast;
    }

    public void setConf(Broadcast<byte[]> broadcast) {
        this.broadConf = broadcast;
        this.conf = null;
    }

    public void initialize(DoFn<?, ?> doFn, Integer num) {
        TaskAttemptID taskAttemptID;
        if (this.context == null || !Objects.equal(this.lastTID, num)) {
            if (num != null) {
                taskAttemptID = new TaskAttemptID(new TaskID(new JobID(this.jobName, 0), false, num.intValue()), 0);
                this.lastTID = num;
            } else {
                taskAttemptID = new TaskAttemptID();
                this.lastTID = null;
            }
            configureLocalFiles();
            this.context = new MapContextImpl(getConfiguration(), taskAttemptID, (RecordReader) null, (RecordWriter) null, (OutputCommitter) null, new SparkReporter(this.counters), (InputSplit) null);
        }
        doFn.setContext(this.context);
        doFn.initialize();
    }

    private void configureLocalFiles() {
        try {
            URI[] cacheFiles = DistributedCache.getCacheFiles(getConfiguration());
            if (cacheFiles != null) {
                ArrayList newArrayList = Lists.newArrayList();
                for (URI uri : cacheFiles) {
                    newArrayList.add(SparkFiles.get(new File(uri.getPath()).getName()));
                }
                String join = Joiner.on(',').join(newArrayList);
                getConfiguration().set("mapreduce.job.cache.local.files", join);
                getConfiguration().set("mapred.cache.localFiles", join);
            }
        } catch (IOException e) {
            throw new CrunchRuntimeException(e);
        }
    }

    public Configuration getConfiguration() {
        if (this.conf == null) {
            this.conf = new Configuration();
            try {
                this.conf.readFields(ByteStreams.newDataInput((byte[]) this.broadConf.value()));
            } catch (Exception e) {
                throw new RuntimeException("Error reading broadcast configuration", e);
            }
        }
        return this.conf;
    }
}
