diff --git a/jupyter_server/auth/identity.py b/jupyter_server/auth/identity.py index be238b63e5..fc4b029922 100644 --- a/jupyter_server/auth/identity.py +++ b/jupyter_server/auth/identity.py @@ -20,7 +20,7 @@ from http.cookies import Morsel from tornado import escape, httputil, web -from traitlets import Bool, Dict, Type, Unicode, default +from traitlets import Bool, Dict, Enum, List, TraitError, Type, Unicode, default, validate from traitlets.config import LoggingConfigurable from jupyter_server.transutils import _i18n @@ -31,6 +31,10 @@ _non_alphanum = re.compile(r"[^A-Za-z0-9]") +# Define the User properties that can be updated +UpdatableField = t.Literal["name", "display_name", "initials", "avatar_url", "color"] + + @dataclass class User: """Object representing a User @@ -188,6 +192,14 @@ class IdentityProvider(LoggingConfigurable): help=_i18n("The logout handler class to use."), ) + # Define the fields that can be updated + updatable_fields = List( + trait=Enum(list(t.get_args(UpdatableField))), + default_value=["color"], # Default updatable field + config=True, + help=_i18n("List of fields in the User model that can be updated."), + ) + token_generated = False @default("token") @@ -207,6 +219,18 @@ def _token_default(self): self.token_generated = True return binascii.hexlify(os.urandom(24)).decode("ascii") + @validate("updatable_fields") + def _validate_updatable_fields(self, proposal): + """Validate that all fields in updatable_fields are valid.""" + valid_updatable_fields = list(t.get_args(UpdatableField)) + invalid_fields = [ + field for field in proposal["value"] if field not in valid_updatable_fields + ] + if invalid_fields: + msg = f"Invalid fields in updatable_fields: {invalid_fields}" + raise TraitError(msg) + return proposal["value"] + need_token: bool | Bool[bool, t.Union[bool, int]] = Bool(True) def get_user(self, handler: web.RequestHandler) -> User | None | t.Awaitable[User | None]: @@ -269,6 +293,31 @@ async def _get_user(self, handler: web.RequestHandler) -> User | None: return user + def update_user( + self, handler: web.RequestHandler, user_data: dict[UpdatableField, str] + ) -> User: + """Update user information and persist the user model.""" + self.check_update(user_data) + current_user = t.cast(User, handler.current_user) + updated_user = self.update_user_model(current_user, user_data) + self.persist_user_model(handler) + return updated_user + + def check_update(self, user_data: dict[UpdatableField, str]) -> None: + """Raises if some fields to update are not updatable.""" + for field in user_data: + if field not in self.updatable_fields: + msg = f"Field {field} is not updatable" + raise ValueError(msg) + + def update_user_model(self, current_user: User, user_data: dict[UpdatableField, str]) -> User: + """Update user information.""" + raise NotImplementedError + + def persist_user_model(self, handler: web.RequestHandler) -> None: + """Persist the user model (i.e. a cookie).""" + raise NotImplementedError + def identity_model(self, user: User) -> dict[str, t.Any]: """Return a User as an Identity model""" # TODO: validate? @@ -617,6 +666,16 @@ class PasswordIdentityProvider(IdentityProvider): def _need_token_default(self): return not bool(self.hashed_password) + @default("updatable_fields") + def _default_updatable_fields(self): + return [ + "name", + "display_name", + "initials", + "avatar_url", + "color", + ] + @property def login_available(self) -> bool: """Whether a LoginHandler is needed - and therefore whether the login page should be displayed.""" @@ -627,6 +686,17 @@ def auth_enabled(self) -> bool: """Return whether any auth is enabled""" return bool(self.hashed_password or self.token) + def update_user_model(self, current_user: User, user_data: dict[UpdatableField, str]) -> User: + """Update user information.""" + for field in self.updatable_fields: + if field in user_data: + setattr(current_user, field, user_data[field]) + return current_user + + def persist_user_model(self, handler: web.RequestHandler) -> None: + """Persist the user model to a cookie.""" + self.set_login_cookie(handler, handler.current_user) + def passwd_check(self, password): """Check password against our stored hashed password""" return passwd_check(self.hashed_password, password) diff --git a/jupyter_server/services/api/handlers.py b/jupyter_server/services/api/handlers.py index f61d9dd10f..609d68601f 100644 --- a/jupyter_server/services/api/handlers.py +++ b/jupyter_server/services/api/handlers.py @@ -4,13 +4,14 @@ # Distributed under the terms of the Modified BSD License. import json import os -from typing import Any +from typing import Any, cast from jupyter_core.utils import ensure_async from tornado import web from jupyter_server._tz import isoformat, utcfromtimestamp from jupyter_server.auth.decorator import authorized +from jupyter_server.auth.identity import IdentityProvider, UpdatableField from ...base.handlers import APIHandler, JupyterHandler @@ -70,7 +71,7 @@ async def get(self): class IdentityHandler(APIHandler): - """Get the current user's identity model""" + """Get or patch the current user's identity model""" @web.authenticated async def get(self): @@ -106,6 +107,9 @@ async def get(self): if authorized: allowed.append(action) + # Add permission to user to update their own identity + permissions["updatable_fields"] = self.identity_provider.updatable_fields + identity: dict[str, Any] = self.identity_provider.identity_model(user) model = { "identity": identity, @@ -113,6 +117,28 @@ async def get(self): } self.write(json.dumps(model)) + @web.authenticated + async def patch(self): + """Update user information.""" + user_data = cast(dict[UpdatableField, str], self.get_json_body()) + if not user_data: + raise web.HTTPError(400, "Invalid or missing JSON body") + + # Update user information + identity_provider = self.settings["identity_provider"] + if not isinstance(identity_provider, IdentityProvider): + raise web.HTTPError(500, "Identity provider not configured properly") + + try: + updated_user = identity_provider.update_user(self, user_data) + self.write( + {"status": "success", "identity": identity_provider.identity_model(updated_user)} + ) + except ValueError as e: + raise web.HTTPError(400, str(e)) from e + except NotImplementedError as e: + raise web.HTTPError(501, str(e)) from e + default_handlers = [ (r"/api/spec.yaml", APISpecHandler), diff --git a/tests/services/api/test_api.py b/tests/services/api/test_api.py index 8339f6e4af..fe412e4f1d 100644 --- a/tests/services/api/test_api.py +++ b/tests/services/api/test_api.py @@ -6,6 +6,7 @@ from tornado.httpclient import HTTPError from jupyter_server.auth import Authorizer, IdentityProvider, User +from jupyter_server.auth.identity import PasswordIdentityProvider async def test_get_spec(jp_fetch): @@ -50,6 +51,22 @@ async def get_user(self, handler): return self.mock_user +class MockPasswordIdentityProvider(PasswordIdentityProvider): + mock_user: MockUser + + async def get_user(self, handler): + # super returns a UUID + # return our mock user instead, as long as the request is authorized + _authenticated = super().get_user(handler) + if isinstance(_authenticated, Awaitable): + _authenticated = await _authenticated + authenticated = _authenticated + if isinstance(self.mock_user, dict): + self.mock_user = MockUser(**self.mock_user) + if authenticated: + return self.mock_user + + class MockAuthorizer(Authorizer): def is_authorized(self, handler, user, action, resource): permissions = user.permissions @@ -70,6 +87,17 @@ def identity_provider(jp_serverapp): yield idp +@pytest.fixture +def password_identity_provider(jp_serverapp): + idp = MockPasswordIdentityProvider(parent=jp_serverapp) + authorizer = MockAuthorizer(parent=jp_serverapp) + with mock.patch.dict( + jp_serverapp.web_app.settings, + {"identity_provider": idp, "authorizer": authorizer}, + ): + yield idp + + @pytest.mark.parametrize( "identity, expected", [ @@ -117,10 +145,128 @@ async def test_identity(jp_fetch, identity, expected, identity_provider): assert set(identity_model.keys()) == set(User.__dataclass_fields__) +@pytest.mark.parametrize("identity", [{"username": "user.username"}]) +async def test_update_user_not_implemented_update(jp_fetch, identity, identity_provider): + """Test successful user update.""" + identity_provider.mock_user = MockUser(**identity) + payload = { + "color": "#000000", + } + with pytest.raises(HTTPError) as exc: + await jp_fetch( + "/api/me", + method="PATCH", + body=json.dumps(payload), + headers={"Content-Type": "application/json"}, + ) + assert exc.value.code == 501 + + +@pytest.mark.parametrize("identity", [{"username": "user.username"}]) +async def test_update_user_not_implemented_persist(jp_fetch, identity, identity_provider): + """Test successful user update.""" + identity_provider.mock_user = MockUser(**identity) + identity_provider.update_user_model = lambda *args, **kwargs: identity_provider.mock_user + payload = { + "color": "#000000", + } + with pytest.raises(HTTPError) as exc: + await jp_fetch( + "/api/me", + method="PATCH", + body=json.dumps(payload), + headers={"Content-Type": "application/json"}, + ) + assert exc.value.code == 501 + + +@pytest.mark.parametrize("identity", [{"username": "user.username"}]) +async def test_update_user_success(jp_fetch, identity, password_identity_provider): + """Test successful user update.""" + password_identity_provider.mock_user = MockUser(**identity) + payload = { + "color": "#000000", + } + r = await jp_fetch( + "/api/me", + method="PATCH", + body=json.dumps(payload), + headers={"Content-Type": "application/json"}, + ) + assert r.code == 200 + response = json.loads(r.body.decode()) + assert response["status"] == "success" + assert response["identity"]["color"] == "#000000" + + +@pytest.mark.parametrize("identity", [{"username": "user.username"}]) +async def test_update_user_raise(jp_fetch, identity, password_identity_provider): + """Test failing user update.""" + password_identity_provider.mock_user = MockUser(**identity) + payload = { + "name": "Updated Name", + "fake_prop": "anything", + } + with pytest.raises(HTTPError) as exc: + await jp_fetch( + "/api/me", + method="PATCH", + body=json.dumps(payload), + headers={"Content-Type": "application/json"}, + ) + assert exc.value.code == 400 + + +@pytest.mark.parametrize( + "identity, expected", + [ + ( + {"username": "user.username"}, + { + "username": "user.username", + "name": "Updated Name", + "display_name": "Updated Display Name", + "color": "#000000", + }, + ) + ], +) +async def test_update_user_success_custom_updatable_fields( + jp_fetch, identity, expected, password_identity_provider +): + """Test successful user update.""" + password_identity_provider.mock_user = MockUser(**identity) + identity_provider.updatable_fields = ["name", "display_name", "color"] + payload = { + "name": expected["name"], + "display_name": expected["display_name"], + "color": expected["color"], + } + r = await jp_fetch( + "/api/me", + method="PATCH", + body=json.dumps(payload), + headers={"Content-Type": "application/json"}, + ) + assert r.code == 200 + response = json.loads(r.body.decode()) + identity_model = response["identity"] + for key, value in expected.items(): + assert identity_model[key] == value + + # Test GET request to ensure the updated fields are returned + r = await jp_fetch("api/me") + assert r.code == 200 + response = json.loads(r.body.decode()) + identity_model = response["identity"] + for key, value in expected.items(): + assert identity_model[key] == value + + @pytest.mark.parametrize( "have_permissions, check_permissions, expected", [ - ("*", None, {}), + ("*", None, {"updatable_fields": ["color"]}), ( { "contents": ["read"], @@ -136,9 +282,10 @@ async def test_identity(jp_fetch, identity, expected, identity_provider): "contents": ["read"], "kernels": ["read", "write"], "terminals": [], + "updatable_fields": ["color"], }, ), - ("*", {"contents": ["write"]}, {"contents": ["write"]}), + ("*", {"contents": ["write"]}, {"contents": ["write"], "updatable_fields": ["color"]}), ], ) async def test_identity_permissions( @@ -161,6 +308,44 @@ async def test_identity_permissions( assert response["permissions"] == expected +@pytest.mark.parametrize( + "have_permissions, check_permissions, expected", + [ + ( + "*", + None, + { + "updatable_fields": [ + "name", + "display_name", + "initials", + "avatar_url", + "color", + ] + }, + ), + ], +) +async def test_password_identity_permissions( + jp_fetch, have_permissions, check_permissions, expected, password_identity_provider +): + user = MockUser("username") + user.permissions = have_permissions + password_identity_provider.mock_user = user + + if check_permissions is not None: + params = {"permissions": json.dumps(check_permissions)} + else: + params = None + + r = await jp_fetch("api/me", params=params) + assert r is not None + assert r.code == 200 + response = json.loads(r.body.decode()) + assert set(response.keys()) == {"identity", "permissions"} + assert response["permissions"] == expected + + @pytest.mark.parametrize( "permissions", [