/*
 * KUtil
 * Copyright (C) 2021-2022 Moritz Zwerger
 *
 * This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License along with this program. If not, see <https://www.gnu.org/licenses/>.
 */
package de.bixilon.kutil.latch

import de.bixilon.kutil.time.TimeUtil.millis


class CountUpAndDownLatch @JvmOverloads constructor(count: Int, val parent: CountUpAndDownLatch? = null) {
    private val callbacks: MutableSet<() -> Unit> = mutableSetOf()
    private val notify = Object()
    private var _count = 0
        set(value) {
            val diff = value - field
            check(value >= 0) { "Can not set negative count (previous=$field, value=$value)" }
            if (diff > 0) {
                total += diff
            }
            field = value
        }

    var count: Int
        get() {
            synchronized(notify) {
                return _count
            }
        }
        set(value) {
            val diff: Int
            synchronized(notify) {
                diff = value - _count
                _count = value
            }
            notify()
            parent?.plus(diff)
        }

    var total: Int = 0
        get() {
            synchronized(notify) {
                return field
            }
        }
        private set(value) {
            check(value >= 0) { "Total can not be < 0: $value" }
            synchronized(notify) {
                check(value >= field) { "Total can not decrement! (current=$field, wanted=$value)" }
                field = value
            }
        }


    init {
        check(parent !== this)
        this.count += count
    }

    @JvmOverloads
    fun await(timeout: Long = 0L) {
        val start = if (timeout > 0) millis() else 0L
        synchronized(notify) {
            while (true) {
                if (_count == 0) {
                    return
                }
                notify.wait(timeout)
                if (timeout > 0L) {
                    val time = millis()
                    if (time - start >= timeout) {
                        throw InterruptedException("Timeout reached!")
                    }
                }
            }
        }
    }

    @JvmName("Notify2")
    private fun notify() {
        synchronized(notify) {
            for (callback in callbacks) {
                callback.invoke()
            }
            notify.notifyAll()
        }
    }

    operator fun inc(): CountUpAndDownLatch {
        plus(1)
        return this
    }

    operator fun dec(): CountUpAndDownLatch {
        minus(1)
        return this
    }

    fun countUp() {
        plus(1)
    }

    fun countDown() {
        minus(1)
    }

    fun plus(value: Int): CountUpAndDownLatch {
        synchronized(notify) {
            count += value
        }
        return this
    }

    fun minus(value: Int): CountUpAndDownLatch {
        return plus(-value)
    }


    fun waitForChange(timeout: Long = 0L) {
        val start = if (timeout > 0) millis() else 0L

        synchronized(notify) {
            val lastCount = count
            val lastTotal = total
            while (true) {
                val count = this.count
                val total = this.total

                if (count != lastCount || total != lastTotal) {
                    return
                }

                notify.wait(timeout)
                if (timeout > 0L) {
                    val time = millis()
                    if (time - start >= timeout) {
                        throw InterruptedException("Timeout reached!")
                    }
                }
            }
        }
    }

    fun awaitWithChange(timeout: Long = 0L) {
        synchronized(notify) {
            if (total == 0) {
                waitForChange(timeout)
            }
            await(timeout)
        }
    }

    fun awaitOrChange(timeout: Long = 0L) {
        synchronized(notify) {
            if (total == 0) {
                return
            }
            waitForChange(timeout)
        }
    }


    operator fun plusAssign(callback: () -> Unit) {
        synchronized(notify) {
            callbacks += callback
        }
    }

    fun waitIfGreater(value: Int, timeout: Long = 0L) {
        synchronized(notify) {
            while (this.count > value) {
                waitForChange(timeout)
            }
        }
    }

    override fun toString(): String {
        return String.format("%d / %d", count, total)
    }
}
