Skip to content

Commit 7abc26c

Browse files
committed
refactor: enhance AxAI message handling and signature validation
This commit refines the message creation logic in the AxAIOpenAIChatRequest, improving the handling of user and assistant roles, including better management of tool calls. The AxSignature class is updated to accept a configuration object for inputs and outputs, enhancing flexibility. Additionally, the AxGen class constructor is modified to ensure non-nullable signature parameters, and various debug logging improvements are introduced across the AxGen and AxProgram classes.
1 parent 15c82e3 commit 7abc26c

File tree

7 files changed

+139
-54
lines changed

7 files changed

+139
-54
lines changed

src/ax/ai/openai/api.ts

Lines changed: 45 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -456,16 +456,19 @@ const mapFinishReason = (
456456
function createMessages<TModel>(
457457
req: Readonly<AxInternalChatRequest<TModel>>
458458
): AxAIOpenAIChatRequest<TModel>['messages'] {
459+
type UserContent = Extract<
460+
AxAIOpenAIChatRequest<TModel>['messages'][number],
461+
{ role: 'user' }
462+
>['content']
463+
459464
return req.chatPrompt.map((msg) => {
460465
switch (msg.role) {
461466
case 'system':
462467
return { role: 'system' as const, content: msg.content }
468+
463469
case 'user':
464-
if (Array.isArray(msg.content)) {
465-
return {
466-
role: 'user' as const,
467-
name: msg.name,
468-
content: msg.content.map((c) => {
470+
const content: UserContent = Array.isArray(msg.content)
471+
? msg.content.map((c) => {
469472
switch (c.type) {
470473
case 'text':
471474
return { type: 'text' as const, text: c.text }
@@ -486,27 +489,48 @@ function createMessages<TModel>(
486489
default:
487490
throw new Error('Invalid content type')
488491
}
489-
}),
490-
}
492+
})
493+
: msg.content
494+
return {
495+
role: 'user' as const,
496+
...(msg.name ? { name: msg.name } : {}),
497+
content,
491498
}
492-
return { role: 'user' as const, content: msg.content, name: msg.name }
499+
493500
case 'assistant':
501+
const toolCalls = msg.functionCalls?.map((v) => ({
502+
id: v.id,
503+
type: 'function' as const,
504+
function: {
505+
name: v.function.name,
506+
arguments:
507+
typeof v.function.params === 'object'
508+
? JSON.stringify(v.function.params)
509+
: v.function.params,
510+
},
511+
}))
512+
513+
if (toolCalls && toolCalls.length > 0) {
514+
return {
515+
role: 'assistant' as const,
516+
...(msg.content ? { content: msg.content } : {}),
517+
name: msg.name,
518+
tool_calls: toolCalls,
519+
}
520+
}
521+
522+
if (!msg.content) {
523+
throw new Error(
524+
'Assistant content is required when no tool calls are provided'
525+
)
526+
}
527+
494528
return {
495529
role: 'assistant' as const,
496-
content: msg.content as string,
497-
name: msg.name,
498-
tool_calls: msg.functionCalls?.map((v) => ({
499-
id: v.id,
500-
type: 'function' as const,
501-
function: {
502-
name: v.function.name,
503-
arguments:
504-
typeof v.function.params === 'object'
505-
? JSON.stringify(v.function.params)
506-
: v.function.params,
507-
},
508-
})),
530+
content: msg.content,
531+
...(msg.name ? { name: msg.name } : {}),
509532
}
533+
510534
case 'function':
511535
return {
512536
role: 'tool' as const,

src/ax/ai/openai/chat_types.ts

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,17 @@ export type AxAIOpenAIChatRequest<TModel> = {
124124
text: string
125125
}
126126
name?: string
127-
tool_calls?: {
127+
}
128+
| {
129+
role: 'assistant'
130+
content?:
131+
| string
132+
| {
133+
type: string
134+
text: string
135+
}
136+
name?: string
137+
tool_calls: {
128138
type: 'function'
129139
function: {
130140
name: string

src/ax/dsp/generate.ts

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ import type {
1414
AxChatResponse,
1515
AxChatResponseResult,
1616
AxFunction,
17-
AxLoggerFunction,
1817
} from '../ai/types.js'
1918
import { mergeFunctionCalls } from '../ai/util.js'
2019
import { AxMemory } from '../mem/memory.js'
@@ -105,16 +104,14 @@ export class AxGen<
105104
private values: AxGenOutType = {}
106105
private excludeContentFromTrace: boolean = false
107106
private thoughtFieldName: string
108-
private logger?: AxLoggerFunction
109107

110108
constructor(
111-
signature: Readonly<AxSignature | string>,
109+
signature: NonNullable<ConstructorParameters<typeof AxSignature>[0]>,
112110
options?: Readonly<AxProgramForwardOptions>
113111
) {
114112
super(signature, { description: options?.description })
115113

116114
this.options = options
117-
this.logger = options?.logger
118115
this.thoughtFieldName = options?.thoughtFieldName ?? 'thought'
119116
const promptTemplateOptions = {
120117
functions: options?.functions,
@@ -247,7 +244,7 @@ export class AxGen<
247244
traceId,
248245
rateLimiter,
249246
stream,
250-
debug: false,
247+
debug: false, // we do our own debug logging
251248
thinkingTokenBudget,
252249
showThoughts,
253250
traceContext,
@@ -275,7 +272,6 @@ export class AxGen<
275272
}>) {
276273
const { sessionId, traceId, functions: _functions } = options ?? {}
277274
const fastFail = options?.fastFail ?? this.options?.fastFail
278-
279275
const model = options.model
280276

281277
// biome-ignore lint/complexity/useFlatMap: you cannot use flatMap here
@@ -303,6 +299,8 @@ export class AxGen<
303299
fastFail,
304300
span,
305301
})
302+
303+
this.getLogger(ai, options)?.('', { tags: ['responseEnd'] })
306304
} else {
307305
yield await this.processResponse({
308306
ai,
@@ -343,7 +341,6 @@ export class AxGen<
343341
mem.addResult(
344342
{
345343
content: '',
346-
name: 'initial',
347344
functionCalls: [],
348345
},
349346
sessionId
@@ -492,11 +489,6 @@ export class AxGen<
492489
xstate
493490
)
494491
}
495-
496-
if (ai.getOptions().debug) {
497-
const logger = ai.getLogger()
498-
logger('', { tags: ['responseEnd'] })
499-
}
500492
}
501493

502494
private async processResponse({
@@ -589,9 +581,11 @@ export class AxGen<
589581

590582
const maxRetries = options.maxRetries ?? this.options?.maxRetries ?? 10
591583
const maxSteps = options.maxSteps ?? this.options?.maxSteps ?? 10
592-
const debug = options.debug ?? ai.getOptions().debug
593584
const debugHideSystemPrompt = options.debugHideSystemPrompt
594-
const memOptions = { debug, debugHideSystemPrompt }
585+
const memOptions = {
586+
debug: this.isDebug(ai, options),
587+
debugHideSystemPrompt,
588+
}
595589

596590
const mem =
597591
options.mem ?? this.options?.mem ?? new AxMemory(10000, memOptions)
@@ -664,11 +658,7 @@ export class AxGen<
664658
continue multiStepLoop
665659
}
666660

667-
if (debug) {
668-
const logger = options.logger ?? this.logger ?? ai.getLogger()
669-
logger('', { tags: ['responseEnd'] })
670-
}
671-
661+
this.getLogger(ai, options)?.('', { tags: ['responseEnd'] })
672662
return
673663
} catch (e) {
674664
let errorFields: AxIField[] | undefined
@@ -872,6 +862,22 @@ export class AxGen<
872862
super.setExamples(examples, options)
873863
// No need to update prompt template - all fields can be missing in examples
874864
}
865+
866+
private isDebug(
867+
ai: Readonly<AxAIService>,
868+
options?: Readonly<AxProgramForwardOptions>
869+
) {
870+
return (
871+
options?.debug ?? this.options?.debug ?? ai.getOptions().debug ?? false
872+
)
873+
}
874+
875+
private getLogger(
876+
ai: Readonly<AxAIService>,
877+
options?: Readonly<AxProgramForwardOptions>
878+
) {
879+
return options?.logger ?? this.options?.logger ?? ai.getLogger()
880+
}
875881
}
876882

877883
export type AxGenerateErrorDetails = {

src/ax/dsp/program.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ export class AxProgramWithSignature<IN extends AxGenIn, OUT extends AxGenOut>
149149
private children: AxInstanceRegistry<Readonly<AxTunable & AxUsable>>
150150

151151
constructor(
152-
signature: Readonly<AxSignature | string>,
152+
signature: NonNullable<ConstructorParameters<typeof AxSignature>[0]>,
153153
options?: Readonly<AxProgramWithSignatureOptions>
154154
) {
155155
this.signature = new AxSignature(signature)

src/ax/dsp/sig.ts

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ export interface AxField {
2626
| 'datetime'
2727
| 'class'
2828
| 'code'
29-
isArray: boolean
29+
isArray?: boolean
3030
options?: string[]
3131
}
3232
isOptional?: boolean
@@ -46,6 +46,12 @@ class AxSignatureValidationError extends Error {
4646
}
4747
}
4848

49+
export interface AxSignatureConfig {
50+
description?: string
51+
inputs: readonly AxField[]
52+
outputs: readonly AxField[]
53+
}
54+
4955
export class AxSignature {
5056
private description?: string
5157
private inputFields: AxIField[]
@@ -57,7 +63,7 @@ export class AxSignature {
5763
// Validation caching - stores hash when validation last passed
5864
private validatedAtHash?: string
5965

60-
constructor(signature?: Readonly<AxSignature | string>) {
66+
constructor(signature?: Readonly<AxSignature | string | AxSignatureConfig>) {
6167
if (!signature) {
6268
this.inputFields = []
6369
this.outputFields = []
@@ -108,11 +114,47 @@ export class AxSignature {
108114
if (signature.validatedAtHash === this.sigHash) {
109115
this.validatedAtHash = this.sigHash
110116
}
117+
} else if (typeof signature === 'object' && signature !== null) {
118+
// Handle AxSignatureConfig object
119+
if (!('inputs' in signature) || !('outputs' in signature)) {
120+
throw new AxSignatureValidationError(
121+
'Invalid signature object: missing inputs or outputs',
122+
undefined,
123+
'Signature object must have "inputs" and "outputs" arrays. Example: { inputs: [...], outputs: [...] }'
124+
)
125+
}
126+
127+
if (
128+
!Array.isArray(signature.inputs) ||
129+
!Array.isArray(signature.outputs)
130+
) {
131+
throw new AxSignatureValidationError(
132+
'Invalid signature object: inputs and outputs must be arrays',
133+
undefined,
134+
'Both "inputs" and "outputs" must be arrays of AxField objects'
135+
)
136+
}
137+
138+
try {
139+
this.description = signature.description
140+
this.inputFields = signature.inputs.map((v) => this.parseField(v))
141+
this.outputFields = signature.outputs.map((v) => this.parseField(v))
142+
;[this.sigHash, this.sigString] = this.updateHash()
143+
} catch (error) {
144+
if (error instanceof AxSignatureValidationError) {
145+
throw error
146+
}
147+
throw new AxSignatureValidationError(
148+
`Failed to create signature from object: ${error instanceof Error ? error.message : 'Unknown error'}`,
149+
undefined,
150+
'Check that all fields in inputs and outputs arrays are valid AxField objects'
151+
)
152+
}
111153
} else {
112154
throw new AxSignatureValidationError(
113155
'Invalid signature argument type',
114156
undefined,
115-
'Signature must be a string or another AxSignature instance'
157+
'Signature must be a string, another AxSignature instance, or an object with inputs and outputs arrays'
116158
)
117159
}
118160
}
@@ -166,7 +208,7 @@ export class AxSignature {
166208
}
167209
this.description = desc
168210
this.invalidateValidationCache()
169-
this.updateHash()
211+
this.updateHashLight()
170212
}
171213

172214
public addInputField = (field: Readonly<AxField>) => {

src/ax/prompts/agent.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ export class AxAgent<IN extends AxGenIn, OUT extends AxGenOut = AxGenOut>
184184
name: string
185185
description: string
186186
definition?: string
187-
signature: AxSignature | string
187+
signature: NonNullable<ConstructorParameters<typeof AxSignature>[0]>
188188
agents?: AxAgentic[]
189189
functions?: AxInputFunctionType
190190
}>,

src/examples/fibonacci.ts

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1-
import { AxAI, AxGen, AxJSInterpreter, AxSignature } from '@ax-llm/ax'
1+
import { AxAI, AxGen, AxJSInterpreter } from '@ax-llm/ax'
22

3-
const sig = new AxSignature(
4-
`numberSeriesTask:string -> fibonacciSeries:number[]`
3+
const gen = new AxGen<{ numberSeriesTask: string }>(
4+
{
5+
inputs: [{ name: 'numberSeriesTask', type: { name: 'string' } }],
6+
outputs: [{ name: 'fibonacciSeries', type: { name: 'number' } }],
7+
},
8+
{
9+
functions: [new AxJSInterpreter()],
10+
debug: true,
11+
}
512
)
613

7-
const gen = new AxGen<{ numberSeriesTask: string }>(sig, {
8-
functions: [new AxJSInterpreter()],
9-
})
10-
1114
const ai = new AxAI({
1215
name: 'openai',
1316
apiKey: process.env.OPENAI_APIKEY as string,

0 commit comments

Comments
 (0)