diff --git a/crates/amalthea/src/fixtures/dummy_frontend.rs b/crates/amalthea/src/fixtures/dummy_frontend.rs index ca8465f0e..ba7c0b8fe 100644 --- a/crates/amalthea/src/fixtures/dummy_frontend.rs +++ b/crates/amalthea/src/fixtures/dummy_frontend.rs @@ -12,11 +12,11 @@ use crate::connection_file::ConnectionFile; use crate::session::Session; use crate::socket::socket::Socket; use crate::wire::execute_input::ExecuteInput; -use crate::wire::execute_reply::ExecuteReply; use crate::wire::execute_request::ExecuteRequest; use crate::wire::jupyter_message::JupyterMessage; use crate::wire::jupyter_message::Message; use crate::wire::jupyter_message::ProtocolMessage; +use crate::wire::jupyter_message::Status; use crate::wire::status::ExecutionState; use crate::wire::wire_message::WireMessage; @@ -160,12 +160,25 @@ impl DummyFrontend { Message::read_from_socket(&self.shell_socket).unwrap() } - /// Receive from Shell and assert ExecuteReply message - pub fn recv_shell_execute_reply(&self) -> ExecuteReply { + /// Receive from Shell and assert `ExecuteReply` message. + /// Returns `execution_count`. + pub fn recv_shell_execute_reply(&self) -> u32 { let msg = self.recv_shell(); assert_match!(msg, Message::ExecuteReply(data) => { - data.content + assert_eq!(data.content.status, Status::Ok); + data.content.execution_count + }) + } + + /// Receive from Shell and assert `ExecuteReplyException` message. + /// Returns `execution_count`. + pub fn recv_shell_execute_reply_exception(&self) -> u32 { + let msg = self.recv_shell(); + + assert_match!(msg, Message::ExecuteReplyException(data) => { + assert_eq!(data.content.status, Status::Error); + data.content.execution_count }) } diff --git a/crates/amalthea/src/socket/shell.rs b/crates/amalthea/src/socket/shell.rs index 91d5824d0..6ad2a04bd 100644 --- a/crates/amalthea/src/socket/shell.rs +++ b/crates/amalthea/src/socket/shell.rs @@ -221,6 +221,8 @@ impl Shell { let r = req.send_reply(reply, &self.socket); r }, + // FIXME: Ark already created an `ExecuteReplyException` so we use + // `.send_reply()` instead of `.send_error()`. Can we streamline this? Err(err) => req.send_reply(err, &self.socket), } } diff --git a/crates/amalthea/src/wire/error_reply.rs b/crates/amalthea/src/wire/error_reply.rs index f3edb3c4b..1d683d7ae 100644 --- a/crates/amalthea/src/wire/error_reply.rs +++ b/crates/amalthea/src/wire/error_reply.rs @@ -13,7 +13,11 @@ use crate::wire::jupyter_message::MessageType; use crate::wire::jupyter_message::Status; /// Represents an error that occurred after processing a request on a -/// ROUTER/DEALER socket +/// ROUTER/DEALER socket. +/// +/// This is the payload of a response to a request. Note that, as an exception, +/// responses to `"execute_request"` include an `execution_count` field. We +/// represent these with an `ExecuteReplyException`. #[derive(Debug, Serialize, Deserialize, Clone)] pub struct ErrorReply { /// The status; always Error @@ -25,9 +29,11 @@ pub struct ErrorReply { } /// Note that the message type of an error reply is generally adjusted to match -/// its request type (e.g. foo_request => foo_reply) +/// its request type (e.g. foo_request => foo_reply). The message type +/// implemented here is only a placeholder and should not appear in any +/// serialized/deserialized message. impl MessageType for ErrorReply { fn message_type() -> String { - String::from("error") + String::from("*error payload*") } } diff --git a/crates/amalthea/src/wire/execute_error.rs b/crates/amalthea/src/wire/execute_error.rs index e98f37899..4763d91f2 100644 --- a/crates/amalthea/src/wire/execute_error.rs +++ b/crates/amalthea/src/wire/execute_error.rs @@ -11,7 +11,10 @@ use serde::Serialize; use crate::wire::exception::Exception; use crate::wire::jupyter_message::MessageType; -/// Represents an exception that occurred while executing code +/// Represents an exception that occurred while executing code. +/// This is sent to IOPub. Not to be confused with `ExecuteReplyException` +/// which is a special case of a message of type `"execute_reply"` sent to Shell +/// in response to an `"execute_request"`. #[derive(Debug, Serialize, Deserialize, Clone)] pub struct ExecuteError { /// The exception that occurred during execution diff --git a/crates/amalthea/src/wire/jupyter_message.rs b/crates/amalthea/src/wire/jupyter_message.rs index f0006aa64..36eb09074 100644 --- a/crates/amalthea/src/wire/jupyter_message.rs +++ b/crates/amalthea/src/wire/jupyter_message.rs @@ -168,57 +168,92 @@ impl TryFrom<&WireMessage> for Message { /// messages that are received from the frontend. fn try_from(msg: &WireMessage) -> Result { let kind = msg.header.msg_type.clone(); + if kind == KernelInfoRequest::message_type() { return Ok(Message::KernelInfoRequest(JupyterMessage::try_from(msg)?)); - } else if kind == KernelInfoReply::message_type() { + } + if kind == KernelInfoReply::message_type() { return Ok(Message::KernelInfoReply(JupyterMessage::try_from(msg)?)); - } else if kind == IsCompleteRequest::message_type() { + } + if kind == IsCompleteRequest::message_type() { return Ok(Message::IsCompleteRequest(JupyterMessage::try_from(msg)?)); - } else if kind == IsCompleteReply::message_type() { + } + if kind == IsCompleteReply::message_type() { return Ok(Message::IsCompleteReply(JupyterMessage::try_from(msg)?)); - } else if kind == InspectRequest::message_type() { + } + if kind == InspectRequest::message_type() { return Ok(Message::InspectRequest(JupyterMessage::try_from(msg)?)); - } else if kind == InspectReply::message_type() { + } + if kind == InspectReply::message_type() { return Ok(Message::InspectReply(JupyterMessage::try_from(msg)?)); - } else if kind == ExecuteRequest::message_type() { + } + if kind == ExecuteReplyException::message_type() { + if let Ok(data) = JupyterMessage::try_from(msg) { + return Ok(Message::ExecuteReplyException(data)); + } + // else fallthrough to try `ExecuteRequest` which has the same message type + } + if kind == ExecuteRequest::message_type() { return Ok(Message::ExecuteRequest(JupyterMessage::try_from(msg)?)); - } else if kind == ExecuteReply::message_type() { + } + if kind == ExecuteReply::message_type() { return Ok(Message::ExecuteReply(JupyterMessage::try_from(msg)?)); - } else if kind == ExecuteResult::message_type() { + } + if kind == ExecuteResult::message_type() { return Ok(Message::ExecuteResult(JupyterMessage::try_from(msg)?)); - } else if kind == ExecuteInput::message_type() { + } + if kind == ExecuteError::message_type() { + return Ok(Message::ExecuteError(JupyterMessage::try_from(msg)?)); + } + if kind == ExecuteInput::message_type() { return Ok(Message::ExecuteInput(JupyterMessage::try_from(msg)?)); - } else if kind == CompleteRequest::message_type() { + } + if kind == CompleteRequest::message_type() { return Ok(Message::CompleteRequest(JupyterMessage::try_from(msg)?)); - } else if kind == CompleteReply::message_type() { + } + if kind == CompleteReply::message_type() { return Ok(Message::CompleteReply(JupyterMessage::try_from(msg)?)); - } else if kind == ShutdownRequest::message_type() { + } + if kind == ShutdownRequest::message_type() { return Ok(Message::ShutdownRequest(JupyterMessage::try_from(msg)?)); - } else if kind == KernelStatus::message_type() { + } + if kind == KernelStatus::message_type() { return Ok(Message::Status(JupyterMessage::try_from(msg)?)); - } else if kind == CommInfoRequest::message_type() { + } + if kind == CommInfoRequest::message_type() { return Ok(Message::CommInfoRequest(JupyterMessage::try_from(msg)?)); - } else if kind == CommInfoReply::message_type() { + } + if kind == CommInfoReply::message_type() { return Ok(Message::CommInfoReply(JupyterMessage::try_from(msg)?)); - } else if kind == CommOpen::message_type() { + } + if kind == CommOpen::message_type() { return Ok(Message::CommOpen(JupyterMessage::try_from(msg)?)); - } else if kind == CommWireMsg::message_type() { + } + if kind == CommWireMsg::message_type() { return Ok(Message::CommMsg(JupyterMessage::try_from(msg)?)); - } else if kind == CommClose::message_type() { + } + if kind == CommClose::message_type() { return Ok(Message::CommClose(JupyterMessage::try_from(msg)?)); - } else if kind == InterruptRequest::message_type() { + } + if kind == InterruptRequest::message_type() { return Ok(Message::InterruptRequest(JupyterMessage::try_from(msg)?)); - } else if kind == InterruptReply::message_type() { + } + if kind == InterruptReply::message_type() { return Ok(Message::InterruptReply(JupyterMessage::try_from(msg)?)); - } else if kind == InputReply::message_type() { + } + if kind == InputReply::message_type() { return Ok(Message::InputReply(JupyterMessage::try_from(msg)?)); - } else if kind == InputRequest::message_type() { + } + if kind == InputRequest::message_type() { return Ok(Message::InputRequest(JupyterMessage::try_from(msg)?)); - } else if kind == StreamOutput::message_type() { + } + if kind == StreamOutput::message_type() { return Ok(Message::StreamOutput(JupyterMessage::try_from(msg)?)); - } else if kind == UiFrontendRequest::message_type() { + } + if kind == UiFrontendRequest::message_type() { return Ok(Message::CommRequest(JupyterMessage::try_from(msg)?)); - } else if kind == JsonRpcReply::message_type() { + } + if kind == JsonRpcReply::message_type() { return Ok(Message::CommReply(JupyterMessage::try_from(msg)?)); } return Err(Error::UnknownMessageType(kind)); diff --git a/crates/ark/tests/kernel.rs b/crates/ark/tests/kernel.rs index 7b2382ed0..fe9a93230 100644 --- a/crates/ark/tests/kernel.rs +++ b/crates/ark/tests/kernel.rs @@ -1,5 +1,4 @@ use amalthea::wire::jupyter_message::Message; -use amalthea::wire::jupyter_message::Status; use amalthea::wire::kernel_info_request::KernelInfoRequest; use ark::fixtures::DummyArkFrontend; use stdext::assert_match; @@ -28,11 +27,30 @@ fn test_execute_request() { frontend.send_execute_request("42"); frontend.recv_iopub_busy(); - assert_eq!(frontend.recv_iopub_execute_input().code, "42"); + let input = frontend.recv_iopub_execute_input(); + assert_eq!(input.code, "42"); assert_eq!(frontend.recv_iopub_execute_result(), "[1] 42"); frontend.recv_iopub_idle(); - let reply = frontend.recv_shell_execute_reply(); - assert_eq!(reply.status, Status::Ok); + assert_eq!(frontend.recv_shell_execute_reply(), input.execution_count) +} + +#[test] +fn test_execute_request_error() { + let frontend = DummyArkFrontend::lock(); + + frontend.send_execute_request("stop('foobar')"); + frontend.recv_iopub_busy(); + + let input = frontend.recv_iopub_execute_input(); + assert_eq!(input.code, "stop('foobar')"); + assert!(frontend.recv_iopub_execute_error().contains("foobar")); + + frontend.recv_iopub_idle(); + + assert_eq!( + frontend.recv_shell_execute_reply_exception(), + input.execution_count + ); }