diff --git a/gigl/experimental/knowledge_graph_embedding/common/torchrec/utils.py b/gigl/experimental/knowledge_graph_embedding/common/torchrec/utils.py index f8726dcb1..9d3b54ffd 100644 --- a/gigl/experimental/knowledge_graph_embedding/common/torchrec/utils.py +++ b/gigl/experimental/knowledge_graph_embedding/common/torchrec/utils.py @@ -120,24 +120,30 @@ def get_sharding_plan( def apply_sparse_optimizer( parameters: Iterable[nn.Parameter], optimizer_cls: Optional[Type[Optimizer]] = None, - optimizer_kwargs: Dict[str, Any] = dict(), + optimizer_kwargs: Optional[Dict[str, Any]] = None, ) -> None: - """ - Apply a sparse optimizer to the sparse/EBC parts of a model. + """Apply a sparse optimizer to the sparse/EBC parts of a model. + This optimizer is fused, so it will be applied directly in the backward pass. This should only be used for sparse parameters. Args: parameters (Iterable[nn.Parameter]): The sparse parameters to apply the optimizer to. - optimizer_cls (Type[Optimizer], optional): The optimizer class to use. Defaults to RowWiseAdagrad. - optimizer_kwargs (Dict[str, Any], optional): Additional keyword arguments for the optimizer. + optimizer_cls (Type[Optimizer], optional): The optimizer class to use. + Defaults to ``RowWiseAdagrad`` when ``None`` is passed. + optimizer_kwargs (Dict[str, Any], optional): Additional keyword arguments + for the optimizer. Defaults to ``{"lr": 0.01}`` when both + ``optimizer_cls`` and ``optimizer_kwargs`` are unset. """ - if not optimizer_cls and optimizer_kwargs: + if optimizer_cls is None: optimizer_cls = RowWiseAdagrad - optimizer_kwargs = {"lr": 0.01} - apply_optimizer_in_backward(optimizer_cls, parameters, optimizer_kwargs) # ty: ignore[invalid-argument-type] TODO(ty-torch-api-surface): fix ty false positives around the torch API surface. + if not optimizer_kwargs: + optimizer_kwargs = {"lr": 0.01} + if optimizer_kwargs is None: + optimizer_kwargs = {} + apply_optimizer_in_backward(optimizer_cls, parameters, optimizer_kwargs) def apply_dense_optimizer( diff --git a/gigl/src/common/modeling_task_specs/node_classification_modeling_task_spec.py b/gigl/src/common/modeling_task_specs/node_classification_modeling_task_spec.py index bfaba1fb0..feed188b4 100644 --- a/gigl/src/common/modeling_task_specs/node_classification_modeling_task_spec.py +++ b/gigl/src/common/modeling_task_specs/node_classification_modeling_task_spec.py @@ -248,11 +248,14 @@ def train( graph_backend=self._graph_backend, device=device, ) + assert data_loaders.train_main is not None, ( + "train_main dataloader required for training" + ) best_val_acc = 0.0 for epoch in range(self.__num_epochs): logger.info(f"Batch training... for epoch {epoch}/{self.__num_epochs}") train_loss = self._train( - data_loader=data_loaders.train_main, # type: ignore[arg-type] # ty: ignore[invalid-argument-type] TODO(ty-torch-api-surface): fix ty false positives around the torch API surface. + data_loader=data_loaders.train_main, device=device, ) train_loss_str = ( diff --git a/gigl/src/common/modeling_task_specs/utils/profiler_wrapper.py b/gigl/src/common/modeling_task_specs/utils/profiler_wrapper.py index 9e634fce4..2a7c423a5 100644 --- a/gigl/src/common/modeling_task_specs/utils/profiler_wrapper.py +++ b/gigl/src/common/modeling_task_specs/utils/profiler_wrapper.py @@ -20,7 +20,7 @@ class TorchProfiler: def __init__(self, **kwargs) -> None: self.trace_handler = tensorboard_trace_handler( - dir_name=TMP_PROFILER_LOG_DIR_NAME, # type: ignore[arg-type] # ty: ignore[invalid-argument-type] TODO(ty-torch-api-surface): fix ty false positives around the torch API surface. + dir_name=TMP_PROFILER_LOG_DIR_NAME.uri, use_gzip=True, ) self.wait = int(kwargs.get("wait", 5)) diff --git a/gigl/src/common/models/layers/loss.py b/gigl/src/common/models/layers/loss.py index 08ef93c7a..9ca934126 100644 --- a/gigl/src/common/models/layers/loss.py +++ b/gigl/src/common/models/layers/loss.py @@ -100,18 +100,27 @@ def forward( class SoftmaxLoss(nn.Module): - """ - A loss layer built on top of the PyTorch implementation of the softmax cross entropy loss. + """A loss layer built on top of the PyTorch implementation of the softmax cross entropy loss. + + The loss function by default calculates the loss by + cross_entropy(all_scores / softmax_temperature, ys, reduction='sum'). - The loss function by default calculate the loss by - cross_entropy(all_scores, ys, reduction='sum') + Dividing the scores by ``softmax_temperature`` controls the sharpness of the + softmax distribution. A temperature of ``1.0`` is a no-op and corresponds to + plain cross-entropy. See: https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html for more information. + + Args: + softmax_temperature (float): Scaling factor applied via ``scores / + softmax_temperature`` before computing cross-entropy. Defaults to + ``1.0`` (no scaling). Must be non-zero; the caller is responsible for + supplying a finite, non-zero value. """ def __init__( self, - softmax_temperature: Optional[float] = None, + softmax_temperature: float = 1.0, ): super(SoftmaxLoss, self).__init__() self.softmax_temperature = softmax_temperature @@ -142,8 +151,7 @@ def _calculate_softmax_loss( ) # shape=[num_pos_nodes] loss = F.cross_entropy( - input=all_scores - / self.softmax_temperature, # https://github.com/Snapchat/GiGL/issues/408 # ty: ignore[unsupported-operator] TODO(ty-torch-union-inference): fix ty Tensor/Module union inference regressions. + input=all_scores / self.softmax_temperature, target=ys, reduction="sum", ) diff --git a/tests/unit/experimental/torchrec_utils_test.py b/tests/unit/experimental/torchrec_utils_test.py new file mode 100644 index 000000000..3a8f07061 --- /dev/null +++ b/tests/unit/experimental/torchrec_utils_test.py @@ -0,0 +1,23 @@ +"""Regression tests for ``apply_sparse_optimizer``.""" + +from unittest.mock import patch + +import torch +import torch.nn as nn + +from gigl.experimental.knowledge_graph_embedding.common.torchrec import utils +from tests.test_assets.test_case import TestCase + + +class ApplySparseOptimizerTest(TestCase): + def test_default_optimizer_uses_rowwise_adagrad(self) -> None: + """When called with neither ``optimizer_cls`` nor ``optimizer_kwargs``, + ``apply_sparse_optimizer`` falls back to ``RowWiseAdagrad`` with + ``lr=0.01`` and forwards them to ``apply_optimizer_in_backward``. + """ + parameters = [nn.Parameter(torch.zeros(1))] + with patch.object(utils, "apply_optimizer_in_backward") as mock_apply: + utils.apply_sparse_optimizer(parameters=parameters) + forwarded_cls, _, forwarded_kwargs = mock_apply.call_args[0] + self.assertIs(forwarded_cls, utils.RowWiseAdagrad) + self.assertEqual(forwarded_kwargs, {"lr": 0.01}) diff --git a/tests/unit/src/common/models/layers/loss_test.py b/tests/unit/src/common/models/layers/loss_test.py index 5c0f3f073..4be356ca3 100644 --- a/tests/unit/src/common/models/layers/loss_test.py +++ b/tests/unit/src/common/models/layers/loss_test.py @@ -2,6 +2,7 @@ import torch.nn.functional as F from gigl.nn.loss import RetrievalLoss +from gigl.src.common.models.layers.loss import SoftmaxLoss from tests.test_assets.test_case import TestCase @@ -176,3 +177,11 @@ def test_empty_loss(self): scores=empty_scores, query_ids=query_ids, candidate_ids=candidate_ids ) self.assert_tensor_equality(loss, expected_loss) + + +class SoftmaxLossDefaultTemperatureTest(TestCase): + def test_default_temperature_is_one(self) -> None: + """SoftmaxLoss defaults softmax_temperature to 1.0 so the forward-pass + division scores / softmax_temperature is always well-defined. + """ + self.assertEqual(SoftmaxLoss().softmax_temperature, 1.0)