diff --git a/lldb/cmake/modules/LLDBConfig.cmake b/lldb/cmake/modules/LLDBConfig.cmake index 37b823feb584b..8c30b6e09d2c7 100644 --- a/lldb/cmake/modules/LLDBConfig.cmake +++ b/lldb/cmake/modules/LLDBConfig.cmake @@ -67,6 +67,7 @@ add_optional_dependency(LLDB_ENABLE_FBSDVMCORE "Enable libfbsdvmcore support in option(LLDB_USE_ENTITLEMENTS "When codesigning, use entitlements if available" ON) option(LLDB_BUILD_FRAMEWORK "Build LLDB.framework (Darwin only)" OFF) +option(LLDB_ENABLE_PROTOCOL_SERVERS "Enable protocol servers (e.g. MCP) in LLDB" ON) option(LLDB_NO_INSTALL_DEFAULT_RPATH "Disable default RPATH settings in binaries" OFF) option(LLDB_USE_SYSTEM_DEBUGSERVER "Use the system's debugserver for testing (Darwin only)." OFF) option(LLDB_SKIP_STRIP "Whether to skip stripping of binaries when installing lldb." OFF) diff --git a/lldb/include/lldb/Core/Debugger.h b/lldb/include/lldb/Core/Debugger.h index d73aba1e3ce58..0f6659d1a0bf7 100644 --- a/lldb/include/lldb/Core/Debugger.h +++ b/lldb/include/lldb/Core/Debugger.h @@ -598,6 +598,10 @@ class Debugger : public std::enable_shared_from_this, void FlushProcessOutput(Process &process, bool flush_stdout, bool flush_stderr); + void AddProtocolServer(lldb::ProtocolServerSP protocol_server_sp); + void RemoveProtocolServer(lldb::ProtocolServerSP protocol_server_sp); + lldb::ProtocolServerSP GetProtocolServer(llvm::StringRef protocol) const; + SourceManager::SourceFileCache &GetSourceFileCache() { return m_source_file_cache; } @@ -768,6 +772,8 @@ class Debugger : public std::enable_shared_from_this, mutable std::mutex m_progress_reports_mutex; /// @} + llvm::SmallVector m_protocol_servers; + std::mutex m_destroy_callback_mutex; lldb::callback_token_t m_destroy_callback_next_token = 0; struct DestroyCallbackInfo { diff --git a/lldb/include/lldb/Core/PluginManager.h b/lldb/include/lldb/Core/PluginManager.h index e7b1691031111..e50bf97189cfc 100644 --- a/lldb/include/lldb/Core/PluginManager.h +++ b/lldb/include/lldb/Core/PluginManager.h @@ -327,6 +327,17 @@ class PluginManager { static void AutoCompleteProcessName(llvm::StringRef partial_name, CompletionRequest &request); + // Protocol + static bool RegisterPlugin(llvm::StringRef name, llvm::StringRef description, + ProtocolServerCreateInstance create_callback); + + static bool UnregisterPlugin(ProtocolServerCreateInstance create_callback); + + static llvm::StringRef GetProtocolServerPluginNameAtIndex(uint32_t idx); + + static ProtocolServerCreateInstance + GetProtocolCreateCallbackForPluginName(llvm::StringRef name); + // Register Type Provider static bool RegisterPlugin(llvm::StringRef name, llvm::StringRef description, RegisterTypeBuilderCreateInstance create_callback); diff --git a/lldb/include/lldb/Core/ProtocolServer.h b/lldb/include/lldb/Core/ProtocolServer.h new file mode 100644 index 0000000000000..fafe460904323 --- /dev/null +++ b/lldb/include/lldb/Core/ProtocolServer.h @@ -0,0 +1,39 @@ +//===-- ProtocolServer.h --------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLDB_CORE_PROTOCOLSERVER_H +#define LLDB_CORE_PROTOCOLSERVER_H + +#include "lldb/Core/PluginInterface.h" +#include "lldb/Host/Socket.h" +#include "lldb/lldb-private-interfaces.h" + +namespace lldb_private { + +class ProtocolServer : public PluginInterface { +public: + ProtocolServer() = default; + virtual ~ProtocolServer() = default; + + static lldb::ProtocolServerSP Create(llvm::StringRef name, + Debugger &debugger); + + struct Connection { + Socket::SocketProtocol protocol; + std::string name; + }; + + virtual llvm::Error Start(Connection connection) = 0; + virtual llvm::Error Stop() = 0; + + virtual Socket *GetSocket() const = 0; +}; + +} // namespace lldb_private + +#endif diff --git a/lldb/include/lldb/Interpreter/CommandOptionArgumentTable.h b/lldb/include/lldb/Interpreter/CommandOptionArgumentTable.h index 8535dfcf46da5..4face717531b1 100644 --- a/lldb/include/lldb/Interpreter/CommandOptionArgumentTable.h +++ b/lldb/include/lldb/Interpreter/CommandOptionArgumentTable.h @@ -315,6 +315,7 @@ static constexpr CommandObject::ArgumentTableEntry g_argument_table[] = { { lldb::eArgTypeCPUName, "cpu-name", lldb::CompletionType::eNoCompletion, {}, { nullptr, false }, "The name of a CPU." }, { lldb::eArgTypeCPUFeatures, "cpu-features", lldb::CompletionType::eNoCompletion, {}, { nullptr, false }, "The CPU feature string." }, { lldb::eArgTypeManagedPlugin, "managed-plugin", lldb::CompletionType::eNoCompletion, {}, { nullptr, false }, "Plugins managed by the PluginManager" }, + { lldb::eArgTypeProtocol, "protocol", lldb::CompletionType::eNoCompletion, {}, { nullptr, false }, "The name of the protocol." }, // clang-format on }; diff --git a/lldb/include/lldb/lldb-enumerations.h b/lldb/include/lldb/lldb-enumerations.h index eeb7299a354e1..69e8671b6e21b 100644 --- a/lldb/include/lldb/lldb-enumerations.h +++ b/lldb/include/lldb/lldb-enumerations.h @@ -664,6 +664,7 @@ enum CommandArgumentType { eArgTypeCPUName, eArgTypeCPUFeatures, eArgTypeManagedPlugin, + eArgTypeProtocol, eArgTypeLastArg // Always keep this entry as the last entry in this // enumeration!! }; diff --git a/lldb/include/lldb/lldb-forward.h b/lldb/include/lldb/lldb-forward.h index c664d1398f74d..558818e8e2309 100644 --- a/lldb/include/lldb/lldb-forward.h +++ b/lldb/include/lldb/lldb-forward.h @@ -164,13 +164,13 @@ class PersistentExpressionState; class Platform; class Process; class ProcessAttachInfo; -class ProcessLaunchInfo; class ProcessInfo; class ProcessInstanceInfo; class ProcessInstanceInfoMatch; class ProcessLaunchInfo; class ProcessModID; class Property; +class ProtocolServer; class Queue; class QueueImpl; class QueueItem; @@ -391,6 +391,7 @@ typedef std::shared_ptr PlatformSP; typedef std::shared_ptr ProcessSP; typedef std::shared_ptr ProcessAttachInfoSP; typedef std::shared_ptr ProcessLaunchInfoSP; +typedef std::shared_ptr ProtocolServerSP; typedef std::weak_ptr ProcessWP; typedef std::shared_ptr RegisterCheckpointSP; typedef std::shared_ptr RegisterContextSP; diff --git a/lldb/include/lldb/lldb-private-interfaces.h b/lldb/include/lldb/lldb-private-interfaces.h index d366dbd1d7832..34eaaa8e581e9 100644 --- a/lldb/include/lldb/lldb-private-interfaces.h +++ b/lldb/include/lldb/lldb-private-interfaces.h @@ -81,6 +81,8 @@ typedef lldb::PlatformSP (*PlatformCreateInstance)(bool force, typedef lldb::ProcessSP (*ProcessCreateInstance)( lldb::TargetSP target_sp, lldb::ListenerSP listener_sp, const FileSpec *crash_file_path, bool can_connect); +typedef lldb::ProtocolServerSP (*ProtocolServerCreateInstance)( + Debugger &debugger); typedef lldb::RegisterTypeBuilderSP (*RegisterTypeBuilderCreateInstance)( Target &target); typedef lldb::ScriptInterpreterSP (*ScriptInterpreterCreateInstance)( diff --git a/lldb/source/Commands/CMakeLists.txt b/lldb/source/Commands/CMakeLists.txt index 1ea51acec5f15..69e4c45f0b8e5 100644 --- a/lldb/source/Commands/CMakeLists.txt +++ b/lldb/source/Commands/CMakeLists.txt @@ -23,6 +23,7 @@ add_lldb_library(lldbCommands NO_PLUGIN_DEPENDENCIES CommandObjectPlatform.cpp CommandObjectPlugin.cpp CommandObjectProcess.cpp + CommandObjectProtocolServer.cpp CommandObjectQuit.cpp CommandObjectRegexCommand.cpp CommandObjectRegister.cpp diff --git a/lldb/source/Commands/CommandObjectProtocolServer.cpp b/lldb/source/Commands/CommandObjectProtocolServer.cpp new file mode 100644 index 0000000000000..c15e4188a92c4 --- /dev/null +++ b/lldb/source/Commands/CommandObjectProtocolServer.cpp @@ -0,0 +1,186 @@ +//===-- CommandObjectProtocolServer.cpp +//----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "CommandObjectProtocolServer.h" +#include "lldb/Core/PluginManager.h" +#include "lldb/Core/ProtocolServer.h" +#include "lldb/Host/Socket.h" +#include "lldb/Interpreter/CommandInterpreter.h" +#include "lldb/Interpreter/CommandReturnObject.h" +#include "lldb/Utility/UriParser.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/FormatAdapters.h" + +using namespace llvm; +using namespace lldb; +using namespace lldb_private; + +#define LLDB_OPTIONS_mcp +#include "CommandOptions.inc" + +static std::vector GetSupportedProtocols() { + std::vector supported_protocols; + size_t i = 0; + + for (llvm::StringRef protocol_name = + PluginManager::GetProtocolServerPluginNameAtIndex(i++); + !protocol_name.empty(); + protocol_name = PluginManager::GetProtocolServerPluginNameAtIndex(i++)) { + supported_protocols.push_back(protocol_name); + } + + return supported_protocols; +} + +static llvm::Expected> +validateConnection(llvm::StringRef conn) { + auto uri = lldb_private::URI::Parse(conn); + + if (uri && (uri->scheme == "tcp" || uri->scheme == "connect" || + !uri->hostname.empty() || uri->port)) { + return std::make_pair( + Socket::ProtocolTcp, + formatv("[{0}]:{1}", uri->hostname.empty() ? "0.0.0.0" : uri->hostname, + uri->port.value_or(0))); + } + + if (uri && (uri->scheme == "unix" || uri->scheme == "unix-connect" || + uri->path != "/")) { + return std::make_pair(Socket::ProtocolUnixDomain, uri->path.str()); + } + + return llvm::createStringError( + "Unsupported connection specifier, expected 'unix-connect:///path' or " + "'connect://[host]:port', got '%s'.", + conn.str().c_str()); +} + +class CommandObjectProtocolServerStart : public CommandObjectParsed { +public: + CommandObjectProtocolServerStart(CommandInterpreter &interpreter) + : CommandObjectParsed(interpreter, "protocol-server start", + "start protocol server", + "protocol-server start ") { + AddSimpleArgumentList(lldb::eArgTypeProtocol, eArgRepeatPlain); + AddSimpleArgumentList(lldb::eArgTypeConnectURL, eArgRepeatPlain); + } + + ~CommandObjectProtocolServerStart() override = default; + +protected: + void DoExecute(Args &args, CommandReturnObject &result) override { + if (args.GetArgumentCount() < 1) { + result.AppendError("no protocol specified"); + return; + } + + llvm::StringRef protocol = args.GetArgumentAtIndex(0); + std::vector supported_protocols = GetSupportedProtocols(); + if (llvm::find(supported_protocols, protocol) == + supported_protocols.end()) { + result.AppendErrorWithFormatv( + "unsupported protocol: {0}. Supported protocols are: {1}", protocol, + llvm::join(GetSupportedProtocols(), ", ")); + return; + } + + if (args.GetArgumentCount() < 2) { + result.AppendError("no connection specified"); + return; + } + llvm::StringRef connection_uri = args.GetArgumentAtIndex(1); + + ProtocolServerSP server_sp = GetDebugger().GetProtocolServer(protocol); + if (!server_sp) + server_sp = ProtocolServer::Create(protocol, GetDebugger()); + + auto maybeProtoclAndName = validateConnection(connection_uri); + if (auto error = maybeProtoclAndName.takeError()) { + result.AppendErrorWithFormatv("{0}", llvm::fmt_consume(std::move(error))); + return; + } + + ProtocolServer::Connection connection; + std::tie(connection.protocol, connection.name) = *maybeProtoclAndName; + + if (llvm::Error error = server_sp->Start(connection)) { + result.AppendErrorWithFormatv("{0}", llvm::fmt_consume(std::move(error))); + return; + } + + GetDebugger().AddProtocolServer(server_sp); + + if (Socket *socket = server_sp->GetSocket()) { + std::string address = + llvm::join(socket->GetListeningConnectionURI(), ", "); + result.AppendMessageWithFormatv( + "{0} server started with connection listeners: {1}", protocol, + address); + } + } +}; + +class CommandObjectProtocolServerStop : public CommandObjectParsed { +public: + CommandObjectProtocolServerStop(CommandInterpreter &interpreter) + : CommandObjectParsed(interpreter, "protocol-server stop", + "stop protocol server", + "protocol-server stop ") { + AddSimpleArgumentList(lldb::eArgTypeProtocol, eArgRepeatPlain); + } + + ~CommandObjectProtocolServerStop() override = default; + +protected: + void DoExecute(Args &args, CommandReturnObject &result) override { + if (args.GetArgumentCount() < 1) { + result.AppendError("no protocol specified"); + return; + } + + llvm::StringRef protocol = args.GetArgumentAtIndex(0); + std::vector supported_protocols = GetSupportedProtocols(); + if (llvm::find(supported_protocols, protocol) == + supported_protocols.end()) { + result.AppendErrorWithFormatv( + "unsupported protocol: {0}. Supported protocols are: {1}", protocol, + llvm::join(GetSupportedProtocols(), ", ")); + return; + } + + Debugger &debugger = GetDebugger(); + + ProtocolServerSP server_sp = debugger.GetProtocolServer(protocol); + if (!server_sp) { + result.AppendError( + llvm::formatv("no {0} protocol server running", protocol).str()); + return; + } + + if (llvm::Error error = server_sp->Stop()) { + result.AppendErrorWithFormatv("{0}", llvm::fmt_consume(std::move(error))); + return; + } + + debugger.RemoveProtocolServer(server_sp); + } +}; + +CommandObjectProtocolServer::CommandObjectProtocolServer( + CommandInterpreter &interpreter) + : CommandObjectMultiword(interpreter, "protocol-server", + "Start and stop a protocol server.", + "protocol-server") { + LoadSubCommand("start", CommandObjectSP(new CommandObjectProtocolServerStart( + interpreter))); + LoadSubCommand("stop", CommandObjectSP( + new CommandObjectProtocolServerStop(interpreter))); +} + +CommandObjectProtocolServer::~CommandObjectProtocolServer() = default; diff --git a/lldb/source/Commands/CommandObjectProtocolServer.h b/lldb/source/Commands/CommandObjectProtocolServer.h new file mode 100644 index 0000000000000..3591216b014cb --- /dev/null +++ b/lldb/source/Commands/CommandObjectProtocolServer.h @@ -0,0 +1,25 @@ +//===-- CommandObjectProtocolServer.h +//------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLDB_SOURCE_COMMANDS_COMMANDOBJECTPROTOCOLSERVER_H +#define LLDB_SOURCE_COMMANDS_COMMANDOBJECTPROTOCOLSERVER_H + +#include "lldb/Interpreter/CommandObjectMultiword.h" + +namespace lldb_private { + +class CommandObjectProtocolServer : public CommandObjectMultiword { +public: + CommandObjectProtocolServer(CommandInterpreter &interpreter); + ~CommandObjectProtocolServer() override; +}; + +} // namespace lldb_private + +#endif // LLDB_SOURCE_COMMANDS_COMMANDOBJECTMCP_H diff --git a/lldb/source/Core/CMakeLists.txt b/lldb/source/Core/CMakeLists.txt index d6b75bca7f2d6..df35bd5c025f3 100644 --- a/lldb/source/Core/CMakeLists.txt +++ b/lldb/source/Core/CMakeLists.txt @@ -46,6 +46,7 @@ add_lldb_library(lldbCore NO_PLUGIN_DEPENDENCIES Opcode.cpp PluginManager.cpp Progress.cpp + ProtocolServer.cpp Statusline.cpp RichManglingContext.cpp SearchFilter.cpp diff --git a/lldb/source/Core/Debugger.cpp b/lldb/source/Core/Debugger.cpp index 81037d3def811..2bc9c7ead79d3 100644 --- a/lldb/source/Core/Debugger.cpp +++ b/lldb/source/Core/Debugger.cpp @@ -16,6 +16,7 @@ #include "lldb/Core/ModuleSpec.h" #include "lldb/Core/PluginManager.h" #include "lldb/Core/Progress.h" +#include "lldb/Core/ProtocolServer.h" #include "lldb/Core/StreamAsynchronousIO.h" #include "lldb/Core/Telemetry.h" #include "lldb/DataFormatters/DataVisualization.h" @@ -2363,3 +2364,26 @@ llvm::ThreadPoolInterface &Debugger::GetThreadPool() { "Debugger::GetThreadPool called before Debugger::Initialize"); return *g_thread_pool; } + +void Debugger::AddProtocolServer(lldb::ProtocolServerSP protocol_server_sp) { + assert(protocol_server_sp && + GetProtocolServer(protocol_server_sp->GetPluginName()) == nullptr); + m_protocol_servers.push_back(protocol_server_sp); +} + +void Debugger::RemoveProtocolServer(lldb::ProtocolServerSP protocol_server_sp) { + auto it = llvm::find(m_protocol_servers, protocol_server_sp); + if (it != m_protocol_servers.end()) + m_protocol_servers.erase(it); +} + +lldb::ProtocolServerSP +Debugger::GetProtocolServer(llvm::StringRef protocol) const { + for (ProtocolServerSP protocol_server_sp : m_protocol_servers) { + if (!protocol_server_sp) + continue; + if (protocol_server_sp->GetPluginName() == protocol) + return protocol_server_sp; + } + return nullptr; +} diff --git a/lldb/source/Core/PluginManager.cpp b/lldb/source/Core/PluginManager.cpp index 5d44434033c55..a59a390e40bb6 100644 --- a/lldb/source/Core/PluginManager.cpp +++ b/lldb/source/Core/PluginManager.cpp @@ -1006,6 +1006,38 @@ void PluginManager::AutoCompleteProcessName(llvm::StringRef name, } } +#pragma mark ProtocolServer + +typedef PluginInstance ProtocolServerInstance; +typedef PluginInstances ProtocolServerInstances; + +static ProtocolServerInstances &GetProtocolServerInstances() { + static ProtocolServerInstances g_instances; + return g_instances; +} + +bool PluginManager::RegisterPlugin( + llvm::StringRef name, llvm::StringRef description, + ProtocolServerCreateInstance create_callback) { + return GetProtocolServerInstances().RegisterPlugin(name, description, + create_callback); +} + +bool PluginManager::UnregisterPlugin( + ProtocolServerCreateInstance create_callback) { + return GetProtocolServerInstances().UnregisterPlugin(create_callback); +} + +llvm::StringRef +PluginManager::GetProtocolServerPluginNameAtIndex(uint32_t idx) { + return GetProtocolServerInstances().GetNameAtIndex(idx); +} + +ProtocolServerCreateInstance +PluginManager::GetProtocolCreateCallbackForPluginName(llvm::StringRef name) { + return GetProtocolServerInstances().GetCallbackForName(name); +} + #pragma mark RegisterTypeBuilder struct RegisterTypeBuilderInstance diff --git a/lldb/source/Core/ProtocolServer.cpp b/lldb/source/Core/ProtocolServer.cpp new file mode 100644 index 0000000000000..d57a047afa7b2 --- /dev/null +++ b/lldb/source/Core/ProtocolServer.cpp @@ -0,0 +1,21 @@ +//===-- ProtocolServer.cpp ------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "lldb/Core/ProtocolServer.h" +#include "lldb/Core/PluginManager.h" + +using namespace lldb_private; +using namespace lldb; + +ProtocolServerSP ProtocolServer::Create(llvm::StringRef name, + Debugger &debugger) { + if (ProtocolServerCreateInstance create_callback = + PluginManager::GetProtocolCreateCallbackForPluginName(name)) + return create_callback(debugger); + return nullptr; +} diff --git a/lldb/source/Interpreter/CommandInterpreter.cpp b/lldb/source/Interpreter/CommandInterpreter.cpp index 4f9ae104dedea..00c3472444d2e 100644 --- a/lldb/source/Interpreter/CommandInterpreter.cpp +++ b/lldb/source/Interpreter/CommandInterpreter.cpp @@ -30,6 +30,7 @@ #include "Commands/CommandObjectPlatform.h" #include "Commands/CommandObjectPlugin.h" #include "Commands/CommandObjectProcess.h" +#include "Commands/CommandObjectProtocolServer.h" #include "Commands/CommandObjectQuit.h" #include "Commands/CommandObjectRegexCommand.h" #include "Commands/CommandObjectRegister.h" @@ -574,6 +575,7 @@ void CommandInterpreter::LoadCommandDictionary() { REGISTER_COMMAND_OBJECT("platform", CommandObjectPlatform); REGISTER_COMMAND_OBJECT("plugin", CommandObjectPlugin); REGISTER_COMMAND_OBJECT("process", CommandObjectMultiwordProcess); + REGISTER_COMMAND_OBJECT("protocol-server", CommandObjectProtocolServer); REGISTER_COMMAND_OBJECT("quit", CommandObjectQuit); REGISTER_COMMAND_OBJECT("register", CommandObjectRegister); REGISTER_COMMAND_OBJECT("scripting", CommandObjectMultiwordScripting); diff --git a/lldb/source/Plugins/CMakeLists.txt b/lldb/source/Plugins/CMakeLists.txt index 854f589f45ae0..08f444e7b15e8 100644 --- a/lldb/source/Plugins/CMakeLists.txt +++ b/lldb/source/Plugins/CMakeLists.txt @@ -27,6 +27,10 @@ add_subdirectory(TraceExporter) add_subdirectory(TypeSystem) add_subdirectory(UnwindAssembly) +if(LLDB_ENABLE_PROTOCOL_SERVERS) + add_subdirectory(Protocol) +endif() + set(LLDB_STRIPPED_PLUGINS) get_property(LLDB_ALL_PLUGINS GLOBAL PROPERTY LLDB_PLUGINS) diff --git a/lldb/source/Plugins/Protocol/CMakeLists.txt b/lldb/source/Plugins/Protocol/CMakeLists.txt new file mode 100644 index 0000000000000..93b347d4cc9d8 --- /dev/null +++ b/lldb/source/Plugins/Protocol/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(MCP) diff --git a/lldb/source/Plugins/Protocol/MCP/CMakeLists.txt b/lldb/source/Plugins/Protocol/MCP/CMakeLists.txt new file mode 100644 index 0000000000000..db31a7a69cb33 --- /dev/null +++ b/lldb/source/Plugins/Protocol/MCP/CMakeLists.txt @@ -0,0 +1,13 @@ +add_lldb_library(lldbPluginProtocolServerMCP PLUGIN + MCPError.cpp + Protocol.cpp + ProtocolServerMCP.cpp + Tool.cpp + + LINK_COMPONENTS + Support + + LINK_LIBS + lldbHost + lldbUtility +) diff --git a/lldb/source/Plugins/Protocol/MCP/MCPError.cpp b/lldb/source/Plugins/Protocol/MCP/MCPError.cpp new file mode 100644 index 0000000000000..5ed850066b659 --- /dev/null +++ b/lldb/source/Plugins/Protocol/MCP/MCPError.cpp @@ -0,0 +1,34 @@ +//===-- MCPError.cpp ------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "MCPError.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/raw_ostream.h" +#include + +namespace lldb_private::mcp { + +char MCPError::ID; + +MCPError::MCPError(std::string message, int64_t error_code) + : m_message(message), m_error_code(error_code) {} + +void MCPError::log(llvm::raw_ostream &OS) const { OS << m_message; } + +std::error_code MCPError::convertToErrorCode() const { + return llvm::inconvertibleErrorCode(); +} + +protocol::Error MCPError::toProtcolError() const { + protocol::Error error; + error.error.code = m_error_code; + error.error.message = m_message; + return error; +} + +} // namespace lldb_private::mcp diff --git a/lldb/source/Plugins/Protocol/MCP/MCPError.h b/lldb/source/Plugins/Protocol/MCP/MCPError.h new file mode 100644 index 0000000000000..2a76a7b087e20 --- /dev/null +++ b/lldb/source/Plugins/Protocol/MCP/MCPError.h @@ -0,0 +1,33 @@ +//===-- MCPError.h --------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Protocol.h" +#include "llvm/Support/Error.h" +#include + +namespace lldb_private::mcp { + +class MCPError : public llvm::ErrorInfo { +public: + static char ID; + + MCPError(std::string message, int64_t error_code); + + void log(llvm::raw_ostream &OS) const override; + std::error_code convertToErrorCode() const override; + + const std::string &getMessage() const { return m_message; } + + protocol::Error toProtcolError() const; + +private: + std::string m_message; + int64_t m_error_code; +}; + +} // namespace lldb_private::mcp diff --git a/lldb/source/Plugins/Protocol/MCP/Protocol.cpp b/lldb/source/Plugins/Protocol/MCP/Protocol.cpp new file mode 100644 index 0000000000000..d66c931a0b284 --- /dev/null +++ b/lldb/source/Plugins/Protocol/MCP/Protocol.cpp @@ -0,0 +1,214 @@ +//===- Protocol.cpp -------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Protocol.h" +#include "llvm/Support/JSON.h" + +using namespace llvm; + +namespace lldb_private::mcp::protocol { + +static bool mapRaw(const json::Value &Params, StringLiteral Prop, + std::optional &V, json::Path P) { + const auto *O = Params.getAsObject(); + if (!O) { + P.report("expected object"); + return false; + } + const json::Value *E = O->get(Prop); + if (E) + V = std::move(*E); + return true; +} + +llvm::json::Value toJSON(const Request &R) { + json::Object Result{{"jsonrpc", "2.0"}, {"id", R.id}, {"method", R.method}}; + if (R.params) + Result.insert({"params", R.params}); + return Result; +} + +bool fromJSON(const llvm::json::Value &V, Request &R, llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + if (!O || !O.map("id", R.id) || !O.map("method", R.method)) + return false; + return mapRaw(V, "params", R.params, P); +} + +llvm::json::Value toJSON(const ErrorInfo &EI) { + llvm::json::Object Result{{"code", EI.code}, {"message", EI.message}}; + if (EI.data) + Result.insert({"data", EI.data}); + return Result; +} + +bool fromJSON(const llvm::json::Value &V, ErrorInfo &EI, llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + return O && O.map("code", EI.code) && O.map("message", EI.message) && + O.mapOptional("data", EI.data); +} + +llvm::json::Value toJSON(const Error &E) { + return json::Object{{"jsonrpc", "2.0"}, {"id", E.id}, {"error", E.error}}; +} + +bool fromJSON(const llvm::json::Value &V, Error &E, llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + return O && O.map("id", E.id) && O.map("error", E.error); +} + +llvm::json::Value toJSON(const Response &R) { + llvm::json::Object Result{{"jsonrpc", "2.0"}, {"id", R.id}}; + if (R.result) + Result.insert({"result", R.result}); + if (R.error) + Result.insert({"error", R.error}); + return Result; +} + +bool fromJSON(const llvm::json::Value &V, Response &R, llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + if (!O || !O.map("id", R.id) || !O.map("error", R.error)) + return false; + return mapRaw(V, "result", R.result, P); +} + +llvm::json::Value toJSON(const Notification &N) { + llvm::json::Object Result{{"jsonrpc", "2.0"}, {"method", N.method}}; + if (N.params) + Result.insert({"params", N.params}); + return Result; +} + +bool fromJSON(const llvm::json::Value &V, Notification &N, llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + if (!O || !O.map("method", N.method)) + return false; + auto *Obj = V.getAsObject(); + if (!Obj) + return false; + if (auto *Params = Obj->get("params")) + N.params = *Params; + return true; +} + +llvm::json::Value toJSON(const ToolCapability &TC) { + return llvm::json::Object{{"listChanged", TC.listChanged}}; +} + +bool fromJSON(const llvm::json::Value &V, ToolCapability &TC, + llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + return O && O.map("listChanged", TC.listChanged); +} + +llvm::json::Value toJSON(const Capabilities &C) { + return llvm::json::Object{{"tools", C.tools}}; +} + +bool fromJSON(const llvm::json::Value &V, Capabilities &C, llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + return O && O.map("tools", C.tools); +} + +llvm::json::Value toJSON(const TextContent &TC) { + return llvm::json::Object{{"type", "text"}, {"text", TC.text}}; +} + +bool fromJSON(const llvm::json::Value &V, TextContent &TC, llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + return O && O.map("text", TC.text); +} + +llvm::json::Value toJSON(const TextResult &TR) { + return llvm::json::Object{{"content", TR.content}, {"isError", TR.isError}}; +} + +bool fromJSON(const llvm::json::Value &V, TextResult &TR, llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + return O && O.map("content", TR.content) && O.map("isError", TR.isError); +} + +llvm::json::Value toJSON(const ToolDefinition &TD) { + llvm::json::Object Result{{"name", TD.name}}; + if (TD.description) + Result.insert({"description", TD.description}); + if (TD.inputSchema) + Result.insert({"inputSchema", TD.inputSchema}); + return Result; +} + +bool fromJSON(const llvm::json::Value &V, ToolDefinition &TD, + llvm::json::Path P) { + + llvm::json::ObjectMapper O(V, P); + if (!O || !O.map("name", TD.name) || + !O.mapOptional("description", TD.description)) + return false; + return mapRaw(V, "inputSchema", TD.inputSchema, P); +} + +llvm::json::Value toJSON(const Message &M) { + return std::visit([](auto &M) { return toJSON(M); }, M); +} + +bool fromJSON(const llvm::json::Value &V, Message &M, llvm::json::Path P) { + const auto *O = V.getAsObject(); + if (!O) { + P.report("expected object"); + return false; + } + + if (const json::Value *V = O->get("jsonrpc")) { + if (V->getAsString().value_or("") != "2.0") { + P.report("unsupported JSON RPC version"); + return false; + } + } else { + P.report("not a valid JSON RPC message"); + return false; + } + + // A message without an ID is a Notification. + if (!O->get("id")) { + protocol::Notification N; + if (!fromJSON(V, N, P)) + return false; + M = std::move(N); + return true; + } + + if (O->get("error")) { + protocol::Error E; + if (!fromJSON(V, E, P)) + return false; + M = std::move(E); + return true; + } + + if (O->get("result")) { + protocol::Response R; + if (!fromJSON(V, R, P)) + return false; + M = std::move(R); + return true; + } + + if (O->get("method")) { + protocol::Request R; + if (!fromJSON(V, R, P)) + return false; + M = std::move(R); + return true; + } + + P.report("unrecognized message type"); + return false; +} + +} // namespace lldb_private::mcp::protocol diff --git a/lldb/source/Plugins/Protocol/MCP/Protocol.h b/lldb/source/Plugins/Protocol/MCP/Protocol.h new file mode 100644 index 0000000000000..e315899406573 --- /dev/null +++ b/lldb/source/Plugins/Protocol/MCP/Protocol.h @@ -0,0 +1,128 @@ +//===- Protocol.h ---------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains POD structs based on the MCP specification at +// https://github.com/modelcontextprotocol/modelcontextprotocol/blob/main/schema/2024-11-05/schema.json +// +//===----------------------------------------------------------------------===// + +#ifndef LLDB_PLUGINS_PROTOCOL_MCP_PROTOCOL_H +#define LLDB_PLUGINS_PROTOCOL_MCP_PROTOCOL_H + +#include "llvm/Support/JSON.h" +#include +#include +#include + +namespace lldb_private::mcp::protocol { + +static llvm::StringLiteral kVersion = "2024-11-05"; + +/// A request that expects a response. +struct Request { + uint64_t id = 0; + std::string method; + std::optional params; +}; + +llvm::json::Value toJSON(const Request &); +bool fromJSON(const llvm::json::Value &, Request &, llvm::json::Path); + +struct ErrorInfo { + int64_t code = 0; + std::string message; + std::optional data; +}; + +llvm::json::Value toJSON(const ErrorInfo &); +bool fromJSON(const llvm::json::Value &, ErrorInfo &, llvm::json::Path); + +struct Error { + uint64_t id = 0; + ErrorInfo error; +}; + +llvm::json::Value toJSON(const Error &); +bool fromJSON(const llvm::json::Value &, Error &, llvm::json::Path); + +struct Response { + uint64_t id = 0; + std::optional result; + std::optional error; +}; + +llvm::json::Value toJSON(const Response &); +bool fromJSON(const llvm::json::Value &, Response &, llvm::json::Path); + +/// A notification which does not expect a response. +struct Notification { + std::string method; + std::optional params; +}; + +llvm::json::Value toJSON(const Notification &); +bool fromJSON(const llvm::json::Value &, Notification &, llvm::json::Path); + +struct ToolCapability { + /// Whether this server supports notifications for changes to the tool list. + bool listChanged = false; +}; + +llvm::json::Value toJSON(const ToolCapability &); +bool fromJSON(const llvm::json::Value &, ToolCapability &, llvm::json::Path); + +/// Capabilities that a server may support. Known capabilities are defined here, +/// in this schema, but this is not a closed set: any server can define its own, +/// additional capabilities. +struct Capabilities { + /// Present if the server offers any tools to call. + ToolCapability tools; +}; + +llvm::json::Value toJSON(const Capabilities &); +bool fromJSON(const llvm::json::Value &, Capabilities &, llvm::json::Path); + +/// Text provided to or from an LLM. +struct TextContent { + /// The text content of the message. + std::string text; +}; + +llvm::json::Value toJSON(const TextContent &); +bool fromJSON(const llvm::json::Value &, TextContent &, llvm::json::Path); + +struct TextResult { + std::vector content; + bool isError = false; +}; + +llvm::json::Value toJSON(const TextResult &); +bool fromJSON(const llvm::json::Value &, TextResult &, llvm::json::Path); + +struct ToolDefinition { + /// Unique identifier for the tool. + std::string name; + + /// Human-readable description. + std::optional description; + + // JSON Schema for the tool's parameters. + std::optional inputSchema; +}; + +llvm::json::Value toJSON(const ToolDefinition &); +bool fromJSON(const llvm::json::Value &, ToolDefinition &, llvm::json::Path); + +using Message = std::variant; + +bool fromJSON(const llvm::json::Value &, Message &, llvm::json::Path); +llvm::json::Value toJSON(const Message &); + +} // namespace lldb_private::mcp::protocol + +#endif diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp new file mode 100644 index 0000000000000..2e0557c19a732 --- /dev/null +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp @@ -0,0 +1,309 @@ +//===- ProtocolServerMCP.cpp ----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "ProtocolServerMCP.h" +#include "MCPError.h" +#include "lldb/Core/PluginManager.h" +#include "lldb/Utility/LLDBLog.h" +#include "lldb/Utility/Log.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/Threading.h" +#include +#include + +using namespace lldb_private; +using namespace lldb_private::mcp; +using namespace llvm; + +LLDB_PLUGIN_DEFINE(ProtocolServerMCP) + +ProtocolServerMCP::ProtocolServerMCP(Debugger &debugger) + : ProtocolServer(), m_debugger(debugger) { + AddRequestHandler("initialize", + std::bind(&ProtocolServerMCP::InitializeHandler, this, + std::placeholders::_1)); + AddRequestHandler("tools/list", + std::bind(&ProtocolServerMCP::ToolsListHandler, this, + std::placeholders::_1)); + AddRequestHandler("tools/call", + std::bind(&ProtocolServerMCP::ToolsCallHandler, this, + std::placeholders::_1)); + AddNotificationHandler( + "notifications/initialized", [](const protocol::Notification &) { + LLDB_LOG(GetLog(LLDBLog::Host), "MCP initialization complete"); + }); + AddTool(std::make_unique( + "lldb_command", "Run an lldb command.", m_debugger)); +} + +ProtocolServerMCP::~ProtocolServerMCP() { llvm::consumeError(Stop()); } + +void ProtocolServerMCP::Initialize() { + PluginManager::RegisterPlugin(GetPluginNameStatic(), + GetPluginDescriptionStatic(), CreateInstance); +} + +void ProtocolServerMCP::Terminate() { + PluginManager::UnregisterPlugin(CreateInstance); +} + +lldb::ProtocolServerSP ProtocolServerMCP::CreateInstance(Debugger &debugger) { + return std::make_shared(debugger); +} + +llvm::StringRef ProtocolServerMCP::GetPluginDescriptionStatic() { + return "MCP Server."; +} + +llvm::Expected +ProtocolServerMCP::Handle(protocol::Request request) { + auto it = m_request_handlers.find(request.method); + if (it != m_request_handlers.end()) { + llvm::Expected response = it->second(request); + if (!response) + return response; + response->id = request.id; + return *response; + } + + return make_error( + llvm::formatv("no handler for request: {0}", request.method).str(), 1); +} + +void ProtocolServerMCP::Handle(protocol::Notification notification) { + auto it = m_notification_handlers.find(notification.method); + if (it != m_notification_handlers.end()) { + it->second(notification); + return; + } + + LLDB_LOG(GetLog(LLDBLog::Host), "MPC notification: {0} ({1})", + notification.method, notification.params); +} + +void ProtocolServerMCP::AcceptCallback(std::unique_ptr socket) { + LLDB_LOG(GetLog(LLDBLog::Host), "New MCP client ({0}) connected", + m_clients.size() + 1); + + lldb::IOObjectSP io_sp = std::move(socket); + auto transport_sp = std::make_shared(io_sp, io_sp); + + Status status; + auto read_handle_up = m_loop.RegisterReadObject( + io_sp, + [=](MainLoopBase &) { + if (llvm::Error err = HandleData(*transport_sp)) { + LLDB_LOG_ERROR(GetLog(LLDBLog::Host), std::move(err), "{0}"); + } + }, + status); + if (status.Fail()) + return; + + m_clients.emplace_back(io_sp, std::move(read_handle_up)); +} + +llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) { + std::lock_guard guard(m_server_mutex); + + if (m_running) + return llvm::createStringError("server already running"); + + Status status; + m_listener = Socket::Create(connection.protocol, status); + if (status.Fail()) + return status.takeError(); + + status = m_listener->Listen(connection.name, /*backlog=*/5); + if (status.Fail()) + return status.takeError(); + + std::string address = + llvm::join(m_listener->GetListeningConnectionURI(), ", "); + auto handles = + m_listener->Accept(m_loop, std::bind(&ProtocolServerMCP::AcceptCallback, + this, std::placeholders::_1)); + if (llvm::Error error = handles.takeError()) + return error; + + m_listen_handlers = std::move(*handles); + m_loop_thread = std::thread([=] { + llvm::set_thread_name( + llvm::formatv("debugger-{0}.mcp.runloop", m_debugger.GetID())); + m_loop.Run(); + }); + + return llvm::Error::success(); +} + +llvm::Error ProtocolServerMCP::Stop() { + { + std::lock_guard guard(m_server_mutex); + m_running = false; + } + + // Stop the main loop. + m_loop.AddPendingCallback( + [](MainLoopBase &loop) { loop.RequestTermination(); }); + + // Wait for the main loop to exit. + if (m_loop_thread.joinable()) + m_loop_thread.join(); + + { + std::lock_guard guard(m_server_mutex); + m_listener.reset(); + m_listen_handlers.clear(); + m_clients.clear(); + } + + return llvm::Error::success(); +} + +llvm::Error ProtocolServerMCP::HandleData(JSONTransport &transport) { + llvm::Expected maybe_message = + transport.Read(std::chrono::seconds(1)); + if (maybe_message.errorIsA() || + maybe_message.errorIsA() || + maybe_message.errorIsA()) { + consumeError(maybe_message.takeError()); + return llvm::Error::success(); + } + + if (llvm::Error err = maybe_message.takeError()) + return err; + + protocol::Message &message = *maybe_message; + if (const protocol::Request *request = + std::get_if(&message)) { + llvm::Expected maybe_response = Handle(*request); + + // Handle failures. + if (!maybe_response) { + protocol::Error protocol_error; + llvm::handleAllErrors( + maybe_response.takeError(), + [&](const MCPError &err) { protocol_error = err.toProtcolError(); }, + [&](const llvm::ErrorInfoBase &err) { + protocol_error.error.code = -1; + protocol_error.error.message = err.message(); + }); + protocol_error.id = request->id; + if (llvm::Error err = transport.Write(protocol_error)) + return err; + + return llvm::Error::success(); + } + + // Handle success. + if (llvm::Error err = transport.Write(*maybe_response)) + return err; + + return llvm::Error::success(); + } + + if (const protocol::Notification *notification = + std::get_if(&message)) { + Handle(*notification); + return llvm::Error::success(); + } + + if (std::get_if(&message)) + return llvm::createStringError("unexpected MCP message: error"); + + if (std::get_if(&message)) + return llvm::createStringError("unexpected MCP message: response"); + + llvm_unreachable("all message types handled"); +} + +protocol::Capabilities ProtocolServerMCP::GetCapabilities() { + protocol::Capabilities capabilities; + capabilities.tools.listChanged = true; + return capabilities; +} + +void ProtocolServerMCP::AddTool(std::unique_ptr tool) { + std::lock_guard guard(m_server_mutex); + + if (!tool) + return; + m_tools[tool->GetName()] = std::move(tool); +} + +void ProtocolServerMCP::AddRequestHandler(llvm::StringRef method, + RequestHandler handler) { + std::lock_guard guard(m_server_mutex); + m_request_handlers[method] = std::move(handler); +} + +void ProtocolServerMCP::AddNotificationHandler(llvm::StringRef method, + NotificationHandler handler) { + std::lock_guard guard(m_server_mutex); + m_notification_handlers[method] = std::move(handler); +} + +llvm::Expected +ProtocolServerMCP::InitializeHandler(const protocol::Request &request) { + protocol::Response response; + response.result.emplace(llvm::json::Object{ + {"protocolVersion", protocol::kVersion}, + {"capabilities", GetCapabilities()}, + {"serverInfo", + llvm::json::Object{{"name", kName}, {"version", kVersion}}}}); + return response; +} + +llvm::Expected +ProtocolServerMCP::ToolsListHandler(const protocol::Request &request) { + protocol::Response response; + + llvm::json::Array tools; + for (const auto &tool : m_tools) + tools.emplace_back(toJSON(tool.second->GetDefinition())); + + response.result.emplace(llvm::json::Object{{"tools", std::move(tools)}}); + + return response; +} + +llvm::Expected +ProtocolServerMCP::ToolsCallHandler(const protocol::Request &request) { + protocol::Response response; + + if (!request.params) + return llvm::createStringError("no tool parameters"); + + const json::Object *param_obj = request.params->getAsObject(); + if (!param_obj) + return llvm::createStringError("no tool parameters"); + + const json::Value *name = param_obj->get("name"); + if (!name) + return llvm::createStringError("no tool name"); + + llvm::StringRef tool_name = name->getAsString().value_or(""); + if (tool_name.empty()) + return llvm::createStringError("no tool name"); + + auto it = m_tools.find(tool_name); + if (it == m_tools.end()) + return llvm::createStringError(llvm::formatv("no tool \"{0}\"", tool_name)); + + const json::Value *args = param_obj->get("arguments"); + if (!args) + return llvm::createStringError("no tool arguments"); + + llvm::Expected text_result = it->second->Call(*args); + if (!text_result) + return text_result.takeError(); + + response.result.emplace(toJSON(*text_result)); + + return response; +} diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h new file mode 100644 index 0000000000000..a7a47fbe3500e --- /dev/null +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h @@ -0,0 +1,94 @@ +//===- ProtocolServerMCP.h ------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLDB_PLUGINS_PROTOCOL_MCP_PROTOCOLSERVERMCP_H +#define LLDB_PLUGINS_PROTOCOL_MCP_PROTOCOLSERVERMCP_H + +#include "Protocol.h" +#include "Tool.h" +#include "lldb/Core/ProtocolServer.h" +#include "lldb/Host/JSONTransport.h" +#include "lldb/Host/MainLoop.h" +#include "lldb/Host/Socket.h" +#include "llvm/ADT/StringMap.h" +#include + +namespace lldb_private::mcp { + +class ProtocolServerMCP : public ProtocolServer { +public: + ProtocolServerMCP(Debugger &debugger); + virtual ~ProtocolServerMCP() override; + + virtual llvm::Error Start(ProtocolServer::Connection connection) override; + virtual llvm::Error Stop() override; + + static void Initialize(); + static void Terminate(); + + static llvm::StringRef GetPluginNameStatic() { return "MCP"; } + static llvm::StringRef GetPluginDescriptionStatic(); + + static lldb::ProtocolServerSP CreateInstance(Debugger &debugger); + + llvm::StringRef GetPluginName() override { return GetPluginNameStatic(); } + + Socket *GetSocket() const override { return m_listener.get(); } + +protected: + using RequestHandler = std::function( + const protocol::Request &)>; + using NotificationHandler = + std::function; + + void AddTool(std::unique_ptr tool); + void AddRequestHandler(llvm::StringRef method, RequestHandler handler); + void AddNotificationHandler(llvm::StringRef method, + NotificationHandler handler); + +private: + void AcceptCallback(std::unique_ptr socket); + + llvm::Error HandleData(JSONTransport &transport); + + llvm::Expected Handle(protocol::Request request); + void Handle(protocol::Notification notification); + + llvm::Expected + InitializeHandler(const protocol::Request &); + llvm::Expected + ToolsListHandler(const protocol::Request &); + llvm::Expected + ToolsCallHandler(const protocol::Request &); + + protocol::Capabilities GetCapabilities(); + + llvm::StringLiteral kName = "lldb-mcp"; + llvm::StringLiteral kVersion = "0.1.0"; + + Debugger &m_debugger; + + bool m_running = false; + + MainLoop m_loop; + std::thread m_loop_thread; + + std::unique_ptr m_listener; + std::vector m_listen_handlers; + std::vector> + m_clients; + + std::mutex m_server_mutex; + llvm::StringMap> m_tools; + + llvm::StringMap m_request_handlers; + llvm::StringMap m_notification_handlers; +}; +} // namespace lldb_private::mcp + +#endif diff --git a/lldb/source/Plugins/Protocol/MCP/Tool.cpp b/lldb/source/Plugins/Protocol/MCP/Tool.cpp new file mode 100644 index 0000000000000..de8fcc8f3cb4c --- /dev/null +++ b/lldb/source/Plugins/Protocol/MCP/Tool.cpp @@ -0,0 +1,81 @@ +//===- Tool.cpp -----------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Tool.h" +#include "lldb/Interpreter/CommandInterpreter.h" +#include "lldb/Interpreter/CommandReturnObject.h" + +using namespace lldb_private::mcp; +using namespace llvm; + +struct LLDBCommandToolArguments { + std::string arguments; +}; + +bool fromJSON(const llvm::json::Value &V, LLDBCommandToolArguments &A, + llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + return O && O.map("arguments", A.arguments); +} + +Tool::Tool(std::string name, std::string description) + : m_name(std::move(name)), m_description(std::move(description)) {} + +protocol::ToolDefinition Tool::GetDefinition() const { + protocol::ToolDefinition definition; + definition.name = m_name; + definition.description.emplace(m_description); + + if (std::optional input_schema = GetSchema()) + definition.inputSchema = *input_schema; + + return definition; +} + +LLDBCommandTool::LLDBCommandTool(std::string name, std::string description, + Debugger &debugger) + : Tool(std::move(name), std::move(description)), m_debugger(debugger) {} + +llvm::Expected +LLDBCommandTool::Call(const llvm::json::Value &args) { + llvm::json::Path::Root root; + + LLDBCommandToolArguments arguments; + if (!fromJSON(args, arguments, root)) + return root.getError(); + + // FIXME: Disallow certain commands and their aliases. + CommandReturnObject result(/*colors=*/false); + m_debugger.GetCommandInterpreter().HandleCommand(arguments.arguments.c_str(), + eLazyBoolYes, result); + + std::string output; + llvm::StringRef output_str = result.GetOutputString(); + if (!output_str.empty()) + output += output_str.str(); + + std::string err_str = result.GetErrorString(); + if (!err_str.empty()) { + if (!output.empty()) + output += '\n'; + output += err_str; + } + + mcp::protocol::TextResult text_result; + text_result.content.emplace_back(mcp::protocol::TextContent{{output}}); + text_result.isError = !result.Succeeded(); + return text_result; +} + +std::optional LLDBCommandTool::GetSchema() const { + llvm::json::Object str_type{{"type", "string"}}; + llvm::json::Object properties{{"arguments", std::move(str_type)}}; + llvm::json::Object schema{{"type", "object"}, + {"properties", std::move(properties)}}; + return schema; +} diff --git a/lldb/source/Plugins/Protocol/MCP/Tool.h b/lldb/source/Plugins/Protocol/MCP/Tool.h new file mode 100644 index 0000000000000..57a5125813b76 --- /dev/null +++ b/lldb/source/Plugins/Protocol/MCP/Tool.h @@ -0,0 +1,56 @@ +//===- Tool.h -------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLDB_PLUGINS_PROTOCOL_MCP_TOOL_H +#define LLDB_PLUGINS_PROTOCOL_MCP_TOOL_H + +#include "Protocol.h" +#include "lldb/Core/Debugger.h" +#include "llvm/Support/JSON.h" +#include + +namespace lldb_private::mcp { + +class Tool { +public: + Tool(std::string name, std::string description); + virtual ~Tool() = default; + + virtual llvm::Expected + Call(const llvm::json::Value &args) = 0; + + virtual std::optional GetSchema() const { + return std::nullopt; + } + + protocol::ToolDefinition GetDefinition() const; + + const std::string &GetName() { return m_name; } + +private: + std::string m_name; + std::string m_description; +}; + +class LLDBCommandTool : public mcp::Tool { +public: + LLDBCommandTool(std::string name, std::string description, + Debugger &debugger); + ~LLDBCommandTool() = default; + + virtual llvm::Expected + Call(const llvm::json::Value &args) override; + + virtual std::optional GetSchema() const override; + +private: + Debugger &m_debugger; +}; +} // namespace lldb_private::mcp + +#endif diff --git a/lldb/unittests/CMakeLists.txt b/lldb/unittests/CMakeLists.txt index 6eaaa4f4c8c98..b48b9bafe3bc3 100644 --- a/lldb/unittests/CMakeLists.txt +++ b/lldb/unittests/CMakeLists.txt @@ -78,6 +78,10 @@ add_subdirectory(Utility) add_subdirectory(Thread) add_subdirectory(ValueObject) +if(LLDB_ENABLE_PROTOCOL_SERVERS) + add_subdirectory(Protocol) +endif() + if(LLDB_CAN_USE_DEBUGSERVER AND LLDB_TOOL_DEBUGSERVER_BUILD AND NOT LLDB_USE_SYSTEM_DEBUGSERVER) add_subdirectory(debugserver) endif() diff --git a/lldb/unittests/DAP/ProtocolTypesTest.cpp b/lldb/unittests/DAP/ProtocolTypesTest.cpp index f2a23db346565..f8e49750a83cc 100644 --- a/lldb/unittests/DAP/ProtocolTypesTest.cpp +++ b/lldb/unittests/DAP/ProtocolTypesTest.cpp @@ -9,6 +9,7 @@ #include "Protocol/ProtocolTypes.h" #include "Protocol/ProtocolEvents.h" #include "Protocol/ProtocolRequests.h" +#include "TestingSupport/TestUtilities.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/JSON.h" #include "llvm/Testing/Support/Error.h" @@ -20,6 +21,7 @@ using namespace llvm; using namespace lldb; using namespace lldb_dap; using namespace lldb_dap::protocol; +using lldb_private::roundtripJSON; using llvm::json::parse; using llvm::json::Value; @@ -28,15 +30,6 @@ static std::string pp(const json::Value &E) { return formatv("{0:2}", E).str(); } -template static llvm::Expected roundtrip(const T &input) { - llvm::json::Value value = toJSON(input); - llvm::json::Path::Root root; - T output; - if (!fromJSON(value, output, root)) - return root.getError(); - return output; -} - TEST(ProtocolTypesTest, ExceptionBreakpointsFilter) { ExceptionBreakpointsFilter filter; filter.filter = "testFilter"; @@ -47,7 +40,7 @@ TEST(ProtocolTypesTest, ExceptionBreakpointsFilter) { filter.conditionDescription = "Condition for test filter"; llvm::Expected deserialized_filter = - roundtrip(filter); + roundtripJSON(filter); ASSERT_THAT_EXPECTED(deserialized_filter, llvm::Succeeded()); EXPECT_EQ(filter.filter, deserialized_filter->filter); @@ -66,7 +59,7 @@ TEST(ProtocolTypesTest, Source) { source.sourceReference = 12345; source.presentationHint = Source::eSourcePresentationHintEmphasize; - llvm::Expected deserialized_source = roundtrip(source); + llvm::Expected deserialized_source = roundtripJSON(source); ASSERT_THAT_EXPECTED(deserialized_source, llvm::Succeeded()); EXPECT_EQ(source.name, deserialized_source->name); @@ -83,7 +76,7 @@ TEST(ProtocolTypesTest, ColumnDescriptor) { column.type = eColumnTypeString; column.width = 20; - llvm::Expected deserialized_column = roundtrip(column); + llvm::Expected deserialized_column = roundtripJSON(column); ASSERT_THAT_EXPECTED(deserialized_column, llvm::Succeeded()); EXPECT_EQ(column.attributeName, deserialized_column->attributeName); @@ -101,7 +94,7 @@ TEST(ProtocolTypesTest, BreakpointMode) { mode.appliesTo = {eBreakpointModeApplicabilitySource, eBreakpointModeApplicabilityException}; - llvm::Expected deserialized_mode = roundtrip(mode); + llvm::Expected deserialized_mode = roundtripJSON(mode); ASSERT_THAT_EXPECTED(deserialized_mode, llvm::Succeeded()); EXPECT_EQ(mode.mode, deserialized_mode->mode); @@ -125,7 +118,8 @@ TEST(ProtocolTypesTest, Breakpoint) { breakpoint.offset = 4; breakpoint.reason = BreakpointReason::eBreakpointReasonPending; - llvm::Expected deserialized_breakpoint = roundtrip(breakpoint); + llvm::Expected deserialized_breakpoint = + roundtripJSON(breakpoint); ASSERT_THAT_EXPECTED(deserialized_breakpoint, llvm::Succeeded()); EXPECT_EQ(breakpoint.id, deserialized_breakpoint->id); @@ -157,7 +151,7 @@ TEST(ProtocolTypesTest, SourceBreakpoint) { source_breakpoint.mode = "hardware"; llvm::Expected deserialized_source_breakpoint = - roundtrip(source_breakpoint); + roundtripJSON(source_breakpoint); ASSERT_THAT_EXPECTED(deserialized_source_breakpoint, llvm::Succeeded()); EXPECT_EQ(source_breakpoint.line, deserialized_source_breakpoint->line); @@ -178,7 +172,7 @@ TEST(ProtocolTypesTest, FunctionBreakpoint) { function_breakpoint.hitCondition = "3"; llvm::Expected deserialized_function_breakpoint = - roundtrip(function_breakpoint); + roundtripJSON(function_breakpoint); ASSERT_THAT_EXPECTED(deserialized_function_breakpoint, llvm::Succeeded()); EXPECT_EQ(function_breakpoint.name, deserialized_function_breakpoint->name); @@ -196,7 +190,7 @@ TEST(ProtocolTypesTest, DataBreakpoint) { data_breakpoint_info.hitCondition = "10"; llvm::Expected deserialized_data_breakpoint_info = - roundtrip(data_breakpoint_info); + roundtripJSON(data_breakpoint_info); ASSERT_THAT_EXPECTED(deserialized_data_breakpoint_info, llvm::Succeeded()); EXPECT_EQ(data_breakpoint_info.dataId, @@ -233,9 +227,9 @@ TEST(ProtocolTypesTest, Capabilities) { {eBreakpointModeApplicabilitySource}}}; capabilities.lldbExtVersion = "1.0.0"; - // Perform roundtrip serialization and deserialization. + // Perform roundtripJSON serialization and deserialization. llvm::Expected deserialized_capabilities = - roundtrip(capabilities); + roundtripJSON(capabilities); ASSERT_THAT_EXPECTED(deserialized_capabilities, llvm::Succeeded()); // Verify supported features. @@ -326,7 +320,7 @@ TEST(ProtocolTypesTest, Scope) { source.presentationHint = Source::eSourcePresentationHintNormal; scope.source = source; - llvm::Expected deserialized_scope = roundtrip(scope); + llvm::Expected deserialized_scope = roundtripJSON(scope); ASSERT_THAT_EXPECTED(deserialized_scope, llvm::Succeeded()); EXPECT_EQ(scope.name, deserialized_scope->name); EXPECT_EQ(scope.presentationHint, deserialized_scope->presentationHint); @@ -696,7 +690,7 @@ TEST(ProtocolTypesTest, StepInTarget) { target.endLine = 32; target.endColumn = 23; - llvm::Expected deserialized_target = roundtrip(target); + llvm::Expected deserialized_target = roundtripJSON(target); ASSERT_THAT_EXPECTED(deserialized_target, llvm::Succeeded()); EXPECT_EQ(target.id, deserialized_target->id); @@ -705,4 +699,4 @@ TEST(ProtocolTypesTest, StepInTarget) { EXPECT_EQ(target.column, deserialized_target->column); EXPECT_EQ(target.endLine, deserialized_target->endLine); EXPECT_EQ(target.endColumn, deserialized_target->endColumn); -} \ No newline at end of file +} diff --git a/lldb/unittests/Protocol/CMakeLists.txt b/lldb/unittests/Protocol/CMakeLists.txt new file mode 100644 index 0000000000000..801662b0544d8 --- /dev/null +++ b/lldb/unittests/Protocol/CMakeLists.txt @@ -0,0 +1,12 @@ +add_lldb_unittest(ProtocolTests + ProtocolMCPTest.cpp + ProtocolMCPServerTest.cpp + + LINK_LIBS + lldbCore + lldbUtility + lldbHost + lldbPluginPlatformMacOSX + lldbPluginProtocolServerMCP + LLVMTestingSupport + ) diff --git a/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp b/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp new file mode 100644 index 0000000000000..63e6557bebc65 --- /dev/null +++ b/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp @@ -0,0 +1,290 @@ +//===-- ProtocolServerMCPTest.cpp -----------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Plugins/Platform/MacOSX/PlatformRemoteMacOSX.h" +#include "Plugins/Protocol/MCP/ProtocolServerMCP.h" +#include "TestingSupport/Host/SocketTestUtilities.h" +#include "TestingSupport/SubsystemRAII.h" +#include "lldb/Core/ProtocolServer.h" +#include "lldb/Host/FileSystem.h" +#include "lldb/Host/HostInfo.h" +#include "lldb/Host/Socket.h" +#include "llvm/Testing/Support/Error.h" +#include "gtest/gtest.h" + +using namespace llvm; +using namespace lldb; +using namespace lldb_private; +using namespace lldb_private::mcp::protocol; + +namespace { +class TestProtocolServerMCP : public lldb_private::mcp::ProtocolServerMCP { +public: + using ProtocolServerMCP::AddNotificationHandler; + using ProtocolServerMCP::AddRequestHandler; + using ProtocolServerMCP::AddTool; + using ProtocolServerMCP::GetSocket; + using ProtocolServerMCP::ProtocolServerMCP; +}; + +class TestJSONTransport : public lldb_private::JSONRPCTransport { +public: + using JSONRPCTransport::JSONRPCTransport; + using JSONRPCTransport::ReadImpl; + using JSONRPCTransport::WriteImpl; +}; + +/// Test tool that returns it argument as text. +class TestTool : public mcp::Tool { +public: + using mcp::Tool::Tool; + + virtual llvm::Expected + Call(const llvm::json::Value &args) override { + std::string argument; + if (const json::Object *args_obj = args.getAsObject()) { + if (const json::Value *s = args_obj->get("arguments")) { + argument = s->getAsString().value_or(""); + } + } + + mcp::protocol::TextResult text_result; + text_result.content.emplace_back(mcp::protocol::TextContent{{argument}}); + return text_result; + } +}; + +/// Test tool that returns an error. +class ErrorTool : public mcp::Tool { +public: + using mcp::Tool::Tool; + + virtual llvm::Expected + Call(const llvm::json::Value &args) override { + return llvm::createStringError("error"); + } +}; + +/// Test tool that fails but doesn't return an error. +class FailTool : public mcp::Tool { +public: + using mcp::Tool::Tool; + + virtual llvm::Expected + Call(const llvm::json::Value &args) override { + mcp::protocol::TextResult text_result; + text_result.content.emplace_back(mcp::protocol::TextContent{{"failed"}}); + text_result.isError = true; + return text_result; + } +}; + +class ProtocolServerMCPTest : public ::testing::Test { +public: + SubsystemRAII subsystems; + DebuggerSP m_debugger_sp; + + lldb::IOObjectSP m_io_sp; + std::unique_ptr m_transport_up; + std::unique_ptr m_server_up; + + static constexpr llvm::StringLiteral k_localhost = "localhost"; + + llvm::Error Write(llvm::StringRef message) { + return m_transport_up->WriteImpl(llvm::formatv("{0}\n", message).str()); + } + + llvm::Expected Read() { + return m_transport_up->ReadImpl(std::chrono::milliseconds(100)); + } + + void SetUp() { + // Create a debugger. + ArchSpec arch("arm64-apple-macosx-"); + Platform::SetHostPlatform( + PlatformRemoteMacOSX::CreateInstance(true, &arch)); + m_debugger_sp = Debugger::CreateInstance(); + + // Create & start the server. + ProtocolServer::Connection connection; + connection.protocol = Socket::SocketProtocol::ProtocolTcp; + connection.name = llvm::formatv("{0}:0", k_localhost).str(); + m_server_up = std::make_unique(*m_debugger_sp); + m_server_up->AddTool(std::make_unique("test", "test tool")); + ASSERT_THAT_ERROR(m_server_up->Start(connection), llvm::Succeeded()); + + // Connect to the server over a TCP socket. + auto connect_socket_up = std::make_unique(true); + ASSERT_THAT_ERROR(connect_socket_up + ->Connect(llvm::formatv("{0}:{1}", k_localhost, + static_cast( + m_server_up->GetSocket()) + ->GetLocalPortNumber()) + .str()) + .ToError(), + llvm::Succeeded()); + + // Set up JSON transport for the client. + m_io_sp = std::move(connect_socket_up); + m_transport_up = std::make_unique(m_io_sp, m_io_sp); + } + + void TearDown() { + // Stop the server. + ASSERT_THAT_ERROR(m_server_up->Stop(), llvm::Succeeded()); + } +}; + +} // namespace + +TEST_F(ProtocolServerMCPTest, Intialization) { + llvm::StringLiteral request = + R"json({"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"claude-ai","version":"0.1.0"}},"jsonrpc":"2.0","id":0})json"; + llvm::StringLiteral response = + R"json({"jsonrpc":"2.0","id":0,"result":{"capabilities":{"tools":{"listChanged":true}},"protocolVersion":"2024-11-05","serverInfo":{"name":"lldb-mcp","version":"0.1.0"}}})json"; + + ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); + + llvm::Expected response_str = Read(); + ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); + + llvm::Expected response_json = json::parse(*response_str); + ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); + + llvm::Expected expected_json = json::parse(response); + ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); + + EXPECT_EQ(*response_json, *expected_json); +} + +TEST_F(ProtocolServerMCPTest, ToolsList) { + llvm::StringLiteral request = + R"json({"method":"tools/list","params":{},"jsonrpc":"2.0","id":1})json"; + llvm::StringLiteral response = + R"json({"id":1,"jsonrpc":"2.0","result":{"tools":[{"description":"test tool","name":"test"},{"description":"Run an lldb command.","inputSchema":{"properties":{"arguments":{"type":"string"}},"type":"object"},"name":"lldb_command"}]}})json"; + + ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); + + llvm::Expected response_str = Read(); + ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); + + llvm::Expected response_json = json::parse(*response_str); + ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); + + llvm::Expected expected_json = json::parse(response); + ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); + + EXPECT_EQ(*response_json, *expected_json); +} + +TEST_F(ProtocolServerMCPTest, ResourcesList) { + llvm::StringLiteral request = + R"json({"method":"resources/list","params":{},"jsonrpc":"2.0","id":2})json"; + llvm::StringLiteral response = + R"json({"error":{"code":1,"message":"no handler for request: resources/list"},"id":2,"jsonrpc":"2.0"})json"; + + ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); + + llvm::Expected response_str = Read(); + ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); + + llvm::Expected response_json = json::parse(*response_str); + ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); + + llvm::Expected expected_json = json::parse(response); + ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); + + EXPECT_EQ(*response_json, *expected_json); +} + +TEST_F(ProtocolServerMCPTest, ToolsCall) { + llvm::StringLiteral request = + R"json({"method":"tools/call","params":{"name":"test","arguments":{"arguments":"foo"}},"jsonrpc":"2.0","id":11})json"; + llvm::StringLiteral response = + R"json({"id":11,"jsonrpc":"2.0","result":{"content":[{"text":"foo","type":"text"}],"isError":false}})json"; + + ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); + + llvm::Expected response_str = Read(); + ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); + + llvm::Expected response_json = json::parse(*response_str); + ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); + + llvm::Expected expected_json = json::parse(response); + ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); + + EXPECT_EQ(*response_json, *expected_json); +} + +TEST_F(ProtocolServerMCPTest, ToolsCallError) { + m_server_up->AddTool(std::make_unique("error", "error tool")); + + llvm::StringLiteral request = + R"json({"method":"tools/call","params":{"name":"error","arguments":{"arguments":"foo"}},"jsonrpc":"2.0","id":11})json"; + llvm::StringLiteral response = + R"json({"error":{"code":-1,"message":"error"},"id":11,"jsonrpc":"2.0"})json"; + + ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); + + llvm::Expected response_str = Read(); + ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); + + llvm::Expected response_json = json::parse(*response_str); + ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); + + llvm::Expected expected_json = json::parse(response); + ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); + + EXPECT_EQ(*response_json, *expected_json); +} + +TEST_F(ProtocolServerMCPTest, ToolsCallFail) { + m_server_up->AddTool(std::make_unique("fail", "fail tool")); + + llvm::StringLiteral request = + R"json({"method":"tools/call","params":{"name":"fail","arguments":{"arguments":"foo"}},"jsonrpc":"2.0","id":11})json"; + llvm::StringLiteral response = + R"json({"id":11,"jsonrpc":"2.0","result":{"content":[{"text":"failed","type":"text"}],"isError":true}})json"; + + ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); + + llvm::Expected response_str = Read(); + ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); + + llvm::Expected response_json = json::parse(*response_str); + ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); + + llvm::Expected expected_json = json::parse(response); + ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); + + EXPECT_EQ(*response_json, *expected_json); +} + +TEST_F(ProtocolServerMCPTest, NotificationInitialized) { + bool handler_called = false; + std::condition_variable cv; + std::mutex mutex; + + m_server_up->AddNotificationHandler( + "notifications/initialized", + [&](const mcp::protocol::Notification ¬ification) { + { + std::lock_guard lock(mutex); + handler_called = true; + } + cv.notify_all(); + }); + llvm::StringLiteral request = + R"json({"method":"notifications/initialized","jsonrpc":"2.0"})json"; + + ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); + + std::unique_lock lock(mutex); + cv.wait(lock, [&] { return handler_called; }); +} diff --git a/lldb/unittests/Protocol/ProtocolMCPTest.cpp b/lldb/unittests/Protocol/ProtocolMCPTest.cpp new file mode 100644 index 0000000000000..00959f3ce20be --- /dev/null +++ b/lldb/unittests/Protocol/ProtocolMCPTest.cpp @@ -0,0 +1,135 @@ +//===-- ProtocolMCPTest.cpp -----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Plugins/Protocol/MCP/Protocol.h" +#include "TestingSupport/TestUtilities.h" +#include "llvm/Testing/Support/Error.h" +#include "gtest/gtest.h" + +using namespace lldb; +using namespace lldb_private; +using namespace lldb_private::mcp::protocol; + +TEST(ProtocolMCPTest, Request) { + Request request; + request.id = 1; + request.method = "foo"; + request.params = llvm::json::Object{{"key", "value"}}; + + llvm::Expected deserialized_request = roundtripJSON(request); + ASSERT_THAT_EXPECTED(deserialized_request, llvm::Succeeded()); + + EXPECT_EQ(request.id, deserialized_request->id); + EXPECT_EQ(request.method, deserialized_request->method); + EXPECT_EQ(request.params, deserialized_request->params); +} + +TEST(ProtocolMCPTest, Response) { + Response response; + response.id = 1; + response.result = llvm::json::Object{{"key", "value"}}; + + llvm::Expected deserialized_response = roundtripJSON(response); + ASSERT_THAT_EXPECTED(deserialized_response, llvm::Succeeded()); + + EXPECT_EQ(response.id, deserialized_response->id); + EXPECT_EQ(response.result, deserialized_response->result); +} + +TEST(ProtocolMCPTest, Notification) { + Notification notification; + notification.method = "notifyMethod"; + notification.params = llvm::json::Object{{"key", "value"}}; + + llvm::Expected deserialized_notification = + roundtripJSON(notification); + ASSERT_THAT_EXPECTED(deserialized_notification, llvm::Succeeded()); + + EXPECT_EQ(notification.method, deserialized_notification->method); + EXPECT_EQ(notification.params, deserialized_notification->params); +} + +TEST(ProtocolMCPTest, ToolCapability) { + ToolCapability tool_capability; + tool_capability.listChanged = true; + + llvm::Expected deserialized_tool_capability = + roundtripJSON(tool_capability); + ASSERT_THAT_EXPECTED(deserialized_tool_capability, llvm::Succeeded()); + + EXPECT_EQ(tool_capability.listChanged, + deserialized_tool_capability->listChanged); +} + +TEST(ProtocolMCPTest, Capabilities) { + ToolCapability tool_capability; + tool_capability.listChanged = true; + + Capabilities capabilities; + capabilities.tools = tool_capability; + + llvm::Expected deserialized_capabilities = + roundtripJSON(capabilities); + ASSERT_THAT_EXPECTED(deserialized_capabilities, llvm::Succeeded()); + + EXPECT_EQ(capabilities.tools.listChanged, + deserialized_capabilities->tools.listChanged); +} + +TEST(ProtocolMCPTest, TextContent) { + TextContent text_content; + text_content.text = "Sample text"; + + llvm::Expected deserialized_text_content = + roundtripJSON(text_content); + ASSERT_THAT_EXPECTED(deserialized_text_content, llvm::Succeeded()); + + EXPECT_EQ(text_content.text, deserialized_text_content->text); +} + +TEST(ProtocolMCPTest, TextResult) { + TextContent text_content1; + text_content1.text = "Text 1"; + + TextContent text_content2; + text_content2.text = "Text 2"; + + TextResult text_result; + text_result.content = {text_content1, text_content2}; + text_result.isError = true; + + llvm::Expected deserialized_text_result = + roundtripJSON(text_result); + ASSERT_THAT_EXPECTED(deserialized_text_result, llvm::Succeeded()); + + EXPECT_EQ(text_result.isError, deserialized_text_result->isError); + ASSERT_EQ(text_result.content.size(), + deserialized_text_result->content.size()); + EXPECT_EQ(text_result.content[0].text, + deserialized_text_result->content[0].text); + EXPECT_EQ(text_result.content[1].text, + deserialized_text_result->content[1].text); +} + +TEST(ProtocolMCPTest, ToolDefinition) { + ToolDefinition tool_definition; + tool_definition.name = "ToolName"; + tool_definition.description = "Tool Description"; + tool_definition.inputSchema = + llvm::json::Object{{"schemaKey", "schemaValue"}}; + + llvm::Expected deserialized_tool_definition = + roundtripJSON(tool_definition); + ASSERT_THAT_EXPECTED(deserialized_tool_definition, llvm::Succeeded()); + + EXPECT_EQ(tool_definition.name, deserialized_tool_definition->name); + EXPECT_EQ(tool_definition.description, + deserialized_tool_definition->description); + EXPECT_EQ(tool_definition.inputSchema, + deserialized_tool_definition->inputSchema); +} diff --git a/lldb/unittests/TestingSupport/TestUtilities.h b/lldb/unittests/TestingSupport/TestUtilities.h index 65994384059fb..db62881872fef 100644 --- a/lldb/unittests/TestingSupport/TestUtilities.h +++ b/lldb/unittests/TestingSupport/TestUtilities.h @@ -59,6 +59,15 @@ class TestFile { std::string Buffer; }; + +template static llvm::Expected roundtripJSON(const T &input) { + llvm::json::Value value = toJSON(input); + llvm::json::Path::Root root; + T output; + if (!fromJSON(value, output, root)) + return root.getError(); + return output; +} } // namespace lldb_private #endif