| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438 |
- """
- Hardened connection pool for py_opengauss inspired by dbutils.pooled_db.
- Key differences from the lightweight pool:
- * Uses a proxy wrapper so connections are always returned (close/__del__/with).
- * Resets sessions on return (rollback + optional setsession).
- * Avoids deadlocks by using an RLock and separating discard-with-lock paths.
- * Supports blocking/non-blocking acquire with timeout and max usage cycling.
- * Uses monotonic clocks for all interval calculations to avoid wall-clock jumps.
- """
- from __future__ import annotations
- import logging
- import threading
- import time
- from collections import deque
- from contextlib import contextmanager
- from dataclasses import dataclass
- from typing import Any, Deque, Dict, Optional, Sequence
- import py_opengauss
- logger = logging.getLogger(__name__)
- class PoolError(Exception):
- """Base pool error."""
- class PoolClosedError(PoolError):
- """Pool has been closed."""
- class PoolExhaustedError(PoolError):
- """Pool is at capacity and blocking is disabled or timed out."""
- @dataclass
- class ConnectionPoolConfig:
- dsn: str
- min_size: int = 1
- max_size: int = 10
- blocking: bool = True
- acquire_timeout: Optional[float] = None
- idle_timeout: float = 300.0
- max_lifetime: float = 3600.0
- max_usage: Optional[int] = None
- test_on_borrow: bool = True
- test_sql: str = "SELECT 1"
- keepalive: bool = True
- keepalive_interval: float = 60.0
- health_check_interval: float = 30.0
- reset_on_return: bool = True
- setsession: Optional[Sequence[str]] = None
- connect_kwargs: Optional[Dict[str, Any]] = None
- @dataclass
- class _ConnectionEntry:
- conn: Any
- created_at: float
- last_used: float
- last_check: float
- usage_count: int = 0
- class _PooledConnectionProxy:
- """Lightweight proxy that returns connections to the pool on close/del."""
- def __init__(self, pool: "OpenGaussConnectionPool", entry: _ConnectionEntry):
- self._pool = pool
- self._entry = entry
- self._returned = False
- def __getattr__(self, item: str) -> Any:
- return getattr(self._entry.conn, item)
- def __enter__(self) -> "_PooledConnectionProxy":
- return self
- def __exit__(self, exc_type, exc, tb) -> None:
- self.close(broken=bool(exc_type))
- def close(self, broken: bool = False) -> None:
- if self._returned:
- return
- self._returned = True
- self._pool._return_entry(self._entry, broken=broken)
- def mark_broken(self) -> None:
- """Explicitly mark this connection as broken so it gets discarded."""
- self.close(broken=True)
- def __del__(self):
- # Best-effort return to pool to avoid leaks if user forgets to close.
- try:
- self.close()
- except Exception:
- pass
- def __repr__(self) -> str:
- return f"<PooledOpenGaussConnection id={id(self._entry.conn)} returned={self._returned}>"
- class OpenGaussConnectionPool:
- """
- Hardened, thread-safe connection pool for py_opengauss.
- Connections are wrapped in a proxy that ensures return to the pool on close
- or object finalization. Pool operations rely on monotonic clocks to avoid
- issues from system clock jumps.
- """
- def __init__(self, config: ConnectionPoolConfig):
- if config.min_size < 0 or config.max_size <= 0:
- raise ValueError("Pool sizes must be positive")
- if config.min_size > config.max_size:
- raise ValueError("min_size cannot be greater than max_size")
- self._config = config
- self._lock = threading.RLock()
- self._condition = threading.Condition(self._lock)
- self._available: Deque[_ConnectionEntry] = deque()
- self._in_use: Dict[int, _ConnectionEntry] = {}
- self._total = 0
- self._create_failures = 0
- self._discarded = 0
- self._closed = False
- self._stop_event = threading.Event()
- self._housekeeper = threading.Thread(
- target=self._housekeeping, name="opengauss-pool-housekeeper", daemon=True
- )
- self._initial_fill()
- self._housekeeper.start()
- def borrow(self, timeout: Optional[float] = None) -> _PooledConnectionProxy:
- """
- Borrow a connection from the pool.
- If test_on_borrow is enabled, the connection will be validated with the
- configured SQL before being returned. When blocking is False, pool
- exhaustion immediately raises PoolExhaustedError.
- """
- deadline = None
- effective_timeout = timeout
- if effective_timeout is None:
- effective_timeout = self._config.acquire_timeout
- if effective_timeout is not None:
- deadline = time.monotonic() + effective_timeout
- while True:
- entry = self._try_acquire_available()
- if entry:
- proxy = self._prepare_borrow(entry)
- if proxy:
- return proxy
- continue
- to_create = self._reserve_slot_for_new()
- if to_create:
- entry = self._create_entry(reserved=True)
- if entry is None:
- # Creation failed; avoid tight loop when DB is unavailable.
- self._wait_for_availability(deadline)
- continue
- proxy = self._prepare_borrow(entry)
- if proxy:
- return proxy
- continue
- # Pool is at capacity and nothing available.
- if not self._config.blocking:
- raise PoolExhaustedError("Pool is at capacity and blocking is disabled")
- self._wait_for_availability(deadline)
- def return_connection(self, conn: Any, had_error: bool = False) -> None:
- """
- Return a connection (or proxy) to the pool.
- If had_error is True, the connection will be closed and removed to
- avoid reusing broken connections.
- """
- entry = None
- if isinstance(conn, _PooledConnectionProxy):
- conn.close(broken=had_error)
- return
- with self._condition:
- entry = self._in_use.pop(id(conn), None)
- if entry is None:
- try:
- conn.close()
- finally:
- return
- self._return_entry(entry, broken=had_error)
- @contextmanager
- def connection(self, timeout: Optional[float] = None):
- """
- Context manager helper that returns the connection to the pool
- automatically and discards it if an exception is raised.
- """
- proxy = self.borrow(timeout=timeout)
- try:
- yield proxy
- except Exception:
- proxy.close(broken=True)
- raise
- else:
- proxy.close()
- def close(self) -> None:
- """Close the pool and all managed connections."""
- with self._condition:
- if self._closed:
- return
- self._closed = True
- to_close = list(self._available)
- self._available.clear()
- to_close.extend(self._in_use.values())
- self._in_use.clear()
- self._condition.notify_all()
- self._stop_event.set()
- for entry in to_close:
- self._close_entry(entry)
- self._total = 0
- if self._housekeeper.is_alive():
- self._housekeeper.join(timeout=1.0)
- def stats(self) -> Dict[str, Any]:
- """Lightweight snapshot of pool counters for observability/testing."""
- with self._condition:
- return {
- "total": self._total,
- "available": len(self._available),
- "in_use": len(self._in_use),
- "closed": self._closed,
- "create_failures": self._create_failures,
- "discarded": self._discarded,
- }
- # Internal helpers -----------------------------------------------------
- def _prepare_borrow(self, entry: _ConnectionEntry) -> Optional[_PooledConnectionProxy]:
- if self._config.test_on_borrow and not self._validate(entry):
- self._discard_entry(entry)
- return None
- entry.last_used = time.monotonic()
- entry.usage_count += 1
- with self._condition:
- self._in_use[id(entry.conn)] = entry
- return _PooledConnectionProxy(self, entry)
- def _initial_fill(self) -> None:
- for _ in range(self._config.min_size):
- entry = self._create_entry()
- if entry:
- with self._condition:
- self._available.append(entry)
- self._total += 1
- def _reserve_slot_for_new(self) -> bool:
- with self._condition:
- if self._closed:
- raise PoolClosedError("Connection pool is closed")
- if self._total >= self._config.max_size:
- return False
- self._total += 1
- return True
- def _try_acquire_available(self) -> Optional[_ConnectionEntry]:
- now = time.monotonic()
- with self._condition:
- if self._closed:
- raise PoolClosedError("Connection pool is closed")
- while self._available:
- entry = self._available.popleft()
- if self._should_discard(entry, now):
- self._discard_entry_locked(entry)
- continue
- return entry
- return None
- def _create_entry(self, reserved: bool = False) -> Optional[_ConnectionEntry]:
- try:
- conn = py_opengauss.open(self._config.dsn, **(self._config.connect_kwargs or {}))
- now = time.monotonic()
- return _ConnectionEntry(conn=conn, created_at=now, last_used=now, last_check=now)
- except Exception:
- logger.exception("Failed to create new py_opengauss connection")
- self._create_failures += 1
- if reserved:
- with self._condition:
- self._total = max(0, self._total - 1)
- self._condition.notify_all()
- return None
- def _validate(self, entry: _ConnectionEntry) -> bool:
- try:
- stmt = entry.conn.prepare(self._config.test_sql)
- stmt()
- entry.last_check = time.monotonic()
- return True
- except Exception:
- logger.warning("Validation failed; discarding connection", exc_info=True)
- return False
- def _should_discard(self, entry: _ConnectionEntry, now: Optional[float] = None) -> bool:
- now = now or time.monotonic()
- if self._config.max_lifetime and now - entry.created_at >= self._config.max_lifetime:
- return True
- if self._config.idle_timeout and now - entry.last_used >= self._config.idle_timeout:
- return True
- if self._config.max_usage and entry.usage_count >= self._config.max_usage:
- return True
- return False
- def _reset_connection(self, entry: _ConnectionEntry) -> bool:
- if not self._config.reset_on_return and not self._config.setsession:
- return True
- try:
- if self._config.reset_on_return:
- rollback = getattr(entry.conn, "rollback", None)
- if callable(rollback):
- rollback()
- else:
- entry.conn.execute("ROLLBACK")
- if self._config.setsession:
- for sql in self._config.setsession:
- stmt = entry.conn.prepare(sql)
- stmt()
- return True
- except Exception:
- logger.warning("Failed to reset connection; discarding", exc_info=True)
- return False
- def _close_entry(self, entry: _ConnectionEntry) -> None:
- try:
- entry.conn.close()
- except Exception:
- logger.debug("Failed closing connection", exc_info=True)
- def _discard_entry_locked(self, entry: _ConnectionEntry) -> None:
- """Discard while holding the pool lock."""
- self._close_entry(entry)
- self._total = max(0, self._total - 1)
- self._discarded += 1
- self._condition.notify_all()
- def _discard_entry(self, entry: _ConnectionEntry) -> None:
- with self._condition:
- self._discard_entry_locked(entry)
- def _return_entry(self, entry: _ConnectionEntry, broken: bool = False) -> None:
- with self._condition:
- stored = self._in_use.pop(id(entry.conn), None)
- if stored is None:
- return
- entry = stored
- if broken or not self._reset_connection(entry) or self._should_discard(entry):
- self._discard_entry(entry)
- return
- with self._condition:
- if self._closed:
- self._discard_entry_locked(entry)
- return
- entry.last_used = time.monotonic()
- self._available.append(entry)
- self._condition.notify_all()
- def _wait_for_availability(self, deadline: Optional[float]) -> None:
- if deadline is not None:
- remaining = deadline - time.monotonic()
- if remaining <= 0:
- raise PoolExhaustedError("Timed out waiting for connection from pool")
- with self._condition:
- if self._closed:
- raise PoolClosedError("Connection pool is closed")
- self._condition.wait(timeout=remaining)
- else:
- with self._condition:
- if self._closed:
- raise PoolClosedError("Connection pool is closed")
- self._condition.wait()
- def _housekeeping(self) -> None:
- interval = max(1.0, self._config.health_check_interval)
- while not self._stop_event.wait(interval):
- self._perform_health_check()
- def _perform_health_check(self) -> None:
- now = time.monotonic()
- to_check: Deque[_ConnectionEntry] = deque()
- with self._condition:
- if self._closed:
- return
- kept: Deque[_ConnectionEntry] = deque()
- while self._available:
- entry = self._available.popleft()
- if self._should_discard(entry, now):
- self._discard_entry_locked(entry)
- continue
- if (
- self._config.keepalive
- and self._config.keepalive_interval
- and now - entry.last_check >= self._config.keepalive_interval
- ):
- to_check.append(entry)
- else:
- kept.append(entry)
- self._available = kept
- for entry in list(to_check):
- if self._validate(entry):
- with self._condition:
- if not self._closed:
- entry.last_used = time.monotonic()
- self._available.append(entry)
- else:
- self._discard_entry(entry)
- with self._condition:
- needed = max(0, self._config.min_size - self._total)
- for _ in range(needed):
- entry = self._create_entry()
- if entry:
- with self._condition:
- if self._closed:
- self._close_entry(entry)
- return
- self._available.append(entry)
- self._total += 1
- self._condition.notify_all()
|