| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453 |
- #!/usr/bin/env python3
- """Batch test runner for NL2SQL generation and SQL execution."""
- from __future__ import annotations
- import argparse
- import json
- import sys
- import uuid
- from dataclasses import dataclass, field
- from pathlib import Path
- from typing import Any
- import requests
- # ===== Interface constants: adjust these values before each run if needed. =====
- BASE_URL = "replace-with-your-base-url"
- GENERATE_SQL_PATH = "/generate_sql"
- QUERY_SQL_URL = "http://replace-with-your-query-sql-api"
- USER_TOKEN = "replace-with-your-token"
- BBK = 100
- DOMAIN = "cmb"
- DEPARTMENT = "ABC"
- RCM_NUM = 5
- YST_ID = 311139
- USER_NAME = "杨林/311139"
- DEFAULT_TIMEOUT_SECONDS = 60
- DEFAULT_INPUT_JSON = "sql_generate_testcase.json"
- DEFAULT_OUTPUT_DIR = "output"
- @dataclass
- class SQLCandidate:
- card_id: str
- sql: str
- source_event: str = ""
- @dataclass
- class QueryExecutionResult:
- status_code: int | None
- columns: list[str] = field(default_factory=list)
- data: list[dict[str, Any]] = field(default_factory=list)
- row_count: int | None = None
- total_count: int | None = None
- truncated: bool | None = None
- error: str = ""
- raw_payload: Any = None
- def parse_args() -> argparse.Namespace:
- parser = argparse.ArgumentParser(description="Batch test NL2SQL generate/query flow.")
- parser.add_argument(
- "--input-json",
- default=DEFAULT_INPUT_JSON,
- help="JSON file containing question and gold answer cases.",
- )
- parser.add_argument(
- "--output-dir",
- default=DEFAULT_OUTPUT_DIR,
- help="Directory for summary and per-question generated answers.",
- )
- parser.add_argument(
- "--timeout",
- type=int,
- default=DEFAULT_TIMEOUT_SECONDS,
- help="HTTP timeout in seconds for both APIs.",
- )
- return parser.parse_args()
- def build_generate_sql_url() -> str:
- if BASE_URL == "replace-with-your-base-url":
- raise ValueError("BASE_URL is still placeholder. Please set the actual service host.")
- return f"http://{BASE_URL.rstrip('/')}{GENERATE_SQL_PATH}"
- def load_cases(path: Path) -> list[dict[str, Any]]:
- with path.open("r", encoding="utf-8") as handle:
- payload = json.load(handle)
- if isinstance(payload, dict):
- payload = payload.get("cases") or payload.get("data") or [payload]
- if not isinstance(payload, list):
- raise ValueError(f"Unsupported JSON structure in {path}")
- cases: list[dict[str, Any]] = []
- for index, item in enumerate(payload, start=1):
- if not isinstance(item, dict):
- raise ValueError(f"Case #{index} is not an object")
- question = item.get("question")
- answer = item.get("answer")
- if not isinstance(question, str) or not question.strip():
- raise ValueError(f"Case #{index} missing valid 'question'")
- cases.append(
- {
- "case_id": item.get("case_id") or f"case_{index}",
- "question": question.strip(),
- "answer": answer,
- }
- )
- if not cases:
- raise ValueError(f"No valid test cases found in {path}")
- return cases
- def make_request_id() -> str:
- return uuid.uuid4().hex
- def build_generate_payload(question: str) -> dict[str, Any]:
- return {
- "request_id": make_request_id(),
- "user_question": question,
- "token": USER_TOKEN,
- "bbk": BBK,
- "domain": DOMAIN,
- "department": DEPARTMENT,
- "rcm_num": RCM_NUM,
- "yst_id": YST_ID,
- "user_name": USER_NAME,
- }
- def build_query_payload(sql: str) -> dict[str, Any]:
- return {
- "request_id": make_request_id(),
- "sql": sql,
- }
- def parse_sse_event_block(block: str) -> tuple[str, str]:
- event_type = ""
- data_lines: list[str] = []
- for line in block.splitlines():
- if line.startswith("event:"):
- event_type = line.split(":", 1)[1].strip()
- elif line.startswith("data:"):
- data_lines.append(line.split(":", 1)[1].lstrip())
- return event_type, "\n".join(data_lines)
- def parse_json_maybe(raw: str) -> Any:
- try:
- return json.loads(raw)
- except json.JSONDecodeError:
- return raw
- def normalize_sql(sql: str) -> str:
- return " ".join(sql.split()).strip()
- def get_first_non_empty_string(value: Any) -> str:
- if isinstance(value, str):
- return value.strip()
- return ""
- def extract_candidate_from_mapping(mapping: dict[str, Any], source_event: str) -> SQLCandidate | None:
- card_id_keys = ("card_id", "cardId", "id")
- sql_keys = ("sql", "SQL", "final_sql", "finalSql", "content")
- card_id = ""
- sql = ""
- for key in card_id_keys:
- card_id = get_first_non_empty_string(mapping.get(key))
- if card_id:
- break
- for key in sql_keys:
- sql = get_first_non_empty_string(mapping.get(key))
- if sql:
- break
- if card_id and sql:
- return SQLCandidate(card_id=card_id, sql=normalize_sql(sql), source_event=source_event)
- return None
- def extract_candidates_from_payload(payload: Any, source_event: str) -> list[SQLCandidate]:
- candidates: list[SQLCandidate] = []
- if isinstance(payload, dict):
- direct = extract_candidate_from_mapping(payload, source_event)
- if direct is not None:
- candidates.append(direct)
- for value in payload.values():
- candidates.extend(extract_candidates_from_payload(value, source_event))
- return dedupe_candidates(candidates)
- if isinstance(payload, list):
- for item in payload:
- candidates.extend(extract_candidates_from_payload(item, source_event))
- return dedupe_candidates(candidates)
- return candidates
- def dedupe_candidates(candidates: list[SQLCandidate]) -> list[SQLCandidate]:
- deduped: list[SQLCandidate] = []
- seen: set[tuple[str, str]] = set()
- for candidate in candidates:
- key = (candidate.card_id, candidate.sql)
- if not candidate.card_id or not candidate.sql or key in seen:
- continue
- seen.add(key)
- deduped.append(candidate)
- return deduped
- def parse_generate_stream(response: requests.Response) -> tuple[list[SQLCandidate], list[dict[str, Any]]]:
- raw_events: list[dict[str, Any]] = []
- candidates: list[SQLCandidate] = []
- buffer: list[str] = []
- for raw_line in response.iter_lines(decode_unicode=True):
- if raw_line is None:
- continue
- line = raw_line
- if line == "":
- if not buffer:
- continue
- block = "\n".join(buffer)
- buffer.clear()
- event_type, data_raw = parse_sse_event_block(block)
- payload = parse_json_maybe(data_raw)
- raw_events.append({"event": event_type, "data": payload})
- if event_type in {"dict", "message"}:
- candidates.extend(extract_candidates_from_payload(payload, event_type))
- elif event_type == "error":
- error_msg = payload if isinstance(payload, str) else json.dumps(payload, ensure_ascii=False)
- raise RuntimeError(f"generate_sql stream error: {error_msg}")
- elif event_type == "end":
- break
- continue
- buffer.append(line)
- if buffer:
- block = "\n".join(buffer)
- event_type, data_raw = parse_sse_event_block(block)
- payload = parse_json_maybe(data_raw)
- raw_events.append({"event": event_type, "data": payload})
- if event_type in {"dict", "message"}:
- candidates.extend(extract_candidates_from_payload(payload, event_type))
- return dedupe_candidates(candidates)[:RCM_NUM], raw_events
- def request_generate_sql(question: str, timeout_seconds: int) -> tuple[list[SQLCandidate], list[dict[str, Any]]]:
- url = build_generate_sql_url()
- payload = build_generate_payload(question)
- response = requests.post(
- url,
- json=payload,
- stream=True,
- timeout=timeout_seconds,
- headers={"Accept": "text/event-stream"},
- )
- response.raise_for_status()
- return parse_generate_stream(response)
- def request_query_sql(sql: str, timeout_seconds: int) -> QueryExecutionResult:
- if QUERY_SQL_URL == "http://replace-with-your-query-sql-api":
- raise ValueError("QUERY_SQL_URL is still placeholder. Please set the actual query API URL.")
- response = requests.post(
- QUERY_SQL_URL,
- json=build_query_payload(sql),
- timeout=timeout_seconds,
- )
- response.raise_for_status()
- payload = response.json()
- if not isinstance(payload, dict):
- raise ValueError(f"Query API returned non-object payload: {payload!r}")
- return QueryExecutionResult(
- status_code=payload.get("status_code"),
- columns=ensure_string_list(payload.get("columns")),
- data=ensure_row_list(payload.get("data")),
- row_count=payload.get("row_count"),
- total_count=payload.get("total_count"),
- truncated=payload.get("truncated"),
- error=str(payload.get("error") or ""),
- raw_payload=payload,
- )
- def ensure_string_list(value: Any) -> list[str]:
- if not isinstance(value, list):
- return []
- return [str(item) for item in value]
- def ensure_row_list(value: Any) -> list[dict[str, Any]]:
- if not isinstance(value, list):
- return []
- return [item for item in value if isinstance(item, dict)]
- def canonicalize(value: Any) -> Any:
- if isinstance(value, dict):
- return {key: canonicalize(value[key]) for key in sorted(value)}
- if isinstance(value, list):
- normalized_items = [canonicalize(item) for item in value]
- return sorted(normalized_items, key=lambda item: json.dumps(item, ensure_ascii=False, sort_keys=True))
- return value
- def answers_match(actual_data: list[dict[str, Any]], expected_answer: Any) -> bool:
- return canonicalize(actual_data) == canonicalize(expected_answer)
- def ensure_output_dir(path: Path) -> None:
- path.mkdir(parents=True, exist_ok=True)
- def write_json(path: Path, payload: Any) -> None:
- with path.open("w", encoding="utf-8") as handle:
- json.dump(payload, handle, ensure_ascii=False, indent=2)
- def run_case(case: dict[str, Any], timeout_seconds: int) -> dict[str, Any]:
- candidates, stream_events = request_generate_sql(case["question"], timeout_seconds)
- candidate_results: list[dict[str, Any]] = []
- case_success = False
- for rank in range(RCM_NUM):
- candidate = candidates[rank] if rank < len(candidates) else None
- if candidate is None:
- candidate_results.append(
- {
- "rank": rank + 1,
- "card_id": "",
- "sql": "",
- "query_success": False,
- "query_status_code": None,
- "query_error": "missing candidate from generate_sql stream",
- "query_result": [],
- "match_gold_answer": False,
- }
- )
- continue
- try:
- query_result = request_query_sql(candidate.sql, timeout_seconds)
- matched = (
- query_result.status_code == 1
- and not query_result.error
- and answers_match(query_result.data, case["answer"])
- )
- case_success = case_success or matched
- candidate_results.append(
- {
- "rank": rank + 1,
- "card_id": candidate.card_id,
- "sql": candidate.sql,
- "query_success": query_result.status_code == 1 and not query_result.error,
- "query_status_code": query_result.status_code,
- "query_error": query_result.error,
- "query_result": query_result.data,
- "match_gold_answer": matched,
- }
- )
- except Exception as exc: # noqa: BLE001
- candidate_results.append(
- {
- "rank": rank + 1,
- "card_id": candidate.card_id,
- "sql": candidate.sql,
- "query_success": False,
- "query_status_code": None,
- "query_error": str(exc),
- "query_result": [],
- "match_gold_answer": False,
- }
- )
- return {
- "case_id": case["case_id"],
- "question": case["question"],
- "gold_answer": case["answer"],
- "success": case_success,
- "generated_answers": candidate_results,
- "stream_events": stream_events,
- }
- def main() -> int:
- args = parse_args()
- input_path = Path(args.input_json)
- output_dir = Path(args.output_dir)
- ensure_output_dir(output_dir)
- cases = load_cases(input_path)
- case_results: list[dict[str, Any]] = []
- success_count = 0
- for case in cases:
- try:
- result = run_case(case, args.timeout)
- except Exception as exc: # noqa: BLE001
- result = {
- "case_id": case["case_id"],
- "question": case["question"],
- "gold_answer": case["answer"],
- "success": False,
- "generated_answers": [],
- "error": str(exc),
- }
- success_count += int(result["success"])
- case_results.append(result)
- total_count = len(case_results)
- success_rate = (success_count / total_count) if total_count else 0.0
- generated_answers_path = output_dir / "nl2sql_generated_answers.json"
- summary_path = output_dir / "nl2sql_batch_summary.json"
- write_json(generated_answers_path, {"cases": case_results})
- write_json(
- summary_path,
- {
- "input_json": str(input_path),
- "total_count": total_count,
- "success_count": success_count,
- "success_rate": success_rate,
- "generated_answers_file": str(generated_answers_path),
- },
- )
- print(
- json.dumps(
- {
- "total_count": total_count,
- "success_count": success_count,
- "success_rate": success_rate,
- "summary_file": str(summary_path),
- "generated_answers_file": str(generated_answers_path),
- },
- ensure_ascii=False,
- indent=2,
- )
- )
- return 0
- if __name__ == "__main__":
- try:
- raise SystemExit(main())
- except Exception as exc: # noqa: BLE001
- print(f"[ERROR] {exc}", file=sys.stderr)
- raise SystemExit(1) from exc
|