package cn.lingyangwl.framework.security.ratelimit;

import cn.lingyangwl.framework.core.http.HttpRequestServletManager;
import cn.lingyangwl.framework.core.utils.IpUtils;
import cn.lingyangwl.framework.core.utils.servlet.ServletUtils;
import cn.lingyangwl.framework.tool.core.DateUtils;
import cn.lingyangwl.framework.tool.core.StringUtils;
import cn.lingyangwl.framework.tool.core.exception.BizException;
import cn.hutool.crypto.digest.DigestUtil;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.reflect.MethodSignature;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.stereotype.Component;

import javax.annotation.Resource;
import javax.servlet.http.HttpServletRequest;
import java.lang.reflect.Method;
import java.time.Duration;
import java.util.Collections;
import java.util.List;

/**
 * 限流处理
 *
 * @author shenguangyang
 */
@Aspect
@Component
public class RateLimitAspect {
    private static final Logger log = LoggerFactory.getLogger(RateLimitAspect.class);

    @Resource
    private RateLimitManager rateLimitManager;

    @Resource(name = "securityRedisTemplate")
    private RedisTemplate<String, Object> securityRedisTemplate;

    @Resource
    private RedisScript<Long> limitScript;

    @Resource
    private RateLimitProperties rateLimitProperties;

    @Before("@annotation(rateLimit)")
    public void doBefore(JoinPoint point, RateLimit rateLimit) throws Exception {
        String key = rateLimit.key();
        int time = rateLimit.time();
        int count = rateLimit.count();

        String combineKey = getCombineKey(rateLimit, point);
        List<String> keys = Collections.singletonList(combineKey);

        // 判断是否是黑名单
        rateLimitManager.checkBlacklist(combineKey, rateLimit);
        Long number = securityRedisTemplate.execute(limitScript, keys, count, time);
        if (StringUtils.isNull(number) || number.intValue() > count) {
            // 将ip添加到黑名单中
            rateLimitManager.addBlacklist(combineKey, rateLimit.limitType());

            throw new BizException(parseMsg(rateLimit.msg()));
        }
        log.debug("限制请求 [{}], 当前请求 [{}], 缓存key [{}]", count, number.intValue(), key);
    }

    public String parseMsg(String msg) {
        if (StringUtils.isEmpty(msg)) {
            return msg;
        }
        Duration limitTime = rateLimitProperties.getBlacklist().getLimitTime();
        long timeMillis = System.currentTimeMillis();
        String remainingTime = DateUtils.distance(timeMillis, timeMillis + limitTime.getSeconds() * 1000L, true);
        msg = msg.replace(RateLimitCons.REMAINING_TIME_ARG, remainingTime);
        return msg;
    }

    public String getCombineKey(RateLimit rateLimit, JoinPoint point) {
        StringBuilder stringBuffer = new StringBuilder(rateLimit.key());
        MethodSignature signature = (MethodSignature) point.getSignature();
        Method method = signature.getMethod();
        Class<?> targetClass = method.getDeclaringClass();
        stringBuffer.append(":").append(DigestUtil.md5Hex(targetClass.getName() + "-" + method.getName()));

        if (rateLimit.limitType() == LimitTypeEnum.IP) {
            HttpServletRequest request = ServletUtils.getRequest().orElseThrow(() -> new RuntimeException("request is null"));
            stringBuffer.append(":").append(IpUtils.getRequestIp(new HttpRequestServletManager(request)));
        }
        return stringBuffer.toString();
    }
}
