package cn.zzq0324.radish.extension;

import cn.zzq0324.radish.common.spring.SpringContextHolder;
import cn.zzq0324.radish.extension.annotation.Extension;
import cn.zzq0324.radish.extension.annotation.SPI;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.aop.support.AopUtils;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;

/**
 * 扩展实现加载器
 *
 * @author: zzq0324
 * @since : 1.0.0
 */
@Slf4j
public class ExtensionLoader<T> {

  // 所有的扩展点加载器，初始化之后缓存起来，提高访问效率。每个扩展点一个加载器实例
  private static final ConcurrentMap<Class<?>, ExtensionLoader<?>> EXTENSION_LOADERS = new ConcurrentHashMap();

  // 扩展实现map, value按照
  private final Map<BusinessScenario, List<T>> extensionInfoMap = new LinkedHashMap<>();

  /**
   * 私有构造器，不允许直接new
   */
  private ExtensionLoader(Class<T> type) {
    if (!withSPIAnnotation(type)) {
      throw new IllegalArgumentException("Type (" + type + ") should be annotated with @SPI !");
    }

    // key为beanName，value为对应的bean
    Map<String, T> beanMap = SpringContextHolder.getApplicationContext().getBeansOfType(type);

    if (CollectionUtils.isEmpty(beanMap)) {
      log.warn("empty extension with type [{}]", type.getName());

      return;
    }

    // 转为List并按order排序
    List<T> orderedExtensionList = sort(beanMap);

    orderedExtensionList.stream().forEach(extensionInstance -> {
      Extension extension = getExtensionAnnotation(extensionInstance);

      BusinessScenario scenario = BusinessScenario.of(extension.business(), extension.useCase(), extension.scenario());

      // 添加到扩展点分组map
      addExtensionToMap(scenario, extensionInstance);

    });

    printExtensionList(type);

    // 检测场景是否冲突，此处打印主要是在打印之后更便于排查
    checkRepeatable(type);
  }

  /**
   * 获取对应扩展节点的加载器，每个扩展节点有自己独立的加载器
   */
  public static <T> ExtensionLoader<T> getExtensionLoader(Class<T> type) {
    if (type == null) {
      throw new IllegalArgumentException("Extension type == null");
    }

    ExtensionLoader<T> loader = (ExtensionLoader) EXTENSION_LOADERS.get(type);
    if (loader == null) {
      synchronized (type) {
        loader = (ExtensionLoader) EXTENSION_LOADERS.computeIfAbsent(type, (spiType) -> new ExtensionLoader(spiType));
      }
    }

    return loader;
  }

  public static <T> T getExtension(Class<T> type, BusinessScenario scenario) {
    return getExtension(type, scenario, false);
  }

  public static <T> T getExtension(Class<T> type, String scenario) {
    BusinessScenario businessScenario = BusinessScenario.of(ExtensionConstant.DEFAULT_BUSINESS,
        ExtensionConstant.DEFAULT_USE_CASE, scenario);
    
    return getExtension(type, businessScenario, false);
  }

  public static <T> T getExtension(Class<T> type, BusinessScenario scenario, boolean fallback) {
    List<T> extList = getExtensionList(type, scenario, fallback);
    if (CollectionUtils.isEmpty(extList)) {
      return null;
    }

    return extList.get(0);
  }

  /**
   * 根据场景获取对应的扩展实现列表，不支持容错降级
   */
  public static <T> List<T> getExtensionList(Class<T> type, BusinessScenario scenario) {
    return getExtensionList(type, scenario, false);
  }

  /**
   * 根据场景获取对应的扩展实现列表
   */
  public static <T> List<T> getExtensionList(Class<T> type, BusinessScenario scenario, boolean failOver) {
    List<T> extensionList = getExtensionLoader(type).extensionInfoMap.get(scenario);
    if (!CollectionUtils.isEmpty(extensionList)) {
      return extensionList;
    }

    // 如果不需要降级，直接返回
    if (!failOver) {
      return null;
    }

    return getFailOverExtensionList(type, scenario);
  }

  public static <T> List<T> getFailOverExtensionList(Class<T> type, BusinessScenario scenario) {
    // business + useCase + 默认场景
    scenario = BusinessScenario.of(scenario.getBusiness(), scenario.getUseCase());
    List<T> extensionList = getExtensionLoader(type).extensionInfoMap.get(scenario);
    if (!CollectionUtils.isEmpty(extensionList)) {
      return extensionList;
    }

    // business + 默认useCase + 默认scenario
    scenario = BusinessScenario.of(scenario.getBusiness());
    extensionList = getExtensionLoader(type).extensionInfoMap.get(scenario);
    if (!CollectionUtils.isEmpty(extensionList)) {
      return extensionList;
    }

    // 采用全默认的进行之星
    scenario = BusinessScenario.of();
    return getExtensionLoader(type).extensionInfoMap.get(scenario);
  }

  public Map<BusinessScenario, List<T>> getExtensionInfoMap() {
    return extensionInfoMap;
  }

  private void checkRepeatable(Class<?> type) {
    SPI spi = type.getAnnotation(SPI.class);
    // 允许重复不需要检测
    if (spi.scenarioRepeatable()) {
      return;
    }

    extensionInfoMap.values().forEach(extList -> Assert.isTrue(extList.size() <= 1, "scenario cannot repeat!"));
  }

  /**
   * 添加到扩展点分组的map
   */
  private void addExtensionToMap(BusinessScenario scenario, T extensionInstance) {
    List<T> extensionList = extensionInfoMap.computeIfAbsent(scenario, group -> new ArrayList<>());
    extensionList.add(extensionInstance);
  }

  /**
   * 打印扩展实现列表
   */
  private void printExtensionList(Class<T> type) {
    log.info("|──  {}", type.getName());
    extensionInfoMap.keySet().forEach(scenario -> {
      log.info("|    |── {}", scenario);

      extensionInfoMap.get(scenario).forEach(instance -> log.info("|    |    |── {}", instance.getClass().getName()));
    });
  }

  /**
   * 实现节点按order排序，数值越小越靠前
   */
  private <T> List<T> sort(Map<String, T> beanMap) {
    List<T> orderedExtensionList = beanMap.values().stream().collect(Collectors.toList());
    // 按照@Extension上的order排序
    Collections.sort(orderedExtensionList, (o1, o2) -> {
      Extension extension1 = getExtensionAnnotation(o1);
      Extension extension2 = getExtensionAnnotation(o2);

      return extension1.order() - extension2.order();
    });

    return orderedExtensionList;
  }

  /**
   * 获取实现节点上的@Extension注解
   */
  private <T> Extension getExtensionAnnotation(T extensionInstance) {
    Class<?> targetClass = AopUtils.getTargetClass(extensionInstance);

    // 未注解@Extension将忽略，不作为扩展实现节点使用
    if (!targetClass.isAnnotationPresent(Extension.class)) {
      throw new RuntimeException(targetClass.getName() + " not annotated with @Extension.");
    }

    return targetClass.getAnnotation(Extension.class);
  }

  /**
   * 判断是否声明为扩展点
   */
  private static <T> boolean withSPIAnnotation(Class<T> type) {
    return type.isAnnotationPresent(SPI.class);
  }
}
