import threading import time import unittest from typing import List try: import py_opengauss # type: ignore # noqa: F401 except ImportError: # pragma: no cover - handled via skip py_opengauss = None from gdb_utils import ConnectionPoolConfig, OpenGaussConnectionPool DB_HOST = "127.0.0.1" DB_PORT = 5432 DB_USER = "gaussdb" DB_PASSWORD = "Ysl#1234" DB_NAME = "postgres" def build_dsn() -> str: return f"opengauss://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}" @unittest.skipIf(py_opengauss is None, "py_opengauss is required for integration tests") class OpenGaussPoolIntegrationTest(unittest.TestCase): seed_uid: int seed_process_id: str config: ConnectionPoolConfig pool: OpenGaussConnectionPool @classmethod def setUpClass(cls) -> None: base = int(time.time()) % 1_500_000_000 cls.seed_uid = base + 100 cls.seed_process_id = f"pool_seed_{base}" conn = py_opengauss.open(build_dsn()) try: delete_stmt = conn.prepare("DELETE FROM my_schema.table1 WHERE u_uid = $1") delete_stmt(cls.seed_uid) insert_stmt = conn.prepare( """ INSERT INTO my_schema.table1 (u_uid, process_id, ins_tm, content) VALUES ($1, $2, NOW(), $3) """ ) insert_stmt(cls.seed_uid, cls.seed_process_id, "seed") finally: conn.close() cls.config = ConnectionPoolConfig( dsn=build_dsn(), min_size=2, max_size=8, idle_timeout=30, max_lifetime=60, test_on_borrow=True, test_sql="SELECT 1", keepalive=True, keepalive_interval=0.5, health_check_interval=0.5, ) cls.pool = OpenGaussConnectionPool(cls.config) @classmethod def tearDownClass(cls) -> None: if hasattr(cls, "pool") and cls.pool: cls.pool.close() conn = py_opengauss.open(build_dsn()) try: delete_stmt = conn.prepare( "DELETE FROM my_schema.table1 WHERE process_id = $1" ) delete_stmt(cls.seed_process_id) finally: conn.close() def _wait_for_min_size(self, timeout: float = 2.0) -> None: deadline = time.monotonic() + timeout while time.monotonic() < deadline: stats = self.pool.stats() if stats["total"] >= self.config.min_size: return time.sleep(0.05) def test_read_and_write(self): print("STEP: test_read_and_write - setup table and insert row") # u_uid,name are VARCHAR(10); keep identifiers short. unique_uid = f"u{int(time.time()) % 100000:05d}" # Ensure schema/table exist and insert a row. with self.pool.connection() as conn: conn.execute("CREATE SCHEMA IF NOT EXISTS my_schema") conn.execute( """ CREATE TABLE IF NOT EXISTS my_schema.test_table ( u_uid VARCHAR(10), name VARCHAR(10), CONSTRAINT pk_uid PRIMARY KEY(u_uid) ) """ ) # Ensure no leftover row with same key before insert. delete_before_insert = conn.prepare("DELETE FROM my_schema.test_table WHERE u_uid = $1") delete_before_insert(unique_uid) insert_stmt = conn.prepare( "INSERT INTO my_schema.test_table (u_uid, name) VALUES ($1, $2)" ) insert_stmt(unique_uid, "codex") print("STEP: test_read_and_write - read row and assert") # Read back the row. with self.pool.connection() as conn: select_stmt = conn.prepare("SELECT name FROM my_schema.test_table WHERE u_uid = $1") rows = select_stmt(unique_uid) self.assertIsNotNone(rows) self.assertGreaterEqual(len(rows), 1) self.assertEqual(rows[0][0], "codex") print("STEP: test_read_and_write - cleanup inserted row") # Clean up the inserted row. with self.pool.connection() as conn: delete_stmt = conn.prepare("DELETE FROM my_schema.test_table WHERE u_uid = $1") delete_stmt(unique_uid) def test_broken_connection_is_replaced(self): print("STEP: test_broken_connection_is_replaced - close and return broken connection") conn = self.pool.borrow() # Simulate an unusable connection by closing it and marking as broken. conn.close() self.pool.return_connection(conn, had_error=True) print("STEP: test_broken_connection_is_replaced - borrow replacement and validate") # Borrow again; pool should create a new usable connection. with self.pool.connection() as conn2: stmt = conn2.prepare("SELECT 1") rows = stmt() self.assertEqual(rows[0][0], 1) self._wait_for_min_size() stats_after = self.pool.stats() self.assertFalse(stats_after["closed"]) self.assertLessEqual(stats_after["total"], self.config.max_size) self.assertGreaterEqual(stats_after["total"], self.config.min_size) def test_pool_limit_and_timeout(self): print("STEP: test_pool_limit_and_timeout - exhaust pool") borrowed = [self.pool.borrow() for _ in range(self.config.max_size)] errors: List[Exception] = [] print("STEP: test_pool_limit_and_timeout - attempt timed borrow") def try_borrow(): try: self.pool.borrow(timeout=0.5) except Exception as exc: # noqa: BLE001 - capture TimeoutError errors.append(exc) t = threading.Thread(target=try_borrow) t.start() t.join() print("STEP: test_pool_limit_and_timeout - return connections and assert stats") for conn in borrowed: self.pool.return_connection(conn) stats = self.pool.stats() self.assertLessEqual(stats["total"], self.config.max_size) self.assertEqual(stats["in_use"], 0) self.assertGreaterEqual(stats["available"], self.config.min_size) self.assertTrue(any(isinstance(err, TimeoutError) for err in errors)) def test_serial_reads_with_lifecycle(self): print("STEP: test_serial_reads_with_lifecycle - serial reads with pauses") for _ in range(3): with self.pool.connection() as conn: select_stmt = conn.prepare( "SELECT content FROM my_schema.table1 WHERE u_uid = $1" ) rows = select_stmt(self.seed_uid) self.assertEqual(rows[0][0], "seed") time.sleep(0.5) print("STEP: test_serial_reads_with_lifecycle - assert pool stats") stats = self.pool.stats() self.assertLessEqual(stats["total"], self.config.max_size) self.assertFalse(stats["closed"]) def test_concurrent_reads(self): print("STEP: test_concurrent_reads - start threads") results: List[str] = [] errors: List[Exception] = [] lock = threading.Lock() def worker(): try: with self.pool.connection(timeout=2.0) as conn: select_stmt = conn.prepare( "SELECT content FROM my_schema.table1 WHERE u_uid = $1" ) rows = select_stmt(self.seed_uid) with lock: results.append(rows[0][0]) except Exception as exc: # noqa: BLE001 - capture any worker errors with lock: errors.append(exc) threads = [threading.Thread(target=worker) for _ in range(20)] for t in threads: t.start() for t in threads: t.join() print("STEP: test_concurrent_reads - assert results and pool stats") stats = self.pool.stats() self.assertEqual(len(errors), 0) self.assertEqual(len(results), 20) self.assertFalse(stats["closed"]) self.assertLessEqual(stats["total"], self.config.max_size) def test_concurrent_batch_writes(self): print("STEP: test_concurrent_batch_writes - start writer threads") base = self.seed_uid + 10000 process_id = f"pool_batch_{base}" total_batches = 5 batch_size = 20 worker_count = 8 errors: List[Exception] = [] lock = threading.Lock() def writer(worker_idx: int): try: for batch_idx in range(total_batches): with self.pool.connection(timeout=2.0) as conn: insert_stmt = conn.prepare( """ INSERT INTO my_schema.table1 (u_uid, process_id, ins_tm, content) VALUES ($1, $2, NOW(), $3) """ ) for i in range(batch_size): uid = base + (worker_idx * 1000) + (batch_idx * batch_size) + i insert_stmt(uid, process_id, f"batch_{batch_idx}") except Exception as exc: # noqa: BLE001 - capture any worker errors with lock: errors.append(exc) threads = [threading.Thread(target=writer, args=(i,)) for i in range(worker_count)] for t in threads: t.start() for t in threads: t.join() print("STEP: test_concurrent_batch_writes - validate count and cleanup") with self.pool.connection(timeout=2.0) as conn: count_stmt = conn.prepare( "SELECT COUNT(*) FROM my_schema.table1 WHERE process_id = $1" ) rows = count_stmt(process_id) total_inserted = rows[0][0] delete_stmt = conn.prepare( "DELETE FROM my_schema.table1 WHERE process_id = $1" ) delete_stmt(process_id) stats = self.pool.stats() self.assertEqual(len(errors), 0) self.assertEqual(total_inserted, worker_count * total_batches * batch_size) self.assertFalse(stats["closed"]) def test_connection_lifecycle_management(self): print("STEP: test_connection_lifecycle_management - borrow and age connection") conn = self.pool.borrow() conn.prepare("SELECT 1")() time.sleep(self.config.max_lifetime + 0.2) self.pool.return_connection(conn) print("STEP: test_connection_lifecycle_management - borrow new connection") new_conn = self.pool.borrow(timeout=1.0) self.pool.return_connection(new_conn) stats = self.pool.stats() self.assertNotEqual(id(conn), id(new_conn)) self.assertLessEqual(stats["total"], self.config.max_size) def test_connection_validation_replaces_bad_connection(self): print("STEP: test_connection_validation_replaces_bad_connection - return closed conn") conn = self.pool.borrow() conn.close() self.pool.return_connection(conn, had_error=False) print("STEP: test_connection_validation_replaces_bad_connection - borrow and validate") with self.pool.connection(timeout=1.0) as conn2: rows = conn2.prepare("SELECT 1")() stats = self.pool.stats() self.assertEqual(rows[0][0], 1) self.assertLessEqual(stats["total"], self.config.max_size) def test_long_running_requests(self): print("STEP: test_long_running_requests - run mixed workload") errors: List[Exception] = [] lock = threading.Lock() stop_at = time.monotonic() + 3.0 def load_worker(): while time.monotonic() < stop_at: try: with self.pool.connection(timeout=2.0) as conn: conn.prepare("SELECT 1")() time.sleep(0.2) select_stmt = conn.prepare( "SELECT content FROM my_schema.table1 WHERE u_uid = $1" ) rows = select_stmt(self.seed_uid) if rows[0][0] != "seed": raise AssertionError("Unexpected content") except Exception as exc: # noqa: BLE001 - capture any worker errors with lock: errors.append(exc) return threads = [threading.Thread(target=load_worker) for _ in range(4)] for t in threads: t.start() for t in threads: t.join() print("STEP: test_long_running_requests - assert pool stats") stats = self.pool.stats() self.assertEqual(len(errors), 0) self.assertFalse(stats["closed"]) self.assertLessEqual(stats["total"], self.config.max_size) if __name__ == "__main__": unittest.main()