| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142 |
- """Management utilities for heterogeneous LLM providers."""
- from __future__ import annotations
- import json
- from dataclasses import dataclass
- from pathlib import Path
- from threading import Lock
- from typing import Any, Dict, Optional, Type
- import yaml
- from app.const import CONFIG_DIR
- from .parser_manager import ParserManager
- from .parsers import BaseParser, ParsedResult, TextParser
- @dataclass
- class LLMResponse:
- content: str
- raw: Dict[str, Any]
- class BaseLLMClient:
- """Base class that concrete providers inherit from."""
- def __init__(self, name: str, parser: Optional[BaseParser] = None, **config: Any) -> None:
- self.name = name
- self.config = config
- self.parser = parser or TextParser()
- async def ainvoke(self, prompt: str, **kwargs: Any) -> LLMResponse: # pragma: no cover - abstract
- raise NotImplementedError
- def _wrap(self, parsed: ParsedResult) -> LLMResponse:
- content = parsed.data
- if not isinstance(content, str):
- content = json.dumps(content, ensure_ascii=False)
- return LLMResponse(content=content, raw={"metadata": parsed.metadata})
- class EchoLLMClient(BaseLLMClient):
- """Fallback client that simply echoes prompts."""
- async def ainvoke(self, prompt: str, **_: Any) -> LLMResponse:
- parsed = self.parser.parse(prompt)
- return self._wrap(parsed)
- class HttpLLMClient(BaseLLMClient):
- """HTTP JSON based LLM provider."""
- async def ainvoke(self, prompt: str, **kwargs: Any) -> LLMResponse:
- if httpx is None:
- raise ImportError("httpx is required for HttpLLMClient")
- endpoint = self.config["endpoint"]
- method = self.config.get("method", "POST").upper()
- timeout = self.config.get("timeout", 60)
- payload = dict(self.config.get("payload", {}))
- payload_field = self.config.get("prompt_field", "prompt")
- payload[payload_field] = prompt
- payload.update(kwargs.get("extra_payload", {}))
- headers = self.config.get("headers", {})
- async with httpx.AsyncClient(timeout=timeout) as client:
- response = await client.request(method, endpoint, json=payload, headers=headers)
- response.raise_for_status()
- response_json = response.json()
- content_path = self.config.get("response_path", [])
- content: Any = response_json
- for key in content_path:
- if isinstance(content, dict):
- content = content[key]
- else:
- raise KeyError(f"Cannot resolve response path segment {key}")
- if not isinstance(content, str):
- content = json.dumps(content, ensure_ascii=False)
- return LLMResponse(content=content, raw=response_json)
- LLM_CLIENTS: Dict[str, Type[BaseLLMClient]] = {
- "echo": EchoLLMClient,
- "http": HttpLLMClient,
- }
- class LLMManager:
- """Singleton responsible for instantiating configured LLM clients."""
- _instance: Optional["LLMManager"] = None
- _lock = Lock()
- def __new__(cls, config_path: Optional[Path] = None) -> "LLMManager":
- with cls._lock:
- if cls._instance is None:
- cls._instance = super().__new__(cls)
- cls._instance._build(config_path)
- return cls._instance
- def _build(self, config_path: Optional[Path]) -> None:
- path = Path(config_path or (CONFIG_DIR / "llm_settings.yaml"))
- self._models: Dict[str, BaseLLMClient] = {}
- self._model_config: Dict[str, Dict[str, Any]] = {}
- self._default_model = "echo"
- if path.exists():
- with path.open("r", encoding="utf-8") as fh:
- config = yaml.safe_load(fh) or {}
- self._default_model = config.get("default_model", "echo")
- self._model_config = config.get("models", {})
- else:
- self._model_config = {"echo": {"provider": "echo", "parser": "text"}}
- def _resolve_parser(self, parser_name: Optional[str]) -> BaseParser:
- if not parser_name:
- return TextParser()
- try:
- return ParserManager().get_parser(parser_name)
- except KeyError:
- return TextParser()
- def _create_model(self, model_name: str) -> BaseLLMClient:
- config = self._model_config.get(model_name)
- if not config:
- raise KeyError(f"Model {model_name} not defined in llm_settings.yaml")
- provider = config.get("provider", "echo")
- parser_name = config.get("parser")
- parser = self._resolve_parser(parser_name)
- options = config.get("options", {})
- client_cls = LLM_CLIENTS.get(provider)
- if not client_cls:
- raise ValueError(f"Unsupported LLM provider: {provider}")
- return client_cls(model_name, parser=parser, **options)
- def get_llm_model(self, model_name: Optional[str] = None) -> BaseLLMClient:
- target = model_name or self._default_model
- if target not in self._models:
- self._models[target] = self._create_model(target)
- return self._models[target]
- try:
- import httpx
- except ImportError: # pragma: no cover - optional dependency
- httpx = None
|