@@ -196,6 +196,9 @@ def traced_llm_generate(langchain, pin, func, instance, args, kwargs):
196
196
span .set_tag_str ("langchain.request.%s.parameters.%s" % (llm_provider , param ), str (val ))
197
197
198
198
completions = func (* args , ** kwargs )
199
+
200
+ _iast_taint_llm_output (prompts , completions )
201
+
199
202
if _is_openai_llm_instance (instance ):
200
203
_tag_openai_token_usage (span , completions .llm_output )
201
204
@@ -252,6 +255,9 @@ async def traced_llm_agenerate(langchain, pin, func, instance, args, kwargs):
252
255
span .set_tag_str ("langchain.request.%s.parameters.%s" % (llm_provider , param ), str (val ))
253
256
254
257
completions = await func (* args , ** kwargs )
258
+
259
+ _iast_taint_llm_output (prompts , completions )
260
+
255
261
if _is_openai_llm_instance (instance ):
256
262
_tag_openai_token_usage (span , completions .llm_output )
257
263
@@ -942,6 +948,59 @@ async def traced_base_tool_ainvoke(langchain, pin, func, instance, args, kwargs)
942
948
return tool_output
943
949
944
950
951
+ def _iast_taint_llm_output (prompts , completions ):
952
+ """
953
+ Taints the output of an LLM call if its inputs are tainted.
954
+
955
+ Range propagation does not make sense in LLMs. So we get the first source in inputs, if any,
956
+ and taint the full output with that source.
957
+ """
958
+ print (f"_iast_taint_llm_output: prompts={ prompts } , completions={ completions } " )
959
+ if not asm_config ._iast_enabled :
960
+ return
961
+ if not isinstance (prompts , (tuple , list )):
962
+ return
963
+ if not hasattr (completions , "generations" ):
964
+ return
965
+ try :
966
+ generations = completions .generations
967
+ if not isinstance (generations , list ):
968
+ return
969
+
970
+ from ddtrace .appsec ._iast ._taint_tracking ._taint_objects import get_tainted_ranges
971
+ from ddtrace .appsec ._iast ._taint_tracking ._taint_objects import taint_pyobject
972
+
973
+ source = None
974
+ for prompt in prompts :
975
+ if not isinstance (prompt , str ):
976
+ continue
977
+ tainted_ranges = get_tainted_ranges (prompt )
978
+ if tainted_ranges :
979
+ source = tainted_ranges [0 ].source
980
+ break
981
+ if not source :
982
+ return
983
+ for gens in generations :
984
+ for gen in gens :
985
+ if not hasattr (gen , "text" ):
986
+ continue
987
+ text = gen .text
988
+ if not isinstance (text , str ):
989
+ continue
990
+ print ("TAINT LLM OUTPUT" )
991
+ new_text = taint_pyobject (
992
+ pyobject = text ,
993
+ source_name = source .name ,
994
+ source_value = source .value ,
995
+ source_origin = source .origin ,
996
+ )
997
+ setattr (gen , "text" , new_text )
998
+ except Exception as e :
999
+ from ddtrace .appsec ._iast ._metrics import _set_iast_error_metric
1000
+
1001
+ _set_iast_error_metric ("IAST propagation error. langchain _iast_taint_llm_output. {}" .format (e ))
1002
+
1003
+
945
1004
def _patch_embeddings_and_vectorstores ():
946
1005
"""
947
1006
Text embedding models override two abstract base methods instead of super calls,
@@ -1081,10 +1140,15 @@ def patch():
1081
1140
if asm_config ._iast_enabled :
1082
1141
from ddtrace .appsec ._iast ._metrics import _set_iast_error_metric
1083
1142
1143
+ wrap ("langchain_core" , "prompts.prompt.PromptTemplate.format" , iast_propagate_prompt_template_format )
1144
+ wrap ("langchain_core" , "prompts.prompt.PromptTemplate.aformat" , iast_propagate_prompt_template_aformat )
1145
+
1084
1146
def wrap_output_parser (module , parser ):
1085
1147
# Ensure not double patched
1086
1148
if not isinstance (deep_getattr (module , "%s.parse" % parser ), wrapt .ObjectProxy ):
1087
- wrap (module , "%s.parse" % parser , taint_parser_output )
1149
+ wrap (module , "%s.parse" % parser , iast_propagate_output_parse )
1150
+ if not isinstance (deep_getattr (module , "%s.aparse" % parser ), wrapt .ObjectProxy ):
1151
+ wrap (module , "%s.aparse" % parser , iast_propagate_output_aparse )
1088
1152
1089
1153
try :
1090
1154
with_agent_output_parser (wrap_output_parser )
@@ -1114,6 +1178,7 @@ def unpatch():
1114
1178
unwrap (langchain_core .language_models .llms .BaseLLM , "astream" )
1115
1179
unwrap (langchain_core .tools .BaseTool , "invoke" )
1116
1180
unwrap (langchain_core .tools .BaseTool , "ainvoke" )
1181
+
1117
1182
if langchain_openai :
1118
1183
unwrap (langchain_openai .OpenAIEmbeddings , "embed_documents" )
1119
1184
if langchain_pinecone :
@@ -1122,16 +1187,70 @@ def unpatch():
1122
1187
if langchain_community :
1123
1188
_unpatch_embeddings_and_vectorstores ()
1124
1189
1190
+ if asm_config ._iast_enabled :
1191
+ unwrap (langchain_core .prompts .prompt .PromptTemplate , "format" )
1192
+ unwrap (langchain_core .prompts .prompt .PromptTemplate , "aformat" )
1193
+
1125
1194
delattr (langchain , "_datadog_integration" )
1126
1195
1127
1196
1128
- def taint_parser_output (func , instance , args , kwargs ):
1129
- from ddtrace .appsec ._iast ._metrics import _set_iast_error_metric
1130
- from ddtrace .appsec ._iast ._taint_tracking ._taint_objects import get_tainted_ranges
1131
- from ddtrace .appsec ._iast ._taint_tracking ._taint_objects import taint_pyobject
1197
+ def iast_propagate_prompt_template_format (func , instance , args , kwargs ):
1198
+ """
1199
+ Propagate taint in PromptTemplate.format, from any input, to the output.
1200
+ """
1201
+ result = func (* args , ** kwargs )
1202
+ return _iast_propagate_prompt_template_format_inner (kwargs , result )
1203
+
1204
+
1205
+ async def iast_propagate_prompt_template_aformat (func , instance , args , kwargs ):
1206
+ """
1207
+ Propagate taint in PromptTemplate.aformat, from any input, to the output.
1208
+ """
1209
+ result = await func (* args , ** kwargs )
1210
+ return _iast_propagate_prompt_template_format_inner (kwargs , result )
1211
+
1212
+
1213
+ def _iast_propagate_prompt_template_format_inner (kwargs , result ):
1214
+ print (f"_iast_propagate_prompt_template_format_inner: { kwargs } , { result } " )
1215
+ try :
1216
+ if not asm_config .is_iast_request_enabled :
1217
+ return result
1218
+
1219
+ from ddtrace .appsec ._iast ._taint_tracking ._taint_objects import get_tainted_ranges
1220
+ from ddtrace .appsec ._iast ._taint_tracking ._taint_objects import taint_pyobject
1221
+
1222
+ source = None
1223
+ for value in kwargs .values ():
1224
+ ranges = get_tainted_ranges (value )
1225
+ if ranges :
1226
+ source = ranges [0 ].source
1227
+ break
1228
+ if source :
1229
+ print ("TAINTED TEMPLATE FORMAT" )
1230
+ return taint_pyobject (result , source .name , source .value , source .origin )
1231
+ except Exception as e :
1232
+ from ddtrace .appsec ._iast ._metrics import _set_iast_error_metric
1233
+
1234
+ _set_iast_error_metric ("IAST propagation error. langchain iast_propagate_prompt_template_format. {}" .format (e ))
1235
+ return result
1236
+
1132
1237
1238
+ def iast_propagate_output_parse (func , instance , args , kwargs ):
1133
1239
result = func (* args , ** kwargs )
1240
+ return _iast_propagate_output_parse_inner (args , kwargs , result )
1241
+
1242
+
1243
+ async def iast_propagate_output_aparse (func , instance , args , kwargs ):
1244
+ result = await func (* args , ** kwargs )
1245
+ return _iast_propagate_output_parse_inner (args , kwargs , result )
1246
+
1247
+
1248
+ def _iast_propagate_output_parse_inner (args , kwargs , result ):
1249
+ print (f"_iast_propagate_output_parse_inner: { args } , { kwargs } , { result } " )
1134
1250
try :
1251
+ from ddtrace .appsec ._iast ._taint_tracking ._taint_objects import get_tainted_ranges
1252
+ from ddtrace .appsec ._iast ._taint_tracking ._taint_objects import taint_pyobject
1253
+
1135
1254
try :
1136
1255
from langchain_core .agents import AgentAction
1137
1256
from langchain_core .agents import AgentFinish
@@ -1141,12 +1260,17 @@ def taint_parser_output(func, instance, args, kwargs):
1141
1260
ranges = get_tainted_ranges (args [0 ])
1142
1261
if ranges :
1143
1262
source = ranges [0 ].source
1263
+ print ("WILL TAINT" )
1144
1264
if isinstance (result , AgentAction ):
1265
+ print ("TAINTED TOOL INPUT" )
1145
1266
result .tool_input = taint_pyobject (result .tool_input , source .name , source .value , source .origin )
1146
1267
elif isinstance (result , AgentFinish ) and "output" in result .return_values :
1268
+ print ("TAINTED OUTPUT" )
1147
1269
values = result .return_values
1148
1270
values ["output" ] = taint_pyobject (values ["output" ], source .name , source .value , source .origin )
1149
1271
except Exception as e :
1272
+ from ddtrace .appsec ._iast ._metrics import _set_iast_error_metric
1273
+
1150
1274
_set_iast_error_metric ("IAST propagation error. langchain taint_parser_output. {}" .format (e ))
1151
1275
1152
1276
return result
0 commit comments