diff --git a/src/server/auth-client.test.ts b/src/server/auth-client.test.ts index 00fe8bad..7ca5541d 100644 --- a/src/server/auth-client.test.ts +++ b/src/server/auth-client.test.ts @@ -4584,7 +4584,7 @@ ca/T0LLtgmbMmxSv/MmzIg== const [error, updatedTokenSet] = await authClient.getTokenSet(tokenSet); expect(error).toBeNull(); - expect(updatedTokenSet).toEqual(tokenSet); + expect(updatedTokenSet?.tokenSet).toEqual(tokenSet); }); it("should return an error if the token set does not contain a refresh token and the access token has expired", async () => { @@ -4657,7 +4657,7 @@ ca/T0LLtgmbMmxSv/MmzIg== const [error, updatedTokenSet] = await authClient.getTokenSet(tokenSet); expect(error).toBeNull(); - expect(updatedTokenSet).toEqual({ + expect(updatedTokenSet?.tokenSet).toEqual({ accessToken: DEFAULT.accessToken, refreshToken: DEFAULT.refreshToken, expiresAt: expect.any(Number) @@ -4778,7 +4778,7 @@ ca/T0LLtgmbMmxSv/MmzIg== const [error, updatedTokenSet] = await authClient.getTokenSet(tokenSet); expect(error).toBeNull(); - expect(updatedTokenSet).toEqual({ + expect(updatedTokenSet?.tokenSet).toEqual({ accessToken: DEFAULT.accessToken, refreshToken: "rt_456", expiresAt: expect.any(Number) diff --git a/src/server/auth-client.ts b/src/server/auth-client.ts index 89f825f7..378fa49d 100644 --- a/src/server/auth-client.ts +++ b/src/server/auth-client.ts @@ -25,7 +25,8 @@ import { LogoutToken, SessionData, StartInteractiveLoginOptions, - TokenSet + TokenSet, + User } from "../types/index.js"; import { ensureNoLeadingSlash, @@ -577,18 +578,9 @@ export class AuthClient { const res = await this.onCallback(null, onCallbackCtx, session); - if (this.beforeSessionSaved) { - const updatedSession = await this.beforeSessionSaved( - session, - oidcRes.id_token ?? null - ); - session = { - ...updatedSession, - internal: session.internal - }; - } else { - session.user = filterDefaultIdTokenClaims(idTokenClaims); - } + // call beforeSessionSaved callback if present + // if not then filter id_token claims with default rules + session = await this.finalizeSession(session, oidcRes.id_token); await this.sessionStore.set(req.cookies, res.cookies, session, true); addCacheControlHeadersForSession(res); @@ -633,7 +625,9 @@ export class AuthClient { ); } - const [error, updatedTokenSet] = await this.getTokenSet(session.tokenSet); + const [error, getTokenSetResponse] = await this.getTokenSet( + session.tokenSet + ); if (error) { return NextResponse.json( @@ -648,6 +642,9 @@ export class AuthClient { } ); } + + const { tokenSet: updatedTokenSet, idTokenClaims } = getTokenSetResponse; + const res = NextResponse.json({ token: updatedTokenSet.accessToken, scope: updatedTokenSet.scope, @@ -656,11 +653,20 @@ export class AuthClient { if ( updatedTokenSet.accessToken !== session.tokenSet.accessToken || - updatedTokenSet.refreshToken !== session.tokenSet.refreshToken || - updatedTokenSet.expiresAt !== session.tokenSet.expiresAt + updatedTokenSet.expiresAt !== session.tokenSet.expiresAt || + updatedTokenSet.refreshToken !== session.tokenSet.refreshToken ) { + if (idTokenClaims) { + session.user = idTokenClaims as User; + } + // call beforeSessionSaved callback if present + // if not then filter id_token claims with default rules + const finalSession = await this.finalizeSession( + session, + updatedTokenSet.idToken + ); await this.sessionStore.set(req.cookies, res.cookies, { - ...session, + ...finalSession, tokenSet: updatedTokenSet }); addCacheControlHeadersForSession(res); @@ -710,13 +716,17 @@ export class AuthClient { } /** - * getTokenSet returns a valid token set. If the access token has expired, it will attempt to - * refresh it using the refresh token, if available. + * Retrieves OAuth token sets, handling token refresh when necessary or if forced. + * + * @returns A tuple containing either: + * - `[SdkError, null]` if an error occurred (missing refresh token, discovery failure, or refresh failure) + * - `[null, {tokenSet, idTokenClaims}]` if a new token was retrieved, containing the new token set ID token claims + * - `[null, {tokenSet, }]` if token refresh was not done and existing token was returned */ async getTokenSet( tokenSet: TokenSet, forceRefresh?: boolean | undefined - ): Promise<[null, TokenSet] | [SdkError, null]> { + ): Promise<[null, GetTokenSetResponse] | [SdkError, null]> { // the access token has expired but we do not have a refresh token if (!tokenSet.refreshToken && tokenSet.expiresAt <= Date.now() / 1000) { return [ @@ -771,6 +781,7 @@ export class AuthClient { ]; } + const idTokenClaims = oauth.getValidatedIdTokenClaims(oauthRes)!; const accessTokenExpiresAt = Math.floor(Date.now() / 1000) + Number(oauthRes.expires_in); @@ -789,11 +800,17 @@ export class AuthClient { updatedTokenSet.refreshToken = tokenSet.refreshToken; } - return [null, updatedTokenSet]; + return [ + null, + { + tokenSet: updatedTokenSet, + idTokenClaims: idTokenClaims + } + ]; } } - return [null, tokenSet]; + return [null, { tokenSet, idTokenClaims: undefined }]; } private async discoverAuthorizationServerMetadata(): Promise< @@ -1161,6 +1178,32 @@ export class AuthClient { return [null, connectionTokenSet] as [null, ConnectionTokenSet]; } + + /** + * Filters and processes ID token claims for a session. + * + * If a `beforeSessionSaved` callback is configured, it will be invoked to allow + * custom processing of the session and ID token. Otherwise, default filtering + * will be applied to remove standard ID token claims from the user object. + */ + async finalizeSession( + session: SessionData, + idToken?: string + ): Promise { + if (this.beforeSessionSaved) { + const updatedSession = await this.beforeSessionSaved( + session, + idToken ?? null + ); + session = { + ...updatedSession, + internal: session.internal + }; + } else { + session.user = filterDefaultIdTokenClaims(session.user); + } + return session; + } } const encodeBase64 = (input: string) => { @@ -1175,3 +1218,8 @@ const encodeBase64 = (input: string) => { } return btoa(arr.join("")); }; + +type GetTokenSetResponse = { + tokenSet: TokenSet; + idTokenClaims?: { [key: string]: any }; +}; diff --git a/src/server/client.ts b/src/server/client.ts index db449c6b..ad99301b 100644 --- a/src/server/client.ts +++ b/src/server/client.ts @@ -7,15 +7,15 @@ import { AccessTokenError, AccessTokenErrorCode, AccessTokenForConnectionError, - AccessTokenForConnectionErrorCode, + AccessTokenForConnectionErrorCode } from "../errors/index.js"; - import { AccessTokenForConnectionOptions, AuthorizationParameters, SessionData, SessionDataStore, - StartInteractiveLoginOptions + StartInteractiveLoginOptions, + User } from "../types/index.js"; import { AuthClient, @@ -420,23 +420,32 @@ export class Auth0Client { ); } - const [error, tokenSet] = await this.authClient.getTokenSet( + const [error, getTokenSetResponse] = await this.authClient.getTokenSet( session.tokenSet, options.refresh ); if (error) { throw error; } - + const { tokenSet, idTokenClaims } = getTokenSetResponse; // update the session with the new token set, if necessary if ( tokenSet.accessToken !== session.tokenSet.accessToken || tokenSet.expiresAt !== session.tokenSet.expiresAt || tokenSet.refreshToken !== session.tokenSet.refreshToken ) { + if (idTokenClaims) { + session.user = idTokenClaims as User; + } + // call beforeSessionSaved callback if present + // if not then filter id_token claims with default rules + const finalSession = await this.authClient.finalizeSession( + session, + tokenSet.idToken + ); await this.saveToSession( { - ...session, + ...finalSession, tokenSet }, req, diff --git a/src/server/get-access-token.test.ts b/src/server/get-access-token.test.ts index 1e7450e6..6b6932fc 100644 --- a/src/server/get-access-token.test.ts +++ b/src/server/get-access-token.test.ts @@ -13,62 +13,67 @@ import { vi } from "vitest"; -import { SessionData, TokenSet } from "../types/index.js"; +import { SessionData } from "../types/index.js"; import { Auth0Client } from "./client.js"; // Basic constants for testing -const DEFAULT = { - domain: "https://op.example.com", +const domain = "https://auth0.local"; +const alg = "RS256"; +const sub = "test-sub"; +const sid = "test-sid"; +const scope = "openid profile email offline_access"; + +const testAuth0ClientConfig = { + domain, clientId: "test-client-id", clientSecret: "test-client-secret", appBaseUrl: "https://example.org", - secret: "test-secret-long-enough-for-hs256-test-secret-long-enough-for-hs256", - alg: "RS256", - sub: "test-sub", - sid: "test-sid", - scope: "openid profile email offline_access" + secret: "test-secret-long-enough-for-hs256-test-secret-long-enough-for-hs256" }; -const initialTokenSetBase = { - accessToken: "test-access-token", - refreshToken: "test-refresh-token", - idToken: "test-id-token", - scope: DEFAULT.scope -}; - -const authClientConfig = { - domain: DEFAULT.domain, - clientId: DEFAULT.clientId, - clientSecret: DEFAULT.clientSecret, - appBaseUrl: DEFAULT.appBaseUrl, - secret: DEFAULT.secret -}; - -// msw server setup let keyPair: jose.GenerateKeyPairResult; + const refreshedAccessToken = "msw-refreshed-access-token"; const refreshedRefreshToken = "msw-refreshed-refresh-token"; const refreshedExpiresIn = 3600; -const issuer = DEFAULT.domain; -const audience = DEFAULT.clientId; +const issuer = domain; +const audience = testAuth0ClientConfig.clientId; +const initialName = "initialName"; +const updatedName = "updatedName"; + +const generateToken = async (claims?: any) => + await new jose.SignJWT({ + sid, + sub, + auth_time: Math.floor(Date.now() / 1000), + nonce: "nonce-value", + jti: Date.now().toString(), + ...(claims && { ...claims }) + }) + .setProtectedHeader({ alg }) + .setIssuer(issuer) + .setAudience(audience) + .setIssuedAt() + .setExpirationTime("1h") + .sign(keyPair.privateKey); const handlers = [ // OIDC Discovery Endpoint - http.get(`${DEFAULT.domain}/.well-known/openid-configuration`, () => { + http.get(`${domain}/.well-known/openid-configuration`, () => { return HttpResponse.json({ issuer: issuer, - token_endpoint: `${DEFAULT.domain}/oauth/token`, - jwks_uri: `${DEFAULT.domain}/.well-known/jwks.json` + token_endpoint: `${domain}/oauth/token`, + jwks_uri: `${domain}/.well-known/jwks.json` }); }), // JWKS Endpoint - http.get(`${DEFAULT.domain}/.well-known/jwks.json`, async () => { + http.get(`${domain}/.well-known/jwks.json`, async () => { const jwk = await jose.exportJWK(keyPair.publicKey); return HttpResponse.json({ keys: [jwk] }); }), // Token Endpoint (for refresh token grant) http.post( - `${DEFAULT.domain}/oauth/token`, + `${domain}/oauth/token`, async ({ request }: { request: Request }) => { const body = await request.formData(); @@ -76,27 +81,15 @@ const handlers = [ body.get("grant_type") === "refresh_token" && body.get("refresh_token") ) { - // Generate a new ID token for the refreshed set - const newIdToken = await new jose.SignJWT({ - sid: DEFAULT.sid, - sub: DEFAULT.sub, - auth_time: Math.floor(Date.now() / 1000), - nonce: "nonce-value" // Example nonce - }) - .setProtectedHeader({ alg: DEFAULT.alg }) - .setIssuer(issuer) - .setAudience(audience) - .setIssuedAt() - .setExpirationTime("1h") - .sign(keyPair.privateKey); - return HttpResponse.json({ access_token: refreshedAccessToken, refresh_token: refreshedRefreshToken, - id_token: newIdToken, + id_token: await generateToken({ + name: updatedName + }), token_type: "Bearer", expires_in: refreshedExpiresIn, - scope: DEFAULT.scope // Assuming scope doesn't change on refresh + scope }); } @@ -112,7 +105,7 @@ const handlers = [ const server = setupServer(...handlers); beforeAll(async () => { - keyPair = await jose.generateKeyPair(DEFAULT.alg); + keyPair = await jose.generateKeyPair(alg); server.listen({ onUnhandledRequest: "error" }); }); afterEach(() => server.resetHandlers()); @@ -121,19 +114,20 @@ afterAll(() => server.close()); /** * Creates initial session data for tests. */ -function createInitialSession(): SessionData { - // Use a VALID (non-expired) initial token - const initialExpiresAt = Math.floor(Date.now() / 1000) + 3600; // Expires in 1 hour - const initialTokenSet: TokenSet = { - ...initialTokenSetBase, // Spread the base token set from the new constant - expiresAt: initialExpiresAt // Add the dynamic expiration time - }; - const initialSession: SessionData = { - user: { sub: DEFAULT.sub }, - tokenSet: initialTokenSet, - internal: { sid: DEFAULT.sid, createdAt: Date.now() / 1000 } +async function createInitialSession(): Promise { + return { + user: { sub, name: initialName }, + tokenSet: { + accessToken: "test-access-token", + refreshToken: "test-refresh-token", + idToken: await generateToken({ + name: initialName + }), + scope, + expiresAt: Math.floor(Date.now() / 1000) + 3600 // Expires in 1 hour + }, + internal: { sid, createdAt: Date.now() / 1000 } }; - return initialSession; } describe("Auth0Client - getAccessToken", () => { @@ -142,12 +136,19 @@ describe("Auth0Client - getAccessToken", () => { beforeEach(async () => { // Instantiate Auth0Client normally, it will use intercepted fetch - auth0Client = new Auth0Client(authClientConfig); + auth0Client = new Auth0Client(testAuth0ClientConfig); // Mock saveToSession to avoid cookie/request context issues mockSaveToSession = vi .spyOn(Auth0Client.prototype as any, "saveToSession") .mockResolvedValue(undefined); // Mock successful save + + const initialSession = await createInitialSession(); + + // Mock getSession specifically for this test + vi.spyOn(Auth0Client.prototype as any, "getSession").mockResolvedValue( + initialSession + ); }); afterEach(() => { @@ -160,15 +161,10 @@ describe("Auth0Client - getAccessToken", () => { * it refreshes the token. */ it("should refresh token and save session for pages-router overload when refresh is true (with valid token)", async () => { - const initialSession = createInitialSession(); - - // Mock getSession specifically for this test - vi.spyOn(Auth0Client.prototype as any, "getSession").mockResolvedValue( - initialSession - ); - // Pages router overload requires req/res objects - const mockReq = new NextRequest(`https://${DEFAULT.appBaseUrl}/api/test`); + const mockReq = new NextRequest( + `https://${testAuth0ClientConfig.appBaseUrl}/api/test` + ); const mockRes = new NextResponse(); // --- Execution --- @@ -188,6 +184,10 @@ describe("Auth0Client - getAccessToken", () => { // The '0' precision checks for equality at the integer second level. expect(result?.expiresAt).toBeCloseTo(expectedExpiresAtRough, 0); expect(mockSaveToSession).toHaveBeenCalledOnce(); + + // Verify user profile data is updated in saved session + const savedSessionData = mockSaveToSession.mock.calls[0][0] as SessionData; + expect(savedSessionData.user.name).toBe(updatedName); }); /** @@ -196,13 +196,6 @@ describe("Auth0Client - getAccessToken", () => { * it refreshes the token. */ it("should refresh token for app-router overload when refresh is true (with valid token)", async () => { - const initialSession = createInitialSession(); - - // Mock getSession specifically for this test - vi.spyOn(Auth0Client.prototype as any, "getSession").mockResolvedValue( - initialSession - ); - // --- Execution --- const result = await auth0Client.getAccessToken({ refresh: true @@ -217,5 +210,9 @@ describe("Auth0Client - getAccessToken", () => { expect(result?.expiresAt).toBeCloseTo(expectedExpiresAtRough, 0); expect(mockSaveToSession).toHaveBeenCalledOnce(); + + // Verify user profile data is updated in saved session + const savedSessionData = mockSaveToSession.mock.calls[0][0] as SessionData; + expect(savedSessionData.user.name).toBe(updatedName); }); });