@@ -196,6 +196,9 @@ def traced_llm_generate(langchain, pin, func, instance, args, kwargs):
196
196
span .set_tag_str ("langchain.request.%s.parameters.%s" % (llm_provider , param ), str (val ))
197
197
198
198
completions = func (* args , ** kwargs )
199
+
200
+ _iast_taint_llm_output (prompts , completions )
201
+
199
202
if _is_openai_llm_instance (instance ):
200
203
_tag_openai_token_usage (span , completions .llm_output )
201
204
@@ -252,6 +255,9 @@ async def traced_llm_agenerate(langchain, pin, func, instance, args, kwargs):
252
255
span .set_tag_str ("langchain.request.%s.parameters.%s" % (llm_provider , param ), str (val ))
253
256
254
257
completions = await func (* args , ** kwargs )
258
+
259
+ _iast_taint_llm_output (prompts , completions )
260
+
255
261
if _is_openai_llm_instance (instance ):
256
262
_tag_openai_token_usage (span , completions .llm_output )
257
263
@@ -942,6 +948,57 @@ async def traced_base_tool_ainvoke(langchain, pin, func, instance, args, kwargs)
942
948
return tool_output
943
949
944
950
951
+ def _iast_taint_llm_output (prompts , completions ):
952
+ """
953
+ Taints the output of an LLM call if its inputs are tainted.
954
+
955
+ Range propagation does not make sense in LLMs. So we get the first source in inputs, if any,
956
+ and taint the full output with that source.
957
+ """
958
+ if not asm_config ._iast_enabled :
959
+ return
960
+ if not isinstance (prompts , (tuple , list )):
961
+ return
962
+ if not hasattr (completions , "generations" ):
963
+ return
964
+ try :
965
+ generations = completions .generations
966
+ if not isinstance (generations , list ):
967
+ return
968
+
969
+ from ddtrace .appsec ._iast ._taint_tracking ._taint_objects import get_tainted_ranges
970
+ from ddtrace .appsec ._iast ._taint_tracking ._taint_objects import taint_pyobject
971
+
972
+ source = None
973
+ for prompt in prompts :
974
+ if not isinstance (prompt , str ):
975
+ continue
976
+ tainted_ranges = get_tainted_ranges (prompt )
977
+ if tainted_ranges :
978
+ source = tainted_ranges [0 ].source
979
+ break
980
+ if not source :
981
+ return
982
+ for gens in generations :
983
+ for gen in gens :
984
+ if not hasattr (gen , "text" ):
985
+ continue
986
+ text = gen .text
987
+ if not isinstance (text , str ):
988
+ continue
989
+ new_text = taint_pyobject (
990
+ pyobject = text ,
991
+ source_name = source .name ,
992
+ source_value = source .value ,
993
+ source_origin = source .origin ,
994
+ )
995
+ setattr (gen , "text" , new_text )
996
+ except Exception as e :
997
+ from ddtrace .appsec ._iast ._metrics import _set_iast_error_metric
998
+
999
+ _set_iast_error_metric ("IAST propagation error. langchain _iast_taint_llm_output. {}" .format (e ))
1000
+
1001
+
945
1002
def _patch_embeddings_and_vectorstores ():
946
1003
"""
947
1004
Text embedding models override two abstract base methods instead of super calls,
@@ -1081,10 +1138,15 @@ def patch():
1081
1138
if asm_config ._iast_enabled :
1082
1139
from ddtrace .appsec ._iast ._metrics import _set_iast_error_metric
1083
1140
1141
+ wrap ("langchain_core" , "prompts.prompt.PromptTemplate.format" , iast_propagate_prompt_template_format )
1142
+ wrap ("langchain_core" , "prompts.prompt.PromptTemplate.aformat" , iast_propagate_prompt_template_aformat )
1143
+
1084
1144
def wrap_output_parser (module , parser ):
1085
1145
# Ensure not double patched
1086
1146
if not isinstance (deep_getattr (module , "%s.parse" % parser ), wrapt .ObjectProxy ):
1087
- wrap (module , "%s.parse" % parser , taint_parser_output )
1147
+ wrap (module , "%s.parse" % parser , iast_propagate_output_parse )
1148
+ if not isinstance (deep_getattr (module , "%s.aparse" % parser ), wrapt .ObjectProxy ):
1149
+ wrap (module , "%s.aparse" % parser , iast_propagate_output_aparse )
1088
1150
1089
1151
try :
1090
1152
with_agent_output_parser (wrap_output_parser )
@@ -1114,6 +1176,7 @@ def unpatch():
1114
1176
unwrap (langchain_core .language_models .llms .BaseLLM , "astream" )
1115
1177
unwrap (langchain_core .tools .BaseTool , "invoke" )
1116
1178
unwrap (langchain_core .tools .BaseTool , "ainvoke" )
1179
+
1117
1180
if langchain_openai :
1118
1181
unwrap (langchain_openai .OpenAIEmbeddings , "embed_documents" )
1119
1182
if langchain_pinecone :
@@ -1122,16 +1185,67 @@ def unpatch():
1122
1185
if langchain_community :
1123
1186
_unpatch_embeddings_and_vectorstores ()
1124
1187
1188
+ if asm_config ._iast_enabled :
1189
+ unwrap (langchain_core .prompts .prompt .PromptTemplate , "format" )
1190
+ unwrap (langchain_core .prompts .prompt .PromptTemplate , "aformat" )
1191
+
1125
1192
delattr (langchain , "_datadog_integration" )
1126
1193
1127
1194
1128
- def taint_parser_output (func , instance , args , kwargs ):
1129
- from ddtrace .appsec ._iast ._metrics import _set_iast_error_metric
1130
- from ddtrace .appsec ._iast ._taint_tracking ._taint_objects import get_tainted_ranges
1131
- from ddtrace .appsec ._iast ._taint_tracking ._taint_objects import taint_pyobject
1195
+ def iast_propagate_prompt_template_format (func , instance , args , kwargs ):
1196
+ """
1197
+ Propagate taint in PromptTemplate.format, from any input, to the output.
1198
+ """
1199
+ result = func (* args , ** kwargs )
1200
+ return _iast_propagate_prompt_template_format_inner (kwargs , result )
1201
+
1202
+
1203
+ async def iast_propagate_prompt_template_aformat (func , instance , args , kwargs ):
1204
+ """
1205
+ Propagate taint in PromptTemplate.aformat, from any input, to the output.
1206
+ """
1207
+ result = await func (* args , ** kwargs )
1208
+ return _iast_propagate_prompt_template_format_inner (kwargs , result )
1209
+
1210
+
1211
+ def _iast_propagate_prompt_template_format_inner (kwargs , result ):
1212
+ try :
1213
+ if not asm_config .is_iast_request_enabled :
1214
+ return result
1215
+
1216
+ from ddtrace .appsec ._iast ._taint_tracking ._taint_objects import get_tainted_ranges
1217
+ from ddtrace .appsec ._iast ._taint_tracking ._taint_objects import taint_pyobject
1218
+
1219
+ source = None
1220
+ for value in kwargs .values ():
1221
+ ranges = get_tainted_ranges (value )
1222
+ if ranges :
1223
+ source = ranges [0 ].source
1224
+ break
1225
+ if source :
1226
+ return taint_pyobject (result , source .name , source .value , source .origin )
1227
+ except Exception as e :
1228
+ from ddtrace .appsec ._iast ._metrics import _set_iast_error_metric
1229
+
1230
+ _set_iast_error_metric ("IAST propagation error. langchain iast_propagate_prompt_template_format. {}" .format (e ))
1231
+ return result
1232
+
1132
1233
1234
+ def iast_propagate_output_parse (func , instance , args , kwargs ):
1133
1235
result = func (* args , ** kwargs )
1236
+ return _iast_propagate_output_parse_inner (args , kwargs , result )
1237
+
1238
+
1239
+ async def iast_propagate_output_aparse (func , instance , args , kwargs ):
1240
+ result = await func (* args , ** kwargs )
1241
+ return _iast_propagate_output_parse_inner (args , kwargs , result )
1242
+
1243
+
1244
+ def _iast_propagate_output_parse_inner (args , kwargs , result ):
1134
1245
try :
1246
+ from ddtrace .appsec ._iast ._taint_tracking ._taint_objects import get_tainted_ranges
1247
+ from ddtrace .appsec ._iast ._taint_tracking ._taint_objects import taint_pyobject
1248
+
1135
1249
try :
1136
1250
from langchain_core .agents import AgentAction
1137
1251
from langchain_core .agents import AgentFinish
@@ -1147,6 +1261,8 @@ def taint_parser_output(func, instance, args, kwargs):
1147
1261
values = result .return_values
1148
1262
values ["output" ] = taint_pyobject (values ["output" ], source .name , source .value , source .origin )
1149
1263
except Exception as e :
1264
+ from ddtrace .appsec ._iast ._metrics import _set_iast_error_metric
1265
+
1150
1266
_set_iast_error_metric ("IAST propagation error. langchain taint_parser_output. {}" .format (e ))
1151
1267
1152
1268
return result
0 commit comments