ysl2007 hai 5 meses
pai
achega
db5be68bfe

+ 7 - 0
gdb_utils/__init__.py.txt

@@ -0,0 +1,7 @@
+"""
+Utility module for database helpers.
+"""
+
+from .opengauss_pool import OpenGaussConnectionPool, ConnectionPoolConfig
+
+__all__ = ["OpenGaussConnectionPool", "ConnectionPoolConfig"]

+ 333 - 0
gdb_utils/opengauss_pool.py.txt

@@ -0,0 +1,333 @@
+"""
+Lightweight connection pool built on top of py_opengauss.
+
+The pool keeps a configurable number of connections ready, validates them on
+borrow when requested, performs periodic health checks/keepalives, and discards
+connections that fail SQL execution or exceed idle/lifetime thresholds.
+"""
+
+from __future__ import annotations
+
+import logging
+import threading
+import time
+from collections import deque
+from contextlib import contextmanager
+from dataclasses import dataclass
+from typing import Any, Deque, Dict, Optional
+
+import py_opengauss
+
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class ConnectionPoolConfig:
+    dsn: str
+    min_size: int = 1
+    max_size: int = 10
+    idle_timeout: float = 300.0
+    max_lifetime: float = 3600.0
+    test_on_borrow: bool = True
+    test_sql: str = "SELECT 1"
+    keepalive: bool = True
+    keepalive_interval: float = 60.0
+    health_check_interval: float = 30.0
+    connect_kwargs: Optional[Dict[str, Any]] = None
+
+
+@dataclass
+class _ConnectionEntry:
+    conn: Any
+    created_at: float
+    last_used: float
+    last_check: float
+
+
+class OpenGaussConnectionPool:
+    """
+    Thread-safe connection pool for py_opengauss.
+
+    Borrowed connections are returned directly without wrappers. When an error
+    is detected during SQL execution, callers should mark the connection as
+    broken via `return_connection(conn, had_error=True)` to evict it quickly.
+    """
+
+    def __init__(self, config: ConnectionPoolConfig):
+        if config.min_size < 0 or config.max_size <= 0:
+            raise ValueError("Pool sizes must be positive")
+        if config.min_size > config.max_size:
+            raise ValueError("min_size cannot be greater than max_size")
+
+        self._config = config
+        self._lock = threading.Lock()
+        self._condition = threading.Condition(self._lock)
+        self._available: Deque[_ConnectionEntry] = deque()
+        self._in_use: Dict[int, _ConnectionEntry] = {}
+        self._total = 0
+        self._closed = False
+
+        self._housekeeper = threading.Thread(
+            target=self._housekeeping, name="opengauss-pool-housekeeper", daemon=True
+        )
+
+        self._initial_fill()
+        self._housekeeper.start()
+
+    def borrow(self, timeout: Optional[float] = None) -> Any:
+        """
+        Borrow a connection from the pool.
+
+        If test_on_borrow is enabled, the connection will be validated with the
+        configured SQL before being returned.
+        """
+        deadline = time.monotonic() + timeout if timeout else None
+        while True:
+            entry = self._try_acquire_available()
+            if entry:
+                if self._config.test_on_borrow and not self._validate(entry):
+                    self._discard(entry)
+                    continue
+                entry.last_used = time.time()
+                with self._condition:
+                    self._in_use[id(entry.conn)] = entry
+                return entry.conn
+
+            to_create = self._reserve_slot_for_new()
+            if to_create:
+                entry = self._create_entry(reserved=True)
+                if entry is None:
+                    # Creation failed; avoid tight loop when DB is unavailable.
+                    with self._condition:
+                        if self._closed:
+                            raise RuntimeError("Connection pool is closed")
+                        if deadline is not None:
+                            remaining = deadline - time.monotonic()
+                            if remaining <= 0:
+                                raise TimeoutError("Timed out waiting for connection")
+                            self._condition.wait(timeout=min(0.1, remaining))
+                        else:
+                            self._condition.wait(timeout=0.1)
+                    continue
+                with self._condition:
+                    self._in_use[id(entry.conn)] = entry
+                return entry.conn
+
+            # Wait for a return or until timed out.
+            with self._condition:
+                if self._closed:
+                    raise RuntimeError("Connection pool is closed")
+                if deadline is not None:
+                    remaining = deadline - time.monotonic()
+                    if remaining <= 0:
+                        raise TimeoutError("Timed out waiting for connection")
+                    self._condition.wait(timeout=remaining)
+                else:
+                    self._condition.wait()
+
+    def return_connection(self, conn: Any, had_error: bool = False) -> None:
+        """
+        Return a connection to the pool.
+
+        If had_error is True, the connection will be closed and removed to
+        avoid reusing broken connections.
+        """
+        entry = None
+        with self._condition:
+            entry = self._in_use.pop(id(conn), None)
+            if entry is None:
+                # Unknown connection; close it defensively.
+                try:
+                    conn.close()
+                finally:
+                    return
+
+        if had_error:
+            self._discard(entry)
+        else:
+            entry.last_used = time.time()
+            if self._should_discard(entry):
+                self._discard(entry)
+            else:
+                with self._condition:
+                    self._available.append(entry)
+                    self._condition.notify()
+
+    @contextmanager
+    def connection(self, timeout: Optional[float] = None):
+        """
+        Context manager helper that returns the connection to the pool
+        automatically and discards it if an exception is raised.
+        """
+        conn = self.borrow(timeout=timeout)
+        try:
+            yield conn
+        except Exception:
+            self.return_connection(conn, had_error=True)
+            raise
+        else:
+            self.return_connection(conn)
+
+    def close(self) -> None:
+        """
+        Close the pool and all managed connections.
+        """
+        with self._condition:
+            if self._closed:
+                return
+            self._closed = True
+            to_close = list(self._available)
+            self._available.clear()
+            to_close.extend(self._in_use.values())
+            self._in_use.clear()
+            self._condition.notify_all()
+        for entry in to_close:
+            try:
+                entry.conn.close()
+            except Exception:
+                logger.debug("Failed closing connection during pool shutdown", exc_info=True)
+        self._total = 0
+
+    def stats(self) -> Dict[str, Any]:
+        """
+        Lightweight snapshot of pool counters for observability/testing.
+        """
+        with self._condition:
+            return {
+                "total": self._total,
+                "available": len(self._available),
+                "in_use": len(self._in_use),
+                "closed": self._closed,
+            }
+
+    # Internal helpers -----------------------------------------------------
+
+    def _initial_fill(self) -> None:
+        for _ in range(self._config.min_size):
+            entry = self._create_entry()
+            if entry:
+                with self._condition:
+                    self._available.append(entry)
+                    self._total += 1
+
+    def _reserve_slot_for_new(self) -> bool:
+        with self._condition:
+            if self._closed:
+                raise RuntimeError("Connection pool is closed")
+            if self._total >= self._config.max_size:
+                return False
+            self._total += 1
+            return True
+
+    def _try_acquire_available(self) -> Optional[_ConnectionEntry]:
+        now = time.time()
+        to_discard: Deque[_ConnectionEntry] = deque()
+        with self._condition:
+            if self._closed:
+                raise RuntimeError("Connection pool is closed")
+            while self._available:
+                entry = self._available.popleft()
+                if self._should_discard(entry, now):
+                    to_discard.append(entry)
+                    continue
+                return entry
+        for entry in list(to_discard):
+            self._discard(entry)
+        return None
+
+    def _create_entry(self, reserved: bool = False) -> Optional[_ConnectionEntry]:
+        try:
+            conn = py_opengauss.open(self._config.dsn, **(self._config.connect_kwargs or {}))
+            now = time.time()
+            return _ConnectionEntry(conn=conn, created_at=now, last_used=now, last_check=now)
+        except Exception:
+            logger.exception("Failed to create new py_opengauss connection")
+            if reserved:
+                with self._condition:
+                    self._total = max(0, self._total - 1)
+            return None
+
+    def _validate(self, entry: _ConnectionEntry) -> bool:
+        try:
+            stmt = entry.conn.prepare(self._config.test_sql)
+            stmt()
+            entry.last_check = time.time()
+            return True
+        except Exception:
+            logger.warning("Validation failed; discarding connection", exc_info=True)
+            return False
+
+    def _should_discard(self, entry: _ConnectionEntry, now: Optional[float] = None) -> bool:
+        now = now or time.time()
+        if self._config.max_lifetime and now - entry.created_at >= self._config.max_lifetime:
+            return True
+        if self._config.idle_timeout and now - entry.last_used >= self._config.idle_timeout:
+            return True
+        return False
+
+    def _discard(self, entry: _ConnectionEntry) -> None:
+        try:
+            entry.conn.close()
+        except Exception:
+            logger.debug("Failed closing connection", exc_info=True)
+        with self._condition:
+            self._total = max(0, self._total - 1)
+            self._condition.notify()
+
+    def _housekeeping(self) -> None:
+        interval = max(1.0, self._config.health_check_interval)
+        while True:
+            if self._closed:
+                return
+            self._perform_health_check()
+            time.sleep(interval)
+
+    def _perform_health_check(self) -> None:
+        now = time.time()
+        to_check: Deque[_ConnectionEntry] = deque()
+        to_discard: Deque[_ConnectionEntry] = deque()
+        with self._condition:
+            if self._closed:
+                return
+            kept: Deque[_ConnectionEntry] = deque()
+            while self._available:
+                entry = self._available.popleft()
+                if self._should_discard(entry, now):
+                    to_discard.append(entry)
+                    continue
+                if (
+                    self._config.keepalive
+                    and self._config.keepalive_interval
+                    and now - entry.last_check >= self._config.keepalive_interval
+                ):
+                    to_check.append(entry)
+                else:
+                    kept.append(entry)
+            self._available = kept
+
+        for entry in list(to_discard):
+            self._discard(entry)
+
+        # Run keepalive checks outside lock to avoid blocking borrowers.
+        for entry in list(to_check):
+            if self._validate(entry):
+                with self._condition:
+                    if not self._closed:
+                        self._available.append(entry)
+            else:
+                self._discard(entry)
+
+        # Ensure minimum pool size is preserved.
+        with self._condition:
+            needed = max(0, self._config.min_size - self._total)
+        for _ in range(needed):
+            entry = self._create_entry()
+            if entry:
+                with self._condition:
+                    if self._closed:
+                        entry.conn.close()
+                        return
+                    self._available.append(entry)
+                    self._total += 1
+                    self._condition.notify()

+ 438 - 0
gdb_utils/opengauss_pool_hardened.py.txt

@@ -0,0 +1,438 @@
+"""
+Hardened connection pool for py_opengauss inspired by dbutils.pooled_db.
+
+Key differences from the lightweight pool:
+* Uses a proxy wrapper so connections are always returned (close/__del__/with).
+* Resets sessions on return (rollback + optional setsession).
+* Avoids deadlocks by using an RLock and separating discard-with-lock paths.
+* Supports blocking/non-blocking acquire with timeout and max usage cycling.
+* Uses monotonic clocks for all interval calculations to avoid wall-clock jumps.
+"""
+
+from __future__ import annotations
+
+import logging
+import threading
+import time
+from collections import deque
+from contextlib import contextmanager
+from dataclasses import dataclass
+from typing import Any, Deque, Dict, Optional, Sequence
+
+import py_opengauss
+
+logger = logging.getLogger(__name__)
+
+
+class PoolError(Exception):
+    """Base pool error."""
+
+
+class PoolClosedError(PoolError):
+    """Pool has been closed."""
+
+
+class PoolExhaustedError(PoolError):
+    """Pool is at capacity and blocking is disabled or timed out."""
+
+
+@dataclass
+class ConnectionPoolConfig:
+    dsn: str
+    min_size: int = 1
+    max_size: int = 10
+    blocking: bool = True
+    acquire_timeout: Optional[float] = None
+    idle_timeout: float = 300.0
+    max_lifetime: float = 3600.0
+    max_usage: Optional[int] = None
+    test_on_borrow: bool = True
+    test_sql: str = "SELECT 1"
+    keepalive: bool = True
+    keepalive_interval: float = 60.0
+    health_check_interval: float = 30.0
+    reset_on_return: bool = True
+    setsession: Optional[Sequence[str]] = None
+    connect_kwargs: Optional[Dict[str, Any]] = None
+
+
+@dataclass
+class _ConnectionEntry:
+    conn: Any
+    created_at: float
+    last_used: float
+    last_check: float
+    usage_count: int = 0
+
+
+class _PooledConnectionProxy:
+    """Lightweight proxy that returns connections to the pool on close/del."""
+
+    def __init__(self, pool: "OpenGaussConnectionPool", entry: _ConnectionEntry):
+        self._pool = pool
+        self._entry = entry
+        self._returned = False
+
+    def __getattr__(self, item: str) -> Any:
+        return getattr(self._entry.conn, item)
+
+    def __enter__(self) -> "_PooledConnectionProxy":
+        return self
+
+    def __exit__(self, exc_type, exc, tb) -> None:
+        self.close(broken=bool(exc_type))
+
+    def close(self, broken: bool = False) -> None:
+        if self._returned:
+            return
+        self._returned = True
+        self._pool._return_entry(self._entry, broken=broken)
+
+    def mark_broken(self) -> None:
+        """Explicitly mark this connection as broken so it gets discarded."""
+        self.close(broken=True)
+
+    def __del__(self):
+        # Best-effort return to pool to avoid leaks if user forgets to close.
+        try:
+            self.close()
+        except Exception:
+            pass
+
+    def __repr__(self) -> str:
+        return f"<PooledOpenGaussConnection id={id(self._entry.conn)} returned={self._returned}>"
+
+
+class OpenGaussConnectionPool:
+    """
+    Hardened, thread-safe connection pool for py_opengauss.
+
+    Connections are wrapped in a proxy that ensures return to the pool on close
+    or object finalization. Pool operations rely on monotonic clocks to avoid
+    issues from system clock jumps.
+    """
+
+    def __init__(self, config: ConnectionPoolConfig):
+        if config.min_size < 0 or config.max_size <= 0:
+            raise ValueError("Pool sizes must be positive")
+        if config.min_size > config.max_size:
+            raise ValueError("min_size cannot be greater than max_size")
+
+        self._config = config
+        self._lock = threading.RLock()
+        self._condition = threading.Condition(self._lock)
+        self._available: Deque[_ConnectionEntry] = deque()
+        self._in_use: Dict[int, _ConnectionEntry] = {}
+        self._total = 0
+        self._create_failures = 0
+        self._discarded = 0
+        self._closed = False
+        self._stop_event = threading.Event()
+
+        self._housekeeper = threading.Thread(
+            target=self._housekeeping, name="opengauss-pool-housekeeper", daemon=True
+        )
+
+        self._initial_fill()
+        self._housekeeper.start()
+
+    def borrow(self, timeout: Optional[float] = None) -> _PooledConnectionProxy:
+        """
+        Borrow a connection from the pool.
+
+        If test_on_borrow is enabled, the connection will be validated with the
+        configured SQL before being returned. When blocking is False, pool
+        exhaustion immediately raises PoolExhaustedError.
+        """
+        deadline = None
+        effective_timeout = timeout
+        if effective_timeout is None:
+            effective_timeout = self._config.acquire_timeout
+        if effective_timeout is not None:
+            deadline = time.monotonic() + effective_timeout
+
+        while True:
+            entry = self._try_acquire_available()
+            if entry:
+                proxy = self._prepare_borrow(entry)
+                if proxy:
+                    return proxy
+                continue
+
+            to_create = self._reserve_slot_for_new()
+            if to_create:
+                entry = self._create_entry(reserved=True)
+                if entry is None:
+                    # Creation failed; avoid tight loop when DB is unavailable.
+                    self._wait_for_availability(deadline)
+                    continue
+                proxy = self._prepare_borrow(entry)
+                if proxy:
+                    return proxy
+                continue
+
+            # Pool is at capacity and nothing available.
+            if not self._config.blocking:
+                raise PoolExhaustedError("Pool is at capacity and blocking is disabled")
+
+            self._wait_for_availability(deadline)
+
+    def return_connection(self, conn: Any, had_error: bool = False) -> None:
+        """
+        Return a connection (or proxy) to the pool.
+
+        If had_error is True, the connection will be closed and removed to
+        avoid reusing broken connections.
+        """
+        entry = None
+        if isinstance(conn, _PooledConnectionProxy):
+            conn.close(broken=had_error)
+            return
+
+        with self._condition:
+            entry = self._in_use.pop(id(conn), None)
+
+        if entry is None:
+            try:
+                conn.close()
+            finally:
+                return
+        self._return_entry(entry, broken=had_error)
+
+    @contextmanager
+    def connection(self, timeout: Optional[float] = None):
+        """
+        Context manager helper that returns the connection to the pool
+        automatically and discards it if an exception is raised.
+        """
+        proxy = self.borrow(timeout=timeout)
+        try:
+            yield proxy
+        except Exception:
+            proxy.close(broken=True)
+            raise
+        else:
+            proxy.close()
+
+    def close(self) -> None:
+        """Close the pool and all managed connections."""
+        with self._condition:
+            if self._closed:
+                return
+            self._closed = True
+            to_close = list(self._available)
+            self._available.clear()
+            to_close.extend(self._in_use.values())
+            self._in_use.clear()
+            self._condition.notify_all()
+            self._stop_event.set()
+        for entry in to_close:
+            self._close_entry(entry)
+        self._total = 0
+        if self._housekeeper.is_alive():
+            self._housekeeper.join(timeout=1.0)
+
+    def stats(self) -> Dict[str, Any]:
+        """Lightweight snapshot of pool counters for observability/testing."""
+        with self._condition:
+            return {
+                "total": self._total,
+                "available": len(self._available),
+                "in_use": len(self._in_use),
+                "closed": self._closed,
+                "create_failures": self._create_failures,
+                "discarded": self._discarded,
+            }
+
+    # Internal helpers -----------------------------------------------------
+
+    def _prepare_borrow(self, entry: _ConnectionEntry) -> Optional[_PooledConnectionProxy]:
+        if self._config.test_on_borrow and not self._validate(entry):
+            self._discard_entry(entry)
+            return None
+        entry.last_used = time.monotonic()
+        entry.usage_count += 1
+        with self._condition:
+            self._in_use[id(entry.conn)] = entry
+        return _PooledConnectionProxy(self, entry)
+
+    def _initial_fill(self) -> None:
+        for _ in range(self._config.min_size):
+            entry = self._create_entry()
+            if entry:
+                with self._condition:
+                    self._available.append(entry)
+                    self._total += 1
+
+    def _reserve_slot_for_new(self) -> bool:
+        with self._condition:
+            if self._closed:
+                raise PoolClosedError("Connection pool is closed")
+            if self._total >= self._config.max_size:
+                return False
+            self._total += 1
+            return True
+
+    def _try_acquire_available(self) -> Optional[_ConnectionEntry]:
+        now = time.monotonic()
+        with self._condition:
+            if self._closed:
+                raise PoolClosedError("Connection pool is closed")
+            while self._available:
+                entry = self._available.popleft()
+                if self._should_discard(entry, now):
+                    self._discard_entry_locked(entry)
+                    continue
+                return entry
+        return None
+
+    def _create_entry(self, reserved: bool = False) -> Optional[_ConnectionEntry]:
+        try:
+            conn = py_opengauss.open(self._config.dsn, **(self._config.connect_kwargs or {}))
+            now = time.monotonic()
+            return _ConnectionEntry(conn=conn, created_at=now, last_used=now, last_check=now)
+        except Exception:
+            logger.exception("Failed to create new py_opengauss connection")
+            self._create_failures += 1
+            if reserved:
+                with self._condition:
+                    self._total = max(0, self._total - 1)
+                    self._condition.notify_all()
+            return None
+
+    def _validate(self, entry: _ConnectionEntry) -> bool:
+        try:
+            stmt = entry.conn.prepare(self._config.test_sql)
+            stmt()
+            entry.last_check = time.monotonic()
+            return True
+        except Exception:
+            logger.warning("Validation failed; discarding connection", exc_info=True)
+            return False
+
+    def _should_discard(self, entry: _ConnectionEntry, now: Optional[float] = None) -> bool:
+        now = now or time.monotonic()
+        if self._config.max_lifetime and now - entry.created_at >= self._config.max_lifetime:
+            return True
+        if self._config.idle_timeout and now - entry.last_used >= self._config.idle_timeout:
+            return True
+        if self._config.max_usage and entry.usage_count >= self._config.max_usage:
+            return True
+        return False
+
+    def _reset_connection(self, entry: _ConnectionEntry) -> bool:
+        if not self._config.reset_on_return and not self._config.setsession:
+            return True
+        try:
+            if self._config.reset_on_return:
+                rollback = getattr(entry.conn, "rollback", None)
+                if callable(rollback):
+                    rollback()
+                else:
+                    entry.conn.execute("ROLLBACK")
+            if self._config.setsession:
+                for sql in self._config.setsession:
+                    stmt = entry.conn.prepare(sql)
+                    stmt()
+            return True
+        except Exception:
+            logger.warning("Failed to reset connection; discarding", exc_info=True)
+            return False
+
+    def _close_entry(self, entry: _ConnectionEntry) -> None:
+        try:
+            entry.conn.close()
+        except Exception:
+            logger.debug("Failed closing connection", exc_info=True)
+
+    def _discard_entry_locked(self, entry: _ConnectionEntry) -> None:
+        """Discard while holding the pool lock."""
+        self._close_entry(entry)
+        self._total = max(0, self._total - 1)
+        self._discarded += 1
+        self._condition.notify_all()
+
+    def _discard_entry(self, entry: _ConnectionEntry) -> None:
+        with self._condition:
+            self._discard_entry_locked(entry)
+
+    def _return_entry(self, entry: _ConnectionEntry, broken: bool = False) -> None:
+        with self._condition:
+            stored = self._in_use.pop(id(entry.conn), None)
+        if stored is None:
+            return
+        entry = stored
+        if broken or not self._reset_connection(entry) or self._should_discard(entry):
+            self._discard_entry(entry)
+            return
+        with self._condition:
+            if self._closed:
+                self._discard_entry_locked(entry)
+                return
+            entry.last_used = time.monotonic()
+            self._available.append(entry)
+            self._condition.notify_all()
+
+    def _wait_for_availability(self, deadline: Optional[float]) -> None:
+        if deadline is not None:
+            remaining = deadline - time.monotonic()
+            if remaining <= 0:
+                raise PoolExhaustedError("Timed out waiting for connection from pool")
+            with self._condition:
+                if self._closed:
+                    raise PoolClosedError("Connection pool is closed")
+                self._condition.wait(timeout=remaining)
+        else:
+            with self._condition:
+                if self._closed:
+                    raise PoolClosedError("Connection pool is closed")
+                self._condition.wait()
+
+    def _housekeeping(self) -> None:
+        interval = max(1.0, self._config.health_check_interval)
+        while not self._stop_event.wait(interval):
+            self._perform_health_check()
+
+    def _perform_health_check(self) -> None:
+        now = time.monotonic()
+        to_check: Deque[_ConnectionEntry] = deque()
+        with self._condition:
+            if self._closed:
+                return
+            kept: Deque[_ConnectionEntry] = deque()
+            while self._available:
+                entry = self._available.popleft()
+                if self._should_discard(entry, now):
+                    self._discard_entry_locked(entry)
+                    continue
+                if (
+                    self._config.keepalive
+                    and self._config.keepalive_interval
+                    and now - entry.last_check >= self._config.keepalive_interval
+                ):
+                    to_check.append(entry)
+                else:
+                    kept.append(entry)
+            self._available = kept
+
+        for entry in list(to_check):
+            if self._validate(entry):
+                with self._condition:
+                    if not self._closed:
+                        entry.last_used = time.monotonic()
+                        self._available.append(entry)
+            else:
+                self._discard_entry(entry)
+
+        with self._condition:
+            needed = max(0, self._config.min_size - self._total)
+        for _ in range(needed):
+            entry = self._create_entry()
+            if entry:
+                with self._condition:
+                    if self._closed:
+                        self._close_entry(entry)
+                        return
+                    self._available.append(entry)
+                    self._total += 1
+                    self._condition.notify_all()

+ 340 - 0
tests/test_opengauss_pool.py.txt

@@ -0,0 +1,340 @@
+import threading
+import time
+import unittest
+from typing import List
+
+try:
+    import py_opengauss  # type: ignore  # noqa: F401
+except ImportError:  # pragma: no cover - handled via skip
+    py_opengauss = None
+
+from gdb_utils import ConnectionPoolConfig, OpenGaussConnectionPool
+
+
+DB_HOST = "127.0.0.1"
+DB_PORT = 5432
+DB_USER = "gaussdb"
+DB_PASSWORD = "Ysl#1234"
+DB_NAME = "postgres"
+
+
+def build_dsn() -> str:
+    return f"opengauss://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
+
+
+@unittest.skipIf(py_opengauss is None, "py_opengauss is required for integration tests")
+class OpenGaussPoolIntegrationTest(unittest.TestCase):
+    seed_uid: int
+    seed_process_id: str
+    config: ConnectionPoolConfig
+    pool: OpenGaussConnectionPool
+
+    @classmethod
+    def setUpClass(cls) -> None:
+        base = int(time.time()) % 1_500_000_000
+        cls.seed_uid = base + 100
+        cls.seed_process_id = f"pool_seed_{base}"
+        conn = py_opengauss.open(build_dsn())
+        try:
+            delete_stmt = conn.prepare("DELETE FROM my_schema.table1 WHERE u_uid = $1")
+            delete_stmt(cls.seed_uid)
+            insert_stmt = conn.prepare(
+                """
+                INSERT INTO my_schema.table1 (u_uid, process_id, ins_tm, content)
+                VALUES ($1, $2, NOW(), $3)
+                """
+            )
+            insert_stmt(cls.seed_uid, cls.seed_process_id, "seed")
+        finally:
+            conn.close()
+
+        cls.config = ConnectionPoolConfig(
+            dsn=build_dsn(),
+            min_size=2,
+            max_size=8,
+            idle_timeout=30,
+            max_lifetime=60,
+            test_on_borrow=True,
+            test_sql="SELECT 1",
+            keepalive=True,
+            keepalive_interval=0.5,
+            health_check_interval=0.5,
+        )
+        cls.pool = OpenGaussConnectionPool(cls.config)
+
+    @classmethod
+    def tearDownClass(cls) -> None:
+        if hasattr(cls, "pool") and cls.pool:
+            cls.pool.close()
+        conn = py_opengauss.open(build_dsn())
+        try:
+            delete_stmt = conn.prepare(
+                "DELETE FROM my_schema.table1 WHERE process_id = $1"
+            )
+            delete_stmt(cls.seed_process_id)
+        finally:
+            conn.close()
+
+    def _wait_for_min_size(self, timeout: float = 2.0) -> None:
+        deadline = time.monotonic() + timeout
+        while time.monotonic() < deadline:
+            stats = self.pool.stats()
+            if stats["total"] >= self.config.min_size:
+                return
+            time.sleep(0.05)
+
+    def test_read_and_write(self):
+        print("STEP: test_read_and_write - setup table and insert row")
+        # u_uid,name are VARCHAR(10); keep identifiers short.
+        unique_uid = f"u{int(time.time()) % 100000:05d}"
+        # Ensure schema/table exist and insert a row.
+        with self.pool.connection() as conn:
+            conn.execute("CREATE SCHEMA IF NOT EXISTS my_schema")
+            conn.execute(
+                """
+                CREATE TABLE IF NOT EXISTS my_schema.test_table (
+                    u_uid VARCHAR(10),
+                    name VARCHAR(10),
+                    CONSTRAINT pk_uid PRIMARY KEY(u_uid)
+                )
+                """
+            )
+            # Ensure no leftover row with same key before insert.
+            delete_before_insert = conn.prepare("DELETE FROM my_schema.test_table WHERE u_uid = $1")
+            delete_before_insert(unique_uid)
+            insert_stmt = conn.prepare(
+                "INSERT INTO my_schema.test_table (u_uid, name) VALUES ($1, $2)"
+            )
+            insert_stmt(unique_uid, "codex")
+
+        print("STEP: test_read_and_write - read row and assert")
+        # Read back the row.
+        with self.pool.connection() as conn:
+            select_stmt = conn.prepare("SELECT name FROM my_schema.test_table WHERE u_uid = $1")
+            rows = select_stmt(unique_uid)
+
+        self.assertIsNotNone(rows)
+        self.assertGreaterEqual(len(rows), 1)
+        self.assertEqual(rows[0][0], "codex")
+
+        print("STEP: test_read_and_write - cleanup inserted row")
+        # Clean up the inserted row.
+        with self.pool.connection() as conn:
+            delete_stmt = conn.prepare("DELETE FROM my_schema.test_table WHERE u_uid = $1")
+            delete_stmt(unique_uid)
+
+    def test_broken_connection_is_replaced(self):
+        print("STEP: test_broken_connection_is_replaced - close and return broken connection")
+        conn = self.pool.borrow()
+        # Simulate an unusable connection by closing it and marking as broken.
+        conn.close()
+        self.pool.return_connection(conn, had_error=True)
+
+        print("STEP: test_broken_connection_is_replaced - borrow replacement and validate")
+        # Borrow again; pool should create a new usable connection.
+        with self.pool.connection() as conn2:
+            stmt = conn2.prepare("SELECT 1")
+            rows = stmt()
+        self.assertEqual(rows[0][0], 1)
+
+        self._wait_for_min_size()
+        stats_after = self.pool.stats()
+        self.assertFalse(stats_after["closed"])
+        self.assertLessEqual(stats_after["total"], self.config.max_size)
+        self.assertGreaterEqual(stats_after["total"], self.config.min_size)
+
+    def test_pool_limit_and_timeout(self):
+        print("STEP: test_pool_limit_and_timeout - exhaust pool")
+        borrowed = [self.pool.borrow() for _ in range(self.config.max_size)]
+        errors: List[Exception] = []
+
+        print("STEP: test_pool_limit_and_timeout - attempt timed borrow")
+        def try_borrow():
+            try:
+                self.pool.borrow(timeout=0.5)
+            except Exception as exc:  # noqa: BLE001 - capture TimeoutError
+                errors.append(exc)
+
+        t = threading.Thread(target=try_borrow)
+        t.start()
+        t.join()
+
+        print("STEP: test_pool_limit_and_timeout - return connections and assert stats")
+        for conn in borrowed:
+            self.pool.return_connection(conn)
+        stats = self.pool.stats()
+
+        self.assertLessEqual(stats["total"], self.config.max_size)
+        self.assertEqual(stats["in_use"], 0)
+        self.assertGreaterEqual(stats["available"], self.config.min_size)
+        self.assertTrue(any(isinstance(err, TimeoutError) for err in errors))
+
+    def test_serial_reads_with_lifecycle(self):
+        print("STEP: test_serial_reads_with_lifecycle - serial reads with pauses")
+        for _ in range(3):
+            with self.pool.connection() as conn:
+                select_stmt = conn.prepare(
+                    "SELECT content FROM my_schema.table1 WHERE u_uid = $1"
+                )
+                rows = select_stmt(self.seed_uid)
+                self.assertEqual(rows[0][0], "seed")
+                time.sleep(0.5)
+
+        print("STEP: test_serial_reads_with_lifecycle - assert pool stats")
+        stats = self.pool.stats()
+
+        self.assertLessEqual(stats["total"], self.config.max_size)
+        self.assertFalse(stats["closed"])
+
+    def test_concurrent_reads(self):
+        print("STEP: test_concurrent_reads - start threads")
+        results: List[str] = []
+        errors: List[Exception] = []
+        lock = threading.Lock()
+
+        def worker():
+            try:
+                with self.pool.connection(timeout=2.0) as conn:
+                    select_stmt = conn.prepare(
+                        "SELECT content FROM my_schema.table1 WHERE u_uid = $1"
+                    )
+                    rows = select_stmt(self.seed_uid)
+                with lock:
+                    results.append(rows[0][0])
+            except Exception as exc:  # noqa: BLE001 - capture any worker errors
+                with lock:
+                    errors.append(exc)
+
+        threads = [threading.Thread(target=worker) for _ in range(20)]
+        for t in threads:
+            t.start()
+        for t in threads:
+            t.join()
+
+        print("STEP: test_concurrent_reads - assert results and pool stats")
+        stats = self.pool.stats()
+
+        self.assertEqual(len(errors), 0)
+        self.assertEqual(len(results), 20)
+        self.assertFalse(stats["closed"])
+        self.assertLessEqual(stats["total"], self.config.max_size)
+
+    def test_concurrent_batch_writes(self):
+        print("STEP: test_concurrent_batch_writes - start writer threads")
+        base = self.seed_uid + 10000
+        process_id = f"pool_batch_{base}"
+        total_batches = 5
+        batch_size = 20
+        worker_count = 8
+        errors: List[Exception] = []
+        lock = threading.Lock()
+
+        def writer(worker_idx: int):
+            try:
+                for batch_idx in range(total_batches):
+                    with self.pool.connection(timeout=2.0) as conn:
+                        insert_stmt = conn.prepare(
+                            """
+                            INSERT INTO my_schema.table1 (u_uid, process_id, ins_tm, content)
+                            VALUES ($1, $2, NOW(), $3)
+                            """
+                        )
+                        for i in range(batch_size):
+                            uid = base + (worker_idx * 1000) + (batch_idx * batch_size) + i
+                            insert_stmt(uid, process_id, f"batch_{batch_idx}")
+            except Exception as exc:  # noqa: BLE001 - capture any worker errors
+                with lock:
+                    errors.append(exc)
+
+        threads = [threading.Thread(target=writer, args=(i,)) for i in range(worker_count)]
+        for t in threads:
+            t.start()
+        for t in threads:
+            t.join()
+
+        print("STEP: test_concurrent_batch_writes - validate count and cleanup")
+        with self.pool.connection(timeout=2.0) as conn:
+            count_stmt = conn.prepare(
+                "SELECT COUNT(*) FROM my_schema.table1 WHERE process_id = $1"
+            )
+            rows = count_stmt(process_id)
+            total_inserted = rows[0][0]
+            delete_stmt = conn.prepare(
+                "DELETE FROM my_schema.table1 WHERE process_id = $1"
+            )
+            delete_stmt(process_id)
+
+        stats = self.pool.stats()
+
+        self.assertEqual(len(errors), 0)
+        self.assertEqual(total_inserted, worker_count * total_batches * batch_size)
+        self.assertFalse(stats["closed"])
+
+    def test_connection_lifecycle_management(self):
+        print("STEP: test_connection_lifecycle_management - borrow and age connection")
+        conn = self.pool.borrow()
+        conn.prepare("SELECT 1")()
+        time.sleep(self.config.max_lifetime + 0.2)
+        self.pool.return_connection(conn)
+
+        print("STEP: test_connection_lifecycle_management - borrow new connection")
+        new_conn = self.pool.borrow(timeout=1.0)
+        self.pool.return_connection(new_conn)
+        stats = self.pool.stats()
+
+        self.assertNotEqual(id(conn), id(new_conn))
+        self.assertLessEqual(stats["total"], self.config.max_size)
+
+    def test_connection_validation_replaces_bad_connection(self):
+        print("STEP: test_connection_validation_replaces_bad_connection - return closed conn")
+        conn = self.pool.borrow()
+        conn.close()
+        self.pool.return_connection(conn, had_error=False)
+
+        print("STEP: test_connection_validation_replaces_bad_connection - borrow and validate")
+        with self.pool.connection(timeout=1.0) as conn2:
+            rows = conn2.prepare("SELECT 1")()
+        stats = self.pool.stats()
+
+        self.assertEqual(rows[0][0], 1)
+        self.assertLessEqual(stats["total"], self.config.max_size)
+
+    def test_long_running_requests(self):
+        print("STEP: test_long_running_requests - run mixed workload")
+        errors: List[Exception] = []
+        lock = threading.Lock()
+        stop_at = time.monotonic() + 3.0
+
+        def load_worker():
+            while time.monotonic() < stop_at:
+                try:
+                    with self.pool.connection(timeout=2.0) as conn:
+                        conn.prepare("SELECT 1")()
+                        time.sleep(0.2)
+                        select_stmt = conn.prepare(
+                            "SELECT content FROM my_schema.table1 WHERE u_uid = $1"
+                        )
+                        rows = select_stmt(self.seed_uid)
+                        if rows[0][0] != "seed":
+                            raise AssertionError("Unexpected content")
+                except Exception as exc:  # noqa: BLE001 - capture any worker errors
+                    with lock:
+                        errors.append(exc)
+                    return
+
+        threads = [threading.Thread(target=load_worker) for _ in range(4)]
+        for t in threads:
+            t.start()
+        for t in threads:
+            t.join()
+
+        print("STEP: test_long_running_requests - assert pool stats")
+        stats = self.pool.stats()
+
+        self.assertEqual(len(errors), 0)
+        self.assertFalse(stats["closed"])
+        self.assertLessEqual(stats["total"], self.config.max_size)
+
+
+if __name__ == "__main__":
+    unittest.main()

+ 298 - 0
tests/test_opengauss_pool_hardened.py.txt

@@ -0,0 +1,298 @@
+import threading
+import time
+import unittest
+from typing import List
+
+try:
+    import py_opengauss  # type: ignore  # noqa: F401
+except ImportError:  # pragma: no cover - handled via skip
+    py_opengauss = None
+
+from gdb_utils.opengauss_pool_hardened import (
+    ConnectionPoolConfig,
+    OpenGaussConnectionPool,
+    PoolExhaustedError,
+)
+
+
+DB_HOST = "127.0.0.1"
+DB_PORT = 5432
+DB_USER = "gaussdb"
+DB_PASSWORD = "Ysl#1234"
+DB_NAME = "postgres"
+
+
+def build_dsn() -> str:
+    return f"opengauss://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
+
+
+@unittest.skipIf(py_opengauss is None, "py_opengauss is required for integration tests")
+class OpenGaussHardenedPoolIntegrationTest(unittest.TestCase):
+    seed_uid: int
+    seed_process_id: str
+    config: ConnectionPoolConfig
+    pool: OpenGaussConnectionPool
+
+    @classmethod
+    def setUpClass(cls) -> None:
+        base = int(time.time()) % 1_500_000_000
+        cls.seed_uid = base + 200
+        cls.seed_process_id = f"hardened_seed_{base}"
+        conn = py_opengauss.open(build_dsn())
+        try:
+            delete_stmt = conn.prepare("DELETE FROM my_schema.table1 WHERE u_uid = $1")
+            delete_stmt(cls.seed_uid)
+            insert_stmt = conn.prepare(
+                """
+                INSERT INTO my_schema.table1 (u_uid, process_id, ins_tm, content)
+                VALUES ($1, $2, NOW(), $3)
+                """
+            )
+            insert_stmt(cls.seed_uid, cls.seed_process_id, "seed")
+        finally:
+            conn.close()
+
+        cls.config = ConnectionPoolConfig(
+            dsn=build_dsn(),
+            min_size=2,
+            max_size=6,
+            blocking=True,
+            acquire_timeout=0.5,
+            idle_timeout=10,
+            max_lifetime=30,
+            max_usage=20,
+            test_on_borrow=True,
+            test_sql="SELECT 1",
+            keepalive=True,
+            keepalive_interval=0.5,
+            health_check_interval=0.5,
+            reset_on_return=True,
+        )
+        cls.pool = OpenGaussConnectionPool(cls.config)
+
+    @classmethod
+    def tearDownClass(cls) -> None:
+        if hasattr(cls, "pool") and cls.pool:
+            cls.pool.close()
+        conn = py_opengauss.open(build_dsn())
+        try:
+            delete_stmt = conn.prepare(
+                "DELETE FROM my_schema.table1 WHERE process_id = $1"
+            )
+            delete_stmt(cls.seed_process_id)
+        finally:
+            conn.close()
+
+    def _new_pool(self, **overrides) -> OpenGaussConnectionPool:
+        config_dict = self.config.__dict__.copy()
+        config_dict.update(overrides)
+        return OpenGaussConnectionPool(ConnectionPoolConfig(**config_dict))
+
+    def test_pool_limit_and_timeout(self):
+        print("STEP: test_pool_limit_and_timeout - exhaust pool")
+        borrowed = [self.pool.borrow() for _ in range(self.config.max_size)]
+
+        print("STEP: test_pool_limit_and_timeout - timed borrow raises")
+        with self.assertRaises(PoolExhaustedError):
+            self.pool.borrow(timeout=0.3)
+
+        print("STEP: test_pool_limit_and_timeout - return connections and assert stats")
+        for conn in borrowed:
+            conn.close()
+        stats = self.pool.stats()
+
+        self.assertLessEqual(stats["total"], self.config.max_size)
+        self.assertEqual(stats["in_use"], 0)
+        self.assertGreaterEqual(stats["available"], self.config.min_size)
+
+    def test_serial_reads_with_lifecycle(self):
+        print("STEP: test_serial_reads_with_lifecycle - serial reads with pauses")
+        pool = self._new_pool(min_size=1, max_size=3, max_lifetime=1.0, idle_timeout=5.0)
+        stats = {}
+        seen_ids = set()
+        stats_before = pool.stats()
+        try:
+            for idx in range(4):
+                with pool.connection(timeout=1.0) as conn:
+                    seen_ids.add(id(conn._entry.conn))
+                    select_stmt = conn.prepare(
+                        "SELECT content FROM my_schema.table1 WHERE u_uid = $1"
+                    )
+                    rows = select_stmt(self.seed_uid)
+                    self.assertEqual(rows[0][0], "seed")
+                    if idx == 1:
+                        time.sleep(1.2)
+                    else:
+                        time.sleep(0.2)
+            stats = pool.stats()
+        finally:
+            pool.close()
+
+        print("STEP: test_serial_reads_with_lifecycle - assert pool stats")
+        self.assertLessEqual(stats["total"], 3)
+        self.assertFalse(stats["closed"])
+        self.assertTrue(
+            len(seen_ids) >= 2 or stats["discarded"] > stats_before["discarded"]
+        )
+
+    def test_concurrent_reads(self):
+        print("STEP: test_concurrent_reads - start threads")
+        results: List[str] = []
+        errors: List[Exception] = []
+        lock = threading.Lock()
+
+        def worker():
+            try:
+                with self.pool.connection(timeout=2.0) as conn:
+                    select_stmt = conn.prepare(
+                        "SELECT content FROM my_schema.table1 WHERE u_uid = $1"
+                    )
+                    rows = select_stmt(self.seed_uid)
+                with lock:
+                    results.append(rows[0][0])
+            except Exception as exc:  # noqa: BLE001 - capture any worker errors
+                with lock:
+                    errors.append(exc)
+
+        threads = [threading.Thread(target=worker) for _ in range(20)]
+        for t in threads:
+            t.start()
+        for t in threads:
+            t.join()
+
+        print("STEP: test_concurrent_reads - assert results and pool stats")
+        stats = self.pool.stats()
+
+        self.assertEqual(len(errors), 0)
+        self.assertEqual(len(results), 20)
+        self.assertFalse(stats["closed"])
+        self.assertLessEqual(stats["total"], self.config.max_size)
+
+    def test_concurrent_batch_writes(self):
+        print("STEP: test_concurrent_batch_writes - start writer threads")
+        base = self.seed_uid + 20000
+        process_id = f"hardened_batch_{base}"
+        total_batches = 5
+        batch_size = 20
+        worker_count = 8
+        errors: List[Exception] = []
+        lock = threading.Lock()
+
+        def writer(worker_idx: int):
+            try:
+                for batch_idx in range(total_batches):
+                    with self.pool.connection(timeout=2.0) as conn:
+                        insert_stmt = conn.prepare(
+                            """
+                            INSERT INTO my_schema.table1 (u_uid, process_id, ins_tm, content)
+                            VALUES ($1, $2, NOW(), $3)
+                            """
+                        )
+                        for i in range(batch_size):
+                            uid = base + (worker_idx * 1000) + (batch_idx * batch_size) + i
+                            insert_stmt(uid, process_id, f"batch_{batch_idx}")
+            except Exception as exc:  # noqa: BLE001 - capture any worker errors
+                with lock:
+                    errors.append(exc)
+
+        threads = [threading.Thread(target=writer, args=(i,)) for i in range(worker_count)]
+        for t in threads:
+            t.start()
+        for t in threads:
+            t.join()
+
+        print("STEP: test_concurrent_batch_writes - validate count and cleanup")
+        with self.pool.connection(timeout=2.0) as conn:
+            count_stmt = conn.prepare(
+                "SELECT COUNT(*) FROM my_schema.table1 WHERE process_id = $1"
+            )
+            rows = count_stmt(process_id)
+            total_inserted = rows[0][0]
+            delete_stmt = conn.prepare(
+                "DELETE FROM my_schema.table1 WHERE process_id = $1"
+            )
+            delete_stmt(process_id)
+
+        stats = self.pool.stats()
+
+        self.assertEqual(len(errors), 0)
+        self.assertEqual(total_inserted, worker_count * total_batches * batch_size)
+        self.assertFalse(stats["closed"])
+
+    def test_connection_lifecycle_management(self):
+        print("STEP: test_connection_lifecycle_management - hold and expire connection")
+        pool = self._new_pool(min_size=1, max_size=2, max_lifetime=1.0, idle_timeout=5.0)
+        stats = {}
+        try:
+            conn = pool.borrow()
+            conn_id = id(conn._entry.conn)
+            time.sleep(1.2)
+            conn.close()
+
+            print("STEP: test_connection_lifecycle_management - borrow replacement")
+            new_conn = pool.borrow(timeout=1.0)
+            new_id = id(new_conn._entry.conn)
+            new_conn.close()
+            stats = pool.stats()
+        finally:
+            pool.close()
+
+        self.assertNotEqual(conn_id, new_id)
+        self.assertLessEqual(stats["total"], 2)
+
+    def test_connection_validation_replaces_bad_connection(self):
+        print("STEP: test_connection_validation_replaces_bad_connection - return closed conn")
+        conn = self.pool.borrow()
+        bad_id = id(conn._entry.conn)
+        conn._entry.conn.close()
+        self.pool.return_connection(conn)
+
+        print("STEP: test_connection_validation_replaces_bad_connection - borrow and validate")
+        with self.pool.connection(timeout=1.0) as conn2:
+            new_id = id(conn2._entry.conn)
+            rows = conn2.prepare("SELECT 1")()
+        stats = self.pool.stats()
+
+        self.assertEqual(rows[0][0], 1)
+        self.assertNotEqual(bad_id, new_id)
+        self.assertLessEqual(stats["total"], self.config.max_size)
+
+    def test_long_running_requests(self):
+        print("STEP: test_long_running_requests - run mixed workload")
+        errors: List[Exception] = []
+        lock = threading.Lock()
+        stop_at = time.monotonic() + 3.0
+
+        def load_worker():
+            while time.monotonic() < stop_at:
+                try:
+                    with self.pool.connection(timeout=2.0) as conn:
+                        conn.prepare("SELECT 1")()
+                        time.sleep(0.2)
+                        select_stmt = conn.prepare(
+                            "SELECT content FROM my_schema.table1 WHERE u_uid = $1"
+                        )
+                        rows = select_stmt(self.seed_uid)
+                        if rows[0][0] != "seed":
+                            raise AssertionError("Unexpected content")
+                except Exception as exc:  # noqa: BLE001 - capture any worker errors
+                    with lock:
+                        errors.append(exc)
+                    return
+
+        threads = [threading.Thread(target=load_worker) for _ in range(4)]
+        for t in threads:
+            t.start()
+        for t in threads:
+            t.join()
+
+        print("STEP: test_long_running_requests - assert pool stats")
+        stats = self.pool.stats()
+
+        self.assertEqual(len(errors), 0)
+        self.assertFalse(stats["closed"])
+        self.assertLessEqual(stats["total"], self.config.max_size)
+
+
+if __name__ == "__main__":
+    unittest.main()