Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 162 additions & 0 deletions convert_to_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import json
import re
from enum import Enum
from pathlib import Path
from typing import Annotated

import torch
import typer
from composer.models import write_huggingface_pretrained_from_composer_checkpoint
from typer import Option
from safetensors.torch import save_file as safetensors_save_file

app = typer.Typer(context_settings={"help_option_names": ["-h", "--help"]}, pretty_exceptions_show_locals=False)


class TorchDtype(str, Enum):
float32 = "float32"
float16 = "float16"
bfloat16 = "bfloat16"


def update_config(
source_config: dict,
bos_token_id: int,
eos_token_id: int,
cls_token_id: int,
pad_token_id: int,
sep_token_id: int,
max_length: int,
torch_dtype: TorchDtype,
) -> dict:
target_config = {
# "_name_or_path": "ModernBERT-base",
"architectures": ["ModernBertForMaskedLM"],
"attention_bias": source_config["attn_out_bias"],
"attention_dropout": source_config["attention_probs_dropout_prob"],
"bos_token_id": bos_token_id,
"classifier_activation": source_config.get("head_class_act", source_config["hidden_act"]),
"classifier_bias": source_config["head_class_bias"],
"classifier_dropout": source_config["head_class_dropout"],
"classifier_pooling": "mean",
"cls_token_id": cls_token_id,
"decoder_bias": source_config["decoder_bias"],
"deterministic_flash_attn": source_config["deterministic_fa2"],
"embedding_dropout": source_config["embed_dropout_prob"],
"eos_token_id": eos_token_id,
"global_attn_every_n_layers": source_config["global_attn_every_n_layers"],
"global_rope_theta": source_config["rotary_emb_base"],
"gradient_checkpointing": source_config["gradient_checkpointing"],
"hidden_activation": source_config["hidden_act"],
"hidden_size": source_config["hidden_size"],
"initializer_cutoff_factor": source_config["init_cutoff_factor"],
"initializer_range": source_config["initializer_range"],
"intermediate_size": source_config["intermediate_size"],
"layer_norm_eps": source_config["norm_kwargs"]["eps"],
"local_attention": source_config["sliding_window"],
"local_rope_theta": source_config["local_attn_rotary_emb_base"]
if (
source_config["local_attn_rotary_emb_base"]
and source_config["local_attn_rotary_emb_base"] != -1
)
else source_config["rotary_emb_base"],
"max_position_embeddings": max_length, # Override with first config value
"mlp_bias": source_config["mlp_in_bias"],
"mlp_dropout": source_config["mlp_dropout_prob"],
"model_type": "modernbert",
"norm_bias": source_config["norm_kwargs"]["bias"],
"norm_eps": source_config["norm_kwargs"]["eps"],
"num_attention_heads": source_config["num_attention_heads"],
"num_hidden_layers": source_config["num_hidden_layers"],
"pad_token_id": pad_token_id,
"position_embedding_type": source_config["position_embedding_type"],
"sep_token_id": sep_token_id,
"tie_word_embeddings": source_config.get("tie_word_embeddings", True),
"torch_dtype": torch_dtype.value,
"transformers_version": "4.48.0",
"vocab_size": source_config["vocab_size"],
}
return target_config


@app.command(help="Convert a ModernBERT Composer checkpoint to HuggingFace pretrained format.")
def main(
output_name: Annotated[str, Option(help="Name of the output model", show_default=False)],
output_dir: Annotated[Path, Option(help="Path to the output directory", show_default=False)],
input_checkpoint: Annotated[Path, Option(help="Path to the ModernBERT Composer checkpoint file", show_default=False)],
bos_token_id: Annotated[int, Option(help="ID of the BOS token. Defaults to the ModernBERT BOS token.")] = 50281,
eos_token_id: Annotated[int, Option(help="ID of the EOS token. Defaults to the ModernBERT EOS token.")] = 50282,
cls_token_id: Annotated[int, Option(help="ID of the CLS token. Defaults to the ModernBERT CLS token.")] = 50281,
sep_token_id: Annotated[int, Option(help="ID of the SEP token. Defaults to the ModernBERT SEP token.")] = 50282,
pad_token_id: Annotated[int, Option(help="ID of the PAD token. Defaults to the ModernBERT PAD token.")] = 50283,
mask_token_id: Annotated[int, Option(help="ID of the MASK token. Defaults to the ModernBERT MASK token.")] = 50284,
max_length: Annotated[int, Option(help="Maximum length of the input sequence. Defaults to the final ModernBERT sequence length.")] = 8192,
torch_dtype: Annotated[TorchDtype, Option(help="Torch dtype to use for the model.")] = TorchDtype.float32,
pytorch_bin: Annotated[bool, Option(help="Save weights as a pytorch_model.bin file.")] = True,
safetensors: Annotated[bool, Option(help="Save weights as a model.safetensors file.")] = True,
drop_tied_decoder_weights: Annotated[bool, Option(help="Don't save the wieght tied decoder weights.")] = True,
): # fmt: skip
"""
Convert a ModernBERT Composer checkpoint to HuggingFace pretrained format.
"""
target_path = f"{output_dir}/{output_name}"
write_huggingface_pretrained_from_composer_checkpoint(input_checkpoint, target_path)

# Process pytorch_model.bin
state_dict_path = f"{target_path}/pytorch_model.bin"
state_dict = torch.load(state_dict_path, map_location=torch.device("cpu"))
var_map = (
(re.compile(r"encoder\.layers\.(.*)"), r"layers.\1"),
(re.compile(r"^bert\.(.*)"), r"model.\1"), # Replaces 'bert.' with 'model.' at the start of keys
)
for pattern, replacement in var_map:
state_dict = {re.sub(pattern, replacement, name): tensor for name, tensor in state_dict.items()}

# Update config.json
config_json_path = f"{target_path}/config.json"
with open(config_json_path, "r") as f:
config_dict = json.load(f)
config_dict = update_config(
config_dict, bos_token_id, eos_token_id, cls_token_id, pad_token_id, sep_token_id, max_length, torch_dtype
)
with open(config_json_path, "w") as f:
json.dump(config_dict, f, indent=2)

if config_dict.get("tie_word_embeddings", False) and drop_tied_decoder_weights:
if "decoder.weight" in state_dict:
del state_dict["decoder.weight"]

# Export to pytorch_model.bin
if pytorch_bin:
torch.save(state_dict, state_dict_path)

# Export to safetensors
if safetensors:
safetensors_path = f"{target_path}/model.safetensors"
safetensors_save_file(state_dict, safetensors_path)

# Update tokenizer_config.json
tokenizer_config_path = f"{target_path}/tokenizer_config.json"
with open(tokenizer_config_path, "r") as f:
config_dict = json.load(f)
config_dict["model_max_length"] = max_length
config_dict["added_tokens_decoder"][str(mask_token_id)]["lstrip"] = True
config_dict["model_input_names"] = ["input_ids", "attention_mask"]
config_dict["tokenizer_class"] = "PreTrainedTokenizerFast"

if "extra_special_tokens" in config_dict:
del config_dict["extra_special_tokens"]
with open(tokenizer_config_path, "w") as f:
json.dump(config_dict, f, indent=2)

# Update special_tokens_map.json
special_tokens_path = f"{target_path}/special_tokens_map.json"
with open(special_tokens_path, "r") as f:
config_dict = json.load(f)
config_dict["mask_token"]["lstrip"] = True
with open(special_tokens_path, "w") as f:
json.dump(config_dict, f, indent=2)


if __name__ == "__main__":
app()
159 changes: 159 additions & 0 deletions yamls/modernbert/modernbert-base-context-extension.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
data_local: data_folder
data_remote: # If blank, files must be present in data_local

pretrain_data_local: pretrain_data_folder # set this to use pretraining data for validation metrics
pretrain_data_remote: # If blank, files must be present in pretrain_data_local

max_seq_len: 8192
tokenizer_name: answerdotai/ModernBERT-base
mlm_probability: 0.3 # FlexBERT should use 30% masking for optimal performance
count_padding_tokens: false

# Run Name
run_name: modernbert-base-context-extension
pretrain_run_name: modernbert-base-pretrain

# Model
model:
name: flex_bert
pretrained_model_name: bert-base-uncased # has to be set to bert-base-uncased legacy reasons
tokenizer_name: ${tokenizer_name}
disable_train_metrics: true # save some time by not computing metrics on the training set
model_config:
vocab_size: 50368
init_method: full_megatron
num_hidden_layers: 22
hidden_size: 768
intermediate_size: 1152
num_attention_heads: 12 # to have head size of 64
attention_layer: rope
attention_probs_dropout_prob: 0.0
attn_out_bias: false
attn_out_dropout_prob: 0.1
attn_qkv_bias: false
bert_layer: prenorm
embed_dropout_prob: 0.0
embed_norm: true
final_norm: true
skip_first_prenorm: true
embedding_layer: sans_pos
loss_function: fa_cross_entropy
loss_kwargs:
reduction: mean
mlp_dropout_prob: 0.0
mlp_in_bias: false
mlp_layer: glu
mlp_out_bias: false
normalization: layernorm
norm_kwargs:
eps: 1e-5
bias: false
hidden_act: gelu
head_pred_act: gelu
activation_function: gelu # better safe than sorry
padding: unpadded
rotary_emb_dim: null
rotary_emb_base: 160000.0
rotary_emb_scale_base: null
rotary_emb_interleaved: false
local_attn_rotary_emb_base: 10000.0
local_attn_rotary_emb_dim: null
allow_embedding_resizing: true
sliding_window: 128
global_attn_every_n_layers: 3
unpad_embeddings: true
compile_model: true
masked_prediction: true

# Dataloaders
train_loader:
name: text
dataset:
local: ${data_local}
remote: ${data_remote}
split:
tokenizer_name: ${tokenizer_name}
max_seq_len: ${max_seq_len}
shuffle: true
mlm_probability: ${mlm_probability}
streaming: false
shuffle_seed: 2998
drop_last: true
num_workers: 6
sequence_packing: true

eval_loader:
name: text
dataset:
local: ${pretrain_data_local}
remote: ${pretrain_data_remote}
split: validation
tokenizer_name: ${tokenizer_name}
max_seq_len: ${max_seq_len}
shuffle: false
mlm_probability: 0.15 # We always evaluate at 15% masking for consistent comparison
streaming: false
drop_last: false
num_workers: 3
sequence_packing: false


# Optimization
scheduler:
name: constant_with_warmup
t_warmup: 0tok
t_max: ${max_duration}

optimizer:
name: decoupled_stableadamw
lr: 3e-4 # Peak learning rate
betas:
- 0.9
- 0.98
eps: 1.0e-06
weight_decay: 1.0e-5 # Amount of weight decay regularization
filter_bias_norm_wd: true # If True, doesn't apply weight decay to norm layers and biases
log_grad_norm: true

max_duration: 250_000_000_000tok
eval_interval: 4000ba
global_train_batch_size: 576
global_eval_batch_size: 1024

# System
seed: 17
device_eval_batch_size: 128
device_train_microbatch_size: 12
precision: amp_bf16

# Logging
progress_bar: false
log_to_console: true
console_log_interval: 500ba

callbacks:
speed_monitor:
window_size: 50
lr_monitor: {}
scheduled_gc: {}
log_grad_norm:
batch_log_interval: 10
packing_efficiency:
log_interval: 10

# W&B logging
# loggers:
# wandb:
# project:
# entity:

save_interval: 4000ba
save_num_checkpoints_to_keep: -1 # Important, this cleans up checkpoints saved to DISK
save_folder: checkpoints/{run_name}

# Load from local filesystem or remote object store to
load_path: checkpoints/{pretrain_run_name}/latest-rank0.pt

autoresume: false
reset_time: true # restarts the scheduler, dataloaders, etc from step zero
restart_override: true # resets optimizer hyperparameters (LR, WD, etc), LR Scheduler, and training microbatch size from the checkpoint's values
Loading