From bb9beb5df99f5451b0eb23f9179933e86a40a8ca Mon Sep 17 00:00:00 2001 From: David Brochart Date: Fri, 24 May 2024 16:41:22 +0200 Subject: [PATCH 1/2] Use ypywidgets for server-side input widget --- .../fps_kernels/kernel_driver/driver.py | 117 +++++++++++++++++- plugins/yjs/fps_yjs/ydocs/ynotebook.py | 9 ++ 2 files changed, 121 insertions(+), 5 deletions(-) diff --git a/plugins/kernels/fps_kernels/kernel_driver/driver.py b/plugins/kernels/fps_kernels/kernel_driver/driver.py index d1761b77..b35f299c 100644 --- a/plugins/kernels/fps_kernels/kernel_driver/driver.py +++ b/plugins/kernels/fps_kernels/kernel_driver/driver.py @@ -1,10 +1,14 @@ +from __future__ import annotations + import asyncio import os import time import uuid +from functools import partial from typing import Any, Dict, List, Optional, cast -from pycrdt import Array, Map +from pycrdt import Array, Doc, Map, MapEvent, Text +from ypywidgets import Declare, Widget from jupyverse_api.yjs import Yjs @@ -46,6 +50,7 @@ def __init__( self.execute_requests: Dict[str, Dict[str, asyncio.Queue]] = {} self.comm_messages: asyncio.Queue = asyncio.Queue() self.tasks: List[asyncio.Task] = [] + self._background_tasks: set[asyncio.Task] = set() async def restart(self, startup_timeout: float = float("inf")) -> None: for task in self.tasks: @@ -80,13 +85,23 @@ async def connect(self, startup_timeout: float = float("inf")) -> None: def connect_channels(self, connection_cfg: Optional[cfg_t] = None): connection_cfg = connection_cfg or self.connection_cfg - self.shell_channel = connect_channel("shell", connection_cfg) + self.shell_channel = connect_channel( + "shell", + connection_cfg, + identity=self.session_id.encode(), + ) self.control_channel = connect_channel("control", connection_cfg) self.iopub_channel = connect_channel("iopub", connection_cfg) + self.stdin_channel = connect_channel( + "stdin", + connection_cfg, + identity=self.session_id.encode(), + ) def listen_channels(self): self.tasks.append(asyncio.create_task(self.listen_iopub())) self.tasks.append(asyncio.create_task(self.listen_shell())) + self.tasks.append(asyncio.create_task(self.listen_stdin())) async def stop(self) -> None: self.kernel_process.kill() @@ -111,6 +126,13 @@ async def listen_shell(self): if msg_id in self.execute_requests.keys(): self.execute_requests[msg_id]["shell_msg"].put_nowait(msg) + async def listen_stdin(self): + while True: + msg = await receive_message(self.stdin_channel, change_str_to_date=True) + msg_id = msg["parent_header"].get("msg_id") + if msg_id in self.execute_requests.keys(): + self.execute_requests[msg_id]["stdin_msg"].put_nowait(msg) + async def execute( self, ycell: Map, @@ -121,7 +143,7 @@ async def execute( if ycell["cell_type"] != "code": return ycell["execution_state"] = "busy" - content = {"code": str(ycell["source"]), "silent": False} + content = {"code": str(ycell["source"]), "silent": False, "allow_stdin": True} msg = create_message( "execute_request", content, session_id=self.session_id, msg_id=str(self.msg_cnt) ) @@ -134,6 +156,7 @@ async def execute( self.execute_requests[msg_id] = { "iopub_msg": asyncio.Queue(), "shell_msg": asyncio.Queue(), + "stdin_msg": asyncio.Queue(), } if wait_for_executed: deadline = time.time() + timeout @@ -165,9 +188,11 @@ async def execute( ycell["execution_state"] = "idle" del self.execute_requests[msg_id] else: - self.tasks.append(asyncio.create_task(self._handle_iopub(msg_id, ycell))) + stdin_task = asyncio.create_task(self._handle_stdin(msg_id, ycell)) + self.tasks.append(stdin_task) + self.tasks.append(asyncio.create_task(self._handle_iopub(msg_id, ycell, stdin_task))) - async def _handle_iopub(self, msg_id: str, ycell: Map) -> None: + async def _handle_iopub(self, msg_id: str, ycell: Map, stdin_task: asyncio.Task) -> None: while True: msg = await self.execute_requests[msg_id]["iopub_msg"].get() await self._handle_outputs(ycell["outputs"], msg) @@ -175,11 +200,67 @@ async def _handle_iopub(self, msg_id: str, ycell: Map) -> None: (msg["header"]["msg_type"] == "status" and msg["content"]["execution_state"] == "idle") ): + stdin_task.cancel() msg = await self.execute_requests[msg_id]["shell_msg"].get() with ycell.doc.transaction(): ycell["execution_count"] = msg["content"]["execution_count"] ycell["execution_state"] = "idle" + async def _handle_stdin(self, msg_id: str, ycell: Map) -> None: + while True: + msg = await self.execute_requests[msg_id]["stdin_msg"].get() + if msg["msg_type"] == "input_request": + content = msg["content"] + prompt = content["prompt"] + password = content["password"] + guid = uuid.uuid4().hex + path = f"ywidget:{guid}" + input_model = InputModel(prompt=prompt, password=password, value="") + await self.yjs.room_manager.websocket_server.get_room(path, ydoc=input_model.ydoc) + stdin_output = Map( + { + "output_type": "ywidget", + "room_id": path, + "model_name": "Input", + } + ) + outputs: Array = cast(Array, ycell.get("outputs")) + stdin_idx = len(outputs) + outputs.append(stdin_output) + input_model._attrs.observe_deep( + partial(self._handle_stdin_submission, outputs, stdin_idx, password, prompt) + ) + + def _handle_stdin_submission(self, outputs, stdin_idx, password, prompt, events): + for event in events: + if isinstance(event, MapEvent): + if event.target["submitted"]: + # send input reply to kernel + value = str(event.target["value"]) + content = {"value": value} + msg = create_message( + "input_reply", content, session_id=self.session_id, msg_id=str(self.msg_cnt) + ) + task0 = asyncio.create_task( + send_message(msg, self.stdin_channel, self.key, change_date_to_str=True) + ) + if password: + value = "········" + value = f"{prompt} {value}" + task1 = asyncio.create_task(self._change_stdin_to_stream(outputs, stdin_idx, value)) + self._background_tasks.add(task0) + self._background_tasks.add(task1) + task0.add_done_callback(self._background_tasks.discard) + task1.add_done_callback(self._background_tasks.discard) + + async def _change_stdin_to_stream(self, outputs, stdin_idx, value): + # replace stdin output with stream output + outputs[stdin_idx] = { + "output_type": "stream", + "name": "stdin", + "text": value + '\n', + } + async def _handle_comms(self) -> None: if self.yjs is None or self.yjs.widgets is None: # type: ignore return @@ -296,3 +377,29 @@ def send(self, buffers): asyncio.create_task( send_message(msg, self.shell_channel, self.key, change_date_to_str=True) ) + + +class InputModel(Widget): + submitted = Declare[bool](False) + password = Declare[bool](False) + prompt = Declare[str]("") + value = Declare[Text](Text()) + + def __init__( + self, + submitted: bool | None = None, + password: bool | None = None, + prompt: str | None = None, + value: str | None = None, + ydoc: Doc | None = None, + ) -> None: + super().__init__(ydoc) + if submitted is not None: + self.submitted = submitted + if password is not None: + self.password = password + if prompt is not None: + self.prompt = prompt + self._attrs["value"] = _value = Text() + if value is not None: + _value += value diff --git a/plugins/yjs/fps_yjs/ydocs/ynotebook.py b/plugins/yjs/fps_yjs/ydocs/ynotebook.py index 0cadd698..141929a8 100644 --- a/plugins/yjs/fps_yjs/ydocs/ynotebook.py +++ b/plugins/yjs/fps_yjs/ydocs/ynotebook.py @@ -49,6 +49,15 @@ def get_cell(self, index: int) -> Dict[str, Any]: and not cell["attachments"] ): del cell["attachments"] + outputs = cell.get("outputs", []) + del_outputs = [] + for idx, output in enumerate(outputs): + if output["output_type"] == "ywidget": + del_outputs.append(idx) + deleted = 0 + for idx in del_outputs: + del outputs[idx - deleted] + deleted += 1 return cell def append_cell(self, value: Dict[str, Any]) -> None: From 89dfa3da8c8baf2c430da286bdb15682f4b6b356 Mon Sep 17 00:00:00 2001 From: David Brochart Date: Fri, 21 Jun 2024 16:26:52 +0200 Subject: [PATCH 2/2] Support stream output ywidget in server-side --- .../fps_kernels/kernel_driver/driver.py | 59 ++++++++++++++----- 1 file changed, 43 insertions(+), 16 deletions(-) diff --git a/plugins/kernels/fps_kernels/kernel_driver/driver.py b/plugins/kernels/fps_kernels/kernel_driver/driver.py index b35f299c..66398e1f 100644 --- a/plugins/kernels/fps_kernels/kernel_driver/driver.py +++ b/plugins/kernels/fps_kernels/kernel_driver/driver.py @@ -207,6 +207,9 @@ async def _handle_iopub(self, msg_id: str, ycell: Map, stdin_task: asyncio.Task) ycell["execution_state"] = "idle" async def _handle_stdin(self, msg_id: str, ycell: Map) -> None: + if self.yjs is None: + return + while True: msg = await self.execute_requests[msg_id]["stdin_msg"].get() if msg["msg_type"] == "input_request": @@ -303,30 +306,36 @@ async def _handle_outputs(self, outputs: Array, msg: Dict[str, Any]): content = msg["content"] if msg_type == "stream": with outputs.doc.transaction(): - # TODO: uncomment when changes are made in jupyter-ydoc text = content["text"] if text.endswith((os.linesep, "\n")): text = text[:-1] - if (not outputs) or (outputs[-1]["name"] != content["name"]): # type: ignore - outputs.append( - #Map( - # { - # "name": content["name"], - # "output_type": msg_type, - # "text": Array([content["text"]]), - # } - #) + stream_name = content["name"] + model_name = stream_name.capitalize() + "Output" + if (len(outputs) == 0) or (outputs[-1]["model_name"] != model_name): # type: ignore + guid = uuid.uuid4().hex + path = f"ywidget:{guid}" + output_model = OutputModel(text=text, stream_name=stream_name) + _attrs = output_model.ydoc.get("_attrs", type=Map) + _text = _attrs["text"] + await self.yjs.room_manager.websocket_server.get_room(path, ydoc=output_model.ydoc) + std_output = Map( { - "name": content["name"], - "output_type": msg_type, - "text": [text], + "output_type": "ywidget", + "room_id": path, + "model_name": model_name, } ) + outputs.append(std_output) else: - #outputs[-1]["text"].append(content["text"]) # type: ignore last_output = outputs[-1] - last_output["text"].append(text) # type: ignore - outputs[-1] = last_output + path = last_output["room_id"] + model_name = last_output["model_name"] + stream_name = model_name[:model_name.find("Output")] + stream_name = stream_name[0].lower() + stream_name[1:] + room = await self.yjs.room_manager.websocket_server.get_room(path) + _attrs = room.ydoc.get("_attrs", type=Map) + _text = _attrs["text"] + _text += text elif msg_type in ("display_data", "execute_result"): if "application/vnd.jupyter.ywidget-view+json" in content["data"]: # this is a collaborative widget @@ -403,3 +412,21 @@ def __init__( self._attrs["value"] = _value = Text() if value is not None: _value += value + + +class OutputModel(Widget): + text = Declare[Text](Text()) + stream_name = Declare[str]("") + + def __init__( + self, + text: str | None = None, + stream_name: str | None = None, + ydoc: Doc | None = None, + ) -> None: + super().__init__(ydoc) + if text is not None: + self._attrs["text"] = _text = Text() + _text += text + if stream_name is not None: + self.stream_name = stream_name