""" Hardened connection pool for py_opengauss inspired by dbutils.pooled_db. Key differences from the lightweight pool: * Uses a proxy wrapper so connections are always returned (close/__del__/with). * Resets sessions on return (rollback + optional setsession). * Avoids deadlocks by using an RLock and separating discard-with-lock paths. * Supports blocking/non-blocking acquire with timeout and max usage cycling. * Uses monotonic clocks for all interval calculations to avoid wall-clock jumps. """ from __future__ import annotations import logging import threading import time from collections import deque from contextlib import contextmanager from dataclasses import dataclass from typing import Any, Deque, Dict, Optional, Sequence import py_opengauss logger = logging.getLogger(__name__) class PoolError(Exception): """Base pool error.""" class PoolClosedError(PoolError): """Pool has been closed.""" class PoolExhaustedError(PoolError): """Pool is at capacity and blocking is disabled or timed out.""" @dataclass class ConnectionPoolConfig: dsn: str min_size: int = 1 max_size: int = 10 blocking: bool = True acquire_timeout: Optional[float] = None idle_timeout: float = 300.0 max_lifetime: float = 3600.0 max_usage: Optional[int] = None test_on_borrow: bool = True test_sql: str = "SELECT 1" keepalive: bool = True keepalive_interval: float = 60.0 health_check_interval: float = 30.0 reset_on_return: bool = True setsession: Optional[Sequence[str]] = None connect_kwargs: Optional[Dict[str, Any]] = None @dataclass class _ConnectionEntry: conn: Any created_at: float last_used: float last_check: float usage_count: int = 0 class _PooledConnectionProxy: """Lightweight proxy that returns connections to the pool on close/del.""" def __init__(self, pool: "OpenGaussConnectionPool", entry: _ConnectionEntry): self._pool = pool self._entry = entry self._returned = False def __getattr__(self, item: str) -> Any: return getattr(self._entry.conn, item) def __enter__(self) -> "_PooledConnectionProxy": return self def __exit__(self, exc_type, exc, tb) -> None: self.close(broken=bool(exc_type)) def close(self, broken: bool = False) -> None: if self._returned: return self._returned = True self._pool._return_entry(self._entry, broken=broken) def mark_broken(self) -> None: """Explicitly mark this connection as broken so it gets discarded.""" self.close(broken=True) def __del__(self): # Best-effort return to pool to avoid leaks if user forgets to close. try: self.close() except Exception: pass def __repr__(self) -> str: return f"" class OpenGaussConnectionPool: """ Hardened, thread-safe connection pool for py_opengauss. Connections are wrapped in a proxy that ensures return to the pool on close or object finalization. Pool operations rely on monotonic clocks to avoid issues from system clock jumps. """ def __init__(self, config: ConnectionPoolConfig): if config.min_size < 0 or config.max_size <= 0: raise ValueError("Pool sizes must be positive") if config.min_size > config.max_size: raise ValueError("min_size cannot be greater than max_size") self._config = config self._lock = threading.RLock() self._condition = threading.Condition(self._lock) self._available: Deque[_ConnectionEntry] = deque() self._in_use: Dict[int, _ConnectionEntry] = {} self._total = 0 self._create_failures = 0 self._discarded = 0 self._closed = False self._stop_event = threading.Event() self._housekeeper = threading.Thread( target=self._housekeeping, name="opengauss-pool-housekeeper", daemon=True ) self._initial_fill() self._housekeeper.start() def borrow(self, timeout: Optional[float] = None) -> _PooledConnectionProxy: """ Borrow a connection from the pool. If test_on_borrow is enabled, the connection will be validated with the configured SQL before being returned. When blocking is False, pool exhaustion immediately raises PoolExhaustedError. """ deadline = None effective_timeout = timeout if effective_timeout is None: effective_timeout = self._config.acquire_timeout if effective_timeout is not None: deadline = time.monotonic() + effective_timeout while True: entry = self._try_acquire_available() if entry: proxy = self._prepare_borrow(entry) if proxy: return proxy continue to_create = self._reserve_slot_for_new() if to_create: entry = self._create_entry(reserved=True) if entry is None: # Creation failed; avoid tight loop when DB is unavailable. self._wait_for_availability(deadline) continue proxy = self._prepare_borrow(entry) if proxy: return proxy continue # Pool is at capacity and nothing available. if not self._config.blocking: raise PoolExhaustedError("Pool is at capacity and blocking is disabled") self._wait_for_availability(deadline) def return_connection(self, conn: Any, had_error: bool = False) -> None: """ Return a connection (or proxy) to the pool. If had_error is True, the connection will be closed and removed to avoid reusing broken connections. """ entry = None if isinstance(conn, _PooledConnectionProxy): conn.close(broken=had_error) return with self._condition: entry = self._in_use.pop(id(conn), None) if entry is None: try: conn.close() finally: return self._return_entry(entry, broken=had_error) @contextmanager def connection(self, timeout: Optional[float] = None): """ Context manager helper that returns the connection to the pool automatically and discards it if an exception is raised. """ proxy = self.borrow(timeout=timeout) try: yield proxy except Exception: proxy.close(broken=True) raise else: proxy.close() def close(self) -> None: """Close the pool and all managed connections.""" with self._condition: if self._closed: return self._closed = True to_close = list(self._available) self._available.clear() to_close.extend(self._in_use.values()) self._in_use.clear() self._condition.notify_all() self._stop_event.set() for entry in to_close: self._close_entry(entry) self._total = 0 if self._housekeeper.is_alive(): self._housekeeper.join(timeout=1.0) def stats(self) -> Dict[str, Any]: """Lightweight snapshot of pool counters for observability/testing.""" with self._condition: return { "total": self._total, "available": len(self._available), "in_use": len(self._in_use), "closed": self._closed, "create_failures": self._create_failures, "discarded": self._discarded, } # Internal helpers ----------------------------------------------------- def _prepare_borrow(self, entry: _ConnectionEntry) -> Optional[_PooledConnectionProxy]: if self._config.test_on_borrow and not self._validate(entry): self._discard_entry(entry) return None entry.last_used = time.monotonic() entry.usage_count += 1 with self._condition: self._in_use[id(entry.conn)] = entry return _PooledConnectionProxy(self, entry) def _initial_fill(self) -> None: for _ in range(self._config.min_size): entry = self._create_entry() if entry: with self._condition: self._available.append(entry) self._total += 1 def _reserve_slot_for_new(self) -> bool: with self._condition: if self._closed: raise PoolClosedError("Connection pool is closed") if self._total >= self._config.max_size: return False self._total += 1 return True def _try_acquire_available(self) -> Optional[_ConnectionEntry]: now = time.monotonic() with self._condition: if self._closed: raise PoolClosedError("Connection pool is closed") while self._available: entry = self._available.popleft() if self._should_discard(entry, now): self._discard_entry_locked(entry) continue return entry return None def _create_entry(self, reserved: bool = False) -> Optional[_ConnectionEntry]: try: conn = py_opengauss.open(self._config.dsn, **(self._config.connect_kwargs or {})) now = time.monotonic() return _ConnectionEntry(conn=conn, created_at=now, last_used=now, last_check=now) except Exception: logger.exception("Failed to create new py_opengauss connection") self._create_failures += 1 if reserved: with self._condition: self._total = max(0, self._total - 1) self._condition.notify_all() return None def _validate(self, entry: _ConnectionEntry) -> bool: try: stmt = entry.conn.prepare(self._config.test_sql) stmt() entry.last_check = time.monotonic() return True except Exception: logger.warning("Validation failed; discarding connection", exc_info=True) return False def _should_discard(self, entry: _ConnectionEntry, now: Optional[float] = None) -> bool: now = now or time.monotonic() if self._config.max_lifetime and now - entry.created_at >= self._config.max_lifetime: return True if self._config.idle_timeout and now - entry.last_used >= self._config.idle_timeout: return True if self._config.max_usage and entry.usage_count >= self._config.max_usage: return True return False def _reset_connection(self, entry: _ConnectionEntry) -> bool: if not self._config.reset_on_return and not self._config.setsession: return True try: if self._config.reset_on_return: rollback = getattr(entry.conn, "rollback", None) if callable(rollback): rollback() else: entry.conn.execute("ROLLBACK") if self._config.setsession: for sql in self._config.setsession: stmt = entry.conn.prepare(sql) stmt() return True except Exception: logger.warning("Failed to reset connection; discarding", exc_info=True) return False def _close_entry(self, entry: _ConnectionEntry) -> None: try: entry.conn.close() except Exception: logger.debug("Failed closing connection", exc_info=True) def _discard_entry_locked(self, entry: _ConnectionEntry) -> None: """Discard while holding the pool lock.""" self._close_entry(entry) self._total = max(0, self._total - 1) self._discarded += 1 self._condition.notify_all() def _discard_entry(self, entry: _ConnectionEntry) -> None: with self._condition: self._discard_entry_locked(entry) def _return_entry(self, entry: _ConnectionEntry, broken: bool = False) -> None: with self._condition: stored = self._in_use.pop(id(entry.conn), None) if stored is None: return entry = stored if broken or not self._reset_connection(entry) or self._should_discard(entry): self._discard_entry(entry) return with self._condition: if self._closed: self._discard_entry_locked(entry) return entry.last_used = time.monotonic() self._available.append(entry) self._condition.notify_all() def _wait_for_availability(self, deadline: Optional[float]) -> None: if deadline is not None: remaining = deadline - time.monotonic() if remaining <= 0: raise PoolExhaustedError("Timed out waiting for connection from pool") with self._condition: if self._closed: raise PoolClosedError("Connection pool is closed") self._condition.wait(timeout=remaining) else: with self._condition: if self._closed: raise PoolClosedError("Connection pool is closed") self._condition.wait() def _housekeeping(self) -> None: interval = max(1.0, self._config.health_check_interval) while not self._stop_event.wait(interval): self._perform_health_check() def _perform_health_check(self) -> None: now = time.monotonic() to_check: Deque[_ConnectionEntry] = deque() with self._condition: if self._closed: return kept: Deque[_ConnectionEntry] = deque() while self._available: entry = self._available.popleft() if self._should_discard(entry, now): self._discard_entry_locked(entry) continue if ( self._config.keepalive and self._config.keepalive_interval and now - entry.last_check >= self._config.keepalive_interval ): to_check.append(entry) else: kept.append(entry) self._available = kept for entry in list(to_check): if self._validate(entry): with self._condition: if not self._closed: entry.last_used = time.monotonic() self._available.append(entry) else: self._discard_entry(entry) with self._condition: needed = max(0, self._config.min_size - self._total) for _ in range(needed): entry = self._create_entry() if entry: with self._condition: if self._closed: self._close_entry(entry) return self._available.append(entry) self._total += 1 self._condition.notify_all()