opengauss_pool.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. """
  2. Lightweight connection pool built on top of py_opengauss.
  3. The pool keeps a configurable number of connections ready, validates them on
  4. borrow when requested, performs periodic health checks/keepalives, and discards
  5. connections that fail SQL execution or exceed idle/lifetime thresholds.
  6. """
  7. from __future__ import annotations
  8. import logging
  9. import threading
  10. import time
  11. from collections import deque
  12. from contextlib import contextmanager
  13. from dataclasses import dataclass
  14. from typing import Any, Deque, Dict, Optional
  15. import py_opengauss
  16. logger = logging.getLogger(__name__)
  17. @dataclass
  18. class ConnectionPoolConfig:
  19. dsn: str
  20. min_size: int = 1
  21. max_size: int = 10
  22. idle_timeout: float = 300.0
  23. max_lifetime: float = 3600.0
  24. test_on_borrow: bool = True
  25. test_sql: str = "SELECT 1"
  26. keepalive: bool = True
  27. keepalive_interval: float = 60.0
  28. health_check_interval: float = 30.0
  29. connect_kwargs: Optional[Dict[str, Any]] = None
  30. @dataclass
  31. class _ConnectionEntry:
  32. conn: Any
  33. created_at: float
  34. last_used: float
  35. last_check: float
  36. class OpenGaussConnectionPool:
  37. """
  38. Thread-safe connection pool for py_opengauss.
  39. Borrowed connections are returned directly without wrappers. When an error
  40. is detected during SQL execution, callers should mark the connection as
  41. broken via `return_connection(conn, had_error=True)` to evict it quickly.
  42. """
  43. def __init__(self, config: ConnectionPoolConfig):
  44. if config.min_size < 0 or config.max_size <= 0:
  45. raise ValueError("Pool sizes must be positive")
  46. if config.min_size > config.max_size:
  47. raise ValueError("min_size cannot be greater than max_size")
  48. self._config = config
  49. self._lock = threading.Lock()
  50. self._condition = threading.Condition(self._lock)
  51. self._available: Deque[_ConnectionEntry] = deque()
  52. self._in_use: Dict[int, _ConnectionEntry] = {}
  53. self._total = 0
  54. self._closed = False
  55. self._housekeeper = threading.Thread(
  56. target=self._housekeeping, name="opengauss-pool-housekeeper", daemon=True
  57. )
  58. self._initial_fill()
  59. self._housekeeper.start()
  60. def borrow(self, timeout: Optional[float] = None) -> Any:
  61. """
  62. Borrow a connection from the pool.
  63. If test_on_borrow is enabled, the connection will be validated with the
  64. configured SQL before being returned.
  65. """
  66. deadline = time.monotonic() + timeout if timeout else None
  67. while True:
  68. entry = self._try_acquire_available()
  69. if entry:
  70. if self._config.test_on_borrow and not self._validate(entry):
  71. self._discard(entry)
  72. continue
  73. entry.last_used = time.time()
  74. with self._condition:
  75. self._in_use[id(entry.conn)] = entry
  76. return entry.conn
  77. to_create = self._reserve_slot_for_new()
  78. if to_create:
  79. entry = self._create_entry(reserved=True)
  80. if entry is None:
  81. # Creation failed; avoid tight loop when DB is unavailable.
  82. with self._condition:
  83. if self._closed:
  84. raise RuntimeError("Connection pool is closed")
  85. if deadline is not None:
  86. remaining = deadline - time.monotonic()
  87. if remaining <= 0:
  88. raise TimeoutError("Timed out waiting for connection")
  89. self._condition.wait(timeout=min(0.1, remaining))
  90. else:
  91. self._condition.wait(timeout=0.1)
  92. continue
  93. with self._condition:
  94. self._in_use[id(entry.conn)] = entry
  95. return entry.conn
  96. # Wait for a return or until timed out.
  97. with self._condition:
  98. if self._closed:
  99. raise RuntimeError("Connection pool is closed")
  100. if deadline is not None:
  101. remaining = deadline - time.monotonic()
  102. if remaining <= 0:
  103. raise TimeoutError("Timed out waiting for connection")
  104. self._condition.wait(timeout=remaining)
  105. else:
  106. self._condition.wait()
  107. def return_connection(self, conn: Any, had_error: bool = False) -> None:
  108. """
  109. Return a connection to the pool.
  110. If had_error is True, the connection will be closed and removed to
  111. avoid reusing broken connections.
  112. """
  113. entry = None
  114. with self._condition:
  115. entry = self._in_use.pop(id(conn), None)
  116. if entry is None:
  117. # Unknown connection; close it defensively.
  118. try:
  119. conn.close()
  120. finally:
  121. return
  122. if had_error:
  123. self._discard(entry)
  124. else:
  125. entry.last_used = time.time()
  126. if self._should_discard(entry):
  127. self._discard(entry)
  128. else:
  129. with self._condition:
  130. self._available.append(entry)
  131. self._condition.notify()
  132. @contextmanager
  133. def connection(self, timeout: Optional[float] = None):
  134. """
  135. Context manager helper that returns the connection to the pool
  136. automatically and discards it if an exception is raised.
  137. """
  138. conn = self.borrow(timeout=timeout)
  139. try:
  140. yield conn
  141. except Exception:
  142. self.return_connection(conn, had_error=True)
  143. raise
  144. else:
  145. self.return_connection(conn)
  146. def close(self) -> None:
  147. """
  148. Close the pool and all managed connections.
  149. """
  150. with self._condition:
  151. if self._closed:
  152. return
  153. self._closed = True
  154. to_close = list(self._available)
  155. self._available.clear()
  156. to_close.extend(self._in_use.values())
  157. self._in_use.clear()
  158. self._condition.notify_all()
  159. for entry in to_close:
  160. try:
  161. entry.conn.close()
  162. except Exception:
  163. logger.debug("Failed closing connection during pool shutdown", exc_info=True)
  164. self._total = 0
  165. def stats(self) -> Dict[str, Any]:
  166. """
  167. Lightweight snapshot of pool counters for observability/testing.
  168. """
  169. with self._condition:
  170. return {
  171. "total": self._total,
  172. "available": len(self._available),
  173. "in_use": len(self._in_use),
  174. "closed": self._closed,
  175. }
  176. # Internal helpers -----------------------------------------------------
  177. def _initial_fill(self) -> None:
  178. for _ in range(self._config.min_size):
  179. entry = self._create_entry()
  180. if entry:
  181. with self._condition:
  182. self._available.append(entry)
  183. self._total += 1
  184. def _reserve_slot_for_new(self) -> bool:
  185. with self._condition:
  186. if self._closed:
  187. raise RuntimeError("Connection pool is closed")
  188. if self._total >= self._config.max_size:
  189. return False
  190. self._total += 1
  191. return True
  192. def _try_acquire_available(self) -> Optional[_ConnectionEntry]:
  193. now = time.time()
  194. to_discard: Deque[_ConnectionEntry] = deque()
  195. with self._condition:
  196. if self._closed:
  197. raise RuntimeError("Connection pool is closed")
  198. while self._available:
  199. entry = self._available.popleft()
  200. if self._should_discard(entry, now):
  201. to_discard.append(entry)
  202. continue
  203. return entry
  204. for entry in list(to_discard):
  205. self._discard(entry)
  206. return None
  207. def _create_entry(self, reserved: bool = False) -> Optional[_ConnectionEntry]:
  208. try:
  209. conn = py_opengauss.open(self._config.dsn, **(self._config.connect_kwargs or {}))
  210. now = time.time()
  211. return _ConnectionEntry(conn=conn, created_at=now, last_used=now, last_check=now)
  212. except Exception:
  213. logger.exception("Failed to create new py_opengauss connection")
  214. if reserved:
  215. with self._condition:
  216. self._total = max(0, self._total - 1)
  217. return None
  218. def _validate(self, entry: _ConnectionEntry) -> bool:
  219. try:
  220. stmt = entry.conn.prepare(self._config.test_sql)
  221. stmt()
  222. entry.last_check = time.time()
  223. return True
  224. except Exception:
  225. logger.warning("Validation failed; discarding connection", exc_info=True)
  226. return False
  227. def _should_discard(self, entry: _ConnectionEntry, now: Optional[float] = None) -> bool:
  228. now = now or time.time()
  229. if self._config.max_lifetime and now - entry.created_at >= self._config.max_lifetime:
  230. return True
  231. if self._config.idle_timeout and now - entry.last_used >= self._config.idle_timeout:
  232. return True
  233. return False
  234. def _discard(self, entry: _ConnectionEntry) -> None:
  235. try:
  236. entry.conn.close()
  237. except Exception:
  238. logger.debug("Failed closing connection", exc_info=True)
  239. with self._condition:
  240. self._total = max(0, self._total - 1)
  241. self._condition.notify()
  242. def _housekeeping(self) -> None:
  243. interval = max(1.0, self._config.health_check_interval)
  244. while True:
  245. if self._closed:
  246. return
  247. self._perform_health_check()
  248. time.sleep(interval)
  249. def _perform_health_check(self) -> None:
  250. now = time.time()
  251. to_check: Deque[_ConnectionEntry] = deque()
  252. to_discard: Deque[_ConnectionEntry] = deque()
  253. with self._condition:
  254. if self._closed:
  255. return
  256. kept: Deque[_ConnectionEntry] = deque()
  257. while self._available:
  258. entry = self._available.popleft()
  259. if self._should_discard(entry, now):
  260. to_discard.append(entry)
  261. continue
  262. if (
  263. self._config.keepalive
  264. and self._config.keepalive_interval
  265. and now - entry.last_check >= self._config.keepalive_interval
  266. ):
  267. to_check.append(entry)
  268. else:
  269. kept.append(entry)
  270. self._available = kept
  271. for entry in list(to_discard):
  272. self._discard(entry)
  273. # Run keepalive checks outside lock to avoid blocking borrowers.
  274. for entry in list(to_check):
  275. if self._validate(entry):
  276. with self._condition:
  277. if not self._closed:
  278. self._available.append(entry)
  279. else:
  280. self._discard(entry)
  281. # Ensure minimum pool size is preserved.
  282. with self._condition:
  283. needed = max(0, self._config.min_size - self._total)
  284. for _ in range(needed):
  285. entry = self._create_entry()
  286. if entry:
  287. with self._condition:
  288. if self._closed:
  289. entry.conn.close()
  290. return
  291. self._available.append(entry)
  292. self._total += 1
  293. self._condition.notify()