nl2sql_batch_test.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453
  1. #!/usr/bin/env python3
  2. """Batch test runner for NL2SQL generation and SQL execution."""
  3. from __future__ import annotations
  4. import argparse
  5. import json
  6. import sys
  7. import uuid
  8. from dataclasses import dataclass, field
  9. from pathlib import Path
  10. from typing import Any
  11. import requests
  12. # ===== Interface constants: adjust these values before each run if needed. =====
  13. BASE_URL = "replace-with-your-base-url"
  14. GENERATE_SQL_PATH = "/generate_sql"
  15. QUERY_SQL_URL = "http://replace-with-your-query-sql-api"
  16. USER_TOKEN = "replace-with-your-token"
  17. BBK = 100
  18. DOMAIN = "cmb"
  19. DEPARTMENT = "ABC"
  20. RCM_NUM = 5
  21. YST_ID = 311139
  22. USER_NAME = "杨林/311139"
  23. DEFAULT_TIMEOUT_SECONDS = 60
  24. DEFAULT_INPUT_JSON = "sql_generate_testcase.json"
  25. DEFAULT_OUTPUT_DIR = "output"
  26. @dataclass
  27. class SQLCandidate:
  28. card_id: str
  29. sql: str
  30. source_event: str = ""
  31. @dataclass
  32. class QueryExecutionResult:
  33. status_code: int | None
  34. columns: list[str] = field(default_factory=list)
  35. data: list[dict[str, Any]] = field(default_factory=list)
  36. row_count: int | None = None
  37. total_count: int | None = None
  38. truncated: bool | None = None
  39. error: str = ""
  40. raw_payload: Any = None
  41. def parse_args() -> argparse.Namespace:
  42. parser = argparse.ArgumentParser(description="Batch test NL2SQL generate/query flow.")
  43. parser.add_argument(
  44. "--input-json",
  45. default=DEFAULT_INPUT_JSON,
  46. help="JSON file containing question and gold answer cases.",
  47. )
  48. parser.add_argument(
  49. "--output-dir",
  50. default=DEFAULT_OUTPUT_DIR,
  51. help="Directory for summary and per-question generated answers.",
  52. )
  53. parser.add_argument(
  54. "--timeout",
  55. type=int,
  56. default=DEFAULT_TIMEOUT_SECONDS,
  57. help="HTTP timeout in seconds for both APIs.",
  58. )
  59. return parser.parse_args()
  60. def build_generate_sql_url() -> str:
  61. if BASE_URL == "replace-with-your-base-url":
  62. raise ValueError("BASE_URL is still placeholder. Please set the actual service host.")
  63. return f"http://{BASE_URL.rstrip('/')}{GENERATE_SQL_PATH}"
  64. def load_cases(path: Path) -> list[dict[str, Any]]:
  65. with path.open("r", encoding="utf-8") as handle:
  66. payload = json.load(handle)
  67. if isinstance(payload, dict):
  68. payload = payload.get("cases") or payload.get("data") or [payload]
  69. if not isinstance(payload, list):
  70. raise ValueError(f"Unsupported JSON structure in {path}")
  71. cases: list[dict[str, Any]] = []
  72. for index, item in enumerate(payload, start=1):
  73. if not isinstance(item, dict):
  74. raise ValueError(f"Case #{index} is not an object")
  75. question = item.get("question")
  76. answer = item.get("answer")
  77. if not isinstance(question, str) or not question.strip():
  78. raise ValueError(f"Case #{index} missing valid 'question'")
  79. cases.append(
  80. {
  81. "case_id": item.get("case_id") or f"case_{index}",
  82. "question": question.strip(),
  83. "answer": answer,
  84. }
  85. )
  86. if not cases:
  87. raise ValueError(f"No valid test cases found in {path}")
  88. return cases
  89. def make_request_id() -> str:
  90. return uuid.uuid4().hex
  91. def build_generate_payload(question: str) -> dict[str, Any]:
  92. return {
  93. "request_id": make_request_id(),
  94. "user_question": question,
  95. "token": USER_TOKEN,
  96. "bbk": BBK,
  97. "domain": DOMAIN,
  98. "department": DEPARTMENT,
  99. "rcm_num": RCM_NUM,
  100. "yst_id": YST_ID,
  101. "user_name": USER_NAME,
  102. }
  103. def build_query_payload(sql: str) -> dict[str, Any]:
  104. return {
  105. "request_id": make_request_id(),
  106. "sql": sql,
  107. }
  108. def parse_sse_event_block(block: str) -> tuple[str, str]:
  109. event_type = ""
  110. data_lines: list[str] = []
  111. for line in block.splitlines():
  112. if line.startswith("event:"):
  113. event_type = line.split(":", 1)[1].strip()
  114. elif line.startswith("data:"):
  115. data_lines.append(line.split(":", 1)[1].lstrip())
  116. return event_type, "\n".join(data_lines)
  117. def parse_json_maybe(raw: str) -> Any:
  118. try:
  119. return json.loads(raw)
  120. except json.JSONDecodeError:
  121. return raw
  122. def normalize_sql(sql: str) -> str:
  123. return " ".join(sql.split()).strip()
  124. def get_first_non_empty_string(value: Any) -> str:
  125. if isinstance(value, str):
  126. return value.strip()
  127. return ""
  128. def extract_candidate_from_mapping(mapping: dict[str, Any], source_event: str) -> SQLCandidate | None:
  129. card_id_keys = ("card_id", "cardId", "id")
  130. sql_keys = ("sql", "SQL", "final_sql", "finalSql", "content")
  131. card_id = ""
  132. sql = ""
  133. for key in card_id_keys:
  134. card_id = get_first_non_empty_string(mapping.get(key))
  135. if card_id:
  136. break
  137. for key in sql_keys:
  138. sql = get_first_non_empty_string(mapping.get(key))
  139. if sql:
  140. break
  141. if card_id and sql:
  142. return SQLCandidate(card_id=card_id, sql=normalize_sql(sql), source_event=source_event)
  143. return None
  144. def extract_candidates_from_payload(payload: Any, source_event: str) -> list[SQLCandidate]:
  145. candidates: list[SQLCandidate] = []
  146. if isinstance(payload, dict):
  147. direct = extract_candidate_from_mapping(payload, source_event)
  148. if direct is not None:
  149. candidates.append(direct)
  150. for value in payload.values():
  151. candidates.extend(extract_candidates_from_payload(value, source_event))
  152. return dedupe_candidates(candidates)
  153. if isinstance(payload, list):
  154. for item in payload:
  155. candidates.extend(extract_candidates_from_payload(item, source_event))
  156. return dedupe_candidates(candidates)
  157. return candidates
  158. def dedupe_candidates(candidates: list[SQLCandidate]) -> list[SQLCandidate]:
  159. deduped: list[SQLCandidate] = []
  160. seen: set[tuple[str, str]] = set()
  161. for candidate in candidates:
  162. key = (candidate.card_id, candidate.sql)
  163. if not candidate.card_id or not candidate.sql or key in seen:
  164. continue
  165. seen.add(key)
  166. deduped.append(candidate)
  167. return deduped
  168. def parse_generate_stream(response: requests.Response) -> tuple[list[SQLCandidate], list[dict[str, Any]]]:
  169. raw_events: list[dict[str, Any]] = []
  170. candidates: list[SQLCandidate] = []
  171. buffer: list[str] = []
  172. for raw_line in response.iter_lines(decode_unicode=True):
  173. if raw_line is None:
  174. continue
  175. line = raw_line
  176. if line == "":
  177. if not buffer:
  178. continue
  179. block = "\n".join(buffer)
  180. buffer.clear()
  181. event_type, data_raw = parse_sse_event_block(block)
  182. payload = parse_json_maybe(data_raw)
  183. raw_events.append({"event": event_type, "data": payload})
  184. if event_type in {"dict", "message"}:
  185. candidates.extend(extract_candidates_from_payload(payload, event_type))
  186. elif event_type == "error":
  187. error_msg = payload if isinstance(payload, str) else json.dumps(payload, ensure_ascii=False)
  188. raise RuntimeError(f"generate_sql stream error: {error_msg}")
  189. elif event_type == "end":
  190. break
  191. continue
  192. buffer.append(line)
  193. if buffer:
  194. block = "\n".join(buffer)
  195. event_type, data_raw = parse_sse_event_block(block)
  196. payload = parse_json_maybe(data_raw)
  197. raw_events.append({"event": event_type, "data": payload})
  198. if event_type in {"dict", "message"}:
  199. candidates.extend(extract_candidates_from_payload(payload, event_type))
  200. return dedupe_candidates(candidates)[:RCM_NUM], raw_events
  201. def request_generate_sql(question: str, timeout_seconds: int) -> tuple[list[SQLCandidate], list[dict[str, Any]]]:
  202. url = build_generate_sql_url()
  203. payload = build_generate_payload(question)
  204. response = requests.post(
  205. url,
  206. json=payload,
  207. stream=True,
  208. timeout=timeout_seconds,
  209. headers={"Accept": "text/event-stream"},
  210. )
  211. response.raise_for_status()
  212. return parse_generate_stream(response)
  213. def request_query_sql(sql: str, timeout_seconds: int) -> QueryExecutionResult:
  214. if QUERY_SQL_URL == "http://replace-with-your-query-sql-api":
  215. raise ValueError("QUERY_SQL_URL is still placeholder. Please set the actual query API URL.")
  216. response = requests.post(
  217. QUERY_SQL_URL,
  218. json=build_query_payload(sql),
  219. timeout=timeout_seconds,
  220. )
  221. response.raise_for_status()
  222. payload = response.json()
  223. if not isinstance(payload, dict):
  224. raise ValueError(f"Query API returned non-object payload: {payload!r}")
  225. return QueryExecutionResult(
  226. status_code=payload.get("status_code"),
  227. columns=ensure_string_list(payload.get("columns")),
  228. data=ensure_row_list(payload.get("data")),
  229. row_count=payload.get("row_count"),
  230. total_count=payload.get("total_count"),
  231. truncated=payload.get("truncated"),
  232. error=str(payload.get("error") or ""),
  233. raw_payload=payload,
  234. )
  235. def ensure_string_list(value: Any) -> list[str]:
  236. if not isinstance(value, list):
  237. return []
  238. return [str(item) for item in value]
  239. def ensure_row_list(value: Any) -> list[dict[str, Any]]:
  240. if not isinstance(value, list):
  241. return []
  242. return [item for item in value if isinstance(item, dict)]
  243. def canonicalize(value: Any) -> Any:
  244. if isinstance(value, dict):
  245. return {key: canonicalize(value[key]) for key in sorted(value)}
  246. if isinstance(value, list):
  247. normalized_items = [canonicalize(item) for item in value]
  248. return sorted(normalized_items, key=lambda item: json.dumps(item, ensure_ascii=False, sort_keys=True))
  249. return value
  250. def answers_match(actual_data: list[dict[str, Any]], expected_answer: Any) -> bool:
  251. return canonicalize(actual_data) == canonicalize(expected_answer)
  252. def ensure_output_dir(path: Path) -> None:
  253. path.mkdir(parents=True, exist_ok=True)
  254. def write_json(path: Path, payload: Any) -> None:
  255. with path.open("w", encoding="utf-8") as handle:
  256. json.dump(payload, handle, ensure_ascii=False, indent=2)
  257. def run_case(case: dict[str, Any], timeout_seconds: int) -> dict[str, Any]:
  258. candidates, stream_events = request_generate_sql(case["question"], timeout_seconds)
  259. candidate_results: list[dict[str, Any]] = []
  260. case_success = False
  261. for rank in range(RCM_NUM):
  262. candidate = candidates[rank] if rank < len(candidates) else None
  263. if candidate is None:
  264. candidate_results.append(
  265. {
  266. "rank": rank + 1,
  267. "card_id": "",
  268. "sql": "",
  269. "query_success": False,
  270. "query_status_code": None,
  271. "query_error": "missing candidate from generate_sql stream",
  272. "query_result": [],
  273. "match_gold_answer": False,
  274. }
  275. )
  276. continue
  277. try:
  278. query_result = request_query_sql(candidate.sql, timeout_seconds)
  279. matched = (
  280. query_result.status_code == 1
  281. and not query_result.error
  282. and answers_match(query_result.data, case["answer"])
  283. )
  284. case_success = case_success or matched
  285. candidate_results.append(
  286. {
  287. "rank": rank + 1,
  288. "card_id": candidate.card_id,
  289. "sql": candidate.sql,
  290. "query_success": query_result.status_code == 1 and not query_result.error,
  291. "query_status_code": query_result.status_code,
  292. "query_error": query_result.error,
  293. "query_result": query_result.data,
  294. "match_gold_answer": matched,
  295. }
  296. )
  297. except Exception as exc: # noqa: BLE001
  298. candidate_results.append(
  299. {
  300. "rank": rank + 1,
  301. "card_id": candidate.card_id,
  302. "sql": candidate.sql,
  303. "query_success": False,
  304. "query_status_code": None,
  305. "query_error": str(exc),
  306. "query_result": [],
  307. "match_gold_answer": False,
  308. }
  309. )
  310. return {
  311. "case_id": case["case_id"],
  312. "question": case["question"],
  313. "gold_answer": case["answer"],
  314. "success": case_success,
  315. "generated_answers": candidate_results,
  316. "stream_events": stream_events,
  317. }
  318. def main() -> int:
  319. args = parse_args()
  320. input_path = Path(args.input_json)
  321. output_dir = Path(args.output_dir)
  322. ensure_output_dir(output_dir)
  323. cases = load_cases(input_path)
  324. case_results: list[dict[str, Any]] = []
  325. success_count = 0
  326. for case in cases:
  327. try:
  328. result = run_case(case, args.timeout)
  329. except Exception as exc: # noqa: BLE001
  330. result = {
  331. "case_id": case["case_id"],
  332. "question": case["question"],
  333. "gold_answer": case["answer"],
  334. "success": False,
  335. "generated_answers": [],
  336. "error": str(exc),
  337. }
  338. success_count += int(result["success"])
  339. case_results.append(result)
  340. total_count = len(case_results)
  341. success_rate = (success_count / total_count) if total_count else 0.0
  342. generated_answers_path = output_dir / "nl2sql_generated_answers.json"
  343. summary_path = output_dir / "nl2sql_batch_summary.json"
  344. write_json(generated_answers_path, {"cases": case_results})
  345. write_json(
  346. summary_path,
  347. {
  348. "input_json": str(input_path),
  349. "total_count": total_count,
  350. "success_count": success_count,
  351. "success_rate": success_rate,
  352. "generated_answers_file": str(generated_answers_path),
  353. },
  354. )
  355. print(
  356. json.dumps(
  357. {
  358. "total_count": total_count,
  359. "success_count": success_count,
  360. "success_rate": success_rate,
  361. "summary_file": str(summary_path),
  362. "generated_answers_file": str(generated_answers_path),
  363. },
  364. ensure_ascii=False,
  365. indent=2,
  366. )
  367. )
  368. return 0
  369. if __name__ == "__main__":
  370. try:
  371. raise SystemExit(main())
  372. except Exception as exc: # noqa: BLE001
  373. print(f"[ERROR] {exc}", file=sys.stderr)
  374. raise SystemExit(1) from exc