package net.morimekta.providence.storage;

import net.morimekta.providence.PMessage;
import net.morimekta.providence.descriptor.PField;
import net.morimekta.util.concurrent.ReadWriteMutex;
import net.morimekta.util.concurrent.ReentrantReadWriteMutex;

import javax.annotation.Nonnull;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.function.Function;

/**
 * Simple in-memory set storage of providence messages. Uses a local hash map for
 * storing the instances. The store is thread safe through using re-entrant
 * read-write mutex handling, so reading can happen in parallel.
 */
public class InMemoryMessageSetStore<K, M extends PMessage<M,F>, F extends PField> implements MessageSetStore<K,M,F> {
    private final Map<K, M>      map;
    private final ReadWriteMutex mutex;
    private final Function<M, K> messageToKey;

    public InMemoryMessageSetStore(Function<M,K> messageToKey) {
        this.messageToKey = messageToKey;
        this.map   = new HashMap<>();
        this.mutex = new ReentrantReadWriteMutex();
    }

    @Nonnull
    @Override
    public Map<K, M> getAll(@Nonnull Collection<K> keys) {
        return mutex.lockForReading(() -> {
            Map<K, M> out = new HashMap<>();
            for (K key : keys) {
                if (map.containsKey(key)) {
                    out.put(key, map.get(key));
                }
            }
            return out;
        });
    }

    @Override
    public boolean containsKey(@Nonnull K key) {
        return mutex.lockForReading(() -> map.containsKey(key));
    }

    @Override @Nonnull
    public Collection<K> keys() {
        return mutex.lockForReading(() -> new HashSet<>(map.keySet()));
    }

    @Override @Nonnull
    public Map<K, M> putAll(@Nonnull Collection<M> values) {
        return mutex.lockForWriting(() -> {
            Map<K, M> out = new HashMap<>(values.size());
            for (M entry : values) {
                K key = messageToKey.apply(entry);
                M tmp = map.put(key, entry);
                if (tmp != null) {
                    out.put(key, tmp);
                }
            }
            return out;
        });
    }

    @Override @Nonnull
    public Map<K, M> removeAll(Collection<K> keys) {
        return mutex.lockForWriting(() -> {
            Map<K, M> out = new HashMap<>(keys.size());
            for (K key : keys) {
                M tmp = map.remove(key);
                if (tmp != null) {
                    out.put(key, tmp);
                }
            }
            return out;
        });
    }
}
