package cn.imkarl.sqldsl

import cn.imkarl.sqldsl.column.Column
import cn.imkarl.sqldsl.database.Database
import cn.imkarl.sqldsl.sql.SQLExpression
import cn.imkarl.sqldsl.table.Table
import com.google.devtools.ksp.getClassDeclarationByName
import com.google.devtools.ksp.isPublic
import com.google.devtools.ksp.processing.*
import com.google.devtools.ksp.symbol.*
import com.google.devtools.ksp.validate
import com.squareup.kotlinpoet.ksp.toTypeName
import kotlin.reflect.KMutableProperty1
import kotlin.reflect.jvm.jvmName

class SqlDslGenerateProcessor : SymbolProcessorProvider {
    override fun create(
        environment: SymbolProcessorEnvironment
    ): SymbolProcessor {
        sLogger = environment.logger
        log("===============${this::class.simpleName}===============")
        return GenerateProcessor(environment.codeGenerator)
    }
}



class GenerateProcessor(
    val codeGenerator: CodeGenerator
) : SymbolProcessor {
    private val processorClassNames = mutableListOf<String>()

    private val tableInnerFieldName = arrayOf(
        "columns",
        "primaryKey",
        "tableName",
        "autoIncColumn",
    )

    override fun process(resolver: Resolver): List<KSAnnotated> {
        resolver.getAllFiles()
            .filter { it.origin == Origin.KOTLIN }
            .filter { it.validate() }
            .map { file ->
                val packageName = file.packageName.asString()
                val simpleClassName = file.fileName.removeSuffix(".kt").removeSuffix(".java")
                "${packageName}.${simpleClassName}"
            }
            .filter { className ->
                !processorClassNames.contains(className)
            }
            .mapNotNull { className ->
                processorClassNames.add(className)
                resolver.getClassDeclarationByName(className)
            }
            .filter {
                it.superTypes.find { type -> type.toTypeNameString() == Table::class.java.name } != null
            }
            .forEach {
                log("find sqldsl Table class: $it")
                it.accept(GenerateVisitor(), Unit)
            }
        return emptyList()
    }

    /**
     * 编译期生成新的类
     */
    inner class GenerateVisitor : KSVisitorVoid() {
        override fun visitClassDeclaration(classDeclaration: KSClassDeclaration, data: Unit) {
            // 创建新的类（源码文件）
            val packageName = classDeclaration.containingFile!!.packageName.asString()
            val className = classDeclaration.simpleName.asString()
            val fields = classDeclaration.getAllProperties()
                .filter { !tableInnerFieldName.contains(it.simpleName.asString()) }
                .filter { it.isPublic() }
                .filter {
                    it.type.toTypeNameString()?.startsWith("cn.imkarl.sqldsl.column.Column<") == true
                            || it.type.toTypeNameString()?.startsWith("cn.imkarl.sqldsl.column.PrimaryKey<") == true
                }
                .map { property ->
                    val fieldName = property.simpleName.asString()
                    val fieldTypeName = property.type.resolve().arguments.firstOrNull()?.toTypeNameString() ?: "<ERROR>"
                    val defaultValueFun = "${classDeclaration.simpleName.asString()}.${fieldName}.defaultValueFun.invoke()"
                    FieldInfo(fieldName, fieldTypeName, defaultValueFun)
                }

            generatorTTable(classDeclaration, packageName, className, fields)
            generatorDao(classDeclaration, packageName, className, fields)
        }
    }


    /**
     * 生成 Txx 数据类
     */
    private fun generatorTTable(
        classDeclaration: KSClassDeclaration,
        packageName: String,
        className: String,
        fields: Sequence<FieldInfo>,
    ) {
        val tClassname = "T${className}"
        val file = codeGenerator.createNewFile(Dependencies(true, classDeclaration.containingFile!!), packageName, tClassname)
        file.write(
            """
                package $packageName

                data class $tClassname (
                ${
                    fields.joinToString(",\n") { (fieldName, fieldTypeName, defaultValueFun) ->
                        "  var $fieldName: $fieldTypeName = $defaultValueFun"
                    }
                }
                )
                """.replace("                ", "").toByteArray()
        )
        file.close()
    }

    /**
     * 生成 xxDao 数据库访问类
     */
    private fun generatorDao(
        classDeclaration: KSClassDeclaration,
        packageName: String,
        className: String,
        fields: Sequence<FieldInfo>,
    ) {
        val daoClassname = "${className}Dao"
        val file = codeGenerator.createNewFile(Dependencies(true, classDeclaration.containingFile!!), packageName, daoClassname)
        file.write(
            """
                package $packageName
                
                import ${Database::class.jvmName}
                import ${SQLExpression::class.jvmName}
                import ${SQLExpression::class.java.name.removeSuffix(SQLExpression::class.java.simpleName)}*
                import ${Table::class.jvmName}
                import ${Column::class.jvmName}
                import ${KMutableProperty1::class.jvmName}

                object $daoClassname {

                    private val db: Database
                        get() {
                            return Database.instance
                        }

                    fun findOne(
                        where: SQLExpression
                    ): T${className}? {
                        return db.select(${className}, where)
                            .firstOrNull()
                            ?.let {
                               T${className}(${
                                    fields.joinToString(",") { "it[${className}.${it.fieldName}]" }
                               })
                            }
                    }

                    fun findPage(
                        where: SQLExpression? = null
                    ): List<T${className}> {
                        return db.select(${className}, where)
                            .map {
                               T${className}(${
                                    fields.joinToString(",") { "it[${className}.${it.fieldName}]" }
                               })
                            }
                    }

                    fun count(
                        where: SQLExpression? = null
                    ): Long {
                        return db.count(${className}, where)
                    }
                    
                    private fun newRow(entity: T${className}): Table.Row {
                        return ${className}.newRow {
                            ${
                                fields.joinToString("\n") {
                                    "if (${className}.primaryKey != ${className}.${it.fieldName}) " +
                                            "it[${className}.${it.fieldName}] = entity.${it.fieldName}"
                                }
                            }
                        }
                    }
                    
                    fun insertOrUpdate(
                        entity: T${className}
                    ): Boolean {
                        val idName = ${className}.primaryKey?.columnName
                        val entityId = if (!idName.isNullOrEmpty()) {
                            T${className}::class.members.find { it.name == idName }?.call(entity)
                        } else {
                            null
                        }
                        val oldEntity = entityId?.let { findOne(${className}.primaryKey!!.eq(entityId)) }
                        val row = newRow(entity)
                        if (oldEntity != null) {
                            db.update(row, ${className}.primaryKey!!.eq(entityId))
                        } else {
                            db.insert(row)
                        }
                        return true
                    }
                    
                    fun insert(
                        entity: T${className}
                    ): Boolean {
                        val row = newRow(entity)
                        return db.insert(row) > 0
                    }
                    
                    fun insertAndGetId(
                        entity: T${className}
                    ): Boolean {
                        val idName = ${className}.primaryKey?.columnName
                        val row = newRow(entity)
                        val insertId: Long = db.insertAndGetId(row)
                        if (!idName.isNullOrEmpty()) {
                            val field = T${className}::class.members.find { it.name == idName } as KMutableProperty1<T${className}, Long>
                            field.set(entity, insertId)
                        }
                        return true
                    }
                    
                    fun update(
                        entity: T${className},
                        where: SQLExpression,
                        updateColumns: Collection<Column<*>>
                    ): Int {
                        val row = ${className}.newRow {
                           ${
                               fields.joinToString("\n") {
                                   "if (updateColumns.contains(${className}.${it.fieldName})) " +
                                           "it[${className}.${it.fieldName}] = entity.${it.fieldName}"
                               }
                           }
                        }
                        return db.update(row, where)
                    }

                    fun delete(
                        where: SQLExpression
                    ): Int {
                        return db.deleteWhere(${className}, where)
                    }

                }
                """.replace("                ", "").toByteArray()
        )
        file.close()
    }

}

private fun KSTypeReference.toTypeNameString(): String? {
    return try {
        toTypeName()
    } catch (throwable: Throwable) {
        null
    }?.toString()
}
private fun KSTypeArgument.toTypeNameString(): String? {
    return try {
        toTypeName()
    } catch (throwable: Throwable) {
        null
    }?.toString()
}

private data class FieldInfo(
    val fieldName: String,
    val fieldTypeName: String,
    val defaultValueFun: String,
)

private var sLogger: KSPLogger? = null
fun log(message: Any? = null) {
    println(message)
}
private fun println(message: Any? = null) {
    sLogger?.warn(message.toString().ifEmpty { "\n" })
}
