package org.jpmml.evaluator;

import com.google.common.base.Objects;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Multimap;
import com.google.common.collect.Sets;
import com.google.common.collect.Table;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import org.dmg.pmml.DataDictionary;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DefineFunction;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldUsageType;
import org.dmg.pmml.InlineTable;
import org.dmg.pmml.LocalTransformations;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.Model;
import org.dmg.pmml.ModelVerification;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.PMML;
import org.dmg.pmml.Target;
import org.dmg.pmml.Targets;
import org.dmg.pmml.TransformationDictionary;
import org.dmg.pmml.TypeDefinitionField;
import org.dmg.pmml.VerificationField;
import org.dmg.pmml.VerificationFields;

/* loaded from: input_file:org/jpmml/evaluator/ModelEvaluator.class */
public abstract class ModelEvaluator<M extends Model> extends ModelManager<M> implements Evaluator {
    private Map<FieldName, DataField> dataFields;
    private Map<FieldName, DerivedField> derivedFields;
    private Map<String, DefineFunction> functions;
    private Map<FieldName, MiningField> miningFields;
    private ListMultimap<EnumSet<FieldUsageType>, FieldName> miningFieldNames;
    private Map<FieldName, DerivedField> localDerivedFields;
    private Map<FieldName, Target> targets;
    private Map<FieldName, OutputField> outputFields;
    private static final LoadingCache<DataDictionary, Map<FieldName, DataField>> dataFieldCache = CacheUtil.buildLoadingCache(new CacheLoader<DataDictionary, Map<FieldName, DataField>>() { // from class: org.jpmml.evaluator.ModelEvaluator.1
        @Override // com.google.common.cache.CacheLoader
        public Map<FieldName, DataField> load(DataDictionary dataDictionary) {
            return IndexableUtil.buildMap(dataDictionary.getDataFields());
        }
    });
    private static final LoadingCache<TransformationDictionary, Map<FieldName, DerivedField>> derivedFieldCache = CacheUtil.buildLoadingCache(new CacheLoader<TransformationDictionary, Map<FieldName, DerivedField>>() { // from class: org.jpmml.evaluator.ModelEvaluator.2
        @Override // com.google.common.cache.CacheLoader
        public Map<FieldName, DerivedField> load(TransformationDictionary transformationDictionary) {
            return IndexableUtil.buildMap(transformationDictionary.getDerivedFields());
        }
    });
    private static final LoadingCache<TransformationDictionary, Map<String, DefineFunction>> functionCache = CacheUtil.buildLoadingCache(new CacheLoader<TransformationDictionary, Map<String, DefineFunction>>() { // from class: org.jpmml.evaluator.ModelEvaluator.3
        @Override // com.google.common.cache.CacheLoader
        public Map<String, DefineFunction> load(TransformationDictionary transformationDictionary) {
            return IndexableUtil.buildMap(transformationDictionary.getDefineFunctions());
        }
    });
    private static final LoadingCache<MiningSchema, Map<FieldName, MiningField>> miningFieldCache = CacheUtil.buildLoadingCache(new CacheLoader<MiningSchema, Map<FieldName, MiningField>>() { // from class: org.jpmml.evaluator.ModelEvaluator.4
        @Override // com.google.common.cache.CacheLoader
        public Map<FieldName, MiningField> load(MiningSchema miningSchema) {
            return IndexableUtil.buildMap(miningSchema.getMiningFields());
        }
    });
    private static final LoadingCache<MiningSchema, ListMultimap<EnumSet<FieldUsageType>, FieldName>> miningFieldNameCache = CacheUtil.buildLoadingCache(new CacheLoader<MiningSchema, ListMultimap<EnumSet<FieldUsageType>, FieldName>>() { // from class: org.jpmml.evaluator.ModelEvaluator.5
        @Override // com.google.common.cache.CacheLoader
        public ListMultimap<EnumSet<FieldUsageType>, FieldName> load(MiningSchema miningSchema) {
            return ImmutableListMultimap.copyOf((Multimap) ModelEvaluator.parseMiningFieldNames(miningSchema.getMiningFields()));
        }
    });
    private static final LoadingCache<LocalTransformations, Map<FieldName, DerivedField>> localDerivedFieldCache = CacheUtil.buildLoadingCache(new CacheLoader<LocalTransformations, Map<FieldName, DerivedField>>() { // from class: org.jpmml.evaluator.ModelEvaluator.6
        @Override // com.google.common.cache.CacheLoader
        public Map<FieldName, DerivedField> load(LocalTransformations localTransformations) {
            return IndexableUtil.buildMap(localTransformations.getDerivedFields());
        }
    });
    private static final LoadingCache<Targets, Map<FieldName, Target>> targetCache = CacheUtil.buildLoadingCache(new CacheLoader<Targets, Map<FieldName, Target>>() { // from class: org.jpmml.evaluator.ModelEvaluator.7
        @Override // com.google.common.cache.CacheLoader
        public Map<FieldName, Target> load(Targets targets) {
            return IndexableUtil.buildMap(targets.getTargets());
        }
    });
    private static final LoadingCache<Output, Map<FieldName, OutputField>> outputFieldCache = CacheUtil.buildLoadingCache(new CacheLoader<Output, Map<FieldName, OutputField>>() { // from class: org.jpmml.evaluator.ModelEvaluator.8
        @Override // com.google.common.cache.CacheLoader
        public Map<FieldName, OutputField> load(Output output) {
            return IndexableUtil.buildMap(output.getOutputFields());
        }
    });
    private static final LoadingCache<ModelVerification, VerificationBatch> batchCache = CacheUtil.buildLoadingCache(new CacheLoader<ModelVerification, VerificationBatch>() { // from class: org.jpmml.evaluator.ModelEvaluator.9
        @Override // com.google.common.cache.CacheLoader
        public VerificationBatch load(ModelVerification modelVerification) {
            return ModelEvaluator.parseModelVerification(modelVerification);
        }
    });

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/jpmml/evaluator/ModelEvaluator$VerificationBatch.class */
    public static class VerificationBatch extends LinkedHashMap<FieldName, VerificationField> {
        private List<Map<FieldName, Object>> records;

        private VerificationBatch() {
            this.records = null;
        }

        public List<Map<FieldName, Object>> getRecords() {
            return this.records;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void setRecords(List<Map<FieldName, Object>> list) {
            this.records = list;
        }
    }

    public ModelEvaluator(PMML pmml, Class<? extends M> cls) {
        this(pmml, selectModel(pmml, cls));
    }

    public ModelEvaluator(PMML pmml, M m) {
        super(pmml, m);
        this.dataFields = Collections.emptyMap();
        this.derivedFields = Collections.emptyMap();
        this.functions = Collections.emptyMap();
        this.miningFields = Collections.emptyMap();
        this.miningFieldNames = ImmutableListMultimap.of();
        this.localDerivedFields = Collections.emptyMap();
        this.targets = Collections.emptyMap();
        this.outputFields = Collections.emptyMap();
        DataDictionary dataDictionary = pmml.getDataDictionary();
        if (dataDictionary.hasDataFields()) {
            this.dataFields = (Map) CacheUtil.getValue(dataDictionary, dataFieldCache);
        }
        TransformationDictionary transformationDictionary = pmml.getTransformationDictionary();
        if (transformationDictionary != null && transformationDictionary.hasDerivedFields()) {
            this.derivedFields = (Map) CacheUtil.getValue(transformationDictionary, derivedFieldCache);
        }
        if (transformationDictionary != null && transformationDictionary.hasDefineFunctions()) {
            this.functions = (Map) CacheUtil.getValue(transformationDictionary, functionCache);
        }
        MiningSchema miningSchema = m.getMiningSchema();
        if (miningSchema.hasMiningFields()) {
            this.miningFields = (Map) CacheUtil.getValue(miningSchema, miningFieldCache);
        }
        if (miningSchema.hasMiningFields()) {
            this.miningFieldNames = (ListMultimap) CacheUtil.getValue(miningSchema, miningFieldNameCache);
        }
        LocalTransformations localTransformations = m.getLocalTransformations();
        if (localTransformations != null && localTransformations.hasDerivedFields()) {
            this.localDerivedFields = (Map) CacheUtil.getValue(localTransformations, localDerivedFieldCache);
        }
        Targets targets = m.getTargets();
        if (targets != null) {
            this.targets = (Map) CacheUtil.getValue(targets, targetCache);
        }
        Output output = m.getOutput();
        if (output != null) {
            this.outputFields = (Map) CacheUtil.getValue(output, outputFieldCache);
        }
    }

    public abstract Map<FieldName, ?> evaluate(ModelEvaluationContext modelEvaluationContext);

    @Override // org.jpmml.evaluator.PMMLManager, org.jpmml.evaluator.Consumer
    public DataField getDataField(FieldName fieldName) {
        return Objects.equal(TargetUtil.DEFAULT_NAME, fieldName) ? getDataField() : this.dataFields.get(fieldName);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public DataField getDataField() {
        switch (getModel().getFunctionName()) {
            case REGRESSION:
                return new DataField(TargetUtil.DEFAULT_NAME, OpType.CONTINUOUS, DataType.DOUBLE);
            case CLASSIFICATION:
            case CLUSTERING:
                return new DataField(TargetUtil.DEFAULT_NAME, OpType.CATEGORICAL, DataType.STRING);
            default:
                return null;
        }
    }

    @Override // org.jpmml.evaluator.PMMLManager
    public DerivedField getDerivedField(FieldName fieldName) {
        return this.derivedFields.get(fieldName);
    }

    @Override // org.jpmml.evaluator.PMMLManager
    public DefineFunction getFunction(String str) {
        return this.functions.get(str);
    }

    @Override // org.jpmml.evaluator.ModelManager, org.jpmml.evaluator.Consumer
    public MiningField getMiningField(FieldName fieldName) {
        if (fieldName == null) {
            return null;
        }
        return this.miningFields.get(fieldName);
    }

    @Override // org.jpmml.evaluator.ModelManager
    protected List<FieldName> getMiningFields(EnumSet<FieldUsageType> enumSet) {
        List<FieldName> list = this.miningFieldNames.get((ListMultimap<EnumSet<FieldUsageType>, FieldName>) enumSet);
        return list != null ? list : super.getMiningFields(enumSet);
    }

    @Override // org.jpmml.evaluator.ModelManager
    public DerivedField getLocalDerivedField(FieldName fieldName) {
        return this.localDerivedFields.get(fieldName);
    }

    @Override // org.jpmml.evaluator.ModelManager, org.jpmml.evaluator.Consumer
    public Target getTarget(FieldName fieldName) {
        if (fieldName == null) {
            return null;
        }
        return this.targets.get(fieldName);
    }

    @Override // org.jpmml.evaluator.ModelManager, org.jpmml.evaluator.Consumer
    public OutputField getOutputField(FieldName fieldName) {
        return this.outputFields.get(fieldName);
    }

    @Override // org.jpmml.evaluator.Evaluator
    public FieldValue prepare(FieldName fieldName, Object obj) {
        DataField dataField = getDataField(fieldName);
        MiningField miningField = getMiningField(fieldName);
        if (dataField == null || miningField == null) {
            throw new MissingFieldException(fieldName, getModel());
        }
        return FieldValueUtil.prepare(dataField, miningField, obj);
    }

    public void verify() {
        ModelVerification modelVerification = getModel().getModelVerification();
        if (modelVerification == null) {
            return;
        }
        VerificationBatch verificationBatch = (VerificationBatch) CacheUtil.getValue(modelVerification, batchCache);
        List<Map<FieldName, Object>> records = verificationBatch.getRecords();
        List<FieldName> activeFields = getActiveFields();
        List<FieldName> groupFields = getGroupFields();
        if (groupFields.size() == 1) {
            records = EvaluatorUtil.groupRows(groupFields.get(0), records);
        } else if (groupFields.size() > 1) {
            throw new EvaluationException();
        }
        List<FieldName> targetFields = getTargetFields();
        List<FieldName> outputFields = getOutputFields();
        Sets.SetView intersection = Sets.intersection(verificationBatch.keySet(), ImmutableSet.copyOf((Collection) outputFields));
        for (Map<FieldName, Object> map : records) {
            HashMap hashMap = new HashMap();
            for (FieldName fieldName : activeFields) {
                hashMap.put(fieldName, EvaluatorUtil.prepare(this, fieldName, map.get(fieldName)));
            }
            Map<FieldName, ?> evaluate = evaluate(hashMap);
            if (intersection.size() > 0) {
                for (FieldName fieldName2 : outputFields) {
                    VerificationField verificationField = verificationBatch.get(fieldName2);
                    if (verificationField != null) {
                        verify(map.get(fieldName2), evaluate.get(fieldName2), verificationField.getPrecision().doubleValue(), verificationField.getZeroThreshold().doubleValue());
                    }
                }
            } else {
                for (FieldName fieldName3 : targetFields) {
                    VerificationField verificationField2 = verificationBatch.get(fieldName3);
                    if (verificationField2 != null) {
                        verify(map.get(fieldName3), EvaluatorUtil.decode(evaluate.get(fieldName3)), verificationField2.getPrecision().doubleValue(), verificationField2.getZeroThreshold().doubleValue());
                    }
                }
            }
        }
    }

    private void verify(Object obj, Object obj2, double d, double d2) {
        if (obj == null) {
            return;
        }
        if (!(obj2 instanceof Collection)) {
            obj = TypeUtil.parseOrCast(TypeUtil.getDataType(obj2), obj);
        }
        if (!VerificationUtil.acceptable(obj, obj2, d, d2)) {
            throw new EvaluationException();
        }
    }

    public ModelEvaluationContext createContext(ModelEvaluationContext modelEvaluationContext) {
        return new ModelEvaluationContext(modelEvaluationContext, this);
    }

    @Override // org.jpmml.evaluator.Evaluator
    public Map<FieldName, ?> evaluate(Map<FieldName, ?> map) {
        ModelEvaluationContext createContext = createContext(null);
        createContext.setArguments(map);
        return evaluate(createContext);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public TypeDefinitionField resolveField(FieldName fieldName) {
        TypeDefinitionField dataField = getDataField(fieldName);
        if (dataField == null) {
            dataField = resolveDerivedField(fieldName);
        }
        return dataField;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public DerivedField resolveDerivedField(FieldName fieldName) {
        DerivedField derivedField = getDerivedField(fieldName);
        if (derivedField == null) {
            derivedField = getLocalDerivedField(fieldName);
        }
        return derivedField;
    }

    public <V> V getValue(LoadingCache<M, V> loadingCache) {
        return (V) CacheUtil.getValue(getModel(), loadingCache);
    }

    public <V> V getValue(Callable<? extends V> callable, Cache<M, V> cache) {
        return (V) CacheUtil.getValue(getModel(), callable, cache);
    }

    private static <M extends Model> M selectModel(PMML pmml, Class<? extends M> cls) {
        M m = (M) Iterables.getFirst(Iterables.filter(pmml.getModels(), cls), null);
        if (m == null) {
            throw new InvalidFeatureException(pmml);
        }
        return m;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static ListMultimap<EnumSet<FieldUsageType>, FieldName> parseMiningFieldNames(List<MiningField> list) {
        ImmutableSet<EnumSet> of = ImmutableSet.of(ModelManager.ACTIVE_TYPES, ModelManager.GROUP_TYPES, ModelManager.ORDER_TYPES, ModelManager.TARGET_TYPES);
        ArrayListMultimap create = ArrayListMultimap.create();
        for (MiningField miningField : list) {
            for (EnumSet enumSet : of) {
                if (enumSet.contains(miningField.getUsageType())) {
                    create.put(enumSet, miningField.getName());
                }
            }
        }
        return create;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static VerificationBatch parseModelVerification(ModelVerification modelVerification) {
        VerificationBatch verificationBatch = new VerificationBatch();
        VerificationFields verificationFields = modelVerification.getVerificationFields();
        if (verificationFields == null) {
            throw new InvalidFeatureException(modelVerification);
        }
        Iterator<VerificationField> it = verificationFields.iterator();
        while (it.hasNext()) {
            VerificationField next = it.next();
            verificationBatch.put(next.getField(), next);
        }
        InlineTable inlineTable = modelVerification.getInlineTable();
        if (inlineTable == null) {
            throw new InvalidFeatureException(modelVerification);
        }
        Table<Integer, String, String> content = InlineTableUtil.getContent(inlineTable);
        ArrayList arrayList = new ArrayList();
        Iterator<Integer> it2 = content.rowKeySet().iterator();
        while (it2.hasNext()) {
            Map<String, String> row = content.row(it2.next());
            LinkedHashMap linkedHashMap = new LinkedHashMap();
            Iterator<VerificationField> it3 = verificationFields.iterator();
            while (it3.hasNext()) {
                VerificationField next2 = it3.next();
                FieldName field = next2.getField();
                String column = next2.getColumn();
                if (column == null) {
                    column = field.getValue();
                }
                if (row.containsKey(column)) {
                    linkedHashMap.put(field, row.get(column));
                }
            }
            arrayList.add(linkedHashMap);
        }
        Integer recordCount = modelVerification.getRecordCount();
        if (recordCount != null && recordCount.intValue() != arrayList.size()) {
            throw new InvalidFeatureException(inlineTable);
        }
        verificationBatch.setRecords(arrayList);
        return verificationBatch;
    }
}
