Skip to content

Commit d186a41

Browse files
authored
ensure that cancellation token is passed in InvokeWithActivityAsync (#4329)
* ensure that cancellation token is passed in InvokeWithActivityAsync * add comments and baggange is not nullable * store ncrunch settings * shange signature to have nullable activity at the end of Update * correct spelling case * primary contructor * add docs and make async interface accept cancellation tokens * address code ql error
1 parent 01dc56b commit d186a41

File tree

8 files changed

+81
-38
lines changed

8 files changed

+81
-38
lines changed

dotnet/AutoGen.v3.ncrunchsolution

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
<SolutionConfiguration>
2+
<Settings>
3+
<AllowParallelTestExecution>True</AllowParallelTestExecution>
4+
<EnableRDI>True</EnableRDI>
5+
<RdiConfigured>True</RdiConfigured>
6+
<SolutionConfigured>True</SolutionConfigured>
7+
</Settings>
8+
</SolutionConfiguration>

dotnet/src/Microsoft.AutoGen/Abstractions/IAgentRuntime.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ public interface IAgentRuntime
1515
ValueTask SendRequestAsync(IAgentBase agent, RpcRequest request, CancellationToken cancellationToken = default);
1616
ValueTask SendMessageAsync(Message message, CancellationToken cancellationToken = default);
1717
ValueTask PublishEventAsync(CloudEvent @event, CancellationToken cancellationToken = default);
18-
void Update(Activity? activity, RpcRequest request);
19-
void Update(Activity? activity, CloudEvent cloudEvent);
20-
(string?, string?) GetTraceIDandState(IDictionary<string, string> metadata);
18+
void Update(RpcRequest request, Activity? activity);
19+
void Update(CloudEvent cloudEvent, Activity? activity);
20+
(string?, string?) GetTraceIdAndState(IDictionary<string, string> metadata);
2121
IDictionary<string, string> ExtractMetadata(IDictionary<string, string> metadata);
2222
}

dotnet/src/Microsoft.AutoGen/Abstractions/IAgentState.cs

+18-2
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,24 @@
33

44
namespace Microsoft.AutoGen.Abstractions;
55

6+
/// <summary>
7+
/// Interface for managing the state of an agent.
8+
/// </summary>
69
public interface IAgentState
710
{
8-
ValueTask<AgentState> ReadStateAsync();
9-
ValueTask<string> WriteStateAsync(AgentState state, string eTag);
11+
/// <summary>
12+
/// Reads the current state of the agent asynchronously.
13+
/// </summary>
14+
/// <param name="cancellationToken">A token to cancel the operation.</param>
15+
/// <returns>A task that represents the asynchronous read operation. The task result contains the current state of the agent.</returns>
16+
ValueTask<AgentState> ReadStateAsync(CancellationToken cancellationToken = default);
17+
18+
/// <summary>
19+
/// Writes the specified state of the agent asynchronously.
20+
/// </summary>
21+
/// <param name="state">The state to write.</param>
22+
/// <param name="eTag">The ETag for concurrency control.</param>
23+
/// <param name="cancellationToken">A token to cancel the operation.</param>
24+
/// <returns>A task that represents the asynchronous write operation. The task result contains the ETag of the written state.</returns>
25+
ValueTask<string> WriteStateAsync(AgentState state, string eTag, CancellationToken cancellationToken = default);
1026
}

dotnet/src/Microsoft.AutoGen/Agents/AgentBase.cs

+10-10
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ protected internal async Task HandleRpcMessage(Message msg, CancellationToken ca
9393
{
9494
var activity = this.ExtractActivity(msg.CloudEvent.Type, msg.CloudEvent.Metadata);
9595
await this.InvokeWithActivityAsync(
96-
static ((AgentBase Agent, CloudEvent Item) state) => state.Agent.CallHandler(state.Item),
96+
static ((AgentBase Agent, CloudEvent Item) state, CancellationToken _) => state.Agent.CallHandler(state.Item),
9797
(this, msg.CloudEvent),
9898
activity,
9999
msg.CloudEvent.Type, cancellationToken).ConfigureAwait(false);
@@ -103,7 +103,7 @@ await this.InvokeWithActivityAsync(
103103
{
104104
var activity = this.ExtractActivity(msg.Request.Method, msg.Request.Metadata);
105105
await this.InvokeWithActivityAsync(
106-
static ((AgentBase Agent, RpcRequest Request) state) => state.Agent.OnRequestCoreAsync(state.Request),
106+
static ((AgentBase Agent, RpcRequest Request) state, CancellationToken ct) => state.Agent.OnRequestCoreAsync(state.Request, ct),
107107
(this, msg.Request),
108108
activity,
109109
msg.Request.Method, cancellationToken).ConfigureAwait(false);
@@ -142,8 +142,8 @@ public async Task StoreAsync(AgentState state, CancellationToken cancellationTok
142142
}
143143
public async Task<T> ReadAsync<T>(AgentId agentId, CancellationToken cancellationToken = default) where T : IMessage, new()
144144
{
145-
var agentstate = await _context.ReadAsync(agentId, cancellationToken).ConfigureAwait(false);
146-
return agentstate.FromAgentState<T>();
145+
var agentState = await _context.ReadAsync(agentId, cancellationToken).ConfigureAwait(false);
146+
return agentState.FromAgentState<T>();
147147
}
148148
private void OnResponseCore(RpcResponse response)
149149
{
@@ -195,9 +195,9 @@ protected async Task<RpcResponse> RequestAsync(AgentId target, string method, Di
195195
activity?.SetTag("peer.service", target.ToString());
196196

197197
var completion = new TaskCompletionSource<RpcResponse>(TaskCreationOptions.RunContinuationsAsynchronously);
198-
_context.Update(activity, request);
198+
_context.Update(request, activity);
199199
await this.InvokeWithActivityAsync(
200-
static async ((AgentBase Agent, RpcRequest Request, TaskCompletionSource<RpcResponse>) state) =>
200+
static async ((AgentBase Agent, RpcRequest Request, TaskCompletionSource<RpcResponse>) state, CancellationToken ct) =>
201201
{
202202
var (self, request, completion) = state;
203203

@@ -206,7 +206,7 @@ static async ((AgentBase Agent, RpcRequest Request, TaskCompletionSource<RpcResp
206206
self._pendingRequests[request.RequestId] = completion;
207207
}
208208

209-
await state.Agent._context.SendRequestAsync(state.Agent, state.Request).ConfigureAwait(false);
209+
await state.Agent._context.SendRequestAsync(state.Agent, state.Request, ct).ConfigureAwait(false);
210210

211211
await completion.Task.ConfigureAwait(false);
212212
},
@@ -231,11 +231,11 @@ public async ValueTask PublishEventAsync(CloudEvent item, CancellationToken canc
231231
activity?.SetTag("peer.service", $"{item.Type}/{item.Source}");
232232

233233
// TODO: fix activity
234-
_context.Update(activity, item);
234+
_context.Update(item, activity);
235235
await this.InvokeWithActivityAsync(
236-
static async ((AgentBase Agent, CloudEvent Event) state) =>
236+
static async ((AgentBase Agent, CloudEvent Event) state, CancellationToken ct) =>
237237
{
238-
await state.Agent._context.PublishEventAsync(state.Event).ConfigureAwait(false);
238+
await state.Agent._context.PublishEventAsync(state.Event, ct).ConfigureAwait(false);
239239
},
240240
(this, item),
241241
activity,

dotnet/src/Microsoft.AutoGen/Agents/AgentBaseExtensions.cs

+28-9
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,25 @@
55

66
namespace Microsoft.AutoGen.Agents;
77

8+
/// <summary>
9+
/// Provides extension methods for the <see cref="AgentBase"/> class.
10+
/// </summary>
811
public static class AgentBaseExtensions
912
{
13+
/// <summary>
14+
/// Extracts an <see cref="Activity"/> from the given agent and metadata.
15+
/// </summary>
16+
/// <param name="agent">The agent from which to extract the activity.</param>
17+
/// <param name="activityName">The name of the activity.</param>
18+
/// <param name="metadata">The metadata containing trace information.</param>
19+
/// <returns>The extracted <see cref="Activity"/> or null if extraction fails.</returns>
1020
public static Activity? ExtractActivity(this AgentBase agent, string activityName, IDictionary<string, string> metadata)
1121
{
1222
Activity? activity;
13-
(var traceParent, var traceState) = agent.Context.GetTraceIDandState(metadata);
23+
var (traceParent, traceState) = agent.Context.GetTraceIdAndState(metadata);
1424
if (!string.IsNullOrEmpty(traceParent))
1525
{
16-
if (ActivityContext.TryParse(traceParent, traceState, isRemote: true, out ActivityContext parentContext))
26+
if (ActivityContext.TryParse(traceParent, traceState, isRemote: true, out var parentContext))
1727
{
1828
// traceParent is a W3CId
1929
activity = AgentBase.s_source.CreateActivity(activityName, ActivityKind.Server, parentContext);
@@ -33,12 +43,9 @@ public static class AgentBaseExtensions
3343

3444
var baggage = agent.Context.ExtractMetadata(metadata);
3545

36-
if (baggage is not null)
46+
foreach (var baggageItem in baggage)
3747
{
38-
foreach (var baggageItem in baggage)
39-
{
40-
activity.AddBaggage(baggageItem.Key, baggageItem.Value);
41-
}
48+
activity.AddBaggage(baggageItem.Key, baggageItem.Value);
4249
}
4350
}
4451
}
@@ -49,7 +56,19 @@ public static class AgentBaseExtensions
4956

5057
return activity;
5158
}
52-
public static async Task InvokeWithActivityAsync<TState>(this AgentBase agent, Func<TState, Task> func, TState state, Activity? activity, string methodName, CancellationToken cancellationToken = default)
59+
60+
/// <summary>
61+
/// Invokes a function asynchronously within the context of an <see cref="Activity"/>.
62+
/// </summary>
63+
/// <typeparam name="TState">The type of the state parameter.</typeparam>
64+
/// <param name="agent">The agent invoking the function.</param>
65+
/// <param name="func">The function to invoke.</param>
66+
/// <param name="state">The state parameter to pass to the function.</param>
67+
/// <param name="activity">The activity within which to invoke the function.</param>
68+
/// <param name="methodName">The name of the method being invoked.</param>
69+
/// <param name="cancellationToken">A token to monitor for cancellation requests.</param>
70+
/// <returns>A task representing the asynchronous operation.</returns>
71+
public static async Task InvokeWithActivityAsync<TState>(this AgentBase agent, Func<TState, CancellationToken, Task> func, TState state, Activity? activity, string methodName, CancellationToken cancellationToken = default)
5372
{
5473
if (activity is not null && activity.StartTimeUtc == default)
5574
{
@@ -63,7 +82,7 @@ public static async Task InvokeWithActivityAsync<TState>(this AgentBase agent, F
6382

6483
try
6584
{
66-
await func(state).ConfigureAwait(false);
85+
await func(state, cancellationToken).ConfigureAwait(false);
6786
if (activity is not null && activity.IsAllDataRequested)
6887
{
6988
activity.SetStatus(ActivityStatusCode.Ok);

dotnet/src/Microsoft.AutoGen/Agents/AgentRuntime.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ internal sealed class AgentRuntime(AgentId agentId, IAgentWorker worker, ILogger
1515
public ILogger<AgentBase> Logger { get; } = logger;
1616
public IAgentBase? AgentInstance { get; set; }
1717
private DistributedContextPropagator DistributedContextPropagator { get; } = distributedContextPropagator;
18-
public (string?, string?) GetTraceIDandState(IDictionary<string, string> metadata)
18+
public (string?, string?) GetTraceIdAndState(IDictionary<string, string> metadata)
1919
{
2020
DistributedContextPropagator.ExtractTraceIdAndState(metadata,
2121
static (object? carrier, string fieldName, out string? fieldValue, out IEnumerable<string>? fieldValues) =>
@@ -28,11 +28,11 @@ internal sealed class AgentRuntime(AgentId agentId, IAgentWorker worker, ILogger
2828
out var traceState);
2929
return (traceParent, traceState);
3030
}
31-
public void Update(Activity? activity, RpcRequest request)
31+
public void Update(RpcRequest request, Activity? activity = null)
3232
{
3333
DistributedContextPropagator.Inject(activity, request.Metadata, static (carrier, key, value) => ((IDictionary<string, string>)carrier!)[key] = value);
3434
}
35-
public void Update(Activity? activity, CloudEvent cloudEvent)
35+
public void Update(CloudEvent cloudEvent, Activity? activity = null)
3636
{
3737
DistributedContextPropagator.Inject(activity, cloudEvent.Metadata, static (carrier, key, value) => ((IDictionary<string, string>)carrier!)[key] = value);
3838
}

dotnet/src/Microsoft.AutoGen/Agents/Agents/AIAgent/InferenceAgent.cs

+7-9
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,14 @@
55
using Microsoft.AutoGen.Abstractions;
66
using Microsoft.Extensions.AI;
77
namespace Microsoft.AutoGen.Agents;
8-
public abstract class InferenceAgent<T> : AgentBase where T : IMessage, new()
8+
public abstract class InferenceAgent<T>(
9+
IAgentRuntime context,
10+
EventTypes typeRegistry,
11+
IChatClient client)
12+
: AgentBase(context, typeRegistry)
13+
where T : IMessage, new()
914
{
10-
protected IChatClient ChatClient { get; }
11-
public InferenceAgent(
12-
IAgentRuntime context,
13-
EventTypes typeRegistry, IChatClient client
14-
) : base(context, typeRegistry)
15-
{
16-
ChatClient = client;
17-
}
15+
protected IChatClient ChatClient { get; } = client;
1816

1917
private Task<ChatCompletion> CompleteAsync(
2018
IList<ChatMessage> chatMessages,

dotnet/src/Microsoft.AutoGen/Agents/Services/Orleans/AgentStateGrain.cs

+4-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ namespace Microsoft.AutoGen.Agents;
77

88
internal sealed class AgentStateGrain([PersistentState("state", "AgentStateStore")] IPersistentState<AgentState> state) : Grain, IAgentState
99
{
10-
public async ValueTask<string> WriteStateAsync(AgentState newState, string eTag)
10+
/// <inheritdoc />
11+
public async ValueTask<string> WriteStateAsync(AgentState newState, string eTag, CancellationToken cancellationToken = default)
1112
{
1213
// etags for optimistic concurrency control
1314
// if the Etag is null, its a new state
@@ -27,7 +28,8 @@ public async ValueTask<string> WriteStateAsync(AgentState newState, string eTag)
2728
return state.Etag;
2829
}
2930

30-
public ValueTask<AgentState> ReadStateAsync()
31+
/// <inheritdoc />
32+
public ValueTask<AgentState> ReadStateAsync(CancellationToken cancellationToken = default)
3133
{
3234
return ValueTask.FromResult(state.State);
3335
}

0 commit comments

Comments
 (0)