diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/registration.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/registration.py index 8e4b3b122..321304373 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/registration.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/registration.py @@ -521,7 +521,6 @@ def _get_possible_handlers( registry: CheckpointableHandlerRegistry, is_handleable: Callable[[CheckpointableHandler, Any], bool], checkpointable: Any | None, - name: str, ) -> Sequence[CheckpointableHandler]: """Raises a NoEntryError if no possible handlers are found.""" registry_entries = [ @@ -545,20 +544,6 @@ def _get_possible_handlers( if checkpointable_name is None and is_handleable(handler, checkpointable) ] - if not possible_handlers: - available_handlers = [ - handler_type for handler_type, _ in registry.get_all_entries() - ] - error_msg = ( - f'Could not identify a valid handler for the checkpointable: "{name}"' - f' and checkpointable type={type(checkpointable)}. Make sure to' - ' register a `CheckpointableHandler` for the object using' - ' `register_handler`, or by specifying a local registry' - ' (`CheckpointablesOptions`). If a handler is already registered,' - ' ensure that `is_handleable` correctly identifies the object as' - f' handleable. The available handlers are: {available_handlers}' - ) - raise NoEntryError(error_msg) return possible_handlers @@ -580,11 +565,14 @@ def resolve_handler_for_save( ) -> CheckpointableHandler: """Resolves a :py:class:`~.v1.handlers.CheckpointableHandler` for saving. - 1. If a name matching the provided checkpointable name is explicitly + 1. If the checkpointable is a StatefulCheckpointable, prefer to use a + handler that supports it (e.g. StatefulCheckpointableHandler), bypassing + explicit name registration. + 2. If a name matching the provided checkpointable name is explicitly registered, return the corresponding handler. - 2. Resolve based on the `checkpointable` (using + 3. Resolve based on the `checkpointable` (using :py:meth:`~.v1._src.handlers.types.CheckpointableHandler.is_handleable`). - 3. If multiple handlers are usable, return the *last* usable handler. This + 4. If multiple handlers are usable, return the *last* usable handler. This allows us to resolve the most recently-registered handler. Args: @@ -602,9 +590,6 @@ def resolve_handler_for_save( NoEntryError: If no compatible :py:class:`~.v1.handlers.CheckpointableHandler` can be found. """ - # If explicitly registered, use that first. - if registry.has(name): - return _construct_handler_instance(name, registry.get(name)) if checkpointable is None: raise ValueError('checkpointable must not be None for saving.') @@ -613,11 +598,38 @@ def is_handleable(handler: CheckpointableHandler, ckpt: Any) -> bool: return handler.is_handleable(ckpt) possible_handlers = _get_possible_handlers( - registry, is_handleable, checkpointable, name + registry, is_handleable, checkpointable ) + possible_handler = possible_handlers[-1] if possible_handlers else None - # Prefer the last handler in the absence of any other information. - return possible_handlers[-1] + # 1. If the checkpointable is a StatefulCheckpointable, prefer to use a + # handler that supports it, bypassing explicit name registration. + if ( + isinstance(checkpointable, handler_types.StatefulCheckpointable) + and possible_handler + ): + return possible_handler + + # 2. If explicitly registered, use that. + if registry.has(name): + return _construct_handler_instance(name, registry.get(name)) + + # 3 & 4. Resolve based on the checkpointable and return the last usable. + if possible_handler: + return possible_handler + + available_handlers = [ + handler_type for handler_type, _ in registry.get_all_entries() + ] + raise NoEntryError( + f'Could not identify a valid handler for the checkpointable: "{name}"' + f' and checkpointable type={type(checkpointable)}. Make sure to' + ' register a `CheckpointableHandler` for the object using' + ' `register_handler`, or by specifying a local registry' + ' (`CheckpointablesOptions`). If a handler is already registered,' + ' ensure that `is_handleable` correctly identifies the object as' + f' handleable. The available handlers are: {available_handlers}' + ) def resolve_handler_for_load( @@ -629,17 +641,16 @@ def resolve_handler_for_load( ) -> CheckpointableHandler: """Resolves a :py:class:`~.v1.handlers.CheckpointableHandler` for loading. - 1. If name is explicitly registered, return the handler. - 2. Resolve based on the `abstract_checkpointable` (using - :py:meth:`~.v1._src.handlers.types.CheckpointableHandler.is_abstract_handleable`). - 3. If `abstract_checkpointable` is None or not provided, all registered - handlers not scoped to a specific item name are potentially usable. - 4. If multiple handlers are usable, return the handler with the matching - typestr. If no matching typestr is found, then the handler used for saving - may not be available now. - 5. Return the *last* usable handler. This allows us to resolve the most - recently-registered handler, unless abstract_checkpointable is None, in - which case raise a NoEntryError. + 1. If `abstract_checkpointable` is a `StatefulCheckpointable`, prefer the + handler matching `handler_typestr` if it is handleable. + 2. If `name` is explicitly registered, return its handler (provided it is + handleable or `abstract_checkpointable` is `None`). + 3. If `handler_typestr` matches a registered handler, return it (provided it + is handleable or `abstract_checkpointable` is `None`). + 4. If `abstract_checkpointable` is provided, return the most recently + registered handler that can handle it. + 5. Fallback to the explicitly registered handler for `name` even if + incompatible, otherwise raise `NoEntryError`. Args: registry: The @@ -660,30 +671,26 @@ def resolve_handler_for_load( :py:class:`~.v1.handlers.CheckpointableHandler` can be found. """ - # If explicitly registered, use that first. - if registry.has(name): - return _construct_handler_instance(name, registry.get(name)) + explicit_handler = ( + _construct_handler_instance(name, registry.get(name)) + if registry.has(name) + else None + ) def is_handleable(handler: CheckpointableHandler, ckpt: Any) -> bool | None: return handler.is_abstract_handleable(ckpt) - possible_handlers = _get_possible_handlers( - registry, is_handleable, abstract_checkpointable, name - ) - possible_handler_typestrs = [ - handler_types.typestr(type(handler)) for handler in possible_handlers - ] - + # Find the handler matching the typestr from the checkpoint metadata. + resolved_by_typestr = None if handler_typestr: - if handler_typestr in possible_handler_typestrs: - idx = possible_handler_typestrs.index(handler_typestr) - return possible_handlers[idx] - # Attempt to find a handler with a matching secondary typestr. - for i in reversed(range(len(possible_handlers))): - if handler_typestr in registry.get_secondary_typestrs( - type(possible_handlers[i]) - ): - return possible_handlers[i] + for h_type, ckpt_name in reversed(registry.get_all_entries()): + h_type_str = handler_types.typestr(h_type) + secondary_typestrs = registry.get_secondary_typestrs(h_type) + if h_type_str == handler_typestr or handler_typestr in secondary_typestrs: + resolved_by_typestr = _construct_handler_instance(ckpt_name, h_type) + break + + if handler_typestr and not resolved_by_typestr: logging.warning( 'No handler found for typestr %s (or its converted form). The ' 'checkpointable may be restored with different handler logic ' @@ -691,9 +698,43 @@ def is_handleable(handler: CheckpointableHandler, ckpt: Any) -> bool | None: handler_typestr, ) - if abstract_checkpointable: - # Prefer the last handler in the absence of any other information. - return possible_handlers[-1] + # Determine if we're in a "stateful" context. + is_stateful = False + if abstract_checkpointable is not None: + is_stateful = isinstance( + abstract_checkpointable, handler_types.StatefulCheckpointable + ) + + # 1. If stateful, prefer the stateful handler over explicit name. + if is_stateful and resolved_by_typestr: + if is_handleable(resolved_by_typestr, abstract_checkpointable): + return resolved_by_typestr + + # 2. Explicitly registered handler. + if explicit_handler: + if abstract_checkpointable is None or is_handleable( + explicit_handler, abstract_checkpointable + ): + return explicit_handler + + # 3. Any handler matching the typestr. + if resolved_by_typestr: + if abstract_checkpointable is None or is_handleable( + resolved_by_typestr, abstract_checkpointable + ): + return resolved_by_typestr + + # 4. Any handler that can handle the object. + if abstract_checkpointable is not None: + possible_handlers = _get_possible_handlers( + registry, is_handleable, abstract_checkpointable + ) + if possible_handlers: + return possible_handlers[-1] + + # 5. Fallback: Return explicit handler even if incompatible. + if explicit_handler: + return explicit_handler raise NoEntryError( f'No entry for checkpointable={name} in the registry, using' diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/registration_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/registration_test.py index 5a4d604a6..54401c708 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/registration_test.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/registration_test.py @@ -16,6 +16,7 @@ from absl.testing import absltest from absl.testing import parameterized from orbax.checkpoint.experimental.v1._src.handlers import registration +from orbax.checkpoint.experimental.v1._src.handlers import stateful_checkpointable_handler from orbax.checkpoint.experimental.v1._src.handlers import types as handler_types from orbax.checkpoint.experimental.v1._src.testing import handler_utils @@ -173,6 +174,24 @@ def test_resolve_handler_for_save_abstract_checkpointable(self): local_registry, handler_utils.AbstractFoo(), name='foo' ) + def test_resolve_handler_for_save_stateful_checkpointable_priority(self): + local_registry = registration.local_registry() + local_registry.add( + handler_utils.FooHandler, checkpointable_name='custom_name' + ) + local_registry.add( + stateful_checkpointable_handler.StatefulCheckpointableHandler, + ) + resolved_handler = registration.resolve_handler_for_save( + local_registry, + handler_utils.Point(1, 2), + name='custom_name', + ) + self.assertIsInstance( + resolved_handler, + stateful_checkpointable_handler.StatefulCheckpointableHandler, + ) + class ResolveHandlerForLoadTest(RegistrationTestBase): @@ -374,6 +393,27 @@ def test_resolve_handler_for_load_no_checkpointable_no_metadata(self): handler_typestr=None, ) + def test_resolve_handler_for_load_stateful_checkpointable_priority(self): + local_registry = registration.local_registry() + local_registry.add( + handler_utils.FooHandler, checkpointable_name='custom_name' + ) + local_registry.add( + stateful_checkpointable_handler.StatefulCheckpointableHandler, + ) + resolved_handler = registration.resolve_handler_for_load( + local_registry, + handler_utils.Point(1, 2), + name='custom_name', + handler_typestr=handler_types.typestr( + stateful_checkpointable_handler.StatefulCheckpointableHandler + ), + ) + self.assertIsInstance( + resolved_handler, + stateful_checkpointable_handler.StatefulCheckpointableHandler, + ) + if __name__ == '__main__': absltest.main() diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/saving.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/saving.py index 885afde06..9b7c88887 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/saving.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/saving.py @@ -22,10 +22,8 @@ from orbax.checkpoint._src.path import async_path from orbax.checkpoint._src.path import utils as ocp_path_utils from orbax.checkpoint.experimental.v1._src.context import context as context_lib -from orbax.checkpoint.experimental.v1._src.context import options as options_lib from orbax.checkpoint.experimental.v1._src.handlers import global_registration # pylint: disable=unused-import from orbax.checkpoint.experimental.v1._src.handlers import pytree_handler -from orbax.checkpoint.experimental.v1._src.handlers import registration from orbax.checkpoint.experimental.v1._src.handlers import stateful_checkpointable_handler from orbax.checkpoint.experimental.v1._src.handlers import types as handler_types from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout @@ -222,36 +220,15 @@ def save_pytree_async( if path.exists(): raise FileExistsError(f'Finalized checkpoint already exists at {path}.') - # By default, the registry associates 'pytree' (PYTREE_CHECKPOINTABLE_KEY) - # with PyTreeHandler. We want to use StatefulCheckpointableHandler for our - # wrapper (_PartialSavePyTree) to carry the partial save flag. Since - # name-based resolution takes priority, we override 'pytree' in a local - # registry. - current_reg = ctx.checkpointables_options.registry - local_reg = registration.local_registry(include_global_registry=False) - for handler, name in current_reg.get_all_entries(): - if name != PYTREE_CHECKPOINTABLE_KEY: - local_reg.add( - handler, - checkpointable_name=name, - secondary_typestrs=current_reg.get_secondary_typestrs(handler), - ) - local_reg.add( - StatefulCheckpointableHandler, - checkpointable_name=PYTREE_CHECKPOINTABLE_KEY, + return execution.save_checkpointables_impl( + partial_path_lib.add_partial_save_suffix(path), + {PYTREE_CHECKPOINTABLE_KEY: _PartialSavePyTree(pytree)}, + overwrite=False, + custom_metadata=custom_metadata, + async_origin=True, + partial_save=True, ) - new_options = options_lib.CheckpointablesOptions(registry=local_reg) - with context_lib.Context(ctx, checkpointables_options=new_options): - return execution.save_checkpointables_impl( - partial_path_lib.add_partial_save_suffix(path), - {PYTREE_CHECKPOINTABLE_KEY: _PartialSavePyTree(pytree)}, - overwrite=False, - custom_metadata=custom_metadata, - async_origin=True, - partial_save=True, - ) - def finalize(path: path_types.PathLike) -> None: """Finalizes a partially-saved checkpoint, making it permanent and readable. diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py index 9724f715b..edcc5e006 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/save_load_test_base.py @@ -226,13 +226,18 @@ async def mock_finalize(self_handler, directory): (tuple([]),), (dict(),), (list(),), - (None,), (optax.EmptyState(),), ) def test_empty_tree(self, tree): - with self.assertRaisesRegex(ValueError, 'empty'): + with self.assertRaisesRegex(ValueError, 'Found empty item'): ocp.save_pytree(self.directory, tree) + def test_none_tree(self): + with self.assertRaisesRegex( + ValueError, 'checkpointable must not be None for saving' + ): + ocp.save_pytree(self.directory, None) + # Note the ommission of jax.Array, since this is covered in # several other tests. @parameterized.parameters(