package pl.codewise.commons.aws.cqrs.utils;

import org.awaitility.core.ConditionTimeoutException;
import org.hamcrest.Matcher;
import org.hamcrest.Matchers;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.concurrent.Callable;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;

import static java.lang.String.format;
import static org.awaitility.Awaitility.await;

public class Awaitilities {

    private static final Logger log = LoggerFactory.getLogger(Awaitilities.class);

    public void awaitTillActionSucceed(
            long atMostInMilliseconds,
            long pollIntervalMilliseconds,
            long pollDelayMilliseconds,
            String message,
            Callable<Boolean> action) {

        log.info("Wait for {}: [atMost: {}ms, pollInterval: {}ms, pollDelay: {}ms]",
                message, atMostInMilliseconds, pollIntervalMilliseconds, pollDelayMilliseconds);

        AtomicInteger counter = new AtomicInteger(1);

        try {
            await()
                    .atMost(atMostInMilliseconds, TimeUnit.MILLISECONDS)
                    .pollInterval(pollIntervalMilliseconds, TimeUnit.MILLISECONDS)
                    .pollDelay(pollDelayMilliseconds, TimeUnit.MILLISECONDS)
                    .until(() -> {
                        try {
                            log.info("Wait for {}: {}ms out of {}ms has passed", message,
                                    counter.getAndIncrement() * pollIntervalMilliseconds, atMostInMilliseconds);
                            return action.call();
                        } catch (Exception e) {
                            if (log.isDebugEnabled()) {
                                log.debug(e.getMessage(), e);
                            } else {
                                log.warn(e.getMessage());
                            }
                            return false;
                        }
                    });
        } catch (ConditionTimeoutException e) {
            throw message != null ?
                    new ConditionTimeoutException(
                            format("Wait for %s timed out after %dms", message, atMostInMilliseconds)) : e;
        }
        log.info("Wait for {}. DONE.", message);
    }

    public <T> T awaitForValue(int atMostInMilliseconds,
            int pollIntervalInMilliseconds,
            Callable<T> action,
            String defaultExceptionMessage) {

        AtomicReference<RuntimeException> lastException = new AtomicReference<>(
                new RuntimeException(defaultExceptionMessage));

        try {
            return await()
                    .atMost(atMostInMilliseconds, TimeUnit.MILLISECONDS)
                    .pollInterval(pollIntervalInMilliseconds, TimeUnit.MILLISECONDS)
                    .until(() -> {
                        try {
                            return action.call();
                        } catch (RuntimeException e) {
                            lastException.set(e);
                            return null;
                        }
                    }, Matchers.notNullValue());
        } catch (Exception e) {
            throw lastException.get();
        }
    }

    public void awaitTillActionSucceed(
            long atMostInMilliseconds,
            long pollIntervalMilliseconds,
            String message,
            Callable<Boolean> action) {
        awaitTillActionSucceed(atMostInMilliseconds, pollIntervalMilliseconds, 100, message, action);
    }

    public <T> T awaitForValueOrReturnLastValue(int atMostInMilliseconds,
            int pollIntervalInMilliseconds,
            Callable<T> action,
            Matcher<T> matcher,
            String defaultExceptionMessage) {

        AtomicReference<RuntimeException> lastException = new AtomicReference<>(
                new RuntimeException(defaultExceptionMessage));

        final AtomicReference<T> lastValue = new AtomicReference<>();

        try {
            return await()
                    .atMost(atMostInMilliseconds, TimeUnit.MILLISECONDS)
                    .pollInterval(pollIntervalInMilliseconds, TimeUnit.MILLISECONDS)
                    .pollDelay(pollIntervalInMilliseconds, TimeUnit.MILLISECONDS)
                    .until(() -> {
                        try {
                            T result = action.call();
                            lastValue.set(result);

                            return result;
                        } catch (RuntimeException e) {
                            lastException.set(e);
                            return null;
                        }
                    }, matcher);
        } catch (ConditionTimeoutException e) {
            T lastResult = lastValue.get();
            if (lastResult != null) {
                return lastResult;
            }

            throw lastException.get();
        } catch (Exception e) {
            throw lastException.get();
        }
    }
}
