package hex.tree.xgboost.rabit;

import hex.tree.xgboost.rabit.util.LinkMap;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.channels.ServerSocketChannel;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import ml.dmlc.xgboost4j.java.IRabitTracker;
import water.H2O;
import water.util.Log;

/* loaded from: input_file:hex/tree/xgboost/rabit/RabitTrackerH2O.class */
public class RabitTrackerH2O implements IRabitTracker {
    public static final int MAGIC = 65433;
    private ServerSocketChannel sock;
    private int workers;
    private RabitTrackerH2OThread trackerThread;
    private int port = 9091;
    private Map<String, String> envs = new HashMap();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:hex/tree/xgboost/rabit/RabitTrackerH2O$RabitTrackerH2OThread.class */
    public class RabitTrackerH2OThread extends Thread {
        private RabitTrackerH2O tracker;
        private LinkMap linkMap;
        private Map<String, Integer> jobToRankMap;
        private static final String PRINT_CMD = "print";
        private static final String SHUTDOWN_CMD = "shutdown";
        private static final String START_CMD = "start";
        private static final String RECOVER_CMD = "recover";
        private static final String NULL_STR = "null";
        static final /* synthetic */ boolean $assertionsDisabled;

        private RabitTrackerH2OThread(RabitTrackerH2O rabitTrackerH2O) {
            this.jobToRankMap = new HashMap();
            setPriority(9);
            setName("TCP-" + rabitTrackerH2O.sock);
            this.tracker = rabitTrackerH2O;
        }

        @Override // java.lang.Thread, java.lang.Runnable
        public void run() {
            HashSet hashSet = new HashSet();
            Map<Integer, RabitWorker> hashMap = new HashMap<>();
            ArrayList<RabitWorker> arrayList = new ArrayList();
            ArrayDeque arrayDeque = new ArrayDeque(this.tracker.workers);
            while (!interrupted() && hashSet.size() != this.tracker.workers) {
                try {
                    RabitWorker rabitWorker = new RabitWorker(this.tracker.sock.accept());
                    if (PRINT_CMD.equals(rabitWorker.cmd)) {
                        Log.warn(new Object[]{"Rabit worker: ", rabitWorker.receiver().getStr()});
                    } else if (SHUTDOWN_CMD.equals(rabitWorker.cmd)) {
                        if (!$assertionsDisabled && (rabitWorker.rank < 0 || hashSet.contains(Integer.valueOf(rabitWorker.rank)))) {
                            throw new AssertionError();
                        }
                        if (!$assertionsDisabled && hashMap.containsKey(rabitWorker)) {
                            throw new AssertionError();
                        }
                        hashSet.add(Integer.valueOf(rabitWorker.rank));
                        Log.debug(new Object[]{"Received ", rabitWorker.cmd, " signal from ", Integer.valueOf(rabitWorker.rank)});
                    } else {
                        if (!$assertionsDisabled && !START_CMD.equals(rabitWorker.cmd) && !RECOVER_CMD.equals(rabitWorker.cmd)) {
                            throw new AssertionError();
                        }
                        if (null == this.linkMap) {
                            if (!$assertionsDisabled && !START_CMD.equals(rabitWorker.cmd)) {
                                throw new AssertionError();
                            }
                            this.linkMap = new LinkMap(this.tracker.workers);
                            for (int i = 0; i < this.tracker.workers; i++) {
                                arrayDeque.add(Integer.valueOf(i));
                            }
                        } else if (!$assertionsDisabled && rabitWorker.worldSize != -1 && rabitWorker.worldSize != this.tracker.workers) {
                            throw new AssertionError();
                        }
                        if (RECOVER_CMD.equals(rabitWorker.cmd) && !$assertionsDisabled && rabitWorker.rank < 0) {
                            throw new AssertionError();
                        }
                        int decideRank = rabitWorker.decideRank(this.jobToRankMap);
                        if (-1 != decideRank) {
                            rabitWorker.assignRank(decideRank, hashMap, this.linkMap);
                            if (rabitWorker.waitAccept > 0) {
                                hashMap.put(Integer.valueOf(decideRank), rabitWorker);
                            }
                        } else {
                            if (!$assertionsDisabled && arrayDeque.size() == 0) {
                                throw new AssertionError();
                            }
                            arrayList.add(rabitWorker);
                            if (arrayList.size() == arrayDeque.size()) {
                                Collections.sort(arrayList);
                                for (RabitWorker rabitWorker2 : arrayList) {
                                    int intValue = ((Integer) arrayDeque.poll()).intValue();
                                    if (!NULL_STR.equals(rabitWorker2.jobId)) {
                                        this.jobToRankMap.put(rabitWorker2.jobId, Integer.valueOf(intValue));
                                    }
                                    rabitWorker2.assignRank(intValue, hashMap, this.linkMap);
                                    if (rabitWorker2.waitAccept > 0) {
                                        hashMap.put(Integer.valueOf(intValue), rabitWorker2);
                                    }
                                    Log.debug(new Object[]{"Received " + rabitWorker2.cmd + " signal from " + rabitWorker2.host + ":" + rabitWorker2.workerPort + ". Assigned rank " + rabitWorker2.rank});
                                }
                            }
                            if (arrayDeque.isEmpty()) {
                                Log.debug(new Object[]{"All " + this.tracker.workers + " Rabit workers are getting started."});
                            }
                        }
                    }
                } catch (IOException e) {
                    Log.debug(new Object[]{"Exception in Rabit tracker.", e});
                }
            }
            Log.debug(new Object[]{"All Rabit nodes finished."});
        }

        static {
            $assertionsDisabled = !RabitTrackerH2O.class.desiredAssertionStatus();
        }
    }

    public RabitTrackerH2O(int i) {
        if (i < 1) {
            throw new IllegalStateException("workers must be greater than or equal to one (1).");
        }
        this.workers = i;
        Log.debug(new Object[]{"Rabit tracker started on port ", Integer.valueOf(this.port)});
    }

    public Map<String, String> getWorkerEnvs() {
        this.envs.put("DMLC_NUM_WORKER", String.valueOf(this.workers));
        this.envs.put("DMLC_NUM_SERVER", "0");
        this.envs.put("DMLC_TRACKER_URI", H2O.SELF_ADDRESS.getHostAddress());
        this.envs.put("DMLC_TRACKER_PORT", Integer.toString(this.port));
        this.envs.put("rabit_world_size", Integer.toString(this.workers));
        return this.envs;
    }

    public boolean start(long j) {
        boolean z = true;
        while (z) {
            try {
                this.sock = ServerSocketChannel.open();
                this.sock.socket().setReceiveBufferSize(65536);
                this.sock.socket().bind(new InetSocketAddress(H2O.SELF_ADDRESS, this.port));
                z = false;
            } catch (IOException e) {
                this.port++;
                if (this.port > 9999) {
                    throw new RuntimeException("Failed to bind Rabit tracker to a socket in range 9091-9999", e);
                }
            }
        }
        if (null != this.trackerThread) {
            throw new IllegalStateException("Rabit tracker already started.");
        }
        RabitTrackerH2OThread rabitTrackerH2OThread = new RabitTrackerH2OThread(this);
        rabitTrackerH2OThread.setDaemon(true);
        rabitTrackerH2OThread.start();
        this.trackerThread = rabitTrackerH2OThread;
        return true;
    }

    public void stop() {
        if (null != this.trackerThread) {
            this.trackerThread.interrupt();
            this.trackerThread = null;
            try {
                try {
                    this.sock.close();
                    this.port = 9091;
                } catch (IOException e) {
                    throw new RuntimeException("Failed to close Rabit tracker socket.", e);
                }
            } catch (Throwable th) {
                this.port = 9091;
                throw th;
            }
        }
    }

    public int waitFor(long j) {
        while (null != this.trackerThread && this.trackerThread.isAlive()) {
            try {
                this.trackerThread.join(j);
                this.trackerThread = null;
            } catch (InterruptedException e) {
                Log.debug(new Object[]{"Rabit tracker thread got suddenly interrupted.", e});
            }
        }
        return 0;
    }

    public void uncaughtException(Thread thread, Throwable th) {
        stop();
    }
}
