package cn.springlet.redis.aspect;


import cn.springlet.core.auto_config.SPELParserUtils;
import cn.springlet.core.exception.web_return.RateLimiterException;
import cn.springlet.core.exception.web_return.ReturnMsgException;
import cn.springlet.core.util.StrUtil;
import cn.springlet.redis.annotation.RedisRateLimiter;
import cn.springlet.redis.annotation.RedisRateLimiters;
import cn.springlet.redis.bean.RedisRateLimiterBean;
import cn.springlet.redis.constant.RedisCacheKey;
import lombok.extern.slf4j.Slf4j;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.core.annotation.Order;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.stereotype.Component;

import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.stream.Collectors;

/**
 * 限流切面
 *
 * @author springlet
 * @time 2020/11/5
 */
@Aspect
@Component
@Order(-1)
@Slf4j
public class RedisRateLimiterAspect {

    @Autowired
    private RedisTemplate<Object, Object> redisTemplate;

    @Autowired
    @Qualifier("redisRateLimiterScript")
    private RedisScript<Long> limitScript;

    /**
     * 定义点
     */
    @Pointcut("@annotation(cn.springlet.redis.annotation.RedisRateLimiter) || @annotation(cn.springlet.redis.annotation.RedisRateLimiters)")
    public void pointcut() {
    }

    /**
     * 限流处理
     *
     * @return
     * @throws Throwable
     */
    @Before(value = "pointcut()")
    public void doBefore(JoinPoint point) {
        MethodSignature sign = (MethodSignature) point.getSignature();
        Method method = sign.getMethod();
        RedisRateLimiter redisRateLimiter = AnnotationUtils.findAnnotation(method, RedisRateLimiter.class);
        RedisRateLimiters redisRateLimiters = AnnotationUtils.findAnnotation(method, RedisRateLimiters.class);

        List<RedisRateLimiterBean> beanList = new ArrayList<>();
        //将 RedisRateLimiters 按 时间长短排序，时间最长的放在最前边
        if (redisRateLimiter != null) {
            beanList.add(initBean(redisRateLimiter));
        }
        if (redisRateLimiters != null) {
            for (RedisRateLimiter rateLimiter : redisRateLimiters.value()) {
                beanList.add(initBean(rateLimiter));
            }
        }
        handleRateLimiter(beanList, point);
    }


    private void handleRateLimiter(List<RedisRateLimiterBean> beanList, JoinPoint point) {
        //按限流时间排序，时间区间最长的在前边，优先判断
        beanList = beanList.stream().sorted(Comparator.comparing(RedisRateLimiterBean::getMillisecondsTime).reversed()).collect(Collectors.toList());
        int size = beanList.size();

        List<Object> keys = new ArrayList<>();
        Object[] args = new Object[size * 2];

        for (int i = 0; i < size; i++) {
            RedisRateLimiterBean redisRateLimiterBean = beanList.get(i);
            keys.add(getKey(point, redisRateLimiterBean.getKey()));
            int pageIndex = i * 2;
            args[pageIndex] = redisRateLimiterBean.getCount();
            args[pageIndex + 1] = redisRateLimiterBean.getMillisecondsTime();
        }

        try {
            //多个限流规则，一个lua脚本执行，lua脚本先判断所有的规则是否限流，然后再统一自增，如果有限流规则命中，则所有的都不自增
            Long index = redisTemplate.execute(limitScript, keys, args);
            if (index == null) {
                return;
            }
            RedisRateLimiterBean redisRateLimiterBean = beanList.get(index.intValue());
            log.info("触发限流:限流key:{},限流配置:次数{},时间:{},时间单位:{}", getKey(point, redisRateLimiterBean.getKey()), redisRateLimiterBean.getCount(), redisRateLimiterBean.getTime(), redisRateLimiterBean.getUnit());
            throw new RateLimiterException(redisRateLimiterBean.getErrMsg());
        } catch (RateLimiterException e) {
            throw e;
        } catch (Exception e) {
            log.error("服务器限流异常", e);
            throw new ReturnMsgException("服务器限流异常，请稍候再试");
        }
    }

    /**
     * 获取 key
     *
     * @return
     */
    private String getKey(JoinPoint joinPoint, String annotationKey) {
        String className = joinPoint.getTarget().getClass().getName();
        String methodName = joinPoint.getSignature().getName();
        String key = null;
        if (StrUtil.isNotBlank(annotationKey)) {
            String parseValue = SPELParserUtils.parse(((MethodSignature) joinPoint.getSignature()).getMethod(), joinPoint.getArgs(), annotationKey, String.class);
            if (StrUtil.isBlank(parseValue)) {
                key = StrUtil.format("{}#{}", className, methodName);
            } else {
                key = StrUtil.format("{}#{}#{}", className, methodName, parseValue);
            }
        } else {
            key = StrUtil.format("{}#{}", className, methodName);
        }

        return RedisCacheKey.RATE_LIMITER_KEY + key;
    }

    private RedisRateLimiterBean initBean(RedisRateLimiter redisRateLimiter) {
        RedisRateLimiterBean redisRateLimiterBean = new RedisRateLimiterBean();
        redisRateLimiterBean.setTime(redisRateLimiter.time());
        redisRateLimiterBean.setUnit(redisRateLimiter.unit());
        redisRateLimiterBean.setKey(redisRateLimiter.key());
        redisRateLimiterBean.setCount(redisRateLimiter.count());
        redisRateLimiterBean.setErrMsg(redisRateLimiter.errMsg());
        redisRateLimiterBean.setMillisecondsTime(Long.valueOf(redisRateLimiter.unit().toMillis(redisRateLimiter.time())).intValue());
        return redisRateLimiterBean;
    }
}
