Skip to content

Commit cf4170a

Browse files
committed
Add test for declared_attr
1 parent a2db216 commit cf4170a

File tree

3 files changed

+220
-46
lines changed

3 files changed

+220
-46
lines changed

docs/customizing.rst

+41-42
Original file line numberDiff line numberDiff line change
@@ -21,29 +21,26 @@ joined-table inheritance.
2121

2222
.. code-block:: python
2323
24-
from flask_sqlalchemy.model import Model
25-
import sqlalchemy as sa
26-
import sqlalchemy.orm
24+
from sqlalchemy import Integer, String, ForeignKey
25+
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, declared_attr
2726
28-
class IdModel(Model):
29-
@sa.orm.declared_attr
27+
class Base(DeclarativeBase):
28+
@declared_attr.cascading
29+
@classmethod
3030
def id(cls):
3131
for base in cls.__mro__[1:-1]:
3232
if getattr(base, "__table__", None) is not None:
33-
type = sa.ForeignKey(base.id)
34-
break
35-
else:
36-
type = sa.Integer
33+
return mapped_column(ForeignKey(base.id), primary_key=True)
34+
else:
35+
return mapped_column(Integer, primary_key=True)
3736
38-
return sa.Column(type, primary_key=True)
39-
40-
db = SQLAlchemy(model_class=IdModel)
37+
db = SQLAlchemy(app, model_class=Base)
4138
4239
class User(db.Model):
43-
name = db.Column(db.String)
40+
name: Mapped[str] = mapped_column(String)
4441
4542
class Employee(User):
46-
title = db.Column(db.String)
43+
title: Mapped[str] = mapped_column(String)
4744
4845
4946
Abstract Models and Mixins
@@ -56,31 +53,49 @@ they are created or updated.
5653
.. code-block:: python
5754
5855
from datetime import datetime
56+
from sqlalchemy import DateTime, Integer, String
57+
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, declared_attr
5958
6059
class TimestampModel(db.Model):
6160
__abstract__ = True
62-
created: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, default=datetime.utcnow)
63-
updated: Mapped[datetime] = mapped_column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
61+
created: Mapped[datetime] = mapped_column(DateTime, nullable=False, default=datetime.utcnow)
62+
updated: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
6463
6564
class Author(db.Model):
66-
id: Mapped[int] = mapped_column(db.Integer, primary_key=True)
67-
username: Mapped[str] = mapped_column(db.String, unique=True, nullable=False)
65+
id: Mapped[int] = mapped_column(Integer, primary_key=True)
66+
username: Mapped[str] = mapped_column(String, unique=True, nullable=False)
6867
6968
class Post(TimestampModel):
70-
id: Mapped[int] = mapped_column(db.Integer, primary_key=True)
71-
title: Mapped[str] = mapped_column(db.String, nullable=False)
69+
id: Mapped[int] = mapped_column(Integer, primary_key=True)
70+
title: Mapped[str] = mapped_column(String, nullable=False)
7271
7372
This can also be done with a mixin class, inheriting from ``db.Model`` separately.
7473

7574
.. code-block:: python
7675
77-
class TimestampModel:
78-
created: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, default=datetime.utcnow)
79-
updated: Mapped[datetime] = mapped_column(db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
76+
class TimestampMixin:
77+
created: Mapped[datetime] = mapped_column(DateTime, nullable=False, default=datetime.utcnow)
78+
updated: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
79+
80+
class Post(TimestampMixin, db.Model):
81+
id: Mapped[int] = mapped_column(Integer, primary_key=True)
82+
title: Mapped[str] = mapped_column(String, nullable=False)
83+
8084
81-
class Post2(TimestampModel, db.Model):
82-
id: Mapped[int] = mapped_column(db.Integer, primary_key=True)
83-
title: Mapped[str] = mapped_column(db.String, nullable=False)
85+
Disabling Table Name Generation
86+
-------------------------------
87+
88+
Some projects prefer to set each model's ``__tablename__`` manually rather than relying
89+
on Flask-SQLAlchemy's detection and generation. The simple way to achieve that is to
90+
set each ``__tablename__`` and not modify the base class. However, the table name
91+
generation can be disabled by setting `disable_autonaming=True` in the `SQLAlchemy` constructor.
92+
93+
.. code-block:: python
94+
95+
class Base(sa_orm.DeclarativeBase):
96+
pass
97+
98+
db = SQLAlchemy(app, model_class=Base, disable_autonaming=True)
8499
85100
86101
Session Class
@@ -161,19 +176,3 @@ To customize only ``session.query``, pass the ``query_cls`` key to the
161176
.. code-block:: python
162177
163178
db = SQLAlchemy(session_options={"query_cls": GetOrQuery})
164-
165-
166-
Disabling Table Name Generation
167-
-------------------------------
168-
169-
Some projects prefer to set each model's ``__tablename__`` manually rather than relying
170-
on Flask-SQLAlchemy's detection and generation. The simple way to achieve that is to
171-
set each ``__tablename__`` and not modify the base class. However, the table name
172-
generation can be disabled by setting `disable_autonaming=True` in the `SQLAlchemy` constructor.
173-
174-
.. code-block:: python
175-
176-
class Base(sa_orm.DeclarativeBase):
177-
pass
178-
179-
db = SQLAlchemy(app, model_class=Base, disable_autonaming=True)

src/flask_sqlalchemy/extension.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -540,10 +540,8 @@ def _make_declarative_base(
540540
" Got: {}".format(model_class.__bases__)
541541
)
542542
elif len(declarative_bases) == 1:
543-
body = {
544-
"__fsa__": self,
545-
"metadata": model_class.metadata, # type: ignore[attr-defined]
546-
}
543+
body = dict(model_class.__dict__) # type: ignore[arg-type]
544+
body["__fsa__"] = self
547545
mixin_classes = [BindMixin, NameMixin, Model]
548546
if disable_autonaming:
549547
mixin_classes.remove(NameMixin)

tests/test_model.py

+177
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import typing as t
4+
from datetime import datetime
45

56
import pytest
67
import sqlalchemy as sa
@@ -80,6 +81,182 @@ class Base(sa_orm.DeclarativeBaseNoMeta, sa_orm.MappedAsDataclass):
8081
assert isinstance(db.Model, sa_orm.decl_api.DCTransformDeclarative)
8182

8283

84+
@pytest.mark.usefixtures("app_ctx")
85+
def test_declaredattr(app: Flask, model_class: t.Any) -> None:
86+
if model_class is Model:
87+
88+
class IdModel(Model):
89+
@sa.orm.declared_attr
90+
@classmethod
91+
def id(cls: type[Model]): # type: ignore[no-untyped-def]
92+
for base in cls.__mro__[1:-1]:
93+
if getattr(base, "__table__", None) is not None and hasattr(
94+
base, "id"
95+
):
96+
return sa.Column(sa.ForeignKey(base.id), primary_key=True)
97+
return sa.Column(sa.Integer, primary_key=True)
98+
99+
db = SQLAlchemy(app, model_class=IdModel)
100+
101+
class User(db.Model):
102+
name = db.Column(db.String)
103+
104+
class Employee(User):
105+
title = db.Column(db.String)
106+
107+
else:
108+
109+
class Base(sa_orm.DeclarativeBase):
110+
@sa_orm.declared_attr
111+
@classmethod
112+
def id(cls: type[sa_orm.DeclarativeBase]) -> sa_orm.Mapped[int]:
113+
for base in cls.__mro__[1:-1]:
114+
if getattr(base, "__table__", None) is not None and hasattr(
115+
base, "id"
116+
):
117+
return sa_orm.mapped_column(
118+
db.ForeignKey(base.id), primary_key=True
119+
)
120+
return sa_orm.mapped_column(db.Integer, primary_key=True)
121+
122+
db = SQLAlchemy(app, model_class=Base)
123+
124+
class User(db.Model): # type: ignore[no-redef]
125+
name: sa_orm.Mapped[str] = sa_orm.mapped_column(db.String)
126+
127+
class Employee(User): # type: ignore[no-redef]
128+
title: sa_orm.Mapped[str] = sa_orm.mapped_column(db.String)
129+
130+
db.create_all()
131+
db.session.add(Employee(name="Emp Loyee", title="Admin"))
132+
db.session.commit()
133+
user = db.session.execute(db.select(User)).scalar()
134+
employee = db.session.execute(db.select(Employee)).scalar()
135+
assert user is not None
136+
assert employee is not None
137+
assert user.id == 1
138+
assert employee.id == 1
139+
140+
141+
@pytest.mark.usefixtures("app_ctx")
142+
def test_abstractmodel(app: Flask, model_class: t.Any) -> None:
143+
db = SQLAlchemy(app, model_class=model_class)
144+
145+
if issubclass(db.Model, (sa_orm.MappedAsDataclass)):
146+
147+
class TimestampModel(db.Model):
148+
__abstract__ = True
149+
created: sa_orm.Mapped[datetime] = sa_orm.mapped_column(
150+
db.DateTime, nullable=False, insert_default=datetime.utcnow, init=False
151+
)
152+
updated: sa_orm.Mapped[datetime] = sa_orm.mapped_column(
153+
db.DateTime,
154+
insert_default=datetime.utcnow,
155+
onupdate=datetime.utcnow,
156+
init=False,
157+
)
158+
159+
class Post(TimestampModel):
160+
id: sa_orm.Mapped[int] = sa_orm.mapped_column(
161+
db.Integer, primary_key=True, init=False
162+
)
163+
title: sa_orm.Mapped[str] = sa_orm.mapped_column(db.String, nullable=False)
164+
165+
elif issubclass(db.Model, (sa_orm.DeclarativeBase, sa_orm.DeclarativeBaseNoMeta)):
166+
167+
class TimestampModel(db.Model): # type: ignore[no-redef]
168+
__abstract__ = True
169+
created: sa_orm.Mapped[datetime] = sa_orm.mapped_column(
170+
db.DateTime, nullable=False, default=datetime.utcnow
171+
)
172+
updated: sa_orm.Mapped[datetime] = sa_orm.mapped_column(
173+
db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
174+
)
175+
176+
class Post(TimestampModel): # type: ignore[no-redef]
177+
id: sa_orm.Mapped[int] = sa_orm.mapped_column(db.Integer, primary_key=True)
178+
title: sa_orm.Mapped[str] = sa_orm.mapped_column(db.String, nullable=False)
179+
180+
else:
181+
182+
class TimestampModel(db.Model): # type: ignore[no-redef]
183+
__abstract__ = True
184+
created = db.Column(db.DateTime, nullable=False, default=datetime.utcnow)
185+
updated = db.Column(
186+
db.DateTime, onupdate=datetime.utcnow, default=datetime.utcnow
187+
)
188+
189+
class Post(TimestampModel): # type: ignore[no-redef]
190+
id = db.Column(db.Integer, primary_key=True)
191+
title = db.Column(db.String, nullable=False)
192+
193+
db.create_all()
194+
db.session.add(Post(title="Admin Post"))
195+
db.session.commit()
196+
post = db.session.execute(db.select(Post)).scalar()
197+
assert post is not None
198+
assert post.created is not None
199+
assert post.updated is not None
200+
201+
202+
@pytest.mark.usefixtures("app_ctx")
203+
def test_mixinmodel(app: Flask, model_class: t.Any) -> None:
204+
db = SQLAlchemy(app, model_class=model_class)
205+
206+
if issubclass(db.Model, (sa_orm.MappedAsDataclass)):
207+
208+
class TimestampMixin(sa_orm.MappedAsDataclass):
209+
created: sa_orm.Mapped[datetime] = sa_orm.mapped_column(
210+
db.DateTime, nullable=False, insert_default=datetime.utcnow, init=False
211+
)
212+
updated: sa_orm.Mapped[datetime] = sa_orm.mapped_column(
213+
db.DateTime,
214+
insert_default=datetime.utcnow,
215+
onupdate=datetime.utcnow,
216+
init=False,
217+
)
218+
219+
class Post(TimestampMixin, db.Model):
220+
id: sa_orm.Mapped[int] = sa_orm.mapped_column(
221+
db.Integer, primary_key=True, init=False
222+
)
223+
title: sa_orm.Mapped[str] = sa_orm.mapped_column(db.String, nullable=False)
224+
225+
elif issubclass(db.Model, (sa_orm.DeclarativeBase, sa_orm.DeclarativeBaseNoMeta)):
226+
227+
class TimestampMixin: # type: ignore[no-redef]
228+
created: sa_orm.Mapped[datetime] = sa_orm.mapped_column(
229+
db.DateTime, nullable=False, default=datetime.utcnow
230+
)
231+
updated: sa_orm.Mapped[datetime] = sa_orm.mapped_column(
232+
db.DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
233+
)
234+
235+
class Post(TimestampMixin, db.Model): # type: ignore[no-redef]
236+
id: sa_orm.Mapped[int] = sa_orm.mapped_column(db.Integer, primary_key=True)
237+
title: sa_orm.Mapped[str] = sa_orm.mapped_column(db.String, nullable=False)
238+
239+
else:
240+
241+
class TimestampMixin: # type: ignore[no-redef]
242+
created = db.Column(db.DateTime, nullable=False, default=datetime.utcnow)
243+
updated = db.Column(
244+
db.DateTime, onupdate=datetime.utcnow, default=datetime.utcnow
245+
)
246+
247+
class Post(TimestampMixin, db.Model): # type: ignore[no-redef]
248+
id = db.Column(db.Integer, primary_key=True)
249+
title = db.Column(db.String, nullable=False)
250+
251+
db.create_all()
252+
db.session.add(Post(title="Admin Post"))
253+
db.session.commit()
254+
post = db.session.execute(db.select(Post)).scalar()
255+
assert post is not None
256+
assert post.created is not None
257+
assert post.updated is not None
258+
259+
83260
@pytest.mark.usefixtures("app_ctx")
84261
def test_model_repr(db: SQLAlchemy) -> None:
85262
class User(db.Model):

0 commit comments

Comments
 (0)