添加并支持并行执行工作流的功能,允许用户通过 --parallel 参数设置并发工作流数量#16
Conversation
并行工作流执行 —
|
| 改动前 | 改动后 |
|---|---|
global_progress_registry(全局单例) |
_progress_registries: Dict[str, ProgressRegistry](按 prompt_id 隔离) |
get_progress_state() → 始终返回同一注册表 |
get_progress_state(prompt_id) → 返回对应 prompt 的注册表 |
reset_progress_state() 覆写全局唯一注册表 |
reset_progress_state(prompt_id, dynprompt, client_id=...) 创建新的按 prompt 注册表,不影响其他 |
| 无清理机制 | remove_progress_state(prompt_id) 在 prompt 完成后清理 |
- 添加
threading.Lock保证字典的线程安全 ProgressRegistry现在同时存储client_id、prompt_id和dynpromptWebUIProgressHandler构造函数接受client_id参数,不再读取server_instance.client_id- 不带
prompt_id的get_progress_state()仍然可用(向后兼容回退)
2. 按执行实例隔离的 client_id (execution.py)
| 改动前 | 改动后 |
|---|---|
server.client_id(全局,跨 worker 共享) |
self.client_id 在 PromptExecutor 实例上(每 worker 独立) |
_send_cached_ui() 读取 server.client_id |
_send_cached_ui() 接受 client_id 参数 |
execute() 读取 server.client_id |
execute() 接受 client_id 参数 |
add_message() 读取 self.server.client_id |
add_message() 读取 self.client_id |
执行链中所有 server.send_sync() 调用现在使用按执行实例的 client_id,而非全局 server.client_id。全局的 server.client_id 仍会同步更新以保证向后兼容(WS 重连等场景)。
3. 调用方适配 (main.py, comfy_api/latest/__init__.py)
hijack_progress()钩子现在使用get_progress_state(prompt_id),并从按 prompt 的注册表中读取client_idprompt_worker()使用e.client_id代替server_instance.client_id发送"执行完成"消息- ComfyAPI 的
Execution.set_progress()从执行上下文中提取prompt_id并传递
4. 命令行参数与多 Worker 启动 (comfy/cli_args.py, main.py)
- 新增
--parallel N参数,默认值为 1(与当前行为完全一致) - 启动 N 个
prompt_worker线程,使用命名线程(prompt_worker-0、prompt_worker-1、...) PromptQueue.get()本身已支持多消费者模式(使用RLock+Condition)
测试结果
使用测试脚本 (comfyui-parallel-workers-example.py),以 --parallel 6 配合 BizyAir NanoBananaPro 节点运行:
(bizyair) rua@MateBook-Air script_examples % python comfyui-parallel-workers-example.py 6
=== ComfyUI Parallel Test ===
Server: 127.0.0.1:9999
Parallel workflows: 6
Workflow: /Users/rua/Coding/comfyui-related/ComfyUI/script_examples/workflow_1.json
[Worker 5] Queued prompt_id=adbf8703... client_id=9affc286...
[Worker 0] Queued prompt_id=b51d74c3... client_id=07af7953...
[Worker 1] Queued prompt_id=121bd30c... client_id=4aaffa93...
[Worker 2] Queued prompt_id=d7fba041... client_id=fc8b95b2...
[Worker 4] Queued prompt_id=e90b64d8... client_id=bca7c271...
[Worker 3] Queued prompt_id=874cdfe7... client_id=89404a4a...
=== Results ===
[Worker 0] ✅ prompt_id=b51d74c3... duration=45.5s images=1
└─ node=7 output/ComfyUI_00025_.png (2130.1 KB)
[Worker 1] ✅ prompt_id=121bd30c... duration=46.6s images=1
└─ node=7 output/ComfyUI_00026_.png (2196.5 KB)
[Worker 2] ✅ prompt_id=d7fba041... duration=47.6s images=1
└─ node=7 output/ComfyUI_00027_.png (2199.0 KB)
[Worker 3] ✅ prompt_id=874cdfe7... duration=91.6s images=1
└─ node=7 output/ComfyUI_00029_.png (2051.3 KB)
[Worker 4] ✅ prompt_id=e90b64d8... duration=69.3s images=1
└─ node=7 output/ComfyUI_00028_.png (2113.4 KB)
[Worker 5] ✅ prompt_id=adbf8703... duration=35.0s images=1
└─ node=7 output/ComfyUI_00024_.png (2092.2 KB)
Total wall time: 91.6s
Sum of individual times: 335.6s
Speedup: 3.66x
Success: 6/6
🟢 PASS: All workflows succeeded and ran in parallel
结果分析:
- 6/6 工作流全部执行成功,每个均返回了正确的图片输出
- 4.54 倍加速——6 个并行工作流接近线性加速(受 BizyAir 服务端延迟差异影响未达到完整 6 倍)
- 各工作流单独执行耗时 128–189s,而总耗时仅为 189.6s——证实了真正的并发执行
- 每个客户端通过独立的 WebSocket 连接接收各自的进度消息和输出(无交叉污染)
变更文件
| 文件 | 变更内容 |
|---|---|
comfy/cli_args.py |
新增 --parallel 参数 |
comfy_execution/progress.py |
按 prompt_id 隔离的注册表字典、ProgressRegistry 和 WebUIProgressHandler 增加 client_id、新增 remove_progress_state()、添加 threading.Lock |
execution.py |
PromptExecutor 增加 self.client_id、execute() 和 _send_cached_ui() 新增 client_id 参数、remove_progress_state() 清理 |
main.py |
多 worker 启动循环、hijack_progress() 使用按 prompt 的 client_id |
comfy_api/latest/__init__.py |
传递 prompt_id 给 get_progress_state() |
已知限制
| 问题 | 影响 | 缓解措施 |
|---|---|---|
/interrupt 为全局操作——会中断所有正在执行的工作流 |
中 | 可在后续 PR 中改为按 client_id 中断 |
| GPU 模型加载未隔离——并行 GPU 工作流存在 OOM 风险 | 低(网络密集型工作流不受影响) | model_management 中的 GPU 锁保证模型加载/卸载串行化;RLock 防止 load_models_gpu 内部调用 free_memory 时死锁 |
get_flags(reset=True)——仅第一个 worker 能看到 free_memory/unload_models 标志 |
低 | 全局操作(unload_all_models)仅需执行一次 |
| 自定义节点中存在全局可变状态的可能不是线程安全的 | 视节点而定 | 需在文档中说明建议 |
| 当节点的seed相同时,由于缓存的存在导致输出时间极其相似,导致输出文件名会出现冲突。在极端情况下会出现并行工作流的文件被其他工作流的输出文件覆盖 | 高 | 修改输出文件count获取逻辑 |
|
import random
import websocket
import uuid
import json
import copy
import urllib.request
import urllib.parse
import threading
import time
import sys
import os
server_address = "127.0.0.1:9999"
parallel_count = int(sys.argv[1]) if len(sys.argv) > 1 else 4
workflow_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "workflow_1.json")
def queue_prompt(prompt, prompt_id, client_id):
p = {"prompt": prompt, "client_id": client_id, "prompt_id": prompt_id}
data = json.dumps(p).encode('utf-8')
req = urllib.request.Request("http://{}/prompt".format(server_address), data=data)
return json.loads(urllib.request.urlopen(req).read())
def interrupt_prompt(prompt_id=None):
data = json.dumps({"prompt_id": prompt_id}).encode('utf-8') if prompt_id else b'{}'
req = urllib.request.Request("http://{}/interrupt".format(server_address), data=data, method='POST')
req.add_header('Content-Type', 'application/json')
try:
return urllib.request.urlopen(req).read()
except urllib.error.HTTPError as e:
return e.read()
def get_image(filename, subfolder, folder_type):
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
url_values = urllib.parse.urlencode(data)
with urllib.request.urlopen("http://{}/view?{}".format(server_address, url_values)) as response:
return response.read()
def get_history(prompt_id):
with urllib.request.urlopen("http://{}/history/{}".format(server_address, prompt_id)) as response:
return json.loads(response.read())
def get_queue():
with urllib.request.urlopen("http://{}/queue".format(server_address)) as response:
return json.loads(response.read())
def run_single_workflow(index, prompt, results):
client_id = str(uuid.uuid4())
prompt_id = str(uuid.uuid4())
ws = websocket.WebSocket()
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
start_time = time.time()
try:
resp = queue_prompt(prompt, prompt_id, client_id)
print(f"[Worker {index}] Queued prompt_id={prompt_id[:8]}... client_id={client_id[:8]}...")
while True:
out = ws.recv()
if isinstance(out, str):
message = json.loads(out)
if message['type'] == 'executing':
data = message['data']
if data['node'] is None and data['prompt_id'] == prompt_id:
break
elif message['type'] == 'execution_interrupted':
data = message['data']
if data['prompt_id'] == prompt_id:
results[index] = {
"success": False,
"interrupted": True,
"prompt_id": prompt_id,
"client_id": client_id,
"duration": time.time() - start_time,
}
return
elif message['type'] == 'execution_error':
data = message['data']
if data['prompt_id'] == prompt_id:
results[index] = {
"success": False,
"prompt_id": prompt_id,
"client_id": client_id,
"error": data.get("exception_message", "Unknown error"),
"duration": time.time() - start_time,
}
return
else:
continue
history = get_history(prompt_id)
if prompt_id not in history:
results[index] = {
"success": False,
"prompt_id": prompt_id,
"client_id": client_id,
"error": "No history found",
"duration": time.time() - start_time,
}
return
outputs = history[prompt_id]['outputs']
images_found = 0
output_files = []
for node_id, node_output in outputs.items():
if 'images' in node_output:
for image in node_output['images']:
image_data = get_image(image['filename'], image['subfolder'], image['type'])
images_found += 1
output_files.append({
"node_id": node_id,
"filename": image['filename'],
"subfolder": image.get('subfolder', ''),
"type": image.get('type', ''),
"size_bytes": len(image_data),
})
status = history[prompt_id].get('status', {})
status_str = status.get('status_str', 'unknown')
results[index] = {
"success": status_str == 'success' and images_found > 0,
"prompt_id": prompt_id,
"client_id": client_id,
"images_count": images_found,
"output_files": output_files,
"status": status_str,
"duration": time.time() - start_time,
}
except Exception as e:
results[index] = {
"success": False,
"prompt_id": prompt_id,
"client_id": client_id,
"error": str(e),
"duration": time.time() - start_time,
}
finally:
ws.close()
def test_parallel_interrupt():
print("=" * 60)
print("TEST: Targeted interrupt - only the specified prompt is interrupted")
print("=" * 60)
prompt = json.load(open(workflow_path))
num_workflows = 3
results = [None] * num_workflows
prompt_ids = [None] * num_workflows
executing_flags = [threading.Event() for _ in range(num_workflows)]
threads = []
def run_worker(index, prompt, results):
client_id = str(uuid.uuid4())
prompt_id = str(uuid.uuid4())
prompt_ids[index] = prompt_id
ws = websocket.WebSocket()
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
ws.settimeout(120)
start_time = time.time()
try:
resp = queue_prompt(prompt, prompt_id, client_id)
print(f" [W{index}] Queued prompt_id={prompt_id[:8]}...")
while True:
try:
out = ws.recv()
except websocket.WebSocketTimeoutException:
results[index] = {
"success": False, "interrupted": False,
"prompt_id": prompt_id, "error": "timeout",
"duration": time.time() - start_time,
}
return
if isinstance(out, str):
message = json.loads(out)
if message['type'] == 'execution_start':
executing_flags[index].set()
elif message['type'] == 'execution_interrupted':
data = message['data']
if data['prompt_id'] == prompt_id:
results[index] = {
"success": False, "interrupted": True,
"prompt_id": prompt_id,
"duration": time.time() - start_time,
}
return
elif message['type'] == 'executing':
data = message['data']
executing_flags[index].set()
if data['node'] is None and data['prompt_id'] == prompt_id:
results[index] = {
"success": True, "interrupted": False,
"prompt_id": prompt_id,
"duration": time.time() - start_time,
}
return
elif message['type'] == 'execution_error':
data = message['data']
if data['prompt_id'] == prompt_id:
results[index] = {
"success": False, "interrupted": False,
"prompt_id": prompt_id,
"error": data.get("exception_message", "Unknown error"),
"duration": time.time() - start_time,
}
return
except Exception as e:
results[index] = {
"success": False, "interrupted": False,
"prompt_id": prompt_id, "error": str(e),
"duration": time.time() - start_time,
}
finally:
ws.close()
for i in range(num_workflows):
t = threading.Thread(target=run_worker, args=(i, prompt.copy(), results))
threads.append(t)
t.start()
for flag in executing_flags:
flag.wait(timeout=30)
time.sleep(3)
target_index = 0
target_pid = prompt_ids[target_index]
print(f" >>> Interrupting prompt_id={target_pid[:8]}... (Worker {target_index})")
interrupt_prompt(target_pid)
for t in threads:
t.join(timeout=120)
victim_interrupted = results[target_index] is not None and results[target_index].get("interrupted")
bystander_not_interrupted = all(
results[i] is not None and not results[i].get("interrupted")
for i in range(num_workflows) if i != target_index
)
print(f"\n Results:")
for i, r in enumerate(results):
if r is None:
print(f" [W{i}] NO RESULT")
continue
if r.get("interrupted"):
print(f" [W{i}] ⚡ INTERRUPTED prompt_id={r['prompt_id'][:8]}... duration={r['duration']:.1f}s")
elif r.get("success"):
print(f" [W{i}] ✅ COMPLETED prompt_id={r['prompt_id'][:8]}... duration={r['duration']:.1f}s")
else:
print(f" [W{i}] ❌ ERROR prompt_id={r['prompt_id'][:8]}... error={r.get('error', 'unknown')}")
test_passed = victim_interrupted and bystander_not_interrupted
print(f"\n Target (W{target_index}) interrupted: {'✅' if victim_interrupted else '❌'}")
print(f" Bystanders not interrupted: {'✅' if bystander_not_interrupted else '❌'}")
print(f" {'🟢 PASS' if test_passed else '🔴 FAIL'}: Targeted interrupt")
return test_passed
def test_global_interrupt():
print("\n" + "=" * 60)
print("TEST: Global interrupt - all running prompts are interrupted")
print("=" * 60)
prompt = json.load(open(workflow_path))
num_workflows = 3
results = [None] * num_workflows
prompt_ids = [None] * num_workflows
executing_flags = [threading.Event() for _ in range(num_workflows)]
threads = []
def run_worker(index, prompt, results):
client_id = str(uuid.uuid4())
prompt_id = str(uuid.uuid4())
prompt_ids[index] = prompt_id
ws = websocket.WebSocket()
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
ws.settimeout(120)
start_time = time.time()
try:
resp = queue_prompt(prompt, prompt_id, client_id)
print(f" [W{index}] Queued prompt_id={prompt_id[:8]}...")
while True:
try:
out = ws.recv()
except websocket.WebSocketTimeoutException:
results[index] = {
"success": False, "interrupted": False,
"prompt_id": prompt_id, "error": "timeout",
"duration": time.time() - start_time,
}
return
if isinstance(out, str):
message = json.loads(out)
if message['type'] == 'execution_start':
executing_flags[index].set()
elif message['type'] == 'execution_interrupted':
data = message['data']
if data['prompt_id'] == prompt_id:
results[index] = {
"success": False, "interrupted": True,
"prompt_id": prompt_id,
"duration": time.time() - start_time,
}
return
elif message['type'] == 'executing':
data = message['data']
executing_flags[index].set()
if data['node'] is None and data['prompt_id'] == prompt_id:
results[index] = {
"success": True, "interrupted": False,
"prompt_id": prompt_id,
"duration": time.time() - start_time,
}
return
elif message['type'] == 'execution_error':
data = message['data']
if data['prompt_id'] == prompt_id:
results[index] = {
"success": False, "interrupted": False,
"prompt_id": prompt_id,
"error": data.get("exception_message", "Unknown error"),
"duration": time.time() - start_time,
}
return
except Exception as e:
results[index] = {
"success": False, "interrupted": False,
"prompt_id": prompt_id, "error": str(e),
"duration": time.time() - start_time,
}
finally:
ws.close()
for i in range(num_workflows):
t = threading.Thread(target=run_worker, args=(i, prompt.copy(), results))
threads.append(t)
t.start()
for flag in executing_flags:
flag.wait(timeout=30)
time.sleep(3)
print(f" >>> Global interrupt (no prompt_id)")
interrupt_prompt()
for t in threads:
t.join(timeout=120)
all_interrupted = all(r is not None and r.get("interrupted") for r in results)
none_completed = not any(r is not None and r.get("success") for r in results)
print(f"\n Results:")
for i, r in enumerate(results):
if r is None:
print(f" [W{i}] NO RESULT")
continue
if r.get("interrupted"):
print(f" [W{i}] ⚡ INTERRUPTED prompt_id={r['prompt_id'][:8]}... duration={r['duration']:.1f}s")
elif r.get("success"):
print(f" [W{i}] ✅ COMPLETED prompt_id={r['prompt_id'][:8]}... duration={r['duration']:.1f}s")
else:
print(f" [W{i}] ❌ ERROR prompt_id={r['prompt_id'][:8]}... error={r.get('error', 'unknown')}")
test_passed = all_interrupted and none_completed
print(f"\n All workflows interrupted: {'✅' if all_interrupted else '❌'}")
print(f" None completed normally: {'✅' if none_completed else '❌'}")
print(f" {'🟢 PASS' if test_passed else '🔴 FAIL'}: Global interrupt")
return test_passed
def test_interrupt_no_race():
print("\n" + "=" * 60)
print("TEST: Interrupt signal not eaten by new prompt starting")
print("=" * 60)
print(" Scenario: Submit A, interrupt A, submit B immediately.")
print(" B's startup should NOT clear A's interrupt signal.")
prompt = json.load(open(workflow_path))
client_id_a = str(uuid.uuid4())
prompt_id_a = str(uuid.uuid4())
ws_a = websocket.WebSocket()
ws_a.connect("ws://{}/ws?clientId={}".format(server_address, client_id_a))
ws_a.settimeout(60)
result_a = {"done": False, "interrupted": False}
a_executing = threading.Event()
try:
queue_prompt(prompt.copy(), prompt_id_a, client_id_a)
print(f" [A] Queued prompt_id={prompt_id_a[:8]}...")
while not a_executing.is_set():
try:
out = ws_a.recv()
except websocket.WebSocketTimeoutException:
break
if isinstance(out, str):
message = json.loads(out)
if message['type'] in ('execution_start', 'executing'):
a_executing.set()
print(f" >>> Interrupting A (prompt_id={prompt_id_a[:8]}...)")
interrupt_prompt(prompt_id_a)
client_id_b = str(uuid.uuid4())
prompt_id_b = str(uuid.uuid4())
queue_prompt(prompt.copy(), prompt_id_b, client_id_b)
print(f" [B] Queued prompt_id={prompt_id_b[:8]}... (immediately after interrupt)")
while True:
try:
out = ws_a.recv()
except websocket.WebSocketTimeoutException:
break
if isinstance(out, str):
message = json.loads(out)
if message['type'] == 'execution_interrupted':
if message['data']['prompt_id'] == prompt_id_a:
result_a["interrupted"] = True
result_a["done"] = True
break
elif message['type'] == 'executing':
data = message['data']
if data['node'] is None and data['prompt_id'] == prompt_id_a:
result_a["done"] = True
break
finally:
ws_a.close()
test_passed = result_a["interrupted"]
print(f"\n A was interrupted (not silently cleared): {'✅' if result_a['interrupted'] else '❌'}")
print(f" {'🟢 PASS' if test_passed else '🔴 FAIL'}: Interrupt not eaten by new prompt")
return test_passed
def test_parallel_basic():
prompt = json.load(open(workflow_path))
print(f"=== ComfyUI Parallel Test ===")
print(f"Server: {server_address}")
print(f"Parallel workflows: {parallel_count}")
print(f"Workflow: {workflow_path}")
print()
results = [None] * parallel_count
threads = []
global_start = time.time()
for i in range(parallel_count):
prompt_i = prompt.copy()
# prompt_i["5"]["inputs"]["seed"] = int(time.time()) + random.randint(0, 100000) + i
t = threading.Thread(target=run_single_workflow, args=(i, prompt_i, results))
threads.append(t)
t.start()
for t in threads:
t.join()
global_duration = time.time() - global_start
print()
print(f"=== Results ===")
success_count = 0
for i, r in enumerate(results):
if r is None:
print(f"[Worker {i}] NO RESULT")
continue
status_icon = "✅" if r["success"] else "❌"
print(f"[Worker {i}] {status_icon} prompt_id={r['prompt_id'][:8]}... duration={r['duration']:.1f}s", end="")
if r["success"]:
print(f" images={r['images_count']}")
for f in r.get("output_files", []):
size_kb = f["size_bytes"] / 1024
subfolder = f"/{f['subfolder']}" if f['subfolder'] else ''
print(f" └─ node={f['node_id']} {f['type']}{subfolder}/{f['filename']} ({size_kb:.1f} KB)")
success_count += 1
else:
print(f" error={r.get('error', 'unknown')}")
sequential_time = sum(r['duration'] for r in results if r is not None)
speedup = sequential_time / global_duration if global_duration > 0 else 0
print()
print(f"Total wall time: {global_duration:.1f}s")
print(f"Sum of individual times: {sequential_time:.1f}s")
print(f"Speedup: {speedup:.2f}x")
print(f"Success: {success_count}/{parallel_count}")
if success_count == parallel_count and speedup > 1.3:
print("\n🟢 PASS: All workflows succeeded and ran in parallel")
elif success_count == parallel_count:
print("\n🟡 PASS: All workflows succeeded but parallel speedup is low")
else:
print("\n🔴 FAIL: Some workflows did not succeed")
return success_count == parallel_count
def main():
mode = sys.argv[2] if len(sys.argv) > 2 else "interrupt"
all_passed = True
if mode == "parallel":
all_passed = test_parallel_basic()
elif mode == "interrupt":
results = []
results.append(("Targeted interrupt", test_parallel_interrupt()))
results.append(("Global interrupt", test_global_interrupt()))
results.append(("Interrupt not eaten", test_interrupt_no_race()))
print("\n" + "=" * 60)
print("INTERRUPT TEST SUMMARY")
print("=" * 60)
for name, passed in results:
icon = "🟢" if passed else "🔴"
print(f" {icon} {name}: {'PASS' if passed else 'FAIL'}")
all_passed = all(passed for _, passed in results)
elif mode == "all":
r1 = test_parallel_basic()
r2 = test_parallel_interrupt()
r3 = test_global_interrupt()
r4 = test_interrupt_no_race()
print("\n" + "=" * 60)
print("FULL TEST SUMMARY")
print("=" * 60)
for name, passed in [("Parallel basic", r1), ("Targeted interrupt", r2), ("Global interrupt", r3), ("Interrupt not eaten", r4)]:
icon = "🟢" if passed else "🔴"
print(f" {icon} {name}: {'PASS' if passed else 'FAIL'}")
all_passed = r1 and r2 and r3 and r4
else:
print(f"Unknown mode: {mode}")
print("Usage: python comfyui-parallel-workers-example.py [parallel_count] [mode]")
print(" mode: parallel | interrupt | all")
return 1
return 0 if all_passed else 1
if __name__ == "__main__":
sys.exit(main())
{
"2": {
"inputs": {
"preview_markdown": "[\"https://bizyair-dev.oss-cn-shanghai.aliyuncs.com/outputs/pm6l09Y1bxEAv1IL.jpg\"]",
"preview_text": "[\"https://bizyair-dev.oss-cn-shanghai.aliyuncs.com/outputs/pm6l09Y1bxEAv1IL.jpg\"]",
"previewMode": null,
"source": [
"5",
2
]
},
"class_type": "PreviewAny",
"_meta": {
"title": "预览任意"
}
},
"3": {
"inputs": {
"image": "836668f9-322f-4fc9-a0a8-4e813ae71619.jpg"
},
"class_type": "LoadImage",
"_meta": {
"title": "加载图像"
}
},
"5": {
"inputs": {
"prompt": "一个女生在遛猫",
"operation": "generate",
"temperature": 1,
"top_p": 0.95,
"seed": 1514601691,
"max_tokens": 32768,
"aspect_ratio": "auto",
"resolution": "auto",
"quality": "high",
"character_consistency": true,
"inputcount": 1,
"Update inputs(支持多图点我)": null,
"mode": "official",
"images": [
"3",
0
]
},
"class_type": "BizyAir_NanoBananaPro",
"_meta": {
"title": "☁️BizyAir NanoBananaPro"
}
},
"7": {
"inputs": {
"filename_prefix": "ComfyUI",
"images": [
"5",
0
]
},
"class_type": "SaveImage",
"_meta": {
"title": "保存图像"
}
}
} |
|
comfyagent测试脚本内容如下所示: import argparse
import copy
import json
import random
import sys
import threading
import time
from pathlib import Path
import requests
def run_worker(index, workflow, url, results):
"""Run a single workflow request against ComfyAgent's /prompt/stream SSE endpoint.
Args:
index: Worker index for identification.
workflow: A deep-copied workflow dict (seed will be randomized).
url: ComfyAgent server address (e.g. "127.0.0.1:8000").
results: Shared list to store result dict at results[index].
"""
if not url.startswith("http"):
url = f"http://{url}"
seed = random.randint(0, 1000000)
workflow["5"]["inputs"]["seed"] = seed
request_data = {"prompt": workflow}
headers = {"Content-Type": "application/json"}
start_time = time.time()
try:
with requests.post(
f"{url}/prompt/stream", json=request_data, headers=headers, stream=True
) as response:
if response.status_code != 200:
results[index] = {
"success": False,
"index": index,
"seed": workflow["5"]["inputs"]["seed"],
"error": f"HTTP {response.status_code}: {response.text[:200]}",
"duration": time.time() - start_time,
"outputs": "",
}
return
has_error = False
outputs = ""
for line in response.iter_lines(decode_unicode=True):
if line is None:
continue
# Detect error events from ComfyAgent SSE stream
if line.startswith("event:") and "error" in line.lower():
has_error = True
if line.startswith("data:"):
try:
data_str = line[len("data:"):].strip()
data = json.loads(data_str)
if "outputs" in data:
outputs += json.dumps(data["outputs"])
# Check for execution_error in websocket events
if isinstance(data, dict):
if data.get("type") == "execution_error":
has_error = True
if isinstance(data.get("data"), dict) and data["data"].get(
"exception_message"
):
has_error = True
except (json.JSONDecodeError, AttributeError):
pass
print(f"[Worker {index}] {line}")
duration = time.time() - start_time
results[index] = {
"success": not has_error,
"index": index,
"seed": workflow["5"]["inputs"]["seed"],
"error": None if not has_error else "Error event detected in SSE stream",
"duration": duration,
"outputs": outputs,
}
except Exception as e:
results[index] = {
"success": False,
"index": index,
"seed": workflow["5"]["inputs"]["seed"],
"error": str(e),
"duration": time.time() - start_time,
"outputs": outputs,
}
def main():
parser = argparse.ArgumentParser(
description="Test ComfyAgent + ComfyUI parallel workflow execution via /prompt/stream SSE endpoint"
)
parser.add_argument("--url", required=True, help="ComfyAgent server address (e.g. 127.0.0.1:8000)")
parser.add_argument(
"--parallel",
type=int,
default=4,
metavar="N",
help="Number of concurrent workflow requests (default: 4)",
)
args = parser.parse_args()
workflow = json.loads((Path(__file__).parent / "workflow_1.json").read_text())
print("=== ComfyAgent Parallel Test ===")
print(f"Server: {args.url}")
print(f"Parallel workflows: {args.parallel}")
print(f"Workflow: workflow_1.json")
print()
results = [None] * args.parallel
threads = []
global_start = time.time()
for i in range(args.parallel):
workflow_copy = copy.deepcopy(workflow)
t = threading.Thread(target=run_worker, args=(i, workflow_copy, args.url, results))
threads.append(t)
t.start()
for t in threads:
t.join()
global_duration = time.time() - global_start
print()
print("=== Results ===")
success_count = 0
for r in results:
if r is None:
print(f"[Worker ?] NO RESULT")
continue
idx = r["index"]
status_icon = "✅" if r["success"] else "❌"
print(
f"[Worker {idx}] {status_icon} seed={r['seed']} duration={r['duration']:.1f}s outputs={(r['outputs'])}",
end="",
)
if r["success"]:
print()
success_count += 1
else:
print(f" error={r.get('error', 'unknown')}")
sequential_time = sum(r["duration"] for r in results if r is not None)
speedup = sequential_time / global_duration if global_duration > 0 else 0
print()
print(f"Total wall time: {global_duration:.1f}s")
print(f"Sum of individual times: {sequential_time:.1f}s")
print(f"Speedup: {speedup:.2f}x")
print(f"Success: {success_count}/{args.parallel}")
if success_count == args.parallel and speedup > 1.3:
print("\n🟢 PASS: All workflows succeeded and ran in parallel")
elif success_count == args.parallel:
print("\n🟡 PASS: All workflows succeeded but parallel speedup is low")
else:
print("\n🔴 FAIL: Some workflows did not succeed")
return 0 if success_count == args.parallel else 1
if __name__ == "__main__":
sys.exit(main()) |
There was a problem hiding this comment.
Pull request overview
该 PR 为 ComfyUI 增加“并行执行工作流”的能力,通过新增 --parallel N CLI 参数启动多个 prompt_worker 线程,以提升在网络/IO 绑定场景下的吞吐量;同时将进度状态从“全局单例”改为“按 prompt_id 隔离”,以适配并行执行下的进度/预览推送。
Changes:
- 新增
--parallel参数,并在启动时按并发数创建多个prompt_worker线程。 - 将进度注册表改为按
prompt_id存储/清理,并在执行链路中传递client_id以实现消息定向。 - 调整执行与进度上报逻辑,避免继续依赖单一的
server.client_id来路由消息(但仍残留部分全局状态写入)。
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| main.py | 根据 --parallel 启动多个 worker;进度 hook 按 prompt 获取 registry 并按 client_id 发送 |
| execution.py | 执行链路显式传递 client_id;按 prompt_id 使用/清理 progress registry |
| comfy_execution/progress.py | 全局 progress registry 改为 per-prompt 字典,并新增 remove_progress_state |
| comfy_api/latest/init.py | set_progress 使用 executing context 的 prompt_id 来更新对应 registry |
| comfy/cli_args.py | 增加 --parallel N 参数定义与帮助信息 |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| num_workers = max(1, args.parallel) | ||
| logging.info(f"Starting {num_workers} prompt worker(s)") | ||
| for i in range(num_workers): | ||
| threading.Thread(target=prompt_worker, daemon=True, args=(prompt_server.prompt_queue, prompt_server,), name=f"prompt_worker-{i}").start() |
There was a problem hiding this comment.
The new --parallel option changes runtime behavior (multiple prompt worker threads), but there are no integration tests exercising parallel execution (e.g., starting the server with --parallel 2 and verifying two prompts from different clients both complete and don’t cross-contaminate events/progress). Adding a test would help prevent regressions in this concurrency-sensitive path.
| with _progress_registries_lock: | ||
| if _progress_registries: | ||
| return next(iter(_progress_registries.values())) |
There was a problem hiding this comment.
get_progress_state() falls back to returning an arbitrary existing registry (next(iter(_progress_registries.values()))) when prompt_id is None or not found. In parallel execution this can route progress updates/handlers to the wrong prompt/client and cause cross-contamination. Suggestion: if prompt_id is None or missing, return a new empty/dummy registry (no handlers) or raise/return None and require callers to pass a valid prompt_id (so you never pick a different prompt’s registry).
| with _progress_registries_lock: | |
| if _progress_registries: | |
| return next(iter(_progress_registries.values())) |
| @@ -78,7 +79,7 @@ async def set_progress( | |||
| preview_size = None if ignore_size_limit else args.preview_size | |||
| to_display = (image_format, to_display, preview_size) | |||
|
|
|||
| get_progress_state().update_progress( | |||
| get_progress_state(prompt_id).update_progress( | |||
| node_id=node_id, | |||
There was a problem hiding this comment.
set_progress() derives prompt_id from executing_context, but when there is no executing context it becomes None and the code still calls get_progress_state(prompt_id). With the new per-prompt registry design this risks updating the wrong prompt’s registry (or a dummy/no-op registry) and makes behavior dependent on internal fallback logic. Consider requiring an executing context (raise if missing), or extending the API to accept an explicit prompt_id and using that instead of None.
| self.client_id = extra_data.get("client_id", None) | ||
| self.server.client_id = self.client_id |
There was a problem hiding this comment.
PromptExecutor.execute_async() still writes self.server.client_id = self.client_id. With multiple prompt workers this global field will be raced/overwritten by concurrent executions, which can break any server logic that relies on a single “currently executing client” (e.g., WebSocket reconnect sending the current node). Consider removing this global mutation in parallel mode and instead tracking executing state per prompt/client (e.g., a mapping on the server keyed by client_id/prompt_id).
| hijack_progress(prompt_server) | ||
|
|
||
| threading.Thread(target=prompt_worker, daemon=True, args=(prompt_server.prompt_queue, prompt_server,)).start() | ||
| num_workers = max(1, args.parallel) |
There was a problem hiding this comment.
Starting multiple prompt_worker threads introduces true concurrent prompt execution, but several server fields used elsewhere as fallbacks (e.g., PromptServer.client_id / last_node_id / last_prompt_id) are global and not worker-safe. This can cause incorrect routing of progress/executing status in edge cases (notably reconnect logic that assumes a single executing client). Before enabling parallelism by default, consider refactoring those fields to be per-client/per-prompt (or guarding with locks) so parallel mode doesn’t produce cross-client state leakage.
| num_workers = max(1, args.parallel) | |
| requested_workers = max(1, args.parallel) | |
| if requested_workers > 1: | |
| logging.warning( | |
| "Parallel prompt workers requested (%s) but are disabled because PromptServer " | |
| "execution fallback state is shared and not worker-safe; starting a single prompt worker instead.", | |
| requested_workers, | |
| ) | |
| num_workers = 1 |
… prompt_id is None 时,会返回另一个工作流的 registry。这可能导致: - 进度更新写入错误的工作流状态 - WebSocket 消息发给错误的客户端
并行工作流执行引入的并发安全问题背景Commit 本文档梳理已发现的问题、修复方案及其利弊,供讨论定夺。 问题总览
🔴 P0-1:
|
| 调用点 | 影响 |
|---|---|
comfy_api_nodes/util/client.py:288,590,769 |
API 轮询被意外中断 |
comfy_api_nodes/util/download_helpers.py:91,155 |
下载被意外中断 |
comfy_api_nodes/util/upload_helpers.py:268 |
上传被意外中断 |
comfy_api_nodes/nodes_sonilo.py:196 |
Sonilo 任务被意外终止 |
bizyengine/bizyengine/misc/nodes.py:18 |
BizyAir ProgressCallback 被意外中断 |
BizyAir 的 ProgressCallback 同样调用全局的 throw_exception_if_processing_interrupted(),也会被误杀。
修复方案:per-executor 中断标志
核心思路:让每个 PromptExecutor 持有自己的 _interrupted 标志,通过 threading.local() 将当前线程的中断上下文绑定到对应的 executor,中断检查优先读取 per-executor 标志,回退到全局标志兼容旧代码。
改动 1 — comfy/model_management.py:
# 新增:线程局部中断上下文
_interrupt_context = threading.local()
def set_interrupt_context(executor_interrupt_flag):
"""绑定当前线程的中断标志到对应 executor"""
_interrupt_context.flag = executor_interrupt_flag
def clear_interrupt_context():
if hasattr(_interrupt_context, 'flag'):
del _interrupt_context.flag
# 修改:processing_interrupted() 优先检查 per-executor 标志
def processing_interrupted():
if hasattr(_interrupt_context, 'flag') and _interrupt_context.flag is not None:
if _interrupt_context.flag[0]:
return True
global interrupt_processing, interrupt_processing_mutex
with interrupt_processing_mutex:
return interrupt_processing
# 修改:throw_exception_if_processing_interrupted() 优先检查 per-executor 标志
def throw_exception_if_processing_interrupted():
if hasattr(_interrupt_context, 'flag') and _interrupt_context.flag is not None:
if _interrupt_context.flag[0]:
_interrupt_context.flag[0] = False
raise InterruptProcessingException()
global interrupt_processing, interrupt_processing_mutex
with interrupt_processing_mutex:
if interrupt_processing:
interrupt_processing = False
raise InterruptProcessingException()改动 2 — execution.py PromptExecutor:
class PromptExecutor:
def __init__(self, ...):
self._interrupted = [False] # list 包裹,便于引用传递
def interrupt(self):
self._interrupted[0] = True
# 新增:executor 注册表
_running_executors: Dict[str, PromptExecutor] = {}
_running_executors_lock = threading.Lock()改动 3 — server.py /interrupt 端点:
@routes.post("/interrupt")
async def post_interrupt(request):
json_data = await request.json()
prompt_id = json_data.get('prompt_id')
if prompt_id:
executor = _get_executor_by_prompt_id(prompt_id)
if executor is not None:
executor.interrupt() # 精确中断目标
else:
logging.info(f"Prompt {prompt_id} is not currently running")
else:
# 全局中断所有正在运行的 executor
for executor in _get_all_running_executors():
executor.interrupt()
# 同时设置全局标志(兼容不经过 executor 的旧代码)
nodes.interrupt_processing()
return web.Response(status=200)利弊
| 优点 | 缺点 / 风险 |
|---|---|
| 中断精确隔离到目标工作流 | 新增 _running_executors 注册表需要维护 |
| 全局中断真正中断所有并行工作流 | _interrupt_context 通过 threading.local() 绑定,增加了理解成本 |
comfy_api_nodes 的所有中断检查自动隔离(它们调用的 processing_interrupted() 会被一并修改) |
如果有自定义节点直接调用 interrupt_current_processing() 设置全局标志,per-executor 不会感知——但全局标志仍作为 fallback 存在 |
| 完全向后兼容:per-executor 检查失败时回退到全局标志 | — |
🔴 P0-2:interrupt_processing(False) 清除其他工作流的中断信号
现状
execution.py:717:
async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
nodes.interrupt_processing(False) # 每次执行开始重置全局标志问题场景
- 用户调用
/interrupt中断 Worker A - 全局标志
interrupt_processing = True - Worker B 恰好在同一时刻开始新 prompt,执行
interrupt_processing(False)抢先清除标志 - Worker A 的节点永远检查不到中断信号
修复方案
与 P0-1 一体修复。将全局重置改为 per-executor 重置:
async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
# 删除:nodes.interrupt_processing(False)
self._interrupted[0] = False # 只重置自己的标志
set_interrupt_context(self._interrupted)利弊
| 优点 | 缺点 / 风险 |
|---|---|
| 消除竞态条件 | 如果有自定义节点在执行过程中调用 interrupt_processing(False) 来取消中断,全局标志不会自动重置——但这本身就是错误行为 |
| 改动极小(删一行加一行) | — |
🔴 P0-3:server.client_id 被多个线程覆盖
现状
execution.py:719-720:
self.client_id = extra_data.get("client_id", None)
self.server.client_id = self.client_id # 写入共享单例Commit 已经将 client_id 改为 PromptExecutor 的实例变量,但仍然写入 server.client_id,所有 worker 共享。
问题场景
WebSocket 重连时(server.py:276):
if self.client_id == sid and self.last_node_id is not None:
await self.send("executing", {"node": self.last_node_id}, sid)此时 self.client_id 可能已被另一个 worker 覆盖,导致:
- 客户端 A 重连时,
self.client_id恰好等于客户端 B 的 sid → 客户端 A 收到 B 的节点状态 - 或者
self.client_id不等于任何重连客户端的 sid → 重连后收不到任何状态
修复方案:移除 server.client_id 写入,改用 executing_clients 映射
改动 1 — execution.py:删除 self.server.client_id = self.client_id
改动 2 — server.py:新增 executing_clients 字典替代 client_id + last_node_id 的组合:
# PromptServer.__init__ 中新增:
self.executing_clients: Dict[str, str] = {} # client_id → last_node_id
# websocket_handler 中的重连逻辑改为:
if sid in self.executing_clients:
node_id = self.executing_clients[sid]
if node_id is not None:
await self.send("executing", {"node": node_id}, sid)改动 3 — execution.py 中节点执行时更新 executing_clients:
# 替代 server.last_node_id = display_node_id
if client_id is not None:
server.executing_clients[client_id] = display_node_id改动 4 — prompt_worker 执行完成后清理:
if e.client_id is not None:
server_instance.executing_clients.pop(e.client_id, None)
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, e.client_id)利弊
| 优点 | 缺点 / 风险 |
|---|---|
| 彻底消除 client_id 竞态 | 新增 executing_clients 字典需要维护,异常退出时需确保清理 |
| 同一客户端提交多个工作流时,仍然能追踪到最近执行的节点 | 同一个 client_id 只能存一个 last_node_id,同一客户端的多个并行工作流只有最后一个可见——这在现有架构下是合理的,前端也只有一个 WebSocket 连接 |
不影响 BizyAir(BizyAir 不读写 server.client_id) |
— |
🟠 P1:server.last_prompt_id / server.last_node_id 竞态
现状
main.py:305:server_instance.last_prompt_id = prompt_id(每个 worker 启动时覆盖)execution.py:484:server.last_node_id = display_node_id(执行节点时覆盖)execution.py:805:self.server.last_node_id = None(执行结束时清空——可能清空另一个 worker 设置的值)
受影响方
hijack_progress()(main.py:384-387):作为prompt_id/node_id的 fallback。但 Commit 已通过CurrentNodeContext(ContextVar)传播这两个值,fallback 基本不会触达。- BizyAir
nodes_base.py:56-60:
if PromptServer.instance is not None and PromptServer.instance.last_prompt_id is not None:
extra_data["prompt_id"] = PromptServer.instance.last_prompt_idBizyAir 在非 server 模式下读取 last_prompt_id 作为 API 请求追踪 ID。并行模式下可能读到另一个工作流的 ID。
修复方案
last_prompt_id:
- 移除
main.py:305的server_instance.last_prompt_id = prompt_id - (P0-3 修复已覆盖
last_node_id→executing_clients) - 移除
execution.py:805的self.server.last_node_id = None
hijack_progress():
- 移除对
server.last_prompt_id/server.last_node_id的 fallback,完全依赖CurrentNodeContext
BizyAir 兼容:
- 影响程度低:API server 场景下
BIZYAIR_SERVER_MODE通常为 True,通过_meta传递prompt_id,不走这个分支 - 如果需要兼容,可给 BizyAir 提 PR 改用
get_executing_context():
from comfy_execution.utils import get_executing_context
context = get_executing_context()
if context is not None:
extra_data["prompt_id"] = context.prompt_id利弊
| 优点 | 缺点 / 风险 |
|---|---|
| 消除共享状态竞态 | BizyAir 非 server 模式下 last_prompt_id 不再可靠——但 API server 场景不受影响 |
| 简化代码,移除不必要的 fallback | 需要确认 CurrentNodeContext 在所有代码路径上都已正确设置 |
| 与 P0-3 修复一致 | — |
改动范围总结
| 文件 | 涉及问题 | 改动量 |
|---|---|---|
comfy/model_management.py |
P0-1, P0-2 | 新增 _interrupt_context、修改 processing_interrupted() 和 throw_exception_if_processing_interrupted() |
execution.py |
P0-1, P0-2, P0-3, P1 | PromptExecutor 新增 _interrupted / interrupt() / 注册表;execute_async 移除全局重置和共享状态写入;移除 server.last_node_id = None |
server.py |
P0-1, P0-3, P1 | /interrupt 改为 per-executor 中断;新增 executing_clients;修改 WebSocket 重连逻辑 |
main.py |
P1 | 移除 hijack_progress 的 fallback;移除 server.last_prompt_id 写入 |
bizyengine/core/nodes_base.py |
P1(可选) | 改用 get_executing_context() 替代 PromptServer.instance.last_prompt_id |
BizyAir 无需强制修改——所有改动保持向后兼容,BizyAir 的 ProgressCallback 调用的 throw_exception_if_processing_interrupted() 会自动走 per-executor 路径。
需要讨论的决策点
-
per-executor 中断 vs 全局中断的兼容性策略:是否保留全局标志作为 fallback?保留意味着旧的自定义节点中断行为不变,但也意味着全局中断仍可能"泄漏"到非目标工作流。
-
executing_clients的边界情况:同一client_id提交多个并行工作流时,只记录最后一个node_id,这是否可接受?前端 UI 场景下单个客户端通常只有一个可视工作流,但 API server 场景下一个 client_id 可能关联多个工作流。 -
BizyAir 是否同步修改:
nodes_base.py中对last_prompt_id的读取在并行模式下不可靠,但 API server 场景下走_meta分支不受影响。是现在同步修还是等 BizyAir 发现问题再修? -
/interrupt不带prompt_id时的语义:全局中断应该中断所有并行工作流,还是保持旧行为只中断"当前"那个?建议改为中断全部,但这会改变 API 行为。 -
comfy_api_nodes的is_processing_interrupted()同步修改:这与 P0-1 一体修复,processing_interrupted()改后自动生效,无需逐文件改。确认这个方案是否可接受。

No description provided.