diff --git a/postgres/__init__.py b/postgres/__init__.py index 7c529bf..69b5a11 100644 --- a/postgres/__init__.py +++ b/postgres/__init__.py @@ -186,6 +186,7 @@ from inspect import isclass from postgres.context_managers import ( ConnectionContextManager, CursorContextManager, CursorSubcontextManager, + ConnectionCursorContextManager, ) from postgres.cursors import ( make_dict, make_namedtuple, return_tuple_as_is, @@ -695,7 +696,14 @@ def cursor(self, back_as=None, **kw): cursor.back_as = back_as return cursor - get_cursor = cursor + def get_cursor(self, cursor=None, **kw): + if cursor: + if cursor.connection is not self: + raise ValueError( + "the provided cursor is from a different connection" + ) + return CursorSubcontextManager(cursor, **kw) + return ConnectionCursorContextManager(self, **kw) return Connection diff --git a/postgres/context_managers.py b/postgres/context_managers.py index e92af78..9ee6a3d 100644 --- a/postgres/context_managers.py +++ b/postgres/context_managers.py @@ -6,7 +6,7 @@ class CursorContextManager(object): """Instantiated once per :func:`~postgres.Postgres.get_cursor` call. - :param pool: see :mod:`psycopg2.pool` + :param pool: see :mod:`psycopg2_pool` :param bool autocommit: see :attr:`psycopg2:connection.autocommit` :param bool readonly: see :attr:`psycopg2:connection.readonly` :param \**cursor_kwargs: passed to :meth:`psycopg2:connection.cursor` @@ -49,6 +49,48 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.pool.putconn(self.conn) +class ConnectionCursorContextManager(object): + """Creates a cursor from the given connection, then wraps it in a context + manager that automatically commits or rolls back the changes on exit. + + :param conn: a :class:`psycopg2:connection` + :param bool autocommit: see :attr:`psycopg2:connection.autocommit` + :param bool readonly: see :attr:`psycopg2:connection.readonly` + :param \**cursor_kwargs: passed to :meth:`psycopg2:connection.cursor` + + During construction, the connection's :attr:`autocommit` and :attr:`readonly` + attributes are set, then :meth:`psycopg2:connection.cursor` is called with + `cursor_kwargs`. + + Upon exit of the ``with`` block, the connection is rolled back if an + exception was raised, or committed otherwise. There are two exceptions to + this: + + 1. if :attr:`autocommit` is :obj:`True`, then the connection is neither + rolled back nor committed; + 2. if :attr:`readonly` is :obj:`True`, then the connection is always rolled + back, never committed. + + In all cases the cursor is closed. + + """ + + __slots__ = ('conn', 'cursor') + + def __init__(self, conn, autocommit=False, readonly=False, **cursor_kwargs): + conn.autocommit = autocommit + conn.readonly = readonly + self.conn = conn + self.cursor = conn.cursor(**cursor_kwargs) + + def __enter__(self): + return self.cursor + + def __exit__(self, exc_type, exc_val, exc_tb): + self.cursor.close() + self.conn.__exit__(exc_type, exc_val, exc_tb) + + class CursorSubcontextManager(object): """Wraps a cursor so that it can be used for a subtransaction. @@ -78,7 +120,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): class ConnectionContextManager(object): """Instantiated once per :func:`~postgres.Postgres.get_connection` call. - :param pool: see :mod:`psycopg2.pool` + :param pool: see :mod:`psycopg2_pool` :param bool autocommit: see :attr:`psycopg2:connection.autocommit` :param bool readonly: see :attr:`psycopg2:connection.readonly` diff --git a/tests.py b/tests.py index ee037ed..3a84362 100644 --- a/tests.py +++ b/tests.py @@ -315,11 +315,16 @@ def test_connection_has_get_cursor_method(self): with self.db.get_connection() as conn: with conn.get_cursor() as cursor: cursor.execute("DELETE FROM foo WHERE bar = 'baz'") - with conn.get_cursor(cursor_factory=SimpleDictCursor) as cursor: - cursor.execute("SELECT * FROM foo ORDER BY bar") - actual = cursor.fetchall() + with self.db.get_cursor(cursor_factory=SimpleDictCursor) as cursor: + cursor.execute("SELECT * FROM foo ORDER BY bar") + actual = cursor.fetchall() assert actual == [{"bar": "buz"}] + def test_get_cursor_method_checks_cursor_argument(self): + with self.db.get_connection() as conn, self.db.get_cursor() as cursor: + with self.assertRaises(ValueError): + conn.get_cursor(cursor=cursor) + # orm # ===