package cn.wumoe.hime.module

import cn.wumoe.hime.api.scripting.HimeContext
import cn.wumoe.hime.inter.Function
import cn.wumoe.hime.inter.Module
import cn.wumoe.hime.lexer.Token
import cn.wumoe.hime.lexer.Word
import cn.wumoe.hime.toWord
import java.security.KeyFactory
import java.security.KeyPairGenerator
import java.security.SecureRandom
import java.security.interfaces.RSAPublicKey
import java.security.spec.PKCS8EncodedKeySpec
import java.security.spec.X509EncodedKeySpec
import java.util.*
import javax.crypto.Cipher


class RSAModule : Module("hime.rea") {
    override fun init(context: HimeContext) {
        addFunction(Key())      // rsa-key
        addFunction(Encrypt())  // rsa-encrypt
        addFunction(Decrypt())  // rsa-decrypt
    }

    class Key : Function("rsa-key") {
        override fun call(pars: Array<out Token>): Token {
            val keyPairGen = KeyPairGenerator.getInstance("RSA")
            keyPairGen.initialize(1024, SecureRandom())
            val keyPair = keyPairGen.generateKeyPair()
            return cn.wumoe.hime.lexer.Array(
                arrayOf(
                    keyPair.public.toString().toWord(),
                    keyPair.private.toString().toWord()
                )
            )
        }
    }

    // (rsa-encrypt publicKey data)
    class Encrypt : Function("rsa-encrypt") {
        override fun call(pars: Array<out Token>): Token {
            return if (pars.size >= 2) {
                val decoded = Base64.getDecoder().decode(pars[0].toString())
                val pubKey = KeyFactory.getInstance("RSA").generatePublic(X509EncodedKeySpec(decoded)) as RSAPublicKey
                val cipher = Cipher.getInstance("RSA")
                cipher.init(Cipher.ENCRYPT_MODE, pubKey)
                Base64.getEncoder().encodeToString(cipher.doFinal(pars[1].toString().toByteArray())).toWord()
            } else
                Word.NIL
        }
    }

    // (rsa-decrypt privateKey data)
    class Decrypt : Function("rsa-decrypt") {
        override fun call(pars: Array<out Token>): Token {
            return if (pars.size >= 2) {
                val inputByte = Base64.getDecoder().decode(pars[1].toString())
                val decoded: ByteArray = Base64.getDecoder().decode(pars[0].toString())
                val priKey = KeyFactory.getInstance("RSA").generatePrivate(PKCS8EncodedKeySpec(decoded))
                val cipher = Cipher.getInstance("RSA")
                cipher.init(Cipher.DECRYPT_MODE, priKey)
                String(cipher.doFinal(inputByte)).toWord()
            } else
                Word.NIL
        }
    }
}