From bc6050323b214ca0bb4e89e734cd6d44f3761abd Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Fri, 18 Apr 2025 16:12:44 -0500 Subject: [PATCH 1/8] use context vars to separate streams when async streaming; handle usage chunk --- guardrails/run/async_stream_runner.py | 12 +++++++++ guardrails/run/stream_runner.py | 5 ++++ guardrails/validator_base.py | 38 ++++++++++++++++++++------- 3 files changed, 46 insertions(+), 9 deletions(-) diff --git a/guardrails/run/async_stream_runner.py b/guardrails/run/async_stream_runner.py index b9d334fbd..df3f62934 100644 --- a/guardrails/run/async_stream_runner.py +++ b/guardrails/run/async_stream_runner.py @@ -1,3 +1,4 @@ +from contextvars import ContextVar, copy_context from typing import ( Any, AsyncIterator, @@ -118,6 +119,10 @@ async def async_step( refrain_triggered = False validation_passed = True + ctx_accumulated_chunks = ContextVar("accumulated_chunks") + ctx_accumulated_chunks.set([]) + context = copy_context() + if self.output_type == OutputTypes.STRING: validator_service = AsyncValidatorService(self.disable_tracer) async for chunk in stream_output: @@ -134,6 +139,8 @@ async def async_step( "$", "$", True, + context=context, + ctx_accumulated_chunks=ctx_accumulated_chunks, ) validators = self.validation_map.get("$", []) @@ -240,6 +247,8 @@ async def async_step( parsed_fragment, output_schema, validate_subschema=True, + context=context, + ctx_accumulated_chunks=ctx_accumulated_chunks, ) if isinstance(validated_fragment, SkeletonReAsk): raise ValueError( @@ -275,6 +284,9 @@ def get_chunk_text(self, chunk: Any, api: Union[PromptCallableBase, None]) -> st """Get the text from a chunk.""" chunk_text = "" + if not chunk.choices or len(chunk.choices) == 0: + return chunk_text + try: finished = chunk.choices[0].finish_reason content = chunk.choices[0].delta.content diff --git a/guardrails/run/stream_runner.py b/guardrails/run/stream_runner.py index 1f37a1623..d195eaf63 100644 --- a/guardrails/run/stream_runner.py +++ b/guardrails/run/stream_runner.py @@ -240,6 +240,11 @@ def prepare_chunk_generator(stream) -> Iterator[Tuple[Any, bool]]: def is_last_chunk(self, chunk: Any, api: Union[PromptCallableBase, None]) -> bool: """Detect if chunk is final chunk.""" try: + if ( + not chunk.choices or len(chunk.choices) == 0 + ) and chunk.usage is not None: + # This is the last extra chunk for usage statistics + return True finished = chunk.choices[0].finish_reason return finished is not None except (AttributeError, TypeError): diff --git a/guardrails/validator_base.py b/guardrails/validator_base.py index 1425c134b..0c1737434 100644 --- a/guardrails/validator_base.py +++ b/guardrails/validator_base.py @@ -4,7 +4,7 @@ # - [ ] Remove validator_base.py in 0.6.x import asyncio -import contextlib +from contextvars import Context, ContextVar from functools import partial import inspect import logging @@ -67,10 +67,8 @@ def split_sentence_word_tokenizers_jl_separator( # we check for a . to avoid wastefully calling the tokenizer # check at least 3 characters have been accumulated before splitting - is_minimum_length = False - with contextlib.suppress(IndexError): - chunk[2] - is_minimum_length = True + third_chunk = safe_get(chunk, 2) + is_minimum_length = third_chunk is not None # check for potential line endings, which is what split_sentences does chunk_with_potential_line_endings, count = re.subn( @@ -292,7 +290,13 @@ def _chunking_function(self, chunk: str) -> List[str]: return split_sentence_word_tokenizers_jl_separator(chunk) def validate_stream( - self, chunk: Any, metadata: Dict[str, Any], **kwargs + self, + chunk: Any, + metadata: Dict[str, Any], + *, + ctx_accumulated_chunks: Optional[ContextVar[List[str]]] = None, + context: Optional[Context] = None, + **kwargs, ) -> Optional[ValidationResult]: """Validates a chunk emitted by an LLM. If the LLM chunk is smaller than the validator's chunking strategy, it will be accumulated until it @@ -307,8 +311,13 @@ def validate_stream( result. """ # combine accumulated chunks and new [:-1]chunk - self.accumulated_chunks.append(chunk) - accumulated_text = "".join(self.accumulated_chunks) + accumulated_chunks = ( + context.run(ctx_accumulated_chunks.get) + if ctx_accumulated_chunks and context + else self.accumulated_chunks + ) + accumulated_chunks.append(chunk) + accumulated_text = "".join(accumulated_chunks) # check if enough chunks have accumulated for validation split_contents = self._chunking_function(accumulated_text) @@ -318,9 +327,16 @@ def validate_stream( split_contents = [accumulated_text, ""] # if no chunks are returned, we haven't accumulated enough if len(split_contents) == 0: + if ctx_accumulated_chunks and context: + context.run(ctx_accumulated_chunks.set, accumulated_chunks) + else: + self.accumulated_chunks = accumulated_chunks return None [chunk_to_validate, new_accumulated_chunks] = split_contents - self.accumulated_chunks = [new_accumulated_chunks] + if ctx_accumulated_chunks and context: + context.run(ctx_accumulated_chunks.set, [new_accumulated_chunks]) + else: + self.accumulated_chunks = [new_accumulated_chunks] # exclude last chunk, because it may not be a complete chunk validation_result = self.validate(chunk_to_validate, metadata) # if validate doesn't set validated chunk, we set it @@ -336,6 +352,10 @@ def validate_stream( ) ] + if ctx_accumulated_chunks: + ctx_accumulated_chunks.set(accumulated_chunks) + else: + self.accumulated_chunks = accumulated_chunks return validation_result async def async_validate_stream( From 33871130620b354a4c2f523f7b90a8128162264f Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Mon, 21 Apr 2025 09:07:06 -0500 Subject: [PATCH 2/8] keep support for (Async)Iterator[str]; safeguard against non-OpenAI compliant chunks --- guardrails/run/async_stream_runner.py | 31 ----------- guardrails/run/stream_runner.py | 80 +++++++++++++++++++-------- guardrails/utils/safe_get.py | 7 ++- 3 files changed, 62 insertions(+), 56 deletions(-) diff --git a/guardrails/run/async_stream_runner.py b/guardrails/run/async_stream_runner.py index df3f62934..41e6617bf 100644 --- a/guardrails/run/async_stream_runner.py +++ b/guardrails/run/async_stream_runner.py @@ -5,7 +5,6 @@ Dict, List, Optional, - Union, cast, ) @@ -16,7 +15,6 @@ from guardrails.classes.output_type import OutputTypes from guardrails.llm_providers import ( AsyncPromptCallableBase, - PromptCallableBase, ) from guardrails.logger import set_scope from guardrails.run import StreamRunner @@ -279,32 +277,3 @@ async def async_step( iteration.outputs.parsed_output = parsed_fragment or fragment # type: ignore iteration.outputs.validation_response = validation_response iteration.outputs.guarded_output = valid_op - - def get_chunk_text(self, chunk: Any, api: Union[PromptCallableBase, None]) -> str: - """Get the text from a chunk.""" - chunk_text = "" - - if not chunk.choices or len(chunk.choices) == 0: - return chunk_text - - try: - finished = chunk.choices[0].finish_reason - content = chunk.choices[0].delta.content - if not finished and content: - chunk_text = content - except Exception: - try: - finished = chunk.choices[0].finish_reason - content = chunk.choices[0].text - if not finished and content: - chunk_text = content - except Exception: - try: - chunk_text = chunk - except Exception as e: - raise ValueError( - f"Error getting chunk from stream: {e}. " - "Non-OpenAI API callables expected to return " - "a generator of strings." - ) from e - return chunk_text diff --git a/guardrails/run/stream_runner.py b/guardrails/run/stream_runner.py index d195eaf63..600770cfc 100644 --- a/guardrails/run/stream_runner.py +++ b/guardrails/run/stream_runner.py @@ -17,6 +17,7 @@ from guardrails.actions.reask import ReAsk, SkeletonReAsk from guardrails.constants import pass_status from guardrails.telemetry import trace_stream_step +from guardrails.utils.safe_get import safe_get class StreamRunner(Runner): @@ -251,29 +252,62 @@ def is_last_chunk(self, chunk: Any, api: Union[PromptCallableBase, None]) -> boo return False def get_chunk_text(self, chunk: Any, api: Union[PromptCallableBase, None]) -> str: - """Get the text from a chunk.""" - chunk_text = "" - try: - finished = chunk.choices[0].finish_reason - content = chunk.choices[0].delta.content - if not finished and content: - chunk_text = content - except Exception: - try: - finished = chunk.choices[0].finish_reason - content = chunk.choices[0].text - if not finished and content: - chunk_text = content - except Exception: - try: - chunk_text = chunk - except Exception as e: - raise ValueError( - f"Error getting chunk from stream: {e}. " - "Non-OpenAI API callables expected to return " - "a generator of strings." - ) from e - return chunk_text + """ + Get the text from a chunk. + + chunk is assumed to be an Iterator of either string or ChatCompletionChunk + + These types are not properly enforced upstream so we must use reflection + """ + # Safeguard against None + # which can happen when the user provides + # custom LLM wrappers + if not chunk: + return "" + elif isinstance(chunk, str): + # If chunk is a string, return it + return chunk + elif hasattr(chunk, "choices") and hasattr(chunk.choices, "__iter__"): + # If chunk is a ChatCompletionChunk, return the text + # from the first choice + chunk_text = "" + first_choice = safe_get(chunk.choices, 0) + if not first_choice: + return chunk_text + + if hasattr(first_choice, "delta") and hasattr( + first_choice.delta, "content" + ): + chunk_text = first_choice.delta.content + elif hasattr(first_choice, "text"): + chunk_text = first_choice.text + else: + # If chunk is not a string or ChatCompletionChunk, raise an error + raise ValueError( + "chunk.choices[0] does not have " + "delta.content or text. " + "Non-OpenAI compliant callables must return " + "a generator of strings." + ) + + if not chunk_text: + # If chunk_text is empty, return an empty string + return "" + elif not isinstance(chunk_text, str): + # If chunk_text is not a string, raise an error + raise ValueError( + "Chunk text is not a string. " + "Non-OpenAI compliant callables must return " + "a generator of strings." + ) + return chunk_text + else: + # If chunk is not a string or ChatCompletionChunk, raise an error + raise ValueError( + "Chunk is not a string or ChatCompletionChunk. " + "Non-OpenAI compliant callables must return " + "a generator of strings." + ) def parse( self, output: str, output_schema: Dict[str, Any], *, verified: set, **kwargs diff --git a/guardrails/utils/safe_get.py b/guardrails/utils/safe_get.py index dbc35f84a..e53b3e18b 100644 --- a/guardrails/utils/safe_get.py +++ b/guardrails/utils/safe_get.py @@ -12,9 +12,12 @@ def safe_get_with_brackets( return value except Exception as e: logger.debug( - f"Failed to get value for key: {key} out of container: {container}!" + f""" + Failed to get value for key: {key} out of container: {container}. + Reason: {e} + Fallbacking to default value... + """ ) - logger.debug(e) return default From dfbb6058c661cb535ec650143e805be00df2cb0f Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Mon, 21 Apr 2025 09:11:36 -0500 Subject: [PATCH 3/8] authorize mit and python2.0 combo license --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index b244ccdea..63d7d423a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -162,6 +162,7 @@ authorized_licenses = [ "python software foundation", "python software foundation license", "zpl 2.1", + "mit and python-2.0" ] unauthorized_licenses = [ "gpl v3", From 6209f67fea171c95e42ec571257b456785ea39f0 Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Mon, 21 Apr 2025 11:21:13 -0500 Subject: [PATCH 4/8] context per validator stream --- guardrails/run/async_stream_runner.py | 26 ++++++++++---- guardrails/run/stream_runner.py | 9 ++--- guardrails/validator_base.py | 36 +++++++++++-------- .../async_validator_service.py | 6 ++++ .../test_async_validator_service.py | 31 +++++++++++++--- 5 files changed, 79 insertions(+), 29 deletions(-) diff --git a/guardrails/run/async_stream_runner.py b/guardrails/run/async_stream_runner.py index 41e6617bf..c89800afa 100644 --- a/guardrails/run/async_stream_runner.py +++ b/guardrails/run/async_stream_runner.py @@ -117,9 +117,23 @@ async def async_step( refrain_triggered = False validation_passed = True - ctx_accumulated_chunks = ContextVar("accumulated_chunks") - ctx_accumulated_chunks.set([]) + # TODO: reset the context vars when finished context = copy_context() + stream_context_vars: ContextVar[Dict[str, ContextVar[List[str]]]] = ContextVar( + "stream_context" + ) + context_vars: Dict[str, ContextVar[List[str]]] = {} + for k, v in self.validation_map.items(): + if isinstance(v, list): + for validator in v: + property_validation_chunks = ContextVar( + f"{k}_{validator.rail_alias}_chunks" + ) + context.run(property_validation_chunks.set, []) + context_vars[f"{k}_{validator.rail_alias}"] = ( + property_validation_chunks # noqa: E501 + ) + context.run(stream_context_vars.set, context_vars) if self.output_type == OutputTypes.STRING: validator_service = AsyncValidatorService(self.disable_tracer) @@ -138,7 +152,7 @@ async def async_step( "$", True, context=context, - ctx_accumulated_chunks=ctx_accumulated_chunks, + context_vars=stream_context_vars, ) validators = self.validation_map.get("$", []) @@ -186,9 +200,7 @@ async def async_step( rechecked_value=None, ) # type: ignore - if not hasattr( - validation_progress, validator_log.validator_name - ): + if validator_log.validator_name not in validation_progress: validation_progress[validator_log.validator_name] = "" validation_progress[validator_log.validator_name] += chunk @@ -246,7 +258,7 @@ async def async_step( output_schema, validate_subschema=True, context=context, - ctx_accumulated_chunks=ctx_accumulated_chunks, + context_vars=stream_context_vars, ) if isinstance(validated_fragment, SkeletonReAsk): raise ValueError( diff --git a/guardrails/run/stream_runner.py b/guardrails/run/stream_runner.py index 600770cfc..222d5e1d0 100644 --- a/guardrails/run/stream_runner.py +++ b/guardrails/run/stream_runner.py @@ -252,12 +252,13 @@ def is_last_chunk(self, chunk: Any, api: Union[PromptCallableBase, None]) -> boo return False def get_chunk_text(self, chunk: Any, api: Union[PromptCallableBase, None]) -> str: - """ - Get the text from a chunk. + """Get the text from a chunk. - chunk is assumed to be an Iterator of either string or ChatCompletionChunk + chunk is assumed to be an Iterator of either string or + ChatCompletionChunk - These types are not properly enforced upstream so we must use reflection + These types are not properly enforced upstream so we must use + reflection """ # Safeguard against None # which can happen when the user provides diff --git a/guardrails/validator_base.py b/guardrails/validator_base.py index 0c1737434..f3d38919b 100644 --- a/guardrails/validator_base.py +++ b/guardrails/validator_base.py @@ -294,7 +294,8 @@ def validate_stream( chunk: Any, metadata: Dict[str, Any], *, - ctx_accumulated_chunks: Optional[ContextVar[List[str]]] = None, + property_path: Optional[str] = "$", + context_vars: Optional[ContextVar[Dict[str, ContextVar[List[str]]]]] = None, context: Optional[Context] = None, **kwargs, ) -> Optional[ValidationResult]: @@ -311,11 +312,18 @@ def validate_stream( result. """ # combine accumulated chunks and new [:-1]chunk - accumulated_chunks = ( - context.run(ctx_accumulated_chunks.get) - if ctx_accumulated_chunks and context - else self.accumulated_chunks - ) + accumulated_chunks = self.accumulated_chunks + + # if context_vars is passed, use it to get the accumulated chunks + context_var: Optional[ContextVar[List[str]]] = None + ctx_var_map: Optional[Dict[str, ContextVar[List[str]]]] = None + context_key = f"{property_path}_{self.rail_alias}" + if context_vars and context: + ctx_var_map = context.run(context_vars.get) + context_var = ctx_var_map.get(context_key) + if context_var: + accumulated_chunks = context.run(context_var.get) + accumulated_chunks.append(chunk) accumulated_text = "".join(accumulated_chunks) # check if enough chunks have accumulated for validation @@ -327,14 +335,18 @@ def validate_stream( split_contents = [accumulated_text, ""] # if no chunks are returned, we haven't accumulated enough if len(split_contents) == 0: - if ctx_accumulated_chunks and context: - context.run(ctx_accumulated_chunks.set, accumulated_chunks) + if context_vars and context_var and context and ctx_var_map: + context.run(context_var.set, accumulated_chunks) + ctx_var_map[context_key] = context_var + context.run(context_vars.set, ctx_var_map) else: self.accumulated_chunks = accumulated_chunks return None [chunk_to_validate, new_accumulated_chunks] = split_contents - if ctx_accumulated_chunks and context: - context.run(ctx_accumulated_chunks.set, [new_accumulated_chunks]) + if context_vars and context_var and context and ctx_var_map: + context.run(context_var.set, [new_accumulated_chunks]) + ctx_var_map[context_key] = context_var + context.run(context_vars.set, ctx_var_map) else: self.accumulated_chunks = [new_accumulated_chunks] # exclude last chunk, because it may not be a complete chunk @@ -352,10 +364,6 @@ def validate_stream( ) ] - if ctx_accumulated_chunks: - ctx_accumulated_chunks.set(accumulated_chunks) - else: - self.accumulated_chunks = accumulated_chunks return validation_result async def async_validate_stream( diff --git a/guardrails/validator_service/async_validator_service.py b/guardrails/validator_service/async_validator_service.py index 52a2a530c..a56cb6037 100644 --- a/guardrails/validator_service/async_validator_service.py +++ b/guardrails/validator_service/async_validator_service.py @@ -84,6 +84,8 @@ async def run_validator( metadata: Dict, absolute_property_path: str, stream: Optional[bool] = False, + *, + reference_path: Optional[str] = None, **kwargs, ) -> ValidatorRun: validator_logs = self.before_run_validator( @@ -96,6 +98,7 @@ async def run_validator( metadata, stream, validation_session_id=iteration.id, + reference_path=reference_path, **kwargs, ) @@ -111,6 +114,7 @@ async def run_validator( result.metadata or {}, stream, validation_session_id=iteration.id, + reference_path=reference_path, **kwargs, ) value = self.perform_correction( @@ -160,6 +164,7 @@ async def run_validators( metadata, absolute_property_path, stream=stream, + reference_property_path=reference_property_path, **kwargs, ) ) @@ -277,6 +282,7 @@ async def async_partial_validate( metadata, absolute_path, stream=stream, + reference_path=reference_path, **kwargs, ) ) diff --git a/tests/unit_tests/validator_service/test_async_validator_service.py b/tests/unit_tests/validator_service/test_async_validator_service.py index 7c0ebd368..e321a6011 100644 --- a/tests/unit_tests/validator_service/test_async_validator_service.py +++ b/tests/unit_tests/validator_service/test_async_validator_service.py @@ -503,7 +503,12 @@ async def test_pass_result(self, mocker): assert mock_run_validator_async.call_count == 1 mock_run_validator_async.assert_called_once_with( - validator, "value", {}, False, validation_session_id=iteration.id + validator, + "value", + {}, + False, + validation_session_id=iteration.id, + reference_path=None, ) assert mock_after_run_validator.call_count == 1 @@ -562,7 +567,12 @@ async def test_pass_result_with_override(self, mocker): assert mock_run_validator_async.call_count == 1 mock_run_validator_async.assert_called_once_with( - validator, "value", {}, False, validation_session_id=iteration.id + validator, + "value", + {}, + False, + validation_session_id=iteration.id, + reference_path=None, ) assert mock_after_run_validator.call_count == 1 @@ -625,7 +635,12 @@ async def test_fail_result(self, mocker): assert mock_run_validator_async.call_count == 1 mock_run_validator_async.assert_called_once_with( - validator, "value", {}, False, validation_session_id=iteration.id + validator, + "value", + {}, + False, + validation_session_id=iteration.id, + reference_path=None, ) assert mock_after_run_validator.call_count == 1 @@ -699,13 +714,21 @@ async def test_fail_result_with_fix_reask(self, mocker): assert mock_run_validator_async.call_count == 2 mock_run_validator_async.assert_has_calls( [ - call(validator, "value", {}, False, validation_session_id=iteration.id), + call( + validator, + "value", + {}, + False, + validation_session_id=iteration.id, + reference_path=None, + ), call( validator, "fixed-value", {}, False, validation_session_id=iteration.id, + reference_path=None, ), ] ) From 5c00824ed7c04e8212b32ee8f4cb39725cde712f Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Mon, 21 Apr 2025 13:41:44 -0500 Subject: [PATCH 5/8] fix async streaming of structured output --- guardrails/run/async_stream_runner.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/guardrails/run/async_stream_runner.py b/guardrails/run/async_stream_runner.py index c89800afa..fdc2c2046 100644 --- a/guardrails/run/async_stream_runner.py +++ b/guardrails/run/async_stream_runner.py @@ -279,10 +279,11 @@ async def async_step( validation_response = cast(dict, validated_fragment) yield ValidationOutcome( call_id=call_log.id, # type: ignore - raw_llm_output=fragment, + raw_llm_output=validated_fragment, validated_output=chunk_text, validation_passed=validated_fragment is not None, ) + fragment = "" iteration.outputs.raw_output = fragment # FIXME: Handle case where parsing continuously fails/is a reask From aefd640dbeaf3e19bc5f3018e7c7666bddc281d8 Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Mon, 21 Apr 2025 14:07:47 -0500 Subject: [PATCH 6/8] reset context vars at the end of the stream --- guardrails/run/async_stream_runner.py | 279 +++++++++++++++----------- 1 file changed, 164 insertions(+), 115 deletions(-) diff --git a/guardrails/run/async_stream_runner.py b/guardrails/run/async_stream_runner.py index fdc2c2046..0c291b498 100644 --- a/guardrails/run/async_stream_runner.py +++ b/guardrails/run/async_stream_runner.py @@ -1,4 +1,5 @@ from contextvars import ContextVar, copy_context +import sys from typing import ( Any, AsyncIterator, @@ -28,6 +29,10 @@ ) +if sys.version_info.minor < 10: + from guardrails.utils.polyfills import anext + + class AsyncStreamRunner(AsyncRunner, StreamRunner): # @async_trace_stream(name="/reasks", origin="AsyncStreamRunner.async_run") async def async_run( @@ -137,96 +142,124 @@ async def async_step( if self.output_type == OutputTypes.STRING: validator_service = AsyncValidatorService(self.disable_tracer) - async for chunk in stream_output: - chunk_text = self.get_chunk_text(chunk, api) - _ = self.is_last_chunk(chunk, api) - fragment += chunk_text + next_exists = True + while next_exists: + try: + chunk = await anext(stream_output) + chunk_text = self.get_chunk_text(chunk, api) + _ = self.is_last_chunk(chunk, api) - results = await validator_service.async_partial_validate( - chunk_text, - self.metadata, - self.validation_map, - iteration, - "$", - "$", - True, - context=context, - context_vars=stream_context_vars, - ) - validators = self.validation_map.get("$", []) + fragment += chunk_text - # collect the result validated_chunk into validation progress - # per validator - for result in results: - validator_log = result.validator_logs # type: ignore - validator = next( - filter( - lambda x: x.rail_alias == validator_log.registered_name, - validators, - ), - None, + results = await validator_service.async_partial_validate( + chunk_text, + self.metadata, + self.validation_map, + iteration, + "$", + "$", + True, + context=context, + context_vars=stream_context_vars, ) - if ( - validator_log.validation_result - and validator_log.validation_result.validated_chunk - ): - is_filter = validator.on_fail_descriptor is OnFailAction.FILTER # type: ignore - is_refrain = ( - validator.on_fail_descriptor is OnFailAction.REFRAIN # type: ignore + validators = self.validation_map.get("$", []) + + # collect the result validated_chunk into validation progress + # per validator + for result in results: + validator_log = result.validator_logs # type: ignore + validator = next( + filter( + lambda x: x.rail_alias == validator_log.registered_name, + validators, + ), + None, ) - if validator_log.validation_result.outcome == "fail": - validation_passed = False - reasks, valid_op = self.introspect( + if ( validator_log.validation_result - ) - if reasks: - raise ValueError( - "Reasks are not yet supported with streaming. Please " - "remove reasks from schema or disable streaming." + and validator_log.validation_result.validated_chunk + ): + is_filter = ( + validator.on_fail_descriptor is OnFailAction.FILTER # type: ignore + ) + is_refrain = ( + validator.on_fail_descriptor is OnFailAction.REFRAIN # type: ignore + ) + if validator_log.validation_result.outcome == "fail": + validation_passed = False + reasks, valid_op = self.introspect( + validator_log.validation_result ) + if reasks: + raise ValueError( + "Reasks are not yet supported with streaming. " + "Please remove reasks from schema or disable" + " streaming." + ) - if isinstance(validator_log.validation_result, PassResult): - chunk = validator_log.validation_result.validated_chunk - elif isinstance(validator_log.validation_result, FailResult): - if is_filter or is_refrain: - refrain_triggered = True - chunk = "" - else: - chunk = validator_service.perform_correction( - validator_log.validation_result, - validator_log.validation_result.validated_chunk, - validator, # type: ignore - rechecked_value=None, - ) # type: ignore + if isinstance(validator_log.validation_result, PassResult): + chunk = validator_log.validation_result.validated_chunk + elif isinstance( + validator_log.validation_result, FailResult + ): + if is_filter or is_refrain: + refrain_triggered = True + chunk = "" + else: + chunk = validator_service.perform_correction( + validator_log.validation_result, + validator_log.validation_result.validated_chunk, + validator, # type: ignore + rechecked_value=None, + ) # type: ignore - if validator_log.validator_name not in validation_progress: - validation_progress[validator_log.validator_name] = "" + if validator_log.validator_name not in validation_progress: + validation_progress[validator_log.validator_name] = "" - validation_progress[validator_log.validator_name] += chunk - # if there is an entry for every validator - # run a merge and emit a validation outcome - if len(validation_progress) == len(validators) or len(validators) == 0: - if refrain_triggered: - current = "" - else: - merge_chunks = [] - for piece in validation_progress: - merge_chunks.append(validation_progress[piece]) + validation_progress[validator_log.validator_name] += chunk + # if there is an entry for every validator + # run a merge and emit a validation outcome + if ( + len(validation_progress) == len(validators) + or len(validators) == 0 + ): + if refrain_triggered: + current = "" + else: + merge_chunks = [] + for piece in validation_progress: + merge_chunks.append(validation_progress[piece]) - current = validator_service.multi_merge(fragment, merge_chunks) + current = validator_service.multi_merge( + fragment, merge_chunks + ) - vo = ValidationOutcome( - call_id=call_log.id, # type: ignore - raw_llm_output=fragment, - validated_output=current, - validation_passed=True, - ) - fragment = "" - validation_progress = {} - refrain_triggered = False + vo = ValidationOutcome( + call_id=call_log.id, # type: ignore + raw_llm_output=fragment, + validated_output=current, + validation_passed=True, + ) + fragment = "" + validation_progress = {} + refrain_triggered = False + + yield vo - yield vo + except StopIteration: + next_exists = False + except StopAsyncIteration: + next_exists = False + except Exception as e: + raise e + finally: + # reset all context vars + for context_var in context_vars.values(): + token = context.run(context_var.set, []) + context.run(context_var.reset, token) + token = context.run(stream_context_vars.set, {}) + context.run(stream_context_vars.reset, token) # if theres anything left merge and emit a chunk if len(validation_progress) > 0: @@ -242,48 +275,64 @@ async def async_step( validation_passed=validation_passed, ) else: - async for chunk in stream_output: - chunk_text = self.get_chunk_text(chunk, api) - fragment += chunk_text + next_exists = True + while next_exists: + try: + chunk = await anext(stream_output) + chunk_text = self.get_chunk_text(chunk, api) + fragment += chunk_text - parsed_fragment, move_to_next = self.parse( - fragment, output_schema, verified=verified - ) - if move_to_next: - continue - validated_fragment = await self.async_validate( - iteration, - index, - parsed_fragment, - output_schema, - validate_subschema=True, - context=context, - context_vars=stream_context_vars, - ) - if isinstance(validated_fragment, SkeletonReAsk): - raise ValueError( - "Received fragment schema is an invalid sub-schema " - "of the expected output JSON schema." + parsed_fragment, move_to_next = self.parse( + fragment, output_schema, verified=verified ) - - reasks, valid_op = self.introspect(validated_fragment) - if reasks: - raise ValueError( - "Reasks are not yet supported with streaming. Please " - "remove reasks from schema or disable streaming." + if move_to_next: + continue + validated_fragment = await self.async_validate( + iteration, + index, + parsed_fragment, + output_schema, + validate_subschema=True, + context=context, + context_vars=stream_context_vars, ) + if isinstance(validated_fragment, SkeletonReAsk): + raise ValueError( + "Received fragment schema is an invalid sub-schema " + "of the expected output JSON schema." + ) - if self.output_type == OutputTypes.LIST: - validation_response = cast(list, validated_fragment) - else: - validation_response = cast(dict, validated_fragment) - yield ValidationOutcome( - call_id=call_log.id, # type: ignore - raw_llm_output=validated_fragment, - validated_output=chunk_text, - validation_passed=validated_fragment is not None, - ) - fragment = "" + reasks, valid_op = self.introspect(validated_fragment) + if reasks: + raise ValueError( + "Reasks are not yet supported with streaming. Please " + "remove reasks from schema or disable streaming." + ) + + if self.output_type == OutputTypes.LIST: + validation_response = cast(list, validated_fragment) + else: + validation_response = cast(dict, validated_fragment) + yield ValidationOutcome( + call_id=call_log.id, # type: ignore + raw_llm_output=fragment, + validated_output=validated_fragment, + validation_passed=validated_fragment is not None, + ) + fragment = "" + except StopIteration: + next_exists = False + except StopAsyncIteration: + next_exists = False + except Exception as e: + raise e + finally: + # reset all context vars + for context_var in context_vars.values(): + token = context.run(context_var.set, []) + context.run(context_var.reset, token) + token = context.run(stream_context_vars.set, {}) + context.run(stream_context_vars.reset, token) iteration.outputs.raw_output = fragment # FIXME: Handle case where parsing continuously fails/is a reask From 9d4b3ef92342376f857d918cf496e80ee8d3cae0 Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Mon, 21 Apr 2025 14:51:49 -0500 Subject: [PATCH 7/8] add fixmes --- guardrails/llm_providers.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/guardrails/llm_providers.py b/guardrails/llm_providers.py index 5f19d470f..a8ded382f 100644 --- a/guardrails/llm_providers.py +++ b/guardrails/llm_providers.py @@ -218,6 +218,7 @@ def _invoke_llm( llm_response = cast(Iterator[str], response) return LLMResponse( output="", + # FIXME: Why is this different from the async streaming implementation? stream_output=llm_response, ) @@ -491,6 +492,7 @@ def _invoke_llm(self, *args, **kwargs) -> LLMResponse: llm_response = cast(Iterator[str], llm_response) return LLMResponse( output="", + # FIXME: Why is this different from the async streaming implementation? stream_output=llm_response, ) @@ -685,6 +687,8 @@ async def invoke_llm( # response = cast(AsyncIterator[str], response) return LLMResponse( output="", + # FIXME: Why is this different from the synchronous streaming implementation? ## noqa: E501 + # This shouldn't be necessary: https://docs.litellm.ai/docs/completion/stream#async-streaming async_stream_output=response.completion_stream, # pyright: ignore[reportGeneralTypeIssues] ) @@ -842,6 +846,8 @@ async def invoke_llm(self, *args, **kwargs) -> LLMResponse: # the callable returns a generator object return LLMResponse( output="", + # FIXME: Why is this different from the synchronous streaming implementation? ## noqa: E501 + # This shouldn't be necessary: https://docs.litellm.ai/docs/completion/stream#async-streaming async_stream_output=output.completion_stream, ) From ac016b510f2cede1be4527d89aef4a203deee34c Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Mon, 21 Apr 2025 15:11:58 -0500 Subject: [PATCH 8/8] remove obsolete TODO --- guardrails/run/async_stream_runner.py | 1 - 1 file changed, 1 deletion(-) diff --git a/guardrails/run/async_stream_runner.py b/guardrails/run/async_stream_runner.py index 0c291b498..9b344572b 100644 --- a/guardrails/run/async_stream_runner.py +++ b/guardrails/run/async_stream_runner.py @@ -122,7 +122,6 @@ async def async_step( refrain_triggered = False validation_passed = True - # TODO: reset the context vars when finished context = copy_context() stream_context_vars: ContextVar[Dict[str, ContextVar[List[str]]]] = ContextVar( "stream_context"