package net.asynchorswim.ddd

import language.postfixOps
import akka.actor.{Actor, ActorLogging, ActorRef}
import scala.reflect.ClassTag
import akka.pattern.ask
import akka.util.Timeout
import concurrent.duration._
import scala.concurrent.duration.Duration
import scala.concurrent.{Await, Future}

abstract class AbstractAggregateRoot[A <: Entity[A] : ClassTag](propsFactory: EntityPropsFactory) extends Actor with ActorLogging {

  val route: PartialFunction[Any, String]
  val payload: PartialFunction[Any, Any] = { case x => x }

  implicit private val ex = context.dispatcher
  implicit val timeout = Timeout(10 seconds)

  override def preStart() = context.become(receive(Map.empty[String, ActorRef]))

  def receive = {
    case msg =>
      context.become(receive(Map.empty[String, ActorRef]))
      self forward payload(msg)
  }

  def receive(state: Map[String, ActorRef]): Receive = {
    case StreamMessage(msg: Broadcast) =>
      Await.ready(Future.sequence(state.values.map(_ ? msg)), Duration.Inf)
      sender ! StreamAck
    case StreamMessage(msg) =>
      processMessage(state, msg)
      sender ! StreamAck
    case msg: Broadcast =>
      state.values.foreach(_ forward msg)
    case msg  =>
      processMessage(state, msg)
  }

  private def processMessage(state: Map[String, ActorRef], msg: Any) = {
    val id = route(msg)
    state.get(id) match {
      case Some(actor) =>
        actor forward payload(msg)
      case None if payload(msg).isInstanceOf[CanBeFirst] =>
        val actor = context.actorOf(propsFactory.props[A], id)
        context.become(receive(state.updated(id, actor)))
        actor forward payload(msg)
      case _ =>
        sender ! AggregateRoot.UnknownEntity(id)
    }
  }
}

object AggregateRoot {
  case class UnknownEntity(id: String)
}

abstract class TransientAggregateRoot[A <: Entity[A] : ClassTag] extends AbstractAggregateRoot[A](TransientEntity)

abstract class EventSourcedAggregateRoot[A <: Entity[A] : ClassTag] extends AbstractAggregateRoot[A](EventSourcedEntity)

