package org.apache.ctakes.coreference.eval;

import com.google.common.base.Function;
import com.lexicalscope.jewel.cli.CliFactory;
import java.io.File;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.ctakes.assertion.medfacts.cleartk.PolarityCleartkAnalysisEngine;
import org.apache.ctakes.coreference.ae.DeterministicMarkableAnnotator;
import org.apache.ctakes.coreference.ae.MarkableSalienceAnnotator;
import org.apache.ctakes.coreference.eval.EvaluationOfEventCoreference;
import org.apache.ctakes.dependency.parser.util.DependencyUtility;
import org.apache.ctakes.temporal.eval.Evaluation_ImplBase;
import org.apache.ctakes.typesystem.type.relation.CollectionTextRelation;
import org.apache.ctakes.typesystem.type.syntax.ConllDependencyNode;
import org.apache.ctakes.typesystem.type.textsem.Markable;
import org.apache.uima.analysis_engine.AnalysisEngine;
import org.apache.uima.analysis_engine.AnalysisEngineDescription;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.cas.CASException;
import org.apache.uima.collection.CollectionReader;
import org.apache.uima.fit.component.JCasAnnotator_ImplBase;
import org.apache.uima.fit.component.ViewCreatorAnnotator;
import org.apache.uima.fit.descriptor.ConfigurationParameter;
import org.apache.uima.fit.factory.AggregateBuilder;
import org.apache.uima.fit.factory.AnalysisEngineFactory;
import org.apache.uima.fit.pipeline.JCasIterator;
import org.apache.uima.fit.pipeline.SimplePipeline;
import org.apache.uima.fit.util.JCasUtil;
import org.apache.uima.jcas.JCas;
import org.apache.uima.jcas.cas.FSList;
import org.apache.uima.jcas.cas.NonEmptyFSList;
import org.cleartk.eval.AnnotationStatistics;
import org.cleartk.ml.jar.JarClassifierBuilder;
import org.cleartk.ml.liblinear.LibLinearBooleanOutcomeDataWriter;

/* loaded from: input_file:org/apache/ctakes/coreference/eval/EvaluationOfMarkableSalience.class */
public class EvaluationOfMarkableSalience extends Evaluation_ImplBase<AnnotationStatistics<Boolean>> {

    /* loaded from: input_file:org/apache/ctakes/coreference/eval/EvaluationOfMarkableSalience$CreatePseudoGoldMarkables.class */
    public static class CreatePseudoGoldMarkables extends JCasAnnotator_ImplBase {
        public static final String PARAM_PSEUDO_GOLD_VIEW = "PseudoViewName";

        @ConfigurationParameter(name = PARAM_PSEUDO_GOLD_VIEW)
        private String fakeGoldName;
        public static final String PARAM_GOLD_VIEW = "GoldViewName";

        @ConfigurationParameter(name = "GoldViewName")
        private String goldViewName;

        public void process(JCas jCas) throws AnalysisEngineProcessException {
            try {
                JCas view = jCas.getView(this.fakeGoldName);
                JCas view2 = jCas.getView(this.goldViewName);
                HashSet hashSet = new HashSet();
                Map indexCovering = JCasUtil.indexCovering(jCas, ConllDependencyNode.class, Markable.class);
                Iterator it = JCasUtil.select(view2, CollectionTextRelation.class).iterator();
                while (it.hasNext()) {
                    FSList members = ((CollectionTextRelation) it.next()).getMembers();
                    do {
                        NonEmptyFSList nonEmptyFSList = (NonEmptyFSList) members;
                        Markable head = nonEmptyFSList.getHead();
                        if (head.getBegin() >= 0 && head.getEnd() < jCas.getDocumentText().length()) {
                            ConllDependencyNode nominalHeadNode = DependencyUtility.getNominalHeadNode(jCas, head);
                            Iterator it2 = ((Collection) indexCovering.get(nominalHeadNode)).iterator();
                            while (true) {
                                if (!it2.hasNext()) {
                                    break;
                                }
                                Markable markable = (Markable) it2.next();
                                if (DependencyUtility.getNominalHeadNode(jCas, markable) == nominalHeadNode) {
                                    hashSet.add(markable);
                                    break;
                                }
                            }
                        }
                        members = nonEmptyFSList.getTail();
                    } while (members instanceof NonEmptyFSList);
                }
                for (Markable markable2 : JCasUtil.select(jCas, Markable.class)) {
                    Markable markable3 = new Markable(view, markable2.getBegin(), markable2.getEnd());
                    if (hashSet.contains(markable2)) {
                        markable3.setConfidence(1.0f);
                    } else {
                        markable3.setConfidence(0.0f);
                    }
                    markable3.addToIndexes();
                }
            } catch (CASException e) {
                throw new AnalysisEngineProcessException(e);
            }
        }
    }

    /* loaded from: input_file:org/apache/ctakes/coreference/eval/EvaluationOfMarkableSalience$SetGoldConfidence.class */
    public static class SetGoldConfidence extends JCasAnnotator_ImplBase {
        public static final String PARAM_GOLD_VIEW = "GoldViewName";

        @ConfigurationParameter(name = "GoldViewName", mandatory = true, description = "View containing gold standard annotations")
        private String goldViewName;

        public void process(JCas jCas) throws AnalysisEngineProcessException {
            try {
                JCas view = jCas.getView(this.goldViewName);
                Map indexCovering = JCasUtil.indexCovering(jCas, ConllDependencyNode.class, Markable.class);
                Iterator it = JCasUtil.select(view, CollectionTextRelation.class).iterator();
                while (it.hasNext()) {
                    FSList members = ((CollectionTextRelation) it.next()).getMembers();
                    do {
                        NonEmptyFSList nonEmptyFSList = (NonEmptyFSList) members;
                        Markable head = nonEmptyFSList.getHead();
                        if (head.getBegin() >= 0 && head.getEnd() < jCas.getDocumentText().length()) {
                            ConllDependencyNode nominalHeadNode = DependencyUtility.getNominalHeadNode(jCas, head);
                            Iterator it2 = ((Collection) indexCovering.get(nominalHeadNode)).iterator();
                            while (true) {
                                if (!it2.hasNext()) {
                                    break;
                                }
                                Markable markable = (Markable) it2.next();
                                if (DependencyUtility.getNominalHeadNode(jCas, markable) == nominalHeadNode) {
                                    markable.setConfidence(1.0f);
                                    break;
                                }
                            }
                        }
                        members = nonEmptyFSList.getTail();
                    } while (members instanceof NonEmptyFSList);
                }
            } catch (CASException e) {
                e.printStackTrace();
                throw new AnalysisEngineProcessException(e);
            }
        }
    }

    public static void main(String[] strArr) throws Exception {
        Evaluation_ImplBase.Options options = (Evaluation_ImplBase.Options) CliFactory.parseArguments(Evaluation_ImplBase.Options.class, strArr);
        List list = options.getPatients().getList();
        List trainItems = getTrainItems(options);
        List testItems = getTestItems(options);
        EvaluationOfMarkableSalience evaluationOfMarkableSalience = new EvaluationOfMarkableSalience(new File("target/eval/salience"), options.getRawTextDirectory(), options.getXMLDirectory(), options.getXMLFormat(), options.getSubcorpus(), options.getXMIDirectory(), null);
        evaluationOfMarkableSalience.prepareXMIsFor(list);
        AnnotationStatistics annotationStatistics = (AnnotationStatistics) evaluationOfMarkableSalience.trainAndTest(trainItems, testItems);
        System.out.println(annotationStatistics);
        System.out.println(annotationStatistics.confusions());
    }

    public EvaluationOfMarkableSalience(File file, File file2, File file3, Evaluation_ImplBase.XMLFormat xMLFormat, Evaluation_ImplBase.Subcorpus subcorpus, File file4, File file5) {
        super(file, file2, file3, xMLFormat, subcorpus, file4, file5);
    }

    protected void train(CollectionReader collectionReader, File file) throws Exception {
        AggregateBuilder preprocessorAggregateBuilder = getPreprocessorAggregateBuilder();
        preprocessorAggregateBuilder.add(PolarityCleartkAnalysisEngine.createAnnotatorDescription(), new String[0]);
        preprocessorAggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(EvaluationOfEventCoreference.DocumentIDPrinter.class, new Object[0]), new String[0]);
        preprocessorAggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(DeterministicMarkableAnnotator.class, new Object[0]), new String[0]);
        preprocessorAggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(EvaluationOfEventCoreference.RemovePersonMarkables.class, new Object[0]), new String[0]);
        preprocessorAggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(SetGoldConfidence.class, new Object[]{"GoldViewName", "GoldView"}), new String[0]);
        preprocessorAggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(new AnalysisEngineDescription[]{MarkableSalienceAnnotator.createDataWriterDescription(LibLinearBooleanOutcomeDataWriter.class, file)}), new String[0]);
        SimplePipeline.runPipeline(collectionReader, new AnalysisEngine[]{preprocessorAggregateBuilder.createAggregate()});
        JarClassifierBuilder.trainAndPackage(file, new String[]{"-s", "0", "-c", "1", "-w1", "1"});
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: test, reason: merged with bridge method [inline-methods] */
    public AnnotationStatistics<Boolean> m19test(CollectionReader collectionReader, File file) throws Exception {
        AggregateBuilder preprocessorAggregateBuilder = getPreprocessorAggregateBuilder();
        preprocessorAggregateBuilder.add(PolarityCleartkAnalysisEngine.createAnnotatorDescription(), new String[0]);
        preprocessorAggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(EvaluationOfEventCoreference.DocumentIDPrinter.class, new Object[0]), new String[0]);
        preprocessorAggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(DeterministicMarkableAnnotator.class, new Object[0]), new String[0]);
        preprocessorAggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(EvaluationOfEventCoreference.RemovePersonMarkables.class, new Object[0]), new String[0]);
        preprocessorAggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(ViewCreatorAnnotator.class, new Object[]{"viewName", "PseudoGold"}), new String[0]);
        preprocessorAggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(CreatePseudoGoldMarkables.class, new Object[]{"GoldViewName", "GoldView", CreatePseudoGoldMarkables.PARAM_PSEUDO_GOLD_VIEW, "PseudoGold"}), new String[0]);
        preprocessorAggregateBuilder.add(MarkableSalienceAnnotator.createAnnotatorDescription(file.getAbsolutePath() + File.separator + "model.jar"), new String[0]);
        AnnotationStatistics<Boolean> annotationStatistics = new AnnotationStatistics<>();
        JCasIterator jCasIterator = new JCasIterator(collectionReader, new AnalysisEngine[]{preprocessorAggregateBuilder.createAggregate()});
        while (jCasIterator.hasNext()) {
            JCas jCas = (JCas) jCasIterator.next();
            annotationStatistics.add(JCasUtil.select(jCas.getView("PseudoGold"), Markable.class), JCasUtil.select(jCas.getView("_InitialView"), Markable.class), AnnotationStatistics.annotationToSpan(), mapConfidenceToBoolean());
        }
        return annotationStatistics;
    }

    public static Function<Markable, Boolean> mapConfidenceToBoolean() {
        return new Function<Markable, Boolean>() { // from class: org.apache.ctakes.coreference.eval.EvaluationOfMarkableSalience.1
            public Boolean apply(Markable markable) {
                return Boolean.valueOf(((double) markable.getConfidence()) > 0.5d);
            }
        };
    }
}
