| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333 |
- """
- Lightweight connection pool built on top of py_opengauss.
- The pool keeps a configurable number of connections ready, validates them on
- borrow when requested, performs periodic health checks/keepalives, and discards
- connections that fail SQL execution or exceed idle/lifetime thresholds.
- """
- from __future__ import annotations
- import logging
- import threading
- import time
- from collections import deque
- from contextlib import contextmanager
- from dataclasses import dataclass
- from typing import Any, Deque, Dict, Optional
- import py_opengauss
- logger = logging.getLogger(__name__)
- @dataclass
- class ConnectionPoolConfig:
- dsn: str
- min_size: int = 1
- max_size: int = 10
- idle_timeout: float = 300.0
- max_lifetime: float = 3600.0
- test_on_borrow: bool = True
- test_sql: str = "SELECT 1"
- keepalive: bool = True
- keepalive_interval: float = 60.0
- health_check_interval: float = 30.0
- connect_kwargs: Optional[Dict[str, Any]] = None
- @dataclass
- class _ConnectionEntry:
- conn: Any
- created_at: float
- last_used: float
- last_check: float
- class OpenGaussConnectionPool:
- """
- Thread-safe connection pool for py_opengauss.
- Borrowed connections are returned directly without wrappers. When an error
- is detected during SQL execution, callers should mark the connection as
- broken via `return_connection(conn, had_error=True)` to evict it quickly.
- """
- def __init__(self, config: ConnectionPoolConfig):
- if config.min_size < 0 or config.max_size <= 0:
- raise ValueError("Pool sizes must be positive")
- if config.min_size > config.max_size:
- raise ValueError("min_size cannot be greater than max_size")
- self._config = config
- self._lock = threading.Lock()
- self._condition = threading.Condition(self._lock)
- self._available: Deque[_ConnectionEntry] = deque()
- self._in_use: Dict[int, _ConnectionEntry] = {}
- self._total = 0
- self._closed = False
- self._housekeeper = threading.Thread(
- target=self._housekeeping, name="opengauss-pool-housekeeper", daemon=True
- )
- self._initial_fill()
- self._housekeeper.start()
- def borrow(self, timeout: Optional[float] = None) -> Any:
- """
- Borrow a connection from the pool.
- If test_on_borrow is enabled, the connection will be validated with the
- configured SQL before being returned.
- """
- deadline = time.monotonic() + timeout if timeout else None
- while True:
- entry = self._try_acquire_available()
- if entry:
- if self._config.test_on_borrow and not self._validate(entry):
- self._discard(entry)
- continue
- entry.last_used = time.time()
- with self._condition:
- self._in_use[id(entry.conn)] = entry
- return entry.conn
- to_create = self._reserve_slot_for_new()
- if to_create:
- entry = self._create_entry(reserved=True)
- if entry is None:
- # Creation failed; avoid tight loop when DB is unavailable.
- with self._condition:
- if self._closed:
- raise RuntimeError("Connection pool is closed")
- if deadline is not None:
- remaining = deadline - time.monotonic()
- if remaining <= 0:
- raise TimeoutError("Timed out waiting for connection")
- self._condition.wait(timeout=min(0.1, remaining))
- else:
- self._condition.wait(timeout=0.1)
- continue
- with self._condition:
- self._in_use[id(entry.conn)] = entry
- return entry.conn
- # Wait for a return or until timed out.
- with self._condition:
- if self._closed:
- raise RuntimeError("Connection pool is closed")
- if deadline is not None:
- remaining = deadline - time.monotonic()
- if remaining <= 0:
- raise TimeoutError("Timed out waiting for connection")
- self._condition.wait(timeout=remaining)
- else:
- self._condition.wait()
- def return_connection(self, conn: Any, had_error: bool = False) -> None:
- """
- Return a connection to the pool.
- If had_error is True, the connection will be closed and removed to
- avoid reusing broken connections.
- """
- entry = None
- with self._condition:
- entry = self._in_use.pop(id(conn), None)
- if entry is None:
- # Unknown connection; close it defensively.
- try:
- conn.close()
- finally:
- return
- if had_error:
- self._discard(entry)
- else:
- entry.last_used = time.time()
- if self._should_discard(entry):
- self._discard(entry)
- else:
- with self._condition:
- self._available.append(entry)
- self._condition.notify()
- @contextmanager
- def connection(self, timeout: Optional[float] = None):
- """
- Context manager helper that returns the connection to the pool
- automatically and discards it if an exception is raised.
- """
- conn = self.borrow(timeout=timeout)
- try:
- yield conn
- except Exception:
- self.return_connection(conn, had_error=True)
- raise
- else:
- self.return_connection(conn)
- def close(self) -> None:
- """
- Close the pool and all managed connections.
- """
- with self._condition:
- if self._closed:
- return
- self._closed = True
- to_close = list(self._available)
- self._available.clear()
- to_close.extend(self._in_use.values())
- self._in_use.clear()
- self._condition.notify_all()
- for entry in to_close:
- try:
- entry.conn.close()
- except Exception:
- logger.debug("Failed closing connection during pool shutdown", exc_info=True)
- self._total = 0
- def stats(self) -> Dict[str, Any]:
- """
- Lightweight snapshot of pool counters for observability/testing.
- """
- with self._condition:
- return {
- "total": self._total,
- "available": len(self._available),
- "in_use": len(self._in_use),
- "closed": self._closed,
- }
- # Internal helpers -----------------------------------------------------
- def _initial_fill(self) -> None:
- for _ in range(self._config.min_size):
- entry = self._create_entry()
- if entry:
- with self._condition:
- self._available.append(entry)
- self._total += 1
- def _reserve_slot_for_new(self) -> bool:
- with self._condition:
- if self._closed:
- raise RuntimeError("Connection pool is closed")
- if self._total >= self._config.max_size:
- return False
- self._total += 1
- return True
- def _try_acquire_available(self) -> Optional[_ConnectionEntry]:
- now = time.time()
- to_discard: Deque[_ConnectionEntry] = deque()
- with self._condition:
- if self._closed:
- raise RuntimeError("Connection pool is closed")
- while self._available:
- entry = self._available.popleft()
- if self._should_discard(entry, now):
- to_discard.append(entry)
- continue
- return entry
- for entry in list(to_discard):
- self._discard(entry)
- return None
- def _create_entry(self, reserved: bool = False) -> Optional[_ConnectionEntry]:
- try:
- conn = py_opengauss.open(self._config.dsn, **(self._config.connect_kwargs or {}))
- now = time.time()
- return _ConnectionEntry(conn=conn, created_at=now, last_used=now, last_check=now)
- except Exception:
- logger.exception("Failed to create new py_opengauss connection")
- if reserved:
- with self._condition:
- self._total = max(0, self._total - 1)
- return None
- def _validate(self, entry: _ConnectionEntry) -> bool:
- try:
- stmt = entry.conn.prepare(self._config.test_sql)
- stmt()
- entry.last_check = time.time()
- return True
- except Exception:
- logger.warning("Validation failed; discarding connection", exc_info=True)
- return False
- def _should_discard(self, entry: _ConnectionEntry, now: Optional[float] = None) -> bool:
- now = now or time.time()
- if self._config.max_lifetime and now - entry.created_at >= self._config.max_lifetime:
- return True
- if self._config.idle_timeout and now - entry.last_used >= self._config.idle_timeout:
- return True
- return False
- def _discard(self, entry: _ConnectionEntry) -> None:
- try:
- entry.conn.close()
- except Exception:
- logger.debug("Failed closing connection", exc_info=True)
- with self._condition:
- self._total = max(0, self._total - 1)
- self._condition.notify()
- def _housekeeping(self) -> None:
- interval = max(1.0, self._config.health_check_interval)
- while True:
- if self._closed:
- return
- self._perform_health_check()
- time.sleep(interval)
- def _perform_health_check(self) -> None:
- now = time.time()
- to_check: Deque[_ConnectionEntry] = deque()
- to_discard: Deque[_ConnectionEntry] = deque()
- with self._condition:
- if self._closed:
- return
- kept: Deque[_ConnectionEntry] = deque()
- while self._available:
- entry = self._available.popleft()
- if self._should_discard(entry, now):
- to_discard.append(entry)
- continue
- if (
- self._config.keepalive
- and self._config.keepalive_interval
- and now - entry.last_check >= self._config.keepalive_interval
- ):
- to_check.append(entry)
- else:
- kept.append(entry)
- self._available = kept
- for entry in list(to_discard):
- self._discard(entry)
- # Run keepalive checks outside lock to avoid blocking borrowers.
- for entry in list(to_check):
- if self._validate(entry):
- with self._condition:
- if not self._closed:
- self._available.append(entry)
- else:
- self._discard(entry)
- # Ensure minimum pool size is preserved.
- with self._condition:
- needed = max(0, self._config.min_size - self._total)
- for _ in range(needed):
- entry = self._create_entry()
- if entry:
- with self._condition:
- if self._closed:
- entry.conn.close()
- return
- self._available.append(entry)
- self._total += 1
- self._condition.notify()
|