Skip to content

Fix Async Stream Contexts #1257

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions guardrails/llm_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def _invoke_llm(
llm_response = cast(Iterator[str], response)
return LLMResponse(
output="",
# FIXME: Why is this different from the async streaming implementation?
stream_output=llm_response,
)

Expand Down Expand Up @@ -491,6 +492,7 @@ def _invoke_llm(self, *args, **kwargs) -> LLMResponse:
llm_response = cast(Iterator[str], llm_response)
return LLMResponse(
output="",
# FIXME: Why is this different from the async streaming implementation?
stream_output=llm_response,
)

Expand Down Expand Up @@ -685,6 +687,8 @@ async def invoke_llm(
# response = cast(AsyncIterator[str], response)
return LLMResponse(
output="",
# FIXME: Why is this different from the synchronous streaming implementation? ## noqa: E501
# This shouldn't be necessary: https://docs.litellm.ai/docs/completion/stream#async-streaming
async_stream_output=response.completion_stream, # pyright: ignore[reportGeneralTypeIssues]
)

Expand Down Expand Up @@ -842,6 +846,8 @@ async def invoke_llm(self, *args, **kwargs) -> LLMResponse:
# the callable returns a generator object
return LLMResponse(
output="",
# FIXME: Why is this different from the synchronous streaming implementation? ## noqa: E501
# This shouldn't be necessary: https://docs.litellm.ai/docs/completion/stream#async-streaming
async_stream_output=output.completion_stream,
)

Expand Down
322 changes: 182 additions & 140 deletions guardrails/run/async_stream_runner.py

Large diffs are not rendered by default.

86 changes: 63 additions & 23 deletions guardrails/run/stream_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from guardrails.actions.reask import ReAsk, SkeletonReAsk
from guardrails.constants import pass_status
from guardrails.telemetry import trace_stream_step
from guardrails.utils.safe_get import safe_get


class StreamRunner(Runner):
Expand Down Expand Up @@ -240,35 +241,74 @@ def prepare_chunk_generator(stream) -> Iterator[Tuple[Any, bool]]:
def is_last_chunk(self, chunk: Any, api: Union[PromptCallableBase, None]) -> bool:
"""Detect if chunk is final chunk."""
try:
if (
not chunk.choices or len(chunk.choices) == 0
) and chunk.usage is not None:
# This is the last extra chunk for usage statistics
return True
finished = chunk.choices[0].finish_reason
return finished is not None
except (AttributeError, TypeError):
return False

def get_chunk_text(self, chunk: Any, api: Union[PromptCallableBase, None]) -> str:
"""Get the text from a chunk."""
chunk_text = ""
try:
finished = chunk.choices[0].finish_reason
content = chunk.choices[0].delta.content
if not finished and content:
chunk_text = content
except Exception:
try:
finished = chunk.choices[0].finish_reason
content = chunk.choices[0].text
if not finished and content:
chunk_text = content
except Exception:
try:
chunk_text = chunk
except Exception as e:
raise ValueError(
f"Error getting chunk from stream: {e}. "
"Non-OpenAI API callables expected to return "
"a generator of strings."
) from e
return chunk_text
"""Get the text from a chunk.

chunk is assumed to be an Iterator of either string or
ChatCompletionChunk

These types are not properly enforced upstream so we must use
reflection
"""
# Safeguard against None
# which can happen when the user provides
# custom LLM wrappers
if not chunk:
return ""
elif isinstance(chunk, str):
# If chunk is a string, return it
return chunk
elif hasattr(chunk, "choices") and hasattr(chunk.choices, "__iter__"):
# If chunk is a ChatCompletionChunk, return the text
# from the first choice
chunk_text = ""
first_choice = safe_get(chunk.choices, 0)
if not first_choice:
return chunk_text

if hasattr(first_choice, "delta") and hasattr(
first_choice.delta, "content"
):
chunk_text = first_choice.delta.content
elif hasattr(first_choice, "text"):
chunk_text = first_choice.text
else:
# If chunk is not a string or ChatCompletionChunk, raise an error
raise ValueError(
"chunk.choices[0] does not have "
"delta.content or text. "
"Non-OpenAI compliant callables must return "
"a generator of strings."
)

if not chunk_text:
# If chunk_text is empty, return an empty string
return ""
elif not isinstance(chunk_text, str):
# If chunk_text is not a string, raise an error
raise ValueError(
"Chunk text is not a string. "
"Non-OpenAI compliant callables must return "
"a generator of strings."
)
return chunk_text
else:
# If chunk is not a string or ChatCompletionChunk, raise an error
raise ValueError(
"Chunk is not a string or ChatCompletionChunk. "
"Non-OpenAI compliant callables must return "
"a generator of strings."
)

def parse(
self, output: str, output_schema: Dict[str, Any], *, verified: set, **kwargs
Expand Down
7 changes: 5 additions & 2 deletions guardrails/utils/safe_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@ def safe_get_with_brackets(
return value
except Exception as e:
logger.debug(
f"Failed to get value for key: {key} out of container: {container}!"
f"""
Failed to get value for key: {key} out of container: {container}.
Reason: {e}
Fallbacking to default value...
"""
)
logger.debug(e)
return default


Expand Down
46 changes: 37 additions & 9 deletions guardrails/validator_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# - [ ] Remove validator_base.py in 0.6.x

import asyncio
import contextlib
from contextvars import Context, ContextVar
from functools import partial
import inspect
import logging
Expand Down Expand Up @@ -67,10 +67,8 @@ def split_sentence_word_tokenizers_jl_separator(
# we check for a . to avoid wastefully calling the tokenizer

# check at least 3 characters have been accumulated before splitting
is_minimum_length = False
with contextlib.suppress(IndexError):
chunk[2]
is_minimum_length = True
third_chunk = safe_get(chunk, 2)
is_minimum_length = third_chunk is not None

# check for potential line endings, which is what split_sentences does
chunk_with_potential_line_endings, count = re.subn(
Expand Down Expand Up @@ -292,7 +290,14 @@ def _chunking_function(self, chunk: str) -> List[str]:
return split_sentence_word_tokenizers_jl_separator(chunk)

def validate_stream(
self, chunk: Any, metadata: Dict[str, Any], **kwargs
self,
chunk: Any,
metadata: Dict[str, Any],
*,
property_path: Optional[str] = "$",
context_vars: Optional[ContextVar[Dict[str, ContextVar[List[str]]]]] = None,
context: Optional[Context] = None,
**kwargs,
) -> Optional[ValidationResult]:
"""Validates a chunk emitted by an LLM. If the LLM chunk is smaller
than the validator's chunking strategy, it will be accumulated until it
Expand All @@ -307,8 +312,20 @@ def validate_stream(
result.
"""
# combine accumulated chunks and new [:-1]chunk
self.accumulated_chunks.append(chunk)
accumulated_text = "".join(self.accumulated_chunks)
accumulated_chunks = self.accumulated_chunks

# if context_vars is passed, use it to get the accumulated chunks
context_var: Optional[ContextVar[List[str]]] = None
ctx_var_map: Optional[Dict[str, ContextVar[List[str]]]] = None
context_key = f"{property_path}_{self.rail_alias}"
if context_vars and context:
ctx_var_map = context.run(context_vars.get)
context_var = ctx_var_map.get(context_key)
if context_var:
accumulated_chunks = context.run(context_var.get)

accumulated_chunks.append(chunk)
accumulated_text = "".join(accumulated_chunks)
# check if enough chunks have accumulated for validation
split_contents = self._chunking_function(accumulated_text)

Expand All @@ -318,9 +335,20 @@ def validate_stream(
split_contents = [accumulated_text, ""]
# if no chunks are returned, we haven't accumulated enough
if len(split_contents) == 0:
if context_vars and context_var and context and ctx_var_map:
context.run(context_var.set, accumulated_chunks)
ctx_var_map[context_key] = context_var
context.run(context_vars.set, ctx_var_map)
else:
self.accumulated_chunks = accumulated_chunks
return None
[chunk_to_validate, new_accumulated_chunks] = split_contents
self.accumulated_chunks = [new_accumulated_chunks]
if context_vars and context_var and context and ctx_var_map:
context.run(context_var.set, [new_accumulated_chunks])
ctx_var_map[context_key] = context_var
context.run(context_vars.set, ctx_var_map)
else:
self.accumulated_chunks = [new_accumulated_chunks]
# exclude last chunk, because it may not be a complete chunk
validation_result = self.validate(chunk_to_validate, metadata)
# if validate doesn't set validated chunk, we set it
Expand Down
6 changes: 6 additions & 0 deletions guardrails/validator_service/async_validator_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ async def run_validator(
metadata: Dict,
absolute_property_path: str,
stream: Optional[bool] = False,
*,
reference_path: Optional[str] = None,
**kwargs,
) -> ValidatorRun:
validator_logs = self.before_run_validator(
Expand All @@ -96,6 +98,7 @@ async def run_validator(
metadata,
stream,
validation_session_id=iteration.id,
reference_path=reference_path,
**kwargs,
)

Expand All @@ -111,6 +114,7 @@ async def run_validator(
result.metadata or {},
stream,
validation_session_id=iteration.id,
reference_path=reference_path,
**kwargs,
)
value = self.perform_correction(
Expand Down Expand Up @@ -160,6 +164,7 @@ async def run_validators(
metadata,
absolute_property_path,
stream=stream,
reference_property_path=reference_property_path,
**kwargs,
)
)
Expand Down Expand Up @@ -277,6 +282,7 @@ async def async_partial_validate(
metadata,
absolute_path,
stream=stream,
reference_path=reference_path,
**kwargs,
)
)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ authorized_licenses = [
"python software foundation",
"python software foundation license",
"zpl 2.1",
"mit and python-2.0"
]
unauthorized_licenses = [
"gpl v3",
Expand Down
31 changes: 27 additions & 4 deletions tests/unit_tests/validator_service/test_async_validator_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,12 @@ async def test_pass_result(self, mocker):

assert mock_run_validator_async.call_count == 1
mock_run_validator_async.assert_called_once_with(
validator, "value", {}, False, validation_session_id=iteration.id
validator,
"value",
{},
False,
validation_session_id=iteration.id,
reference_path=None,
)

assert mock_after_run_validator.call_count == 1
Expand Down Expand Up @@ -562,7 +567,12 @@ async def test_pass_result_with_override(self, mocker):

assert mock_run_validator_async.call_count == 1
mock_run_validator_async.assert_called_once_with(
validator, "value", {}, False, validation_session_id=iteration.id
validator,
"value",
{},
False,
validation_session_id=iteration.id,
reference_path=None,
)

assert mock_after_run_validator.call_count == 1
Expand Down Expand Up @@ -625,7 +635,12 @@ async def test_fail_result(self, mocker):

assert mock_run_validator_async.call_count == 1
mock_run_validator_async.assert_called_once_with(
validator, "value", {}, False, validation_session_id=iteration.id
validator,
"value",
{},
False,
validation_session_id=iteration.id,
reference_path=None,
)

assert mock_after_run_validator.call_count == 1
Expand Down Expand Up @@ -699,13 +714,21 @@ async def test_fail_result_with_fix_reask(self, mocker):
assert mock_run_validator_async.call_count == 2
mock_run_validator_async.assert_has_calls(
[
call(validator, "value", {}, False, validation_session_id=iteration.id),
call(
validator,
"value",
{},
False,
validation_session_id=iteration.id,
reference_path=None,
),
call(
validator,
"fixed-value",
{},
False,
validation_session_id=iteration.id,
reference_path=None,
),
]
)
Expand Down