|
|
@@ -0,0 +1,453 @@
|
|
|
+#!/usr/bin/env python3
|
|
|
+"""Batch test runner for NL2SQL generation and SQL execution."""
|
|
|
+
|
|
|
+from __future__ import annotations
|
|
|
+
|
|
|
+import argparse
|
|
|
+import json
|
|
|
+import sys
|
|
|
+import uuid
|
|
|
+from dataclasses import dataclass, field
|
|
|
+from pathlib import Path
|
|
|
+from typing import Any
|
|
|
+
|
|
|
+import requests
|
|
|
+
|
|
|
+# ===== Interface constants: adjust these values before each run if needed. =====
|
|
|
+BASE_URL = "replace-with-your-base-url"
|
|
|
+GENERATE_SQL_PATH = "/generate_sql"
|
|
|
+QUERY_SQL_URL = "http://replace-with-your-query-sql-api"
|
|
|
+USER_TOKEN = "replace-with-your-token"
|
|
|
+BBK = 100
|
|
|
+DOMAIN = "cmb"
|
|
|
+DEPARTMENT = "ABC"
|
|
|
+RCM_NUM = 5
|
|
|
+YST_ID = 311139
|
|
|
+USER_NAME = "杨林/311139"
|
|
|
+
|
|
|
+DEFAULT_TIMEOUT_SECONDS = 60
|
|
|
+DEFAULT_INPUT_JSON = "sql_generate_testcase.json"
|
|
|
+DEFAULT_OUTPUT_DIR = "output"
|
|
|
+
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class SQLCandidate:
|
|
|
+ card_id: str
|
|
|
+ sql: str
|
|
|
+ source_event: str = ""
|
|
|
+
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class QueryExecutionResult:
|
|
|
+ status_code: int | None
|
|
|
+ columns: list[str] = field(default_factory=list)
|
|
|
+ data: list[dict[str, Any]] = field(default_factory=list)
|
|
|
+ row_count: int | None = None
|
|
|
+ total_count: int | None = None
|
|
|
+ truncated: bool | None = None
|
|
|
+ error: str = ""
|
|
|
+ raw_payload: Any = None
|
|
|
+
|
|
|
+
|
|
|
+def parse_args() -> argparse.Namespace:
|
|
|
+ parser = argparse.ArgumentParser(description="Batch test NL2SQL generate/query flow.")
|
|
|
+ parser.add_argument(
|
|
|
+ "--input-json",
|
|
|
+ default=DEFAULT_INPUT_JSON,
|
|
|
+ help="JSON file containing question and gold answer cases.",
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--output-dir",
|
|
|
+ default=DEFAULT_OUTPUT_DIR,
|
|
|
+ help="Directory for summary and per-question generated answers.",
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--timeout",
|
|
|
+ type=int,
|
|
|
+ default=DEFAULT_TIMEOUT_SECONDS,
|
|
|
+ help="HTTP timeout in seconds for both APIs.",
|
|
|
+ )
|
|
|
+ return parser.parse_args()
|
|
|
+
|
|
|
+
|
|
|
+def build_generate_sql_url() -> str:
|
|
|
+ if BASE_URL == "replace-with-your-base-url":
|
|
|
+ raise ValueError("BASE_URL is still placeholder. Please set the actual service host.")
|
|
|
+ return f"http://{BASE_URL.rstrip('/')}{GENERATE_SQL_PATH}"
|
|
|
+
|
|
|
+
|
|
|
+def load_cases(path: Path) -> list[dict[str, Any]]:
|
|
|
+ with path.open("r", encoding="utf-8") as handle:
|
|
|
+ payload = json.load(handle)
|
|
|
+ if isinstance(payload, dict):
|
|
|
+ payload = payload.get("cases") or payload.get("data") or [payload]
|
|
|
+ if not isinstance(payload, list):
|
|
|
+ raise ValueError(f"Unsupported JSON structure in {path}")
|
|
|
+
|
|
|
+ cases: list[dict[str, Any]] = []
|
|
|
+ for index, item in enumerate(payload, start=1):
|
|
|
+ if not isinstance(item, dict):
|
|
|
+ raise ValueError(f"Case #{index} is not an object")
|
|
|
+ question = item.get("question")
|
|
|
+ answer = item.get("answer")
|
|
|
+ if not isinstance(question, str) or not question.strip():
|
|
|
+ raise ValueError(f"Case #{index} missing valid 'question'")
|
|
|
+ cases.append(
|
|
|
+ {
|
|
|
+ "case_id": item.get("case_id") or f"case_{index}",
|
|
|
+ "question": question.strip(),
|
|
|
+ "answer": answer,
|
|
|
+ }
|
|
|
+ )
|
|
|
+ if not cases:
|
|
|
+ raise ValueError(f"No valid test cases found in {path}")
|
|
|
+ return cases
|
|
|
+
|
|
|
+
|
|
|
+def make_request_id() -> str:
|
|
|
+ return uuid.uuid4().hex
|
|
|
+
|
|
|
+
|
|
|
+def build_generate_payload(question: str) -> dict[str, Any]:
|
|
|
+ return {
|
|
|
+ "request_id": make_request_id(),
|
|
|
+ "user_question": question,
|
|
|
+ "token": USER_TOKEN,
|
|
|
+ "bbk": BBK,
|
|
|
+ "domain": DOMAIN,
|
|
|
+ "department": DEPARTMENT,
|
|
|
+ "rcm_num": RCM_NUM,
|
|
|
+ "yst_id": YST_ID,
|
|
|
+ "user_name": USER_NAME,
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+def build_query_payload(sql: str) -> dict[str, Any]:
|
|
|
+ return {
|
|
|
+ "request_id": make_request_id(),
|
|
|
+ "sql": sql,
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+def parse_sse_event_block(block: str) -> tuple[str, str]:
|
|
|
+ event_type = ""
|
|
|
+ data_lines: list[str] = []
|
|
|
+ for line in block.splitlines():
|
|
|
+ if line.startswith("event:"):
|
|
|
+ event_type = line.split(":", 1)[1].strip()
|
|
|
+ elif line.startswith("data:"):
|
|
|
+ data_lines.append(line.split(":", 1)[1].lstrip())
|
|
|
+ return event_type, "\n".join(data_lines)
|
|
|
+
|
|
|
+
|
|
|
+def parse_json_maybe(raw: str) -> Any:
|
|
|
+ try:
|
|
|
+ return json.loads(raw)
|
|
|
+ except json.JSONDecodeError:
|
|
|
+ return raw
|
|
|
+
|
|
|
+
|
|
|
+def normalize_sql(sql: str) -> str:
|
|
|
+ return " ".join(sql.split()).strip()
|
|
|
+
|
|
|
+
|
|
|
+def get_first_non_empty_string(value: Any) -> str:
|
|
|
+ if isinstance(value, str):
|
|
|
+ return value.strip()
|
|
|
+ return ""
|
|
|
+
|
|
|
+
|
|
|
+def extract_candidate_from_mapping(mapping: dict[str, Any], source_event: str) -> SQLCandidate | None:
|
|
|
+ card_id_keys = ("card_id", "cardId", "id")
|
|
|
+ sql_keys = ("sql", "SQL", "final_sql", "finalSql", "content")
|
|
|
+
|
|
|
+ card_id = ""
|
|
|
+ sql = ""
|
|
|
+ for key in card_id_keys:
|
|
|
+ card_id = get_first_non_empty_string(mapping.get(key))
|
|
|
+ if card_id:
|
|
|
+ break
|
|
|
+ for key in sql_keys:
|
|
|
+ sql = get_first_non_empty_string(mapping.get(key))
|
|
|
+ if sql:
|
|
|
+ break
|
|
|
+
|
|
|
+ if card_id and sql:
|
|
|
+ return SQLCandidate(card_id=card_id, sql=normalize_sql(sql), source_event=source_event)
|
|
|
+ return None
|
|
|
+
|
|
|
+
|
|
|
+def extract_candidates_from_payload(payload: Any, source_event: str) -> list[SQLCandidate]:
|
|
|
+ candidates: list[SQLCandidate] = []
|
|
|
+
|
|
|
+ if isinstance(payload, dict):
|
|
|
+ direct = extract_candidate_from_mapping(payload, source_event)
|
|
|
+ if direct is not None:
|
|
|
+ candidates.append(direct)
|
|
|
+ for value in payload.values():
|
|
|
+ candidates.extend(extract_candidates_from_payload(value, source_event))
|
|
|
+ return dedupe_candidates(candidates)
|
|
|
+
|
|
|
+ if isinstance(payload, list):
|
|
|
+ for item in payload:
|
|
|
+ candidates.extend(extract_candidates_from_payload(item, source_event))
|
|
|
+ return dedupe_candidates(candidates)
|
|
|
+
|
|
|
+ return candidates
|
|
|
+
|
|
|
+
|
|
|
+def dedupe_candidates(candidates: list[SQLCandidate]) -> list[SQLCandidate]:
|
|
|
+ deduped: list[SQLCandidate] = []
|
|
|
+ seen: set[tuple[str, str]] = set()
|
|
|
+ for candidate in candidates:
|
|
|
+ key = (candidate.card_id, candidate.sql)
|
|
|
+ if not candidate.card_id or not candidate.sql or key in seen:
|
|
|
+ continue
|
|
|
+ seen.add(key)
|
|
|
+ deduped.append(candidate)
|
|
|
+ return deduped
|
|
|
+
|
|
|
+
|
|
|
+def parse_generate_stream(response: requests.Response) -> tuple[list[SQLCandidate], list[dict[str, Any]]]:
|
|
|
+ raw_events: list[dict[str, Any]] = []
|
|
|
+ candidates: list[SQLCandidate] = []
|
|
|
+ buffer: list[str] = []
|
|
|
+
|
|
|
+ for raw_line in response.iter_lines(decode_unicode=True):
|
|
|
+ if raw_line is None:
|
|
|
+ continue
|
|
|
+ line = raw_line
|
|
|
+ if line == "":
|
|
|
+ if not buffer:
|
|
|
+ continue
|
|
|
+ block = "\n".join(buffer)
|
|
|
+ buffer.clear()
|
|
|
+ event_type, data_raw = parse_sse_event_block(block)
|
|
|
+ payload = parse_json_maybe(data_raw)
|
|
|
+ raw_events.append({"event": event_type, "data": payload})
|
|
|
+ if event_type in {"dict", "message"}:
|
|
|
+ candidates.extend(extract_candidates_from_payload(payload, event_type))
|
|
|
+ elif event_type == "error":
|
|
|
+ error_msg = payload if isinstance(payload, str) else json.dumps(payload, ensure_ascii=False)
|
|
|
+ raise RuntimeError(f"generate_sql stream error: {error_msg}")
|
|
|
+ elif event_type == "end":
|
|
|
+ break
|
|
|
+ continue
|
|
|
+ buffer.append(line)
|
|
|
+
|
|
|
+ if buffer:
|
|
|
+ block = "\n".join(buffer)
|
|
|
+ event_type, data_raw = parse_sse_event_block(block)
|
|
|
+ payload = parse_json_maybe(data_raw)
|
|
|
+ raw_events.append({"event": event_type, "data": payload})
|
|
|
+ if event_type in {"dict", "message"}:
|
|
|
+ candidates.extend(extract_candidates_from_payload(payload, event_type))
|
|
|
+
|
|
|
+ return dedupe_candidates(candidates)[:RCM_NUM], raw_events
|
|
|
+
|
|
|
+
|
|
|
+def request_generate_sql(question: str, timeout_seconds: int) -> tuple[list[SQLCandidate], list[dict[str, Any]]]:
|
|
|
+ url = build_generate_sql_url()
|
|
|
+ payload = build_generate_payload(question)
|
|
|
+ response = requests.post(
|
|
|
+ url,
|
|
|
+ json=payload,
|
|
|
+ stream=True,
|
|
|
+ timeout=timeout_seconds,
|
|
|
+ headers={"Accept": "text/event-stream"},
|
|
|
+ )
|
|
|
+ response.raise_for_status()
|
|
|
+ return parse_generate_stream(response)
|
|
|
+
|
|
|
+
|
|
|
+def request_query_sql(sql: str, timeout_seconds: int) -> QueryExecutionResult:
|
|
|
+ if QUERY_SQL_URL == "http://replace-with-your-query-sql-api":
|
|
|
+ raise ValueError("QUERY_SQL_URL is still placeholder. Please set the actual query API URL.")
|
|
|
+
|
|
|
+ response = requests.post(
|
|
|
+ QUERY_SQL_URL,
|
|
|
+ json=build_query_payload(sql),
|
|
|
+ timeout=timeout_seconds,
|
|
|
+ )
|
|
|
+ response.raise_for_status()
|
|
|
+ payload = response.json()
|
|
|
+ if not isinstance(payload, dict):
|
|
|
+ raise ValueError(f"Query API returned non-object payload: {payload!r}")
|
|
|
+ return QueryExecutionResult(
|
|
|
+ status_code=payload.get("status_code"),
|
|
|
+ columns=ensure_string_list(payload.get("columns")),
|
|
|
+ data=ensure_row_list(payload.get("data")),
|
|
|
+ row_count=payload.get("row_count"),
|
|
|
+ total_count=payload.get("total_count"),
|
|
|
+ truncated=payload.get("truncated"),
|
|
|
+ error=str(payload.get("error") or ""),
|
|
|
+ raw_payload=payload,
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+def ensure_string_list(value: Any) -> list[str]:
|
|
|
+ if not isinstance(value, list):
|
|
|
+ return []
|
|
|
+ return [str(item) for item in value]
|
|
|
+
|
|
|
+
|
|
|
+def ensure_row_list(value: Any) -> list[dict[str, Any]]:
|
|
|
+ if not isinstance(value, list):
|
|
|
+ return []
|
|
|
+ return [item for item in value if isinstance(item, dict)]
|
|
|
+
|
|
|
+
|
|
|
+def canonicalize(value: Any) -> Any:
|
|
|
+ if isinstance(value, dict):
|
|
|
+ return {key: canonicalize(value[key]) for key in sorted(value)}
|
|
|
+ if isinstance(value, list):
|
|
|
+ normalized_items = [canonicalize(item) for item in value]
|
|
|
+ return sorted(normalized_items, key=lambda item: json.dumps(item, ensure_ascii=False, sort_keys=True))
|
|
|
+ return value
|
|
|
+
|
|
|
+
|
|
|
+def answers_match(actual_data: list[dict[str, Any]], expected_answer: Any) -> bool:
|
|
|
+ return canonicalize(actual_data) == canonicalize(expected_answer)
|
|
|
+
|
|
|
+
|
|
|
+def ensure_output_dir(path: Path) -> None:
|
|
|
+ path.mkdir(parents=True, exist_ok=True)
|
|
|
+
|
|
|
+
|
|
|
+def write_json(path: Path, payload: Any) -> None:
|
|
|
+ with path.open("w", encoding="utf-8") as handle:
|
|
|
+ json.dump(payload, handle, ensure_ascii=False, indent=2)
|
|
|
+
|
|
|
+
|
|
|
+def run_case(case: dict[str, Any], timeout_seconds: int) -> dict[str, Any]:
|
|
|
+ candidates, stream_events = request_generate_sql(case["question"], timeout_seconds)
|
|
|
+ candidate_results: list[dict[str, Any]] = []
|
|
|
+ case_success = False
|
|
|
+
|
|
|
+ for rank in range(RCM_NUM):
|
|
|
+ candidate = candidates[rank] if rank < len(candidates) else None
|
|
|
+ if candidate is None:
|
|
|
+ candidate_results.append(
|
|
|
+ {
|
|
|
+ "rank": rank + 1,
|
|
|
+ "card_id": "",
|
|
|
+ "sql": "",
|
|
|
+ "query_success": False,
|
|
|
+ "query_status_code": None,
|
|
|
+ "query_error": "missing candidate from generate_sql stream",
|
|
|
+ "query_result": [],
|
|
|
+ "match_gold_answer": False,
|
|
|
+ }
|
|
|
+ )
|
|
|
+ continue
|
|
|
+
|
|
|
+ try:
|
|
|
+ query_result = request_query_sql(candidate.sql, timeout_seconds)
|
|
|
+ matched = (
|
|
|
+ query_result.status_code == 1
|
|
|
+ and not query_result.error
|
|
|
+ and answers_match(query_result.data, case["answer"])
|
|
|
+ )
|
|
|
+ case_success = case_success or matched
|
|
|
+ candidate_results.append(
|
|
|
+ {
|
|
|
+ "rank": rank + 1,
|
|
|
+ "card_id": candidate.card_id,
|
|
|
+ "sql": candidate.sql,
|
|
|
+ "query_success": query_result.status_code == 1 and not query_result.error,
|
|
|
+ "query_status_code": query_result.status_code,
|
|
|
+ "query_error": query_result.error,
|
|
|
+ "query_result": query_result.data,
|
|
|
+ "match_gold_answer": matched,
|
|
|
+ }
|
|
|
+ )
|
|
|
+ except Exception as exc: # noqa: BLE001
|
|
|
+ candidate_results.append(
|
|
|
+ {
|
|
|
+ "rank": rank + 1,
|
|
|
+ "card_id": candidate.card_id,
|
|
|
+ "sql": candidate.sql,
|
|
|
+ "query_success": False,
|
|
|
+ "query_status_code": None,
|
|
|
+ "query_error": str(exc),
|
|
|
+ "query_result": [],
|
|
|
+ "match_gold_answer": False,
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+ return {
|
|
|
+ "case_id": case["case_id"],
|
|
|
+ "question": case["question"],
|
|
|
+ "gold_answer": case["answer"],
|
|
|
+ "success": case_success,
|
|
|
+ "generated_answers": candidate_results,
|
|
|
+ "stream_events": stream_events,
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+def main() -> int:
|
|
|
+ args = parse_args()
|
|
|
+ input_path = Path(args.input_json)
|
|
|
+ output_dir = Path(args.output_dir)
|
|
|
+ ensure_output_dir(output_dir)
|
|
|
+
|
|
|
+ cases = load_cases(input_path)
|
|
|
+
|
|
|
+ case_results: list[dict[str, Any]] = []
|
|
|
+ success_count = 0
|
|
|
+
|
|
|
+ for case in cases:
|
|
|
+ try:
|
|
|
+ result = run_case(case, args.timeout)
|
|
|
+ except Exception as exc: # noqa: BLE001
|
|
|
+ result = {
|
|
|
+ "case_id": case["case_id"],
|
|
|
+ "question": case["question"],
|
|
|
+ "gold_answer": case["answer"],
|
|
|
+ "success": False,
|
|
|
+ "generated_answers": [],
|
|
|
+ "error": str(exc),
|
|
|
+ }
|
|
|
+ success_count += int(result["success"])
|
|
|
+ case_results.append(result)
|
|
|
+
|
|
|
+ total_count = len(case_results)
|
|
|
+ success_rate = (success_count / total_count) if total_count else 0.0
|
|
|
+
|
|
|
+ generated_answers_path = output_dir / "nl2sql_generated_answers.json"
|
|
|
+ summary_path = output_dir / "nl2sql_batch_summary.json"
|
|
|
+
|
|
|
+ write_json(generated_answers_path, {"cases": case_results})
|
|
|
+ write_json(
|
|
|
+ summary_path,
|
|
|
+ {
|
|
|
+ "input_json": str(input_path),
|
|
|
+ "total_count": total_count,
|
|
|
+ "success_count": success_count,
|
|
|
+ "success_rate": success_rate,
|
|
|
+ "generated_answers_file": str(generated_answers_path),
|
|
|
+ },
|
|
|
+ )
|
|
|
+
|
|
|
+ print(
|
|
|
+ json.dumps(
|
|
|
+ {
|
|
|
+ "total_count": total_count,
|
|
|
+ "success_count": success_count,
|
|
|
+ "success_rate": success_rate,
|
|
|
+ "summary_file": str(summary_path),
|
|
|
+ "generated_answers_file": str(generated_answers_path),
|
|
|
+ },
|
|
|
+ ensure_ascii=False,
|
|
|
+ indent=2,
|
|
|
+ )
|
|
|
+ )
|
|
|
+ return 0
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ try:
|
|
|
+ raise SystemExit(main())
|
|
|
+ except Exception as exc: # noqa: BLE001
|
|
|
+ print(f"[ERROR] {exc}", file=sys.stderr)
|
|
|
+ raise SystemExit(1) from exc
|