feat: Compute per-query average for 2D retrieval_normalized_dcg#3229
feat: Compute per-query average for 2D retrieval_normalized_dcg#3229rintaro121 wants to merge 11 commits intoLightning-AI:masterfrom
retrieval_normalized_dcg#3229Conversation
retrieval_normalized_dcg
|
Hi, @Borda Although this PR focuses on fixing For example, with Average Precision: from torchmetrics.functional.retrieval import retrieval_average_precision
import torch
# Query 1
p1 = retrieval_average_precision(torch.tensor([0.1, 0.2, 0.3]), torch.tensor([0, 1, 0]))
print(p1) # tensor(0.5000)
# Query 2
p2 = retrieval_average_precision(torch.tensor([0.8, 0.1, 0.05]), torch.tensor([1, 0, 0]))
print(p2) # tensor(1.0000)
print("Mean per-query Precision:", (p1 + p2) / 2)
# tensor(0.7500)
# Batched input (2D)
p_batch = retrieval_average_precision(
torch.tensor([[0.1, 0.2, 0.3], [0.8, 0.1, 0.05]]),
torch.tensor([[0, 1, 0], [1, 0, 0]]),
)
print("Batch Precision:", p_batch)
# tensor(0.8333) <-- Not the mean per-query value (same as the NDCG example)Here, the batched input is incorrectly treated as one large query with num_queries * num_documents documents, which inflates the reported metric compared to the mean per-query value. To keep the scope manageable, my current plan is to limit this PR to NDCG only, adding tests that demonstrate the correct per-query behavior, so it is easier to review.
Do you think this level of granularity makes sense, or would you prefer the broader retrieval-wide fix to be included directly in this PR? |
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #3229 +/- ##
=======================================
- Coverage 69% 68% -0%
=======================================
Files 364 349 -15
Lines 20096 19923 -173
=======================================
- Hits 13790 13605 -185
- Misses 6306 6318 +12 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Pull Request Overview
This PR fixes the NDCG computation for 2D tensor inputs to correctly compute per-query averages instead of treating all queries as a single ranking problem. Previously, 2D inputs were flattened and computed as one large ranking, but now each query is processed separately and averaged.
- Updated
retrieval_normalized_dcgto handle 2D tensors as batches of samples with per-query computation - Added
empty_target_actionparameter to handle queries with no positive labels - Enhanced test infrastructure to support 2D tensor testing and the new parameter
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
| src/torchmetrics/functional/retrieval/ndcg.py | Core logic changes for 2D tensor handling and empty target handling |
| tests/unittests/retrieval/_inputs.py | Added 2D test input data structure |
| tests/unittests/retrieval/helpers.py | Updated test helpers to support 2D inputs and new metric parameter |
| tests/unittests/retrieval/test_ndcg.py | Added test coverage for the new empty_target_action parameter |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| original_shape = preds.shape | ||
| preds, target = _check_retrieval_functional_inputs(preds, target, allow_non_binary_target=True) | ||
|
|
||
| top_k = preds.shape[-1] if top_k is None else top_k | ||
| # reshape back if input was 2D | ||
| if len(original_shape) == 2: | ||
| preds = preds.view(original_shape) | ||
| target = target.view(original_shape) | ||
| else: | ||
| preds = preds.unsqueeze(0) |
There was a problem hiding this comment.
Line 143 should use target.unsqueeze(0) instead of target.view(original_shape). The else branch handles 1D inputs which need to be unsqueezed to match the preds tensor on line 142.
What does this PR do?
Fixes #3216
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Motivation
Currently, when passing a 2D tensor ([num_queries, num_documents]) to
retrieval_normalized_dcg, the function flattens both preds and target and computes DCG/IDCG on the concatenated list.This treats all queries as a single large ranking problem.
Changes
Fixes to
retrieval_normalized_dcg:Test additions and modifications:
ndcg_scoreto align conditions with our implementation.For the reviewer
Feedback on the naming and default behavior of the new
empty_target_actionargument would be greatly appreciated. Please also let us know if any test coverage is missing.Did you have fun?
Make sure you had fun coding 🙃
📚 Documentation preview 📚: https://torchmetrics--3229.org.readthedocs.build/en/3229/