package ml.dmlc.xgboost4j.java;

import hex.tree.xgboost.BoosterParms;
import hex.tree.xgboost.XGBoostModel;
import java.lang.Thread;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.TimeUnit;
import water.H2O;
import water.Key;
import water.nbhm.NonBlockingHashMap;
import water.util.Log;

/* loaded from: input_file:ml/dmlc/xgboost4j/java/XGBoostUpdater.class */
public class XGBoostUpdater extends Thread {
    private static long WORK_START_TIMEOUT_SECS = 300;
    private static long INACTIVE_CHECK_INTERVAL_SECS = 60;
    private static final NonBlockingHashMap<Key<XGBoostModel>, XGBoostUpdater> updaters = new NonBlockingHashMap<>();
    private final Key<XGBoostModel> _modelKey;
    private final DMatrix _trainMat;
    private final BoosterParms _boosterParms;
    private final Map<String, String> _rabitEnv;
    private volatile SynchronousQueue<BoosterCallable<?>> _in;
    private volatile SynchronousQueue<Object> _out;
    private Booster _booster;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ml/dmlc/xgboost4j/java/XGBoostUpdater$BoosterCallable.class */
    public interface BoosterCallable<E> {
        E call() throws XGBoostError;
    }

    /* loaded from: input_file:ml/dmlc/xgboost4j/java/XGBoostUpdater$LoggingExceptionHandler.class */
    private static class LoggingExceptionHandler implements Thread.UncaughtExceptionHandler {
        private static LoggingExceptionHandler INSTANCE = new LoggingExceptionHandler();

        private LoggingExceptionHandler() {
        }

        @Override // java.lang.Thread.UncaughtExceptionHandler
        public void uncaughtException(Thread thread, Throwable th) {
            Log.err(new Object[]{"Uncaught exception in " + thread.getName()});
            Log.err(th);
        }
    }

    /* loaded from: input_file:ml/dmlc/xgboost4j/java/XGBoostUpdater$SerializeBooster.class */
    private class SerializeBooster implements BoosterCallable<byte[]> {
        private SerializeBooster() {
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // ml.dmlc.xgboost4j.java.XGBoostUpdater.BoosterCallable
        public byte[] call() throws XGBoostError {
            return XGBoostUpdater.this._booster.toByteArray();
        }

        public String toString() {
            return "SerializeBooster";
        }
    }

    /* loaded from: input_file:ml/dmlc/xgboost4j/java/XGBoostUpdater$UpdateBooster.class */
    private class UpdateBooster implements BoosterCallable<Booster> {
        private final int _tid;
        static final /* synthetic */ boolean $assertionsDisabled;

        private UpdateBooster(int i) {
            this._tid = i;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // ml.dmlc.xgboost4j.java.XGBoostUpdater.BoosterCallable
        public Booster call() throws XGBoostError {
            if (XGBoostUpdater.this._booster == null && this._tid == 0) {
                XGBoostUpdater.this._booster = XGBoost.train(XGBoostUpdater.this._trainMat, XGBoostUpdater.this._boosterParms.get(), 0, new HashMap(), (IObjective) null, (IEvaluation) null);
                Log.info(new Object[]{"Initial (0 tree) Booster created, size=" + XGBoostUpdater.this._booster.toByteArray().length});
            } else {
                if (!$assertionsDisabled && XGBoostUpdater.this._booster == null) {
                    throw new AssertionError();
                }
                XGBoostUpdater.this._booster.update(XGBoostUpdater.this._trainMat, this._tid);
                XGBoostUpdater.this._booster.saveRabitCheckpoint();
            }
            return XGBoostUpdater.this._booster;
        }

        public String toString() {
            return "Boosting Iteration (tid=" + this._tid + ")";
        }

        static {
            $assertionsDisabled = !XGBoostUpdater.class.desiredAssertionStatus();
        }
    }

    private XGBoostUpdater(Key<XGBoostModel> key, DMatrix dMatrix, BoosterParms boosterParms, Map<String, String> map) {
        super("XGBoostUpdater-" + key);
        this._modelKey = key;
        this._trainMat = dMatrix;
        this._boosterParms = boosterParms;
        this._rabitEnv = map;
        this._in = new SynchronousQueue<>();
        this._out = new SynchronousQueue<>();
    }

    @Override // java.lang.Thread, java.lang.Runnable
    public void run() {
        try {
            try {
                Rabit.init(this._rabitEnv);
                while (!interrupted()) {
                    this._out.put(this._in.take().call());
                }
                this._in = null;
                this._out = null;
                updaters.remove(this._modelKey);
                try {
                    this._trainMat.dispose();
                    if (this._booster != null) {
                        this._booster.dispose();
                    }
                } catch (Exception e) {
                    Log.warn(new Object[]{"Failed to dispose of training matrix/booster", e});
                }
                try {
                    Rabit.shutdown();
                } catch (Exception e2) {
                    Log.warn(new Object[]{"Rabit shutdown during update failed", e2});
                }
            } catch (Throwable th) {
                this._in = null;
                this._out = null;
                updaters.remove(this._modelKey);
                try {
                    this._trainMat.dispose();
                    if (this._booster != null) {
                        this._booster.dispose();
                    }
                } catch (Exception e3) {
                    Log.warn(new Object[]{"Failed to dispose of training matrix/booster", e3});
                }
                try {
                    Rabit.shutdown();
                } catch (Exception e4) {
                    Log.warn(new Object[]{"Rabit shutdown during update failed", e4});
                }
                throw th;
            }
        } catch (InterruptedException e5) {
            XGBoostUpdater xGBoostUpdater = (XGBoostUpdater) updaters.get(this._modelKey);
            if (xGBoostUpdater != null) {
                Log.err(new Object[]{"Updater thread was interrupted while it was still registered, name=" + xGBoostUpdater.getName()});
                Log.err(e5);
            } else {
                Log.debug(new Object[]{"Updater thread interrupted.", e5});
            }
            this._in = null;
            this._out = null;
            updaters.remove(this._modelKey);
            try {
                this._trainMat.dispose();
                if (this._booster != null) {
                    this._booster.dispose();
                }
            } catch (Exception e6) {
                Log.warn(new Object[]{"Failed to dispose of training matrix/booster", e6});
            }
            try {
                Rabit.shutdown();
            } catch (Exception e7) {
                Log.warn(new Object[]{"Rabit shutdown during update failed", e7});
            }
        } catch (XGBoostError e8) {
            Log.err(new Object[]{"XGBoost training iteration failed"});
            Log.err(e8);
            this._in = null;
            this._out = null;
            updaters.remove(this._modelKey);
            try {
                this._trainMat.dispose();
                if (this._booster != null) {
                    this._booster.dispose();
                }
            } catch (Exception e9) {
                Log.warn(new Object[]{"Failed to dispose of training matrix/booster", e9});
            }
            try {
                Rabit.shutdown();
            } catch (Exception e10) {
                Log.warn(new Object[]{"Rabit shutdown during update failed", e10});
            }
        }
    }

    private <T> T invoke(BoosterCallable<T> boosterCallable) throws InterruptedException {
        SynchronousQueue<BoosterCallable<?>> synchronousQueue = this._in;
        if (synchronousQueue == null) {
            throw new IllegalStateException("Updater is inactive on node " + H2O.SELF);
        }
        if (!synchronousQueue.offer(boosterCallable, WORK_START_TIMEOUT_SECS, TimeUnit.SECONDS)) {
            throw new IllegalStateException("XGBoostUpdater couldn't start work on task " + boosterCallable + " in " + WORK_START_TIMEOUT_SECS + "s.");
        }
        int i = 0;
        while (true) {
            SynchronousQueue<Object> synchronousQueue2 = this._out;
            if (synchronousQueue2 == null) {
                throw new IllegalStateException("Cannot perform booster operation: updater is inactive on node " + H2O.SELF);
            }
            i++;
            T t = (T) synchronousQueue2.poll(INACTIVE_CHECK_INTERVAL_SECS, TimeUnit.SECONDS);
            if (t != null) {
                return t;
            }
            if (i > 5) {
                Log.warn(new Object[]{String.format("Exceeded waiting interval of %d seconds for a task of type '%s' to finish on node '%s'. ", Long.valueOf(INACTIVE_CHECK_INTERVAL_SECS * i), boosterCallable, H2O.SELF)});
            }
        }
    }

    Booster getBooster() {
        return this._booster;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public byte[] getBoosterBytes() {
        try {
            return (byte[]) invoke(new SerializeBooster());
        } catch (InterruptedException e) {
            throw new IllegalStateException("Failed to serialize Booster - operation was interrupted", e);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Booster doUpdate(int i) {
        try {
            return (Booster) invoke(new UpdateBooster(i));
        } catch (InterruptedException e) {
            throw new IllegalStateException("Boosting iteration failed - operation was interrupted", e);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static XGBoostUpdater make(Key<XGBoostModel> key, DMatrix dMatrix, BoosterParms boosterParms, Map<String, String> map) {
        XGBoostUpdater xGBoostUpdater = new XGBoostUpdater(key, dMatrix, boosterParms, map);
        xGBoostUpdater.setUncaughtExceptionHandler(LoggingExceptionHandler.INSTANCE);
        if (updaters.putIfAbsent(key, xGBoostUpdater) != null) {
            throw new IllegalStateException("XGBoostUpdater for modelKey=" + key + " already exists!");
        }
        return xGBoostUpdater;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void terminate(Key<XGBoostModel> key) {
        XGBoostUpdater xGBoostUpdater = (XGBoostUpdater) updaters.remove(key);
        if (xGBoostUpdater == null) {
            Log.debug(new Object[]{"XGBoostUpdater for modelKey=" + key + " was already clean-up on node " + H2O.SELF});
        } else {
            xGBoostUpdater.interrupt();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static XGBoostUpdater getUpdater(Key<XGBoostModel> key) {
        XGBoostUpdater xGBoostUpdater = (XGBoostUpdater) updaters.get(key);
        if (xGBoostUpdater == null) {
            throw new IllegalStateException("XGBoostUpdater for modelKey=" + key + " was not found!");
        }
        return xGBoostUpdater;
    }
}
