@@ -22,8 +22,8 @@ def __init__(
2222 from agentscope .model import TrinityChatModel
2323 except ImportError :
2424 raise ImportError (
25- "This workflow requires agentscope >= 0.1.6 , please install "
26- "it via `pip install agentscope>=0.1.6 `" ,
25+ "This workflow requires agentscope >= 1.0.7 , please install "
26+ "it via `pip install agentscope>=1.0.7 `" ,
2727 )
2828
2929 super ().__init__ (
@@ -72,3 +72,108 @@ async def run_async(self) -> List[Experience]:
7272 """Run the workflow asynchronously and return experiences."""
7373 reward = await self .workflow_func (self .task .raw_task , self .chat_model ) # type: ignore [arg-type]
7474 return self .construct_experiences (reward )
75+
76+
77+ class AgentScopeWorkflowAdapterV1 (Workflow ):
78+ """A more general adapter to wrap agentscope trainable workflow and judge functions into a Trinity Workflow."""
79+
80+ is_async : bool = True
81+
82+ def __init__ (
83+ self ,
84+ * ,
85+ task : Task ,
86+ model : ModelWrapper ,
87+ auxiliary_models : Optional [List [ModelWrapper ]] = None ,
88+ ):
89+ """Initialize the adapter with the task and model."""
90+ try :
91+ from agentscope .model import TrinityChatModel
92+ except ImportError :
93+ raise ImportError (
94+ "This workflow requires agentscope >= 1.0.11, please install "
95+ "it via `pip install agentscope>=1.0.11`" ,
96+ )
97+
98+ super ().__init__ (
99+ task = task ,
100+ model = model ,
101+ auxiliary_models = auxiliary_models ,
102+ )
103+ self .workflow_func = task .workflow_args .get ("workflow_func" , None )
104+ self .judge_func = task .workflow_args .get ("judge_func" , None )
105+
106+ if self .workflow_func is None :
107+ raise ValueError (
108+ "The 'workflow_func' is not provided." ,
109+ )
110+
111+ self .chat_model : TrinityChatModel = TrinityChatModel (
112+ model .get_openai_async_client (),
113+ generate_kwargs = {
114+ "temperature" : self .task .rollout_args .temperature ,
115+ "top_p" : self .task .rollout_args .top_p ,
116+ "max_tokens" : self .task .rollout_args .max_tokens or 4096 ,
117+ "logprobs" : True ,
118+ "top_logprobs" : self .task .rollout_args .logprobs ,
119+ },
120+ )
121+ self .auxiliary_chat_models = [
122+ TrinityChatModel (
123+ openai_async_client = aux_model ,
124+ # TODO: customize generate_kwargs for auxiliary models if needed
125+ )
126+ for aux_model in (self .auxiliary_models or [])
127+ ]
128+
129+ def construct_experiences (
130+ self ,
131+ reward : float ,
132+ metrics : Dict ,
133+ ) -> List [Experience ]:
134+ """Construct experiences from the agent's interaction history.
135+
136+ Args:
137+ reward (float): The reward value to assign to each experience.
138+ metrics (Dict): A dictionary of metrics to be attached to the last experience.
139+
140+ Returns:
141+ List: A list of Experience objects.
142+ """
143+ exps = self .model .extract_experience_from_history ()
144+ for exp in exps :
145+ exp .reward = reward
146+ # only attach metrics to the last experience
147+ if len (exps ) > 0 :
148+ exps [- 1 ].metrics = metrics
149+ return exps
150+
151+ async def run_async (self ) -> List [Experience ]:
152+ """Run the workflow asynchronously and return experiences."""
153+ try :
154+ from agentscope .tuner import JudgeOutput , WorkflowOutput
155+ except ImportError :
156+ raise ImportError (
157+ "Fail to import agentscope tuner related types. Please ensure agentscope>=1.0.11 is installed."
158+ )
159+
160+ metrics = {}
161+ workflow_output : WorkflowOutput = await self .workflow_func (
162+ self .task .raw_task , self .chat_model , self .auxiliary_chat_models
163+ ) # type: ignore [arg-type]
164+ metrics .update (workflow_output .metrics or {})
165+ if self .judge_func is not None :
166+ assert (
167+ workflow_output .response is not None
168+ ), "Workflow must provide response for judging."
169+ judge_output : JudgeOutput = await self .judge_func (
170+ self .task .raw_task , workflow_output .response , self .auxiliary_chat_models
171+ ) # type: ignore [arg-type]
172+ reward = judge_output .reward
173+ metrics .update (judge_output .metrics or {})
174+ else :
175+ assert (
176+ workflow_output .reward is not None
177+ ), "Either workflow or judge must provide reward."
178+ reward = workflow_output .reward
179+ return self .construct_experiences (reward , metrics )
0 commit comments