| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309 |
- """Operator entry point for generating QA pairs from dashboard metadata."""
- from __future__ import annotations
- import inspect
- import json
- import os
- from typing import Any, Dict, List, Tuple
- from app.const import LOG_ROOT
- from app.endpoints.schema import CardInfo, DashboardInfo, QAPair
- from app.llm import LLMManager
- from app.log import logger
- from app.pipeline import PipelineManager
- from app.prompt_manager import PromptManager
- def validate_input(input_args: Dict[str, Any]) -> str:
- """Validate BBK info in ``input_args`` and return the optional user request."""
- bbk = input_args["bbk"]
- return input_args.get("user_request", "")
- def build_condition_dict(card: CardInfo) -> Dict[str, Any]:
- """Build a searchable condition dictionary for the given card."""
- condition_dict: Dict[str, Any] = {}
- for flt in card.filters:
- key = flt.filter_id.replace("'", "")
- condition_dict[key] = {
- "type": flt.type,
- "条件语句": flt.where_clause,
- "聚合条件": False,
- "默认值": flt.default_value,
- "选项": flt.options,
- }
- sql_where = card.sql_where or {}
- for raw_key, raw_value in sql_where.items():
- key = raw_key.replace("'", "")
- if key not in condition_dict:
- condition_dict[key] = {
- "type": raw_value.get("type", "F"),
- "条件语句": raw_value.get("exp", ""),
- "聚合条件": raw_value.get("agg", False),
- "默认值": raw_value.get("default", ""),
- "选项": raw_value.get("options", []),
- }
- return condition_dict
- def build_card_content(card: CardInfo, condition_dict: Dict[str, Any]) -> List[str]:
- """Build prompt content for a single card."""
- content_lines = [
- "---------------------------------------------------------------------------------------",
- f"卡片(ID: {card.card_id}):",
- f"卡片名称: {card.card_name}",
- f"卡片描述: {card.card_desc}",
- f"数据集ID: {card.dataset_id}",
- f"SELECT: {card.sql_select}",
- ]
- if card.sql_groupby:
- content_lines.append(f"GROUP BY: {card.sql_groupby}")
- content_lines.append("\n卡片的过滤器条件:")
- for key, value in condition_dict.items():
- if value["type"] == "F":
- content_lines.extend([
- f"filter_id: {key}",
- f"条件语句: {value['条件语句']}",
- ])
- if value["聚合条件"]:
- content_lines.append(f"是否聚合: {value['聚合条件']}")
- if value["默认值"]:
- content_lines.append(f"默认值: {value['默认值']}")
- if value["选项"]:
- content_lines.append(f"选项: {value['选项']}")
- content_lines.append("\n卡片的固化条件(D类型)")
- for key, value in condition_dict.items():
- if value["type"] == "D":
- content_lines.extend([
- f"filter_id: {key}",
- f"条件语句: {value['条件语句']}",
- ])
- return content_lines
- def build_dashboard_content(dashboard_info: DashboardInfo) -> Tuple[str, Dict[str, Any]]:
- """Aggregate the prompt text for the whole dashboard."""
- content_lines = [
- "=======================================================================================",
- f"仪表盘ID: {dashboard_info.dashboard_id}",
- f"仪表盘名称: {dashboard_info.dashboard_name}",
- f"仪表盘描述: {dashboard_info.dashboard_desc}",
- f"文件夹全路径: {dashboard_info.folder_path or '无'}\n",
- ]
- card_id_2_filters: Dict[str, Any] = {}
- for card in dashboard_info.cards:
- logger.log(
- f"Processing card ID: {card.card_id} in dashboard ID: {dashboard_info.dashboard_id}",
- level="DEBUG",
- )
- condition_dict = build_condition_dict(card)
- card_id_2_filters[card.card_id] = condition_dict
- content_lines.extend(build_card_content(card, condition_dict))
- return "\n".join(content_lines), card_id_2_filters
- async def generate_prompt_content(
- prompt_manager: PromptManager,
- content: str,
- user_request: str,
- ):
- """Render the prompt template with optional user request."""
- context = {"content": content}
- if user_request:
- context["user_request"] = user_request
- template = prompt_manager.get_prompt_template("generate_qa_pair_with_user_request")
- else:
- template = prompt_manager.get_prompt_template("generate_qa_pair")
- return await template.ainvoke(context)
- def save_prompt_log(dashboard_id: str, prompt_text: str) -> None:
- """Persist the generated prompt for debugging and auditing."""
- log_dir = os.path.join(LOG_ROOT, "prompts", "generate_qa_pair", dashboard_id)
- os.makedirs(log_dir, exist_ok=True)
- log_file = os.path.join(log_dir, "prompt.txt")
- with open(log_file, "w", encoding="utf-8") as file:
- file.write(prompt_text)
- logger.log(f"Prompt saved to {log_file}", level="INFO")
- def build_where_clauses(
- condition_dict: Dict[str, Any],
- filter_ids: List[str],
- ) -> Tuple[str, str]:
- """Build WHERE and HAVING clauses based on selected filters."""
- where_statement, having_statement = "", ""
- for key, value in condition_dict.items():
- if value["type"] == "D" and key not in filter_ids:
- filter_ids.append(key)
- for filter_id in filter_ids:
- if filter_id not in condition_dict:
- logger.log(f"Filter ID {filter_id} not found in condition_dict", level="WARN")
- continue
- filter_dict = condition_dict[filter_id]
- statement = filter_dict["条件语句"] + " AND "
- if filter_dict["聚合条件"]:
- having_statement += statement
- else:
- where_statement += statement
- return where_statement.rstrip(" AND "), having_statement.rstrip(" AND ")
- def build_sql_statement(card: CardInfo, where_statement: str, having_statement: str) -> str:
- """Combine the SQL fragments for the final executable statement."""
- sql = card.sql_select
- if where_statement:
- sql += f" WHERE {where_statement} "
- if card.sql_groupby:
- sql += card.sql_groupby
- if having_statement:
- sql += f" HAVING {having_statement} "
- return sql.replace("\\n", "\n")
- def generate_qa_pairs(
- dashboard_info: DashboardInfo,
- generated_qa: List[Dict[str, Any]],
- card_id_2_filters: Dict[str, Any],
- ) -> List[QAPair]:
- """Combine LLM output with metadata to create ``QAPair`` objects."""
- qa_pairs: List[QAPair] = []
- card_map = {card.card_id: card for card in dashboard_info.cards}
- for item in generated_qa:
- try:
- card_id = item.get("card_id", "")
- if card_id not in card_map:
- logger.log(
- f"Card ID {card_id} not found in dashboard {dashboard_info.dashboard_id}",
- level="WARN",
- )
- continue
- card = card_map[card_id]
- filter_ids = item.get("filter_ids", [])
- condition_dict = card_id_2_filters[card_id]
- where_statement, having_statement = build_where_clauses(condition_dict, filter_ids)
- sql_statement = build_sql_statement(card, where_statement, having_statement)
- logger.log(f"Generated SQL for card ID {card_id}:\n{sql_statement}", level="DEBUG")
- qa_pairs.append(
- QAPair(
- dashboard_id=dashboard_info.dashboard_id,
- dashboard_name=dashboard_info.dashboard_name,
- dashboard_desc=dashboard_info.dashboard_desc,
- card_id=card.card_id,
- card_name=card.card_name,
- card_desc=card.card_desc,
- dataset_id=card.dataset_id,
- question=item.get("question", ""),
- question_with_slot=item.get("sample_question", ""),
- answer=item.get("sql", sql_statement),
- filter_ids=filter_ids,
- )
- )
- except Exception as exc: # pragma: no cover - defensive logging
- logger.log(f"Error generating QA pair for item {item}: {exc}", level="WARN")
- return qa_pairs
- def _clean_llm_response(text: str) -> str:
- """Strip optional Markdown fences from the LLM output."""
- content = text.strip()
- if content.startswith("```"):
- lines = content.splitlines()
- if len(lines) >= 2:
- content = "\n".join(lines[1:-1])
- if content.startswith("json"):
- content = content[4:].lstrip()
- return content
- async def generate_qa_pair(input_args: Dict[str, Any]) -> List[QAPair]:
- """Generate QA pairs based on dashboard + card metadata."""
- user_request = validate_input(input_args)
- dashboard_info = input_args["get_dashboard_info"]
- content, card_id_2_filters = build_dashboard_content(dashboard_info)
- prompt = await generate_prompt_content(PromptManager(), content, user_request)
- save_prompt_log(dashboard_info.dashboard_id, prompt.text)
- llm = LLMManager().get_llm_model(
- PipelineManager().get_node_config(inspect.currentframe().f_code.co_name)["model"]
- )
- resp = await llm.ainvoke(prompt.text)
- try:
- generated_qa = json.loads(_clean_llm_response(resp.content))
- except json.JSONDecodeError as exc:
- logger.log(f"JSON decode error: {exc}, primary response: {resp.content}", level="ERROR")
- generated_qa = []
- if not generated_qa:
- # fallback: one QA per card with basic filter-less SQL
- generated_qa = [
- {
- "card_id": card.card_id,
- "question": f"{card.card_name} 的指标是多少?",
- "sample_question": f"请给出 {card.card_name} 的最新指标",
- "filter_ids": [],
- }
- for card in dashboard_info.cards
- ]
- return generate_qa_pairs(dashboard_info, generated_qa, card_id_2_filters)
- if __name__ == "__main__": # pragma: no cover - manual smoke test
- import asyncio
- from app.operators.get_dashboard_info import get_dashboard_info
- from app.const import CONFIG_DIR
- pipline_setup_path = CONFIG_DIR / "pipline_settings.yaml"
- _pipline_manager = PipelineManager(str(pipline_setup_path))
- template_config_path = CONFIG_DIR / "prompt_template.yaml"
- _prompt_manager = PromptManager(str(template_config_path))
- llm_config_path = CONFIG_DIR / "llm_settings.yaml"
- _llm_manager = LLMManager(str(llm_config_path))
- test_data = {
- "dashboard_id": "test_dashboard_001",
- "card_ids": ["card_001", "card_002"],
- "bbk": "default_bbk",
- "user_request": "",
- }
- print("Getting dashboard info...")
- dashboard = asyncio.run(get_dashboard_info(test_data))
- test_data["get_dashboard_info"] = dashboard
- result = asyncio.run(generate_qa_pair(test_data))
- print(result)
|