package hex.tree.xgboost.remote;

import hex.genmodel.utils.IOUtils;
import hex.schemas.XGBoostExecRespV3;
import hex.tree.xgboost.matrix.RemoteMatrixLoader;
import hex.tree.xgboost.matrix.SparseMatrixDimensions;
import hex.tree.xgboost.task.XGBoostUploadMatrixTask;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.apache.log4j.Logger;
import water.AutoBuffer;
import water.BootstrapFreezable;
import water.H2O;
import water.Key;
import water.TypeMap;
import water.server.ServletUtils;

/* loaded from: input_file:hex/tree/xgboost/remote/RemoteXGBoostUploadServlet.class */
public class RemoteXGBoostUploadServlet extends HttpServlet {
    private static final Logger LOG = Logger.getLogger(RemoteXGBoostUploadServlet.class);

    /* loaded from: input_file:hex/tree/xgboost/remote/RemoteXGBoostUploadServlet$RequestType.class */
    public enum RequestType {
        checkpoint,
        sparseMatrixDimensions,
        sparseMatrixChunk,
        denseMatrixDimensions,
        denseMatrixChunk,
        matrixData
    }

    public static File getUploadDir(String str) {
        return new File(H2O.ICE_ROOT.toString(), str);
    }

    public static File getCheckpointFile(String str) {
        File uploadDir = getUploadDir(str);
        if (uploadDir.mkdirs()) {
            LOG.debug("Created temporary directory " + uploadDir);
        }
        return new File(getUploadDir(str), "checkpoint.bin");
    }

    protected void doPost(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) {
        String decodedUri = ServletUtils.getDecodedUri(httpServletRequest);
        try {
            try {
                String parameter = httpServletRequest.getParameter("model_key");
                String parameter2 = httpServletRequest.getParameter("data_type");
                LOG.info("Upload request for " + parameter + " " + parameter2 + " received");
                RequestType valueOf = RequestType.valueOf(parameter2);
                if (valueOf == RequestType.checkpoint) {
                    saveIntoFile(getCheckpointFile(parameter), httpServletRequest);
                } else {
                    handleMatrixRequest(parameter, valueOf, httpServletRequest);
                }
                httpServletResponse.setContentType("application/json");
                httpServletResponse.getWriter().write(new XGBoostExecRespV3(Key.make(parameter)).toJsonString());
                ServletUtils.logRequest("POST", httpServletRequest, httpServletResponse);
            } catch (Exception e) {
                ServletUtils.sendErrorResponse(httpServletResponse, e, decodedUri);
                ServletUtils.logRequest("POST", httpServletRequest, httpServletResponse);
            }
        } catch (Throwable th) {
            ServletUtils.logRequest("POST", httpServletRequest, httpServletResponse);
            throw th;
        }
    }

    private void handleMatrixRequest(String str, RequestType requestType, HttpServletRequest httpServletRequest) throws IOException {
        AutoBuffer autoBuffer = new AutoBuffer(httpServletRequest.getInputStream(), TypeMap.bootstrapClasses());
        Throwable th = null;
        try {
            try {
                BootstrapFreezable bootstrapFreezable = autoBuffer.get();
                if (autoBuffer != null) {
                    if (0 != 0) {
                        try {
                            autoBuffer.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        autoBuffer.close();
                    }
                }
                switch (requestType) {
                    case sparseMatrixDimensions:
                        RemoteMatrixLoader.initSparse(str, (SparseMatrixDimensions) bootstrapFreezable);
                        return;
                    case sparseMatrixChunk:
                        RemoteMatrixLoader.sparseChunk(str, (XGBoostUploadMatrixTask.SparseMatrixChunk) bootstrapFreezable);
                        return;
                    case denseMatrixDimensions:
                        RemoteMatrixLoader.initDense(str, (XGBoostUploadMatrixTask.DenseMatrixDimensions) bootstrapFreezable);
                        return;
                    case denseMatrixChunk:
                        RemoteMatrixLoader.denseChunk(str, (XGBoostUploadMatrixTask.DenseMatrixChunk) bootstrapFreezable);
                        return;
                    case matrixData:
                        RemoteMatrixLoader.matrixData(str, (XGBoostUploadMatrixTask.MatrixData) bootstrapFreezable);
                        return;
                    default:
                        throw new IllegalArgumentException("Unexpected request type: " + requestType);
                }
            } catch (Throwable th3) {
                th = th3;
                throw th3;
            }
        } catch (Throwable th4) {
            if (autoBuffer != null) {
                if (th != null) {
                    try {
                        autoBuffer.close();
                    } catch (Throwable th5) {
                        th.addSuppressed(th5);
                    }
                } else {
                    autoBuffer.close();
                }
            }
            throw th4;
        }
    }

    private void saveIntoFile(File file, HttpServletRequest httpServletRequest) throws IOException {
        LOG.debug("Saving contents into " + file);
        ServletInputStream inputStream = httpServletRequest.getInputStream();
        FileOutputStream fileOutputStream = new FileOutputStream(file);
        Throwable th = null;
        try {
            IOUtils.copyStream(inputStream, fileOutputStream);
            if (fileOutputStream != null) {
                if (0 == 0) {
                    fileOutputStream.close();
                    return;
                }
                try {
                    fileOutputStream.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
        } catch (Throwable th3) {
            if (fileOutputStream != null) {
                if (0 != 0) {
                    try {
                        fileOutputStream.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    fileOutputStream.close();
                }
            }
            throw th3;
        }
    }
}
