/*
 * Decompiled with CFR 0.152.
 */
package de.datexis.sector.exec;

import de.datexis.common.CommandLineParser;
import de.datexis.common.Resource;
import de.datexis.common.WordHelpers;
import de.datexis.encoder.Encoder;
import de.datexis.encoder.impl.BloomEncoder;
import de.datexis.encoder.impl.DummyEncoder;
import de.datexis.encoder.impl.StructureEncoder;
import de.datexis.encoder.impl.Word2VecEncoder;
import de.datexis.model.Dataset;
import de.datexis.model.Document;
import de.datexis.sector.SectorAnnotator;
import de.datexis.sector.encoder.ClassEncoder;
import de.datexis.sector.encoder.HeadingEncoder;
import de.datexis.sector.model.SectionAnnotation;
import de.datexis.sector.reader.WikiSectionReader;
import java.io.IOException;
import java.util.ArrayList;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.deeplearning4j.ui.api.UIServer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT;
import org.nd4j.linalg.lossfunctions.impl.LossMCXENT;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TrainSectorAnnotator {
    protected static final Logger log = LoggerFactory.getLogger(TrainSectorAnnotator.class);

    public static void main(String[] args) throws IOException {
        ExecParams params = new ExecParams();
        CommandLineParser parser = new CommandLineParser((CommandLineParser.Options)params);
        try {
            parser.parse(args);
            new TrainSectorAnnotator().runTraining(params);
            System.exit(0);
        }
        catch (ParseException e) {
            HelpFormatter formatter = new HelpFormatter();
            formatter.printHelp("texoo-train-sector", "TeXoo: train SectorAnnotator from WikiSection dataset", params.setUpCliOptions(), "", true);
            System.exit(1);
        }
        catch (Exception e) {
            e.printStackTrace();
            System.exit(1);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     * Loose catch block
     */
    protected void runTraining(ExecParams params) throws IOException {
        Dataset validation;
        Dataset train;
        Resource trainingPath = Resource.fromDirectory((String)params.trainFile);
        Resource validationPath = params.devFile != null ? Resource.fromDirectory((String)params.devFile) : null;
        Resource testPath = params.testFile != null ? Resource.fromDirectory((String)params.testFile) : null;
        Resource output = Resource.fromDirectory((String)params.outputPath);
        WordHelpers.Language lang = WordHelpers.getLanguage((String)params.language);
        Dataset dataset = train = trainingPath.getFileName().endsWith(".json") ? WikiSectionReader.readDatasetFromJSON(trainingPath) : WikiSectionReader.readDatasetFromJSON(trainingPath);
        Dataset dataset2 = validationPath == null ? null : (validation = validationPath.getFileName().endsWith(".json") ? WikiSectionReader.readDatasetFromJSON(validationPath) : WikiSectionReader.readDatasetFromJSON(validationPath));
        Dataset test = testPath == null ? null : (testPath.getFileName().endsWith(".json") ? WikiSectionReader.readDatasetFromJSON(testPath) : WikiSectionReader.readDatasetFromJSON(testPath));
        SectorAnnotator.Builder builder = new SectorAnnotator.Builder();
        if (params.embeddingsFile == null) {
            this.initializeInputEncodings_bloom(builder, train, lang);
        } else {
            this.initializeInputEncodings_wemb(builder, Resource.fromFile((String)params.embeddingsFile));
        }
        if (params.isHeadingsModel) {
            this.initializeHeadingsTarget(builder, train, lang);
        } else {
            this.initializeClassLabelsTarget(builder, train);
        }
        SectorAnnotator sector = builder.withDataset(train.getName(), lang).withModelParams(0, 256, 128).withTrainingParams(0.01, 0.5, 2048, 396, 16, 10).enableTrainingUI(params.trainingUI).build();
        boolean success = false;
        try {
            if (validation == null) {
                sector.trainModel(train);
            } else {
                sector.trainModelEarlyStopping(train, validation, 10, 10, 100);
            }
            output = output.resolve(sector.getTagger().getName());
            output.toFile().mkdirs();
            sector.writeModel(output);
            sector.writeTrainLog(output);
            if (test != null) {
                sector.annotate(test.getDocuments(), SectorAnnotator.SegmentationMethod.BEMD);
                sector.evaluateModel(test, false, true, true);
            }
            sector.writeTestLog(output);
            success = true;
        }
        catch (Throwable throwable) {
            block23: {
                try {
                    if (params.trainingUI) {
                        UIServer.getInstance().stop();
                    }
                    System.exit(success ? 0 : 1);
                }
                catch (NoClassDefFoundError noClassDefFoundError) {
                    System.exit(success ? 0 : 1);
                }
                catch (Exception exception) {
                    System.exit(success ? 0 : 1);
                    break block23;
                    {
                        catch (Throwable throwable2) {
                            System.exit(success ? 0 : 1);
                            throw throwable2;
                        }
                    }
                }
            }
            throw throwable;
        }
        try {
            if (params.trainingUI) {
                UIServer.getInstance().stop();
            }
            System.exit(success ? 0 : 1);
        }
        catch (NoClassDefFoundError noClassDefFoundError) {
            System.exit(success ? 0 : 1);
        }
        catch (Exception exception) {
            System.exit(success ? 0 : 1);
            {
                catch (Throwable throwable) {
                    System.exit(success ? 0 : 1);
                    throw throwable;
                }
            }
        }
    }

    protected SectorAnnotator.Builder initializeInputEncodings_bloom(SectorAnnotator.Builder builder, Dataset train, WordHelpers.Language lang) {
        BloomEncoder bloom = new BloomEncoder(4096, 5);
        bloom.trainModel(train.getDocuments(), 5, lang);
        StructureEncoder structure = new StructureEncoder();
        return builder.withInputEncoders("bloom", (Encoder)bloom, (Encoder)new DummyEncoder(), (Encoder)structure);
    }

    protected SectorAnnotator.Builder initializeInputEncodings_wemb(SectorAnnotator.Builder builder, Resource embeddingModel) throws IOException {
        Word2VecEncoder wordEmb = new Word2VecEncoder();
        wordEmb.loadModel(embeddingModel);
        StructureEncoder structure = new StructureEncoder();
        return builder.withInputEncoders("emb", (Encoder)new DummyEncoder(), (Encoder)wordEmb, (Encoder)structure);
    }

    protected SectorAnnotator.Builder initializeClassLabelsTarget(SectorAnnotator.Builder builder, Dataset train) {
        ArrayList<String> sections = new ArrayList<String>();
        for (Document doc : train.getDocuments()) {
            for (SectionAnnotation ann : doc.getAnnotations(SectionAnnotation.class)) {
                sections.add(ann.getSectionLabel());
            }
        }
        ClassEncoder targetEncoder = new ClassEncoder();
        targetEncoder.trainModel(sections, 0);
        return builder.withId("SEC>T").withTargetEncoder((Encoder)targetEncoder).withLossFunction((ILossFunction)new LossMCXENT(), Activation.SOFTMAX, false);
    }

    protected SectorAnnotator.Builder initializeHeadingsTarget(SectorAnnotator.Builder builder, Dataset train, WordHelpers.Language lang) {
        ArrayList<String> headings = new ArrayList<String>();
        for (Document doc : train.getDocuments()) {
            for (SectionAnnotation ann : doc.getAnnotations(SectionAnnotation.class)) {
                headings.add(ann.getSectionHeading());
            }
        }
        HeadingEncoder targetEncoder = new HeadingEncoder();
        targetEncoder.trainModel(headings, 20, lang);
        return builder.withId("SEC>H").withTargetEncoder((Encoder)targetEncoder).withLossFunction((ILossFunction)new LossBinaryXENT(), Activation.SIGMOID, false);
    }

    protected static class ExecParams
    implements CommandLineParser.Options {
        protected String trainFile;
        protected String devFile = null;
        protected String testFile = null;
        protected String outputPath = null;
        protected String embeddingsFile = null;
        protected String language = null;
        protected boolean trainingUI = false;
        protected boolean isHeadingsModel = false;

        protected ExecParams() {
        }

        public void setParams(CommandLine parse) {
            this.trainFile = parse.getOptionValue("i");
            this.devFile = parse.getOptionValue("v");
            this.testFile = parse.getOptionValue("t");
            this.outputPath = parse.getOptionValue("o");
            this.embeddingsFile = parse.getOptionValue("e");
            this.language = parse.getOptionValue("l", "en");
            this.trainingUI = parse.hasOption("u");
            this.isHeadingsModel = parse.hasOption("h");
        }

        public Options setUpCliOptions() {
            Options op = new Options();
            op.addRequiredOption("i", "input", true, "file name of WikiSection training dataset");
            op.addRequiredOption("o", "output", true, "path to create and store the model");
            op.addOption("h", "headings", false, "train multi-label model (SEC>H), otherwise single-label model (SEC>T) is used");
            op.addRequiredOption("o", "output", true, "path to create and store the model");
            op.addOption("v", "validation", true, "file name of WikiSection validation dataset (will use early stopping if given)");
            op.addOption("t", "test", true, "file name of WikiSection test dataset (will test after training if given)");
            op.addOption("e", "embedding", true, "path to word embedding model, will use bloom filters if not given");
            op.addOption("l", "language", true, "language to use for sentence splitting and stopwords (EN or DE)");
            op.addOption("u", "ui", false, "enable training UI (http://127.0.0.1:9000)");
            return op;
        }
    }
}

