package org.apache.crunch.impl.spark;

import com.google.common.base.Joiner;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import java.io.File;
import java.io.IOException;
import java.io.Serializable;
import java.lang.reflect.Method;
import java.net.URI;
import java.util.ArrayList;
import java.util.Map;
import javassist.util.proxy.MethodFilter;
import javassist.util.proxy.MethodHandler;
import javassist.util.proxy.ProxyFactory;
import org.apache.crunch.CrunchRuntimeException;
import org.apache.crunch.DoFn;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.filecache.DistributedCache;
import org.apache.hadoop.mapreduce.Counter;
import org.apache.hadoop.mapreduce.OutputCommitter;
import org.apache.hadoop.mapreduce.RecordWriter;
import org.apache.hadoop.mapreduce.StatusReporter;
import org.apache.hadoop.mapreduce.TaskAttemptID;
import org.apache.hadoop.mapreduce.TaskInputOutputContext;
import org.apache.spark.Accumulator;
import org.apache.spark.SparkFiles;
import org.apache.spark.broadcast.Broadcast;

/* loaded from: input_file:org/apache/crunch/impl/spark/SparkRuntimeContext.class */
public class SparkRuntimeContext implements Serializable {
    private Broadcast<Configuration> broadConf;
    private Accumulator<Map<String, Long>> counters;
    private transient TaskInputOutputContext context;

    public SparkRuntimeContext(Broadcast<Configuration> broadcast, Accumulator<Map<String, Long>> accumulator) {
        this.broadConf = broadcast;
        this.counters = accumulator;
    }

    public void initialize(DoFn<?, ?> doFn) {
        if (this.context == null) {
            configureLocalFiles();
            this.context = getTaskIOContext(this.broadConf, this.counters);
        }
        doFn.setContext(this.context);
        doFn.initialize();
    }

    private void configureLocalFiles() {
        try {
            URI[] cacheFiles = DistributedCache.getCacheFiles(getConfiguration());
            if (cacheFiles != null) {
                ArrayList newArrayList = Lists.newArrayList();
                for (URI uri : cacheFiles) {
                    File file = new File(uri.getPath());
                    SparkFiles.get(file.getName());
                    newArrayList.add(SparkFiles.get(file.getName()));
                }
                String join = Joiner.on(',').join(newArrayList);
                getConfiguration().set("mapreduce.job.cache.local.files", join);
                getConfiguration().set("mapred.cache.localFiles", join);
            }
        } catch (IOException e) {
            throw new CrunchRuntimeException(e);
        }
    }

    public Configuration getConfiguration() {
        return (Configuration) this.broadConf.value();
    }

    public static TaskInputOutputContext getTaskIOContext(final Broadcast<Configuration> broadcast, final Accumulator<Map<String, Long>> accumulator) {
        ProxyFactory proxyFactory = new ProxyFactory();
        Class[] clsArr = new Class[0];
        Object[] objArr = new Object[0];
        final TaskAttemptID taskAttemptID = new TaskAttemptID();
        if (TaskInputOutputContext.class.isInterface()) {
            proxyFactory.setInterfaces(new Class[]{TaskInputOutputContext.class});
        } else {
            clsArr = new Class[]{Configuration.class, TaskAttemptID.class, RecordWriter.class, OutputCommitter.class, StatusReporter.class};
            objArr = new Object[]{broadcast.value(), taskAttemptID, null, null, null};
            proxyFactory.setSuperclass(TaskInputOutputContext.class);
        }
        final ImmutableSet of = ImmutableSet.of("getConfiguration", "getCounter", "progress", "getTaskAttemptID");
        proxyFactory.setFilter(new MethodFilter() { // from class: org.apache.crunch.impl.spark.SparkRuntimeContext.1
            public boolean isHandled(Method method) {
                return of.contains(method.getName());
            }
        });
        try {
            return (TaskInputOutputContext) proxyFactory.create(clsArr, objArr, new MethodHandler() { // from class: org.apache.crunch.impl.spark.SparkRuntimeContext.2
                public Object invoke(Object obj, Method method, Method method2, Object[] objArr2) throws Throwable {
                    String name = method.getName();
                    if ("getConfiguration".equals(name)) {
                        return broadcast.value();
                    }
                    if ("progress".equals(name)) {
                        return null;
                    }
                    if ("getTaskAttemptID".equals(name)) {
                        return taskAttemptID;
                    }
                    if ("getCounter".equals(name)) {
                        return objArr2.length == 1 ? SparkRuntimeContext.getCounter(accumulator, objArr2[0].getClass().getName(), ((Enum) objArr2[0]).name()) : SparkRuntimeContext.getCounter(accumulator, (String) objArr2[0], (String) objArr2[1]);
                    }
                    throw new IllegalStateException("Unhandled method " + name);
                }
            });
        } catch (Exception e) {
            e.printStackTrace();
            throw new RuntimeException(e);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Counter getCounter(final Accumulator<Map<String, Long>> accumulator, final String str, final String str2) {
        ProxyFactory proxyFactory = new ProxyFactory();
        Class[] clsArr = new Class[0];
        Object[] objArr = new Object[0];
        if (Counter.class.isInterface()) {
            proxyFactory.setInterfaces(new Class[]{Counter.class});
        } else {
            clsArr = new Class[]{String.class, String.class};
            objArr = new Object[]{str, str2};
            proxyFactory.setSuperclass(Counter.class);
        }
        final ImmutableSet of = ImmutableSet.of("getDisplayName", "getName", "getValue", "increment", "setValue", "setDisplayName", new String[0]);
        proxyFactory.setFilter(new MethodFilter() { // from class: org.apache.crunch.impl.spark.SparkRuntimeContext.3
            public boolean isHandled(Method method) {
                return of.contains(method.getName());
            }
        });
        try {
            return (Counter) proxyFactory.create(clsArr, objArr, new MethodHandler() { // from class: org.apache.crunch.impl.spark.SparkRuntimeContext.4
                public Object invoke(Object obj, Method method, Method method2, Object[] objArr2) throws Throwable {
                    String name = method.getName();
                    if ("increment".equals(name)) {
                        accumulator.add(ImmutableMap.of(str + ":" + str2, (Long) objArr2[0]));
                        return null;
                    }
                    if (!"getDisplayName".equals(name) && !"getName".equals(name)) {
                        if ("setDisplayName".equals(name)) {
                            return null;
                        }
                        if ("setValue".equals(name)) {
                            throw new UnsupportedOperationException("Cannot set counter values in Spark, only increment them");
                        }
                        if ("getValue".equals(name)) {
                            throw new UnsupportedOperationException("Cannot read counters during Spark execution");
                        }
                        throw new IllegalStateException("Unhandled method " + name);
                    }
                    return str2;
                }
            });
        } catch (Exception e) {
            e.printStackTrace();
            throw new RuntimeException(e);
        }
    }
}
