generated from guardrails-ai/validator-template
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
331 lines (295 loc) · 13.2 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
import json
import math
from typing import Callable, List, Optional, Union, Any
import torch
from torch.nn import functional as F
from transformers import pipeline, AutoTokenizer, AutoModel
from guardrails.validator_base import (
FailResult,
PassResult,
ValidationResult,
Validator,
register_validator,
)
from .resources import KNOWN_ATTACKS, get_tokenizer_and_model_by_path, get_pipeline_by_path
from .models import PromptSaturationDetectorV3
@register_validator(name="guardrails/detect_jailbreak", data_type="string")
class DetectJailbreak(Validator):
"""Validates that a prompt does not attempt to circumvent restrictions on behavior.
An example would be convincing the model via prompt to provide instructions that
could cause harm to one or more people.
**Key Properties**
| Property | Description |
| ----------------------------- | --------------------------------- |
| Name for `format` attribute | `guardrails/detect-jailbreak` |
| Supported data types | `string` |
| Programmatic fix | `None` |
Args:
threshold (float): Defaults to 0.81. A float between 0 and 1, with lower being
more sensitive. A high value means the model will be fairly permissive and
unlikely to flag any but the most flagrant jailbreak attempts. A low value will
be pessimistic and will possibly flag legitimate inquiries.
device (str): Defaults to 'cpu'. The device on which the model will be run.
Accepts 'mps' for hardware acceleration on MacOS and 'cuda' for GPU acceleration
on supported hardware. A device ID can also be specified, e.g., "cuda:0".
model_path_override (str): A pointer to an ensemble tar file in S3 or on disk.
""" # noqa
TEXT_CLASSIFIER_NAME = "zhx123/ftrobertallm"
TEXT_CLASSIFIER_PASS_LABEL = 0
TEXT_CLASSIFIER_FAIL_LABEL = 1
EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
DEFAULT_KNOWN_PROMPT_MATCH_THRESHOLD = 0.9
MALICIOUS_EMBEDDINGS = KNOWN_ATTACKS
SATURATION_CLASSIFIER_PASS_LABEL = "safe"
SATURATION_CLASSIFIER_FAIL_LABEL = "jailbreak"
# These were found with a basic low-granularity beam search.
DEFAULT_KNOWN_ATTACK_SCALE_FACTORS = (0.5, 2.0)
DEFAULT_SATURATION_ATTACK_SCALE_FACTORS = (3.5, 2.5)
DEFAULT_TEXT_CLASSIFIER_SCALE_FACTORS = (3.0, 2.5)
def __init__(
self,
threshold: float = 0.81,
device: str = "cpu",
on_fail: Optional[Callable] = None,
model_path_override: str = "",
**kwargs,
):
super().__init__(on_fail=on_fail, **kwargs)
self.device = device
self.threshold = threshold
self.saturation_attack_detector = None
self.text_classifier = None
self.embedding_tokenizer = None
self.embedding_model = None
self.known_malicious_embeddings = []
# It's possible for self.use_local to be unset and in some indeterminate state.
# First take use_local as a kwarg as the truth.
# If that's not present, try self.use_local.
# If that's not present, default to true.
if "use_local" in kwargs:
self.use_local = kwargs["use_local"]
elif self.use_local is None:
self.use_local = True
if self.use_local:
if not model_path_override:
self.saturation_attack_detector = PromptSaturationDetectorV3(
device=torch.device(device),
)
self.text_classifier = pipeline(
"text-classification",
DetectJailbreak.TEXT_CLASSIFIER_NAME,
max_length=512, # HACK: Fix classifier size.
truncation=True,
device=device,
)
# There are a large number of fairly low-effort prompts people will use.
# The embedding detectors do checks to roughly match those.
self.embedding_tokenizer = AutoTokenizer.from_pretrained(
DetectJailbreak.EMBEDDING_MODEL_NAME
)
self.embedding_model = AutoModel.from_pretrained(
DetectJailbreak.EMBEDDING_MODEL_NAME
).to(device)
else:
# Saturation:
self.saturation_attack_detector = PromptSaturationDetectorV3(
device=torch.device(device),
model_path_override=model_path_override
)
# Known attacks:
embedding_tokenizer, embedding_model = get_tokenizer_and_model_by_path(
model_path_override,
"embedding",
AutoTokenizer,
AutoModel
)
self.embedding_tokenizer = embedding_tokenizer
self.embedding_model = embedding_model.to(device)
# Other text attacks:
self.text_classifier = get_pipeline_by_path(
model_path_override,
"text-classifier",
"text-classification",
max_length=512,
truncation=True,
device=device
)
# Quick compute on startup:
self.known_malicious_embeddings = self._embed(KNOWN_ATTACKS)
# These _are_ modifyable, but not explicitly advertised.
self.known_attack_scales = DetectJailbreak.DEFAULT_KNOWN_ATTACK_SCALE_FACTORS
self.saturation_attack_scales = DetectJailbreak.DEFAULT_SATURATION_ATTACK_SCALE_FACTORS
self.text_attack_scales = DetectJailbreak.DEFAULT_TEXT_CLASSIFIER_SCALE_FACTORS
@staticmethod
def _rescale(x: float, a: float = 1.0, b: float = 1.0):
return 1.0 / (1.0 + (a*math.exp(-b*x)))
@staticmethod
def _mean_pool(model_output, attention_mask):
"""Taken from https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2."""
# First element of model_output contains all token embeddings
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(
token_embeddings.size()
).float()
return torch.sum(
token_embeddings * input_mask_expanded, 1
) / torch.clamp(
input_mask_expanded.sum(1), min=1e-9
)
def _embed(self, prompts: List[str]):
"""Taken from https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
We use the long-form to avoid a dependency on sentence transformers.
This method returns the maximum of the matches against all known attacks.
"""
encoded_input = self.embedding_tokenizer(
prompts,
padding=True,
truncation=True,
return_tensors='pt',
max_length=512, # This may be too small to adequately capture the info.
).to(self.device)
with torch.no_grad():
model_outputs = self.embedding_model(**encoded_input)
embeddings = DetectJailbreak._mean_pool(
model_outputs, attention_mask=encoded_input['attention_mask'])
return F.normalize(embeddings, p=2, dim=1)
def _match_known_malicious_prompts(
self,
prompts: Union[List[str], torch.Tensor],
) -> List[float]:
"""Returns an array of floats, one per prompt, with the max match to known
attacks. If prompts is a list of strings, embeddings will be generated. If
embeddings are passed, they will be used."""
if isinstance(prompts, list):
prompt_embeddings = self._embed(prompts)
else:
prompt_embeddings = prompts
# These are already normalized. We don't need to divide by magnitudes again.
distances = prompt_embeddings @ self.known_malicious_embeddings.T
return [
DetectJailbreak._rescale(s, *self.known_attack_scales)
for s in (torch.max(distances, axis=1).values).tolist()
]
def _predict_and_remap(
self,
model,
prompts: List[str],
label_field: str,
score_field: str,
safe_case: str,
fail_case: str,
):
predictions = model(prompts)
scores = list() # We want to remap so 0 is 'safe' and 1 is 'unsafe'.
for pred in predictions:
old_score = pred[score_field]
is_safe = pred[label_field] == safe_case
assert pred[label_field] in {safe_case, fail_case} \
and 0.0 <= old_score <= 1.0
if is_safe:
new_score = 0.5 - (old_score * 0.5)
else:
new_score = 0.5 + (old_score * 0.5)
scores.append(new_score)
return scores
def _predict_jailbreak(self, prompts: List[str]) -> List[float]:
return [
DetectJailbreak._rescale(s, *self.text_attack_scales)
for s in self._predict_and_remap(
self.text_classifier,
prompts,
"label",
"score",
self.TEXT_CLASSIFIER_PASS_LABEL,
self.TEXT_CLASSIFIER_FAIL_LABEL,
)
]
def _predict_saturation(self, prompts: List[str]) -> List[float]:
return [
DetectJailbreak._rescale(
s,
self.saturation_attack_scales[0],
self.saturation_attack_scales[1],
) for s in self._predict_and_remap(
self.saturation_attack_detector,
prompts,
"label",
"score",
self.SATURATION_CLASSIFIER_PASS_LABEL,
self.SATURATION_CLASSIFIER_FAIL_LABEL,
)
]
def predict_jailbreak(
self,
prompts: List[str],
reduction_function: Optional[Callable] = max,
) -> Union[List[float], List[dict]]:
"""predict_jailbreak will return an array of floats by default, one per prompt.
If reduction_function is set to 'none' it will return a dict with the different
sub-validators and their scores. Useful for debugging and tuning."""
if isinstance(prompts, str):
print("WARN: predict_jailbreak should be called with a list of strings.")
prompts = [prompts, ]
known_attack_scores = self._match_known_malicious_prompts(prompts)
saturation_scores = self._predict_saturation(prompts)
predicted_scores = self._predict_jailbreak(prompts)
if reduction_function is None:
return [{
"known_attack": known,
"saturation_attack": sat,
"other_attack": pred
} for known, sat, pred in zip(
known_attack_scores, saturation_scores, predicted_scores
)]
else:
return [
reduction_function(subscores)
for subscores in
zip(known_attack_scores, saturation_scores, predicted_scores)
]
def validate(
self,
value: Union[str, List[str]],
metadata: Optional[dict] = None,
) -> ValidationResult:
"""Validates that will return a failure if the value is a jailbreak attempt.
If the provided value is a list of strings the validation result will be based
on the maximum injection likelihood. A single validation result will be
returned for all.
"""
if metadata:
pass # Log that this model supports no metadata?
# In the case of a single string, make a one-element list -> one codepath.
if isinstance(value, str):
value = [value, ]
# _inference is to support local/remote. It is equivalent to this:
# scores = self.predict_jailbreak(value)
scores = self._inference(value)
failed_prompts = list()
failed_scores = list() # To help people calibrate their thresholds.
for p, score in zip(value, scores):
if score > self.threshold:
failed_prompts.append(p)
failed_scores.append(score)
if failed_prompts:
failure_message = f"{len(failed_prompts)} detected as potential jailbreaks:"
for txt, score in zip(failed_prompts, failed_scores):
failure_message += f"\n\"{txt}\" (Score: {score})"
return FailResult(
error_message=failure_message
)
return PassResult()
# The rest of these methods are made for validator compatibility and may have some
# strange properties,
def _inference_local(self, model_input: List[str]) -> Any:
return self.predict_jailbreak(model_input)
def _inference_remote(self, model_input: List[str]) -> Any:
# This needs to be kept in-sync with app_inference_spec.
request_body = {"prompts": model_input}
response = self._hub_inference_request(
json.dumps(request_body),
self.validation_endpoint
)
if not response or "scores" not in response:
raise ValueError("Invalid response from remote inference", response)
return response["scores"]