import threading import time import unittest from typing import List try: import py_opengauss # type: ignore # noqa: F401 except ImportError: # pragma: no cover - handled via skip py_opengauss = None from gdb_utils.opengauss_pool_hardened import ( ConnectionPoolConfig, OpenGaussConnectionPool, PoolExhaustedError, ) DB_HOST = "127.0.0.1" DB_PORT = 5432 DB_USER = "gaussdb" DB_PASSWORD = "Ysl#1234" DB_NAME = "postgres" def build_dsn() -> str: return f"opengauss://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}" @unittest.skipIf(py_opengauss is None, "py_opengauss is required for integration tests") class OpenGaussHardenedPoolIntegrationTest(unittest.TestCase): seed_uid: int seed_process_id: str config: ConnectionPoolConfig pool: OpenGaussConnectionPool @classmethod def setUpClass(cls) -> None: base = int(time.time()) % 1_500_000_000 cls.seed_uid = base + 200 cls.seed_process_id = f"hardened_seed_{base}" conn = py_opengauss.open(build_dsn()) try: delete_stmt = conn.prepare("DELETE FROM my_schema.table1 WHERE u_uid = $1") delete_stmt(cls.seed_uid) insert_stmt = conn.prepare( """ INSERT INTO my_schema.table1 (u_uid, process_id, ins_tm, content) VALUES ($1, $2, NOW(), $3) """ ) insert_stmt(cls.seed_uid, cls.seed_process_id, "seed") finally: conn.close() cls.config = ConnectionPoolConfig( dsn=build_dsn(), min_size=2, max_size=6, blocking=True, acquire_timeout=0.5, idle_timeout=10, max_lifetime=30, max_usage=20, test_on_borrow=True, test_sql="SELECT 1", keepalive=True, keepalive_interval=0.5, health_check_interval=0.5, reset_on_return=True, ) cls.pool = OpenGaussConnectionPool(cls.config) @classmethod def tearDownClass(cls) -> None: if hasattr(cls, "pool") and cls.pool: cls.pool.close() conn = py_opengauss.open(build_dsn()) try: delete_stmt = conn.prepare( "DELETE FROM my_schema.table1 WHERE process_id = $1" ) delete_stmt(cls.seed_process_id) finally: conn.close() def _new_pool(self, **overrides) -> OpenGaussConnectionPool: config_dict = self.config.__dict__.copy() config_dict.update(overrides) return OpenGaussConnectionPool(ConnectionPoolConfig(**config_dict)) def test_pool_limit_and_timeout(self): print("STEP: test_pool_limit_and_timeout - exhaust pool") borrowed = [self.pool.borrow() for _ in range(self.config.max_size)] print("STEP: test_pool_limit_and_timeout - timed borrow raises") with self.assertRaises(PoolExhaustedError): self.pool.borrow(timeout=0.3) print("STEP: test_pool_limit_and_timeout - return connections and assert stats") for conn in borrowed: conn.close() stats = self.pool.stats() self.assertLessEqual(stats["total"], self.config.max_size) self.assertEqual(stats["in_use"], 0) self.assertGreaterEqual(stats["available"], self.config.min_size) def test_serial_reads_with_lifecycle(self): print("STEP: test_serial_reads_with_lifecycle - serial reads with pauses") pool = self._new_pool(min_size=1, max_size=3, max_lifetime=1.0, idle_timeout=5.0) stats = {} seen_ids = set() stats_before = pool.stats() try: for idx in range(4): with pool.connection(timeout=1.0) as conn: seen_ids.add(id(conn._entry.conn)) select_stmt = conn.prepare( "SELECT content FROM my_schema.table1 WHERE u_uid = $1" ) rows = select_stmt(self.seed_uid) self.assertEqual(rows[0][0], "seed") if idx == 1: time.sleep(1.2) else: time.sleep(0.2) stats = pool.stats() finally: pool.close() print("STEP: test_serial_reads_with_lifecycle - assert pool stats") self.assertLessEqual(stats["total"], 3) self.assertFalse(stats["closed"]) self.assertTrue( len(seen_ids) >= 2 or stats["discarded"] > stats_before["discarded"] ) def test_concurrent_reads(self): print("STEP: test_concurrent_reads - start threads") results: List[str] = [] errors: List[Exception] = [] lock = threading.Lock() def worker(): try: with self.pool.connection(timeout=2.0) as conn: select_stmt = conn.prepare( "SELECT content FROM my_schema.table1 WHERE u_uid = $1" ) rows = select_stmt(self.seed_uid) with lock: results.append(rows[0][0]) except Exception as exc: # noqa: BLE001 - capture any worker errors with lock: errors.append(exc) threads = [threading.Thread(target=worker) for _ in range(20)] for t in threads: t.start() for t in threads: t.join() print("STEP: test_concurrent_reads - assert results and pool stats") stats = self.pool.stats() self.assertEqual(len(errors), 0) self.assertEqual(len(results), 20) self.assertFalse(stats["closed"]) self.assertLessEqual(stats["total"], self.config.max_size) def test_concurrent_batch_writes(self): print("STEP: test_concurrent_batch_writes - start writer threads") base = self.seed_uid + 20000 process_id = f"hardened_batch_{base}" total_batches = 5 batch_size = 20 worker_count = 8 errors: List[Exception] = [] lock = threading.Lock() def writer(worker_idx: int): try: for batch_idx in range(total_batches): with self.pool.connection(timeout=2.0) as conn: insert_stmt = conn.prepare( """ INSERT INTO my_schema.table1 (u_uid, process_id, ins_tm, content) VALUES ($1, $2, NOW(), $3) """ ) for i in range(batch_size): uid = base + (worker_idx * 1000) + (batch_idx * batch_size) + i insert_stmt(uid, process_id, f"batch_{batch_idx}") except Exception as exc: # noqa: BLE001 - capture any worker errors with lock: errors.append(exc) threads = [threading.Thread(target=writer, args=(i,)) for i in range(worker_count)] for t in threads: t.start() for t in threads: t.join() print("STEP: test_concurrent_batch_writes - validate count and cleanup") with self.pool.connection(timeout=2.0) as conn: count_stmt = conn.prepare( "SELECT COUNT(*) FROM my_schema.table1 WHERE process_id = $1" ) rows = count_stmt(process_id) total_inserted = rows[0][0] delete_stmt = conn.prepare( "DELETE FROM my_schema.table1 WHERE process_id = $1" ) delete_stmt(process_id) stats = self.pool.stats() self.assertEqual(len(errors), 0) self.assertEqual(total_inserted, worker_count * total_batches * batch_size) self.assertFalse(stats["closed"]) def test_connection_lifecycle_management(self): print("STEP: test_connection_lifecycle_management - hold and expire connection") pool = self._new_pool(min_size=1, max_size=2, max_lifetime=1.0, idle_timeout=5.0) stats = {} try: conn = pool.borrow() conn_id = id(conn._entry.conn) time.sleep(1.2) conn.close() print("STEP: test_connection_lifecycle_management - borrow replacement") new_conn = pool.borrow(timeout=1.0) new_id = id(new_conn._entry.conn) new_conn.close() stats = pool.stats() finally: pool.close() self.assertNotEqual(conn_id, new_id) self.assertLessEqual(stats["total"], 2) def test_connection_validation_replaces_bad_connection(self): print("STEP: test_connection_validation_replaces_bad_connection - return closed conn") conn = self.pool.borrow() bad_id = id(conn._entry.conn) conn._entry.conn.close() self.pool.return_connection(conn) print("STEP: test_connection_validation_replaces_bad_connection - borrow and validate") with self.pool.connection(timeout=1.0) as conn2: new_id = id(conn2._entry.conn) rows = conn2.prepare("SELECT 1")() stats = self.pool.stats() self.assertEqual(rows[0][0], 1) self.assertNotEqual(bad_id, new_id) self.assertLessEqual(stats["total"], self.config.max_size) def test_long_running_requests(self): print("STEP: test_long_running_requests - run mixed workload") errors: List[Exception] = [] lock = threading.Lock() stop_at = time.monotonic() + 3.0 def load_worker(): while time.monotonic() < stop_at: try: with self.pool.connection(timeout=2.0) as conn: conn.prepare("SELECT 1")() time.sleep(0.2) select_stmt = conn.prepare( "SELECT content FROM my_schema.table1 WHERE u_uid = $1" ) rows = select_stmt(self.seed_uid) if rows[0][0] != "seed": raise AssertionError("Unexpected content") except Exception as exc: # noqa: BLE001 - capture any worker errors with lock: errors.append(exc) return threads = [threading.Thread(target=load_worker) for _ in range(4)] for t in threads: t.start() for t in threads: t.join() print("STEP: test_long_running_requests - assert pool stats") stats = self.pool.stats() self.assertEqual(len(errors), 0) self.assertFalse(stats["closed"]) self.assertLessEqual(stats["total"], self.config.max_size) if __name__ == "__main__": unittest.main()