From ac41cb54c6f1155c9cf6f3e22dafd29ce053fad9 Mon Sep 17 00:00:00 2001 From: Zayd Simjee Date: Tue, 15 Apr 2025 17:00:28 -0700 Subject: [PATCH 1/4] do not allow remote inference --- tests/test_validator.py | 2 +- validator/main.py | 18 ++++++++---------- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/tests/test_validator.py b/tests/test_validator.py index ace4601..43d8a47 100644 --- a/tests/test_validator.py +++ b/tests/test_validator.py @@ -23,6 +23,6 @@ def test_failure_case(): def test_sentence_fix(): v = BiasCheck(on_fail='fix', threshold=0.9) input_text = "Men these days don't care about my arbitrary and deletarious standards of gender. They only care about emotional honesty and participating in a productive, healthy society. smh" - out = v.validate(input_text) + out = v.validate(input_text, {}) assert isinstance(out, FailResult) assert out.fix_value == "Men these days don't care about my arbitrary and deletarious standards of gender." diff --git a/validator/main.py b/validator/main.py index 051238f..636935f 100644 --- a/validator/main.py +++ b/validator/main.py @@ -1,4 +1,4 @@ -from typing import Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union from guardrails.validator_base import ( FailResult, @@ -32,14 +32,9 @@ class BiasCheck(Validator): def __init__( self, threshold: float = 0.9, - on_fail: Optional[Union[str, Callable]] = None, + **kwargs, ): - super().__init__(on_fail=on_fail) # type: ignore - valid_on_fail_operations = {"fix", "noop", "exception"} - if isinstance(on_fail, str) and on_fail not in valid_on_fail_operations: - raise Exception( - f"on_fail value ({on_fail}) not in list of allowable operations: {valid_on_fail_operations}" - ) + super().__init__(**kwargs) self.threshold = threshold # There are some spurious loading complaints with TFDistilBert models. @@ -50,7 +45,10 @@ def __init__( tokenizer="d4data/bias-detection-model", ) - def validate( + def validate(self, value: Any, metadata: Dict[str, Any] = {}) -> ValidationResult: + return super().validate(value, metadata) + + def _validate( self, value: Union[str, List[str]], metadata: Optional[Dict] = None @@ -61,7 +59,7 @@ def validate( single_sentence_passed = True value = [value,] # Ensure we're always passing lists of strings into the classifier. - scores = self._inference(value) + scores = self._inference_local(value) passing_outputs = list() passing_scores = list() failing_outputs = list() From 9a9311cab13c377c89e12d705b709a22551d8c1e Mon Sep 17 00:00:00 2001 From: Zayd Simjee Date: Tue, 15 Apr 2025 17:10:47 -0700 Subject: [PATCH 2/4] remove unused types --- validator/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/validator/main.py b/validator/main.py index 636935f..d7f7f64 100644 --- a/validator/main.py +++ b/validator/main.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union from guardrails.validator_base import ( FailResult, From dfd80784d4a80204a3b050686594956b8d9d0eb6 Mon Sep 17 00:00:00 2001 From: Zayd Simjee Date: Tue, 15 Apr 2025 17:35:40 -0700 Subject: [PATCH 3/4] fix make, actions --- .github/workflows/pr_qc.yml | 1 + Makefile | 4 ++-- pyproject.toml | 3 ++- validator/main.py | 15 ++++++++++----- validator/post-install.py | 2 +- 5 files changed, 16 insertions(+), 9 deletions(-) diff --git a/.github/workflows/pr_qc.yml b/.github/workflows/pr_qc.yml index 7a46efe..0d6e748 100644 --- a/.github/workflows/pr_qc.yml +++ b/.github/workflows/pr_qc.yml @@ -16,4 +16,5 @@ jobs: - name: Run qa run: | pip install ".[dev]" + python validator/post-install.py make qa diff --git a/Makefile b/Makefile index 414e178..7ea2ab2 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,7 @@ lint: ruff check . test: - pytest ./tests + pytest tests/* type: pyright validator @@ -13,4 +13,4 @@ type: qa: make lint make type - make tests \ No newline at end of file + make test \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 96662b6..17a0b17 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,8 @@ dependencies = [ "tf-keras", "sentencepiece", "tensorflow>=2.16.0", # Required for the dbias model, but not as a direct dependency. - "sentence-splitter>=1.4" + "sentence-splitter>=1.4", + "torch" ] [project.optional-dependencies] diff --git a/validator/main.py b/validator/main.py index d7f7f64..787ba0f 100644 --- a/validator/main.py +++ b/validator/main.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union, TypedDict from guardrails.validator_base import ( FailResult, @@ -9,7 +9,7 @@ ) from guardrails.types import OnFailAction from sentence_splitter import split_text_into_sentences -from transformers import pipeline +from transformers.pipelines import pipeline @register_validator(name="guardrails/bias_check", data_type="string") @@ -115,10 +115,10 @@ def fix_passage(self, text: str) -> str: # Remote inference is unsupported for this model on account of the NER. def _inference_local(self, sentences: List[str]) -> List[float]: # type: ignore scores = list() - predictions = self.classification_model(sentences) + predictions: List[PipelinePrediction] = self.classification_model(sentences) # type: ignore for pred in predictions: - label = pred['label'] # type: ignore - score = pred['score'] # type: ignore + label = pred['label'] + score = pred['score'] if label == 'Biased': scores.append(score) elif label == 'Non-biased': @@ -127,3 +127,8 @@ def _inference_local(self, sentences: List[str]) -> List[float]: # type: ignore # This should never happen: raise Exception("Unexpected prediction label: {}".format(label)) return scores + +# Define the type for pipeline predictions +class PipelinePrediction(TypedDict): + label: str + score: float diff --git a/validator/post-install.py b/validator/post-install.py index 26ec5ef..7388d43 100644 --- a/validator/post-install.py +++ b/validator/post-install.py @@ -1,4 +1,4 @@ -from transformers import pipeline +from transformers.pipelines import pipeline print("post-install starting...") _ = pipeline( 'text-classification', From a44ebe3faf7299bf24885c53cef1447d9501490f Mon Sep 17 00:00:00 2001 From: Zayd Simjee Date: Thu, 24 Apr 2025 17:03:55 -0700 Subject: [PATCH 4/4] knock off the last _inference call --- validator/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/validator/main.py b/validator/main.py index 787ba0f..a9b895f 100644 --- a/validator/main.py +++ b/validator/main.py @@ -104,7 +104,7 @@ def fix_passage(self, text: str) -> str: then recombine them and return a new paragraph. May not preserve whitespace between sentences.""" sentences = split_text_into_sentences(text, language='en') - scores = self._inference(sentences) + scores = self._inference_local(sentences) unbiased_sentences = list() for score, sentence in zip(scores, sentences): if score < self.threshold: