@file:JvmName("Tables")
@file:Suppress("NOTHING_TO_INLINE")

package tech.ostack.kform.datatypes

import kotlin.concurrent.Volatile
import kotlin.jvm.JvmName
import kotlin.jvm.JvmSynthetic
import kotlinx.serialization.KSerializer
import kotlinx.serialization.Serializable
import kotlinx.serialization.builtins.ListSerializer
import kotlinx.serialization.builtins.MapSerializer
import kotlinx.serialization.builtins.serializer
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.encoding.Decoder
import kotlinx.serialization.encoding.Encoder

/** Identifier of a table row. */
public typealias TableRowId = Int

/**
 * Row of a table containing a value and its associated identifier.
 *
 * @property id Row identifier.
 * @property value Row Value.
 */
public data class TableRow<T>(public val id: TableRowId, public var value: T)

/** Collection where each value has a stable identifier. */
@Serializable(with = Table.RowsSerializer::class)
public class Table<T>
private constructor(
    private val list: ArrayList<TableRow<T>>,
    private val map: HashMap<TableRowId, TableRow<T>>,
) {
    public constructor() : this(ArrayList(), HashMap())

    public constructor(
        initialCapacity: Int
    ) : this(ArrayList(initialCapacity), HashMap(initialCapacity))

    public constructor(table: Table<T>) : this(table.size) {
        for ((id, value) in table.rows) {
            this[id] = value
        }
    }

    public constructor(map: Map<TableRowId, T>) : this(map.size) {
        for ((key, value) in map) {
            this[key] = value
        }
    }

    public constructor(values: Collection<T>) : this(values.size) {
        addAll(values)
    }

    private var _nextId: TableRowId = 0

    /** Identifier of the next row (when incremented automatically). */
    public val nextId: TableRowId
        get() = _nextId

    /** Computes the next id. */
    private fun computeNextId() {
        while (_nextId in map) {
            ++_nextId
        }
    }

    /** Size of the table. */
    public val size: Int
        get() = list.size

    /** Set of all rows of the table. */
    public val rows: Set<TableRow<T>>
        get() {
            if (_rows == null) {
                _rows =
                    object : AbstractSet<TableRow<T>>() {
                        override val size
                            get() = this@Table.size

                        override operator fun contains(element: TableRow<T>): Boolean {
                            val rowWithMatchingId = map[element.id]
                            return rowWithMatchingId != null &&
                                rowWithMatchingId.value == element.value
                        }

                        override fun iterator() = list.iterator()
                    }
            }
            return _rows!!
        }

    @Volatile private var _rows: Set<TableRow<T>>? = null

    /** Set of all identifiers of the table. */
    public val ids: Set<TableRowId>
        get() {
            if (_ids == null) {
                _ids =
                    object : AbstractSet<TableRowId>() {
                        override val size
                            get() = this@Table.size

                        override operator fun contains(element: TableRowId) = containsId(element)

                        override fun iterator(): Iterator<TableRowId> {
                            val iterator = list.iterator()
                            return object : Iterator<TableRowId> {
                                override fun hasNext() = iterator.hasNext()

                                override fun next() = iterator.next().id
                            }
                        }
                    }
            }
            return _ids!!
        }

    @Volatile private var _ids: Set<TableRowId>? = null

    /** Collection of all values of the table. */
    public val values: Collection<T>
        get() {
            if (_values == null) {
                _values =
                    object : AbstractCollection<T>() {
                        override val size: Int
                            get() = this@Table.size

                        override operator fun contains(element: T) = containsValue(element)

                        override fun iterator(): Iterator<T> {
                            val iterator = list.iterator()
                            return object : Iterator<T> {
                                override fun hasNext() = iterator.hasNext()

                                override fun next() = iterator.next().value
                            }
                        }
                    }
            }
            return _values!!
        }

    @Volatile private var _values: Collection<T>? = null

    override fun equals(other: Any?): Boolean =
        when {
            this === other -> true
            other !is Table<*> -> false
            else -> list == other.list
        }

    override fun hashCode(): Int = list.hashCode()

    override fun toString(): String =
        list.joinToString(prefix = "[", postfix = "]") { "${it.id}=${it.value}" }

    /** Whether the table is empty. */
    public fun isEmpty(): Boolean = list.isEmpty()

    /** Whether the table contains a row with the provided [id]. */
    public fun containsId(id: TableRowId): Boolean = id in map

    /** Whether the table contains a row with the provided [value]. */
    public fun containsValue(value: T): Boolean = list.any { it.value == value }

    /** Whether the table contains a row with the provided [id]. */
    @JvmSynthetic public inline operator fun contains(id: TableRowId): Boolean = containsId(id)

    /** Whether the table contains a row with the provided [value]. */
    @JvmSynthetic public inline operator fun contains(value: T): Boolean = containsValue(value)

    /** Returns the index of the row with the provided [id] or `-1` when no such row exists. */
    public fun indexOfId(id: TableRowId): Int = list.indexOfFirst { row -> row.id == id }

    /** Returns the index of the row with the provided [value] or `-1` when no such row exists. */
    public fun indexOfValue(value: T): Int = list.indexOfFirst { row -> row.value == value }

    /**
     * Returns the last index of the row with the provided [value] or `-1` when no such row exists.
     */
    public fun lastIndexOfValue(value: T): Int = list.indexOfLast { row -> row.value == value }

    /** Returns the row at the provided [index]. */
    public fun rowAt(index: Int): TableRow<T> = list[index]

    /** Returns the identifier of the row at the provided [index]. */
    public fun idAt(index: Int): TableRowId = rowAt(index).id

    /** Returns the value of the row at the provided [index]. */
    public fun valueAt(index: Int): T = rowAt(index).value

    private fun getRow(id: TableRowId): TableRow<T> =
        map.getOrElse(id) { throw NoSuchElementException("No row found with id '$id'.") }

    /** Returns the value of the row with the provided [id]. */
    public operator fun get(id: TableRowId): T? = map[id]?.value

    private fun setRowValue(row: TableRow<T>, value: T): T {
        val oldValue = row.value
        row.value = value
        return oldValue
    }

    /**
     * Sets the value of a row with the provided [id]. If no such row exists, then it will be
     * created. Returns the value previously at said row or `null` if no such row existed.
     */
    public operator fun set(id: TableRowId, value: T): T? =
        if (id in map) setRowValue(getRow(id), value)
        else {
            val row = TableRow(id, value)
            list.add(row)
            map[id] = row
            computeNextId()
            null
        }

    /**
     * Sets the value of the row at the provided [index]. If `index == size`, a new row is added.
     * Returns the value previously at said row or `null` if no such row existed.
     */
    public fun setAt(index: Int, value: T): T? =
        if (index == size) {
            addAt(size, value)
            null
        } else setRowValue(rowAt(index), value)

    /**
     * Adds a new row with the provided [value] to the table. This simply calls [add] while ignoring
     * the returned identifier. If you require knowing the identifier of the newly added row, use
     * [add] instead.
     */
    @JvmSynthetic
    public inline operator fun plusAssign(value: T) {
        add(value)
    }

    /**
     * Adds a new row with the provided [value] to the table. Returns the identifier of the newly
     * added row.
     */
    public fun add(value: T): TableRowId = addAt(size, value)

    /**
     * Adds a new row with the provided [value] at the provided [index]. Returns the identifier of
     * the newly added row.
     */
    public fun addAt(index: Int, value: T): TableRowId {
        val newId = nextId
        val row = TableRow(newId, value)
        list.add(index, row)
        map[newId] = row
        computeNextId()
        return newId
    }

    /**
     * Adds all provided [values] to the table. Returns a list with the identifiers of all newly
     * added rows.
     */
    public fun addAll(values: Collection<T>): List<TableRowId> = addAllAt(size, values)

    /**
     * Adds all provided [values] to the table at the provided [index]. Returns a list with the
     * identifiers of all newly added rows.
     */
    public fun addAllAt(index: Int, values: Collection<T>): List<TableRowId> {
        val rows = ArrayList<TableRow<T>>(values.size)
        for (value in values) {
            val row = TableRow(nextId, value)
            map[nextId] = row
            computeNextId()
            rows += row
        }
        list.addAll(index, rows)
        return rows.map { it.id }
    }

    /**
     * Removes the row with the provided [id]. This simply calls [remove] while ignoring the
     * returned element. If you require knowing which element was removed, use [remove] instead.
     */
    @JvmSynthetic
    public inline operator fun minusAssign(id: TableRowId) {
        remove(id)
    }

    /**
     * Removes the row with the provided [id]. Returns the element of the removed row or `null` if
     * no such row exists.
     */
    public fun remove(id: TableRowId): T? =
        indexOfId(id).let { if (it == -1) null else removeAt(it).value }

    /** Removes the row at the provided [index]. Returns the removed row. */
    public fun removeAt(index: Int): TableRow<T> {
        val row = list.removeAt(index)
        map.remove(row.id)
        return row
    }

    /** Clears the table by removing all of its rows. */
    public fun clear() {
        list.clear()
        map.clear()
    }

    /**
     * Returns a sub-list representation of the table from index [fromIndex] (inclusive) to index
     * [toIndex] (exclusive).
     */
    public fun subList(fromIndex: Int, toIndex: Int): List<TableRow<T>> =
        list.subList(fromIndex, toIndex)

    /**
     * Table serialiser that serialises the table as a map, mapping each row id to its respective
     * value.
     */
    public class RowsSerializer<T>(valueSerializer: KSerializer<T>) : KSerializer<Table<T>> {
        private val mapSerializer = MapSerializer(Int.serializer(), valueSerializer)

        override val descriptor: SerialDescriptor =
            SerialDescriptor("tech.ostack.kform.datatypes.Table", mapSerializer.descriptor)

        override fun serialize(encoder: Encoder, value: Table<T>): Unit =
            encoder.encodeSerializableValue(mapSerializer, value.toMap())

        override fun deserialize(decoder: Decoder): Table<T> =
            decoder.decodeSerializableValue(mapSerializer).toTable()
    }

    /**
     * Table serialiser that serialises the values of the table as a list, ignoring the ids of all
     * table rows.
     */
    public class ValuesSerializer<T>(valueSerializer: KSerializer<T>) : KSerializer<Table<T>> {
        private val listSerializer = ListSerializer(valueSerializer)

        override val descriptor: SerialDescriptor =
            SerialDescriptor("tech.ostack.kform.datatypes.Table", listSerializer.descriptor)

        override fun serialize(encoder: Encoder, value: Table<T>): Unit =
            encoder.encodeSerializableValue(listSerializer, value.values.toList())

        override fun deserialize(decoder: Decoder): Table<T> =
            decoder.decodeSerializableValue(listSerializer).toTable()
    }
}

/**
 * Function used to disambiguate between the two [tableOf] implementations, in case of 0 arguments.
 */
public fun <T> tableOf(): Table<T> = Table()

/** Creates a new table from a list of elements, using auto-generated identifiers. */
public fun <T> tableOf(vararg values: T): Table<T> {
    val table = Table<T>(values.size)
    for (value in values) {
        table += value
    }
    return table
}

/** Creates a new table from a list of pairs mapping an identifier to its respective element. */
public fun <T> tableOf(vararg pairs: Pair<TableRowId, T>): Table<T> {
    val table = Table<T>(pairs.size)
    for ((id, elem) in pairs) {
        table[id] = elem
    }
    return table
}

/** Creates a table from an [Iterable], using auto-generated identifiers. */
public fun <T> Iterable<T>.toTable(): Table<T> {
    val table = Table<T>()
    for (element in this) {
        table += element
    }
    return table
}

/** Creates a copy of the provided table. */
public fun <T> Table<T>.toTable(): Table<T> = Table(this)

/** Creates a table from a [Map]. */
public fun <T> Map<TableRowId, T>.toTable(): Table<T> = Table(this)

/** Creates a table from a [Collection], using auto-generated identifiers. */
public fun <T> Collection<T>.toTable(): Table<T> = Table(this)

/** Creates a table from an [Array], using auto-generated identifiers. */
public fun <T> Array<T>.toTable(): Table<T> {
    val table = Table<T>(size)
    for (value in this) {
        table += value
    }
    return table
}

/** Creates a new [Map] from a table. */
public fun <T> Table<T>.toMap(): MutableMap<TableRowId, T> {
    val map = LinkedHashMap<TableRowId, T>(size)
    for ((id, value) in this.rows) {
        map[id] = value
    }
    return map
}
