@@ -196,6 +196,9 @@ def traced_llm_generate(langchain, pin, func, instance, args, kwargs):
196
196
span .set_tag_str ("langchain.request.%s.parameters.%s" % (llm_provider , param ), str (val ))
197
197
198
198
completions = func (* args , ** kwargs )
199
+
200
+ _iast_taint_llm_output (prompts , completions )
201
+
199
202
if _is_openai_llm_instance (instance ):
200
203
_tag_openai_token_usage (span , completions .llm_output )
201
204
@@ -942,6 +945,57 @@ async def traced_base_tool_ainvoke(langchain, pin, func, instance, args, kwargs)
942
945
return tool_output
943
946
944
947
948
+ def _iast_taint_llm_output (prompts , completions ):
949
+ """
950
+ Taints the output of an LLM call if its inputs are tainted.
951
+
952
+ Range propagation does not make sense in LLMs. So we get the first source in inputs, if any,
953
+ and taint the full output with that source.
954
+ """
955
+ if not asm_config ._iast_enabled :
956
+ return
957
+ if not isinstance (prompts , (tuple , list )):
958
+ return
959
+ if not hasattr (completions , "generations" ):
960
+ return
961
+ try :
962
+ generations = completions .generations
963
+ if not isinstance (generations , list ):
964
+ return
965
+
966
+ from ddtrace .appsec ._iast ._taint_tracking ._taint_objects import get_tainted_ranges
967
+ from ddtrace .appsec ._iast ._taint_tracking ._taint_objects import taint_pyobject
968
+
969
+ source = None
970
+ for prompt in prompts :
971
+ if not isinstance (prompt , str ):
972
+ continue
973
+ tainted_ranges = get_tainted_ranges (prompt )
974
+ if tainted_ranges :
975
+ source = tainted_ranges [0 ].source
976
+ break
977
+ if not source :
978
+ return
979
+ for gens in generations :
980
+ for gen in gens :
981
+ if not hasattr (gen , "text" ):
982
+ continue
983
+ text = gen .text
984
+ if not isinstance (text , str ):
985
+ continue
986
+ new_text = taint_pyobject (
987
+ pyobject = text ,
988
+ source_name = source .name ,
989
+ source_value = source .value ,
990
+ source_origin = source .origin ,
991
+ )
992
+ setattr (gen , "text" , new_text )
993
+ except Exception as e :
994
+ from ddtrace .appsec ._iast ._metrics import _set_iast_error_metric
995
+
996
+ _set_iast_error_metric ("IAST propagation error. langchain _iast_taint_llm_output. {}" .format (e ))
997
+
998
+
945
999
def _patch_embeddings_and_vectorstores ():
946
1000
"""
947
1001
Text embedding models override two abstract base methods instead of super calls,
@@ -1081,10 +1135,15 @@ def patch():
1081
1135
if asm_config ._iast_enabled :
1082
1136
from ddtrace .appsec ._iast ._metrics import _set_iast_error_metric
1083
1137
1138
+ wrap ("langchain_core" , "prompts.prompt.PromptTemplate.format" , iast_propagate_prompt_template_format )
1139
+
1084
1140
def wrap_output_parser (module , parser ):
1085
1141
# Ensure not double patched
1086
1142
if not isinstance (deep_getattr (module , "%s.parse" % parser ), wrapt .ObjectProxy ):
1143
+ print (f"PATCHING wrap_output_parser { module } { parser } " )
1087
1144
wrap (module , "%s.parse" % parser , taint_parser_output )
1145
+ else :
1146
+ print (f"NOT PATCHING wrap_output_parser { module } { parser } " )
1088
1147
1089
1148
try :
1090
1149
with_agent_output_parser (wrap_output_parser )
@@ -1125,13 +1184,37 @@ def unpatch():
1125
1184
delattr (langchain , "_datadog_integration" )
1126
1185
1127
1186
1128
- def taint_parser_output (func , instance , args , kwargs ):
1129
- from ddtrace .appsec ._iast ._metrics import _set_iast_error_metric
1130
- from ddtrace .appsec ._iast ._taint_tracking ._taint_objects import get_tainted_ranges
1131
- from ddtrace .appsec ._iast ._taint_tracking ._taint_objects import taint_pyobject
1187
+ def iast_propagate_prompt_template_format (func , instance , args , kwargs ):
1188
+ result = func (* args , ** kwargs )
1189
+ try :
1190
+ if not asm_config .is_iast_request_enabled :
1191
+ return result
1192
+
1193
+ from ddtrace .appsec ._iast ._taint_tracking ._taint_objects import get_tainted_ranges
1194
+ from ddtrace .appsec ._iast ._taint_tracking ._taint_objects import taint_pyobject
1195
+
1196
+ source = None
1197
+ for value in kwargs .values ():
1198
+ ranges = get_tainted_ranges (value )
1199
+ if ranges :
1200
+ source = ranges [0 ].source
1201
+ break
1202
+ if source :
1203
+ return taint_pyobject (result , source .name , source .value , source .origin )
1204
+
1205
+ except Exception as e :
1206
+ from ddtrace .appsec ._iast ._metrics import _set_iast_error_metric
1207
+
1208
+ _set_iast_error_metric ("IAST propagation error. langchain iast_propagate_prompt_template_format. {}" .format (e ))
1209
+ return result
1210
+
1132
1211
1212
+ def taint_parser_output (func , instance , args , kwargs ):
1133
1213
result = func (* args , ** kwargs )
1134
1214
try :
1215
+ from ddtrace .appsec ._iast ._taint_tracking ._taint_objects import get_tainted_ranges
1216
+ from ddtrace .appsec ._iast ._taint_tracking ._taint_objects import taint_pyobject
1217
+
1135
1218
try :
1136
1219
from langchain_core .agents import AgentAction
1137
1220
from langchain_core .agents import AgentFinish
@@ -1147,6 +1230,8 @@ def taint_parser_output(func, instance, args, kwargs):
1147
1230
values = result .return_values
1148
1231
values ["output" ] = taint_pyobject (values ["output" ], source .name , source .value , source .origin )
1149
1232
except Exception as e :
1233
+ from ddtrace .appsec ._iast ._metrics import _set_iast_error_metric
1234
+
1150
1235
_set_iast_error_metric ("IAST propagation error. langchain taint_parser_output. {}" .format (e ))
1151
1236
1152
1237
return result
0 commit comments