package edu.berkeley.nlp.lm.cache;

import edu.berkeley.nlp.lm.AbstractContextEncodedNgramLanguageModel;
import edu.berkeley.nlp.lm.ContextEncodedNgramLanguageModel;
import edu.berkeley.nlp.lm.WordIndexer;
import edu.berkeley.nlp.lm.bits.BitUtils;
import edu.berkeley.nlp.lm.util.Annotations;
import edu.berkeley.nlp.lm.util.MurmurHash;

/* loaded from: input_file:berkeleylm-1.1.2.jar:edu/berkeley/nlp/lm/cache/ContextEncodedCachingLmWrapper.class */
public class ContextEncodedCachingLmWrapper<T> extends AbstractContextEncodedNgramLanguageModel<T> {
    private static final long serialVersionUID = 1;
    private final ContextEncodedLmCache contextCache;
    private final ContextEncodedNgramLanguageModel<T> lm;
    private final int capacity;

    public static <T> ContextEncodedCachingLmWrapper<T> wrapWithCacheNotThreadSafe(ContextEncodedNgramLanguageModel<T> contextEncodedNgramLanguageModel) {
        return wrapWithCacheNotThreadSafe(contextEncodedNgramLanguageModel, 18);
    }

    public static <T> ContextEncodedCachingLmWrapper<T> wrapWithCacheNotThreadSafe(ContextEncodedNgramLanguageModel<T> contextEncodedNgramLanguageModel, int i) {
        return new ContextEncodedCachingLmWrapper<>(contextEncodedNgramLanguageModel, false, i);
    }

    public static <T> ContextEncodedCachingLmWrapper<T> wrapWithCacheThreadSafe(ContextEncodedNgramLanguageModel<T> contextEncodedNgramLanguageModel) {
        return wrapWithCacheThreadSafe(contextEncodedNgramLanguageModel, 16);
    }

    public static <T> ContextEncodedCachingLmWrapper<T> wrapWithCacheThreadSafe(ContextEncodedNgramLanguageModel<T> contextEncodedNgramLanguageModel, int i) {
        return new ContextEncodedCachingLmWrapper<>(contextEncodedNgramLanguageModel, true, i);
    }

    private ContextEncodedCachingLmWrapper(ContextEncodedNgramLanguageModel<T> contextEncodedNgramLanguageModel, boolean z, int i) {
        this(contextEncodedNgramLanguageModel, new ContextEncodedDirectMappedLmCache(i, z));
    }

    private ContextEncodedCachingLmWrapper(ContextEncodedNgramLanguageModel<T> contextEncodedNgramLanguageModel, ContextEncodedLmCache contextEncodedLmCache) {
        super(contextEncodedNgramLanguageModel.getLmOrder(), contextEncodedNgramLanguageModel.getWordIndexer(), Float.NaN);
        this.lm = contextEncodedNgramLanguageModel;
        this.contextCache = contextEncodedLmCache;
        this.capacity = this.contextCache.capacity();
    }

    @Override // edu.berkeley.nlp.lm.AbstractNgramLanguageModel, edu.berkeley.nlp.lm.NgramLanguageModel
    public WordIndexer<T> getWordIndexer() {
        return this.lm.getWordIndexer();
    }

    @Override // edu.berkeley.nlp.lm.AbstractContextEncodedNgramLanguageModel, edu.berkeley.nlp.lm.ContextEncodedNgramLanguageModel
    public ContextEncodedNgramLanguageModel.LmContextInfo getOffsetForNgram(int[] iArr, int i, int i2) {
        return this.lm.getOffsetForNgram(iArr, i, i2);
    }

    @Override // edu.berkeley.nlp.lm.AbstractContextEncodedNgramLanguageModel, edu.berkeley.nlp.lm.ContextEncodedNgramLanguageModel
    public int[] getNgramForOffset(long j, int i, int i2) {
        return this.lm.getNgramForOffset(j, i, i2);
    }

    @Override // edu.berkeley.nlp.lm.AbstractContextEncodedNgramLanguageModel, edu.berkeley.nlp.lm.ContextEncodedNgramLanguageModel
    public float getLogProb(long j, int i, int i2, @Annotations.OutputParameter ContextEncodedNgramLanguageModel.LmContextInfo lmContextInfo) {
        if (i < 0) {
            return this.lm.getLogProb(j, i, i2, lmContextInfo);
        }
        int hash = hash(j, i, i2) % this.capacity;
        float cached = this.contextCache.getCached(j, i, i2, hash, lmContextInfo);
        if (!Float.isNaN(cached)) {
            return cached;
        }
        float logProb = this.lm.getLogProb(j, i, i2, lmContextInfo);
        this.contextCache.putCached(j, i, i2, logProb, hash, lmContextInfo);
        return logProb;
    }

    private static int hash(long j, int i, int i2) {
        return BitUtils.abs((int) MurmurHash.hashThreeLongs(j, i, i2));
    }
}
