Skip to content

Commit 5a7bba4

Browse files
committed
feat(iast): add support for langchain v0.1.0+
1 parent 8922ed8 commit 5a7bba4

File tree

14 files changed

+366
-86
lines changed

14 files changed

+366
-86
lines changed

ddtrace/appsec/_iast/_ast/iastpatch.c

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ static size_t cached_packages_count = 0;
1818

1919
/* Static Lists */
2020
static const char* static_allowlist[] = {
21-
"jinja2.", "pygments.", "multipart.", "sqlalchemy.", "python_multipart.", "attrs.",
22-
"jsonschema.", "s3fs.", "mysql.", "pymysql.", "markupsafe.", "werkzeug.utils."
21+
"jinja2.", "pygments.", "multipart.", "sqlalchemy.", "python_multipart.", "attrs.", "jsonschema.",
22+
"s3fs.", "mysql.", "pymysql.", "markupsafe.", "werkzeug.utils.", "langchain_core."
2323
};
2424
static const size_t static_allowlist_count = sizeof(static_allowlist) / sizeof(static_allowlist[0]);
2525

ddtrace/appsec/_iast/_iast_request_context.py

+3
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424

2525
def set_iast_reporter(iast_reporter: IastSpanReporter) -> None:
26+
print("set_iast_reporter")
2627
env = _get_iast_env()
2728
if env:
2829
env.iast_reporter = iast_reporter
@@ -33,7 +34,9 @@ def set_iast_reporter(iast_reporter: IastSpanReporter) -> None:
3334
def get_iast_reporter() -> Optional[IastSpanReporter]:
3435
env = _get_iast_env()
3536
if env:
37+
print("GOT ENV")
3638
return env.iast_reporter
39+
print("GOT NO ENV")
3740
return None
3841

3942

ddtrace/appsec/_iast/_overhead_control_engine.py

+1
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def has_quota(cls) -> bool:
6868
@classmethod
6969
def is_not_reported(cls, filename: Text, lineno: int) -> bool:
7070
if asm_config._iast_deduplication_enabled:
71+
print("DAMN DEDUP ENABLED")
7172
vulnerability_id = (filename, lineno)
7273
if vulnerability_id in cls._reported_vulnerabilities:
7374
return False

ddtrace/appsec/_iast/taint_sinks/_base.py

+2
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ def _create_evidence_and_report(
189189
@classmethod
190190
def report(cls, evidence_value: TEXT_TYPES = "", dialect: Optional[str] = None) -> None:
191191
"""Build a IastSpanReporter instance to report it in the `AppSecIastSpanProcessor` as a string JSON"""
192+
print(f"report {cls.vulnerability_type}")
192193
if cls.acquire_quota():
193194
file_name = line_number = function_name = class_name = None
194195

@@ -198,6 +199,7 @@ def report(cls, evidence_value: TEXT_TYPES = "", dialect: Optional[str] = None)
198199
else:
199200
file_name, line_number, function_name, class_name = cls._compute_file_line()
200201
if file_name is None:
202+
print("NO LOCATION FILE NAME")
201203
cls.increment_quota()
202204
return
203205

ddtrace/appsec/_iast/taint_sinks/command_injection.py

-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ class CommandInjection(VulnerabilityBase):
4848

4949
def _iast_report_cmdi(shell_args: Union[str, List[str]]) -> None:
5050
report_cmdi = ""
51-
5251
try:
5352
if asm_config.is_iast_request_enabled:
5453
if CommandInjection.has_quota():

ddtrace/contrib/internal/langchain/patch.py

+129-5
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,9 @@ def traced_llm_generate(langchain, pin, func, instance, args, kwargs):
196196
span.set_tag_str("langchain.request.%s.parameters.%s" % (llm_provider, param), str(val))
197197

198198
completions = func(*args, **kwargs)
199+
200+
_iast_taint_llm_output(prompts, completions)
201+
199202
if _is_openai_llm_instance(instance):
200203
_tag_openai_token_usage(span, completions.llm_output)
201204

@@ -252,6 +255,9 @@ async def traced_llm_agenerate(langchain, pin, func, instance, args, kwargs):
252255
span.set_tag_str("langchain.request.%s.parameters.%s" % (llm_provider, param), str(val))
253256

254257
completions = await func(*args, **kwargs)
258+
259+
_iast_taint_llm_output(prompts, completions)
260+
255261
if _is_openai_llm_instance(instance):
256262
_tag_openai_token_usage(span, completions.llm_output)
257263

@@ -942,6 +948,59 @@ async def traced_base_tool_ainvoke(langchain, pin, func, instance, args, kwargs)
942948
return tool_output
943949

944950

951+
def _iast_taint_llm_output(prompts, completions):
952+
"""
953+
Taints the output of an LLM call if its inputs are tainted.
954+
955+
Range propagation does not make sense in LLMs. So we get the first source in inputs, if any,
956+
and taint the full output with that source.
957+
"""
958+
print(f"_iast_taint_llm_output: prompts={prompts}, completions={completions}")
959+
if not asm_config._iast_enabled:
960+
return
961+
if not isinstance(prompts, (tuple, list)):
962+
return
963+
if not hasattr(completions, "generations"):
964+
return
965+
try:
966+
generations = completions.generations
967+
if not isinstance(generations, list):
968+
return
969+
970+
from ddtrace.appsec._iast._taint_tracking._taint_objects import get_tainted_ranges
971+
from ddtrace.appsec._iast._taint_tracking._taint_objects import taint_pyobject
972+
973+
source = None
974+
for prompt in prompts:
975+
if not isinstance(prompt, str):
976+
continue
977+
tainted_ranges = get_tainted_ranges(prompt)
978+
if tainted_ranges:
979+
source = tainted_ranges[0].source
980+
break
981+
if not source:
982+
return
983+
for gens in generations:
984+
for gen in gens:
985+
if not hasattr(gen, "text"):
986+
continue
987+
text = gen.text
988+
if not isinstance(text, str):
989+
continue
990+
print("TAINT LLM OUTPUT")
991+
new_text = taint_pyobject(
992+
pyobject=text,
993+
source_name=source.name,
994+
source_value=source.value,
995+
source_origin=source.origin,
996+
)
997+
setattr(gen, "text", new_text)
998+
except Exception as e:
999+
from ddtrace.appsec._iast._metrics import _set_iast_error_metric
1000+
1001+
_set_iast_error_metric("IAST propagation error. langchain _iast_taint_llm_output. {}".format(e))
1002+
1003+
9451004
def _patch_embeddings_and_vectorstores():
9461005
"""
9471006
Text embedding models override two abstract base methods instead of super calls,
@@ -1081,10 +1140,15 @@ def patch():
10811140
if asm_config._iast_enabled:
10821141
from ddtrace.appsec._iast._metrics import _set_iast_error_metric
10831142

1143+
wrap("langchain_core", "prompts.prompt.PromptTemplate.format", iast_propagate_prompt_template_format)
1144+
wrap("langchain_core", "prompts.prompt.PromptTemplate.aformat", iast_propagate_prompt_template_aformat)
1145+
10841146
def wrap_output_parser(module, parser):
10851147
# Ensure not double patched
10861148
if not isinstance(deep_getattr(module, "%s.parse" % parser), wrapt.ObjectProxy):
1087-
wrap(module, "%s.parse" % parser, taint_parser_output)
1149+
wrap(module, "%s.parse" % parser, iast_propagate_output_parse)
1150+
if not isinstance(deep_getattr(module, "%s.aparse" % parser), wrapt.ObjectProxy):
1151+
wrap(module, "%s.aparse" % parser, iast_propagate_output_aparse)
10881152

10891153
try:
10901154
with_agent_output_parser(wrap_output_parser)
@@ -1114,6 +1178,7 @@ def unpatch():
11141178
unwrap(langchain_core.language_models.llms.BaseLLM, "astream")
11151179
unwrap(langchain_core.tools.BaseTool, "invoke")
11161180
unwrap(langchain_core.tools.BaseTool, "ainvoke")
1181+
11171182
if langchain_openai:
11181183
unwrap(langchain_openai.OpenAIEmbeddings, "embed_documents")
11191184
if langchain_pinecone:
@@ -1122,16 +1187,70 @@ def unpatch():
11221187
if langchain_community:
11231188
_unpatch_embeddings_and_vectorstores()
11241189

1190+
if asm_config._iast_enabled:
1191+
unwrap(langchain_core.prompts.prompt.PromptTemplate, "format")
1192+
unwrap(langchain_core.prompts.prompt.PromptTemplate, "aformat")
1193+
11251194
delattr(langchain, "_datadog_integration")
11261195

11271196

1128-
def taint_parser_output(func, instance, args, kwargs):
1129-
from ddtrace.appsec._iast._metrics import _set_iast_error_metric
1130-
from ddtrace.appsec._iast._taint_tracking._taint_objects import get_tainted_ranges
1131-
from ddtrace.appsec._iast._taint_tracking._taint_objects import taint_pyobject
1197+
def iast_propagate_prompt_template_format(func, instance, args, kwargs):
1198+
"""
1199+
Propagate taint in PromptTemplate.format, from any input, to the output.
1200+
"""
1201+
result = func(*args, **kwargs)
1202+
return _iast_propagate_prompt_template_format_inner(kwargs, result)
1203+
1204+
1205+
async def iast_propagate_prompt_template_aformat(func, instance, args, kwargs):
1206+
"""
1207+
Propagate taint in PromptTemplate.aformat, from any input, to the output.
1208+
"""
1209+
result = await func(*args, **kwargs)
1210+
return _iast_propagate_prompt_template_format_inner(kwargs, result)
1211+
1212+
1213+
def _iast_propagate_prompt_template_format_inner(kwargs, result):
1214+
print(f"_iast_propagate_prompt_template_format_inner: {kwargs}, {result}")
1215+
try:
1216+
if not asm_config.is_iast_request_enabled:
1217+
return result
1218+
1219+
from ddtrace.appsec._iast._taint_tracking._taint_objects import get_tainted_ranges
1220+
from ddtrace.appsec._iast._taint_tracking._taint_objects import taint_pyobject
1221+
1222+
source = None
1223+
for value in kwargs.values():
1224+
ranges = get_tainted_ranges(value)
1225+
if ranges:
1226+
source = ranges[0].source
1227+
break
1228+
if source:
1229+
print("TAINTED TEMPLATE FORMAT")
1230+
return taint_pyobject(result, source.name, source.value, source.origin)
1231+
except Exception as e:
1232+
from ddtrace.appsec._iast._metrics import _set_iast_error_metric
1233+
1234+
_set_iast_error_metric("IAST propagation error. langchain iast_propagate_prompt_template_format. {}".format(e))
1235+
return result
1236+
11321237

1238+
def iast_propagate_output_parse(func, instance, args, kwargs):
11331239
result = func(*args, **kwargs)
1240+
return _iast_propagate_output_parse_inner(args, kwargs, result)
1241+
1242+
1243+
async def iast_propagate_output_aparse(func, instance, args, kwargs):
1244+
result = await func(*args, **kwargs)
1245+
return _iast_propagate_output_parse_inner(args, kwargs, result)
1246+
1247+
1248+
def _iast_propagate_output_parse_inner(args, kwargs, result):
1249+
print(f"_iast_propagate_output_parse_inner: {args}, {kwargs}, {result}")
11341250
try:
1251+
from ddtrace.appsec._iast._taint_tracking._taint_objects import get_tainted_ranges
1252+
from ddtrace.appsec._iast._taint_tracking._taint_objects import taint_pyobject
1253+
11351254
try:
11361255
from langchain_core.agents import AgentAction
11371256
from langchain_core.agents import AgentFinish
@@ -1141,12 +1260,17 @@ def taint_parser_output(func, instance, args, kwargs):
11411260
ranges = get_tainted_ranges(args[0])
11421261
if ranges:
11431262
source = ranges[0].source
1263+
print("WILL TAINT")
11441264
if isinstance(result, AgentAction):
1265+
print("TAINTED TOOL INPUT")
11451266
result.tool_input = taint_pyobject(result.tool_input, source.name, source.value, source.origin)
11461267
elif isinstance(result, AgentFinish) and "output" in result.return_values:
1268+
print("TAINTED OUTPUT")
11471269
values = result.return_values
11481270
values["output"] = taint_pyobject(values["output"], source.name, source.value, source.origin)
11491271
except Exception as e:
1272+
from ddtrace.appsec._iast._metrics import _set_iast_error_metric
1273+
11501274
_set_iast_error_metric("IAST propagation error. langchain taint_parser_output. {}".format(e))
11511275

11521276
return result

hatch.toml

+41
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,47 @@ fastapi = ["==0.94.1"]
601601
python = ["3.8", "3.10", "3.13"]
602602
fastapi = ["~=0.114.2"]
603603

604+
## ASM appsec_integrations_langchain
605+
606+
[envs.appsec_integrations_langchain]
607+
template = "appsec_integrations_langchain"
608+
dependencies = [
609+
"pytest",
610+
"pytest-asyncio",
611+
"pytest-cov",
612+
"hypothesis",
613+
"langchain{matrix:langchain:}",
614+
"langchain-experimental{matrix:langchain-experimental:}",
615+
]
616+
617+
[envs.appsec_integrations_langchain.env-vars]
618+
DD_TRACE_AGENT_URL = "http://testagent:9126"
619+
_DD_IAST_PATCH_MODULES = "benchmarks.,tests.appsec."
620+
DD_IAST_REQUEST_SAMPLING = "100"
621+
DD_IAST_DEDUPLICATION_ENABLED = "false"
622+
623+
[envs.appsec_integrations_langchain.scripts]
624+
test = [
625+
"uname -a",
626+
"pip freeze",
627+
"python -m pytest -vvv {args:tests/appsec/integrations/langchain_tests/}",
628+
]
629+
630+
[[envs.appsec_integrations_langchain.matrix]]
631+
python = ["3.9", "3.10", "3.11", "3.12", "3.13"]
632+
langchain = ["~=0.1"]
633+
langchain-experimental = ["~=0.1"]
634+
635+
[[envs.appsec_integrations_langchain.matrix]]
636+
python = ["3.9", "3.10", "3.11", "3.12", "3.13"]
637+
langchain = ["~=0.2"]
638+
langchain-experimental = ["~=0.2"]
639+
640+
[[envs.appsec_integrations_langchain.matrix]]
641+
python = ["3.9", "3.10", "3.11", "3.12", "3.13"]
642+
langchain = ["~=0.3"]
643+
langchain-experimental = ["~=0.3"]
644+
604645
## ASM FastAPI
605646

606647
[envs.appsec_threats_fastapi]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
features:
3+
- |
4+
Code Security: IAST support for langchain v0.1.0 and above.

tests/appsec/iast/conftest.py

+1-9
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def _start_iast_context_and_oce(span=None):
5252
if oce.acquire_request(span):
5353
start_iast_context()
5454
request_iast_enabled = True
55+
5556
set_iast_request_enabled(request_iast_enabled)
5657

5758

@@ -61,13 +62,6 @@ def _end_iast_context_and_oce(span=None):
6162

6263

6364
def iast_context(env, request_sampling=100.0, deduplication=False, asm_enabled=False):
64-
try:
65-
from ddtrace.contrib.internal.langchain.patch import patch as langchain_patch
66-
from ddtrace.contrib.internal.langchain.patch import unpatch as langchain_unpatch
67-
except Exception:
68-
langchain_patch = lambda: True # noqa: E731
69-
langchain_unpatch = lambda: True # noqa: E731
70-
7165
class MockSpan:
7266
_trace_id_64bits = 17577308072598193742
7367

@@ -87,7 +81,6 @@ class MockSpan:
8781
cmdi_patch()
8882
header_injection_patch()
8983
code_injection_patch()
90-
langchain_patch()
9184
patch_common_modules()
9285
yield
9386
unpatch_common_modules()
@@ -97,7 +90,6 @@ class MockSpan:
9790
cmdi_unpatch()
9891
header_injection_unpatch()
9992
code_injection_unpatch()
100-
langchain_unpatch()
10193
_end_iast_context_and_oce()
10294

10395

tests/appsec/iast/iast_utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def get_line_and_hash(label: Text, vuln_type: Text, filename=None, fixed_line=No
5353
def _iast_patched_module_and_patched_source(module_name, new_module_object=False):
5454
module = importlib.import_module(module_name)
5555
module_path, patched_source = astpatch_module(module)
56+
assert patched_source is not None
5657
compiled_code = compile(patched_source, module_path, "exec")
5758
module_changed = types.ModuleType(module_name) if new_module_object else module
5859
exec(compiled_code, module_changed.__dict__)

tests/appsec/integrations/fixtures/patch_langchain.py

-13
This file was deleted.

0 commit comments

Comments
 (0)