package org.apache.ctakes.relationextractor.eval;

import com.google.common.base.Function;
import com.google.common.base.Functions;
import com.google.common.collect.Lists;
import com.google.common.collect.Ordering;
import com.lexicalscope.jewel.cli.Option;
import java.io.File;
import java.io.FileOutputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.regex.Pattern;
import org.apache.ctakes.core.ae.SHARPKnowtatorXMLReader;
import org.apache.ctakes.core.util.DocumentIDAnnotationUtil;
import org.apache.ctakes.relationextractor.data.GoldAnnotationStatsCalculator;
import org.apache.ctakes.relationextractor.data.Stats;
import org.apache.ctakes.typesystem.type.structured.DocumentID;
import org.apache.uima.UIMAFramework;
import org.apache.uima.analysis_engine.AnalysisEngine;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.cas.CASException;
import org.apache.uima.cas.impl.XmiCasSerializer;
import org.apache.uima.collection.CollectionReader;
import org.apache.uima.fit.component.JCasAnnotator_ImplBase;
import org.apache.uima.fit.component.ViewCreatorAnnotator;
import org.apache.uima.fit.factory.AggregateBuilder;
import org.apache.uima.fit.factory.AnalysisEngineFactory;
import org.apache.uima.fit.factory.CollectionReaderFactory;
import org.apache.uima.fit.factory.TypeSystemDescriptionFactory;
import org.apache.uima.fit.pipeline.JCasIterator;
import org.apache.uima.jcas.JCas;
import org.apache.uima.util.XMLInputSource;
import org.apache.uima.util.XMLSerializer;
import org.cleartk.eval.AnnotationStatistics;
import org.cleartk.util.ViewUriUtil;
import org.cleartk.util.ae.UriToDocumentTextAnnotator;
import org.cleartk.util.cr.UriCollectionReader;

/* loaded from: input_file:org/apache/ctakes/relationextractor/eval/SHARPXMI.class */
public class SHARPXMI {
    public static final String GOLD_VIEW_NAME = "GoldView";

    /* renamed from: org.apache.ctakes.relationextractor.eval.SHARPXMI$1, reason: invalid class name */
    /* loaded from: input_file:org/apache/ctakes/relationextractor/eval/SHARPXMI$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$apache$ctakes$relationextractor$eval$SHARPXMI$EvaluateOn = new int[EvaluateOn.values().length];

        static {
            try {
                $SwitchMap$org$apache$ctakes$relationextractor$eval$SHARPXMI$EvaluateOn[EvaluateOn.TRAIN.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$apache$ctakes$relationextractor$eval$SHARPXMI$EvaluateOn[EvaluateOn.DEV.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$apache$ctakes$relationextractor$eval$SHARPXMI$EvaluateOn[EvaluateOn.TEST.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    /* loaded from: input_file:org/apache/ctakes/relationextractor/eval/SHARPXMI$CopyDocumentTextToGoldView.class */
    public static class CopyDocumentTextToGoldView extends JCasAnnotator_ImplBase {
        public void process(JCas jCas) throws AnalysisEngineProcessException {
            try {
                jCas.getView("GoldView").setDocumentText(jCas.getDocumentText());
            } catch (CASException e) {
                throw new AnalysisEngineProcessException(e);
            }
        }
    }

    /* loaded from: input_file:org/apache/ctakes/relationextractor/eval/SHARPXMI$DocumentIDAnnotator.class */
    public static class DocumentIDAnnotator extends JCasAnnotator_ImplBase {
        public void process(JCas jCas) throws AnalysisEngineProcessException {
            String path = new File(ViewUriUtil.getURI(jCas)).getPath();
            DocumentID documentID = new DocumentID(jCas);
            documentID.setDocumentID(path);
            documentID.addToIndexes();
        }
    }

    /* loaded from: input_file:org/apache/ctakes/relationextractor/eval/SHARPXMI$EvaluateOn.class */
    public enum EvaluateOn {
        TRAIN,
        DEV,
        TEST
    }

    /* loaded from: input_file:org/apache/ctakes/relationextractor/eval/SHARPXMI$EvaluationOptions.class */
    public interface EvaluationOptions extends Options {
        @Option(longName = {"evaluate-on"}, defaultValue = {"DEV"}, description = "perform evaluation using the training (TRAIN), development (DEV) or test (TEST) data.")
        EvaluateOn getEvaluteOn();

        @Option(longName = {"grid-search"}, description = "run a grid search to select the best parameters")
        boolean getGridSearch();
    }

    /* loaded from: input_file:org/apache/ctakes/relationextractor/eval/SHARPXMI$Evaluation_ImplBase.class */
    public static abstract class Evaluation_ImplBase extends org.cleartk.eval.Evaluation_ImplBase<File, AnnotationStatistics<String>> {
        public Evaluation_ImplBase(File file) {
            super(file);
        }

        public CollectionReader getCollectionReader(List<File> list) throws Exception {
            return CollectionReaderFactory.createReader(XMIReader.class, TypeSystemDescriptionFactory.createTypeSystemDescription(), new Object[]{XMIReader.PARAM_FILES, list});
        }
    }

    /* loaded from: input_file:org/apache/ctakes/relationextractor/eval/SHARPXMI$Options.class */
    public interface Options {
        @Option(longName = {"batches-dir"}, description = "directory containing ssN_batchNN directories, each of which should contain a Knowtator directory and a Knowtator_XML directory")
        File getBatchesDirectory();

        @Option(longName = {"xmi-dir"}, defaultValue = {"target/xmi"}, description = "directory to store and load XMI serialization of annotations")
        File getXMIDirectory();

        @Option(longName = {"generate-xmi"}, description = "read in the gold annotations and serialize them as XMI")
        boolean getGenerateXMI();
    }

    public static List<File> getTrainTextFiles(File file) {
        return getTextFilesFor(file, Pattern.compile("^(ss[1234]_batch0[2-9]|ss[1234]_batch1[56]|ss[1234]_batch1[89]|ss[123]_batch01|ss[12]_batch1[34]|ss[34]_batch1[12])$"));
    }

    public static List<File> getDevTextFiles(File file) {
        return getTextFilesFor(file, Pattern.compile("^(ss[1234]_batch1[07])$"));
    }

    public static List<File> getTestTextFiles(File file) {
        return getTextFilesFor(file, Pattern.compile("^(ss[12]_batch1[12]|ss[34]_batch1[34])$"));
    }

    public static List<File> getAllTextFiles(File file) {
        return getTextFilesFor(file, Pattern.compile(""));
    }

    private static List<File> getTextFilesFor(File file, Pattern pattern) {
        ArrayList newArrayList = Lists.newArrayList();
        for (File file2 : file.listFiles()) {
            if (file2.isDirectory() && !file2.isHidden() && pattern.matcher(file2.getName()).find()) {
                for (File file3 : new File(file2, "Knowtator/text").listFiles()) {
                    if (file3.isFile() && !file3.isHidden()) {
                        newArrayList.add(file3);
                    }
                }
            }
        }
        return newArrayList;
    }

    public static List<File> toXMIFiles(Options options, List<File> list) {
        ArrayList newArrayList = Lists.newArrayList();
        Iterator<File> it = list.iterator();
        while (it.hasNext()) {
            newArrayList.add(toXMIFile(options, it.next()));
        }
        return newArrayList;
    }

    private static File toXMIFile(Options options, File file) {
        return new File(options.getXMIDirectory(), file.getName() + ".xmi");
    }

    public static void generateXMI(Options options) throws Exception {
        if (options.getGenerateXMI()) {
            if (!options.getXMIDirectory().exists()) {
                options.getXMIDirectory().mkdirs();
            }
            ArrayList newArrayList = Lists.newArrayList();
            newArrayList.addAll(getTrainTextFiles(options.getBatchesDirectory()));
            newArrayList.addAll(getDevTextFiles(options.getBatchesDirectory()));
            newArrayList.addAll(getTestTextFiles(options.getBatchesDirectory()));
            CollectionReader collectionReaderFromFiles = UriCollectionReader.getCollectionReaderFromFiles(newArrayList);
            AggregateBuilder aggregateBuilder = new AggregateBuilder();
            aggregateBuilder.add(UriToDocumentTextAnnotator.getDescription(), new String[0]);
            aggregateBuilder.add(UIMAFramework.getXMLParser().parseAnalysisEngineDescription(new XMLInputSource(new File("desc/analysis_engine/RelationExtractorPreprocessor.xml"))), new String[0]);
            aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(ViewCreatorAnnotator.class, new Object[]{"viewName", "GoldView"}), new String[0]);
            aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(CopyDocumentTextToGoldView.class, new Object[0]), new String[0]);
            aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(DocumentIDAnnotator.class, new Object[0]), new String[]{GoldAnnotationStatsCalculator.systemViewName, "GoldView"});
            aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(SHARPKnowtatorXMLReader.class, new Object[]{"SetDefaults", true}), new String[]{GoldAnnotationStatsCalculator.systemViewName, "GoldView"});
            JCasIterator jCasIterator = new JCasIterator(collectionReaderFromFiles, new AnalysisEngine[]{aggregateBuilder.createAggregate()});
            while (jCasIterator.hasNext()) {
                JCas jCas = (JCas) jCasIterator.next();
                String documentID = DocumentIDAnnotationUtil.getDocumentID(jCas.getView("GoldView"));
                if (documentID == null) {
                    throw new IllegalArgumentException("No documentID for CAS:\n" + jCas);
                }
                FileOutputStream fileOutputStream = new FileOutputStream(toXMIFile(options, new File(documentID)));
                new XmiCasSerializer(jCas.getTypeSystem()).serialize(jCas.getCas(), new XMLSerializer(fileOutputStream).getContentHandler());
                fileOutputStream.close();
            }
        }
    }

    public static void validate(EvaluationOptions evaluationOptions) throws Exception {
        if (evaluationOptions.getEvaluteOn().equals(EvaluateOn.TEST) && evaluationOptions.getGridSearch()) {
            throw new IllegalArgumentException("grid search can only be run on the train or dev sets");
        }
    }

    public static <T extends Evaluation_ImplBase> void evaluate(EvaluationOptions evaluationOptions, ParameterSettings parameterSettings, List<ParameterSettings> list, Function<ParameterSettings, T> function) throws Exception {
        List<ParameterSettings> newArrayList = evaluationOptions.getGridSearch() ? list : Lists.newArrayList(new ParameterSettings[]{parameterSettings});
        HashMap hashMap = new HashMap();
        for (ParameterSettings parameterSettings2 : newArrayList) {
            Evaluation_ImplBase evaluation_ImplBase = (Evaluation_ImplBase) function.apply(parameterSettings2);
            switch (AnonymousClass1.$SwitchMap$org$apache$ctakes$relationextractor$eval$SHARPXMI$EvaluateOn[evaluationOptions.getEvaluteOn().ordinal()]) {
                case Stats.readOnlySharpRelations /* 1 */:
                    parameterSettings2.stats = AnnotationStatistics.addAll(evaluation_ImplBase.crossValidation(toXMIFiles(evaluationOptions, getTrainTextFiles(evaluationOptions.getBatchesDirectory())), 2));
                    break;
                case 2:
                    parameterSettings2.stats = (AnnotationStatistics) evaluation_ImplBase.trainAndTest(toXMIFiles(evaluationOptions, getTrainTextFiles(evaluationOptions.getBatchesDirectory())), toXMIFiles(evaluationOptions, getDevTextFiles(evaluationOptions.getBatchesDirectory())));
                    break;
                case 3:
                    ArrayList arrayList = new ArrayList();
                    arrayList.addAll(getTrainTextFiles(evaluationOptions.getBatchesDirectory()));
                    arrayList.addAll(getDevTextFiles(evaluationOptions.getBatchesDirectory()));
                    parameterSettings2.stats = (AnnotationStatistics) evaluation_ImplBase.trainAndTest(toXMIFiles(evaluationOptions, arrayList), toXMIFiles(evaluationOptions, getTestTextFiles(evaluationOptions.getBatchesDirectory())));
                    break;
                default:
                    throw new IllegalArgumentException("Invalid EvaluateOn: " + evaluationOptions.getEvaluteOn());
            }
            hashMap.put(parameterSettings2, Double.valueOf(parameterSettings2.stats.f1()));
        }
        ArrayList<ParameterSettings> arrayList2 = new ArrayList(hashMap.keySet());
        Collections.sort(arrayList2, Ordering.natural().onResultOf(Functions.forMap(hashMap)));
        if (arrayList2.size() > 1) {
            System.err.println("Summary");
            for (ParameterSettings parameterSettings3 : arrayList2) {
                System.err.printf("F1=%.3f P=%.3f R=%.3f %s\n", Double.valueOf(parameterSettings3.stats.f1()), Double.valueOf(parameterSettings3.stats.precision()), Double.valueOf(parameterSettings3.stats.recall()), parameterSettings3);
            }
            System.err.println();
        }
        if (arrayList2.isEmpty()) {
            return;
        }
        ParameterSettings parameterSettings4 = (ParameterSettings) arrayList2.get(arrayList2.size() - 1);
        System.err.println("Best model:");
        System.err.print(parameterSettings4.stats);
        System.err.println(parameterSettings4);
        System.err.println(parameterSettings4.stats.confusions());
        System.err.println();
    }
}
