package fact;

import java.io.InputStream;
import java.io.Serializable;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.Model;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.ModelEvaluatorFactory;
import org.jpmml.model.ImportFilter;
import org.jpmml.model.JAXBUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.xml.sax.InputSource;
import org.xml.sax.SAXException;
import stream.Data;
import stream.ProcessContext;
import stream.StatefulProcessor;
import stream.annotations.Parameter;
import stream.io.SourceURL;

/* loaded from: input_file:fact/ApplyModel.class */
public class ApplyModel implements StatefulProcessor {
    static Logger log = LoggerFactory.getLogger((Class<?>) ApplyModel.class);
    PMML pmml;
    Map<FieldName, FieldValue> arguments = new LinkedHashMap();
    private ModelEvaluator<? extends Model> modelEvaluator;
    private FieldName targetName;
    private List<FieldName> activeFields;

    @Parameter(required = true, description = "URL point to the .pmml model")
    SourceURL url;

    @Override // stream.StatefulProcessor
    public void init(ProcessContext processContext) throws Exception {
        log.info("Loading .pmml model");
        Throwable th = null;
        try {
            try {
                InputStream openStream = this.url.openStream();
                try {
                    this.pmml = JAXBUtil.unmarshalPMML(ImportFilter.apply(new InputSource(openStream)));
                    if (openStream != null) {
                        openStream.close();
                    }
                } catch (Throwable th2) {
                    if (openStream != null) {
                        openStream.close();
                    }
                    throw th2;
                }
            } catch (Throwable th3) {
                if (0 == 0) {
                    th = th3;
                } else if (null != th3) {
                    th.addSuppressed(th3);
                }
                throw th;
            }
        } catch (SAXException e) {
            log.error("Could not load model from file provided at" + this.url);
        }
        this.modelEvaluator = ModelEvaluatorFactory.newInstance().newModelManager(this.pmml);
        log.info("Loaded model requires the following fields: " + this.modelEvaluator.getActiveFields().toString());
        log.info("Loaded model has targets: " + this.modelEvaluator.getTargetFields().toString());
        if (this.modelEvaluator.getTargetFields().size() > 1) {
            log.error("Only models with one target variable are supported for now");
        }
        this.targetName = this.modelEvaluator.getTargetField();
        this.activeFields = this.modelEvaluator.getActiveFields();
    }

    @Override // stream.StatefulProcessor
    public void resetState() throws Exception {
    }

    @Override // stream.StatefulProcessor
    public void finish() throws Exception {
    }

    @Override // stream.Processor
    public Data process(Data data) {
        for (FieldName fieldName : this.activeFields) {
            this.arguments.put(fieldName, this.modelEvaluator.prepare(fieldName, data.get(fieldName.toString())));
        }
        Object obj = this.modelEvaluator.evaluate(this.arguments).get(this.targetName);
        log.info("Prediction: " + obj);
        try {
            data.put(this.targetName.getValue(), (Serializable) obj);
        } catch (ClassCastException e) {
            log.warn("Cannot cast target type to serializable type");
            data.put(this.targetName.getValue(), obj.toString());
        }
        return data;
    }

    public void setUrl(SourceURL sourceURL) {
        this.url = sourceURL;
    }
}
