package org.apache.joshua.decoder.ff.lm.bloomfilter_lm;

import java.io.Externalizable;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInput;
import java.io.ObjectInputStream;
import java.io.ObjectOutput;
import java.io.ObjectOutputStream;
import java.util.HashMap;
import java.util.Iterator;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.apache.joshua.corpus.Vocabulary;
import org.apache.joshua.decoder.ff.lm.DefaultNGramLanguageModel;
import org.apache.joshua.util.Regex;
import org.apache.joshua.util.io.LineReader;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:joshua-incubating-6.1.jar:org/apache/joshua/decoder/ff/lm/bloomfilter_lm/BloomFilterLanguageModel.class */
public class BloomFilterLanguageModel extends DefaultNGramLanguageModel implements Externalizable {
    public static final int HASH_SEED = 17;
    public static final int HASH_OFFSET = 37;
    public static final double MAX_SCORE = 100.0d;
    private static final Logger LOG = LoggerFactory.getLogger((Class<?>) BloomFilterLanguageModel.class);
    private BloomFilter bf;
    private double quantizationBase;
    private double numTokens;
    private long[][] countFuncs;
    private long[][] typesFuncs;
    private transient double p0;
    private transient double lambda0;
    private transient int maxQ;

    public BloomFilterLanguageModel(int i, String str) throws IOException {
        super(i);
        try {
            readExternal(new ObjectInputStream(new GZIPInputStream(new FileInputStream(str))));
            int size = Vocabulary.size();
            this.p0 = -Math.log(size + 1);
            this.p0 += this.numTokens - logAdd(Math.log(size), this.numTokens);
            this.lambda0 = Math.log(size) - logAdd(Math.log(size), this.numTokens);
            this.maxQ = quantize((long) Math.exp(this.numTokens));
        } catch (ClassNotFoundException e) {
            IOException iOException = new IOException("Could not rebuild bloom filter LM from file " + str);
            iOException.initCause(e);
            throw iOException;
        }
    }

    private BloomFilterLanguageModel(String str, int i, int i2, double d) {
        super(i);
        this.quantizationBase = d;
        populateBloomFilter(i2, str);
    }

    private float wittenBell(int[] iArr, int i) {
        int length = iArr.length;
        double d = this.p0;
        int count = getCount(iArr, iArr.length - 1, iArr.length, this.maxQ);
        if (count == 0) {
            return (float) d;
        }
        double logAdd = logAdd(d, this.lambda0 + (Math.log(unQuantize(count)) - this.numTokens));
        if (iArr.length == 1) {
            return (float) logAdd;
        }
        for (int i2 = length - 2; i2 >= length - i && i2 >= 0; i2--) {
            int count2 = getCount(iArr, i2, length, count);
            if (count2 == 0) {
                return (float) logAdd;
            }
            int typesAfter = getTypesAfter(iArr, i2, length, count2);
            double unQuantize = unQuantize(count2);
            double unQuantize2 = 1.0d + unQuantize(typesAfter);
            double log = Math.log(unQuantize2) - Math.log(unQuantize2 + unQuantize);
            double log2 = logAdd + (Math.log(unQuantize) - Math.log(unQuantize2 + unQuantize));
            int count3 = getCount(iArr, i2 + 1, length, typesAfter);
            double unQuantize3 = unQuantize(count3);
            if (unQuantize3 == CMAESOptimizer.DEFAULT_STOPFITNESS) {
                return (float) log2;
            }
            logAdd = logAdd(log2, (log + Math.log(unQuantize3)) - Math.log(unQuantize));
            count = count3;
        }
        return (float) logAdd;
    }

    private int getCount(int[] iArr, int i, int i2, int i3) {
        for (int i4 = 1; i4 <= i3; i4++) {
            if (!this.bf.query(hashNgram(iArr, i, i2, i4), this.countFuncs)) {
                return i4 - 1;
            }
        }
        return i3;
    }

    private int getTypesAfter(int[] iArr, int i, int i2, int i3) {
        if (!this.bf.query(hashNgram(iArr, i, i2, 1), this.countFuncs)) {
            return 0;
        }
        for (int i4 = 1; i4 < i3; i4++) {
            if (!this.bf.query(hashNgram(iArr, i, i2, i4), this.typesFuncs)) {
                return i4 - 1;
            }
        }
        return i3;
    }

    private int quantize(long j) {
        return 1 + ((int) Math.floor(Math.log(j) / Math.log(this.quantizationBase)));
    }

    private double unQuantize(int i) {
        return i == 0 ? CMAESOptimizer.DEFAULT_STOPFITNESS : (((this.quantizationBase + 1.0d) * Math.pow(this.quantizationBase, i - 1)) - 1.0d) / 2.0d;
    }

    private int hashNgram(int[] iArr, int i, int i2, int i3) {
        int i4 = 629 + i3;
        for (int i5 = i; i5 < i2; i5++) {
            i4 = (37 * i4) + iArr[i5];
        }
        return i4;
    }

    private static double logAdd(double d, double d2) {
        return d2 <= d ? d + Math.log1p(Math.exp(d2 - d)) : d2 + Math.log1p(Math.exp(d - d2));
    }

    public static void main(String[] strArr) {
        if (strArr.length < 5) {
            System.err.println("usage: BloomFilterLanguageModel <statistics file> <order> <size> <quantization base> <output file>");
            LOG.error("usage: BloomFilterLanguageModel <statistics file> <order> <size> <quantization base> <output file>");
            return;
        }
        try {
            BloomFilterLanguageModel bloomFilterLanguageModel = new BloomFilterLanguageModel(strArr[0], Integer.parseInt(strArr[1]), (int) (Integer.parseInt(strArr[2]) * Math.pow(2.0d, 23.0d)), Double.parseDouble(strArr[3]));
            ObjectOutputStream objectOutputStream = new ObjectOutputStream(new GZIPOutputStream(new FileOutputStream(strArr[4])));
            bloomFilterLanguageModel.writeExternal(objectOutputStream);
            objectOutputStream.close();
        } catch (IOException e) {
            LOG.error(e.getMessage(), (Throwable) e);
        }
    }

    private void populateBloomFilter(int i, String str) {
        InputStream inputStream;
        InputStream inputStream2;
        HashMap<String, Long> hashMap = new HashMap<>();
        try {
            FileInputStream fileInputStream = new FileInputStream(str);
            FileInputStream fileInputStream2 = new FileInputStream(str);
            if (str.endsWith(".gz")) {
                inputStream = new GZIPInputStream(fileInputStream);
                inputStream2 = new GZIPInputStream(fileInputStream2);
            } else {
                inputStream = fileInputStream;
                inputStream2 = fileInputStream2;
            }
            int estimateNumberOfObjects = estimateNumberOfObjects(inputStream2);
            LOG.debug("Estimated number of objects: {}", Integer.valueOf(estimateNumberOfObjects));
            this.bf = new BloomFilter(i, estimateNumberOfObjects);
            this.countFuncs = this.bf.initializeHashFunctions();
            populateFromInputStream(inputStream, hashMap);
            inputStream.close();
            this.typesFuncs = this.bf.initializeHashFunctions();
            for (String str2 : hashMap.keySet()) {
                String[] split = Regex.spaces.split(str2);
                int[] iArr = new int[split.length];
                for (int i2 = 0; i2 < split.length; i2++) {
                    iArr[i2] = Vocabulary.id(split[i2]);
                }
                add(iArr, hashMap.get(str2).longValue(), this.typesFuncs);
            }
        } catch (IOException e) {
            LOG.error(e.getMessage(), (Throwable) e);
        }
    }

    private int estimateNumberOfObjects(InputStream inputStream) {
        int i = 0;
        long j = 0;
        Iterator<String> it = new LineReader(inputStream).iterator();
        while (it.hasNext()) {
            String next = it.next();
            if (!next.trim().equals("")) {
                String[] split = Regex.spaces.split(next);
                if (split.length > this.ngramOrder + 1) {
                    continue;
                } else {
                    try {
                        long parseLong = Long.parseLong(split[split.length - 1]);
                        if (parseLong > j) {
                            j = parseLong;
                        }
                        i++;
                    } catch (NumberFormatException e) {
                        LOG.error(e.getMessage(), (Throwable) e);
                    }
                }
            }
        }
        return (int) Math.round(i * (Math.log(j) / Math.log(this.quantizationBase)));
    }

    private void populateFromInputStream(InputStream inputStream, HashMap<String, Long> hashMap) {
        this.numTokens = Double.NEGATIVE_INFINITY;
        Iterator<String> it = new LineReader(inputStream).iterator();
        while (it.hasNext()) {
            String[] split = Regex.spaces.split(it.next());
            if (split.length >= 2 && split.length <= this.ngramOrder + 1) {
                int[] iArr = new int[split.length - 1];
                StringBuilder sb = new StringBuilder();
                for (int i = 0; i < split.length - 1; i++) {
                    iArr[i] = Vocabulary.id(split[i]);
                    if (i < split.length - 2) {
                        sb.append(split[i]).append(" ");
                    }
                }
                long parseLong = Long.parseLong(split[split.length - 1]);
                add(iArr, parseLong, this.countFuncs);
                if (split.length == 2) {
                    this.numTokens = logAdd(this.numTokens, Math.log(parseLong));
                } else if (hashMap.get(sb) == null) {
                    hashMap.put(sb.toString(), 1L);
                } else {
                    hashMap.put(sb.toString(), Long.valueOf(hashMap.get(sb).longValue() + 1));
                }
            }
        }
    }

    private void add(int[] iArr, long j, long[][] jArr) {
        if (iArr == null) {
            return;
        }
        int quantize = quantize(j);
        for (int i = 1; i <= quantize; i++) {
            this.bf.add(hashNgram(iArr, 0, iArr.length, i), jArr);
        }
    }

    @Override // java.io.Externalizable
    public void readExternal(ObjectInput objectInput) throws IOException, ClassNotFoundException {
        int readInt = objectInput.readInt();
        for (int i = 0; i < readInt; i++) {
            Vocabulary.id(objectInput.readUTF());
        }
        this.numTokens = objectInput.readDouble();
        this.countFuncs = new long[objectInput.readInt()][2];
        for (int i2 = 0; i2 < this.countFuncs.length; i2++) {
            this.countFuncs[i2][0] = objectInput.readLong();
            this.countFuncs[i2][1] = objectInput.readLong();
        }
        this.typesFuncs = new long[objectInput.readInt()][2];
        for (int i3 = 0; i3 < this.typesFuncs.length; i3++) {
            this.typesFuncs[i3][0] = objectInput.readLong();
            this.typesFuncs[i3][1] = objectInput.readLong();
        }
        this.quantizationBase = objectInput.readDouble();
        this.bf = new BloomFilter();
        this.bf.readExternal(objectInput);
    }

    @Override // java.io.Externalizable
    public void writeExternal(ObjectOutput objectOutput) throws IOException {
        objectOutput.writeInt(Vocabulary.size());
        for (int i = 0; i < Vocabulary.size(); i++) {
            objectOutput.writeUTF(Vocabulary.word(i));
        }
        objectOutput.writeDouble(this.numTokens);
        objectOutput.writeInt(this.countFuncs.length);
        for (long[] jArr : this.countFuncs) {
            objectOutput.writeLong(jArr[0]);
            objectOutput.writeLong(jArr[1]);
        }
        objectOutput.writeInt(this.typesFuncs.length);
        for (long[] jArr2 : this.typesFuncs) {
            objectOutput.writeLong(jArr2[0]);
            objectOutput.writeLong(jArr2[1]);
        }
        objectOutput.writeDouble(this.quantizationBase);
        this.bf.writeExternal(objectOutput);
    }

    @Override // org.apache.joshua.decoder.ff.lm.DefaultNGramLanguageModel
    protected float ngramLogProbability_helper(int[] iArr, int i) {
        int[] iArr2 = new int[iArr.length];
        for (int i2 = 0; i2 < iArr.length; i2++) {
            iArr2[i2] = Vocabulary.id(Vocabulary.word(iArr[i2]));
        }
        return wittenBell(iArr2, i);
    }

    @Override // org.apache.joshua.decoder.ff.lm.NGramLanguageModel
    public boolean isOov(int i) {
        int[] iArr = {i};
        return getCount(iArr, iArr.length - 1, iArr.length, this.maxQ) == 0;
    }
}
