/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.lm;

import com.aliasi.io.BitInput;
import com.aliasi.io.BitOutput;
import com.aliasi.lm.BitTrieReader;
import com.aliasi.lm.BitTrieWriter;
import com.aliasi.lm.CharSeqCounter;
import com.aliasi.lm.NBestCounter;
import com.aliasi.lm.Node;
import com.aliasi.lm.NodeFactory;
import com.aliasi.lm.TrieReader;
import com.aliasi.lm.TrieWriter;
import com.aliasi.util.ObjectToCounterMap;
import com.aliasi.util.Strings;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class TrieCharSeqCounter
implements CharSeqCounter {
    Node mRootNode = NodeFactory.createNode(0L);
    final int mMaxLength;
    static final Node[] EMPTY_NODE_ARRAY = new Node[0];

    public TrieCharSeqCounter(int maxLength) {
        if (maxLength < 0) {
            String msg = "Max length must be >= 0. Found length=" + maxLength;
            throw new IllegalArgumentException(msg);
        }
        this.mMaxLength = maxLength;
    }

    @Override
    public long count(char[] cs, int start, int end) {
        Strings.checkArgsStartEnd(cs, start, end);
        return this.mRootNode.count(cs, start, end);
    }

    @Override
    public long extensionCount(char[] cs, int start, int end) {
        Strings.checkArgsStartEnd(cs, start, end);
        return this.mRootNode.contextCount(cs, start, end);
    }

    @Override
    public char[] observedCharacters() {
        return com.aliasi.util.Arrays.copy(this.mRootNode.outcomes(new char[0], 0, 0));
    }

    @Override
    public char[] charactersFollowing(char[] cs, int start, int end) {
        Strings.checkArgsStartEnd(cs, start, end);
        return com.aliasi.util.Arrays.copy(this.mRootNode.outcomes(cs, start, end));
    }

    @Override
    public int numCharactersFollowing(char[] cs, int start, int end) {
        Strings.checkArgsStartEnd(cs, start, end);
        return this.mRootNode.numOutcomes(cs, start, end);
    }

    public long totalSequenceCount() {
        long sum = 0L;
        long[][] uniqueTotals = this.uniqueTotalNGramCount();
        int i = 0;
        while (i < uniqueTotals.length) {
            sum += uniqueTotals[i][1];
            ++i;
        }
        return sum;
    }

    public long totalSequenceCount(int length) {
        return this.mRootNode.totalNGramCount(length);
    }

    public long uniqueSequenceCount() {
        return this.mRootNode.size();
    }

    public long uniqueSequenceCount(int nGramOrder) {
        return this.mRootNode.uniqueNGramCount(nGramOrder);
    }

    public void prune(int minCount) {
        if (minCount < 1) {
            String msg = "Prune minimum count must be more than 1. Found minCount=" + minCount;
            throw new IllegalArgumentException(msg);
        }
        this.mRootNode = this.mRootNode.prune(minCount);
        if (this.mRootNode == null) {
            this.mRootNode = NodeFactory.createNode(0L);
        }
    }

    public int[] nGramFrequencies(int nGramOrder) {
        List<Long> counts = this.countsList(nGramOrder);
        int[] result = new int[counts.size()];
        int i = 0;
        while (i < result.length) {
            result[i] = counts.get(i).intValue();
            ++i;
        }
        Arrays.sort(result);
        i = result.length / 2;
        while (i >= 0) {
            int iOpp = result.length - i - 1;
            int tmp = result[i];
            result[i] = result[iOpp];
            result[iOpp] = tmp;
            --i;
        }
        return result;
    }

    public long[][] uniqueTotalNGramCount() {
        long[][] result = new long[this.mMaxLength + 1][2];
        this.mRootNode.addNGramCounts(result, 0);
        return result;
    }

    public ObjectToCounterMap<String> topNGrams(int nGramOrder, int maxReturn) {
        NBestCounter counter = new NBestCounter(maxReturn, true);
        this.mRootNode.topNGrams(counter, new char[nGramOrder], 0, nGramOrder);
        return counter.toObjectToCounter();
    }

    public long count(CharSequence cSeq) {
        return this.count(com.aliasi.util.Arrays.toArray(cSeq), 0, cSeq.length());
    }

    public long extensionCount(CharSequence cSeq) {
        return this.mRootNode.contextCount(com.aliasi.util.Arrays.toArray(cSeq), 0, cSeq.length());
    }

    public void incrementSubstrings(char[] cs, int start, int end) {
        this.incrementSubstrings(cs, start, end, 1);
    }

    public void incrementSubstrings(char[] cs, int start, int end, int count) {
        Strings.checkArgsStartEnd(cs, start, end);
        int i = start;
        while (i + this.mMaxLength <= end) {
            this.incrementPrefixes(cs, i, i + this.mMaxLength, count);
            ++i;
        }
        i = Math.max(start, end - this.mMaxLength + 1);
        while (i < end) {
            this.incrementPrefixes(cs, i, end, count);
            ++i;
        }
    }

    public void incrementSubstrings(CharSequence cSeq) {
        this.incrementSubstrings(cSeq, 1);
    }

    public void incrementSubstrings(CharSequence cSeq, int count) {
        this.incrementSubstrings(com.aliasi.util.Arrays.toArray(cSeq), 0, cSeq.length(), count);
    }

    public void incrementPrefixes(char[] cs, int start, int end) {
        this.incrementPrefixes(cs, start, end, 1);
    }

    public void incrementPrefixes(char[] cs, int start, int end, int count) {
        Strings.checkArgsStartEnd(cs, start, end);
        this.mRootNode = this.mRootNode.increment(cs, start, end, count);
    }

    public void decrementSubstrings(char[] cs, int start, int end) {
        Strings.checkArgsStartEnd(cs, start, end);
        int i = start;
        while (i < end) {
            int j = i;
            while (j <= end) {
                this.mRootNode = this.mRootNode.decrement(cs, i, j);
                ++j;
            }
            ++i;
        }
    }

    public String toString() {
        return this.mRootNode.toString();
    }

    void toStringBuilder(StringBuilder sb) {
        this.mRootNode.toString(sb, 0);
    }

    public void decrementUnigram(char c) {
        this.decrementUnigram(c, 1);
    }

    public void decrementUnigram(char c, int count) {
        this.mRootNode = this.mRootNode.decrement(new char[]{c}, 0, 1, count);
    }

    private List<Long> countsList(int nGramOrder) {
        ArrayList<Long> accum = new ArrayList<Long>();
        this.mRootNode.addCounts(accum, nGramOrder);
        return accum;
    }

    public void writeTo(OutputStream out) throws IOException {
        BitOutput bitOut = new BitOutput(out);
        bitOut.writeDelta((long)this.mMaxLength + 1L);
        BitTrieWriter writer = new BitTrieWriter(bitOut);
        TrieCharSeqCounter.writeCounter(this, writer, this.mMaxLength);
        bitOut.flush();
    }

    public static void writeCounter(CharSeqCounter counter, TrieWriter writer, int maxNGram) throws IOException {
        TrieCharSeqCounter.writeCounter(new char[maxNGram], 0, counter, writer);
    }

    public static TrieCharSeqCounter readFrom(InputStream in) throws IOException {
        BitInput bitIn = new BitInput(in);
        int maxNGram = (int)(bitIn.readDelta() - 1L);
        BitTrieReader reader = new BitTrieReader(bitIn);
        return TrieCharSeqCounter.readCounter(reader, maxNGram);
    }

    public static TrieCharSeqCounter readCounter(TrieReader reader, int maxNGram) throws IOException {
        TrieCharSeqCounter counter = new TrieCharSeqCounter(maxNGram);
        counter.mRootNode = TrieCharSeqCounter.readNode(reader, 0, maxNGram);
        return counter;
    }

    static void writeCounter(char[] cs, int pos, CharSeqCounter counter, TrieWriter writer) throws IOException {
        long count = counter.count(cs, 0, pos);
        writer.writeCount(count);
        if (pos < cs.length) {
            char[] csNext = counter.charactersFollowing(cs, 0, pos);
            int i = 0;
            while (i < csNext.length) {
                writer.writeSymbol(csNext[i]);
                cs[pos] = csNext[i];
                TrieCharSeqCounter.writeCounter(cs, pos + 1, counter, writer);
                ++i;
            }
        }
        writer.writeSymbol(-1L);
    }

    private static void skipNode(TrieReader reader) throws IOException {
        reader.readCount();
        while (reader.readSymbol() != -1L) {
            TrieCharSeqCounter.skipNode(reader);
        }
    }

    private static Node readNode(TrieReader reader, int depth, int maxDepth) throws IOException {
        long sym;
        if (depth > maxDepth) {
            TrieCharSeqCounter.skipNode(reader);
            return null;
        }
        long count = reader.readCount();
        int depthPlus1 = depth + 1;
        long sym1 = reader.readSymbol();
        if (sym1 == -1L) {
            return NodeFactory.createNode(count);
        }
        Node node1 = TrieCharSeqCounter.readNode(reader, depthPlus1, maxDepth);
        long sym2 = reader.readSymbol();
        if (sym2 == -1L) {
            return NodeFactory.createNodeFold((char)sym1, node1, count);
        }
        Node node2 = TrieCharSeqCounter.readNode(reader, depthPlus1, maxDepth);
        long sym3 = reader.readSymbol();
        if (sym3 == -1L) {
            return NodeFactory.createNode((char)sym1, node1, (char)sym2, node2, count);
        }
        Node node3 = TrieCharSeqCounter.readNode(reader, depthPlus1, maxDepth);
        long sym4 = reader.readSymbol();
        if (sym4 == -1L) {
            return NodeFactory.createNode((char)sym1, node1, (char)sym2, node2, (char)sym3, node3, count);
        }
        Node node4 = TrieCharSeqCounter.readNode(reader, depthPlus1, maxDepth);
        StringBuilder cBuf = new StringBuilder();
        cBuf.append((char)sym1);
        cBuf.append((char)sym2);
        cBuf.append((char)sym3);
        cBuf.append((char)sym4);
        ArrayList<Node> nodeList = new ArrayList<Node>();
        nodeList.add(node1);
        nodeList.add(node2);
        nodeList.add(node3);
        nodeList.add(node4);
        while ((sym = reader.readSymbol()) != -1L) {
            cBuf.append((char)sym);
            nodeList.add(TrieCharSeqCounter.readNode(reader, depthPlus1, maxDepth));
        }
        Node[] nodes = nodeList.toArray(EMPTY_NODE_ARRAY);
        char[] cs = Strings.toCharArray(cBuf);
        return NodeFactory.createNode(cs, nodes, count);
    }
}

