/*
 * Decompiled with CFR 0.152.
 */
package com.wcohen.ss.abbvGapsHmm;

import com.wcohen.ss.abbvGapsHmm.AbbreviationAlignmentContainer;
import com.wcohen.ss.abbvGapsHmm.AbbvGapsHmmBackwardsEvaluator;
import com.wcohen.ss.abbvGapsHmm.AbbvGapsHmmBackwardsViterbiEvaluator;
import com.wcohen.ss.abbvGapsHmm.AbbvGapsHmmExpectationEvaluator;
import com.wcohen.ss.abbvGapsHmm.AbbvGapsHmmForwardEvaluator;
import com.wcohen.ss.abbvGapsHmm.Acronym;
import com.wcohen.ss.abbvGapsHmm.Matrix3D;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class AbbvGapsHMM {
    private String _tfIdfDataFile = null;
    private Double _dfWordThreshold = 0.2;
    private Map<String, Double> _commonWordDF = null;
    List<Double> _transitionCounters = new ArrayList<Double>();
    List<Double> _emissionCounters = new ArrayList<Double>();
    Matrix3D _alpha;
    Matrix3D _beta;
    List<Double> _transitionParams = new ArrayList<Double>();
    List<Double> _emissionParams = new ArrayList<Double>();
    boolean _externalySet = false;
    List<Double> _stateStartProb = null;
    private static final Double CHANGE_THRESHOLD = 0.01;
    private static final int MAX_ITERATIONS = 300;
    private String _modelParamsFile = null;

    public AbbvGapsHMM() {
        this._modelParamsFile = null;
    }

    public AbbvGapsHMM(String modelParamFile) {
        this._modelParamsFile = modelParamFile;
    }

    public AbbvGapsHMM(String modelParamFile, boolean allowVowelsMatch) {
        this._modelParamsFile = modelParamFile;
    }

    public List<Double> getEmmisionParams() {
        return this._emissionParams;
    }

    public List<Double> getTransitionParams() {
        return this._transitionParams;
    }

    public boolean useTDIDF() {
        return this._tfIdfDataFile != null;
    }

    public Double getDF(String word) {
        if (this._commonWordDF.containsKey(word)) {
            return this._commonWordDF.get(word);
        }
        return null;
    }

    public void setTfIdfData(String dataFile) throws IOException {
        String line;
        this._tfIdfDataFile = dataFile;
        this._commonWordDF = new HashMap<String, Double>();
        BufferedReader fi = new BufferedReader(new FileReader(dataFile));
        while ((line = fi.readLine()) != null) {
            String[] parts = line.split(" ");
            String word = parts[0];
            Double df = Double.parseDouble(parts[1]);
            if (df.compareTo(this._dfWordThreshold) < 0) continue;
            this._commonWordDF.put(word, df);
        }
        fi.close();
    }

    protected void initStartProbs() {
        if (this._stateStartProb != null) {
            return;
        }
        States[] states = States.values();
        this._stateStartProb = new ArrayList<Double>();
        for (int i = 0; i < states.length; ++i) {
            if (states[i].name().equals("S")) {
                this._stateStartProb.add(1.0);
                continue;
            }
            this._stateStartProb.add(0.0);
        }
    }

    public void setParamFile(String paramFile) {
        this._modelParamsFile = paramFile;
    }

    public boolean train(List<List<Acronym>> corpus, List<Map<String, String>> trueLabels) {
        if (!this.loadModelParams()) {
            return this.trainCorpus(corpus, trueLabels);
        }
        return true;
    }

    public boolean train(List<List<Acronym>> corpus, List<Map<String, String>> trueLabels, boolean force) {
        if (force) {
            return this.trainCorpus(corpus, trueLabels);
        }
        return this.loadModelParams();
    }

    public void setStartingParams(List<Double> emmisions, List<Double> transitions) {
        this._emissionParams.clear();
        this._emissionParams.addAll(emmisions);
        this._transitionParams.clear();
        this._transitionParams.addAll(transitions);
        this._externalySet = true;
    }

    public void initModelParamsAndCounters() {
        Emissions[] emissions = Emissions.values();
        this._emissionCounters.clear();
        if (!this._externalySet) {
            this._emissionParams.clear();
        }
        for (int i = 0; i < emissions.length; ++i) {
            this._emissionCounters.add(0.0);
            this._emissionParams.add(0.5);
        }
        Transitions[] transitions = Transitions.values();
        this._transitionCounters.clear();
        if (!this._externalySet) {
            this._transitionParams.clear();
        }
        for (int i = 0; i < transitions.length; ++i) {
            this._transitionCounters.add(0.0);
            this._transitionParams.add(0.5);
        }
        this._emissionParams.set(Emissions.e_END_end.ordinal(), 1.0);
    }

    protected boolean trainCorpus(List<List<Acronym>> corpus, List<Map<String, String>> trueLabels) {
        boolean converge = false;
        trueLabels = null;
        this.initModelParamsAndCounters();
        int n = corpus.size();
        int c = 1;
        System.out.print("training:");
        while (!converge) {
            for (int i = 0; i < n; ++i) {
                List<Acronym> docAcronyms = corpus.get(i);
                Map<String, String> docTrueLabels = null;
                if (trueLabels != null) {
                    docTrueLabels = trueLabels.get(i);
                }
                int m3 = docAcronyms.size();
                for (int j = 0; j < m3; ++j) {
                    Acronym currAcronym = docAcronyms.get(j);
                    if (trueLabels != null) {
                        this.expectationStep(currAcronym, docTrueLabels.get(currAcronym._shortForm));
                        continue;
                    }
                    this.expectationStep(currAcronym, null);
                }
            }
            Double change = this.maximizationStep();
            System.out.print(".");
            if (++c > 300) {
                System.out.println("\n\tTraining stopped after " + (c - 1) + " iterations with final change: " + change);
                converge = true;
            }
            if (change.compareTo(CHANGE_THRESHOLD) >= 0) continue;
            System.out.println("\n\tTraining converged in " + (c - 1) + " iterations.");
            converge = true;
        }
        this.saveModelParams();
        return true;
    }

    protected void expectationStep(Acronym acronym, String trueLongForm) {
        AbbvGapsHmmBackwardsEvaluator backEval = new AbbvGapsHmmBackwardsEvaluator(this);
        backEval.backwardEvaluate(acronym, this._transitionParams, this._emissionParams);
        this._beta = backEval.getEvalMatrix();
        if (this._beta.at(0, 0, States.S.ordinal()) == 0.0) {
            return;
        }
        AbbvGapsHmmForwardEvaluator forEval = new AbbvGapsHmmForwardEvaluator(this);
        forEval.forwardEvaluate(acronym, this._transitionParams, this._emissionParams);
        this._alpha = forEval.getEvalMatrix();
        AbbvGapsHmmExpectationEvaluator expectationEval = new AbbvGapsHmmExpectationEvaluator(this);
        expectationEval.expectationEvaluate(acronym, this._transitionCounters, this._emissionCounters, this._transitionParams, this._emissionParams, this._alpha, this._beta);
        this._transitionCounters = expectationEval.getTransitionCounters();
        this._emissionCounters = expectationEval.getEmissionCounters();
    }

    public AbbreviationAlignmentContainer<Emissions, States> viterbi(Acronym acronym) {
        AbbvGapsHmmBackwardsViterbiEvaluator viterbi = new AbbvGapsHmmBackwardsViterbiEvaluator(this);
        return viterbi.backwardViterbiEvaluate(acronym, this._transitionParams, this._emissionParams);
    }

    public void saveModelParams() {
        if (this._modelParamsFile == null) {
            return;
        }
        try {
            BufferedWriter bw = new BufferedWriter(new FileWriter(this._modelParamsFile));
            bw.write("# Emmisions\n");
            Emissions[] emissions = Emissions.values();
            for (int i = 0; i < emissions.length; ++i) {
                bw.write(emissions[i].toString() + "\t" + this._emissionParams.get(i) + "\n");
            }
            bw.write("# Transitions\n");
            Transitions[] transitions = Transitions.values();
            for (int i = 0; i < transitions.length; ++i) {
                bw.write(transitions[i].toString() + "\t" + this._transitionParams.get(i) + "\n");
            }
            bw.close();
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    public boolean loadModelParams() {
        try {
            String line;
            if (this._modelParamsFile == null) {
                return false;
            }
            File f = new File(this._modelParamsFile);
            if (!f.exists()) {
                return false;
            }
            BufferedReader fi = new BufferedReader(new FileReader(this._modelParamsFile));
            this._emissionParams.clear();
            this._transitionParams.clear();
            Emissions[] emissions = Emissions.values();
            int i = 0;
            while ((line = fi.readLine()) != null) {
                if (line.startsWith("#")) continue;
                String[] parts = line.split("\t");
                if (i < emissions.length) {
                    this._emissionParams.add(Double.parseDouble(parts[1]));
                    ++i;
                    continue;
                }
                this._transitionParams.add(Double.parseDouble(parts[1]));
            }
            fi.close();
            return this._transitionParams.size() == Transitions.values().length;
        }
        catch (IOException e) {
            this._emissionParams.clear();
            this._transitionParams.clear();
            e.printStackTrace();
            return false;
        }
    }

    protected Double maximizationStep() {
        Double valChange = 0.0;
        valChange = valChange + this.maximizationStepForEmissions();
        valChange = valChange + this.maximizationStepForTransitions();
        return valChange;
    }

    protected Double maximizationStepForTransitions() {
        double total_M = 0.0;
        double total_D = 0.0;
        double total_DL = 0.0;
        double total_S = 0.0;
        double total_I = 0.0;
        Transitions[] transitions = Transitions.values();
        for (int i = 0; i < transitions.length; ++i) {
            String currTransition = transitions[i].name();
            if (currTransition.startsWith("t_DL_")) {
                total_DL += this.smoothCounter(i, this._transitionCounters, this._transitionParams);
                continue;
            }
            if (currTransition.startsWith("t_M_")) {
                total_M += this.smoothCounter(i, this._transitionCounters, this._transitionParams);
                continue;
            }
            if (currTransition.startsWith("t_D_")) {
                total_D += this.smoothCounter(i, this._transitionCounters, this._transitionParams);
                continue;
            }
            if (currTransition.startsWith("t_S_")) {
                total_S += this.smoothCounter(i, this._transitionCounters, this._transitionParams);
                continue;
            }
            if (!currTransition.startsWith("t_I_")) continue;
            total_I += this.smoothCounter(i, this._transitionCounters, this._transitionParams);
        }
        Double valChange = 0.0;
        for (int i = 0; i < transitions.length; ++i) {
            String currTransition = transitions[i].name();
            Double newVal = currTransition.startsWith("t_DL_") ? new Double(this.getNewStateVal(this.smoothCounter(i, this._transitionCounters, this._transitionParams), total_DL)) : (currTransition.startsWith("t_M_") ? new Double(this.getNewStateVal(this.smoothCounter(i, this._transitionCounters, this._transitionParams), total_M)) : (currTransition.startsWith("t_D_") ? new Double(this.getNewStateVal(this.smoothCounter(i, this._transitionCounters, this._transitionParams), total_D)) : (currTransition.startsWith("t_S_") ? new Double(this.getNewStateVal(this.smoothCounter(i, this._transitionCounters, this._transitionParams), total_S)) : (currTransition.startsWith("t_I_") ? new Double(this.getNewStateVal(this.smoothCounter(i, this._transitionCounters, this._transitionParams), total_I)) : new Double(1.0)))));
            valChange = valChange + Math.abs(this._transitionParams.get(i) - newVal);
            this._transitionParams.set(i, newVal);
        }
        return valChange;
    }

    protected double smoothCounter(int i, List<Double> counters, List<Double> params) {
        double alpha = 1.0;
        return counters.get(i) + Math.pow(params.get(i), alpha);
    }

    protected double getNewStateVal(double current, double total) {
        if (total == 0.0) {
            return 0.0;
        }
        return new Double(current / total);
    }

    protected Double maximizationStepForEmissions() {
        double total_M = 0.0;
        double total_D = 0.0;
        double total_DL = 0.0;
        double total_I = 0.0;
        Emissions[] emissions = Emissions.values();
        for (int i = 0; i < emissions.length; ++i) {
            String currEmission = emissions[i].name();
            if (currEmission.startsWith("e_DL_")) {
                total_DL += this.smoothCounter(i, this._emissionCounters, this._emissionParams);
                continue;
            }
            if (currEmission.startsWith("e_M_")) {
                total_M += this.smoothCounter(i, this._emissionCounters, this._emissionParams);
                continue;
            }
            if (currEmission.startsWith("e_D_")) {
                total_D += this.smoothCounter(i, this._emissionCounters, this._emissionParams);
                continue;
            }
            if (!currEmission.startsWith("e_I_")) continue;
            total_I += this.smoothCounter(i, this._emissionCounters, this._emissionParams);
        }
        Double valChange = 0.0;
        for (int i = 0; i < emissions.length; ++i) {
            String currEmission = emissions[i].name();
            Double newVal = currEmission.startsWith("e_DL_") ? new Double(this.getNewStateVal(this.smoothCounter(i, this._emissionCounters, this._emissionParams), total_DL)) : (currEmission.startsWith("e_M_") ? new Double(this.getNewStateVal(this.smoothCounter(i, this._emissionCounters, this._emissionParams), total_M)) : (currEmission.startsWith("e_D_") ? new Double(this.getNewStateVal(this.smoothCounter(i, this._emissionCounters, this._emissionParams), total_D)) : (currEmission.startsWith("e_I_") ? new Double(this.getNewStateVal(this.smoothCounter(i, this._emissionCounters, this._emissionParams), total_I)) : new Double(1.0))));
            valChange = valChange + Math.abs(this._emissionParams.get(i) - newVal);
            this._emissionParams.set(i, newVal);
        }
        return valChange;
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    public static enum Emissions {
        e_DL_alphaNumeric_to_none,
        e_DL_nonAlphaNumeric_to_none,
        e_DL_word_to_none,
        e_D_alphaNumeric_to_none,
        e_D_word_to_none,
        e_D_none_to_nonAlphaNumeric,
        e_M_partialWord_to_letter,
        e_M_word_to_firstLetter,
        e_M_letter_to_letter,
        e_M_nonAlphaNumeric_to_none,
        e_M_commonWordDeletion,
        e_M_AND_to_symbol,
        e_M_one_to_1,
        e_M_two_to_2,
        e_M_three_to_3,
        e_M_four_to_4,
        e_M_five_to_5,
        e_M_six_to_6,
        e_M_seven_to_7,
        e_M_eight_to_8,
        e_M_nine_to_9,
        e_M_Silver_Ag,
        e_M_Gold_Au,
        e_M_Copper_Cu,
        e_M_Iron_Fe,
        e_M_Mercury_Hg,
        e_M_Potassium_K,
        e_M_Sodium_Na,
        e_M_Lead_Pb,
        e_M_Antimony_Sb,
        e_M_Tin_Sn,
        e_M_Tungsten_W,
        e_END_end;

    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    public static enum Transitions {
        t_DL_in,
        t_DL_to_M,
        t_M_in,
        t_M_to_D,
        t_M_to_END,
        t_D_in,
        t_D_to_M,
        t_D_to_END,
        t_S_to_M,
        t_S_to_DL;

    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    public static enum States {
        S,
        DL,
        M,
        D,
        END;

    }
}

