From 74cc57979cc9a0ada35629116cc1984a9247afc1 Mon Sep 17 00:00:00 2001 From: saichandrapandraju Date: Mon, 14 Apr 2025 21:51:53 -0400 Subject: [PATCH 1/7] Introduce RemoteLLMAttribution to support LLMAttribution for remotely hosted models that provide logprobs (like vLLM) --- captum/attr/__init__.py | 5 + captum/attr/_core/llm_attr.py | 117 +++++++++ captum/attr/_core/remote_provider.py | 93 +++++++ setup.py | 4 + tests/attr/test_llm_attr.py | 364 ++++++++++++++++++++++++++- 5 files changed, 581 insertions(+), 2 deletions(-) create mode 100644 captum/attr/_core/remote_provider.py mode change 100755 => 100644 setup.py diff --git a/captum/attr/__init__.py b/captum/attr/__init__.py index a33cd862dd..ee006bbe2d 100644 --- a/captum/attr/__init__.py +++ b/captum/attr/__init__.py @@ -27,7 +27,9 @@ LLMAttribution, LLMAttributionResult, LLMGradientAttribution, + RemoteLLMAttribution, ) +from captum.attr._core.remote_provider import RemoteLLMProvider, VLLMProvider from captum.attr._core.lrp import LRP from captum.attr._core.neuron.neuron_conductance import NeuronConductance from captum.attr._core.neuron.neuron_deep_lift import NeuronDeepLift, NeuronDeepLiftShap @@ -111,6 +113,9 @@ "LLMAttribution", "LLMAttributionResult", "LLMGradientAttribution", + "RemoteLLMAttribution", + "RemoteLLMProvider", + "VLLMProvider", "InternalInfluence", "InterpretableInput", "LayerGradCam", diff --git a/captum/attr/_core/llm_attr.py b/captum/attr/_core/llm_attr.py index 3466ad4996..772b06838f 100644 --- a/captum/attr/_core/llm_attr.py +++ b/captum/attr/_core/llm_attr.py @@ -35,6 +35,7 @@ TextTokenInput, ) from torch import nn, Tensor +from captum.attr._core.remote_provider import RemoteLLMProvider DEFAULT_GEN_ARGS: Dict[str, Any] = { "max_new_tokens": 25, @@ -892,3 +893,119 @@ def forward( # the attribution target is limited to the log probability return token_log_probs + + +class RemoteLLMAttribution(LLMAttribution): + """ + Attribution class for large language models that are hosted remotely and offer logprob APIs. + """ + def __init__( + self, + attr_method: PerturbationAttribution, + tokenizer: TokenizerLike, + provider: RemoteLLMProvider, + attr_target: str = "log_prob", + ) -> None: + """ + Args: + attr_method: Instance of a supported perturbation attribution class + tokenizer (Tokenizer): tokenizer of the llm model used in the attr_method + provider: Remote LLM provider that implements the RemoteLLMProvider protocol + attr_target: attribute towards log probability or probability. + Available values ["log_prob", "prob"] + Default: "log_prob" + """ + super().__init__( + attr_method=attr_method, + tokenizer=tokenizer, + attr_target=attr_target, + ) + + self.provider = provider + self.attr_method.forward_func = self._remote_forward_func + + def _get_target_tokens( + self, + inp: InterpretableInput, + target: Union[str, torch.Tensor, None] = None, + skip_tokens: Union[List[int], List[str], None] = None, + gen_args: Optional[Dict[str, Any]] = None + ) -> Tensor: + """ + Get the target tokens for the remote LLM provider. + """ + assert isinstance( + inp, self.SUPPORTED_INPUTS + ), f"RemoteLLMAttribution does not support input type {type(inp)}" + + if target is None: + # generate when None with remote provider + assert hasattr(self.provider, "generate") and callable(self.provider.generate), ( + "The provider does not have generate function for generating target sequence." + "Target must be given for attribution" + ) + if not gen_args: + gen_args = DEFAULT_GEN_ARGS + + model_inp = self._format_model_input(inp.to_model_input()) + target_str = self.provider.generate(model_inp, **gen_args) + target_tokens = self.tokenizer.encode(target_str, return_tensors="pt", add_special_tokens=False)[0] + + else: + target_tokens = super()._get_target_tokens(inp, target, skip_tokens, gen_args) + + return target_tokens + + def _format_model_input(self, model_input: Union[str, Tensor]) -> str: + """ + Format the model input for the remote LLM provider. + """ + # return str input + if isinstance(model_input, Tensor): + return self.tokenizer.decode(model_input.flatten()) + return model_input + + def _remote_forward_func( + self, + perturbed_tensor: Union[None, Tensor], + inp: InterpretableInput, + target_tokens: Tensor, + use_cached_outputs: bool = False, + _inspect_forward: Optional[Callable[[str, str, List[float]], None]] = None, + ) -> Tensor: + """ + Forward function for the remote LLM provider. + """ + + perturbed_input = self._format_model_input(inp.to_model_input(perturbed_tensor)) + + target_str:str = self.tokenizer.decode(target_tokens) + + target_token_probs = self.provider.get_logprobs(input_prompt=perturbed_input, target_str=target_str, tokenizer=self.tokenizer) + + assert len(target_token_probs) == target_tokens.size()[0], ( + f"Number of token logprobs from provider ({len(target_token_probs)}) " + f"does not match expected target token length ({target_tokens.size()[0]})" + ) + + log_prob_list: List[Tensor] = list(map(torch.tensor, target_token_probs)) + + total_log_prob = torch.sum(torch.stack(log_prob_list), dim=0) + # 1st element is the total prob, rest are the target tokens + # add a leading dim for batch even we only support single instance for now + if self.include_per_token_attr: + target_log_probs = torch.stack( + [total_log_prob, *log_prob_list], dim=0 + ).unsqueeze(0) + else: + target_log_probs = total_log_prob + target_probs = torch.exp(target_log_probs) + + if _inspect_forward: + prompt = perturbed_input + response = self.tokenizer.decode(target_tokens) + + # callback for externals to inspect (prompt, response, seq_prob) + _inspect_forward(prompt, response, target_probs[0].tolist()) + + return target_probs if self.attr_target != "log_prob" else target_log_probs \ No newline at end of file diff --git a/captum/attr/_core/remote_provider.py b/captum/attr/_core/remote_provider.py new file mode 100644 index 0000000000..9d5a0461a9 --- /dev/null +++ b/captum/attr/_core/remote_provider.py @@ -0,0 +1,93 @@ +from abc import ABC, abstractmethod +from typing import Any, List, Optional +from captum._utils.typing import TokenizerLike +from openai import OpenAI +import os + +class RemoteLLMProvider(ABC): + """All remote LLM providers that offer logprob via API (like vLLM) extends this class.""" + + api_url: str + + @abstractmethod + def generate( + self, + prompt: str, + **gen_args: Any + ) -> str: + """ + Args: + prompt: The input prompt to generate from + gen_args: Additional generation arguments + + Returns: + The generated text. + """ + ... + + @abstractmethod + def get_logprobs( + self, + input_prompt: str, + target_str: str, + tokenizer: Optional[TokenizerLike] = None + ) -> List[float]: + """ + Get the log probabilities for all tokens in the target string. + + Args: + input_prompt: The input prompt + target_str: The target string + tokenizer: The tokenizer to use + + Returns: + A list of log probabilities corresponding to each token in the target prompt. + For a `target_str` of `t` tokens, this method returns a list of logprobs of length `k`. + """ + ... + +class VLLMProvider(RemoteLLMProvider): + def __init__(self, api_url: str): + assert api_url.strip() != "", "API URL is required" + + self.api_url = api_url + self.client = OpenAI(base_url=self.api_url, + api_key=os.getenv("OPENAI_API_KEY", "EMPTY") + ) + self.model_name = self.client.models.list().data[0].id + + + def generate(self, prompt: str, **gen_args: Any) -> str: + if not 'max_tokens' in gen_args: + gen_args['max_tokens'] = gen_args.pop('max_new_tokens', 25) + if 'do_sample' in gen_args: + gen_args.pop('do_sample') + + response = self.client.completions.create( + model=self.model_name, + prompt=prompt, + **gen_args + ) + + return response.choices[0].text + + def get_logprobs(self, input_prompt: str, target_str: str, tokenizer: Optional[TokenizerLike] = None) -> List[float]: + assert tokenizer is not None, "Tokenizer is required for VLLM provider" + + num_target_str_tokens = len(tokenizer.encode(target_str, add_special_tokens=False)) + + prompt = input_prompt + target_str + + response = self.client.completions.create( + model=self.model_name, + prompt=prompt, + temperature=0.0, + max_tokens=1, + extra_body={"prompt_logprobs": 0} + ) + prompt_logprobs = [] + for probs in response.choices[0].prompt_logprobs[1:]: + prompt_logprobs.append(list(probs.values())[0]['logprob']) + + return prompt_logprobs[-num_target_str_tokens:] + \ No newline at end of file diff --git a/setup.py b/setup.py old mode 100755 new mode 100644 index 38cb97d5b3..2f473b5327 --- a/setup.py +++ b/setup.py @@ -63,9 +63,12 @@ def report(*args): TEST_REQUIRES = ["pytest", "pytest-cov", "parameterized", "flask", "flask-compress"] +REMOTE_REQUIRES = ["openai"] + DEV_REQUIRES = ( INSIGHTS_REQUIRES + TEST_REQUIRES + + REMOTE_REQUIRES + [ "black", "flake8", @@ -169,6 +172,7 @@ def get_package_files(root, subdirs): "insights": INSIGHTS_REQUIRES, "test": TEST_REQUIRES, "tutorials": TUTORIALS_REQUIRES, + "remote": REMOTE_REQUIRES, }, package_data={"captum": package_files}, data_files=[ diff --git a/tests/attr/test_llm_attr.py b/tests/attr/test_llm_attr.py index d6f1a2a4ea..1f557e8e25 100644 --- a/tests/attr/test_llm_attr.py +++ b/tests/attr/test_llm_attr.py @@ -21,14 +21,15 @@ import torch from captum._utils.models.linear_model import SkLearnLasso -from captum._utils.typing import BatchEncodingType +from captum._utils.typing import BatchEncodingType, TokenizerLike from captum.attr._core.feature_ablation import FeatureAblation from captum.attr._core.kernel_shap import KernelShap from captum.attr._core.layer.layer_gradient_shap import LayerGradientShap from captum.attr._core.layer.layer_gradient_x_activation import LayerGradientXActivation from captum.attr._core.layer.layer_integrated_gradients import LayerIntegratedGradients from captum.attr._core.lime import Lime -from captum.attr._core.llm_attr import LLMAttribution, LLMGradientAttribution +from captum.attr._core.llm_attr import LLMAttribution, LLMGradientAttribution, RemoteLLMAttribution +from captum.attr._core.remote_provider import RemoteLLMProvider from captum.attr._core.shapley_value import ShapleyValues, ShapleyValueSampling from captum.attr._utils.attribution import GradientAttribution, PerturbationAttribution from captum.attr._utils.interpretable_input import TextTemplateInput, TextTokenInput @@ -669,3 +670,362 @@ def test_llm_attr_with_skip_tensor_target(self) -> None: self.assertEqual(token_attr.shape, (5, 4)) self.assertEqual(res.input_tokens, ["", "a", "b", "c"]) self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"]) + +class DummyRemoteLLMProvider(RemoteLLMProvider): + def __init__(self, deterministic_logprobs: bool = False) -> None: + self.api_url = "https://test-api.com" + self.deterministic_logprobs = deterministic_logprobs + + def generate(self, prompt: str, **gen_args: Any) -> str: + assert "mock_response" in gen_args, "must mock response to use DummyRemoteLLMProvider to generate" + return gen_args["mock_response"] + + def get_logprobs(self, input_prompt: str, target_str: str, tokenizer: Optional[TokenizerLike] = None) -> List[float]: + assert tokenizer is not None, "Tokenizer is required" + prompt = input_prompt + target_str + tokens = tokenizer.encode(prompt, add_special_tokens=False) + num_tokens = len(tokens) + + num_target_str_tokens = len(tokenizer.encode(target_str, add_special_tokens=False)) + + logprobs = [] + + for i in range(num_tokens): + # Start with a base value + logprob = -0.1 - (0.01 * i) + + # Make sensitive to key features + if "a" not in prompt: + logprob -= 0.1 + if "c" not in prompt: + logprob -= 0.2 + if "d" not in prompt: + logprob -= 0.3 + if "f" not in prompt: + logprob -= 0.4 + + logprobs.append(logprob) + + return logprobs[-num_target_str_tokens:] + +@parameterized_class( + ("device",), [("cpu",), ("cuda",)] if torch.cuda.is_available() else [("cpu",)] +) +class TestRemoteLLMAttr(BaseTest): + # pyre-fixme[13]: Attribute `device` is never initialized. + device: str + + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @parameterized.expand( + [ + ( + AttrClass, + delta, + n_samples, + torch.tensor(true_seq_attr), + torch.tensor(true_tok_attr), + ) + for AttrClass, delta, n_samples, true_seq_attr, true_tok_attr in zip( + (FeatureAblation, ShapleyValueSampling, ShapleyValues), # AttrClass + (0.001, 0.001, 0.001), # delta + (None, 1000, None), # n_samples + ( # true_seq_attr + [0.5, 1.0, 1.5, 2.0], # FeatureAblation + [0.5, 1.0, 1.5, 2.0], # ShapleyValueSampling + [0.5, 1.0, 1.5, 2.0], # ShapleyValues + ), + ( # true_tok_attr + [ # FeatureAblation + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + ], + [ # ShapleyValueSampling + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + ], + [ # ShapleyValues + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + ], + ), + ) + ] + ) + def test_remote_llm_attr( + self, + AttrClass: Type[PerturbationAttribution], + delta: float, + n_samples: Optional[int], + true_seq_attr: Tensor, + true_tok_attr: Tensor, + ) -> None: + attr_kws: Dict[str, int] = {} + if n_samples is not None: + attr_kws["n_samples"] = n_samples + + # In remote mode, we don't need the actual model, this is just a placeholder + placeholder_model = torch.nn.Module() + placeholder_model.device = self.device + + tokenizer = DummyTokenizer() + provider = DummyRemoteLLMProvider(deterministic_logprobs=True) + attr_method = AttrClass(placeholder_model) + remote_llm_attr = RemoteLLMAttribution( + attr_method=attr_method, + tokenizer=tokenizer, + provider=provider, + ) + + # from TestLLMAttr + inp = TextTemplateInput("{} b {} {} e {}", ["a", "c", "d", "f"]) + res = remote_llm_attr.attribute( + inp, + "m n o p q", + skip_tokens=[0], + # use_cached_outputs=self.use_cached_outputs, + # pyre-fixme[6]: In call `LLMAttribution.attribute`, + # for 4th positional argument, expected + # `Optional[typing.Callable[..., typing.Any]]` but got `int`. + **attr_kws, # type: ignore + ) + + self.assertEqual(res.seq_attr.shape, (4,)) + self.assertEqual(cast(Tensor, res.token_attr).shape, (5, 4)) + self.assertEqual(res.input_tokens, ["a", "c", "d", "f"]) + self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"]) + self.assertEqual(res.seq_attr.device.type, self.device) + self.assertEqual(cast(Tensor, res.token_attr).device.type, self.device) + + assertTensorAlmostEqual( + self, + actual=res.seq_attr, + expected=true_seq_attr, + delta=delta, + mode="max", + ) + assertTensorAlmostEqual( + self, + actual=res.token_attr, + expected=true_tok_attr, + delta=delta, + mode="max", + ) + + def test_remote_llm_attr_without_target(self) -> None: + # In remote mode, we don't need the actual model, this is just a placeholder + placeholder_model = torch.nn.Module() + placeholder_model.device = self.device + + tokenizer = DummyTokenizer() + provider = DummyRemoteLLMProvider(deterministic_logprobs=True) + attr_method = FeatureAblation(placeholder_model) + remote_llm_attr = RemoteLLMAttribution( + attr_method=attr_method, + tokenizer=tokenizer, + provider=provider, + ) + + # from TestLLMAttr + inp = TextTemplateInput("{} b {} {} e {}", ["a", "c", "d", "f"]) + res = remote_llm_attr.attribute( + inp, + gen_args={"mock_response": "x y z"}, + # use_cached_outputs=self.use_cached_outputs, + ) + + self.assertEqual(res.seq_attr.shape, (4,)) + self.assertEqual(cast(Tensor, res.token_attr).shape, (3, 4)) + self.assertEqual(res.input_tokens, ["a", "c", "d", "f"]) + self.assertEqual(res.output_tokens, ["x", "y", "z"]) + self.assertEqual(res.seq_attr.device.type, self.device) + self.assertEqual(cast(Tensor, res.token_attr).device.type, self.device) + + def test_remote_llm_attr_fa_log_prob(self) -> None: + # In remote mode, we don't need the actual model, this is just a placeholder + placeholder_model = torch.nn.Module() + placeholder_model.device = self.device + + tokenizer = DummyTokenizer() + provider = DummyRemoteLLMProvider(deterministic_logprobs=True) + attr_method = FeatureAblation(placeholder_model) + remote_llm_attr = RemoteLLMAttribution( + attr_method=attr_method, + tokenizer=tokenizer, + provider=provider, + attr_target="log_prob", + ) + + # from TestLLMAttr + inp = TextTemplateInput("{} b {} {} e {}", ["a", "c", "d", "f"]) + res = remote_llm_attr.attribute( + inp, + "m n o p q", + skip_tokens=[0], + # use_cached_outputs=self.use_cached_outputs, + ) + + # With FeatureAblation, the seq attr in log_prob + # equals to the sum of each token attr + assertTensorAlmostEqual(self, res.seq_attr, cast(Tensor, res.token_attr).sum(0)) + + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @parameterized.expand( + [ + ( + AttrClass, + delta, + n_samples, + torch.tensor(true_seq_attr), + interpretable_model, + ) + for AttrClass, delta, n_samples, true_seq_attr, interpretable_model in zip( + (Lime, KernelShap), + (0.003, 0.001), + (1000, 2500), + ( + [0.4956, 0.9957, 1.4959, 1.9959], + [0.5, 1.0, 1.5, 2.0], + ), + (SkLearnLasso(alpha=0.001), None), + ) + ] + ) + def test_remote_llm_attr_without_token( + self, + AttrClass: Type[PerturbationAttribution], + delta: float, + n_samples: int, + true_seq_attr: Tensor, + interpretable_model: Optional[nn.Module] = None, + ) -> None: + init_kws = {} + if interpretable_model is not None: + init_kws["interpretable_model"] = interpretable_model + attr_kws: Dict[str, int] = {} + if n_samples is not None: + attr_kws["n_samples"] = n_samples + + # In remote mode, we don't need the actual model, this is just a placeholder + placeholder_model = torch.nn.Module() + placeholder_model.device = self.device + + tokenizer = DummyTokenizer() + provider = DummyRemoteLLMProvider(deterministic_logprobs=True) + attr_method = AttrClass(placeholder_model, **init_kws) + remote_llm_attr = RemoteLLMAttribution( + attr_method=attr_method, + tokenizer=tokenizer, + provider=provider, + attr_target="log_prob", + ) + + inp = TextTemplateInput("{} b {} {} e {}", ["a", "c", "d", "f"]) + res = remote_llm_attr.attribute( + inp, + "m n o p q", + skip_tokens=[0], + # use_cached_outputs=self.use_cached_outputs, + **attr_kws, # type: ignore + ) + + self.assertEqual(res.seq_attr.shape, (4,)) + self.assertEqual(res.seq_attr.device.type, self.device) + self.assertEqual(res.token_attr, None) + self.assertEqual(res.input_tokens, ["a", "c", "d", "f"]) + self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"]) + assertTensorAlmostEqual( + self, + actual=res.seq_attr, + expected=true_seq_attr, + delta=delta, + mode="max", + ) + def test_remote_llm_attr_futures_not_implemented(self) -> None: + # In remote mode, we don't need the actual model, this is just a placeholder + placeholder_model = torch.nn.Module() + placeholder_model.device = self.device + + tokenizer = DummyTokenizer() + provider = DummyRemoteLLMProvider() + attr_method = FeatureAblation(placeholder_model) + remote_llm_attr = RemoteLLMAttribution( + attr_method=attr_method, + tokenizer=tokenizer, + provider=provider, + ) + + # from TestLLMAttr + attributions = None + with self.assertRaises(NotImplementedError): + attributions = remote_llm_attr.attribute_future() + self.assertEqual(attributions, None) + + def test_remote_llm_attr_with_no_skip_tokens(self) -> None: + # In remote mode, we don't need the actual model, this is just a placeholder + placeholder_model = torch.nn.Module() + placeholder_model.device = self.device + + tokenizer = DummyTokenizer() + provider = DummyRemoteLLMProvider(deterministic_logprobs=True) + attr_method = FeatureAblation(placeholder_model) + remote_llm_fa = RemoteLLMAttribution( + attr_method=attr_method, + tokenizer=tokenizer, + provider=provider, + ) + + # from TestLLMAttr + inp = TextTokenInput("a b c", tokenizer) + res = remote_llm_fa.attribute( + inp, + "m n o p q" + ) + + # 5 output tokens, 4 input tokens including sos + self.assertEqual(res.seq_attr.shape, (4,)) + assert res.token_attr is not None + self.assertIsNotNone(res.token_attr) + token_attr = res.token_attr + self.assertEqual(token_attr.shape, (6, 4)) + self.assertEqual(res.input_tokens, ["", "a", "b", "c"]) + self.assertEqual(res.output_tokens, ["", "m", "n", "o", "p", "q"]) + + def test_remote_llm_attr_with_skip_tensor_target(self) -> None: + # In remote mode, we don't need the actual model, this is just a placeholder + placeholder_model = torch.nn.Module() + placeholder_model.device = self.device + + tokenizer = DummyTokenizer() + provider = DummyRemoteLLMProvider(deterministic_logprobs=True) + attr_method = FeatureAblation(placeholder_model) + remote_llm_fa = RemoteLLMAttribution( + attr_method=attr_method, + tokenizer=tokenizer, + provider=provider, + ) + + # from TestLLMAttr + inp = TextTokenInput("a b c", tokenizer) + res = remote_llm_fa.attribute( + inp, + torch.tensor(tokenizer.encode("m n o p q")), + skip_tokens=[0], + ) + + # 5 output tokens, 4 input tokens including sos + self.assertEqual(res.seq_attr.shape, (4,)) + assert res.token_attr is not None + self.assertIsNotNone(res.token_attr) + token_attr = res.token_attr + self.assertEqual(token_attr.shape, (5, 4)) + self.assertEqual(res.input_tokens, ["", "a", "b", "c"]) + self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"]) From e6c929bd160e5685241614767d15b6d37dd161a0 Mon Sep 17 00:00:00 2001 From: saichandrapandraju Date: Thu, 24 Apr 2025 18:55:16 -0400 Subject: [PATCH 2/7] add optional 'model_name' for VLLMProvider and add better exception handling --- captum/attr/_core/llm_attr.py | 13 ++- captum/attr/_core/remote_provider.py | 152 ++++++++++++++++++++++----- 2 files changed, 133 insertions(+), 32 deletions(-) diff --git a/captum/attr/_core/llm_attr.py b/captum/attr/_core/llm_attr.py index 772b06838f..202c2bbdc9 100644 --- a/captum/attr/_core/llm_attr.py +++ b/captum/attr/_core/llm_attr.py @@ -975,18 +975,21 @@ def _remote_forward_func( ) -> Tensor: """ Forward function for the remote LLM provider. + + Raises: + ValueError: If the number of token logprobs doesn't match expected length """ - perturbed_input = self._format_model_input(inp.to_model_input(perturbed_tensor)) target_str:str = self.tokenizer.decode(target_tokens) target_token_probs = self.provider.get_logprobs(input_prompt=perturbed_input, target_str=target_str, tokenizer=self.tokenizer) - assert len(target_token_probs) == target_tokens.size()[0], ( - f"Number of token logprobs from provider ({len(target_token_probs)}) " - f"does not match expected target token length ({target_tokens.size()[0]})" - ) + if len(target_token_probs) != target_tokens.size()[0]: + raise ValueError( + f"Number of token logprobs from provider ({len(target_token_probs)}) " + f"does not match expected target token length ({target_tokens.size()[0]})" + ) log_prob_list: List[Tensor] = list(map(torch.tensor, target_token_probs)) diff --git a/captum/attr/_core/remote_provider.py b/captum/attr/_core/remote_provider.py index 9d5a0461a9..149337b962 100644 --- a/captum/attr/_core/remote_provider.py +++ b/captum/attr/_core/remote_provider.py @@ -47,47 +47,145 @@ def get_logprobs( ... class VLLMProvider(RemoteLLMProvider): - def __init__(self, api_url: str): - assert api_url.strip() != "", "API URL is required" + def __init__(self, api_url: str, model_name: Optional[str] = None): + """ + Initialize a vLLM provider. + + Args: + api_url: The URL of the vLLM API + model_name: The name of the model to use. If None, the first model from + the API's model list will be used. + Raises: + ValueError: If api_url is empty or model_name is not in the API's model list + ConnectionError: If API connection fails + """ + if not api_url.strip(): + raise ValueError("API URL is required") + self.api_url = api_url - self.client = OpenAI(base_url=self.api_url, + + try: + self.client = OpenAI(base_url=self.api_url, api_key=os.getenv("OPENAI_API_KEY", "EMPTY") ) - self.model_name = self.client.models.list().data[0].id - + + # If model_name is not provided, get the first available model from the API + if model_name is None: + models = self.client.models.list().data + if not models: + raise ValueError("No models available from the vLLM API") + self.model_name = models[0].id + else: + self.model_name = model_name + + except ConnectionError as e: + raise ConnectionError(f"Failed to connect to vLLM API: {str(e)}") + except Exception as e: + raise Exception(f"Unexpected error while initializing vLLM provider: {str(e)}") def generate(self, prompt: str, **gen_args: Any) -> str: - if not 'max_tokens' in gen_args: + """ + Generate text using the vLLM API. + + Args: + prompt: The input prompt for text generation + **gen_args: Additional generation arguments + + Returns: + str: The generated text + + Raises: + KeyError: If API response is missing expected data + ConnectionError: If connection to API fails + """ + # Parameter normalization + if 'max_tokens' not in gen_args: gen_args['max_tokens'] = gen_args.pop('max_new_tokens', 25) if 'do_sample' in gen_args: gen_args.pop('do_sample') + + try: + response = self.client.completions.create( + model=self.model_name, + prompt=prompt, + **gen_args + ) + if not hasattr(response, 'choices') or not response.choices: + raise KeyError("API response missing expected 'choices' data") + + return response.choices[0].text - response = self.client.completions.create( - model=self.model_name, - prompt=prompt, - **gen_args - ) - - return response.choices[0].text + except ConnectionError as e: + raise ConnectionError(f"Failed to connect to vLLM API: {str(e)}") + except Exception as e: + raise Exception(f"Unexpected error during text generation: {str(e)}") - def get_logprobs(self, input_prompt: str, target_str: str, tokenizer: Optional[TokenizerLike] = None) -> List[float]: - assert tokenizer is not None, "Tokenizer is required for VLLM provider" + def get_logprobs( + self, + input_prompt: str, + target_str: str, + tokenizer: Optional[TokenizerLike] = None + ) -> List[float]: + """ + Get the log probabilities for all tokens in the target string. + + Args: + input_prompt: The input prompt + target_str: The target string + tokenizer: The tokenizer to use + + Returns: + A list of log probabilities corresponding to each token in the target prompt. + For a `target_str` of `t` tokens, this method returns a list of logprobs of length `k`. + + Raises: + ValueError: If tokenizer is None or target_str is empty or response format is invalid + KeyError: If API response is missing expected data + IndexError: If response format is unexpected + ConnectionError: If connection to API fails + """ + if tokenizer is None: + raise ValueError("Tokenizer is required for vLLM provider") + if not target_str: + raise ValueError("Target string cannot be empty") num_target_str_tokens = len(tokenizer.encode(target_str, add_special_tokens=False)) prompt = input_prompt + target_str + + try: + response = self.client.completions.create( + model=self.model_name, + prompt=prompt, + temperature=0.0, + max_tokens=1, + extra_body={"prompt_logprobs": 0} + ) + + if not hasattr(response, 'choices') or not response.choices: + raise KeyError("API response missing expected 'choices' data") + + if not hasattr(response.choices[0], 'prompt_logprobs'): + raise KeyError("API response missing 'prompt_logprobs' data") + + prompt_logprobs = [] + try: + for probs in response.choices[0].prompt_logprobs[1:]: + if not probs: + raise ValueError("Empty probability data in API response") + prompt_logprobs.append(list(probs.values())[0]['logprob']) + except (IndexError, KeyError) as e: + raise IndexError(f"Unexpected format in log probability data: {str(e)}") + + if len(prompt_logprobs) < num_target_str_tokens: + raise ValueError(f"Not enough logprobs received: expected {num_target_str_tokens}, got {len(prompt_logprobs)}") + + return prompt_logprobs[-num_target_str_tokens:] - response = self.client.completions.create( - model=self.model_name, - prompt=prompt, - temperature=0.0, - max_tokens=1, - extra_body={"prompt_logprobs": 0} - ) - prompt_logprobs = [] - for probs in response.choices[0].prompt_logprobs[1:]: - prompt_logprobs.append(list(probs.values())[0]['logprob']) - - return prompt_logprobs[-num_target_str_tokens:] + except ConnectionError as e: + raise ConnectionError(f"Failed to connect to vLLM API when getting logprobs: {str(e)}") + except Exception as e: + raise Exception(f"Unexpected error while getting log probabilities: {str(e)}") + \ No newline at end of file From 51826d02fe9df1fcb91ca62e291dfd088ee404c1 Mon Sep 17 00:00:00 2001 From: saichandrapandraju Date: Tue, 13 May 2025 09:24:38 -0700 Subject: [PATCH 3/7] create placeholder class attribute, remote provider improvements, and formatting --- captum/attr/__init__.py | 2 +- captum/attr/_core/llm_attr.py | 92 +++++++++----- captum/attr/_core/remote_provider.py | 182 ++++++++++++++++----------- tests/attr/test_llm_attr.py | 128 +++++++++---------- 4 files changed, 234 insertions(+), 170 deletions(-) diff --git a/captum/attr/__init__.py b/captum/attr/__init__.py index ee006bbe2d..e0d9b5bd41 100644 --- a/captum/attr/__init__.py +++ b/captum/attr/__init__.py @@ -29,7 +29,6 @@ LLMGradientAttribution, RemoteLLMAttribution, ) -from captum.attr._core.remote_provider import RemoteLLMProvider, VLLMProvider from captum.attr._core.lrp import LRP from captum.attr._core.neuron.neuron_conductance import NeuronConductance from captum.attr._core.neuron.neuron_deep_lift import NeuronDeepLift, NeuronDeepLiftShap @@ -45,6 +44,7 @@ ) from captum.attr._core.noise_tunnel import NoiseTunnel from captum.attr._core.occlusion import Occlusion +from captum.attr._core.remote_provider import RemoteLLMProvider, VLLMProvider from captum.attr._core.saliency import Saliency from captum.attr._core.shapley_value import ShapleyValues, ShapleyValueSampling from captum.attr._models.base import ( diff --git a/captum/attr/_core/llm_attr.py b/captum/attr/_core/llm_attr.py index 202c2bbdc9..d3a4b4e363 100644 --- a/captum/attr/_core/llm_attr.py +++ b/captum/attr/_core/llm_attr.py @@ -23,6 +23,7 @@ from captum.attr._core.layer.layer_gradient_x_activation import LayerGradientXActivation from captum.attr._core.layer.layer_integrated_gradients import LayerIntegratedGradients from captum.attr._core.lime import Lime +from captum.attr._core.remote_provider import RemoteLLMProvider from captum.attr._core.shapley_value import ShapleyValues, ShapleyValueSampling from captum.attr._utils.attribution import ( Attribution, @@ -35,7 +36,6 @@ TextTokenInput, ) from torch import nn, Tensor -from captum.attr._core.remote_provider import RemoteLLMProvider DEFAULT_GEN_ARGS: Dict[str, Any] = { "max_new_tokens": 25, @@ -895,10 +895,29 @@ def forward( return token_log_probs +class _PlaceholderModel: + """ + Simple placeholder model that can be used with + RemoteLLMAttribution without needing a real model. + This can be acheived by `lambda *_:0` but BaseLLMAttribution expects + `device`, so creating this class to set the device. + """ + + def __init__(self): + self.device = torch.device("cpu") + + def __call__(self, *args, **kwargs): + return 0 + + class RemoteLLMAttribution(LLMAttribution): """ - Attribution class for large language models that are hosted remotely and offer logprob APIs. + Attribution class for large language models + that are hosted remotely and offer logprob APIs. """ + + placeholder_model = _PlaceholderModel() + def __init__( self, attr_method: PerturbationAttribution, @@ -920,7 +939,7 @@ def __init__( tokenizer=tokenizer, attr_target=attr_target, ) - + self.provider = provider self.attr_method.forward_func = self._remote_forward_func @@ -929,42 +948,52 @@ def _get_target_tokens( inp: InterpretableInput, target: Union[str, torch.Tensor, None] = None, skip_tokens: Union[List[int], List[str], None] = None, - gen_args: Optional[Dict[str, Any]] = None - ) -> Tensor: + gen_args: Optional[Dict[str, Any]] = None, + ) -> Tensor: """ Get the target tokens for the remote LLM provider. """ assert isinstance( inp, self.SUPPORTED_INPUTS ), f"RemoteLLMAttribution does not support input type {type(inp)}" - + if target is None: # generate when None with remote provider - assert hasattr(self.provider, "generate") and callable(self.provider.generate), ( - "The provider does not have generate function for generating target sequence." + assert hasattr(self.provider, "generate") and callable( + self.provider.generate + ), ( + "The provider does not have generate function" + " for generating target sequence." "Target must be given for attribution" ) if not gen_args: gen_args = DEFAULT_GEN_ARGS - - model_inp = self._format_model_input(inp.to_model_input()) + + model_inp = self._format_remote_model_input(inp.to_model_input()) target_str = self.provider.generate(model_inp, **gen_args) - target_tokens = self.tokenizer.encode(target_str, return_tensors="pt", add_special_tokens=False)[0] - + target_tokens = self.tokenizer.encode( + target_str, return_tensors="pt", add_special_tokens=False + )[0] + else: - target_tokens = super()._get_target_tokens(inp, target, skip_tokens, gen_args) - + target_tokens = super()._get_target_tokens( + inp, target, skip_tokens, gen_args + ) + return target_tokens - - def _format_model_input(self, model_input: Union[str, Tensor]) -> str: + + def _format_remote_model_input(self, model_input: Union[str, Tensor]) -> str: """ Format the model input for the remote LLM provider. + Convert tokenized tensor to str + to make RemoteLLMAttribution work with model inputs of both + raw text and text token tensors """ # return str input if isinstance(model_input, Tensor): return self.tokenizer.decode(model_input.flatten()) return model_input - + def _remote_forward_func( self, perturbed_tensor: Union[None, Tensor], @@ -975,24 +1004,31 @@ def _remote_forward_func( ) -> Tensor: """ Forward function for the remote LLM provider. - + Raises: ValueError: If the number of token logprobs doesn't match expected length """ - perturbed_input = self._format_model_input(inp.to_model_input(perturbed_tensor)) - - target_str:str = self.tokenizer.decode(target_tokens) - - target_token_probs = self.provider.get_logprobs(input_prompt=perturbed_input, target_str=target_str, tokenizer=self.tokenizer) - + perturbed_input = self._format_remote_model_input( + inp.to_model_input(perturbed_tensor) + ) + + target_str: str = self.tokenizer.decode(target_tokens) + + target_token_probs = self.provider.get_logprobs( + input_prompt=perturbed_input, + target_str=target_str, + tokenizer=self.tokenizer, + ) + if len(target_token_probs) != target_tokens.size()[0]: raise ValueError( f"Number of token logprobs from provider ({len(target_token_probs)}) " - f"does not match expected target token length ({target_tokens.size()[0]})" + f"does not match expected target " + f"token length ({target_tokens.size()[0]})" ) - + log_prob_list: List[Tensor] = list(map(torch.tensor, target_token_probs)) - + total_log_prob = torch.sum(torch.stack(log_prob_list), dim=0) # 1st element is the total prob, rest are the target tokens # add a leading dim for batch even we only support single instance for now @@ -1011,4 +1047,4 @@ def _remote_forward_func( # callback for externals to inspect (prompt, response, seq_prob) _inspect_forward(prompt, response, target_probs[0].tolist()) - return target_probs if self.attr_target != "log_prob" else target_log_probs \ No newline at end of file + return target_probs if self.attr_target != "log_prob" else target_log_probs diff --git a/captum/attr/_core/remote_provider.py b/captum/attr/_core/remote_provider.py index 149337b962..5d7a94f647 100644 --- a/captum/attr/_core/remote_provider.py +++ b/captum/attr/_core/remote_provider.py @@ -1,146 +1,175 @@ +import logging +import os from abc import ABC, abstractmethod from typing import Any, List, Optional + from captum._utils.typing import TokenizerLike -from openai import OpenAI -import os + class RemoteLLMProvider(ABC): - """All remote LLM providers that offer logprob via API (like vLLM) extends this class.""" - + """All remote LLM providers that offer logprob via API + (like vLLM) extends this class.""" + api_url: str - + @abstractmethod - def generate( - self, - prompt: str, - **gen_args: Any - ) -> str: + def generate(self, prompt: str, **gen_args: Any) -> str: """ Args: prompt: The input prompt to generate from gen_args: Additional generation arguments - + Returns: The generated text. """ ... - + @abstractmethod def get_logprobs( - self, + self, input_prompt: str, target_str: str, - tokenizer: Optional[TokenizerLike] = None + tokenizer: Optional[TokenizerLike] = None, ) -> List[float]: """ Get the log probabilities for all tokens in the target string. - + Args: input_prompt: The input prompt target_str: The target string tokenizer: The tokenizer to use - + Returns: - A list of log probabilities corresponding to each token in the target prompt. - For a `target_str` of `t` tokens, this method returns a list of logprobs of length `k`. + A list of log probabilities corresponding to + each token in the target prompt. + For a `target_str` of `t` tokens, + this method returns a list of logprobs of length `k`. """ ... + class VLLMProvider(RemoteLLMProvider): def __init__(self, api_url: str, model_name: Optional[str] = None): """ Initialize a vLLM provider. - + Args: api_url: The URL of the vLLM API - model_name: The name of the model to use. If None, the first model from + model_name: The name of the model to use. If None, the first model from the API's model list will be used. - + + Environment Variables: + OPENAI_API_KEY: If not set, "EMPTY" will be used as the API key. + Raises: ValueError: If api_url is empty or model_name is not in the API's model list ConnectionError: If API connection fails + ImportError: If the openai package is not installed """ + try: + from openai import OpenAI + except ImportError: + raise ImportError( + "The 'openai' package is required to use the VLLMProvider." + "You can install it by either:\n" + "1. Installing captum with remote dependencies: " + "`pip install captum[remote]` OR\n" + "2. Installing openai directly: `pip install openai`" + ) + if not api_url.strip(): raise ValueError("API URL is required") - + self.api_url = api_url try: - self.client = OpenAI(base_url=self.api_url, - api_key=os.getenv("OPENAI_API_KEY", "EMPTY") - ) - + self.client = OpenAI( + base_url=self.api_url, api_key=os.getenv("OPENAI_API_KEY", "EMPTY") + ) + # If model_name is not provided, get the first available model from the API if model_name is None: models = self.client.models.list().data if not models: raise ValueError("No models available from the vLLM API") self.model_name = models[0].id + logging.info( + f"No model_name is specified for VLLMProvider." + f" Using first available model: {self.model_name}" + ) else: self.model_name = model_name except ConnectionError as e: raise ConnectionError(f"Failed to connect to vLLM API: {str(e)}") except Exception as e: - raise Exception(f"Unexpected error while initializing vLLM provider: {str(e)}") + raise Exception( + f"Unexpected error while initializing vLLM provider: {str(e)}" + ) def generate(self, prompt: str, **gen_args: Any) -> str: """ Generate text using the vLLM API. - + Args: prompt: The input prompt for text generation - **gen_args: Additional generation arguments - + **gen_args: Additional generation arguments. Supported arguments include: + - max_tokens: Maximum number of tokens to generate (default: 25) + - max_new_tokens: Alternative to max_tokens + (will be converted to max_tokens) + - temperature, top_p, etc.: Other generation parameters + supported by the OpenAI API + Returns: str: The generated text - + Raises: KeyError: If API response is missing expected data ConnectionError: If connection to API fails """ # Parameter normalization - if 'max_tokens' not in gen_args: - gen_args['max_tokens'] = gen_args.pop('max_new_tokens', 25) - if 'do_sample' in gen_args: - gen_args.pop('do_sample') - + if "max_tokens" not in gen_args: + gen_args["max_tokens"] = gen_args.pop("max_new_tokens", 25) + if "do_sample" in gen_args: + gen_args.pop("do_sample") + try: response = self.client.completions.create( - model=self.model_name, - prompt=prompt, - **gen_args + model=self.model_name, prompt=prompt, **gen_args ) - if not hasattr(response, 'choices') or not response.choices: + if not hasattr(response, "choices") or not response.choices: raise KeyError("API response missing expected 'choices' data") - + return response.choices[0].text except ConnectionError as e: raise ConnectionError(f"Failed to connect to vLLM API: {str(e)}") except Exception as e: raise Exception(f"Unexpected error during text generation: {str(e)}") - + def get_logprobs( - self, + self, input_prompt: str, target_str: str, - tokenizer: Optional[TokenizerLike] = None + tokenizer: Optional[TokenizerLike] = None, ) -> List[float]: """ Get the log probabilities for all tokens in the target string. - + Args: input_prompt: The input prompt target_str: The target string tokenizer: The tokenizer to use - + Returns: - A list of log probabilities corresponding to each token in the target prompt. - For a `target_str` of `t` tokens, this method returns a list of logprobs of length `k`. - + A list of log probabilities corresponding to each token + in the target prompt. + For a `target_str` of `t` tokens, this method returns + a list of logprobs of length `t`. + Raises: - ValueError: If tokenizer is None or target_str is empty or response format is invalid + ValueError: If tokenizer is None or target_str is empty + or response format is invalid KeyError: If API response is missing expected data IndexError: If response format is unexpected ConnectionError: If connection to API fails @@ -149,43 +178,52 @@ def get_logprobs( raise ValueError("Tokenizer is required for vLLM provider") if not target_str: raise ValueError("Target string cannot be empty") - - num_target_str_tokens = len(tokenizer.encode(target_str, add_special_tokens=False)) - + + num_target_str_tokens = len( + tokenizer.encode(target_str, add_special_tokens=False) + ) + prompt = input_prompt + target_str - + try: response = self.client.completions.create( model=self.model_name, prompt=prompt, temperature=0.0, max_tokens=1, - extra_body={"prompt_logprobs": 0} + extra_body={"prompt_logprobs": 0}, ) - - if not hasattr(response, 'choices') or not response.choices: + + if not hasattr(response, "choices") or not response.choices: raise KeyError("API response missing expected 'choices' data") - - if not hasattr(response.choices[0], 'prompt_logprobs'): + + if not hasattr(response.choices[0], "prompt_logprobs"): raise KeyError("API response missing 'prompt_logprobs' data") - + prompt_logprobs = [] try: - for probs in response.choices[0].prompt_logprobs[1:]: + for probs in response.choices[0].prompt_logprobs[ + -num_target_str_tokens: + ]: if not probs: raise ValueError("Empty probability data in API response") - prompt_logprobs.append(list(probs.values())[0]['logprob']) + prompt_logprobs.append(next(iter(probs.values()))["logprob"]) except (IndexError, KeyError) as e: raise IndexError(f"Unexpected format in log probability data: {str(e)}") - - if len(prompt_logprobs) < num_target_str_tokens: - raise ValueError(f"Not enough logprobs received: expected {num_target_str_tokens}, got {len(prompt_logprobs)}") - - return prompt_logprobs[-num_target_str_tokens:] - + + if len(prompt_logprobs) != num_target_str_tokens: + raise ValueError( + f"Not enough logprobs received:" + f"expected {num_target_str_tokens}, got {len(prompt_logprobs)}" + ) + + return prompt_logprobs + except ConnectionError as e: - raise ConnectionError(f"Failed to connect to vLLM API when getting logprobs: {str(e)}") + raise ConnectionError( + f"Failed to connect to vLLM API when getting logprobs: {str(e)}" + ) except Exception as e: - raise Exception(f"Unexpected error while getting log probabilities: {str(e)}") - - \ No newline at end of file + raise Exception( + f"Unexpected error while getting log probabilities: {str(e)}" + ) diff --git a/tests/attr/test_llm_attr.py b/tests/attr/test_llm_attr.py index 1f557e8e25..3cc5f0919d 100644 --- a/tests/attr/test_llm_attr.py +++ b/tests/attr/test_llm_attr.py @@ -28,7 +28,11 @@ from captum.attr._core.layer.layer_gradient_x_activation import LayerGradientXActivation from captum.attr._core.layer.layer_integrated_gradients import LayerIntegratedGradients from captum.attr._core.lime import Lime -from captum.attr._core.llm_attr import LLMAttribution, LLMGradientAttribution, RemoteLLMAttribution +from captum.attr._core.llm_attr import ( + LLMAttribution, + LLMGradientAttribution, + RemoteLLMAttribution, +) from captum.attr._core.remote_provider import RemoteLLMProvider from captum.attr._core.shapley_value import ShapleyValues, ShapleyValueSampling from captum.attr._utils.attribution import GradientAttribution, PerturbationAttribution @@ -671,29 +675,39 @@ def test_llm_attr_with_skip_tensor_target(self) -> None: self.assertEqual(res.input_tokens, ["", "a", "b", "c"]) self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"]) + class DummyRemoteLLMProvider(RemoteLLMProvider): def __init__(self, deterministic_logprobs: bool = False) -> None: self.api_url = "https://test-api.com" self.deterministic_logprobs = deterministic_logprobs - + def generate(self, prompt: str, **gen_args: Any) -> str: - assert "mock_response" in gen_args, "must mock response to use DummyRemoteLLMProvider to generate" + assert ( + "mock_response" in gen_args + ), "must mock response to use DummyRemoteLLMProvider to generate" return gen_args["mock_response"] - def get_logprobs(self, input_prompt: str, target_str: str, tokenizer: Optional[TokenizerLike] = None) -> List[float]: + def get_logprobs( + self, + input_prompt: str, + target_str: str, + tokenizer: Optional[TokenizerLike] = None, + ) -> List[float]: assert tokenizer is not None, "Tokenizer is required" prompt = input_prompt + target_str tokens = tokenizer.encode(prompt, add_special_tokens=False) num_tokens = len(tokens) - - num_target_str_tokens = len(tokenizer.encode(target_str, add_special_tokens=False)) - + + num_target_str_tokens = len( + tokenizer.encode(target_str, add_special_tokens=False) + ) + logprobs = [] - + for i in range(num_tokens): # Start with a base value logprob = -0.1 - (0.01 * i) - + # Make sensitive to key features if "a" not in prompt: logprob -= 0.1 @@ -703,18 +717,19 @@ def get_logprobs(self, input_prompt: str, target_str: str, tokenizer: Optional[T logprob -= 0.3 if "f" not in prompt: logprob -= 0.4 - + logprobs.append(logprob) - + return logprobs[-num_target_str_tokens:] - + + @parameterized_class( ("device",), [("cpu",), ("cuda",)] if torch.cuda.is_available() else [("cpu",)] ) class TestRemoteLLMAttr(BaseTest): # pyre-fixme[13]: Attribute `device` is never initialized. device: str - + # pyre-fixme[56]: Pyre was not able to infer the type of argument @parameterized.expand( [ @@ -736,25 +751,25 @@ class TestRemoteLLMAttr(BaseTest): ), ( # true_tok_attr [ # FeatureAblation - [0.1, 0.2, 0.3, 0.4], - [0.1, 0.2, 0.3, 0.4], - [0.1, 0.2, 0.3, 0.4], - [0.1, 0.2, 0.3, 0.4], - [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], ], [ # ShapleyValueSampling - [0.1, 0.2, 0.3, 0.4], - [0.1, 0.2, 0.3, 0.4], - [0.1, 0.2, 0.3, 0.4], - [0.1, 0.2, 0.3, 0.4], - [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], ], [ # ShapleyValues - [0.1, 0.2, 0.3, 0.4], - [0.1, 0.2, 0.3, 0.4], - [0.1, 0.2, 0.3, 0.4], - [0.1, 0.2, 0.3, 0.4], - [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], ], ), ) @@ -771,14 +786,10 @@ def test_remote_llm_attr( attr_kws: Dict[str, int] = {} if n_samples is not None: attr_kws["n_samples"] = n_samples - - # In remote mode, we don't need the actual model, this is just a placeholder - placeholder_model = torch.nn.Module() - placeholder_model.device = self.device tokenizer = DummyTokenizer() provider = DummyRemoteLLMProvider(deterministic_logprobs=True) - attr_method = AttrClass(placeholder_model) + attr_method = AttrClass(RemoteLLMAttribution.placeholder_model) remote_llm_attr = RemoteLLMAttribution( attr_method=attr_method, tokenizer=tokenizer, @@ -821,13 +832,10 @@ def test_remote_llm_attr( ) def test_remote_llm_attr_without_target(self) -> None: - # In remote mode, we don't need the actual model, this is just a placeholder - placeholder_model = torch.nn.Module() - placeholder_model.device = self.device - + tokenizer = DummyTokenizer() provider = DummyRemoteLLMProvider(deterministic_logprobs=True) - attr_method = FeatureAblation(placeholder_model) + attr_method = FeatureAblation(RemoteLLMAttribution.placeholder_model) remote_llm_attr = RemoteLLMAttribution( attr_method=attr_method, tokenizer=tokenizer, @@ -850,13 +858,10 @@ def test_remote_llm_attr_without_target(self) -> None: self.assertEqual(cast(Tensor, res.token_attr).device.type, self.device) def test_remote_llm_attr_fa_log_prob(self) -> None: - # In remote mode, we don't need the actual model, this is just a placeholder - placeholder_model = torch.nn.Module() - placeholder_model.device = self.device - + tokenizer = DummyTokenizer() provider = DummyRemoteLLMProvider(deterministic_logprobs=True) - attr_method = FeatureAblation(placeholder_model) + attr_method = FeatureAblation(RemoteLLMAttribution.placeholder_model) remote_llm_attr = RemoteLLMAttribution( attr_method=attr_method, tokenizer=tokenizer, @@ -914,13 +919,9 @@ def test_remote_llm_attr_without_token( if n_samples is not None: attr_kws["n_samples"] = n_samples - # In remote mode, we don't need the actual model, this is just a placeholder - placeholder_model = torch.nn.Module() - placeholder_model.device = self.device - tokenizer = DummyTokenizer() provider = DummyRemoteLLMProvider(deterministic_logprobs=True) - attr_method = AttrClass(placeholder_model, **init_kws) + attr_method = AttrClass(RemoteLLMAttribution.placeholder_model, **init_kws) remote_llm_attr = RemoteLLMAttribution( attr_method=attr_method, tokenizer=tokenizer, @@ -949,20 +950,18 @@ def test_remote_llm_attr_without_token( delta=delta, mode="max", ) + def test_remote_llm_attr_futures_not_implemented(self) -> None: - # In remote mode, we don't need the actual model, this is just a placeholder - placeholder_model = torch.nn.Module() - placeholder_model.device = self.device - + tokenizer = DummyTokenizer() provider = DummyRemoteLLMProvider() - attr_method = FeatureAblation(placeholder_model) + attr_method = FeatureAblation(RemoteLLMAttribution.placeholder_model) remote_llm_attr = RemoteLLMAttribution( attr_method=attr_method, tokenizer=tokenizer, provider=provider, ) - + # from TestLLMAttr attributions = None with self.assertRaises(NotImplementedError): @@ -970,13 +969,10 @@ def test_remote_llm_attr_futures_not_implemented(self) -> None: self.assertEqual(attributions, None) def test_remote_llm_attr_with_no_skip_tokens(self) -> None: - # In remote mode, we don't need the actual model, this is just a placeholder - placeholder_model = torch.nn.Module() - placeholder_model.device = self.device - + tokenizer = DummyTokenizer() provider = DummyRemoteLLMProvider(deterministic_logprobs=True) - attr_method = FeatureAblation(placeholder_model) + attr_method = FeatureAblation(RemoteLLMAttribution.placeholder_model) remote_llm_fa = RemoteLLMAttribution( attr_method=attr_method, tokenizer=tokenizer, @@ -985,10 +981,7 @@ def test_remote_llm_attr_with_no_skip_tokens(self) -> None: # from TestLLMAttr inp = TextTokenInput("a b c", tokenizer) - res = remote_llm_fa.attribute( - inp, - "m n o p q" - ) + res = remote_llm_fa.attribute(inp, "m n o p q") # 5 output tokens, 4 input tokens including sos self.assertEqual(res.seq_attr.shape, (4,)) @@ -998,15 +991,12 @@ def test_remote_llm_attr_with_no_skip_tokens(self) -> None: self.assertEqual(token_attr.shape, (6, 4)) self.assertEqual(res.input_tokens, ["", "a", "b", "c"]) self.assertEqual(res.output_tokens, ["", "m", "n", "o", "p", "q"]) - + def test_remote_llm_attr_with_skip_tensor_target(self) -> None: - # In remote mode, we don't need the actual model, this is just a placeholder - placeholder_model = torch.nn.Module() - placeholder_model.device = self.device - + tokenizer = DummyTokenizer() provider = DummyRemoteLLMProvider(deterministic_logprobs=True) - attr_method = FeatureAblation(placeholder_model) + attr_method = FeatureAblation(RemoteLLMAttribution.placeholder_model) remote_llm_fa = RemoteLLMAttribution( attr_method=attr_method, tokenizer=tokenizer, From d2b948c0b14be1ef9593ca01ed9baa6fa5e2c3ee Mon Sep 17 00:00:00 2001 From: saichandrapandraju Date: Tue, 13 May 2025 17:51:12 -0700 Subject: [PATCH 4/7] add tests for VLLMProvider --- captum/attr/_core/remote_provider.py | 11 +- tests/attr/test_llm_attr.py | 352 ++++++++++++++++++++++++++- 2 files changed, 361 insertions(+), 2 deletions(-) diff --git a/captum/attr/_core/remote_provider.py b/captum/attr/_core/remote_provider.py index 5d7a94f647..8ec2b9aeae 100644 --- a/captum/attr/_core/remote_provider.py +++ b/captum/attr/_core/remote_provider.py @@ -100,6 +100,8 @@ def __init__(self, api_url: str, model_name: Optional[str] = None): else: self.model_name = model_name + except ValueError: + raise except ConnectionError as e: raise ConnectionError(f"Failed to connect to vLLM API: {str(e)}") except Exception as e: @@ -142,6 +144,8 @@ def generate(self, prompt: str, **gen_args: Any) -> str: return response.choices[0].text + except KeyError: + raise except ConnectionError as e: raise ConnectionError(f"Failed to connect to vLLM API: {str(e)}") except Exception as e: @@ -197,7 +201,10 @@ def get_logprobs( if not hasattr(response, "choices") or not response.choices: raise KeyError("API response missing expected 'choices' data") - if not hasattr(response.choices[0], "prompt_logprobs"): + if ( + not hasattr(response.choices[0], "prompt_logprobs") + or not response.choices[0].prompt_logprobs + ): raise KeyError("API response missing 'prompt_logprobs' data") prompt_logprobs = [] @@ -219,6 +226,8 @@ def get_logprobs( return prompt_logprobs + except (KeyError, IndexError, ValueError): + raise except ConnectionError as e: raise ConnectionError( f"Failed to connect to vLLM API when getting logprobs: {str(e)}" diff --git a/tests/attr/test_llm_attr.py b/tests/attr/test_llm_attr.py index 3cc5f0919d..411fd70e37 100644 --- a/tests/attr/test_llm_attr.py +++ b/tests/attr/test_llm_attr.py @@ -18,6 +18,7 @@ Type, Union, ) +from unittest.mock import MagicMock, patch import torch from captum._utils.models.linear_model import SkLearnLasso @@ -33,7 +34,7 @@ LLMGradientAttribution, RemoteLLMAttribution, ) -from captum.attr._core.remote_provider import RemoteLLMProvider +from captum.attr._core.remote_provider import RemoteLLMProvider, VLLMProvider from captum.attr._core.shapley_value import ShapleyValues, ShapleyValueSampling from captum.attr._utils.attribution import GradientAttribution, PerturbationAttribution from captum.attr._utils.interpretable_input import TextTemplateInput, TextTokenInput @@ -676,6 +677,355 @@ def test_llm_attr_with_skip_tensor_target(self) -> None: self.assertEqual(res.output_tokens, ["m", "n", "o", "p", "q"]) +class TestVLLMProvider(BaseTest): + """Test suite for VLLMProvider class.""" + + def setUp(self) -> None: + super().setUp() + self.api_url = "https://test-vllm-api.com" + self.model_name = "test-model" + self.input_prompt = "a b c d" + self.target_str = "e f g h" + + self.tokenizer = DummyTokenizer() + + # Set up patch for OpenAI import + self.openai_patcher = patch("openai.OpenAI") + self.mock_openai = self.openai_patcher.start() + + # Create a mock OpenAI client + self.mock_client = MagicMock() + self.mock_openai.return_value = self.mock_client + + def tearDown(self) -> None: + self.openai_patcher.stop() + super().tearDown() + + def test_init_successful(self) -> None: + """Test successful initialization of VLLMProvider.""" + model_name: str = "default-model" + + # Mock the models.list() response + mock_models_data = [MagicMock(id=model_name)] + self.mock_client.models.list.return_value = MagicMock(data=mock_models_data) + + # Create provider without specifying model name + provider = VLLMProvider(api_url=self.api_url) + + # Verify the client was initialized correctly + self.mock_openai.assert_called_once() + self.assertEqual(provider.api_url, self.api_url) + self.assertEqual(provider.model_name, model_name) + + # Verify models.list() was called + self.mock_client.models.list.assert_called_once() + + def test_init_with_model_name(self) -> None: + """Test initialization with specific model name.""" + provider = VLLMProvider(api_url=self.api_url, model_name=self.model_name) + + # Verify model name was set correctly + self.assertEqual(provider.model_name, self.model_name) + + # Verify models.list() was NOT called + self.mock_client.models.list.assert_not_called() + + def test_init_empty_api_url(self) -> None: + """Test initialization with empty API URL raises ValueError.""" + with self.assertRaises(ValueError) as context: + VLLMProvider(api_url=" ") + + self.assertIn("API URL is required", str(context.exception)) + + def test_init_connection_error(self) -> None: + """Test initialization handling connection error.""" + # Mock connection error + self.mock_openai.side_effect = ConnectionError("Failed to connect") + + with self.assertRaises(ConnectionError) as context: + VLLMProvider(api_url=self.api_url) + + self.assertIn("Failed to connect to vLLM API", str(context.exception)) + + def test_init_no_models(self) -> None: + """Test initialization when no models are available.""" + # Mock empty models list + self.mock_client.models.list.return_value = MagicMock(data=[]) + + with self.assertRaises(ValueError) as context: + VLLMProvider(api_url=self.api_url) + + self.assertIn("No models available", str(context.exception)) + + def test_generate_successful(self) -> None: + """Test successful text generation.""" + # Set up mock response + mock_choice = MagicMock() + mock_choice.text = self.target_str + mock_response = MagicMock() + mock_response.choices = [mock_choice] + self.mock_client.completions.create.return_value = mock_response + + # Create provider + provider = VLLMProvider(api_url=self.api_url, model_name=self.model_name) + provider.client = self.mock_client + + # Call generate + result = provider.generate(self.input_prompt, max_tokens=10) + + # Verify result + self.assertEqual(result, self.target_str) + + # Verify API was called with correct parameters + self.mock_client.completions.create.assert_called_once_with( + model=self.model_name, prompt=self.input_prompt, max_tokens=10 + ) + + def test_generate_with_max_new_tokens(self) -> None: + """Test generation with max_new_tokens parameter.""" + # Set up mock response + mock_choice = MagicMock() + mock_choice.text = self.target_str + mock_response = MagicMock() + mock_response.choices = [mock_choice] + self.mock_client.completions.create.return_value = mock_response + + # Create provider + provider = VLLMProvider(api_url=self.api_url, model_name=self.model_name) + provider.client = self.mock_client + + # Call generate with max_new_tokens instead of max_tokens + _ = provider.generate(self.input_prompt, max_new_tokens=10) + + # Verify API was called with converted max_tokens parameter + self.mock_client.completions.create.assert_called_once_with( + model=self.model_name, prompt=self.input_prompt, max_tokens=10 + ) + + def test_generate_empty_choices(self) -> None: + """Test generation when response has empty choices.""" + # Set up mock response with empty choices + mock_response = MagicMock() + mock_response.choices = [] + self.mock_client.completions.create.return_value = mock_response + + # Create provider + provider = VLLMProvider(api_url=self.api_url, model_name=self.model_name) + provider.client = self.mock_client + + # Call generate and expect exception + with self.assertRaises(KeyError) as context: + provider.generate(self.input_prompt) + + self.assertIn( + "API response missing expected 'choices' data", str(context.exception) + ) + + def test_generate_connection_error(self) -> None: + """Test generation handling connection error.""" + # Mock connection error + self.mock_client.completions.create.side_effect = ConnectionError( + "Connection failed" + ) + + # Create provider + provider = VLLMProvider(api_url=self.api_url, model_name=self.model_name) + provider.client = self.mock_client + + # Call generate and expect exception + with self.assertRaises(ConnectionError) as context: + provider.generate(self.input_prompt) + + self.assertIn("Failed to connect to vLLM API", str(context.exception)) + + def test_get_logprobs_successful(self) -> None: + """Test successful retrieval of log probabilities.""" + # Set up test data + input_token_ids = self.tokenizer.encode( + self.input_prompt, add_special_tokens=False + ) + num_input_tokens = len(input_token_ids) + + target_token_ids = self.tokenizer.encode( + self.target_str, add_special_tokens=False + ) + expected_values = [0.1, 0.2, 0.3, 0.4] + num_target_tokens = len(target_token_ids) + + # Create mock vLLM response with prompt_logprobs + prompt_logprobs: List[Dict[str, Dict[str, Any]]] = [] + for i in range(num_input_tokens): + token_probs = { + str(input_token_ids[i]): { + "logprob": -0.5, # fixed logprob for input tokens (for testing) + "rank": i + 1, + "decoded_token": self.tokenizer.convert_ids_to_tokens( + input_token_ids[i] + ), + } + } + prompt_logprobs.append(token_probs) + for i in range(num_target_tokens): + token_probs = { + str(target_token_ids[i]): { + "logprob": expected_values[i], + "rank": i + 1, + "decoded_token": self.tokenizer.convert_ids_to_tokens( + target_token_ids[i] + ), + } + } + prompt_logprobs.append(token_probs) + + mock_choices = MagicMock() + # prompt_logprobs will be of length + # num_input_tokens + num_target_tokens + mock_choices.prompt_logprobs = prompt_logprobs + mock_response = MagicMock() + mock_response.choices = [mock_choices] + self.mock_client.completions.create.return_value = mock_response + + # Create provider and call get_logprobs + provider = VLLMProvider(api_url=self.api_url, model_name=self.model_name) + provider.client = self.mock_client + + logprobs = provider.get_logprobs( + self.input_prompt, self.target_str, self.tokenizer + ) + + # Verify API call + self.mock_client.completions.create.assert_called_once_with( + model=self.model_name, + prompt=self.input_prompt + self.target_str, + temperature=0.0, + max_tokens=1, + extra_body={"prompt_logprobs": 0}, + ) + + # Verify results + self.assertEqual(len(logprobs), num_target_tokens) + for i, logprob in enumerate(logprobs): + self.assertEqual(logprob, expected_values[i]) + + def test_get_logprobs_missing_tokenizer(self) -> None: + """Test get_logprobs with missing tokenizer.""" + with self.assertRaises(ValueError) as context: + provider = VLLMProvider(api_url=self.api_url, model_name=self.model_name) + provider.get_logprobs(self.input_prompt, self.target_str, None) + + self.assertIn("Tokenizer is required", str(context.exception)) + + def test_get_logprobs_empty_target(self) -> None: + """Test get_logprobs with empty target string.""" + with self.assertRaises(ValueError) as context: + provider = VLLMProvider(api_url=self.api_url, model_name=self.model_name) + provider.get_logprobs(self.input_prompt, "", self.tokenizer) + + self.assertIn("Target string cannot be empty", str(context.exception)) + + def test_get_logprobs_missing_prompt_logprobs(self) -> None: + """Test get_logprobs when response is missing prompt_logprobs.""" + # Set up mock response without prompt_logprobs + mock_choices = MagicMock() + mock_choices.prompt_logprobs = None + mock_response = MagicMock() + mock_response.choices = [mock_choices] + self.mock_client.completions.create.return_value = mock_response + + # Create provider + provider = VLLMProvider(api_url=self.api_url, model_name=self.model_name) + provider.client = self.mock_client + + with self.assertRaises(KeyError) as context: + provider.get_logprobs(self.input_prompt, self.target_str, self.tokenizer) + + self.assertIn( + "API response missing 'prompt_logprobs' data", str(context.exception) + ) + + def test_get_logprobs_empty_probs(self) -> None: + """Test get_logprobs with empty probability data.""" + # Create mock response with empty probs + prompt_logprobs: List[Dict[str, Dict[str, Any]]] = [{}] # Empty dict for token probabilities + mock_choices = MagicMock() + mock_choices.prompt_logprobs = prompt_logprobs + mock_response = MagicMock() + mock_response.choices = [mock_choices] + self.mock_client.completions.create.return_value = mock_response + + # Create provider + provider = VLLMProvider(api_url=self.api_url, model_name=self.model_name) + provider.client = self.mock_client + + with self.assertRaises(ValueError) as context: + provider.get_logprobs(self.input_prompt, self.target_str, self.tokenizer) + + self.assertIn("Empty probability data", str(context.exception)) + + def test_get_logprobs_keyerror(self) -> None: + """Test get_logprobs with missing 'logprob' key in response.""" + # Create mock response with invalid token prob structure + prompt_logprobs: List[Dict[str, Dict[str, Any]]] = [ + {"1": {"wrong_logprob_key": 0.1, "rank": 1, "decoded_token": "a"}}, + {"2": {"wrong_logprob_key": 0.2, "rank": 1, "decoded_token": "b"}}, + ] + mock_choices = MagicMock() + mock_choices.prompt_logprobs = prompt_logprobs + mock_response = MagicMock() + mock_response.choices = [mock_choices] + self.mock_client.completions.create.return_value = mock_response + + # Create provider + provider = VLLMProvider(api_url=self.api_url, model_name=self.model_name) + provider.client = self.mock_client + + with self.assertRaises(IndexError) as context: + provider.get_logprobs("a", "b", self.tokenizer) + + self.assertIn( + "Unexpected format in log probability data", str(context.exception) + ) + + def test_get_logprobs_length_mismatch(self) -> None: + """Test get_logprobs with length mismatch + between expected and received tokens.""" + # Create mock response with only 1 logprobs (fewer than expected) + prompt_logprobs = [{"1": {"logprob": 0.1, "rank": 1, "decoded_token": "a"}}] + mock_choices = MagicMock() + mock_choices.prompt_logprobs = prompt_logprobs + mock_response = MagicMock() + mock_response.choices = [mock_choices] + self.mock_client.completions.create.return_value = mock_response + + # Create provider + provider = VLLMProvider(api_url=self.api_url, model_name=self.model_name) + provider.client = self.mock_client + + with self.assertRaises(ValueError) as context: + provider.get_logprobs(self.input_prompt, self.target_str, self.tokenizer) + + self.assertIn("Not enough logprobs received", str(context.exception)) + + def test_get_logprobs_connection_error(self) -> None: + """Test get_logprobs handling connection error.""" + # Mock connection error + self.mock_client.completions.create.side_effect = ConnectionError( + "Connection failed" + ) + + # Create provider + provider = VLLMProvider(api_url=self.api_url, model_name=self.model_name) + provider.client = self.mock_client + + with self.assertRaises(ConnectionError) as context: + provider.get_logprobs(self.input_prompt, self.target_str, self.tokenizer) + + self.assertIn( + "Failed to connect to vLLM API when getting logprobs", + str(context.exception), + ) + + class DummyRemoteLLMProvider(RemoteLLMProvider): def __init__(self, deterministic_logprobs: bool = False) -> None: self.api_url = "https://test-api.com" From fd469b2b97862458846ada1eebf77069594156a1 Mon Sep 17 00:00:00 2001 From: saichandrapandraju Date: Tue, 13 May 2025 17:54:36 -0700 Subject: [PATCH 5/7] format empty prompt_logprobs --- tests/attr/test_llm_attr.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/attr/test_llm_attr.py b/tests/attr/test_llm_attr.py index 411fd70e37..d944d283f4 100644 --- a/tests/attr/test_llm_attr.py +++ b/tests/attr/test_llm_attr.py @@ -946,7 +946,9 @@ def test_get_logprobs_missing_prompt_logprobs(self) -> None: def test_get_logprobs_empty_probs(self) -> None: """Test get_logprobs with empty probability data.""" # Create mock response with empty probs - prompt_logprobs: List[Dict[str, Dict[str, Any]]] = [{}] # Empty dict for token probabilities + prompt_logprobs: List[Dict[str, Dict[str, Any]]] = [ + {} + ] # Empty dict for token probabilities mock_choices = MagicMock() mock_choices.prompt_logprobs = prompt_logprobs mock_response = MagicMock() From 5ddc238b116673dc37bad2e0b3cdca84c94bc931 Mon Sep 17 00:00:00 2001 From: saichandrapandraju Date: Tue, 13 May 2025 19:38:01 -0700 Subject: [PATCH 6/7] assert probs length to 1 with `prompt_logprobs = 0` --- captum/attr/_core/remote_provider.py | 1 + 1 file changed, 1 insertion(+) diff --git a/captum/attr/_core/remote_provider.py b/captum/attr/_core/remote_provider.py index 8ec2b9aeae..386f823aca 100644 --- a/captum/attr/_core/remote_provider.py +++ b/captum/attr/_core/remote_provider.py @@ -214,6 +214,7 @@ def get_logprobs( ]: if not probs: raise ValueError("Empty probability data in API response") + assert len(probs) == 1, "Expected exactly one token in logprobs" prompt_logprobs.append(next(iter(probs.values()))["logprob"]) except (IndexError, KeyError) as e: raise IndexError(f"Unexpected format in log probability data: {str(e)}") From 38ca44b965dc3a64903f2189f67decd148fc160c Mon Sep 17 00:00:00 2001 From: saichandrapandraju Date: Tue, 13 May 2025 21:30:06 -0700 Subject: [PATCH 7/7] log at module-level --- captum/attr/_core/remote_provider.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/captum/attr/_core/remote_provider.py b/captum/attr/_core/remote_provider.py index 386f823aca..db2dc2a40a 100644 --- a/captum/attr/_core/remote_provider.py +++ b/captum/attr/_core/remote_provider.py @@ -5,6 +5,8 @@ from captum._utils.typing import TokenizerLike +logger = logging.getLogger(__name__) + class RemoteLLMProvider(ABC): """All remote LLM providers that offer logprob via API @@ -93,7 +95,7 @@ def __init__(self, api_url: str, model_name: Optional[str] = None): if not models: raise ValueError("No models available from the vLLM API") self.model_name = models[0].id - logging.info( + logger.info( f"No model_name is specified for VLLMProvider." f" Using first available model: {self.model_name}" )