Skip to content

do not allow remote inference #8

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/pr_qc.yml
Original file line number Diff line number Diff line change
@@ -16,4 +16,5 @@ jobs:
- name: Run qa
run: |
pip install ".[dev]"
python validator/post-install.py
make qa
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -5,12 +5,12 @@ lint:
ruff check .

test:
pytest ./tests
pytest tests/*

type:
pyright validator

qa:
make lint
make type
make tests
make test
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -14,7 +14,8 @@ dependencies = [
"tf-keras",
"sentencepiece",
"tensorflow>=2.16.0", # Required for the dbias model, but not as a direct dependency.
"sentence-splitter>=1.4"
"sentence-splitter>=1.4",
"torch"
]

[project.optional-dependencies]
2 changes: 1 addition & 1 deletion tests/test_validator.py
Original file line number Diff line number Diff line change
@@ -23,6 +23,6 @@ def test_failure_case():
def test_sentence_fix():
v = BiasCheck(on_fail='fix', threshold=0.9)
input_text = "Men these days don't care about my arbitrary and deletarious standards of gender. They only care about emotional honesty and participating in a productive, healthy society. smh"
out = v.validate(input_text)
out = v.validate(input_text, {})
assert isinstance(out, FailResult)
assert out.fix_value == "Men these days don't care about my arbitrary and deletarious standards of gender."
33 changes: 18 additions & 15 deletions validator/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union, TypedDict

from guardrails.validator_base import (
FailResult,
@@ -9,7 +9,7 @@
)
from guardrails.types import OnFailAction
from sentence_splitter import split_text_into_sentences
from transformers import pipeline
from transformers.pipelines import pipeline


@register_validator(name="guardrails/bias_check", data_type="string")
@@ -32,14 +32,9 @@ class BiasCheck(Validator):
def __init__(
self,
threshold: float = 0.9,
on_fail: Optional[Union[str, Callable]] = None,
**kwargs,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this would be a breaking change since on_fail is going fro a positional argument to a keyword only argument.

):
super().__init__(on_fail=on_fail) # type: ignore
valid_on_fail_operations = {"fix", "noop", "exception"}
if isinstance(on_fail, str) and on_fail not in valid_on_fail_operations:
raise Exception(
f"on_fail value ({on_fail}) not in list of allowable operations: {valid_on_fail_operations}"
)
super().__init__(**kwargs)
self.threshold = threshold

# There are some spurious loading complaints with TFDistilBert models.
@@ -50,7 +45,10 @@ def __init__(
tokenizer="d4data/bias-detection-model",
)

def validate(
def validate(self, value: Any, metadata: Dict[str, Any] = {}) -> ValidationResult:
return super().validate(value, metadata)

def _validate(
self,
value: Union[str, List[str]],
metadata: Optional[Dict] = None
@@ -61,7 +59,7 @@ def validate(
single_sentence_passed = True
value = [value,] # Ensure we're always passing lists of strings into the classifier.

scores = self._inference(value)
scores = self._inference_local(value)
passing_outputs = list()
passing_scores = list()
failing_outputs = list()
@@ -106,7 +104,7 @@ def fix_passage(self, text: str) -> str:
then recombine them and return a new paragraph. May not preserve whitespace
between sentences."""
sentences = split_text_into_sentences(text, language='en')
scores = self._inference(sentences)
scores = self._inference_local(sentences)
unbiased_sentences = list()
for score, sentence in zip(scores, sentences):
if score < self.threshold:
@@ -117,10 +115,10 @@ def fix_passage(self, text: str) -> str:
# Remote inference is unsupported for this model on account of the NER.
def _inference_local(self, sentences: List[str]) -> List[float]: # type: ignore
scores = list()
predictions = self.classification_model(sentences)
predictions: List[PipelinePrediction] = self.classification_model(sentences) # type: ignore
for pred in predictions:
label = pred['label'] # type: ignore
score = pred['score'] # type: ignore
label = pred['label']
score = pred['score']
Comment on lines +118 to +121

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙏

if label == 'Biased':
scores.append(score)
elif label == 'Non-biased':
@@ -129,3 +127,8 @@ def _inference_local(self, sentences: List[str]) -> List[float]: # type: ignore
# This should never happen:
raise Exception("Unexpected prediction label: {}".format(label))
return scores

# Define the type for pipeline predictions
class PipelinePrediction(TypedDict):
label: str
score: float
2 changes: 1 addition & 1 deletion validator/post-install.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from transformers import pipeline
from transformers.pipelines import pipeline
print("post-install starting...")
_ = pipeline(
'text-classification',