package org.apache.ctakes.relationextractor.pipelines;

import com.google.common.collect.ObjectArrays;
import com.lexicalscope.jewel.cli.CliFactory;
import com.lexicalscope.jewel.cli.Option;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.List;
import org.apache.ctakes.relationextractor.ae.ModifierExtractorAnnotator;
import org.apache.ctakes.relationextractor.ae.RelationExtractorAnnotator;
import org.apache.ctakes.relationextractor.data.GoldAnnotationStatsCalculator;
import org.apache.ctakes.relationextractor.eval.ModifierExtractorEvaluation;
import org.apache.ctakes.relationextractor.eval.ParameterSettings;
import org.apache.ctakes.relationextractor.eval.RelationExtractorEvaluation;
import org.apache.ctakes.relationextractor.eval.SHARPXMI;
import org.apache.ctakes.typesystem.type.relation.BinaryTextRelation;
import org.apache.ctakes.typesystem.type.relation.DegreeOfTextRelation;
import org.apache.ctakes.typesystem.type.relation.LocationOfTextRelation;
import org.apache.uima.UIMAFramework;
import org.apache.uima.analysis_engine.AnalysisEngineDescription;
import org.apache.uima.fit.factory.AggregateBuilder;
import org.apache.uima.fit.factory.AnalysisEngineFactory;
import org.apache.uima.fit.factory.TypeSystemDescriptionFactory;
import org.apache.uima.util.XMLInputSource;
import org.cleartk.ml.jar.JarClassifierBuilder;
import org.xml.sax.SAXException;

/* loaded from: input_file:org/apache/ctakes/relationextractor/pipelines/RelationExtractorTrain.class */
public class RelationExtractorTrain {

    /* loaded from: input_file:org/apache/ctakes/relationextractor/pipelines/RelationExtractorTrain$Options.class */
    interface Options extends RelationExtractorEvaluation.Options {
        @Option(longName = {"resources-dir"}, defaultValue = {"../ctakes-relation-extractor-res/src/main/resources"}, description = "the directory where resources (e.g. models) should be written")
        File getResourcesDirectory();

        @Option(longName = {"descriptors-dir"}, defaultValue = {"desc/analysis_engine"}, description = "the directory where descriptors should be written")
        File getDescriptorsDirectory();
    }

    public static void main(String[] strArr) throws Exception {
        Options options = (Options) CliFactory.parseArguments(Options.class, strArr);
        if (!options.getResourcesDirectory().exists()) {
            throw new IllegalArgumentException("directory not found: " + options.getResourcesDirectory().getCanonicalPath());
        }
        if (!options.getDescriptorsDirectory().exists()) {
            throw new IllegalArgumentException("directory not found: " + options.getDescriptorsDirectory().getCanonicalPath());
        }
        File resourcesDirectory = options.getResourcesDirectory();
        File descriptorsDirectory = options.getDescriptorsDirectory();
        File file = new File(descriptorsDirectory, "RelationExtractorPreprocessor.xml");
        if (!file.exists()) {
            throw new IllegalArgumentException("Can't create aggregate without " + file.getCanonicalPath());
        }
        List<File> xMIFiles = SHARPXMI.toXMIFiles(options, SHARPXMI.getAllTextFiles(options.getBatchesDirectory()));
        String str = "org/apache/ctakes/relationextractor/models/modifier_extractor";
        String str2 = "org/apache/ctakes/relationextractor/models/" + GoldAnnotationStatsCalculator.targetRelationType;
        System.err.println("Training modifier extractor");
        File file2 = new File(resourcesDirectory, str);
        ModifierExtractorEvaluation modifierExtractorEvaluation = new ModifierExtractorEvaluation(file2, ModifierExtractorEvaluation.BEST_PARAMETERS);
        modifierExtractorEvaluation.train(modifierExtractorEvaluation.getCollectionReader(xMIFiles), file2);
        AnalysisEngineDescription createEngineDescription = AnalysisEngineFactory.createEngineDescription(ModifierExtractorAnnotator.class, new Object[]{"classifierJarPath", "/" + str + "/model.jar"});
        writeDesc(descriptorsDirectory, (Class<?>) ModifierExtractorAnnotator.class, createEngineDescription);
        System.err.println("Training degree_of extractor");
        AnalysisEngineDescription trainRelationExtractor = trainRelationExtractor(resourcesDirectory, "org/apache/ctakes/relationextractor/models/degree_of", xMIFiles, DegreeOfTextRelation.class, descriptorsDirectory);
        System.err.println("Training location_of extractor");
        AnalysisEngineDescription trainRelationExtractor2 = trainRelationExtractor(resourcesDirectory, str2, xMIFiles, LocationOfTextRelation.class, descriptorsDirectory);
        System.err.println("Assembling relation extraction aggregate");
        AggregateBuilder aggregateBuilder = new AggregateBuilder();
        aggregateBuilder.add(UIMAFramework.getXMLParser().parseAnalysisEngineDescription(new XMLInputSource(file)), new String[0]);
        aggregateBuilder.add(createEngineDescription, new String[0]);
        aggregateBuilder.add(trainRelationExtractor, new String[0]);
        aggregateBuilder.add(trainRelationExtractor2, new String[0]);
        writeDesc(descriptorsDirectory, "RelationExtractorAggregate", aggregateBuilder.createAggregateDescription());
        for (File file3 : new File(resourcesDirectory, "org/apache/ctakes/relationextractor/models/").listFiles()) {
            File modelJarFile = JarClassifierBuilder.getModelJarFile(file3);
            for (File file4 : file3.listFiles()) {
                if (!file4.equals(modelJarFile)) {
                    file4.delete();
                }
            }
        }
    }

    private static AnalysisEngineDescription trainRelationExtractor(File file, String str, List<File> list, Class<? extends BinaryTextRelation> cls, File file2) throws Exception {
        Class<? extends RelationExtractorAnnotator> cls2 = RelationExtractorEvaluation.ANNOTATOR_CLASSES.get(cls);
        ParameterSettings parameterSettings = RelationExtractorEvaluation.BEST_PARAMETERS.get(cls);
        File file3 = new File(file, str);
        RelationExtractorEvaluation relationExtractorEvaluation = new RelationExtractorEvaluation(file3, cls, cls2, parameterSettings);
        relationExtractorEvaluation.train(relationExtractorEvaluation.getCollectionReader(list), file3);
        AnalysisEngineDescription createEngineDescription = AnalysisEngineFactory.createEngineDescription(cls2, ObjectArrays.concat(parameterSettings.configurationParameters, new Object[]{"classifierJarPath", "/" + str + "/model.jar"}, Object.class));
        writeDesc(file2, cls2, createEngineDescription);
        return createEngineDescription;
    }

    private static void writeDesc(File file, Class<?> cls, AnalysisEngineDescription analysisEngineDescription) throws SAXException, IOException {
        analysisEngineDescription.getAnalysisEngineMetaData().setTypeSystem(TypeSystemDescriptionFactory.createTypeSystemDescription(new String[]{"org.apache.ctakes.typesystem.types.TypeSystem"}));
        writeDesc(file, cls.getSimpleName(), analysisEngineDescription);
    }

    private static void writeDesc(File file, String str, AnalysisEngineDescription analysisEngineDescription) throws SAXException, IOException {
        analysisEngineDescription.getMetaData().setName(str);
        File file2 = new File(file, str + ".xml");
        System.err.println("Writing description to " + file2);
        FileOutputStream fileOutputStream = new FileOutputStream(file2);
        analysisEngineDescription.toXML(fileOutputStream);
        fileOutputStream.close();
    }
}
