/*
 * Decompiled with CFR 0.152.
 */
package io.inverno.mod.security.jose.internal.jwa;

import io.inverno.mod.security.jose.internal.JOSEUtils;
import io.inverno.mod.security.jose.internal.jwa.AbstractEncryptingJWAKeyManager;
import io.inverno.mod.security.jose.internal.jwa.GenericEncryptedCEK;
import io.inverno.mod.security.jose.internal.jwk.oct.GenericOCTJWK;
import io.inverno.mod.security.jose.jwa.EncryptingJWAKeyManager;
import io.inverno.mod.security.jose.jwa.JWACipherException;
import io.inverno.mod.security.jose.jwa.JWAKeyManagerException;
import io.inverno.mod.security.jose.jwa.JWAProcessingException;
import io.inverno.mod.security.jose.jwa.OCTAlgorithm;
import io.inverno.mod.security.jose.jwk.oct.OCTJWK;
import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException;
import java.security.Key;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.Base64;
import java.util.Map;
import java.util.Set;
import javax.crypto.BadPaddingException;
import javax.crypto.Cipher;
import javax.crypto.IllegalBlockSizeException;
import javax.crypto.NoSuchPaddingException;
import javax.crypto.SecretKey;
import javax.crypto.spec.GCMParameterSpec;
import org.apache.commons.lang3.StringUtils;

public class AESGCMKWKeyManager
extends AbstractEncryptingJWAKeyManager<OCTJWK, OCTAlgorithm> {
    public static final Set<String> PROCESSED_PARAMETERS = Set.of("iv", "tag");
    public static final Set<OCTAlgorithm> SUPPORTED_ALGORITHMS = Set.of(OCTAlgorithm.A128GCMKW, OCTAlgorithm.A192GCMKW, OCTAlgorithm.A256GCMKW);
    private SecretKey secretKey;

    public AESGCMKWKeyManager(OCTJWK jwk, OCTAlgorithm algorithm) throws JWAProcessingException {
        super(jwk, algorithm);
        if (!SUPPORTED_ALGORITHMS.contains(algorithm)) {
            throw new JWAProcessingException("Unsupported algorithm: " + algorithm.getAlgorithm());
        }
        this.init();
    }

    protected AESGCMKWKeyManager(OCTJWK jwk) {
        super(jwk);
    }

    @Override
    protected final void init() throws JWAProcessingException {
        this.secretKey = ((OCTJWK)this.jwk).toSecretKey().orElseThrow(() -> new JWAProcessingException("JWK is missing secret key"));
        if (this.secretKey.getEncoded().length != ((OCTAlgorithm)this.algorithm).getEncryptionKeyLength()) {
            throw new JWAProcessingException("Key length " + this.secretKey.getEncoded().length + "does not match algorithm " + ((OCTAlgorithm)this.algorithm).getAlgorithm());
        }
    }

    @Override
    public Set<String> getProcessedParameters() {
        return PROCESSED_PARAMETERS;
    }

    @Override
    protected EncryptingJWAKeyManager.EncryptedCEK doEncryptCEK(OCTJWK cek, Map<String, Object> parameters, SecureRandom secureRandom) throws JWAKeyManagerException {
        try {
            byte[] iv = JOSEUtils.generateInitializationVector(secureRandom, ((OCTAlgorithm)this.algorithm).getInitializationVectorLength());
            Cipher cipher = Cipher.getInstance(((OCTAlgorithm)this.algorithm).getJcaAlgorithm());
            cipher.init(1, (Key)this.secretKey, new GCMParameterSpec(((OCTAlgorithm)this.algorithm).getAuthenticationTagLength() * 8, iv), secureRandom);
            cipher.updateAAD(new byte[0]);
            byte[] encryptedData = cipher.doFinal(Base64.getUrlDecoder().decode(cek.getKeyValue()));
            byte[] cipherText = new byte[encryptedData.length - ((OCTAlgorithm)this.algorithm).getAuthenticationTagLength()];
            byte[] authenticationTag = new byte[((OCTAlgorithm)this.algorithm).getAuthenticationTagLength().intValue()];
            System.arraycopy(encryptedData, 0, cipherText, 0, cipherText.length);
            System.arraycopy(encryptedData, cipherText.length, authenticationTag, 0, authenticationTag.length);
            return new GenericEncryptedCEK(cipherText, Map.of("iv", JOSEUtils.BASE64_NOPAD_URL_ENCODER.encodeToString(iv), "tag", JOSEUtils.BASE64_NOPAD_URL_ENCODER.encodeToString(authenticationTag)));
        }
        catch (InvalidAlgorithmParameterException | InvalidKeyException | NoSuchAlgorithmException | BadPaddingException | IllegalBlockSizeException | NoSuchPaddingException e) {
            throw new JWACipherException(e);
        }
    }

    @Override
    protected OCTJWK doDecryptCEK(byte[] encrypted_key, OCTAlgorithm octEnc, Map<String, Object> parameters) throws JWAKeyManagerException {
        byte[] iv = AESGCMKWKeyManager.getInitializationVector(parameters);
        byte[] tag = AESGCMKWKeyManager.getAuthenticationTag(parameters);
        if (iv.length != ((OCTAlgorithm)this.algorithm).getInitializationVectorLength()) {
            throw new JWACipherException("Initialization vector length " + iv.length + "does not match algorithm " + ((OCTAlgorithm)this.algorithm).getAlgorithm());
        }
        try {
            Cipher cipher = Cipher.getInstance(((OCTAlgorithm)this.algorithm).getJcaAlgorithm());
            cipher.init(2, (Key)this.secretKey, new GCMParameterSpec(((OCTAlgorithm)this.algorithm).getAuthenticationTagLength() * 8, iv));
            cipher.updateAAD(new byte[0]);
            byte[] encryptedData = new byte[encrypted_key.length + tag.length];
            System.arraycopy(encrypted_key, 0, encryptedData, 0, encrypted_key.length);
            System.arraycopy(tag, 0, encryptedData, encrypted_key.length, tag.length);
            GenericOCTJWK cek = new GenericOCTJWK(JOSEUtils.BASE64_NOPAD_URL_ENCODER.encodeToString(cipher.doFinal(encryptedData)));
            cek.setAlgorithm(octEnc);
            return cek;
        }
        catch (InvalidAlgorithmParameterException | InvalidKeyException | NoSuchAlgorithmException | BadPaddingException | IllegalBlockSizeException | NoSuchPaddingException e) {
            throw new JWACipherException(e);
        }
    }

    private static byte[] getInitializationVector(Map<String, Object> parameters) throws JWAKeyManagerException {
        String iv = (String)parameters.get("iv");
        if (StringUtils.isBlank((CharSequence)iv)) {
            throw new JWAKeyManagerException("Missing initialization vector");
        }
        return Base64.getUrlDecoder().decode(iv);
    }

    private static byte[] getAuthenticationTag(Map<String, Object> parameters) throws JWAKeyManagerException {
        String tag = (String)parameters.get("tag");
        if (StringUtils.isBlank((CharSequence)tag)) {
            throw new JWAKeyManagerException("Missing authentication tag");
        }
        return Base64.getUrlDecoder().decode(tag);
    }
}

