/*
 * Decompiled with CFR 0.152.
 */
package banner.tagging;

import banner.tagging.FeatureSet;
import banner.tagging.TagFormat;
import banner.tagging.Tagger;
import banner.types.Sentence;
import cc.mallet.fst.CRF;
import cc.mallet.fst.CRFTrainerByLabelLikelihood;
import cc.mallet.fst.Transducer;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelSequence;
import cc.mallet.types.Sequence;
import cc.mallet.types.SparseVector;
import dragon.nlp.tool.Lemmatiser;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;

public class CRFTagger
implements Tagger {
    protected CRF model;
    private FeatureSet featureSet;
    private int order;

    protected CRFTagger(CRF model, FeatureSet featureSet, int order) {
        this.model = model;
        this.featureSet = featureSet;
        this.order = order;
    }

    public static CRFTagger load(InputStream f, Lemmatiser lemmatiser, dragon.nlp.tool.Tagger posTagger, Tagger preTagger) throws IOException {
        try {
            ObjectInputStream ois = new ObjectInputStream(new GZIPInputStream(f));
            CRF model = (CRF)ois.readObject();
            FeatureSet featureSet = (FeatureSet)ois.readObject();
            if (lemmatiser != null) {
                featureSet.setLemmatiser(lemmatiser);
            }
            if (posTagger != null) {
                featureSet.setPosTagger(posTagger);
            }
            if (preTagger != null) {
                featureSet.setPreTagger(preTagger);
            }
            int order = ois.readInt();
            ois.close();
            return new CRFTagger(model, featureSet, order);
        }
        catch (ClassNotFoundException e) {
            throw new RuntimeException(e);
        }
    }

    public static CRFTagger train(Set<Sentence> sentences, int order, TagFormat format, FeatureSet featureSet) {
        if (sentences.size() == 0) {
            throw new RuntimeException("Number of sentences must be greater than zero");
        }
        InstanceList instances = new InstanceList(featureSet.getPipe());
        for (Sentence sentence : sentences) {
            boolean omitConjunctionOffsets = true;
            Instance instance = new Instance((Object)sentence, null, (Object)sentence.getSentenceId(), (Object)sentence);
            instances.addThruPipe(instance);
            FeatureVectorSequence fvs = (FeatureVectorSequence)instance.getData();
            LabelSequence labelSequence = (LabelSequence)instance.getTarget();
        }
        CRF model = new CRF(featureSet.getPipe(), null);
        if (order == 1) {
            model.addStatesForLabelsConnectedAsIn(instances);
        } else if (order == 2) {
            model.addStatesForBiLabelsConnectedAsIn(instances);
        } else {
            throw new IllegalArgumentException("Order must be equal to 1 or 2");
        }
        CRFTrainerByLabelLikelihood crfTrainer = new CRFTrainerByLabelLikelihood(model);
        crfTrainer.train(instances);
        return new CRFTagger(model, featureSet, order);
    }

    public void write(File f) {
        try {
            ObjectOutputStream oos = new ObjectOutputStream(new GZIPOutputStream(new FileOutputStream(f)));
            oos.writeObject(this.model);
            oos.writeObject(this.featureSet);
            oos.writeInt(this.order);
            oos.close();
        }
        catch (IOException e) {
            System.err.println("Exception writing file " + f + ": " + e);
        }
    }

    @Override
    public void tag(Sentence sentence) {
        Instance instance = this.getInstance(sentence);
        Sequence tags = this.model.transduce((Sequence)instance.getData());
        sentence.addMentions(CRFTagger.getTagList((Sequence<Object>)tags), 1.0);
    }

    protected Instance getInstance(Sentence sentence) {
        InstanceList iList = new InstanceList(this.model.getInputPipe());
        iList.addThruPipe(new Instance((Object)sentence, null, (Object)sentence.getSentenceId(), (Object)sentence));
        return (Instance)iList.get(0);
    }

    protected static List<String> getTagList(Sequence<Object> tags) {
        int size = tags.size();
        ArrayList<String> tags2 = new ArrayList<String>();
        for (int i = 0; i < size; ++i) {
            tags2.add(tags.get(i).toString());
        }
        return tags2;
    }

    public int getOrder() {
        return this.order;
    }

    public Set<String> getFeatureNames() {
        HashSet<String> featureNames = new HashSet<String>();
        Alphabet inputAlphabet = this.model.getInputAlphabet();
        int size = inputAlphabet.size();
        for (int i = 0; i < size; ++i) {
            String featureName = inputAlphabet.lookupObject(i).toString();
            featureNames.add(featureName);
        }
        return Collections.unmodifiableSet(featureNames);
    }

    public List<List<String>> getFeatureRepresentation(Sentence sentence) {
        Instance instance = this.getInstance(sentence);
        Sequence sentenceSequence = (Sequence)instance.getData();
        Alphabet alphabet = this.model.getInputAlphabet();
        ArrayList<List<String>> sentenceFeatureRepresentation = new ArrayList<List<String>>();
        for (int i = 0; i < sentenceSequence.size(); ++i) {
            ArrayList<String> tokenFeatureRepresentation = new ArrayList<String>();
            FeatureVector tokenFeatures = (FeatureVector)sentenceSequence.get(i);
            int[] featureIndicies = tokenFeatures.getIndices();
            double[] featureValues = tokenFeatures.getValues();
            for (int j = 0; j < featureIndicies.length; ++j) {
                StringBuilder tokenFeature = new StringBuilder();
                tokenFeature.append(alphabet.lookupObject(featureIndicies[j]).toString());
                if (featureValues != null) {
                    tokenFeature.append("=");
                    tokenFeature.append(featureValues[j]);
                }
                tokenFeatureRepresentation.add(tokenFeature.toString());
            }
            Collections.sort(tokenFeatureRepresentation);
            sentenceFeatureRepresentation.add(tokenFeatureRepresentation);
        }
        return sentenceFeatureRepresentation;
    }

    public void describe(String fileName) throws IOException {
        System.out.println("Number of default weights = " + this.model.getDefaultWeights().length);
        System.out.println("Number of states = " + this.model.numStates());
        for (int i = 0; i < this.model.numStates(); ++i) {
            Transducer.State state = this.model.getState(i);
            System.out.println("State " + i + " is " + state.getName());
        }
        SparseVector[] weights = this.model.getWeights();
        System.out.println("Size of weights vector = " + weights.length);
        for (int i = 0; i < weights.length; ++i) {
            System.out.print("Number of non-zero values for weight vector " + i);
            System.out.println(" (" + this.model.getWeightsName(i) + ") is " + weights[i].numLocations());
        }
        int size = this.model.getInputAlphabet().size();
        System.out.println("Size of input alphabet: " + size);
        PrintWriter output = new PrintWriter(fileName);
        for (int i = 0; i < size; ++i) {
            String featureName = this.model.getInputAlphabet().lookupObject(i).toString();
            int equalsIndex = featureName.indexOf("=");
            int atIndex = featureName.indexOf("@");
            int featureTypeEnd = featureName.length();
            if (equalsIndex != -1 && equalsIndex < featureTypeEnd) {
                featureTypeEnd = equalsIndex;
            }
            if (atIndex != -1 && atIndex < featureTypeEnd) {
                featureTypeEnd = atIndex;
            }
            String featureType = featureName.substring(0, featureTypeEnd);
            String featureOffset = "0";
            int featureDataEnd = featureName.length();
            if (atIndex != -1) {
                featureDataEnd = atIndex;
                featureOffset = featureName.substring(atIndex + 1, featureName.length());
            }
            String featureData = "";
            if (featureDataEnd > featureTypeEnd) {
                featureData = featureName.substring(featureTypeEnd + 1, featureDataEnd);
            }
            featureData = featureData.replaceAll("^\"", "\\\"");
            double maxWeight = Double.NEGATIVE_INFINITY;
            for (int j = 0; j < weights.length; ++j) {
                if (this.model.getWeightsName(j).endsWith("O:O") || !(maxWeight < weights[j].value(i))) continue;
                maxWeight = weights[j].value(i);
            }
            output.print(i + "\t");
            output.print(featureName + "\t");
            output.print(featureType + "\t");
            output.print(featureOffset + "\t");
            output.print(featureData + "\t");
            output.print(maxWeight + "\t");
            output.println();
        }
        output.close();
    }

    public Map<String, Double> getMaxWeights() {
        HashMap<String, Double> weightMap = new HashMap<String, Double>();
        SparseVector[] weights = this.model.getWeights();
        Alphabet inputAlphabet = this.model.getInputAlphabet();
        int size = inputAlphabet.size();
        for (int i = 0; i < size; ++i) {
            double max = Double.MIN_VALUE;
            for (int j = 0; j < weights.length; ++j) {
                double weight = weights[j].value(i);
                if (!(max < weight)) continue;
                max = weight;
            }
            String featureName = inputAlphabet.lookupObject(i).toString();
            weightMap.put(featureName, max);
        }
        return weightMap;
    }

    public Map<String, Double> getMinWeights() {
        HashMap<String, Double> weightMap = new HashMap<String, Double>();
        SparseVector[] weights = this.model.getWeights();
        Alphabet inputAlphabet = this.model.getInputAlphabet();
        int size = inputAlphabet.size();
        for (int i = 0; i < size; ++i) {
            double min = Double.MAX_VALUE;
            for (int j = 0; j < weights.length; ++j) {
                double weight = weights[j].value(i);
                if (!(min > weight)) continue;
                min = weight;
            }
            String featureName = inputAlphabet.lookupObject(i).toString();
            weightMap.put(featureName, min);
        }
        return weightMap;
    }
}

