/*
 * Decompiled with CFR 0.152.
 */
package io.inverno.mod.security.jose.internal.jwa;

import io.inverno.mod.security.jose.internal.JOSEUtils;
import io.inverno.mod.security.jose.internal.jwa.AbstractJWACipher;
import io.inverno.mod.security.jose.internal.jwa.GenericEncryptedData;
import io.inverno.mod.security.jose.jwa.JWACipher;
import io.inverno.mod.security.jose.jwa.JWACipherException;
import io.inverno.mod.security.jose.jwa.JWAProcessingException;
import io.inverno.mod.security.jose.jwa.OCTAlgorithm;
import io.inverno.mod.security.jose.jwk.oct.OCTJWK;
import java.nio.ByteBuffer;
import java.security.InvalidAlgorithmParameterException;
import java.security.InvalidKeyException;
import java.security.Key;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.Arrays;
import java.util.Base64;
import java.util.Set;
import javax.crypto.BadPaddingException;
import javax.crypto.Cipher;
import javax.crypto.IllegalBlockSizeException;
import javax.crypto.Mac;
import javax.crypto.NoSuchPaddingException;
import javax.crypto.SecretKey;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;

public class AESCBCCipher
extends AbstractJWACipher {
    public static final Set<OCTAlgorithm> SUPPORTED_ALGORITHMS = Set.of(OCTAlgorithm.A128CBC_HS256, OCTAlgorithm.A192CBC_HS384, OCTAlgorithm.A256CBC_HS512);
    private SecretKey digestSecretKey;
    private SecretKey secretKey;

    public AESCBCCipher(OCTJWK jwk, OCTAlgorithm algorithm) throws JWAProcessingException {
        super(jwk, algorithm);
        if (!SUPPORTED_ALGORITHMS.contains(algorithm)) {
            throw new JWAProcessingException("Unsupported algorithm: " + algorithm.getAlgorithm());
        }
        this.init();
    }

    protected AESCBCCipher(OCTJWK jwk) {
        super(jwk);
    }

    @Override
    protected final void init() throws JWAProcessingException {
        byte[] key = Base64.getUrlDecoder().decode(this.jwk.getKeyValue());
        if (key.length != this.algorithm.getEncryptionKeyLength() + this.algorithm.getMacKeyLength()) {
            throw new JWAProcessingException("Key length " + key.length + "does not match algorithm " + this.algorithm.getAlgorithm());
        }
        this.digestSecretKey = new SecretKeySpec(key, 0, this.algorithm.getMacKeyLength(), this.algorithm.getMacAlgorithm());
        this.secretKey = new SecretKeySpec(key, this.algorithm.getMacKeyLength(), this.algorithm.getEncryptionKeyLength(), "AES");
    }

    @Override
    protected JWACipher.EncryptedData doEncrypt(byte[] data, byte[] aad, SecureRandom secureRandom) throws JWACipherException {
        byte[] iv = JOSEUtils.generateInitializationVector(secureRandom, this.algorithm.getInitializationVectorLength());
        byte[] cipherText = this.cipherText(data, iv, secureRandom);
        byte[] authenticationTag = Arrays.copyOf(this.computeMac(aad, iv, cipherText), (int)this.algorithm.getAuthenticationTagLength());
        return new GenericEncryptedData(iv, cipherText, authenticationTag);
    }

    @Override
    protected byte[] doDecrypt(byte[] cipherText, byte[] aad, byte[] iv, byte[] tag) throws JWACipherException {
        if (iv.length != this.algorithm.getInitializationVectorLength()) {
            throw new JWACipherException("Initialization vector length " + iv.length + " does not match algorithm " + this.algorithm.getAlgorithm());
        }
        if (!Arrays.equals(tag, Arrays.copyOf(this.computeMac(aad, iv, cipherText), (int)this.algorithm.getAuthenticationTagLength()))) {
            throw new JWACipherException("Invalid authentication tag");
        }
        return this.decrypt(cipherText, iv);
    }

    private byte[] cipherText(byte[] data, byte[] iv, SecureRandom secureRandom) throws JWACipherException {
        try {
            Cipher cipher = Cipher.getInstance(this.algorithm.getJcaAlgorithm());
            cipher.init(1, (Key)this.secretKey, new IvParameterSpec(iv), secureRandom);
            return cipher.doFinal(data);
        }
        catch (InvalidAlgorithmParameterException | InvalidKeyException | NoSuchAlgorithmException | BadPaddingException | IllegalBlockSizeException | NoSuchPaddingException e) {
            throw new JWACipherException(e);
        }
    }

    private byte[] computeMac(byte[] aad, byte[] iv, byte[] cipherText) throws JWACipherException {
        byte[] al = ByteBuffer.allocate(8).putLong(Integer.toUnsignedLong(aad.length * 8)).array();
        byte[] input = new byte[aad.length + iv.length + cipherText.length + al.length];
        System.arraycopy(aad, 0, input, 0, aad.length);
        System.arraycopy(iv, 0, input, aad.length, iv.length);
        System.arraycopy(cipherText, 0, input, aad.length + iv.length, cipherText.length);
        System.arraycopy(al, 0, input, aad.length + iv.length + cipherText.length, al.length);
        try {
            Mac sig = Mac.getInstance(this.algorithm.getMacAlgorithm());
            sig.init(this.digestSecretKey);
            return sig.doFinal(input);
        }
        catch (InvalidKeyException | NoSuchAlgorithmException e) {
            throw new JWACipherException(e);
        }
    }

    private byte[] decrypt(byte[] cipherText, byte[] iv) throws JWACipherException {
        try {
            Cipher cipher = Cipher.getInstance(this.algorithm.getJcaAlgorithm());
            cipher.init(2, (Key)this.secretKey, new IvParameterSpec(iv));
            return cipher.doFinal(cipherText);
        }
        catch (InvalidAlgorithmParameterException | InvalidKeyException | NoSuchAlgorithmException | BadPaddingException | IllegalBlockSizeException | NoSuchPaddingException e) {
            throw new JWACipherException(e);
        }
    }
}

