/*
 * Decompiled with CFR 0.152.
 */
package io.inverno.mod.security.jose.internal.jwa;

import io.inverno.mod.security.jose.internal.JOSEUtils;
import io.inverno.mod.security.jose.internal.jwa.AbstractEncryptingJWAKeyManager;
import io.inverno.mod.security.jose.internal.jwa.GenericEncryptedCEK;
import io.inverno.mod.security.jose.internal.jwk.oct.GenericOCTJWK;
import io.inverno.mod.security.jose.jwa.EncryptingJWAKeyManager;
import io.inverno.mod.security.jose.jwa.JWAKeyManagerException;
import io.inverno.mod.security.jose.jwa.JWAProcessingException;
import io.inverno.mod.security.jose.jwa.OCTAlgorithm;
import io.inverno.mod.security.jose.jwa.PBES2Algorithm;
import io.inverno.mod.security.jose.jwk.oct.OCTJWK;
import io.inverno.mod.security.jose.jwk.pbes2.PBES2JWK;
import java.nio.charset.StandardCharsets;
import java.security.InvalidKeyException;
import java.security.Key;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.security.spec.InvalidKeySpecException;
import java.util.Base64;
import java.util.Map;
import java.util.Set;
import javax.crypto.Cipher;
import javax.crypto.IllegalBlockSizeException;
import javax.crypto.NoSuchPaddingException;
import javax.crypto.SecretKey;
import javax.crypto.SecretKeyFactory;
import javax.crypto.spec.PBEKeySpec;
import javax.crypto.spec.SecretKeySpec;
import org.apache.commons.lang3.StringUtils;

public class PBES2KeyManager
extends AbstractEncryptingJWAKeyManager<PBES2JWK, PBES2Algorithm> {
    public static final Set<String> PROCESSED_PARAMETERS = Set.of("p2s", "p2c");
    public static final Set<PBES2Algorithm> SUPPORTED_ALGORITHMS = Set.of(PBES2Algorithm.PBES2_HS256_A128KW, PBES2Algorithm.PBES2_HS384_A192KW, PBES2Algorithm.PBES2_HS512_A256KW);

    public PBES2KeyManager(PBES2JWK jwk, PBES2Algorithm algorithm) throws JWAProcessingException {
        super(jwk, algorithm);
        if (!SUPPORTED_ALGORITHMS.contains(algorithm)) {
            throw new JWAProcessingException("Unsupported algorithm: " + algorithm.getAlgorithm());
        }
        this.init();
    }

    protected PBES2KeyManager(PBES2JWK jwk) {
        super(jwk);
    }

    @Override
    protected final void init() throws JWAProcessingException {
    }

    @Override
    public Set<String> getProcessedParameters() {
        return PROCESSED_PARAMETERS;
    }

    @Override
    protected EncryptingJWAKeyManager.EncryptedCEK doEncryptCEK(OCTJWK cek, Map<String, Object> parameters, SecureRandom secureRandom) throws JWAKeyManagerException {
        return cek.toSecretKey().map(cekSecretKey -> {
            try {
                SecretKeyFactory skf = SecretKeyFactory.getInstance(((PBES2Algorithm)this.algorithm).getJcaAlgorithm());
                byte[] p2s = PBES2KeyManager.getP2s(parameters, false, secureRandom);
                Integer p2c = PBES2KeyManager.getP2c(parameters, false);
                PBEKeySpec derivedKeySpec = new PBEKeySpec(new String(Base64.getUrlDecoder().decode(((PBES2JWK)this.jwk).getPassword())).toCharArray(), PBES2KeyManager.computeSaltValue((PBES2Algorithm)this.algorithm, p2s), p2c, ((PBES2Algorithm)this.algorithm).getEncryptionKeyLength() * 8);
                SecretKeySpec derivedKey = new SecretKeySpec(skf.generateSecret(derivedKeySpec).getEncoded(), "AES");
                Cipher cipher = Cipher.getInstance(((PBES2Algorithm)this.algorithm).getJcaEncryptionAlgorithm());
                cipher.init(3, derivedKey);
                return new GenericEncryptedCEK(cipher.wrap((Key)cekSecretKey), Map.of("p2s", JOSEUtils.BASE64_NOPAD_URL_ENCODER.encodeToString(p2s), "p2c", p2c));
            }
            catch (InvalidKeyException | NoSuchAlgorithmException | InvalidKeySpecException | IllegalBlockSizeException | NoSuchPaddingException e) {
                throw new JWAKeyManagerException(e);
            }
        }).orElseThrow(() -> new JWAKeyManagerException("CEK secret key is missing"));
    }

    @Override
    protected OCTJWK doDecryptCEK(byte[] encrypted_key, OCTAlgorithm octEnc, Map<String, Object> parameters) throws JWAKeyManagerException {
        try {
            SecretKeyFactory skf = SecretKeyFactory.getInstance(((PBES2Algorithm)this.algorithm).getJcaAlgorithm());
            byte[] p2s = PBES2KeyManager.getP2s(parameters, true, null);
            Integer p2c = PBES2KeyManager.getP2c(parameters, true);
            PBEKeySpec derivedKeySpec = new PBEKeySpec(new String(Base64.getUrlDecoder().decode(((PBES2JWK)this.jwk).getPassword())).toCharArray(), PBES2KeyManager.computeSaltValue((PBES2Algorithm)this.algorithm, p2s), p2c, ((PBES2Algorithm)this.algorithm).getEncryptionKeyLength() * 8);
            SecretKeySpec derivedKey = new SecretKeySpec(skf.generateSecret(derivedKeySpec).getEncoded(), "AES");
            Cipher cipher = Cipher.getInstance(((PBES2Algorithm)this.algorithm).getJcaEncryptionAlgorithm());
            cipher.init(4, derivedKey);
            SecretKey decryptedKey = (SecretKey)cipher.unwrap(encrypted_key, "AES", 3);
            GenericOCTJWK cek = new GenericOCTJWK(JOSEUtils.BASE64_NOPAD_URL_ENCODER.encodeToString(decryptedKey.getEncoded()), decryptedKey, true);
            cek.setAlgorithm(octEnc);
            return cek;
        }
        catch (InvalidKeyException | NoSuchAlgorithmException | InvalidKeySpecException | NoSuchPaddingException e) {
            throw new JWAKeyManagerException(e);
        }
    }

    private static byte[] computeSaltValue(PBES2Algorithm algorithm, byte[] salt) {
        byte[] utf8_alg = algorithm.getAlgorithm().getBytes(StandardCharsets.UTF_8);
        byte[] saltValue = new byte[utf8_alg.length + 1 + salt.length];
        System.arraycopy(utf8_alg, 0, saltValue, 0, utf8_alg.length);
        System.arraycopy(salt, 0, saltValue, utf8_alg.length + 1, salt.length);
        return saltValue;
    }

    private static byte[] getP2s(Map<String, Object> parameters, boolean failOnMissing, SecureRandom secureRandom) throws JWAKeyManagerException {
        String p2ss;
        byte[] p2s = null;
        String string = p2ss = parameters != null ? (String)parameters.get("p2s") : null;
        if (!StringUtils.isNotBlank((CharSequence)p2ss)) {
            if (failOnMissing) {
                throw new JWAKeyManagerException("Missing PBES2 salt input");
            }
            return JOSEUtils.generateSalt(secureRandom, 16);
        }
        p2s = Base64.getUrlDecoder().decode(p2ss);
        if (p2s.length < 8) {
            throw new JWAKeyManagerException("PBES2 salt input must be at least 8 bytes long");
        }
        return p2s;
    }

    private static int getP2c(Map<String, Object> parameters, boolean failOnMissing) {
        Integer p2c;
        Integer n = p2c = parameters != null ? (Integer)parameters.get("p2c") : null;
        if (p2c == null) {
            if (failOnMissing) {
                throw new JWAKeyManagerException("Missing PBES2 iteration count");
            }
            return 1000;
        }
        if (p2c < 1000) {
            throw new JWAKeyManagerException("PBES2 iteration count must be at least 1000");
        }
        return p2c;
    }
}

