"""Prompt management utilities used across pipeline nodes.""" from __future__ import annotations from dataclasses import dataclass from pathlib import Path from threading import Lock from typing import Dict, Iterable, Optional import yaml from app.const import CONFIG_DIR, PROMPT_DIR DEFAULT_PROMPTS = { "generate_qa_pair": { "template": ( "你是一名BI分析助手,需要根据以下仪表盘与卡片的定义信息," "结合卡片的过滤器条件,为每张卡片生成多个用户问题及其对应的SQL查询。\n" "仪表盘与卡片内容:\n{content}\n" "请输出JSON数组,每个元素包含 card_id, question, sample_question, filter_ids。" ), "variables": ["content"], }, "generate_qa_pair_with_user_request": { "template": ( "以下是用户额外的指标需求:{user_request}\n" "结合仪表盘内容生成更贴近需求的问题:\n{content}\n" "输出格式同 generate_qa_pair。" ), "variables": ["user_request", "content"], }, } @dataclass class PromptRenderResult: text: str template_name: str class PromptTemplate: """Simple template wrapper with async friendly render helper.""" def __init__(self, name: str, template: str, variables: Optional[Iterable[str]] = None) -> None: self.name = name self.template = template self.variables = list(variables or []) def render(self, context: Dict[str, str]) -> str: missing = [var for var in self.variables if var not in context] if missing: raise KeyError(f"Missing variables {missing} for template {self.name}") safe_context = {key: str(value) for key, value in context.items()} return self.template.format(**safe_context) async def ainvoke(self, context: Dict[str, str]) -> PromptRenderResult: return PromptRenderResult(text=self.render(context), template_name=self.name) class PromptManager: """Singleton responsible for loading prompt templates from disk or defaults.""" _instance: Optional["PromptManager"] = None _lock = Lock() def __new__(cls, config_path: Optional[Path] = None) -> "PromptManager": with cls._lock: if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._build(config_path) return cls._instance def _build(self, config_path: Optional[Path]) -> None: path = Path(config_path or (CONFIG_DIR / "prompt_template.yaml")) self._templates: Dict[str, PromptTemplate] = {} if path.exists(): self._load_from_config(path) else: self._load_defaults() def _load_defaults(self) -> None: for name, info in DEFAULT_PROMPTS.items(): self._templates[name] = PromptTemplate( name=name, template=info["template"], variables=info.get("variables"), ) def _load_from_config(self, path: Path) -> None: with path.open("r", encoding="utf-8") as fh: prompt_config = yaml.safe_load(fh) or {} for name, info in prompt_config.items(): template_text = self._resolve_template_source(info) self._templates[name] = PromptTemplate( name=name, template=template_text, variables=info.get("variables", []), ) def _resolve_template_source(self, info: Dict[str, str]) -> str: if "template" in info: return info["template"] if "path" in info: tpl_path = info["path"] resolved = (PROMPT_DIR / tpl_path) if not Path(tpl_path).is_absolute() else Path(tpl_path) with resolved.open("r", encoding="utf-8") as fh: return fh.read() raise ValueError("Prompt template must define either 'template' or 'path'") def get_prompt_template(self, name: str) -> PromptTemplate: if name not in self._templates: raise KeyError(f"Prompt template {name} not found") return self._templates[name]