llm_manager.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. """Management utilities for heterogeneous LLM providers."""
  2. from __future__ import annotations
  3. import json
  4. from dataclasses import dataclass
  5. from pathlib import Path
  6. from threading import Lock
  7. from typing import Any, Dict, Optional, Type
  8. import yaml
  9. from app.const import CONFIG_DIR
  10. from .parser_manager import ParserManager
  11. from .parsers import BaseParser, ParsedResult, TextParser
  12. @dataclass
  13. class LLMResponse:
  14. content: str
  15. raw: Dict[str, Any]
  16. class BaseLLMClient:
  17. """Base class that concrete providers inherit from."""
  18. def __init__(self, name: str, parser: Optional[BaseParser] = None, **config: Any) -> None:
  19. self.name = name
  20. self.config = config
  21. self.parser = parser or TextParser()
  22. async def ainvoke(self, prompt: str, **kwargs: Any) -> LLMResponse: # pragma: no cover - abstract
  23. raise NotImplementedError
  24. def _wrap(self, parsed: ParsedResult) -> LLMResponse:
  25. content = parsed.data
  26. if not isinstance(content, str):
  27. content = json.dumps(content, ensure_ascii=False)
  28. return LLMResponse(content=content, raw={"metadata": parsed.metadata})
  29. class EchoLLMClient(BaseLLMClient):
  30. """Fallback client that simply echoes prompts."""
  31. async def ainvoke(self, prompt: str, **_: Any) -> LLMResponse:
  32. parsed = self.parser.parse(prompt)
  33. return self._wrap(parsed)
  34. class HttpLLMClient(BaseLLMClient):
  35. """HTTP JSON based LLM provider."""
  36. async def ainvoke(self, prompt: str, **kwargs: Any) -> LLMResponse:
  37. if httpx is None:
  38. raise ImportError("httpx is required for HttpLLMClient")
  39. endpoint = self.config["endpoint"]
  40. method = self.config.get("method", "POST").upper()
  41. timeout = self.config.get("timeout", 60)
  42. payload = dict(self.config.get("payload", {}))
  43. payload_field = self.config.get("prompt_field", "prompt")
  44. payload[payload_field] = prompt
  45. payload.update(kwargs.get("extra_payload", {}))
  46. headers = self.config.get("headers", {})
  47. async with httpx.AsyncClient(timeout=timeout) as client:
  48. response = await client.request(method, endpoint, json=payload, headers=headers)
  49. response.raise_for_status()
  50. response_json = response.json()
  51. content_path = self.config.get("response_path", [])
  52. content: Any = response_json
  53. for key in content_path:
  54. if isinstance(content, dict):
  55. content = content[key]
  56. else:
  57. raise KeyError(f"Cannot resolve response path segment {key}")
  58. if not isinstance(content, str):
  59. content = json.dumps(content, ensure_ascii=False)
  60. return LLMResponse(content=content, raw=response_json)
  61. LLM_CLIENTS: Dict[str, Type[BaseLLMClient]] = {
  62. "echo": EchoLLMClient,
  63. "http": HttpLLMClient,
  64. }
  65. class LLMManager:
  66. """Singleton responsible for instantiating configured LLM clients."""
  67. _instance: Optional["LLMManager"] = None
  68. _lock = Lock()
  69. def __new__(cls, config_path: Optional[Path] = None) -> "LLMManager":
  70. with cls._lock:
  71. if cls._instance is None:
  72. cls._instance = super().__new__(cls)
  73. cls._instance._build(config_path)
  74. return cls._instance
  75. def _build(self, config_path: Optional[Path]) -> None:
  76. path = Path(config_path or (CONFIG_DIR / "llm_settings.yaml"))
  77. self._models: Dict[str, BaseLLMClient] = {}
  78. self._model_config: Dict[str, Dict[str, Any]] = {}
  79. self._default_model = "echo"
  80. if path.exists():
  81. with path.open("r", encoding="utf-8") as fh:
  82. config = yaml.safe_load(fh) or {}
  83. self._default_model = config.get("default_model", "echo")
  84. self._model_config = config.get("models", {})
  85. else:
  86. self._model_config = {"echo": {"provider": "echo", "parser": "text"}}
  87. def _resolve_parser(self, parser_name: Optional[str]) -> BaseParser:
  88. if not parser_name:
  89. return TextParser()
  90. try:
  91. return ParserManager().get_parser(parser_name)
  92. except KeyError:
  93. return TextParser()
  94. def _create_model(self, model_name: str) -> BaseLLMClient:
  95. config = self._model_config.get(model_name)
  96. if not config:
  97. raise KeyError(f"Model {model_name} not defined in llm_settings.yaml")
  98. provider = config.get("provider", "echo")
  99. parser_name = config.get("parser")
  100. parser = self._resolve_parser(parser_name)
  101. options = config.get("options", {})
  102. client_cls = LLM_CLIENTS.get(provider)
  103. if not client_cls:
  104. raise ValueError(f"Unsupported LLM provider: {provider}")
  105. return client_cls(model_name, parser=parser, **options)
  106. def get_llm_model(self, model_name: Optional[str] = None) -> BaseLLMClient:
  107. target = model_name or self._default_model
  108. if target not in self._models:
  109. self._models[target] = self._create_model(target)
  110. return self._models[target]
  111. try:
  112. import httpx
  113. except ImportError: # pragma: no cover - optional dependency
  114. httpx = None