package security.service;

import framework.captcha.Captcha;
import framework.config.SecurityConfig;
import framework.crypto.GeneralCrypto;
import framework.security.*;
import framework.security.password.PasswordService;
import framework.security.token.AuthTokenBuilder;
import framework.security.token.AuthTokenInfo;
import framework.utils.RequestUtil;
import framework.utils.ServletUtil;
import lombok.Getter;
import org.springframework.util.StringUtils;
import org.springframework.web.context.request.ServletRequestAttributes;
import security.utils.SecurityAuthUtil;
import security.utils.SecurityCookieUtil;
import security.vo.LoginInfo;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.util.Date;

/**
 * 登录服务基础类型
 */
public class LoginServiceDefault implements LoginService {

    @Getter
    private final SecurityConfig securityConfig;
    @Getter
    private final AccountLoader accountLoader;
    @Getter
    private final CaptchaFlagAdmin captchaFlagAdmin;
    @Getter
    private final GeneralCrypto generalCrypto;
    @Getter
    private final PasswordService passwordService;
    @Getter
    private final AuthTokenBuilder authTokenBuilder;
    @Getter
    private final Captcha captcha;
    @Getter
    private final AuthService authService;

    public LoginServiceDefault(
            SecurityConfig securityConfig,
            AccountLoader accountLoader,
            CaptchaFlagAdmin captchaFlagAdmin,
            GeneralCrypto generalCrypto,
            PasswordService passwordService,
            AuthTokenBuilder authTokenBuilder,
            AuthService authService,
            Captcha captcha
    ) {
        this.securityConfig = securityConfig;
        this.accountLoader = accountLoader;
        this.captchaFlagAdmin = captchaFlagAdmin;
        this.generalCrypto = generalCrypto;
        this.passwordService = passwordService;
        this.authTokenBuilder = authTokenBuilder;
        this.authService = authService;
        this.captcha = captcha;
    }

    /**
     * 登录前
     *
     * @param loginEntity
     * @param loginInfo
     */
    protected void onLoginBefore(LoginEntity loginEntity, LoginInfo loginInfo) throws AuthException {
    }

    /**
     * 登录失败后的数据处理与记录
     *
     * @param loginEntity
     * @param exception
     */
    protected void onLoginFailed(LoginEntity loginEntity, AuthException exception) {
        if (exception.getAuthCode() == AuthCode.CAPTCHA_ERROR) {
            // no log
        } else if (exception.getAuthCode() == AuthCode.STATUS_UNAVAILABLE) {
            // no log
        } else {
            getAccountLoader().loginUnsuccessful(loginEntity.getUsername(), exception.getMessage());
        }
        // 记录登录异常，要求下次登录使用验证码
        if (StringUtils.hasText(loginEntity.getUsername())) {
            getCaptchaFlagAdmin().setFlag(loginEntity.getUsername());
        }
    }

    /**
     * 登录成功后的数据处理与记录
     *
     * @param loginEntity
     * @param loginInfo
     */
    protected void onLoginSuccess(LoginEntity loginEntity, LoginInfo loginInfo) {
        //更新数据库
        getAccountLoader().loginSuccessful(loginEntity.getUsername(), loginInfo.getId());
        //移除可能存在的登录异常验证码要求记录
        getCaptchaFlagAdmin().remove(loginEntity.getUsername());
    }

    /**
     * 登出前的数据处理与记录
     *
     * @param accountId
     */
    protected void onLogoutBefore(Long accountId) throws AuthException {
    }

    /**
     * 登出成功的数据处理与记录
     *
     * @param accountId
     */
    protected void onLogoutSuccess(Long accountId) {
        //更新数据库
        getAccountLoader().logoutSuccessful(accountId);
    }

    /**
     * 登录失败
     *
     * @param accountId
     * @param exception
     */
    protected void onLogoutFailed(Long accountId, AuthException exception) {

    }

    /**
     * 登入
     *
     * @param loginEntity
     * @return
     */
    @Override
    public LoginInfo login(LoginEntity loginEntity) throws AuthException {
        LoginInfo loginInfo = new LoginInfo();
        try {
            this.doLogin(loginEntity, loginInfo);
        } catch (AuthException exception) {
            this.onLoginFailed(loginEntity, exception);
            throw exception;
        } catch (Exception exception) {
            throw new RuntimeException(exception.getMessage(), exception);
        }
        this.onLoginSuccess(loginEntity, loginInfo);
        return loginInfo;
    }

    private void doLogin(LoginEntity loginEntity, LoginInfo loginInfo) throws AuthException {
        this.onLoginBefore(loginEntity, loginInfo);

        // check captcha
        if (getSecurityConfig().getEnableLoginCaptcha()) {
            String captcha = loginEntity.getCaptcha();
            String captchaId = loginEntity.getCaptchaId();

            // is captcha check
            boolean isCheckCaptcha = StringUtils.hasText(captchaId) && StringUtils.hasText(captcha);

            // is need captcha check
            if (!isCheckCaptcha) {
                // need check to have login failed,
                // permit through from no failed
                isCheckCaptcha = getSecurityConfig().getEnableLoginCaptcha() && this.getCaptchaFlagAdmin().hasFlag(loginEntity.getUsername());
            }

            // captcha check
            if (isCheckCaptcha) {
                this.checkCaptcha(captchaId, captcha);
            }
        }

        // login fail limit
        this.checkLoginLimit(loginEntity);

        // find account
        Account account = getAccountLoader().loadUserByUsername(loginEntity.getUsername());
        if (account == null) {
            throw new AuthException(AuthCode.USERNAME_OR_PASSWORD_ERROR, RequestUtil.getMessageDefault("security.userNotFound", "username or password error"));
        }

        // check account
        this.checkAccount(loginEntity, account);

        // check password
        try {
            this.checkPassword(loginEntity, account);
        } catch (AuthException exception) {
            if (getSecurityConfig().getEnableLoginCaptcha()) {
                // enable user captcha
                getCaptchaFlagAdmin().setFlag(loginEntity.getUsername());
            }
            throw exception;
        }

        // create token
        String token = this.createToken(loginEntity, account);

        // to writer to cookie
        this.tokenToCookie(loginEntity, account, token);

        // fill login info
        loginInfo.setId(account.getId());
        loginInfo.setName(account.getName());
        loginInfo.setUsername(loginEntity.getUsername());
        loginInfo.setPasswordChanged(account.passwordMustChanged() ? 1 : 0);
        loginInfo.setToken(token);

        // set security context
        this.fillSecurityContext(account, token);
    }

    /**
     * 填充授权上下文
     *
     * @param account
     * @param token
     */
    protected void fillSecurityContext(Account account, String token) {
        SecurityAuthUtil.authorized(account.getId(), token);
    }

    /**
     * 登出
     *
     * @return
     */
    @Override
    public void logout() throws AuthException {
        if (getAuthService().isAuthenticated()) {
            Long accountId = getAuthService().getAccountId();
            try {
                this.onLogoutBefore(accountId);
                this.doLogout(accountId);
            } catch (AuthException exception) {
                this.onLogoutFailed(accountId, exception);
                throw exception;
            } catch (Exception exception) {
                throw new RuntimeException(exception.getMessage(), exception);
            }
            this.onLogoutSuccess(accountId);
        }
    }

    /**
     * 登出
     *
     * @param accountId
     */
    private void doLogout(Long accountId) {
    }

    /**
     * 密码验证
     *
     * @param loginEntity
     * @param account
     */
    protected void checkPassword(LoginEntity loginEntity, Account account) throws AuthException {
        String password = loginEntity.getPassword();
        // password check
        if (password == null) {
            throw new AuthException(AuthCode.USERNAME_OR_PASSWORD_ERROR, RequestUtil.getMessageDefault("AbstractUserDetailsAuthenticationProvider.badCredentials", "Bad credentials"));
        }

        // decrypt password
        if ("gc".equals(loginEntity.getPasswordCipher())) {
            try {
                password = getGeneralCrypto().decryptFromBase64AsString(password);
            } catch (Exception exception) {
                throw new AuthException(AuthCode.REQUEST_INVALID, "Password decrypt failed, please refresh page and submit again");
            }
        }

        // check password
        String salt = account.getPasswordSalt();
        String encodePassword = account.getPassword();
        if (!getPasswordService().matched(password, salt, encodePassword)) {
            throw new AuthException(AuthCode.USERNAME_OR_PASSWORD_ERROR, RequestUtil.getMessageDefault("security.userNotFound", "username or password error"));
        }
    }

    /**
     * 验证码验证
     *
     * @param captchaId
     * @param captcha
     */
    protected void checkCaptcha(String captchaId, String captcha) throws AuthException {
        if (!StringUtils.hasText(captcha)) {
            throw new AuthException(AuthCode.CAPTCHA_ERROR, RequestUtil.getMessageDefault("security.captcha.empty", "Please input captcha code"));
        }

        if (!StringUtils.hasText(captchaId)) {
            throw new AuthException(AuthCode.REQUEST_INVALID, "Not set captchaId");
        }

        boolean checkSuccess = getCaptcha().check(captchaId, captcha);
        if (checkSuccess) {
            getCaptcha().remove(captchaId);
        } else {
            throw new AuthException(AuthCode.CAPTCHA_ERROR, RequestUtil.getMessageDefault("security.captcha.invalid", "Captcha code error"));
        }
    }

    /**
     * 验证账户属性
     *
     * @param loginEntity
     * @param account
     */
    protected void checkAccount(LoginEntity loginEntity, Account account) throws AuthException {
        // attr check
        try {
            account.statusCheck();
        } catch (Exception exception) {
            throw new AuthException(AuthCode.STATUS_UNAVAILABLE, exception.getMessage());
        }
        // check registration approved
        if (account.getRegApproval() == null) {
            //
        } else if (RegApproval.Agree.equals(account.getRegApproval())) {
            //
        } else if (RegApproval.Reject.equals(account.getRegApproval())) {
            throw new AuthException(AuthCode.STATUS_UNAVAILABLE, RequestUtil.getMessageDefault("security.userRejectApproved", "Registration approval is rejected, the account is unavailable"));
        } else if (RegApproval.Waiting.equals(account.getRegApproval())) {
            throw new AuthException(AuthCode.STATUS_UNAVAILABLE, RequestUtil.getMessageDefault("security.userWaitingApproved", "Waiting for registration approval, the account is unavailable"));
        }
    }

    /**
     * 登录超限检查
     *
     * @param loginEntity
     */
    protected void checkLoginLimit(LoginEntity loginEntity) throws AuthException {
        int loginFailLimit = getAccountLoader().loginFailLimit(loginEntity.getUsername());
        if (loginFailLimit > 0) {
            throw new AuthException(AuthCode.STATUS_UNAVAILABLE, RequestUtil.getMessageDefault("security.loginFailLimit"
                    , "Login failure too many times, limit login {0} minutes"
                    , ((int) Math.ceil(loginFailLimit / 60d)) + ""));
        }
    }

    /**
     * 写入TOKEN到Cookie
     *
     * @param loginEntity
     * @param account
     * @param token
     */
    protected void tokenToCookie(LoginEntity loginEntity, Account account, String token) throws AuthException {
        SecurityConfig securityConfig = getSecurityConfig();
        if (!StringUtils.hasText(securityConfig.getCookieTokenName())) return;
        ServletRequestAttributes attributes = ServletUtil.getRequestAttributes();
        if (attributes == null) return;
        HttpServletResponse response = attributes.getResponse();
        if (response == null) return;
        //
        String cookiePath = securityConfig.getCookiePath();
        if (!StringUtils.hasLength(cookiePath)) {
            ServletRequestAttributes requestAttributes = ServletUtil.getRequestAttributes();
            if (requestAttributes != null) {
                HttpServletRequest request = requestAttributes.getRequest();
                if (request != null) {
                    if (StringUtils.hasText(request.getContextPath()) && !request.getContextPath().equals("/")) {
                        cookiePath = request.getContextPath();
                    }
                }
            }
        }
        // fix HttpCookie not support SameSite
        // use custom build
        StringBuffer buffer = SecurityCookieUtil.buildCookie(securityConfig, securityConfig.getCookieTokenName(), token, cookiePath, loginEntity.getRememberMe());
        response.addHeader("Set-Cookie", buffer.toString());
    }

    /**
     * token创建
     *
     * @param loginEntity
     * @param account
     */
    protected String createToken(LoginEntity loginEntity, Account account) throws AuthException {
        Date now = new Date();
        Date expired = new Date(now.getTime() + getSecurityConfig().getTokenSeconds() * 1000);
        String token = getAuthTokenBuilder().encode(new AuthTokenInfo(account.getId(), expired));
        return token;
    }
}
