"""
Lightweight connection pool built on top of py_opengauss.

The pool keeps a configurable number of connections ready, validates them on
borrow when requested, performs periodic health checks/keepalives, and discards
connections that fail SQL execution or exceed idle/lifetime thresholds.
"""

from __future__ import annotations

import logging
import threading
import time
from collections import deque
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Deque, Dict, Optional

import py_opengauss


logger = logging.getLogger(__name__)


@dataclass
class ConnectionPoolConfig:
    dsn: str
    min_size: int = 1
    max_size: int = 10
    idle_timeout: float = 300.0
    max_lifetime: float = 3600.0
    test_on_borrow: bool = True
    test_sql: str = "SELECT 1"
    keepalive: bool = True
    keepalive_interval: float = 60.0
    health_check_interval: float = 30.0
    connect_kwargs: Optional[Dict[str, Any]] = None


@dataclass
class _ConnectionEntry:
    conn: Any
    created_at: float
    last_used: float
    last_check: float


class OpenGaussConnectionPool:
    """
    Thread-safe connection pool for py_opengauss.

    Borrowed connections are returned directly without wrappers. When an error
    is detected during SQL execution, callers should mark the connection as
    broken via `return_connection(conn, had_error=True)` to evict it quickly.
    """

    def __init__(self, config: ConnectionPoolConfig):
        if config.min_size < 0 or config.max_size <= 0:
            raise ValueError("Pool sizes must be positive")
        if config.min_size > config.max_size:
            raise ValueError("min_size cannot be greater than max_size")

        self._config = config
        self._lock = threading.Lock()
        self._condition = threading.Condition(self._lock)
        self._available: Deque[_ConnectionEntry] = deque()
        self._in_use: Dict[int, _ConnectionEntry] = {}
        self._total = 0
        self._closed = False

        self._housekeeper = threading.Thread(
            target=self._housekeeping, name="opengauss-pool-housekeeper", daemon=True
        )

        self._initial_fill()
        self._housekeeper.start()

    def borrow(self, timeout: Optional[float] = None) -> Any:
        """
        Borrow a connection from the pool.

        If test_on_borrow is enabled, the connection will be validated with the
        configured SQL before being returned.
        """
        deadline = time.monotonic() + timeout if timeout else None
        while True:
            entry = self._try_acquire_available()
            if entry:
                if self._config.test_on_borrow and not self._validate(entry):
                    self._discard(entry)
                    continue
                entry.last_used = time.time()
                with self._condition:
                    self._in_use[id(entry.conn)] = entry
                return entry.conn

            to_create = self._reserve_slot_for_new()
            if to_create:
                entry = self._create_entry(reserved=True)
                if entry is None:
                    # Creation failed; avoid tight loop when DB is unavailable.
                    with self._condition:
                        if self._closed:
                            raise RuntimeError("Connection pool is closed")
                        if deadline is not None:
                            remaining = deadline - time.monotonic()
                            if remaining <= 0:
                                raise TimeoutError("Timed out waiting for connection")
                            self._condition.wait(timeout=min(0.1, remaining))
                        else:
                            self._condition.wait(timeout=0.1)
                    continue
                with self._condition:
                    self._in_use[id(entry.conn)] = entry
                return entry.conn

            # Wait for a return or until timed out.
            with self._condition:
                if self._closed:
                    raise RuntimeError("Connection pool is closed")
                if deadline is not None:
                    remaining = deadline - time.monotonic()
                    if remaining <= 0:
                        raise TimeoutError("Timed out waiting for connection")
                    self._condition.wait(timeout=remaining)
                else:
                    self._condition.wait()

    def return_connection(self, conn: Any, had_error: bool = False) -> None:
        """
        Return a connection to the pool.

        If had_error is True, the connection will be closed and removed to
        avoid reusing broken connections.
        """
        entry = None
        with self._condition:
            entry = self._in_use.pop(id(conn), None)
            if entry is None:
                # Unknown connection; close it defensively.
                try:
                    conn.close()
                finally:
                    return

        if had_error:
            self._discard(entry)
        else:
            entry.last_used = time.time()
            if self._should_discard(entry):
                self._discard(entry)
            else:
                with self._condition:
                    self._available.append(entry)
                    self._condition.notify()

    @contextmanager
    def connection(self, timeout: Optional[float] = None):
        """
        Context manager helper that returns the connection to the pool
        automatically and discards it if an exception is raised.
        """
        conn = self.borrow(timeout=timeout)
        try:
            yield conn
        except Exception:
            self.return_connection(conn, had_error=True)
            raise
        else:
            self.return_connection(conn)

    def close(self) -> None:
        """
        Close the pool and all managed connections.
        """
        with self._condition:
            if self._closed:
                return
            self._closed = True
            to_close = list(self._available)
            self._available.clear()
            to_close.extend(self._in_use.values())
            self._in_use.clear()
            self._condition.notify_all()
        for entry in to_close:
            try:
                entry.conn.close()
            except Exception:
                logger.debug("Failed closing connection during pool shutdown", exc_info=True)
        self._total = 0

    def stats(self) -> Dict[str, Any]:
        """
        Lightweight snapshot of pool counters for observability/testing.
        """
        with self._condition:
            return {
                "total": self._total,
                "available": len(self._available),
                "in_use": len(self._in_use),
                "closed": self._closed,
            }

    # Internal helpers -----------------------------------------------------

    def _initial_fill(self) -> None:
        for _ in range(self._config.min_size):
            entry = self._create_entry()
            if entry:
                with self._condition:
                    self._available.append(entry)
                    self._total += 1

    def _reserve_slot_for_new(self) -> bool:
        with self._condition:
            if self._closed:
                raise RuntimeError("Connection pool is closed")
            if self._total >= self._config.max_size:
                return False
            self._total += 1
            return True

    def _try_acquire_available(self) -> Optional[_ConnectionEntry]:
        now = time.time()
        to_discard: Deque[_ConnectionEntry] = deque()
        with self._condition:
            if self._closed:
                raise RuntimeError("Connection pool is closed")
            while self._available:
                entry = self._available.popleft()
                if self._should_discard(entry, now):
                    to_discard.append(entry)
                    continue
                return entry
        for entry in list(to_discard):
            self._discard(entry)
        return None

    def _create_entry(self, reserved: bool = False) -> Optional[_ConnectionEntry]:
        try:
            conn = py_opengauss.open(self._config.dsn, **(self._config.connect_kwargs or {}))
            now = time.time()
            return _ConnectionEntry(conn=conn, created_at=now, last_used=now, last_check=now)
        except Exception:
            logger.exception("Failed to create new py_opengauss connection")
            if reserved:
                with self._condition:
                    self._total = max(0, self._total - 1)
            return None

    def _validate(self, entry: _ConnectionEntry) -> bool:
        try:
            stmt = entry.conn.prepare(self._config.test_sql)
            stmt()
            entry.last_check = time.time()
            return True
        except Exception:
            logger.warning("Validation failed; discarding connection", exc_info=True)
            return False

    def _should_discard(self, entry: _ConnectionEntry, now: Optional[float] = None) -> bool:
        now = now or time.time()
        if self._config.max_lifetime and now - entry.created_at >= self._config.max_lifetime:
            return True
        if self._config.idle_timeout and now - entry.last_used >= self._config.idle_timeout:
            return True
        return False

    def _discard(self, entry: _ConnectionEntry) -> None:
        try:
            entry.conn.close()
        except Exception:
            logger.debug("Failed closing connection", exc_info=True)
        with self._condition:
            self._total = max(0, self._total - 1)
            self._condition.notify()

    def _housekeeping(self) -> None:
        interval = max(1.0, self._config.health_check_interval)
        while True:
            if self._closed:
                return
            self._perform_health_check()
            time.sleep(interval)

    def _perform_health_check(self) -> None:
        now = time.time()
        to_check: Deque[_ConnectionEntry] = deque()
        to_discard: Deque[_ConnectionEntry] = deque()
        with self._condition:
            if self._closed:
                return
            kept: Deque[_ConnectionEntry] = deque()
            while self._available:
                entry = self._available.popleft()
                if self._should_discard(entry, now):
                    to_discard.append(entry)
                    continue
                if (
                    self._config.keepalive
                    and self._config.keepalive_interval
                    and now - entry.last_check >= self._config.keepalive_interval
                ):
                    to_check.append(entry)
                else:
                    kept.append(entry)
            self._available = kept

        for entry in list(to_discard):
            self._discard(entry)

        # Run keepalive checks outside lock to avoid blocking borrowers.
        for entry in list(to_check):
            if self._validate(entry):
                with self._condition:
                    if not self._closed:
                        self._available.append(entry)
            else:
                self._discard(entry)

        # Ensure minimum pool size is preserved.
        with self._condition:
            needed = max(0, self._config.min_size - self._total)
        for _ in range(needed):
            entry = self._create_entry()
            if entry:
                with self._condition:
                    if self._closed:
                        entry.conn.close()
                        return
                    self._available.append(entry)
                    self._total += 1
                    self._condition.notify()
