package net.guerlab.smart.license.service.utils;

import io.jsonwebtoken.*;
import net.guerlab.commons.exception.ApplicationException;
import net.guerlab.commons.number.NumberHelper;
import net.guerlab.smart.license.core.domain.LicenseExtends;
import net.guerlab.smart.license.core.exception.LicenseBeOverdueException;
import net.guerlab.smart.license.core.exception.LicenseInvalidException;
import net.guerlab.smart.license.core.exception.LicenseParseFailException;
import net.guerlab.smart.license.service.entity.License;
import org.apache.commons.lang3.StringUtils;

import java.security.Key;
import java.security.KeyFactory;
import java.security.spec.PKCS8EncodedKeySpec;
import java.security.spec.X509EncodedKeySpec;
import java.time.LocalDate;
import java.time.ZoneId;
import java.util.Base64;
import java.util.Date;
import java.util.Objects;

/**
 * 授权码工具类
 *
 * @author guer
 */
public class LicenseHelper {

    public static final String ALGORITHM = "RSA";

    public static final String EXTENDS_PREFIX = "extends:";

    public static final int EXTENDS_PREFIX_LENGTH = EXTENDS_PREFIX.length();

    public static final String LICENSE_FILE_CONTENT_SEPARATOR = ",";

    public static final int LICENSE_FILE_CONTENT_LENGTH = 6;

    private static final String KEY_LICENSE_ID = "licenseId";

    private static final String KEY_LICENSE_GROUP_ID = "licenseGroupId";

    private static final String KEY_LICENSE_GROUP_NAME = "licenseGroupName";

    private LicenseHelper() {
    }

    /**
     * 检查两个证书是否匹配
     *
     * @param a
     *         a证书
     * @param b
     *         b证书
     * @return 是否匹配
     */
    public static boolean match(License a, License b) {
        if (a == null || b == null) {
            return false;
        }
        if (!Objects.equals(a.getLicenseId(), b.getLicenseId())) {
            return false;
        } else if (!Objects.equals(a.getLicenseTo(), b.getLicenseTo())) {
            return false;
        } else if (!Objects.equals(a.getLicenseGroupId(), b.getLicenseGroupId())) {
            return false;
        } else if (!Objects.equals(a.getLicenseStartDate(), b.getLicenseStartDate())) {
            return false;
        } else if (!Objects.equals(a.getLicenseEndDate(), b.getLicenseEndDate())) {
            return false;
        }

        return true;
    }

    /**
     * 创建授权码内容
     *
     * @param license
     *         授权对象
     * @param rsaKey
     *         公私钥
     * @return 授权码内容
     */
    public static String createLicenseCodeData(License license, RsaKey rsaKey) {
        JwtBuilder builder = Jwts.builder();
        builder.setHeaderParam("typ", "JWT");
        builder.setSubject(license.getLicenseTo());
        builder.setNotBefore(toDate(license.getLicenseStartDate()));
        builder.setExpiration(toDate(license.getLicenseEndDate()));
        builder.claim(KEY_LICENSE_ID, license.getLicenseId());
        builder.claim(KEY_LICENSE_GROUP_ID, license.getLicenseGroupId());
        builder.claim(KEY_LICENSE_GROUP_NAME, license.getLicenseGroupName());

        LicenseExtends licenseExtends = license.getLicenseExtends();
        if (licenseExtends != null && !licenseExtends.isEmpty()) {
            licenseExtends.forEach((key, value) -> builder.claim(EXTENDS_PREFIX + key, value));
        }

        try {
            Key key = KeyFactory.getInstance(ALGORITHM)
                    .generatePrivate(new PKCS8EncodedKeySpec(rsaKey.getPrivateKey()));

            builder.signWith(SignatureAlgorithm.RS512, key);

            return builder.compact();
        } catch (Exception e) {
            throw new ApplicationException(e.getLocalizedMessage(), e);
        }
    }

    /**
     * 解析授权文件内容
     *
     * @param licenseFileCode
     *         授权文件内容
     * @return 授权信息
     */
    public static License parseLicense(String licenseFileCode) {
        if (StringUtils.isBlank(licenseFileCode)) {
            throw new LicenseInvalidException();
        }

        String[] licenseFileContents = new String(Base64.getDecoder().decode(licenseFileCode))
                .split(LICENSE_FILE_CONTENT_SEPARATOR);

        if (licenseFileContents.length != LICENSE_FILE_CONTENT_LENGTH) {
            throw new LicenseInvalidException();
        }

        Long licenseId;
        LocalDate startDate;
        LocalDate endDate;
        try {
            licenseId = Long.parseLong(licenseFileContents[0]);
            startDate = LocalDate.parse(licenseFileContents[2]);
            endDate = LocalDate.parse(licenseFileContents[3]);
        } catch (Exception e) {
            throw new LicenseInvalidException();
        }

        String licenseTo = licenseFileContents[1];
        String licenseCode = licenseFileContents[4];
        String publicKey = licenseFileContents[5];

        if (!NumberHelper.greaterZero(licenseId)) {
            throw new LicenseInvalidException();
        }

        Claims claims = LicenseHelper.parse(licenseCode, new RsaKey().setPublicKey(publicKey));

        if (!Objects.equals(claims.getSubject(), licenseTo) || !Objects
                .equals(claims.get(KEY_LICENSE_ID, Long.class), licenseId)) {
            throw new LicenseInvalidException();
        }

        LocalDate now = LocalDate.now();

        License license = new License();
        license.setLicenseId(licenseId);
        license.setLicenseStartDate(startDate);
        license.setLicenseEndDate(endDate);
        license.setLicenseTo(licenseTo);
        license.setLicenseCode(licenseCode);
        license.setLicenseGroupId(claims.get(KEY_LICENSE_GROUP_ID, Long.class));
        license.setLicenseGroupName(claims.get(KEY_LICENSE_GROUP_NAME, String.class));
        license.setEffective(!now.isBefore(startDate) && !endDate.isBefore(now));

        LicenseExtends licenseExtends = new LicenseExtends();
        license.setLicenseExtends(licenseExtends);

        claims.keySet().stream().filter(key -> key.startsWith(LicenseHelper.EXTENDS_PREFIX)).forEach(
                key -> licenseExtends.put(key.substring(EXTENDS_PREFIX_LENGTH), claims.get(key, String.class)));

        return license;
    }

    private static Date toDate(LocalDate date) {
        return Date.from(date.atStartOfDay(ZoneId.systemDefault()).toInstant());
    }

    public static Claims parse(String licenseCode, RsaKey rsaKey) {
        try {
            Key key = KeyFactory.getInstance(ALGORITHM).generatePublic(new X509EncodedKeySpec(rsaKey.getPublicKey()));

            Jws<Claims> claimsJws = Jwts.parser().setSigningKey(key).parseClaimsJws(licenseCode);

            return claimsJws.getBody();
        } catch (ExpiredJwtException e) {
            throw new LicenseBeOverdueException();
        } catch (MalformedJwtException | SignatureException | UnsupportedJwtException | IllegalArgumentException e) {
            throw new LicenseParseFailException();
        } catch (Exception e) {
            throw new ApplicationException(e.getLocalizedMessage(), e);
        }
    }
}
