diff --git a/scripts/train_dflash.py b/scripts/train_dflash.py index 5fe974a0..4a96125d 100755 --- a/scripts/train_dflash.py +++ b/scripts/train_dflash.py @@ -164,24 +164,54 @@ def build_models(args) -> Tuple[DFlashTargetModel, DFlashDraftModel]: ) # 2. Build Draft Model - if args.draft_config_path: - draft_config = AutoConfig.from_pretrained(args.draft_config_path) - print_on_rank0(f"Loaded draft config from {args.draft_config_path}") - else: - target_config = AutoConfig.from_pretrained(args.target_model_path) - draft_config = AutoConfig.from_pretrained(args.target_model_path) - draft_config.num_hidden_layers = args.num_draft_layers - draft_config.block_size = args.block_size - draft_config.num_target_layers = target_config.num_hidden_layers - print_on_rank0("Auto-generated draft config from target model") + # Handle checkpoint resumption + draft_model_last_checkpoint = None + if args.resume and os.path.isdir(args.output_dir): + from specforge.utils import get_last_checkpoint - if not hasattr(draft_config, "dflash_config") or draft_config.dflash_config is None: - draft_config.dflash_config = {} + print_on_rank0(args.output_dir) + draft_model_last_checkpoint = get_last_checkpoint(args.output_dir) + print_on_rank0(f"Last checkpoint detected: {draft_model_last_checkpoint}") - draft_config._attn_implementation = args.attention_backend - print_on_rank0(f"Using attention backend: {args.attention_backend}") + if draft_model_last_checkpoint: + # Load config from checkpoint + draft_config = AutoConfig.from_pretrained(draft_model_last_checkpoint) + print_on_rank0( + f"Loaded draft config from checkpoint: {draft_model_last_checkpoint}" + ) - draft_model = DFlashDraftModel(draft_config).cuda().to(torch.bfloat16) + # Set attention implementation based on backend + draft_config._attn_implementation = args.attention_backend + print_on_rank0(f"Using attention backend: {args.attention_backend}") + + # Load draft model from checkpoint + draft_model = DFlashDraftModel.from_pretrained( + draft_model_last_checkpoint, + config=draft_config, + torch_dtype=torch.bfloat16, + ).cuda() + print_on_rank0( + f"Resumed draft model from checkpoint: {draft_model_last_checkpoint}" + ) + else: + # Build draft model from scratch or provided config + if args.draft_config_path: + draft_config = AutoConfig.from_pretrained(args.draft_config_path) + print_on_rank0(f"Loaded draft config from {args.draft_config_path}") + else: + # Load config from HF (needed for structure info even if backend is sglang) + target_config = AutoConfig.from_pretrained(args.target_model_path) + draft_config = AutoConfig.from_pretrained(args.target_model_path) + draft_config.num_hidden_layers = args.num_draft_layers + draft_config.block_size = args.block_size + draft_config.num_target_layers = target_config.num_hidden_layers + print_on_rank0("Auto-generated draft config from target model") + + # Set attention implementation based on backend + draft_config._attn_implementation = args.attention_backend + print_on_rank0(f"Using attention backend: {args.attention_backend}") + + draft_model = DFlashDraftModel(draft_config).cuda().to(torch.bfloat16) target_model.set_capture_layers(draft_model.target_layer_ids) @@ -261,7 +291,9 @@ def build_dataloader(args, tokenizer) -> Tuple[DataLoader, Optional[DataLoader]] def save_checkpoint(args, epoch, step, dflash_model, draft_model, optimizer): """Save checkpoint.""" - save_dir = os.path.join(args.output_dir, f"epoch_{epoch}_step_{step}") + # Use naming convention that get_last_checkpoint can detect + save_dir = os.path.join(args.output_dir, f"epoch_{epoch}") + if dist.get_rank() == 0: os.makedirs(save_dir, exist_ok=True) dist.barrier() @@ -534,11 +566,10 @@ def main(): } ) - if global_step % args.save_interval == 0: - save_checkpoint( - args, epoch, global_step, dflash_model, draft_model, optimizer - ) + # Save checkpoint after each epoch + save_checkpoint(args, epoch, global_step, dflash_model, draft_model, optimizer) + # Final checkpoint save_checkpoint( args, args.num_epochs, global_step, dflash_model, draft_model, optimizer )