Skip to content

FEATURE: Add streaming to composer helper #1256

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 30 additions & 12 deletions app/controllers/discourse_ai/ai_helper/assistant_controller.rb
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def suggest
prompt,
input,
current_user,
force_default_locale,
force_default_locale: force_default_locale,
),
status: 200
end
Expand Down Expand Up @@ -110,26 +110,44 @@ def suggest_thumbnails(input)
end

def stream_suggestion
post_id = get_post_param!
text = get_text_param!
post = Post.includes(:topic).find_by(id: post_id)

location = params[:location]
raise Discourse::InvalidParameters.new(:location) if !location

prompt = CompletionPrompt.find_by(id: params[:mode])

raise Discourse::InvalidParameters.new(:mode) if !prompt || !prompt.enabled?
raise Discourse::InvalidParameters.new(:post_id) unless post
return suggest_thumbnails(input) if prompt.id == CompletionPrompt::ILLUSTRATE_POST

if prompt.id == CompletionPrompt::CUSTOM_PROMPT
raise Discourse::InvalidParameters.new(:custom_prompt) if params[:custom_prompt].blank?
end

Jobs.enqueue(
:stream_post_helper,
post_id: post.id,
user_id: current_user.id,
text: text,
prompt: prompt.name,
custom_prompt: params[:custom_prompt],
)
if location == "composer"
Jobs.enqueue(
:stream_composer_helper,
user_id: current_user.id,
text: text,
prompt: prompt.name,
custom_prompt: params[:custom_prompt],
force_default_locale: params[:force_default_locale] || false,
)
else
post_id = get_post_param!
post = Post.includes(:topic).find_by(id: post_id)

raise Discourse::InvalidParameters.new(:post_id) unless post

Jobs.enqueue(
:stream_post_helper,
post_id: post.id,
user_id: current_user.id,
text: text,
prompt: prompt.name,
custom_prompt: params[:custom_prompt],
)
end

render json: { success: true }, status: 200
rescue DiscourseAi::Completions::Endpoints::Base::CompletionFailed
Expand Down
27 changes: 27 additions & 0 deletions app/jobs/regular/stream_composer_helper.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# frozen_string_literal: true

module Jobs
class StreamComposerHelper < ::Jobs::Base
sidekiq_options retry: false

def execute(args)
return unless args[:prompt]
return unless user = User.find_by(id: args[:user_id])
return unless args[:text]

prompt = CompletionPrompt.enabled_by_name(args[:prompt])

if prompt.id == CompletionPrompt::CUSTOM_PROMPT
prompt.custom_instruction = args[:custom_prompt]
end

DiscourseAi::AiHelper::Assistant.new.stream_prompt(
prompt,
args[:text],
user,
"/discourse-ai/ai-helper/stream_composer_suggestion",
force_default_locale: args[:force_default_locale],
)
end
end
end
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ export default class AiPostHelperMenu extends Component {
this._activeAiRequest = ajax(fetchUrl, {
method: "POST",
data: {
location: "post",
mode: option.id,
text: this.args.data.selectedText,
post_id: this.args.data.quoteState.postId,
Expand Down
103 changes: 84 additions & 19 deletions assets/javascripts/discourse/components/modal/diff-modal.gjs
Original file line number Diff line number Diff line change
@@ -1,45 +1,92 @@
import Component from "@glimmer/component";
import { tracked } from "@glimmer/tracking";
import { action } from "@ember/object";
import didInsert from "@ember/render-modifiers/modifiers/did-insert";
import willDestroy from "@ember/render-modifiers/modifiers/will-destroy";
import { service } from "@ember/service";
import { htmlSafe } from "@ember/template";
import CookText from "discourse/components/cook-text";
import DButton from "discourse/components/d-button";
import DModal from "discourse/components/d-modal";
import concatClass from "discourse/helpers/concat-class";
import { ajax } from "discourse/lib/ajax";
import { popupAjaxError } from "discourse/lib/ajax-error";
import { bind } from "discourse/lib/decorators";
import { i18n } from "discourse-i18n";
import SmoothStreamer from "../../lib/smooth-streamer";
import AiIndicatorWave from "../ai-indicator-wave";

export default class ModalDiffModal extends Component {
@service currentUser;
@service messageBus;

@tracked loading = false;
@tracked diff;
@tracked suggestion = "";
@tracked
smoothStreamer = new SmoothStreamer(
() => this.suggestion,
(newValue) => (this.suggestion = newValue)
);

constructor() {
super(...arguments);
this.suggestChanges();
}

@bind
subscribe() {
const channel = "/discourse-ai/ai-helper/stream_composer_suggestion";
this.messageBus.subscribe(channel, this.updateResult);
}

@bind
unsubscribe() {
const channel = "/discourse-ai/ai-helper/stream_composer_suggestion";
this.messageBus.subscribe(channel, this.updateResult);
}

@action
async updateResult(result) {
if (result) {
this.loading = false;
}
await this.smoothStreamer.updateResult(result, "result");

if (result.done) {
this.diff = result.diff;
}

const mdTablePromptId = this.currentUser?.ai_helper_prompts.find(
(prompt) => prompt.name === "markdown_table"
).id;

// Markdown table prompt looks better with
// before/after results than diff
// despite having `type: diff`
if (this.args.model.mode === mdTablePromptId) {
this.diff = null;
}
}

@action
async suggestChanges() {
this.smoothStreamer.resetStreaming();
this.diff = null;
this.suggestion = "";
this.loading = true;

try {
const suggestion = await ajax("/discourse-ai/ai-helper/suggest", {
return await ajax("/discourse-ai/ai-helper/stream_suggestion", {
method: "POST",
data: {
location: "composer",
mode: this.args.model.mode,
text: this.args.model.selectedText,
custom_prompt: this.args.model.customPromptValue,
force_default_locale: true,
},
});

this.diff = suggestion.diff;
this.suggestion = suggestion.suggestions[0];
} catch (e) {
popupAjaxError(e);
} finally {
Expand All @@ -66,24 +113,42 @@ export default class ModalDiffModal extends Component {
@closeModal={{@closeModal}}
>
<:body>
{{#if this.loading}}
<div class="composer-ai-helper-modal__loading">
<CookText @rawText={{@model.selectedText}} />
</div>
{{else}}
{{#if this.diff}}
{{htmlSafe this.diff}}
{{else}}
<div class="composer-ai-helper-modal__old-value">
{{@model.selectedText}}
<div {{didInsert this.subscribe}} {{willDestroy this.unsubscribe}}>
{{#if this.loading}}
<div class="composer-ai-helper-modal__loading">
<CookText @rawText={{@model.selectedText}} />
</div>

<div class="composer-ai-helper-modal__new-value">
{{this.suggestion}}
{{else}}
<div
class={{concatClass
"composer-ai-helper-modal__suggestion"
"streamable-content"
(if this.smoothStreamer.isStreaming "streaming" "")
}}
>
{{#if this.smoothStreamer.isStreaming}}
<CookText
@rawText={{this.smoothStreamer.renderedText}}
class="cooked"
/>
{{else}}
{{#if this.diff}}
{{htmlSafe this.diff}}
{{else}}
<div class="composer-ai-helper-modal__old-value">
{{@model.selectedText}}
</div>
<div class="composer-ai-helper-modal__new-value">
<CookText
@rawText={{this.smoothStreamer.renderedText}}
class="cooked"
/>
</div>
{{/if}}
{{/if}}
</div>
{{/if}}
{{/if}}

</div>
</:body>

<:footer>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ export default class ThumbnailSuggestionItem extends Component {
return this.args.removeSelection(thumbnail);
}

this.selectIcon = "check-circle";
this.selectIcon = "circle-check";
this.selectLabel = "discourse_ai.ai_helper.thumbnail_suggestions.selected";
this.selected = true;
return this.args.addSelection(thumbnail);
Expand Down
41 changes: 30 additions & 11 deletions lib/ai_helper/assistant.rb
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def custom_locale_instructions(user = nil, force_default_locale)
end
end

def localize_prompt!(prompt, user = nil, force_default_locale = false)
def localize_prompt!(prompt, user = nil, force_default_locale: false)
locale_instructions = custom_locale_instructions(user, force_default_locale)
if locale_instructions
prompt.messages[0][:content] = prompt.messages[0][:content] + locale_instructions
Expand Down Expand Up @@ -128,10 +128,10 @@ def localize_prompt!(prompt, user = nil, force_default_locale = false)
end
end

def generate_prompt(completion_prompt, input, user, force_default_locale = false, &block)
def generate_prompt(completion_prompt, input, user, force_default_locale: false, &block)
llm = helper_llm
prompt = completion_prompt.messages_with_input(input)
localize_prompt!(prompt, user, force_default_locale)
localize_prompt!(prompt, user, force_default_locale: force_default_locale)

llm.generate(
prompt,
Expand All @@ -143,8 +143,14 @@ def generate_prompt(completion_prompt, input, user, force_default_locale = false
)
end

def generate_and_send_prompt(completion_prompt, input, user, force_default_locale = false)
completion_result = generate_prompt(completion_prompt, input, user, force_default_locale)
def generate_and_send_prompt(completion_prompt, input, user, force_default_locale: false)
completion_result =
generate_prompt(
completion_prompt,
input,
user,
force_default_locale: force_default_locale,
)
result = { type: completion_prompt.prompt_type }

result[:suggestions] = (
Expand All @@ -160,24 +166,37 @@ def generate_and_send_prompt(completion_prompt, input, user, force_default_local
result
end

def stream_prompt(completion_prompt, input, user, channel)
def stream_prompt(completion_prompt, input, user, channel, force_default_locale: false)
streamed_diff = +""
streamed_result = +""
start = Time.now

generate_prompt(completion_prompt, input, user) do |partial_response, cancel_function|
generate_prompt(
completion_prompt,
input,
user,
force_default_locale: force_default_locale,
) do |partial_response, cancel_function|
streamed_result << partial_response

# Throttle the updates
if (Time.now - start > 0.5) || Rails.env.test?
payload = { result: sanitize_result(streamed_result), done: false }
streamed_diff = parse_diff(input, partial_response) if completion_prompt.diff?

# Throttle the updates and
# checking length prevents partial tags
# that aren't sanitized correctly yet (i.e. '<output')
# from being sent in the stream
if (streamed_result.length > 10 && (Time.now - start > 0.3)) || Rails.env.test?
payload = { result: sanitize_result(streamed_result), diff: streamed_diff, done: false }
publish_update(channel, payload, user)
start = Time.now
end
end

final_diff = parse_diff(input, streamed_result) if completion_prompt.diff?

sanitized_result = sanitize_result(streamed_result)
if sanitized_result.present?
publish_update(channel, { result: sanitized_result, done: true }, user)
publish_update(channel, { result: sanitized_result, diff: final_diff, done: true }, user)
end
end

Expand Down
Loading