package de.otto.eventsourcing.query;

import de.otto.eventsourcing.event.Key;
import de.otto.eventsourcing.event.Payload;
import org.apache.kafka.clients.consumer.ConsumerRecord;
import org.slf4j.Logger;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

import static java.lang.String.format;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.slf4j.LoggerFactory.getLogger;

public class LatchedCallback implements ConsumerRecordCallback {
    private static final Logger LOG = getLogger(LatchedCallback.class);
    private final CountDownLatch latch;
    private volatile ConsumerRecord<Key, Payload> lastEvent;

    public LatchedCallback() {
        latch = new CountDownLatch(1);
    }

    public LatchedCallback(final int count) {
        latch = new CountDownLatch(count);
    }

    @Override
    public void onFailure(Throwable ex) {
        LOG.debug("received onFailure {}", ex.getMessage());
    }

    @Override
    public void onSuccess(final ConsumerRecord<Key, Payload> result) {
        LOG.trace("Received onSuccess {}", result);
        lastEvent = result;
        latch.countDown();
    }

    public ConsumerRecord<Key, Payload> await() throws InterruptedException, TimeoutException {
        return await(1, SECONDS);
    }

    public ConsumerRecord<Key, Payload> await(final int timeout, final TimeUnit timeUnit) throws InterruptedException, TimeoutException {
        LOG.trace("Waiting for latch...");
        if (latch.await(timeout, timeUnit)) {
            LOG.trace("...got it. Returning with {}", lastEvent);
            return lastEvent;
        } else {
            throw new TimeoutException(format("Did no receive onSuccess within %s%s", timeout, timeUnit));
        }
    }

}
