diff --git a/streamable/stream.py b/streamable/stream.py index baf634ab..f36f58b7 100644 --- a/streamable/stream.py +++ b/streamable/stream.py @@ -1,3 +1,4 @@ +import copy import datetime import logging from contextlib import suppress @@ -7,6 +8,7 @@ Callable, Collection, Coroutine, + Dict, Generic, Iterable, Iterator, @@ -609,6 +611,11 @@ class DownStream(Stream[U], Generic[T, U]): def __init__(self, upstream: Stream[T]) -> None: self._upstream: Stream[T] = upstream + def __deepcopy__(self, memo: Dict[int, Any]) -> "DownStream[T, U]": + new = copy.copy(self) + new._upstream = copy.deepcopy(self._upstream, memo) + return new + @property def source(self) -> Union[Iterable, Callable[[], Iterable]]: return self._upstream.source diff --git a/tests/test_stream.py b/tests/test_stream.py index 35f47c59..2b6d23d5 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -1,4 +1,5 @@ import asyncio +import copy import datetime import logging import math @@ -1866,3 +1867,22 @@ def test_on_queue_in_thread(self) -> None: ["foo", "bar"], msg="stream must work on Queue", ) + + def test_deepcopy(self) -> None: + stream = Stream([]).map(str) + stream_copy = copy.deepcopy(stream) + self.assertEqual( + stream, + stream_copy, + msg="the copy must be equal", + ) + self.assertIsNot( + stream, + stream_copy, + msg="the copy must be a different object", + ) + self.assertIsNot( + stream.source, + stream_copy.source, + msg="the copy's source must be a different object", + )