/*
 * This file is part of LibEuFin.
 * Copyright (C) 2025 Taler Systems S.A.
 *
 * LibEuFin is free software; you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as
 * published by the Free Software Foundation; either version 3, or
 * (at your option) any later version.
 *
 * LibEuFin is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Affero General
 * Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public
 * License along with LibEuFin; see the file COPYING.  If not, see
 * <http://www.gnu.org/licenses/>
 */

package tech.libeufin.common.db

import org.slf4j.Logger
import org.slf4j.LoggerFactory
import org.postgresql.util.PSQLState
import tech.libeufin.common.*
import java.sql.*
import java.time.*
import java.util.*

internal val logger: Logger = LoggerFactory.getLogger("libeufin-db")

class TalerStatement(internal val stmt: PreparedStatement): java.io.Closeable {
    override fun close() {
        // Close inner statement
        stmt.close()
    }

    private fun consume() {
        // Log warnings
        var current = stmt.getWarnings()
        while (current != null) {
            logger.warn(current.message)
            current = current.getNextWarning()
        }

        // Reset params
        stmt.clearParameters()
        idx=1
    }

    /* ----- Bindings helpers ----- */

    private var idx = 1;

    fun bind(string: String?) {
        stmt.setString(idx, string)
        idx+=1;
    }

    fun bind(bool: Boolean) {
        stmt.setBoolean(idx, bool)
        idx+=1;
    }

    fun bind(nb: Long?) {
        if (nb != null) {
            stmt.setLong(idx, nb)
        } else {
            stmt.setNull(idx, Types.INTEGER)
        }
        idx+=1;
    }

    fun bind(nb: Int) {
        stmt.setInt(idx, nb)
        idx+=1;
    }

    fun bind(amount: TalerAmount?) {
        bind(amount?.number())
    }

    fun bind(nb: DecimalNumber?) {
        if (nb != null) {
            stmt.setLong(idx, nb.value)
            stmt.setInt(idx+1, nb.frac)
            idx+=2
        }
    }

    fun bind(timestamp: Instant) {
        stmt.setLong(idx, timestamp.micros())
        idx+=1
    }

    fun bind(bytes: Base32Crockford64B?) {
        stmt.setBytes(idx, bytes?.raw)
        idx+=1
    }

    fun bind(bytes: Base32Crockford32B?) {
        stmt.setBytes(idx, bytes?.raw)
        idx+=1
    }

    fun bind(bytes: Base32Crockford16B?) {
        stmt.setBytes(idx, bytes?.raw)
        idx+=1
    }


    fun bind(bytes: ByteArray?) {
        stmt.setBytes(idx, bytes)
        idx+=1
    }

    fun <T : kotlin.Enum<T>> bind(enum: T?) {
        bind(enum?.name)
    }

    fun bind(date: LocalDateTime) {
        stmt.setObject(idx, date)
        idx+=1
    }

    fun bind(uuid: UUID?) {
        stmt.setObject(idx, uuid)
        idx+=1
    }

    fun <T : Enum<T>> bind(array: Array<T>) {
        val sqlArray = stmt.connection.createArrayOf("text", array)
        stmt.setArray(idx, sqlArray)
        idx+=1
    }

    fun bind(array: Array<String>) {
        val sqlArray = stmt.connection.createArrayOf("text", array)
        stmt.setArray(idx, sqlArray)
        idx+=1
    }

    fun bind(array: Array<UUID>) {
        val sqlArray = stmt.connection.createArrayOf("uuid", array)
        stmt.setArray(idx, sqlArray)
        idx+=1
    }

    /* ----- Transaction helpers ----- */

    fun executeQuery(): ResultSet {
        return try {
            stmt.executeQuery()
        } finally {
           consume()
        }
    }

    fun executeUpdate(): Int {
        return try {
            stmt.executeUpdate()
        } finally {
            consume()
        }
    }

    /** Read one row or null if none */
    fun <T> oneOrNull(lambda: (ResultSet) -> T): T? {
        return executeQuery().use {
            if (it.next()) lambda(it) else null
        }
    }

    /** Read one row or throw if none */
    fun <T> one(lambda: (ResultSet) -> T): T =
        requireNotNull(oneOrNull(lambda)) { "Missing result to database query" }

    /** Read one row or throw [err] in case or unique violation error */
    fun <T> oneUniqueViolation(err: T, lambda: (ResultSet) -> T): T {
        return try {
            one(lambda)
        } catch (e: SQLException) {
            if (e.sqlState == PSQLState.UNIQUE_VIOLATION.state) return err
            throw e // rethrowing, not to hide other types of errors.
        }
    }

    /** Read all rows */
    fun <T> all(lambda: (ResultSet) -> T): List<T> {
        return executeQuery().use {
            val ret = mutableListOf<T>()
            while (it.next()) {
                ret.add(lambda(it))
            }
            ret
        }
    }

    /** Execute a query checking it return a least one row  */
    fun executeQueryCheck(): Boolean {
        return executeQuery().use {
            it.next()
        }
    }

    /** Execute an update checking it update at least one row */
    fun executeUpdateCheck(): Boolean {
        executeUpdate()
        return stmt.updateCount > 0
    }

    /** Execute an update checking if fail because of unique violation error */
    fun executeUpdateViolation(): Boolean {
        return try {
            executeUpdateCheck()
        } catch (e: SQLException) {
            logger.debug(e.message)
            if (e.sqlState == PSQLState.UNIQUE_VIOLATION.state) return false
            throw e // rethrowing, not to hide other types of errors.
        }
    }

    /** Execute an update checking if fail because of unique violation error and resetting state */
    fun executeProcedureViolation(): Boolean {
        val savepoint = stmt.connection.setSavepoint()
        return try {
            executeUpdate()
            stmt.connection.releaseSavepoint(savepoint)
            true
        } catch (e: SQLException) {
            stmt.connection.rollback(savepoint)
            if (e.sqlState == PSQLState.UNIQUE_VIOLATION.state) return false
            throw e // rethrowing, not to hide other types of errors.
        }
    }
}