Skip to content

feat: refactor validators and fields to support pydantic v2 #967

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions nemoguardrails/eval/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import os
from typing import Any, Dict, List, Optional, Union

from pydantic import BaseModel, Field, root_validator
from pydantic import BaseModel, Field, model_validator

from nemoguardrails.eval.utils import load_dict_from_path
from nemoguardrails.logging.explain import LLMCallInfo
Expand Down Expand Up @@ -107,7 +107,7 @@ class InteractionSet(BaseModel):
description="A list of tags that should be associated with the interactions. Useful for filtering when reporting.",
)

@root_validator(pre=True)
@model_validator(mode="before")
def instantiate_expected_output(cls, values: Any):
"""Creates the right instance of the expected output."""
type_mapping = {
Expand Down Expand Up @@ -147,11 +147,11 @@ class EvalConfig(BaseModel):
description="The prompts that should be used for the various LLM tasks.",
)

@root_validator(pre=False, skip_on_failure=True)
def validate_policy_ids(cls, values: Any):
@model_validator(mode="after")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be

    @model_validator(mode="after")
    def validate_policy_ids(cls, self) -> "EvalConfig":
        """Validates the policy ids used in the interactions."""
        policy_ids = {policy.id for policy in self.policies}
        for interaction_set in self.interactions:
            for expected_output in interaction_set.expected_output:
                if expected_output.policy not in policy_ids:
                    raise ValueError(
                        f"Invalid policy id {expected_output.policy} used in interaction set."
                    )
            for policy_id in (
                interaction_set.include_policies + interaction_set.exclude_policies
            ):
                if policy_id not in policy_ids:
                    raise ValueError(
                        f"Invalid policy id {policy_id} used in interaction set."
                    )
        return self

def validate_policy_ids(cls, values: "EvalConfig") -> "EvalConfig":
"""Validates the policy ids used in the interactions."""
policy_ids = {policy.id for policy in values.get("policies")}
for interaction_set in values.get("interactions"):
policy_ids = {policy.id for policy in values.policies}
for interaction_set in values.interactions:
for expected_output in interaction_set.expected_output:
if expected_output.policy not in policy_ids:
raise ValueError(
Expand Down Expand Up @@ -180,7 +180,7 @@ def from_path(
else:
raise ValueError(f"Invalid config path {config_path}.")

return cls.parse_obj(config_obj)
return cls.model_validate(config_obj)


class ComplianceCheckLog(BaseModel):
Expand Down Expand Up @@ -361,4 +361,4 @@ def from_path(
else:
raise ValueError(f"Invalid config path {output_path}.")

return cls.parse_obj(output_obj)
return cls.model_validate(output_obj)
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# The minimal set of requirements for the AlignScore server to run.
pydantic>=1.10
pydantic>=2.10.6
fastapi>=0.109.1
starlette>=0.36.2
typer>=0.7.0
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# The minimal set of requirements for the jailbreak detection server to run.
pydantic>=1.10.9
pydantic>=2.10.6
fastapi>=0.103.1
starlette>=0.27.0
typer>=0.7.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatResult
from langchain_nvidia_ai_endpoints import ChatNVIDIA as ChatNVIDIAOriginal
from pydantic.v1 import Field
from pydantic import Field

log = logging.getLogger(__name__)

Expand Down
4 changes: 2 additions & 2 deletions nemoguardrails/llm/providers/nemollm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from langchain.schema import Generation
from langchain.schema.output import GenerationChunk, LLMResult
from langchain_core.language_models.llms import BaseLLM
from pydantic.v1 import root_validator
from pydantic import model_validator

log = logging.getLogger(__name__)

Expand All @@ -52,7 +52,7 @@ class NeMoLLM(BaseLLM):
streaming: bool = False
check_api_host_version: bool = True

@root_validator(pre=True, allow_reuse=True)
@model_validator(mode="before")
def check_env_variables(cls, values):
for field in ["api_host", "api_key", "organization_id"]:
# If it's an explicit environment variable, we use that
Expand Down
8 changes: 4 additions & 4 deletions nemoguardrails/llm/providers/trtllm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.llms import BaseLLM
from pydantic.v1 import Field, root_validator
from pydantic import Field, model_validator

from nemoguardrails.llm.providers.trtllm.client import TritonClient

Expand Down Expand Up @@ -61,12 +61,12 @@ class TRTLLM(BaseLLM):
client: Any
streaming: Optional[bool] = True

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be

    @classmethod
    def validate_environment(cls, self) -> "TRTLLM":
        """Validate that python package exists in environment."""
        try:
            # instantiate and attach the client
            self.client = TritonClient(self.server_url)

        except ImportError as err:
            raise ImportError(
                "Could not import triton client python package. "
                "Please install it with `pip install tritonclient[all]`."
            ) from err
        return self

@root_validator(allow_reuse=True)
@model_validator(mode="after")
@classmethod
def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]:
def validate_environment(cls, values: "TRTLLM") -> "TRTLLM":
"""Validate that python package exists in environment."""
try:
values["client"] = TritonClient(values["server_url"])
values.client = TritonClient(values.server_url)

except ImportError as err:
raise ImportError(
Expand Down
22 changes: 10 additions & 12 deletions nemoguardrails/rails/llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
from typing import Any, Dict, List, Optional, Set, Tuple, Union

import yaml
from pydantic import BaseModel, ConfigDict, ValidationError, root_validator
from pydantic.fields import Field
from pydantic import BaseModel, ConfigDict, Field, ValidationError, model_validator

from nemoguardrails import utils
from nemoguardrails.colang import parse_colang_file, parse_flow_elements
Expand Down Expand Up @@ -253,7 +252,7 @@ class TaskPrompt(BaseModel):
description="The maximum number of tokens that can be generated in the chat completion.",
)

@root_validator(pre=True, allow_reuse=True)
@model_validator(mode="before")
def check_fields(cls, values):
if not values.get("content") and not values.get("messages"):
raise ValidationError("One of `content` or `messages` must be provided.")
Expand Down Expand Up @@ -947,16 +946,15 @@ class RailsConfig(BaseModel):
description="The list of bot messages that should be used for the rails.",
)

# NOTE: the Any below is used to get rid of a warning with pydantic 1.10.x;
# The correct typing should be List[Dict, Flow]. To be updated when
# support for pydantic 1.10.x is dropped.
flows: List[Union[Dict, Any]] = Field(
flows: List[Union[Dict, Flow]] = Field(
default_factory=list,
description="The list of flows that should be used for the rails.",
)

instructions: Optional[List[Instruction]] = Field(
default=[Instruction.parse_obj(obj) for obj in _default_config["instructions"]],
default=[
Instruction.model_validate(obj) for obj in _default_config["instructions"]
],
description="List of instructions in natural language that the LLM should use.",
)

Expand Down Expand Up @@ -1061,7 +1059,7 @@ class RailsConfig(BaseModel):
description="Configuration for tracing.",
)

@root_validator(pre=True, allow_reuse=True)
@model_validator(mode="before")
def check_prompt_exist_for_self_check_rails(cls, values):
rails = values.get("rails", {})

Expand Down Expand Up @@ -1115,7 +1113,7 @@ def check_prompt_exist_for_self_check_rails(cls, values):

return values

@root_validator(pre=True, allow_reuse=True)
@model_validator(mode="before")
def check_output_parser_exists(cls, values):
tasks_requiring_output_parser = [
"self_check_input",
Expand Down Expand Up @@ -1148,7 +1146,7 @@ def check_output_parser_exists(cls, values):
)
return values

@root_validator(pre=True, allow_reuse=True)
@model_validator(mode="before")
def fill_in_default_values_for_v2_x(cls, values):
instructions = values.get("instructions", {})
sample_conversation = values.get("sample_conversation")
Expand Down Expand Up @@ -1277,7 +1275,7 @@ def parse_object(cls, obj):
):
flow_data["elements"] = parse_flow_elements(flow_data["elements"])

return cls.parse_obj(obj)
return cls.model_validate(obj)

@property
def streaming_supported(self):
Expand Down
4 changes: 2 additions & 2 deletions nemoguardrails/rails/llm/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
"""
from typing import Any, Dict, List, Optional, Union

from pydantic import BaseModel, Field, root_validator
from pydantic import BaseModel, Field, model_validator

from nemoguardrails.logging.explain import LLMCallInfo, LLMCallSummary

Expand Down Expand Up @@ -168,7 +168,7 @@ class GenerationOptions(BaseModel):
description="Options about what to include in the log. By default, nothing is included. ",
)

@root_validator(pre=True, allow_reuse=True)
@model_validator(mode="before")
def check_fields(cls, values):
# Translate the `rails` generation option from List[str] to dict.
if "rails" in values and isinstance(values["rails"], list):
Expand Down
14 changes: 9 additions & 5 deletions nemoguardrails/server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field, root_validator, validator
from pydantic import BaseModel, Field, field_validator, model_validator
from starlette.responses import StreamingResponse
from starlette.staticfiles import StaticFiles

Expand Down Expand Up @@ -208,7 +208,7 @@ class RequestBody(BaseModel):
description="A state object that should be used to continue the interaction.",
)

@root_validator(pre=True)
@model_validator(mode="before")
def ensure_config_id(cls, data: Any) -> Any:
if isinstance(data, dict):
if data.get("config_id") is not None and data.get("config_ids") is not None:
Expand All @@ -223,11 +223,15 @@ def ensure_config_id(cls, data: Any) -> Any:
)
return data

@validator("config_ids", pre=True, always=True)
@field_validator("config_ids", mode="before")
def ensure_config_ids(cls, v, values):
if v is None and values.get("config_id") and values.get("config_ids") is None:
if (
v is None
and values.data.get("config_id")
and values.data.get("config_ids") is None
):
# populate config_ids with config_id if only config_id is provided
return [values["config_id"]]
return [values.data["config_id"]]
return v


Expand Down
30 changes: 1 addition & 29 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ langchain-community = ">=0.0.16,<0.4.0"
lark = ">=1.1.7"
nest-asyncio = ">=1.5.6,"
prompt-toolkit = ">=3.0"
pydantic = ">=1.10"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should relax it a bit, probably >= 2.0

And I still need to see the reasons we were still supporting Pydantic 1.0, I'll get back to this later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed the extra log.text and hidden file, changed the version to >= 2.0, and signed the commit, thanks...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and <3.0

pydantic = ">=2.0"
pyyaml = ">=6.0"
rich = ">=13.5.2"
simpleeval = ">=0.9.13,"
Expand Down