Skip to content

State: add get_bytes / set_bytes methods #834

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions docs/advanced/stateful-processing.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,43 @@ for the state to become slightly out of sync with a topic in between shutdowns a
While the impact of this is generally minimal and only for a small amount of messages, be aware this could cause side effects where the same message may be reprocessed differently, if it depended on certain state conditionals.

"Exactly Once" delivery guarantees avoid this. You can learn more about delivery/processing guarantees [here](https://quix.io/docs/quix-streams/configuration.html?h=#processing-guarantees).

## Serialization

By default, the keys and values are serialized to JSON for storage. If you need to change the serialization format, you can do so using the `rocksdb_options` parameter when creating the `Application` object. This change will apply to all state stores created by the application and existing state will be un-readable.

For example, you can use [python `pickle` module](https://docs.python.org/3/library/pickle.html) to serialize and deserialize all stores data.

```python
import pickle

from quixstreams import Application
app = Application(
broker_address='localhost:9092',
rocksdb_options=RocksDBOptions(dumps=pickle.dumps, loads=pickle.loads)
)
```

You can also handle the serialization and deserialization yourself by using the [`State.get_bytes`](../api-reference/state.md#stateget_bytes) and [`State.set_bytes`](../api-reference/state.md#stateset_bytes) methods. This allows you to store any type of values in the state store, as long as you can convert it to bytes and back.

```python
import pickle

from quixstreams import Application, State
app = Application(
broker_address='localhost:9092',
consumer_group='consumer',
)
topic = app.topic('topic')

sdf = app.dataframe(topic)

def apply(value, state):
old = state.get_bytes('key', default=None)
if old is not None:
old = pickle.loads(old)
state.set_bytes('key', pickle.dumps(value))
return {"old": old, "new": value}

sdf = sdf.apply(apply, stateful=True)
```
60 changes: 55 additions & 5 deletions quixstreams/state/base/state.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Generic, Optional, TypeVar, overload
from typing import TYPE_CHECKING, Generic, Literal, Optional, TypeVar, overload

if TYPE_CHECKING:
from .transaction import PartitionTransaction
Expand All @@ -20,7 +20,7 @@ class State(ABC, Generic[K, V]):
"""

@overload
def get(self, key: K) -> Optional[V]: ...
def get(self, key: K, default: Literal[None] = None) -> Optional[V]: ...

@overload
def get(self, key: K, default: V) -> V: ...
Expand All @@ -36,8 +36,32 @@ def get(self, key: K, default: Optional[V] = None) -> Optional[V]:
"""
...

@overload
def get_bytes(self, key: K, default: Literal[None] = None) -> Optional[bytes]: ...

@overload
def get_bytes(self, key: K, default: bytes) -> bytes: ...

def get_bytes(self, key: K, default: Optional[bytes] = None) -> Optional[bytes]:
"""
Get the value for key if key is present in the state, else default

:param key: key
:param default: default value to return if the key is not found
:return: value as bytes or None if the key is not found and `default` is not provided
"""

@abstractmethod
def set(self, key: K, value: V) -> None:
"""
Set value for the key.
:param key: key
:param value: value
"""
...

@abstractmethod
def set(self, key: K, value: V):
def set_bytes(self, key: K, value: bytes) -> None:
"""
Set value for the key.
:param key: key
Expand Down Expand Up @@ -81,7 +105,7 @@ def __init__(self, prefix: bytes, transaction: "PartitionTransaction"):
self._transaction = transaction

@overload
def get(self, key: K) -> Optional[V]: ...
def get(self, key: K, default: Literal[None] = None) -> Optional[V]: ...

@overload
def get(self, key: K, default: V) -> V: ...
Expand All @@ -96,14 +120,40 @@ def get(self, key: K, default: Optional[V] = None) -> Optional[V]:
"""
return self._transaction.get(key=key, prefix=self._prefix, default=default)

def set(self, key: K, value: V):
@overload
def get_bytes(self, key: K, default: Literal[None] = None) -> Optional[bytes]: ...

@overload
def get_bytes(self, key: K, default: bytes) -> bytes: ...

def get_bytes(self, key: K, default: Optional[bytes] = None) -> Optional[bytes]:
"""
Get the bytes value for key if key is present in the state, else default

:param key: key
:param default: default value to return if the key is not found
:return: value or None if the key is not found and `default` is not provided
"""
return self._transaction.get_bytes(
key=key, prefix=self._prefix, default=default
)

def set(self, key: K, value: V) -> None:
"""
Set value for the key.
:param key: key
:param value: value
"""
return self._transaction.set(key=key, value=value, prefix=self._prefix)

def set_bytes(self, key: K, value: bytes) -> None:
"""
Set value for the key.
:param key: key
:param value: value
"""
return self._transaction.set_bytes(key=key, value=value, prefix=self._prefix)

def delete(self, key: K):
"""
Delete value for the key.
Expand Down
99 changes: 85 additions & 14 deletions quixstreams/state/base/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Any,
Dict,
Generic,
Literal,
Optional,
Set,
Tuple,
Expand All @@ -17,7 +18,11 @@
)

from quixstreams.models import Headers
from quixstreams.state.exceptions import InvalidChangelogOffset, StateTransactionError
from quixstreams.state.exceptions import (
InvalidChangelogOffset,
StateSerializationError,
StateTransactionError,
)
from quixstreams.state.metadata import (
CHANGELOG_CF_MESSAGE_HEADER,
CHANGELOG_PROCESSED_OFFSETS_MESSAGE_HEADER,
Expand Down Expand Up @@ -301,7 +306,6 @@ def get(
@overload
def get(self, key: K, prefix: bytes, default: V, cf_name: str = "default") -> V: ...

@validate_transaction_status(PartitionTransactionStatus.STARTED)
def get(
self,
key: K,
Expand All @@ -320,25 +324,71 @@ def get(
:param cf_name: column family name
:return: value or None if the key is not found and `default` is not provided
"""

data = self._get_bytes(key, prefix, cf_name)
if data is Marker.DELETED or data is Marker.UNDEFINED:
return default

return self._deserialize_value(data)

@overload
def get_bytes(
self,
key: K,
prefix: bytes,
default: Literal[None] = None,
cf_name: str = "default",
) -> Optional[bytes]: ...

@overload
def get_bytes(
self, key: K, prefix: bytes, default: bytes, cf_name: str = "default"
) -> bytes: ...

def get_bytes(
self,
key: K,
prefix: bytes,
default: Optional[bytes] = None,
cf_name: str = "default",
) -> Optional[bytes]:
"""
Get a key from the store.

It returns `None` if the key is not found and `default` is not provided.

:param key: key
:param prefix: a key prefix
:param default: default value to return if the key is not found
:param cf_name: column family name
:return: value as bytes or None if the key is not found and `default` is not provided
"""
data = self._get_bytes(key, prefix, cf_name)
if data is Marker.DELETED or data is Marker.UNDEFINED:
return default

return data

@validate_transaction_status(PartitionTransactionStatus.STARTED)
def _get_bytes(
self,
key: K,
prefix: bytes,
cf_name: str = "default",
) -> Union[bytes, Literal[Marker.DELETED, Marker.UNDEFINED]]:
key_serialized = self._serialize_key(key, prefix=prefix)

cached = self._update_cache.get(
key=key_serialized, prefix=prefix, cf_name=cf_name
)
if cached is Marker.DELETED:
return default

if cached is not Marker.UNDEFINED:
return self._deserialize_value(cached)

stored = self._partition.get(key_serialized, cf_name)
if stored is Marker.UNDEFINED:
return default
if cached is Marker.UNDEFINED:
return self._partition.get(key_serialized, cf_name)

return self._deserialize_value(stored)
return cached

@validate_transaction_status(PartitionTransactionStatus.STARTED)
def set(self, key: K, value: V, prefix: bytes, cf_name: str = "default"):
def set(self, key: K, value: V, prefix: bytes, cf_name: str = "default") -> None:
"""
Set value for the key.
:param key: key
Expand All @@ -348,11 +398,32 @@ def set(self, key: K, value: V, prefix: bytes, cf_name: str = "default"):
"""

try:
key_serialized = self._serialize_key(key, prefix=prefix)
value_serialized = self._serialize_value(value)
except Exception:
self._status = PartitionTransactionStatus.FAILED
raise

self.set_bytes(key, value_serialized, prefix, cf_name=cf_name)

@validate_transaction_status(PartitionTransactionStatus.STARTED)
def set_bytes(
self, key: K, value: bytes, prefix: bytes, cf_name: str = "default"
) -> None:
"""
Set bytes value for the key.
:param key: key
:param prefix: a key prefix
:param value: value
:param cf_name: column family name
"""
try:
if not isinstance(value, bytes):
raise StateSerializationError("Value must be bytes")

key_serialized = self._serialize_key(key, prefix=prefix)
self._update_cache.set(
key=key_serialized,
value=value_serialized,
value=value,
prefix=prefix,
cf_name=cf_name,
)
Expand Down
Loading