test_opengauss_pool.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. import threading
  2. import time
  3. import unittest
  4. from typing import List
  5. try:
  6. import py_opengauss # type: ignore # noqa: F401
  7. except ImportError: # pragma: no cover - handled via skip
  8. py_opengauss = None
  9. from gdb_utils import ConnectionPoolConfig, OpenGaussConnectionPool
  10. DB_HOST = "127.0.0.1"
  11. DB_PORT = 5432
  12. DB_USER = "gaussdb"
  13. DB_PASSWORD = "Ysl#1234"
  14. DB_NAME = "postgres"
  15. def build_dsn() -> str:
  16. return f"opengauss://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
  17. @unittest.skipIf(py_opengauss is None, "py_opengauss is required for integration tests")
  18. class OpenGaussPoolIntegrationTest(unittest.TestCase):
  19. seed_uid: int
  20. seed_process_id: str
  21. config: ConnectionPoolConfig
  22. pool: OpenGaussConnectionPool
  23. @classmethod
  24. def setUpClass(cls) -> None:
  25. base = int(time.time()) % 1_500_000_000
  26. cls.seed_uid = base + 100
  27. cls.seed_process_id = f"pool_seed_{base}"
  28. conn = py_opengauss.open(build_dsn())
  29. try:
  30. delete_stmt = conn.prepare("DELETE FROM my_schema.table1 WHERE u_uid = $1")
  31. delete_stmt(cls.seed_uid)
  32. insert_stmt = conn.prepare(
  33. """
  34. INSERT INTO my_schema.table1 (u_uid, process_id, ins_tm, content)
  35. VALUES ($1, $2, NOW(), $3)
  36. """
  37. )
  38. insert_stmt(cls.seed_uid, cls.seed_process_id, "seed")
  39. finally:
  40. conn.close()
  41. cls.config = ConnectionPoolConfig(
  42. dsn=build_dsn(),
  43. min_size=2,
  44. max_size=8,
  45. idle_timeout=30,
  46. max_lifetime=60,
  47. test_on_borrow=True,
  48. test_sql="SELECT 1",
  49. keepalive=True,
  50. keepalive_interval=0.5,
  51. health_check_interval=0.5,
  52. )
  53. cls.pool = OpenGaussConnectionPool(cls.config)
  54. @classmethod
  55. def tearDownClass(cls) -> None:
  56. if hasattr(cls, "pool") and cls.pool:
  57. cls.pool.close()
  58. conn = py_opengauss.open(build_dsn())
  59. try:
  60. delete_stmt = conn.prepare(
  61. "DELETE FROM my_schema.table1 WHERE process_id = $1"
  62. )
  63. delete_stmt(cls.seed_process_id)
  64. finally:
  65. conn.close()
  66. def _wait_for_min_size(self, timeout: float = 2.0) -> None:
  67. deadline = time.monotonic() + timeout
  68. while time.monotonic() < deadline:
  69. stats = self.pool.stats()
  70. if stats["total"] >= self.config.min_size:
  71. return
  72. time.sleep(0.05)
  73. def test_read_and_write(self):
  74. print("STEP: test_read_and_write - setup table and insert row")
  75. # u_uid,name are VARCHAR(10); keep identifiers short.
  76. unique_uid = f"u{int(time.time()) % 100000:05d}"
  77. # Ensure schema/table exist and insert a row.
  78. with self.pool.connection() as conn:
  79. conn.execute("CREATE SCHEMA IF NOT EXISTS my_schema")
  80. conn.execute(
  81. """
  82. CREATE TABLE IF NOT EXISTS my_schema.test_table (
  83. u_uid VARCHAR(10),
  84. name VARCHAR(10),
  85. CONSTRAINT pk_uid PRIMARY KEY(u_uid)
  86. )
  87. """
  88. )
  89. # Ensure no leftover row with same key before insert.
  90. delete_before_insert = conn.prepare("DELETE FROM my_schema.test_table WHERE u_uid = $1")
  91. delete_before_insert(unique_uid)
  92. insert_stmt = conn.prepare(
  93. "INSERT INTO my_schema.test_table (u_uid, name) VALUES ($1, $2)"
  94. )
  95. insert_stmt(unique_uid, "codex")
  96. print("STEP: test_read_and_write - read row and assert")
  97. # Read back the row.
  98. with self.pool.connection() as conn:
  99. select_stmt = conn.prepare("SELECT name FROM my_schema.test_table WHERE u_uid = $1")
  100. rows = select_stmt(unique_uid)
  101. self.assertIsNotNone(rows)
  102. self.assertGreaterEqual(len(rows), 1)
  103. self.assertEqual(rows[0][0], "codex")
  104. print("STEP: test_read_and_write - cleanup inserted row")
  105. # Clean up the inserted row.
  106. with self.pool.connection() as conn:
  107. delete_stmt = conn.prepare("DELETE FROM my_schema.test_table WHERE u_uid = $1")
  108. delete_stmt(unique_uid)
  109. def test_broken_connection_is_replaced(self):
  110. print("STEP: test_broken_connection_is_replaced - close and return broken connection")
  111. conn = self.pool.borrow()
  112. # Simulate an unusable connection by closing it and marking as broken.
  113. conn.close()
  114. self.pool.return_connection(conn, had_error=True)
  115. print("STEP: test_broken_connection_is_replaced - borrow replacement and validate")
  116. # Borrow again; pool should create a new usable connection.
  117. with self.pool.connection() as conn2:
  118. stmt = conn2.prepare("SELECT 1")
  119. rows = stmt()
  120. self.assertEqual(rows[0][0], 1)
  121. self._wait_for_min_size()
  122. stats_after = self.pool.stats()
  123. self.assertFalse(stats_after["closed"])
  124. self.assertLessEqual(stats_after["total"], self.config.max_size)
  125. self.assertGreaterEqual(stats_after["total"], self.config.min_size)
  126. def test_pool_limit_and_timeout(self):
  127. print("STEP: test_pool_limit_and_timeout - exhaust pool")
  128. borrowed = [self.pool.borrow() for _ in range(self.config.max_size)]
  129. errors: List[Exception] = []
  130. print("STEP: test_pool_limit_and_timeout - attempt timed borrow")
  131. def try_borrow():
  132. try:
  133. self.pool.borrow(timeout=0.5)
  134. except Exception as exc: # noqa: BLE001 - capture TimeoutError
  135. errors.append(exc)
  136. t = threading.Thread(target=try_borrow)
  137. t.start()
  138. t.join()
  139. print("STEP: test_pool_limit_and_timeout - return connections and assert stats")
  140. for conn in borrowed:
  141. self.pool.return_connection(conn)
  142. stats = self.pool.stats()
  143. self.assertLessEqual(stats["total"], self.config.max_size)
  144. self.assertEqual(stats["in_use"], 0)
  145. self.assertGreaterEqual(stats["available"], self.config.min_size)
  146. self.assertTrue(any(isinstance(err, TimeoutError) for err in errors))
  147. def test_serial_reads_with_lifecycle(self):
  148. print("STEP: test_serial_reads_with_lifecycle - serial reads with pauses")
  149. for _ in range(3):
  150. with self.pool.connection() as conn:
  151. select_stmt = conn.prepare(
  152. "SELECT content FROM my_schema.table1 WHERE u_uid = $1"
  153. )
  154. rows = select_stmt(self.seed_uid)
  155. self.assertEqual(rows[0][0], "seed")
  156. time.sleep(0.5)
  157. print("STEP: test_serial_reads_with_lifecycle - assert pool stats")
  158. stats = self.pool.stats()
  159. self.assertLessEqual(stats["total"], self.config.max_size)
  160. self.assertFalse(stats["closed"])
  161. def test_concurrent_reads(self):
  162. print("STEP: test_concurrent_reads - start threads")
  163. results: List[str] = []
  164. errors: List[Exception] = []
  165. lock = threading.Lock()
  166. def worker():
  167. try:
  168. with self.pool.connection(timeout=2.0) as conn:
  169. select_stmt = conn.prepare(
  170. "SELECT content FROM my_schema.table1 WHERE u_uid = $1"
  171. )
  172. rows = select_stmt(self.seed_uid)
  173. with lock:
  174. results.append(rows[0][0])
  175. except Exception as exc: # noqa: BLE001 - capture any worker errors
  176. with lock:
  177. errors.append(exc)
  178. threads = [threading.Thread(target=worker) for _ in range(20)]
  179. for t in threads:
  180. t.start()
  181. for t in threads:
  182. t.join()
  183. print("STEP: test_concurrent_reads - assert results and pool stats")
  184. stats = self.pool.stats()
  185. self.assertEqual(len(errors), 0)
  186. self.assertEqual(len(results), 20)
  187. self.assertFalse(stats["closed"])
  188. self.assertLessEqual(stats["total"], self.config.max_size)
  189. def test_concurrent_batch_writes(self):
  190. print("STEP: test_concurrent_batch_writes - start writer threads")
  191. base = self.seed_uid + 10000
  192. process_id = f"pool_batch_{base}"
  193. total_batches = 5
  194. batch_size = 20
  195. worker_count = 8
  196. errors: List[Exception] = []
  197. lock = threading.Lock()
  198. def writer(worker_idx: int):
  199. try:
  200. for batch_idx in range(total_batches):
  201. with self.pool.connection(timeout=2.0) as conn:
  202. insert_stmt = conn.prepare(
  203. """
  204. INSERT INTO my_schema.table1 (u_uid, process_id, ins_tm, content)
  205. VALUES ($1, $2, NOW(), $3)
  206. """
  207. )
  208. for i in range(batch_size):
  209. uid = base + (worker_idx * 1000) + (batch_idx * batch_size) + i
  210. insert_stmt(uid, process_id, f"batch_{batch_idx}")
  211. except Exception as exc: # noqa: BLE001 - capture any worker errors
  212. with lock:
  213. errors.append(exc)
  214. threads = [threading.Thread(target=writer, args=(i,)) for i in range(worker_count)]
  215. for t in threads:
  216. t.start()
  217. for t in threads:
  218. t.join()
  219. print("STEP: test_concurrent_batch_writes - validate count and cleanup")
  220. with self.pool.connection(timeout=2.0) as conn:
  221. count_stmt = conn.prepare(
  222. "SELECT COUNT(*) FROM my_schema.table1 WHERE process_id = $1"
  223. )
  224. rows = count_stmt(process_id)
  225. total_inserted = rows[0][0]
  226. delete_stmt = conn.prepare(
  227. "DELETE FROM my_schema.table1 WHERE process_id = $1"
  228. )
  229. delete_stmt(process_id)
  230. stats = self.pool.stats()
  231. self.assertEqual(len(errors), 0)
  232. self.assertEqual(total_inserted, worker_count * total_batches * batch_size)
  233. self.assertFalse(stats["closed"])
  234. def test_connection_lifecycle_management(self):
  235. print("STEP: test_connection_lifecycle_management - borrow and age connection")
  236. conn = self.pool.borrow()
  237. conn.prepare("SELECT 1")()
  238. time.sleep(self.config.max_lifetime + 0.2)
  239. self.pool.return_connection(conn)
  240. print("STEP: test_connection_lifecycle_management - borrow new connection")
  241. new_conn = self.pool.borrow(timeout=1.0)
  242. self.pool.return_connection(new_conn)
  243. stats = self.pool.stats()
  244. self.assertNotEqual(id(conn), id(new_conn))
  245. self.assertLessEqual(stats["total"], self.config.max_size)
  246. def test_connection_validation_replaces_bad_connection(self):
  247. print("STEP: test_connection_validation_replaces_bad_connection - return closed conn")
  248. conn = self.pool.borrow()
  249. conn.close()
  250. self.pool.return_connection(conn, had_error=False)
  251. print("STEP: test_connection_validation_replaces_bad_connection - borrow and validate")
  252. with self.pool.connection(timeout=1.0) as conn2:
  253. rows = conn2.prepare("SELECT 1")()
  254. stats = self.pool.stats()
  255. self.assertEqual(rows[0][0], 1)
  256. self.assertLessEqual(stats["total"], self.config.max_size)
  257. def test_long_running_requests(self):
  258. print("STEP: test_long_running_requests - run mixed workload")
  259. errors: List[Exception] = []
  260. lock = threading.Lock()
  261. stop_at = time.monotonic() + 3.0
  262. def load_worker():
  263. while time.monotonic() < stop_at:
  264. try:
  265. with self.pool.connection(timeout=2.0) as conn:
  266. conn.prepare("SELECT 1")()
  267. time.sleep(0.2)
  268. select_stmt = conn.prepare(
  269. "SELECT content FROM my_schema.table1 WHERE u_uid = $1"
  270. )
  271. rows = select_stmt(self.seed_uid)
  272. if rows[0][0] != "seed":
  273. raise AssertionError("Unexpected content")
  274. except Exception as exc: # noqa: BLE001 - capture any worker errors
  275. with lock:
  276. errors.append(exc)
  277. return
  278. threads = [threading.Thread(target=load_worker) for _ in range(4)]
  279. for t in threads:
  280. t.start()
  281. for t in threads:
  282. t.join()
  283. print("STEP: test_long_running_requests - assert pool stats")
  284. stats = self.pool.stats()
  285. self.assertEqual(len(errors), 0)
  286. self.assertFalse(stats["closed"])
  287. self.assertLessEqual(stats["total"], self.config.max_size)
  288. if __name__ == "__main__":
  289. unittest.main()