package com.github.yingzhuo.carnival.shield.core;

import com.github.yingzhuo.carnival.shield.algorithm.Algorithm;
import java.io.IOException;
import java.nio.charset.Charset;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletOutputStream;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.springframework.util.StringUtils;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerMapping;

/* loaded from: input_file:com/github/yingzhuo/carnival/shield/core/ShieldFilter.class */
public class ShieldFilter extends AbstractShieldFilter {
    private final Algorithm algorithm;
    private final Charset charset;

    public ShieldFilter(RequestMappingHandlerMapping requestMappingHandlerMapping, Algorithm algorithm, Charset charset) {
        super(requestMappingHandlerMapping);
        this.algorithm = algorithm;
        this.charset = charset;
    }

    protected void doFilterInternal(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse, FilterChain filterChain) throws ServletException, IOException {
        HandlerMethod handlerMethod = super.getHandlerMethod(httpServletRequest);
        if (handlerMethod == null) {
            filterChain.doFilter(httpServletRequest, httpServletResponse);
            return;
        }
        boolean shouldEncrypt = shouldEncrypt(handlerMethod);
        boolean shouldDecrypt = shouldDecrypt(handlerMethod);
        if (!shouldEncrypt && !shouldDecrypt) {
            filterChain.doFilter(httpServletRequest, httpServletResponse);
            return;
        }
        DecryptionRequest decryptionRequest = new DecryptionRequest(httpServletRequest);
        EncryptionResponse encryptionResponse = new EncryptionResponse(httpServletResponse);
        if (shouldDecrypt) {
            processDecryption(decryptionRequest, httpServletRequest);
        }
        if (shouldEncrypt && shouldDecrypt) {
            filterChain.doFilter(decryptionRequest, encryptionResponse);
        } else if (shouldEncrypt) {
            filterChain.doFilter(httpServletRequest, encryptionResponse);
        } else {
            filterChain.doFilter(decryptionRequest, httpServletResponse);
        }
        if (shouldEncrypt) {
            writeEncryptedContent(encryptionResponse.getResponseBodyAsString(), httpServletResponse);
        }
    }

    private void writeEncryptedContent(String str, ServletResponse servletResponse) throws IOException {
        ServletOutputStream servletOutputStream = null;
        try {
            try {
                String encrypt = this.algorithm.encrypt(str);
                servletResponse.setContentLength(encrypt.length());
                servletResponse.setCharacterEncoding(this.charset.displayName());
                servletOutputStream = servletResponse.getOutputStream();
                servletOutputStream.write(encrypt.getBytes(this.charset));
                if (servletOutputStream != null) {
                    servletOutputStream.flush();
                    servletOutputStream.close();
                }
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        } catch (Throwable th) {
            if (servletOutputStream != null) {
                servletOutputStream.flush();
                servletOutputStream.close();
            }
            throw th;
        }
    }

    private void processDecryption(DecryptionRequest decryptionRequest, HttpServletRequest httpServletRequest) {
        String requestBodyAsString = decryptionRequest.getRequestBodyAsString();
        try {
            if (!StringUtils.endsWithIgnoreCase(httpServletRequest.getMethod(), RequestMethod.GET.name())) {
                decryptionRequest.setRequestBody(this.algorithm.decrypt(requestBodyAsString).getBytes(this.charset));
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
}
