/*
 * Decompiled with CFR 0.152.
 */
package de.julielab.jcore.ae.jpos.main;

import cc.mallet.fst.CRF;
import cc.mallet.types.Alphabet;
import com.google.common.base.Charsets;
import com.google.common.io.Files;
import de.julielab.jcore.ae.jpos.tagger.POSTagger;
import de.julielab.jcore.ae.jpos.tagger.Sentence;
import de.julielab.jcore.ae.jpos.tagger.Unit;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileWriter;
import java.io.IOException;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.List;
import java.util.Properties;
import java.util.Random;

public class JPOSApplication {
    public static void main(String[] args) throws Exception {
        File modelFile;
        File trainFile;
        String mode;
        long startTime = System.currentTimeMillis();
        if (args.length < 1) {
            System.err.println("usage: <mode> <mode-specific-parameters>");
            JPOSApplication.showModes();
            System.exit(-1);
        }
        if ((mode = args[0]).equals("x")) {
            if (args.length < 4) {
                System.err.println("usage: x <trainData> <x-rounds> <featureConfigFile> [number of iterations]");
                System.err.println("pred-out format: token pred gold");
                System.exit(-1);
            }
            trainFile = new File(args[1]);
            int rounds = new Integer(args[2]);
            File featureConfigFile = new File(args[3]);
            int number_iter = 0;
            if (args.length == 5) {
                number_iter = new Integer(args[4]);
            }
            boolean max_ent = true;
            JPOSApplication.evalXVal(trainFile, rounds, featureConfigFile, number_iter, true);
        } else if (mode.equals("t")) {
            if (args.length < 4) {
                System.err.println("usage: t <trainData> <model-out-file> <featureConfigFile> [number of iterations]");
                System.exit(-1);
            }
            trainFile = new File(args[1]);
            File modelFile2 = new File(args[2]);
            File featureConfigFile = null;
            int number_iter = 0;
            featureConfigFile = new File(args[3]);
            if (args.length == 5) {
                number_iter = new Integer(args[4]);
            }
            JPOSApplication.train(trainFile, modelFile2, featureConfigFile, number_iter);
        } else if (mode.equals("p")) {
            if (args.length != 4) {
                System.err.println("usage: p <unlabeled data> <modelFile> <outFile>");
                System.exit(-1);
            }
            trainFile = new File(args[1]);
            File modelFile3 = new File(args[2]);
            File outFile = new File(args[3]);
            JPOSApplication.predict(trainFile, modelFile3, outFile);
        } else if (mode.equals("c")) {
            if (args.length != 3) {
                System.err.println("\ncompares the gold standard agains the prediction");
                System.err.println("\nusage: c <predData> <goldData>");
                System.exit(-1);
            }
            File predFile = new File(args[1]);
            File goldFile = new File(args[2]);
            JPOSApplication.compare(predFile, goldFile);
        } else if (mode.equals("oc")) {
            if (args.length != 2) {
                System.err.println("\nusage: oc <model>");
                System.exit(-1);
            }
            modelFile = new File(args[1]);
            JPOSApplication.printFeatureConfig(modelFile);
        } else if (mode.equals("ts")) {
            if (args.length != 2) {
                System.err.println("\nusage: ts <model>");
                System.exit(-1);
            }
            modelFile = new File(args[1]);
            JPOSApplication.printTagset(modelFile);
        } else {
            System.err.println("ERR: unknown mode");
            JPOSApplication.showModes();
            System.exit(-1);
        }
        long timeNeeded = (System.currentTimeMillis() - startTime) / 1000L / 60L;
        System.out.println("Finished in " + timeNeeded + " minutes");
    }

    static void showModes() {
        System.err.println("\nAvailable modes:");
        System.err.println("x: cross validation ");
        System.err.println("c: compare goldstandard and prediction");
        System.err.println("t: train ");
        System.err.println("p: predict ");
        System.err.println("oc: output model configuration ");
        System.err.println("ts: output model tagset");
        System.exit(-1);
    }

    static void train(File trainFile, File outFile, File featureConfigFile, int number_iter) throws IOException {
        List<String> ppdSentences = Files.readLines(trainFile, Charsets.UTF_8);
        ArrayList<Sentence> sentences = new ArrayList<Sentence>();
        POSTagger tagger = featureConfigFile != null ? new POSTagger(featureConfigFile) : new POSTagger();
        tagger.set_Number_Iterations(number_iter);
        for (String ppdSentence : ppdSentences) {
            sentences.add(tagger.PPDtoUnits(ppdSentence));
        }
        tagger.train(sentences);
        tagger.writeModel(outFile.toString());
    }

    public static void evalXVal(File dataFile, int n, File featureConfigFile, int number_iter, boolean maxEnt) throws IOException {
        List<String> ppdData = Files.readLines(dataFile, Charsets.UTF_8);
        DecimalFormat df = new DecimalFormat("0.000");
        long seed = 1L;
        Collections.shuffle(ppdData, new Random(1L));
        int pos = 0;
        int sizeRound = ppdData.size() / n;
        int sizeAll = ppdData.size();
        int sizeLastRound = sizeRound + sizeAll % n;
        System.out.println(" * number of sentences: " + sizeAll);
        System.out.println(" * size of each/last round: " + sizeRound + "/" + sizeLastRound);
        System.out.println();
        double[] accuracies = new double[n];
        for (int i = 0; i < n; ++i) {
            double eval;
            int j;
            ArrayList<String> ppdTrainData = new ArrayList<String>();
            ArrayList<String> ppdTestData = new ArrayList<String>();
            if (i == n - 1) {
                for (j = 0; j < ppdData.size(); ++j) {
                    if (j < pos) {
                        ppdTrainData.add(ppdData.get(j));
                        continue;
                    }
                    ppdTestData.add(ppdData.get(j));
                }
            } else {
                for (j = 0; j < ppdData.size(); ++j) {
                    if (j < pos || j >= pos + sizeRound) {
                        ppdTrainData.add(ppdData.get(j));
                        continue;
                    }
                    ppdTestData.add(ppdData.get(j));
                }
                pos += sizeRound;
            }
            System.out.println(" * training on: " + ppdTrainData.size() + " -- testing on: " + ppdTestData.size());
            accuracies[i] = eval = JPOSApplication.eval(ppdTrainData, ppdTestData, featureConfigFile, number_iter, i);
            System.out.println("\n** round " + (i + 1) + "\tAccuracy: " + df.format(eval));
        }
        double avgAcc = JPOSApplication.getAverage(accuracies);
        double stdAcc = JPOSApplication.getStandardDeviation(accuracies, avgAcc);
        StringBuffer summary = new StringBuffer();
        summary.append("Cross-validation results:\n");
        summary.append("Number of sentences in evaluation data set: " + sizeAll + "\n");
        summary.append("Number of sentences for training in each/last round: " + sizeRound + "/" + sizeLastRound + "\n\n");
        summary.append("Overall performance: avg (standard deviation)\n");
        summary.append("Accuracy: " + df.format(avgAcc) + "(" + df.format(stdAcc) + ")\n");
        System.out.println("\n\nCross-validation finished");
        System.out.println(summary);
    }

    public static double getStandardDeviation(double[] values, double avg) {
        double sum = 0.0;
        for (double value : values) {
            sum += Math.pow(value - avg, 2.0);
        }
        return Math.sqrt(sum / ((double)values.length - 1.0));
    }

    public static double getAverage(double[] values) {
        double sum = 0.0;
        for (double value : values) {
            sum += value;
        }
        return sum / (double)values.length;
    }

    static void predict(File testDataFile, File modelFile, File outFile) throws Exception {
        List<String> testData = Files.readLines(testDataFile, Charsets.UTF_8);
        ArrayList<Sentence> sentences = new ArrayList<Sentence>();
        POSTagger tagger = POSTagger.readModel(modelFile);
        try {
            System.out.println("  * predicting...");
            long t1 = System.currentTimeMillis();
            FileWriter fw = new FileWriter(outFile);
            for (String sentence : testData) {
                sentences.add(tagger.textToUnits(sentence));
            }
            ArrayList<String> results = tagger.predictForCLI(sentences);
            for (String result : results) {
                fw.write(result);
            }
            long t2 = System.currentTimeMillis();
            System.out.println("prediction took: " + (t2 - t1));
            fw.close();
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    /*
     * WARNING - void declaration
     */
    static double eval(ArrayList<String> ppdTrainData, ArrayList<String> ppdTestData, File featureConfigFile, int number_iter, int numXval) {
        void var12_19;
        ArrayList<Sentence> trainSentences = new ArrayList<Sentence>();
        ArrayList<Sentence> testSentences = new ArrayList<Sentence>();
        POSTagger tagger = featureConfigFile != null ? new POSTagger(featureConfigFile) : new POSTagger();
        tagger.set_Number_Iterations(number_iter);
        for (String ppdTrainSentence : ppdTrainData) {
            trainSentences.add(tagger.PPDtoUnits(ppdTrainSentence));
        }
        for (String ppdTestSentence : ppdTestData) {
            testSentences.add(tagger.PPDtoUnits(ppdTestSentence));
        }
        tagger.train(trainSentences);
        ArrayList<String> gold = new ArrayList<String>();
        for (int i = 0; i < testSentences.size(); ++i) {
            Sentence sentence = (Sentence)testSentences.get(i);
            for (Unit unit : sentence.getUnits()) {
                gold.add(unit.getRep() + "|" + unit.getLabel());
            }
        }
        ArrayList<String> pred = new ArrayList<String>();
        for (String predictedSentence : tagger.predictForCLI(testSentences)) {
            for (String predictedTag : predictedSentence.trim().split(" ")) {
                pred.add(predictedTag);
            }
        }
        double correct = 0.0;
        if (pred.size() != gold.size()) {
            throw new RuntimeException();
        }
        boolean bl = false;
        while (var12_19 < gold.size()) {
            if (((String)pred.get((int)var12_19)).replaceAll(".*\\|", "").equals(((String)gold.get((int)var12_19)).replaceAll(".*\\|", ""))) {
                correct += 1.0;
            } else {
                System.out.println("Predicted:\t" + (String)pred.get((int)var12_19) + "\tCorrect: " + (String)gold.get((int)var12_19));
            }
            ++var12_19;
        }
        return correct / (double)gold.size();
    }

    static void compare(File predFile, File goldFile) throws IOException {
        List<String> gold = Files.readLines(goldFile, Charsets.UTF_8);
        List<String> pred = Files.readLines(predFile, Charsets.UTF_8);
        if (gold.size() != pred.size()) {
            System.err.println("ERR: number of lines in gold standard is different from prediction... please check!");
            System.exit(-1);
        }
        int correct = 0;
        int seen = 0;
        for (int i = 0; i < gold.size(); ++i) {
            String[] predToken;
            String[] goldToken = gold.get(i).split(" +");
            if (goldToken.length != (predToken = pred.get(i).split(" +")).length) {
                System.err.println("ERR: number of tokens in gold standard is different from prediction for\n" + goldToken + "\n" + predToken);
                System.exit(-1);
            }
            for (int j = 0; j < goldToken.length; ++j) {
                ++seen;
                if (!goldToken[j].replaceAll(".*\\|", "").equals(predToken[j].replaceAll(".*\\|", ""))) continue;
                ++correct;
            }
        }
        System.out.println("Correct: " + correct);
        System.out.println("Seen: " + seen);
        System.out.println("Accuracy: " + (double)correct / (double)seen);
    }

    public static void printFeatureConfig(File modelFile) throws FileNotFoundException, ClassNotFoundException, IOException {
        POSTagger tagger = POSTagger.readModel(modelFile);
        Properties featureConfig = tagger.getFeatureConfig();
        Enumeration<?> keys = featureConfig.propertyNames();
        while (keys.hasMoreElements()) {
            String key = (String)keys.nextElement();
            System.out.printf("%s = %s\n", key, featureConfig.getProperty(key));
        }
    }

    public static void printTagset(File modelFile) throws FileNotFoundException, ClassNotFoundException, IOException {
        Object[] modelLabels;
        POSTagger tagger = POSTagger.readModel(modelFile);
        Object model = tagger.getModel();
        Alphabet alpha = ((CRF)model).getOutputAlphabet();
        for (Object modelLabel : modelLabels = alpha.toArray()) {
            System.out.println(modelLabel);
        }
    }
}

