Skip to content

Commit 32051cb

Browse files
BadSingletonfilmor
andauthored
Expose serialization api (pythonnet#2336)
* Expose an API for users to specify their own formatter Adds post-serialization and pre-deserialization hooks for additional customization. * Add API for capsuling data when serializing * Add NoopFormatter and fall back to it if BinaryFormatter is not available --------- Co-authored-by: Benedikt Reinartz <[email protected]>
1 parent 195cde6 commit 32051cb

File tree

5 files changed

+269
-6
lines changed

5 files changed

+269
-6
lines changed

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ This document follows the conventions laid out in [Keep a CHANGELOG][].
1515
to compare with primitive .NET types like `long`.
1616

1717
### Changed
18+
- Added a `FormatterFactory` member in RuntimeData to create formatters with parameters. For compatibility, the `FormatterType` member is still present and has precedence when defining both `FormatterFactory` and `FormatterType`
19+
- Added a post-serialization and a pre-deserialization step callbacks to extend (de)serialization process
20+
- Added an API to stash serialized data on Python capsules
1821

1922
### Fixed
2023

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
using System;
2+
using System.IO;
3+
using System.Runtime.Serialization;
4+
5+
namespace Python.Runtime;
6+
7+
public class NoopFormatter : IFormatter {
8+
public object Deserialize(Stream s) => throw new NotImplementedException();
9+
public void Serialize(Stream s, object o) {}
10+
11+
public SerializationBinder? Binder { get; set; }
12+
public StreamingContext Context { get; set; }
13+
public ISurrogateSelector? SurrogateSelector { get; set; }
14+
}

src/runtime/StateSerialization/RuntimeData.cs

+132-6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
using System;
2-
using System.Collections;
32
using System.Collections.Generic;
4-
using System.Collections.ObjectModel;
53
using System.Diagnostics;
64
using System.IO;
75
using System.Linq;
@@ -17,7 +15,34 @@ namespace Python.Runtime
1715
{
1816
public static class RuntimeData
1917
{
20-
private static Type? _formatterType;
18+
19+
public readonly static Func<IFormatter> DefaultFormatterFactory = () =>
20+
{
21+
try
22+
{
23+
return new BinaryFormatter();
24+
}
25+
catch
26+
{
27+
return new NoopFormatter();
28+
}
29+
};
30+
31+
private static Func<IFormatter> _formatterFactory { get; set; } = DefaultFormatterFactory;
32+
33+
public static Func<IFormatter> FormatterFactory
34+
{
35+
get => _formatterFactory;
36+
set
37+
{
38+
if (value == null)
39+
throw new ArgumentNullException(nameof(value));
40+
41+
_formatterFactory = value;
42+
}
43+
}
44+
45+
private static Type? _formatterType = null;
2146
public static Type? FormatterType
2247
{
2348
get => _formatterType;
@@ -31,6 +56,14 @@ public static Type? FormatterType
3156
}
3257
}
3358

59+
/// <summary>
60+
/// Callback called as a last step in the serialization process
61+
/// </summary>
62+
public static Action? PostStashHook { get; set; } = null;
63+
/// <summary>
64+
/// Callback called as the first step in the deserialization process
65+
/// </summary>
66+
public static Action? PreRestoreHook { get; set; } = null;
3467
public static ICLRObjectStorer? WrappersStorer { get; set; }
3568

3669
/// <summary>
@@ -74,6 +107,7 @@ internal static void Stash()
74107
using NewReference capsule = PyCapsule_New(mem, IntPtr.Zero, IntPtr.Zero);
75108
int res = PySys_SetObject("clr_data", capsule.BorrowOrThrow());
76109
PythonException.ThrowIfIsNotZero(res);
110+
PostStashHook?.Invoke();
77111
}
78112

79113
internal static void RestoreRuntimeData()
@@ -90,6 +124,7 @@ internal static void RestoreRuntimeData()
90124

91125
private static void RestoreRuntimeDataImpl()
92126
{
127+
PreRestoreHook?.Invoke();
93128
BorrowedReference capsule = PySys_GetObject("clr_data");
94129
if (capsule.IsNull)
95130
{
@@ -250,11 +285,102 @@ private static void RestoreRuntimeDataObjects(SharedObjectsState storage)
250285
}
251286
}
252287

288+
static readonly string serialization_key_namepsace = "pythonnet_serialization_";
289+
/// <summary>
290+
/// Removes the serialization capsule from the `sys` module object.
291+
/// </summary>
292+
/// <remarks>
293+
/// The serialization data must have been set with <code>StashSerializationData</code>
294+
/// </remarks>
295+
/// <param name="key">The name given to the capsule on the `sys` module object</param>
296+
public static void FreeSerializationData(string key)
297+
{
298+
key = serialization_key_namepsace + key;
299+
BorrowedReference oldCapsule = PySys_GetObject(key);
300+
if (!oldCapsule.IsNull)
301+
{
302+
IntPtr oldData = PyCapsule_GetPointer(oldCapsule, IntPtr.Zero);
303+
Marshal.FreeHGlobal(oldData);
304+
PyCapsule_SetPointer(oldCapsule, IntPtr.Zero);
305+
PySys_SetObject(key, null);
306+
}
307+
}
308+
309+
/// <summary>
310+
/// Stores the data in the <paramref name="stream"/> argument in a Python capsule and stores
311+
/// the capsule on the `sys` module object with the name <paramref name="key"/>.
312+
/// </summary>
313+
/// <remarks>
314+
/// No checks on pre-existing names on the `sys` module object are made.
315+
/// </remarks>
316+
/// <param name="key">The name given to the capsule on the `sys` module object</param>
317+
/// <param name="stream">A MemoryStream that contains the data to be placed in the capsule</param>
318+
public static void StashSerializationData(string key, MemoryStream stream)
319+
{
320+
if (stream.TryGetBuffer(out var data))
321+
{
322+
IntPtr mem = Marshal.AllocHGlobal(IntPtr.Size + data.Count);
323+
324+
// store the length of the buffer first
325+
Marshal.WriteIntPtr(mem, (IntPtr)data.Count);
326+
Marshal.Copy(data.Array, data.Offset, mem + IntPtr.Size, data.Count);
327+
328+
try
329+
{
330+
using NewReference capsule = PyCapsule_New(mem, IntPtr.Zero, IntPtr.Zero);
331+
int res = PySys_SetObject(key, capsule.BorrowOrThrow());
332+
PythonException.ThrowIfIsNotZero(res);
333+
}
334+
catch
335+
{
336+
Marshal.FreeHGlobal(mem);
337+
}
338+
}
339+
else
340+
{
341+
throw new NotImplementedException($"{nameof(stream)} must be exposable");
342+
}
343+
344+
}
345+
346+
static byte[] emptyBuffer = new byte[0];
347+
/// <summary>
348+
/// Retreives the previously stored data on a Python capsule.
349+
/// Throws if the object corresponding to the <paramref name="key"/> parameter
350+
/// on the `sys` module object is not a capsule.
351+
/// </summary>
352+
/// <param name="key">The name given to the capsule on the `sys` module object</param>
353+
/// <returns>A MemoryStream containing the previously saved serialization data.
354+
/// The stream is empty if no name matches the key. </returns>
355+
public static MemoryStream GetSerializationData(string key)
356+
{
357+
BorrowedReference capsule = PySys_GetObject(key);
358+
if (capsule.IsNull)
359+
{
360+
// nothing to do.
361+
return new MemoryStream(emptyBuffer, writable:false);
362+
}
363+
var ptr = PyCapsule_GetPointer(capsule, IntPtr.Zero);
364+
if (ptr == IntPtr.Zero)
365+
{
366+
// The PyCapsule API returns NULL on error; NULL cannot be stored
367+
// as a capsule's value
368+
PythonException.ThrowIfIsNull(null);
369+
}
370+
var len = (int)Marshal.ReadIntPtr(ptr);
371+
byte[] buffer = new byte[len];
372+
Marshal.Copy(ptr+IntPtr.Size, buffer, 0, len);
373+
return new MemoryStream(buffer, writable:false);
374+
}
375+
253376
internal static IFormatter CreateFormatter()
254377
{
255-
return FormatterType != null ?
256-
(IFormatter)Activator.CreateInstance(FormatterType)
257-
: new BinaryFormatter();
378+
379+
if (FormatterType != null)
380+
{
381+
return (IFormatter)Activator.CreateInstance(FormatterType);
382+
}
383+
return FormatterFactory();
258384
}
259385
}
260386
}

tests/domain_tests/TestRunner.cs

+117
Original file line numberDiff line numberDiff line change
@@ -1132,6 +1132,66 @@ import System
11321132
11331133
",
11341134
},
1135+
new TestCase
1136+
{
1137+
Name = "test_serialize_unserializable_object",
1138+
DotNetBefore = @"
1139+
namespace TestNamespace
1140+
{
1141+
public class NotSerializableTextWriter : System.IO.TextWriter
1142+
{
1143+
override public System.Text.Encoding Encoding { get { return System.Text.Encoding.ASCII;} }
1144+
}
1145+
[System.Serializable]
1146+
public static class SerializableWriter
1147+
{
1148+
private static System.IO.TextWriter _writer = null;
1149+
public static System.IO.TextWriter Writer {get { return _writer; }}
1150+
public static void CreateInternalWriter()
1151+
{
1152+
_writer = System.IO.TextWriter.Synchronized(new NotSerializableTextWriter());
1153+
}
1154+
}
1155+
}
1156+
",
1157+
DotNetAfter = @"
1158+
namespace TestNamespace
1159+
{
1160+
public class NotSerializableTextWriter : System.IO.TextWriter
1161+
{
1162+
override public System.Text.Encoding Encoding { get { return System.Text.Encoding.ASCII;} }
1163+
}
1164+
[System.Serializable]
1165+
public static class SerializableWriter
1166+
{
1167+
private static System.IO.TextWriter _writer = null;
1168+
public static System.IO.TextWriter Writer {get { return _writer; }}
1169+
public static void CreateInternalWriter()
1170+
{
1171+
_writer = System.IO.TextWriter.Synchronized(new NotSerializableTextWriter());
1172+
}
1173+
}
1174+
}
1175+
",
1176+
PythonCode = @"
1177+
import sys
1178+
1179+
def before_reload():
1180+
import clr
1181+
import System
1182+
clr.AddReference('DomainTests')
1183+
import TestNamespace
1184+
TestNamespace.SerializableWriter.CreateInternalWriter();
1185+
sys.__obj = TestNamespace.SerializableWriter.Writer
1186+
sys.__obj.WriteLine('test')
1187+
1188+
def after_reload():
1189+
import clr
1190+
import System
1191+
sys.__obj.WriteLine('test')
1192+
1193+
",
1194+
}
11351195
};
11361196

11371197
/// <summary>
@@ -1142,7 +1202,59 @@ import System
11421202
const string CaseRunnerTemplate = @"
11431203
using System;
11441204
using System.IO;
1205+
using System.Runtime.Serialization;
1206+
using System.Runtime.Serialization.Formatters.Binary;
11451207
using Python.Runtime;
1208+
1209+
namespace Serialization
1210+
{{
1211+
// Classes in this namespace is mostly useful for test_serialize_unserializable_object
1212+
class NotSerializableSerializer : ISerializationSurrogate
1213+
{{
1214+
public NotSerializableSerializer()
1215+
{{
1216+
}}
1217+
public void GetObjectData(object obj, SerializationInfo info, StreamingContext context)
1218+
{{
1219+
info.AddValue(""notSerialized_tp"", obj.GetType());
1220+
}}
1221+
public object SetObjectData(object obj, SerializationInfo info, StreamingContext context, ISurrogateSelector selector)
1222+
{{
1223+
if (info == null)
1224+
{{
1225+
return null;
1226+
}}
1227+
Type typeObj = info.GetValue(""notSerialized_tp"", typeof(Type)) as Type;
1228+
if (typeObj == null)
1229+
{{
1230+
return null;
1231+
}}
1232+
1233+
obj = Activator.CreateInstance(typeObj);
1234+
return obj;
1235+
}}
1236+
}}
1237+
class NonSerializableSelector : SurrogateSelector
1238+
{{
1239+
public override ISerializationSurrogate GetSurrogate(Type type, StreamingContext context, out ISurrogateSelector selector)
1240+
{{
1241+
if (type == null)
1242+
{{
1243+
throw new ArgumentNullException();
1244+
}}
1245+
selector = (ISurrogateSelector)this;
1246+
if (type.IsSerializable)
1247+
{{
1248+
return null; // use whichever default
1249+
}}
1250+
else
1251+
{{
1252+
return (ISerializationSurrogate)(new NotSerializableSerializer());
1253+
}}
1254+
}}
1255+
}}
1256+
}}
1257+
11461258
namespace CaseRunner
11471259
{{
11481260
class CaseRunner
@@ -1151,6 +1263,11 @@ public static int Main()
11511263
{{
11521264
try
11531265
{{
1266+
RuntimeData.FormatterFactory = () =>
1267+
{{
1268+
return new BinaryFormatter(){{SurrogateSelector = new Serialization.NonSerializableSelector()}};
1269+
}};
1270+
11541271
PythonEngine.Initialize();
11551272
using (Py.GIL())
11561273
{{

tests/domain_tests/test_domain_reload.py

+3
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,6 @@ def test_nested_type():
8888

8989
def test_import_after_reload():
9090
_run_test("import_after_reload")
91+
92+
def test_import_after_reload():
93+
_run_test("test_serialize_unserializable_object")

0 commit comments

Comments
 (0)