Skip to content

Commit 48937b5

Browse files
committed
feat(iast): add support for langchain v0.1.0+
1 parent 0ee168f commit 48937b5

File tree

10 files changed

+352
-85
lines changed

10 files changed

+352
-85
lines changed

ddtrace/appsec/_iast/_ast/iastpatch.c

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ static size_t cached_packages_count = 0;
1818

1919
/* Static Lists */
2020
static const char* static_allowlist[] = {
21-
"jinja2.", "pygments.", "multipart.", "sqlalchemy.", "python_multipart.", "attrs.",
22-
"jsonschema.", "s3fs.", "mysql.", "pymysql.", "markupsafe.", "werkzeug.utils."
21+
"jinja2.", "pygments.", "multipart.", "sqlalchemy.", "python_multipart.", "attrs.", "jsonschema.",
22+
"s3fs.", "mysql.", "pymysql.", "markupsafe.", "werkzeug.utils.", "langchain_core."
2323
};
2424
static const size_t static_allowlist_count = sizeof(static_allowlist) / sizeof(static_allowlist[0]);
2525

ddtrace/contrib/internal/langchain/patch.py

+121-5
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,9 @@ def traced_llm_generate(langchain, pin, func, instance, args, kwargs):
196196
span.set_tag_str("langchain.request.%s.parameters.%s" % (llm_provider, param), str(val))
197197

198198
completions = func(*args, **kwargs)
199+
200+
_iast_taint_llm_output(prompts, completions)
201+
199202
if _is_openai_llm_instance(instance):
200203
_tag_openai_token_usage(span, completions.llm_output)
201204

@@ -252,6 +255,9 @@ async def traced_llm_agenerate(langchain, pin, func, instance, args, kwargs):
252255
span.set_tag_str("langchain.request.%s.parameters.%s" % (llm_provider, param), str(val))
253256

254257
completions = await func(*args, **kwargs)
258+
259+
_iast_taint_llm_output(prompts, completions)
260+
255261
if _is_openai_llm_instance(instance):
256262
_tag_openai_token_usage(span, completions.llm_output)
257263

@@ -942,6 +948,57 @@ async def traced_base_tool_ainvoke(langchain, pin, func, instance, args, kwargs)
942948
return tool_output
943949

944950

951+
def _iast_taint_llm_output(prompts, completions):
952+
"""
953+
Taints the output of an LLM call if its inputs are tainted.
954+
955+
Range propagation does not make sense in LLMs. So we get the first source in inputs, if any,
956+
and taint the full output with that source.
957+
"""
958+
if not asm_config._iast_enabled:
959+
return
960+
if not isinstance(prompts, (tuple, list)):
961+
return
962+
if not hasattr(completions, "generations"):
963+
return
964+
try:
965+
generations = completions.generations
966+
if not isinstance(generations, list):
967+
return
968+
969+
from ddtrace.appsec._iast._taint_tracking._taint_objects import get_tainted_ranges
970+
from ddtrace.appsec._iast._taint_tracking._taint_objects import taint_pyobject
971+
972+
source = None
973+
for prompt in prompts:
974+
if not isinstance(prompt, str):
975+
continue
976+
tainted_ranges = get_tainted_ranges(prompt)
977+
if tainted_ranges:
978+
source = tainted_ranges[0].source
979+
break
980+
if not source:
981+
return
982+
for gens in generations:
983+
for gen in gens:
984+
if not hasattr(gen, "text"):
985+
continue
986+
text = gen.text
987+
if not isinstance(text, str):
988+
continue
989+
new_text = taint_pyobject(
990+
pyobject=text,
991+
source_name=source.name,
992+
source_value=source.value,
993+
source_origin=source.origin,
994+
)
995+
setattr(gen, "text", new_text)
996+
except Exception as e:
997+
from ddtrace.appsec._iast._metrics import _set_iast_error_metric
998+
999+
_set_iast_error_metric("IAST propagation error. langchain _iast_taint_llm_output. {}".format(e))
1000+
1001+
9451002
def _patch_embeddings_and_vectorstores():
9461003
"""
9471004
Text embedding models override two abstract base methods instead of super calls,
@@ -1081,10 +1138,15 @@ def patch():
10811138
if asm_config._iast_enabled:
10821139
from ddtrace.appsec._iast._metrics import _set_iast_error_metric
10831140

1141+
wrap("langchain_core", "prompts.prompt.PromptTemplate.format", iast_propagate_prompt_template_format)
1142+
wrap("langchain_core", "prompts.prompt.PromptTemplate.aformat", iast_propagate_prompt_template_aformat)
1143+
10841144
def wrap_output_parser(module, parser):
10851145
# Ensure not double patched
10861146
if not isinstance(deep_getattr(module, "%s.parse" % parser), wrapt.ObjectProxy):
1087-
wrap(module, "%s.parse" % parser, taint_parser_output)
1147+
wrap(module, "%s.parse" % parser, iast_propagate_output_parse)
1148+
if not isinstance(deep_getattr(module, "%s.aparse" % parser), wrapt.ObjectProxy):
1149+
wrap(module, "%s.aparse" % parser, iast_propagate_output_aparse)
10881150

10891151
try:
10901152
with_agent_output_parser(wrap_output_parser)
@@ -1114,6 +1176,7 @@ def unpatch():
11141176
unwrap(langchain_core.language_models.llms.BaseLLM, "astream")
11151177
unwrap(langchain_core.tools.BaseTool, "invoke")
11161178
unwrap(langchain_core.tools.BaseTool, "ainvoke")
1179+
11171180
if langchain_openai:
11181181
unwrap(langchain_openai.OpenAIEmbeddings, "embed_documents")
11191182
if langchain_pinecone:
@@ -1122,16 +1185,67 @@ def unpatch():
11221185
if langchain_community:
11231186
_unpatch_embeddings_and_vectorstores()
11241187

1188+
if asm_config._iast_enabled:
1189+
unwrap(langchain_core.prompts.prompt.PromptTemplate, "format")
1190+
unwrap(langchain_core.prompts.prompt.PromptTemplate, "aformat")
1191+
11251192
delattr(langchain, "_datadog_integration")
11261193

11271194

1128-
def taint_parser_output(func, instance, args, kwargs):
1129-
from ddtrace.appsec._iast._metrics import _set_iast_error_metric
1130-
from ddtrace.appsec._iast._taint_tracking._taint_objects import get_tainted_ranges
1131-
from ddtrace.appsec._iast._taint_tracking._taint_objects import taint_pyobject
1195+
def iast_propagate_prompt_template_format(func, instance, args, kwargs):
1196+
"""
1197+
Propagate taint in PromptTemplate.format, from any input, to the output.
1198+
"""
1199+
result = func(*args, **kwargs)
1200+
return _iast_propagate_prompt_template_format_inner(kwargs, result)
1201+
1202+
1203+
async def iast_propagate_prompt_template_aformat(func, instance, args, kwargs):
1204+
"""
1205+
Propagate taint in PromptTemplate.aformat, from any input, to the output.
1206+
"""
1207+
result = await func(*args, **kwargs)
1208+
return _iast_propagate_prompt_template_format_inner(kwargs, result)
1209+
1210+
1211+
def _iast_propagate_prompt_template_format_inner(kwargs, result):
1212+
try:
1213+
if not asm_config.is_iast_request_enabled:
1214+
return result
1215+
1216+
from ddtrace.appsec._iast._taint_tracking._taint_objects import get_tainted_ranges
1217+
from ddtrace.appsec._iast._taint_tracking._taint_objects import taint_pyobject
1218+
1219+
source = None
1220+
for value in kwargs.values():
1221+
ranges = get_tainted_ranges(value)
1222+
if ranges:
1223+
source = ranges[0].source
1224+
break
1225+
if source:
1226+
return taint_pyobject(result, source.name, source.value, source.origin)
1227+
except Exception as e:
1228+
from ddtrace.appsec._iast._metrics import _set_iast_error_metric
1229+
1230+
_set_iast_error_metric("IAST propagation error. langchain iast_propagate_prompt_template_format. {}".format(e))
1231+
return result
1232+
11321233

1234+
def iast_propagate_output_parse(func, instance, args, kwargs):
11331235
result = func(*args, **kwargs)
1236+
return _iast_propagate_output_parse_inner(args, kwargs, result)
1237+
1238+
1239+
async def iast_propagate_output_aparse(func, instance, args, kwargs):
1240+
result = await func(*args, **kwargs)
1241+
return _iast_propagate_output_parse_inner(args, kwargs, result)
1242+
1243+
1244+
def _iast_propagate_output_parse_inner(args, kwargs, result):
11341245
try:
1246+
from ddtrace.appsec._iast._taint_tracking._taint_objects import get_tainted_ranges
1247+
from ddtrace.appsec._iast._taint_tracking._taint_objects import taint_pyobject
1248+
11351249
try:
11361250
from langchain_core.agents import AgentAction
11371251
from langchain_core.agents import AgentFinish
@@ -1147,6 +1261,8 @@ def taint_parser_output(func, instance, args, kwargs):
11471261
values = result.return_values
11481262
values["output"] = taint_pyobject(values["output"], source.name, source.value, source.origin)
11491263
except Exception as e:
1264+
from ddtrace.appsec._iast._metrics import _set_iast_error_metric
1265+
11501266
_set_iast_error_metric("IAST propagation error. langchain taint_parser_output. {}".format(e))
11511267

11521268
return result

hatch.toml

+41
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,47 @@ fastapi = ["==0.94.1"]
601601
python = ["3.8", "3.10", "3.13"]
602602
fastapi = ["~=0.114.2"]
603603

604+
## ASM appsec_integrations_langchain
605+
606+
[envs.appsec_integrations_langchain]
607+
template = "appsec_integrations_langchain"
608+
dependencies = [
609+
"pytest",
610+
"pytest-asyncio",
611+
"pytest-cov",
612+
"hypothesis",
613+
"langchain{matrix:langchain:}",
614+
"langchain-experimental{matrix:langchain-experimental:}",
615+
]
616+
617+
[envs.appsec_integrations_langchain.env-vars]
618+
DD_TRACE_AGENT_URL = "http://testagent:9126"
619+
_DD_IAST_PATCH_MODULES = "benchmarks.,tests.appsec."
620+
DD_IAST_REQUEST_SAMPLING = "100"
621+
DD_IAST_DEDUPLICATION_ENABLED = "false"
622+
623+
[envs.appsec_integrations_langchain.scripts]
624+
test = [
625+
"uname -a",
626+
"pip freeze",
627+
"python -m pytest -vvv {args:tests/appsec/integrations/langchain_tests/}",
628+
]
629+
630+
[[envs.appsec_integrations_langchain.matrix]]
631+
python = ["3.9", "3.10", "3.11", "3.12", "3.13"]
632+
langchain = ["~=0.1"]
633+
langchain-experimental = ["~=0.1"]
634+
635+
[[envs.appsec_integrations_langchain.matrix]]
636+
python = ["3.9", "3.10", "3.11", "3.12", "3.13"]
637+
langchain = ["~=0.2"]
638+
langchain-experimental = ["~=0.2"]
639+
640+
[[envs.appsec_integrations_langchain.matrix]]
641+
python = ["3.9", "3.10", "3.11", "3.12", "3.13"]
642+
langchain = ["~=0.3"]
643+
langchain-experimental = ["~=0.3"]
644+
604645
## ASM FastAPI
605646

606647
[envs.appsec_threats_fastapi]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
features:
3+
- |
4+
Code Security: IAST support for langchain v0.1.0 and above.

tests/appsec/iast/conftest.py

+1-9
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def _start_iast_context_and_oce(span=None):
5252
if oce.acquire_request(span):
5353
start_iast_context()
5454
request_iast_enabled = True
55+
5556
set_iast_request_enabled(request_iast_enabled)
5657

5758

@@ -61,13 +62,6 @@ def _end_iast_context_and_oce(span=None):
6162

6263

6364
def iast_context(env, request_sampling=100.0, deduplication=False, asm_enabled=False):
64-
try:
65-
from ddtrace.contrib.internal.langchain.patch import patch as langchain_patch
66-
from ddtrace.contrib.internal.langchain.patch import unpatch as langchain_unpatch
67-
except Exception:
68-
langchain_patch = lambda: True # noqa: E731
69-
langchain_unpatch = lambda: True # noqa: E731
70-
7165
class MockSpan:
7266
_trace_id_64bits = 17577308072598193742
7367

@@ -87,7 +81,6 @@ class MockSpan:
8781
cmdi_patch()
8882
header_injection_patch()
8983
code_injection_patch()
90-
langchain_patch()
9184
patch_common_modules()
9285
yield
9386
unpatch_common_modules()
@@ -97,7 +90,6 @@ class MockSpan:
9790
cmdi_unpatch()
9891
header_injection_unpatch()
9992
code_injection_unpatch()
100-
langchain_unpatch()
10193
_end_iast_context_and_oce()
10294

10395

tests/appsec/iast/iast_utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def get_line_and_hash(label: Text, vuln_type: Text, filename=None, fixed_line=No
5353
def _iast_patched_module_and_patched_source(module_name, new_module_object=False):
5454
module = importlib.import_module(module_name)
5555
module_path, patched_source = astpatch_module(module)
56+
assert patched_source is not None
5657
compiled_code = compile(patched_source, module_path, "exec")
5758
module_changed = types.ModuleType(module_name) if new_module_object else module
5859
exec(compiled_code, module_changed.__dict__)

tests/appsec/integrations/fixtures/patch_langchain.py

-13
This file was deleted.

tests/appsec/integrations/flask_tests/test_iast_langchain.py

-56
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from ddtrace.appsec._iast import enable_iast_propagation
2+
from ddtrace.appsec._iast._patch_modules import patch_iast
3+
from ddtrace.contrib.internal.langchain.patch import patch as langchain_patch
4+
from tests.utils import override_env
5+
from tests.utils import override_global_config
6+
7+
8+
# `pytest` automatically calls this function once when tests are run.
9+
def pytest_configure():
10+
with override_global_config(
11+
dict(
12+
_iast_enabled=True,
13+
_iast_deduplication_enabled=False,
14+
_iast_request_sampling=100.0,
15+
)
16+
), override_env(dict(_DD_IAST_PATCH_MODULES="tests.appsec.integrations")):
17+
patch_iast()
18+
enable_iast_propagation()
19+
langchain_patch()

0 commit comments

Comments
 (0)