package net.takela.auth.access.filter;

import jakarta.servlet.*;
import jakarta.servlet.http.HttpServletRequest;
import net.takela.auth.access.conf.AccessKeySignProperties;
import net.takela.auth.access.model.AkSk;
import net.takela.auth.access.service.AkSkService;
import net.takela.common.spring.exception.AuthException;
import net.takela.common.spring.filter.HttpRequestCachedServlet;
import org.apache.commons.codec.digest.HmacUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.AntPathMatcher;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;

/**
 *
 */
public class AccessKeySignFilter implements Filter {
    private final AccessKeySignProperties signProperties;
    private final AkSkService akSkService;
    private final AntPathMatcher antPathMatcher = new AntPathMatcher();

    /**
     *
     * @param signProperties
     * @param akSkService
     */
    public AccessKeySignFilter(AccessKeySignProperties signProperties, AkSkService akSkService) {
        this.signProperties = signProperties;
        this.akSkService = akSkService;
    }

    /**
     *
     * @param timestamp
     * @return
     */
    private boolean isTimestampValid(long timestamp) {
        long timeDiff = Math.abs(System.currentTimeMillis() - timestamp);
        return timeDiff <= signProperties.getExpireTime();
    }

    /**
     *
     * @param servletRequest
     * @param servletResponse
     * @param filterChain
     * @throws IOException
     * @throws ServletException
     */
    @Override
    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
        HttpServletRequest httpServletRequest = (HttpServletRequest) servletRequest;
        //没有启用
        if (!signProperties.getEnabled()){
            filterChain.doFilter(servletRequest, servletResponse);
            return;
        }
        HttpRequestCachedServlet cachedServletRequest = new HttpRequestCachedServlet((HttpServletRequest)servletRequest);
        //如果允许匿名
        if ( !signProperties.getAnonymousUrls().isEmpty() ){
            boolean isAnonymous = signProperties.getAnonymousUrls().stream().anyMatch( url -> antPathMatcher.match(url, httpServletRequest.getRequestURI()));
            if (isAnonymous){
                filterChain.doFilter(cachedServletRequest, servletResponse);
                return;
            }
        }

//        boolean f = checkURLMatchers.stream().anyMatch((m)->{ return m.matches(request)?true:false;});
        //check-urls名单不为空
        if ( !signProperties.getCheckUrls().isEmpty() ){
            boolean f = signProperties.getCheckUrls().stream().anyMatch( url -> antPathMatcher.match(url, httpServletRequest.getRequestURI()));
            //3、如果check-urls清单不为空，但是没有匹配上，说明不需要校验
            if (!f){
                filterChain.doFilter(servletRequest, servletResponse);
                return;
            }
        }
        //1、如果check-urls清单为空，说明全部需要校验
        //2、如果check-urls清单不为空，并且匹配上了，则需要校验
        // 获取请求头中的认证信息
        String ak = httpServletRequest.getHeader(signProperties.getAccessKeyHeaderName());
        String timestamp = httpServletRequest.getHeader(signProperties.getTimestampHeaderName());
        String sign = httpServletRequest.getHeader(signProperties.getSignHeaderName());

        // 1. 检查必要请求头是否存在
        if (StringUtils.isEmpty(ak) || StringUtils.isEmpty(timestamp) || StringUtils.isEmpty(sign)) {
            throw new AuthException("Miss some header");
        }

        // 2. 检查时间戳有效性（例如允许5分钟内的请求）
        if (!isTimestampValid(Long.parseLong(timestamp))) {
            throw new AuthException("Time error");
        }

        // 3. 根据AK查询SK
        AkSk akSk = akSkService.getByAk(ak);
        if (akSk == null) {
            throw new AuthException("Invalid access key");
        }
        // 4. 服务端重新生成签名
        String serverSign;
        try {
            String dataToSign = buildSign(httpServletRequest, ak, timestamp);
            serverSign = HmacUtils.hmacSha256Hex(akSk.getSk(), dataToSign);
        } catch (Exception e) {
            throw  new AuthException("Invalid request");
        }
        // 5. 比对签名
        if (!serverSign.equalsIgnoreCase(sign)) {
            throw  new AuthException("Invalid request");
        }

        // 验证通过，放行请求
        filterChain.doFilter(servletRequest, servletResponse);
    }

    /**
     *
     * @param request
     * @param ak
     * @param timestamp
     * @return
     * @throws IOException
     */
    private String buildSign(HttpServletRequest request, String ak, String timestamp) throws IOException {
        StringBuilder sb = new StringBuilder();
        sb.append(timestamp)
                .append(ak)
                .append(request.getMethod())
                .append(request.getRequestURI());
        String queryString = request.getQueryString();
        sb.append(queryString);
        if (request.getContentType() != null && request.getContentType().toLowerCase().contains("application/json")){
            String body = getBody( request.getInputStream());
            sb.append(body);
        }
        return sb.toString();
    }

    /**
     *
     * @param inputStream
     * @return
     */
    private String getBody(InputStream inputStream){
        StringBuilder sb = new StringBuilder();
        BufferedReader reader = null;
        try
        {
            reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8));
            char[] buffer = new char[1024];
            int bytesRead = -1;
            while ( (bytesRead = reader.read(buffer) ) > 0 )
            {
                sb.append(buffer, 0, bytesRead);
            }
        }
        catch (IOException e)
        {
            e.printStackTrace();
        }
        finally
        {
            if (reader != null)
            {
                try
                {
                    reader.close();
                }
                catch (IOException e)
                {
                    e.printStackTrace();
                }
            }
        }
        return sb.toString();
    }
}
