package org.apache.lens.ml.server;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import javax.ws.rs.BadRequestException;
import javax.ws.rs.Consumes;
import javax.ws.rs.DELETE;
import javax.ws.rs.GET;
import javax.ws.rs.NotFoundException;
import javax.ws.rs.POST;
import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
import javax.ws.rs.Produces;
import javax.ws.rs.core.Context;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.MultivaluedMap;
import javax.ws.rs.core.Response;
import javax.ws.rs.core.UriInfo;
import org.apache.commons.lang.StringUtils;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.lens.api.LensSessionHandle;
import org.apache.lens.api.StringList;
import org.apache.lens.ml.algo.api.MLModel;
import org.apache.lens.ml.api.MLTestReport;
import org.apache.lens.ml.api.ModelMetadata;
import org.apache.lens.ml.api.TestReport;
import org.apache.lens.ml.impl.ModelLoader;
import org.apache.lens.server.api.ServiceProvider;
import org.apache.lens.server.api.ServiceProviderFactory;
import org.apache.lens.server.api.error.LensException;
import org.glassfish.jersey.media.multipart.FormDataParam;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Produces({"application/json", "application/xml"})
@Path("/ml")
/* loaded from: input_file:org/apache/lens/ml/server/MLServiceResource.class */
public class MLServiceResource {
    MLService mlService;
    ServiceProvider serviceProvider;
    ServiceProviderFactory serviceProviderFactory = getServiceProviderFactory(HIVE_CONF);
    public static final String ML_UP_MESSAGE = "ML service is up";
    private static final Logger log = LoggerFactory.getLogger(MLServiceResource.class);
    private static final HiveConf HIVE_CONF = new HiveConf();

    private ServiceProvider getServiceProvider() {
        if (this.serviceProvider == null) {
            this.serviceProvider = this.serviceProviderFactory.getServiceProvider();
        }
        return this.serviceProvider;
    }

    private ServiceProviderFactory getServiceProviderFactory(HiveConf hiveConf) {
        try {
            return (ServiceProviderFactory) hiveConf.getClass("lens.server.service.provider.factory", ServiceProviderFactory.class).newInstance();
        } catch (IllegalAccessException e) {
            throw new RuntimeException(e);
        } catch (InstantiationException e2) {
            throw new RuntimeException(e2);
        }
    }

    private MLService getMlService() {
        if (this.mlService == null) {
            this.mlService = getServiceProvider().getService("ml");
        }
        return this.mlService;
    }

    @GET
    public String mlResourceUp() {
        return ML_UP_MESSAGE;
    }

    @GET
    @Path("algos")
    public StringList getAlgoNames() {
        return new StringList(getMlService().getAlgorithms());
    }

    @GET
    @Path("algos/{algorithm}")
    public StringList getParamDescription(@PathParam("algorithm") String str) {
        Map<String, String> algoParamDescription = getMlService().getAlgoParamDescription(str);
        if (algoParamDescription == null) {
            throw new NotFoundException("Param description not found for " + str);
        }
        ArrayList arrayList = new ArrayList();
        for (String str2 : algoParamDescription.keySet()) {
            arrayList.add(str2 + " : " + algoParamDescription.get(str2));
        }
        return new StringList(arrayList);
    }

    @GET
    @Path("models/{algorithm}")
    public StringList getModelsForAlgo(@PathParam("algorithm") String str) throws LensException {
        List<String> models = getMlService().getModels(str);
        if (models == null || models.isEmpty()) {
            throw new NotFoundException("No models found for algorithm " + str);
        }
        return new StringList(models);
    }

    @GET
    @Path("models/{algorithm}/{modelID}")
    public ModelMetadata getModelMetadata(@PathParam("algorithm") String str, @PathParam("modelID") String str2) throws LensException {
        MLModel model = getMlService().getModel(str, str2);
        if (model == null) {
            throw new NotFoundException("Model not found " + str2 + ", algo=" + str);
        }
        return new ModelMetadata(model.getId(), model.getTable(), model.getAlgoName(), StringUtils.join(model.getParams(), ' '), model.getCreatedAt().toString(), getMlService().getModelPath(str, str2), model.getLabelColumn(), StringUtils.join(model.getFeatureColumns(), ","));
    }

    @Path("models/{algorithm}/{modelID}")
    @Consumes({"application/json", "application/xml", "text/plain"})
    @DELETE
    public String deleteModel(@PathParam("algorithm") String str, @PathParam("modelID") String str2) throws LensException {
        getMlService().deleteModel(str, str2);
        return "DELETED model=" + str2 + " algorithm=" + str;
    }

    @POST
    @Path("{algorithm}/train")
    @Consumes({"application/x-www-form-urlencoded"})
    public String train(@PathParam("algorithm") String str, MultivaluedMap<String, String> multivaluedMap) throws LensException {
        if (getMlService().getAlgoForName(str) == null) {
            throw new NotFoundException("Algo for algo: " + str + " not found");
        }
        if (StringUtils.isBlank((String) multivaluedMap.getFirst("table"))) {
            throw new BadRequestException("table parameter is rquired");
        }
        String str2 = (String) multivaluedMap.getFirst("table");
        if (StringUtils.isBlank((String) multivaluedMap.getFirst("label"))) {
            throw new BadRequestException("label parameter is required");
        }
        if (((List) multivaluedMap.get("feature")).size() < 1) {
            throw new BadRequestException("At least one feature is required");
        }
        ArrayList arrayList = new ArrayList();
        for (Map.Entry entry : multivaluedMap.entrySet()) {
            String str3 = (String) entry.getKey();
            List<String> list = (List) entry.getValue();
            if (!"algorithm".equals(str3) && !"table".equals(str3)) {
                if ("feature".equals(str3)) {
                    for (String str4 : list) {
                        arrayList.add("feature");
                        arrayList.add(str4);
                    }
                } else if ("label".equals(str3)) {
                    arrayList.add("label");
                    arrayList.add(list.get(0));
                } else {
                    arrayList.add(str3);
                    arrayList.add(list.get(0));
                }
            }
        }
        log.info("Training table {} with algo {} params={}", new Object[]{str2, str, arrayList.toString()});
        String train = getMlService().train(str2, str, (String[]) arrayList.toArray(new String[0]));
        log.info("Done training {} modelid = {}", str2, train);
        return train;
    }

    @Produces({"text/plain"})
    @Path("clearModelCache")
    @DELETE
    public Response clearModelCache() {
        ModelLoader.clearCache();
        log.info("Cleared model cache");
        return Response.ok("Cleared cache", MediaType.TEXT_PLAIN_TYPE).build();
    }

    @POST
    @Path("test/{table}/{algorithm}/{modelID}")
    @Consumes({"multipart/form-data"})
    public String test(@PathParam("algorithm") String str, @PathParam("modelID") String str2, @PathParam("table") String str3, @FormDataParam("sessionid") LensSessionHandle lensSessionHandle, @FormDataParam("outputTable") String str4) throws LensException {
        return getMlService().testModel(lensSessionHandle, str3, str, str2, str4).getReportID();
    }

    @GET
    @Path("reports/{algorithm}")
    public StringList getReportsForAlgorithm(@PathParam("algorithm") String str) throws LensException {
        List<String> testReports = getMlService().getTestReports(str);
        if (testReports == null || testReports.isEmpty()) {
            throw new NotFoundException("No test reports found for " + str);
        }
        return new StringList(testReports);
    }

    @GET
    @Path("reports/{algorithm}/{reportID}")
    public TestReport getTestReport(@PathParam("algorithm") String str, @PathParam("reportID") String str2) throws LensException {
        MLTestReport testReport = getMlService().getTestReport(str, str2);
        if (testReport == null) {
            throw new NotFoundException("Test report: " + str2 + " not found for algorithm " + str);
        }
        return new TestReport(testReport.getTestTable(), testReport.getOutputTable(), testReport.getOutputColumn(), testReport.getLabelColumn(), StringUtils.join(testReport.getFeatureColumns(), ","), testReport.getAlgorithm(), testReport.getModelID(), testReport.getReportID(), testReport.getLensQueryID());
    }

    @Path("reports/{algorithm}/{reportID}")
    @Consumes({"application/json", "application/xml", "text/plain"})
    @DELETE
    public String deleteTestReport(@PathParam("algorithm") String str, @PathParam("reportID") String str2) throws LensException {
        getMlService().deleteTestReport(str, str2);
        return "DELETED report=" + str2 + " algorithm=" + str;
    }

    @GET
    @Produces({"application/atom+xml", "application/json"})
    @Path("/predict/{algorithm}/{modelID}")
    public String predict(@PathParam("algorithm") String str, @PathParam("modelID") String str2, @Context UriInfo uriInfo) throws LensException {
        MLModel model = getMlService().getModel(str, str2);
        MultivaluedMap queryParameters = uriInfo.getQueryParameters();
        String[] strArr = new String[model.getFeatureColumns().size()];
        int i = 0;
        Iterator<String> it = model.getFeatureColumns().iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            strArr[i2] = (String) queryParameters.getFirst(it.next());
        }
        return getMlService().predict(str, str2, strArr).toString();
    }

    static {
        HIVE_CONF.addResource("lensserver-default.xml");
        HIVE_CONF.addResource("lens-site.xml");
    }
}
