Skip to content

Commit a1419bd

Browse files
authored
Merge pull request #403 from isuruf/leak
Add workaround for symbol class leak
2 parents fd7b471 + 1830662 commit a1419bd

File tree

4 files changed

+75
-30
lines changed

4 files changed

+75
-30
lines changed

symengine/lib/pywrapper.cpp

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -281,47 +281,61 @@ inline PyObject* get_pickle_module() {
281281
return module;
282282
}
283283

284+
PyObject* pickle_loads(const std::string &pickle_str) {
285+
PyObject *module = get_pickle_module();
286+
PyObject *pickle_bytes = PyBytes_FromStringAndSize(pickle_str.data(), pickle_str.size());
287+
PyObject *obj = PyObject_CallMethod(module, "loads", "O", pickle_bytes);
288+
Py_XDECREF(pickle_bytes);
289+
if (obj == NULL) {
290+
throw SerializationError("error when loading pickled symbol subclass object");
291+
}
292+
return obj;
293+
}
294+
284295
RCP<const Basic> load_basic(cereal::PortableBinaryInputArchive &ar, RCP<const Symbol> &)
285296
{
286297
bool is_pysymbol;
298+
bool store_pickle;
287299
std::string name;
288300
ar(is_pysymbol);
289301
ar(name);
290302
if (is_pysymbol) {
291303
std::string pickle_str;
292304
ar(pickle_str);
293-
PyObject *module = get_pickle_module();
294-
PyObject *pickle_bytes = PyBytes_FromStringAndSize(pickle_str.data(), pickle_str.size());
295-
PyObject *obj = PyObject_CallMethod(module, "loads", "O", pickle_bytes);
296-
if (obj == NULL) {
297-
throw SerializationError("error when loading pickled symbol subclass object");
298-
}
299-
RCP<const Basic> result = make_rcp<PySymbol>(name, obj);
300-
Py_XDECREF(pickle_bytes);
305+
ar(store_pickle);
306+
PyObject *obj = pickle_loads(pickle_str);
307+
RCP<const Basic> result = make_rcp<PySymbol>(name, obj, store_pickle);
308+
Py_XDECREF(obj);
301309
return result;
302310
} else {
303311
return symbol(name);
304312
}
305313
}
306314

315+
std::string pickle_dumps(const PyObject * obj) {
316+
PyObject *module = get_pickle_module();
317+
PyObject *pickle_bytes = PyObject_CallMethod(module, "dumps", "O", obj);
318+
if (pickle_bytes == NULL) {
319+
throw SerializationError("error when pickling symbol subclass object");
320+
}
321+
Py_ssize_t size;
322+
char* buffer;
323+
PyBytes_AsStringAndSize(pickle_bytes, &buffer, &size);
324+
return std::string(buffer, size);
325+
}
326+
307327
void save_basic(cereal::PortableBinaryOutputArchive &ar, const Symbol &b)
308328
{
309329
bool is_pysymbol = is_a_sub<PySymbol>(b);
310330
ar(is_pysymbol);
311331
ar(b.__str__());
312332
if (is_pysymbol) {
313333
RCP<const PySymbol> p = rcp_static_cast<const PySymbol>(b.rcp_from_this());
314-
PyObject *module = get_pickle_module();
315-
PyObject *pickle_bytes = PyObject_CallMethod(module, "dumps", "O", p->get_py_object());
316-
if (pickle_bytes == NULL) {
317-
throw SerializationError("error when pickling symbol subclass object");
318-
}
319-
Py_ssize_t size;
320-
char* buffer;
321-
PyBytes_AsStringAndSize(pickle_bytes, &buffer, &size);
322-
std::string pickle_str(buffer, size);
334+
PyObject *obj = p->get_py_object();
335+
std::string pickle_str = pickle_dumps(obj);
323336
ar(pickle_str);
324-
Py_XDECREF(pickle_bytes);
337+
ar(p->store_pickle);
338+
Py_XDECREF(obj);
325339
}
326340
}
327341

symengine/lib/pywrapper.h

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88

99
namespace SymEngine {
1010

11+
std::string pickle_dumps(const PyObject *);
12+
PyObject* pickle_loads(const std::string &);
13+
1114
/*
1215
* PySymbol is a subclass of Symbol that keeps a reference to a Python object.
1316
* When subclassing a Symbol from Python, the information stored in subclassed
@@ -27,16 +30,30 @@ namespace SymEngine {
2730
class PySymbol : public Symbol {
2831
private:
2932
PyObject* obj;
33+
std::string bytes;
3034
public:
31-
PySymbol(const std::string& name, PyObject* obj) : Symbol(name), obj(obj) {
32-
Py_INCREF(obj);
35+
const bool store_pickle;
36+
PySymbol(const std::string& name, PyObject* obj, bool store_pickle) :
37+
Symbol(name), obj(obj), store_pickle(store_pickle) {
38+
if (store_pickle) {
39+
bytes = pickle_dumps(obj);
40+
} else {
41+
Py_INCREF(obj);
42+
}
3343
}
3444
PyObject* get_py_object() const {
35-
return obj;
45+
if (store_pickle) {
46+
return pickle_loads(bytes);
47+
} else {
48+
Py_INCREF(obj);
49+
return obj;
50+
}
3651
}
3752
virtual ~PySymbol() {
38-
// TODO: This is never called because of the cyclic reference.
39-
Py_DECREF(obj);
53+
if (not store_pickle) {
54+
// TODO: This is never called because of the cyclic reference.
55+
Py_DECREF(obj);
56+
}
4057
}
4158
};
4259

symengine/lib/symengine.pxd

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ cdef extern from "<symengine/basic.h>" namespace "SymEngine":
197197
bool neq(const Basic &a, const Basic &b) nogil except +
198198

199199
RCP[const Symbol] rcp_static_cast_Symbol "SymEngine::rcp_static_cast<const SymEngine::Symbol>"(rcp_const_basic &b) nogil
200-
RCP[const PySymbol] rcp_static_cast_PySymbol "SymEngine::rcp_static_cast<const SymEngine::PySymbol>"(rcp_const_basic &b) nogil
200+
RCP[const PySymbol] rcp_static_cast_PySymbol "SymEngine::rcp_static_cast<const SymEngine::PySymbol>"(rcp_const_basic &b) nogil except +
201201
RCP[const Integer] rcp_static_cast_Integer "SymEngine::rcp_static_cast<const SymEngine::Integer>"(rcp_const_basic &b) nogil
202202
RCP[const Rational] rcp_static_cast_Rational "SymEngine::rcp_static_cast<const SymEngine::Rational>"(rcp_const_basic &b) nogil
203203
RCP[const Complex] rcp_static_cast_Complex "SymEngine::rcp_static_cast<const SymEngine::Complex>"(rcp_const_basic &b) nogil
@@ -369,8 +369,8 @@ cdef extern from "pywrapper.h" namespace "SymEngine":
369369

370370
cdef extern from "pywrapper.h" namespace "SymEngine":
371371
cdef cppclass PySymbol(Symbol):
372-
PySymbol(string name, PyObject* pyobj)
373-
PyObject* get_py_object()
372+
PySymbol(string name, PyObject* pyobj, bool use_pickle) except +
373+
PyObject* get_py_object() except +
374374

375375
string wrapper_dumps(const Basic &x) nogil except +
376376
rcp_const_basic wrapper_loads(const string &s) nogil except +
@@ -479,7 +479,7 @@ cdef extern from "<symengine/basic.h>" namespace "SymEngine":
479479
rcp_const_basic make_rcp_Symbol "SymEngine::make_rcp<const SymEngine::Symbol>"(string name) nogil
480480
rcp_const_basic make_rcp_Dummy "SymEngine::make_rcp<const SymEngine::Dummy>"() nogil
481481
rcp_const_basic make_rcp_Dummy "SymEngine::make_rcp<const SymEngine::Dummy>"(string name) nogil
482-
rcp_const_basic make_rcp_PySymbol "SymEngine::make_rcp<const SymEngine::PySymbol>"(string name, PyObject * pyobj) nogil
482+
rcp_const_basic make_rcp_PySymbol "SymEngine::make_rcp<const SymEngine::PySymbol>"(string name, PyObject * pyobj, bool use_pickle) except +
483483
rcp_const_basic make_rcp_Constant "SymEngine::make_rcp<const SymEngine::Constant>"(string name) nogil
484484
rcp_const_basic make_rcp_Infty "SymEngine::make_rcp<const SymEngine::Infty>"(RCP[const Number] i) nogil
485485
rcp_const_basic make_rcp_NaN "SymEngine::make_rcp<const SymEngine::NaN>"() nogil

symengine/lib/symengine_wrapper.pyx

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ cpdef void assign_to_capsule(object capsule, object value):
4646

4747
cdef object c2py(rcp_const_basic o):
4848
cdef Basic r
49+
cdef PyObject *obj
4950
if (symengine.is_a_Add(deref(o))):
5051
r = Expr.__new__(Add)
5152
elif (symengine.is_a_Mul(deref(o))):
@@ -74,7 +75,10 @@ cdef object c2py(rcp_const_basic o):
7475
r = Dummy.__new__(Dummy)
7576
elif (symengine.is_a_Symbol(deref(o))):
7677
if (symengine.is_a_PySymbol(deref(o))):
77-
return <object>(deref(symengine.rcp_static_cast_PySymbol(o)).get_py_object())
78+
obj = deref(symengine.rcp_static_cast_PySymbol(o)).get_py_object()
79+
result = <object>(obj)
80+
Py_XDECREF(obj);
81+
return result
7882
r = Symbol.__new__(Symbol)
7983
elif (symengine.is_a_Constant(deref(o))):
8084
r = S.Pi
@@ -1216,16 +1220,26 @@ cdef class Expr(Basic):
12161220

12171221

12181222
cdef class Symbol(Expr):
1219-
12201223
"""
12211224
Symbol is a class to store a symbolic variable with a given name.
1225+
Subclassing Symbol leads to a memory leak due to a cycle in reference counting.
1226+
To avoid this with a performance penalty, set the kwarg store_pickle=True
1227+
in the constructor and support the pickle protocol in the subclass by
1228+
implmenting __reduce__.
12221229
"""
12231230

12241231
def __init__(Basic self, name, *args, **kwargs):
1232+
cdef cppbool store_pickle;
12251233
if type(self) == Symbol:
12261234
self.thisptr = symengine.make_rcp_Symbol(name.encode("utf-8"))
12271235
else:
1228-
self.thisptr = symengine.make_rcp_PySymbol(name.encode("utf-8"), <PyObject*>self)
1236+
store_pickle = kwargs.pop("store_pickle", False)
1237+
if store_pickle:
1238+
# First set the pointer to a regular symbol so that when pickle.dumps
1239+
# is called when the PySymbol is created, methods like name works.
1240+
self.thisptr = symengine.make_rcp_Symbol(name.encode("utf-8"))
1241+
self.thisptr = symengine.make_rcp_PySymbol(name.encode("utf-8"), <PyObject*>self,
1242+
store_pickle)
12291243

12301244
def _sympy_(self):
12311245
import sympy

0 commit comments

Comments
 (0)