Skip to content

Refactor: Filters were messy with duplication #58

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 20 additions & 88 deletions src/aleph/sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
)
from aleph_message.models.execution.base import Encoding
from aleph_message.status import MessageStatus
from pydantic import ValidationError

from aleph.sdk.types import Account, GenericMessage, StorageEnum
from aleph.sdk.utils import Writable, copy_async_readable_to_buffer
Expand All @@ -59,6 +58,7 @@
MultipleMessagesError,
)
from .models import MessagesResponse
from .query import MessageQuery, MessageQueryFilter
from .utils import check_unix_socket_valid, get_message_type_value

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -736,96 +736,28 @@ async def get_messages(
) -> MessagesResponse:
"""
Fetch a list of messages from the network.

:param pagination: Number of items to fetch (Default: 200)
:param page: Page to fetch, begins at 1 (Default: 1)
:param message_type: Filter by message type, can be "AGGREGATE", "POST", "PROGRAM", "VM", "STORE" or "FORGET"
:param content_types: Filter by content type
:param content_keys: Filter by content key
:param refs: If set, only fetch posts that reference these hashes (in the "refs" field)
:param addresses: Addresses of the posts to fetch (Default: all addresses)
:param tags: Tags of the posts to fetch (Default: all tags)
:param hashes: Specific item_hashes to fetch
:param channels: Channels of the posts to fetch (Default: all channels)
:param chains: Filter by sender address chain
:param start_date: Earliest date to fetch messages from
:param end_date: Latest date to fetch messages from
:param ignore_invalid_messages: Ignore invalid messages (Default: False)
:param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET)
"""
ignore_invalid_messages = (
True if ignore_invalid_messages is None else ignore_invalid_messages
)
invalid_messages_log_level = (
logging.NOTSET
if invalid_messages_log_level is None
else invalid_messages_log_level
)

params: Dict[str, Any] = dict(pagination=pagination, page=page)

if message_type is not None:
params["msgType"] = message_type.value
if content_types is not None:
params["contentTypes"] = ",".join(content_types)
if content_keys is not None:
params["contentKeys"] = ",".join(content_keys)
if refs is not None:
params["refs"] = ",".join(refs)
if addresses is not None:
params["addresses"] = ",".join(addresses)
if tags is not None:
params["tags"] = ",".join(tags)
if hashes is not None:
params["hashes"] = ",".join(hashes)
if channels is not None:
params["channels"] = ",".join(channels)
if chains is not None:
params["chains"] = ",".join(chains)

if start_date is not None:
if not isinstance(start_date, float) and hasattr(start_date, "timestamp"):
start_date = start_date.timestamp()
params["startDate"] = start_date
if end_date is not None:
if not isinstance(end_date, float) and hasattr(start_date, "timestamp"):
end_date = end_date.timestamp()
params["endDate"] = end_date
query_filter = MessageQueryFilter(
message_type=message_type,
content_types=content_types,
content_keys=content_keys,
refs=refs,
addresses=addresses,
tags=tags,
hashes=hashes,
channels=channels,
chains=chains,
start_date=start_date,
end_date=end_date,
)

async with self.http_session.get(
"/api/v0/messages.json", params=params
) as resp:
resp.raise_for_status()
response_json = await resp.json()
messages_raw = response_json["messages"]

# All messages may not be valid according to the latest specification in
# aleph-message. This allows the user to specify how errors should be handled.
messages: List[AlephMessage] = []
for message_raw in messages_raw:
try:
message = parse_message(message_raw)
messages.append(message)
except KeyError as e:
if not ignore_invalid_messages:
raise e
logger.log(
level=invalid_messages_log_level,
msg=f"KeyError: Field '{e.args[0]}' not found",
)
except ValidationError as e:
if not ignore_invalid_messages:
raise e
if invalid_messages_log_level:
logger.log(level=invalid_messages_log_level, msg=e)

return MessagesResponse(
messages=messages,
pagination_page=response_json["pagination_page"],
pagination_total=response_json["pagination_total"],
pagination_per_page=response_json["pagination_per_page"],
pagination_item=response_json["pagination_item"],
)
return await MessageQuery(
query_filter=query_filter,
http_client_session=self.http_session,
ignore_invalid_messages=ignore_invalid_messages,
invalid_messages_log_level=invalid_messages_log_level,
).fetch(page=page, pagination=pagination)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seeing this, I wonder if we should ditch the .fetch(page=page, pagination=pagination) call to allow users to decide whether to fetch a certain page or to iterate through all the messages?

Currently I cannot see how a user would actually access the iterator of MessageQuery.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The user can always initiate a MessageQuery instance on his own and use it directly. That could even be more elegant, but I am trying to maintain API compatibility.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding this, I cannot help but feel reminded of the way I implemented similar functionality on the Active Record SDK, see: https://github.com/aleph-im/active-record-sdk/blob/main/src/aars/utils.py

Here, we have IndexQuery which prepares the parameters like MessageQueryFilter does.
The AARS client class handles the actual requests, which is here being handled by MessageQuery.
Then PageableRequest & PageableResponse give all the pagination & iteration over the response of the executed MessageQuery.

I got to admit though, that your approach looks more elegant than what I did in the Active Record SDK, but some of this uglyness resulted in the fact that:

  1. I am using an ORM and all the functionality needs to be attached to the base class of Records (which actually represent POSTs).
  2. Some responses are async generators, while some are sync, as AARS does not have a cache capable of async iteration/pagination (while my DomainNode implementation in MessageCache and LightNode #59 does).

I think there is some significant overlap between this PR, the PRs #54 and #59 and the Active Record SDK. Given the right preparation and time for thought, we could build a more coherent and powerful SDK for handling messages as well as POSTs.

One more thing about POSTs: Given the interesting behavior that is possible by amending these messages, I think that they would benefit from the ORM approach chosen in AARS, as it would give users very simple syntax to update/amend POSTs as well as retrieving past revisions of them. This is a paradigm not often seen in DLTs and should be highlighted and easier to work with, IMHO.


async def get_message(
self,
Expand Down
211 changes: 211 additions & 0 deletions src/aleph/sdk/query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
from __future__ import annotations

import logging
from datetime import datetime
from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Union

import aiohttp
from aleph_message.models import (
AlephMessage,
MessagesResponse,
MessageType,
parse_message,
)
from pydantic import ValidationError

logger = logging.getLogger(__name__)


def serialize_list(values: Optional[Iterable[str]]) -> Optional[str]:
if values:
return ",".join(values)
else:
return None


def _date_field_to_float(date: Optional[Union[datetime, float]]) -> Optional[float]:
if date is None:
return None
elif isinstance(date, float):
return date
elif hasattr(date, "timestamp"):
return date.timestamp()
else:
raise TypeError(f"Invalid type: `{type(date)}`")


class MessageQueryFilter:
"""
A collection of filters that can be applied on message queries.

:param message_type: Filter by message type, can be "AGGREGATE", "POST", "PROGRAM", "VM", "STORE" or "FORGET"
:param content_types: Filter by content type
:param content_keys: Filter by content key
:param refs: If set, only fetch posts that reference these hashes (in the "refs" field)
:param addresses: Addresses of the posts to fetch (Default: all addresses)
:param tags: Tags of the posts to fetch (Default: all tags)
:param hashes: Specific item_hashes to fetch
:param channels: Channels of the posts to fetch (Default: all channels)
:param chains: Filter by sender address chain
:param start_date: Earliest date to fetch messages from
:param end_date: Latest date to fetch messages from
"""

message_type: Optional[MessageType]
content_types: Optional[Iterable[str]]
content_keys: Optional[Iterable[str]]
refs: Optional[Iterable[str]]
addresses: Optional[Iterable[str]]
tags: Optional[Iterable[str]]
hashes: Optional[Iterable[str]]
channels: Optional[Iterable[str]]
chains: Optional[Iterable[str]]
start_date: Optional[Union[datetime, float]]
end_date: Optional[Union[datetime, float]]

def __init__(
self,
message_type: Optional[MessageType] = None,
content_types: Optional[Iterable[str]] = None,
content_keys: Optional[Iterable[str]] = None,
refs: Optional[Iterable[str]] = None,
addresses: Optional[Iterable[str]] = None,
tags: Optional[Iterable[str]] = None,
hashes: Optional[Iterable[str]] = None,
channels: Optional[Iterable[str]] = None,
chains: Optional[Iterable[str]] = None,
start_date: Optional[Union[datetime, float]] = None,
end_date: Optional[Union[datetime, float]] = None,
):
self.message_type = message_type
self.content_types = content_types
self.content_keys = content_keys
self.refs = refs
self.addresses = addresses
self.tags = tags
self.hashes = hashes
self.channels = channels
self.chains = chains
self.start_date = start_date
self.end_date = end_date

def as_http_params(self) -> Dict[str, str]:
"""Convert the filters into a dict that can be used by an `aiohttp` client
as `params` to build the HTTP query string.
"""

partial_result = {
"msgType": self.message_type.value if self.message_type else None,
"contentTypes": serialize_list(self.content_types),
"contentKeys": serialize_list(self.content_keys),
"refs": serialize_list(self.refs),
"addresses": serialize_list(self.addresses),
"tags": serialize_list(self.tags),
"hashes": serialize_list(self.hashes),
"channels": serialize_list(self.channels),
"chains": serialize_list(self.chains),
"startDate": _date_field_to_float(self.start_date),
"endDate": _date_field_to_float(self.end_date),
}

# Ensure all values are strings.
result: Dict[str, str] = {}

# Drop empty values
for key, value in partial_result.items():
if value:
assert isinstance(value, str), f"Value must be a string: `{value}`"
result[key] = value

return result


class MessageQuery:
"""
Interface to query messages from an API server.

:param query_filter: The filter to apply when fetching messages
:param http_client_session: The Aiohttp client session to the API server
:param ignore_invalid_messages: Ignore invalid messages (Default: False)
:param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET)
"""

query_filter: MessageQueryFilter
http_client_session: aiohttp.ClientSession
ignore_invalid_messages: bool
invalid_messages_log_level: int

def __init__(
self,
query_filter: MessageQueryFilter,
http_client_session: aiohttp.ClientSession,
ignore_invalid_messages: bool = True,
invalid_messages_log_level: int = logging.NOTSET,
):
self.query_filter = query_filter
self.http_client_session = http_client_session
self.ignore_invalid_messages = ignore_invalid_messages
self.invalid_messages_log_level = invalid_messages_log_level

async def fetch_json(self, page: int = 0, pagination: int = 200):
"""Return the raw JSON response from the API server."""
params: Dict[str, Any] = self.query_filter.as_http_params()
params["page"] = str(page)
params["pagination"] = str(pagination)
async with self.http_client_session.get(
"/api/v0/messages.json", params=params
) as resp:
resp.raise_for_status()
return await resp.json()

async def fetch(self, page: int = 0, pagination: int = 200):
"""Return the parsed messages from the API server."""
response_json = await self.fetch_json(page=page, pagination=pagination)

messages_raw = response_json["messages"]

# All messages may not be valid according to the latest specification in
# aleph-message. This allows the user to specify how errors should be handled.
messages: List[AlephMessage] = []
for message_raw in messages_raw:
try:
message = parse_message(message_raw)
messages.append(message)
except KeyError as e:
if not self.ignore_invalid_messages:
raise e
logger.log(
level=self.invalid_messages_log_level,
msg=f"KeyError: Field '{e.args[0]}' not found",
)
except ValidationError as e:
if not self.ignore_invalid_messages:
raise e
if self.invalid_messages_log_level:
logger.log(level=self.invalid_messages_log_level, msg=e)

return MessagesResponse(
messages=messages,
pagination_page=response_json["pagination_page"],
pagination_total=response_json["pagination_total"],
pagination_per_page=response_json["pagination_per_page"],
pagination_item=response_json["pagination_item"],
)

async def __aiter__(self) -> AsyncIterator[AlephMessage]:
"""Iterate asynchronously over matching messages.
Handles pagination internally.

```
async for message in MessageQuery(query_filter=filter):
print(message)
```
"""
page: int = 0
partial_result = await self.fetch(page=0)
while partial_result:
for message in partial_result.messages:
yield message

page += 1
partial_result = await self.fetch(page=0)