package org.apache.ctakes.coreference.cc;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.PrintWriter;
import java.io.Writer;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Scanner;
import libsvm.svm_node;
import org.apache.ctakes.constituency.parser.treekernel.TreeExtractor;
import org.apache.ctakes.constituency.parser.util.TreeUtils;
import org.apache.ctakes.core.pipeline.PipeBitInfo;
import org.apache.ctakes.core.resource.FileLocator;
import org.apache.ctakes.core.util.doc.DocIdUtil;
import org.apache.ctakes.coreference.type.BooleanLabeledFS;
import org.apache.ctakes.coreference.type.DemMarkable;
import org.apache.ctakes.coreference.type.Markable;
import org.apache.ctakes.coreference.type.MarkablePairSet;
import org.apache.ctakes.coreference.type.NEMarkable;
import org.apache.ctakes.coreference.util.CorefConsts;
import org.apache.ctakes.coreference.util.FSIteratorToList;
import org.apache.ctakes.coreference.util.GoldStandardLabeler;
import org.apache.ctakes.coreference.util.MarkableTreeUtils;
import org.apache.ctakes.coreference.util.PairAttributeCalculator;
import org.apache.ctakes.coreference.util.SvmVectorCreator;
import org.apache.ctakes.relationextractor.eval.XMIReader;
import org.apache.ctakes.typesystem.type.syntax.TreebankNode;
import org.apache.ctakes.utils.tree.SimpleTree;
import org.apache.log4j.Logger;
import org.apache.uima.UimaContext;
import org.apache.uima.analysis_component.JCasAnnotator_ImplBase;
import org.apache.uima.analysis_engine.AnalysisEngine;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.cas.FSIterator;
import org.apache.uima.fit.factory.AnalysisEngineFactory;
import org.apache.uima.fit.factory.CollectionReaderFactory;
import org.apache.uima.fit.pipeline.SimplePipeline;
import org.apache.uima.jcas.JCas;
import org.apache.uima.jcas.cas.FSList;
import org.apache.uima.jcas.cas.NonEmptyFSList;

@PipeBitInfo(name = "ODIE Vector File Writer", description = "Write ODIE Vector File.", role = PipeBitInfo.Role.WRITER, dependencies = {PipeBitInfo.TypeProduct.DOCUMENT_ID, PipeBitInfo.TypeProduct.MARKABLE})
/* loaded from: input_file:org/apache/ctakes/coreference/cc/ODIEVectorFileWriter.class */
public class ODIEVectorFileWriter extends JCasAnnotator_ImplBase {
    private HashSet<String> stopwords;
    private ArrayList<String> treeFrags;
    private boolean printVectors;
    private boolean printTrees;
    public static final String PARAM_OUTPUT_DIR = "outputDir";
    public static final String PARAM_GOLD_DIR = "goldStandardDir";
    public static final String PARAM_VECTORS = "writeVectors";
    public static final String PARAM_TREES = "writeTrees";
    public static final String PARAM_FRAGS = "treeFrags";
    public static final String PARAM_STOPS = "stopWords";
    private Logger log = Logger.getLogger(getClass());
    private String outputDir = null;
    private String goldStandardDir = null;
    private PrintWriter neOut = null;
    private PrintWriter pronOut = null;
    private PrintWriter demOut = null;
    private PrintWriter neTreeOut = null;
    private PrintWriter pronTreeOut = null;
    private PrintWriter demTreeOut = null;
    private PrintWriter debug = null;
    private boolean initialized = false;
    private int posNeInst = 0;
    private int negNeInst = 0;
    private int posDemInst = 0;
    private int negDemInst = 0;
    private int posPronInst = 0;
    private int negPronInst = 0;
    private int posAnaphInst = 0;
    private int negAnaphInst = 0;
    private PairAttributeCalculator attr = null;
    private SvmVectorCreator vecCreator = null;
    private GoldStandardLabeler labeler = null;
    private boolean useFrags = true;

    public void initialize(UimaContext uimaContext) {
        this.outputDir = (String) uimaContext.getConfigParameterValue(PARAM_OUTPUT_DIR);
        this.goldStandardDir = (String) uimaContext.getConfigParameterValue(PARAM_GOLD_DIR);
        this.printVectors = ((Boolean) uimaContext.getConfigParameterValue(PARAM_VECTORS)).booleanValue();
        this.printTrees = ((Boolean) uimaContext.getConfigParameterValue(PARAM_TREES)).booleanValue();
        try {
            new File(this.outputDir + "/" + CorefConsts.NE + "/vectors/").mkdirs();
            new File(this.outputDir + "/" + CorefConsts.PRON + "/vectors/").mkdirs();
            new File(this.outputDir + "/" + CorefConsts.DEM + "/vectors/").mkdirs();
            if (this.printTrees) {
                this.neTreeOut = new PrintWriter(this.outputDir + "/" + CorefConsts.NE + "/trees.txt");
                this.demTreeOut = new PrintWriter(this.outputDir + "/" + CorefConsts.DEM + "/trees.txt");
                this.pronTreeOut = new PrintWriter(this.outputDir + "/" + CorefConsts.PRON + "/trees.txt");
                this.debug = new PrintWriter((Writer) new PrintWriter(this.outputDir + "/" + CorefConsts.NE + "/fulltrees_debug.txt"), true);
            }
            this.stopwords = new HashSet<>();
            BufferedReader bufferedReader = new BufferedReader(new FileReader(FileLocator.getFile((String) uimaContext.getConfigParameterValue(PARAM_STOPS))));
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    break;
                }
                String trim = readLine.trim();
                if (trim.length() != 0) {
                    int indexOf = trim.indexOf(124);
                    if (indexOf > 0) {
                        this.stopwords.add(trim.substring(0, indexOf).trim());
                    } else if (indexOf < 0) {
                        this.stopwords.add(trim.trim());
                    }
                }
            }
            this.vecCreator = new SvmVectorCreator(this.stopwords);
            Scanner scanner = new Scanner(FileLocator.getFile((String) uimaContext.getConfigParameterValue(PARAM_FRAGS)));
            if (this.useFrags) {
                this.treeFrags = new ArrayList<>();
                while (scanner.hasNextLine()) {
                    this.treeFrags.add(scanner.nextLine().split(" ")[1]);
                }
                this.vecCreator.setFrags(this.treeFrags);
            }
            this.initialized = true;
        } catch (Exception e) {
            System.err.println("Error initializing file writers.");
        }
    }

    public void process(JCas jCas) {
        if (this.initialized) {
            String documentID = DocIdUtil.getDocumentID(jCas);
            String substring = documentID.substring(documentID.lastIndexOf(47) + 1, documentID.length());
            if (substring == null) {
                substring = "141471681_1";
            }
            System.out.println("creating vectors for " + substring);
            int i = 0;
            this.labeler = new GoldStandardLabeler(this.goldStandardDir, substring, FSIteratorToList.convert(jCas.getAnnotationIndex(Markable.type).iterator()));
            if (this.printVectors) {
                try {
                    this.neOut = new PrintWriter(this.outputDir + "/" + CorefConsts.NE + "/vectors/" + substring + ".libsvm");
                    this.demOut = new PrintWriter(this.outputDir + "/" + CorefConsts.DEM + "/vectors/" + substring + ".libsvm");
                    this.pronOut = new PrintWriter(this.outputDir + "/" + CorefConsts.PRON + "/vectors/" + substring + ".libsvm");
                } catch (FileNotFoundException e) {
                    e.printStackTrace();
                }
            }
            FSIterator allIndexedFS = jCas.getJFSIndexRepository().getAllIndexedFS(MarkablePairSet.type);
            while (allIndexedFS.hasNext()) {
                MarkablePairSet markablePairSet = (MarkablePairSet) allIndexedFS.next();
                Markable anaphor = markablePairSet.getAnaphor();
                String str = anaphor instanceof NEMarkable ? CorefConsts.NE : anaphor instanceof DemMarkable ? CorefConsts.DEM : CorefConsts.PRON;
                FSList antecedentList = markablePairSet.getAntecedentList();
                while (true) {
                    FSList fSList = antecedentList;
                    if (fSList instanceof NonEmptyFSList) {
                        NonEmptyFSList nonEmptyFSList = (NonEmptyFSList) fSList;
                        BooleanLabeledFS booleanLabeledFS = (BooleanLabeledFS) nonEmptyFSList.getHead();
                        boolean z = booleanLabeledFS.getLabel();
                        Markable feature = booleanLabeledFS.getFeature();
                        int i2 = this.labeler.isGoldPair(anaphor, feature) ? 1 : 0;
                        if (i2 == 1) {
                            i++;
                            if (str.equals(CorefConsts.NE)) {
                                this.posNeInst++;
                            } else if (str.equals(CorefConsts.DEM)) {
                                this.posDemInst++;
                            } else if (str.equals(CorefConsts.PRON)) {
                                this.posPronInst++;
                            }
                        } else if (i2 == 0) {
                            if (str.equals(CorefConsts.NE)) {
                                this.negNeInst++;
                            } else if (str.equals(CorefConsts.DEM)) {
                                this.negDemInst++;
                            } else if (str.equals(CorefConsts.PRON)) {
                                this.negPronInst++;
                            }
                        }
                        if (this.printVectors) {
                            svm_node[] nodeFeatures = this.vecCreator.getNodeFeatures(anaphor, feature, jCas);
                            PrintWriter printWriter = null;
                            if (str.equals(CorefConsts.NE)) {
                                printWriter = this.neOut;
                            } else if (str.equals(CorefConsts.PRON)) {
                                printWriter = this.pronOut;
                            } else if (str.equals(CorefConsts.DEM)) {
                                printWriter = this.demOut;
                            }
                            printWriter.print(i2);
                            for (svm_node svm_nodeVar : nodeFeatures) {
                                printWriter.print(" ");
                                printWriter.print(svm_nodeVar.index);
                                printWriter.print(":");
                                printWriter.print(svm_nodeVar.value);
                            }
                            printWriter.println();
                            printWriter.flush();
                        }
                        if (this.printTrees) {
                            TreebankNode markableNode = MarkableTreeUtils.markableNode(jCas, feature.getBegin(), feature.getEnd());
                            TreebankNode markableNode2 = MarkableTreeUtils.markableNode(jCas, anaphor.getBegin(), anaphor.getEnd());
                            this.debug.println(TreeUtils.tree2str(markableNode));
                            this.debug.println(TreeUtils.tree2str(markableNode2));
                            SimpleTree extractPathTree = TreeExtractor.extractPathTree(markableNode, markableNode2);
                            TreeExtractor.extractPathEnclosedTree(markableNode, markableNode2, jCas);
                            String simpleTree = extractPathTree.toString();
                            PrintWriter printWriter2 = null;
                            if (str.equals(CorefConsts.NE)) {
                                printWriter2 = this.neTreeOut;
                            } else if (str.equals(CorefConsts.PRON)) {
                                printWriter2 = this.pronTreeOut;
                            } else if (str.equals(CorefConsts.DEM)) {
                                printWriter2 = this.demTreeOut;
                            }
                            printWriter2.print(i2 == 1 ? "+1" : "-1");
                            printWriter2.print(" |BT| ");
                            printWriter2.print(simpleTree.replaceAll("\\) \\(", ")("));
                            printWriter2.println(" |ET|");
                        }
                        antecedentList = nonEmptyFSList.getTail();
                    }
                }
            }
            if (this.printVectors) {
                this.neOut.close();
                this.demOut.close();
                this.pronOut.close();
            }
        }
    }

    private int getLabel(String str) {
        return Integer.parseInt(str.substring(0, 1));
    }

    public void batchProcessComplete() throws AnalysisEngineProcessException {
        super.batchProcessComplete();
        if (this.initialized) {
            if (this.printVectors) {
                this.neOut.close();
                this.demOut.close();
                this.pronOut.close();
            }
            if (this.printTrees) {
                this.neTreeOut.flush();
                this.neTreeOut.close();
                this.demTreeOut.flush();
                this.demTreeOut.close();
                this.pronTreeOut.flush();
                this.pronTreeOut.close();
            }
        }
    }

    private double[] listToDoubleArray(ArrayList<Integer> arrayList) {
        double[] dArr = new double[arrayList.size()];
        for (int i = 0; i < arrayList.size(); i++) {
            dArr[i] = arrayList.get(i).intValue();
        }
        return dArr;
    }

    public static void main(String[] strArr) {
        if (strArr.length < 3) {
            System.err.println("Arguments: <training directory> <gold-pairs directory> <output directory>");
            System.exit(-1);
        }
        File file = new File(strArr[0]);
        if (!file.isDirectory()) {
            System.err.println("Arg1 should be a directory! (full of xmi files)");
            System.exit(-1);
        }
        File[] listFiles = file.listFiles();
        String[] strArr2 = new String[listFiles.length];
        for (int i = 0; i < listFiles.length; i++) {
            strArr2[i] = listFiles[i].getAbsolutePath();
        }
        try {
            SimplePipeline.runPipeline(CollectionReaderFactory.createReader(XMIReader.class, new Object[]{"files", strArr2}), new AnalysisEngine[]{AnalysisEngineFactory.createEngine(ODIEVectorFileWriter.class, new Object[]{PARAM_VECTORS, true, PARAM_TREES, false, PARAM_STOPS, "org/apache/ctakes/coreference/models/stop.txt", PARAM_FRAGS, "org/apache/ctakes/coreference/models/frags.txt", PARAM_GOLD_DIR, strArr[1], PARAM_OUTPUT_DIR, strArr[2]})});
        } catch (Exception e) {
            System.err.println("Exception thrown!");
            e.printStackTrace();
        }
    }
}
