Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 51 additions & 20 deletions scripts/train_dflash.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,24 +164,54 @@ def build_models(args) -> Tuple[DFlashTargetModel, DFlashDraftModel]:
)

# 2. Build Draft Model
if args.draft_config_path:
draft_config = AutoConfig.from_pretrained(args.draft_config_path)
print_on_rank0(f"Loaded draft config from {args.draft_config_path}")
else:
target_config = AutoConfig.from_pretrained(args.target_model_path)
draft_config = AutoConfig.from_pretrained(args.target_model_path)
draft_config.num_hidden_layers = args.num_draft_layers
draft_config.block_size = args.block_size
draft_config.num_target_layers = target_config.num_hidden_layers
print_on_rank0("Auto-generated draft config from target model")
# Handle checkpoint resumption
draft_model_last_checkpoint = None
if args.resume and os.path.isdir(args.output_dir):
from specforge.utils import get_last_checkpoint

if not hasattr(draft_config, "dflash_config") or draft_config.dflash_config is None:
draft_config.dflash_config = {}
print_on_rank0(args.output_dir)
draft_model_last_checkpoint = get_last_checkpoint(args.output_dir)
print_on_rank0(f"Last checkpoint detected: {draft_model_last_checkpoint}")

draft_config._attn_implementation = args.attention_backend
print_on_rank0(f"Using attention backend: {args.attention_backend}")
if draft_model_last_checkpoint:
# Load config from checkpoint
draft_config = AutoConfig.from_pretrained(draft_model_last_checkpoint)
print_on_rank0(
f"Loaded draft config from checkpoint: {draft_model_last_checkpoint}"
)

draft_model = DFlashDraftModel(draft_config).cuda().to(torch.bfloat16)
# Set attention implementation based on backend
draft_config._attn_implementation = args.attention_backend
print_on_rank0(f"Using attention backend: {args.attention_backend}")

# Load draft model from checkpoint
draft_model = DFlashDraftModel.from_pretrained(
draft_model_last_checkpoint,
config=draft_config,
torch_dtype=torch.bfloat16,
).cuda()
print_on_rank0(
f"Resumed draft model from checkpoint: {draft_model_last_checkpoint}"
)
else:
# Build draft model from scratch or provided config
if args.draft_config_path:
draft_config = AutoConfig.from_pretrained(args.draft_config_path)
print_on_rank0(f"Loaded draft config from {args.draft_config_path}")
else:
# Load config from HF (needed for structure info even if backend is sglang)
target_config = AutoConfig.from_pretrained(args.target_model_path)
draft_config = AutoConfig.from_pretrained(args.target_model_path)
draft_config.num_hidden_layers = args.num_draft_layers
draft_config.block_size = args.block_size
draft_config.num_target_layers = target_config.num_hidden_layers
print_on_rank0("Auto-generated draft config from target model")

# Set attention implementation based on backend
draft_config._attn_implementation = args.attention_backend
print_on_rank0(f"Using attention backend: {args.attention_backend}")

draft_model = DFlashDraftModel(draft_config).cuda().to(torch.bfloat16)

target_model.set_capture_layers(draft_model.target_layer_ids)

Expand Down Expand Up @@ -261,7 +291,9 @@ def build_dataloader(args, tokenizer) -> Tuple[DataLoader, Optional[DataLoader]]

def save_checkpoint(args, epoch, step, dflash_model, draft_model, optimizer):
"""Save checkpoint."""
save_dir = os.path.join(args.output_dir, f"epoch_{epoch}_step_{step}")
# Use naming convention that get_last_checkpoint can detect
save_dir = os.path.join(args.output_dir, f"epoch_{epoch}")

if dist.get_rank() == 0:
os.makedirs(save_dir, exist_ok=True)
dist.barrier()
Expand Down Expand Up @@ -534,11 +566,10 @@ def main():
}
)

if global_step % args.save_interval == 0:
save_checkpoint(
args, epoch, global_step, dflash_model, draft_model, optimizer
)
# Save checkpoint after each epoch
save_checkpoint(args, epoch, global_step, dflash_model, draft_model, optimizer)

# Final checkpoint
save_checkpoint(
args, args.num_epochs, global_step, dflash_model, draft_model, optimizer
)
Expand Down