workflow.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import asyncio
  2. from dataclasses import dataclass, field
  3. import inspect
  4. import json
  5. from typing import Any, Dict, List
  6. import app.operators
  7. from app.pipeline.pipeline_manager import PipelineManager
  8. function_map = {name: getattr(app.operators, name) for name in dir(app.operators) if callable(getattr(app.operators, name))}
  9. @dataclass
  10. class Node(object):
  11. name: str
  12. dependencies: List[str]
  13. result: Any = None
  14. extra_fields: Dict[str, Any] = field(default_factory=dict)
  15. def __init__(self, name: str, dependencies: List[str], **kwargs):
  16. self.name = name
  17. self.dependencies = dependencies
  18. self.extra_fields = kwargs
  19. self.last_node = kwargs.get("last_node", False)
  20. class Workflow:
  21. def __init__(self) -> None:
  22. self.sorted_nodes = [Node(name=name, **PipelineManager().node_map[name]) for name in PipelineManager().nodes]
  23. self.node_map: Dict[str, Node] = {node.name: node for node in self.sorted_nodes}
  24. self.result = {}
  25. self.lock = asyncio.Lock()
  26. self.queue = asyncio.Queue()
  27. async def execute_node(self, node: Node):
  28. func = function_map[node.name]
  29. res = func(self.result)
  30. if inspect.isawaitable(res):
  31. try:
  32. res = await res
  33. except Exception as e:
  34. res = {"error": str(e)}
  35. node.result = res
  36. async with self.lock:
  37. self.result.update({node.name: res})
  38. if node.last_node:
  39. self.result.update({"result": res})
  40. if inspect.isasyncgen(res):
  41. async for item in res:
  42. await self.queue.put(item)
  43. async def execute_workflow(self, input_args: Dict[str, Any], stream = False):
  44. self.result.update(input_args)
  45. level_map = {node.name: 0 for node in self.sorted_nodes}
  46. for node in self.sorted_nodes:
  47. if node.dependencies:
  48. level_map[node.name] = max([level_map[dep] for dep in node.dependencies]) + 1
  49. max_level = max(level_map.values())
  50. for level in range(max_level + 1):
  51. current_level_nodes = [self.node_map[name] for name, lvl in level_map.items() if lvl == level]
  52. tasks = [self.execute_node(node) for node in current_level_nodes]
  53. try:
  54. await asyncio.gather(*tasks)
  55. except Exception as e:
  56. print(f"Error executing level {level}: {e}")
  57. raise e
  58. if stream:
  59. await self.queue.put({"event": "end", "data": json.dumps({"status": "done", "answer": "[Done]"}, ensure_ascii=False)})
  60. await self.queue.put(None)
  61. else:
  62. return self.result["result"]