package cn.zzq0324.radish.web.log;

import cn.zzq0324.radish.common.util.DateUtils;
import cn.zzq0324.radish.common.util.StrUtils;
import cn.zzq0324.radish.web.annotation.SkipLogRequestData;
import java.nio.charset.StandardCharsets;
import java.util.Date;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Map;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.servlet.DispatcherServlet;
import org.springframework.web.servlet.HandlerMapping;
import org.springframework.web.util.ContentCachingRequestWrapper;

/**
 * 访问日志记录类
 *
 * @author zzq0324
 * @since : 1.0.0
 */
@Slf4j
public class AccessLogger {

  private static final Logger BIZ_LOGGER = LoggerFactory.getLogger(AccessLog.class);

  // 本地线程上下文
  private static final ThreadLocal<AccessLogger> LOG_CONTEXT = new ThreadLocal<>();

  public static final String AND_SYMBOL = "&";
  public static final String EQUAL_SYMBOL = "=";

  // handler的格式，类名-方法名
  private static final String HANDLER_FORMAT = "%s-%s";

  // 日志分隔符
  private static final String LOG_SPLITTER = " | ";

  protected AccessLog accessLog;

  public AccessLogger(HttpServletRequest request) {
    initAccessLog(request);

    LOG_CONTEXT.set(this);
  }

  /**
   * 记录日志，此处注意catch异常，出问题不能影响核心流程
   */
  public void log(HttpServletRequest request, HttpServletResponse response) {
    try {
      // 设置请求数据，如果类或方法上增加@SkipRequestData的注解，则不记录请求数据
      if (!isSkipRequestData(request)) {
        accessLog.setRequestData(getRequestData(request));
      }

      accessLog.setUrlPattern(getPathPattern(request));
      accessLog.setHandler(getHandler(request));
      accessLog.setHttpStatusCode(response.getStatus());
      accessLog.setException(getException(request));
      accessLog.setElapsed(System.currentTimeMillis() - accessLog.getCreateTime().getTime());

      log.info(buildLog());
    } catch (Exception e) {
      BIZ_LOGGER.error("record access log error", e);
    } finally {
      // 移除本地线程数据，避免内存泄露
      LOG_CONTEXT.remove();
    }
  }

  private String buildLog() {
    StringBuilder logBuilder = new StringBuilder();
    String exceptionName = StrUtils.DASH;
    if (accessLog.getException() != null) {
      exceptionName = accessLog.getException().getClass().getName();
    }

    logBuilder.append(DateUtils.format(DateUtils.FORMAT_ISO_DATETIME, accessLog.getCreateTime()))
        .append(LOG_SPLITTER)
        .append(StrUtils.emptyToDash(accessLog.getClientIp()))
        .append(LOG_SPLITTER)
        .append(accessLog.getHttpMethod())
        .append(LOG_SPLITTER)
        .append(accessLog.getRequestUri())
        .append(LOG_SPLITTER)
        .append(accessLog.getUrlPattern())
        .append(LOG_SPLITTER)
        .append(accessLog.getUserAgent())
        .append(LOG_SPLITTER)
        .append(accessLog.getHandler())
        .append(LOG_SPLITTER)
        .append(accessLog.getElapsed())
        .append(LOG_SPLITTER)
        .append(accessLog.getHttpStatusCode())
        .append(LOG_SPLITTER)
        .append(StrUtils.removeLineBreak(StrUtils.emptyToDash(accessLog.getRequestData())))
        .append(LOG_SPLITTER)
        .append(exceptionName);

    // 个性化信息，拼接为key=value的格式
    Map<String, String> context = accessLog.getContext();
    if (context != null) {
      for (Map.Entry<String, String> entry : context.entrySet()) {
        logBuilder.append(LOG_SPLITTER).append(entry.getKey()).append('=').append(entry.getValue());
      }
    }

    return logBuilder.toString();
  }

  /**
   * 获取AccessLogger实例
   */
  public static AccessLogger getLogger() {
    return LOG_CONTEXT.get();
  }

  /**
   * 记录上下文字段信息
   */
  public void logContext(String key, String value) {
    if (accessLog.getContext() == null) {
      accessLog.setContext(new HashMap<>());
    }

    accessLog.getContext().put(key, StrUtils.removeLineBreak(value));
  }

  private String getHandler(HttpServletRequest request) {
    HandlerMethod handlerMethod = getHandlerMethod(request);
    // 静态资源或者访问的url 404就会出现HandlerMethod为null的情况
    if (handlerMethod == null) {
      return null;
    }

    String controllerName = handlerMethod.getBeanType().getSimpleName();
    String methodName = handlerMethod.getMethod().getName();

    return String.format(HANDLER_FORMAT, controllerName, methodName);
  }

  private String getPathPattern(HttpServletRequest request) {
    Object urlPattern = request.getAttribute(HandlerMapping.BEST_MATCHING_PATTERN_ATTRIBUTE);

    if (urlPattern != null) {
      return (String) urlPattern;
    }

    return null;
  }

  private Throwable getException(HttpServletRequest request) {
    Object exception = request.getAttribute(DispatcherServlet.EXCEPTION_ATTRIBUTE);
    if (exception instanceof Throwable) {
      return (Throwable) exception;
    }

    return null;
  }

  /**
   * 是否跳过请求数据的记录
   */
  private boolean isSkipRequestData(HttpServletRequest request) {
    HandlerMethod handlerMethod = getHandlerMethod(request);
    if (handlerMethod == null) {
      return true;
    }

    return handlerMethod.getMethod().isAnnotationPresent(SkipLogRequestData.class)
        || handlerMethod.getBeanType().isAnnotationPresent(SkipLogRequestData.class);
  }

  private HandlerMethod getHandlerMethod(HttpServletRequest request) {
    Object handlerObj = request.getAttribute(HandlerMapping.BEST_MATCHING_HANDLER_ATTRIBUTE);
    if (handlerObj instanceof HandlerMethod) {
      return (HandlerMethod) handlerObj;
    }

    return null;
  }

  /**
   * 获取请求数据
   */
  private String getRequestData(HttpServletRequest request) {
    if (!(request instanceof ContentCachingRequestWrapper)) {
      return null;
    }

    ContentCachingRequestWrapper requestWrapper = (ContentCachingRequestWrapper) request;
    if (request.getContentLength() > 0) {
      return new String(requestWrapper.getContentAsByteArray(), StandardCharsets.UTF_8);
    }

    return buildParameterValue(requestWrapper);
  }

  /**
   * 初始化access log共性参数
   */
  private void initAccessLog(HttpServletRequest request) {
    accessLog = new AccessLog();
    accessLog.setCreateTime(new Date());
    accessLog.setHttpMethod(request.getMethod());
    accessLog.setClientIp(request.getRemoteAddr());
    accessLog.setRequestUri(request.getRequestURI());
    accessLog.setUserAgent(request.getHeader(HttpHeaders.USER_AGENT));
  }

  /**
   * 构建a=b&c=d的参数格式
   */
  private String buildParameterValue(ContentCachingRequestWrapper request) {
    Enumeration<String> paramNames = request.getParameterNames();
    if (paramNames == null || !paramNames.hasMoreElements()) {
      return null;
    }

    StringBuilder builder = new StringBuilder();
    while (paramNames.hasMoreElements()) {
      if (builder.length() > 0) {
        builder.append(AND_SYMBOL);
      }

      String paramName = paramNames.nextElement();
      String value = request.getParameter(paramName);
      builder.append(paramName).append(EQUAL_SYMBOL).append(value);
    }

    return builder.toString();
  }
}
