/*
 * Decompiled with CFR 0.152.
 */
package de.julielab.ml.embeddings.client;

import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import de.julielab.ml.embeddings.client.util.WordEmbeddingClientException;
import de.julielab.ml.embeddings.spi.EmbeddingVectors;
import de.julielab.ml.embeddings.spi.WordEmbedding;
import de.julielab.ml.embeddings.util.WordEmbeddingAccessException;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.UnsupportedEncodingException;
import java.net.MalformedURLException;
import java.net.URL;
import java.net.URLEncoder;
import java.nio.ByteBuffer;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.cpu.nativecpu.NDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class WordEmbeddingClient
implements WordEmbedding {
    private static final long serialVersionUID = 1976342456409693833L;
    private static final String UTF_8 = "UTF-8";
    private static final Logger log = LoggerFactory.getLogger(WordEmbeddingClient.class);
    public static final String GET_EMBEDDING = "get_embedding";
    public static final String GET_EMBEDDINGS = "get_embeddings";
    public static final String GET_EMBEDDINGS_MEAN = "get_embeddings_mean";
    public static final String GET_VOCAB_SIZE = "get_vocabulary_size";
    public static final String GET_WORD = "get_word";
    public static final String GET_HAS_WORD = "has_word";
    public static final String GET_EMBEDDING_DIMS = "get_embedding_dimensions";
    public static final String PARAM_WORD = "word";
    public static final String PARAM_INDEX = "index";
    private transient LoadingCache<String, double[]> vectorCache;
    private transient LoadingCache<WordListKey, EmbeddingVectors> vectorsCache;
    private transient LoadingCache<WordListKey, EmbeddingVectors> vectorsMeanCache;
    private static final double[] EMPTY = new double[0];
    private String host;
    private int port;

    public WordEmbeddingClient(String host, int port) {
        this.host = host;
        this.port = port;
        this.setupCaches();
    }

    public void setupCaches() {
        this.vectorCache = CacheBuilder.newBuilder().maximumSize(10000L).expireAfterAccess(5L, TimeUnit.MINUTES).build((CacheLoader)new CacheLoader<String, double[]>(){

            public double[] load(String key) throws Exception {
                return WordEmbeddingClient.this.loadWordVector(key);
            }
        });
        this.vectorsCache = CacheBuilder.newBuilder().maximumSize(10000L).expireAfterAccess(5L, TimeUnit.MINUTES).build((CacheLoader)new CacheLoader<WordListKey, EmbeddingVectors>(){

            public EmbeddingVectors load(WordListKey key) throws Exception {
                return WordEmbeddingClient.this.loadWordVectors(key.getQueryWords() instanceof List ? (List)key.getQueryWords() : new ArrayList<String>(key.getQueryWords()));
            }
        });
        this.vectorsMeanCache = CacheBuilder.newBuilder().maximumSize(10000L).expireAfterAccess(5L, TimeUnit.MINUTES).build((CacheLoader)new CacheLoader<WordListKey, EmbeddingVectors>(){

            public EmbeddingVectors load(WordListKey key) throws Exception {
                return WordEmbeddingClient.this.loadWordVectorsMean(key.getQueryWords() instanceof List ? (List)key.getQueryWords() : new ArrayList<String>(key.getQueryWords()));
            }
        });
    }

    private URL getUrl(String contextPath, String query) throws WordEmbeddingClientException {
        try {
            String spec = "http://" + this.host + ":" + this.port + "/" + contextPath;
            if (query != null) {
                spec = spec + "?" + query;
            }
            return new URL(spec);
        }
        catch (MalformedURLException e) {
            throw new WordEmbeddingClientException(e);
        }
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    private double[] loadWordVector(String word) {
        try {
            URL url = this.getUrl(GET_EMBEDDING, "word=" + URLEncoder.encode(word, UTF_8));
            if (log.isTraceEnabled()) {
                log.trace("Sending query {}", (Object)url.toString());
            }
            try (InputStream is = url.openStream();){
                int numread;
                ByteArrayOutputStream baos = new ByteArrayOutputStream();
                byte[] buffer = new byte[128];
                int allread = 0;
                while ((numread = is.read(buffer)) != -1) {
                    baos.write(buffer, 0, numread);
                    allread += numread;
                }
                log.trace("Read {} bytes", (Object)allread);
                ByteBuffer bb = ByteBuffer.wrap(baos.toByteArray());
                if (bb.capacity() == 1) {
                    double[] dArray = EMPTY;
                    return dArray;
                }
                double[] dArray = this.readEmbeddingVector(bb);
                return dArray;
            }
        }
        catch (WordEmbeddingClientException | IOException e) {
            throw new WordEmbeddingAccessException((Throwable)e);
        }
    }

    public double[] getWordVector(String word) throws WordEmbeddingAccessException {
        try {
            return (double[])this.vectorCache.get((Object)word);
        }
        catch (ExecutionException e) {
            throw new WordEmbeddingAccessException((Throwable)e);
        }
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    private EmbeddingVectors loadWordVectors(List<String> words) {
        try {
            String queryString = words.stream().map(w -> {
                try {
                    return "word=" + URLEncoder.encode(w, UTF_8);
                }
                catch (UnsupportedEncodingException e) {
                    log.error("{}", (Throwable)e);
                    return null;
                }
            }).collect(Collectors.joining("&"));
            URL url = this.getUrl(GET_EMBEDDINGS, queryString);
            if (log.isTraceEnabled()) {
                log.trace("Sending query {}", (Object)url.toString());
            }
            try (InputStream is = url.openStream();){
                int numread;
                ByteArrayOutputStream baos = new ByteArrayOutputStream();
                byte[] buffer = new byte[128];
                int allread = 0;
                while ((numread = is.read(buffer)) != -1) {
                    baos.write(buffer, 0, numread);
                    allread += numread;
                }
                log.trace("Read {} bytes", (Object)allread);
                ByteBuffer bb = ByteBuffer.wrap(baos.toByteArray());
                int embeddingDimensions = bb.getInt();
                log.trace("Got embedding vector dimensions {}", (Object)embeddingDimensions);
                List<Integer> foundIndices = this.readFoundIndices(words, bb);
                log.trace("Got {} found word indices: {}", (Object)foundIndices.size(), foundIndices);
                int numFound = foundIndices.size();
                double[][] vectors = new double[numFound][];
                for (int i = 0; i < numFound; ++i) {
                    double[] vector = this.readEmbeddingVector(bb);
                    vectors[i] = vector;
                }
                NDArray ndVectors = null;
                if (numFound > 0) {
                    ndVectors = new NDArray((double[][])vectors);
                }
                EmbeddingVectors embeddingVectors = new EmbeddingVectors((INDArray)ndVectors, words, foundIndices, embeddingDimensions, EmbeddingVectors.StreamType.CONCATENATION);
                return embeddingVectors;
            }
        }
        catch (WordEmbeddingClientException | IOException e) {
            throw new WordEmbeddingAccessException((Throwable)e);
        }
    }

    public EmbeddingVectors getWordVectors(List<String> words) {
        try {
            return (EmbeddingVectors)this.vectorsCache.get((Object)new WordListKey(words));
        }
        catch (ExecutionException e) {
            throw new WordEmbeddingAccessException((Throwable)e);
        }
    }

    public List<Integer> readFoundIndices(List<String> words, ByteBuffer bb) {
        return IntStream.range(0, words.size()).filter(i -> bb.get() == 1).mapToObj(Integer::new).collect(Collectors.toList());
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    private EmbeddingVectors loadWordVectorsMean(List<String> words) {
        try {
            String queryString = words.stream().map(w -> {
                try {
                    return "word=" + URLEncoder.encode(w, UTF_8);
                }
                catch (UnsupportedEncodingException e) {
                    log.error("{}", (Throwable)e);
                    return null;
                }
            }).collect(Collectors.joining("&"));
            URL url = this.getUrl(GET_EMBEDDINGS_MEAN, queryString);
            if (log.isTraceEnabled()) {
                log.trace("Sending query {}", (Object)url.toString());
            }
            try (InputStream is = url.openStream();){
                int numread;
                ByteArrayOutputStream baos = new ByteArrayOutputStream();
                byte[] buffer = new byte[128];
                int allread = 0;
                while ((numread = is.read(buffer)) != -1) {
                    baos.write(buffer, 0, numread);
                    allread += numread;
                }
                log.trace("Read {} bytes", (Object)allread);
                ByteBuffer bb = ByteBuffer.wrap(baos.toByteArray());
                int embeddingDimensions = bb.getInt();
                List<Integer> foundIndices = this.readFoundIndices(words, bb);
                INDArray meanVector = null;
                if (!foundIndices.isEmpty()) {
                    meanVector = Nd4j.create((double[])this.readEmbeddingVector(bb));
                }
                EmbeddingVectors embeddingVectors = new EmbeddingVectors(meanVector, words, foundIndices, embeddingDimensions, EmbeddingVectors.StreamType.AGGREGATION);
                return embeddingVectors;
            }
        }
        catch (WordEmbeddingClientException | IOException e) {
            throw new WordEmbeddingAccessException((Throwable)e);
        }
    }

    public EmbeddingVectors getWordVectorsMean(List<String> words) {
        try {
            return (EmbeddingVectors)this.vectorsMeanCache.get((Object)new WordListKey(words));
        }
        catch (ExecutionException e) {
            throw new WordEmbeddingAccessException((Throwable)e);
        }
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    public int getVocabularySize() {
        try {
            URL url = this.getUrl(GET_VOCAB_SIZE, null);
            if (log.isTraceEnabled()) {
                log.trace("Sending query {}", (Object)url.toString());
            }
            try (InputStream is = url.openStream();){
                int numread;
                ByteArrayOutputStream baos = new ByteArrayOutputStream();
                byte[] buffer = new byte[4];
                int allread = 0;
                while ((numread = is.read(buffer)) != -1) {
                    baos.write(buffer, 0, numread);
                    allread += numread;
                }
                log.trace("Read {} bytes", (Object)allread);
                int n = Integer.parseInt(new String(baos.toByteArray(), Charset.forName(UTF_8)));
                return n;
            }
        }
        catch (WordEmbeddingClientException | IOException e) {
            throw new WordEmbeddingAccessException((Throwable)e);
        }
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    public String getWord(int index) {
        try {
            URL url = this.getUrl(GET_WORD, "index=" + index);
            if (log.isTraceEnabled()) {
                log.trace("Sending query {}", (Object)url.toString());
            }
            try (InputStream is = url.openStream();){
                int numread;
                ByteArrayOutputStream baos = new ByteArrayOutputStream();
                byte[] buffer = new byte[128];
                int allread = 0;
                while ((numread = is.read(buffer)) != -1) {
                    baos.write(buffer, 0, numread);
                    allread += numread;
                }
                log.trace("Read {} bytes", (Object)allread);
                String string = new String(baos.toByteArray(), Charset.forName(UTF_8));
                return string;
            }
        }
        catch (WordEmbeddingClientException | IOException e) {
            throw new WordEmbeddingAccessException((Throwable)e);
        }
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    public boolean hasWord(String word) {
        try {
            URL url = this.getUrl(GET_HAS_WORD, "word=" + URLEncoder.encode(word, UTF_8));
            if (log.isTraceEnabled()) {
                log.trace("Sending query {}", (Object)url.toString());
            }
            try (InputStream is = url.openStream();){
                int numread;
                ByteArrayOutputStream baos = new ByteArrayOutputStream();
                byte[] buffer = new byte[1];
                int allread = 0;
                while ((numread = is.read(buffer)) != -1) {
                    baos.write(buffer, 0, numread);
                    allread += numread;
                }
                log.trace("Read {} bytes", (Object)allread);
                boolean bl = Boolean.parseBoolean(new String(baos.toByteArray(), Charset.forName(UTF_8)));
                return bl;
            }
        }
        catch (WordEmbeddingClientException | IOException e) {
            throw new WordEmbeddingAccessException((Throwable)e);
        }
    }

    public double[] readEmbeddingVector(ByteBuffer bb) {
        int numDoubles = bb.getInt() / 8;
        double[] v = new double[numDoubles];
        for (int i = 0; i < numDoubles; ++i) {
            v[i] = bb.getDouble();
        }
        return v;
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    public int getEmbeddingDimensions() {
        try {
            URL url = this.getUrl(GET_EMBEDDING_DIMS, null);
            if (log.isTraceEnabled()) {
                log.trace("Sending query {}", (Object)url.toString());
            }
            try (InputStream is = url.openStream();){
                int numread;
                ByteArrayOutputStream baos = new ByteArrayOutputStream();
                byte[] buffer = new byte[4];
                int allread = 0;
                while ((numread = is.read(buffer)) != -1) {
                    baos.write(buffer, 0, numread);
                    allread += numread;
                }
                log.trace("Read {} bytes", (Object)allread);
                int n = Integer.parseInt(new String(baos.toByteArray(), Charset.forName(UTF_8)));
                return n;
            }
        }
        catch (WordEmbeddingClientException | IOException e) {
            throw new WordEmbeddingAccessException((Throwable)e);
        }
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        in.defaultReadObject();
        this.setupCaches();
    }

    private class WordListKey {
        private String words;
        private Collection<String> queryWords;

        public WordListKey(Collection<String> queryWords) {
            this.queryWords = queryWords;
            this.words = queryWords.stream().collect(Collectors.joining());
        }

        public int hashCode() {
            int prime = 31;
            int result = 1;
            result = 31 * result + this.getOuterType().hashCode();
            result = 31 * result + (this.words == null ? 0 : this.words.hashCode());
            return result;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null) {
                return false;
            }
            if (this.getClass() != obj.getClass()) {
                return false;
            }
            WordListKey other = (WordListKey)obj;
            if (!this.getOuterType().equals(other.getOuterType())) {
                return false;
            }
            return !(this.words == null ? other.words != null : !this.words.equals(other.words));
        }

        private WordEmbeddingClient getOuterType() {
            return WordEmbeddingClient.this;
        }

        public Collection<String> getQueryWords() {
            return this.queryWords;
        }
    }
}

