package com.github.keenon.loglinear.simple;

import com.github.keenon.loglinear.inference.CliqueTree;
import com.github.keenon.loglinear.learning.LogLikelihoodDifferentiableFunction;
import com.github.keenon.loglinear.model.ConcatVector;
import com.github.keenon.loglinear.model.GraphicalModel;
import edu.stanford.nlp.pipeline.Annotation;
import edu.stanford.nlp.pipeline.StanfordCoreNLP;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;

/* loaded from: input_file:com/github/keenon/loglinear/simple/DurableMulticlassPredictor.class */
public class DurableMulticlassPredictor extends SimpleDurablePredictor<Annotation> {
    public String[] tags;
    private Map<String, Function<Annotation, String>> stringFeatures;
    private Map<String, Function<Annotation, double[]>> embeddingFeatures;
    private StanfordCoreNLP coreNLP;
    private static final String SOURCE_TEXT = "com.github.keenon.loglinear.simple.DurableMulticlassPredictor.SOURCE_TEXT";
    static final /* synthetic */ boolean $assertionsDisabled;

    public DurableMulticlassPredictor(String str, String[] strArr, StanfordCoreNLP stanfordCoreNLP) throws IOException {
        super(str);
        this.stringFeatures = new HashMap();
        this.embeddingFeatures = new HashMap();
        this.tags = strArr;
        this.coreNLP = stanfordCoreNLP;
    }

    public void addStringFeature(String str, Function<Annotation, String> function) {
        this.stringFeatures.put(str, function);
    }

    public void addEmbeddingFeature(String str, Function<Annotation, double[]> function) {
        this.embeddingFeatures.put(str, function);
    }

    public String labelSequence(Annotation annotation) {
        return this.tags[new CliqueTree(createModel(annotation), this.weights).calculateMAP()[0]];
    }

    public void addTrainingExample(Annotation annotation, String str) {
        GraphicalModel createModel = createModel(annotation);
        int i = -1;
        for (int i2 = 0; i2 < this.tags.length; i2++) {
            if (this.tags[i2].equals(str)) {
                i = i2;
            }
        }
        if (!$assertionsDisabled && i == -1) {
            throw new AssertionError();
        }
        createModel.getVariableMetaDataByReference(0).put(LogLikelihoodDifferentiableFunction.VARIABLE_TRAINING_VALUE, "" + i);
        addLabeledTrainingExample(createModel);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.github.keenon.loglinear.simple.SimpleDurablePredictor
    public GraphicalModel createModelInternal(Annotation annotation) {
        GraphicalModel graphicalModel = new GraphicalModel();
        graphicalModel.getModelMetaDataByReference().put(SOURCE_TEXT, annotation.toString());
        featurizeModel(graphicalModel, annotation);
        return graphicalModel;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.github.keenon.loglinear.simple.SimpleDurablePredictor
    public Annotation restoreContextObjectFromModelTags(GraphicalModel graphicalModel) {
        Annotation annotation = new Annotation(graphicalModel.getModelMetaDataByReference().get(SOURCE_TEXT));
        this.coreNLP.annotate(annotation);
        return annotation;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.github.keenon.loglinear.simple.SimpleDurablePredictor
    public void featurizeModel(GraphicalModel graphicalModel, Annotation annotation) {
        graphicalModel.addFactor(new int[]{0}, new int[]{this.tags.length}, iArr -> {
            ConcatVector newVector = this.namespace.newVector();
            String str = this.tags[iArr[0]];
            for (String str2 : this.stringFeatures.keySet()) {
                this.namespace.setSparseFeature(newVector, str + ":" + str2, this.stringFeatures.get(str2).apply(annotation), 1.0d);
            }
            for (String str3 : this.embeddingFeatures.keySet()) {
                this.namespace.setDenseFeature(newVector, str + ":" + str3, this.embeddingFeatures.get(str3).apply(annotation));
            }
            return newVector;
        });
    }

    static {
        $assertionsDisabled = !DurableMulticlassPredictor.class.desiredAssertionStatus();
    }
}
