Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1019,7 +1019,7 @@ def _configure_defaults():

PARTITION_ACTIVATIONS = False
CONTIGUOUS_CHECKPOINTING = False
num_layers = False
num_layers = None
CPU_CHECKPOINT = False
SYNCHRONIZE = False
PROFILE_TIME = False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,3 +309,45 @@ def __init__(self):
assert model._is_checkpointable([layers[0]]) == True # ParallelTransformerLayerPipe
assert model._is_checkpointable([layers[1]]) == True # GMLPBlock
assert model._is_checkpointable([layers[2]]) == False # Linear layer


def test_configure_with_contiguous_checkpointing_requires_num_checkpoints():
# Regression: ``_configure_defaults`` previously initialized ``num_layers``
# to ``False`` while the assert below uses ``is not None``; ``False is not
# None`` is True, so the missing-config assert silently passed and a
# cryptic ``IndexError`` surfaced later from ``range(num_layers)``. With
# the default switched to ``None`` (matching the module-level default),
# the helpful assert message fires at the configure() call site.
#
# ``configure()`` mutates module globals before raising, so snapshot and
# restore them around the call to avoid order-dependent failures in other
# activation-checkpointing tests sharing the same pytest worker.
cp = deepspeed.checkpointing
saved = (
cp.PARTITION_ACTIVATIONS,
cp.CONTIGUOUS_CHECKPOINTING,
cp.num_layers,
cp.CPU_CHECKPOINT,
cp.SYNCHRONIZE,
cp.PROFILE_TIME,
cp.mpu,
cp.deepspeed_checkpointing_enabled,
)
try:
with pytest.raises(AssertionError, match="number of layers"):
deepspeed.checkpointing.configure(
mpu_=None,
partition_activations=True,
contiguous_checkpointing=True,
)
finally:
(
cp.PARTITION_ACTIVATIONS,
cp.CONTIGUOUS_CHECKPOINTING,
cp.num_layers,
cp.CPU_CHECKPOINT,
cp.SYNCHRONIZE,
cp.PROFILE_TIME,
cp.mpu,
cp.deepspeed_checkpointing_enabled,
) = saved
Loading