package com.linkedin.dagli.fasttext.anonymized;

import com.linkedin.dagli.embedding.classification.FastTextInternal;
import com.linkedin.dagli.fasttext.anonymized.io.BufferedLineReader;
import com.linkedin.dagli.fasttext.anonymized.io.LineReader;
import com.linkedin.dagli.tuple.Tuple2;
import com.linkedin.dagli.util.array.ArraysEx;
import it.unimi.dsi.fastutil.ints.IntArrayList;
import it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap;
import it.unimi.dsi.fastutil.longs.Long2LongMap;
import it.unimi.dsi.fastutil.longs.Long2LongOpenHashMap;
import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap;
import it.unimi.dsi.fastutil.longs.LongArrayList;
import it.unimi.dsi.fastutil.longs.LongArrays;
import it.unimi.dsi.fastutil.objects.ObjectIterator;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.util.stream.IntStream;

/* loaded from: input_file:com/linkedin/dagli/fasttext/anonymized/Dictionary.class */
public class Dictionary {
    private static final int MAX_VOCAB_SIZE = 30000000;
    private static final int INITIAL_VOCAB_SIZE = 10000;
    private static final int INITIAL_LABEL_SIZE = 10000;
    private static final int MAX_LINE_SIZE = 1024;
    private static final Integer WORDID_DEFAULT = -1;
    private static final String BOW = "<";
    private static final String EOW = ">";
    private long[] _labelCounts;
    private String[] _labels;
    private Args args_;
    private long _totalWordsRead = 0;
    private long _totalLabelsRead = 0;
    private String _charsetName = "UTF-8";
    private Class<? extends LineReader> _lineReaderClass = BufferedLineReader.class;
    private Long2IntOpenHashMap _wordIDMap = null;
    private Long2IntOpenHashMap _labelIDMap = null;

    public Dictionary(Args args) {
        this.args_ = args;
    }

    public Long2IntOpenHashMap getWordIDMap() {
        return this._wordIDMap;
    }

    public long getTotalTokensRead() {
        return this._totalLabelsRead + this._totalWordsRead;
    }

    public int distinctWordCount() {
        return this._wordIDMap.size();
    }

    public int distinctLabelCount() {
        return this._labelIDMap.size();
    }

    public void readFromFile(String str) throws IOException {
        LineReader lineReader = null;
        Long2LongOpenHashMap long2LongOpenHashMap = new Long2LongOpenHashMap(10000);
        Long2LongOpenHashMap long2LongOpenHashMap2 = new Long2LongOpenHashMap(10000);
        Long2ObjectOpenHashMap long2ObjectOpenHashMap = new Long2ObjectOpenHashMap(10000);
        try {
            try {
                lineReader = this._lineReaderClass.getConstructor(String.class, String.class).newInstance(str, this._charsetName);
                long j = 0;
                long j2 = 1;
                long j3 = 0;
                while (true) {
                    String[] readLineTokens = lineReader.readLineTokens();
                    if (readLineTokens == null) {
                        break;
                    }
                    j3++;
                    for (int i = 0; i <= readLineTokens.length; i++) {
                        if (i == readLineTokens.length) {
                            long2LongOpenHashMap.addTo(1337L, 1L);
                            this._totalWordsRead++;
                        } else {
                            String str2 = readLineTokens[i];
                            if (!Utils.isEmpty(str2)) {
                                long hash = FastTextInternal.Util.hash(str2);
                                if (str2.startsWith(this.args_.label)) {
                                    long2LongOpenHashMap2.addTo(hash, 1L);
                                    this._totalLabelsRead++;
                                    long2ObjectOpenHashMap.put(hash, str2);
                                } else {
                                    long2LongOpenHashMap.addTo(hash, 1L);
                                    this._totalWordsRead++;
                                }
                            }
                        }
                    }
                    if (this.args_.verbose > 1 && (this._totalLabelsRead + this._totalWordsRead) / 1000000 > j) {
                        j = (this._totalLabelsRead + this._totalWordsRead) / 1000000;
                        System.err.printf("Read %dM tokens\n", Long.valueOf(j));
                    }
                    if (long2LongOpenHashMap.size() > 2.25E7d) {
                        j2++;
                        threshold(long2LongOpenHashMap, j2);
                    }
                }
                if (this.args_.verbose > 1) {
                    System.err.println("FastText found " + j3 + " lines in its input file");
                }
                if (lineReader != null) {
                    lineReader.close();
                }
                threshold(long2LongOpenHashMap, this.args_.minCount);
                Tuple2<long[], long[]> parallelArraysReverseSortedByCount = getParallelArraysReverseSortedByCount(long2LongOpenHashMap);
                this._wordIDMap = new Long2IntOpenHashMap((long[]) parallelArraysReverseSortedByCount.get0(), IntStream.range(0, ((long[]) parallelArraysReverseSortedByCount.get0()).length).toArray());
                this._wordIDMap.defaultReturnValue(-1);
                threshold(long2LongOpenHashMap2, this.args_.minCountLabel);
                Tuple2<long[], long[]> parallelArraysReverseSortedByCount2 = getParallelArraysReverseSortedByCount(long2LongOpenHashMap2);
                this._labelCounts = (long[]) parallelArraysReverseSortedByCount2.get1();
                this._labelIDMap = new Long2IntOpenHashMap((long[]) parallelArraysReverseSortedByCount2.get0(), IntStream.range(0, ((long[]) parallelArraysReverseSortedByCount2.get0()).length).toArray());
                this._labelIDMap.defaultReturnValue(-1);
                this._labels = new String[this._labelIDMap.size()];
                this._labelIDMap.long2IntEntrySet().fastForEach(entry -> {
                    this._labels[entry.getIntValue()] = (String) long2ObjectOpenHashMap.get(entry.getLongKey());
                });
                if (this.args_.verbose > 0) {
                    System.out.printf("\rRead %dM tokens\n", Long.valueOf((this._totalLabelsRead + this._totalWordsRead) / 1000000));
                    System.out.println("Number of words:  " + distinctWordCount());
                    System.out.println("Number of labels: " + distinctLabelCount());
                }
                if (this._wordIDMap.isEmpty()) {
                    throw new IllegalStateException("Empty vocabulary. Try a smaller -minCount value.");
                }
            } catch (IllegalAccessException | InstantiationException | NoSuchMethodException | InvocationTargetException e) {
                throw new RuntimeException("FastText ecountered an error while trying to create its line reader: " + this._lineReaderClass, e);
            }
        } catch (Throwable th) {
            if (lineReader != null) {
                lineReader.close();
            }
            throw th;
        }
    }

    private static Tuple2<long[], long[]> getParallelArraysReverseSortedByCount(Long2LongOpenHashMap long2LongOpenHashMap) {
        long[] jArr = new long[long2LongOpenHashMap.size()];
        long[] jArr2 = new long[long2LongOpenHashMap.size()];
        int i = 0;
        ObjectIterator it = long2LongOpenHashMap.long2LongEntrySet().iterator();
        while (it.hasNext()) {
            Long2LongMap.Entry entry = (Long2LongMap.Entry) it.next();
            jArr[i] = entry.getLongKey();
            jArr2[i] = entry.getLongValue();
            i++;
        }
        ArraysEx.sort(jArr2, jArr);
        return Tuple2.of(LongArrays.reverse(jArr), LongArrays.reverse(jArr2));
    }

    private static void threshold(Long2LongOpenHashMap long2LongOpenHashMap, long j) {
        long2LongOpenHashMap.long2LongEntrySet().removeIf(entry -> {
            return entry.getLongValue() < j;
        });
    }

    public long[] getLabelCounts() {
        return this._labelCounts;
    }

    public IntArrayList getNgramRowIDs(LongArrayList longArrayList, int i, int i2, int i3) {
        if (i <= 1) {
            return new IntArrayList(0);
        }
        IntArrayList intArrayList = new IntArrayList((i - 1) * longArrayList.size());
        for (int i4 = 0; i4 < longArrayList.size(); i4++) {
            long j = longArrayList.getLong(i4);
            for (int i5 = i4 + 1; i5 < longArrayList.size() && i5 < i4 + i; i5++) {
                j = FastTextInternal.Util.hash(j, longArrayList.getLong(i5));
                intArrayList.add(FastText.embeddingRowIndexForNgramHash(j, i2, i3));
            }
        }
        return intArrayList;
    }

    public int addWord(String str) {
        return this._wordIDMap.putIfAbsent(FastTextInternal.Util.hash(str), this._wordIDMap.size());
    }

    public int getLine(String[] strArr, LongArrayList longArrayList, IntArrayList intArrayList, IntArrayList intArrayList2) {
        intArrayList.clear();
        longArrayList.clear();
        intArrayList2.clear();
        if (strArr != null) {
            int i = 0;
            while (i <= strArr.length) {
                if (i >= strArr.length || !Utils.isEmpty(strArr[i])) {
                    long hash = i == strArr.length ? 1337L : FastTextInternal.Util.hash(strArr[i]);
                    if (i >= strArr.length || !strArr[i].startsWith(this.args_.label)) {
                        int i2 = this._wordIDMap.get(hash);
                        if (i2 >= 0) {
                            longArrayList.add(hash);
                            intArrayList.add(i2);
                        }
                    } else {
                        int i3 = this._labelIDMap.get(hash);
                        if (i3 >= 0) {
                            intArrayList2.add(i3);
                        }
                    }
                }
                i++;
            }
        }
        return intArrayList2.size() + intArrayList.size();
    }

    public String[] getLabels() {
        return this._labels;
    }

    public int getWordID(String str) {
        return this._wordIDMap.get(FastTextInternal.Util.hash(str));
    }

    public int getLabelID(String str) {
        return this._labelIDMap.get(FastTextInternal.Util.hash(str));
    }

    public String getLabel(int i) {
        return this._labels[i];
    }

    public String toString() {
        return "Dictionary [wordsIDMap_=" + this._wordIDMap + ", nlabels_=" + this._labels.length + "]";
    }

    public Args getArgs() {
        return this.args_;
    }

    public String getCharsetName() {
        return this._charsetName;
    }

    public Class<? extends LineReader> getLineReaderClass() {
        return this._lineReaderClass;
    }

    public void setCharsetName(String str) {
        this._charsetName = str;
    }

    public void setLineReaderClass(Class<? extends LineReader> cls) {
        this._lineReaderClass = cls;
    }
}
