package io.contek.tusk.counter;

import io.contek.tusk.BatchingConfig;
import io.contek.tusk.Metric;
import io.contek.tusk.Table;

import javax.annotation.Nullable;
import javax.annotation.concurrent.Immutable;
import javax.annotation.concurrent.NotThreadSafe;
import java.time.Duration;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.Future;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.atomic.AtomicReference;

import static java.util.concurrent.Executors.newSingleThreadScheduledExecutor;
import static java.util.concurrent.TimeUnit.SECONDS;

@Immutable
public final class Counter {

  private static final Duration DEFAULT_COUNT_PERIOD = Duration.ofSeconds(15);
  private static final int DEFAULT_COUNTS_PER_BATCH = 4;

  private final Metric metric;
  private final String countColumn;
  private final Duration countPeriod;
  private final Map<TagSet, Integer> counts = new HashMap<>();
  private final AtomicReference<Future<?>> task = new AtomicReference<>(null);

  private final ScheduledExecutorService scheduler = newSingleThreadScheduledExecutor();

  private Counter(Metric metric, String countColumn, Duration countPeriod) {
    this.metric = metric;
    this.countColumn = countColumn;
    this.countPeriod = countPeriod;
  }

  public static Counter counter(String table, String countColumn) {
    return counter(null, table, countColumn);
  }

  public static Counter counter(@Nullable String database, String table, String countColumn) {
    return counter(Table.newBuilder().setDatabase(database).setName(table).build(), countColumn);
  }

  public static Counter counter(Table table, String countColumn) {
    return counter(table, countColumn, DEFAULT_COUNT_PERIOD);
  }

  public static Counter counter(Table table, String countColumn, Duration countPeriod) {
    return counter(
        table,
        countColumn,
        countPeriod,
        BatchingConfig.forDuration(countPeriod.multipliedBy(DEFAULT_COUNTS_PER_BATCH)));
  }

  public static Counter counter(
      Table table, String countColumn, Duration countPeriod, BatchingConfig batching) {
    return new Counter(Metric.metric(table, batching), countColumn, countPeriod);
  }

  public Tagging withTags() {
    return new Tagging(this);
  }

  private void count(TagSet tags, int n) {
    synchronized (counts) {
      counts.compute(
          tags,
          (k, oldValue) -> {
            if (oldValue == null) {
              return n;
            }
            return oldValue + n;
          });
    }

    scheduleIfIdle();
  }

  private void scheduleIfIdle() {
    synchronized (task) {
      Future<?> future = task.get();
      if (future != null && !future.isDone()) {
        return;
      }
      schedule();
    }
  }

  private void flushAndSchedule() {
    boolean updated = flush();
    if (!updated) {
      return;
    }

    schedule();
  }

  private void schedule() {
    synchronized (task) {
      Future<?> future =
          scheduler.schedule(this::flushAndSchedule, countPeriod.getSeconds(), SECONDS);
      task.set(future);
    }
  }

  private boolean flush() {
    Map<TagSet, Integer> buffer;
    synchronized (counts) {
      if (counts.isEmpty()) {
        return false;
      }
      buffer = new HashMap<>(counts);
      counts.clear();
    }

    for (Map.Entry<TagSet, Integer> next : buffer.entrySet()) {
      Map<String, String> tags = next.getKey().getTags();
      int count = next.getValue();
      metric.newEntry().putAll(tags).putUInt32(countColumn, count).write();
    }

    return true;
  }

  @NotThreadSafe
  public static final class Tagging {

    private final Counter counter;
    private final Map<String, String> tags = new HashMap<>();

    private Tagging(Counter counter) {
      this.counter = counter;
    }

    public Tagging put(String key, String value) {
      this.tags.put(key, value);
      return this;
    }

    public Tagging putAll(Map<String, String> tags) {
      this.tags.putAll(tags);
      return this;
    }

    public void count() {
      count(1);
    }

    public void count(int n) {
      counter.count(new TagSet(tags), n);
    }
  }
}
