Skip to content

Commit 106e69c

Browse files
authored
Enhance AgentScope Workflow Adapter (#457)
1 parent a65cf0e commit 106e69c

File tree

3 files changed

+110
-4
lines changed

3 files changed

+110
-4
lines changed

examples/agentscope_react/gsm8k.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ buffer:
2929
response_key: 'answer'
3030
rollout_args:
3131
temperature: 1.0
32-
default_workflow_type: 'as_react_workflow'
32+
default_workflow_type: 'agentscope_react_workflow'
3333
eval_tasksets: []
3434
trainer_input:
3535
experience_buffer:

trinity/common/workflows/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@
2020
# tool_call
2121
"tool_call_workflow": "trinity.common.workflows.customized_toolcall_workflows.ToolCallWorkflow",
2222
# agentscope
23-
"agentscope_react_workflow": "trinity.common.workflows.agentscope.react.react_workflow.AgentScopeReActWorkflow",
2423
"agentscope_workflow_adapter": "trinity.common.workflows.agentscope_workflow.AgentScopeWorkflowAdapter",
24+
"agentscope_workflow_adapter_v1": "trinity.common.workflows.agentscope_workflow.AgentScopeWorkflowAdapterV1",
25+
"agentscope_react_workflow": "trinity.common.workflows.agentscope.react.react_workflow.AgentScopeReActWorkflow",
2526
"agentscope_react_math_workflow": "trinity.common.workflows.envs.agentscope.agentscopev1_react_workflow.AgentScopeReactMathWorkflow",
2627
"as_react_workflow": "trinity.common.workflows.agentscope.react.react_workflow.AgentScopeReActWorkflow",
2728
"agentscopev0_react_math_workflow": "trinity.common.workflows.envs.agentscope.agentscopev0_react_workflow.AgentScopeV0ReactMathWorkflow",

trinity/common/workflows/agentscope_workflow.py

Lines changed: 107 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ def __init__(
2222
from agentscope.model import TrinityChatModel
2323
except ImportError:
2424
raise ImportError(
25-
"This workflow requires agentscope >= 0.1.6, please install "
26-
"it via `pip install agentscope>=0.1.6`",
25+
"This workflow requires agentscope >= 1.0.7, please install "
26+
"it via `pip install agentscope>=1.0.7`",
2727
)
2828

2929
super().__init__(
@@ -72,3 +72,108 @@ async def run_async(self) -> List[Experience]:
7272
"""Run the workflow asynchronously and return experiences."""
7373
reward = await self.workflow_func(self.task.raw_task, self.chat_model) # type: ignore [arg-type]
7474
return self.construct_experiences(reward)
75+
76+
77+
class AgentScopeWorkflowAdapterV1(Workflow):
78+
"""A more general adapter to wrap agentscope trainable workflow and judge functions into a Trinity Workflow."""
79+
80+
is_async: bool = True
81+
82+
def __init__(
83+
self,
84+
*,
85+
task: Task,
86+
model: ModelWrapper,
87+
auxiliary_models: Optional[List[ModelWrapper]] = None,
88+
):
89+
"""Initialize the adapter with the task and model."""
90+
try:
91+
from agentscope.model import TrinityChatModel
92+
except ImportError:
93+
raise ImportError(
94+
"This workflow requires agentscope >= 1.0.11, please install "
95+
"it via `pip install agentscope>=1.0.11`",
96+
)
97+
98+
super().__init__(
99+
task=task,
100+
model=model,
101+
auxiliary_models=auxiliary_models,
102+
)
103+
self.workflow_func = task.workflow_args.get("workflow_func", None)
104+
self.judge_func = task.workflow_args.get("judge_func", None)
105+
106+
if self.workflow_func is None:
107+
raise ValueError(
108+
"The 'workflow_func' is not provided.",
109+
)
110+
111+
self.chat_model: TrinityChatModel = TrinityChatModel(
112+
model.get_openai_async_client(),
113+
generate_kwargs={
114+
"temperature": self.task.rollout_args.temperature,
115+
"top_p": self.task.rollout_args.top_p,
116+
"max_tokens": self.task.rollout_args.max_tokens or 4096,
117+
"logprobs": True,
118+
"top_logprobs": self.task.rollout_args.logprobs,
119+
},
120+
)
121+
self.auxiliary_chat_models = [
122+
TrinityChatModel(
123+
openai_async_client=aux_model,
124+
# TODO: customize generate_kwargs for auxiliary models if needed
125+
)
126+
for aux_model in (self.auxiliary_models or [])
127+
]
128+
129+
def construct_experiences(
130+
self,
131+
reward: float,
132+
metrics: Dict,
133+
) -> List[Experience]:
134+
"""Construct experiences from the agent's interaction history.
135+
136+
Args:
137+
reward (float): The reward value to assign to each experience.
138+
metrics (Dict): A dictionary of metrics to be attached to the last experience.
139+
140+
Returns:
141+
List: A list of Experience objects.
142+
"""
143+
exps = self.model.extract_experience_from_history()
144+
for exp in exps:
145+
exp.reward = reward
146+
# only attach metrics to the last experience
147+
if len(exps) > 0:
148+
exps[-1].metrics = metrics
149+
return exps
150+
151+
async def run_async(self) -> List[Experience]:
152+
"""Run the workflow asynchronously and return experiences."""
153+
try:
154+
from agentscope.tuner import JudgeOutput, WorkflowOutput
155+
except ImportError:
156+
raise ImportError(
157+
"Fail to import agentscope tuner related types. Please ensure agentscope>=1.0.11 is installed."
158+
)
159+
160+
metrics = {}
161+
workflow_output: WorkflowOutput = await self.workflow_func(
162+
self.task.raw_task, self.chat_model, self.auxiliary_chat_models
163+
) # type: ignore [arg-type]
164+
metrics.update(workflow_output.metrics or {})
165+
if self.judge_func is not None:
166+
assert (
167+
workflow_output.response is not None
168+
), "Workflow must provide response for judging."
169+
judge_output: JudgeOutput = await self.judge_func(
170+
self.task.raw_task, workflow_output.response, self.auxiliary_chat_models
171+
) # type: ignore [arg-type]
172+
reward = judge_output.reward
173+
metrics.update(judge_output.metrics or {})
174+
else:
175+
assert (
176+
workflow_output.reward is not None
177+
), "Either workflow or judge must provide reward."
178+
reward = workflow_output.reward
179+
return self.construct_experiences(reward, metrics)

0 commit comments

Comments
 (0)