package cn.schoolwow.quickserver.handler;

import cn.schoolwow.quickserver.controller.annotation.BasicAuth;
import cn.schoolwow.quickserver.controller.annotation.CrossOrigin;
import cn.schoolwow.quickserver.domain.Client;
import cn.schoolwow.quickserver.response.HttpStatus;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.*;

/**注解处理器*/
public class AnnotationHandler implements Handler{
    private Logger logger = LoggerFactory.getLogger(AnnotationHandler.class);

    @Override
    public Handler handle(Client client) throws Exception {
        if(!handleBasicAuth(client)){
            return new HttpResponseHandler();
        }
        if(!handleCrossOrigin(client)){
            return new HttpResponseHandler();
        }
        return new ControllerHandler();
    }

    /**
     * 处理BasicAuth注解
     * @return 是否继续往下执行
     */
    public boolean handleBasicAuth(Client client) {
        BasicAuth basicAuth = client.controllerMeta.method.getAnnotation(BasicAuth.class);
        if (null == basicAuth) {
            basicAuth = client.controllerMeta.method.getDeclaringClass().getAnnotation(BasicAuth.class);
        }
        if (null == basicAuth) {
            return true;
        }
        logger.debug("处理BasicAuth注解,地址:{},用户名:{},密码:{}", client.httpRequestMeta.uri, basicAuth.username(), basicAuth.password());
        if (client.httpRequestMeta.headers.containsKey("Authorization")) {
            String authorization = client.httpRequestMeta.headers.get("Authorization").get(0);
            authorization = authorization.substring(authorization.indexOf("Basic ") + 6);
            String expectAuthorization = new String(Base64.getEncoder().encode((basicAuth.username() + ":" + basicAuth.password()).getBytes()));
            if(authorization.equals(expectAuthorization)){
                logger.debug("BasicAuth验证通过");
                return true;
            }
        }
        client.httpResponse.httpStatus(HttpStatus.UNAUTHORIZED);
        client.httpResponseMeta.headers.put("WWW-Authenticate", Arrays.asList("Basic realm=\"" + basicAuth.realm() + "\""));
        return false;
    }

    /**
     * 处理跨域请求注解
     */
    private boolean handleCrossOrigin(Client client) {
        CrossOrigin crossOriginAnnotation = client.controllerMeta.method.getDeclaredAnnotation(CrossOrigin.class);
        if (null == crossOriginAnnotation) {
            crossOriginAnnotation = client.controllerMeta.method.getDeclaringClass().getDeclaredAnnotation(CrossOrigin.class);
        }
        if (null == crossOriginAnnotation) {
            return true;
        }
        logger.debug("处理CrossOrigin注解,方法名:{}",client.controllerMeta.method.getName());
        if (!client.httpRequestMeta.headers.containsKey("Origin")) {
            logger.warn("CrossOrigin服务端为跨域请求,但客户端无Origin头部!");
            client.httpResponse.httpStatus(HttpStatus.BAD_REQUEST);
            return false;
        }
        String origin = client.httpRequestMeta.headers.get("Origin").get(0);
        //检查origin
        boolean allowOrigin = false;
        if (crossOriginAnnotation.origins().length == 0) {
            allowOrigin = true;
        } else {
            String[] crossOrigins = crossOriginAnnotation.origins();
            for (String crossOrigin : crossOrigins) {
                if (origin.equalsIgnoreCase(crossOrigin) || "*".equalsIgnoreCase(crossOrigin)) {
                    allowOrigin = true;
                    break;
                }
            }
        }
        if (!allowOrigin) {
            logger.warn("CrossOrigin,origin头部不匹配,跨域失败!客户端Origin头部:{}",origin);
            client.httpResponse.httpStatus(HttpStatus.BAD_REQUEST);
            return false;
        }
        //检查请求方法
        String accessControlRequestMethod = null;
        if (client.httpRequestMeta.headers.containsKey("Access-Control-Allow-Methods")) {
            accessControlRequestMethod = client.httpRequestMeta.headers.get("Access-Control-Allow-Methods").get(0);
        }
        if (null != accessControlRequestMethod && crossOriginAnnotation.methods().length > 0) {
            boolean allowMethod = false;
            for (String method : crossOriginAnnotation.methods()) {
                if (accessControlRequestMethod.equalsIgnoreCase(method)) {
                    allowMethod = true;
                    break;
                }
            }
            if (!allowMethod) {
                logger.warn("CrossOrigin请求方法不允许!客户端头部Access-Control-Allow-Methods:{}",accessControlRequestMethod);
                client.httpResponse.httpStatus(HttpStatus.BAD_REQUEST);
                return false;
            }
        }
        //检查请求头部
        String accessControlRequestHeaders = null;
        if (client.httpRequestMeta.headers.containsKey("Access-Control-Allow-Headers")) {
            accessControlRequestHeaders = client.httpRequestMeta.headers.get("Access-Control-Allow-Headers").get(0);
        }
        if (null != accessControlRequestHeaders && crossOriginAnnotation.headers().length > 0) {
            String[] requestHeaders = accessControlRequestHeaders.split(",");
            for (String allowHeader : crossOriginAnnotation.headers()) {
                boolean exist = false;
                for (String requestHeader : requestHeaders) {
                    if ("*".equals(allowHeader) || requestHeader.equals(allowHeader)) {
                        exist = true;
                        break;
                    }
                }
                if (!exist) {
                    logger.warn("CrossOrigin,请求头部不允许!客户端头部Access-Control-Allow-Headers:{}",accessControlRequestHeaders);
                    client.httpResponse.httpStatus(HttpStatus.BAD_REQUEST);
                    return false;
                }
            }
        }
        //缓存跨域响应头部信息
        if (!client.serverConfigMeta.crossOriginMap.containsKey(client.httpRequestMeta.uri)) {
            //设置跨域头部信息
            Map<String, List<String>> crossOriginMap = new HashMap<>();
            crossOriginMap.put("Access-Control-Allow-Origin", Arrays.asList(crossOriginAnnotation.origins().length == 0 ? "*" : origin));
            crossOriginMap.put("Access-Control-Max-Age", Arrays.asList(crossOriginAnnotation.maxAge() + ""));
            if (crossOriginAnnotation.allowCredentials()) {
                crossOriginMap.put("Access-Control-Allow-Credentials", Arrays.asList("true"));
            }
            if (crossOriginAnnotation.methods().length > 0) {
                StringBuffer stringBuffer = new StringBuffer();
                for (String method : crossOriginAnnotation.methods()) {
                    stringBuffer.append(method + ",");
                }
                stringBuffer.deleteCharAt(stringBuffer.length() - 1);
                crossOriginMap.put("Access-Control-Allow-Methods", Arrays.asList(stringBuffer.toString()));
            } else if (null != accessControlRequestMethod) {
                crossOriginMap.put("Access-Control-Allow-Methods", Arrays.asList(accessControlRequestMethod));
            }
            StringBuilder builder = new StringBuilder();
            if (crossOriginAnnotation.headers().length > 0) {
                for (String header : crossOriginAnnotation.headers()) {
                    builder.append(header + ",");
                }
                builder.deleteCharAt(builder.length() - 1);
                crossOriginMap.put("Access-Control-Allow-Headers", Arrays.asList(builder.toString()));
            } else if (null != accessControlRequestHeaders) {
                crossOriginMap.put("Access-Control-Allow-Headers", Arrays.asList(accessControlRequestHeaders));
            }
            if (crossOriginAnnotation.exposedHeaders().length > 0) {
                builder.setLength(0);
                for (String exposedHeader : crossOriginAnnotation.exposedHeaders()) {
                    builder.append(exposedHeader + ",");
                }
                builder.deleteCharAt(builder.length() - 1);
                crossOriginMap.put("Access-Control-Expose-Headers", Arrays.asList(builder.toString()));
            }
            client.serverConfigMeta.crossOriginMap.put(client.httpRequestMeta.uri, crossOriginMap);
        }
        client.httpResponseMeta.headers.putAll(client.serverConfigMeta.crossOriginMap.get(client.httpRequestMeta.uri));
        return true;
    }
}
