55from typing import Any , Dict , List , Optional , Type , Union
66
77from nest .common .exceptions import CircularDependencyException
8+ from nest .common .interfaces import (
9+ BeforeApplicationShutdown ,
10+ OnApplicationBootstrap ,
11+ OnApplicationShutdown ,
12+ OnModuleDestroy ,
13+ OnModuleInit ,
14+ )
815from nest .common .module import CompiledModule , ModuleCompiler , ModuleTokenFactory
916from nest .common .provider import InjectionToken , ProviderDescriptor
1017from nest .core .dependency_graph import DependencyGraph
1118from nest .core .encapsulation import validate_module_encapsulation
1219from nest .core .injector_module import build_injector , _to_key
1320
21+ _LIFECYCLE_METHOD_NAMES = (
22+ "on_module_init" ,
23+ "on_application_bootstrap" ,
24+ "before_application_shutdown" ,
25+ "on_module_destroy" ,
26+ "on_application_shutdown" ,
27+ )
28+
1429
1530class ModuleRef :
1631 """Internal container representation of a registered module."""
@@ -37,6 +52,9 @@ def __init__(self) -> None:
3752 self ._modules : Dict [str , ModuleRef ] = {}
3853 self ._all_descriptors : List [ProviderDescriptor ] = []
3954 self ._controller_classes : List [Type ] = []
55+ self ._module_instances : Dict [str , Any ] = {}
56+ self ._lifecycle_initialized = False
57+ self ._lifecycle_shutdown = False
4058 self ._module_token_factory = ModuleTokenFactory ()
4159 self ._module_compiler = ModuleCompiler (self ._module_token_factory )
4260
@@ -105,20 +123,80 @@ def clear(self) -> None:
105123 self ._modules .clear ()
106124 self ._all_descriptors .clear ()
107125 self ._controller_classes .clear ()
126+ self ._module_instances .clear ()
127+ self ._lifecycle_initialized = False
128+ self ._lifecycle_shutdown = False
129+
130+ async def initialize_lifecycle (self ) -> None :
131+ """Run module init and application bootstrap hooks once."""
132+ if self ._injector is None :
133+ raise RuntimeError (
134+ "Container not built. Call container.build() before lifecycle hooks."
135+ )
136+ if self ._lifecycle_initialized :
137+ return
138+
139+ for module_ref in self ._modules .values ():
140+ await self ._call_hooks (
141+ self ._get_module_lifecycle_instances (module_ref ),
142+ OnModuleInit ,
143+ "on_module_init" ,
144+ )
145+
146+ await self ._call_hooks (
147+ self ._get_all_lifecycle_instances (),
148+ OnApplicationBootstrap ,
149+ "on_application_bootstrap" ,
150+ )
151+ self ._lifecycle_initialized = True
152+
153+ async def shutdown_lifecycle (self , signal : Optional [str ] = None ) -> None :
154+ """Run application shutdown hooks once in graceful shutdown order."""
155+ if self ._injector is None :
156+ raise RuntimeError (
157+ "Container not built. Call container.build() before lifecycle hooks."
158+ )
159+ if self ._lifecycle_shutdown :
160+ return
161+
162+ modules = list (self ._modules .values ())
163+ for module_ref in reversed (modules ):
164+ await self ._call_hooks (
165+ self ._get_module_lifecycle_instances (module_ref ),
166+ BeforeApplicationShutdown ,
167+ "before_application_shutdown" ,
168+ signal ,
169+ )
170+
171+ for module_ref in reversed (modules ):
172+ await self ._call_hooks (
173+ self ._get_module_lifecycle_instances (module_ref ),
174+ OnModuleDestroy ,
175+ "on_module_destroy" ,
176+ )
177+
178+ for module_ref in reversed (modules ):
179+ await self ._call_hooks (
180+ self ._get_module_lifecycle_instances (module_ref ),
181+ OnApplicationShutdown ,
182+ "on_application_shutdown" ,
183+ signal ,
184+ )
185+
186+ self ._lifecycle_shutdown = True
108187
109188 # ── Internal ───────────────────────────────────────────────────────────────
110189
111190 def _make_controller_descriptors (self ) -> List [ProviderDescriptor ]:
112191 from nest .common .provider import Scope
192+
113193 return [
114194 ProviderDescriptor (provide = cls , use_class = cls , scope = Scope .SINGLETON )
115195 for cls in self ._controller_classes
116196 ]
117197
118198 def _validate_dependency_graph (self ) -> None :
119199 """Build a DAG from all class providers and raise CircularDependencyException on cycles."""
120- import sys
121-
122200 graph = DependencyGraph ()
123201
124202 # Build a name→class lookup from all registered providers so forward refs can be resolved
@@ -162,9 +240,90 @@ def _validate_dependency_graph(self) -> None:
162240
163241 cycles = graph .detect_cycles ()
164242 if cycles :
165- chain = " → " .join (
166- getattr (n , "__name__" , repr (n )) for n in cycles [0 ]
167- )
168- raise CircularDependencyException (
169- f"Circular dependency detected: { chain } "
170- )
243+ chain = " → " .join (getattr (n , "__name__" , repr (n )) for n in cycles [0 ])
244+ raise CircularDependencyException (f"Circular dependency detected: { chain } " )
245+
246+ def _get_all_lifecycle_instances (self ) -> List [Any ]:
247+ instances : List [Any ] = []
248+ seen : set [int ] = set ()
249+ for module_ref in self ._modules .values ():
250+ for instance in self ._get_module_lifecycle_instances (module_ref ):
251+ instance_id = id (instance )
252+ if instance_id in seen :
253+ continue
254+ seen .add (instance_id )
255+ instances .append (instance )
256+ return instances
257+
258+ def _get_module_lifecycle_instances (self , module_ref : ModuleRef ) -> List [Any ]:
259+ instances : List [Any ] = []
260+ seen : set [int ] = set ()
261+
262+ for desc in module_ref .compiled .provider_descriptors :
263+ instance = self .get (desc .provide )
264+ instance_id = id (instance )
265+ if instance_id in seen :
266+ continue
267+ seen .add (instance_id )
268+ instances .append (instance )
269+
270+ module_instance = self ._get_module_instance (module_ref )
271+ if module_instance is not None and id (module_instance ) not in seen :
272+ instances .append (module_instance )
273+
274+ return instances
275+
276+ def _get_module_instance (self , module_ref : ModuleRef ) -> Optional [Any ]:
277+ if module_ref .token in self ._module_instances :
278+ return self ._module_instances [module_ref .token ]
279+
280+ if not any (
281+ callable (getattr (module_ref .metatype , name , None ))
282+ for name in _LIFECYCLE_METHOD_NAMES
283+ ):
284+ return None
285+
286+ instance = self ._instantiate_module (module_ref .metatype )
287+ self ._module_instances [module_ref .token ] = instance
288+ return instance
289+
290+ def _instantiate_module (self , module_class : Type ) -> Any :
291+ try :
292+ signature = inspect .signature (module_class .__init__ )
293+ except (TypeError , ValueError ):
294+ return module_class ()
295+
296+ kwargs = {}
297+ for param in list (signature .parameters .values ())[1 :]:
298+ if param .kind in (
299+ inspect .Parameter .VAR_POSITIONAL ,
300+ inspect .Parameter .VAR_KEYWORD ,
301+ ):
302+ continue
303+ if param .annotation is not inspect .Parameter .empty :
304+ kwargs [param .name ] = self .get (param .annotation )
305+ elif param .default is inspect .Parameter .empty :
306+ raise RuntimeError (
307+ f"Cannot instantiate module { module_class .__name__ } : "
308+ f"constructor parameter { param .name !r} has no type annotation"
309+ )
310+
311+ return module_class (** kwargs )
312+
313+ async def _call_hooks (
314+ self , instances : List [Any ], protocol : Type , method_name : str , * args : Any
315+ ) -> None :
316+ calls = [
317+ self ._call_hook (instance , method_name , * args )
318+ for instance in instances
319+ if isinstance (instance , protocol )
320+ ]
321+ if calls :
322+ import asyncio
323+
324+ await asyncio .gather (* calls )
325+
326+ async def _call_hook (self , instance : Any , method_name : str , * args : Any ) -> None :
327+ result = getattr (instance , method_name )(* args )
328+ if inspect .isawaitable (result ):
329+ await result
0 commit comments