prompt_manager.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. """Prompt management utilities used across pipeline nodes."""
  2. from __future__ import annotations
  3. from dataclasses import dataclass
  4. from pathlib import Path
  5. from threading import Lock
  6. from typing import Dict, Iterable, Optional
  7. import yaml
  8. from app.const import CONFIG_DIR, PROMPT_DIR
  9. DEFAULT_PROMPTS = {
  10. "generate_qa_pair": {
  11. "template": (
  12. "你是一名BI分析助手,需要根据以下仪表盘与卡片的定义信息,"
  13. "结合卡片的过滤器条件,为每张卡片生成多个用户问题及其对应的SQL查询。\n"
  14. "仪表盘与卡片内容:\n{content}\n"
  15. "请输出JSON数组,每个元素包含 card_id, question, sample_question, filter_ids。"
  16. ),
  17. "variables": ["content"],
  18. },
  19. "generate_qa_pair_with_user_request": {
  20. "template": (
  21. "以下是用户额外的指标需求:{user_request}\n"
  22. "结合仪表盘内容生成更贴近需求的问题:\n{content}\n"
  23. "输出格式同 generate_qa_pair。"
  24. ),
  25. "variables": ["user_request", "content"],
  26. },
  27. }
  28. @dataclass
  29. class PromptRenderResult:
  30. text: str
  31. template_name: str
  32. class PromptTemplate:
  33. """Simple template wrapper with async friendly render helper."""
  34. def __init__(self, name: str, template: str, variables: Optional[Iterable[str]] = None) -> None:
  35. self.name = name
  36. self.template = template
  37. self.variables = list(variables or [])
  38. def render(self, context: Dict[str, str]) -> str:
  39. missing = [var for var in self.variables if var not in context]
  40. if missing:
  41. raise KeyError(f"Missing variables {missing} for template {self.name}")
  42. safe_context = {key: str(value) for key, value in context.items()}
  43. return self.template.format(**safe_context)
  44. async def ainvoke(self, context: Dict[str, str]) -> PromptRenderResult:
  45. return PromptRenderResult(text=self.render(context), template_name=self.name)
  46. class PromptManager:
  47. """Singleton responsible for loading prompt templates from disk or defaults."""
  48. _instance: Optional["PromptManager"] = None
  49. _lock = Lock()
  50. def __new__(cls, config_path: Optional[Path] = None) -> "PromptManager":
  51. with cls._lock:
  52. if cls._instance is None:
  53. cls._instance = super().__new__(cls)
  54. cls._instance._build(config_path)
  55. return cls._instance
  56. def _build(self, config_path: Optional[Path]) -> None:
  57. path = Path(config_path or (CONFIG_DIR / "prompt_template.yaml"))
  58. self._templates: Dict[str, PromptTemplate] = {}
  59. if path.exists():
  60. self._load_from_config(path)
  61. else:
  62. self._load_defaults()
  63. def _load_defaults(self) -> None:
  64. for name, info in DEFAULT_PROMPTS.items():
  65. self._templates[name] = PromptTemplate(
  66. name=name,
  67. template=info["template"],
  68. variables=info.get("variables"),
  69. )
  70. def _load_from_config(self, path: Path) -> None:
  71. with path.open("r", encoding="utf-8") as fh:
  72. prompt_config = yaml.safe_load(fh) or {}
  73. for name, info in prompt_config.items():
  74. template_text = self._resolve_template_source(info)
  75. self._templates[name] = PromptTemplate(
  76. name=name,
  77. template=template_text,
  78. variables=info.get("variables", []),
  79. )
  80. def _resolve_template_source(self, info: Dict[str, str]) -> str:
  81. if "template" in info:
  82. return info["template"]
  83. if "path" in info:
  84. tpl_path = info["path"]
  85. resolved = (PROMPT_DIR / tpl_path) if not Path(tpl_path).is_absolute() else Path(tpl_path)
  86. with resolved.open("r", encoding="utf-8") as fh:
  87. return fh.read()
  88. raise ValueError("Prompt template must define either 'template' or 'path'")
  89. def get_prompt_template(self, name: str) -> PromptTemplate:
  90. if name not in self._templates:
  91. raise KeyError(f"Prompt template {name} not found")
  92. return self._templates[name]