/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.pipe;

import cc.mallet.classify.BalancedWinnowTrainer;
import cc.mallet.classify.Classification;
import cc.mallet.classify.Classifier;
import cc.mallet.classify.ClassifierTrainer;
import cc.mallet.classify.Trial;
import cc.mallet.pipe.Noop;
import cc.mallet.pipe.Pipe;
import cc.mallet.pipe.SerialPipes;
import cc.mallet.types.Alphabet;
import cc.mallet.types.AugmentableFeatureVector;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Label;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelSequence;
import cc.mallet.types.LabelVector;
import cc.mallet.util.MalletLogger;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.HashMap;
import java.util.logging.Logger;

public class AddClassifierTokenPredictions
extends Pipe
implements Serializable {
    private static Logger logger = MalletLogger.getLogger(AddClassifierTokenPredictions.class.getName());
    int[] m_predRanks2add;
    TokenClassifiers m_tokenClassifiers;
    boolean m_binary;
    boolean m_inProduction;
    Alphabet m_dataAlphabet;
    private static final long serialVersionUID = 1L;

    public AddClassifierTokenPredictions(InstanceList trainList) {
        this(trainList, null);
    }

    public AddClassifierTokenPredictions(InstanceList trainList, InstanceList testList) {
        this(new TokenClassifiers(AddClassifierTokenPredictions.convert(trainList, (Noop)trainList.getPipe())), new int[]{1}, true, AddClassifierTokenPredictions.convert(testList, (Noop)trainList.getPipe()));
    }

    public AddClassifierTokenPredictions(TokenClassifiers tokenClassifiers, int[] predRanks2add, boolean binary, InstanceList testList) {
        this.m_predRanks2add = predRanks2add;
        this.m_binary = binary;
        this.m_tokenClassifiers = tokenClassifiers;
        this.m_inProduction = false;
        this.m_dataAlphabet = (Alphabet)tokenClassifiers.getAlphabet().clone();
        LabelAlphabet labelAlphabet = tokenClassifiers.getLabelAlphabet();
        int i = 0;
        while (i < this.m_predRanks2add.length) {
            int j = 0;
            while (j < labelAlphabet.size()) {
                String featName = "TOK_PRED=" + labelAlphabet.lookupObject(j).toString() + "_@_RANK_" + this.m_predRanks2add[i];
                this.m_dataAlphabet.lookupIndex(featName, true);
                ++j;
            }
            ++i;
        }
        if (testList != null) {
            Trial trial = new Trial(this.m_tokenClassifiers, testList);
            logger.info("Token classifier accuracy on test set = " + trial.getAccuracy());
        }
    }

    public void setInProduction(boolean inProduction) {
        this.m_inProduction = inProduction;
    }

    public boolean getInProduction() {
        return this.m_inProduction;
    }

    public static void setInProduction(Pipe p, boolean value) {
        if (p instanceof AddClassifierTokenPredictions) {
            ((AddClassifierTokenPredictions)p).setInProduction(value);
        } else if (p instanceof SerialPipes) {
            SerialPipes sp = (SerialPipes)p;
            int i = 0;
            while (i < sp.size()) {
                AddClassifierTokenPredictions.setInProduction(sp.getPipe(i), value);
                ++i;
            }
        }
    }

    @Override
    public Alphabet getDataAlphabet() {
        return this.m_dataAlphabet;
    }

    @Override
    public Instance pipe(Instance carrier) {
        FeatureVectorSequence fvs = (FeatureVectorSequence)carrier.getData();
        InstanceList ilist = AddClassifierTokenPredictions.convert(carrier, (Noop)this.m_tokenClassifiers.getInstancePipe());
        assert (fvs.size() == ilist.size());
        FeatureVector[] fva = new FeatureVector[fvs.size()];
        int i = 0;
        while (i < ilist.size()) {
            Instance inst = (Instance)ilist.get(i);
            Classification c = this.m_tokenClassifiers.classify(inst, !this.m_inProduction);
            LabelVector lv = c.getLabelVector();
            AugmentableFeatureVector afv1 = (AugmentableFeatureVector)inst.getData();
            int[] indices = afv1.getIndices();
            AugmentableFeatureVector afv2 = new AugmentableFeatureVector(this.m_dataAlphabet, indices, afv1.getValues(), indices.length + this.m_predRanks2add.length);
            int j = 0;
            while (j < this.m_predRanks2add.length) {
                Label label = lv.getLabelAtRank(this.m_predRanks2add[j]);
                int idx = this.m_dataAlphabet.lookupIndex("TOK_PRED=" + label.toString() + "_@_RANK_" + this.m_predRanks2add[j]);
                assert (idx >= 0);
                afv2.add(idx, 1.0);
                ++j;
            }
            fva[i] = afv2;
            ++i;
        }
        carrier.setData(new FeatureVectorSequence(fva));
        return carrier;
    }

    public static InstanceList convert(InstanceList ilist, Noop alphabetsPipe) {
        if (ilist == null) {
            return null;
        }
        InstanceList ret = new InstanceList(alphabetsPipe);
        for (Instance inst : ilist) {
            ret.add(inst);
        }
        return ret;
    }

    public static InstanceList convert(Instance inst, Noop alphabetsPipe) {
        InstanceList ret = new InstanceList(alphabetsPipe);
        Object obj = inst.getData();
        assert (obj instanceof FeatureVectorSequence);
        FeatureVectorSequence fvs = (FeatureVectorSequence)obj;
        LabelSequence ls = (LabelSequence)inst.getTarget();
        assert (fvs.size() == ls.size());
        Object instName = inst.getName() == null ? "NONAME" : inst.getName();
        int j = 0;
        while (j < fvs.size()) {
            FeatureVector fv = fvs.getFeatureVector(j);
            int[] indices = fv.getIndices();
            AugmentableFeatureVector data = new AugmentableFeatureVector(alphabetsPipe.getDataAlphabet(), indices, fv.getValues(), indices.length);
            Label target = ls.getLabelAtPosition(j);
            String name = String.valueOf(instName.toString()) + "_@_POS_" + (j + 1);
            Object source = inst.getSource();
            Instance toAdd = alphabetsPipe.pipe(new Instance(data, target, name, source));
            ret.add(toAdd);
            ++j;
        }
        return ret;
    }

    public static class TokenClassifiers
    extends Classifier
    implements Serializable {
        int m_numCV;
        int m_randSeed;
        ClassifierTrainer m_trainer;
        Classifier m_tokenClassifier;
        HashMap m_table;
        private static final long serialVersionUID = 1L;
        private static final int CURRENT_SERIAL_VERSION = 1;

        public TokenClassifiers(InstanceList trainList) {
            this(trainList, 0, 5);
        }

        public TokenClassifiers(InstanceList trainList, int randSeed, int numCV) {
            this(new BalancedWinnowTrainer(), trainList, randSeed, numCV);
        }

        public TokenClassifiers(ClassifierTrainer trainer, InstanceList trainList, int randSeed, int numCV) {
            super(trainList.getPipe());
            this.m_trainer = trainer;
            this.m_randSeed = randSeed;
            this.m_numCV = numCV;
            this.m_table = new HashMap();
            this.doTraining(trainList);
        }

        private void doTraining(InstanceList trainList) {
            logger.info("Training token classifier on entire data set (size=" + trainList.size() + ")...");
            this.m_tokenClassifier = this.m_trainer.train(trainList);
            Trial t = new Trial(this.m_tokenClassifier, trainList);
            logger.info("Training set accuracy = " + t.getAccuracy());
            if (this.m_numCV == 0) {
                return;
            }
            InstanceList instanceList = trainList;
            instanceList.getClass();
            InstanceList.CrossValidationIterator cvIter = new InstanceList.CrossValidationIterator(instanceList, this.m_numCV, this.m_randSeed);
            int f = 1;
            while (cvIter.hasNext()) {
                InstanceList[] fold = cvIter.nextSplit();
                logger.info("Training token classifier on cv fold " + ++f + " / " + this.m_numCV + " (size=" + fold[0].size() + ")...");
                Object foldClassifier = this.m_trainer.train(fold[0]);
                Trial t1 = new Trial((Classifier)foldClassifier, fold[0]);
                Trial t2 = new Trial((Classifier)foldClassifier, fold[1]);
                logger.info("Within-fold accuracy = " + t1.getAccuracy());
                logger.info("Out-of-fold accuracy = " + t2.getAccuracy());
                int i = 0;
                while (i < fold[1].size()) {
                    Instance inst = (Instance)fold[1].get(i);
                    this.m_table.put(inst.getName(), foldClassifier);
                    ++i;
                }
            }
        }

        @Override
        public Classification classify(Instance instance) {
            return this.classify(instance, false);
        }

        public Classification classify(Instance instance, boolean useOutOfFold) {
            Object instName = instance.getName();
            if (!useOutOfFold || !this.m_table.containsKey(instName)) {
                return this.m_tokenClassifier.classify(instance);
            }
            Classifier classifier = (Classifier)this.m_table.get(instName);
            return classifier.classify(instance);
        }

        private void writeObject(ObjectOutputStream out) throws IOException {
            out.writeInt(1);
            out.writeObject(this.getInstancePipe());
            out.writeInt(this.m_numCV);
            out.writeInt(this.m_randSeed);
            out.writeObject(this.m_table);
            out.writeObject(this.m_tokenClassifier);
            out.writeObject(this.m_trainer);
        }

        private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
            int version2 = in.readInt();
            if (version2 != 1) {
                throw new ClassNotFoundException("Mismatched TokenClassifiers versions: wanted 1, got " + version2);
            }
            this.instancePipe = (Pipe)in.readObject();
            this.m_numCV = in.readInt();
            this.m_randSeed = in.readInt();
            this.m_table = (HashMap)in.readObject();
            this.m_tokenClassifier = (Classifier)in.readObject();
            this.m_trainer = (ClassifierTrainer)in.readObject();
        }
    }
}

