test_opengauss_pool_hardened.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. import threading
  2. import time
  3. import unittest
  4. from typing import List
  5. try:
  6. import py_opengauss # type: ignore # noqa: F401
  7. except ImportError: # pragma: no cover - handled via skip
  8. py_opengauss = None
  9. from gdb_utils.opengauss_pool_hardened import (
  10. ConnectionPoolConfig,
  11. OpenGaussConnectionPool,
  12. PoolExhaustedError,
  13. )
  14. DB_HOST = "127.0.0.1"
  15. DB_PORT = 5432
  16. DB_USER = "gaussdb"
  17. DB_PASSWORD = "Ysl#1234"
  18. DB_NAME = "postgres"
  19. def build_dsn() -> str:
  20. return f"opengauss://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
  21. @unittest.skipIf(py_opengauss is None, "py_opengauss is required for integration tests")
  22. class OpenGaussHardenedPoolIntegrationTest(unittest.TestCase):
  23. seed_uid: int
  24. seed_process_id: str
  25. config: ConnectionPoolConfig
  26. pool: OpenGaussConnectionPool
  27. @classmethod
  28. def setUpClass(cls) -> None:
  29. base = int(time.time()) % 1_500_000_000
  30. cls.seed_uid = base + 200
  31. cls.seed_process_id = f"hardened_seed_{base}"
  32. conn = py_opengauss.open(build_dsn())
  33. try:
  34. delete_stmt = conn.prepare("DELETE FROM my_schema.table1 WHERE u_uid = $1")
  35. delete_stmt(cls.seed_uid)
  36. insert_stmt = conn.prepare(
  37. """
  38. INSERT INTO my_schema.table1 (u_uid, process_id, ins_tm, content)
  39. VALUES ($1, $2, NOW(), $3)
  40. """
  41. )
  42. insert_stmt(cls.seed_uid, cls.seed_process_id, "seed")
  43. finally:
  44. conn.close()
  45. cls.config = ConnectionPoolConfig(
  46. dsn=build_dsn(),
  47. min_size=2,
  48. max_size=6,
  49. blocking=True,
  50. acquire_timeout=0.5,
  51. idle_timeout=10,
  52. max_lifetime=30,
  53. max_usage=20,
  54. test_on_borrow=True,
  55. test_sql="SELECT 1",
  56. keepalive=True,
  57. keepalive_interval=0.5,
  58. health_check_interval=0.5,
  59. reset_on_return=True,
  60. )
  61. cls.pool = OpenGaussConnectionPool(cls.config)
  62. @classmethod
  63. def tearDownClass(cls) -> None:
  64. if hasattr(cls, "pool") and cls.pool:
  65. cls.pool.close()
  66. conn = py_opengauss.open(build_dsn())
  67. try:
  68. delete_stmt = conn.prepare(
  69. "DELETE FROM my_schema.table1 WHERE process_id = $1"
  70. )
  71. delete_stmt(cls.seed_process_id)
  72. finally:
  73. conn.close()
  74. def _new_pool(self, **overrides) -> OpenGaussConnectionPool:
  75. config_dict = self.config.__dict__.copy()
  76. config_dict.update(overrides)
  77. return OpenGaussConnectionPool(ConnectionPoolConfig(**config_dict))
  78. def test_pool_limit_and_timeout(self):
  79. print("STEP: test_pool_limit_and_timeout - exhaust pool")
  80. borrowed = [self.pool.borrow() for _ in range(self.config.max_size)]
  81. print("STEP: test_pool_limit_and_timeout - timed borrow raises")
  82. with self.assertRaises(PoolExhaustedError):
  83. self.pool.borrow(timeout=0.3)
  84. print("STEP: test_pool_limit_and_timeout - return connections and assert stats")
  85. for conn in borrowed:
  86. conn.close()
  87. stats = self.pool.stats()
  88. self.assertLessEqual(stats["total"], self.config.max_size)
  89. self.assertEqual(stats["in_use"], 0)
  90. self.assertGreaterEqual(stats["available"], self.config.min_size)
  91. def test_serial_reads_with_lifecycle(self):
  92. print("STEP: test_serial_reads_with_lifecycle - serial reads with pauses")
  93. pool = self._new_pool(min_size=1, max_size=3, max_lifetime=1.0, idle_timeout=5.0)
  94. stats = {}
  95. seen_ids = set()
  96. stats_before = pool.stats()
  97. try:
  98. for idx in range(4):
  99. with pool.connection(timeout=1.0) as conn:
  100. seen_ids.add(id(conn._entry.conn))
  101. select_stmt = conn.prepare(
  102. "SELECT content FROM my_schema.table1 WHERE u_uid = $1"
  103. )
  104. rows = select_stmt(self.seed_uid)
  105. self.assertEqual(rows[0][0], "seed")
  106. if idx == 1:
  107. time.sleep(1.2)
  108. else:
  109. time.sleep(0.2)
  110. stats = pool.stats()
  111. finally:
  112. pool.close()
  113. print("STEP: test_serial_reads_with_lifecycle - assert pool stats")
  114. self.assertLessEqual(stats["total"], 3)
  115. self.assertFalse(stats["closed"])
  116. self.assertTrue(
  117. len(seen_ids) >= 2 or stats["discarded"] > stats_before["discarded"]
  118. )
  119. def test_concurrent_reads(self):
  120. print("STEP: test_concurrent_reads - start threads")
  121. results: List[str] = []
  122. errors: List[Exception] = []
  123. lock = threading.Lock()
  124. def worker():
  125. try:
  126. with self.pool.connection(timeout=2.0) as conn:
  127. select_stmt = conn.prepare(
  128. "SELECT content FROM my_schema.table1 WHERE u_uid = $1"
  129. )
  130. rows = select_stmt(self.seed_uid)
  131. with lock:
  132. results.append(rows[0][0])
  133. except Exception as exc: # noqa: BLE001 - capture any worker errors
  134. with lock:
  135. errors.append(exc)
  136. threads = [threading.Thread(target=worker) for _ in range(20)]
  137. for t in threads:
  138. t.start()
  139. for t in threads:
  140. t.join()
  141. print("STEP: test_concurrent_reads - assert results and pool stats")
  142. stats = self.pool.stats()
  143. self.assertEqual(len(errors), 0)
  144. self.assertEqual(len(results), 20)
  145. self.assertFalse(stats["closed"])
  146. self.assertLessEqual(stats["total"], self.config.max_size)
  147. def test_concurrent_batch_writes(self):
  148. print("STEP: test_concurrent_batch_writes - start writer threads")
  149. base = self.seed_uid + 20000
  150. process_id = f"hardened_batch_{base}"
  151. total_batches = 5
  152. batch_size = 20
  153. worker_count = 8
  154. errors: List[Exception] = []
  155. lock = threading.Lock()
  156. def writer(worker_idx: int):
  157. try:
  158. for batch_idx in range(total_batches):
  159. with self.pool.connection(timeout=2.0) as conn:
  160. insert_stmt = conn.prepare(
  161. """
  162. INSERT INTO my_schema.table1 (u_uid, process_id, ins_tm, content)
  163. VALUES ($1, $2, NOW(), $3)
  164. """
  165. )
  166. for i in range(batch_size):
  167. uid = base + (worker_idx * 1000) + (batch_idx * batch_size) + i
  168. insert_stmt(uid, process_id, f"batch_{batch_idx}")
  169. except Exception as exc: # noqa: BLE001 - capture any worker errors
  170. with lock:
  171. errors.append(exc)
  172. threads = [threading.Thread(target=writer, args=(i,)) for i in range(worker_count)]
  173. for t in threads:
  174. t.start()
  175. for t in threads:
  176. t.join()
  177. print("STEP: test_concurrent_batch_writes - validate count and cleanup")
  178. with self.pool.connection(timeout=2.0) as conn:
  179. count_stmt = conn.prepare(
  180. "SELECT COUNT(*) FROM my_schema.table1 WHERE process_id = $1"
  181. )
  182. rows = count_stmt(process_id)
  183. total_inserted = rows[0][0]
  184. delete_stmt = conn.prepare(
  185. "DELETE FROM my_schema.table1 WHERE process_id = $1"
  186. )
  187. delete_stmt(process_id)
  188. stats = self.pool.stats()
  189. self.assertEqual(len(errors), 0)
  190. self.assertEqual(total_inserted, worker_count * total_batches * batch_size)
  191. self.assertFalse(stats["closed"])
  192. def test_connection_lifecycle_management(self):
  193. print("STEP: test_connection_lifecycle_management - hold and expire connection")
  194. pool = self._new_pool(min_size=1, max_size=2, max_lifetime=1.0, idle_timeout=5.0)
  195. stats = {}
  196. try:
  197. conn = pool.borrow()
  198. conn_id = id(conn._entry.conn)
  199. time.sleep(1.2)
  200. conn.close()
  201. print("STEP: test_connection_lifecycle_management - borrow replacement")
  202. new_conn = pool.borrow(timeout=1.0)
  203. new_id = id(new_conn._entry.conn)
  204. new_conn.close()
  205. stats = pool.stats()
  206. finally:
  207. pool.close()
  208. self.assertNotEqual(conn_id, new_id)
  209. self.assertLessEqual(stats["total"], 2)
  210. def test_connection_validation_replaces_bad_connection(self):
  211. print("STEP: test_connection_validation_replaces_bad_connection - return closed conn")
  212. conn = self.pool.borrow()
  213. bad_id = id(conn._entry.conn)
  214. conn._entry.conn.close()
  215. self.pool.return_connection(conn)
  216. print("STEP: test_connection_validation_replaces_bad_connection - borrow and validate")
  217. with self.pool.connection(timeout=1.0) as conn2:
  218. new_id = id(conn2._entry.conn)
  219. rows = conn2.prepare("SELECT 1")()
  220. stats = self.pool.stats()
  221. self.assertEqual(rows[0][0], 1)
  222. self.assertNotEqual(bad_id, new_id)
  223. self.assertLessEqual(stats["total"], self.config.max_size)
  224. def test_long_running_requests(self):
  225. print("STEP: test_long_running_requests - run mixed workload")
  226. errors: List[Exception] = []
  227. lock = threading.Lock()
  228. stop_at = time.monotonic() + 3.0
  229. def load_worker():
  230. while time.monotonic() < stop_at:
  231. try:
  232. with self.pool.connection(timeout=2.0) as conn:
  233. conn.prepare("SELECT 1")()
  234. time.sleep(0.2)
  235. select_stmt = conn.prepare(
  236. "SELECT content FROM my_schema.table1 WHERE u_uid = $1"
  237. )
  238. rows = select_stmt(self.seed_uid)
  239. if rows[0][0] != "seed":
  240. raise AssertionError("Unexpected content")
  241. except Exception as exc: # noqa: BLE001 - capture any worker errors
  242. with lock:
  243. errors.append(exc)
  244. return
  245. threads = [threading.Thread(target=load_worker) for _ in range(4)]
  246. for t in threads:
  247. t.start()
  248. for t in threads:
  249. t.join()
  250. print("STEP: test_long_running_requests - assert pool stats")
  251. stats = self.pool.stats()
  252. self.assertEqual(len(errors), 0)
  253. self.assertFalse(stats["closed"])
  254. self.assertLessEqual(stats["total"], self.config.max_size)
  255. if __name__ == "__main__":
  256. unittest.main()