#!/usr/bin/env python3 """Batch test runner for NL2SQL generation and SQL execution.""" from __future__ import annotations import argparse import json import sys import uuid from dataclasses import dataclass, field from pathlib import Path from typing import Any import requests # ===== Interface constants: adjust these values before each run if needed. ===== BASE_URL = "replace-with-your-base-url" GENERATE_SQL_PATH = "/generate_sql" QUERY_SQL_URL = "http://replace-with-your-query-sql-api" USER_TOKEN = "replace-with-your-token" BBK = 100 DOMAIN = "cmb" DEPARTMENT = "ABC" RCM_NUM = 5 YST_ID = 311139 USER_NAME = "杨林/311139" DEFAULT_TIMEOUT_SECONDS = 60 DEFAULT_INPUT_JSON = "sql_generate_testcase.json" DEFAULT_OUTPUT_DIR = "output" @dataclass class SQLCandidate: card_id: str sql: str source_event: str = "" @dataclass class QueryExecutionResult: status_code: int | None columns: list[str] = field(default_factory=list) data: list[dict[str, Any]] = field(default_factory=list) row_count: int | None = None total_count: int | None = None truncated: bool | None = None error: str = "" raw_payload: Any = None def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Batch test NL2SQL generate/query flow.") parser.add_argument( "--input-json", default=DEFAULT_INPUT_JSON, help="JSON file containing question and gold answer cases.", ) parser.add_argument( "--output-dir", default=DEFAULT_OUTPUT_DIR, help="Directory for summary and per-question generated answers.", ) parser.add_argument( "--timeout", type=int, default=DEFAULT_TIMEOUT_SECONDS, help="HTTP timeout in seconds for both APIs.", ) return parser.parse_args() def build_generate_sql_url() -> str: if BASE_URL == "replace-with-your-base-url": raise ValueError("BASE_URL is still placeholder. Please set the actual service host.") return f"http://{BASE_URL.rstrip('/')}{GENERATE_SQL_PATH}" def load_cases(path: Path) -> list[dict[str, Any]]: with path.open("r", encoding="utf-8") as handle: payload = json.load(handle) if isinstance(payload, dict): payload = payload.get("cases") or payload.get("data") or [payload] if not isinstance(payload, list): raise ValueError(f"Unsupported JSON structure in {path}") cases: list[dict[str, Any]] = [] for index, item in enumerate(payload, start=1): if not isinstance(item, dict): raise ValueError(f"Case #{index} is not an object") question = item.get("question") answer = item.get("answer") if not isinstance(question, str) or not question.strip(): raise ValueError(f"Case #{index} missing valid 'question'") cases.append( { "case_id": item.get("case_id") or f"case_{index}", "question": question.strip(), "answer": answer, } ) if not cases: raise ValueError(f"No valid test cases found in {path}") return cases def make_request_id() -> str: return uuid.uuid4().hex def build_generate_payload(question: str) -> dict[str, Any]: return { "request_id": make_request_id(), "user_question": question, "token": USER_TOKEN, "bbk": BBK, "domain": DOMAIN, "department": DEPARTMENT, "rcm_num": RCM_NUM, "yst_id": YST_ID, "user_name": USER_NAME, } def build_query_payload(sql: str) -> dict[str, Any]: return { "request_id": make_request_id(), "sql": sql, } def parse_sse_event_block(block: str) -> tuple[str, str]: event_type = "" data_lines: list[str] = [] for line in block.splitlines(): if line.startswith("event:"): event_type = line.split(":", 1)[1].strip() elif line.startswith("data:"): data_lines.append(line.split(":", 1)[1].lstrip()) return event_type, "\n".join(data_lines) def parse_json_maybe(raw: str) -> Any: try: return json.loads(raw) except json.JSONDecodeError: return raw def normalize_sql(sql: str) -> str: return " ".join(sql.split()).strip() def get_first_non_empty_string(value: Any) -> str: if isinstance(value, str): return value.strip() return "" def extract_candidate_from_mapping(mapping: dict[str, Any], source_event: str) -> SQLCandidate | None: card_id_keys = ("card_id", "cardId", "id") sql_keys = ("sql", "SQL", "final_sql", "finalSql", "content") card_id = "" sql = "" for key in card_id_keys: card_id = get_first_non_empty_string(mapping.get(key)) if card_id: break for key in sql_keys: sql = get_first_non_empty_string(mapping.get(key)) if sql: break if card_id and sql: return SQLCandidate(card_id=card_id, sql=normalize_sql(sql), source_event=source_event) return None def extract_candidates_from_payload(payload: Any, source_event: str) -> list[SQLCandidate]: candidates: list[SQLCandidate] = [] if isinstance(payload, dict): direct = extract_candidate_from_mapping(payload, source_event) if direct is not None: candidates.append(direct) for value in payload.values(): candidates.extend(extract_candidates_from_payload(value, source_event)) return dedupe_candidates(candidates) if isinstance(payload, list): for item in payload: candidates.extend(extract_candidates_from_payload(item, source_event)) return dedupe_candidates(candidates) return candidates def dedupe_candidates(candidates: list[SQLCandidate]) -> list[SQLCandidate]: deduped: list[SQLCandidate] = [] seen: set[tuple[str, str]] = set() for candidate in candidates: key = (candidate.card_id, candidate.sql) if not candidate.card_id or not candidate.sql or key in seen: continue seen.add(key) deduped.append(candidate) return deduped def parse_generate_stream(response: requests.Response) -> tuple[list[SQLCandidate], list[dict[str, Any]]]: raw_events: list[dict[str, Any]] = [] candidates: list[SQLCandidate] = [] buffer: list[str] = [] for raw_line in response.iter_lines(decode_unicode=True): if raw_line is None: continue line = raw_line if line == "": if not buffer: continue block = "\n".join(buffer) buffer.clear() event_type, data_raw = parse_sse_event_block(block) payload = parse_json_maybe(data_raw) raw_events.append({"event": event_type, "data": payload}) if event_type in {"dict", "message"}: candidates.extend(extract_candidates_from_payload(payload, event_type)) elif event_type == "error": error_msg = payload if isinstance(payload, str) else json.dumps(payload, ensure_ascii=False) raise RuntimeError(f"generate_sql stream error: {error_msg}") elif event_type == "end": break continue buffer.append(line) if buffer: block = "\n".join(buffer) event_type, data_raw = parse_sse_event_block(block) payload = parse_json_maybe(data_raw) raw_events.append({"event": event_type, "data": payload}) if event_type in {"dict", "message"}: candidates.extend(extract_candidates_from_payload(payload, event_type)) return dedupe_candidates(candidates)[:RCM_NUM], raw_events def request_generate_sql(question: str, timeout_seconds: int) -> tuple[list[SQLCandidate], list[dict[str, Any]]]: url = build_generate_sql_url() payload = build_generate_payload(question) response = requests.post( url, json=payload, stream=True, timeout=timeout_seconds, headers={"Accept": "text/event-stream"}, ) response.raise_for_status() return parse_generate_stream(response) def request_query_sql(sql: str, timeout_seconds: int) -> QueryExecutionResult: if QUERY_SQL_URL == "http://replace-with-your-query-sql-api": raise ValueError("QUERY_SQL_URL is still placeholder. Please set the actual query API URL.") response = requests.post( QUERY_SQL_URL, json=build_query_payload(sql), timeout=timeout_seconds, ) response.raise_for_status() payload = response.json() if not isinstance(payload, dict): raise ValueError(f"Query API returned non-object payload: {payload!r}") return QueryExecutionResult( status_code=payload.get("status_code"), columns=ensure_string_list(payload.get("columns")), data=ensure_row_list(payload.get("data")), row_count=payload.get("row_count"), total_count=payload.get("total_count"), truncated=payload.get("truncated"), error=str(payload.get("error") or ""), raw_payload=payload, ) def ensure_string_list(value: Any) -> list[str]: if not isinstance(value, list): return [] return [str(item) for item in value] def ensure_row_list(value: Any) -> list[dict[str, Any]]: if not isinstance(value, list): return [] return [item for item in value if isinstance(item, dict)] def canonicalize(value: Any) -> Any: if isinstance(value, dict): return {key: canonicalize(value[key]) for key in sorted(value)} if isinstance(value, list): normalized_items = [canonicalize(item) for item in value] return sorted(normalized_items, key=lambda item: json.dumps(item, ensure_ascii=False, sort_keys=True)) return value def answers_match(actual_data: list[dict[str, Any]], expected_answer: Any) -> bool: return canonicalize(actual_data) == canonicalize(expected_answer) def ensure_output_dir(path: Path) -> None: path.mkdir(parents=True, exist_ok=True) def write_json(path: Path, payload: Any) -> None: with path.open("w", encoding="utf-8") as handle: json.dump(payload, handle, ensure_ascii=False, indent=2) def run_case(case: dict[str, Any], timeout_seconds: int) -> dict[str, Any]: candidates, stream_events = request_generate_sql(case["question"], timeout_seconds) candidate_results: list[dict[str, Any]] = [] case_success = False for rank in range(RCM_NUM): candidate = candidates[rank] if rank < len(candidates) else None if candidate is None: candidate_results.append( { "rank": rank + 1, "card_id": "", "sql": "", "query_success": False, "query_status_code": None, "query_error": "missing candidate from generate_sql stream", "query_result": [], "match_gold_answer": False, } ) continue try: query_result = request_query_sql(candidate.sql, timeout_seconds) matched = ( query_result.status_code == 1 and not query_result.error and answers_match(query_result.data, case["answer"]) ) case_success = case_success or matched candidate_results.append( { "rank": rank + 1, "card_id": candidate.card_id, "sql": candidate.sql, "query_success": query_result.status_code == 1 and not query_result.error, "query_status_code": query_result.status_code, "query_error": query_result.error, "query_result": query_result.data, "match_gold_answer": matched, } ) except Exception as exc: # noqa: BLE001 candidate_results.append( { "rank": rank + 1, "card_id": candidate.card_id, "sql": candidate.sql, "query_success": False, "query_status_code": None, "query_error": str(exc), "query_result": [], "match_gold_answer": False, } ) return { "case_id": case["case_id"], "question": case["question"], "gold_answer": case["answer"], "success": case_success, "generated_answers": candidate_results, "stream_events": stream_events, } def main() -> int: args = parse_args() input_path = Path(args.input_json) output_dir = Path(args.output_dir) ensure_output_dir(output_dir) cases = load_cases(input_path) case_results: list[dict[str, Any]] = [] success_count = 0 for case in cases: try: result = run_case(case, args.timeout) except Exception as exc: # noqa: BLE001 result = { "case_id": case["case_id"], "question": case["question"], "gold_answer": case["answer"], "success": False, "generated_answers": [], "error": str(exc), } success_count += int(result["success"]) case_results.append(result) total_count = len(case_results) success_rate = (success_count / total_count) if total_count else 0.0 generated_answers_path = output_dir / "nl2sql_generated_answers.json" summary_path = output_dir / "nl2sql_batch_summary.json" write_json(generated_answers_path, {"cases": case_results}) write_json( summary_path, { "input_json": str(input_path), "total_count": total_count, "success_count": success_count, "success_rate": success_rate, "generated_answers_file": str(generated_answers_path), }, ) print( json.dumps( { "total_count": total_count, "success_count": success_count, "success_rate": success_rate, "summary_file": str(summary_path), "generated_answers_file": str(generated_answers_path), }, ensure_ascii=False, indent=2, ) ) return 0 if __name__ == "__main__": try: raise SystemExit(main()) except Exception as exc: # noqa: BLE001 print(f"[ERROR] {exc}", file=sys.stderr) raise SystemExit(1) from exc