diff --git a/graphistry/client_session.py b/graphistry/client_session.py index e1a3d9228a..82331eb9fa 100644 --- a/graphistry/client_session.py +++ b/graphistry/client_session.py @@ -71,6 +71,8 @@ def __init__(self) -> None: self.idp_name: Optional[str] = None self.sso_state: Optional[str] = None + self.sso_state_created_at: Optional[float] = None + self.sso_state_ttl_s: int = get_from_env("GRAPHISTRY_SSO_STATE_TTL_S", int, 300) self.personal_key: Optional[str] = None self.personal_key_id: Optional[str] = None diff --git a/graphistry/exceptions.py b/graphistry/exceptions.py index 6b059fad85..a5fecd3740 100644 --- a/graphistry/exceptions.py +++ b/graphistry/exceptions.py @@ -17,6 +17,13 @@ class SsoStateInvalidException(SsoException): """ pass +class SsoStateExpiredException(SsoException): + """ + Raised when the SSO state has exceeded the client-side TTL, + meaning the server's PKCE verifier has likely expired. + """ + pass + class TokenExpireException(Exception): diff --git a/graphistry/pygraphistry.py b/graphistry/pygraphistry.py index fe460e74c3..bf63900cf0 100644 --- a/graphistry/pygraphistry.py +++ b/graphistry/pygraphistry.py @@ -19,7 +19,7 @@ from . import bolt_util from .plotter import Plotter from .util import in_databricks, setup_logger, in_ipython -from .exceptions import SsoRetrieveTokenTimeoutException, TokenExpireException, SsoStateInvalidException +from .exceptions import SsoRetrieveTokenTimeoutException, TokenExpireException, SsoStateInvalidException, SsoStateExpiredException from .messages import ( MSG_REGISTER_MISSING_PASSWORD, @@ -315,6 +315,7 @@ def _handle_auth_url(self, auth_url: str, sso_timeout: Optional[int], sso_opt_in self.session.org_name = org_name # finish, set back to None self.session.sso_state = None + self.session.sso_state_created_at = None print("Successfully logged in") self._maybe_switch_org(org_name) return self.api_token() @@ -334,6 +335,8 @@ def _handle_auth_url(self, auth_url: str, sso_timeout: Optional[int], sso_opt_in # print("Keep trying to get token ...") # time.sleep(5) + ttl_s = self.session.sso_state_ttl_s + print(f"SSO link expires in {ttl_s // 60} minutes. If you wait longer, re-run graphistry.register() for a fresh link.") print("Please run graphistry.sso_get_token() to complete the authentication") return None @@ -352,6 +355,16 @@ def _sso_get_token(self) -> Tuple[Optional[str], Optional[str]]: if state is None: raise SsoStateInvalidException("[SSO] Invalid SSO state: NoneType encountered") + created_at = self.session.sso_state_created_at + ttl = self.session.sso_state_ttl_s + if created_at is not None and (time.time() - created_at) > ttl: + self.session.sso_state = None + self.session.sso_state_created_at = None + raise SsoStateExpiredException( + f"[SSO] SSO link expired (older than {ttl}s). " + f"Run graphistry.register(..., is_sso_login=True) to get a new link." + ) + # print("_sso_get_token : {}".format(state)) arrow_uploader = ArrowUploader( client_session=self.session, @@ -2494,6 +2507,7 @@ def sso_state(self, value: Optional[str] = None): # setter self.session.sso_state = value.strip() + self.session.sso_state_created_at = time.time() def scene_settings(self, menu: Optional[bool] = None, diff --git a/graphistry/tests/test_arrow_uploader.py b/graphistry/tests/test_arrow_uploader.py index 9c8187bea6..caf1c57dea 100644 --- a/graphistry/tests/test_arrow_uploader.py +++ b/graphistry/tests/test_arrow_uploader.py @@ -489,3 +489,64 @@ def test_sso_get_token_missing_org_raises(self, mock_get): with pytest.raises(Exception): au.sso_get_token(state='ignored-valid') + + +class TestSsoStateTtl(unittest.TestCase): + + def setUp(self): + self.client = graphistry.PyGraphistry + self.client.session.sso_state = None + self.client.session.sso_state_created_at = None + + def tearDown(self): + self.client.session.sso_state = None + self.client.session.sso_state_created_at = None + + def test_sso_state_created_at_set_on_state_assignment(self): + import time + before = time.time() + self.client.sso_state('test-state-123') + after = time.time() + assert self.client.session.sso_state == 'test-state-123' + assert self.client.session.sso_state_created_at is not None + assert before <= self.client.session.sso_state_created_at <= after + + def test_sso_state_ttl_default(self): + assert self.client.session.sso_state_ttl_s == 300 + + def test_expired_state_raises(self): + import time + from graphistry.exceptions import SsoStateExpiredException + self.client.sso_state('expired-state') + # Backdate created_at so it appears expired + self.client.session.sso_state_created_at = time.time() - 400 + with pytest.raises(SsoStateExpiredException, match="expired"): + self.client._sso_get_token() + # State should be cleared + assert self.client.session.sso_state is None + assert self.client.session.sso_state_created_at is None + + @mock.patch('requests.get') + def test_fresh_state_not_expired(self, mock_get): + import time + mock_resp = mock.Mock() + mock_resp.status_code = 200 + mock_resp.raise_for_status = mock.Mock() + mock_resp.json.return_value = { + 'status': 'OK', + 'data': { + 'token': 'tok123', + 'active_organization': { + 'slug': 'test-org', + 'is_found': True, + 'is_member': True, + } + } + } + mock_get.return_value = mock_resp + + self.client.sso_state('fresh-state') + # Just created, should not be expired + with mock.patch.object(type(self.client), '_maybe_switch_org'): + token, org_name = self.client._sso_get_token() + assert token is not None