From 9a2deef4c76c411f59dd7ab69b9904126b780ef6 Mon Sep 17 00:00:00 2001 From: zhoutianhao03 Date: Wed, 4 Feb 2026 08:13:02 +0800 Subject: [PATCH 1/4] snapshot --- scripts/train_dflash.py | 54 +++++++++++++++++++++++++++++------------ 1 file changed, 38 insertions(+), 16 deletions(-) diff --git a/scripts/train_dflash.py b/scripts/train_dflash.py index 7baf5ee94..26230f482 100755 --- a/scripts/train_dflash.py +++ b/scripts/train_dflash.py @@ -140,23 +140,45 @@ def build_models(args) -> Tuple[DFlashTargetModel, DFlashDraftModel]: ) # 2. Build Draft Model - if args.draft_config_path: - draft_config = AutoConfig.from_pretrained(args.draft_config_path) - print_on_rank0(f"Loaded draft config from {args.draft_config_path}") + # Handle checkpoint resumption + draft_model_last_checkpoint = None + if args.resume and os.path.isdir(args.output_dir): + from specforge.utils import get_last_checkpoint + print_on_rank0(args.output_dir) + draft_model_last_checkpoint = get_last_checkpoint(args.output_dir) + print_on_rank0(f"Last checkpoint detected: {draft_model_last_checkpoint}") + + if draft_model_last_checkpoint: + # Load config from checkpoint + draft_config = AutoConfig.from_pretrained(draft_model_last_checkpoint) + print_on_rank0(f"Loaded draft config from checkpoint: {draft_model_last_checkpoint}") + + # Load draft model from checkpoint + draft_model = DFlashDraftModel.from_pretrained( + draft_model_last_checkpoint, + config=draft_config, + torch_dtype=torch.bfloat16, + ).cuda() + print_on_rank0(f"Resumed draft model from checkpoint: {draft_model_last_checkpoint}") else: - # Load config from HF (needed for structure info even if backend is sglang) - target_config = AutoConfig.from_pretrained(args.target_model_path) - draft_config = AutoConfig.from_pretrained(args.target_model_path) - draft_config.num_hidden_layers = args.num_draft_layers - draft_config.block_size = args.block_size - draft_config.num_target_layers = target_config.num_hidden_layers - print_on_rank0("Auto-generated draft config from target model") - - # Set attention implementation based on backend - draft_config._attn_implementation = args.attention_backend - print_on_rank0(f"Using attention backend: {args.attention_backend}") - - draft_model = DFlashDraftModel(draft_config).cuda().to(torch.bfloat16) + # Build draft model from scratch or provided config + if args.draft_config_path: + draft_config = AutoConfig.from_pretrained(args.draft_config_path) + print_on_rank0(f"Loaded draft config from {args.draft_config_path}") + else: + # Load config from HF (needed for structure info even if backend is sglang) + target_config = AutoConfig.from_pretrained(args.target_model_path) + draft_config = AutoConfig.from_pretrained(args.target_model_path) + draft_config.num_hidden_layers = args.num_draft_layers + draft_config.block_size = args.block_size + draft_config.num_target_layers = target_config.num_hidden_layers + print_on_rank0("Auto-generated draft config from target model") + + # Set attention implementation based on backend + draft_config._attn_implementation = args.attention_backend + print_on_rank0(f"Using attention backend: {args.attention_backend}") + + draft_model = DFlashDraftModel(draft_config).cuda().to(torch.bfloat16) # Set capture layers for target model based on draft model config target_model.set_capture_layers(draft_model.target_layer_ids) From 59dcd9e19e41f5a2b278f5ccf61f725ef44b54c3 Mon Sep 17 00:00:00 2001 From: zhoutianhao03 Date: Wed, 4 Feb 2026 09:06:15 +0800 Subject: [PATCH 2/4] update --- scripts/train_dflash.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/scripts/train_dflash.py b/scripts/train_dflash.py index 26230f482..a647479e6 100755 --- a/scripts/train_dflash.py +++ b/scripts/train_dflash.py @@ -96,7 +96,6 @@ def parse_args(): output_group.add_argument("--cache-dir", type=str, default="./cache") output_group.add_argument("--log-interval", type=int, default=50) output_group.add_argument("--eval-interval", type=int, default=1000) - output_group.add_argument("--save-interval", type=int, default=1000) optimization_group = parser.add_argument_group("optimization") optimization_group.add_argument( @@ -153,6 +152,10 @@ def build_models(args) -> Tuple[DFlashTargetModel, DFlashDraftModel]: draft_config = AutoConfig.from_pretrained(draft_model_last_checkpoint) print_on_rank0(f"Loaded draft config from checkpoint: {draft_model_last_checkpoint}") + # Set attention implementation based on backend + draft_config._attn_implementation = args.attention_backend + print_on_rank0(f"Using attention backend: {args.attention_backend}") + # Load draft model from checkpoint draft_model = DFlashDraftModel.from_pretrained( draft_model_last_checkpoint, @@ -261,7 +264,9 @@ def build_dataloader(args, tokenizer) -> Tuple[DataLoader, Optional[DataLoader]] def save_checkpoint(args, epoch, step, dflash_model, draft_model, optimizer): """Save checkpoint.""" - save_dir = os.path.join(args.output_dir, f"epoch_{epoch}_step_{step}") + # Use naming convention that get_last_checkpoint can detect + save_dir = os.path.join(args.output_dir, f"epoch_{epoch}") + if dist.get_rank() == 0: os.makedirs(save_dir, exist_ok=True) dist.barrier() @@ -487,11 +492,12 @@ def main(): } ) - if global_step % args.save_interval == 0: - save_checkpoint( - args, epoch, global_step, dflash_model, draft_model, optimizer - ) + # Save checkpoint after each epoch + save_checkpoint( + args, epoch, global_step, dflash_model, draft_model, optimizer + ) + # Final checkpoint save_checkpoint( args, args.num_epochs, global_step, dflash_model, draft_model, optimizer ) From 54c64b1cce662922e0a3a97cee0ed6acacd0ab2c Mon Sep 17 00:00:00 2001 From: zhoutianhao03 Date: Wed, 4 Feb 2026 10:11:12 +0800 Subject: [PATCH 3/4] format --- scripts/train_dflash.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/scripts/train_dflash.py b/scripts/train_dflash.py index a647479e6..51a0ebca8 100755 --- a/scripts/train_dflash.py +++ b/scripts/train_dflash.py @@ -143,6 +143,7 @@ def build_models(args) -> Tuple[DFlashTargetModel, DFlashDraftModel]: draft_model_last_checkpoint = None if args.resume and os.path.isdir(args.output_dir): from specforge.utils import get_last_checkpoint + print_on_rank0(args.output_dir) draft_model_last_checkpoint = get_last_checkpoint(args.output_dir) print_on_rank0(f"Last checkpoint detected: {draft_model_last_checkpoint}") @@ -150,7 +151,9 @@ def build_models(args) -> Tuple[DFlashTargetModel, DFlashDraftModel]: if draft_model_last_checkpoint: # Load config from checkpoint draft_config = AutoConfig.from_pretrained(draft_model_last_checkpoint) - print_on_rank0(f"Loaded draft config from checkpoint: {draft_model_last_checkpoint}") + print_on_rank0( + f"Loaded draft config from checkpoint: {draft_model_last_checkpoint}" + ) # Set attention implementation based on backend draft_config._attn_implementation = args.attention_backend @@ -162,7 +165,9 @@ def build_models(args) -> Tuple[DFlashTargetModel, DFlashDraftModel]: config=draft_config, torch_dtype=torch.bfloat16, ).cuda() - print_on_rank0(f"Resumed draft model from checkpoint: {draft_model_last_checkpoint}") + print_on_rank0( + f"Resumed draft model from checkpoint: {draft_model_last_checkpoint}" + ) else: # Build draft model from scratch or provided config if args.draft_config_path: @@ -493,9 +498,7 @@ def main(): ) # Save checkpoint after each epoch - save_checkpoint( - args, epoch, global_step, dflash_model, draft_model, optimizer - ) + save_checkpoint(args, epoch, global_step, dflash_model, draft_model, optimizer) # Final checkpoint save_checkpoint( From d39267b17a2dbe14cec84a04b06ad7134bd5a9cb Mon Sep 17 00:00:00 2001 From: zhoutianhao03 Date: Fri, 6 Feb 2026 02:18:57 +0800 Subject: [PATCH 4/4] add back option --- scripts/train_dflash.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/train_dflash.py b/scripts/train_dflash.py index 51a0ebca8..b38634027 100755 --- a/scripts/train_dflash.py +++ b/scripts/train_dflash.py @@ -96,6 +96,7 @@ def parse_args(): output_group.add_argument("--cache-dir", type=str, default="./cache") output_group.add_argument("--log-interval", type=int, default=50) output_group.add_argument("--eval-interval", type=int, default=1000) + output_group.add_argument("--save-interval", type=int, default=1000) optimization_group = parser.add_argument_group("optimization") optimization_group.add_argument(