"""Management utilities for heterogeneous LLM providers.""" from __future__ import annotations import json from dataclasses import dataclass from pathlib import Path from threading import Lock from typing import Any, Dict, Optional, Type import yaml from app.const import CONFIG_DIR from .parser_manager import ParserManager from .parsers import BaseParser, ParsedResult, TextParser @dataclass class LLMResponse: content: str raw: Dict[str, Any] class BaseLLMClient: """Base class that concrete providers inherit from.""" def __init__(self, name: str, parser: Optional[BaseParser] = None, **config: Any) -> None: self.name = name self.config = config self.parser = parser or TextParser() async def ainvoke(self, prompt: str, **kwargs: Any) -> LLMResponse: # pragma: no cover - abstract raise NotImplementedError def _wrap(self, parsed: ParsedResult) -> LLMResponse: content = parsed.data if not isinstance(content, str): content = json.dumps(content, ensure_ascii=False) return LLMResponse(content=content, raw={"metadata": parsed.metadata}) class EchoLLMClient(BaseLLMClient): """Fallback client that simply echoes prompts.""" async def ainvoke(self, prompt: str, **_: Any) -> LLMResponse: parsed = self.parser.parse(prompt) return self._wrap(parsed) class HttpLLMClient(BaseLLMClient): """HTTP JSON based LLM provider.""" async def ainvoke(self, prompt: str, **kwargs: Any) -> LLMResponse: if httpx is None: raise ImportError("httpx is required for HttpLLMClient") endpoint = self.config["endpoint"] method = self.config.get("method", "POST").upper() timeout = self.config.get("timeout", 60) payload = dict(self.config.get("payload", {})) payload_field = self.config.get("prompt_field", "prompt") payload[payload_field] = prompt payload.update(kwargs.get("extra_payload", {})) headers = self.config.get("headers", {}) async with httpx.AsyncClient(timeout=timeout) as client: response = await client.request(method, endpoint, json=payload, headers=headers) response.raise_for_status() response_json = response.json() content_path = self.config.get("response_path", []) content: Any = response_json for key in content_path: if isinstance(content, dict): content = content[key] else: raise KeyError(f"Cannot resolve response path segment {key}") if not isinstance(content, str): content = json.dumps(content, ensure_ascii=False) return LLMResponse(content=content, raw=response_json) LLM_CLIENTS: Dict[str, Type[BaseLLMClient]] = { "echo": EchoLLMClient, "http": HttpLLMClient, } class LLMManager: """Singleton responsible for instantiating configured LLM clients.""" _instance: Optional["LLMManager"] = None _lock = Lock() def __new__(cls, config_path: Optional[Path] = None) -> "LLMManager": with cls._lock: if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._build(config_path) return cls._instance def _build(self, config_path: Optional[Path]) -> None: path = Path(config_path or (CONFIG_DIR / "llm_settings.yaml")) self._models: Dict[str, BaseLLMClient] = {} self._model_config: Dict[str, Dict[str, Any]] = {} self._default_model = "echo" if path.exists(): with path.open("r", encoding="utf-8") as fh: config = yaml.safe_load(fh) or {} self._default_model = config.get("default_model", "echo") self._model_config = config.get("models", {}) else: self._model_config = {"echo": {"provider": "echo", "parser": "text"}} def _resolve_parser(self, parser_name: Optional[str]) -> BaseParser: if not parser_name: return TextParser() try: return ParserManager().get_parser(parser_name) except KeyError: return TextParser() def _create_model(self, model_name: str) -> BaseLLMClient: config = self._model_config.get(model_name) if not config: raise KeyError(f"Model {model_name} not defined in llm_settings.yaml") provider = config.get("provider", "echo") parser_name = config.get("parser") parser = self._resolve_parser(parser_name) options = config.get("options", {}) client_cls = LLM_CLIENTS.get(provider) if not client_cls: raise ValueError(f"Unsupported LLM provider: {provider}") return client_cls(model_name, parser=parser, **options) def get_llm_model(self, model_name: Optional[str] = None) -> BaseLLMClient: target = model_name or self._default_model if target not in self._models: self._models[target] = self._create_model(target) return self._models[target] try: import httpx except ImportError: # pragma: no cover - optional dependency httpx = None