Skip to content

Fix the new Connection.get_cursor method #82

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 9, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion postgres/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@
from inspect import isclass
from postgres.context_managers import (
ConnectionContextManager, CursorContextManager, CursorSubcontextManager,
ConnectionCursorContextManager,
)
from postgres.cursors import (
make_dict, make_namedtuple, return_tuple_as_is,
Expand Down Expand Up @@ -695,7 +696,14 @@ def cursor(self, back_as=None, **kw):
cursor.back_as = back_as
return cursor

get_cursor = cursor
def get_cursor(self, cursor=None, **kw):
if cursor:
if cursor.connection is not self:
raise ValueError(
"the provided cursor is from a different connection"
)
return CursorSubcontextManager(cursor, **kw)
return ConnectionCursorContextManager(self, **kw)

return Connection

Expand Down
46 changes: 44 additions & 2 deletions postgres/context_managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
class CursorContextManager(object):
"""Instantiated once per :func:`~postgres.Postgres.get_cursor` call.

:param pool: see :mod:`psycopg2.pool`
:param pool: see :mod:`psycopg2_pool`
:param bool autocommit: see :attr:`psycopg2:connection.autocommit`
:param bool readonly: see :attr:`psycopg2:connection.readonly`
:param \**cursor_kwargs: passed to :meth:`psycopg2:connection.cursor`
Expand Down Expand Up @@ -49,6 +49,48 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self.pool.putconn(self.conn)


class ConnectionCursorContextManager(object):
"""Creates a cursor from the given connection, then wraps it in a context
manager that automatically commits or rolls back the changes on exit.

:param conn: a :class:`psycopg2:connection`
:param bool autocommit: see :attr:`psycopg2:connection.autocommit`
:param bool readonly: see :attr:`psycopg2:connection.readonly`
:param \**cursor_kwargs: passed to :meth:`psycopg2:connection.cursor`

During construction, the connection's :attr:`autocommit` and :attr:`readonly`
attributes are set, then :meth:`psycopg2:connection.cursor` is called with
`cursor_kwargs`.

Upon exit of the ``with`` block, the connection is rolled back if an
exception was raised, or committed otherwise. There are two exceptions to
this:

1. if :attr:`autocommit` is :obj:`True`, then the connection is neither
rolled back nor committed;
2. if :attr:`readonly` is :obj:`True`, then the connection is always rolled
back, never committed.

In all cases the cursor is closed.

"""

__slots__ = ('conn', 'cursor')

def __init__(self, conn, autocommit=False, readonly=False, **cursor_kwargs):
conn.autocommit = autocommit
conn.readonly = readonly
self.conn = conn
self.cursor = conn.cursor(**cursor_kwargs)

def __enter__(self):
return self.cursor

def __exit__(self, exc_type, exc_val, exc_tb):
self.cursor.close()
self.conn.__exit__(exc_type, exc_val, exc_tb)


class CursorSubcontextManager(object):
"""Wraps a cursor so that it can be used for a subtransaction.

Expand Down Expand Up @@ -78,7 +120,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
class ConnectionContextManager(object):
"""Instantiated once per :func:`~postgres.Postgres.get_connection` call.

:param pool: see :mod:`psycopg2.pool`
:param pool: see :mod:`psycopg2_pool`
:param bool autocommit: see :attr:`psycopg2:connection.autocommit`
:param bool readonly: see :attr:`psycopg2:connection.readonly`

Expand Down
11 changes: 8 additions & 3 deletions tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,11 +315,16 @@ def test_connection_has_get_cursor_method(self):
with self.db.get_connection() as conn:
with conn.get_cursor() as cursor:
cursor.execute("DELETE FROM foo WHERE bar = 'baz'")
with conn.get_cursor(cursor_factory=SimpleDictCursor) as cursor:
cursor.execute("SELECT * FROM foo ORDER BY bar")
actual = cursor.fetchall()
with self.db.get_cursor(cursor_factory=SimpleDictCursor) as cursor:
cursor.execute("SELECT * FROM foo ORDER BY bar")
actual = cursor.fetchall()
assert actual == [{"bar": "buz"}]

def test_get_cursor_method_checks_cursor_argument(self):
with self.db.get_connection() as conn, self.db.get_cursor() as cursor:
with self.assertRaises(ValueError):
conn.get_cursor(cursor=cursor)


# orm
# ===
Expand Down