import threading
import time
import unittest
from typing import List

try:
    import py_opengauss  # type: ignore  # noqa: F401
except ImportError:  # pragma: no cover - handled via skip
    py_opengauss = None

from gdb_utils.opengauss_pool_hardened import (
    ConnectionPoolConfig,
    OpenGaussConnectionPool,
    PoolExhaustedError,
)


DB_HOST = "127.0.0.1"
DB_PORT = 5432
DB_USER = "gaussdb"
DB_PASSWORD = "Ysl#1234"
DB_NAME = "postgres"


def build_dsn() -> str:
    return f"opengauss://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}"


@unittest.skipIf(py_opengauss is None, "py_opengauss is required for integration tests")
class OpenGaussHardenedPoolIntegrationTest(unittest.TestCase):
    seed_uid: int
    seed_process_id: str
    config: ConnectionPoolConfig
    pool: OpenGaussConnectionPool

    @classmethod
    def setUpClass(cls) -> None:
        base = int(time.time()) % 1_500_000_000
        cls.seed_uid = base + 200
        cls.seed_process_id = f"hardened_seed_{base}"
        conn = py_opengauss.open(build_dsn())
        try:
            delete_stmt = conn.prepare("DELETE FROM my_schema.table1 WHERE u_uid = $1")
            delete_stmt(cls.seed_uid)
            insert_stmt = conn.prepare(
                """
                INSERT INTO my_schema.table1 (u_uid, process_id, ins_tm, content)
                VALUES ($1, $2, NOW(), $3)
                """
            )
            insert_stmt(cls.seed_uid, cls.seed_process_id, "seed")
        finally:
            conn.close()

        cls.config = ConnectionPoolConfig(
            dsn=build_dsn(),
            min_size=2,
            max_size=6,
            blocking=True,
            acquire_timeout=0.5,
            idle_timeout=10,
            max_lifetime=30,
            max_usage=20,
            test_on_borrow=True,
            test_sql="SELECT 1",
            keepalive=True,
            keepalive_interval=0.5,
            health_check_interval=0.5,
            reset_on_return=True,
        )
        cls.pool = OpenGaussConnectionPool(cls.config)

    @classmethod
    def tearDownClass(cls) -> None:
        if hasattr(cls, "pool") and cls.pool:
            cls.pool.close()
        conn = py_opengauss.open(build_dsn())
        try:
            delete_stmt = conn.prepare(
                "DELETE FROM my_schema.table1 WHERE process_id = $1"
            )
            delete_stmt(cls.seed_process_id)
        finally:
            conn.close()

    def _new_pool(self, **overrides) -> OpenGaussConnectionPool:
        config_dict = self.config.__dict__.copy()
        config_dict.update(overrides)
        return OpenGaussConnectionPool(ConnectionPoolConfig(**config_dict))

    def test_pool_limit_and_timeout(self):
        print("STEP: test_pool_limit_and_timeout - exhaust pool")
        borrowed = [self.pool.borrow() for _ in range(self.config.max_size)]

        print("STEP: test_pool_limit_and_timeout - timed borrow raises")
        with self.assertRaises(PoolExhaustedError):
            self.pool.borrow(timeout=0.3)

        print("STEP: test_pool_limit_and_timeout - return connections and assert stats")
        for conn in borrowed:
            conn.close()
        stats = self.pool.stats()

        self.assertLessEqual(stats["total"], self.config.max_size)
        self.assertEqual(stats["in_use"], 0)
        self.assertGreaterEqual(stats["available"], self.config.min_size)

    def test_serial_reads_with_lifecycle(self):
        print("STEP: test_serial_reads_with_lifecycle - serial reads with pauses")
        pool = self._new_pool(min_size=1, max_size=3, max_lifetime=1.0, idle_timeout=5.0)
        stats = {}
        seen_ids = set()
        stats_before = pool.stats()
        try:
            for idx in range(4):
                with pool.connection(timeout=1.0) as conn:
                    seen_ids.add(id(conn._entry.conn))
                    select_stmt = conn.prepare(
                        "SELECT content FROM my_schema.table1 WHERE u_uid = $1"
                    )
                    rows = select_stmt(self.seed_uid)
                    self.assertEqual(rows[0][0], "seed")
                    if idx == 1:
                        time.sleep(1.2)
                    else:
                        time.sleep(0.2)
            stats = pool.stats()
        finally:
            pool.close()

        print("STEP: test_serial_reads_with_lifecycle - assert pool stats")
        self.assertLessEqual(stats["total"], 3)
        self.assertFalse(stats["closed"])
        self.assertTrue(
            len(seen_ids) >= 2 or stats["discarded"] > stats_before["discarded"]
        )

    def test_concurrent_reads(self):
        print("STEP: test_concurrent_reads - start threads")
        results: List[str] = []
        errors: List[Exception] = []
        lock = threading.Lock()

        def worker():
            try:
                with self.pool.connection(timeout=2.0) as conn:
                    select_stmt = conn.prepare(
                        "SELECT content FROM my_schema.table1 WHERE u_uid = $1"
                    )
                    rows = select_stmt(self.seed_uid)
                with lock:
                    results.append(rows[0][0])
            except Exception as exc:  # noqa: BLE001 - capture any worker errors
                with lock:
                    errors.append(exc)

        threads = [threading.Thread(target=worker) for _ in range(20)]
        for t in threads:
            t.start()
        for t in threads:
            t.join()

        print("STEP: test_concurrent_reads - assert results and pool stats")
        stats = self.pool.stats()

        self.assertEqual(len(errors), 0)
        self.assertEqual(len(results), 20)
        self.assertFalse(stats["closed"])
        self.assertLessEqual(stats["total"], self.config.max_size)

    def test_concurrent_batch_writes(self):
        print("STEP: test_concurrent_batch_writes - start writer threads")
        base = self.seed_uid + 20000
        process_id = f"hardened_batch_{base}"
        total_batches = 5
        batch_size = 20
        worker_count = 8
        errors: List[Exception] = []
        lock = threading.Lock()

        def writer(worker_idx: int):
            try:
                for batch_idx in range(total_batches):
                    with self.pool.connection(timeout=2.0) as conn:
                        insert_stmt = conn.prepare(
                            """
                            INSERT INTO my_schema.table1 (u_uid, process_id, ins_tm, content)
                            VALUES ($1, $2, NOW(), $3)
                            """
                        )
                        for i in range(batch_size):
                            uid = base + (worker_idx * 1000) + (batch_idx * batch_size) + i
                            insert_stmt(uid, process_id, f"batch_{batch_idx}")
            except Exception as exc:  # noqa: BLE001 - capture any worker errors
                with lock:
                    errors.append(exc)

        threads = [threading.Thread(target=writer, args=(i,)) for i in range(worker_count)]
        for t in threads:
            t.start()
        for t in threads:
            t.join()

        print("STEP: test_concurrent_batch_writes - validate count and cleanup")
        with self.pool.connection(timeout=2.0) as conn:
            count_stmt = conn.prepare(
                "SELECT COUNT(*) FROM my_schema.table1 WHERE process_id = $1"
            )
            rows = count_stmt(process_id)
            total_inserted = rows[0][0]
            delete_stmt = conn.prepare(
                "DELETE FROM my_schema.table1 WHERE process_id = $1"
            )
            delete_stmt(process_id)

        stats = self.pool.stats()

        self.assertEqual(len(errors), 0)
        self.assertEqual(total_inserted, worker_count * total_batches * batch_size)
        self.assertFalse(stats["closed"])

    def test_connection_lifecycle_management(self):
        print("STEP: test_connection_lifecycle_management - hold and expire connection")
        pool = self._new_pool(min_size=1, max_size=2, max_lifetime=1.0, idle_timeout=5.0)
        stats = {}
        try:
            conn = pool.borrow()
            conn_id = id(conn._entry.conn)
            time.sleep(1.2)
            conn.close()

            print("STEP: test_connection_lifecycle_management - borrow replacement")
            new_conn = pool.borrow(timeout=1.0)
            new_id = id(new_conn._entry.conn)
            new_conn.close()
            stats = pool.stats()
        finally:
            pool.close()

        self.assertNotEqual(conn_id, new_id)
        self.assertLessEqual(stats["total"], 2)

    def test_connection_validation_replaces_bad_connection(self):
        print("STEP: test_connection_validation_replaces_bad_connection - return closed conn")
        conn = self.pool.borrow()
        bad_id = id(conn._entry.conn)
        conn._entry.conn.close()
        self.pool.return_connection(conn)

        print("STEP: test_connection_validation_replaces_bad_connection - borrow and validate")
        with self.pool.connection(timeout=1.0) as conn2:
            new_id = id(conn2._entry.conn)
            rows = conn2.prepare("SELECT 1")()
        stats = self.pool.stats()

        self.assertEqual(rows[0][0], 1)
        self.assertNotEqual(bad_id, new_id)
        self.assertLessEqual(stats["total"], self.config.max_size)

    def test_long_running_requests(self):
        print("STEP: test_long_running_requests - run mixed workload")
        errors: List[Exception] = []
        lock = threading.Lock()
        stop_at = time.monotonic() + 3.0

        def load_worker():
            while time.monotonic() < stop_at:
                try:
                    with self.pool.connection(timeout=2.0) as conn:
                        conn.prepare("SELECT 1")()
                        time.sleep(0.2)
                        select_stmt = conn.prepare(
                            "SELECT content FROM my_schema.table1 WHERE u_uid = $1"
                        )
                        rows = select_stmt(self.seed_uid)
                        if rows[0][0] != "seed":
                            raise AssertionError("Unexpected content")
                except Exception as exc:  # noqa: BLE001 - capture any worker errors
                    with lock:
                        errors.append(exc)
                    return

        threads = [threading.Thread(target=load_worker) for _ in range(4)]
        for t in threads:
            t.start()
        for t in threads:
            t.join()

        print("STEP: test_long_running_requests - assert pool stats")
        stats = self.pool.stats()

        self.assertEqual(len(errors), 0)
        self.assertFalse(stats["closed"])
        self.assertLessEqual(stats["total"], self.config.max_size)


if __name__ == "__main__":
    unittest.main()
