/*
 * Copyright (c) SinoDawn 2021.
 */

package net.sinodawn.framework.security.firewall;

import net.sinodawn.framework.security.exception.FirewallDeniedException;
import net.sinodawn.framework.utils.StringUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;

import javax.servlet.*;
import javax.servlet.annotation.WebFilter;
import javax.servlet.http.HttpServletRequest;
import java.io.IOException;
import java.util.Iterator;
import java.util.List;

@WebFilter
@Component
public class SecurityFirewallFilter implements Filter {
   @Value("#{'${sino.security.access-host-list}'.split(',')}")
   private List<String> accessHostList;

   public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
      if (!this.isAccessHostListEmpty()) {
         this.hostHeaderFirewall(request);
         this.csrfFirewall(request);
      }

      chain.doFilter(request, response);
   }

   private void hostHeaderFirewall(ServletRequest request) {
      String host = ((HttpServletRequest)request).getHeader("host");
      if (!StringUtils.isBlank(host) && !this.isAccessableHost(host)) {
         throw new FirewallDeniedException();
      }
   }

   private void csrfFirewall(ServletRequest request) {
      String referer = ((HttpServletRequest)request).getHeader("referer");
      if (!StringUtils.isBlank(referer) && !this.isAcccessableReferer(referer)) {
         throw new FirewallDeniedException();
      }
   }

   private boolean isAccessHostListEmpty() {
      if (this.accessHostList != null && !this.accessHostList.isEmpty()) {
         return this.accessHostList.size() == 1 && StringUtils.isBlank((String)this.accessHostList.get(0));
      } else {
         return true;
      }
   }

   private boolean isAccessableHost(String host) {
      Iterator var2 = this.accessHostList.iterator();

      String accessHost;
      do {
         if (!var2.hasNext()) {
            return false;
         }

         accessHost = (String)var2.next();
         if (accessHost.equals(host)) {
            return true;
         }

         if (StringUtils.endsWith(accessHost, ":80") && StringUtils.removeEnd(accessHost, ":80").equals(host)) {
            return true;
         }
      } while(!StringUtils.endsWith(host, ":80") || !accessHost.equals(StringUtils.removeEnd(host, ":80")));

      return true;
   }

   private boolean isAcccessableReferer(String referer) {
      Iterator var2 = this.accessHostList.iterator();

      while(var2.hasNext()) {
         String accessHost = (String)var2.next();
         if (!StringUtils.startsWithIgnoreCase(referer, "http://" + accessHost + "/") && !StringUtils.startsWithIgnoreCase(referer, "https://" + accessHost + "/")) {
            if (StringUtils.endsWith(accessHost, ":80")) {
               String noPortHost = StringUtils.removeEnd(accessHost, ":80");
               if (StringUtils.startsWithIgnoreCase(referer, "http://" + noPortHost + "/") || StringUtils.startsWithIgnoreCase(referer, "https://" + noPortHost + "/")) {
                  return true;
               }
            }

            if (StringUtils.contains(accessHost, ":") || !StringUtils.startsWithIgnoreCase(referer, "http://" + accessHost + ":80/") && !StringUtils.startsWithIgnoreCase(referer, "https://" + accessHost + ":80/")) {
               continue;
            }

            return true;
         }

         return true;
      }

      return false;
   }
}
