package cn.cloudself.query.resolver

import cn.cloudself.query.NULL
import cn.cloudself.query.QueryProTransaction
import cn.cloudself.query.config.DbColumnInfo
import cn.cloudself.query.config.QueryProConfig
import cn.cloudself.query.exception.ConfigException
import cn.cloudself.query.exception.IllegalCall
import cn.cloudself.query.exception.UnSupportException
import cn.cloudself.query.util.*
import org.springframework.jdbc.datasource.DataSourceUtils
import org.springframework.transaction.support.TransactionSynchronizationManager
import java.math.BigDecimal
import java.sql.*
import java.time.LocalDate
import java.time.LocalDateTime
import java.time.LocalTime
import java.util.*
import javax.sql.DataSource
import kotlin.jvm.Throws

/**
 * `IQueryStructureResolver`接口的`JDBC`实现
 */
class JdbcQSR: QSR() {
    fun interface IResultSetWalker {
        @Throws(Exception::class)
        fun walk(rs: ResultSet)
    }

    override fun <T : Any> doSelect(queryPro: Class<*>?, sql: String, params: Array<out Any?>, clazz: Class<T>): List<T> {
        return getConnection(queryPro).autoUse { connection ->
            val preparedStatement = connection.prepareStatement(sql)
            setParam(preparedStatement, params, OnNull.BREAK)
            val resultList = mutableListOf<T>()

            val resultSet = preparedStatement.executeQuery()
            if (IResultSetWalker::class.java != clazz && IResultSetWalker::class.java.isAssignableFrom(clazz)) {
                val r = clazz.getDeclaredConstructor().also { it.isAccessible = true }.newInstance()
                Reflect.of(r).invoke("walk", resultSet)
                return listOf(r)
            }

            val proxy = BeanProxy.fromClass(clazz)
            while (resultSet.next()) {
                resultList.add(mapRow(proxy, resultSet))
            }
            resultList
        }
    }

    override fun <T : Any> doUpdate(queryPro: Class<*>?, sql: String, params: Array<out Any?>, clazz: Class<T>): T {
        return getConnection(queryPro).autoUse { connection ->
            val preparedStatement = connection.prepareStatement(sql)
            setParam(preparedStatement, params, OnNull.BREAK)

            val updatedCount = preparedStatement.executeUpdate()

            @Suppress("UNCHECKED_CAST")
            when {
                clazz.compatibleWithBool() -> (updatedCount > 0) as T
                clazz.compatibleWithInt() -> updatedCount as T
                else -> throw UnSupportException("不支持的class, 目前只支持List::class.java, listOf<Int>().javaClass, Int, Boolean")
            }
        }
    }

    override fun <ID : Any> doInsert(queryPro: Class<*>?, sql: String, params: Array<out Any?>, clazz: Class<ID>?): List<ID> {
        val idColumnProxy = if (clazz == null) null else BeanProxy.fromClass(clazz)

        return getConnection(queryPro).autoUse { connection ->
            val results = mutableListOf<ID>()
            val preparedStatement = connection.prepareStatement(sql, Statement.RETURN_GENERATED_KEYS)
            setParam(preparedStatement, params, OnNull.NULL)
            preparedStatement.execute()
            val resultSet = preparedStatement.generatedKeys
            if (idColumnProxy != null) {
                while (resultSet.next()) {
                    results.add(mapRow(idColumnProxy, resultSet))
                }
            }
            results
        }
    }

    override fun getColumnsDynamic(table: String): Collection<Column> {
        return getConnection(null).autoUse { connection ->
            val metaData = connection.metaData
            val catalog = connection.catalog
            val schema = connection.schema

            // metadata https://dev.mysql.com/doc/refman/8.0/en/show-columns.html
            val tableSet = metaData.getTables(catalog, schema, table, arrayOf("TABLE", "VIEW"))

            tableSet.next()

            val tableName = tableSet.getString("TABLE_NAME")

            var id: String? = null
            var idDefined = false
            val primaryKeys = metaData.getPrimaryKeys(catalog, schema, tableName)
            while (primaryKeys.next()) {
                if (idDefined) {
                    id = null
                    logger.warn("[WARN] 目前仍不支持复合主键")
                } else {
                    val columnName = primaryKeys.getString("COLUMN_NAME")
                    id = columnName
                    idDefined = true
                }
            }

            val columns = mutableListOf<Column>()
            val columnSet = metaData.getColumns(catalog, schema, tableName, null)
            while (columnSet.next()) {
                val columnName = columnSet.getString("COLUMN_NAME") ?: throw RuntimeException("找不到列名")
                val column = Column(
                    columnName,
                    getter = {
                        @Suppress("UNCHECKED_CAST")
                        if (it is MutableMap<*, *>) (it as MutableMap<String, Any?>)[columnName] else throw IllegalCall("不支持非Map类型")
                    },
                    columnName == id
                )
                columns.add(column)
            }

            columns
        }
    }

    enum class OnNull {
        BREAK,
        NULL,
    }

    private fun setParam(preparedStatement: PreparedStatement, params: Array<out Any?>, onNull: OnNull) {
        for ((i, param) in params.withIndex()) {
            when (param) {
                NULL              -> preparedStatement.setNull(i + 1, Types.NULL)
                is BigDecimal     -> preparedStatement.setBigDecimal(i + 1, param)
                is Boolean        -> preparedStatement.setBoolean(i + 1, param)
                is Byte           -> preparedStatement.setByte(i + 1, param)
                is ByteArray      -> preparedStatement.setBytes(i + 1, param)
                is Time           -> preparedStatement.setTime(i + 1, param)
                is Timestamp      -> preparedStatement.setTimestamp(i + 1, param)
                is java.sql.Date  -> preparedStatement.setTimestamp(i + 1, Timestamp(param.time))
                is java.util.Date -> preparedStatement.setTimestamp(i + 1, Timestamp(param.time))
                is Double         -> preparedStatement.setDouble(i + 1, param)
                is Enum<*>        -> preparedStatement.setString(i + 1, param.name)
                is Float          -> preparedStatement.setFloat(i + 1, param)
                is Int            -> preparedStatement.setInt(i + 1, param)
                is LocalDate      -> preparedStatement.setDate(i + 1, java.sql.Date.valueOf(param))
                is LocalTime      -> preparedStatement.setTime(i + 1, Time.valueOf(param))
                is LocalDateTime  -> preparedStatement.setTimestamp(i + 1, Timestamp.valueOf(param))
                is Long           -> preparedStatement.setLong(i + 1, param)
                is Short          -> preparedStatement.setShort(i + 1, param)
                is String         -> preparedStatement.setString(i + 1, param)
                else -> {
                    if (param == null && onNull == OnNull.NULL) {
                        preparedStatement.setNull(i + 1, Types.NULL)
                    } else {
                        throw UnSupportException("equalsTo, in, between等操作传入了不支持的类型{0}", param)
                    }
                }
            }
        }
    }

    private fun <T: Any> mapRow(proxy: BeanProxy<T>, resultSet: ResultSet): T {
        val resultProxy = proxy.newInstance()

        val metaData = resultSet.metaData
        val columnCount = metaData.columnCount

        for (i in 1..columnCount) {
            val columnName = metaData.getColumnLabel(i)
            val columnType = metaData.getColumnTypeName(i)
            var beanNeedType = resultProxy.getPropertyType(columnName)

            if (beanNeedType == null) {
                for ((tester, jt) in QueryProConfig.final.dbColumnInfoToJavaType()) {
                    if (tester(DbColumnInfo(columnType, columnName))) {
                        beanNeedType = jt
                        break
                    }
                }
            }

            val value = if (beanNeedType == null) {
                resultSet.getObject(i)
            } else {
                val parser = QueryProConfig.final.resultSetParser(beanNeedType)
                if (parser != null) {
                    parser.get(resultSet, i) /* value */
                } else {
                    var valueOpt: Optional<Any>? = null
                    for (resultSetParserEx in QueryProConfig.final.resultSetParserEx()) {
                        val valueOptMay = resultSetParserEx.parse(resultSet, beanNeedType, i)
                        if (valueOptMay.isPresent) {
                            valueOpt = valueOptMay
                            break
                        }
                    }
                    if (valueOpt != null) {
                        valueOpt.get()
                    } else {
                        // 没有找到生成目标类型的配置，尝试使用数据库默认的类型转换成目标类型，如果不行，则抛出异常
                        val couldConvertClassName = metaData.getColumnClassName(i)
                        if (beanNeedType.isAssignableFrom(Class.forName(couldConvertClassName))) {
                            resultSet.getObject(i)
                        } else {
                            throw ConfigException("不支持将name: {0}, type: {1}转换为{2}, " +
                                    "使用QueryProConfig.global.addResultSetParser添加转换器",
                                columnName, columnType, beanNeedType.name)
                        }
                    }
                }
            }

            resultProxy.setProperty(columnName, if (resultSet.wasNull()) null else value)
        }

        return resultProxy.toBean() as T
    }

    private fun getConnection(clazz: Class<*>?): Connection {
        var dataSource = QueryProConfig.final.dataSource(clazz)
        if (dataSource == null) {
            dataSource = try {
                SpringUtils.getBean(DataSource::class.java)
            } catch (e: NoClassDefFoundError) {
                null
            } ?: throw ConfigException("无法找到DataSource, 使用QueryProConfig.setDataSource设置")
            QueryProConfig.global.dataSource(dataSource)
        }
        return if (isDataSourceUtilsPresent && TransactionSynchronizationManager.isActualTransactionActive()) {
            DataSourceUtils.getConnection(dataSource)
        } else {
            if (QueryProTransaction.isActualTransactionActive.get()) {
                QueryProTransaction.getConnection(dataSource)
                QueryProTransaction.getConnection(dataSource)
            } else {
                val connection = dataSource.connection
                logger.debug("connection got.")
                connection
            }
        }
    }

    private inline fun <T : AutoCloseable?, R> T.autoUse(block: (T) -> R): R {
        return if (shouldClose()) {
            val used = use(block)
            logger.debug("connection closed.")
            used
        } else {
            block(this)
        }
    }

    private fun shouldClose() = if (isDataSourceUtilsPresent) {
        if (TransactionSynchronizationManager.isActualTransactionActive()) {
            false
        } else {
            !QueryProTransaction.isActualTransactionActive.get()
        }
    } else (!QueryProTransaction.isActualTransactionActive.get())

    companion object {
        private val logger = LogFactory.getLog(JdbcQSR::class.java)

        private val isDataSourceUtilsPresent = try {
            Class.forName("org.springframework.jdbc.datasource.DataSourceUtils")
            true
        } catch (e: Throwable) {
            false
        }
    }
}
