Skip to content

Commit 66b7b66

Browse files
committed
Make response_format a first class citizen and update endpoints to support it
1 parent e34413b commit 66b7b66

File tree

18 files changed

+320
-74
lines changed

18 files changed

+320
-74
lines changed

lib/completions/dialects/nova.rb

+1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def to_payload(options = nil)
4242
result = { system: system, messages: messages }
4343
result[:inferenceConfig] = inference_config if inference_config.present?
4444
result[:toolConfig] = tool_config if tool_config.present?
45+
result[:response_format] = { type: "json_object" } if options[:response_format].present?
4546

4647
result
4748
end

lib/completions/endpoints/anthropic.rb

+18-5
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,16 @@ def xml_tools_enabled?
8888
def prepare_payload(prompt, model_params, dialect)
8989
@native_tool_support = dialect.native_tool_support?
9090

91-
payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages)
91+
payload =
92+
default_options(dialect).merge(model_params.except(:response_format)).merge(
93+
messages: prompt.messages,
94+
)
9295

9396
payload[:system] = prompt.system_prompt if prompt.system_prompt.present?
9497
payload[:stream] = true if @streaming_mode
98+
99+
preffiled_message = +""
100+
95101
if prompt.has_tools?
96102
payload[:tools] = prompt.tools
97103
if dialect.tool_choice.present?
@@ -100,16 +106,23 @@ def prepare_payload(prompt, model_params, dialect)
100106

101107
# prefill prompt to nudge LLM to generate a response that is useful.
102108
# without this LLM (even 3.7) can get confused and start text preambles for a tool calls.
103-
payload[:messages] << {
104-
role: "assistant",
105-
content: dialect.no_more_tool_calls_text,
106-
}
109+
preffiled_message << dialect.no_more_tool_calls_text
107110
else
108111
payload[:tool_choice] = { type: "tool", name: prompt.tool_choice }
109112
end
110113
end
111114
end
112115

116+
# Prefill prompt to force JSON output.
117+
if model_params[:response_format].present?
118+
preffiled_message << " " if !preffiled_message.empty?
119+
preffiled_message << "{"
120+
end
121+
122+
if !preffiled_message.empty?
123+
payload[:messages] << { role: "assistant", content: preffiled_message }
124+
end
125+
113126
payload
114127
end
115128

lib/completions/endpoints/aws_bedrock.rb

+17-5
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,14 @@ def prepare_payload(prompt, model_params, dialect)
116116
payload = nil
117117

118118
if dialect.is_a?(DiscourseAi::Completions::Dialects::Claude)
119-
payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages)
119+
payload =
120+
default_options(dialect).merge(model_params.except(:response_format)).merge(
121+
messages: prompt.messages,
122+
)
120123
payload[:system] = prompt.system_prompt if prompt.system_prompt.present?
121124

125+
preffiled_message = +""
126+
122127
if prompt.has_tools?
123128
payload[:tools] = prompt.tools
124129
if dialect.tool_choice.present?
@@ -128,15 +133,22 @@ def prepare_payload(prompt, model_params, dialect)
128133
# payload[:tool_choice] = { type: "none" }
129134

130135
# prefill prompt to nudge LLM to generate a response that is useful, instead of trying to call a tool
131-
payload[:messages] << {
132-
role: "assistant",
133-
content: dialect.no_more_tool_calls_text,
134-
}
136+
preffiled_message << dialect.no_more_tool_calls_text
135137
else
136138
payload[:tool_choice] = { type: "tool", name: prompt.tool_choice }
137139
end
138140
end
139141
end
142+
143+
# Prefill prompt to force JSON output.
144+
if model_params[:response_format].present?
145+
preffiled_message << " " if !preffiled_message.empty?
146+
preffiled_message << "{"
147+
end
148+
149+
if !preffiled_message.empty?
150+
payload[:messages] << { role: "assistant", content: preffiled_message }
151+
end
140152
elsif dialect.is_a?(DiscourseAi::Completions::Dialects::Nova)
141153
payload = prompt.to_payload(default_options(dialect).merge(model_params))
142154
else

lib/completions/endpoints/canned_response.rb

+9
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ def perform_completion!(
4040
"The number of completions you requested exceed the number of canned responses"
4141
end
4242

43+
response = transform_from_schema(response) if model_params[:response_format].present?
44+
4345
raise response if response.is_a?(StandardError)
4446

4547
@completions += 1
@@ -80,6 +82,13 @@ def is_thinking?(response)
8082
def is_tool?(response)
8183
response.is_a?(DiscourseAi::Completions::ToolCall)
8284
end
85+
86+
def transform_from_schema(response)
87+
key = model_params[:response_format].dig(:json_schema, :schema, :properties)&.keys&.first
88+
return response if key.nil?
89+
90+
{ key => response }.to_json
91+
end
8392
end
8493
end
8594
end

lib/completions/endpoints/gemini.rb

+10-1
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,16 @@ def prepare_payload(prompt, model_params, dialect)
8484

8585
payload[:tool_config] = { function_calling_config: function_calling_config }
8686
end
87-
payload[:generationConfig].merge!(model_params) if model_params.present?
87+
if model_params.present?
88+
payload[:generationConfig].merge!(model_params.except(:response_format))
89+
90+
if model_params[:response_format].present?
91+
# https://ai.google.dev/api/generate-content#generationconfig
92+
payload[:generationConfig][:responseSchema] = model_params[:response_format]
93+
payload[:generationConfig][:responseMimeType] = "application/json"
94+
end
95+
end
96+
8897
payload
8998
end
9099

lib/completions/endpoints/samba_nova.rb

+6-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,12 @@ def model_uri
3434
end
3535

3636
def prepare_payload(prompt, model_params, dialect)
37-
payload = default_options.merge(model_params).merge(messages: prompt)
37+
payload =
38+
default_options.merge(model_params.except(:response_format)).merge(messages: prompt)
39+
40+
if model_params[:response_format].present?
41+
payload[:response_format] = { type: "json_object" }
42+
end
3843

3944
payload[:stream] = true if @streaming_mode
4045

lib/completions/llm.rb

+5-1
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,8 @@ def initialize(dialect_klass, gateway_klass, llm_model, gateway: nil)
241241
# @param feature_context { Hash - Optional } - The feature context to use for the completion.
242242
# @param partial_tool_calls { Boolean - Optional } - If true, the completion will return partial tool calls.
243243
# @param output_thinking { Boolean - Optional } - If true, the completion will return the thinking output for thinking models.
244-
# @param extra_model_params { Hash - Optional } - Other params that are not available accross models. e.g. response_format JSON schema.
244+
# @param response_format { Hash - Optional } - JSON schema passed to the API as the desired structured output.
245+
# @param [Experimental] extra_model_params { Hash - Optional } - Other params that are not available accross models. e.g. response_format JSON schema.
245246
#
246247
# @param &on_partial_blk { Block - Optional } - The passed block will get called with the LLM partial response alongside a cancel function.
247248
#
@@ -259,6 +260,7 @@ def generate(
259260
feature_context: nil,
260261
partial_tool_calls: false,
261262
output_thinking: false,
263+
response_format: nil,
262264
extra_model_params: nil,
263265
&partial_read_blk
264266
)
@@ -274,6 +276,7 @@ def generate(
274276
feature_context: feature_context,
275277
partial_tool_calls: partial_tool_calls,
276278
output_thinking: output_thinking,
279+
response_format: response_format,
277280
extra_model_params: extra_model_params,
278281
},
279282
)
@@ -282,6 +285,7 @@ def generate(
282285

283286
model_params[:temperature] = temperature if temperature
284287
model_params[:top_p] = top_p if top_p
288+
model_params[:response_format] = response_format if response_format
285289
model_params.merge!(extra_model_params) if extra_model_params
286290

287291
if prompt.is_a?(String)

lib/summarization/fold_content.rb

+14-16
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def fold(items, user, &on_partial_blk)
115115
# Auxiliary variables to get the summary content from the JSON response.
116116
raw_buffer = +""
117117
json_start_found = false
118-
json_reply_start_regex = /\{\s*"summary"\s*:\s*"/
118+
json_reply_start_regex = /\{?\s*"summary"\s*:\s*"/ # { is optional because Claude uses prefill, so it's not incldued.
119119
unescape_regex = %r{\\(["/bfnrt])}
120120
json_reply_end = "\"}"
121121

@@ -143,7 +143,7 @@ def fold(items, user, &on_partial_blk)
143143
end
144144
end
145145

146-
bot.reply(context, llm_args: { extra_model_params: response_format }, &buffer_blk)
146+
bot.reply(context, llm_args: { response_format: response_format_schema }, &buffer_blk)
147147

148148
summary.chomp(json_reply_end)
149149
end
@@ -172,24 +172,22 @@ def truncate(item)
172172
item
173173
end
174174

175-
def response_format
175+
def response_format_schema
176176
{
177-
response_format: {
178-
type: "json_schema",
179-
json_schema: {
180-
name: "reply",
181-
schema: {
182-
type: "object",
183-
properties: {
184-
summary: {
185-
type: "string",
186-
},
177+
type: "json_schema",
178+
json_schema: {
179+
name: "reply",
180+
schema: {
181+
type: "object",
182+
properties: {
183+
summary: {
184+
type: "string",
187185
},
188-
required: ["summary"],
189-
additionalProperties: false,
190186
},
191-
strict: true,
187+
required: ["summary"],
188+
additionalProperties: false,
192189
},
190+
strict: true,
193191
},
194192
}
195193
end

spec/jobs/regular/fast_track_topic_gist_spec.rb

+3-7
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,9 @@
2424

2525
let(:updated_gist) { "They updated me :(" }
2626

27-
def in_json_format(summary)
28-
"{\"summary\":\"#{summary}\"}"
29-
end
30-
3127
context "when it's up to date" do
3228
it "does nothing" do
33-
DiscourseAi::Completions::Llm.with_prepared_responses([in_json_format(updated_gist)]) do
29+
DiscourseAi::Completions::Llm.with_prepared_responses([updated_gist]) do
3430
subject.execute(topic_id: topic_1.id)
3531
end
3632

@@ -44,7 +40,7 @@ def in_json_format(summary)
4440
before { Fabricate(:post, topic: topic_1, post_number: 3) }
4541

4642
it "regenerates the gist using the latest data" do
47-
DiscourseAi::Completions::Llm.with_prepared_responses([in_json_format(updated_gist)]) do
43+
DiscourseAi::Completions::Llm.with_prepared_responses([updated_gist]) do
4844
subject.execute(topic_id: topic_1.id)
4945
end
5046

@@ -57,7 +53,7 @@ def in_json_format(summary)
5753
it "does nothing if the gist was created less than 5 minutes ago" do
5854
ai_gist.update!(created_at: 2.minutes.ago)
5955

60-
DiscourseAi::Completions::Llm.with_prepared_responses([in_json_format(updated_gist)]) do
56+
DiscourseAi::Completions::Llm.with_prepared_responses([updated_gist]) do
6157
subject.execute(topic_id: topic_1.id)
6258
end
6359

spec/jobs/regular/stream_topic_ai_summary_spec.rb

+2-6
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,10 @@ def with_responses(responses)
5050
end
5151
end
5252

53-
def in_json_format(summary)
54-
"{\"summary\":\"#{summary}\"}"
55-
end
56-
5753
it "publishes updates with a partial summary" do
5854
summary = "dummy"
5955

60-
with_responses([in_json_format(summary)]) do
56+
with_responses([summary]) do
6157
messages =
6258
MessageBus.track_publish("/discourse-ai/summaries/topic/#{topic.id}") do
6359
job.execute(topic_id: topic.id, user_id: user.id)
@@ -74,7 +70,7 @@ def in_json_format(summary)
7470
it "publishes a final update to signal we're done and provide metadata" do
7571
summary = "dummy"
7672

77-
with_responses([in_json_format(summary)]) do
73+
with_responses([summary]) do
7874
messages =
7975
MessageBus.track_publish("/discourse-ai/summaries/topic/#{topic.id}") do
8076
job.execute(topic_id: topic.id, user_id: user.id)

spec/jobs/scheduled/summaries_backfill_spec.rb

+1-5
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,6 @@
8484
end
8585
end
8686

87-
def in_json_format(summary)
88-
"{\"summary\":\"#{summary}\"}"
89-
end
90-
9187
describe "#execute" do
9288
it "backfills a batch" do
9389
topic_2 =
@@ -102,7 +98,7 @@ def in_json_format(summary)
10298
gist_2 = "Updated gist of topic"
10399

104100
DiscourseAi::Completions::Llm.with_prepared_responses(
105-
[gist_1, gist_2, summary_1, summary_2].map { |s| in_json_format(s) },
101+
[gist_1, gist_2, summary_1, summary_2],
106102
) { subject.execute({}) }
107103

108104
expect(AiSummary.complete.find_by(target: topic_2).summarized_text).to eq(summary_1)

0 commit comments

Comments
 (0)