Skip to content

Commit 838dacc

Browse files
committed
joinloaded vs contains_eager
1 parent d8515be commit 838dacc

File tree

2 files changed

+205
-0
lines changed

2 files changed

+205
-0
lines changed

app.py

+204
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
from curses import meta
2+
from sqlalchemy import event
3+
4+
from sqlalchemy import Column, MetaData, String, create_engine, Integer, ForeignKey
5+
from sqlalchemy.orm import (
6+
declarative_base,
7+
sessionmaker,
8+
relationship,
9+
contains_eager,
10+
configure_mappers,
11+
joinedload,
12+
)
13+
14+
DB_URI = "sqlite://"
15+
16+
engine = create_engine(DB_URI)
17+
metadata = MetaData()
18+
Base = declarative_base(metadata=metadata)
19+
20+
Session = sessionmaker(engine)
21+
Session.configure(bind=engine)
22+
23+
24+
class DBStatementCounter(object):
25+
"""
26+
https://stackoverflow.com/questions/19073099/how-to-count-sqlalchemy-queries-in-unit-tests
27+
Use as a context manager to count the number of execute()'s performed
28+
against the given sqlalchemy connection.
29+
30+
Usage:
31+
with DBStatementCounter(conn) as ctr:
32+
conn.execute("SELECT 1")
33+
conn.execute("SELECT 1")
34+
assert ctr.get_count() == 2
35+
"""
36+
37+
def __init__(self, conn):
38+
self.conn = conn
39+
self.count = 0
40+
# Will have to rely on this since sqlalchemy 0.8 does not support
41+
# removing event listeners
42+
self.do_count = False
43+
event.listen(conn, "after_execute", self.callback)
44+
45+
def __enter__(self):
46+
self.do_count = True
47+
return self
48+
49+
def __exit__(self, *_):
50+
self.do_count = False
51+
52+
def get_count(self):
53+
return self.count
54+
55+
def callback(self, *_):
56+
if self.do_count:
57+
self.count += 1
58+
59+
60+
61+
class Person(Base):
62+
__tablename__ = "person"
63+
id = Column(Integer, primary_key=True, autoincrement=True)
64+
name = Column(String)
65+
user = relationship("User", uselist=False)
66+
67+
68+
class User(Base):
69+
__tablename__ = "user"
70+
id = Column(Integer, primary_key=True, autoincrement=True)
71+
person_id = Column(Integer, ForeignKey("person.id"))
72+
person = relationship("Person")
73+
my_accounts = relationship("UserAccount")
74+
75+
76+
class Company(Base):
77+
__tablename__ = "company"
78+
id = Column(Integer, primary_key=True, autoincrement=True)
79+
name = Column(String)
80+
81+
82+
class Account(Base):
83+
__tablename__ = "account"
84+
id = Column(Integer, primary_key=True, autoincrement=True)
85+
status = Column(String)
86+
company_id = Column(Integer, ForeignKey("company.id"))
87+
company = relationship("Company")
88+
89+
90+
class UserAccount(Base):
91+
__tablename__ = "user_account"
92+
account_id = Column(Integer, ForeignKey("account.id"), primary_key=True)
93+
account = relationship("Account")
94+
user_id = Column(Integer, ForeignKey("user.id"), primary_key=True)
95+
user = relationship("User")
96+
97+
98+
def populate_db(n=10):
99+
users = []
100+
with Session() as session:
101+
for i in range(n):
102+
person = Person(name=f"test{i}")
103+
session.add(person)
104+
session.flush()
105+
user = User(person=person)
106+
session.add(user)
107+
session.flush()
108+
company = Company(name=f"company{i}")
109+
session.add(company)
110+
session.flush()
111+
account = Account(status="x{i}", company=company)
112+
session.add(account)
113+
session.flush()
114+
user_account = UserAccount(user=user, account=account)
115+
session.add(user_account)
116+
session.flush()
117+
session.commit()
118+
119+
120+
configure_mappers()
121+
metadata.drop_all(engine)
122+
metadata.create_all(engine)
123+
populate_db()
124+
125+
126+
def get_query(session, options=None):
127+
if not options:
128+
options = []
129+
return (
130+
session.query(Person)
131+
.join("user", "my_accounts", "account", "company")
132+
.options(*options)
133+
.filter(
134+
Person.name.ilike("test%"),
135+
Account.status.ilike("x%"),
136+
Company.name.ilike("company%"),
137+
)
138+
)
139+
140+
141+
def simple_query():
142+
print("--------")
143+
print("simple")
144+
with Session() as session:
145+
with DBStatementCounter(session.connection()) as ctr:
146+
result = get_query(session)
147+
print("Query")
148+
print(result)
149+
person = result.first()
150+
if person:
151+
a = person.user.my_accounts[0].account.company.name
152+
print("Statements")
153+
print(ctr.count)
154+
155+
156+
def joinedload_query():
157+
print("--------")
158+
print("joinedload")
159+
with Session() as session:
160+
with DBStatementCounter(session.connection()) as ctr:
161+
result = get_query(
162+
session,
163+
options=[
164+
joinedload("user"),
165+
joinedload("user", "my_accounts"),
166+
joinedload("user", "my_accounts", "account"),
167+
joinedload("user", "my_accounts", "account", "company"),
168+
],
169+
)
170+
print("Query")
171+
print(result)
172+
person = result.first()
173+
if person:
174+
a = person.user.my_accounts[0].account.company.name
175+
print("Statements")
176+
print(ctr.count)
177+
178+
179+
def eager_query():
180+
print("--------")
181+
print("joinedload")
182+
with Session() as session:
183+
with DBStatementCounter(session.connection()) as ctr:
184+
result = get_query(
185+
session,
186+
options=[
187+
contains_eager("user"),
188+
contains_eager("user", "my_accounts"),
189+
contains_eager("user", "my_accounts", "account"),
190+
contains_eager("user", "my_accounts", "account", "company"),
191+
],
192+
)
193+
print("Query")
194+
print(result)
195+
person = result.first()
196+
if person:
197+
a = person.user.my_accounts[0].account.company.name
198+
print("Statements")
199+
print(ctr.count)
200+
201+
202+
simple_query()
203+
joinedload_query()
204+
eager_query()

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
sqlalchemy==1.4.36

0 commit comments

Comments
 (0)