| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475 |
- import asyncio
- from dataclasses import dataclass, field
- import inspect
- import json
- from typing import Any, Dict, List
- import app.operators
- from app.pipeline.pipeline_manager import PipelineManager
- function_map = {name: getattr(app.operators, name) for name in dir(app.operators) if callable(getattr(app.operators, name))}
- @dataclass
- class Node(object):
- name: str
- dependencies: List[str]
- result: Any = None
- extra_fields: Dict[str, Any] = field(default_factory=dict)
- def __init__(self, name: str, dependencies: List[str], **kwargs):
- self.name = name
- self.dependencies = dependencies
- self.extra_fields = kwargs
- self.last_node = kwargs.get("last_node", False)
- class Workflow:
- def __init__(self) -> None:
- self.sorted_nodes = [Node(name=name, **PipelineManager().node_map[name]) for name in PipelineManager().nodes]
- self.node_map: Dict[str, Node] = {node.name: node for node in self.sorted_nodes}
- self.result = {}
- self.lock = asyncio.Lock()
- self.queue = asyncio.Queue()
-
- async def execute_node(self, node: Node):
- func = function_map[node.name]
- res = func(self.result)
- if inspect.isawaitable(res):
- try:
- res = await res
- except Exception as e:
- res = {"error": str(e)}
- node.result = res
- async with self.lock:
- self.result.update({node.name: res})
- if node.last_node:
- self.result.update({"result": res})
- if inspect.isasyncgen(res):
- async for item in res:
- await self.queue.put(item)
-
- async def execute_workflow(self, input_args: Dict[str, Any], stream = False):
- self.result.update(input_args)
- level_map = {node.name: 0 for node in self.sorted_nodes}
- for node in self.sorted_nodes:
- if node.dependencies:
- level_map[node.name] = max([level_map[dep] for dep in node.dependencies]) + 1
-
- max_level = max(level_map.values())
- for level in range(max_level + 1):
- current_level_nodes = [self.node_map[name] for name, lvl in level_map.items() if lvl == level]
- tasks = [self.execute_node(node) for node in current_level_nodes]
- try:
- await asyncio.gather(*tasks)
- except Exception as e:
- print(f"Error executing level {level}: {e}")
- raise e
-
- if stream:
- await self.queue.put({"event": "end", "data": json.dumps({"status": "done", "answer": "[Done]"}, ensure_ascii=False)})
- await self.queue.put(None)
- else:
- return self.result["result"]
|