Skip to content

Commit 3178d4f

Browse files
authored
Bug fix in multi stage resume (#462)
1 parent 4b27dff commit 3178d4f

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

tests/trainer/trainer_test.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -331,12 +331,15 @@ def test_trainer(self, mock_load):
331331
),
332332
]
333333
self.config.check_and_update()
334+
old_taskset_path = self.config.stages[1].buffer.explorer_input.taskset.path
335+
self.config.stages[1].buffer.explorer_input.taskset.path = "/invalid/path"
334336

335-
mock_load.return_value = self.config
337+
mock_load.return_value = deepcopy(self.config)
336338

337-
run(config_path="dummy.yaml")
339+
with self.assertRaises(Exception):
340+
run(config_path="dummy.yaml")
338341

339-
stage_configs = [cfg.check_and_update() for cfg in self.config]
342+
stage_configs = [cfg.check_and_update() for cfg in deepcopy(self.config)]
340343

341344
# sft warmup stage
342345
sft_config = stage_configs[0]
@@ -351,6 +354,10 @@ def test_trainer(self, mock_load):
351354
self.assertEqual(parser.metric_min_step(response_metrics[0]), 1)
352355
self.assertEqual(parser.metric_max_step(response_metrics[0]), 3)
353356

357+
self.config.stages[1].buffer.explorer_input.taskset.path = old_taskset_path
358+
mock_load.return_value = deepcopy(self.config)
359+
run(config_path="dummy.yaml")
360+
354361
# grpo stage
355362
grpo_config = stage_configs[1]
356363
parser = TensorBoardParser(os.path.join(grpo_config.monitor.cache_dir, "tensorboard"))

trinity/cli/launcher.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None):
191191
f"> Skipping completed stage {i + 1}/{len(config.stages)}...\n"
192192
"==========================================================="
193193
)
194+
stage_config.check_and_update()
194195
else:
195196
logger.info(
196197
"===========================================================\n"

0 commit comments

Comments
 (0)