Skip to content

Commit 5fbbf5c

Browse files
Gabefirevbrodsky
andauthored
[PLT-600] Converted SDK to use pydantic V2 (#1738)
Co-authored-by: Val Brodsky <[email protected]>
1 parent d40922e commit 5fbbf5c

File tree

112 files changed

+876
-1062
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

112 files changed

+876
-1062
lines changed

libs/labelbox/pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ description = "Labelbox Python API"
55
authors = [{ name = "Labelbox", email = "[email protected]" }]
66
dependencies = [
77
"google-api-core>=1.22.1",
8-
"pydantic>=1.8",
8+
"pydantic>=2.0",
99
"python-dateutil>=2.8.2, <2.10.0",
1010
"requests>=2.22.0",
1111
"strenum>=0.4.15",

libs/labelbox/src/labelbox/data/annotation_types/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929

3030
from .classification import Checklist
3131
from .classification import ClassificationAnswer
32-
from .classification import Dropdown
3332
from .classification import Radio
3433
from .classification import Text
3534

libs/labelbox/src/labelbox/data/annotation_types/annotation.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from labelbox.data.annotation_types.classification.classification import ClassificationAnnotation
99
from .ner import DocumentEntity, TextEntity, ConversationEntity
10+
from typing import Optional
1011

1112

1213
class ObjectAnnotation(BaseAnnotation, ConfidenceMixin, CustomMetricsMixin):
@@ -29,4 +30,4 @@ class ObjectAnnotation(BaseAnnotation, ConfidenceMixin, CustomMetricsMixin):
2930
"""
3031

3132
value: Union[TextEntity, ConversationEntity, DocumentEntity, Geometry]
32-
classifications: List[ClassificationAnnotation] = []
33+
classifications: Optional[List[ClassificationAnnotation]] = []

libs/labelbox/src/labelbox/data/annotation_types/base_annotation.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
import abc
22
from uuid import UUID, uuid4
33
from typing import Any, Dict, Optional
4-
from labelbox import pydantic_compat
54

65
from .feature import FeatureSchema
6+
from pydantic import PrivateAttr, ConfigDict
77

88

99
class BaseAnnotation(FeatureSchema, abc.ABC):
1010
""" Base annotation class. Shouldn't be directly instantiated
1111
"""
12-
_uuid: Optional[UUID] = pydantic_compat.PrivateAttr()
12+
_uuid: Optional[UUID] = PrivateAttr()
1313
extra: Dict[str, Any] = {}
14+
15+
model_config = ConfigDict(extra="allow")
1416

1517
def __init__(self, **data):
1618
super().__init__(**data)
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
from .classification import (Checklist, ClassificationAnswer, Dropdown, Radio,
1+
from .classification import (Checklist, ClassificationAnswer, Radio,
22
Text)
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,12 @@
11
from typing import Any, Dict, List, Union, Optional
2-
import warnings
32
from labelbox.data.annotation_types.base_annotation import BaseAnnotation
43

54
from labelbox.data.mixins import ConfidenceMixin, CustomMetricsMixin
65

7-
try:
8-
from typing import Literal
9-
except:
10-
from typing_extensions import Literal
11-
12-
from labelbox import pydantic_compat
6+
from pydantic import BaseModel
137
from ..feature import FeatureSchema
148

159

16-
# TODO: Replace when pydantic adds support for unions that don't coerce types
17-
class _TempName(ConfidenceMixin, pydantic_compat.BaseModel):
18-
name: str
19-
20-
def dict(self, *args, **kwargs):
21-
res = super().dict(*args, **kwargs)
22-
res.pop('name')
23-
return res
24-
25-
2610
class ClassificationAnswer(FeatureSchema, ConfidenceMixin, CustomMetricsMixin):
2711
"""
2812
- Represents a classification option.
@@ -36,18 +20,10 @@ class ClassificationAnswer(FeatureSchema, ConfidenceMixin, CustomMetricsMixin):
3620
"""
3721
extra: Dict[str, Any] = {}
3822
keyframe: Optional[bool] = None
39-
classifications: List['ClassificationAnnotation'] = []
23+
classifications: Optional[List['ClassificationAnnotation']] = None
4024

41-
def dict(self, *args, **kwargs) -> Dict[str, str]:
42-
res = super().dict(*args, **kwargs)
43-
if res['keyframe'] is None:
44-
res.pop('keyframe')
45-
if res['classifications'] == []:
46-
res.pop('classifications')
47-
return res
4825

49-
50-
class Radio(ConfidenceMixin, CustomMetricsMixin, pydantic_compat.BaseModel):
26+
class Radio(ConfidenceMixin, CustomMetricsMixin, BaseModel):
5127
""" A classification with only one selected option allowed
5228
5329
>>> Radio(answer = ClassificationAnswer(name = "dog"))
@@ -56,17 +32,16 @@ class Radio(ConfidenceMixin, CustomMetricsMixin, pydantic_compat.BaseModel):
5632
answer: ClassificationAnswer
5733

5834

59-
class Checklist(_TempName):
35+
class Checklist(ConfidenceMixin, BaseModel):
6036
""" A classification with many selected options allowed
6137
6238
>>> Checklist(answer = [ClassificationAnswer(name = "cloudy")])
6339
6440
"""
65-
name: Literal["checklist"] = "checklist"
6641
answer: List[ClassificationAnswer]
6742

6843

69-
class Text(ConfidenceMixin, CustomMetricsMixin, pydantic_compat.BaseModel):
44+
class Text(ConfidenceMixin, CustomMetricsMixin, BaseModel):
7045
""" Free form text
7146
7247
>>> Text(answer = "some text answer")
@@ -75,24 +50,6 @@ class Text(ConfidenceMixin, CustomMetricsMixin, pydantic_compat.BaseModel):
7550
answer: str
7651

7752

78-
class Dropdown(_TempName):
79-
"""
80-
- A classification with many selected options allowed .
81-
- This is not currently compatible with MAL.
82-
83-
Deprecation Notice: Dropdown classification is deprecated and will be
84-
removed in a future release. Dropdown will also
85-
no longer be able to be created in the Editor on 3/31/2022.
86-
"""
87-
name: Literal["dropdown"] = "dropdown"
88-
answer: List[ClassificationAnswer]
89-
90-
def __init__(self, **data: Any):
91-
super().__init__(**data)
92-
warnings.warn("Dropdown classification is deprecated and will be "
93-
"removed in a future release")
94-
95-
9653
class ClassificationAnnotation(BaseAnnotation, ConfidenceMixin,
9754
CustomMetricsMixin):
9855
"""Classification annotations (non localized)
@@ -106,12 +63,9 @@ class ClassificationAnnotation(BaseAnnotation, ConfidenceMixin,
10663
name (Optional[str])
10764
classifications (Optional[List[ClassificationAnnotation]]): Optional sub classification of the annotation
10865
feature_schema_id (Optional[Cuid])
109-
value (Union[Text, Checklist, Radio, Dropdown])
66+
value (Union[Text, Checklist, Radio])
11067
extra (Dict[str, Any])
11168
"""
11269

113-
value: Union[Text, Checklist, Radio, Dropdown]
70+
value: Union[Text, Checklist, Radio]
11471
message_id: Optional[str] = None
115-
116-
117-
ClassificationAnswer.update_forward_refs()

libs/labelbox/src/labelbox/data/annotation_types/data/base_data.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from abc import ABC
22
from typing import Optional, Dict, List, Any
33

4-
from labelbox import pydantic_compat
4+
from pydantic import BaseModel
55

66

7-
class BaseData(pydantic_compat.BaseModel, ABC):
7+
class BaseData(BaseModel, ABC):
88
"""
99
Base class for objects representing data.
1010
This class shouldn't directly be used

libs/labelbox/src/labelbox/data/annotation_types/data/conversation.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@
33
from .base_data import BaseData
44

55

6-
class ConversationData(BaseData):
7-
class_name: Literal["ConversationData"] = "ConversationData"
6+
class ConversationData(BaseData, _NoCoercionMixin):
7+
class_name: Literal["ConversationData"] = "ConversationData"

libs/labelbox/src/labelbox/data/annotation_types/data/generic_data_row_data.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from typing import Callable, Literal, Optional
22

3-
from labelbox import pydantic_compat
43
from labelbox.data.annotation_types.data.base_data import BaseData
54
from labelbox.utils import _NoCoercionMixin
5+
from pydantic import model_validator
66

77

88
class GenericDataRowData(BaseData, _NoCoercionMixin):
@@ -14,7 +14,8 @@ class GenericDataRowData(BaseData, _NoCoercionMixin):
1414
def create_url(self, signer: Callable[[bytes], str]) -> Optional[str]:
1515
return self.url
1616

17-
@pydantic_compat.root_validator(pre=True)
17+
@model_validator(mode="before")
18+
@classmethod
1819
def validate_one_datarow_key_present(cls, data):
1920
keys = ['external_id', 'global_key', 'uid']
2021
count = sum([key in data for key in keys])

libs/labelbox/src/labelbox/data/annotation_types/data/raster.py

+15-18
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,22 @@
99
import requests
1010
import numpy as np
1111

12-
from labelbox import pydantic_compat
12+
from pydantic import BaseModel, model_validator, ConfigDict
1313
from labelbox.exceptions import InternalServerError
1414
from .base_data import BaseData
1515
from ..types import TypedArray
1616

1717

18-
class RasterData(pydantic_compat.BaseModel, ABC):
18+
class RasterData(BaseModel, ABC):
1919
"""Represents an image or segmentation mask.
2020
"""
2121
im_bytes: Optional[bytes] = None
2222
file_path: Optional[str] = None
2323
url: Optional[str] = None
24+
uid: Optional[str] = None
25+
global_key: Optional[str] = None
2426
arr: Optional[TypedArray[Literal['uint8']]] = None
27+
model_config = ConfigDict(extra="forbid", copy_on_model_validation="none")
2528

2629
@classmethod
2730
def from_2D_arr(cls, arr: Union[TypedArray[Literal['uint8']],
@@ -155,14 +158,14 @@ def create_url(self, signer: Callable[[bytes], str]) -> str:
155158
"One of url, im_bytes, file_path, arr must not be None.")
156159
return self.url
157160

158-
@pydantic_compat.root_validator()
159-
def validate_args(cls, values):
160-
file_path = values.get("file_path")
161-
im_bytes = values.get("im_bytes")
162-
url = values.get("url")
163-
arr = values.get("arr")
164-
uid = values.get('uid')
165-
global_key = values.get('global_key')
161+
@model_validator(mode="after")
162+
def validate_args(self, values):
163+
file_path = self.file_path
164+
im_bytes = self.im_bytes
165+
url = self.url
166+
arr = self.arr
167+
uid = self.uid
168+
global_key = self.global_key
166169
if uid == file_path == im_bytes == url == global_key == None and arr is None:
167170
raise ValueError(
168171
"One of `file_path`, `im_bytes`, `url`, `uid`, `global_key` or `arr` required."
@@ -175,8 +178,8 @@ def validate_args(cls, values):
175178
elif len(arr.shape) != 3:
176179
raise ValueError(
177180
"unsupported image format. Must be 3D ([H,W,C])."
178-
f"Use {cls.__name__}.from_2D_arr to construct from 2D")
179-
return values
181+
f"Use {self.__name__}.from_2D_arr to construct from 2D")
182+
return self
180183

181184
def __repr__(self) -> str:
182185
symbol_or_none = lambda data: '...' if data is not None else None
@@ -185,12 +188,6 @@ def __repr__(self) -> str:
185188
f"url={self.url}," \
186189
f"arr={symbol_or_none(self.arr)})"
187190

188-
class Config:
189-
# Required for sharing references
190-
copy_on_model_validation = 'none'
191-
# Required for discriminating between data types
192-
extra = 'forbid'
193-
194191

195192
class MaskData(RasterData):
196193
"""Used to represent a segmentation Mask

libs/labelbox/src/labelbox/data/annotation_types/data/text.py

+11-14
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from requests.exceptions import ConnectTimeout
55
from google.api_core import retry
66

7-
from labelbox import pydantic_compat
7+
from pydantic import ConfigDict, model_validator
88
from labelbox.exceptions import InternalServerError
99
from labelbox.typing_imports import Literal
1010
from labelbox.utils import _NoCoercionMixin
@@ -26,6 +26,7 @@ class TextData(BaseData, _NoCoercionMixin):
2626
file_path: Optional[str] = None
2727
text: Optional[str] = None
2828
url: Optional[str] = None
29+
model_config = ConfigDict(extra="forbid")
2930

3031
@property
3132
def value(self) -> str:
@@ -64,7 +65,7 @@ def fetch_remote(self) -> str:
6465
"""
6566
response = requests.get(self.url)
6667
if response.status_code in [500, 502, 503, 504]:
67-
raise labelbox.exceptions.InternalServerError(response.text)
68+
raise InternalServerError(response.text)
6869
response.raise_for_status()
6970
return response.text
7071

@@ -90,24 +91,20 @@ def create_url(self, signer: Callable[[bytes], str]) -> None:
9091
"One of url, im_bytes, file_path, numpy must not be None.")
9192
return self.url
9293

93-
@pydantic_compat.root_validator
94-
def validate_date(cls, values):
95-
file_path = values.get("file_path")
96-
text = values.get("text")
97-
url = values.get("url")
98-
uid = values.get('uid')
99-
global_key = values.get('global_key')
94+
@model_validator(mode="after")
95+
def validate_date(self, values):
96+
file_path = self.file_path
97+
text = self.text
98+
url = self.url
99+
uid = self.uid
100+
global_key = self.global_key
100101
if uid == file_path == text == url == global_key == None:
101102
raise ValueError(
102103
"One of `file_path`, `text`, `uid`, `global_key` or `url` required."
103104
)
104-
return values
105+
return self
105106

106107
def __repr__(self) -> str:
107108
return f"TextData(file_path={self.file_path}," \
108109
f"text={self.text[:30] + '...' if self.text is not None else None}," \
109110
f"url={self.url})"
110-
111-
class config:
112-
# Required for discriminating between data types
113-
extra = 'forbid'

0 commit comments

Comments
 (0)