Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,6 @@ def _get_possible_handlers(
registry: CheckpointableHandlerRegistry,
is_handleable: Callable[[CheckpointableHandler, Any], bool],
checkpointable: Any | None,
name: str,
) -> Sequence[CheckpointableHandler]:
"""Raises a NoEntryError if no possible handlers are found."""
registry_entries = [
Expand All @@ -545,20 +544,6 @@ def _get_possible_handlers(
if checkpointable_name is None
and is_handleable(handler, checkpointable)
]
if not possible_handlers:
available_handlers = [
handler_type for handler_type, _ in registry.get_all_entries()
]
error_msg = (
f'Could not identify a valid handler for the checkpointable: "{name}"'
f' and checkpointable type={type(checkpointable)}. Make sure to'
' register a `CheckpointableHandler` for the object using'
' `register_handler`, or by specifying a local registry'
' (`CheckpointablesOptions`). If a handler is already registered,'
' ensure that `is_handleable` correctly identifies the object as'
f' handleable. The available handlers are: {available_handlers}'
)
raise NoEntryError(error_msg)
return possible_handlers


Expand All @@ -580,11 +565,14 @@ def resolve_handler_for_save(
) -> CheckpointableHandler:
"""Resolves a :py:class:`~.v1.handlers.CheckpointableHandler` for saving.

1. If a name matching the provided checkpointable name is explicitly
1. If the checkpointable is a StatefulCheckpointable, prefer to use a
handler that supports it (e.g. StatefulCheckpointableHandler), bypassing
explicit name registration.
2. If a name matching the provided checkpointable name is explicitly
registered, return the corresponding handler.
2. Resolve based on the `checkpointable` (using
3. Resolve based on the `checkpointable` (using
:py:meth:`~.v1._src.handlers.types.CheckpointableHandler.is_handleable`).
3. If multiple handlers are usable, return the *last* usable handler. This
4. If multiple handlers are usable, return the *last* usable handler. This
allows us to resolve the most recently-registered handler.

Args:
Expand All @@ -602,9 +590,6 @@ def resolve_handler_for_save(
NoEntryError: If no compatible
:py:class:`~.v1.handlers.CheckpointableHandler` can be found.
"""
# If explicitly registered, use that first.
if registry.has(name):
return _construct_handler_instance(name, registry.get(name))

if checkpointable is None:
raise ValueError('checkpointable must not be None for saving.')
Expand All @@ -613,11 +598,38 @@ def is_handleable(handler: CheckpointableHandler, ckpt: Any) -> bool:
return handler.is_handleable(ckpt)

possible_handlers = _get_possible_handlers(
registry, is_handleable, checkpointable, name
registry, is_handleable, checkpointable
)
possible_handler = possible_handlers[-1] if possible_handlers else None

# Prefer the last handler in the absence of any other information.
return possible_handlers[-1]
# 1. If the checkpointable is a StatefulCheckpointable, prefer to use a
# handler that supports it, bypassing explicit name registration.
if (
isinstance(checkpointable, handler_types.StatefulCheckpointable)
and possible_handler
):
return possible_handler

# 2. If explicitly registered, use that.
if registry.has(name):
return _construct_handler_instance(name, registry.get(name))

# 3 & 4. Resolve based on the checkpointable and return the last usable.
if possible_handler:
return possible_handler

available_handlers = [
handler_type for handler_type, _ in registry.get_all_entries()
]
raise NoEntryError(
f'Could not identify a valid handler for the checkpointable: "{name}"'
f' and checkpointable type={type(checkpointable)}. Make sure to'
' register a `CheckpointableHandler` for the object using'
' `register_handler`, or by specifying a local registry'
' (`CheckpointablesOptions`). If a handler is already registered,'
' ensure that `is_handleable` correctly identifies the object as'
f' handleable. The available handlers are: {available_handlers}'
)


def resolve_handler_for_load(
Expand All @@ -629,17 +641,16 @@ def resolve_handler_for_load(
) -> CheckpointableHandler:
"""Resolves a :py:class:`~.v1.handlers.CheckpointableHandler` for loading.

1. If name is explicitly registered, return the handler.
2. Resolve based on the `abstract_checkpointable` (using
:py:meth:`~.v1._src.handlers.types.CheckpointableHandler.is_abstract_handleable`).
3. If `abstract_checkpointable` is None or not provided, all registered
handlers not scoped to a specific item name are potentially usable.
4. If multiple handlers are usable, return the handler with the matching
typestr. If no matching typestr is found, then the handler used for saving
may not be available now.
5. Return the *last* usable handler. This allows us to resolve the most
recently-registered handler, unless abstract_checkpointable is None, in
which case raise a NoEntryError.
1. If `abstract_checkpointable` is a `StatefulCheckpointable`, prefer the
handler matching `handler_typestr` if it is handleable.
2. If `name` is explicitly registered, return its handler (provided it is
handleable or `abstract_checkpointable` is `None`).
3. If `handler_typestr` matches a registered handler, return it (provided it
is handleable or `abstract_checkpointable` is `None`).
4. If `abstract_checkpointable` is provided, return the most recently
registered handler that can handle it.
5. Fallback to the explicitly registered handler for `name` even if
incompatible, otherwise raise `NoEntryError`.

Args:
registry: The
Expand All @@ -660,40 +671,70 @@ def resolve_handler_for_load(
:py:class:`~.v1.handlers.CheckpointableHandler`
can be found.
"""
# If explicitly registered, use that first.
if registry.has(name):
return _construct_handler_instance(name, registry.get(name))
explicit_handler = (
_construct_handler_instance(name, registry.get(name))
if registry.has(name)
else None
)

def is_handleable(handler: CheckpointableHandler, ckpt: Any) -> bool | None:
return handler.is_abstract_handleable(ckpt)

possible_handlers = _get_possible_handlers(
registry, is_handleable, abstract_checkpointable, name
)
possible_handler_typestrs = [
handler_types.typestr(type(handler)) for handler in possible_handlers
]

# Find the handler matching the typestr from the checkpoint metadata.
resolved_by_typestr = None
if handler_typestr:
if handler_typestr in possible_handler_typestrs:
idx = possible_handler_typestrs.index(handler_typestr)
return possible_handlers[idx]
# Attempt to find a handler with a matching secondary typestr.
for i in reversed(range(len(possible_handlers))):
if handler_typestr in registry.get_secondary_typestrs(
type(possible_handlers[i])
):
return possible_handlers[i]
for h_type, ckpt_name in reversed(registry.get_all_entries()):
h_type_str = handler_types.typestr(h_type)
secondary_typestrs = registry.get_secondary_typestrs(h_type)
if h_type_str == handler_typestr or handler_typestr in secondary_typestrs:
resolved_by_typestr = _construct_handler_instance(ckpt_name, h_type)
break

if handler_typestr and not resolved_by_typestr:
logging.warning(
'No handler found for typestr %s (or its converted form). The '
'checkpointable may be restored with different handler logic '
'than was used for saving.',
handler_typestr,
)

if abstract_checkpointable:
# Prefer the last handler in the absence of any other information.
return possible_handlers[-1]
# Determine if we're in a "stateful" context.
is_stateful = False
if abstract_checkpointable is not None:
is_stateful = isinstance(
abstract_checkpointable, handler_types.StatefulCheckpointable
)

# 1. If stateful, prefer the stateful handler over explicit name.
if is_stateful and resolved_by_typestr:
if is_handleable(resolved_by_typestr, abstract_checkpointable):
return resolved_by_typestr

# 2. Explicitly registered handler.
if explicit_handler:
if abstract_checkpointable is None or is_handleable(
explicit_handler, abstract_checkpointable
):
return explicit_handler

# 3. Any handler matching the typestr.
if resolved_by_typestr:
if abstract_checkpointable is None or is_handleable(
resolved_by_typestr, abstract_checkpointable
):
return resolved_by_typestr

# 4. Any handler that can handle the object.
if abstract_checkpointable is not None:
possible_handlers = _get_possible_handlers(
registry, is_handleable, abstract_checkpointable
)
if possible_handlers:
return possible_handlers[-1]

# 5. Fallback: Return explicit handler even if incompatible.
if explicit_handler:
return explicit_handler

raise NoEntryError(
f'No entry for checkpointable={name} in the registry, using'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from absl.testing import absltest
from absl.testing import parameterized
from orbax.checkpoint.experimental.v1._src.handlers import registration
from orbax.checkpoint.experimental.v1._src.handlers import stateful_checkpointable_handler
from orbax.checkpoint.experimental.v1._src.handlers import types as handler_types
from orbax.checkpoint.experimental.v1._src.testing import handler_utils

Expand Down Expand Up @@ -173,6 +174,24 @@ def test_resolve_handler_for_save_abstract_checkpointable(self):
local_registry, handler_utils.AbstractFoo(), name='foo'
)

def test_resolve_handler_for_save_stateful_checkpointable_priority(self):
local_registry = registration.local_registry()
local_registry.add(
handler_utils.FooHandler, checkpointable_name='custom_name'
)
local_registry.add(
stateful_checkpointable_handler.StatefulCheckpointableHandler,
)
resolved_handler = registration.resolve_handler_for_save(
local_registry,
handler_utils.Point(1, 2),
name='custom_name',
)
self.assertIsInstance(
resolved_handler,
stateful_checkpointable_handler.StatefulCheckpointableHandler,
)


class ResolveHandlerForLoadTest(RegistrationTestBase):

Expand Down Expand Up @@ -374,6 +393,27 @@ def test_resolve_handler_for_load_no_checkpointable_no_metadata(self):
handler_typestr=None,
)

def test_resolve_handler_for_load_stateful_checkpointable_priority(self):
local_registry = registration.local_registry()
local_registry.add(
handler_utils.FooHandler, checkpointable_name='custom_name'
)
local_registry.add(
stateful_checkpointable_handler.StatefulCheckpointableHandler,
)
resolved_handler = registration.resolve_handler_for_load(
local_registry,
handler_utils.Point(1, 2),
name='custom_name',
handler_typestr=handler_types.typestr(
stateful_checkpointable_handler.StatefulCheckpointableHandler
),
)
self.assertIsInstance(
resolved_handler,
stateful_checkpointable_handler.StatefulCheckpointableHandler,
)


if __name__ == '__main__':
absltest.main()
37 changes: 7 additions & 30 deletions checkpoint/orbax/checkpoint/experimental/v1/_src/partial/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@
from orbax.checkpoint._src.path import async_path
from orbax.checkpoint._src.path import utils as ocp_path_utils
from orbax.checkpoint.experimental.v1._src.context import context as context_lib
from orbax.checkpoint.experimental.v1._src.context import options as options_lib
from orbax.checkpoint.experimental.v1._src.handlers import global_registration # pylint: disable=unused-import
from orbax.checkpoint.experimental.v1._src.handlers import pytree_handler
from orbax.checkpoint.experimental.v1._src.handlers import registration
from orbax.checkpoint.experimental.v1._src.handlers import stateful_checkpointable_handler
from orbax.checkpoint.experimental.v1._src.handlers import types as handler_types
from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout
Expand Down Expand Up @@ -222,36 +220,15 @@ def save_pytree_async(
if path.exists():
raise FileExistsError(f'Finalized checkpoint already exists at {path}.')

# By default, the registry associates 'pytree' (PYTREE_CHECKPOINTABLE_KEY)
# with PyTreeHandler. We want to use StatefulCheckpointableHandler for our
# wrapper (_PartialSavePyTree) to carry the partial save flag. Since
# name-based resolution takes priority, we override 'pytree' in a local
# registry.
current_reg = ctx.checkpointables_options.registry
local_reg = registration.local_registry(include_global_registry=False)
for handler, name in current_reg.get_all_entries():
if name != PYTREE_CHECKPOINTABLE_KEY:
local_reg.add(
handler,
checkpointable_name=name,
secondary_typestrs=current_reg.get_secondary_typestrs(handler),
)
local_reg.add(
StatefulCheckpointableHandler,
checkpointable_name=PYTREE_CHECKPOINTABLE_KEY,
return execution.save_checkpointables_impl(
partial_path_lib.add_partial_save_suffix(path),
{PYTREE_CHECKPOINTABLE_KEY: _PartialSavePyTree(pytree)},
overwrite=False,
custom_metadata=custom_metadata,
async_origin=True,
partial_save=True,
)

new_options = options_lib.CheckpointablesOptions(registry=local_reg)
with context_lib.Context(ctx, checkpointables_options=new_options):
return execution.save_checkpointables_impl(
partial_path_lib.add_partial_save_suffix(path),
{PYTREE_CHECKPOINTABLE_KEY: _PartialSavePyTree(pytree)},
overwrite=False,
custom_metadata=custom_metadata,
async_origin=True,
partial_save=True,
)


def finalize(path: path_types.PathLike) -> None:
"""Finalizes a partially-saved checkpoint, making it permanent and readable.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,13 +226,18 @@ async def mock_finalize(self_handler, directory):
(tuple([]),),
(dict(),),
(list(),),
(None,),
(optax.EmptyState(),),
)
def test_empty_tree(self, tree):
with self.assertRaisesRegex(ValueError, 'empty'):
with self.assertRaisesRegex(ValueError, 'Found empty item'):
ocp.save_pytree(self.directory, tree)

def test_none_tree(self):
with self.assertRaisesRegex(
ValueError, 'checkpointable must not be None for saving'
):
ocp.save_pytree(self.directory, None)

# Note the ommission of jax.Array, since this is covered in
# several other tests.
@parameterized.parameters(
Expand Down
Loading