Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
b60eab0
merge
carlos-irreverentlabs Jan 16, 2026
644927f
Merge remote-tracking branch 'upstream/main'
carlosgjs Jan 22, 2026
218f7aa
Merge remote-tracking branch 'upstream/main'
carlosgjs Feb 3, 2026
90da389
Merge remote-tracking branch 'upstream/main'
carlosgjs Feb 10, 2026
8618d3c
Merge remote-tracking branch 'upstream/main'
carlosgjs Feb 13, 2026
bd1be5f
Merge remote-tracking branch 'upstream/main'
carlosgjs Feb 17, 2026
b102ae1
Merge remote-tracking branch 'upstream/main'
carlosgjs Feb 19, 2026
bc908aa
fix: PSv2 follow-up fixes from integration tests (#1135)
mihow Feb 21, 2026
4c3802a
PSv2: Improve task fetching & web worker concurrency configuration (#…
carlosgjs Feb 21, 2026
b717e80
fix: include pipeline_slug in MinimalJobSerializer (#1148)
mihow Feb 21, 2026
883c4f8
Merge remote-tracking branch 'upstream/main'
carlosgjs Feb 24, 2026
e26f3c6
Merge remote-tracking branch 'upstream/main'
carlosgjs Feb 24, 2026
4ef7a24
Merge branch 'RolnickLab:main' into main
mihow Feb 27, 2026
c389e90
Merge remote-tracking branch 'upstream/main'
carlosgjs Feb 27, 2026
33a6425
Merge branch 'main' of github.com:uw-ssec/antenna
carlosgjs Feb 27, 2026
bf80824
Merge remote-tracking branch 'upstream/main'
carlosgjs Mar 4, 2026
a2e68a0
WIP: Add support for NATS dead-letter-queue
carlosgjs Mar 4, 2026
db05526
Update tests
carlosgjs Mar 6, 2026
602f825
Add tests, cleanup naming and error handling
carlosgjs Mar 10, 2026
0102ee7
More CR feedback
carlosgjs Mar 10, 2026
b44f5b0
Use constant
carlosgjs Mar 10, 2026
5c3a47b
CR
carlosgjs Mar 10, 2026
e09fd9a
let exception propagate
carlosgjs Mar 10, 2026
e4564fb
Use async_to_sync
carlosgjs Mar 10, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions ami/ml/management/commands/check_dead_letter_queue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""
Management command to check dead letter queue messages for a job.

Usage:
python manage.py check_dead_letter_queue <job_id>

Example:
python manage.py check_dead_letter_queue 123
"""

from asgiref.sync import async_to_sync
from django.core.management.base import BaseCommand, CommandError

from ami.ml.orchestration.nats_queue import TaskQueueManager


class Command(BaseCommand):
help = "Check dead letter queue messages for a job ID"

def add_arguments(self, parser):
parser.add_argument(
"job_id",
type=int,
help="Job ID to check for dead letter queue messages",
)

def handle(self, *args, **options):
job_id = options["job_id"]

try:
dead_letter_ids = async_to_sync(self._check_dead_letter_queue)(job_id)

if dead_letter_ids:
self.stdout.write(
self.style.WARNING(f"Found {len(dead_letter_ids)} dead letter image(s) for job {job_id}:")
)
for image_id in dead_letter_ids:
self.stdout.write(f" - Image ID: {image_id}")
else:
self.stdout.write(self.style.SUCCESS(f"No dead letter images found for job {job_id}"))

except Exception as e:
raise CommandError(f"Failed to check dead letter queue: {e}")

async def _check_dead_letter_queue(self, job_id: int) -> list[str]:
"""Check for dead letter queue messages using TaskQueueManager."""
async with TaskQueueManager() as manager:
return await manager.get_dead_letter_image_ids(job_id)
157 changes: 149 additions & 8 deletions ami/ml/orchestration/nats_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ async def get_connection(nats_url: str) -> tuple[nats.NATS, JetStreamContext]:


TASK_TTR = getattr(settings, "NATS_TASK_TTR", 30) # Visibility timeout in seconds (configurable)
ADVISORY_STREAM_NAME = "advisories" # Shared stream for max delivery advisories across all jobs


class TaskQueueManager:
Expand Down Expand Up @@ -72,6 +73,15 @@ def __init__(self, nats_url: str | None = None, max_ack_pending: int | None = No
async def __aenter__(self):
"""Create connection on enter."""
self.nc, self.js = await get_connection(self.nats_url)

try:
await self._setup_advisory_stream()
except BaseException:
if self.nc and not self.nc.is_closed:
await self.nc.close()
self.nc = None
self.js = None
raise
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
Expand All @@ -95,7 +105,7 @@ def _get_consumer_name(self, job_id: int) -> str:
"""Get consumer name from job_id."""
return f"job-{job_id}-consumer"

async def _stream_exists(self, job_id: int) -> bool:
async def _job_stream_exists(self, job_id: int) -> bool:
"""Check if stream exists for the given job.

Only catches NotFoundError (→ False). TimeoutError propagates deliberately
Expand All @@ -106,6 +116,10 @@ async def _stream_exists(self, job_id: int) -> bool:
raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.")

stream_name = self._get_stream_name(job_id)
return await self._stream_exists(stream_name)

async def _stream_exists(self, stream_name: str) -> bool:
"""Check if a stream with the given name exists."""
try:
await asyncio.wait_for(self.js.stream_info(stream_name), timeout=NATS_JETSTREAM_TIMEOUT)
return True
Expand All @@ -117,7 +131,7 @@ async def _ensure_stream(self, job_id: int):
if self.js is None:
raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.")

if not await self._stream_exists(job_id):
if not await self._job_stream_exists(job_id):
stream_name = self._get_stream_name(job_id)
subject = self._get_subject(job_id)
logger.warning(f"Stream {stream_name} does not exist")
Expand Down Expand Up @@ -218,7 +232,7 @@ async def reserve_tasks(self, job_id: int, count: int, timeout: float = 5) -> li
raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.")

try:
if not await self._stream_exists(job_id):
if not await self._job_stream_exists(job_id):
logger.debug(f"Stream for job '{job_id}' does not exist when reserving task")
return []

Expand All @@ -231,7 +245,7 @@ async def reserve_tasks(self, job_id: int, count: int, timeout: float = 5) -> li

try:
msgs = await psub.fetch(count, timeout=timeout)
except nats.errors.TimeoutError:
except (asyncio.TimeoutError, nats.errors.TimeoutError):
logger.debug(f"No tasks available in stream for job '{job_id}'")
return []
finally:
Expand All @@ -250,7 +264,7 @@ async def reserve_tasks(self, job_id: int, count: int, timeout: float = 5) -> li
logger.debug(f"No tasks reserved from stream for job '{job_id}'")
return tasks

except asyncio.TimeoutError:
except (asyncio.TimeoutError, nats.errors.TimeoutError):
raise # NATS unreachable — propagate so the view can return an appropriate error
except Exception as e:
logger.error(f"Failed to reserve tasks from stream for job '{job_id}': {e}")
Expand All @@ -271,6 +285,7 @@ async def acknowledge_task(self, reply_subject: str) -> bool:

try:
await self.nc.publish(reply_subject, b"+ACK")
await self.nc.flush()
logger.debug(f"Acknowledged task with reply subject {reply_subject}")
return True
except Exception as e:
Expand Down Expand Up @@ -330,9 +345,134 @@ async def delete_stream(self, job_id: int) -> bool:
logger.error(f"Failed to delete stream for job '{job_id}': {e}")
return False

async def _setup_advisory_stream(self):
"""Ensure the shared advisory stream exists to capture max-delivery events.

Called on every __aenter__ so that advisories are captured from the moment
any TaskQueueManager connection is opened, not just when the DLQ is first read.
"""
if not await self._stream_exists(ADVISORY_STREAM_NAME):
await asyncio.wait_for(
self.js.add_stream(
name=ADVISORY_STREAM_NAME,
subjects=["$JS.EVENT.ADVISORY.>"],
max_age=3600, # Keep advisories for 1 hour
),
timeout=NATS_JETSTREAM_TIMEOUT,
)
logger.info("Advisory stream created")

def _get_dlq_consumer_name(self, job_id: int) -> str:
"""Get the durable consumer name for dead letter queue advisory tracking."""
return f"job-{job_id}-dlq"

async def get_dead_letter_image_ids(self, job_id: int, n: int = 10) -> list[str]:
"""
Get image IDs from dead letter queue (messages that exceeded max delivery attempts).

Pulls from persistent advisory stream to find failed messages, then looks up image IDs.
Uses a durable consumer so acknowledged advisories are not re-delivered on subsequent calls.

Args:
job_id: The job ID (integer primary key)
n: Maximum number of image IDs to return (default: 10)

Returns:
List of image IDs that failed to process after max retry attempts
"""
if self.nc is None or self.js is None:
raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.")

stream_name = self._get_stream_name(job_id)
consumer_name = self._get_consumer_name(job_id)
dlq_consumer_name = self._get_dlq_consumer_name(job_id)
dead_letter_ids = []

subject_filter = f"$JS.EVENT.ADVISORY.CONSUMER.MAX_DELIVERIES.{stream_name}.{consumer_name}"

# Use a durable consumer so ACKs persist across calls — ephemeral consumers
# are deleted on unsubscribe, discarding all ACK tracking and causing every
# advisory to be re-delivered on the next call.
psub = await self.js.pull_subscribe(subject_filter, durable=dlq_consumer_name, stream=ADVISORY_STREAM_NAME)

try:
msgs = await psub.fetch(n, timeout=1.0)

for msg in msgs:
advisory_data = json.loads(msg.data.decode())

# Get the stream sequence of the failed message
if "stream_seq" in advisory_data:
stream_seq = advisory_data["stream_seq"]

# Look up the actual message by sequence to get task ID
try:
job_msg = await self.js.get_msg(stream_name, stream_seq)

if job_msg and job_msg.data:
task_data = json.loads(job_msg.data.decode())

if "image_id" in task_data:
dead_letter_ids.append(str(task_data["image_id"]))
else:
logger.warning(f"No image_id found in task data: {task_data}")
except Exception as e:
logger.warning(f"Could not retrieve message {stream_seq} from {stream_name}: {e}")
# The message might have been discarded after max_deliver exceeded
else:
logger.warning(f"No stream_seq in advisory data: {advisory_data}")

# Acknowledge even if we couldn't find the stream_seq or image_id so it doesn't get re-delivered
# it shouldn't happen since stream_seq is part of the `io.nats.jetstream.advisory.v1.max_deliver`
# schema and all our messages have an image_id
await msg.ack()
logger.info(
f"Acknowledged advisory message for stream_seq {advisory_data.get('stream_seq', 'unknown')}"
)

# Flush to ensure all ACKs are written to the socket before unsubscribing.
# msg.ack() only queues a publish in the client buffer; without flush() the
# ACKs can be silently dropped when the subscription is torn down.
await self.nc.flush()

except (asyncio.TimeoutError, nats.errors.TimeoutError):
logger.info(f"No advisory messages found for job {job_id}")
finally:
await psub.unsubscribe()

return dead_letter_ids[:n]

async def delete_dlq_consumer(self, job_id: int) -> bool:
"""
Delete the durable DLQ advisory consumer for a job.

Args:
job_id: The job ID (integer primary key)

Returns:
bool: True if successful, False otherwise
"""
if self.js is None:
raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.")

dlq_consumer_name = self._get_dlq_consumer_name(job_id)
try:
await asyncio.wait_for(
self.js.delete_consumer(ADVISORY_STREAM_NAME, dlq_consumer_name),
timeout=NATS_JETSTREAM_TIMEOUT,
)
logger.info(f"Deleted DLQ consumer {dlq_consumer_name} for job '{job_id}'")
return True
except nats.js.errors.NotFoundError:
logger.debug(f"DLQ consumer {dlq_consumer_name} for job '{job_id}' not found when attempting to delete")
return True # Consider it a success if the consumer is already gone
except Exception as e:
logger.warning(f"Failed to delete DLQ consumer for job '{job_id}': {e}")
return False

async def cleanup_job_resources(self, job_id: int) -> bool:
"""
Clean up all NATS resources (consumer and stream) for a job.
Clean up all NATS resources (consumer, stream, and DLQ advisory consumer) for a job.

This should be called when a job completes or is cancelled.

Expand All @@ -342,8 +482,9 @@ async def cleanup_job_resources(self, job_id: int) -> bool:
Returns:
bool: True if successful, False otherwise
"""
# Delete consumer first, then stream
# Delete consumer first, then stream, then the durable DLQ advisory consumer
consumer_deleted = await self.delete_consumer(job_id)
stream_deleted = await self.delete_stream(job_id)
dlq_consumer_deleted = await self.delete_dlq_consumer(job_id)

return consumer_deleted and stream_deleted
return consumer_deleted and stream_deleted and dlq_consumer_deleted
64 changes: 61 additions & 3 deletions ami/ml/orchestration/tests/test_nats_queue.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Unit tests for TaskQueueManager."""

import json
import unittest
from unittest.mock import AsyncMock, MagicMock, patch

import nats
import nats.errors

from ami.ml.orchestration.nats_queue import TaskQueueManager
from ami.ml.orchestration.nats_queue import ADVISORY_STREAM_NAME, TaskQueueManager
from ami.ml.schemas import PipelineProcessingTask


Expand All @@ -25,6 +27,7 @@ def _create_mock_nats_connection(self):
nc = MagicMock()
nc.is_closed = False
nc.close = AsyncMock()
nc.flush = AsyncMock()

js = MagicMock()
js.stream_info = AsyncMock()
Expand Down Expand Up @@ -60,7 +63,8 @@ async def test_publish_task_creates_stream_and_consumer(self):
async with TaskQueueManager() as manager:
await manager.publish_task(456, sample_task)

js.add_stream.assert_called_once()
# add_stream called twice: advisory stream in __aenter__ + job stream in _ensure_stream
self.assertEqual(js.add_stream.call_count, 2)
self.assertIn("job_456", str(js.add_stream.call_args))
js.add_consumer.assert_called_once()

Expand Down Expand Up @@ -153,7 +157,8 @@ async def test_cleanup_job_resources(self):
result = await manager.cleanup_job_resources(123)

self.assertTrue(result)
js.delete_consumer.assert_called_once()
# delete_consumer called twice: job consumer + DLQ advisory consumer
self.assertEqual(js.delete_consumer.call_count, 2)
js.delete_stream.assert_called_once()

async def test_naming_conventions(self):
Expand All @@ -177,3 +182,56 @@ async def test_operations_without_connection_raise_error(self):

with self.assertRaisesRegex(RuntimeError, "Connection is not open"):
await manager.delete_stream(123)

async def test_get_dead_letter_image_ids_returns_image_ids(self):
"""Test that advisory messages are resolved to image IDs correctly."""
nc, js = self._create_mock_nats_connection()
js.get_msg = AsyncMock()

def make_advisory(seq):
m = MagicMock()
m.data = json.dumps({"stream_seq": seq}).encode()
m.ack = AsyncMock()
return m

def make_job_msg(image_id):
m = MagicMock()
m.data = json.dumps({"image_id": image_id}).encode()
return m

advisories = [make_advisory(1), make_advisory(2)]
js.get_msg.side_effect = [make_job_msg("img-1"), make_job_msg("img-2")]

mock_psub = MagicMock()
mock_psub.fetch = AsyncMock(return_value=advisories)
mock_psub.unsubscribe = AsyncMock()
js.pull_subscribe = AsyncMock(return_value=mock_psub)

with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))):
async with TaskQueueManager() as manager:
result = await manager.get_dead_letter_image_ids(123, n=10)

self.assertEqual(result, ["img-1", "img-2"])
js.pull_subscribe.assert_called_once_with(
"$JS.EVENT.ADVISORY.CONSUMER.MAX_DELIVERIES.job_123.job-123-consumer",
durable="job-123-dlq",
stream=ADVISORY_STREAM_NAME,
)
mock_psub.fetch.assert_called_once_with(10, timeout=1.0)
mock_psub.unsubscribe.assert_called_once()

async def test_get_dead_letter_image_ids_no_messages(self):
"""Test that a fetch timeout returns an empty list and still unsubscribes."""
nc, js = self._create_mock_nats_connection()

mock_psub = MagicMock()
mock_psub.fetch = AsyncMock(side_effect=nats.errors.TimeoutError)
mock_psub.unsubscribe = AsyncMock()
js.pull_subscribe = AsyncMock(return_value=mock_psub)

with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))):
async with TaskQueueManager() as manager:
result = await manager.get_dead_letter_image_ids(123)

self.assertEqual(result, [])
mock_psub.unsubscribe.assert_called_once()