| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298 |
- import threading
- import time
- import unittest
- from typing import List
- try:
- import py_opengauss # type: ignore # noqa: F401
- except ImportError: # pragma: no cover - handled via skip
- py_opengauss = None
- from gdb_utils.opengauss_pool_hardened import (
- ConnectionPoolConfig,
- OpenGaussConnectionPool,
- PoolExhaustedError,
- )
- DB_HOST = "127.0.0.1"
- DB_PORT = 5432
- DB_USER = "gaussdb"
- DB_PASSWORD = "Ysl#1234"
- DB_NAME = "postgres"
- def build_dsn() -> str:
- return f"opengauss://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
- @unittest.skipIf(py_opengauss is None, "py_opengauss is required for integration tests")
- class OpenGaussHardenedPoolIntegrationTest(unittest.TestCase):
- seed_uid: int
- seed_process_id: str
- config: ConnectionPoolConfig
- pool: OpenGaussConnectionPool
- @classmethod
- def setUpClass(cls) -> None:
- base = int(time.time()) % 1_500_000_000
- cls.seed_uid = base + 200
- cls.seed_process_id = f"hardened_seed_{base}"
- conn = py_opengauss.open(build_dsn())
- try:
- delete_stmt = conn.prepare("DELETE FROM my_schema.table1 WHERE u_uid = $1")
- delete_stmt(cls.seed_uid)
- insert_stmt = conn.prepare(
- """
- INSERT INTO my_schema.table1 (u_uid, process_id, ins_tm, content)
- VALUES ($1, $2, NOW(), $3)
- """
- )
- insert_stmt(cls.seed_uid, cls.seed_process_id, "seed")
- finally:
- conn.close()
- cls.config = ConnectionPoolConfig(
- dsn=build_dsn(),
- min_size=2,
- max_size=6,
- blocking=True,
- acquire_timeout=0.5,
- idle_timeout=10,
- max_lifetime=30,
- max_usage=20,
- test_on_borrow=True,
- test_sql="SELECT 1",
- keepalive=True,
- keepalive_interval=0.5,
- health_check_interval=0.5,
- reset_on_return=True,
- )
- cls.pool = OpenGaussConnectionPool(cls.config)
- @classmethod
- def tearDownClass(cls) -> None:
- if hasattr(cls, "pool") and cls.pool:
- cls.pool.close()
- conn = py_opengauss.open(build_dsn())
- try:
- delete_stmt = conn.prepare(
- "DELETE FROM my_schema.table1 WHERE process_id = $1"
- )
- delete_stmt(cls.seed_process_id)
- finally:
- conn.close()
- def _new_pool(self, **overrides) -> OpenGaussConnectionPool:
- config_dict = self.config.__dict__.copy()
- config_dict.update(overrides)
- return OpenGaussConnectionPool(ConnectionPoolConfig(**config_dict))
- def test_pool_limit_and_timeout(self):
- print("STEP: test_pool_limit_and_timeout - exhaust pool")
- borrowed = [self.pool.borrow() for _ in range(self.config.max_size)]
- print("STEP: test_pool_limit_and_timeout - timed borrow raises")
- with self.assertRaises(PoolExhaustedError):
- self.pool.borrow(timeout=0.3)
- print("STEP: test_pool_limit_and_timeout - return connections and assert stats")
- for conn in borrowed:
- conn.close()
- stats = self.pool.stats()
- self.assertLessEqual(stats["total"], self.config.max_size)
- self.assertEqual(stats["in_use"], 0)
- self.assertGreaterEqual(stats["available"], self.config.min_size)
- def test_serial_reads_with_lifecycle(self):
- print("STEP: test_serial_reads_with_lifecycle - serial reads with pauses")
- pool = self._new_pool(min_size=1, max_size=3, max_lifetime=1.0, idle_timeout=5.0)
- stats = {}
- seen_ids = set()
- stats_before = pool.stats()
- try:
- for idx in range(4):
- with pool.connection(timeout=1.0) as conn:
- seen_ids.add(id(conn._entry.conn))
- select_stmt = conn.prepare(
- "SELECT content FROM my_schema.table1 WHERE u_uid = $1"
- )
- rows = select_stmt(self.seed_uid)
- self.assertEqual(rows[0][0], "seed")
- if idx == 1:
- time.sleep(1.2)
- else:
- time.sleep(0.2)
- stats = pool.stats()
- finally:
- pool.close()
- print("STEP: test_serial_reads_with_lifecycle - assert pool stats")
- self.assertLessEqual(stats["total"], 3)
- self.assertFalse(stats["closed"])
- self.assertTrue(
- len(seen_ids) >= 2 or stats["discarded"] > stats_before["discarded"]
- )
- def test_concurrent_reads(self):
- print("STEP: test_concurrent_reads - start threads")
- results: List[str] = []
- errors: List[Exception] = []
- lock = threading.Lock()
- def worker():
- try:
- with self.pool.connection(timeout=2.0) as conn:
- select_stmt = conn.prepare(
- "SELECT content FROM my_schema.table1 WHERE u_uid = $1"
- )
- rows = select_stmt(self.seed_uid)
- with lock:
- results.append(rows[0][0])
- except Exception as exc: # noqa: BLE001 - capture any worker errors
- with lock:
- errors.append(exc)
- threads = [threading.Thread(target=worker) for _ in range(20)]
- for t in threads:
- t.start()
- for t in threads:
- t.join()
- print("STEP: test_concurrent_reads - assert results and pool stats")
- stats = self.pool.stats()
- self.assertEqual(len(errors), 0)
- self.assertEqual(len(results), 20)
- self.assertFalse(stats["closed"])
- self.assertLessEqual(stats["total"], self.config.max_size)
- def test_concurrent_batch_writes(self):
- print("STEP: test_concurrent_batch_writes - start writer threads")
- base = self.seed_uid + 20000
- process_id = f"hardened_batch_{base}"
- total_batches = 5
- batch_size = 20
- worker_count = 8
- errors: List[Exception] = []
- lock = threading.Lock()
- def writer(worker_idx: int):
- try:
- for batch_idx in range(total_batches):
- with self.pool.connection(timeout=2.0) as conn:
- insert_stmt = conn.prepare(
- """
- INSERT INTO my_schema.table1 (u_uid, process_id, ins_tm, content)
- VALUES ($1, $2, NOW(), $3)
- """
- )
- for i in range(batch_size):
- uid = base + (worker_idx * 1000) + (batch_idx * batch_size) + i
- insert_stmt(uid, process_id, f"batch_{batch_idx}")
- except Exception as exc: # noqa: BLE001 - capture any worker errors
- with lock:
- errors.append(exc)
- threads = [threading.Thread(target=writer, args=(i,)) for i in range(worker_count)]
- for t in threads:
- t.start()
- for t in threads:
- t.join()
- print("STEP: test_concurrent_batch_writes - validate count and cleanup")
- with self.pool.connection(timeout=2.0) as conn:
- count_stmt = conn.prepare(
- "SELECT COUNT(*) FROM my_schema.table1 WHERE process_id = $1"
- )
- rows = count_stmt(process_id)
- total_inserted = rows[0][0]
- delete_stmt = conn.prepare(
- "DELETE FROM my_schema.table1 WHERE process_id = $1"
- )
- delete_stmt(process_id)
- stats = self.pool.stats()
- self.assertEqual(len(errors), 0)
- self.assertEqual(total_inserted, worker_count * total_batches * batch_size)
- self.assertFalse(stats["closed"])
- def test_connection_lifecycle_management(self):
- print("STEP: test_connection_lifecycle_management - hold and expire connection")
- pool = self._new_pool(min_size=1, max_size=2, max_lifetime=1.0, idle_timeout=5.0)
- stats = {}
- try:
- conn = pool.borrow()
- conn_id = id(conn._entry.conn)
- time.sleep(1.2)
- conn.close()
- print("STEP: test_connection_lifecycle_management - borrow replacement")
- new_conn = pool.borrow(timeout=1.0)
- new_id = id(new_conn._entry.conn)
- new_conn.close()
- stats = pool.stats()
- finally:
- pool.close()
- self.assertNotEqual(conn_id, new_id)
- self.assertLessEqual(stats["total"], 2)
- def test_connection_validation_replaces_bad_connection(self):
- print("STEP: test_connection_validation_replaces_bad_connection - return closed conn")
- conn = self.pool.borrow()
- bad_id = id(conn._entry.conn)
- conn._entry.conn.close()
- self.pool.return_connection(conn)
- print("STEP: test_connection_validation_replaces_bad_connection - borrow and validate")
- with self.pool.connection(timeout=1.0) as conn2:
- new_id = id(conn2._entry.conn)
- rows = conn2.prepare("SELECT 1")()
- stats = self.pool.stats()
- self.assertEqual(rows[0][0], 1)
- self.assertNotEqual(bad_id, new_id)
- self.assertLessEqual(stats["total"], self.config.max_size)
- def test_long_running_requests(self):
- print("STEP: test_long_running_requests - run mixed workload")
- errors: List[Exception] = []
- lock = threading.Lock()
- stop_at = time.monotonic() + 3.0
- def load_worker():
- while time.monotonic() < stop_at:
- try:
- with self.pool.connection(timeout=2.0) as conn:
- conn.prepare("SELECT 1")()
- time.sleep(0.2)
- select_stmt = conn.prepare(
- "SELECT content FROM my_schema.table1 WHERE u_uid = $1"
- )
- rows = select_stmt(self.seed_uid)
- if rows[0][0] != "seed":
- raise AssertionError("Unexpected content")
- except Exception as exc: # noqa: BLE001 - capture any worker errors
- with lock:
- errors.append(exc)
- return
- threads = [threading.Thread(target=load_worker) for _ in range(4)]
- for t in threads:
- t.start()
- for t in threads:
- t.join()
- print("STEP: test_long_running_requests - assert pool stats")
- stats = self.pool.stats()
- self.assertEqual(len(errors), 0)
- self.assertFalse(stats["closed"])
- self.assertLessEqual(stats["total"], self.config.max_size)
- if __name__ == "__main__":
- unittest.main()
|