import asyncio from dataclasses import dataclass, field import inspect import json from typing import Any, Dict, List import app.operators from app.pipeline.pipeline_manager import PipelineManager function_map = {name: getattr(app.operators, name) for name in dir(app.operators) if callable(getattr(app.operators, name))} @dataclass class Node(object): name: str dependencies: List[str] result: Any = None extra_fields: Dict[str, Any] = field(default_factory=dict) def __init__(self, name: str, dependencies: List[str], **kwargs): self.name = name self.dependencies = dependencies self.extra_fields = kwargs self.last_node = kwargs.get("last_node", False) class Workflow: def __init__(self) -> None: self.sorted_nodes = [Node(name=name, **PipelineManager().node_map[name]) for name in PipelineManager().nodes] self.node_map: Dict[str, Node] = {node.name: node for node in self.sorted_nodes} self.result = {} self.lock = asyncio.Lock() self.queue = asyncio.Queue() async def execute_node(self, node: Node): func = function_map[node.name] res = func(self.result) if inspect.isawaitable(res): try: res = await res except Exception as e: res = {"error": str(e)} node.result = res async with self.lock: self.result.update({node.name: res}) if node.last_node: self.result.update({"result": res}) if inspect.isasyncgen(res): async for item in res: await self.queue.put(item) async def execute_workflow(self, input_args: Dict[str, Any], stream = False): self.result.update(input_args) level_map = {node.name: 0 for node in self.sorted_nodes} for node in self.sorted_nodes: if node.dependencies: level_map[node.name] = max([level_map[dep] for dep in node.dependencies]) + 1 max_level = max(level_map.values()) for level in range(max_level + 1): current_level_nodes = [self.node_map[name] for name, lvl in level_map.items() if lvl == level] tasks = [self.execute_node(node) for node in current_level_nodes] try: await asyncio.gather(*tasks) except Exception as e: print(f"Error executing level {level}: {e}") raise e if stream: await self.queue.put({"event": "end", "data": json.dumps({"status": "done", "answer": "[Done]"}, ensure_ascii=False)}) await self.queue.put(None) else: return self.result["result"]