/*
 * Decompiled with CFR 0.152.
 */
package de.julielab.genemapper.filtering;

import cc.mallet.classify.Classification;
import cc.mallet.classify.Classifier;
import cc.mallet.classify.MaxEnt;
import cc.mallet.classify.MaxEntTrainer;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Label;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.Labeling;
import com.google.common.collect.Multimap;
import de.julielab.evaluation.entities.EntityEvaluationResult;
import de.julielab.evaluation.entities.EntityEvaluator;
import de.julielab.evaluation.entities.EvaluationData;
import de.julielab.evaluation.entities.EvaluationDataEntry;
import de.julielab.geneexpbase.TermNormalizer;
import de.julielab.geneexpbase.data.CorpusReader;
import de.julielab.geneexpbase.genemodel.Acronym;
import de.julielab.geneexpbase.genemodel.GeneDocument;
import de.julielab.geneexpbase.genemodel.GeneMention;
import de.julielab.geneexpbase.genemodel.GeneSet;
import de.julielab.geneexpbase.genemodel.PosTag;
import de.julielab.genemapper.filtering.InstanceListCreator;
import de.julielab.java.utilities.FileUtilities;
import de.julielab.java.utilities.spanutils.OffsetMap;
import de.julielab.java.utilities.spanutils.OffsetSet;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.lang3.Range;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MLGeneFilter {
    private static final Logger log = LoggerFactory.getLogger(MLGeneFilter.class);
    private Classifier classifier;

    public static void main(String[] args) throws IOException {
        if (args.length < 1) {
            System.err.println("Usage: " + MLGeneFilter.class.getSimpleName() + " <train/filter/eval> <mode specific parameters>");
            System.exit(1);
        }
        String mode = args[0];
        File modelfile = new File(args[1]);
        File dataDirectory = new File(args[2]);
        MLGeneFilter mlGeneFilter = new MLGeneFilter();
        switch (mode) {
            case "train": {
                if (args.length != 4) {
                    System.err.println("Usage: " + MLGeneFilter.class.getSimpleName() + "train <model file> <data directory> <gold genelist>");
                    System.exit(1);
                }
                File goldGeneList = new File(args[3]);
                Stream<GeneDocument> documents = MLGeneFilter.loadData(dataDirectory, goldGeneList);
                mlGeneFilter.train(documents);
                mlGeneFilter.writeModel(modelfile);
                System.out.println("Filter performance on training data:");
                MLGeneFilter.evaluate(modelfile, dataDirectory, mlGeneFilter, goldGeneList);
                break;
            }
            case "filter": {
                if (args.length != 3) {
                    System.err.println("Usage: " + MLGeneFilter.class.getSimpleName() + "filter <model file> <data directory>");
                    System.exit(1);
                }
                throw new IllegalArgumentException("Filtering from CLI Not yet implemented");
            }
            case "eval": {
                if (args.length != 4) {
                    System.err.println("Usage: " + MLGeneFilter.class.getSimpleName() + "train <model file> <data directory> <gold genelist>");
                    System.exit(1);
                }
                File goldGeneList = new File(args[3]);
                MLGeneFilter.evaluate(modelfile, dataDirectory, mlGeneFilter, goldGeneList);
            }
        }
    }

    public static void evaluate(File modelfile, File dataDirectory, MLGeneFilter mlGeneFilter, File goldGeneList) throws IOException {
        mlGeneFilter.loadModel(FileUtilities.getInputStreamFromFile(modelfile));
        Stream<GeneDocument> testDocs = MLGeneFilter.loadData(dataDirectory, goldGeneList);
        List<GeneDocument> testList = testDocs.collect(Collectors.toList());
        mlGeneFilter.filter(testList.stream());
        EvaluationData goldEvaluationData = new EvaluationData();
        Multimap<String, GeneMention> goldData = CorpusReader.readMentionsWithOffsets(goldGeneList.getAbsolutePath());
        goldData.entries().stream().map(e -> {
            EvaluationDataEntry entry = new EvaluationDataEntry((String)e.getKey(), ((GeneMention)e.getValue()).getGoldMentionId(), ((GeneMention)e.getValue()).getText(), "gold");
            entry.setBegin(((GeneMention)e.getValue()).getBegin());
            entry.setEnd(((GeneMention)e.getValue()).getEnd());
            entry.setEntityId("NoId");
            return entry;
        }).forEach(goldEvaluationData::add);
        EvaluationData predictionEvaluationData = new EvaluationData();
        testList.forEach(document -> document.getGenes().map(g2 -> {
            EvaluationDataEntry ede = new EvaluationDataEntry(document.getId(), "NoId", g2.getText(), g2.getTagger().name());
            ede.setBegin(g2.getBegin());
            ede.setEnd(g2.getEnd());
            return ede;
        }).forEach(predictionEvaluationData::add));
        Properties evalSettings = new Properties();
        evalSettings.setProperty("comparison-type", EvaluationDataEntry.ComparisonType.OVERLAP.name());
        evalSettings.setProperty("overlap-type", EvaluationDataEntry.OverlapType.PERCENT.name());
        evalSettings.setProperty("overlap-size", "20");
        EntityEvaluator evaluator = new EntityEvaluator(evalSettings);
        EntityEvaluationResult evalResult = evaluator.evaluate(goldEvaluationData, predictionEvaluationData).getSingle();
        Stream<EvaluationDataEntry> fps = evalResult.getFpEvaluationDataEntriesMentionWise();
        fps.forEach(fn -> System.out.println(fn.getEntityString() + "\t" + fn.getDocId() + "\t" + fn.getBegin() + "-" + fn.getEnd()));
        System.out.println("Result:");
        System.out.println(evalResult.getEvaluationReportShort());
    }

    private static Stream<GeneDocument> loadData(File dataDirectory, File goldGeneList) throws IOException {
        String basePath = dataDirectory.getAbsolutePath();
        String predictedGenesPath = basePath + "/genes.tsv.gz";
        String sentencesPath = basePath + "/annotations.tsv.gz";
        String chunksPath = basePath + "/annotations.tsv.gz";
        String acronymsPath = basePath + "/acronyms.tsv.gz";
        String docTextPath = basePath + "/text";
        Multimap<String, GeneMention> goldData = CorpusReader.readMentionsWithOffsets(goldGeneList.getAbsolutePath());
        Multimap<String, GeneMention> predictedGeneMentions = CorpusReader.readMixedFileForGenesWithOffsets(predictedGenesPath);
        Multimap<String, Acronym> acronyms = CorpusReader.readAcronymAnnotations(acronymsPath);
        Map<String, String> documentContexts = CorpusReader.readGeneContexts(docTextPath);
        Multimap<String, Range<Integer>> sentences = CorpusReader.readMixedFileForSentenceOffsets(sentencesPath);
        Map<String, OffsetMap<String>> chunks = CorpusReader.readMixedFileForChunkOffsets(chunksPath);
        Multimap<String, PosTag> posTags = CorpusReader.readMixedFileForPosTags(chunksPath);
        return documentContexts.keySet().stream().map(docId -> {
            GeneDocument document = new GeneDocument((String)docId);
            document.setTermNormalizer(new TermNormalizer());
            document.setAcronyms(new HashSet<Acronym>(acronyms.get((String)docId)));
            document.setDocumentText((String)documentContexts.get(docId));
            document.setChunks((OffsetMap)chunks.get(docId));
            document.setPosTags(posTags.get((String)docId));
            document.setSentences(new OffsetSet(sentences.get((String)docId)));
            document.setGenes(new HashSet<GeneMention>(predictedGeneMentions.get((String)docId)));
            document.getAllGenes().forEach(gm -> gm.setDocumentContext(document.getDocumentText()));
            document.selectGeneMentionsByTagger(GeneMention.GeneTagger.GAZETTEER);
            goldData.get((String)docId).forEach(document::putGoldGene);
            return document;
        });
    }

    public void train(Stream<GeneDocument> documents) {
        InstanceListCreator instanceListCreator = new InstanceListCreator(false);
        InstanceList iList = instanceListCreator.createInstanceList(documents);
        log.debug("training the model from {} training examples ...", (Object)iList.size());
        MaxEntTrainer trainer = new MaxEntTrainer();
        MaxEnt meModel = trainer.train(iList);
        this.classifier = meModel;
    }

    public void filter(Stream<GeneDocument> documents) {
        documents.forEach(d -> {
            ArrayList<GeneMention> filteredGms = new ArrayList<GeneMention>();
            for (GeneSet gs : d.getGeneSets()) {
                Instance inst = this.classifier.getInstancePipe().instanceFrom(new Instance(gs, "", "", d));
                Classification classification = this.classifier.classify(inst);
                if (!classification.getLabeling().getBestLabel().getEntry().toString().equals("FALSE")) continue;
                filteredGms.add((GeneMention)gs.iterator().next());
            }
            filteredGms.forEach(gm -> d.removeGene((GeneMention)gm));
        });
    }

    private double getProbabilityTrueClass(Classification c) {
        Labeling labeling = c.getLabeling();
        LabelAlphabet dict = labeling.getLabelAlphabet();
        Label label = dict.lookupLabel("TRUE");
        double predValue = labeling.value(label);
        return predValue;
    }

    public void loadModel(InputStream s2) {
        try {
            ObjectInputStream in = new ObjectInputStream(s2);
            this.classifier = (Classifier)in.readObject();
            in.close();
        }
        catch (IOException io) {
            io.printStackTrace();
        }
        catch (ClassNotFoundException nf) {
            nf.printStackTrace();
        }
        this.classifier.getInstancePipe().getDataAlphabet().stopGrowth();
    }

    public void writeModel(File modelfile) {
        try {
            ObjectOutputStream out = new ObjectOutputStream(new FileOutputStream(modelfile));
            out.writeObject(this.classifier);
            out.flush();
            out.close();
        }
        catch (IOException io) {
            io.printStackTrace();
        }
    }
}

