Skip to content

Commit cebb83d

Browse files
add retry on error to repeat block
Signed-off-by: hirokuni-kitahara <[email protected]>
1 parent a73c106 commit cebb83d

File tree

2 files changed

+29
-11
lines changed

2 files changed

+29
-11
lines changed

src/pdl/pdl_ast.py

+6
Original file line numberDiff line numberDiff line change
@@ -748,6 +748,12 @@ class RepeatBlock(StructuredBlock):
748748
join: JoinType = JoinText()
749749
"""Define how to combine the result of each iteration.
750750
"""
751+
retry_on_error: bool = False
752+
"""Indicate if this block should be retried when an error occurs.
753+
"""
754+
retry_max: int = 3
755+
"""Maximum number of retry.
756+
"""
751757
# Field for internal use
752758
pdl__trace: Optional[list["BlockType"]] = None
753759

src/pdl/pdl_interpreter.py

+23-11
Original file line numberDiff line numberDiff line change
@@ -721,9 +721,10 @@ def process_block_body(
721721
)
722722
repeat_loc = append(loc, "repeat")
723723
iidx = 0
724-
try:
725-
first = True
726-
while True:
724+
first = True
725+
retry_count = 0
726+
while True:
727+
try:
727728
if max_iterations is not None and iidx >= max_iterations:
728729
break
729730
if lengths is not None and iidx >= lengths[0]:
@@ -776,14 +777,25 @@ def process_block_body(
776777
stop, _ = process_condition_of(block, "until", scope, loc)
777778
if stop:
778779
break
779-
except PDLRuntimeError as exc:
780-
iter_trace.append(exc.pdl__trace)
781-
trace = block.model_copy(update={"pdl__trace": iter_trace})
782-
raise PDLRuntimeError(
783-
exc.message,
784-
loc=exc.loc or repeat_loc,
785-
trace=trace,
786-
) from exc
780+
except PDLRuntimeError as exc:
781+
manual_stop = False
782+
if "Keyboard Interrupt" in exc.message:
783+
manual_stop = True
784+
iter_trace.append(exc.pdl__trace)
785+
trace = block.model_copy(update={"pdl__trace": iter_trace})
786+
if block.retry_on_error and retry_count < block.retry_max and not manual_stop:
787+
retry_count += 1
788+
error = f"Retry on error is triggered in a repeat block. Error detail: {repr(exc)} "
789+
print(f"\n\033[0;31m{error}\033[0m\n")
790+
if background and background.data and background.data[-1]["content"].endswith(error):
791+
error = "The previous error occurs multiple times."
792+
background = lazy_messages_concat(background, [{"role": "assistant", "content": error}])
793+
else:
794+
raise PDLRuntimeError(
795+
exc.message,
796+
loc=exc.loc or repeat_loc,
797+
trace=trace,
798+
) from exc
787799
result = combine_results(block.join.as_, results)
788800
if state.yield_result and not iteration_state.yield_result:
789801
yield_result(result.result(), block.kind)

0 commit comments

Comments
 (0)