Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -120,24 +120,30 @@ def get_sharding_plan(
def apply_sparse_optimizer(
parameters: Iterable[nn.Parameter],
optimizer_cls: Optional[Type[Optimizer]] = None,
optimizer_kwargs: Dict[str, Any] = dict(),
optimizer_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
"""
Apply a sparse optimizer to the sparse/EBC parts of a model.
"""Apply a sparse optimizer to the sparse/EBC parts of a model.

This optimizer is fused, so it will be applied directly in the backward pass.

This should only be used for sparse parameters.

Args:
parameters (Iterable[nn.Parameter]): The sparse parameters to apply the optimizer to.
optimizer_cls (Type[Optimizer], optional): The optimizer class to use. Defaults to RowWiseAdagrad.
optimizer_kwargs (Dict[str, Any], optional): Additional keyword arguments for the optimizer.
optimizer_cls (Type[Optimizer], optional): The optimizer class to use.
Defaults to ``RowWiseAdagrad`` when ``None`` is passed.
optimizer_kwargs (Dict[str, Any], optional): Additional keyword arguments
for the optimizer. Defaults to ``{"lr": 0.01}`` when both
``optimizer_cls`` and ``optimizer_kwargs`` are unset.
"""

if not optimizer_cls and optimizer_kwargs:
if optimizer_cls is None:
optimizer_cls = RowWiseAdagrad
optimizer_kwargs = {"lr": 0.01}
apply_optimizer_in_backward(optimizer_cls, parameters, optimizer_kwargs) # ty: ignore[invalid-argument-type] TODO(ty-torch-api-surface): fix ty false positives around the torch API surface.
if not optimizer_kwargs:
optimizer_kwargs = {"lr": 0.01}
if optimizer_kwargs is None:
optimizer_kwargs = {}
apply_optimizer_in_backward(optimizer_cls, parameters, optimizer_kwargs)


def apply_dense_optimizer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,11 +248,14 @@ def train(
graph_backend=self._graph_backend,
device=device,
)
assert data_loaders.train_main is not None, (
"train_main dataloader required for training"
)
best_val_acc = 0.0
for epoch in range(self.__num_epochs):
logger.info(f"Batch training... for epoch {epoch}/{self.__num_epochs}")
train_loss = self._train(
data_loader=data_loaders.train_main, # type: ignore[arg-type] # ty: ignore[invalid-argument-type] TODO(ty-torch-api-surface): fix ty false positives around the torch API surface.
data_loader=data_loaders.train_main,
device=device,
)
train_loss_str = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
class TorchProfiler:
def __init__(self, **kwargs) -> None:
self.trace_handler = tensorboard_trace_handler(
dir_name=TMP_PROFILER_LOG_DIR_NAME, # type: ignore[arg-type] # ty: ignore[invalid-argument-type] TODO(ty-torch-api-surface): fix ty false positives around the torch API surface.
dir_name=TMP_PROFILER_LOG_DIR_NAME.uri,
use_gzip=True,
)
self.wait = int(kwargs.get("wait", 5))
Expand Down
22 changes: 15 additions & 7 deletions gigl/src/common/models/layers/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,18 +100,27 @@ def forward(


class SoftmaxLoss(nn.Module):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Side node, this isnt really Softmax loss but we have been calling it that.
cc @yliu2-sc - isnt this just a flavor of InfoNCE ?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not even used in any runs or configs

"""
A loss layer built on top of the PyTorch implementation of the softmax cross entropy loss.
"""A loss layer built on top of the PyTorch implementation of the softmax cross entropy loss.

The loss function by default calculates the loss by
cross_entropy(all_scores / softmax_temperature, ys, reduction='sum').

The loss function by default calculate the loss by
cross_entropy(all_scores, ys, reduction='sum')
Dividing the scores by ``softmax_temperature`` controls the sharpness of the
softmax distribution. A temperature of ``1.0`` is a no-op and corresponds to
plain cross-entropy.

See: https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html for more information.

Args:
softmax_temperature (float): Scaling factor applied via ``scores /
softmax_temperature`` before computing cross-entropy. Defaults to
``1.0`` (no scaling). Must be non-zero; the caller is responsible for
supplying a finite, non-zero value.
"""

def __init__(
self,
softmax_temperature: Optional[float] = None,
softmax_temperature: float = 1.0,
):
super(SoftmaxLoss, self).__init__()
self.softmax_temperature = softmax_temperature
Expand Down Expand Up @@ -142,8 +151,7 @@ def _calculate_softmax_loss(
) # shape=[num_pos_nodes]

loss = F.cross_entropy(
input=all_scores
/ self.softmax_temperature, # https://github.com/Snapchat/GiGL/issues/408 # ty: ignore[unsupported-operator] TODO(ty-torch-union-inference): fix ty Tensor/Module union inference regressions.
input=all_scores / self.softmax_temperature,
target=ys,
reduction="sum",
)
Expand Down
23 changes: 23 additions & 0 deletions tests/unit/experimental/torchrec_utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""Regression tests for ``apply_sparse_optimizer``."""

from unittest.mock import patch

import torch
import torch.nn as nn

from gigl.experimental.knowledge_graph_embedding.common.torchrec import utils
from tests.test_assets.test_case import TestCase


class ApplySparseOptimizerTest(TestCase):
def test_default_optimizer_uses_rowwise_adagrad(self) -> None:
"""When called with neither ``optimizer_cls`` nor ``optimizer_kwargs``,
``apply_sparse_optimizer`` falls back to ``RowWiseAdagrad`` with
``lr=0.01`` and forwards them to ``apply_optimizer_in_backward``.
"""
parameters = [nn.Parameter(torch.zeros(1))]
with patch.object(utils, "apply_optimizer_in_backward") as mock_apply:
utils.apply_sparse_optimizer(parameters=parameters)
forwarded_cls, _, forwarded_kwargs = mock_apply.call_args[0]
self.assertIs(forwarded_cls, utils.RowWiseAdagrad)
self.assertEqual(forwarded_kwargs, {"lr": 0.01})
9 changes: 9 additions & 0 deletions tests/unit/src/common/models/layers/loss_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn.functional as F

from gigl.nn.loss import RetrievalLoss
from gigl.src.common.models.layers.loss import SoftmaxLoss
from tests.test_assets.test_case import TestCase


Expand Down Expand Up @@ -176,3 +177,11 @@ def test_empty_loss(self):
scores=empty_scores, query_ids=query_ids, candidate_ids=candidate_ids
)
self.assert_tensor_equality(loss, expected_loss)


class SoftmaxLossDefaultTemperatureTest(TestCase):
def test_default_temperature_is_one(self) -> None:
"""SoftmaxLoss defaults softmax_temperature to 1.0 so the forward-pass
division scores / softmax_temperature is always well-defined.
"""
self.assertEqual(SoftmaxLoss().softmax_temperature, 1.0)