diff --git a/.github/workflows/cli_test.yaml b/.github/workflows/cli_test.yaml index f3ca7bd..d03d81d 100644 --- a/.github/workflows/cli_test.yaml +++ b/.github/workflows/cli_test.yaml @@ -60,6 +60,7 @@ jobs: cd "$app_name" uv run pynest generate resource -n user + uv run pynest generate gateway -n chat -p src - name: Verify Boilerplate run: | @@ -109,4 +110,80 @@ jobs: fi done + gateway_file="$app_name/src/chat_gateway.py" + if [ -f "$gateway_file" ]; then + echo "$gateway_file exists." + else + echo "$gateway_file does not exist." + exit 1 + fi + if ! grep -q '@WebSocketGateway(namespace="/chat")' "$gateway_file"; then + echo "$gateway_file is missing the expected @WebSocketGateway decorator." + exit 1 + fi + echo "Boilerplate for ${{ matrix.app_type }} generated successfully." + + - name: Ping generated WebSocket app + if: matrix.app_type == 'Blank' + run: | + cd "${{ matrix.app_type }}App" + uv run python - <<'PY' + import asyncio + import importlib.util + import json + import socket + + import uvicorn + import websockets + + from nest.core import Module, PyNestContainer, PyNestFactory + + spec = importlib.util.spec_from_file_location( + "chat_gateway", "src/chat_gateway.py" + ) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + ChatGateway = mod.ChatGateway + + @Module(providers=[ChatGateway]) + class App: + pass + + def free_port(): + with socket.socket() as sock: + sock.bind(("127.0.0.1", 0)) + return sock.getsockname()[1] + + async def main(): + PyNestContainer._instance = None + app = PyNestFactory.create(App).get_server() + port = free_port() + config = uvicorn.Config( + app, host="127.0.0.1", port=port, + log_level="critical", lifespan="off", + ) + server = uvicorn.Server(config) + task = asyncio.create_task(server.serve()) + for _ in range(200): + if server.started: + break + await asyncio.sleep(0.05) + try: + async with websockets.connect( + f"ws://127.0.0.1:{port}/chat" + ) as ws: + await ws.send(json.dumps( + {"event": "ping", "data": {"hello": "world"}} + )) + response = json.loads(await ws.recv()) + assert response == { + "event": "pong", "data": {"hello": "world"}, + }, response + print("WEBSOCKET_PING_OK") + finally: + server.should_exit = True + await task + + asyncio.run(main()) + PY diff --git a/README.md b/README.md index 4bb2c92..e933201 100644 --- a/README.md +++ b/README.md @@ -147,6 +147,11 @@ Each module contains a collection of related controllers, services, and provider PyNest supports dependency injection, which makes it easy to manage dependencies and write testable code. You can easily inject services and providers into your controllers using decorators. +### WebSocket Gateways + +PyNest supports native FastAPI WebSocket gateways for real-time APIs. Gateways are registered as providers, participate +in dependency injection, and can use event handlers, lifecycle hooks, guards, rooms, and token streaming patterns. + ### Decorators PyNest makes extensive use of decorators to define routes, middleware, and other application components. This helps keep diff --git a/docs/cli.md b/docs/cli.md index 0b3adce..c435c92 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -138,6 +138,26 @@ pynest generate service --name users This will create a new service named `users` in the default path. +**Gateway** + +Generate a new WebSocket gateway file. + +```bash +pynest generate gateway --name +``` + +**Options** + +* `--name`, `-n`: The name of the new gateway. (Required) +* `--path`, `-p`: The path where the gateway should be created. (Optional) + +**Example** +```bash +pynest generate gateway --name chat +``` + +This creates `chat_gateway.py` with a starter `@WebSocketGateway(namespace="/chat")` class and a `ping` message handler. Add the generated gateway to a module's `providers` list to mount it. + ## Best Practices 🌟 @@ -163,4 +183,4 @@ The PyNest CLI is a powerful tool that simplifies the development of PyNest appl Modules → - \ No newline at end of file + diff --git a/docs/getting_started.md b/docs/getting_started.md index 83fbce6..b32278b 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -143,6 +143,15 @@ python main.py You should see the Uvicorn server starting, and you can access your API at . +## Next Steps + +After the first HTTP endpoint is running, you can add more framework features: + +* [Modules](modules.md) for organizing application boundaries. +* [Providers](providers.md) for injectable business logic. +* [Guards](guards.md) for authorization. +* [WebSocket Gateways](websockets.md) for real-time event APIs. + --- \ No newline at end of file + diff --git a/docs/guards.md b/docs/guards.md index 0317356..2ae3dac 100644 --- a/docs/guards.md +++ b/docs/guards.md @@ -243,6 +243,31 @@ class AdminController: In this example `AdminGuard` protects all routes while `PublicOnlyGuard` is applied only to the `login` route. +## WebSocket Guards + +`@UseGuards` also works on WebSocket gateways and individual `@SubscribeMessage` handlers. WebSocket guards receive an execution context instead of a FastAPI `Request`. + +```python +from nest.core import BaseGuard, UseGuards +from nest.websockets import SubscribeMessage, WebSocketGateway + + +class WsTokenGuard(BaseGuard): + async def can_activate(self, context): + ws = context.switch_to_ws() + return ws.get_client().headers.get("x-token") == "secret" + + +@WebSocketGateway(namespace="/private") +@UseGuards(WsTokenGuard) +class PrivateGateway: + @SubscribeMessage("secret") + async def secret(self): + return {"event": "secret_ack", "data": {}} +``` + +Use `context.switch_to_ws().get_client()` for the active socket, `get_data()` for the message body, `get_event()` for the event name, and `get_server()` for the gateway server. + ## Combining Multiple Guards `UseGuards` accepts any number of guard classes. All specified guards must return `True` in order for the request to proceed. @@ -562,4 +587,3 @@ class EnterpriseController: | Multi-Auth | Any | Flexible authentication | ✅ | PyNest guards provide a powerful, flexible, and standards-compliant way to secure your APIs while maintaining excellent developer experience and automatic documentation generation. - diff --git a/docs/introduction.md b/docs/introduction.md index 96392b3..4618093 100644 --- a/docs/introduction.md +++ b/docs/introduction.md @@ -16,6 +16,10 @@ PyNest follows the modular architecture of NestJS, which allows for easy separat PyNest supports dependency injection, which makes it easy to manage dependencies and write testable code. You can easily inject services and providers into your controllers using decorators. +### WebSocket Gateways + +PyNest supports native FastAPI WebSocket gateways for real-time APIs. Gateways are registered as providers, participate in dependency injection, and can use lifecycle hooks, guards, rooms, and event handlers. See [WebSocket Gateways](websockets.md) for the full guide. + ### Decorators 🏷️ PyNest makes extensive use of decorators to define routes, middleware, and other application components. This helps keep the code concise and easy to read. diff --git a/docs/websockets.md b/docs/websockets.md new file mode 100644 index 0000000..482432b --- /dev/null +++ b/docs/websockets.md @@ -0,0 +1,641 @@ +# WebSocket Gateways + +PyNest WebSocket gateways provide first-class real-time communication through FastAPI's native WebSocket support. A gateway is a module provider that can receive JSON event messages, call injected services, send responses, broadcast to connected clients, and use guards for authorization. + +Use gateways when clients need persistent two-way communication, for example: + +* chat and collaborative editing +* live dashboards +* background job progress +* AI token streaming +* agent session updates +* server-side event fan-out to rooms or individual clients + +## What PyNest Adds + +FastAPI already supports raw WebSocket routes. PyNest gateways add the framework pieces you normally use for HTTP controllers: + +* gateway classes registered in `@Module(providers=[...])` +* constructor dependency injection +* method-level event handlers with `@SubscribeMessage` +* handler parameter helpers with `MessageBody()` and `ConnectedSocket()` +* lifecycle hooks for setup, connection, and disconnection +* `WebSocketServer` helpers for broadcast, rooms, and direct client messages +* `@UseGuards` support with `ExecutionContext.switch_to_ws()` +* CLI scaffolding through `pynest generate gateway` + +The default transport is native WebSocket on the same FastAPI app and port as the rest of your PyNest application. No runtime WebSocket package is required beyond PyNest's FastAPI/Starlette dependencies. + +## Quick Start + +Create a service, gateway, and module: + +```python +from nest.core import Injectable, Module +from nest.websockets import MessageBody, SubscribeMessage, WebSocketGateway + + +@Injectable +class ChatService: + def acknowledge(self, text: str) -> dict: + return {"text": text, "status": "delivered"} + + +@WebSocketGateway(namespace="/chat") +class ChatGateway: + def __init__(self, chat_service: ChatService): + self.chat_service = chat_service + + @SubscribeMessage("send_message") + async def handle_message(self, data=MessageBody()): + return { + "event": "message_ack", + "data": self.chat_service.acknowledge(data["text"]), + } + + +@Module(providers=[ChatService, ChatGateway]) +class ChatModule: + pass +``` + +Create the application as usual: + +```python +from nest.core import PyNestFactory + +app = PyNestFactory.create(ChatModule) +http_server = app.get_server() +``` + +Connect to `ws://localhost:8000/chat` and send: + +```json +{"event": "send_message", "data": {"text": "hello"}} +``` + +The gateway responds: + +```json +{"event": "message_ack", "data": {"text": "hello", "status": "delivered"}} +``` + +## Registering Gateways + +Gateways are registered as providers: + +```python +@Module( + controllers=[], + providers=[ChatService, ChatGateway], +) +class ChatModule: + pass +``` + +`@WebSocketGateway` marks the class as injectable, so the gateway itself does not also need `@Injectable`. Dependencies injected into the gateway constructor still need to be normal PyNest providers. + +```python +@Injectable +class NotificationsService: + ... + + +@WebSocketGateway(namespace="/notifications") +class NotificationsGateway: + def __init__(self, notifications_service: NotificationsService): + self.notifications_service = notifications_service +``` + +## Gateway Decorator + +Use `@WebSocketGateway` on a class: + +```python +@WebSocketGateway(namespace="/events") +class EventsGateway: + ... +``` + +Arguments: + +| Argument | Description | +| --- | --- | +| `namespace` | WebSocket path mounted on the FastAPI app. Values are normalized with a leading slash, so `"chat"` becomes `"/chat"`. | +| `port` | Accepted for API compatibility. Native FastAPI gateways run on the same port as the PyNest application. | +| `options` | Accepted as metadata for future adapters and advanced configuration. | + +If no namespace is provided, PyNest uses `/ws`. + +## Message Protocol + +The native gateway expects each client message to be a JSON object with an `event` key: + +```json +{"event": "event_name", "data": {}} +``` + +`event` selects the `@SubscribeMessage` handler. `data` is the payload passed to `MessageBody()`. + +Valid message: + +```json +{"event": "join_room", "data": {"room": "support"}} +``` + +Invalid message: + +```json +{"data": {"room": "support"}} +``` + +Invalid messages receive an error frame: + +```json +{"event": "error", "data": {"message": "WebSocket message is missing an event"}} +``` + +## Subscribing to Events + +Use `@SubscribeMessage(event)` on gateway methods: + +```python +from nest.websockets import SubscribeMessage + + +@SubscribeMessage("ping") +async def ping(self): + return {"event": "pong", "data": {}} +``` + +Handlers may be sync or async. Async handlers are recommended for I/O-heavy work. + +## Handler Parameters + +Python does not support NestJS/TypeScript-style parameter decorators. PyNest uses default-value markers. + +### MessageBody + +Inject the entire incoming `data` payload: + +```python +@SubscribeMessage("send_message") +async def send_message(self, data=MessageBody()): + return {"event": "received", "data": data} +``` + +Inject a specific key from the payload: + +```python +@SubscribeMessage("join_room") +async def join_room(self, room=MessageBody("room")): + return {"event": "joined", "data": {"room": room}} +``` + +### Pydantic Payloads + +Annotate a `MessageBody()` parameter with a Pydantic model to validate and convert incoming payloads: + +```python +from pydantic import BaseModel +from nest.websockets import MessageBody, SubscribeMessage + + +class SendMessageDto(BaseModel): + room: str + text: str + + +@SubscribeMessage("send_message") +async def send_message(self, data: SendMessageDto = MessageBody()): + return { + "event": "message_ack", + "data": {"room": data.room, "text": data.text}, + } +``` + +### ConnectedSocket + +Inject the active FastAPI `WebSocket`: + +```python +from nest.websockets import ConnectedSocket + + +@SubscribeMessage("whoami") +async def whoami(self, client=ConnectedSocket()): + return { + "event": "client", + "data": {"host": client.client.host}, + } +``` + +## Handler Responses + +If a handler returns `None`, PyNest sends no automatic response. + +If a handler returns a dictionary with `event` and `data`, PyNest sends it unchanged: + +```python +return {"event": "message_ack", "data": {"id": 1}} +``` + +If a handler returns any other value, PyNest wraps it with the incoming event name: + +```python +@SubscribeMessage("count") +async def count(self): + return 3 +``` + +Response: + +```json +{"event": "count", "data": 3} +``` + +Manual sends can be mixed with automatic responses: + +```python +@SubscribeMessage("notify") +async def notify(self, client=ConnectedSocket()): + await client.send_json({"event": "step", "data": {"status": "started"}}) + return {"event": "step", "data": {"status": "done"}} +``` + +## Lifecycle Hooks + +Gateways can implement lifecycle hook interfaces: + +```python +from nest.websockets import ( + OnGatewayConnection, + OnGatewayDisconnect, + OnGatewayInit, + WebSocketGateway, +) + + +@WebSocketGateway(namespace="/events") +class EventsGateway( + OnGatewayInit, + OnGatewayConnection, + OnGatewayDisconnect, +): + async def after_init(self, server): + self.server = server + + async def on_connection(self, client): + await client.send_json({"event": "connected", "data": {}}) + + async def on_disconnect(self, client): + print("client disconnected") +``` + +Hook timing: + +| Hook | When it runs | +| --- | --- | +| `after_init(server)` | Before the first connection is handled. | +| `on_connection(client)` | After the socket is accepted and registered with the gateway server. | +| `on_disconnect(client)` | When the receive loop ends and before the socket is removed from the server registry. | + +PyNest also assigns `self.server` on the gateway instance before registration, so hooks and handlers can use it. + +## WebSocketServer + +`WebSocketServer` tracks connected clients and room membership for one gateway. + +```python +@WebSocketGateway(namespace="/events") +class EventsGateway: + @SubscribeMessage("join") + async def join(self, room=MessageBody("room"), client=ConnectedSocket()): + await self.server.join(client, room) + return {"event": "joined", "data": {"room": room}} + + async def publish_update(self, room: str, payload: dict): + await self.server.to(room).emit("update", payload) +``` + +Available APIs: + +| API | Description | +| --- | --- | +| `await server.emit(event, data)` | Send an event to all connected clients. | +| `await server.broadcast(event, data)` | Alias for `emit`. | +| `await server.join(client, room)` | Add a connected client to a room. | +| `await server.leave(client, room)` | Remove a connected client from a room. | +| `await server.to(room_or_client_id).emit(event, data)` | Send to a room or one client. | +| `server.get_client_id(client)` | Return the PyNest client id assigned to a connected socket. | + +### Room Example + +```python +@WebSocketGateway(namespace="/chat") +class ChatGateway: + @SubscribeMessage("join_room") + async def join_room(self, room=MessageBody("room"), client=ConnectedSocket()): + await self.server.join(client, room) + await self.server.to(room).emit("room_joined", {"room": room}) + + @SubscribeMessage("send_room_message") + async def send_room_message(self, data=MessageBody()): + await self.server.to(data["room"]).emit( + "room_message", + {"room": data["room"], "text": data["text"]}, + ) +``` + +### Direct Client Example + +```python +@SubscribeMessage("private_message") +async def private_message(self, data=MessageBody()): + await self.server.to(data["client_id"]).emit( + "private_message", + {"text": data["text"]}, + ) +``` + +## Guards + +`@UseGuards` works on gateway classes and individual message handlers. + +```python +from nest.core import BaseGuard, UseGuards +from nest.websockets import SubscribeMessage, WebSocketGateway + + +class WsTokenGuard(BaseGuard): + async def can_activate(self, context): + ws = context.switch_to_ws() + return ws.get_client().headers.get("x-token") == "secret" + + +@WebSocketGateway(namespace="/private") +@UseGuards(WsTokenGuard) +class PrivateGateway: + @SubscribeMessage("secret") + async def secret(self): + return {"event": "secret_ack", "data": {}} +``` + +Use handler-level guards for event-specific authorization: + +```python +class AdminEventGuard(BaseGuard): + async def can_activate(self, context): + ws = context.switch_to_ws() + data = ws.get_data() + return data.get("role") == "admin" + + +@WebSocketGateway(namespace="/admin") +class AdminGateway: + @SubscribeMessage("delete_message") + @UseGuards(AdminEventGuard) + async def delete_message(self, data=MessageBody()): + return {"event": "deleted", "data": {"id": data["id"]}} +``` + +`context.switch_to_ws()` returns a `WsArgumentsHost`: + +| Method | Returns | +| --- | --- | +| `get_client()` | Active FastAPI `WebSocket`. | +| `get_data()` | Message `data` payload. | +| `get_event()` | Message event name. | +| `get_server()` | Gateway `WebSocketServer`. | + +When a guard returns `False`, PyNest sends: + +```json +{"event": "error", "data": {"message": "Access denied: insufficient permissions"}} +``` + +Then it closes the socket with WebSocket close code `1008`. + +## Error Frames + +PyNest sends structured error frames for dispatcher-level errors: + +| Error | Frame | +| --- | --- | +| Non-object message | `{"event": "error", "data": {"message": "WebSocket message must be a JSON object"}}` | +| Missing `event` | `{"event": "error", "data": {"message": "WebSocket message is missing an event"}}` | +| Unknown event | `{"event": "error", "data": {"message": "No handler for WebSocket event ''"}}` | +| Guard denial | Error frame, then close code `1008`. | +| Unhandled handler exception | Error frame, then close code `1011`. | + +Application handlers can also send domain-specific errors directly: + +```python +@SubscribeMessage("join_room") +async def join_room(self, data=MessageBody(), client=ConnectedSocket()): + if "room" not in data: + await client.send_json({ + "event": "join_error", + "data": {"message": "room is required"}, + }) + return None +``` + +## Streaming + +For token streaming or progress updates, inject the socket and send frames manually: + +```python +@WebSocketGateway(namespace="/ai") +class AgentGateway: + def __init__(self, llm_service: LlmService): + self.llm_service = llm_service + + @SubscribeMessage("prompt") + async def handle_prompt(self, data=MessageBody(), client=ConnectedSocket()): + async for token in self.llm_service.stream(data["prompt"]): + await client.send_json({"event": "token", "data": token}) + + return {"event": "done", "data": {}} +``` + +Client conversation: + +```json +{"event": "prompt", "data": {"prompt": "Write a title"}} +``` + +Frames: + +```json +{"event": "token", "data": "Real"} +{"event": "token", "data": "-time"} +{"event": "done", "data": {}} +``` + +## CLI Generation + +Generate a gateway file: + +```bash +pynest generate gateway --name chat +``` + +or: + +```bash +pynest generate gateway -n chat +``` + +The command creates `chat_gateway.py` in the current directory. Use `--path` to choose a directory: + +```bash +pynest generate gateway -n chat --path src/chat +``` + +Generated file: + +```python +from nest.websockets import MessageBody, SubscribeMessage, WebSocketGateway + + +@WebSocketGateway(namespace="/chat") +class ChatGateway: + @SubscribeMessage("ping") + async def ping(self, data=MessageBody()): + return {"event": "pong", "data": data} +``` + +Register the generated gateway in the module where it belongs: + +```python +from nest.core import Module +from .chat_gateway import ChatGateway + + +@Module(providers=[ChatGateway]) +class ChatModule: + pass +``` + +## Testing Gateways + +The project test suite uses `uvicorn` and the `websockets` package to test a real WebSocket server. In your own application tests, the shape is: + +```python +import asyncio +import json +import uvicorn +import websockets + + +async def test_chat_gateway(http_server, port): + config = uvicorn.Config( + http_server, + host="127.0.0.1", + port=port, + log_level="critical", + lifespan="off", + ) + server = uvicorn.Server(config) + task = asyncio.create_task(server.serve()) + + try: + while not server.started: + await asyncio.sleep(0.01) + + async with websockets.connect(f"ws://127.0.0.1:{port}/chat") as socket: + await socket.send( + json.dumps({"event": "send_message", "data": {"text": "hello"}}) + ) + response = json.loads(await socket.recv()) + + assert response["event"] == "message_ack" + finally: + server.should_exit = True + await task +``` + +For unit tests, instantiate `NativeWebSocketGateway` with a gateway instance and fake socket object, then call `dispatch_message()`. + +## Complete Example + +```python +from nest.core import BaseGuard, Injectable, Module, PyNestFactory, UseGuards +from nest.websockets import ( + ConnectedSocket, + MessageBody, + OnGatewayConnection, + SubscribeMessage, + WebSocketGateway, +) + + +@Injectable +class ChatService: + def save(self, room: str, text: str) -> dict: + return {"room": room, "text": text, "status": "saved"} + + +class WsTokenGuard(BaseGuard): + async def can_activate(self, context): + ws = context.switch_to_ws() + return ws.get_client().headers.get("x-token") == "secret" + + +@WebSocketGateway(namespace="/chat") +@UseGuards(WsTokenGuard) +class ChatGateway(OnGatewayConnection): + def __init__(self, chat_service: ChatService): + self.chat_service = chat_service + + async def on_connection(self, client): + await client.send_json({"event": "connected", "data": {}}) + + @SubscribeMessage("join") + async def join(self, room=MessageBody("room"), client=ConnectedSocket()): + await self.server.join(client, room) + return {"event": "joined", "data": {"room": room}} + + @SubscribeMessage("message") + async def message(self, data=MessageBody()): + saved = self.chat_service.save(data["room"], data["text"]) + await self.server.to(data["room"]).emit("message", saved) + return {"event": "message_ack", "data": saved} + + +@Module(providers=[ChatService, ChatGateway]) +class ChatModule: + pass + + +app = PyNestFactory.create(ChatModule) +http_server = app.get_server() +``` + +Run the app with Uvicorn and connect to `ws://localhost:8000/chat`. + +## Current Limitations + +The first WebSocket gateway implementation intentionally focuses on the native FastAPI transport. These issue items are extension points for future work: + +* Socket.IO adapter support is not implemented yet. +* WebSocket-specific exception filters are not implemented because the project does not yet have shared exception-filter infrastructure. +* WebSocket interceptors are not implemented because the project does not yet have shared interceptor infrastructure. +* `port` and `options` are accepted by `@WebSocketGateway` for API compatibility, but native gateways are mounted on the existing FastAPI app and use the app's server port. + +## API Reference + +| Symbol | Purpose | +| --- | --- | +| `WebSocketGateway(namespace="/ws", port=None, options=None)` | Decorates a provider class as a WebSocket gateway. | +| `SubscribeMessage(event)` | Decorates a gateway method as a handler for one event name. | +| `MessageBody(key=None)` | Injects the incoming message `data`, or one key from it. | +| `ConnectedSocket()` | Injects the active FastAPI `WebSocket`. | +| `OnGatewayInit` | Optional interface for `after_init(server)`. | +| `OnGatewayConnection` | Optional interface for `on_connection(client)`. | +| `OnGatewayDisconnect` | Optional interface for `on_disconnect(client)`. | +| `WebSocketServer` | Tracks connected clients, rooms, broadcast, and direct emits. | +| `ExecutionContext` | Guard context for WebSocket events. | +| `WsArgumentsHost` | WebSocket-specific argument host returned by `context.switch_to_ws()`. | + diff --git a/mkdocs.yml b/mkdocs.yml index 72f0903..cd6a97d 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -58,6 +58,7 @@ nav: - Providers: providers.md - Guards: guards.md - Exception Filters: exception_filters.md + - WebSockets: websockets.md - Dependency Injection: dependency_injection.md - Deployment: - Docker: docker.md @@ -66,4 +67,4 @@ nav: - Sync ORM Application: sync_orm.md - Async ORM Application: async_orm.md - MongoDB Application: mongodb.md - - License: license.md \ No newline at end of file + - License: license.md diff --git a/nest/cli/src/generate/generate_controller.py b/nest/cli/src/generate/generate_controller.py index 381dc84..0cd089d 100644 --- a/nest/cli/src/generate/generate_controller.py +++ b/nest/cli/src/generate/generate_controller.py @@ -27,6 +27,10 @@ def generate_controller(self, name: SharedOptions.NAME, path: SharedOptions.PATH def generate_service(self, name: SharedOptions.NAME, path: SharedOptions.PATH): self.generate_service.generate_service(name, path) + @CliCommand("gateway", help="Generate a new nest WebSocket gateway") + def generate_gateway(self, name: SharedOptions.NAME, path: SharedOptions.PATH): + self.generate_service.generate_gateway(name, path) + @CliCommand("module", help="Generate a new nest module") def generate_module(self, name: SharedOptions.NAME): self.generate_service.generate_module(name) diff --git a/nest/cli/src/generate/generate_service.py b/nest/cli/src/generate/generate_service.py index 5dd2377..6336a8a 100644 --- a/nest/cli/src/generate/generate_service.py +++ b/nest/cli/src/generate/generate_service.py @@ -97,6 +97,19 @@ def generate_service(self, name: str, path: str = None): with open(f"{path}/{name}_service.py", "w") as f: f.write(template.generate_empty_service_file()) + def generate_gateway(self, name: str, path: str = None): + """ + Create a new nest WebSocket gateway. + + :param name: The name of the gateway + :param path: The path where the gateway file will be created + """ + template = self.get_template(name) + if path is None: + path = Path.cwd() + with open(f"{path}/{name}_gateway.py", "w") as f: + f.write(template.generate_empty_gateway_file()) + def generate_module(self, name: str, path: str = None): """ Create a new nest module diff --git a/nest/cli/templates/base_template.py b/nest/cli/templates/base_template.py index dbdc9ae..99f69d9 100644 --- a/nest/cli/templates/base_template.py +++ b/nest/cli/templates/base_template.py @@ -382,6 +382,17 @@ class {self.capitalized_module_name}Service: ... """ + def generate_empty_gateway_file(self) -> str: + return f"""from nest.websockets import MessageBody, SubscribeMessage, WebSocketGateway + + +@WebSocketGateway(namespace="/{self.module_name}") +class {self.capitalized_module_name}Gateway: + @SubscribeMessage("ping") + async def ping(self, data=MessageBody()): + return {{"event": "pong", "data": data}} +""" + def generate_empty_module_file(self) -> str: return f"""from nest.core import Module diff --git a/nest/common/route_resolver.py b/nest/common/route_resolver.py index c3dc509..6b5fefd 100644 --- a/nest/common/route_resolver.py +++ b/nest/common/route_resolver.py @@ -1,7 +1,7 @@ from __future__ import annotations import inspect -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from fastapi import APIRouter, FastAPI, Request @@ -11,8 +11,8 @@ class RoutesResolver: """ - Walks the module graph, resolves controller instances from the container, - and registers their bound methods as FastAPI route endpoints. + Walks the module graph, resolves controller and gateway instances from the + container, and registers their bound methods on the FastAPI app. """ def __init__(self, container: "PyNestContainer", app_ref: FastAPI) -> None: @@ -20,14 +20,28 @@ def __init__(self, container: "PyNestContainer", app_ref: FastAPI) -> None: self.app_ref = app_ref def register_routes(self) -> None: - seen: set = set() + seen_controllers: set = set() + seen_gateways: set = set() + for module_ref in self.container.modules.values(): for controller_class in module_ref.compiled.controllers: - if controller_class in seen: + if controller_class in seen_controllers: continue - seen.add(controller_class) + seen_controllers.add(controller_class) self._register_controller(controller_class) + for provider in module_ref.compiled.provider_descriptors: + gateway_class = provider.use_class + if gateway_class is None or not hasattr( + gateway_class, "__websocket_gateway__" + ): + continue + if gateway_class in seen_gateways: + continue + seen_gateways.add(gateway_class) + gateway_instance = self.container.get(provider.provide) + self._register_gateway(gateway_class, gateway_instance) + def _register_controller(self, controller_class: type) -> None: instance = self.container.get_controller_instance(controller_class) tag = getattr(controller_class, "__controller_tag__", None) @@ -35,7 +49,9 @@ def _register_controller(self, controller_class: type) -> None: router = APIRouter(tags=[tag] if tag else None) - for method_name, unbound in inspect.getmembers(controller_class, predicate=callable): + for method_name, unbound in inspect.getmembers( + controller_class, predicate=callable + ): if not hasattr(unbound, "__http_method__"): continue bound = getattr(instance, method_name) @@ -43,6 +59,14 @@ def _register_controller(self, controller_class: type) -> None: self.app_ref.include_router(router) + def _register_gateway(self, gateway_class: type, gateway_instance: Any) -> None: + from nest.websockets.gateway import NativeWebSocketGateway + + NativeWebSocketGateway( + gateway=gateway_instance, + metadata=getattr(gateway_class, "__websocket_gateway__"), + ).register(self.app_ref) + def _add_route( self, router: APIRouter, diff --git a/nest/websockets/__init__.py b/nest/websockets/__init__.py new file mode 100644 index 0000000..62df6e4 --- /dev/null +++ b/nest/websockets/__init__.py @@ -0,0 +1,24 @@ +from nest.websockets.context import ExecutionContext, WsArgumentsHost +from nest.websockets.decorators import ( + ConnectedSocket, + MessageBody, + OnGatewayConnection, + OnGatewayDisconnect, + OnGatewayInit, + SubscribeMessage, + WebSocketGateway, +) +from nest.websockets.server import WebSocketServer + +__all__ = [ + "ConnectedSocket", + "ExecutionContext", + "MessageBody", + "OnGatewayConnection", + "OnGatewayDisconnect", + "OnGatewayInit", + "SubscribeMessage", + "WebSocketGateway", + "WebSocketServer", + "WsArgumentsHost", +] diff --git a/nest/websockets/context.py b/nest/websockets/context.py new file mode 100644 index 0000000..827f180 --- /dev/null +++ b/nest/websockets/context.py @@ -0,0 +1,60 @@ +from typing import Any + + +class WsArgumentsHost: + def __init__( + self, + client: Any, + data: Any, + event: str, + server: Any, + ): + self._client = client + self._data = data + self._event = event + self._server = server + + def get_client(self) -> Any: + return self._client + + def get_data(self) -> Any: + return self._data + + def get_event(self) -> str: + return self._event + + def get_server(self) -> Any: + return self._server + + +class ExecutionContext: + def __init__( + self, + *, + client: Any, + data: Any, + event: str, + server: Any, + gateway: Any, + handler: Any, + ): + self._ws_host = WsArgumentsHost( + client=client, + data=data, + event=event, + server=server, + ) + self._gateway = gateway + self._handler = handler + + def get_type(self) -> str: + return "ws" + + def get_class(self) -> Any: + return self._gateway.__class__ + + def get_handler(self) -> Any: + return self._handler + + def switch_to_ws(self) -> WsArgumentsHost: + return self._ws_host diff --git a/nest/websockets/decorators.py b/nest/websockets/decorators.py new file mode 100644 index 0000000..999d8bc --- /dev/null +++ b/nest/websockets/decorators.py @@ -0,0 +1,102 @@ +import inspect +from dataclasses import dataclass +from typing import Any, Callable, Dict, Optional, Type + +from injector import inject + +from nest.common.constants import INJECTABLE_NAME, INJECTABLE_TOKEN +from nest.common.provider import Scope + +WEBSOCKET_GATEWAY_METADATA = "__websocket_gateway__" +WEBSOCKET_MESSAGE_EVENT = "__ws_message_event__" + + +@dataclass(frozen=True) +class WebSocketParam: + source: str + key: Optional[str] = None + + +def MessageBody(key: Optional[str] = None) -> WebSocketParam: + return WebSocketParam(source="body", key=key) + + +def ConnectedSocket() -> WebSocketParam: + return WebSocketParam(source="socket") + + +def SubscribeMessage(event: str) -> Callable: + def decorator(func: Callable) -> Callable: + setattr(func, WEBSOCKET_MESSAGE_EVENT, event) + setattr(func, "__signature__", inspect.signature(func)) + return func + + return decorator + + +def WebSocketGateway( + target_class: Optional[Type] = None, + *, + port: Optional[int] = None, + namespace: str = "/ws", + options: Optional[Dict[str, Any]] = None, + scope: Scope = Scope.SINGLETON, +) -> Callable: + if isinstance(target_class, str): + namespace = target_class + target_class = None + elif isinstance(target_class, int): + port = target_class + target_class = None + + def decorator(decorated_class: Type) -> Type: + if "__init__" not in decorated_class.__dict__: + + def init_method(self, *args, **kwargs): + pass + + decorated_class.__init__ = init_method + + metadata = { + "namespace": normalize_namespace(namespace), + "port": port, + "options": options or {}, + } + + own_init = decorated_class.__dict__.get("__init__") + if own_init is not None and getattr(own_init, "__annotations__", None): + inject(decorated_class) + + setattr(decorated_class, WEBSOCKET_GATEWAY_METADATA, metadata) + setattr(decorated_class, INJECTABLE_TOKEN, True) + setattr(decorated_class, INJECTABLE_NAME, decorated_class.__name__) + setattr(decorated_class, "__injectable_scope__", scope) + + return decorated_class + + if target_class is not None: + return decorator(target_class) + + return decorator + + +def normalize_namespace(namespace: Optional[str]) -> str: + if not namespace: + return "/ws" + if not namespace.startswith("/"): + namespace = f"/{namespace}" + if namespace != "/" and namespace.endswith("/"): + namespace = namespace.rstrip("/") + return namespace + + +class OnGatewayInit: + async def after_init(self, server: Any) -> None: ... + + +class OnGatewayConnection: + async def on_connection(self, client: Any, *args: Any) -> None: ... + + +class OnGatewayDisconnect: + async def on_disconnect(self, client: Any) -> None: ... diff --git a/nest/websockets/gateway.py b/nest/websockets/gateway.py new file mode 100644 index 0000000..c0b7167 --- /dev/null +++ b/nest/websockets/gateway.py @@ -0,0 +1,202 @@ +import inspect +from json import JSONDecodeError +from typing import Any, Callable, Dict, Iterable + +from fastapi import FastAPI, WebSocket +from fastapi.encoders import jsonable_encoder +from pydantic import BaseModel +from starlette.websockets import WebSocketDisconnect + +from nest.websockets.context import ExecutionContext +from nest.websockets.decorators import ( + WEBSOCKET_MESSAGE_EVENT, + WebSocketParam, +) +from nest.websockets.server import WebSocketServer + + +class NativeWebSocketGateway: + def __init__( + self, + gateway: Any, + metadata: Dict[str, Any], + server: WebSocketServer = None, + ): + self.gateway = gateway + self.metadata = metadata + self.server = server or WebSocketServer() + self.handlers = self.discover_handlers() + self._initialized = False + setattr(self.gateway, "server", self.server) + + def register(self, app_ref: FastAPI) -> None: + async def endpoint(websocket: WebSocket): + await self.handle_connection(websocket) + + app_ref.add_api_websocket_route(self.metadata["namespace"], endpoint) + + async def handle_connection(self, websocket: WebSocket) -> None: + await self.ensure_initialized() + await websocket.accept() + await self.server.connect(websocket) + try: + await self.run_lifecycle_hook("on_connection", websocket) + while True: + message = await websocket.receive_json() + await self.dispatch_message(websocket, message) + except WebSocketDisconnect: + pass + except JSONDecodeError: + await self.send_error(websocket, "Invalid JSON payload") + finally: + await self.run_lifecycle_hook("on_disconnect", websocket) + await self.server.disconnect(websocket) + + async def ensure_initialized(self) -> None: + if self._initialized: + return + await self.run_lifecycle_hook("after_init", self.server) + self._initialized = True + + async def run_lifecycle_hook(self, hook_name: str, *args: Any) -> None: + hook = getattr(self.gateway, hook_name, None) + if not callable(hook): + return + result = hook(*args) + if inspect.isawaitable(result): + await result + + async def dispatch_message(self, client: Any, message: Dict[str, Any]) -> None: + if not isinstance(message, dict): + await self.send_error(client, "WebSocket message must be a JSON object") + return + + event = message.get("event") + if not event: + await self.send_error(client, "WebSocket message is missing an event") + return + + handler = self.handlers.get(event) + if handler is None: + await self.send_error(client, f"No handler for WebSocket event '{event}'") + return + + can_activate = await self.run_guards(handler, client, message) + if not can_activate: + await self.send_error( + client, + "Access denied: insufficient permissions", + close_code=1008, + ) + return + + try: + kwargs = self.resolve_handler_arguments(handler, client, message) + result = handler(**kwargs) + if inspect.isawaitable(result): + result = await result + except Exception: + await self.send_error(client, "Unhandled WebSocket handler error", 1011) + return + + if result is not None: + await client.send_json(self.format_response(event, result)) + + async def run_guards( + self, + handler: Callable, + client: Any, + message: Dict[str, Any], + ) -> bool: + for guard_class in self.collect_guards(handler): + guard = guard_class() if inspect.isclass(guard_class) else guard_class + context = ExecutionContext( + client=client, + data=message.get("data"), + event=message.get("event"), + server=self.server, + gateway=self.gateway, + handler=handler, + ) + result = guard.can_activate(context) + if inspect.isawaitable(result): + result = await result + if not result: + return False + return True + + def resolve_handler_arguments( + self, + handler: Callable, + client: Any, + message: Dict[str, Any], + ) -> Dict[str, Any]: + signature = inspect.signature(handler) + kwargs = {} + data = message.get("data") + + for name, parameter in signature.parameters.items(): + if name == "self": + continue + + default = parameter.default + if isinstance(default, WebSocketParam): + if default.source == "socket": + kwargs[name] = client + elif default.source == "body": + body = self.extract_body(data, default.key) + kwargs[name] = self.coerce_value(body, parameter.annotation) + continue + + if parameter.default is inspect.Parameter.empty and name == "data": + kwargs[name] = self.coerce_value(data, parameter.annotation) + + return kwargs + + @staticmethod + def extract_body(data: Any, key: str = None) -> Any: + if key is None: + return data + if isinstance(data, dict): + return data.get(key) + return None + + @staticmethod + def coerce_value(value: Any, annotation: Any) -> Any: + if annotation is inspect.Parameter.empty: + return value + if inspect.isclass(annotation) and issubclass(annotation, BaseModel): + return annotation.model_validate(value) + return value + + def collect_guards(self, handler: Callable) -> Iterable[Any]: + guards = list(getattr(self.gateway.__class__, "__guards__", [])) + func = getattr(handler, "__func__", handler) + guards.extend(getattr(func, "__guards__", [])) + return guards + + def discover_handlers(self) -> Dict[str, Callable]: + handlers = {} + for _, method in inspect.getmembers(self.gateway, predicate=callable): + func = getattr(method, "__func__", method) + event = getattr(func, WEBSOCKET_MESSAGE_EVENT, None) + if event: + handlers[event] = method + return handlers + + @staticmethod + def format_response(event: str, result: Any) -> Dict[str, Any]: + encoded = jsonable_encoder(result) + if isinstance(encoded, dict) and "event" in encoded and "data" in encoded: + return encoded + return {"event": event, "data": encoded} + + @staticmethod + async def send_error( + client: Any, + message: str, + close_code: int = None, + ) -> None: + await client.send_json({"event": "error", "data": {"message": message}}) + if close_code is not None: + await client.close(code=close_code) diff --git a/nest/websockets/server.py b/nest/websockets/server.py new file mode 100644 index 0000000..32d9db4 --- /dev/null +++ b/nest/websockets/server.py @@ -0,0 +1,93 @@ +import uuid +from collections import defaultdict +from typing import Any, Dict, Iterable, Optional, Set + +from fastapi.encoders import jsonable_encoder + + +class WebSocketTarget: + def __init__(self, server: "WebSocketServer", target: str): + self.server = server + self.target = target + + async def emit(self, event: str, data: Any = None) -> None: + clients = self.server.resolve_target(self.target) + await self.server.emit_to_clients(clients, event, data) + + +class WebSocketServer: + def __init__(self): + self.clients: Dict[str, Any] = {} + self.rooms: Dict[str, Set[str]] = defaultdict(set) + self.client_rooms: Dict[str, Set[str]] = defaultdict(set) + + async def connect(self, client: Any) -> str: + client_id = self.get_client_id(client) + if client_id is None: + client_id = str(uuid.uuid4()) + setattr(client.state, "pynest_ws_client_id", client_id) + self.clients[client_id] = client + return client_id + + async def disconnect(self, client: Any) -> None: + client_id = self.get_client_id(client) + if client_id is None: + return + + for room in list(self.client_rooms.get(client_id, set())): + self.rooms[room].discard(client_id) + if not self.rooms[room]: + del self.rooms[room] + + self.client_rooms.pop(client_id, None) + self.clients.pop(client_id, None) + + async def join(self, client: Any, room: str) -> None: + client_id = self.get_client_id(client) + if client_id is None: + client_id = await self.connect(client) + self.rooms[room].add(client_id) + self.client_rooms[client_id].add(room) + + async def leave(self, client: Any, room: str) -> None: + client_id = self.get_client_id(client) + if client_id is None: + return + self.rooms[room].discard(client_id) + self.client_rooms[client_id].discard(room) + if not self.rooms[room]: + del self.rooms[room] + + async def emit(self, event: str, data: Any = None) -> None: + await self.emit_to_clients(self.clients.values(), event, data) + + async def broadcast(self, event: str, data: Any = None) -> None: + await self.emit(event, data) + + def to(self, target: str) -> WebSocketTarget: + return WebSocketTarget(self, target) + + def resolve_target(self, target: str) -> Iterable[Any]: + if target in self.rooms: + return [ + self.clients[client_id] + for client_id in self.rooms[target] + if client_id in self.clients + ] + if target in self.clients: + return [self.clients[target]] + return [] + + async def emit_to_clients( + self, + clients: Iterable[Any], + event: str, + data: Any = None, + ) -> None: + payload = {"event": event, "data": jsonable_encoder(data)} + for client in list(clients): + await client.send_json(payload) + + @staticmethod + def get_client_id(client: Any) -> Optional[str]: + return getattr(client.state, "pynest_ws_client_id", None) diff --git a/pyproject.toml b/pyproject.toml index 49f45c4..29f6205 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,6 +61,7 @@ test = [ "beanie>=1.27.0,<2.0.0", "python-dotenv>=1.0.1,<2.0.0", "aiosqlite>=0.19.0,<1.0.0", + "websockets>=13.0,<16.0", ] docs = [ "mkdocs-material>=9.5.43,<10.0.0", diff --git a/tests/test_cli/__init__.py b/tests/test_cli/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/tests/test_cli/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/test_cli/test_generate_gateway.py b/tests/test_cli/test_generate_gateway.py new file mode 100644 index 0000000..4318d82 --- /dev/null +++ b/tests/test_cli/test_generate_gateway.py @@ -0,0 +1,83 @@ +import asyncio +import contextlib +import importlib.util +import json +import socket + +import uvicorn +import websockets + +from nest.cli.src.generate.generate_service import GenerateService +from nest.core import Module, PyNestContainer, PyNestFactory + + +def test_generate_gateway_creates_gateway_file(tmp_path): + GenerateService().generate_gateway("chat", str(tmp_path)) + + gateway_file = tmp_path / "chat_gateway.py" + + assert gateway_file.exists() + assert ( + "from nest.websockets import MessageBody, SubscribeMessage, WebSocketGateway" + in gateway_file.read_text() + ) + assert '@WebSocketGateway(namespace="/chat")' in gateway_file.read_text() + assert '@SubscribeMessage("ping")' in gateway_file.read_text() + + +def _free_port(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return sock.getsockname()[1] + + +@contextlib.asynccontextmanager +async def _run_server(app, port): + config = uvicorn.Config( + app, host="127.0.0.1", port=port, log_level="critical", lifespan="off" + ) + server = uvicorn.Server(config) + task = asyncio.create_task(server.serve()) + for _ in range(100): + if server.started: + break + await asyncio.sleep(0.01) + try: + yield + finally: + server.should_exit = True + await task + + +def _load_module_from_path(name, path): + spec = importlib.util.spec_from_file_location(name, path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def test_generated_gateway_runs_as_real_websocket_app(tmp_path): + GenerateService().generate_gateway("chat", str(tmp_path)) + + chat_module = _load_module_from_path("chat_gateway", tmp_path / "chat_gateway.py") + ChatGateway = chat_module.ChatGateway + + @Module(providers=[ChatGateway]) + class ChatAppModule: + pass + + async def scenario(): + PyNestContainer._instance = None + app = PyNestFactory.create(ChatAppModule).get_server() + port = _free_port() + + async with _run_server(app, port): + async with websockets.connect(f"ws://127.0.0.1:{port}/chat") as ws: + await ws.send( + json.dumps({"event": "ping", "data": {"hello": "world"}}) + ) + response = json.loads(await ws.recv()) + + assert response == {"event": "pong", "data": {"hello": "world"}} + + asyncio.run(scenario()) diff --git a/tests/test_websockets/__init__.py b/tests/test_websockets/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/tests/test_websockets/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/test_websockets/test_decorators.py b/tests/test_websockets/test_decorators.py new file mode 100644 index 0000000..d140623 --- /dev/null +++ b/tests/test_websockets/test_decorators.py @@ -0,0 +1,34 @@ +from nest.common.constants import INJECTABLE_TOKEN +from nest.common.provider import Scope +from nest.websockets import ( + ConnectedSocket, + MessageBody, + SubscribeMessage, + WebSocketGateway, +) + + +@WebSocketGateway(namespace="chat") +class ChatGateway: + def __init__(self): ... + + @SubscribeMessage("ping") + async def ping(self, data=MessageBody(), client=ConnectedSocket()): + return {"event": "pong", "data": data} + + +def test_websocket_gateway_sets_metadata_and_marks_injectable(): + assert ChatGateway.__websocket_gateway__["namespace"] == "/chat" + assert ChatGateway.__websocket_gateway__["options"] == {} + assert getattr(ChatGateway, INJECTABLE_TOKEN) is True + assert ChatGateway.__injectable_scope__ == Scope.SINGLETON + + +def test_subscribe_message_sets_event_metadata(): + assert ChatGateway.ping.__ws_message_event__ == "ping" + + +def test_parameter_markers_identify_message_sources(): + signature = ChatGateway.ping.__signature__ + assert signature.parameters["data"].default.source == "body" + assert signature.parameters["client"].default.source == "socket" diff --git a/tests/test_websockets/test_gateway_router.py b/tests/test_websockets/test_gateway_router.py new file mode 100644 index 0000000..ba80358 --- /dev/null +++ b/tests/test_websockets/test_gateway_router.py @@ -0,0 +1,106 @@ +from types import SimpleNamespace + +import pytest + +from nest.core import BaseGuard, UseGuards +from nest.websockets import ( + ConnectedSocket, + MessageBody, + SubscribeMessage, + WebSocketGateway, + WebSocketServer, +) +from nest.websockets.gateway import NativeWebSocketGateway + + +class FakeWebSocket: + def __init__(self): + self.sent = [] + self.closed = None + self.state = SimpleNamespace() + self.headers = {"x-token": "secret"} + + async def send_json(self, message): + self.sent.append(message) + + async def close(self, code=1000): + self.closed = code + + +class AllowGuard(BaseGuard): + seen_payloads = [] + + async def can_activate(self, context): + ws = context.switch_to_ws() + self.seen_payloads.append(ws.get_data()) + return ws.get_client().headers["x-token"] == "secret" + + +@WebSocketGateway(namespace="/chat") +@UseGuards(AllowGuard) +class ChatGateway: + @SubscribeMessage("echo") + async def echo(self, data=MessageBody(), client=ConnectedSocket()): + return { + "event": "echo_ack", + "data": {"payload": data, "has_socket": client is not None}, + } + + +@pytest.mark.anyio +async def test_native_gateway_dispatches_message_with_guards_and_markers(): + AllowGuard.seen_payloads = [] + gateway = ChatGateway() + router = NativeWebSocketGateway( + gateway=gateway, + metadata=ChatGateway.__websocket_gateway__, + server=WebSocketServer(), + ) + client = FakeWebSocket() + + await router.dispatch_message( + client, + {"event": "echo", "data": {"text": "hello"}}, + ) + + assert AllowGuard.seen_payloads == [{"text": "hello"}] + assert client.sent == [ + { + "event": "echo_ack", + "data": {"payload": {"text": "hello"}, "has_socket": True}, + } + ] + assert client.closed is None + + +class DenyGuard(BaseGuard): + async def can_activate(self, context): + return False + + +@WebSocketGateway(namespace="/private") +class PrivateGateway: + @SubscribeMessage("secret") + @UseGuards(DenyGuard) + async def secret(self): + return {"event": "secret_ack", "data": {}} + + +@pytest.mark.anyio +async def test_native_gateway_closes_when_guard_denies_message(): + router = NativeWebSocketGateway( + gateway=PrivateGateway(), + metadata=PrivateGateway.__websocket_gateway__, + server=WebSocketServer(), + ) + client = FakeWebSocket() + + await router.dispatch_message(client, {"event": "secret", "data": {}}) + + assert client.sent == [ + { + "event": "error", + "data": {"message": "Access denied: insufficient permissions"}, + } + ] + assert client.closed == 1008 diff --git a/tests/test_websockets/test_integration.py b/tests/test_websockets/test_integration.py new file mode 100644 index 0000000..369224c --- /dev/null +++ b/tests/test_websockets/test_integration.py @@ -0,0 +1,236 @@ +import asyncio +import contextlib +import json +import socket + +import uvicorn +import websockets +import httpx + +from nest.core import ( + Controller, + Get, + Injectable, + Module, + PyNestContainer, + PyNestFactory, +) +from nest.websockets import ( + ConnectedSocket, + MessageBody, + OnGatewayConnection, + OnGatewayDisconnect, + OnGatewayInit, + SubscribeMessage, + WebSocketGateway, +) + + +def get_free_port(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return sock.getsockname()[1] + + +@contextlib.asynccontextmanager +async def run_server(app, port): + config = uvicorn.Config( + app, + host="127.0.0.1", + port=port, + log_level="critical", + lifespan="off", + ) + server = uvicorn.Server(config) + task = asyncio.create_task(server.serve()) + for _ in range(100): + if server.started: + break + await asyncio.sleep(0.01) + try: + yield + finally: + server.should_exit = True + await task + + +def reset_container(): + PyNestContainer._instance = None + + +def test_websocket_gateway_e2e_with_async_client(): + async def scenario(): + reset_container() + events = [] + + @Injectable + class ChatService: + def acknowledge(self, payload): + return {"text": payload["text"], "status": "delivered"} + + @WebSocketGateway(namespace="/chat") + class ChatGateway( + OnGatewayInit, + OnGatewayConnection, + OnGatewayDisconnect, + ): + def __init__(self, chat_service: ChatService): + self.chat_service = chat_service + + async def after_init(self, server): + self.server = server + events.append("init") + + async def on_connection(self, client): + events.append("connect") + await client.send_json({"event": "connected", "data": {}}) + + async def on_disconnect(self, client): + events.append("disconnect") + + @SubscribeMessage("send_message") + async def handle_message(self, data=MessageBody()): + return { + "event": "message_ack", + "data": self.chat_service.acknowledge(data), + } + + @Module(providers=[ChatService, ChatGateway]) + class ChatModule: + pass + + app = PyNestFactory.create(ChatModule).get_server() + port = get_free_port() + + async with run_server(app, port): + async with websockets.connect(f"ws://127.0.0.1:{port}/chat") as websocket: + connected = json.loads(await websocket.recv()) + await websocket.send( + json.dumps({"event": "send_message", "data": {"text": "hello"}}) + ) + ack = json.loads(await websocket.recv()) + + assert connected == {"event": "connected", "data": {}} + assert ack == { + "event": "message_ack", + "data": {"text": "hello", "status": "delivered"}, + } + assert events == ["init", "connect", "disconnect"] + + asyncio.run(scenario()) + + +def test_websocket_gateway_supports_token_streaming_pattern(): + async def scenario(): + reset_container() + + @Injectable + class LlmService: + async def stream(self, prompt): + for token in ["hel", "lo"]: + yield f"{prompt}:{token}" + + @WebSocketGateway(namespace="/ai") + class AgentGateway: + def __init__(self, llm_service: LlmService): + self.llm_service = llm_service + + @SubscribeMessage("prompt") + async def handle_prompt( + self, + data=MessageBody(), + client=ConnectedSocket(), + ): + async for token in self.llm_service.stream(data["prompt"]): + await client.send_json({"event": "token", "data": token}) + return {"event": "done", "data": {}} + + @Module(providers=[LlmService, AgentGateway]) + class AgentModule: + pass + + app = PyNestFactory.create(AgentModule).get_server() + port = get_free_port() + + async with run_server(app, port): + async with websockets.connect(f"ws://127.0.0.1:{port}/ai") as websocket: + await websocket.send( + json.dumps({"event": "prompt", "data": {"prompt": "say"}}) + ) + frames = [json.loads(await websocket.recv()) for _ in range(3)] + + assert frames == [ + {"event": "token", "data": "say:hel"}, + {"event": "token", "data": "say:lo"}, + {"event": "done", "data": {}}, + ] + + asyncio.run(scenario()) + + +def test_websocket_gateway_and_http_controller_share_provider_state(): + async def scenario(): + reset_container() + + @Injectable + class EventStore: + def __init__(self): + self.events = [] + + def append(self, payload): + event = {"id": len(self.events) + 1, "payload": payload} + self.events.append(event) + return event + + def all(self): + return self.events + + @WebSocketGateway(namespace="/events") + class EventsGateway: + def __init__(self, event_store: EventStore): + self.event_store = event_store + + @SubscribeMessage("record") + async def record(self, data=MessageBody()): + event = self.event_store.append(data) + return {"event": "recorded", "data": event} + + @Controller("/events", tag="events") + class EventsController: + def __init__(self, event_store: EventStore): + self.event_store = event_store + + @Get("/received") + def received(self): + return {"events": self.event_store.all()} + + @Module( + controllers=[EventsController], + providers=[EventStore, EventsGateway], + ) + class EventsModule: + pass + + app = PyNestFactory.create(EventsModule).get_server() + port = get_free_port() + + async with run_server(app, port): + async with websockets.connect(f"ws://127.0.0.1:{port}/events") as websocket: + await websocket.send( + json.dumps({"event": "record", "data": {"kind": "created"}}) + ) + ack = json.loads(await websocket.recv()) + + async with httpx.AsyncClient(base_url=f"http://127.0.0.1:{port}") as client: + response = await client.get("/events/received") + + assert ack == { + "event": "recorded", + "data": {"id": 1, "payload": {"kind": "created"}}, + } + assert response.status_code == 200 + assert response.json() == { + "events": [{"id": 1, "payload": {"kind": "created"}}] + } + + asyncio.run(scenario()) diff --git a/tests/test_websockets/test_server.py b/tests/test_websockets/test_server.py new file mode 100644 index 0000000..fd805a1 --- /dev/null +++ b/tests/test_websockets/test_server.py @@ -0,0 +1,52 @@ +from types import SimpleNamespace + +import pytest + +from nest.websockets import WebSocketServer + + +class FakeWebSocket: + def __init__(self): + self.sent = [] + self.state = SimpleNamespace() + + async def send_json(self, message): + self.sent.append(message) + + +@pytest.mark.anyio +async def test_websocket_server_emits_to_all_clients(): + server = WebSocketServer() + first = FakeWebSocket() + second = FakeWebSocket() + await server.connect(first) + await server.connect(second) + + await server.emit("update", {"count": 1}) + + assert first.sent == [{"event": "update", "data": {"count": 1}}] + assert second.sent == [{"event": "update", "data": {"count": 1}}] + + +@pytest.mark.anyio +async def test_websocket_server_targets_rooms_and_clients(): + server = WebSocketServer() + first = FakeWebSocket() + second = FakeWebSocket() + first_id = await server.connect(first) + await server.connect(second) + await server.join(first, "room-a") + + await server.to("room-a").emit("room_event", {"ok": True}) + await server.to(first_id).emit("direct", {"id": first_id}) + + assert first.sent == [ + {"event": "room_event", "data": {"ok": True}}, + {"event": "direct", "data": {"id": first_id}}, + ] + assert second.sent == [] + + await server.leave(first, "room-a") + await server.to("room-a").emit("room_event", {"ok": False}) + + assert len(first.sent) == 2