| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114 |
- """Prompt management utilities used across pipeline nodes."""
- from __future__ import annotations
- from dataclasses import dataclass
- from pathlib import Path
- from threading import Lock
- from typing import Dict, Iterable, Optional
- import yaml
- from app.const import CONFIG_DIR, PROMPT_DIR
- DEFAULT_PROMPTS = {
- "generate_qa_pair": {
- "template": (
- "你是一名BI分析助手,需要根据以下仪表盘与卡片的定义信息,"
- "结合卡片的过滤器条件,为每张卡片生成多个用户问题及其对应的SQL查询。\n"
- "仪表盘与卡片内容:\n{content}\n"
- "请输出JSON数组,每个元素包含 card_id, question, sample_question, filter_ids。"
- ),
- "variables": ["content"],
- },
- "generate_qa_pair_with_user_request": {
- "template": (
- "以下是用户额外的指标需求:{user_request}\n"
- "结合仪表盘内容生成更贴近需求的问题:\n{content}\n"
- "输出格式同 generate_qa_pair。"
- ),
- "variables": ["user_request", "content"],
- },
- }
- @dataclass
- class PromptRenderResult:
- text: str
- template_name: str
- class PromptTemplate:
- """Simple template wrapper with async friendly render helper."""
- def __init__(self, name: str, template: str, variables: Optional[Iterable[str]] = None) -> None:
- self.name = name
- self.template = template
- self.variables = list(variables or [])
- def render(self, context: Dict[str, str]) -> str:
- missing = [var for var in self.variables if var not in context]
- if missing:
- raise KeyError(f"Missing variables {missing} for template {self.name}")
- safe_context = {key: str(value) for key, value in context.items()}
- return self.template.format(**safe_context)
- async def ainvoke(self, context: Dict[str, str]) -> PromptRenderResult:
- return PromptRenderResult(text=self.render(context), template_name=self.name)
- class PromptManager:
- """Singleton responsible for loading prompt templates from disk or defaults."""
- _instance: Optional["PromptManager"] = None
- _lock = Lock()
- def __new__(cls, config_path: Optional[Path] = None) -> "PromptManager":
- with cls._lock:
- if cls._instance is None:
- cls._instance = super().__new__(cls)
- cls._instance._build(config_path)
- return cls._instance
- def _build(self, config_path: Optional[Path]) -> None:
- path = Path(config_path or (CONFIG_DIR / "prompt_template.yaml"))
- self._templates: Dict[str, PromptTemplate] = {}
- if path.exists():
- self._load_from_config(path)
- else:
- self._load_defaults()
- def _load_defaults(self) -> None:
- for name, info in DEFAULT_PROMPTS.items():
- self._templates[name] = PromptTemplate(
- name=name,
- template=info["template"],
- variables=info.get("variables"),
- )
- def _load_from_config(self, path: Path) -> None:
- with path.open("r", encoding="utf-8") as fh:
- prompt_config = yaml.safe_load(fh) or {}
- for name, info in prompt_config.items():
- template_text = self._resolve_template_source(info)
- self._templates[name] = PromptTemplate(
- name=name,
- template=template_text,
- variables=info.get("variables", []),
- )
- def _resolve_template_source(self, info: Dict[str, str]) -> str:
- if "template" in info:
- return info["template"]
- if "path" in info:
- tpl_path = info["path"]
- resolved = (PROMPT_DIR / tpl_path) if not Path(tpl_path).is_absolute() else Path(tpl_path)
- with resolved.open("r", encoding="utf-8") as fh:
- return fh.read()
- raise ValueError("Prompt template must define either 'template' or 'path'")
- def get_prompt_template(self, name: str) -> PromptTemplate:
- if name not in self._templates:
- raise KeyError(f"Prompt template {name} not found")
- return self._templates[name]
|