Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 0 additions & 18 deletions pathwaysutils/jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
"""

import functools
from typing import Any

import jax

Expand Down Expand Up @@ -47,22 +46,6 @@ def __call__(self, *args, **kwargs):
raise ImportError(self.error_message)


try:
# jax>=0.7.0
from jax.extend import backend # pylint: disable=g-import-not-at-top

register_backend_cache = backend.register_backend_cache

del backend
except AttributeError:
# jax<0.7.0
from jax._src import util # pylint: disable=g-import-not-at-top

def register_backend_cache(cache: Any, name: str, util=util): # pylint: disable=unused-argument
return util.cache_clearing_funs.add(cache.cache_clear)

del util

try:
# jax>=0.7.1
from jax.extend import backend # pylint: disable=g-import-not-at-top
Expand Down Expand Up @@ -130,6 +113,5 @@ def ifrt_reshard_available() -> bool:


del jax
del Any
del _FakeJaxFunction
del functools
4 changes: 2 additions & 2 deletions pathwaysutils/lru_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import functools
from typing import Any, Callable

from pathwaysutils import jax as pw_jax
from jax.extend import backend


def lru_cache(
Expand All @@ -38,7 +38,7 @@ def wrap(f):

wrapper.cache_clear = cached.cache_clear
wrapper.cache_info = cached.cache_info
pw_jax.register_backend_cache(wrapper, "Pathways LRU cache")
backend.register_backend_cache(wrapper, "Pathways LRU cache")
return wrapper

return wrap
Loading