diff --git a/examples/map/async_multiproc_map/Dockerfile b/examples/map/async_multiproc_map/Dockerfile new file mode 100644 index 00000000..5ba6663f --- /dev/null +++ b/examples/map/async_multiproc_map/Dockerfile @@ -0,0 +1,55 @@ +#################################################################################################### +# builder: install needed dependencies +#################################################################################################### + +FROM python:3.10-slim-bullseye AS builder + +ENV PYTHONFAULTHANDLER=1 \ + PYTHONUNBUFFERED=1 \ + PYTHONHASHSEED=random \ + PIP_NO_CACHE_DIR=on \ + PIP_DISABLE_PIP_VERSION_CHECK=on \ + PIP_DEFAULT_TIMEOUT=100 \ + POETRY_VERSION=1.2.2 \ + POETRY_HOME="/opt/poetry" \ + POETRY_VIRTUALENVS_IN_PROJECT=true \ + POETRY_NO_INTERACTION=1 \ + PYSETUP_PATH="/opt/pysetup" + +ENV EXAMPLE_PATH="$PYSETUP_PATH/examples/map/async_multiproc_map" +ENV VENV_PATH="$EXAMPLE_PATH/.venv" +ENV PATH="$POETRY_HOME/bin:$VENV_PATH/bin:$PATH" + +RUN apt-get update \ + && apt-get install --no-install-recommends -y \ + curl \ + wget \ + # deps for building python deps + build-essential \ + && apt-get install -y git \ + && apt-get clean && rm -rf /var/lib/apt/lists/* \ + \ + # install dumb-init + && wget -O /dumb-init https://github.com/Yelp/dumb-init/releases/download/v1.2.5/dumb-init_1.2.5_x86_64 \ + && chmod +x /dumb-init \ + && curl -sSL https://install.python-poetry.org | python3 - + +#################################################################################################### +# udf: used for running the udf vertices +#################################################################################################### +FROM builder AS udf + +WORKDIR $PYSETUP_PATH +COPY ./ ./ + +WORKDIR $EXAMPLE_PATH +RUN poetry lock +RUN poetry install --no-cache --no-root && \ + rm -rf ~/.cache/pypoetry/ + +RUN chmod +x entry.sh + +ENTRYPOINT ["/dumb-init", "--"] +CMD ["sh", "-c", "$EXAMPLE_PATH/entry.sh"] + +EXPOSE 5000 diff --git a/examples/map/async_multiproc_map/Makefile b/examples/map/async_multiproc_map/Makefile new file mode 100644 index 00000000..17fd12e8 --- /dev/null +++ b/examples/map/async_multiproc_map/Makefile @@ -0,0 +1,22 @@ +TAG ?= v6 +PUSH ?= false +IMAGE_REGISTRY = quay.io/skohli/numaflow-python/async-multiproc:${TAG} +DOCKER_FILE_PATH = examples/map/async_multiproc_map/Dockerfile + +.PHONY: update +update: + poetry update -vv + +.PHONY: image-push +image-push: update + cd ../../../ && docker buildx build \ + -f ${DOCKER_FILE_PATH} \ + -t ${IMAGE_REGISTRY} \ + --platform linux/amd64,linux/arm64 . --push + +.PHONY: image +image: update + cd ../../../ && docker build \ + -f ${DOCKER_FILE_PATH} \ + -t ${IMAGE_REGISTRY} . + @if [ "$(PUSH)" = "true" ]; then docker push ${IMAGE_REGISTRY}; fi diff --git a/examples/map/async_multiproc_map/README.md b/examples/map/async_multiproc_map/README.md new file mode 100644 index 00000000..0d92ca5e --- /dev/null +++ b/examples/map/async_multiproc_map/README.md @@ -0,0 +1,20 @@ +# Multiprocessing Map + +`pynumaflow` supports only asyncio based Reduce UDFs because we found that procedural Python is not able to handle +any substantial traffic. + +This features enables the `pynumaflow` developer to utilise multiprocessing capabilities while +writing UDFs using the map function. These are particularly useful for CPU intensive operations, +as it allows for better resource utilisation. + +In this mode we would spawn N number (N = Cpu count) of grpc servers in different processes, where each of them are +listening on multiple TCP sockets. + +To enable multiprocessing mode start the multiproc server in the UDF using the following command, +providing the optional argument `server_count` to specify the number of +servers to be forked (defaults to `os.cpu_count` if not provided): +```python +if __name__ == "__main__": + grpc_server = MapMultiProcServer(handler, server_count = 3) + grpc_server.start() +``` \ No newline at end of file diff --git a/examples/map/async_multiproc_map/entry.sh b/examples/map/async_multiproc_map/entry.sh new file mode 100644 index 00000000..073b05e3 --- /dev/null +++ b/examples/map/async_multiproc_map/entry.sh @@ -0,0 +1,4 @@ +#!/bin/sh +set -eux + +python example.py diff --git a/examples/map/async_multiproc_map/example.py b/examples/map/async_multiproc_map/example.py new file mode 100644 index 00000000..370dda11 --- /dev/null +++ b/examples/map/async_multiproc_map/example.py @@ -0,0 +1,40 @@ +import os + +from pynumaflow.mapper import Messages, Message, Datum, Mapper, AsyncMapMultiprocServer +from pynumaflow._constants import _LOGGER + + +class FlatMap(Mapper): + """ + This class needs to be of type Mapper class to be used + as a handler for the MapServer class. + Example of a mapper that calculates if a number is prime. + """ + + async def handler(self, keys: list[str], datum: Datum) -> Messages: + val = datum.value + _ = datum.event_time + _ = datum.watermark + messages = Messages() + messages.append(Message(val, keys=keys)) + _LOGGER.info(f"MY PID {os.getpid()}") + return messages + + +if __name__ == "__main__": + """ + Example of starting a multiprocessing map vertex. + """ + # To set the env server_count value set the env variable + # NUM_CPU_MULTIPROC="N" + server_count = int(os.getenv("NUM_CPU_MULTIPROC", "2")) + server_type = os.getenv("SERVER_KIND", "tcp") + use_tcp = False + if server_type == "tcp": + use_tcp = True + elif server_type == "uds": + use_tcp = False + _class = FlatMap() + # Server count is the number of server processes to start + grpc_server = AsyncMapMultiprocServer(_class, server_count=server_count, use_tcp=use_tcp) + grpc_server.start() diff --git a/examples/map/async_multiproc_map/pipeline.yaml b/examples/map/async_multiproc_map/pipeline.yaml new file mode 100644 index 00000000..0434ee06 --- /dev/null +++ b/examples/map/async_multiproc_map/pipeline.yaml @@ -0,0 +1,42 @@ +apiVersion: numaflow.numaproj.io/v1alpha1 +kind: Pipeline +metadata: + name: simple-pipeline +spec: + limits: + readBatchSize: 10 + vertices: + - name: in + source: + # A self data generating source + generator: + rpu: 200 + duration: 1s + - name: mult + udf: + container: + image: quay.io/skohli/numaflow-python/async-multiproc:v5 +# imagePullPolicy: Always + env: + - name: SERVER_KIND + value: "uds" + - name: PYTHONDEBUG + value: "true" + - name: NUM_CPU_MULTIPROC + value: "3" # DO NOT forget the double quotes!!! + containerTemplate: + env: + - name: NUMAFLOW_RUNTIME + value: "rust" + - name: NUMAFLOW_DEBUG + value: "true" # DO NOT forget the double quotes!!! + + - name: out + sink: + # A simple log printing sink + log: {} + edges: + - from: in + to: mult + - from: mult + to: out diff --git a/examples/map/async_multiproc_map/pyproject.toml b/examples/map/async_multiproc_map/pyproject.toml new file mode 100644 index 00000000..ac19da18 --- /dev/null +++ b/examples/map/async_multiproc_map/pyproject.toml @@ -0,0 +1,15 @@ +[tool.poetry] +name = "async-multiproc-forward-message" +version = "0.2.4" +description = "" +authors = ["Numaflow developers"] + +[tool.poetry.dependencies] +python = ">=3.10,<3.13" +pynumaflow = { path = "../../../"} + +[tool.poetry.dev-dependencies] + +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" diff --git a/pynumaflow/batchmapper/servicer/async_servicer.py b/pynumaflow/batchmapper/servicer/async_servicer.py index 523a4ad4..1d3a1c2b 100644 --- a/pynumaflow/batchmapper/servicer/async_servicer.py +++ b/pynumaflow/batchmapper/servicer/async_servicer.py @@ -98,7 +98,7 @@ async def MapFn( except BaseException as err: _LOGGER.critical("UDFError, re-raising the error", exc_info=True) - await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING) + await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING, False) return async def IsReady( diff --git a/pynumaflow/info/types.py b/pynumaflow/info/types.py index 2845c264..16bfc319 100644 --- a/pynumaflow/info/types.py +++ b/pynumaflow/info/types.py @@ -26,6 +26,7 @@ # MULTIPROC_KEY is the field used to indicate that Multiproc map mode is enabled # The value contains the number of servers spawned. MULTIPROC_KEY = "MULTIPROC" +MULTIPROC_ENDPOINTS = "MULTIPROC_ENDPOINTS" SI = TypeVar("SI", bound="ServerInfo") diff --git a/pynumaflow/mapper/__init__.py b/pynumaflow/mapper/__init__.py index a713d039..12fe3cc1 100644 --- a/pynumaflow/mapper/__init__.py +++ b/pynumaflow/mapper/__init__.py @@ -1,3 +1,4 @@ +from pynumaflow.mapper.async_multiproc_server import AsyncMapMultiprocServer from pynumaflow.mapper.async_server import MapAsyncServer from pynumaflow.mapper.multiproc_server import MapMultiprocServer from pynumaflow.mapper.sync_server import MapServer @@ -13,4 +14,5 @@ "MapServer", "MapAsyncServer", "MapMultiprocServer", + "AsyncMapMultiprocServer", ] diff --git a/pynumaflow/mapper/_servicer/_async_servicer.py b/pynumaflow/mapper/_servicer/_async_servicer.py index 4fd5b9a3..cdf34aa4 100644 --- a/pynumaflow/mapper/_servicer/_async_servicer.py +++ b/pynumaflow/mapper/_servicer/_async_servicer.py @@ -1,4 +1,5 @@ import asyncio +import contextlib from collections.abc import AsyncIterable from google.protobuf import empty_pb2 as _empty_pb2 @@ -18,11 +19,10 @@ class AsyncMapServicer(map_pb2_grpc.MapServicer): Provides the functionality for the required rpc methods. """ - def __init__( - self, - handler: MapAsyncCallable, - ): + def __init__(self, handler: MapAsyncCallable, multiproc: bool = False): self.background_tasks = set() + # This indicates whether the grpc server attached is multiproc or not + self.multiproc = multiproc self.__map_handler: MapAsyncCallable = handler async def MapFn( @@ -36,6 +36,7 @@ async def MapFn( """ # proto repeated field(keys) is of type google._upb._message.RepeatedScalarContainer # we need to explicitly convert it to list + producer = None try: # The first message to be received should be a valid handshake req = await request_iterator.__anext__() @@ -56,44 +57,65 @@ async def MapFn( async for msg in consumer: # If the message is an exception, we raise the exception if isinstance(msg, BaseException): - await handle_async_error(context, msg, ERR_UDF_EXCEPTION_STRING) + await handle_async_error(context, msg, ERR_UDF_EXCEPTION_STRING, self.multiproc) return # Send window response back to the client else: yield msg # wait for the producer task to complete await producer + except GeneratorExit: + _LOGGER.info("Client disconnected, generator closed.") + raise except BaseException as e: _LOGGER.critical("UDFError, re-raising the error", exc_info=True) - await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING) + await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING, self.multiproc) return + finally: + if producer and not producer.done(): + producer.cancel() + with contextlib.suppress(asyncio.CancelledError): + await producer - async def _process_inputs( - self, - request_iterator: AsyncIterable[map_pb2.MapRequest], - result_queue: NonBlockingIterator, - ): - """ - Utility function for processing incoming MapRequests - """ + async def _process_inputs(self, request_iterator, result_queue): try: - # for each incoming request, create a background task to execute the - # UDF code async for req in request_iterator: - msg_task = asyncio.create_task(self._invoke_map(req, result_queue)) - # save a reference to a set to store active tasks - self.background_tasks.add(msg_task) - msg_task.add_done_callback(self.background_tasks.discard) - - # wait for all tasks to complete - for task in self.background_tasks: - await task + task = asyncio.create_task(self._invoke_map(req, result_queue)) + self.background_tasks.add(task) + task.add_done_callback(self.background_tasks.discard) - # send an EOF to result queue to indicate that all tasks have completed + await asyncio.gather(*self.background_tasks) + except BaseException: + _LOGGER.critical("MapFn Error in _process_inputs", exc_info=True) + finally: await result_queue.put(STREAM_EOF) - except BaseException: - _LOGGER.critical("MapFn Error, re-raising the error", exc_info=True) + # async def _process_inputs( + # self, + # request_iterator: AsyncIterable[map_pb2.MapRequest], + # result_queue: NonBlockingIterator, + # ): + # """ + # Utility function for processing incoming MapRequests + # """ + # try: + # # for each incoming request, create a background task to execute the + # # UDF code + # async for req in request_iterator: + # msg_task = asyncio.create_task(self._invoke_map(req, result_queue)) + # # save a reference to a set to store active tasks + # self.background_tasks.add(msg_task) + # msg_task.add_done_callback(self.background_tasks.discard) + # + # # wait for all tasks to complete + # for task in self.background_tasks: + # await task + # + # # send an EOF to result queue to indicate that all tasks have completed + # await result_queue.put(STREAM_EOF) + # + # except BaseException: + # _LOGGER.critical("MapFn Error, re-raising the error", exc_info=True) async def _invoke_map(self, req: map_pb2.MapRequest, result_queue: NonBlockingIterator): """ diff --git a/pynumaflow/mapper/async_multiproc_server.py b/pynumaflow/mapper/async_multiproc_server.py new file mode 100644 index 00000000..cd2f2cfc --- /dev/null +++ b/pynumaflow/mapper/async_multiproc_server.py @@ -0,0 +1,138 @@ +import logging +import multiprocessing +from typing import Optional + +import aiorun +import grpc + +from pynumaflow._constants import ( + MAX_NUM_THREADS, + MAX_MESSAGE_SIZE, + MAP_SERVER_INFO_FILE_PATH, + _PROCESS_COUNT, + NUM_THREADS_DEFAULT, + MULTIPROC_MAP_SOCK_ADDR, +) +from pynumaflow.info.server import get_metadata_env +from pynumaflow.info.types import ( + ServerInfo, + MINIMUM_NUMAFLOW_VERSION, + ContainerType, + MAP_MODE_KEY, + MapMode, + METADATA_ENVS, + MULTIPROC_KEY, + MULTIPROC_ENDPOINTS, + Protocol, +) +from pynumaflow.mapper._dtypes import MapAsyncCallable +from pynumaflow.mapper._servicer._async_servicer import AsyncMapServicer +from pynumaflow.proto.mapper import map_pb2_grpc +from pynumaflow.shared.server import start_async_server, NumaflowServer, reserve_port +from pynumaflow.info.server import write as info_server_write + +_LOGGER = logging.getLogger(__name__) + + +class AsyncMapMultiprocServer(NumaflowServer): + """ + A multiprocess asynchronous gRPC server for Numaflow Map UDFs. + Spawns N worker processes, each running an asyncio-based gRPC server. + """ + + def __init__( + self, + mapper_instance: MapAsyncCallable, + server_count: int = _PROCESS_COUNT, + sock_path: str = MULTIPROC_MAP_SOCK_ADDR, + max_message_size: int = MAX_MESSAGE_SIZE, + max_threads: int = NUM_THREADS_DEFAULT, + server_info_file: Optional[str] = MAP_SERVER_INFO_FILE_PATH, + use_tcp: bool = False, + ): + self.sock_path = f"unix://{sock_path}" + self.max_threads = min(max_threads, MAX_NUM_THREADS) + self.max_message_size = max_message_size + self.server_info_file = server_info_file + self.use_tcp = use_tcp + + self.mapper_instance = mapper_instance + + self._server_options = [ + ("grpc.max_send_message_length", self.max_message_size), + ("grpc.max_receive_message_length", self.max_message_size), + ("grpc.so_reuseport", 1), + ("grpc.so_reuseaddr", 1), + ] + + self._process_count = min(server_count, 2 * _PROCESS_COUNT) + self.servicer = AsyncMapServicer(handler=self.mapper_instance, multiproc=True) + + def start(self): + """ + Starts the multiprocess async gRPC servers. + """ + _LOGGER.info("Starting async multiprocess gRPC server with %d workers", self._process_count) + + workers = [] + ports = [] + + for idx in range(self._process_count): + if self.use_tcp: + with reserve_port(0) as reserved_port: + bind_address = f"0.0.0.0:{reserved_port}" + ports.append(f"http://{bind_address}") + else: + bind_address = f"{self.sock_path}{idx}.sock" + _LOGGER.info("Binding server to: %s", bind_address) + + worker = multiprocessing.Process( + target=self._run_server_process, + args=(bind_address,), + ) + worker.start() + workers.append(worker) + + # Write server info file + if self.server_info_file: + server_info = ServerInfo.get_default_server_info() + server_info.metadata[MULTIPROC_KEY] = str(self._process_count) + server_info.metadata[MAP_MODE_KEY] = MapMode.UnaryMap + if self.use_tcp: + server_info.protocol = Protocol.TCP + server_info.metadata[MULTIPROC_ENDPOINTS] = ",".join(map(str, ports)) + info_server_write(server_info=server_info, info_file=self.server_info_file) + + for worker in workers: + worker.join() + + def _run_server_process(self, bind_address): + async def run_server(): + server = grpc.aio.server(options=self._server_options) + server.add_insecure_port(bind_address) + map_pb2_grpc.add_MapServicer_to_server(self.servicer, server) + + server_info = None + if self.server_info_file: + server_info = ServerInfo.get_default_server_info() + server_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ + ContainerType.Mapper + ] + server_info.metadata = get_metadata_env(envs=METADATA_ENVS) + if self.use_tcp: + server_info.protocol = Protocol.TCP + # Add the MULTIPROC metadata using the number of servers to use + server_info.metadata[MULTIPROC_KEY] = str(self._process_count) + # Add the MAP_MODE metadata to the server info for the correct map mode + server_info.metadata[MAP_MODE_KEY] = MapMode.UnaryMap + + await start_async_server( + server_async=server, + sock_path=bind_address, + max_threads=self.max_threads, + cleanup_coroutines=list(), + server_info_file=None, + server_info=server_info, + ) + + aiorun.run(run_server(), use_uvloop=True) diff --git a/pynumaflow/mapper/async_server.py b/pynumaflow/mapper/async_server.py index 98b553d9..c9ef473f 100644 --- a/pynumaflow/mapper/async_server.py +++ b/pynumaflow/mapper/async_server.py @@ -82,7 +82,7 @@ def __init__( ("grpc.max_receive_message_length", self.max_message_size), ] # Get the servicer instance for the async server - self.servicer = AsyncMapServicer(handler=mapper_instance) + self.servicer = AsyncMapServicer(handler=mapper_instance, multiproc=False) def start(self) -> None: """ diff --git a/pynumaflow/mapstreamer/servicer/async_servicer.py b/pynumaflow/mapstreamer/servicer/async_servicer.py index 0fe58b66..3ecae7b4 100644 --- a/pynumaflow/mapstreamer/servicer/async_servicer.py +++ b/pynumaflow/mapstreamer/servicer/async_servicer.py @@ -59,7 +59,7 @@ async def MapFn( yield map_pb2.MapResponse(status=map_pb2.TransmissionStatus(eot=True), id=req.id) except BaseException as err: _LOGGER.critical("UDFError, re-raising the error", exc_info=True) - await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING) + await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING, False) return async def __invoke_map_stream(self, keys: list[str], req: Datum): diff --git a/pynumaflow/reducer/servicer/async_servicer.py b/pynumaflow/reducer/servicer/async_servicer.py index d8e00c41..24759847 100644 --- a/pynumaflow/reducer/servicer/async_servicer.py +++ b/pynumaflow/reducer/servicer/async_servicer.py @@ -105,7 +105,7 @@ async def ReduceFn( _LOGGER.critical("Reduce Error", exc_info=True) # Send a context abort signal for the rpc, this is required for numa container to get # the correct grpc error - await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING) + await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING, False) # send EOF to all the tasks once the request iterator is exhausted # This will signal the tasks to stop reading the data on their @@ -136,7 +136,7 @@ async def ReduceFn( _LOGGER.critical("Reduce Error", exc_info=True) # Send a context abort signal for the rpc, this is required for numa container to get # the correct grpc error - await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING) + await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING, False) async def IsReady( self, request: _empty_pb2.Empty, context: NumaflowServicerContext diff --git a/pynumaflow/reducestreamer/servicer/async_servicer.py b/pynumaflow/reducestreamer/servicer/async_servicer.py index 4a3498dd..6802bbe3 100644 --- a/pynumaflow/reducestreamer/servicer/async_servicer.py +++ b/pynumaflow/reducestreamer/servicer/async_servicer.py @@ -95,20 +95,20 @@ async def ReduceFn( async for msg in consumer: # If the message is an exception, we raise the exception if isinstance(msg, BaseException): - await handle_async_error(context, msg, ERR_UDF_EXCEPTION_STRING) + await handle_async_error(context, msg, ERR_UDF_EXCEPTION_STRING, False) return # Send window EOF response or Window result response # back to the client else: yield msg except BaseException as e: - await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING) + await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING, False) return # Wait for the process_input_stream task to finish for a clean exit try: await producer except BaseException as e: - await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING) + await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING, False) return async def IsReady( diff --git a/pynumaflow/shared/server.py b/pynumaflow/shared/server.py index 128b8dc7..f9b74650 100644 --- a/pynumaflow/shared/server.py +++ b/pynumaflow/shared/server.py @@ -180,7 +180,7 @@ async def start_async_server( sock_path: str, max_threads: int, cleanup_coroutines: list, - server_info_file: str, + server_info_file: Optional[str] = None, server_info: Optional[ServerInfo] = None, ): """ @@ -190,11 +190,9 @@ async def start_async_server( """ await server_async.start() - if server_info is None: - # Create the server info file if not provided - server_info = ServerInfo.get_default_server_info() # Add the server information to the server info file - info_server_write(server_info=server_info, info_file=server_info_file) + if server_info_file: + info_server_write(server_info=server_info, info_file=server_info_file) # Log the server start _LOGGER.info( @@ -217,7 +215,7 @@ async def server_graceful_shutdown(): @contextlib.contextmanager -def _reserve_port(port_num: int) -> Iterator[int]: +def reserve_port(port_num: int) -> Iterator[int]: """Find and reserve a port for all subprocesses to use.""" sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) @@ -312,7 +310,10 @@ def get_exception_traceback_str(exc) -> str: async def handle_async_error( - context: NumaflowServicerContext, exception: BaseException, exception_type: str + context: NumaflowServicerContext, + exception: BaseException, + exception_type: str, + parent: bool = False, ): """ Handle exceptions for async servers by updating the context and exiting. @@ -322,4 +323,4 @@ async def handle_async_error( await asyncio.gather( context.abort(grpc.StatusCode.INTERNAL, details=err_msg), return_exceptions=True ) - exit_on_error(err=err_msg, parent=False, context=context, update_context=False) + exit_on_error(err=err_msg, parent=parent, context=context, update_context=False) diff --git a/pynumaflow/sinker/servicer/async_servicer.py b/pynumaflow/sinker/servicer/async_servicer.py index ed9a5572..32a05c0d 100644 --- a/pynumaflow/sinker/servicer/async_servicer.py +++ b/pynumaflow/sinker/servicer/async_servicer.py @@ -85,7 +85,7 @@ async def SinkFn( # if there is an exception, we will mark all the responses as a failure err_msg = f"UDSinkError: {repr(err)}" _LOGGER.critical(err_msg, exc_info=True) - await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING) + await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING, False) return async def __invoke_sink( diff --git a/pynumaflow/sourcer/servicer/async_servicer.py b/pynumaflow/sourcer/servicer/async_servicer.py index f5e4b0f9..1ecc3b48 100644 --- a/pynumaflow/sourcer/servicer/async_servicer.py +++ b/pynumaflow/sourcer/servicer/async_servicer.py @@ -108,7 +108,7 @@ async def ReadFn( async for resp in riter: if isinstance(resp, BaseException): - await handle_async_error(context, resp) + await handle_async_error(context, resp, ERR_UDF_EXCEPTION_STRING, False) return yield _create_read_response(resp) @@ -119,7 +119,7 @@ async def ReadFn( yield _create_eot_response() except BaseException as err: _LOGGER.critical("User-Defined Source ReadFn error", exc_info=True) - await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING) + await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING, False) async def __invoke_read(self, req, niter): """Invoke the read handler and manage the iterator.""" @@ -165,7 +165,7 @@ async def AckFn( yield _create_ack_response() except BaseException as err: _LOGGER.critical("User-Defined Source AckFn error", exc_info=True) - await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING) + await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING, False) async def IsReady( self, request: _empty_pb2.Empty, context: NumaflowServicerContext @@ -187,7 +187,7 @@ async def PendingFn( count = await self.__source_pending_handler() except BaseException as err: _LOGGER.critical("PendingFn Error", exc_info=True) - await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING) + await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING, False) return resp = source_pb2.PendingResponse.Result(count=count.count) return source_pb2.PendingResponse(result=resp) @@ -202,7 +202,7 @@ async def PartitionsFn( partitions = await self.__source_partitions_handler() except BaseException as err: _LOGGER.critical("PartitionsFn Error", exc_info=True) - await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING) + await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING, False) return resp = source_pb2.PartitionsResponse.Result(partitions=partitions.partitions) return source_pb2.PartitionsResponse(result=resp) diff --git a/tests/map/test_async_multiproc.py b/tests/map/test_async_multiproc.py new file mode 100644 index 00000000..631e7be0 --- /dev/null +++ b/tests/map/test_async_multiproc.py @@ -0,0 +1,85 @@ +import uuid +from pynumaflow.mapper import Datum, Messages, Message + +sock_prefix = f"/tmp/test_async_multiproc_map_{uuid.uuid4().hex}_" + + +async def async_handler(keys, datum: Datum) -> Messages: + msg = ( + f"payload:{datum.value.decode()} event_time:{datum.event_time} watermark:{datum.watermark}" + ) + return Messages(Message(value=msg.encode(), keys=keys)) + + +# +# class TestAsyncMapMultiprocServer(unittest.TestCase): +# def setUp(self): +# self.base_sock_path = sock_prefix +# self.server = AsyncMapMultiprocServer( +# mapper_instance=async_handler, +# server_count=2, +# sock_path=self.base_sock_path, +# use_tcp=False, +# server_info_file=None, +# ) +# self.process = Process(target=self.server.start) +# self.process.start() +# +# # Wait for both servers to bind +# self.socket_paths = [f"{self.base_sock_path}{i}.sock" for i in range(2)] +# for path in self.socket_paths: +# for _ in range(10): +# if os.path.exists(path): +# break +# time.sleep(0.5) +# +# def tearDown(self): +# self.process.terminate() +# self.process.join() +# for path in self.socket_paths: +# try: +# os.remove(path) +# except FileNotFoundError: +# pass +# +# def test_map_fn(self): +# bind_address = f"unix://{self.socket_paths[0]}" +# request = get_test_datums() +# with grpc.insecure_channel(bind_address) as channel: +# stub = map_pb2_grpc.MapStub(channel) +# responses_iter = stub.MapFn(request_iterator=request_generator(request)) +# responses = [] +# # capture the output from the ReadFn generator and assert. +# for r in responses_iter: +# responses.append(r) +# +# # 1 handshake + 3 data responses +# self.assertEqual(4, len(responses)) +# +# self.assertTrue(responses[0].handshake.sot) +# +# idx = 1 +# while idx < len(responses): +# _id = "test-id-" + str(idx) +# self.assertEqual(_id, responses[idx].id) +# self.assertEqual( +# bytes( +# "payload:test_mock_message " +# "event_time:2022-09-12 16:00:00 watermark:2022-09-12 16:01:00", +# encoding="utf-8", +# ), +# responses[idx].results[0].value, +# ) +# self.assertEqual(1, len(responses[idx].results)) +# idx += 1 +# +# def test_server_start(self): +# for path in self.socket_paths: +# self.assertTrue( +# os.path.exists(path), f"Server socket {path} was not created successfully" +# ) + +# +# if __name__ == "__main__": +# unittest.main() +# diff --git a/tests/source/test_async_source_err.py b/tests/source/test_async_source_err.py index c72060bd..c685b1f5 100644 --- a/tests/source/test_async_source_err.py +++ b/tests/source/test_async_source_err.py @@ -9,7 +9,6 @@ from grpc.aio._server import Server from pynumaflow import setup_logging -from pynumaflow._constants import ERR_UDF_EXCEPTION_STRING from pynumaflow.sourcer import SourceAsyncServer from pynumaflow.proto.sourcer import source_pb2_grpc from google.protobuf import empty_pb2 as _empty_pb2 @@ -93,11 +92,7 @@ def test_read_error(self) -> None: for _ in generator_response: pass except BaseException as e: - self.assertTrue( - f"{ERR_UDF_EXCEPTION_STRING}: TypeError(" - '"handle_async_error() missing 1 required positional argument: ' - "'exception_type'\")" in e.__str__() - ) + self.assertTrue("Got a runtime error from read handler" in e.__str__()) return except grpc.RpcError as e: grpc_exception = e