Skip to content

Commit e0cec71

Browse files
authored
RSDK-9146: Change Python TabularDataBySQL/MQL return type to raw BSON (#774)
1 parent b932e16 commit e0cec71

File tree

5 files changed

+98
-14
lines changed

5 files changed

+98
-14
lines changed

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ dependencies = [
1515
"grpclib>=0.4.7",
1616
"protobuf==5.28.2",
1717
"typing-extensions>=4.12.2",
18+
"pymongo>=4.10.1"
1819
]
1920

2021
[project.urls]

src/viam/app/data_client.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
from dataclasses import dataclass
33
from datetime import datetime
44
from pathlib import Path
5-
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple
5+
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union
6+
import bson
67

78
from google.protobuf.struct_pb2 import Struct
89
from grpclib.client import Channel, Stream
@@ -241,7 +242,7 @@ async def tabular_data_by_filter(
241242
LOGGER.error(f"Failed to write tabular data to file {dest}", exc_info=e)
242243
return data, response.count, response.last
243244

244-
async def tabular_data_by_sql(self, organization_id: str, sql_query: str) -> List[Dict[str, ValueTypes]]:
245+
async def tabular_data_by_sql(self, organization_id: str, sql_query: str) -> List[Dict[str, Union[ValueTypes, datetime]]]:
245246
"""Obtain unified tabular data and metadata, queried with SQL.
246247
247248
::
@@ -264,9 +265,9 @@ async def tabular_data_by_sql(self, organization_id: str, sql_query: str) -> Lis
264265
"""
265266
request = TabularDataBySQLRequest(organization_id=organization_id, sql_query=sql_query)
266267
response: TabularDataBySQLResponse = await self._data_client.TabularDataBySQL(request, metadata=self._metadata)
267-
return [struct_to_dict(struct) for struct in response.data]
268+
return [bson.decode(bson_bytes) for bson_bytes in response.raw_data]
268269

269-
async def tabular_data_by_mql(self, organization_id: str, mql_binary: List[bytes]) -> List[Dict[str, ValueTypes]]:
270+
async def tabular_data_by_mql(self, organization_id: str, mql_binary: List[bytes]) -> List[Dict[str, Union[ValueTypes, datetime]]]:
270271
"""Obtain unified tabular data and metadata, queried with MQL.
271272
272273
::
@@ -303,7 +304,7 @@ async def tabular_data_by_mql(self, organization_id: str, mql_binary: List[bytes
303304
"""
304305
request = TabularDataByMQLRequest(organization_id=organization_id, mql_binary=mql_binary)
305306
response: TabularDataByMQLResponse = await self._data_client.TabularDataByMQL(request, metadata=self._metadata)
306-
return [struct_to_dict(struct) for struct in response.data]
307+
return [bson.decode(bson_bytes) for bson_bytes in response.raw_data]
307308

308309
async def binary_data_by_filter(
309310
self,

tests/mocks/services.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
from typing import Any, Dict, List, Mapping, Optional, Sequence
1+
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union
22

33
import numpy as np
44
from grpclib.server import Stream
55
from numpy.typing import NDArray
6+
from datetime import datetime
7+
import bson
68

79
from viam.app.data_client import DataClient
810
from viam.gen.app.v1.app_pb2 import FragmentHistoryEntry, GetFragmentHistoryRequest, GetFragmentHistoryResponse
@@ -791,12 +793,11 @@ async def SetSmartMachineCredentials(
791793
self.cloud_config = request.cloud
792794
await stream.send_message(SetSmartMachineCredentialsResponse())
793795

794-
795796
class MockData(UnimplementedDataServiceBase):
796797
def __init__(
797798
self,
798799
tabular_response: List[DataClient.TabularData],
799-
tabular_query_response: List[Dict[str, ValueTypes]],
800+
tabular_query_response: List[Dict[str, Union[ValueTypes, datetime]]],
800801
binary_response: List[BinaryData],
801802
delete_remove_response: int,
802803
tags_response: List[str],
@@ -986,12 +987,12 @@ async def RemoveBinaryDataFromDatasetByIDs(
986987
async def TabularDataBySQL(self, stream: Stream[TabularDataBySQLRequest, TabularDataBySQLResponse]) -> None:
987988
request = await stream.recv_message()
988989
assert request is not None
989-
await stream.send_message(TabularDataBySQLResponse(data=[dict_to_struct(dict) for dict in self.tabular_query_response]))
990+
await stream.send_message(TabularDataBySQLResponse(raw_data=[bson.encode(dict) for dict in self.tabular_query_response]))
990991

991992
async def TabularDataByMQL(self, stream: Stream[TabularDataByMQLRequest, TabularDataByMQLResponse]) -> None:
992993
request = await stream.recv_message()
993994
assert request is not None
994-
await stream.send_message(TabularDataByMQLResponse(data=[dict_to_struct(dict) for dict in self.tabular_query_response]))
995+
await stream.send_message(TabularDataByMQLResponse(raw_data=[bson.encode(dict) for dict in self.tabular_query_response]))
995996

996997

997998
class MockDataset(DatasetServiceBase):

tests/test_data_client.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from google.protobuf.timestamp_pb2 import Timestamp
55
from grpclib.testing import ChannelFor
66

7+
from datetime import datetime
78
from viam.app.data_client import DataClient
89
from viam.proto.app.data import Annotations, BinaryData, BinaryID, BinaryMetadata, BoundingBox, CaptureMetadata, Filter, Order
910
from viam.utils import create_filter
@@ -101,7 +102,7 @@
101102

102103
TABULAR_RESPONSE = [DataClient.TabularData(TABULAR_DATA, TABULAR_METADATA, START_DATETIME, END_DATETIME)]
103104
TABULAR_QUERY_RESPONSE = [
104-
{"key1": 1, "key2": "2", "key3": [1, 2, 3], "key4": {"key4sub1": 1}},
105+
{"key1": START_DATETIME, "key2": "2", "key3": [1, 2, 3], "key4": {"key4sub1": END_DATETIME}},
105106
]
106107
BINARY_RESPONSE = [BinaryData(binary=BINARY_DATA, metadata=BINARY_METADATA)]
107108
DELETE_REMOVE_RESPONSE = 1
@@ -153,12 +154,14 @@ async def test_tabular_data_by_sql(self, service: MockData):
153154
async with ChannelFor([service]) as channel:
154155
client = DataClient(channel, DATA_SERVICE_METADATA)
155156
response = await client.tabular_data_by_sql(ORG_ID, SQL_QUERY)
157+
assert isinstance(response[0]["key1"], datetime)
156158
assert response == TABULAR_QUERY_RESPONSE
157159

158160
async def test_tabular_data_by_mql(self, service: MockData):
159161
async with ChannelFor([service]) as channel:
160162
client = DataClient(channel, DATA_SERVICE_METADATA)
161163
response = await client.tabular_data_by_mql(ORG_ID, MQL_BINARY)
164+
assert isinstance(response[0]["key1"], datetime)
162165
assert response == TABULAR_QUERY_RESPONSE
163166

164167
async def test_binary_data_by_filter(self, service: MockData):

0 commit comments

Comments
 (0)