"""
Hardened connection pool for py_opengauss inspired by dbutils.pooled_db.

Key differences from the lightweight pool:
* Uses a proxy wrapper so connections are always returned (close/__del__/with).
* Resets sessions on return (rollback + optional setsession).
* Avoids deadlocks by using an RLock and separating discard-with-lock paths.
* Supports blocking/non-blocking acquire with timeout and max usage cycling.
* Uses monotonic clocks for all interval calculations to avoid wall-clock jumps.
"""

from __future__ import annotations

import logging
import threading
import time
from collections import deque
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Deque, Dict, Optional, Sequence

import py_opengauss

logger = logging.getLogger(__name__)


class PoolError(Exception):
    """Base pool error."""


class PoolClosedError(PoolError):
    """Pool has been closed."""


class PoolExhaustedError(PoolError):
    """Pool is at capacity and blocking is disabled or timed out."""


@dataclass
class ConnectionPoolConfig:
    dsn: str
    min_size: int = 1
    max_size: int = 10
    blocking: bool = True
    acquire_timeout: Optional[float] = None
    idle_timeout: float = 300.0
    max_lifetime: float = 3600.0
    max_usage: Optional[int] = None
    test_on_borrow: bool = True
    test_sql: str = "SELECT 1"
    keepalive: bool = True
    keepalive_interval: float = 60.0
    health_check_interval: float = 30.0
    reset_on_return: bool = True
    setsession: Optional[Sequence[str]] = None
    connect_kwargs: Optional[Dict[str, Any]] = None


@dataclass
class _ConnectionEntry:
    conn: Any
    created_at: float
    last_used: float
    last_check: float
    usage_count: int = 0


class _PooledConnectionProxy:
    """Lightweight proxy that returns connections to the pool on close/del."""

    def __init__(self, pool: "OpenGaussConnectionPool", entry: _ConnectionEntry):
        self._pool = pool
        self._entry = entry
        self._returned = False

    def __getattr__(self, item: str) -> Any:
        return getattr(self._entry.conn, item)

    def __enter__(self) -> "_PooledConnectionProxy":
        return self

    def __exit__(self, exc_type, exc, tb) -> None:
        self.close(broken=bool(exc_type))

    def close(self, broken: bool = False) -> None:
        if self._returned:
            return
        self._returned = True
        self._pool._return_entry(self._entry, broken=broken)

    def mark_broken(self) -> None:
        """Explicitly mark this connection as broken so it gets discarded."""
        self.close(broken=True)

    def __del__(self):
        # Best-effort return to pool to avoid leaks if user forgets to close.
        try:
            self.close()
        except Exception:
            pass

    def __repr__(self) -> str:
        return f"<PooledOpenGaussConnection id={id(self._entry.conn)} returned={self._returned}>"


class OpenGaussConnectionPool:
    """
    Hardened, thread-safe connection pool for py_opengauss.

    Connections are wrapped in a proxy that ensures return to the pool on close
    or object finalization. Pool operations rely on monotonic clocks to avoid
    issues from system clock jumps.
    """

    def __init__(self, config: ConnectionPoolConfig):
        if config.min_size < 0 or config.max_size <= 0:
            raise ValueError("Pool sizes must be positive")
        if config.min_size > config.max_size:
            raise ValueError("min_size cannot be greater than max_size")

        self._config = config
        self._lock = threading.RLock()
        self._condition = threading.Condition(self._lock)
        self._available: Deque[_ConnectionEntry] = deque()
        self._in_use: Dict[int, _ConnectionEntry] = {}
        self._total = 0
        self._create_failures = 0
        self._discarded = 0
        self._closed = False
        self._stop_event = threading.Event()

        self._housekeeper = threading.Thread(
            target=self._housekeeping, name="opengauss-pool-housekeeper", daemon=True
        )

        self._initial_fill()
        self._housekeeper.start()

    def borrow(self, timeout: Optional[float] = None) -> _PooledConnectionProxy:
        """
        Borrow a connection from the pool.

        If test_on_borrow is enabled, the connection will be validated with the
        configured SQL before being returned. When blocking is False, pool
        exhaustion immediately raises PoolExhaustedError.
        """
        deadline = None
        effective_timeout = timeout
        if effective_timeout is None:
            effective_timeout = self._config.acquire_timeout
        if effective_timeout is not None:
            deadline = time.monotonic() + effective_timeout

        while True:
            entry = self._try_acquire_available()
            if entry:
                proxy = self._prepare_borrow(entry)
                if proxy:
                    return proxy
                continue

            to_create = self._reserve_slot_for_new()
            if to_create:
                entry = self._create_entry(reserved=True)
                if entry is None:
                    # Creation failed; avoid tight loop when DB is unavailable.
                    self._wait_for_availability(deadline)
                    continue
                proxy = self._prepare_borrow(entry)
                if proxy:
                    return proxy
                continue

            # Pool is at capacity and nothing available.
            if not self._config.blocking:
                raise PoolExhaustedError("Pool is at capacity and blocking is disabled")

            self._wait_for_availability(deadline)

    def return_connection(self, conn: Any, had_error: bool = False) -> None:
        """
        Return a connection (or proxy) to the pool.

        If had_error is True, the connection will be closed and removed to
        avoid reusing broken connections.
        """
        entry = None
        if isinstance(conn, _PooledConnectionProxy):
            conn.close(broken=had_error)
            return

        with self._condition:
            entry = self._in_use.pop(id(conn), None)

        if entry is None:
            try:
                conn.close()
            finally:
                return
        self._return_entry(entry, broken=had_error)

    @contextmanager
    def connection(self, timeout: Optional[float] = None):
        """
        Context manager helper that returns the connection to the pool
        automatically and discards it if an exception is raised.
        """
        proxy = self.borrow(timeout=timeout)
        try:
            yield proxy
        except Exception:
            proxy.close(broken=True)
            raise
        else:
            proxy.close()

    def close(self) -> None:
        """Close the pool and all managed connections."""
        with self._condition:
            if self._closed:
                return
            self._closed = True
            to_close = list(self._available)
            self._available.clear()
            to_close.extend(self._in_use.values())
            self._in_use.clear()
            self._condition.notify_all()
            self._stop_event.set()
        for entry in to_close:
            self._close_entry(entry)
        self._total = 0
        if self._housekeeper.is_alive():
            self._housekeeper.join(timeout=1.0)

    def stats(self) -> Dict[str, Any]:
        """Lightweight snapshot of pool counters for observability/testing."""
        with self._condition:
            return {
                "total": self._total,
                "available": len(self._available),
                "in_use": len(self._in_use),
                "closed": self._closed,
                "create_failures": self._create_failures,
                "discarded": self._discarded,
            }

    # Internal helpers -----------------------------------------------------

    def _prepare_borrow(self, entry: _ConnectionEntry) -> Optional[_PooledConnectionProxy]:
        if self._config.test_on_borrow and not self._validate(entry):
            self._discard_entry(entry)
            return None
        entry.last_used = time.monotonic()
        entry.usage_count += 1
        with self._condition:
            self._in_use[id(entry.conn)] = entry
        return _PooledConnectionProxy(self, entry)

    def _initial_fill(self) -> None:
        for _ in range(self._config.min_size):
            entry = self._create_entry()
            if entry:
                with self._condition:
                    self._available.append(entry)
                    self._total += 1

    def _reserve_slot_for_new(self) -> bool:
        with self._condition:
            if self._closed:
                raise PoolClosedError("Connection pool is closed")
            if self._total >= self._config.max_size:
                return False
            self._total += 1
            return True

    def _try_acquire_available(self) -> Optional[_ConnectionEntry]:
        now = time.monotonic()
        with self._condition:
            if self._closed:
                raise PoolClosedError("Connection pool is closed")
            while self._available:
                entry = self._available.popleft()
                if self._should_discard(entry, now):
                    self._discard_entry_locked(entry)
                    continue
                return entry
        return None

    def _create_entry(self, reserved: bool = False) -> Optional[_ConnectionEntry]:
        try:
            conn = py_opengauss.open(self._config.dsn, **(self._config.connect_kwargs or {}))
            now = time.monotonic()
            return _ConnectionEntry(conn=conn, created_at=now, last_used=now, last_check=now)
        except Exception:
            logger.exception("Failed to create new py_opengauss connection")
            self._create_failures += 1
            if reserved:
                with self._condition:
                    self._total = max(0, self._total - 1)
                    self._condition.notify_all()
            return None

    def _validate(self, entry: _ConnectionEntry) -> bool:
        try:
            stmt = entry.conn.prepare(self._config.test_sql)
            stmt()
            entry.last_check = time.monotonic()
            return True
        except Exception:
            logger.warning("Validation failed; discarding connection", exc_info=True)
            return False

    def _should_discard(self, entry: _ConnectionEntry, now: Optional[float] = None) -> bool:
        now = now or time.monotonic()
        if self._config.max_lifetime and now - entry.created_at >= self._config.max_lifetime:
            return True
        if self._config.idle_timeout and now - entry.last_used >= self._config.idle_timeout:
            return True
        if self._config.max_usage and entry.usage_count >= self._config.max_usage:
            return True
        return False

    def _reset_connection(self, entry: _ConnectionEntry) -> bool:
        if not self._config.reset_on_return and not self._config.setsession:
            return True
        try:
            if self._config.reset_on_return:
                rollback = getattr(entry.conn, "rollback", None)
                if callable(rollback):
                    rollback()
                else:
                    entry.conn.execute("ROLLBACK")
            if self._config.setsession:
                for sql in self._config.setsession:
                    stmt = entry.conn.prepare(sql)
                    stmt()
            return True
        except Exception:
            logger.warning("Failed to reset connection; discarding", exc_info=True)
            return False

    def _close_entry(self, entry: _ConnectionEntry) -> None:
        try:
            entry.conn.close()
        except Exception:
            logger.debug("Failed closing connection", exc_info=True)

    def _discard_entry_locked(self, entry: _ConnectionEntry) -> None:
        """Discard while holding the pool lock."""
        self._close_entry(entry)
        self._total = max(0, self._total - 1)
        self._discarded += 1
        self._condition.notify_all()

    def _discard_entry(self, entry: _ConnectionEntry) -> None:
        with self._condition:
            self._discard_entry_locked(entry)

    def _return_entry(self, entry: _ConnectionEntry, broken: bool = False) -> None:
        with self._condition:
            stored = self._in_use.pop(id(entry.conn), None)
        if stored is None:
            return
        entry = stored
        if broken or not self._reset_connection(entry) or self._should_discard(entry):
            self._discard_entry(entry)
            return
        with self._condition:
            if self._closed:
                self._discard_entry_locked(entry)
                return
            entry.last_used = time.monotonic()
            self._available.append(entry)
            self._condition.notify_all()

    def _wait_for_availability(self, deadline: Optional[float]) -> None:
        if deadline is not None:
            remaining = deadline - time.monotonic()
            if remaining <= 0:
                raise PoolExhaustedError("Timed out waiting for connection from pool")
            with self._condition:
                if self._closed:
                    raise PoolClosedError("Connection pool is closed")
                self._condition.wait(timeout=remaining)
        else:
            with self._condition:
                if self._closed:
                    raise PoolClosedError("Connection pool is closed")
                self._condition.wait()

    def _housekeeping(self) -> None:
        interval = max(1.0, self._config.health_check_interval)
        while not self._stop_event.wait(interval):
            self._perform_health_check()

    def _perform_health_check(self) -> None:
        now = time.monotonic()
        to_check: Deque[_ConnectionEntry] = deque()
        with self._condition:
            if self._closed:
                return
            kept: Deque[_ConnectionEntry] = deque()
            while self._available:
                entry = self._available.popleft()
                if self._should_discard(entry, now):
                    self._discard_entry_locked(entry)
                    continue
                if (
                    self._config.keepalive
                    and self._config.keepalive_interval
                    and now - entry.last_check >= self._config.keepalive_interval
                ):
                    to_check.append(entry)
                else:
                    kept.append(entry)
            self._available = kept

        for entry in list(to_check):
            if self._validate(entry):
                with self._condition:
                    if not self._closed:
                        entry.last_used = time.monotonic()
                        self._available.append(entry)
            else:
                self._discard_entry(entry)

        with self._condition:
            needed = max(0, self._config.min_size - self._total)
        for _ in range(needed):
            entry = self._create_entry()
            if entry:
                with self._condition:
                    if self._closed:
                        self._close_entry(entry)
                        return
                    self._available.append(entry)
                    self._total += 1
                    self._condition.notify_all()
