Skip to content

Commit e170cc9

Browse files
committed
feat(iast): add support for langchain v0.1.0+
1 parent 225cacb commit e170cc9

File tree

7 files changed

+170
-32
lines changed

7 files changed

+170
-32
lines changed

ddtrace/contrib/internal/langchain/patch.py

+89-4
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,9 @@ def traced_llm_generate(langchain, pin, func, instance, args, kwargs):
196196
span.set_tag_str("langchain.request.%s.parameters.%s" % (llm_provider, param), str(val))
197197

198198
completions = func(*args, **kwargs)
199+
200+
_iast_taint_llm_output(prompts, completions)
201+
199202
if _is_openai_llm_instance(instance):
200203
_tag_openai_token_usage(span, completions.llm_output)
201204

@@ -942,6 +945,57 @@ async def traced_base_tool_ainvoke(langchain, pin, func, instance, args, kwargs)
942945
return tool_output
943946

944947

948+
def _iast_taint_llm_output(prompts, completions):
949+
"""
950+
Taints the output of an LLM call if its inputs are tainted.
951+
952+
Range propagation does not make sense in LLMs. So we get the first source in inputs, if any,
953+
and taint the full output with that source.
954+
"""
955+
if not asm_config._iast_enabled:
956+
return
957+
if not isinstance(prompts, (tuple, list)):
958+
return
959+
if not hasattr(completions, "generations"):
960+
return
961+
try:
962+
generations = completions.generations
963+
if not isinstance(generations, list):
964+
return
965+
966+
from ddtrace.appsec._iast._taint_tracking._taint_objects import get_tainted_ranges
967+
from ddtrace.appsec._iast._taint_tracking._taint_objects import taint_pyobject
968+
969+
source = None
970+
for prompt in prompts:
971+
if not isinstance(prompt, str):
972+
continue
973+
tainted_ranges = get_tainted_ranges(prompt)
974+
if tainted_ranges:
975+
source = tainted_ranges[0].source
976+
break
977+
if not source:
978+
return
979+
for gens in generations:
980+
for gen in gens:
981+
if not hasattr(gen, "text"):
982+
continue
983+
text = gen.text
984+
if not isinstance(text, str):
985+
continue
986+
new_text = taint_pyobject(
987+
pyobject=text,
988+
source_name=source.name,
989+
source_value=source.value,
990+
source_origin=source.origin,
991+
)
992+
setattr(gen, "text", new_text)
993+
except Exception as e:
994+
from ddtrace.appsec._iast._metrics import _set_iast_error_metric
995+
996+
_set_iast_error_metric("IAST propagation error. langchain _iast_taint_llm_output. {}".format(e))
997+
998+
945999
def _patch_embeddings_and_vectorstores():
9461000
"""
9471001
Text embedding models override two abstract base methods instead of super calls,
@@ -1081,10 +1135,15 @@ def patch():
10811135
if asm_config._iast_enabled:
10821136
from ddtrace.appsec._iast._metrics import _set_iast_error_metric
10831137

1138+
wrap("langchain_core", "prompts.prompt.PromptTemplate.format", iast_propagate_prompt_template_format)
1139+
10841140
def wrap_output_parser(module, parser):
10851141
# Ensure not double patched
10861142
if not isinstance(deep_getattr(module, "%s.parse" % parser), wrapt.ObjectProxy):
1143+
print(f"PATCHING wrap_output_parser {module} {parser}")
10871144
wrap(module, "%s.parse" % parser, taint_parser_output)
1145+
else:
1146+
print(f"NOT PATCHING wrap_output_parser {module} {parser}")
10881147

10891148
try:
10901149
with_agent_output_parser(wrap_output_parser)
@@ -1125,13 +1184,37 @@ def unpatch():
11251184
delattr(langchain, "_datadog_integration")
11261185

11271186

1128-
def taint_parser_output(func, instance, args, kwargs):
1129-
from ddtrace.appsec._iast._metrics import _set_iast_error_metric
1130-
from ddtrace.appsec._iast._taint_tracking._taint_objects import get_tainted_ranges
1131-
from ddtrace.appsec._iast._taint_tracking._taint_objects import taint_pyobject
1187+
def iast_propagate_prompt_template_format(func, instance, args, kwargs):
1188+
result = func(*args, **kwargs)
1189+
try:
1190+
if not asm_config.is_iast_request_enabled:
1191+
return result
1192+
1193+
from ddtrace.appsec._iast._taint_tracking._taint_objects import get_tainted_ranges
1194+
from ddtrace.appsec._iast._taint_tracking._taint_objects import taint_pyobject
1195+
1196+
source = None
1197+
for value in kwargs.values():
1198+
ranges = get_tainted_ranges(value)
1199+
if ranges:
1200+
source = ranges[0].source
1201+
break
1202+
if source:
1203+
return taint_pyobject(result, source.name, source.value, source.origin)
1204+
1205+
except Exception as e:
1206+
from ddtrace.appsec._iast._metrics import _set_iast_error_metric
1207+
1208+
_set_iast_error_metric("IAST propagation error. langchain iast_propagate_prompt_template_format. {}".format(e))
1209+
return result
1210+
11321211

1212+
def taint_parser_output(func, instance, args, kwargs):
11331213
result = func(*args, **kwargs)
11341214
try:
1215+
from ddtrace.appsec._iast._taint_tracking._taint_objects import get_tainted_ranges
1216+
from ddtrace.appsec._iast._taint_tracking._taint_objects import taint_pyobject
1217+
11351218
try:
11361219
from langchain_core.agents import AgentAction
11371220
from langchain_core.agents import AgentFinish
@@ -1147,6 +1230,8 @@ def taint_parser_output(func, instance, args, kwargs):
11471230
values = result.return_values
11481231
values["output"] = taint_pyobject(values["output"], source.name, source.value, source.origin)
11491232
except Exception as e:
1233+
from ddtrace.appsec._iast._metrics import _set_iast_error_metric
1234+
11501235
_set_iast_error_metric("IAST propagation error. langchain taint_parser_output. {}".format(e))
11511236

11521237
return result

hatch.toml

+39
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,45 @@ fastapi = ["==0.94.1"]
571571
python = ["3.8", "3.10", "3.13"]
572572
fastapi = ["~=0.114.2"]
573573

574+
## ASM appsec_integrations_langchain
575+
576+
[envs.appsec_integrations_langchain]
577+
template = "appsec_integrations_langchain"
578+
dependencies = [
579+
"pytest",
580+
"pytest-cov",
581+
"langchain{matrix:langchain:}",
582+
"langchain-experimental{matrix:langchain-experimental:}",
583+
]
584+
585+
[envs.appsec_integrations_langchain.env-vars]
586+
DD_TRACE_AGENT_URL = "http://testagent:9126"
587+
_DD_IAST_PATCH_MODULES = "benchmarks.,tests.appsec."
588+
DD_IAST_REQUEST_SAMPLING = "100"
589+
DD_IAST_DEDUPLICATION_ENABLED = "false"
590+
591+
[envs.appsec_integrations_langchain.scripts]
592+
test = [
593+
"uname -a",
594+
"pip freeze",
595+
"python -m pytest -vvv {args:tests/appsec/integrations/langchain_tests/}",
596+
]
597+
598+
[[envs.appsec_integrations_langchain.matrix]]
599+
python = ["3.9", "3.10", "3.11", "3.12", "3.13"]
600+
langchain = ["~=0.1"]
601+
langchain-experimental = ["~=0.1"]
602+
603+
[[envs.appsec_integrations_langchain.matrix]]
604+
python = ["3.9", "3.10", "3.11", "3.12", "3.13"]
605+
langchain = ["~=0.2"]
606+
langchain-experimental = ["~=0.2"]
607+
608+
[[envs.appsec_integrations_langchain.matrix]]
609+
python = ["3.9", "3.10", "3.11", "3.12", "3.13"]
610+
langchain = ["~=0.3"]
611+
langchain-experimental = ["~=0.3"]
612+
574613
## ASM FastAPI
575614

576615
[envs.appsec_threats_fastapi]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
features:
3+
- |
4+
Code Security: IAST support for langchain v0.1.0 and above.

tests/appsec/iast/iast_utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def get_line_and_hash(label: Text, vuln_type: Text, filename=None, fixed_line=No
5252
def _iast_patched_module_and_patched_source(module_name, new_module_object=False):
5353
module = importlib.import_module(module_name)
5454
module_path, patched_source = astpatch_module(module)
55+
assert patched_source is not None
5556
compiled_code = compile(patched_source, module_path, "exec")
5657
module_changed = types.ModuleType(module_name) if new_module_object else module
5758
exec(compiled_code, module_changed.__dict__)

tests/appsec/integrations/fixtures/patch_langchain.py

-13
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from ddtrace.appsec._iast import enable_iast_propagation
2+
from ddtrace.appsec._iast._patch_modules import patch_iast
3+
from tests.utils import override_env
4+
from tests.utils import override_global_config
5+
6+
7+
# `pytest` automatically calls this function once when tests are run.
8+
def pytest_configure():
9+
with override_global_config(
10+
dict(
11+
_iast_enabled=True,
12+
_iast_deduplication_enabled=False,
13+
_iast_request_sampling=100.0,
14+
)
15+
), override_env(dict(_DD_IAST_PATCH_MODULES="tests.appsec.integrations")):
16+
patch_iast()
17+
enable_iast_propagation()
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,39 @@
1-
import pytest
1+
from langchain.agents import AgentType
2+
from langchain.agents import initialize_agent
3+
from langchain_community.tools.shell.tool import ShellTool
4+
from langchain_core.language_models.fake import FakeListLLM
25

36
from ddtrace.appsec._iast.constants import VULN_CMDI
4-
from ddtrace.internal.module import is_module_installed
5-
from tests.appsec.iast.conftest import iast_context_defaults # noqa: F401
6-
from tests.appsec.iast.iast_utils import _iast_patched_module
7+
from tests.appsec.iast.conftest import iast_span_defaults # noqa: F401
78
from tests.appsec.iast.iast_utils import get_line_and_hash
89
from tests.appsec.iast.taint_sinks.conftest import _get_span_report
910
from tests.utils import override_env
1011

1112

12-
FIXTURES_PATH = "tests/appsec/integrations/fixtures/patch_langchain.py"
13-
FIXTURES_MODULE = "tests.appsec.integrations.fixtures.patch_langchain"
14-
1513
with override_env({"DD_IAST_ENABLED": "True"}):
1614
from ddtrace.appsec._iast._taint_tracking import OriginType
1715
from ddtrace.appsec._iast._taint_tracking._taint_objects import taint_pyobject
1816

17+
TEST_FILE = "tests/appsec/integrations/langchain_tests/test_iast_langchain.py"
18+
19+
20+
def test_openai_llm_appsec_iast_cmdi(iast_span_defaults): # noqa: F811
21+
responses = ["Action: terminal\nAction Input: echo Hello World", "Final Answer: 4"]
22+
llm = FakeListLLM(responses=responses)
23+
shell = ShellTool()
24+
shell_chain = initialize_agent([shell], llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True)
1925

20-
@pytest.mark.skipif(not is_module_installed("langchain"), reason="Langchain tests work on 3.9 or higher")
21-
def test_openai_llm_appsec_iast_cmdi(iast_context_defaults): # noqa: F811
22-
mod = _iast_patched_module(FIXTURES_MODULE)
2326
string_to_taint = "I need to use the terminal tool to print a Hello World"
2427
prompt = taint_pyobject(
2528
pyobject=string_to_taint,
2629
source_name="test_openai_llm_appsec_iast_cmdi",
2730
source_value=string_to_taint,
2831
source_origin=OriginType.PARAMETER,
2932
)
30-
res = mod.patch_langchain(prompt)
31-
assert res == "4"
33+
34+
# label test_openai_llm_appsec_iast_cmdi
35+
res = shell_chain.invoke(prompt)
36+
assert res["output"] == "4"
3237

3338
span_report = _get_span_report()
3439
assert span_report
@@ -48,9 +53,9 @@ def test_openai_llm_appsec_iast_cmdi(iast_context_defaults): # noqa: F811
4853
assert source["origin"] == OriginType.PARAMETER
4954
assert "value" not in source.keys()
5055

51-
line, hash_value = get_line_and_hash("test_openai_llm_appsec_iast_cmdi", VULN_CMDI, filename=FIXTURES_PATH)
52-
assert vulnerability["location"]["path"] == FIXTURES_PATH
56+
line, hash_value = get_line_and_hash("test_openai_llm_appsec_iast_cmdi", VULN_CMDI, filename=TEST_FILE)
57+
assert vulnerability["location"]["path"] == TEST_FILE
5358
assert vulnerability["location"]["line"] == line
54-
assert vulnerability["location"]["method"] == "patch_langchain"
59+
assert vulnerability["location"]["method"] == "test_openai_llm_appsec_iast_cmdi"
5560
assert vulnerability["location"]["class_name"] == ""
5661
assert vulnerability["hash"] == hash_value

0 commit comments

Comments
 (0)