package org.apache.ctakes.coreference.ae;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.ctakes.core.pipeline.PipeBitInfo;
import org.apache.ctakes.core.util.ListFactory;
import org.apache.ctakes.coreference.ae.features.cluster.MentionClusterAgreementFeaturesExtractor;
import org.apache.ctakes.coreference.ae.features.cluster.MentionClusterAttributeFeaturesExtractor;
import org.apache.ctakes.coreference.ae.features.cluster.MentionClusterDepHeadExtractor;
import org.apache.ctakes.coreference.ae.features.cluster.MentionClusterSalienceFeaturesExtractor;
import org.apache.ctakes.coreference.ae.features.cluster.MentionClusterSectionFeaturesExtractor;
import org.apache.ctakes.coreference.ae.features.cluster.MentionClusterSemTypeDepPrefsFeatureExtractor;
import org.apache.ctakes.coreference.ae.features.cluster.MentionClusterStackFeaturesExtractor;
import org.apache.ctakes.coreference.ae.features.cluster.MentionClusterStringFeaturesExtractor;
import org.apache.ctakes.coreference.ae.features.cluster.MentionClusterUMLSFeatureExtractor;
import org.apache.ctakes.coreference.ae.pairing.cluster.ClusterMentionPairer_ImplBase;
import org.apache.ctakes.coreference.ae.pairing.cluster.ClusterPairer;
import org.apache.ctakes.coreference.ae.pairing.cluster.HeadwordPairer;
import org.apache.ctakes.coreference.ae.pairing.cluster.SectionHeaderPairer;
import org.apache.ctakes.coreference.ae.pairing.cluster.SentenceDistancePairer;
import org.apache.ctakes.coreference.util.MarkableUtilities;
import org.apache.ctakes.relationextractor.ae.features.RelationFeaturesExtractor;
import org.apache.ctakes.relationextractor.eval.RelationExtractorEvaluation;
import org.apache.ctakes.typesystem.type.refsem.AnatomicalSite;
import org.apache.ctakes.typesystem.type.refsem.DiseaseDisorder;
import org.apache.ctakes.typesystem.type.refsem.Event;
import org.apache.ctakes.typesystem.type.refsem.Medication;
import org.apache.ctakes.typesystem.type.refsem.Procedure;
import org.apache.ctakes.typesystem.type.refsem.SignSymptom;
import org.apache.ctakes.typesystem.type.relation.CollectionTextRelation;
import org.apache.ctakes.typesystem.type.relation.CollectionTextRelationIdentifiedAnnotationRelation;
import org.apache.ctakes.typesystem.type.relation.CoreferenceRelation;
import org.apache.ctakes.typesystem.type.textsem.AnatomicalSiteMention;
import org.apache.ctakes.typesystem.type.textsem.DiseaseDisorderMention;
import org.apache.ctakes.typesystem.type.textsem.IdentifiedAnnotation;
import org.apache.ctakes.typesystem.type.textsem.Markable;
import org.apache.ctakes.typesystem.type.textsem.MedicationMention;
import org.apache.ctakes.typesystem.type.textsem.ProcedureMention;
import org.apache.ctakes.typesystem.type.textsem.SignSymptomMention;
import org.apache.ctakes.typesystem.type.textspan.Segment;
import org.apache.ctakes.utils.struct.CounterMap;
import org.apache.uima.UimaContext;
import org.apache.uima.analysis_engine.AnalysisEngineDescription;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.cas.FeatureStructure;
import org.apache.uima.fit.descriptor.ConfigurationParameter;
import org.apache.uima.fit.factory.AnalysisEngineFactory;
import org.apache.uima.fit.util.JCasUtil;
import org.apache.uima.jcas.JCas;
import org.apache.uima.jcas.cas.EmptyFSList;
import org.apache.uima.jcas.cas.FSArray;
import org.apache.uima.jcas.cas.NonEmptyFSList;
import org.apache.uima.resource.ResourceInitializationException;
import org.cleartk.ml.CleartkAnnotator;
import org.cleartk.ml.CleartkProcessingException;
import org.cleartk.ml.DataWriter;
import org.cleartk.ml.Feature;
import org.cleartk.ml.Instance;
import org.cleartk.ml.feature.extractor.FeatureExtractor1;
import org.cleartk.util.ViewUriUtil;

@PipeBitInfo(name = "Coreference (Clusters)", description = "Coreference annotator using mention-synchronous paradigm.", dependencies = {PipeBitInfo.TypeProduct.BASE_TOKEN, PipeBitInfo.TypeProduct.SENTENCE, PipeBitInfo.TypeProduct.SECTION, PipeBitInfo.TypeProduct.IDENTIFIED_ANNOTATION, PipeBitInfo.TypeProduct.MARKABLE}, products = {PipeBitInfo.TypeProduct.COREFERENCE_RELATION})
/* loaded from: input_file:org/apache/ctakes/coreference/ae/MentionClusterCoreferenceAnnotator.class */
public class MentionClusterCoreferenceAnnotator extends CleartkAnnotator<String> {
    public static final String NO_RELATION_CATEGORY = "-NONE-";
    public static final String PARAM_PROBABILITY_OF_KEEPING_A_NEGATIVE_EXAMPLE = "ProbabilityOfKeepingANegativeExample";
    public static final String PARAM_USE_EXISTING_ENCODERS = "UseExistingEncoders";
    private static DataWriter<String> classDataWriter = null;

    @ConfigurationParameter(name = "ProbabilityOfKeepingANegativeExample", mandatory = false, description = "probability that a negative example should be retained for training")
    protected double probabilityOfKeepingANegativeExample = 0.5d;

    @ConfigurationParameter(name = PARAM_USE_EXISTING_ENCODERS, mandatory = false, description = "Whether to use encoders in output directory during data writing; if we are making multiple calls")
    private boolean useExistingEncoders = false;
    protected Random coin = new Random(0);
    boolean greedyFirst = true;
    private List<RelationFeaturesExtractor<CollectionTextRelation, IdentifiedAnnotation>> relationExtractors = getFeatureExtractors();
    private List<FeatureExtractor1<Markable>> mentionExtractors = getMentionExtractors();
    private List<ClusterMentionPairer_ImplBase> pairExtractors = getPairExtractors();

    /* loaded from: input_file:org/apache/ctakes/coreference/ae/MentionClusterCoreferenceAnnotator$CollectionTextRelationIdentifiedAnnotationPair.class */
    public static class CollectionTextRelationIdentifiedAnnotationPair {
        private final CollectionTextRelation cluster;
        private final IdentifiedAnnotation mention;

        public CollectionTextRelationIdentifiedAnnotationPair(CollectionTextRelation collectionTextRelation, IdentifiedAnnotation identifiedAnnotation) {
            this.cluster = collectionTextRelation;
            this.mention = identifiedAnnotation;
        }

        public final CollectionTextRelation getCluster() {
            return this.cluster;
        }

        public final IdentifiedAnnotation getMention() {
            return this.mention;
        }

        public boolean equals(Object obj) {
            CollectionTextRelationIdentifiedAnnotationPair collectionTextRelationIdentifiedAnnotationPair = (CollectionTextRelationIdentifiedAnnotationPair) obj;
            return this.cluster == collectionTextRelationIdentifiedAnnotationPair.cluster && this.mention == collectionTextRelationIdentifiedAnnotationPair.mention;
        }

        public int hashCode() {
            return (31 * this.cluster.hashCode()) + (this.mention == null ? 0 : this.mention.hashCode());
        }
    }

    public static AnalysisEngineDescription createDataWriterDescription(Class<? extends DataWriter<String>> cls, File file, float f) throws ResourceInitializationException {
        return AnalysisEngineFactory.createEngineDescription(MentionClusterCoreferenceAnnotator.class, new Object[]{"isTraining", true, "ProbabilityOfKeepingANegativeExample", Float.valueOf(f), "dataWriterClassName", cls, "outputDirectory", file});
    }

    public static AnalysisEngineDescription createAnnotatorDescription(String str) throws ResourceInitializationException {
        return AnalysisEngineFactory.createEngineDescription(MentionClusterCoreferenceAnnotator.class, new Object[]{"isTraining", false, "classifierJarPath", str});
    }

    protected List<RelationFeaturesExtractor<CollectionTextRelation, IdentifiedAnnotation>> getFeatureExtractors() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(new MentionClusterAgreementFeaturesExtractor());
        arrayList.add(new MentionClusterStringFeaturesExtractor());
        arrayList.add(new MentionClusterSectionFeaturesExtractor());
        arrayList.add(new MentionClusterUMLSFeatureExtractor());
        arrayList.add(new MentionClusterDepHeadExtractor());
        arrayList.add(new MentionClusterStackFeaturesExtractor());
        arrayList.add(new MentionClusterSalienceFeaturesExtractor());
        arrayList.add(new MentionClusterAttributeFeaturesExtractor());
        try {
            arrayList.add(new MentionClusterSemTypeDepPrefsFeatureExtractor());
        } catch (IOException e) {
            e.printStackTrace();
        }
        return arrayList;
    }

    protected List<FeatureExtractor1<Markable>> getMentionExtractors() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(new MentionClusterAgreementFeaturesExtractor());
        arrayList.add(new MentionClusterSectionFeaturesExtractor());
        arrayList.add(new MentionClusterUMLSFeatureExtractor());
        arrayList.add(new MentionClusterDepHeadExtractor());
        arrayList.add(new MentionClusterSalienceFeaturesExtractor());
        arrayList.add(new MentionClusterAttributeFeaturesExtractor());
        return arrayList;
    }

    protected List<ClusterMentionPairer_ImplBase> getPairExtractors() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(new SentenceDistancePairer(5));
        arrayList.add(new SectionHeaderPairer(5));
        arrayList.add(new ClusterPairer(Integer.MAX_VALUE));
        arrayList.add(new HeadwordPairer());
        return arrayList;
    }

    protected Iterable<CollectionTextRelationIdentifiedAnnotationPair> getCandidateRelationArgumentPairs(JCas jCas, Markable markable) {
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        Iterator<ClusterMentionPairer_ImplBase> it = this.pairExtractors.iterator();
        while (it.hasNext()) {
            linkedHashSet.addAll(it.next().getPairs(jCas, markable));
        }
        return linkedHashSet;
    }

    private void resetPairers(JCas jCas) {
        Iterator<ClusterMentionPairer_ImplBase> it = this.pairExtractors.iterator();
        while (it.hasNext()) {
            it.next().reset(jCas);
        }
    }

    public void initialize(UimaContext uimaContext) throws ResourceInitializationException {
        super.initialize(uimaContext);
        if (this.useExistingEncoders && classDataWriter != null) {
            this.dataWriter = classDataWriter;
        } else if (isTraining()) {
            classDataWriter = this.dataWriter;
        }
    }

    public void process(JCas jCas) throws AnalysisEngineProcessException {
        resetPairers(jCas);
        HashMap hashMap = new HashMap();
        if (isTraining()) {
            for (CollectionTextRelation collectionTextRelation : JCasUtil.select(jCas, CollectionTextRelation.class)) {
                for (IdentifiedAnnotation identifiedAnnotation : JCasUtil.select(collectionTextRelation.getMembers(), Markable.class)) {
                    CollectionTextRelationIdentifiedAnnotationRelation collectionTextRelationIdentifiedAnnotationRelation = new CollectionTextRelationIdentifiedAnnotationRelation(jCas);
                    collectionTextRelationIdentifiedAnnotationRelation.setCluster(collectionTextRelation);
                    collectionTextRelationIdentifiedAnnotationRelation.setMention(identifiedAnnotation);
                    collectionTextRelationIdentifiedAnnotationRelation.setCategory(MentionClusterRankingCoreferenceAnnotator.CLUSTER_RELATION_CATEGORY);
                    collectionTextRelationIdentifiedAnnotationRelation.addToIndexes();
                    CollectionTextRelationIdentifiedAnnotationPair collectionTextRelationIdentifiedAnnotationPair = new CollectionTextRelationIdentifiedAnnotationPair(collectionTextRelation, identifiedAnnotation);
                    if (hashMap.containsKey(collectionTextRelationIdentifiedAnnotationPair)) {
                        String category = hashMap.get(collectionTextRelationIdentifiedAnnotationPair).getCategory();
                        System.err.println("Error in: " + ViewUriUtil.getURI(jCas).toString());
                        System.err.println("Error! This attempted relation " + collectionTextRelationIdentifiedAnnotationRelation.getCategory() + " already has a relation " + category + " at this span: " + identifiedAnnotation.getCoveredText());
                    }
                    hashMap.put(collectionTextRelationIdentifiedAnnotationPair, collectionTextRelationIdentifiedAnnotationRelation);
                }
            }
        }
        Iterator it = JCasUtil.select(jCas, Segment.class).iterator();
        while (it.hasNext()) {
            for (Markable markable : JCasUtil.selectCovered(jCas, Markable.class, (Segment) it.next())) {
                boolean z = true;
                double d = 0.0d;
                CollectionTextRelation collectionTextRelation2 = null;
                Iterator<CollectionTextRelationIdentifiedAnnotationPair> it2 = getCandidateRelationArgumentPairs(jCas, markable).iterator();
                while (true) {
                    if (!it2.hasNext()) {
                        break;
                    }
                    CollectionTextRelation cluster = it2.next().getCluster();
                    ArrayList arrayList = new ArrayList();
                    Iterator<RelationFeaturesExtractor<CollectionTextRelation, IdentifiedAnnotation>> it3 = this.relationExtractors.iterator();
                    while (it3.hasNext()) {
                        List extract = it3.next().extract(jCas, cluster, markable);
                        if (extract != null) {
                            arrayList.addAll(extract);
                        }
                    }
                    Iterator<FeatureExtractor1<Markable>> it4 = this.mentionExtractors.iterator();
                    while (it4.hasNext()) {
                        arrayList.addAll(it4.next().extract(jCas, markable));
                    }
                    ArrayList arrayList2 = new ArrayList();
                    for (Feature feature : arrayList) {
                        if (feature.getValue() == null) {
                            feature.setValue("NULL");
                            System.err.println(String.format("Null value found in %s from %s", feature, arrayList));
                        }
                    }
                    arrayList.addAll(arrayList2);
                    if (isTraining()) {
                        String relationCategory = getRelationCategory(hashMap, cluster, markable);
                        if (relationCategory != null) {
                            this.dataWriter.write(new Instance(relationCategory, arrayList));
                            if (!relationCategory.equals("-NONE-")) {
                                z = false;
                                break;
                            }
                        } else {
                            continue;
                        }
                    } else {
                        String classify = classify(arrayList);
                        Map score = this.classifier.score(arrayList);
                        if (classify.equals("-NONE-")) {
                            continue;
                        } else if (this.greedyFirst) {
                            createRelation(jCas, cluster, markable, classify, (Double) score.get(classify));
                            z = false;
                            break;
                        } else if (((Double) score.get(classify)).doubleValue() > d) {
                            d = ((Double) score.get(classify)).doubleValue();
                            collectionTextRelation2 = cluster;
                        }
                    }
                }
                if (!isTraining() && !this.greedyFirst && collectionTextRelation2 != null) {
                    createRelation(jCas, collectionTextRelation2, markable, MentionClusterRankingCoreferenceAnnotator.CLUSTER_RELATION_CATEGORY, Double.valueOf(d));
                }
                if (z) {
                    CollectionTextRelation collectionTextRelation3 = new CollectionTextRelation(jCas);
                    collectionTextRelation3.setCategory(EventCoreferenceAnnotator.IDENTITY_RELATION);
                    NonEmptyFSList nonEmptyFSList = new NonEmptyFSList(jCas);
                    nonEmptyFSList.setHead(markable);
                    nonEmptyFSList.setTail(new EmptyFSList(jCas));
                    collectionTextRelation3.setMembers(nonEmptyFSList);
                    collectionTextRelation3.addToIndexes();
                    nonEmptyFSList.addToIndexes();
                    nonEmptyFSList.getTail().addToIndexes();
                }
            }
        }
        removeSingletonClusters(jCas);
        createEventClusters(jCas);
    }

    protected String getRelationCategory(Map<CollectionTextRelationIdentifiedAnnotationPair, CollectionTextRelationIdentifiedAnnotationRelation> map, CollectionTextRelation collectionTextRelation, IdentifiedAnnotation identifiedAnnotation) {
        CollectionTextRelationIdentifiedAnnotationRelation collectionTextRelationIdentifiedAnnotationRelation = map.get(new CollectionTextRelationIdentifiedAnnotationPair(collectionTextRelation, identifiedAnnotation));
        return collectionTextRelationIdentifiedAnnotationRelation != null ? collectionTextRelationIdentifiedAnnotationRelation.getCategory() : this.coin.nextDouble() <= this.probabilityOfKeepingANegativeExample ? "-NONE-" : null;
    }

    protected String classify(List<Feature> list) throws CleartkProcessingException {
        return (String) this.classifier.classify(list);
    }

    protected void createRelation(JCas jCas, CollectionTextRelation collectionTextRelation, IdentifiedAnnotation identifiedAnnotation, String str, Double d) {
        CollectionTextRelationIdentifiedAnnotationRelation collectionTextRelationIdentifiedAnnotationRelation = new CollectionTextRelationIdentifiedAnnotationRelation(jCas);
        collectionTextRelationIdentifiedAnnotationRelation.setCluster(collectionTextRelation);
        collectionTextRelationIdentifiedAnnotationRelation.setMention(identifiedAnnotation);
        collectionTextRelationIdentifiedAnnotationRelation.setCategory(str);
        collectionTextRelationIdentifiedAnnotationRelation.setConfidence(d.doubleValue());
        collectionTextRelationIdentifiedAnnotationRelation.addToIndexes();
        ListFactory.append(jCas, collectionTextRelation.getMembers(), identifiedAnnotation);
    }

    private static void createEventClusters(JCas jCas) throws AnalysisEngineProcessException {
        Event anatomicalSite;
        Map<Markable, List<IdentifiedAnnotation>> indexCoveringUmlsAnnotations = MarkableUtilities.indexCoveringUmlsAnnotations(jCas);
        for (CollectionTextRelation collectionTextRelation : JCasUtil.select(jCas, CollectionTextRelation.class)) {
            CounterMap counterMap = new CounterMap();
            ArrayList arrayList = new ArrayList(JCasUtil.select(collectionTextRelation.getMembers(), Markable.class));
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                IdentifiedAnnotation identifiedAnnotation = null;
                for (IdentifiedAnnotation identifiedAnnotation2 : indexCoveringUmlsAnnotations.get((Markable) it.next())) {
                    if (identifiedAnnotation == null || identifiedAnnotation2.getEnd() - identifiedAnnotation2.getBegin() > identifiedAnnotation.getEnd() - identifiedAnnotation.getBegin()) {
                        identifiedAnnotation = identifiedAnnotation2;
                    }
                }
                if (identifiedAnnotation != null) {
                    counterMap.add(identifiedAnnotation.getClass());
                }
            }
            FSArray fSArray = new FSArray(jCas, arrayList.size());
            IntStream.range(0, arrayList.size()).forEach(i -> {
                fSArray.set(i, (FeatureStructure) arrayList.get(i));
            });
            if (counterMap.size() == 0) {
                anatomicalSite = new Event(jCas);
            } else {
                Class cls = (Class) ((List) counterMap.entrySet().stream().sorted(Map.Entry.comparingByValue().reversed()).limit(1L).map(entry -> {
                    return (Class) entry.getKey();
                }).collect(Collectors.toList())).get(0);
                if (cls.equals(DiseaseDisorderMention.class)) {
                    anatomicalSite = new DiseaseDisorder(jCas);
                } else if (cls.equals(ProcedureMention.class)) {
                    anatomicalSite = new Procedure(jCas);
                } else if (cls.equals(SignSymptomMention.class)) {
                    anatomicalSite = new SignSymptom(jCas);
                } else if (cls.equals(MedicationMention.class)) {
                    anatomicalSite = new Medication(jCas);
                } else {
                    if (!cls.equals(AnatomicalSiteMention.class)) {
                        System.err.println("This coreference chain has an unknown type: " + cls.getSimpleName());
                        throw new AnalysisEngineProcessException();
                    }
                    anatomicalSite = new AnatomicalSite(jCas);
                }
            }
            Event event = anatomicalSite;
            event.setMentions(fSArray);
            event.addToIndexes();
        }
    }

    private static void removeSingletonClusters(JCas jCas) {
        ArrayList arrayList = new ArrayList();
        for (CollectionTextRelation collectionTextRelation : JCasUtil.select(jCas, CollectionTextRelation.class)) {
            if (collectionTextRelation.getMembers().getTail() instanceof EmptyFSList) {
                arrayList.add(collectionTextRelation);
            }
        }
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            ((CollectionTextRelation) it.next()).removeFromIndexes();
        }
    }

    public Map<RelationExtractorEvaluation.HashableArguments, Double> getMarkablePairScores(JCas jCas) {
        HashMap hashMap = new HashMap();
        for (CoreferenceRelation coreferenceRelation : JCasUtil.select(jCas, CoreferenceRelation.class)) {
            hashMap.put(new RelationExtractorEvaluation.HashableArguments(coreferenceRelation.getArg1().getArgument(), coreferenceRelation.getArg2().getArgument()), Double.valueOf(coreferenceRelation.getConfidence()));
        }
        return hashMap;
    }
}
