opengauss_pool_hardened.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438
  1. """
  2. Hardened connection pool for py_opengauss inspired by dbutils.pooled_db.
  3. Key differences from the lightweight pool:
  4. * Uses a proxy wrapper so connections are always returned (close/__del__/with).
  5. * Resets sessions on return (rollback + optional setsession).
  6. * Avoids deadlocks by using an RLock and separating discard-with-lock paths.
  7. * Supports blocking/non-blocking acquire with timeout and max usage cycling.
  8. * Uses monotonic clocks for all interval calculations to avoid wall-clock jumps.
  9. """
  10. from __future__ import annotations
  11. import logging
  12. import threading
  13. import time
  14. from collections import deque
  15. from contextlib import contextmanager
  16. from dataclasses import dataclass
  17. from typing import Any, Deque, Dict, Optional, Sequence
  18. import py_opengauss
  19. logger = logging.getLogger(__name__)
  20. class PoolError(Exception):
  21. """Base pool error."""
  22. class PoolClosedError(PoolError):
  23. """Pool has been closed."""
  24. class PoolExhaustedError(PoolError):
  25. """Pool is at capacity and blocking is disabled or timed out."""
  26. @dataclass
  27. class ConnectionPoolConfig:
  28. dsn: str
  29. min_size: int = 1
  30. max_size: int = 10
  31. blocking: bool = True
  32. acquire_timeout: Optional[float] = None
  33. idle_timeout: float = 300.0
  34. max_lifetime: float = 3600.0
  35. max_usage: Optional[int] = None
  36. test_on_borrow: bool = True
  37. test_sql: str = "SELECT 1"
  38. keepalive: bool = True
  39. keepalive_interval: float = 60.0
  40. health_check_interval: float = 30.0
  41. reset_on_return: bool = True
  42. setsession: Optional[Sequence[str]] = None
  43. connect_kwargs: Optional[Dict[str, Any]] = None
  44. @dataclass
  45. class _ConnectionEntry:
  46. conn: Any
  47. created_at: float
  48. last_used: float
  49. last_check: float
  50. usage_count: int = 0
  51. class _PooledConnectionProxy:
  52. """Lightweight proxy that returns connections to the pool on close/del."""
  53. def __init__(self, pool: "OpenGaussConnectionPool", entry: _ConnectionEntry):
  54. self._pool = pool
  55. self._entry = entry
  56. self._returned = False
  57. def __getattr__(self, item: str) -> Any:
  58. return getattr(self._entry.conn, item)
  59. def __enter__(self) -> "_PooledConnectionProxy":
  60. return self
  61. def __exit__(self, exc_type, exc, tb) -> None:
  62. self.close(broken=bool(exc_type))
  63. def close(self, broken: bool = False) -> None:
  64. if self._returned:
  65. return
  66. self._returned = True
  67. self._pool._return_entry(self._entry, broken=broken)
  68. def mark_broken(self) -> None:
  69. """Explicitly mark this connection as broken so it gets discarded."""
  70. self.close(broken=True)
  71. def __del__(self):
  72. # Best-effort return to pool to avoid leaks if user forgets to close.
  73. try:
  74. self.close()
  75. except Exception:
  76. pass
  77. def __repr__(self) -> str:
  78. return f"<PooledOpenGaussConnection id={id(self._entry.conn)} returned={self._returned}>"
  79. class OpenGaussConnectionPool:
  80. """
  81. Hardened, thread-safe connection pool for py_opengauss.
  82. Connections are wrapped in a proxy that ensures return to the pool on close
  83. or object finalization. Pool operations rely on monotonic clocks to avoid
  84. issues from system clock jumps.
  85. """
  86. def __init__(self, config: ConnectionPoolConfig):
  87. if config.min_size < 0 or config.max_size <= 0:
  88. raise ValueError("Pool sizes must be positive")
  89. if config.min_size > config.max_size:
  90. raise ValueError("min_size cannot be greater than max_size")
  91. self._config = config
  92. self._lock = threading.RLock()
  93. self._condition = threading.Condition(self._lock)
  94. self._available: Deque[_ConnectionEntry] = deque()
  95. self._in_use: Dict[int, _ConnectionEntry] = {}
  96. self._total = 0
  97. self._create_failures = 0
  98. self._discarded = 0
  99. self._closed = False
  100. self._stop_event = threading.Event()
  101. self._housekeeper = threading.Thread(
  102. target=self._housekeeping, name="opengauss-pool-housekeeper", daemon=True
  103. )
  104. self._initial_fill()
  105. self._housekeeper.start()
  106. def borrow(self, timeout: Optional[float] = None) -> _PooledConnectionProxy:
  107. """
  108. Borrow a connection from the pool.
  109. If test_on_borrow is enabled, the connection will be validated with the
  110. configured SQL before being returned. When blocking is False, pool
  111. exhaustion immediately raises PoolExhaustedError.
  112. """
  113. deadline = None
  114. effective_timeout = timeout
  115. if effective_timeout is None:
  116. effective_timeout = self._config.acquire_timeout
  117. if effective_timeout is not None:
  118. deadline = time.monotonic() + effective_timeout
  119. while True:
  120. entry = self._try_acquire_available()
  121. if entry:
  122. proxy = self._prepare_borrow(entry)
  123. if proxy:
  124. return proxy
  125. continue
  126. to_create = self._reserve_slot_for_new()
  127. if to_create:
  128. entry = self._create_entry(reserved=True)
  129. if entry is None:
  130. # Creation failed; avoid tight loop when DB is unavailable.
  131. self._wait_for_availability(deadline)
  132. continue
  133. proxy = self._prepare_borrow(entry)
  134. if proxy:
  135. return proxy
  136. continue
  137. # Pool is at capacity and nothing available.
  138. if not self._config.blocking:
  139. raise PoolExhaustedError("Pool is at capacity and blocking is disabled")
  140. self._wait_for_availability(deadline)
  141. def return_connection(self, conn: Any, had_error: bool = False) -> None:
  142. """
  143. Return a connection (or proxy) to the pool.
  144. If had_error is True, the connection will be closed and removed to
  145. avoid reusing broken connections.
  146. """
  147. entry = None
  148. if isinstance(conn, _PooledConnectionProxy):
  149. conn.close(broken=had_error)
  150. return
  151. with self._condition:
  152. entry = self._in_use.pop(id(conn), None)
  153. if entry is None:
  154. try:
  155. conn.close()
  156. finally:
  157. return
  158. self._return_entry(entry, broken=had_error)
  159. @contextmanager
  160. def connection(self, timeout: Optional[float] = None):
  161. """
  162. Context manager helper that returns the connection to the pool
  163. automatically and discards it if an exception is raised.
  164. """
  165. proxy = self.borrow(timeout=timeout)
  166. try:
  167. yield proxy
  168. except Exception:
  169. proxy.close(broken=True)
  170. raise
  171. else:
  172. proxy.close()
  173. def close(self) -> None:
  174. """Close the pool and all managed connections."""
  175. with self._condition:
  176. if self._closed:
  177. return
  178. self._closed = True
  179. to_close = list(self._available)
  180. self._available.clear()
  181. to_close.extend(self._in_use.values())
  182. self._in_use.clear()
  183. self._condition.notify_all()
  184. self._stop_event.set()
  185. for entry in to_close:
  186. self._close_entry(entry)
  187. self._total = 0
  188. if self._housekeeper.is_alive():
  189. self._housekeeper.join(timeout=1.0)
  190. def stats(self) -> Dict[str, Any]:
  191. """Lightweight snapshot of pool counters for observability/testing."""
  192. with self._condition:
  193. return {
  194. "total": self._total,
  195. "available": len(self._available),
  196. "in_use": len(self._in_use),
  197. "closed": self._closed,
  198. "create_failures": self._create_failures,
  199. "discarded": self._discarded,
  200. }
  201. # Internal helpers -----------------------------------------------------
  202. def _prepare_borrow(self, entry: _ConnectionEntry) -> Optional[_PooledConnectionProxy]:
  203. if self._config.test_on_borrow and not self._validate(entry):
  204. self._discard_entry(entry)
  205. return None
  206. entry.last_used = time.monotonic()
  207. entry.usage_count += 1
  208. with self._condition:
  209. self._in_use[id(entry.conn)] = entry
  210. return _PooledConnectionProxy(self, entry)
  211. def _initial_fill(self) -> None:
  212. for _ in range(self._config.min_size):
  213. entry = self._create_entry()
  214. if entry:
  215. with self._condition:
  216. self._available.append(entry)
  217. self._total += 1
  218. def _reserve_slot_for_new(self) -> bool:
  219. with self._condition:
  220. if self._closed:
  221. raise PoolClosedError("Connection pool is closed")
  222. if self._total >= self._config.max_size:
  223. return False
  224. self._total += 1
  225. return True
  226. def _try_acquire_available(self) -> Optional[_ConnectionEntry]:
  227. now = time.monotonic()
  228. with self._condition:
  229. if self._closed:
  230. raise PoolClosedError("Connection pool is closed")
  231. while self._available:
  232. entry = self._available.popleft()
  233. if self._should_discard(entry, now):
  234. self._discard_entry_locked(entry)
  235. continue
  236. return entry
  237. return None
  238. def _create_entry(self, reserved: bool = False) -> Optional[_ConnectionEntry]:
  239. try:
  240. conn = py_opengauss.open(self._config.dsn, **(self._config.connect_kwargs or {}))
  241. now = time.monotonic()
  242. return _ConnectionEntry(conn=conn, created_at=now, last_used=now, last_check=now)
  243. except Exception:
  244. logger.exception("Failed to create new py_opengauss connection")
  245. self._create_failures += 1
  246. if reserved:
  247. with self._condition:
  248. self._total = max(0, self._total - 1)
  249. self._condition.notify_all()
  250. return None
  251. def _validate(self, entry: _ConnectionEntry) -> bool:
  252. try:
  253. stmt = entry.conn.prepare(self._config.test_sql)
  254. stmt()
  255. entry.last_check = time.monotonic()
  256. return True
  257. except Exception:
  258. logger.warning("Validation failed; discarding connection", exc_info=True)
  259. return False
  260. def _should_discard(self, entry: _ConnectionEntry, now: Optional[float] = None) -> bool:
  261. now = now or time.monotonic()
  262. if self._config.max_lifetime and now - entry.created_at >= self._config.max_lifetime:
  263. return True
  264. if self._config.idle_timeout and now - entry.last_used >= self._config.idle_timeout:
  265. return True
  266. if self._config.max_usage and entry.usage_count >= self._config.max_usage:
  267. return True
  268. return False
  269. def _reset_connection(self, entry: _ConnectionEntry) -> bool:
  270. if not self._config.reset_on_return and not self._config.setsession:
  271. return True
  272. try:
  273. if self._config.reset_on_return:
  274. rollback = getattr(entry.conn, "rollback", None)
  275. if callable(rollback):
  276. rollback()
  277. else:
  278. entry.conn.execute("ROLLBACK")
  279. if self._config.setsession:
  280. for sql in self._config.setsession:
  281. stmt = entry.conn.prepare(sql)
  282. stmt()
  283. return True
  284. except Exception:
  285. logger.warning("Failed to reset connection; discarding", exc_info=True)
  286. return False
  287. def _close_entry(self, entry: _ConnectionEntry) -> None:
  288. try:
  289. entry.conn.close()
  290. except Exception:
  291. logger.debug("Failed closing connection", exc_info=True)
  292. def _discard_entry_locked(self, entry: _ConnectionEntry) -> None:
  293. """Discard while holding the pool lock."""
  294. self._close_entry(entry)
  295. self._total = max(0, self._total - 1)
  296. self._discarded += 1
  297. self._condition.notify_all()
  298. def _discard_entry(self, entry: _ConnectionEntry) -> None:
  299. with self._condition:
  300. self._discard_entry_locked(entry)
  301. def _return_entry(self, entry: _ConnectionEntry, broken: bool = False) -> None:
  302. with self._condition:
  303. stored = self._in_use.pop(id(entry.conn), None)
  304. if stored is None:
  305. return
  306. entry = stored
  307. if broken or not self._reset_connection(entry) or self._should_discard(entry):
  308. self._discard_entry(entry)
  309. return
  310. with self._condition:
  311. if self._closed:
  312. self._discard_entry_locked(entry)
  313. return
  314. entry.last_used = time.monotonic()
  315. self._available.append(entry)
  316. self._condition.notify_all()
  317. def _wait_for_availability(self, deadline: Optional[float]) -> None:
  318. if deadline is not None:
  319. remaining = deadline - time.monotonic()
  320. if remaining <= 0:
  321. raise PoolExhaustedError("Timed out waiting for connection from pool")
  322. with self._condition:
  323. if self._closed:
  324. raise PoolClosedError("Connection pool is closed")
  325. self._condition.wait(timeout=remaining)
  326. else:
  327. with self._condition:
  328. if self._closed:
  329. raise PoolClosedError("Connection pool is closed")
  330. self._condition.wait()
  331. def _housekeeping(self) -> None:
  332. interval = max(1.0, self._config.health_check_interval)
  333. while not self._stop_event.wait(interval):
  334. self._perform_health_check()
  335. def _perform_health_check(self) -> None:
  336. now = time.monotonic()
  337. to_check: Deque[_ConnectionEntry] = deque()
  338. with self._condition:
  339. if self._closed:
  340. return
  341. kept: Deque[_ConnectionEntry] = deque()
  342. while self._available:
  343. entry = self._available.popleft()
  344. if self._should_discard(entry, now):
  345. self._discard_entry_locked(entry)
  346. continue
  347. if (
  348. self._config.keepalive
  349. and self._config.keepalive_interval
  350. and now - entry.last_check >= self._config.keepalive_interval
  351. ):
  352. to_check.append(entry)
  353. else:
  354. kept.append(entry)
  355. self._available = kept
  356. for entry in list(to_check):
  357. if self._validate(entry):
  358. with self._condition:
  359. if not self._closed:
  360. entry.last_used = time.monotonic()
  361. self._available.append(entry)
  362. else:
  363. self._discard_entry(entry)
  364. with self._condition:
  365. needed = max(0, self._config.min_size - self._total)
  366. for _ in range(needed):
  367. entry = self._create_entry()
  368. if entry:
  369. with self._condition:
  370. if self._closed:
  371. self._close_entry(entry)
  372. return
  373. self._available.append(entry)
  374. self._total += 1
  375. self._condition.notify_all()