package de.julielab.ml.embeddings.client;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.UnsupportedEncodingException;
import java.net.MalformedURLException;
import java.net.URL;
import java.net.URLEncoder;
import java.nio.ByteBuffer;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.cpu.nativecpu.NDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;

import de.julielab.ml.embeddings.client.util.WordEmbeddingClientException;
import de.julielab.ml.embeddings.spi.EmbeddingVectors;
import de.julielab.ml.embeddings.spi.EmbeddingVectors.StreamType;
import de.julielab.ml.embeddings.spi.WordEmbedding;
import de.julielab.ml.embeddings.util.WordEmbeddingAccessException;

public class WordEmbeddingClient implements WordEmbedding {

	/**
	 * 
	 */
	private static final long serialVersionUID = 1976342456409693833L;

	private static final String UTF_8 = "UTF-8";

	private static final Logger log = LoggerFactory.getLogger(WordEmbeddingClient.class);

	public static final String GET_EMBEDDING = "get_embedding";
	public static final String GET_EMBEDDINGS = "get_embeddings";
	public static final String GET_EMBEDDINGS_MEAN = "get_embeddings_mean";
	public static final String GET_VOCAB_SIZE = "get_vocabulary_size";
	public static final String GET_WORD = "get_word";
	public static final String GET_HAS_WORD = "has_word";
	public static final String GET_EMBEDDING_DIMS = "get_embedding_dimensions";
	public static final String PARAM_WORD = "word";
	public static final String PARAM_INDEX = "index";

	private transient LoadingCache<String, double[]> vectorCache;
	private transient LoadingCache<WordListKey, EmbeddingVectors> vectorsCache;
	private transient LoadingCache<WordListKey, EmbeddingVectors> vectorsMeanCache;

	private static final double[] EMPTY = new double[0];

	private String host;

	private int port;

	public WordEmbeddingClient(String host, int port) {
		this.host = host;
		this.port = port;
		setupCaches();
	}

	public void setupCaches() {
		vectorCache = CacheBuilder.newBuilder().maximumSize(10000).expireAfterAccess(5, TimeUnit.MINUTES)
				.build(new CacheLoader<String, double[]>() {

					@Override
					public double[] load(String key) throws Exception {
						return loadWordVector(key);
					}
				});
		vectorsCache = CacheBuilder.newBuilder().maximumSize(10000).expireAfterAccess(5, TimeUnit.MINUTES)
				.build(new CacheLoader<WordListKey, EmbeddingVectors>() {

					@Override
					public EmbeddingVectors load(WordListKey key) throws Exception {
						return loadWordVectors(
								(key.getQueryWords() instanceof List) ? (List<String>) key.getQueryWords()
										: new ArrayList<>(key.getQueryWords()));
					}
				});
		vectorsMeanCache = CacheBuilder.newBuilder().maximumSize(10000).expireAfterAccess(5, TimeUnit.MINUTES)
				.build(new CacheLoader<WordListKey, EmbeddingVectors>() {

					@Override
					public EmbeddingVectors load(WordListKey key) throws Exception {
						return loadWordVectorsMean(
								(key.getQueryWords() instanceof List) ? (List<String>) key.getQueryWords()
										: new ArrayList<>(key.getQueryWords()));
					}
				});
	}

	private URL getUrl(String contextPath, String query) throws WordEmbeddingClientException {
		try {
			String spec = "http://" + host + ":" + port + "/" + contextPath;
			if (query != null)
				spec += "?" + query;
			return new URL(spec);
		} catch (MalformedURLException e) {
			throw new WordEmbeddingClientException(e);
		}
	}

	private double[] loadWordVector(String word) {
		try {
			URL url = getUrl(GET_EMBEDDING, PARAM_WORD + "=" + URLEncoder.encode(word, "UTF-8"));
			if (log.isTraceEnabled())
				log.trace("Sending query {}", url.toString());
			try (InputStream is = url.openStream()) {
				ByteArrayOutputStream baos = new ByteArrayOutputStream();
				byte[] buffer = new byte[128];
				int numread;
				int allread = 0;
				while ((numread = is.read(buffer)) != -1) {
					baos.write(buffer, 0, numread);
					allread += numread;
				}
				log.trace("Read {} bytes", allread);
				ByteBuffer bb = ByteBuffer.wrap(baos.toByteArray());
				if (bb.capacity() == 1)
					return EMPTY;
				return readEmbeddingVector(bb);
			}
		} catch (IOException | WordEmbeddingClientException e) {
			throw new WordEmbeddingAccessException(e);
		}
	}

	@Override
	public double[] getWordVector(String word) throws WordEmbeddingAccessException {
		try {
			return vectorCache.get(word);
		} catch (ExecutionException e) {
			throw new WordEmbeddingAccessException(e);
		}
	}

	private EmbeddingVectors loadWordVectors(List<String> words) {
		try {
			String queryString = words.stream().map(w -> {
				try {
					return PARAM_WORD + "=" + URLEncoder.encode(w, "UTF-8");
				} catch (UnsupportedEncodingException e) {
					log.error("{}", e);
				}
				return null;
			}).collect(Collectors.joining("&"));
			URL url = getUrl(GET_EMBEDDINGS, queryString);
			if (log.isTraceEnabled())
				log.trace("Sending query {}", url.toString());
			try (InputStream is = url.openStream()) {
				ByteArrayOutputStream baos = new ByteArrayOutputStream();
				byte[] buffer = new byte[128];
				int numread;
				int allread = 0;
				while ((numread = is.read(buffer)) != -1) {
					baos.write(buffer, 0, numread);
					allread += numread;
				}
				log.trace("Read {} bytes", allread);
				ByteBuffer bb = ByteBuffer.wrap(baos.toByteArray());
				// Response: [4b (int): number of dimensions] [num words b: word in vocab or
				// not] [vectors]
				int embeddingDimensions = bb.getInt();
				log.trace("Got embedding vector dimensions {}", embeddingDimensions);
				List<Integer> foundIndices = readFoundIndices(words, bb);
				log.trace("Got {} found word indices: {}", foundIndices.size(), foundIndices);
				int numFound = foundIndices.size();
				double[][] vectors = new double[numFound][];
				for (int i = 0; i < numFound; i++) {
					double[] vector = readEmbeddingVector(bb);
					vectors[i] = vector;
				}
				NDArray ndVectors = null;
				if (numFound > 0)
					ndVectors = new NDArray(vectors);
				return new EmbeddingVectors(ndVectors, words, foundIndices, embeddingDimensions,
						StreamType.CONCATENATION);
			}
		} catch (IOException | WordEmbeddingClientException e) {
			throw new WordEmbeddingAccessException(e);
		}
	}

	@Override
	public EmbeddingVectors getWordVectors(List<String> words) {
		try {
			return vectorsCache.get(new WordListKey(words));
		} catch (ExecutionException e) {
			throw new WordEmbeddingAccessException(e);
		}
	}

	public List<Integer> readFoundIndices(List<String> words, ByteBuffer bb) {
		return IntStream.range(0, words.size()).filter(i -> bb.get() == 1).mapToObj(Integer::new)
				.collect(Collectors.toList());
	}

	private EmbeddingVectors loadWordVectorsMean(List<String> words) {
		try {
			String queryString = words.stream().map(w -> {
				try {
					return PARAM_WORD + "=" + URLEncoder.encode(w, "UTF-8");
				} catch (UnsupportedEncodingException e) {
					log.error("{}", e);
				}
				return null;
			}).collect(Collectors.joining("&"));
			URL url = getUrl(GET_EMBEDDINGS_MEAN, queryString);
			if (log.isTraceEnabled())
				log.trace("Sending query {}", url.toString());
			try (InputStream is = url.openStream()) {
				ByteArrayOutputStream baos = new ByteArrayOutputStream();
				byte[] buffer = new byte[128];
				int numread;
				int allread = 0;
				while ((numread = is.read(buffer)) != -1) {
					baos.write(buffer, 0, numread);
					allread += numread;
				}
				log.trace("Read {} bytes", allread);
				// Response: [4b (int): number of dimensions] [num words b: word in vocab or
				// not] [mean vector]
				ByteBuffer bb = ByteBuffer.wrap(baos.toByteArray());
				int embeddingDimensions = bb.getInt();
				List<Integer> foundIndices = readFoundIndices(words, bb);
				INDArray meanVector = null;
				if (!foundIndices.isEmpty())
					meanVector = Nd4j.create(readEmbeddingVector(bb));
				return new EmbeddingVectors(meanVector, words, foundIndices, embeddingDimensions,
						StreamType.AGGREGATION);
			}
		} catch (IOException | WordEmbeddingClientException e) {
			throw new WordEmbeddingAccessException(e);
		}
	}

	@Override
	public EmbeddingVectors getWordVectorsMean(List<String> words) {
		try {
			return vectorsMeanCache.get(new WordListKey(words));
		} catch (ExecutionException e) {
			throw new WordEmbeddingAccessException(e);
		}
	}

	@Override
	public int getVocabularySize() {
		try {
			URL url = getUrl(GET_VOCAB_SIZE, null);
			if (log.isTraceEnabled())
				log.trace("Sending query {}", url.toString());
			try (InputStream is = url.openStream()) {
				ByteArrayOutputStream baos = new ByteArrayOutputStream();
				byte[] buffer = new byte[4];
				int numread;
				int allread = 0;
				while ((numread = is.read(buffer)) != -1) {
					baos.write(buffer, 0, numread);
					allread += numread;
				}
				log.trace("Read {} bytes", allread);
				return Integer.parseInt(new String(baos.toByteArray(), Charset.forName(UTF_8)));
			}
		} catch (IOException | WordEmbeddingClientException e) {
			throw new WordEmbeddingAccessException(e);
		}
	}

	@Override
	public String getWord(int index) {
		try {
			URL url = getUrl(GET_WORD, PARAM_INDEX + "=" + index);
			if (log.isTraceEnabled())
				log.trace("Sending query {}", url.toString());
			try (InputStream is = url.openStream()) {
				ByteArrayOutputStream baos = new ByteArrayOutputStream();
				byte[] buffer = new byte[128];
				int numread;
				int allread = 0;
				while ((numread = is.read(buffer)) != -1) {
					baos.write(buffer, 0, numread);
					allread += numread;
				}
				log.trace("Read {} bytes", allread);
				return new String(baos.toByteArray(), Charset.forName(UTF_8));
			}
		} catch (IOException | WordEmbeddingClientException e) {
			throw new WordEmbeddingAccessException(e);
		}
	}

	@Override
	public boolean hasWord(String word) {
		try {
			URL url = getUrl(GET_HAS_WORD, PARAM_WORD + "=" + URLEncoder.encode(word, "UTF-8"));
			if (log.isTraceEnabled())
				log.trace("Sending query {}", url.toString());
			try (InputStream is = url.openStream()) {
				ByteArrayOutputStream baos = new ByteArrayOutputStream();
				byte[] buffer = new byte[1];
				int numread;
				int allread = 0;
				while ((numread = is.read(buffer)) != -1) {
					baos.write(buffer, 0, numread);
					allread += numread;
				}
				log.trace("Read {} bytes", allread);
				return Boolean.parseBoolean(new String(baos.toByteArray(), Charset.forName(UTF_8)));
			}
		} catch (IOException | WordEmbeddingClientException e) {
			throw new WordEmbeddingAccessException(e);
		}
	}

	public double[] readEmbeddingVector(ByteBuffer bb) {
		int numDoubles = bb.getInt() / 8;
		double[] v = new double[numDoubles];
		for (int i = 0; i < numDoubles; i++)
			v[i] = bb.getDouble();
		return v;
	}

	@Override
	public int getEmbeddingDimensions() {
		try {
			URL url = getUrl(GET_EMBEDDING_DIMS, null);
			if (log.isTraceEnabled())
				log.trace("Sending query {}", url.toString());
			try (InputStream is = url.openStream()) {
				ByteArrayOutputStream baos = new ByteArrayOutputStream();
				byte[] buffer = new byte[4];
				int numread;
				int allread = 0;
				while ((numread = is.read(buffer)) != -1) {
					baos.write(buffer, 0, numread);
					allread += numread;
				}
				log.trace("Read {} bytes", allread);
				return Integer.parseInt(new String(baos.toByteArray(), Charset.forName(UTF_8)));
			}
		} catch (IOException | WordEmbeddingClientException e) {
			throw new WordEmbeddingAccessException(e);
		}
	}

	private class WordListKey {
		private String words;
		private Collection<String> queryWords;

		public WordListKey(Collection<String> queryWords) {
			this.queryWords = queryWords;
			words = queryWords.stream().collect(Collectors.joining());
		}

		@Override
		public int hashCode() {
			final int prime = 31;
			int result = 1;
			result = prime * result + getOuterType().hashCode();
			result = prime * result + ((words == null) ? 0 : words.hashCode());
			return result;
		}

		@Override
		public boolean equals(Object obj) {
			if (this == obj)
				return true;
			if (obj == null)
				return false;
			if (getClass() != obj.getClass())
				return false;
			WordListKey other = (WordListKey) obj;
			if (!getOuterType().equals(other.getOuterType()))
				return false;
			if (words == null) {
				if (other.words != null)
					return false;
			} else if (!words.equals(other.words))
				return false;
			return true;
		}

		private WordEmbeddingClient getOuterType() {
			return WordEmbeddingClient.this;
		}

		public Collection<String> getQueryWords() {
			return queryWords;
		}

	}

	private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
		in.defaultReadObject();
		setupCaches();
	}
}