/*
 * Decompiled with CFR 0.152.
 */
package de.otto.kafka.messaging.e2ee;

import de.otto.kafka.messaging.e2ee.AesEncryptedPayload;
import de.otto.kafka.messaging.e2ee.Cache;
import de.otto.kafka.messaging.e2ee.DefaultAesEncryptionConfiguration;
import de.otto.kafka.messaging.e2ee.EncryptionKeyProvider;
import de.otto.kafka.messaging.e2ee.InitializationVectorFactory;
import de.otto.kafka.messaging.e2ee.SecureRandomInitializationVectorFactory;
import de.otto.kafka.messaging.e2ee.vault.VaultHelper;
import java.nio.charset.StandardCharsets;
import java.security.Key;
import java.util.Objects;
import javax.crypto.spec.SecretKeySpec;

public class EncryptionService {
    private final EncryptionKeyProvider encryptionKeyProvider;
    private final InitializationVectorFactory initializationVectorFactory;
    private final Cache<String, EncryptionKeyData> encryptionKeyDataCache;

    public EncryptionService(EncryptionKeyProvider encryptionKeyProvider) {
        this(encryptionKeyProvider, new SecureRandomInitializationVectorFactory());
    }

    public EncryptionService(EncryptionKeyProvider encryptionKeyProvider, InitializationVectorFactory initializationVectorFactory) {
        Objects.requireNonNull(encryptionKeyProvider, "encryptionKeyProvider");
        Objects.requireNonNull(initializationVectorFactory, "initializationVectorFactory");
        this.encryptionKeyProvider = encryptionKeyProvider;
        this.initializationVectorFactory = initializationVectorFactory;
        this.encryptionKeyDataCache = new Cache(DefaultAesEncryptionConfiguration.CACHING_DURATION);
    }

    public AesEncryptedPayload encryptPayloadWithAes(String kafkaTopicName, byte[] plainPayload) {
        Objects.requireNonNull(kafkaTopicName, "kafkaTopicName must not be null");
        Objects.requireNonNull(plainPayload, "plainPayload must not be null");
        EncryptionKeyData encryptionKeyData = this.encryptionKeyDataCache.getOrRetrieve(kafkaTopicName, this::retrieveKeyData);
        if (encryptionKeyData == null) {
            return AesEncryptedPayload.ofUnencryptedPayload(plainPayload);
        }
        Key aesKey = encryptionKeyData.aesKey();
        byte[] iv = this.initializationVectorFactory.generateInitializationVector();
        byte[] encryptedData = DefaultAesEncryptionConfiguration.encrypt(plainPayload, aesKey, iv);
        return AesEncryptedPayload.ofEncryptedPayload(encryptedData, iv, encryptionKeyData.keyVersion());
    }

    public AesEncryptedPayload encryptPayloadWithAes(String kafkaTopicName, String plainText) {
        Objects.requireNonNull(kafkaTopicName, "kafkaTopicName must not be null");
        Objects.requireNonNull(plainText, "plainText must not be null");
        return this.encryptPayloadWithAes(kafkaTopicName, plainText.getBytes(StandardCharsets.UTF_8));
    }

    private EncryptionKeyData retrieveKeyData(String topic) {
        EncryptionKeyProvider.KeyVersion keyVersion = this.encryptionKeyProvider.retrieveKeyForEncryption(topic);
        if (keyVersion == null) {
            return null;
        }
        SecretKeySpec aesKey = this.createAesKey(keyVersion);
        return new EncryptionKeyData(aesKey, keyVersion);
    }

    private SecretKeySpec createAesKey(EncryptionKeyProvider.KeyVersion keyVersion) {
        String base64Key = keyVersion.encodedKey();
        byte[] key = VaultHelper.decodeBase64Key(base64Key);
        return new SecretKeySpec(key, "AES");
    }

    private record EncryptionKeyData(Key aesKey, EncryptionKeyProvider.KeyVersion keyVersion) {
        private EncryptionKeyData {
            Objects.requireNonNull(aesKey, "aesKey must not be null");
            Objects.requireNonNull(keyVersion, "keyVersion must not be null");
        }
    }
}

