diff --git a/config/runtime.exs b/config/runtime.exs index 617459d..a89cc30 100644 --- a/config/runtime.exs +++ b/config/runtime.exs @@ -24,7 +24,9 @@ case provider do "openai" -> openai_api_key = System.fetch_env!("OPENAI_API_KEY") openai_model = System.fetch_env!("TEXT_GENERATION_MODEL") - openai_max_tokens = System.get_env("TEXT_GENERATION_MAX_NEW_TOKENS", 400) + + openai_max_tokens = + System.get_env("TEXT_GENERATION_MAX_NEW_TOKENS", "400") |> String.to_integer() config :openai, api_key: openai_api_key, @@ -39,7 +41,9 @@ case provider do "generic" -> generic_api_url = System.fetch_env!("TEXT_GENERATION_API_URL") generic_model = System.fetch_env!("TEXT_GENERATION_MODEL") - generic_max_tokens = System.get_env("TEXT_GENERATION_MAX_NEW_TOKENS", 400) + + generic_max_tokens = + System.get_env("TEXT_GENERATION_MAX_NEW_TOKENS", "400") |> String.to_integer() config :echo, Echo.TextGeneration, provider: Echo.TextGeneration.OpenAI @@ -53,8 +57,12 @@ case provider do "bumblebee" -> bb_text_generation_model = System.fetch_env!("TEXT_GENERATION_MODEL") - bb_max_new_tokens = System.get_env("TEXT_GENERATION_MAX_NEW_TOKENS", 400) - bb_max_sequence_length = System.get_env("TEXT_GENERATION_MAX_SEQUENCE_LENGTH", 2048) + + bb_max_new_tokens = + System.get_env("TEXT_GENERATION_MAX_NEW_TOKENS", "400") |> String.to_integer() + + bb_max_sequence_length = + System.get_env("TEXT_GENERATION_MAX_SEQUENCE_LENGTH", "2048") |> String.to_integer() config :echo, Echo.TextGeneration, provider: Echo.TextGeneration.Bumblebee @@ -65,17 +73,17 @@ case provider do end # Speech-to-Text -stt_model_repo = System.fetch_env!("SPEECH_TO_TEXT_MODEL_REPO") +stt_model_repo = System.fetch_env!("SPEECH_TO_TEXT_MODEL") +config :echo, Echo.SpeechToText, provider: Echo.SpeechToText.Bumblebee config :echo, Echo.SpeechToText.Bumblebee, repo: stt_model_repo - # Text-to-Speech eleven_labs_api_key = System.fetch_env!("ELEVEN_LABS_API_KEY") eleven_labs_voice_id = System.get_env("ELEVEN_LABS_VOICE_ID", "21m00Tcm4TlvDq8ikWAM") eleven_labs_model_id = System.get_env("ELEVEN_LABS_MODEL_ID", "eleven_turbo_v2") eleven_labs_optimize_streaming_latency = - System.get_env("ELEVEN_LABS_OPTIMIZE_STREAMING_LATENCY", 2) + System.get_env("ELEVEN_LABS_OPTIMIZE_STREAMING_LATENCY", "2") |> String.to_integer() eleven_labs_output_format = System.get_env("ELEVEN_LABS_OUTPUT_FORMAT", "mp3_22050_32") diff --git a/lib/echo/client/eleven_labs/web_socket.ex b/lib/echo/client/eleven_labs/web_socket.ex index c791bcf..2d7bda1 100644 --- a/lib/echo/client/eleven_labs/web_socket.ex +++ b/lib/echo/client/eleven_labs/web_socket.ex @@ -3,9 +3,19 @@ defmodule Echo.Client.ElevenLabs.WebSocket do require Logger + @reconnect_interval 5000 + @keepalive_interval 15000 + ## Client def start_link(broadcast_fun, token) do + state = %{ + fun: broadcast_fun, + token: token, + keepalive_timer: nil, + reconnect_timer: nil + } + headers = [{"xi-api-key", env(:api_key)}] params = %{ @@ -20,15 +30,15 @@ defmodule Echo.Client.ElevenLabs.WebSocket do |> URI.append_query(URI.encode_query(params)) |> URI.to_string() - WebSockex.start_link(url, __MODULE__, %{fun: broadcast_fun, token: token}, - extra_headers: headers + WebSockex.start_link(url, __MODULE__, state, + extra_headers: headers, + handle_initial_conn_failure: true ) end def open_stream(pid) do msg = Jason.encode!(%{text: " "}) WebSockex.send_frame(pid, {:text, msg}) - pid end @@ -38,8 +48,13 @@ defmodule Echo.Client.ElevenLabs.WebSocket do end def send(pid, text) do - msg = Jason.encode!(%{text: "#{text} ", try_trigger_generation: true}) - WebSockex.send_frame(pid, {:text, msg}) + if Process.alive?(pid) do + msg = Jason.encode!(%{text: "#{text} ", try_trigger_generation: true}) + WebSockex.send_frame(pid, {:text, msg}) + else + Logger.error("WebSocket process is not alive.") + {:error, :not_alive} + end end def flush(pid) do @@ -53,6 +68,21 @@ defmodule Echo.Client.ElevenLabs.WebSocket do ## Server + def handle_connect(_conn, state) do + Logger.info("Connected to ElevenLabs WebSocket") + {:ok, schedule_keepalive(state)} + end + + def handle_disconnect(%{reason: {:local, reason}}, state) do + Logger.warning("Local disconnect: #{inspect(reason)}. Reconnecting...") + {:reconnect, schedule_reconnect(state)} + end + + def handle_disconnect(disconnect_map, state) do + Logger.warning("Disconnected: #{inspect(disconnect_map)}. Reconnecting...") + {:reconnect, schedule_reconnect(state)} + end + def handle_cast({:update_token, {:binary, token}}, state) do {:ok, %{state | token: token}} end @@ -71,5 +101,28 @@ defmodule Echo.Client.ElevenLabs.WebSocket do {:ok, state} end + def handle_info(:keepalive, state) do + Logger.debug("Sending keepalive") + msg = Jason.encode!(%{text: " "}) + {:reply, {:text, msg}, schedule_keepalive(state)} + end + + def handle_info(:reconnect, state) do + Logger.info("Attempting to reconnect...") + {:reconnect, state} + end + + defp schedule_keepalive(state) do + if state.keepalive_timer, do: Process.cancel_timer(state.keepalive_timer) + timer = Process.send_after(self(), :keepalive, @keepalive_interval) + %{state | keepalive_timer: timer} + end + + defp schedule_reconnect(state) do + if state.reconnect_timer, do: Process.cancel_timer(state.reconnect_timer) + timer = Process.send_after(self(), :reconnect, @reconnect_interval) + %{state | reconnect_timer: timer} + end + defp env(key), do: :echo |> Application.fetch_env!(__MODULE__) |> Keyword.fetch!(key) end diff --git a/lib/echo/text_generation/openai.ex b/lib/echo/text_generation/openai.ex index f2bacb3..9c2dada 100644 --- a/lib/echo/text_generation/openai.ex +++ b/lib/echo/text_generation/openai.ex @@ -1,11 +1,21 @@ defmodule Echo.TextGeneration.OpenAI do @behaviour Echo.TextGeneration.Provider + require Logger + @impl true - def chat_completion(messages) do - opts = Keyword.merge([messages: messages], config()) + def chat_completion(opts) when is_list(opts) do + config = config() + merged_opts = Keyword.merge(config, opts) + + # Ensure the model is a string + model = Keyword.get(merged_opts, :model) + model = if is_tuple(model), do: elem(model, 1), else: model + merged_opts = Keyword.put(merged_opts, :model, model) + + Logger.debug("Sending chat completion request with options: #{inspect(merged_opts)}") - OpenAI.chat_completion(opts) + OpenAI.chat_completion(merged_opts) |> Stream.map(&get_in(&1, ["choices", Access.at(0), "delta", "content"])) |> Stream.reject(&is_nil/1) end diff --git a/lib/echo/text_to_speech.ex b/lib/echo/text_to_speech.ex index 04152c8..ea3bbb0 100644 --- a/lib/echo/text_to_speech.ex +++ b/lib/echo/text_to_speech.ex @@ -3,29 +3,46 @@ defmodule Echo.TextToSpeech do Generic TTS module. """ alias Echo.Client.ElevenLabs + require Logger @separators [".", ",", "?", "!", ";", ":", "—", "-", "(", ")", "[", "]", "}", " "] + defmodule Error do + defexception [:message, :reason] + + @type t :: %__MODULE__{ + message: String.t(), + reason: :connection_closed | :send_failed | :flush_failed | any() + } + end + @doc """ Consumes an Enumerable (such as a stream) of text into speech, applying `fun` to each audio element. Returns the spoken text contained within `enumerable`. + + Raises `Echo.TextToSpeech.Error` if an error occurs during streaming. """ + @spec stream(Enumerable.t(), pid()) :: String.t() | no_return() def stream(enumerable, pid) do result = enumerable |> group_tokens() |> Stream.map(fn text -> text = IO.iodata_to_binary(text) - ElevenLabs.WebSocket.send(pid, text) - text + + case send_text(pid, text) do + :ok -> text + {:error, reason} -> raise Error, message: "WebSocket send failed", reason: reason + end end) |> Enum.join() - ElevenLabs.WebSocket.flush(pid) - - result + case flush_websocket(pid) do + :ok -> result + {:error, reason} -> raise Error, message: "WebSocket flush failed", reason: reason + end end defp group_tokens(stream) do @@ -39,4 +56,29 @@ defmodule Echo.TextToSpeech do end end) end + + defp send_text(pid, text) do + case ElevenLabs.WebSocket.send(pid, text) do + :ok -> + :ok + + {:error, :not_alive} -> + Logger.error("WebSocket connection is closed.") + {:error, :connection_closed} + + {:error, reason} -> + Logger.error("Failed to send text: #{inspect(reason)}") + {:error, :send_failed} + end + end + + defp flush_websocket(pid) do + try do + ElevenLabs.WebSocket.flush(pid) + rescue + e -> + Logger.error("Failed to flush WebSocket: #{inspect(e)}") + {:error, :flush_failed} + end + end end