package com.github.keenon.loglinear.simple;

import com.github.keenon.loglinear.inference.CliqueTree;
import com.github.keenon.loglinear.learning.LogLikelihoodDifferentiableFunction;
import com.github.keenon.loglinear.model.ConcatVector;
import com.github.keenon.loglinear.model.GraphicalModel;
import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.pipeline.Annotation;
import edu.stanford.nlp.pipeline.StanfordCoreNLP;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;

/* loaded from: input_file:com/github/keenon/loglinear/simple/DurableSequencePredictor.class */
public class DurableSequencePredictor extends SimpleDurablePredictor<Annotation> {
    public String[] tags;
    private Map<String, BiFunction<Annotation, Integer, String>> unaryStringFeatures;
    private Map<String, BiFunction<Annotation, Integer, double[]>> unaryEmbeddingFeatures;
    private Map<String, BiFunction<Annotation, Integer, String>> binaryStringFeatures;
    private StanfordCoreNLP coreNLP;
    private static final String SOURCE_TEXT = "com.github.keenon.loglinear.simple.DurableSequencePredictor.SOURCE_TEXT";
    static final /* synthetic */ boolean $assertionsDisabled;

    public DurableSequencePredictor(String str, String[] strArr, StanfordCoreNLP stanfordCoreNLP) throws IOException {
        super(str);
        this.unaryStringFeatures = new HashMap();
        this.unaryEmbeddingFeatures = new HashMap();
        this.binaryStringFeatures = new HashMap();
        this.tags = strArr;
        this.coreNLP = stanfordCoreNLP;
    }

    public String[] labelSequence(Annotation annotation) {
        String[] strArr = new String[((List) annotation.get(CoreAnnotations.TokensAnnotation.class)).size()];
        int[] calculateMAP = new CliqueTree(createModel(annotation), this.weights).calculateMAP();
        for (int i = 0; i < strArr.length; i++) {
            strArr[i] = this.tags[calculateMAP[i]];
        }
        return strArr;
    }

    public void addTrainingExample(Annotation annotation, String[] strArr) {
        GraphicalModel createModel = createModel(annotation);
        if (((List) annotation.get(CoreAnnotations.TokensAnnotation.class)).size() != strArr.length) {
            throw new IllegalStateException("Shouldn't pass a training example with adifferent number of labels from tokens. Got a sentence \"" + annotation + "\" with " + ((List) annotation.get(CoreAnnotations.TokensAnnotation.class)).size() + " tokens, but got labels " + Arrays.toString(strArr) + " with length " + strArr.length);
        }
        if (!$assertionsDisabled && ((List) annotation.get(CoreAnnotations.TokensAnnotation.class)).size() != strArr.length) {
            throw new AssertionError();
        }
        int i = 0;
        Iterator<GraphicalModel.Factor> it = createModel.factors.iterator();
        while (it.hasNext()) {
            for (int i2 : it.next().neigborIndices) {
                if (i2 > i) {
                    i = i2;
                }
            }
        }
        if (i + 1 != strArr.length) {
            System.err.println("Have the wrong number of labels!");
            throw new IllegalStateException();
        }
        for (int i3 = 0; i3 < strArr.length; i3++) {
            int i4 = -1;
            for (int i5 = 0; i5 < this.tags.length; i5++) {
                if (this.tags[i5].equals(strArr[i3])) {
                    i4 = i5;
                }
            }
            if (!$assertionsDisabled && i4 == -1) {
                throw new AssertionError();
            }
            createModel.getVariableMetaDataByReference(i3).put(LogLikelihoodDifferentiableFunction.VARIABLE_TRAINING_VALUE, "" + i4);
        }
        addLabeledTrainingExample(createModel);
    }

    public void addUnaryStringFeature(String str, BiFunction<Annotation, Integer, String> biFunction) {
        this.unaryStringFeatures.put(str, biFunction);
    }

    public void addUnaryEmbeddingFeature(String str, BiFunction<Annotation, Integer, double[]> biFunction) {
        this.unaryEmbeddingFeatures.put(str, biFunction);
    }

    public void addBinaryStringFeature(String str, BiFunction<Annotation, Integer, String> biFunction) {
        this.binaryStringFeatures.put(str, biFunction);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.github.keenon.loglinear.simple.SimpleDurablePredictor
    public GraphicalModel createModelInternal(Annotation annotation) {
        GraphicalModel graphicalModel = new GraphicalModel();
        graphicalModel.getModelMetaDataByReference().put(SOURCE_TEXT, annotation.toString());
        featurizeModel(graphicalModel, annotation);
        return graphicalModel;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.github.keenon.loglinear.simple.SimpleDurablePredictor
    public Annotation restoreContextObjectFromModelTags(GraphicalModel graphicalModel) {
        Annotation annotation = new Annotation(graphicalModel.getModelMetaDataByReference().get(SOURCE_TEXT));
        this.coreNLP.annotate(annotation);
        return annotation;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.github.keenon.loglinear.simple.SimpleDurablePredictor
    public void featurizeModel(GraphicalModel graphicalModel, Annotation annotation) {
        for (int i = 0; i < ((List) annotation.get(CoreAnnotations.TokensAnnotation.class)).size(); i++) {
            int i2 = i;
            graphicalModel.addFactor(new int[]{i}, new int[]{this.tags.length}, iArr -> {
                ConcatVector newVector = this.namespace.newVector();
                String str = this.tags[iArr[0]];
                for (String str2 : this.unaryStringFeatures.keySet()) {
                    this.namespace.setSparseFeature(newVector, str + ":" + str2, this.unaryStringFeatures.get(str2).apply(annotation, Integer.valueOf(i2)), 1.0d);
                }
                for (String str3 : this.unaryEmbeddingFeatures.keySet()) {
                    this.namespace.setDenseFeature(newVector, str + ":" + str3, this.unaryEmbeddingFeatures.get(str3).apply(annotation, Integer.valueOf(i2)));
                }
                return newVector;
            });
            if (i != ((List) annotation.get(CoreAnnotations.TokensAnnotation.class)).size() - 1) {
                graphicalModel.addFactor(new int[]{i, i + 1}, new int[]{this.tags.length, this.tags.length}, iArr2 -> {
                    ConcatVector newVector = this.namespace.newVector();
                    String str = this.tags[iArr2[0]];
                    String str2 = this.tags[iArr2[1]];
                    for (String str3 : this.binaryStringFeatures.keySet()) {
                        this.namespace.setSparseFeature(newVector, str + ":" + str2 + ":" + str3, this.binaryStringFeatures.get(str3).apply(annotation, Integer.valueOf(i2)), 1.0d);
                    }
                    return newVector;
                });
            }
        }
    }

    static {
        $assertionsDisabled = !DurableSequencePredictor.class.desiredAssertionStatus();
    }
}
