package org.apache.dubbo.rpc.protocol.tri.websocket.jakarta;

import jakarta.servlet.Filter;
import jakarta.servlet.FilterChain;
import jakarta.servlet.FilterConfig;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.ServletResponse;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletRequestWrapper;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.websocket.server.ServerContainer;
import jakarta.websocket.server.ServerEndpointConfig;
import java.io.IOException;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import org.apache.dubbo.common.constants.LoggerCodeConstants;
import org.apache.dubbo.common.logger.ErrorTypeAwareLogger;
import org.apache.dubbo.common.logger.LoggerFactory;
import org.apache.dubbo.common.utils.ConcurrentHashSet;
import org.apache.dubbo.remoting.http12.HttpMethods;
import org.apache.dubbo.rpc.protocol.tri.websocket.WebSocketConstants;

/* loaded from: input_file:org/apache/dubbo/rpc/protocol/tri/websocket/jakarta/TripleWebSocketFilter.class */
public class TripleWebSocketFilter implements Filter {
    private static final ErrorTypeAwareLogger LOG = LoggerFactory.getErrorTypeAwareLogger((Class<?>) TripleWebSocketFilter.class);
    private transient ServerContainer sc;
    private final Set<String> existed = new ConcurrentHashSet();

    public void init(FilterConfig filterConfig) {
        this.sc = (ServerContainer) filterConfig.getServletContext().getAttribute(ServerContainer.class.getName());
    }

    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
        if (!isWebSocketUpgradeRequest(servletRequest, servletResponse)) {
            filterChain.doFilter(servletRequest, servletResponse);
            return;
        }
        HttpServletRequest httpServletRequest = (HttpServletRequest) servletRequest;
        HttpServletResponse httpServletResponse = (HttpServletResponse) servletResponse;
        String pathInfo = httpServletRequest.getPathInfo();
        String servletPath = pathInfo == null ? httpServletRequest.getServletPath() : httpServletRequest.getServletPath() + pathInfo;
        final HashMap hashMap = new HashMap(httpServletRequest.getParameterMap());
        hashMap.put(WebSocketConstants.TRIPLE_WEBSOCKET_REMOTE_ADDRESS, new String[]{httpServletRequest.getRemoteHost(), String.valueOf(httpServletRequest.getRemotePort())});
        HttpServletRequestWrapper httpServletRequestWrapper = new HttpServletRequestWrapper(httpServletRequest) { // from class: org.apache.dubbo.rpc.protocol.tri.websocket.jakarta.TripleWebSocketFilter.1
            public Map<String, String[]> getParameterMap() {
                return hashMap;
            }
        };
        if (this.existed.contains(servletPath)) {
            filterChain.doFilter(httpServletRequestWrapper, httpServletResponse);
            return;
        }
        try {
            this.sc.addEndpoint(ServerEndpointConfig.Builder.create(TripleEndpoint.class, servletPath).build());
            this.existed.add(servletPath);
            filterChain.doFilter(httpServletRequestWrapper, httpServletResponse);
        } catch (Exception e) {
            LOG.error(LoggerCodeConstants.PROTOCOL_FAILED_REQUEST, "", "", "Failed to add endpoint", e);
            httpServletResponse.sendError(400);
        }
    }

    public void destroy() {
    }

    public boolean isWebSocketUpgradeRequest(ServletRequest servletRequest, ServletResponse servletResponse) {
        return (servletRequest instanceof HttpServletRequest) && (servletResponse instanceof HttpServletResponse) && headerContainsToken((HttpServletRequest) servletRequest, "Upgrade", WebSocketConstants.TRIPLE_WEBSOCKET_UPGRADE_HEADER_VALUE) && HttpMethods.GET.name().equals(((HttpServletRequest) servletRequest).getMethod());
    }

    private boolean headerContainsToken(HttpServletRequest httpServletRequest, String str, String str2) {
        Enumeration headers = httpServletRequest.getHeaders(str);
        while (headers.hasMoreElements()) {
            for (String str3 : ((String) headers.nextElement()).split(",")) {
                if (str2.equalsIgnoreCase(str3.trim())) {
                    return true;
                }
            }
        }
        return false;
    }
}
