Conversation
|
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
| result_poster: ResultPoster | None = None | ||
| prefetcher = CUDAPrefetcher(loader) # if torch.cuda.is_available() else None | ||
| try: | ||
| for i, batch in enumerate(loader): |
There was a problem hiding this comment.
TODO: The pre-fetcher currently only works with a GPU. Something like this is needed:
batch_source = loader
if torch.cuda.is_available():
prefetcher = CUDAPrefetcher(loader)
prefetcher.preload()
batch_source = prefetcher
And replace both next(prefetcher) calls with next(batch_source)
trapdata/ml/models/classification.py
Outdated
| [ | ||
| torchvision.transforms.Resize((self.input_size, self.input_size)), | ||
| torchvision.transforms.ToTensor(), | ||
| # torchvision.transforms.ToTensor(), |
There was a problem hiding this comment.
This need to be put back for the ami api use case. But I think a wrapper that conditional converts or not could be used, e.g.
def maybe_totensor(x):
if isinstance(x, torch.Tensor):
return x
return torchvision.transform.ToTensor()(x)
| antenna_api_auth_token: str = "" | ||
| antenna_service_name: str = "AMI Data Companion" | ||
| antenna_api_batch_size: int = 16 | ||
| antenna_api_batch_size: int = 24 |
There was a problem hiding this comment.
the batching/collation now happens in the RESTDataset, so effectively the API bstch size is used as the localization batch size. One of the parameters can be removed.
There was a problem hiding this comment.
Pull request overview
This PR aims to improve GPU utilization in the Antenna worker pipeline by reducing CPU↔GPU transfers and overlapping input transfer with inference.
Changes:
- Adjusts REST data loading to collate batches in DataLoader workers, enable pinned memory, and introduce a CUDA prefetcher.
- Modifies worker inference to keep tensors in GPU-friendly form (avoiding PIL conversions) and adds timing metrics.
- Tunes default batch sizes and adds a benchmark option to skip sending acknowledgments.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
trapdata/settings.py |
Updates default batch sizes for localization and Antenna API task fetching. |
trapdata/ml/models/classification.py |
Alters classification transforms by removing ToTensor() steps to support tensor-based inputs. |
trapdata/antenna/worker.py |
Switches worker loop to use CUDA prefetching, avoids PIL conversions, and changes logging/timing. |
trapdata/antenna/datasets.py |
Changes REST dataset iteration/collation behavior, enables pinned-memory DataLoader settings, and adds CUDAPrefetcher. |
trapdata/antenna/benchmark.py |
Adds a CLI flag to skip sending acknowledgments during benchmarking. |
Comments suppressed due to low confidence (1)
trapdata/antenna/datasets.py:371
rest_collate_fnnow usestorch.stack(...), which will throw if images have different spatial sizes (common for real-world inputs). The detector stack already supports receiving alist[Tensor]for variable-size images, so stacking here can introduce hard failures. Consider keepingimagesas a list (as before) or explicitly resizing/padding to a common size before stacking.
# Collate successful items
if successful:
result = {
"images": torch.stack([item["image"] for item in successful]),
"reply_subjects": [item["reply_subject"] for item in successful],
"image_ids": [item["image_id"] for item in successful],
"image_urls": [item.get("image_url") for item in successful],
}
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
trapdata/antenna/worker.py
Outdated
| ) | ||
| return did_work | ||
| except StopIteration: | ||
| pass |
There was a problem hiding this comment.
The current loop relies on next(prefetcher) inside the while body; when the prefetcher is exhausted it raises StopIteration, which is caught by the outer except StopIteration: pass. That path skips the return did_work at the end of the try, so _process_job() will return None instead of bool in the normal end-of-iteration case. Restructure iteration to break cleanly on StopIteration and still hit the final return did_work (or return did_work from the except).
| pass | |
| # Iterator exhausted: return whether any work was done | |
| return did_work |
Fixes
Test logs:
GPU Utilization:
See TODO comments in the PR