Skip to content

Commit e90c374

Browse files
committed
refactor: unify middleware definitions
1 parent f752c1a commit e90c374

File tree

3 files changed

+30
-79
lines changed

3 files changed

+30
-79
lines changed

packages/core/src/auth/sso/clients.ts

Lines changed: 4 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -16,26 +16,25 @@ import {
1616
SSOServiceException,
1717
} from '@aws-sdk/client-sso'
1818
import {
19-
AuthorizationPendingException,
2019
CreateTokenRequest,
2120
RegisterClientRequest,
2221
SSOOIDC,
2322
SSOOIDCClient,
2423
StartDeviceAuthorizationRequest,
2524
} from '@aws-sdk/client-sso-oidc'
2625
import { AsyncCollection } from '../../shared/utilities/asyncCollection'
27-
import { pageableToCollection, partialClone } from '../../shared/utilities/collectionUtils'
26+
import { pageableToCollection } from '../../shared/utilities/collectionUtils'
2827
import { assertHasProps, isNonNullable, RequiredProps, selectFrom } from '../../shared/utilities/tsUtils'
2928
import { getLogger } from '../../shared/logger/logger'
3029
import { SsoAccessTokenProvider } from './ssoAccessTokenProvider'
3130
import { AwsClientResponseError, isClientFault } from '../../shared/errors'
3231
import { DevSettings } from '../../shared/settings'
3332
import { SdkError } from '@aws-sdk/types'
34-
import { HttpRequest, HttpResponse } from '@smithy/protocol-http'
3533
import { StandardRetryStrategy, defaultRetryDecider } from '@smithy/middleware-retry'
3634
import { AuthenticationFlow } from './model'
3735
import { toSnakeCase } from '../../shared/utilities/textUtilities'
3836
import { getUserAgent, withTelemetryContext } from '../../shared/telemetry/util'
37+
import { defaultDeserializeMiddleware, finalizeLoggingMiddleware } from '../../shared/awsClientBuilderV3'
3938

4039
export class OidcClient {
4140
public constructor(
@@ -249,52 +248,6 @@ export class SsoClient {
249248
}
250249

251250
function addLoggingMiddleware(client: SSOOIDCClient) {
252-
client.middlewareStack.add(
253-
(next, context) => (args) => {
254-
if (HttpRequest.isInstance(args.request)) {
255-
const { hostname, path } = args.request
256-
const input = partialClone(
257-
// TODO: Fix
258-
args.input as unknown as Record<string, unknown>,
259-
3,
260-
['clientSecret', 'accessToken', 'refreshToken'],
261-
'[omitted]'
262-
)
263-
getLogger().debug('API request (%s %s): %O', hostname, path, input)
264-
}
265-
return next(args)
266-
},
267-
{ step: 'finalizeRequest' }
268-
)
269-
270-
client.middlewareStack.add(
271-
(next, context) => async (args) => {
272-
if (!HttpRequest.isInstance(args.request)) {
273-
return next(args)
274-
}
275-
276-
const { hostname, path } = args.request
277-
const result = await next(args).catch((e) => {
278-
if (e instanceof Error && !(e instanceof AuthorizationPendingException)) {
279-
const err = { ...e }
280-
delete err['stack']
281-
getLogger().error('API response (%s %s): %O', hostname, path, err)
282-
}
283-
throw e
284-
})
285-
if (HttpResponse.isInstance(result.response)) {
286-
const output = partialClone(
287-
// TODO: Fix
288-
result.output as unknown as Record<string, unknown>,
289-
3,
290-
['clientSecret', 'accessToken', 'refreshToken'],
291-
'[omitted]'
292-
)
293-
getLogger().debug('API response (%s %s): %O', hostname, path, output)
294-
}
295-
296-
return result
297-
},
298-
{ step: 'deserialize' }
299-
)
251+
client.middlewareStack.add(finalizeLoggingMiddleware, { step: 'finalizeRequest' })
252+
client.middlewareStack.add(defaultDeserializeMiddleware, { step: 'deserialize' })
300253
}

packages/core/src/shared/awsClientBuilderV3.ts

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ import { partialClone } from './utilities/collectionUtils'
4242
import { selectFrom } from './utilities/tsUtils'
4343
import { once } from './utilities/functionUtils'
4444
import { isWeb } from './extensionGlobals'
45+
import { AuthorizationPendingException } from '@aws-sdk/client-sso-oidc'
4546

4647
export type AwsClientConstructor<C> = new (o: AwsClientOptions) => C
4748
export type AwsCommandConstructor<CommandInput extends object, Command extends AwsCommand<CommandInput, object>> = new (
@@ -175,8 +176,8 @@ export class AWSClientBuilderV3 {
175176
}
176177

177178
const service = new serviceOptions.serviceClient(opt)
178-
service.middlewareStack.add(telemetryMiddleware, { step: 'deserialize' })
179-
service.middlewareStack.add(loggingMiddleware, { step: 'finalizeRequest' })
179+
service.middlewareStack.add(defaultDeserializeMiddleware, { step: 'deserialize' })
180+
service.middlewareStack.add(finalizeLoggingMiddleware, { step: 'finalizeRequest' })
180181
service.middlewareStack.add(getEndpointMiddleware(serviceOptions.settings), { step: 'build' })
181182

182183
if (keepAlive) {
@@ -211,20 +212,13 @@ export function recordErrorTelemetry(err: Error, serviceName?: string) {
211212
})
212213
}
213214

214-
function logAndThrow(e: any, serviceId: string, errorMessageAppend: string): never {
215-
if (e instanceof Error) {
216-
recordErrorTelemetry(e, serviceId)
217-
getLogger().error('API Response %s: %O', errorMessageAppend, e)
218-
}
219-
throw e
220-
}
221-
222-
const telemetryMiddleware: DeserializeMiddleware<any, any> =
215+
export const defaultDeserializeMiddleware: DeserializeMiddleware<any, any> =
223216
(next: DeserializeHandler<any, any>, context: HandlerExecutionContext) => async (args: any) =>
224-
emitOnRequest(next, context, args)
217+
onDeserialize(next, context, args)
225218

226-
const loggingMiddleware: FinalizeRequestMiddleware<any, any> = (next: FinalizeHandler<any, any>) => async (args: any) =>
227-
logOnRequest(next, args)
219+
export const finalizeLoggingMiddleware: FinalizeRequestMiddleware<any, any> =
220+
(next: FinalizeHandler<any, any>) => async (args: any) =>
221+
logOnFinalize(next, args)
228222

229223
function getEndpointMiddleware(settings: DevSettings = DevSettings.instance): BuildMiddleware<any, any> {
230224
return (next: BuildHandler<any, any>, context: HandlerExecutionContext) => async (args: any) =>
@@ -234,32 +228,36 @@ function getEndpointMiddleware(settings: DevSettings = DevSettings.instance): Bu
234228
const keepAliveMiddleware: BuildMiddleware<any, any> = (next: BuildHandler<any, any>) => async (args: any) =>
235229
addKeepAliveHeader(next, args)
236230

237-
export async function emitOnRequest(next: DeserializeHandler<any, any>, context: HandlerExecutionContext, args: any) {
231+
export async function onDeserialize(next: DeserializeHandler<any, any>, context: HandlerExecutionContext, args: any) {
238232
if (!HttpResponse.isInstance(args.request)) {
239233
return next(args)
240234
}
241-
const serviceId = getServiceId(context as object)
242235
const { hostname, path } = args.request
236+
const serviceId = getServiceId(context as object)
243237
const logTail = `(${hostname} ${path})`
244238
try {
245239
const result = await next(args)
246240
if (HttpResponse.isInstance(result.response)) {
247-
// TODO: omit credentials / sensitive info from the telemetry.
248-
const output = partialClone(result.output, 3)
241+
const output = partialClone(result.output, 3, ['clientSecret', 'accessToken', 'refreshToken'], '[omitted]')
249242
getLogger().debug(`API Response %s: %O`, logTail, output)
250243
}
251244
return result
252-
} catch (e: any) {
253-
logAndThrow(e, serviceId, logTail)
245+
} catch (e: unknown) {
246+
if (e instanceof Error && !(e instanceof AuthorizationPendingException)) {
247+
const err = { ...e, name: e.name, mesage: e.message }
248+
delete err['stack']
249+
recordErrorTelemetry(err, serviceId)
250+
getLogger().error('API Response %s: %O', logTail, err)
251+
}
252+
throw e
254253
}
255254
}
256255

257-
export async function logOnRequest(next: FinalizeHandler<any, any>, args: any) {
256+
export async function logOnFinalize(next: FinalizeHandler<any, any>, args: any) {
258257
const request = args.request
259258
if (HttpRequest.isInstance(args.request)) {
260259
const { hostname, path } = request
261-
// TODO: omit credentials / sensitive info from the logs.
262-
const input = partialClone(args.input, 3)
260+
const input = partialClone(args.input, 3, ['clientSecret', 'accessToken', 'refreshToken'], '[omitted]')
263261
getLogger().debug(`API Request (%s %s): %O`, hostname, path, input)
264262
}
265263
return next(args)

packages/core/src/test/shared/awsClientBuilderV3.test.ts

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ import {
1414
AWSClientBuilderV3,
1515
AwsClientOptions,
1616
AwsCommand,
17-
emitOnRequest,
17+
onDeserialize,
1818
getServiceId,
19-
logOnRequest,
19+
logOnFinalize,
2020
overwriteEndpoint,
2121
recordErrorTelemetry,
2222
} from '../../shared/awsClientBuilderV3'
@@ -237,7 +237,7 @@ describe('AwsClientBuilderV3', function () {
237237
})
238238

239239
it('logs messages on request', async function () {
240-
await logOnRequest((_: any) => _, args as any)
240+
await logOnFinalize((_: any) => _, args as any)
241241
assertLogsContainAllOf(['testHost', 'testPath'], false, 'debug')
242242
})
243243

@@ -246,7 +246,7 @@ describe('AwsClientBuilderV3', function () {
246246
throw new Error('test error')
247247
}
248248
await telemetry.vscode_executeCommand.run(async (span) => {
249-
await assert.rejects(emitOnRequest(next, context, args))
249+
await assert.rejects(onDeserialize(next, context, args))
250250
})
251251
assertLogsContain('test error', false, 'error')
252252
assertTelemetry('vscode_executeCommand', { requestServiceType: 'foo' })
@@ -257,7 +257,7 @@ describe('AwsClientBuilderV3', function () {
257257
return response
258258
}
259259
await telemetry.vscode_executeCommand.run(async (span) => {
260-
assert.deepStrictEqual(await emitOnRequest(next, context, args), response)
260+
assert.deepStrictEqual(await onDeserialize(next, context, args), response)
261261
})
262262
assertLogsContainAllOf(['testHost', 'testPath'], false, 'debug')
263263
assert.throws(() => assertTelemetry('vscode_executeCommand', { requestServiceType: 'foo' }))

0 commit comments

Comments
 (0)