/*
 * Decompiled with CFR 0.152.
 */
package de.datexis.ner;

import com.google.common.collect.Lists;
import de.datexis.annotator.Annotator;
import de.datexis.annotator.AnnotatorComponent;
import de.datexis.common.Resource;
import de.datexis.common.Timer;
import de.datexis.common.WordHelpers;
import de.datexis.encoder.Encoder;
import de.datexis.model.Annotation;
import de.datexis.model.Dataset;
import de.datexis.model.Document;
import de.datexis.model.Sentence;
import de.datexis.model.tag.BIO2Tag;
import de.datexis.model.tag.BIOESTag;
import de.datexis.ner.MentionAnnotation;
import de.datexis.ner.eval.HTMLExport;
import de.datexis.ner.eval.MentionAnnotatorEvaluation;
import de.datexis.ner.tagger.MentionTagger;
import de.datexis.tagger.Tagger;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.apache.commons.io.FileUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MentionAnnotator
extends Annotator {
    protected static final Logger log = LoggerFactory.getLogger(MentionAnnotator.class);
    protected Resource bestModel;

    public MentionAnnotator() {
    }

    public MentionAnnotator(Tagger root) {
        super(root);
    }

    public MentionTagger getTagger() {
        return (MentionTagger)this.tagger;
    }

    public String toString() {
        return this.getProvenance().toString();
    }

    public void annotate(Collection<Document> docs) {
        this.getTagger().tag(docs);
        this.createAnnotations(docs, Annotation.Source.PRED);
    }

    public void trainModel(Dataset train, Dataset test, WordHelpers.Language lang) {
        this.provenance.setDataset(train.getName());
        this.provenance.setLanguage(lang.toString().toLowerCase());
        this.getTagger().setName(this.provenance.toString());
        this.createTags(train.getDocuments(), Annotation.Source.GOLD);
        this.getTagger().trainModel(train, Annotation.Source.GOLD);
        this.createTags(test.getDocuments(), Annotation.Source.GOLD);
        this.getTagger().testModel(test, Annotation.Source.GOLD);
    }

    public void trainModel(Dataset train, Annotation.Source annotationSource, WordHelpers.Language lang) {
        this.trainModel(train, annotationSource, lang, -1, true, true);
    }

    public void trainModel(Dataset train, Annotation.Source annotationSource, WordHelpers.Language lang, int limitExamples, boolean incremental, boolean randomize) {
        this.provenance.setDataset(train.getName());
        this.provenance.setLanguage(lang.toString().toLowerCase());
        this.getTagger().setName(this.provenance.toString());
        this.createTags(train.getDocuments(), annotationSource);
        this.getTagger().trainModel(train, annotationSource, limitExamples, randomize);
    }

    public void trainModelEarlyStopping(Dataset train, Dataset validation, Annotation.Source annotationSource, WordHelpers.Language lang, int epochSize, int minEpochs, int maxEpochs, int maxEpochsWithNoImprovement) {
        this.provenance.setDataset(train.getName());
        this.provenance.setLanguage(lang.toString().toLowerCase());
        this.getTagger().setName(this.provenance.toString());
        this.createTags(train.getDocuments(), annotationSource);
        Timer timer = new Timer();
        int epoch = 1;
        double score = 0.0;
        double bestScore = 0.0;
        int retries = maxEpochsWithNoImprovement;
        timer.start();
        do {
            this.getTagger().appendTrainLog("\n");
            this.getTagger().appendTrainLog("EPOCH " + epoch + ": training " + this.tagger.getName());
            this.getTagger().trainModel(train, annotationSource, epochSize, true);
            this.getTagger().appendTestLog("Testing epoch " + epoch);
            this.annotate(validation.getDocuments());
            MentionAnnotatorEvaluation eval = new MentionAnnotatorEvaluation("TraiNER epoch " + epoch, annotationSource, Annotation.Source.PRED, Annotation.Match.STRONG);
            eval.calculateScores(validation.getDocuments());
            eval.printAnnotationStats();
            score = eval.getScore();
            timer.setSplit("epoch");
            this.getTagger().appendTrainLog("EPOCH " + epoch + " complete: score " + score, timer.getLong("epoch"));
            if (score >= bestScore) {
                this.bestModel = Resource.createTempDirectory();
                try {
                    this.writeModel(this.bestModel, this.getTagger().getName());
                    HTMLExport htmlTest = new HTMLExport(validation.getDocuments(), BIOESTag.class, annotationSource, Annotation.Source.PRED);
                    FileUtils.writeStringToFile((File)this.bestModel.resolve("test_" + epoch + ".html").toFile(), (String)htmlTest.getHTML());
                }
                catch (IOException ex) {
                    log.error("Could not write output: " + ex.toString());
                }
                bestScore = score;
                retries = maxEpochsWithNoImprovement;
                continue;
            }
            --retries;
        } while ((++epoch <= minEpochs || retries >= 0) && epoch <= maxEpochs);
        timer.stop();
        this.getTagger().appendTrainLog("Training complete: " + this.tagger.getName() + " with score " + bestScore, timer.getLong());
        this.getTagger().appendTrainLog("\n");
    }

    public void writeBestModel(Resource path, String name) throws IOException {
        FileUtils.copyDirectory((File)this.bestModel.toFile(), (File)path.toFile());
    }

    public void trainModel(Collection<Sentence> sentences, Annotation.Source tagSource, WordHelpers.Language lang) {
        this.provenance.setLanguage(lang.toString().toLowerCase());
        this.getTagger().setName(this.provenance.toString());
        this.getTagger().trainModel(sentences, tagSource, true);
    }

    protected void createTags(Iterable<Document> docs, Annotation.Source expected) {
        for (Document doc : docs) {
            if (!doc.isTagAvaliable(expected, BIOESTag.class) && doc.isTagAvaliable(expected, BIO2Tag.class)) {
                BIO2Tag.convertToBIOES((Document)doc, (Annotation.Source)expected);
                doc.setTagAvailable(expected, BIOESTag.class, true);
                continue;
            }
            if (doc.isTagAvaliable(expected, BIOESTag.class)) continue;
            MentionAnnotation.createTagsFromAnnotations(doc, expected, BIOESTag.class);
            doc.setTagAvailable(expected, BIOESTag.class, true);
        }
    }

    protected void createAnnotations(Iterable<Document> docs, Annotation.Source expected) {
        for (Document doc : docs) {
            doc.clearAnnotations(expected, MentionAnnotation.class);
            if (doc.isTagAvaliable(expected, BIO2Tag.class)) {
                MentionAnnotation.annotateFromTags(doc, Annotation.Source.PRED, BIO2Tag.class);
                continue;
            }
            log.error("BIO2Tag not set");
        }
    }

    public static class Builder {
        MentionAnnotator ann;
        MentionTagger tagger;
        protected String types = "GENERIC";
        protected Class tagset = BIOESTag.class;
        protected List<Encoder> encoders = new ArrayList<Encoder>();
        private int trainingSize = -1;
        private int ffwLayerSize = 300;
        private int lstmLayerSize = 100;
        private double learningRate = 0.001;
        private int iterations = 1;
        private int batchSize = 16;
        private int numEpochs = 1;
        private int workers = 1;
        private boolean enabletrainingUI = false;

        public Builder() {
            this.tagger = new MentionTagger();
            this.ann = new MentionAnnotator(this.tagger);
        }

        public Builder withModelParams(int ffwLayerSize, int lstmLayerSize) {
            this.ffwLayerSize = ffwLayerSize;
            this.lstmLayerSize = lstmLayerSize;
            return this;
        }

        public Builder withTrainingParams(double learningRate, int batchSize, int numEpochs) {
            this.learningRate = learningRate;
            this.batchSize = batchSize;
            this.numEpochs = numEpochs;
            return this;
        }

        public Builder withWorkspaceParams(int workers) {
            this.workers = workers;
            return this;
        }

        public Builder withTypes(MentionAnnotation.Type types) {
            this.types = types.toString();
            return this;
        }

        public Builder withTypes(String types) {
            this.types = types;
            return this;
        }

        public Builder withEncoders(String desc, Encoder ... encoders) {
            this.ann.getProvenance().setFeatures(desc);
            this.withEncoders(encoders);
            return this;
        }

        public Builder withEncoders(Encoder ... encoders) {
            this.encoders = Lists.newArrayList((Object[])encoders);
            this.ann.getProvenance().setArchitecture(this.encoders.toString());
            return this;
        }

        public Builder enableTrainingUI(boolean enable) {
            this.enabletrainingUI = enable;
            return this;
        }

        public Builder pretrain(Dataset train) {
            for (Encoder e : this.encoders) {
                e.trainModel(train.streamDocuments());
            }
            return this;
        }

        public MentionAnnotator build() {
            for (Encoder e : this.encoders) {
                if (!e.isModelAvailable()) {
                    throw new IllegalArgumentException("encoder " + e.getId() + " has no model available, please consider pretrain()");
                }
                this.ann.addComponent((AnnotatorComponent)e);
            }
            this.tagger.setTagset(this.tagset, this.types);
            this.tagger.setEncoders(this.encoders);
            this.tagger.setModelParams(this.ffwLayerSize, this.lstmLayerSize, this.iterations, this.learningRate * (double)this.batchSize);
            if (this.enabletrainingUI) {
                this.tagger.enableTrainingUI();
            }
            this.tagger.setTrainingParams(this.batchSize, this.numEpochs, true);
            this.tagger.setWorkspaceParams(this.workers);
            this.ann.getProvenance().setTask("NER-" + this.types);
            this.tagger.setName(this.ann.getProvenance().toString());
            return this.ann;
        }
    }
}

