package hex.tree.xgboost.exec;

import hex.genmodel.utils.IOUtils;
import hex.tree.xgboost.BoosterParms;
import hex.tree.xgboost.XGBoostModel;
import hex.tree.xgboost.XGBoostOutput;
import hex.tree.xgboost.exec.XGBoostExecReq;
import hex.tree.xgboost.matrix.FrameMatrixLoader;
import hex.tree.xgboost.matrix.MatrixLoader;
import hex.tree.xgboost.matrix.RemoteMatrixLoader;
import hex.tree.xgboost.rabit.RabitTrackerH2O;
import hex.tree.xgboost.remote.RemoteXGBoostUploadServlet;
import hex.tree.xgboost.task.XGBoostCleanupTask;
import hex.tree.xgboost.task.XGBoostSetupTask;
import hex.tree.xgboost.task.XGBoostUpdateTask;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import water.H2O;
import water.Key;
import water.fvec.Frame;

/* loaded from: input_file:hex/tree/xgboost/exec/LocalXGBoostExecutor.class */
public class LocalXGBoostExecutor implements XGBoostExecutor {
    public final Key modelKey;
    private final BoosterParms boosterParams;
    private final MatrixLoader loader;
    private final CheckpointProvider checkpointProvider;
    private final boolean[] nodes;
    private final String saveMatrixDirectory;
    private final RabitTrackerH2O rt;
    private XGBoostSetupTask setupTask;
    private XGBoostUpdateTask updateTask;

    /* loaded from: input_file:hex/tree/xgboost/exec/LocalXGBoostExecutor$CheckpointProvider.class */
    interface CheckpointProvider {
        byte[] get();
    }

    public LocalXGBoostExecutor(Key key, XGBoostExecReq.Init init) {
        this.modelKey = key;
        this.rt = setupRabitTracker(init.num_nodes);
        this.boosterParams = BoosterParms.fromMap(init.parms);
        this.nodes = new boolean[H2O.CLOUD.size()];
        for (int i = 0; i < init.num_nodes; i++) {
            this.nodes[i] = init.nodes[i] != null;
        }
        this.loader = new RemoteMatrixLoader(this.modelKey);
        this.saveMatrixDirectory = init.save_matrix_path;
        this.checkpointProvider = () -> {
            if (!init.has_checkpoint) {
                return null;
            }
            File checkpointFile = RemoteXGBoostUploadServlet.getCheckpointFile(this.modelKey.toString());
            ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
            try {
                try {
                    FileInputStream fileInputStream = new FileInputStream(checkpointFile);
                    Throwable th = null;
                    try {
                        try {
                            IOUtils.copyStream(fileInputStream, byteArrayOutputStream);
                            if (fileInputStream != null) {
                                if (0 != 0) {
                                    try {
                                        fileInputStream.close();
                                    } catch (Throwable th2) {
                                        th.addSuppressed(th2);
                                    }
                                } else {
                                    fileInputStream.close();
                                }
                            }
                            return byteArrayOutputStream.toByteArray();
                        } finally {
                        }
                    } catch (Throwable th3) {
                        if (fileInputStream != null) {
                            if (th != null) {
                                try {
                                    fileInputStream.close();
                                } catch (Throwable th4) {
                                    th.addSuppressed(th4);
                                }
                            } else {
                                fileInputStream.close();
                            }
                        }
                        throw th3;
                    }
                } catch (IOException e) {
                    throw new RuntimeException("Failed writing data to response.", e);
                }
            } finally {
                checkpointFile.delete();
            }
        };
    }

    public LocalXGBoostExecutor(XGBoostModel xGBoostModel, Frame frame) {
        this.modelKey = xGBoostModel._key;
        XGBoostSetupTask.FrameNodes findFrameNodes = XGBoostSetupTask.findFrameNodes(frame);
        this.rt = setupRabitTracker(findFrameNodes.getNumNodes());
        this.boosterParams = XGBoostModel.createParams((XGBoostModel.XGBoostParameters) xGBoostModel._parms, ((XGBoostOutput) xGBoostModel._output).nclasses(), xGBoostModel.model_info().dataInfo().coefNames());
        ((XGBoostOutput) xGBoostModel._output)._native_parameters = this.boosterParams.toTwoDimTable();
        this.loader = new FrameMatrixLoader(xGBoostModel, frame);
        this.nodes = findFrameNodes._nodes;
        this.saveMatrixDirectory = ((XGBoostModel.XGBoostParameters) xGBoostModel._parms)._save_matrix_directory;
        this.checkpointProvider = () -> {
            if (((XGBoostModel.XGBoostParameters) xGBoostModel._parms).hasCheckpoint()) {
                return xGBoostModel.model_info()._boosterBytes;
            }
            return null;
        };
    }

    @Override // hex.tree.xgboost.exec.XGBoostExecutor
    public byte[] setup() {
        this.setupTask = new XGBoostSetupTask(this.modelKey, this.saveMatrixDirectory, this.boosterParams, this.checkpointProvider.get(), getRabitEnv(), this.nodes, this.loader);
        this.setupTask.run();
        this.updateTask = new XGBoostUpdateTask(this.setupTask, 0).run();
        return this.updateTask.getBoosterBytes();
    }

    private RabitTrackerH2O setupRabitTracker(int i) {
        if (i <= 1) {
            return null;
        }
        RabitTrackerH2O rabitTrackerH2O = new RabitTrackerH2O(i);
        rabitTrackerH2O.start(0L);
        return rabitTrackerH2O;
    }

    private void stopRabitTracker() {
        if (this.rt != null) {
            this.rt.waitFor(0L);
            this.rt.stop();
        }
    }

    private Map<String, String> getRabitEnv() {
        return this.rt != null ? this.rt.getWorkerEnvs() : new HashMap();
    }

    @Override // hex.tree.xgboost.exec.XGBoostExecutor
    public void update(int i) {
        this.updateTask = new XGBoostUpdateTask(this.setupTask, i);
        this.updateTask.run();
    }

    @Override // hex.tree.xgboost.exec.XGBoostExecutor
    public byte[] updateBooster() {
        if (this.updateTask == null) {
            return null;
        }
        byte[] boosterBytes = this.updateTask.getBoosterBytes();
        this.updateTask = null;
        return boosterBytes;
    }

    @Override // java.lang.AutoCloseable
    public void close() {
        XGBoostCleanupTask.cleanUp(this.setupTask);
        stopRabitTracker();
    }
}
