import threading
import time
import unittest
from typing import List

try:
    import py_opengauss  # type: ignore  # noqa: F401
except ImportError:  # pragma: no cover - handled via skip
    py_opengauss = None

from gdb_utils import ConnectionPoolConfig, OpenGaussConnectionPool


DB_HOST = "127.0.0.1"
DB_PORT = 5432
DB_USER = "gaussdb"
DB_PASSWORD = "Ysl#1234"
DB_NAME = "postgres"


def build_dsn() -> str:
    return f"opengauss://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}"


@unittest.skipIf(py_opengauss is None, "py_opengauss is required for integration tests")
class OpenGaussPoolIntegrationTest(unittest.TestCase):
    seed_uid: int
    seed_process_id: str
    config: ConnectionPoolConfig
    pool: OpenGaussConnectionPool

    @classmethod
    def setUpClass(cls) -> None:
        base = int(time.time()) % 1_500_000_000
        cls.seed_uid = base + 100
        cls.seed_process_id = f"pool_seed_{base}"
        conn = py_opengauss.open(build_dsn())
        try:
            delete_stmt = conn.prepare("DELETE FROM my_schema.table1 WHERE u_uid = $1")
            delete_stmt(cls.seed_uid)
            insert_stmt = conn.prepare(
                """
                INSERT INTO my_schema.table1 (u_uid, process_id, ins_tm, content)
                VALUES ($1, $2, NOW(), $3)
                """
            )
            insert_stmt(cls.seed_uid, cls.seed_process_id, "seed")
        finally:
            conn.close()

        cls.config = ConnectionPoolConfig(
            dsn=build_dsn(),
            min_size=2,
            max_size=8,
            idle_timeout=30,
            max_lifetime=60,
            test_on_borrow=True,
            test_sql="SELECT 1",
            keepalive=True,
            keepalive_interval=0.5,
            health_check_interval=0.5,
        )
        cls.pool = OpenGaussConnectionPool(cls.config)

    @classmethod
    def tearDownClass(cls) -> None:
        if hasattr(cls, "pool") and cls.pool:
            cls.pool.close()
        conn = py_opengauss.open(build_dsn())
        try:
            delete_stmt = conn.prepare(
                "DELETE FROM my_schema.table1 WHERE process_id = $1"
            )
            delete_stmt(cls.seed_process_id)
        finally:
            conn.close()

    def _wait_for_min_size(self, timeout: float = 2.0) -> None:
        deadline = time.monotonic() + timeout
        while time.monotonic() < deadline:
            stats = self.pool.stats()
            if stats["total"] >= self.config.min_size:
                return
            time.sleep(0.05)

    def test_read_and_write(self):
        print("STEP: test_read_and_write - setup table and insert row")
        # u_uid,name are VARCHAR(10); keep identifiers short.
        unique_uid = f"u{int(time.time()) % 100000:05d}"
        # Ensure schema/table exist and insert a row.
        with self.pool.connection() as conn:
            conn.execute("CREATE SCHEMA IF NOT EXISTS my_schema")
            conn.execute(
                """
                CREATE TABLE IF NOT EXISTS my_schema.test_table (
                    u_uid VARCHAR(10),
                    name VARCHAR(10),
                    CONSTRAINT pk_uid PRIMARY KEY(u_uid)
                )
                """
            )
            # Ensure no leftover row with same key before insert.
            delete_before_insert = conn.prepare("DELETE FROM my_schema.test_table WHERE u_uid = $1")
            delete_before_insert(unique_uid)
            insert_stmt = conn.prepare(
                "INSERT INTO my_schema.test_table (u_uid, name) VALUES ($1, $2)"
            )
            insert_stmt(unique_uid, "codex")

        print("STEP: test_read_and_write - read row and assert")
        # Read back the row.
        with self.pool.connection() as conn:
            select_stmt = conn.prepare("SELECT name FROM my_schema.test_table WHERE u_uid = $1")
            rows = select_stmt(unique_uid)

        self.assertIsNotNone(rows)
        self.assertGreaterEqual(len(rows), 1)
        self.assertEqual(rows[0][0], "codex")

        print("STEP: test_read_and_write - cleanup inserted row")
        # Clean up the inserted row.
        with self.pool.connection() as conn:
            delete_stmt = conn.prepare("DELETE FROM my_schema.test_table WHERE u_uid = $1")
            delete_stmt(unique_uid)

    def test_broken_connection_is_replaced(self):
        print("STEP: test_broken_connection_is_replaced - close and return broken connection")
        conn = self.pool.borrow()
        # Simulate an unusable connection by closing it and marking as broken.
        conn.close()
        self.pool.return_connection(conn, had_error=True)

        print("STEP: test_broken_connection_is_replaced - borrow replacement and validate")
        # Borrow again; pool should create a new usable connection.
        with self.pool.connection() as conn2:
            stmt = conn2.prepare("SELECT 1")
            rows = stmt()
        self.assertEqual(rows[0][0], 1)

        self._wait_for_min_size()
        stats_after = self.pool.stats()
        self.assertFalse(stats_after["closed"])
        self.assertLessEqual(stats_after["total"], self.config.max_size)
        self.assertGreaterEqual(stats_after["total"], self.config.min_size)

    def test_pool_limit_and_timeout(self):
        print("STEP: test_pool_limit_and_timeout - exhaust pool")
        borrowed = [self.pool.borrow() for _ in range(self.config.max_size)]
        errors: List[Exception] = []

        print("STEP: test_pool_limit_and_timeout - attempt timed borrow")
        def try_borrow():
            try:
                self.pool.borrow(timeout=0.5)
            except Exception as exc:  # noqa: BLE001 - capture TimeoutError
                errors.append(exc)

        t = threading.Thread(target=try_borrow)
        t.start()
        t.join()

        print("STEP: test_pool_limit_and_timeout - return connections and assert stats")
        for conn in borrowed:
            self.pool.return_connection(conn)
        stats = self.pool.stats()

        self.assertLessEqual(stats["total"], self.config.max_size)
        self.assertEqual(stats["in_use"], 0)
        self.assertGreaterEqual(stats["available"], self.config.min_size)
        self.assertTrue(any(isinstance(err, TimeoutError) for err in errors))

    def test_serial_reads_with_lifecycle(self):
        print("STEP: test_serial_reads_with_lifecycle - serial reads with pauses")
        for _ in range(3):
            with self.pool.connection() as conn:
                select_stmt = conn.prepare(
                    "SELECT content FROM my_schema.table1 WHERE u_uid = $1"
                )
                rows = select_stmt(self.seed_uid)
                self.assertEqual(rows[0][0], "seed")
                time.sleep(0.5)

        print("STEP: test_serial_reads_with_lifecycle - assert pool stats")
        stats = self.pool.stats()

        self.assertLessEqual(stats["total"], self.config.max_size)
        self.assertFalse(stats["closed"])

    def test_concurrent_reads(self):
        print("STEP: test_concurrent_reads - start threads")
        results: List[str] = []
        errors: List[Exception] = []
        lock = threading.Lock()

        def worker():
            try:
                with self.pool.connection(timeout=2.0) as conn:
                    select_stmt = conn.prepare(
                        "SELECT content FROM my_schema.table1 WHERE u_uid = $1"
                    )
                    rows = select_stmt(self.seed_uid)
                with lock:
                    results.append(rows[0][0])
            except Exception as exc:  # noqa: BLE001 - capture any worker errors
                with lock:
                    errors.append(exc)

        threads = [threading.Thread(target=worker) for _ in range(20)]
        for t in threads:
            t.start()
        for t in threads:
            t.join()

        print("STEP: test_concurrent_reads - assert results and pool stats")
        stats = self.pool.stats()

        self.assertEqual(len(errors), 0)
        self.assertEqual(len(results), 20)
        self.assertFalse(stats["closed"])
        self.assertLessEqual(stats["total"], self.config.max_size)

    def test_concurrent_batch_writes(self):
        print("STEP: test_concurrent_batch_writes - start writer threads")
        base = self.seed_uid + 10000
        process_id = f"pool_batch_{base}"
        total_batches = 5
        batch_size = 20
        worker_count = 8
        errors: List[Exception] = []
        lock = threading.Lock()

        def writer(worker_idx: int):
            try:
                for batch_idx in range(total_batches):
                    with self.pool.connection(timeout=2.0) as conn:
                        insert_stmt = conn.prepare(
                            """
                            INSERT INTO my_schema.table1 (u_uid, process_id, ins_tm, content)
                            VALUES ($1, $2, NOW(), $3)
                            """
                        )
                        for i in range(batch_size):
                            uid = base + (worker_idx * 1000) + (batch_idx * batch_size) + i
                            insert_stmt(uid, process_id, f"batch_{batch_idx}")
            except Exception as exc:  # noqa: BLE001 - capture any worker errors
                with lock:
                    errors.append(exc)

        threads = [threading.Thread(target=writer, args=(i,)) for i in range(worker_count)]
        for t in threads:
            t.start()
        for t in threads:
            t.join()

        print("STEP: test_concurrent_batch_writes - validate count and cleanup")
        with self.pool.connection(timeout=2.0) as conn:
            count_stmt = conn.prepare(
                "SELECT COUNT(*) FROM my_schema.table1 WHERE process_id = $1"
            )
            rows = count_stmt(process_id)
            total_inserted = rows[0][0]
            delete_stmt = conn.prepare(
                "DELETE FROM my_schema.table1 WHERE process_id = $1"
            )
            delete_stmt(process_id)

        stats = self.pool.stats()

        self.assertEqual(len(errors), 0)
        self.assertEqual(total_inserted, worker_count * total_batches * batch_size)
        self.assertFalse(stats["closed"])

    def test_connection_lifecycle_management(self):
        print("STEP: test_connection_lifecycle_management - borrow and age connection")
        conn = self.pool.borrow()
        conn.prepare("SELECT 1")()
        time.sleep(self.config.max_lifetime + 0.2)
        self.pool.return_connection(conn)

        print("STEP: test_connection_lifecycle_management - borrow new connection")
        new_conn = self.pool.borrow(timeout=1.0)
        self.pool.return_connection(new_conn)
        stats = self.pool.stats()

        self.assertNotEqual(id(conn), id(new_conn))
        self.assertLessEqual(stats["total"], self.config.max_size)

    def test_connection_validation_replaces_bad_connection(self):
        print("STEP: test_connection_validation_replaces_bad_connection - return closed conn")
        conn = self.pool.borrow()
        conn.close()
        self.pool.return_connection(conn, had_error=False)

        print("STEP: test_connection_validation_replaces_bad_connection - borrow and validate")
        with self.pool.connection(timeout=1.0) as conn2:
            rows = conn2.prepare("SELECT 1")()
        stats = self.pool.stats()

        self.assertEqual(rows[0][0], 1)
        self.assertLessEqual(stats["total"], self.config.max_size)

    def test_long_running_requests(self):
        print("STEP: test_long_running_requests - run mixed workload")
        errors: List[Exception] = []
        lock = threading.Lock()
        stop_at = time.monotonic() + 3.0

        def load_worker():
            while time.monotonic() < stop_at:
                try:
                    with self.pool.connection(timeout=2.0) as conn:
                        conn.prepare("SELECT 1")()
                        time.sleep(0.2)
                        select_stmt = conn.prepare(
                            "SELECT content FROM my_schema.table1 WHERE u_uid = $1"
                        )
                        rows = select_stmt(self.seed_uid)
                        if rows[0][0] != "seed":
                            raise AssertionError("Unexpected content")
                except Exception as exc:  # noqa: BLE001 - capture any worker errors
                    with lock:
                        errors.append(exc)
                    return

        threads = [threading.Thread(target=load_worker) for _ in range(4)]
        for t in threads:
            t.start()
        for t in threads:
            t.join()

        print("STEP: test_long_running_requests - assert pool stats")
        stats = self.pool.stats()

        self.assertEqual(len(errors), 0)
        self.assertFalse(stats["closed"])
        self.assertLessEqual(stats["total"], self.config.max_size)


if __name__ == "__main__":
    unittest.main()
