Bläddra i källkod

SQL生成测试

ysl2007 3 månader sedan
förälder
incheckning
5b891f8df2

+ 29 - 0
README.md

@@ -83,3 +83,32 @@ python3 recall_eval.py \
 - `dashboard_id`
 - `dataset_id`
 - `id`
+
+## NL2SQL 批量测试
+
+新增脚本 [nl2sql_batch_test.py](/Users/yangshuli/git/test/nl2sql_batch_test.py),用于批量执行:
+
+1. 读取测试 JSON
+2. 调用 SQL 生成流式接口,解析 5 个候选答案的 `card_id` 和 `sql`
+3. 逐条调用 SQL 查询接口
+4. 将查询结果与标准答案比对
+5. 输出成功数、成功率和逐题生成结果文件
+
+运行前需要先修改脚本顶部常量:
+
+- `BASE_URL`
+- `QUERY_SQL_URL`
+- `USER_TOKEN`
+
+运行命令:
+
+```bash
+python3 nl2sql_batch_test.py \
+  --input-json ./sql_generate_testcase.json \
+  --output-dir ./output
+```
+
+输出文件:
+
+- `output/nl2sql_batch_summary.json`
+- `output/nl2sql_generated_answers.json`

BIN
__pycache__/nl2sql_batch_test.cpython-313.pyc


+ 453 - 0
nl2sql_batch_test.py

@@ -0,0 +1,453 @@
+#!/usr/bin/env python3
+"""Batch test runner for NL2SQL generation and SQL execution."""
+
+from __future__ import annotations
+
+import argparse
+import json
+import sys
+import uuid
+from dataclasses import dataclass, field
+from pathlib import Path
+from typing import Any
+
+import requests
+
+# ===== Interface constants: adjust these values before each run if needed. =====
+BASE_URL = "replace-with-your-base-url"
+GENERATE_SQL_PATH = "/generate_sql"
+QUERY_SQL_URL = "http://replace-with-your-query-sql-api"
+USER_TOKEN = "replace-with-your-token"
+BBK = 100
+DOMAIN = "cmb"
+DEPARTMENT = "ABC"
+RCM_NUM = 5
+YST_ID = 311139
+USER_NAME = "杨林/311139"
+
+DEFAULT_TIMEOUT_SECONDS = 60
+DEFAULT_INPUT_JSON = "sql_generate_testcase.json"
+DEFAULT_OUTPUT_DIR = "output"
+
+
+@dataclass
+class SQLCandidate:
+    card_id: str
+    sql: str
+    source_event: str = ""
+
+
+@dataclass
+class QueryExecutionResult:
+    status_code: int | None
+    columns: list[str] = field(default_factory=list)
+    data: list[dict[str, Any]] = field(default_factory=list)
+    row_count: int | None = None
+    total_count: int | None = None
+    truncated: bool | None = None
+    error: str = ""
+    raw_payload: Any = None
+
+
+def parse_args() -> argparse.Namespace:
+    parser = argparse.ArgumentParser(description="Batch test NL2SQL generate/query flow.")
+    parser.add_argument(
+        "--input-json",
+        default=DEFAULT_INPUT_JSON,
+        help="JSON file containing question and gold answer cases.",
+    )
+    parser.add_argument(
+        "--output-dir",
+        default=DEFAULT_OUTPUT_DIR,
+        help="Directory for summary and per-question generated answers.",
+    )
+    parser.add_argument(
+        "--timeout",
+        type=int,
+        default=DEFAULT_TIMEOUT_SECONDS,
+        help="HTTP timeout in seconds for both APIs.",
+    )
+    return parser.parse_args()
+
+
+def build_generate_sql_url() -> str:
+    if BASE_URL == "replace-with-your-base-url":
+        raise ValueError("BASE_URL is still placeholder. Please set the actual service host.")
+    return f"http://{BASE_URL.rstrip('/')}{GENERATE_SQL_PATH}"
+
+
+def load_cases(path: Path) -> list[dict[str, Any]]:
+    with path.open("r", encoding="utf-8") as handle:
+        payload = json.load(handle)
+    if isinstance(payload, dict):
+        payload = payload.get("cases") or payload.get("data") or [payload]
+    if not isinstance(payload, list):
+        raise ValueError(f"Unsupported JSON structure in {path}")
+
+    cases: list[dict[str, Any]] = []
+    for index, item in enumerate(payload, start=1):
+        if not isinstance(item, dict):
+            raise ValueError(f"Case #{index} is not an object")
+        question = item.get("question")
+        answer = item.get("answer")
+        if not isinstance(question, str) or not question.strip():
+            raise ValueError(f"Case #{index} missing valid 'question'")
+        cases.append(
+            {
+                "case_id": item.get("case_id") or f"case_{index}",
+                "question": question.strip(),
+                "answer": answer,
+            }
+        )
+    if not cases:
+        raise ValueError(f"No valid test cases found in {path}")
+    return cases
+
+
+def make_request_id() -> str:
+    return uuid.uuid4().hex
+
+
+def build_generate_payload(question: str) -> dict[str, Any]:
+    return {
+        "request_id": make_request_id(),
+        "user_question": question,
+        "token": USER_TOKEN,
+        "bbk": BBK,
+        "domain": DOMAIN,
+        "department": DEPARTMENT,
+        "rcm_num": RCM_NUM,
+        "yst_id": YST_ID,
+        "user_name": USER_NAME,
+    }
+
+
+def build_query_payload(sql: str) -> dict[str, Any]:
+    return {
+        "request_id": make_request_id(),
+        "sql": sql,
+    }
+
+
+def parse_sse_event_block(block: str) -> tuple[str, str]:
+    event_type = ""
+    data_lines: list[str] = []
+    for line in block.splitlines():
+        if line.startswith("event:"):
+            event_type = line.split(":", 1)[1].strip()
+        elif line.startswith("data:"):
+            data_lines.append(line.split(":", 1)[1].lstrip())
+    return event_type, "\n".join(data_lines)
+
+
+def parse_json_maybe(raw: str) -> Any:
+    try:
+        return json.loads(raw)
+    except json.JSONDecodeError:
+        return raw
+
+
+def normalize_sql(sql: str) -> str:
+    return " ".join(sql.split()).strip()
+
+
+def get_first_non_empty_string(value: Any) -> str:
+    if isinstance(value, str):
+        return value.strip()
+    return ""
+
+
+def extract_candidate_from_mapping(mapping: dict[str, Any], source_event: str) -> SQLCandidate | None:
+    card_id_keys = ("card_id", "cardId", "id")
+    sql_keys = ("sql", "SQL", "final_sql", "finalSql", "content")
+
+    card_id = ""
+    sql = ""
+    for key in card_id_keys:
+        card_id = get_first_non_empty_string(mapping.get(key))
+        if card_id:
+            break
+    for key in sql_keys:
+        sql = get_first_non_empty_string(mapping.get(key))
+        if sql:
+            break
+
+    if card_id and sql:
+        return SQLCandidate(card_id=card_id, sql=normalize_sql(sql), source_event=source_event)
+    return None
+
+
+def extract_candidates_from_payload(payload: Any, source_event: str) -> list[SQLCandidate]:
+    candidates: list[SQLCandidate] = []
+
+    if isinstance(payload, dict):
+        direct = extract_candidate_from_mapping(payload, source_event)
+        if direct is not None:
+            candidates.append(direct)
+        for value in payload.values():
+            candidates.extend(extract_candidates_from_payload(value, source_event))
+        return dedupe_candidates(candidates)
+
+    if isinstance(payload, list):
+        for item in payload:
+            candidates.extend(extract_candidates_from_payload(item, source_event))
+        return dedupe_candidates(candidates)
+
+    return candidates
+
+
+def dedupe_candidates(candidates: list[SQLCandidate]) -> list[SQLCandidate]:
+    deduped: list[SQLCandidate] = []
+    seen: set[tuple[str, str]] = set()
+    for candidate in candidates:
+        key = (candidate.card_id, candidate.sql)
+        if not candidate.card_id or not candidate.sql or key in seen:
+            continue
+        seen.add(key)
+        deduped.append(candidate)
+    return deduped
+
+
+def parse_generate_stream(response: requests.Response) -> tuple[list[SQLCandidate], list[dict[str, Any]]]:
+    raw_events: list[dict[str, Any]] = []
+    candidates: list[SQLCandidate] = []
+    buffer: list[str] = []
+
+    for raw_line in response.iter_lines(decode_unicode=True):
+        if raw_line is None:
+            continue
+        line = raw_line
+        if line == "":
+            if not buffer:
+                continue
+            block = "\n".join(buffer)
+            buffer.clear()
+            event_type, data_raw = parse_sse_event_block(block)
+            payload = parse_json_maybe(data_raw)
+            raw_events.append({"event": event_type, "data": payload})
+            if event_type in {"dict", "message"}:
+                candidates.extend(extract_candidates_from_payload(payload, event_type))
+            elif event_type == "error":
+                error_msg = payload if isinstance(payload, str) else json.dumps(payload, ensure_ascii=False)
+                raise RuntimeError(f"generate_sql stream error: {error_msg}")
+            elif event_type == "end":
+                break
+            continue
+        buffer.append(line)
+
+    if buffer:
+        block = "\n".join(buffer)
+        event_type, data_raw = parse_sse_event_block(block)
+        payload = parse_json_maybe(data_raw)
+        raw_events.append({"event": event_type, "data": payload})
+        if event_type in {"dict", "message"}:
+            candidates.extend(extract_candidates_from_payload(payload, event_type))
+
+    return dedupe_candidates(candidates)[:RCM_NUM], raw_events
+
+
+def request_generate_sql(question: str, timeout_seconds: int) -> tuple[list[SQLCandidate], list[dict[str, Any]]]:
+    url = build_generate_sql_url()
+    payload = build_generate_payload(question)
+    response = requests.post(
+        url,
+        json=payload,
+        stream=True,
+        timeout=timeout_seconds,
+        headers={"Accept": "text/event-stream"},
+    )
+    response.raise_for_status()
+    return parse_generate_stream(response)
+
+
+def request_query_sql(sql: str, timeout_seconds: int) -> QueryExecutionResult:
+    if QUERY_SQL_URL == "http://replace-with-your-query-sql-api":
+        raise ValueError("QUERY_SQL_URL is still placeholder. Please set the actual query API URL.")
+
+    response = requests.post(
+        QUERY_SQL_URL,
+        json=build_query_payload(sql),
+        timeout=timeout_seconds,
+    )
+    response.raise_for_status()
+    payload = response.json()
+    if not isinstance(payload, dict):
+        raise ValueError(f"Query API returned non-object payload: {payload!r}")
+    return QueryExecutionResult(
+        status_code=payload.get("status_code"),
+        columns=ensure_string_list(payload.get("columns")),
+        data=ensure_row_list(payload.get("data")),
+        row_count=payload.get("row_count"),
+        total_count=payload.get("total_count"),
+        truncated=payload.get("truncated"),
+        error=str(payload.get("error") or ""),
+        raw_payload=payload,
+    )
+
+
+def ensure_string_list(value: Any) -> list[str]:
+    if not isinstance(value, list):
+        return []
+    return [str(item) for item in value]
+
+
+def ensure_row_list(value: Any) -> list[dict[str, Any]]:
+    if not isinstance(value, list):
+        return []
+    return [item for item in value if isinstance(item, dict)]
+
+
+def canonicalize(value: Any) -> Any:
+    if isinstance(value, dict):
+        return {key: canonicalize(value[key]) for key in sorted(value)}
+    if isinstance(value, list):
+        normalized_items = [canonicalize(item) for item in value]
+        return sorted(normalized_items, key=lambda item: json.dumps(item, ensure_ascii=False, sort_keys=True))
+    return value
+
+
+def answers_match(actual_data: list[dict[str, Any]], expected_answer: Any) -> bool:
+    return canonicalize(actual_data) == canonicalize(expected_answer)
+
+
+def ensure_output_dir(path: Path) -> None:
+    path.mkdir(parents=True, exist_ok=True)
+
+
+def write_json(path: Path, payload: Any) -> None:
+    with path.open("w", encoding="utf-8") as handle:
+        json.dump(payload, handle, ensure_ascii=False, indent=2)
+
+
+def run_case(case: dict[str, Any], timeout_seconds: int) -> dict[str, Any]:
+    candidates, stream_events = request_generate_sql(case["question"], timeout_seconds)
+    candidate_results: list[dict[str, Any]] = []
+    case_success = False
+
+    for rank in range(RCM_NUM):
+        candidate = candidates[rank] if rank < len(candidates) else None
+        if candidate is None:
+            candidate_results.append(
+                {
+                    "rank": rank + 1,
+                    "card_id": "",
+                    "sql": "",
+                    "query_success": False,
+                    "query_status_code": None,
+                    "query_error": "missing candidate from generate_sql stream",
+                    "query_result": [],
+                    "match_gold_answer": False,
+                }
+            )
+            continue
+
+        try:
+            query_result = request_query_sql(candidate.sql, timeout_seconds)
+            matched = (
+                query_result.status_code == 1
+                and not query_result.error
+                and answers_match(query_result.data, case["answer"])
+            )
+            case_success = case_success or matched
+            candidate_results.append(
+                {
+                    "rank": rank + 1,
+                    "card_id": candidate.card_id,
+                    "sql": candidate.sql,
+                    "query_success": query_result.status_code == 1 and not query_result.error,
+                    "query_status_code": query_result.status_code,
+                    "query_error": query_result.error,
+                    "query_result": query_result.data,
+                    "match_gold_answer": matched,
+                }
+            )
+        except Exception as exc:  # noqa: BLE001
+            candidate_results.append(
+                {
+                    "rank": rank + 1,
+                    "card_id": candidate.card_id,
+                    "sql": candidate.sql,
+                    "query_success": False,
+                    "query_status_code": None,
+                    "query_error": str(exc),
+                    "query_result": [],
+                    "match_gold_answer": False,
+                }
+            )
+
+    return {
+        "case_id": case["case_id"],
+        "question": case["question"],
+        "gold_answer": case["answer"],
+        "success": case_success,
+        "generated_answers": candidate_results,
+        "stream_events": stream_events,
+    }
+
+
+def main() -> int:
+    args = parse_args()
+    input_path = Path(args.input_json)
+    output_dir = Path(args.output_dir)
+    ensure_output_dir(output_dir)
+
+    cases = load_cases(input_path)
+
+    case_results: list[dict[str, Any]] = []
+    success_count = 0
+
+    for case in cases:
+        try:
+            result = run_case(case, args.timeout)
+        except Exception as exc:  # noqa: BLE001
+            result = {
+                "case_id": case["case_id"],
+                "question": case["question"],
+                "gold_answer": case["answer"],
+                "success": False,
+                "generated_answers": [],
+                "error": str(exc),
+            }
+        success_count += int(result["success"])
+        case_results.append(result)
+
+    total_count = len(case_results)
+    success_rate = (success_count / total_count) if total_count else 0.0
+
+    generated_answers_path = output_dir / "nl2sql_generated_answers.json"
+    summary_path = output_dir / "nl2sql_batch_summary.json"
+
+    write_json(generated_answers_path, {"cases": case_results})
+    write_json(
+        summary_path,
+        {
+            "input_json": str(input_path),
+            "total_count": total_count,
+            "success_count": success_count,
+            "success_rate": success_rate,
+            "generated_answers_file": str(generated_answers_path),
+        },
+    )
+
+    print(
+        json.dumps(
+            {
+                "total_count": total_count,
+                "success_count": success_count,
+                "success_rate": success_rate,
+                "summary_file": str(summary_path),
+                "generated_answers_file": str(generated_answers_path),
+            },
+            ensure_ascii=False,
+            indent=2,
+        )
+    )
+    return 0
+
+
+if __name__ == "__main__":
+    try:
+        raise SystemExit(main())
+    except Exception as exc:  # noqa: BLE001
+        print(f"[ERROR] {exc}", file=sys.stderr)
+        raise SystemExit(1) from exc

+ 7 - 0
output/nl2sql_batch_summary.json

@@ -0,0 +1,7 @@
+{
+  "input_json": "sql_generate_testcase.json",
+  "total_count": 1,
+  "success_count": 0,
+  "success_rate": 0.0,
+  "generated_answers_file": "output/nl2sql_generated_answers.json"
+}

+ 18 - 0
output/nl2sql_generated_answers.json

@@ -0,0 +1,18 @@
+{
+  "cases": [
+    {
+      "case_id": "case_1",
+      "question": "v9e5f5ff0dfc948278659001月梁婷分行的高效仪表板盘数_业务部门占比是多少?",
+      "gold_answer": [
+        {
+          "分行名称": "梁婷",
+          "高效仪表板盘数_非业务部门_求和": -164700,
+          "高效仪表板盘数_业务部门_求和": 164925
+        }
+      ],
+      "success": false,
+      "generated_answers": [],
+      "error": "BASE_URL is still placeholder. Please set the actual service host."
+    }
+  ]
+}

+ 1 - 0
sql_generate_testcase.json

@@ -0,0 +1 @@
+{"question": "v9e5f5ff0dfc948278659001月梁婷分行的高效仪表板盘数_业务部门占比是多少?", "answer": [{"分行名称": "梁婷", "高效仪表板盘数_非业务部门_求和": -164700, "高效仪表板盘数_业务部门_求和": 164925}]}

+ 207 - 0
stream_output_help.md

@@ -0,0 +1,207 @@
+# NL2SQL SSE 流式响应解析器
+## 1. SSE 报文格式
+### 1.1 基本格式
+SSE (Server-Sent Events) 格式如下:
+```
+event: <event_type>
+data: <json_data>
+```
+注意:每个事件以 `\n\n` (两个换行符) 结尾。
+### 1.2 事件类型
+| 事件类型 | 说明 | 触发时机 |
+|---------|------|---------|
+| `message` | 节点状态通知 | 节点开始执行或执行完成时 |
+| `dict` | 流式数据块 | LLM生成SQL的每个token片段 |
+| `error` | 错误信息 | 节点执行出错时 |
+| `end` | 流式结束 | 整个流程完成时 |
+---
+## 2. 各事件详细格式
+### 2.1 message 事件 - 节点状态通知
+**节点开始执行:**
+```
+event: message
+data: {"node": "get_rec_dataset", "msg": "\n##开始执行<get_rec_dataset>##\n", "task": -1}
+```
+**节点执行完成:**
+```
+event: message
+data: {"node": "get_rec_dataset", "msg": "\n##执行完成<get_rec_dataset>##\n", "task": -1}
+```
+**data 字段说明:**
+| 字段 | 类型 | 说明 |
+|-----|------|------|
+| node | string | 节点名称 |
+| msg | string | 消息内容,包含节点状态标识 |
+| task | int | 任务ID(通常为 -1) |
+### 2.2 dict 事件 - 流式数据块
+**格式:**
+```
+event: dict
+data: {"node": "generate_process", "msg": "SELECT", "task": -1}
+```
+```
+event: dict
+data: {"node": "generate_process", "msg": " *", "task": -1}
+```
+**说明:**
+- 每个事件携带一个SQL片段
+- 需要累积所有片段才能得到完整SQL
+- 通常来自 `generate_process` 节点
+**data 字段说明:**
+| 字段 | 类型 | 说明 |
+|-----|------|------|
+| node | string | 节点名称(通常为 generate_process) |
+| msg | string | SQL片段(单个token或多个字符) |
+| task | int | 任务ID |
+### 2.3 error 事件 - 错误信息
+**格式:**
+```
+event: error
+data: {"node": "generate_process", "msg": "错误详情", "task": -1}
+```
+**data 字段说明:**
+| 字段 | 类型 | 说明 |
+|-----|------|------|
+| node | string | 出错的节点名称 |
+| msg | string | 错误消息 |
+| task | int | 任务ID |
+### 2.4 end 事件 - 流式结束
+**格式:**
+```
+event: end
+data: {"status": "done", "answer": "[Done]"}
+```
+**data 字段说明:**
+| 字段 | 类型 | 说明 |
+|-----|------|------|
+| status | string | 状态("done") |
+| answer | string | 结束标识("[Done]") |
+---
+## 3. 完整流程示例
+以下是完整的 SSE 流式响应示例:
+```
+event: message
+data: {"node": "get_rec_dataset", "msg": "\n##开始执行<get_rec_dataset>##\n", "task": -1}
+event: message
+data: {"node": "get_rec_dataset", "msg": "\n##执行完成<get_rec_dataset>##\n", "task": -1}
+event: message
+data: {"node": "get_rec_dataset_info", "msg": "\n##开始执行<get_rec_dataset_info>##\n", "task": -1}
+event: message
+data: {"node": "get_rec_dataset_info", "msg": "\n##执行完成<get_rec_dataset_info>##\n", "task": -1}
+event: message
+data: {"node": "generate_process", "msg": "\n##开始执行<generate_process>##\n", "task": -1}
+event: dict
+data: {"node": "generate_process", "msg": "SELECT", "task": -1}
+event: dict
+data: {"node": "generate_process", "msg": " *", "task": -1}
+event: dict
+data: {"node": "generate_process", "msg": " FROM", "task": -1}
+event: dict
+data: {"node": "generate_process", "msg": " table_name", "task": -1}
+event: dict
+data: {"node": "generate_process", "msg": " WHERE", "task": -1}
+event: dict
+data: {"node": "generate_process", "msg": " condition", "task": -1}
+event: message
+data: {"node": "generate_process", "msg": "\n##执行完成<generate_process>##\n", "task": -1}
+event: end
+data: {"status": "done", "answer": "[Done]"}
+```
+---
+## 4. 节点执行顺序
+根据 Pipeline 配置,节点按拓扑顺序执行:
+```
+get_rec_dataset (数据集推荐)
+       ↓
+get_rec_dataset_info (数据集信息获取)
+       ↓
+generate_process (SQL生成)
+```
+---
+## 5. 使用示例
+### 5.1 异步客户端(推荐)
+```python
+import asyncio
+from client.nl2sql_client import NL2SQLClient
+async def main():
+    client = NL2SQLClient(base_url="http://localhost:8000")
+    
+    # 方式1:获取完整结果
+    result = await client.generate_sql(
+        user_question="我想看看对公客户评级到期当日需推送的客户经理数量",
+        bbk="512",
+        domain="cmb_su",
+        token="your-token"
+    )
+    print(f"SQL: {result.sql}")
+    print(f"成功: {result.success}")
+    
+    # 方式2:流式处理(实时显示)
+    async for event in client.generate_sql_stream(
+        user_question="...",
+        bbk="512",
+        domain="cmb_su",
+        token="your-token"
+    ):
+        if event.event.name == "DICT":
+            print(event.message, end="", flush=True)
+    
+    await client.close()
+asyncio.run(main())
+```
+### 5.2 同步客户端
+```python
+from client.nl2sql_client import NL2SQLClientSync
+client = NL2SQLClientSync(base_url="http://localhost:8000")
+result = client.generate_sql(
+    user_question="我想看看对公客户评级到期当日需推送的客户经理数量",
+    bbk="512",
+    domain="cmb_su",
+    token="your-token"
+)
+print(f"SQL: {result.sql}")
+```
+### 5.3 仅使用解析器
+```python
+from client.sse_parser import SSEParser
+parser = SSEParser()
+# 设置回调
+parser.on_node_start = lambda node: print(f"[开始] {node}")
+parser.on_node_complete = lambda node: print(f"[完成] {node}")
+parser.on_data_chunk = lambda node, chunk: print(chunk, end="")
+# 解析 SSE 行
+parser.parse_line('event: message\\ndata: {"node": "get_rec_dataset", "msg": "##开始执行<get_rec_dataset>##"}\\n\\n')
+# 获取累积结果
+sql = parser.get_node_result("generate_process")
+```
+---
+## 6. 响应头说明
+| 响应头 | 值 | 说明 |
+|-------|---|------|
+| Content-Type | text/event-stream | SSE 流式响应 |
+| Cache-Control | no-cache | 禁止缓存 |
+| Connection | keep-alive | 保持连接 |
+| X-Request-ID | <uuid> | 请求ID(用于追踪) |
+---
+## 7. 错误处理
+### 7.1 客户端断开连接
+服务端会检测客户端断开,并终止流式响应。
+### 7.2 服务端错误
+错误通过 `error` 事件返回:
+```
+event: error
+data: {"node": "generate_process", "msg": "错误详情", "task": -1}
+```
+### 7.3 服务端繁忙(429)
+当并发超过限制时,返回 HTTP 429:
+```
+HTTP/1.1 429 Too Many Requests
+{"detail": "服务器繁忙,请稍后再试"}
+```
+---
+## 8. 文件说明
+| 文件 | 说明 |
+|-----|------|
+| `sse_parser.py` | SSE 解析器核心实现 |
+| `nl2sql_client.py` | NL2SQL 客户端封装 |
+| `README.md` | 本文档 |