diff --git a/ami/ml/management/commands/check_dead_letter_queue.py b/ami/ml/management/commands/check_dead_letter_queue.py new file mode 100644 index 000000000..8fcec9986 --- /dev/null +++ b/ami/ml/management/commands/check_dead_letter_queue.py @@ -0,0 +1,48 @@ +""" +Management command to check dead letter queue messages for a job. + +Usage: + python manage.py check_dead_letter_queue + +Example: + python manage.py check_dead_letter_queue 123 +""" + +from asgiref.sync import async_to_sync +from django.core.management.base import BaseCommand, CommandError + +from ami.ml.orchestration.nats_queue import TaskQueueManager + + +class Command(BaseCommand): + help = "Check dead letter queue messages for a job ID" + + def add_arguments(self, parser): + parser.add_argument( + "job_id", + type=int, + help="Job ID to check for dead letter queue messages", + ) + + def handle(self, *args, **options): + job_id = options["job_id"] + + try: + dead_letter_ids = async_to_sync(self._check_dead_letter_queue)(job_id) + + if dead_letter_ids: + self.stdout.write( + self.style.WARNING(f"Found {len(dead_letter_ids)} dead letter image(s) for job {job_id}:") + ) + for image_id in dead_letter_ids: + self.stdout.write(f" - Image ID: {image_id}") + else: + self.stdout.write(self.style.SUCCESS(f"No dead letter images found for job {job_id}")) + + except Exception as e: + raise CommandError(f"Failed to check dead letter queue: {e}") + + async def _check_dead_letter_queue(self, job_id: int) -> list[str]: + """Check for dead letter queue messages using TaskQueueManager.""" + async with TaskQueueManager() as manager: + return await manager.get_dead_letter_image_ids(job_id) diff --git a/ami/ml/orchestration/nats_queue.py b/ami/ml/orchestration/nats_queue.py index 884676637..b6e9af254 100644 --- a/ami/ml/orchestration/nats_queue.py +++ b/ami/ml/orchestration/nats_queue.py @@ -43,6 +43,7 @@ async def get_connection(nats_url: str) -> tuple[nats.NATS, JetStreamContext]: TASK_TTR = getattr(settings, "NATS_TASK_TTR", 30) # Visibility timeout in seconds (configurable) +ADVISORY_STREAM_NAME = "advisories" # Shared stream for max delivery advisories across all jobs class TaskQueueManager: @@ -72,6 +73,15 @@ def __init__(self, nats_url: str | None = None, max_ack_pending: int | None = No async def __aenter__(self): """Create connection on enter.""" self.nc, self.js = await get_connection(self.nats_url) + + try: + await self._setup_advisory_stream() + except BaseException: + if self.nc and not self.nc.is_closed: + await self.nc.close() + self.nc = None + self.js = None + raise return self async def __aexit__(self, exc_type, exc_val, exc_tb): @@ -95,7 +105,7 @@ def _get_consumer_name(self, job_id: int) -> str: """Get consumer name from job_id.""" return f"job-{job_id}-consumer" - async def _stream_exists(self, job_id: int) -> bool: + async def _job_stream_exists(self, job_id: int) -> bool: """Check if stream exists for the given job. Only catches NotFoundError (→ False). TimeoutError propagates deliberately @@ -106,6 +116,10 @@ async def _stream_exists(self, job_id: int) -> bool: raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") stream_name = self._get_stream_name(job_id) + return await self._stream_exists(stream_name) + + async def _stream_exists(self, stream_name: str) -> bool: + """Check if a stream with the given name exists.""" try: await asyncio.wait_for(self.js.stream_info(stream_name), timeout=NATS_JETSTREAM_TIMEOUT) return True @@ -117,7 +131,7 @@ async def _ensure_stream(self, job_id: int): if self.js is None: raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") - if not await self._stream_exists(job_id): + if not await self._job_stream_exists(job_id): stream_name = self._get_stream_name(job_id) subject = self._get_subject(job_id) logger.warning(f"Stream {stream_name} does not exist") @@ -218,7 +232,7 @@ async def reserve_tasks(self, job_id: int, count: int, timeout: float = 5) -> li raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") try: - if not await self._stream_exists(job_id): + if not await self._job_stream_exists(job_id): logger.debug(f"Stream for job '{job_id}' does not exist when reserving task") return [] @@ -231,7 +245,7 @@ async def reserve_tasks(self, job_id: int, count: int, timeout: float = 5) -> li try: msgs = await psub.fetch(count, timeout=timeout) - except nats.errors.TimeoutError: + except (asyncio.TimeoutError, nats.errors.TimeoutError): logger.debug(f"No tasks available in stream for job '{job_id}'") return [] finally: @@ -250,7 +264,7 @@ async def reserve_tasks(self, job_id: int, count: int, timeout: float = 5) -> li logger.debug(f"No tasks reserved from stream for job '{job_id}'") return tasks - except asyncio.TimeoutError: + except (asyncio.TimeoutError, nats.errors.TimeoutError): raise # NATS unreachable — propagate so the view can return an appropriate error except Exception as e: logger.error(f"Failed to reserve tasks from stream for job '{job_id}': {e}") @@ -271,6 +285,7 @@ async def acknowledge_task(self, reply_subject: str) -> bool: try: await self.nc.publish(reply_subject, b"+ACK") + await self.nc.flush() logger.debug(f"Acknowledged task with reply subject {reply_subject}") return True except Exception as e: @@ -330,9 +345,134 @@ async def delete_stream(self, job_id: int) -> bool: logger.error(f"Failed to delete stream for job '{job_id}': {e}") return False + async def _setup_advisory_stream(self): + """Ensure the shared advisory stream exists to capture max-delivery events. + + Called on every __aenter__ so that advisories are captured from the moment + any TaskQueueManager connection is opened, not just when the DLQ is first read. + """ + if not await self._stream_exists(ADVISORY_STREAM_NAME): + await asyncio.wait_for( + self.js.add_stream( + name=ADVISORY_STREAM_NAME, + subjects=["$JS.EVENT.ADVISORY.>"], + max_age=3600, # Keep advisories for 1 hour + ), + timeout=NATS_JETSTREAM_TIMEOUT, + ) + logger.info("Advisory stream created") + + def _get_dlq_consumer_name(self, job_id: int) -> str: + """Get the durable consumer name for dead letter queue advisory tracking.""" + return f"job-{job_id}-dlq" + + async def get_dead_letter_image_ids(self, job_id: int, n: int = 10) -> list[str]: + """ + Get image IDs from dead letter queue (messages that exceeded max delivery attempts). + + Pulls from persistent advisory stream to find failed messages, then looks up image IDs. + Uses a durable consumer so acknowledged advisories are not re-delivered on subsequent calls. + + Args: + job_id: The job ID (integer primary key) + n: Maximum number of image IDs to return (default: 10) + + Returns: + List of image IDs that failed to process after max retry attempts + """ + if self.nc is None or self.js is None: + raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") + + stream_name = self._get_stream_name(job_id) + consumer_name = self._get_consumer_name(job_id) + dlq_consumer_name = self._get_dlq_consumer_name(job_id) + dead_letter_ids = [] + + subject_filter = f"$JS.EVENT.ADVISORY.CONSUMER.MAX_DELIVERIES.{stream_name}.{consumer_name}" + + # Use a durable consumer so ACKs persist across calls — ephemeral consumers + # are deleted on unsubscribe, discarding all ACK tracking and causing every + # advisory to be re-delivered on the next call. + psub = await self.js.pull_subscribe(subject_filter, durable=dlq_consumer_name, stream=ADVISORY_STREAM_NAME) + + try: + msgs = await psub.fetch(n, timeout=1.0) + + for msg in msgs: + advisory_data = json.loads(msg.data.decode()) + + # Get the stream sequence of the failed message + if "stream_seq" in advisory_data: + stream_seq = advisory_data["stream_seq"] + + # Look up the actual message by sequence to get task ID + try: + job_msg = await self.js.get_msg(stream_name, stream_seq) + + if job_msg and job_msg.data: + task_data = json.loads(job_msg.data.decode()) + + if "image_id" in task_data: + dead_letter_ids.append(str(task_data["image_id"])) + else: + logger.warning(f"No image_id found in task data: {task_data}") + except Exception as e: + logger.warning(f"Could not retrieve message {stream_seq} from {stream_name}: {e}") + # The message might have been discarded after max_deliver exceeded + else: + logger.warning(f"No stream_seq in advisory data: {advisory_data}") + + # Acknowledge even if we couldn't find the stream_seq or image_id so it doesn't get re-delivered + # it shouldn't happen since stream_seq is part of the `io.nats.jetstream.advisory.v1.max_deliver` + # schema and all our messages have an image_id + await msg.ack() + logger.info( + f"Acknowledged advisory message for stream_seq {advisory_data.get('stream_seq', 'unknown')}" + ) + + # Flush to ensure all ACKs are written to the socket before unsubscribing. + # msg.ack() only queues a publish in the client buffer; without flush() the + # ACKs can be silently dropped when the subscription is torn down. + await self.nc.flush() + + except (asyncio.TimeoutError, nats.errors.TimeoutError): + logger.info(f"No advisory messages found for job {job_id}") + finally: + await psub.unsubscribe() + + return dead_letter_ids[:n] + + async def delete_dlq_consumer(self, job_id: int) -> bool: + """ + Delete the durable DLQ advisory consumer for a job. + + Args: + job_id: The job ID (integer primary key) + + Returns: + bool: True if successful, False otherwise + """ + if self.js is None: + raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.") + + dlq_consumer_name = self._get_dlq_consumer_name(job_id) + try: + await asyncio.wait_for( + self.js.delete_consumer(ADVISORY_STREAM_NAME, dlq_consumer_name), + timeout=NATS_JETSTREAM_TIMEOUT, + ) + logger.info(f"Deleted DLQ consumer {dlq_consumer_name} for job '{job_id}'") + return True + except nats.js.errors.NotFoundError: + logger.debug(f"DLQ consumer {dlq_consumer_name} for job '{job_id}' not found when attempting to delete") + return True # Consider it a success if the consumer is already gone + except Exception as e: + logger.warning(f"Failed to delete DLQ consumer for job '{job_id}': {e}") + return False + async def cleanup_job_resources(self, job_id: int) -> bool: """ - Clean up all NATS resources (consumer and stream) for a job. + Clean up all NATS resources (consumer, stream, and DLQ advisory consumer) for a job. This should be called when a job completes or is cancelled. @@ -342,8 +482,9 @@ async def cleanup_job_resources(self, job_id: int) -> bool: Returns: bool: True if successful, False otherwise """ - # Delete consumer first, then stream + # Delete consumer first, then stream, then the durable DLQ advisory consumer consumer_deleted = await self.delete_consumer(job_id) stream_deleted = await self.delete_stream(job_id) + dlq_consumer_deleted = await self.delete_dlq_consumer(job_id) - return consumer_deleted and stream_deleted + return consumer_deleted and stream_deleted and dlq_consumer_deleted diff --git a/ami/ml/orchestration/tests/test_nats_queue.py b/ami/ml/orchestration/tests/test_nats_queue.py index cf3514bce..da47f3429 100644 --- a/ami/ml/orchestration/tests/test_nats_queue.py +++ b/ami/ml/orchestration/tests/test_nats_queue.py @@ -1,11 +1,13 @@ """Unit tests for TaskQueueManager.""" +import json import unittest from unittest.mock import AsyncMock, MagicMock, patch import nats +import nats.errors -from ami.ml.orchestration.nats_queue import TaskQueueManager +from ami.ml.orchestration.nats_queue import ADVISORY_STREAM_NAME, TaskQueueManager from ami.ml.schemas import PipelineProcessingTask @@ -25,6 +27,7 @@ def _create_mock_nats_connection(self): nc = MagicMock() nc.is_closed = False nc.close = AsyncMock() + nc.flush = AsyncMock() js = MagicMock() js.stream_info = AsyncMock() @@ -60,7 +63,8 @@ async def test_publish_task_creates_stream_and_consumer(self): async with TaskQueueManager() as manager: await manager.publish_task(456, sample_task) - js.add_stream.assert_called_once() + # add_stream called twice: advisory stream in __aenter__ + job stream in _ensure_stream + self.assertEqual(js.add_stream.call_count, 2) self.assertIn("job_456", str(js.add_stream.call_args)) js.add_consumer.assert_called_once() @@ -153,7 +157,8 @@ async def test_cleanup_job_resources(self): result = await manager.cleanup_job_resources(123) self.assertTrue(result) - js.delete_consumer.assert_called_once() + # delete_consumer called twice: job consumer + DLQ advisory consumer + self.assertEqual(js.delete_consumer.call_count, 2) js.delete_stream.assert_called_once() async def test_naming_conventions(self): @@ -177,3 +182,56 @@ async def test_operations_without_connection_raise_error(self): with self.assertRaisesRegex(RuntimeError, "Connection is not open"): await manager.delete_stream(123) + + async def test_get_dead_letter_image_ids_returns_image_ids(self): + """Test that advisory messages are resolved to image IDs correctly.""" + nc, js = self._create_mock_nats_connection() + js.get_msg = AsyncMock() + + def make_advisory(seq): + m = MagicMock() + m.data = json.dumps({"stream_seq": seq}).encode() + m.ack = AsyncMock() + return m + + def make_job_msg(image_id): + m = MagicMock() + m.data = json.dumps({"image_id": image_id}).encode() + return m + + advisories = [make_advisory(1), make_advisory(2)] + js.get_msg.side_effect = [make_job_msg("img-1"), make_job_msg("img-2")] + + mock_psub = MagicMock() + mock_psub.fetch = AsyncMock(return_value=advisories) + mock_psub.unsubscribe = AsyncMock() + js.pull_subscribe = AsyncMock(return_value=mock_psub) + + with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): + async with TaskQueueManager() as manager: + result = await manager.get_dead_letter_image_ids(123, n=10) + + self.assertEqual(result, ["img-1", "img-2"]) + js.pull_subscribe.assert_called_once_with( + "$JS.EVENT.ADVISORY.CONSUMER.MAX_DELIVERIES.job_123.job-123-consumer", + durable="job-123-dlq", + stream=ADVISORY_STREAM_NAME, + ) + mock_psub.fetch.assert_called_once_with(10, timeout=1.0) + mock_psub.unsubscribe.assert_called_once() + + async def test_get_dead_letter_image_ids_no_messages(self): + """Test that a fetch timeout returns an empty list and still unsubscribes.""" + nc, js = self._create_mock_nats_connection() + + mock_psub = MagicMock() + mock_psub.fetch = AsyncMock(side_effect=nats.errors.TimeoutError) + mock_psub.unsubscribe = AsyncMock() + js.pull_subscribe = AsyncMock(return_value=mock_psub) + + with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): + async with TaskQueueManager() as manager: + result = await manager.get_dead_letter_image_ids(123) + + self.assertEqual(result, []) + mock_psub.unsubscribe.assert_called_once()