diff --git a/src/surrealdb/data/cbor.py b/src/surrealdb/data/cbor.py index 55fa7e40..72e605a9 100644 --- a/src/surrealdb/data/cbor.py +++ b/src/surrealdb/data/cbor.py @@ -25,9 +25,6 @@ def default_encoder(encoder, obj): if isinstance(obj, GeometryPoint): tagged = CBORTag(constants.TAG_GEOMETRY_POINT, obj.get_coordinates()) - if obj is None: - tagged = CBORTag(constants.TAG_NONE, None) - elif isinstance(obj, GeometryLine): tagged = CBORTag(constants.TAG_GEOMETRY_LINE, obj.get_coordinates()) @@ -65,7 +62,7 @@ def default_encoder(encoder, obj): tagged = CBORTag(constants.TAG_BOUND_EXCLUDED, obj.value) elif isinstance(obj, Duration): - tagged = CBORTag(constants.TAG_DURATION, obj.get_seconds_and_nano()) + tagged = CBORTag(constants.TAG_DURATION, obj.to_string()) elif isinstance(obj, IsoDateTimeWrapper): tagged = CBORTag(constants.TAG_DATETIME, obj.dt) @@ -116,6 +113,8 @@ def tag_decoder(decoder, tag, shareable_index=None): return Range(tag.value[0], tag.value[1]) elif tag.tag == constants.TAG_DURATION_COMPACT: + if len(tag.value) == 1: + return Duration.parse(tag.value[0]) return Duration.parse(tag.value[0], tag.value[1]) # Two numbers (s, ns) elif tag.tag == constants.TAG_DURATION: diff --git a/src/surrealdb/data/types/duration.py b/src/surrealdb/data/types/duration.py index 11fff8c5..139a25e4 100644 --- a/src/surrealdb/data/types/duration.py +++ b/src/surrealdb/data/types/duration.py @@ -24,7 +24,12 @@ def parse(value: Union[str, int], nanoseconds: int = 0) -> "Duration": return Duration(nanoseconds + value * UNITS["s"]) elif isinstance(value, str): unit = value[-1] - num = int(value[:-1]) + if unit == "s": + unit = value[-2] + unit + num = int(value[:-2]) + else: + num = int(value[:-1]) + if unit in UNITS: return Duration(num * UNITS[unit]) else: @@ -75,11 +80,12 @@ def weeks(self) -> int: return self.elapsed // UNITS["w"] def to_string(self) -> str: - for unit in reversed(["w", "d", "h", "m", "s", "ms", "us", "ns"]): - value = self.elapsed // UNITS[unit] - if value > 0: + for unit in ["w", "d", "h", "m", "s", "ms", "us"]: + if self.elapsed % UNITS[unit] == 0: + value = self.elapsed // UNITS[unit] return f"{value}{unit}" - return "0ns" + + return f"{self.elapsed}ns" def to_compact(self) -> list: return [self.elapsed // UNITS["s"]] diff --git a/tests/unit_tests/data_types/test_duration.py b/tests/unit_tests/data_types/test_duration.py new file mode 100644 index 00000000..a1597045 --- /dev/null +++ b/tests/unit_tests/data_types/test_duration.py @@ -0,0 +1,115 @@ +import unittest +from unittest import main, IsolatedAsyncioTestCase +from surrealdb.data.types.duration import Duration, UNITS + +from surrealdb.connections.async_ws import AsyncWsSurrealConnection +import sys + + +class TestDurationClass(unittest.TestCase): + """Test the Duration class functionality (synchronous).""" + + def test_parse_ms(self): + d = Duration.parse("2ms") # parse 2 as seconds + self.assertEqual(d.elapsed, 2 * UNITS["ms"]) + + def test_parse_int_seconds(self): + d = Duration.parse(2) # parse 2 as seconds + self.assertEqual(d.elapsed, 2 * UNITS["s"]) + + def test_parse_str_hours(self): + d = Duration.parse("3h") + self.assertEqual(d.elapsed, 3 * UNITS["h"]) + + def test_parse_str_days(self): + d = Duration.parse("2d") + self.assertEqual(d.elapsed, 2 * UNITS["d"]) + + def test_parse_unknown_unit(self): + with self.assertRaises(ValueError): + Duration.parse("10x") + + def test_parse_wrong_type(self): + with self.assertRaises(TypeError): + Duration.parse(1.5) # float not allowed + + def test_get_seconds_and_nano(self): + # Suppose we have 1 second + 500 nanoseconds + d = Duration( (1 * UNITS["s"]) + 500 ) + sec, nano = d.get_seconds_and_nano() + self.assertEqual(sec, 1) + self.assertEqual(nano, 500) + + def test_equality(self): + d1 = Duration.parse("2h") + d2 = Duration.parse("2h") + d3 = Duration.parse("3h") + self.assertEqual(d1, d2) + self.assertNotEqual(d1, d3) + + def test_properties(self): + d = Duration.parse("2h") # 7200 seconds + self.assertEqual(d.hours, 2) + self.assertEqual(d.minutes, 120) + self.assertEqual(d.seconds, 7200) + self.assertEqual(d.milliseconds, 7200 * 1000) + self.assertEqual(d.microseconds, 7200 * 1000000) + self.assertEqual(d.nanoseconds, 7200 * 1000000000) + + def test_to_string(self): + d = Duration.parse("90m") # 90 minutes = 1 hour 30 min + self.assertEqual("90m", d.to_string()) + + def test_to_compact(self): + d = Duration.parse(60) # 60 seconds + self.assertEqual(d.to_compact(), [60]) + + +class TestAsyncWsSurrealConnectionDuration(IsolatedAsyncioTestCase): + """ + Uses SurrealDB to create a table with a duration field, insert some data, + and verify we can query it back without causing 'list index out of range'. + """ + + async def asyncSetUp(self): + self.url = "ws://localhost:8000/rpc" + self.password = "root" + self.username = "root" + self.vars_params = { + "username": self.username, + "password": self.password, + } + self.database_name = "test_db" + self.namespace = "test_ns" + self.connection = AsyncWsSurrealConnection(self.url) + + # Sign in and select DB + await self.connection.signin(self.vars_params) + await self.connection.use(namespace=self.namespace, database=self.database_name) + + # Cleanup + await self.connection.query("DELETE duration_tests;") + + async def asyncTearDown(self): + await self.connection.query("DELETE duration_tests;") + await self.connection.close() + + async def test_duration_int_insert(self): + """ + Insert an integer as a duration. SurrealDB might store it so that the + local decode sees only one item, or a different structure. If the code + expects two items, it might trigger 'list index out of range'. + """ + test_duration = Duration.parse("3h") + + test_outcome = await self.connection.query( + "CREATE duration_tests:test set duration = $duration;", + params={"duration": test_duration} + ) + duration = test_outcome[0]["duration"] + self.assertIsInstance(duration, Duration) + self.assertEqual(duration.elapsed, 10800000000000) + + +if __name__ == "__main__": + main()