Skip to content

Commit 74bcdec

Browse files
committed
feat: add lifecycle hooks
1 parent 8524deb commit 74bcdec

7 files changed

Lines changed: 551 additions & 9 deletions

File tree

nest/common/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,10 @@
2424
Res,
2525
createParamDecorator,
2626
)
27+
from nest.common.interfaces import (
28+
BeforeApplicationShutdown,
29+
OnApplicationBootstrap,
30+
OnApplicationShutdown,
31+
OnModuleDestroy,
32+
OnModuleInit,
33+
)

nest/common/interfaces.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from __future__ import annotations
2+
3+
from typing import Any, Optional, Protocol, runtime_checkable
4+
5+
6+
@runtime_checkable
7+
class OnModuleInit(Protocol):
8+
def on_module_init(self) -> Any: ...
9+
10+
11+
@runtime_checkable
12+
class OnApplicationBootstrap(Protocol):
13+
def on_application_bootstrap(self) -> Any: ...
14+
15+
16+
@runtime_checkable
17+
class BeforeApplicationShutdown(Protocol):
18+
def before_application_shutdown(self, signal: Optional[str]) -> Any: ...
19+
20+
21+
@runtime_checkable
22+
class OnModuleDestroy(Protocol):
23+
def on_module_destroy(self) -> Any: ...
24+
25+
26+
@runtime_checkable
27+
class OnApplicationShutdown(Protocol):
28+
def on_application_shutdown(self, signal: Optional[str]) -> Any: ...

nest/core/pynest_application.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from __future__ import annotations
22

3+
import asyncio
34
import inspect
4-
from typing import Any
5+
import signal as signal_module
6+
from contextlib import asynccontextmanager
7+
from typing import Any, Iterable, Optional
58

69
from fastapi import FastAPI, Request
710
from fastapi.responses import JSONResponse
@@ -18,6 +21,9 @@ class PyNestApp:
1821
def __init__(self, container: PyNestContainer, http_server: FastAPI) -> None:
1922
self.container = container
2023
self.http_server = http_server
24+
self._closed = False
25+
self._closing = False
26+
self._install_lifespan_shutdown()
2127
routes_resolver = RoutesResolver(self.container, self.http_server)
2228
routes_resolver.register_routes()
2329

@@ -33,6 +39,31 @@ def use(self, middleware: type, **options: Any) -> "PyNestApp":
3339
self.http_server.add_middleware(middleware, **options)
3440
return self
3541

42+
def enable_shutdown_hooks(
43+
self, signals: Optional[Iterable[signal_module.Signals]] = None
44+
) -> "PyNestApp":
45+
"""Register process signal handlers that trigger graceful shutdown."""
46+
shutdown_signals = tuple(
47+
signals or (signal_module.SIGTERM, signal_module.SIGINT)
48+
)
49+
for shutdown_signal in shutdown_signals:
50+
signal_module.signal(
51+
shutdown_signal, self._make_signal_handler(shutdown_signal)
52+
)
53+
return self
54+
55+
async def close(self, signal: Optional[str] = None) -> None:
56+
"""Run graceful application shutdown lifecycle hooks once."""
57+
if self._closed or self._closing:
58+
return
59+
60+
self._closing = True
61+
try:
62+
await self.container.shutdown_lifecycle(signal)
63+
self._closed = True
64+
finally:
65+
self._closing = False
66+
3667
def use_global_filters(self, *filters) -> "PyNestApp":
3768
"""Register one or more exception filters that apply to every route.
3869
@@ -73,3 +104,38 @@ async def handler(request: Request, exc: Exception):
73104
return result
74105

75106
self.http_server.add_exception_handler(exc_type, handler)
107+
108+
def _make_signal_handler(self, shutdown_signal: signal_module.Signals):
109+
def handler(signum, frame):
110+
self._close_from_signal(self._signal_name(signum or shutdown_signal))
111+
112+
return handler
113+
114+
def _close_from_signal(self, signal_name: str) -> None:
115+
try:
116+
loop = asyncio.get_running_loop()
117+
except RuntimeError:
118+
asyncio.run(self.close(signal_name))
119+
return
120+
121+
loop.create_task(self.close(signal_name))
122+
123+
@staticmethod
124+
def _signal_name(signum) -> str:
125+
try:
126+
return signal_module.Signals(signum).name
127+
except ValueError:
128+
return str(signum)
129+
130+
def _install_lifespan_shutdown(self) -> None:
131+
original_lifespan_context = self.http_server.router.lifespan_context
132+
133+
@asynccontextmanager
134+
async def lifespan_context(app: FastAPI):
135+
async with original_lifespan_context(app) as state:
136+
try:
137+
yield state
138+
finally:
139+
await self.close()
140+
141+
self.http_server.router.lifespan_context = lifespan_context

nest/core/pynest_container.py

Lines changed: 167 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,27 @@
55
from typing import Any, Dict, List, Optional, Type, Union
66

77
from nest.common.exceptions import CircularDependencyException
8+
from nest.common.interfaces import (
9+
BeforeApplicationShutdown,
10+
OnApplicationBootstrap,
11+
OnApplicationShutdown,
12+
OnModuleDestroy,
13+
OnModuleInit,
14+
)
815
from nest.common.module import CompiledModule, ModuleCompiler, ModuleTokenFactory
916
from nest.common.provider import InjectionToken, ProviderDescriptor
1017
from nest.core.dependency_graph import DependencyGraph
1118
from nest.core.encapsulation import validate_module_encapsulation
1219
from nest.core.injector_module import build_injector, _to_key
1320

21+
_LIFECYCLE_METHOD_NAMES = (
22+
"on_module_init",
23+
"on_application_bootstrap",
24+
"before_application_shutdown",
25+
"on_module_destroy",
26+
"on_application_shutdown",
27+
)
28+
1429

1530
class ModuleRef:
1631
"""Internal container representation of a registered module."""
@@ -37,6 +52,9 @@ def __init__(self) -> None:
3752
self._modules: Dict[str, ModuleRef] = {}
3853
self._all_descriptors: List[ProviderDescriptor] = []
3954
self._controller_classes: List[Type] = []
55+
self._module_instances: Dict[str, Any] = {}
56+
self._lifecycle_initialized = False
57+
self._lifecycle_shutdown = False
4058
self._module_token_factory = ModuleTokenFactory()
4159
self._module_compiler = ModuleCompiler(self._module_token_factory)
4260

@@ -105,20 +123,80 @@ def clear(self) -> None:
105123
self._modules.clear()
106124
self._all_descriptors.clear()
107125
self._controller_classes.clear()
126+
self._module_instances.clear()
127+
self._lifecycle_initialized = False
128+
self._lifecycle_shutdown = False
129+
130+
async def initialize_lifecycle(self) -> None:
131+
"""Run module init and application bootstrap hooks once."""
132+
if self._injector is None:
133+
raise RuntimeError(
134+
"Container not built. Call container.build() before lifecycle hooks."
135+
)
136+
if self._lifecycle_initialized:
137+
return
138+
139+
for module_ref in self._modules.values():
140+
await self._call_hooks(
141+
self._get_module_lifecycle_instances(module_ref),
142+
OnModuleInit,
143+
"on_module_init",
144+
)
145+
146+
await self._call_hooks(
147+
self._get_all_lifecycle_instances(),
148+
OnApplicationBootstrap,
149+
"on_application_bootstrap",
150+
)
151+
self._lifecycle_initialized = True
152+
153+
async def shutdown_lifecycle(self, signal: Optional[str] = None) -> None:
154+
"""Run application shutdown hooks once in graceful shutdown order."""
155+
if self._injector is None:
156+
raise RuntimeError(
157+
"Container not built. Call container.build() before lifecycle hooks."
158+
)
159+
if self._lifecycle_shutdown:
160+
return
161+
162+
modules = list(self._modules.values())
163+
for module_ref in reversed(modules):
164+
await self._call_hooks(
165+
self._get_module_lifecycle_instances(module_ref),
166+
BeforeApplicationShutdown,
167+
"before_application_shutdown",
168+
signal,
169+
)
170+
171+
for module_ref in reversed(modules):
172+
await self._call_hooks(
173+
self._get_module_lifecycle_instances(module_ref),
174+
OnModuleDestroy,
175+
"on_module_destroy",
176+
)
177+
178+
for module_ref in reversed(modules):
179+
await self._call_hooks(
180+
self._get_module_lifecycle_instances(module_ref),
181+
OnApplicationShutdown,
182+
"on_application_shutdown",
183+
signal,
184+
)
185+
186+
self._lifecycle_shutdown = True
108187

109188
# ── Internal ───────────────────────────────────────────────────────────────
110189

111190
def _make_controller_descriptors(self) -> List[ProviderDescriptor]:
112191
from nest.common.provider import Scope
192+
113193
return [
114194
ProviderDescriptor(provide=cls, use_class=cls, scope=Scope.SINGLETON)
115195
for cls in self._controller_classes
116196
]
117197

118198
def _validate_dependency_graph(self) -> None:
119199
"""Build a DAG from all class providers and raise CircularDependencyException on cycles."""
120-
import sys
121-
122200
graph = DependencyGraph()
123201

124202
# Build a name→class lookup from all registered providers so forward refs can be resolved
@@ -162,9 +240,90 @@ def _validate_dependency_graph(self) -> None:
162240

163241
cycles = graph.detect_cycles()
164242
if cycles:
165-
chain = " → ".join(
166-
getattr(n, "__name__", repr(n)) for n in cycles[0]
167-
)
168-
raise CircularDependencyException(
169-
f"Circular dependency detected: {chain}"
170-
)
243+
chain = " → ".join(getattr(n, "__name__", repr(n)) for n in cycles[0])
244+
raise CircularDependencyException(f"Circular dependency detected: {chain}")
245+
246+
def _get_all_lifecycle_instances(self) -> List[Any]:
247+
instances: List[Any] = []
248+
seen: set[int] = set()
249+
for module_ref in self._modules.values():
250+
for instance in self._get_module_lifecycle_instances(module_ref):
251+
instance_id = id(instance)
252+
if instance_id in seen:
253+
continue
254+
seen.add(instance_id)
255+
instances.append(instance)
256+
return instances
257+
258+
def _get_module_lifecycle_instances(self, module_ref: ModuleRef) -> List[Any]:
259+
instances: List[Any] = []
260+
seen: set[int] = set()
261+
262+
for desc in module_ref.compiled.provider_descriptors:
263+
instance = self.get(desc.provide)
264+
instance_id = id(instance)
265+
if instance_id in seen:
266+
continue
267+
seen.add(instance_id)
268+
instances.append(instance)
269+
270+
module_instance = self._get_module_instance(module_ref)
271+
if module_instance is not None and id(module_instance) not in seen:
272+
instances.append(module_instance)
273+
274+
return instances
275+
276+
def _get_module_instance(self, module_ref: ModuleRef) -> Optional[Any]:
277+
if module_ref.token in self._module_instances:
278+
return self._module_instances[module_ref.token]
279+
280+
if not any(
281+
callable(getattr(module_ref.metatype, name, None))
282+
for name in _LIFECYCLE_METHOD_NAMES
283+
):
284+
return None
285+
286+
instance = self._instantiate_module(module_ref.metatype)
287+
self._module_instances[module_ref.token] = instance
288+
return instance
289+
290+
def _instantiate_module(self, module_class: Type) -> Any:
291+
try:
292+
signature = inspect.signature(module_class.__init__)
293+
except (TypeError, ValueError):
294+
return module_class()
295+
296+
kwargs = {}
297+
for param in list(signature.parameters.values())[1:]:
298+
if param.kind in (
299+
inspect.Parameter.VAR_POSITIONAL,
300+
inspect.Parameter.VAR_KEYWORD,
301+
):
302+
continue
303+
if param.annotation is not inspect.Parameter.empty:
304+
kwargs[param.name] = self.get(param.annotation)
305+
elif param.default is inspect.Parameter.empty:
306+
raise RuntimeError(
307+
f"Cannot instantiate module {module_class.__name__}: "
308+
f"constructor parameter {param.name!r} has no type annotation"
309+
)
310+
311+
return module_class(**kwargs)
312+
313+
async def _call_hooks(
314+
self, instances: List[Any], protocol: Type, method_name: str, *args: Any
315+
) -> None:
316+
calls = [
317+
self._call_hook(instance, method_name, *args)
318+
for instance in instances
319+
if isinstance(instance, protocol)
320+
]
321+
if calls:
322+
import asyncio
323+
324+
await asyncio.gather(*calls)
325+
326+
async def _call_hook(self, instance: Any, method_name: str, *args: Any) -> None:
327+
result = getattr(instance, method_name)(*args)
328+
if inspect.isawaitable(result):
329+
await result

nest/core/pynest_factory.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import asyncio
4+
import threading
35
from abc import ABC, abstractmethod
46
from typing import Type, TypeVar
57

@@ -34,10 +36,33 @@ def create(main_module: Type[ModuleType], **kwargs) -> PyNestApp:
3436
container = PyNestContainer()
3537
container.add_module(main_module)
3638
container.build()
39+
PyNestFactory._run_async(container.initialize_lifecycle())
3740

3841
http_server = FastAPI(**kwargs)
3942
return PyNestApp(container, http_server)
4043

4144
@staticmethod
4245
def _create_server(**kwargs) -> FastAPI:
4346
return FastAPI(**kwargs)
47+
48+
@staticmethod
49+
def _run_async(coro):
50+
try:
51+
asyncio.get_running_loop()
52+
except RuntimeError:
53+
return asyncio.run(coro)
54+
55+
result = {}
56+
57+
def runner():
58+
try:
59+
result["value"] = asyncio.run(coro)
60+
except BaseException as exc:
61+
result["error"] = exc
62+
63+
thread = threading.Thread(target=runner)
64+
thread.start()
65+
thread.join()
66+
if "error" in result:
67+
raise result["error"]
68+
return result.get("value")

0 commit comments

Comments
 (0)