diff --git a/deepspeed/runtime/activation_checkpointing/checkpointing.py b/deepspeed/runtime/activation_checkpointing/checkpointing.py index fae0148ba887..826d7c14bf06 100644 --- a/deepspeed/runtime/activation_checkpointing/checkpointing.py +++ b/deepspeed/runtime/activation_checkpointing/checkpointing.py @@ -1019,7 +1019,7 @@ def _configure_defaults(): PARTITION_ACTIVATIONS = False CONTIGUOUS_CHECKPOINTING = False - num_layers = False + num_layers = None CPU_CHECKPOINT = False SYNCHRONIZE = False PROFILE_TIME = False diff --git a/tests/unit/runtime/activation_checkpointing/test_activation_checkpointing.py b/tests/unit/runtime/activation_checkpointing/test_activation_checkpointing.py index dd3bcd7fb6bd..2a8aa5c14358 100644 --- a/tests/unit/runtime/activation_checkpointing/test_activation_checkpointing.py +++ b/tests/unit/runtime/activation_checkpointing/test_activation_checkpointing.py @@ -309,3 +309,45 @@ def __init__(self): assert model._is_checkpointable([layers[0]]) == True # ParallelTransformerLayerPipe assert model._is_checkpointable([layers[1]]) == True # GMLPBlock assert model._is_checkpointable([layers[2]]) == False # Linear layer + + +def test_configure_with_contiguous_checkpointing_requires_num_checkpoints(): + # Regression: ``_configure_defaults`` previously initialized ``num_layers`` + # to ``False`` while the assert below uses ``is not None``; ``False is not + # None`` is True, so the missing-config assert silently passed and a + # cryptic ``IndexError`` surfaced later from ``range(num_layers)``. With + # the default switched to ``None`` (matching the module-level default), + # the helpful assert message fires at the configure() call site. + # + # ``configure()`` mutates module globals before raising, so snapshot and + # restore them around the call to avoid order-dependent failures in other + # activation-checkpointing tests sharing the same pytest worker. + cp = deepspeed.checkpointing + saved = ( + cp.PARTITION_ACTIVATIONS, + cp.CONTIGUOUS_CHECKPOINTING, + cp.num_layers, + cp.CPU_CHECKPOINT, + cp.SYNCHRONIZE, + cp.PROFILE_TIME, + cp.mpu, + cp.deepspeed_checkpointing_enabled, + ) + try: + with pytest.raises(AssertionError, match="number of layers"): + deepspeed.checkpointing.configure( + mpu_=None, + partition_activations=True, + contiguous_checkpointing=True, + ) + finally: + ( + cp.PARTITION_ACTIVATIONS, + cp.CONTIGUOUS_CHECKPOINTING, + cp.num_layers, + cp.CPU_CHECKPOINT, + cp.SYNCHRONIZE, + cp.PROFILE_TIME, + cp.mpu, + cp.deepspeed_checkpointing_enabled, + ) = saved