package cn.springlet.redis.aspect;


import cn.springlet.core.exception.web_return.DisableReplayException;
import cn.springlet.core.exception.web_return.ParameterVerificationException;
import cn.springlet.core.exception.web_return.ReturnMsgException;
import cn.springlet.core.util.ServletUtil;
import cn.springlet.core.util.StrUtil;
import cn.springlet.redis.annotation.DisableReplay;
import cn.springlet.redis.constant.RedisCacheKey;
import lombok.extern.slf4j.Slf4j;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.annotation.Pointcut;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.core.annotation.Order;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.stereotype.Component;

import javax.servlet.http.HttpServletRequest;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;

/**
 * 禁止 请求重放切面
 *
 * @author springlet
 * @time 2020/11/5
 */
@Aspect
@Component
@Order(-3)
@Slf4j
public class DisableReplayAspect {

    @Autowired
    private RedisTemplate<Object, Object> redisTemplate;

    @Autowired
    @Qualifier("redisNoNewScript")
    private RedisScript<Long> noNewScript;

    /**
     * 定义点
     */
    @Pointcut("@annotation(cn.springlet.redis.annotation.DisableReplay)")
    public void pointcut() {
    }

    /**
     * 禁止 请求重放处理
     *
     * @param joinPoint
     * @return
     * @throws Throwable
     */
    @Before("pointcut() && @annotation(disableReplay)")
    public void handler(JoinPoint joinPoint, DisableReplay disableReplay) {
        int time = disableReplay.time();
        TimeUnit unit = disableReplay.unit();
        String timeStampFiledName = disableReplay.timeStampFiledName();
        String nonceFiledName = disableReplay.nonceFiledName();
        HttpServletRequest request = ServletUtil.getRequest();
        if (request == null) {
            log.error("DisableReplayAspect:请求上下文不存在");
            throw new ReturnMsgException("请求上下文不存在");
        }
        String timeStamp = request.getHeader(timeStampFiledName);
        String nonce = request.getHeader(nonceFiledName);
        //为 null 则去url 中取参数
        if (StrUtil.isBlank(timeStamp)) {
            timeStamp = request.getParameter(timeStampFiledName);
        }
        if (StrUtil.isBlank(nonce)) {
            nonce = request.getParameter(nonceFiledName);
        }

        if (StrUtil.isBlank(timeStamp)) {
            throw new ParameterVerificationException("时间戳不能为空");
        }
        if (StrUtil.isBlank(nonce)) {
            throw new ParameterVerificationException("随机字符串不能为空");
        }

        Long longTimeStamp = null;
        try {
            longTimeStamp = Long.parseLong(timeStamp);
        } catch (NumberFormatException e) {
            throw new ParameterVerificationException("时间戳格式错误");
        }

        long nowTimeStamp = System.currentTimeMillis();
        long millisTime = unit.toMillis(time);
        if (longTimeStamp > nowTimeStamp) {
            throw new ReturnMsgException("时间戳异常");
        } else {
            if (nowTimeStamp - longTimeStamp >= millisTime) {
                throw new ReturnMsgException("请求已过期");
            }
        }

        List<Object> keys = new ArrayList<>();
        keys.add(RedisCacheKey.DISABLE_REPLAY_KEY + nonce);
        try {
            Long result = redisTemplate.execute(noNewScript, keys, 1, (int) millisTime);
            if (result == null || result != 0) {
                throw new DisableReplayException("请求重放");
            }
        } catch (Exception e) {
            log.error("请求重放限制异常", e);
            throw new ReturnMsgException("服务器异常，请稍候再试");
        }
    }
}
