Skip to content

Commit 1300cc8

Browse files
authored
FEATURE: Add streaming to composer helper (#1256)
This update adding streaming to the AI helper inside the composer.
1 parent 38b4925 commit 1300cc8

File tree

9 files changed

+271
-43
lines changed

9 files changed

+271
-43
lines changed

app/controllers/discourse_ai/ai_helper/assistant_controller.rb

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def suggest
4343
prompt,
4444
input,
4545
current_user,
46-
force_default_locale,
46+
force_default_locale: force_default_locale,
4747
),
4848
status: 200
4949
end
@@ -110,26 +110,44 @@ def suggest_thumbnails(input)
110110
end
111111

112112
def stream_suggestion
113-
post_id = get_post_param!
114113
text = get_text_param!
115-
post = Post.includes(:topic).find_by(id: post_id)
114+
115+
location = params[:location]
116+
raise Discourse::InvalidParameters.new(:location) if !location
117+
116118
prompt = CompletionPrompt.find_by(id: params[:mode])
117119

118120
raise Discourse::InvalidParameters.new(:mode) if !prompt || !prompt.enabled?
119-
raise Discourse::InvalidParameters.new(:post_id) unless post
121+
return suggest_thumbnails(input) if prompt.id == CompletionPrompt::ILLUSTRATE_POST
120122

121123
if prompt.id == CompletionPrompt::CUSTOM_PROMPT
122124
raise Discourse::InvalidParameters.new(:custom_prompt) if params[:custom_prompt].blank?
123125
end
124126

125-
Jobs.enqueue(
126-
:stream_post_helper,
127-
post_id: post.id,
128-
user_id: current_user.id,
129-
text: text,
130-
prompt: prompt.name,
131-
custom_prompt: params[:custom_prompt],
132-
)
127+
if location == "composer"
128+
Jobs.enqueue(
129+
:stream_composer_helper,
130+
user_id: current_user.id,
131+
text: text,
132+
prompt: prompt.name,
133+
custom_prompt: params[:custom_prompt],
134+
force_default_locale: params[:force_default_locale] || false,
135+
)
136+
else
137+
post_id = get_post_param!
138+
post = Post.includes(:topic).find_by(id: post_id)
139+
140+
raise Discourse::InvalidParameters.new(:post_id) unless post
141+
142+
Jobs.enqueue(
143+
:stream_post_helper,
144+
post_id: post.id,
145+
user_id: current_user.id,
146+
text: text,
147+
prompt: prompt.name,
148+
custom_prompt: params[:custom_prompt],
149+
)
150+
end
133151

134152
render json: { success: true }, status: 200
135153
rescue DiscourseAi::Completions::Endpoints::Base::CompletionFailed
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# frozen_string_literal: true
2+
3+
module Jobs
4+
class StreamComposerHelper < ::Jobs::Base
5+
sidekiq_options retry: false
6+
7+
def execute(args)
8+
return unless args[:prompt]
9+
return unless user = User.find_by(id: args[:user_id])
10+
return unless args[:text]
11+
12+
prompt = CompletionPrompt.enabled_by_name(args[:prompt])
13+
14+
if prompt.id == CompletionPrompt::CUSTOM_PROMPT
15+
prompt.custom_instruction = args[:custom_prompt]
16+
end
17+
18+
DiscourseAi::AiHelper::Assistant.new.stream_prompt(
19+
prompt,
20+
args[:text],
21+
user,
22+
"/discourse-ai/ai-helper/stream_composer_suggestion",
23+
force_default_locale: args[:force_default_locale],
24+
)
25+
end
26+
end
27+
end

assets/javascripts/discourse/components/ai-post-helper-menu.gjs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ export default class AiPostHelperMenu extends Component {
237237
this._activeAiRequest = ajax(fetchUrl, {
238238
method: "POST",
239239
data: {
240+
location: "post",
240241
mode: option.id,
241242
text: this.args.data.selectedText,
242243
post_id: this.args.data.quoteState.postId,

assets/javascripts/discourse/components/modal/diff-modal.gjs

Lines changed: 84 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,92 @@
11
import Component from "@glimmer/component";
22
import { tracked } from "@glimmer/tracking";
33
import { action } from "@ember/object";
4+
import didInsert from "@ember/render-modifiers/modifiers/did-insert";
5+
import willDestroy from "@ember/render-modifiers/modifiers/will-destroy";
46
import { service } from "@ember/service";
57
import { htmlSafe } from "@ember/template";
68
import CookText from "discourse/components/cook-text";
79
import DButton from "discourse/components/d-button";
810
import DModal from "discourse/components/d-modal";
11+
import concatClass from "discourse/helpers/concat-class";
912
import { ajax } from "discourse/lib/ajax";
1013
import { popupAjaxError } from "discourse/lib/ajax-error";
14+
import { bind } from "discourse/lib/decorators";
1115
import { i18n } from "discourse-i18n";
16+
import SmoothStreamer from "../../lib/smooth-streamer";
1217
import AiIndicatorWave from "../ai-indicator-wave";
1318

1419
export default class ModalDiffModal extends Component {
1520
@service currentUser;
21+
@service messageBus;
1622

1723
@tracked loading = false;
1824
@tracked diff;
1925
@tracked suggestion = "";
26+
@tracked
27+
smoothStreamer = new SmoothStreamer(
28+
() => this.suggestion,
29+
(newValue) => (this.suggestion = newValue)
30+
);
2031

2132
constructor() {
2233
super(...arguments);
2334
this.suggestChanges();
2435
}
2536

37+
@bind
38+
subscribe() {
39+
const channel = "/discourse-ai/ai-helper/stream_composer_suggestion";
40+
this.messageBus.subscribe(channel, this.updateResult);
41+
}
42+
43+
@bind
44+
unsubscribe() {
45+
const channel = "/discourse-ai/ai-helper/stream_composer_suggestion";
46+
this.messageBus.subscribe(channel, this.updateResult);
47+
}
48+
49+
@action
50+
async updateResult(result) {
51+
if (result) {
52+
this.loading = false;
53+
}
54+
await this.smoothStreamer.updateResult(result, "result");
55+
56+
if (result.done) {
57+
this.diff = result.diff;
58+
}
59+
60+
const mdTablePromptId = this.currentUser?.ai_helper_prompts.find(
61+
(prompt) => prompt.name === "markdown_table"
62+
).id;
63+
64+
// Markdown table prompt looks better with
65+
// before/after results than diff
66+
// despite having `type: diff`
67+
if (this.args.model.mode === mdTablePromptId) {
68+
this.diff = null;
69+
}
70+
}
71+
2672
@action
2773
async suggestChanges() {
74+
this.smoothStreamer.resetStreaming();
75+
this.diff = null;
76+
this.suggestion = "";
2877
this.loading = true;
2978

3079
try {
31-
const suggestion = await ajax("/discourse-ai/ai-helper/suggest", {
80+
return await ajax("/discourse-ai/ai-helper/stream_suggestion", {
3281
method: "POST",
3382
data: {
83+
location: "composer",
3484
mode: this.args.model.mode,
3585
text: this.args.model.selectedText,
3686
custom_prompt: this.args.model.customPromptValue,
3787
force_default_locale: true,
3888
},
3989
});
40-
41-
this.diff = suggestion.diff;
42-
this.suggestion = suggestion.suggestions[0];
4390
} catch (e) {
4491
popupAjaxError(e);
4592
} finally {
@@ -66,24 +113,42 @@ export default class ModalDiffModal extends Component {
66113
@closeModal={{@closeModal}}
67114
>
68115
<:body>
69-
{{#if this.loading}}
70-
<div class="composer-ai-helper-modal__loading">
71-
<CookText @rawText={{@model.selectedText}} />
72-
</div>
73-
{{else}}
74-
{{#if this.diff}}
75-
{{htmlSafe this.diff}}
76-
{{else}}
77-
<div class="composer-ai-helper-modal__old-value">
78-
{{@model.selectedText}}
116+
<div {{didInsert this.subscribe}} {{willDestroy this.unsubscribe}}>
117+
{{#if this.loading}}
118+
<div class="composer-ai-helper-modal__loading">
119+
<CookText @rawText={{@model.selectedText}} />
79120
</div>
80-
81-
<div class="composer-ai-helper-modal__new-value">
82-
{{this.suggestion}}
121+
{{else}}
122+
<div
123+
class={{concatClass
124+
"composer-ai-helper-modal__suggestion"
125+
"streamable-content"
126+
(if this.smoothStreamer.isStreaming "streaming" "")
127+
}}
128+
>
129+
{{#if this.smoothStreamer.isStreaming}}
130+
<CookText
131+
@rawText={{this.smoothStreamer.renderedText}}
132+
class="cooked"
133+
/>
134+
{{else}}
135+
{{#if this.diff}}
136+
{{htmlSafe this.diff}}
137+
{{else}}
138+
<div class="composer-ai-helper-modal__old-value">
139+
{{@model.selectedText}}
140+
</div>
141+
<div class="composer-ai-helper-modal__new-value">
142+
<CookText
143+
@rawText={{this.smoothStreamer.renderedText}}
144+
class="cooked"
145+
/>
146+
</div>
147+
{{/if}}
148+
{{/if}}
83149
</div>
84150
{{/if}}
85-
{{/if}}
86-
151+
</div>
87152
</:body>
88153

89154
<:footer>

assets/javascripts/discourse/components/thumbnail-suggestion-item.gjs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ export default class ThumbnailSuggestionItem extends Component {
1818
return this.args.removeSelection(thumbnail);
1919
}
2020

21-
this.selectIcon = "check-circle";
21+
this.selectIcon = "circle-check";
2222
this.selectLabel = "discourse_ai.ai_helper.thumbnail_suggestions.selected";
2323
this.selected = true;
2424
return this.args.addSelection(thumbnail);

lib/ai_helper/assistant.rb

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def custom_locale_instructions(user = nil, force_default_locale)
8585
end
8686
end
8787

88-
def localize_prompt!(prompt, user = nil, force_default_locale = false)
88+
def localize_prompt!(prompt, user = nil, force_default_locale: false)
8989
locale_instructions = custom_locale_instructions(user, force_default_locale)
9090
if locale_instructions
9191
prompt.messages[0][:content] = prompt.messages[0][:content] + locale_instructions
@@ -128,10 +128,10 @@ def localize_prompt!(prompt, user = nil, force_default_locale = false)
128128
end
129129
end
130130

131-
def generate_prompt(completion_prompt, input, user, force_default_locale = false, &block)
131+
def generate_prompt(completion_prompt, input, user, force_default_locale: false, &block)
132132
llm = helper_llm
133133
prompt = completion_prompt.messages_with_input(input)
134-
localize_prompt!(prompt, user, force_default_locale)
134+
localize_prompt!(prompt, user, force_default_locale: force_default_locale)
135135

136136
llm.generate(
137137
prompt,
@@ -143,8 +143,14 @@ def generate_prompt(completion_prompt, input, user, force_default_locale = false
143143
)
144144
end
145145

146-
def generate_and_send_prompt(completion_prompt, input, user, force_default_locale = false)
147-
completion_result = generate_prompt(completion_prompt, input, user, force_default_locale)
146+
def generate_and_send_prompt(completion_prompt, input, user, force_default_locale: false)
147+
completion_result =
148+
generate_prompt(
149+
completion_prompt,
150+
input,
151+
user,
152+
force_default_locale: force_default_locale,
153+
)
148154
result = { type: completion_prompt.prompt_type }
149155

150156
result[:suggestions] = (
@@ -160,24 +166,37 @@ def generate_and_send_prompt(completion_prompt, input, user, force_default_local
160166
result
161167
end
162168

163-
def stream_prompt(completion_prompt, input, user, channel)
169+
def stream_prompt(completion_prompt, input, user, channel, force_default_locale: false)
170+
streamed_diff = +""
164171
streamed_result = +""
165172
start = Time.now
166173

167-
generate_prompt(completion_prompt, input, user) do |partial_response, cancel_function|
174+
generate_prompt(
175+
completion_prompt,
176+
input,
177+
user,
178+
force_default_locale: force_default_locale,
179+
) do |partial_response, cancel_function|
168180
streamed_result << partial_response
169181

170-
# Throttle the updates
171-
if (Time.now - start > 0.5) || Rails.env.test?
172-
payload = { result: sanitize_result(streamed_result), done: false }
182+
streamed_diff = parse_diff(input, partial_response) if completion_prompt.diff?
183+
184+
# Throttle the updates and
185+
# checking length prevents partial tags
186+
# that aren't sanitized correctly yet (i.e. '<output')
187+
# from being sent in the stream
188+
if (streamed_result.length > 10 && (Time.now - start > 0.3)) || Rails.env.test?
189+
payload = { result: sanitize_result(streamed_result), diff: streamed_diff, done: false }
173190
publish_update(channel, payload, user)
174191
start = Time.now
175192
end
176193
end
177194

195+
final_diff = parse_diff(input, streamed_result) if completion_prompt.diff?
196+
178197
sanitized_result = sanitize_result(streamed_result)
179198
if sanitized_result.present?
180-
publish_update(channel, { result: sanitized_result, done: true }, user)
199+
publish_update(channel, { result: sanitized_result, diff: final_diff, done: true }, user)
181200
end
182201
end
183202

0 commit comments

Comments
 (0)