diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 85e57314f8..8fb46dd284 100755 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -170,6 +170,10 @@ jobs: { "ID": "Build-Executor", "Args": ["state", "run", "build-exec"] + }, + { + "ID": "Build-MCP", + "Args": ["state", "run", "build-mcp"] } ] EOF @@ -225,6 +229,11 @@ jobs: shell: bash run: parallelize results Build-Executor + - # === "Build: MCP" === + name: "Build: MCP" + shell: bash + run: parallelize results Build-MCP + - # === Prepare Windows Cert === name: Prepare Windows Cert shell: bash @@ -245,6 +254,7 @@ jobs: signtool.exe sign -d "ActiveState State Service" -f "Cert.p12" -p ${CODE_SIGNING_PASSWD} ./build/state-svc.exe signtool.exe sign -d "ActiveState State Installer" -f "Cert.p12" -p ${CODE_SIGNING_PASSWD} ./build/state-installer.exe signtool.exe sign -d "ActiveState State Tool Remote Installer" -f "Cert.p12" -p ${CODE_SIGNING_PASSWD} ./build/state-remote-installer.exe + signtool.exe sign -d "ActiveState State MCP" -f "Cert.p12" -p ${CODE_SIGNING_PASSWD} ./build/state-mcp.exe env: CODE_SIGNING_PASSWD: ${{ secrets.CODE_SIGNING_PASSWD }} diff --git a/activestate.windows.yaml b/activestate.windows.yaml index 76f5251757..1ffd7363e7 100644 --- a/activestate.windows.yaml +++ b/activestate.windows.yaml @@ -7,6 +7,8 @@ constants: value: state-exec.exe - name: BUILD_INSTALLER_TARGET value: state-installer.exe + - name: BUILD_MCP_TARGET + value: state-mcp.exe - name: SVC_BUILDFLAGS value: -ldflags="-s -w -H=windowsgui" - name: SCRIPT_EXT diff --git a/activestate.yaml b/activestate.yaml index fac0bab6f8..6333cb5368 100644 --- a/activestate.yaml +++ b/activestate.yaml @@ -10,6 +10,8 @@ constants: value: ./cmd/state-installer - name: EXECUTOR_PKGS value: ./cmd/state-exec + - name: MCP_PKGS + value: ./cmd/state-mcp - name: BUILD_TARGET_PREFIX_DIR value: ./build - name: BUILD_TARGET @@ -29,6 +31,9 @@ constants: value: state-installer - name: BUILD_REMOTE_INSTALLER_TARGET value: state-remote-installer + - name: BUILD_MCP_TARGET + if: ne .OS.Name "Windows" + value: state-mcp - name: INTEGRATION_TEST_REGEX value: 'integration\|automation' - name: SET_ENV @@ -106,7 +111,9 @@ scripts: go generate popd > /dev/null fi - go build -tags "$GO_BUILD_TAGS" -o $BUILD_TARGET_DIR/$constants.BUILD_TARGET $constants.CLI_BUILDFLAGS $constants.CLI_PKGS + TARGET=$BUILD_TARGET_DIR/$constants.BUILD_TARGET + echo "Building $TARGET" + go build -tags "$GO_BUILD_TAGS" -o $TARGET $constants.CLI_BUILDFLAGS $constants.CLI_PKGS - name: build-svc language: bash standalone: true @@ -121,7 +128,9 @@ scripts: go generate popd > /dev/null fi - go build -tags "$GO_BUILD_TAGS" -o $BUILD_TARGET_DIR/$constants.BUILD_DAEMON_TARGET $constants.SVC_BUILDFLAGS $constants.DAEMON_PKGS + TARGET=$BUILD_TARGET_DIR/$constants.BUILD_DAEMON_TARGET + echo "Building $TARGET" + go build -tags "$GO_BUILD_TAGS" -o $TARGET $constants.SVC_BUILDFLAGS $constants.DAEMON_PKGS - name: build-exec description: Builds the State Executor application language: bash @@ -129,8 +138,19 @@ scripts: value: | set -e $constants.SET_ENV - - go build -tags "$GO_BUILD_TAGS" -o $BUILD_TARGET_DIR/$constants.BUILD_EXEC_TARGET $constants.CLI_BUILDFLAGS $constants.EXECUTOR_PKGS + TARGET=$BUILD_TARGET_DIR/$constants.BUILD_EXEC_TARGET + echo "Building $TARGET" + go build -tags "$GO_BUILD_TAGS" -o $TARGET $constants.CLI_BUILDFLAGS $constants.EXECUTOR_PKGS + - name: build-mcp + description: Builds the State MCP application + language: bash + standalone: true + value: | + set -e + $constants.SET_ENV + TARGET=$BUILD_TARGET_DIR/$constants.BUILD_MCP_TARGET + echo "Building $TARGET" + go build -tags "$GO_BUILD_TAGS" -o $TARGET $constants.CLI_BUILDFLAGS $constants.MCP_PKGS - name: build-all description: Builds all our tools language: bash @@ -147,6 +167,8 @@ scripts: $scripts.build-svc.path() echo "Building State Executor" $scripts.build-exec.path() + echo "Building State MCP" + $scripts.build-mcp.path() - name: build-installer language: bash standalone: true diff --git a/cmd/state-installer/cmd.go b/cmd/state-installer/cmd.go index 89a5450319..8bf68fbd73 100644 --- a/cmd/state-installer/cmd.go +++ b/cmd/state-installer/cmd.go @@ -475,11 +475,7 @@ func storeInstallSource(installSource string) { installSource = "state-installer" } - appData, err := storage.AppDataPath() - if err != nil { - multilog.Error("Could not store install source due to AppDataPath error: %s", errs.JoinMessage(err)) - return - } + appData := storage.AppDataPath() if err := fileutils.WriteFile(filepath.Join(appData, constants.InstallSourceFile), []byte(installSource)); err != nil { multilog.Error("Could not store install source due to WriteFile error: %s", errs.JoinMessage(err)) } diff --git a/cmd/state-mcp/handlers.go b/cmd/state-mcp/handlers.go new file mode 100644 index 0000000000..cc72209b67 --- /dev/null +++ b/cmd/state-mcp/handlers.go @@ -0,0 +1,103 @@ +package main + +import ( + "bytes" + "context" + "encoding/json" + "strings" + + "github.com/ActiveState/cli/internal/errs" + "github.com/ActiveState/cli/internal/runners/cve" + "github.com/ActiveState/cli/internal/runners/manifest" + "github.com/ActiveState/cli/internal/runners/projects" + "github.com/mark3labs/mcp-go/mcp" +) + +// listProjectsHandler handles the list_projects tool +func (t *mcpServerHandler) listProjectsHandler(ctx context.Context, request mcp.CallToolRequest) (r *mcp.CallToolResult, rerr error) { + var byt bytes.Buffer + prime, close, err := t.newPrimer("", &byt) + if err != nil { + return nil, errs.Wrap(err, "Failed to create primer") + } + defer func() { + if err := close(); err != nil { + rerr = errs.Pack(rerr, err) + } + }() + + runner := projects.NewProjects(prime) + params := projects.NewParams() + err = runner.Run(params) + if err != nil { + return nil, errs.Wrap(err, "Failed to run projects") + } + + return mcp.NewToolResultText(byt.String()), nil +} + +// manifestHandler handles the view_manifest tool +func (t *mcpServerHandler) manifestHandler(ctx context.Context, request mcp.CallToolRequest) (r *mcp.CallToolResult, rerr error) { + pjPath := request.Params.Arguments["project_directory"].(string) + + var byt bytes.Buffer + prime, close, err := t.newPrimer(pjPath, &byt) + if err != nil { + return nil, errs.Wrap(err, "Failed to create primer") + } + defer func() { + if err := close(); err != nil { + rerr = errs.Pack(rerr, err) + } + }() + + m := manifest.NewManifest(prime) + err = m.Run(manifest.Params{}) + if err != nil { + return nil, errs.Wrap(err, "Failed to run manifest") + } + + return mcp.NewToolResultText(byt.String()), nil +} + +// cveHandler handles the view_cves tool +func (t *mcpServerHandler) cveHandler(ctx context.Context, request mcp.CallToolRequest) (r *mcp.CallToolResult, rerr error) { + pjPath := request.Params.Arguments["project_directory"].(string) + + var byt bytes.Buffer + prime, close, err := t.newPrimer(pjPath, &byt) + if err != nil { + return nil, errs.Wrap(err, "Failed to create primer") + } + defer func() { + if err := close(); err != nil { + rerr = errs.Pack(rerr, err) + } + }() + + c := cve.NewCve(prime) + err = c.Run(&cve.Params{}) + if err != nil { + return nil, errs.Wrap(err, "Failed to run manifest") + } + + return mcp.NewToolResultText(byt.String()), nil +} + +// lookupCveHandler handles the lookup_cve tool +func (t *mcpServerHandler) lookupCveHandler(ctx context.Context, request mcp.CallToolRequest) (r *mcp.CallToolResult, rerr error) { + cveId := request.Params.Arguments["cve_ids"].(string) + cveIds := strings.Split(cveId, ",") + + results, err := LookupCve(cveIds...) + if err != nil { + return nil, errs.Wrap(err, "Failed to lookup CVEs") + } + + byt, err := json.Marshal(results) + if err != nil { + return nil, errs.Wrap(err, "Failed to marshal results") + } + + return mcp.NewToolResultText(string(byt)), nil +} \ No newline at end of file diff --git a/cmd/state-mcp/lookupcve.go b/cmd/state-mcp/lookupcve.go new file mode 100644 index 0000000000..2c113b3821 --- /dev/null +++ b/cmd/state-mcp/lookupcve.go @@ -0,0 +1,38 @@ +package main + +import ( + "encoding/json" + "fmt" + "net/http" + + "github.com/ActiveState/cli/internal/chanutils/workerpool" + "github.com/ActiveState/cli/internal/errs" +) + +func LookupCve(cveIds ...string) (map[string]interface{}, error) { + results := map[string]interface{}{} + // https://api.osv.dev/v1/vulns/OSV-2020-111 + wp := workerpool.New(5) + for _, cveId := range cveIds { + wp.Submit(func() error { + resp, err := http.Get(fmt.Sprintf("https://api.osv.dev/v1/vulns/%s", cveId)) + if err != nil { + return err + } + defer resp.Body.Close() + var result map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return err + } + results[cveId] = result + return nil + }) + } + + err := wp.Wait() + if err != nil { + return nil, errs.Wrap(err, "Failed to wait for workerpool") + } + + return results, nil +} \ No newline at end of file diff --git a/cmd/state-mcp/lookupcve_test.go b/cmd/state-mcp/lookupcve_test.go new file mode 100644 index 0000000000..e643f9672b --- /dev/null +++ b/cmd/state-mcp/lookupcve_test.go @@ -0,0 +1,43 @@ +package main + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestLookupCve(t *testing.T) { + // Table-driven test cases + tests := []struct { + name string + cveIds []string + }{ + { + name: "Single CVE", + cveIds: []string{"CVE-2021-44228"}, + }, + { + name: "Multiple CVEs", + cveIds: []string{"CVE-2021-44228", "CVE-2022-22965"}, + }, + { + name: "Non-existent CVE", + cveIds: []string{"CVE-DOES-NOT-EXIST"}, + }, + { + name: "Empty Input", + cveIds: []string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + results, err := LookupCve(tt.cveIds...) + require.NoError(t, err) + require.NotNil(t, results) + for _, cveId := range tt.cveIds { + require.Contains(t, results, cveId) + } + }) + } +} \ No newline at end of file diff --git a/cmd/state-mcp/main.go b/cmd/state-mcp/main.go new file mode 100644 index 0000000000..88f62e5785 --- /dev/null +++ b/cmd/state-mcp/main.go @@ -0,0 +1,51 @@ +package main + +import ( + "flag" + "fmt" + "os" + "runtime/debug" + "time" + + "github.com/ActiveState/cli/internal/events" + "github.com/ActiveState/cli/internal/logging" + "github.com/mark3labs/mcp-go/server" +) + +func main() { + defer func() { + logging.Debug("Exiting") + if r := recover(); r != nil { + logging.Error("Recovered from panic: %v", r) + fmt.Printf("Recovered from panic: %v, stack: %s\n", r, string(debug.Stack())) + os.Exit(1) + } + }() + defer func() { + if err := events.WaitForEvents(5*time.Second, logging.Close); err != nil { + logging.Warning("Failed waiting for events: %v", err) + } + }() + + mcpHandler := registerServer() + + // Parse command line flags + rawFlag := flag.String("type", "", "Type of MCP server to run; raw, curated or scripts") + flag.Parse() + switch *rawFlag { + case "raw": + close := registerRawTools(mcpHandler) + defer close() + case "scripts": + close := registerScriptTools(mcpHandler) + defer close() + default: + registerCuratedTools(mcpHandler) + } + + // Start the stdio server + logging.Info("Starting MCP server") + if err := server.ServeStdio(mcpHandler.mcpServer); err != nil { + logging.Error("Server error: %v\n", err) + } +} \ No newline at end of file diff --git a/cmd/state-mcp/primer.go b/cmd/state-mcp/primer.go new file mode 100644 index 0000000000..7215cd42dd --- /dev/null +++ b/cmd/state-mcp/primer.go @@ -0,0 +1,90 @@ +package main + +import ( + "context" + "io" + + "github.com/ActiveState/cli/internal/config" + "github.com/ActiveState/cli/internal/constraints" + "github.com/ActiveState/cli/internal/errs" + "github.com/ActiveState/cli/internal/logging" + "github.com/ActiveState/cli/internal/multilog" + "github.com/ActiveState/cli/internal/output" + "github.com/ActiveState/cli/internal/primer" + "github.com/ActiveState/cli/internal/subshell" + "github.com/ActiveState/cli/pkg/platform/authentication" + "github.com/ActiveState/cli/pkg/platform/model" + "github.com/ActiveState/cli/pkg/project" + "github.com/ActiveState/cli/pkg/projectfile" +) + +// newPrimer creates a new primer.Values instance for use with command execution +func (t *mcpServerHandler) newPrimer(projectDir string, o io.Writer) (*primer.Values, func() error, error) { + closers := []func() error{} + closer := func() error { + for _, c := range closers { + if err := c(); err != nil { + return err + } + } + return nil + } + + cfg, err := config.New() + if err != nil { + return nil, closer, errs.Wrap(err, "Failed to create config") + } + closers = append(closers, cfg.Close) + + auth := authentication.New(cfg) + closers = append(closers, auth.Close) + + out, err := output.New(string(output.SimpleFormatName), &output.Config{ + OutWriter: o, + ErrWriter: o, + Colored: false, + Interactive: false, + ShellName: "", + }) + if err != nil { + return nil, closer, errs.Wrap(err, "Failed to create output") + } + + var pj *project.Project + if projectDir != "" { + pjf, err := projectfile.FromPath(projectDir) + if err != nil { + return nil, closer, errs.Wrap(err, "Failed to create projectfile") + } + pj, err = project.New(pjf, out) + if err != nil { + return nil, closer, errs.Wrap(err, "Failed to create project") + } + } + + // Set up conditional, which accesses a lot of primer data + sshell := subshell.New(cfg) + + conditional := constraints.NewPrimeConditional(auth, pj, sshell.Shell()) + project.RegisterConditional(conditional) + if err := project.RegisterExpander("mixin", project.NewMixin(auth).Expander); err != nil { + logging.Debug("Could not register mixin expander: %v", err) + } + + svcmodel := model.NewSvcModel(t.svcPort) + + if auth.AvailableAPIToken() != "" { + jwt, err := svcmodel.GetJWT(context.Background()) + if err != nil { + multilog.Critical("Could not get JWT: %v", errs.JoinMessage(err)) + } + if err != nil || jwt == nil { + // Could not authenticate; user got logged out + auth.Logout() + } else { + auth.UpdateSession(jwt) + } + } + + return primer.New(pj, out, auth, sshell, conditional, cfg, t.ipcClient, svcmodel), closer, nil +} \ No newline at end of file diff --git a/cmd/state-mcp/server.go b/cmd/state-mcp/server.go new file mode 100644 index 0000000000..7a4461c98d --- /dev/null +++ b/cmd/state-mcp/server.go @@ -0,0 +1,109 @@ +package main + +import ( + "context" + "fmt" + "os" + "strings" + + "github.com/ActiveState/cli/internal/constants" + "github.com/ActiveState/cli/internal/errs" + "github.com/ActiveState/cli/internal/installation" + "github.com/ActiveState/cli/internal/ipc" + "github.com/ActiveState/cli/internal/logging" + "github.com/ActiveState/cli/internal/svcctl" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// mcpServerHandler wraps the MCP server and provides methods for adding tools and resources +type mcpServerHandler struct { + mcpServer *server.MCPServer + ipcClient *ipc.Client + svcPort string +} + +// registerServer creates and configures a new MCP server +func registerServer() *mcpServerHandler { + ipcClient, svcPort, err := connectToSvc() + if err != nil { + panic(errs.JoinMessage(err)) + } + + // Create MCP server + s := server.NewMCPServer( + constants.CommandName, + constants.VersionNumber, + ) + + mcpHandler := &mcpServerHandler{ + mcpServer: s, + ipcClient: ipcClient, + svcPort: svcPort, + } + + return mcpHandler +} + +// addResource adds a resource to the MCP server with error handling and logging +func (t *mcpServerHandler) addResource(resource mcp.Resource, handler server.ResourceHandlerFunc) { + t.mcpServer.AddResource(resource, func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + defer func() { + if r := recover(); r != nil { + logging.Error("Recovered from resource handler panic: %v", r) + fmt.Printf("Recovered from resource handler panic: %v\n", r) + } + }() + logging.Debug("Received resource request: %s", resource.Name) + r, err := handler(ctx, request) + if err != nil { + logging.Error("%s: Error handling resource request: %v", resource.Name, err) + return nil, errs.Wrap(err, "Failed to handle resource request") + } + return r, nil + }) +} + +// addTool adds a tool to the MCP server with error handling and logging +func (t *mcpServerHandler) addTool(tool mcp.Tool, handler server.ToolHandlerFunc) { + t.mcpServer.AddTool(tool, func(ctx context.Context, request mcp.CallToolRequest) (r *mcp.CallToolResult, rerr error) { + defer func() { + if r := recover(); r != nil { + logging.Error("Recovered from tool handler panic: %v", r) + fmt.Printf("Recovered from tool handler panic: %v\n", r) + } + }() + logging.Debug("Received tool request: %s", tool.Name) + r, err := handler(ctx, request) + logging.Debug("Received tool response from %s", tool.Name) + if err != nil { + logging.Error("%s: Error handling tool request: %v", tool.Name, errs.JoinMessage(err)) + // Format all errors as a single string, so the client gets the full context + return nil, fmt.Errorf("%s: %s", tool.Name, errs.JoinMessage(err)) + } + return r, nil + }) +} + +type stdOutput struct{} + +func (s *stdOutput) Notice(msg interface{}) { + logging.Info(fmt.Sprintf("%v", msg)) +} + +// connectToSvc connects to the state service and returns an IPC client +func connectToSvc() (*ipc.Client, string, error) { + svcExec, err := installation.ServiceExec() + if err != nil { + return nil, "", errs.Wrap(err, "Could not get service info") + } + + ipcClient := svcctl.NewDefaultIPCClient() + argText := strings.Join(os.Args, " ") + svcPort, err := svcctl.EnsureExecStartedAndLocateHTTP(ipcClient, svcExec, argText, &stdOutput{}) + if err != nil { + return nil, "", errs.Wrap(err, "Failed to start state-svc at state tool invocation") + } + + return ipcClient, svcPort, nil +} \ No newline at end of file diff --git a/cmd/state-mcp/server_test.go b/cmd/state-mcp/server_test.go new file mode 100644 index 0000000000..c43892e8e8 --- /dev/null +++ b/cmd/state-mcp/server_test.go @@ -0,0 +1,45 @@ +package main + +import ( + "context" + "encoding/json" + "testing" + + "github.com/ActiveState/cli/internal/environment" +) + +func TestServerProjects(t *testing.T) { + t.Skip("Intended for manual testing") + mcpHandler := registerServer() + registerRawTools(mcpHandler) + + msg := mcpHandler.mcpServer.HandleMessage(context.Background(), json.RawMessage(`{ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": { + "name": "projects", + "arguments": {} + } + }`)) + t.Fatalf("%+v", msg) +} + +func TestServerPackages(t *testing.T) { + t.Skip("Intended for manual testing") + mcpHandler := registerServer() + registerRawTools(mcpHandler) + + msg := mcpHandler.mcpServer.HandleMessage(context.Background(), json.RawMessage(`{ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": { + "name": "packages", + "arguments": { + "project_directory": "`+environment.GetRootPathUnsafe()+`" + } + } + }`)) + t.Fatalf("%+v", msg) +} diff --git a/cmd/state-mcp/tools.go b/cmd/state-mcp/tools.go new file mode 100644 index 0000000000..fb70b11f21 --- /dev/null +++ b/cmd/state-mcp/tools.go @@ -0,0 +1,179 @@ +package main + +import ( + "bytes" + "context" + "fmt" + "os" + "strings" + + "github.com/ActiveState/cli/cmd/state/donotshipme" + "github.com/ActiveState/cli/internal/constants" + "github.com/ActiveState/cli/internal/errs" + "github.com/ActiveState/cli/internal/logging" + "github.com/ActiveState/cli/internal/scriptrun" + "github.com/ActiveState/cli/internal/sliceutils" + "github.com/ActiveState/cli/pkg/project" + "github.com/mark3labs/mcp-go/mcp" +) + +func registerScriptTools(mcpHandler *mcpServerHandler) func() error { + byt := &bytes.Buffer{} + prime, close, err := mcpHandler.newPrimer(os.Getenv(constants.ActivatedStateEnvVarName), byt) + if err != nil { + panic(err) + } + + scripts, err := prime.Project().Scripts() + if err != nil { + panic(err) + } + + for _, script := range scripts { + mcpHandler.addTool(mcp.NewTool(script.Name(), + mcp.WithDescription(script.Description()), + ), func(ctx context.Context, request mcp.CallToolRequest) (r *mcp.CallToolResult, rerr error) { + byt.Truncate(0) + + scriptrunner := scriptrun.New(prime) + if !script.Standalone() && scriptrunner.NeedsActivation() { + if err := scriptrunner.PrepareVirtualEnv(); err != nil { + return nil, errs.Wrap(err, "Failed to prepare virtual environment") + } + } + + err := scriptrunner.Run(script, []string{}) + if err != nil { + return nil, errs.Wrap(err, "Failed to run script") + } + + return mcp.NewToolResultText(byt.String()), nil + }) + } + + return close +} + +// registerCuratedTools registers a curated set of tools for the AI assistant +func registerCuratedTools(mcpHandler *mcpServerHandler) { + projectDirParam := mcp.WithString("project_directory", + mcp.Required(), + mcp.Description("Absolute path to the directory where your activestate project is checked out. It should contain the activestate.yaml file."), + ) + + mcpHandler.addTool(mcp.NewTool("list_projects", + mcp.WithDescription("List all ActiveState projects checked out on the local machine"), + ), mcpHandler.listProjectsHandler) + + mcpHandler.addTool(mcp.NewTool("view_manifest", + mcp.WithDescription("Show the manifest (packages and dependencies) for a locally checked out ActiveState platform project"), + projectDirParam, + ), mcpHandler.manifestHandler) + + mcpHandler.addTool(mcp.NewTool("view_cves", + mcp.WithDescription("Show the CVEs for a locally checked out ActiveState platform project"), + projectDirParam, + ), mcpHandler.cveHandler) + + mcpHandler.addTool(mcp.NewTool("lookup_cve", + mcp.WithDescription("Lookup one or more CVEs by their ID"), + mcp.WithString("cve_ids", + mcp.Required(), + mcp.Description("The IDs of the CVEs to lookup, comma separated"), + ), + ), mcpHandler.lookupCveHandler) +} + +// registerRawTools registers all State Tool commands as raw tools +func registerRawTools(mcpHandler *mcpServerHandler) func() error { + byt := &bytes.Buffer{} + prime, close, err := mcpHandler.newPrimer("", byt) + if err != nil { + panic(err) + } + + require := func(b bool) mcp.PropertyOption { + if b { + return mcp.Required() + } + return func(map[string]interface{}) {} + } + + tree := donotshipme.CmdTree(prime) + for _, command := range tree.Command().AllChildren() { + // Best effort to filter out interactive commands + if sliceutils.Contains([]string{"activate", "shell"}, command.NameRecursive()) { + continue + } + + opts := []mcp.ToolOption{ + mcp.WithDescription(command.Description()), + } + + // Require project directory for most commands. This is currently not encoded into the command tree + if !sliceutils.Contains([]string{"projects", "auth"}, command.BaseCommand().Name()) { + opts = append(opts, mcp.WithString( + "project_directory", + mcp.Required(), + mcp.Description("Absolute path to the directory where your activestate project is checked out. It should contain the activestate.yaml file."), + )) + } + + for _, arg := range command.Arguments() { + opts = append(opts, mcp.WithString(arg.Name, + require(arg.Required), + mcp.Description(arg.Description), + )) + } + for _, flag := range command.Flags() { + opts = append(opts, mcp.WithString(flag.Name, + mcp.Description(flag.Description), + )) + } + mcpHandler.addTool( + mcp.NewTool(strings.Join(strings.Split(command.NameRecursive(), " "), "_"), opts...), + func(ctx context.Context, request mcp.CallToolRequest) (r *mcp.CallToolResult, rerr error) { + byt.Truncate(0) + if projectDir, ok := request.Params.Arguments["project_directory"]; ok { + pj, err := project.FromPath(projectDir.(string)) + if err != nil { + return nil, errs.Wrap(err, "Failed to create project") + } + prime.SetProject(pj) + } + // Reinitialize tree with updated primer, because currently our command can take things + // from the primer at the time of registration, and not the time of invocation. + invocationTree := donotshipme.CmdTree(prime) + for _, child := range invocationTree.Command().AllChildren() { + if child.NameRecursive() == command.NameRecursive() { + command = child + break + } + } + args := strings.Split(command.NameRecursive(), " ") + for _, arg := range command.Arguments() { + v, ok := request.Params.Arguments[arg.Name] + if !ok { + break + } + args = append(args, v.(string)) + } + for _, flag := range command.Flags() { + v, ok := request.Params.Arguments[flag.Name] + if !ok { + break + } + args = append(args, fmt.Sprintf("--%s=%s", flag.Name, v.(string))) + } + logging.Debug("Executing command: %s, args: %v (%v)", command.NameRecursive(), args, args == nil) + err := command.Execute(args) + if err != nil { + return nil, errs.Wrap(err, "Failed to execute command") + } + return mcp.NewToolResultText(byt.String()), nil + }, + ) + } + + return close +} \ No newline at end of file diff --git a/cmd/state/donotshipme/donotshipme.go b/cmd/state/donotshipme/donotshipme.go new file mode 100644 index 0000000000..b021fc9ca5 --- /dev/null +++ b/cmd/state/donotshipme/donotshipme.go @@ -0,0 +1,17 @@ +package donotshipme + +import ( + "github.com/ActiveState/cli/cmd/state/internal/cmdtree" + "github.com/ActiveState/cli/internal/constants" + "github.com/ActiveState/cli/internal/primer" +) + +func init() { + if constants.ChannelName == "release" { + panic("This file is for experimentation only, it should not be shipped as is. CmdTree is internal to the State command and should remain that way or be refactored.") + } +} + +func CmdTree(prime *primer.Values, args ...string) *cmdtree.CmdTree { + return cmdtree.New(prime, args...) +} \ No newline at end of file diff --git a/go.mod b/go.mod index d83833829c..d5f3fd3da6 100644 --- a/go.mod +++ b/go.mod @@ -45,7 +45,6 @@ require ( github.com/phayes/permbits v0.0.0-20190108233746-1efae4548023 github.com/posener/wstest v0.0.0-20180216222922-04b166ca0bf1 github.com/rollbar/rollbar-go v1.1.0 - github.com/shibukawa/configdir v0.0.0-20170330084843-e180dbdc8da0 github.com/shirou/gopsutil/v3 v3.24.5 github.com/skratchdot/open-golang v0.0.0-20190104022628-a2dfa6d0dab6 github.com/spf13/cast v1.3.0 @@ -76,6 +75,7 @@ require ( github.com/go-git/go-git/v5 v5.13.1 github.com/gowebpki/jcs v1.0.1 github.com/klauspost/compress v1.11.4 + github.com/mark3labs/mcp-go v0.18.0 github.com/mholt/archiver/v3 v3.5.1 github.com/zijiren233/yaml-comment v0.2.1 ) @@ -109,6 +109,7 @@ require ( github.com/shoenig/go-m1cpu v0.1.6 // indirect github.com/skeema/knownhosts v1.3.0 // indirect github.com/sosodev/duration v1.3.1 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect golang.org/x/sync v0.11.0 // indirect ) diff --git a/go.sum b/go.sum index 1661bd5015..30abed107f 100644 --- a/go.sum +++ b/go.sum @@ -472,6 +472,8 @@ github.com/mailru/easyjson v0.7.1/go.mod h1:KAzv3t3aY1NaHWoQz1+4F1ccyAH66Jk7yos7 github.com/mailru/easyjson v0.7.6/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/mark3labs/mcp-go v0.18.0 h1:YuhgIVjNlTG2ZOwmrkORWyPTp0dz1opPEqvsPtySXao= +github.com/mark3labs/mcp-go v0.18.0/go.mod h1:KmJndYv7GIgcPVwEKJjNcbhVQ+hJGJhrCCB/9xITzpE= github.com/markbates/oncer v0.0.0-20181203154359-bf2de49a0be2/go.mod h1:Ld9puTsIW75CHf65OeIOkyKbteujpZVXDpWK6YGZbxE= github.com/markbates/safe v1.0.1/go.mod h1:nAqgmRi7cY2nqMc92/bSEeQA+R4OheNU2T1kNSCBdG0= github.com/maruel/natural v1.1.0 h1:2z1NgP/Vae+gYrtC0VuvrTJ6U35OuyUqDdfluLqMWuQ= @@ -602,8 +604,6 @@ github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 h1:n661drycOFuPLCN3Uc8sB6B/s6Z4t2xvBgU1htSHuq8= github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4= -github.com/shibukawa/configdir v0.0.0-20170330084843-e180dbdc8da0 h1:Xuk8ma/ibJ1fOy4Ee11vHhUFHQNpHhrBneOCNHVXS5w= -github.com/shibukawa/configdir v0.0.0-20170330084843-e180dbdc8da0/go.mod h1:7AwjWCpdPhkSmNAgUv5C7EJ4AbmjEB3r047r3DXWu3Y= github.com/shirou/gopsutil/v3 v3.24.5 h1:i0t8kL+kQTvpAYToeuiVk3TgDeKOFioZO3Ztz/iZ9pI= github.com/shirou/gopsutil/v3 v3.24.5/go.mod h1:bsoOS1aStSs9ErQ1WWfxllSeS1K5D+U30r2NfcubMVk= github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM= @@ -684,6 +684,8 @@ github.com/xdg/stringprep v0.0.0-20180714160509-73f8eece6fdc/go.mod h1:Jhud4/sHM github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8 h1:nIPpBwaJSVYIxUFsDv3M8ofmx9yWTog9BfvIu0q41lo= github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8/go.mod h1:HUYIGzjTL3rfEspMxjDjgmT5uz5wzYJKVo23qUhYTos= github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7JulP+udvsHwJoVG1YGAP6VLg4y9I5dyZdqmA= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= diff --git a/internal/analytics/client/sync/reporters/test.go b/internal/analytics/client/sync/reporters/test.go index 95ff0c426a..7fe70846b6 100644 --- a/internal/analytics/client/sync/reporters/test.go +++ b/internal/analytics/client/sync/reporters/test.go @@ -18,10 +18,8 @@ type TestReporter struct { const TestReportFilename = "analytics.log" func TestReportFilepath() string { - appdata, err := storage.AppDataPath() - if err != nil { - logging.Warning("Could not acquire appdata path, using cwd instead. Error received: %s", errs.JoinMessage(err)) - } + appdata := storage.AppDataPath() + logging.Warning("Appdata path: %s", appdata) return filepath.Join(appdata, TestReportFilename) } diff --git a/internal/captain/command.go b/internal/captain/command.go index 16a25bc252..13288e04e9 100644 --- a/internal/captain/command.go +++ b/internal/captain/command.go @@ -219,9 +219,12 @@ func (c *Command) ShortDescription() string { func (c *Command) Execute(args []string) error { defer profile.Measure("cobra:Execute", time.Now()) c.logArgs(args) - c.cobra.SetArgs(args) - err := c.cobra.Execute() - c.cobra.SetArgs(nil) + // Cobra always executes the root command, so we need to set the args for the root command + // This makes running command.Execute() super error-prone if the args don't match the command + // We should probably get rid of Cobra over issues like this + c.cobra.Root().SetArgs(args) + err := c.cobra.Root().Execute() + c.cobra.Root().SetArgs(nil) rationalizeError(&err) return setupSensibleErrors(err, args) } @@ -272,15 +275,26 @@ func (c *Command) Name() string { } func (c *Command) NameRecursive() string { - child := c + parent := c name := []string{} - for child != nil { - name = append([]string{child.Name()}, name...) - child = child.parent + for parent != nil { + name = append([]string{parent.Name()}, name...) + parent = parent.parent + if parent.parent == nil { + break // Don't include the root command in the name + } } return strings.Join(name, " ") } +func (c *Command) BaseCommand() *Command { + base := c + for base.parent != nil && base.parent.parent != nil { + base = base.parent + } + return base +} + func (c *Command) NamePadding() int { return c.cobra.NamePadding() } @@ -457,6 +471,15 @@ func (c *Command) Children() []*Command { return commands } +func (c *Command) AllChildren() []*Command { + commands := []*Command{} + for _, child := range c.Children() { + commands = append(commands, child) + commands = append(commands, child.AllChildren()...) + } + return commands +} + func (c *Command) AvailableChildren() []*Command { commands := []*Command{} for _, child := range c.Children() { diff --git a/internal/config/instance.go b/internal/config/instance.go index fd6d68fe48..d1c1a47e24 100644 --- a/internal/config/instance.go +++ b/internal/config/instance.go @@ -41,14 +41,10 @@ func NewCustom(localPath string, thread *singlethread.Thread, closeThread bool) i.thread = thread i.closeThread = closeThread - var err error if localPath != "" { - i.appDataDir, err = storage.AppDataPathWithParent(localPath) + i.appDataDir = storage.AppDataPathWithParent(localPath) } else { - i.appDataDir, err = storage.AppDataPath() - } - if err != nil { - return nil, errs.Wrap(err, "Could not detect appdata dir") + i.appDataDir = storage.AppDataPath() } // Ensure appdata dir exists, because the sqlite driver sure doesn't @@ -61,6 +57,7 @@ func NewCustom(localPath string, thread *singlethread.Thread, closeThread bool) path := filepath.Join(i.appDataDir, C.InternalConfigFileName) + var err error t := time.Now() i.db, err = sql.Open("sqlite", path) if err != nil { diff --git a/internal/constants/constants.go b/internal/constants/constants.go index da5104fe4b..29319a310a 100644 --- a/internal/constants/constants.go +++ b/internal/constants/constants.go @@ -436,6 +436,9 @@ const InstallerName = "State Installer" // StateExecutorCmd is the name of the state executor binary const StateExecutorCmd = "state-exec" +// StateMCPCmd is the name of the state mcp binary +const StateMCPCmd = "state-mcp" + // LegacyToplevelInstallArchiveDir is the top-level directory for files in an installation archive // This constant will be removed in DX-2081. const LegacyToplevelInstallArchiveDir = "state-install" diff --git a/internal/installation/appinfo.go b/internal/installation/appinfo.go index 2a18a42e93..0099891ad4 100644 --- a/internal/installation/appinfo.go +++ b/internal/installation/appinfo.go @@ -46,7 +46,9 @@ func newExecFromDir(baseDir string, exec executableType) (string, error) { // Work around dlv and goland debugger giving an unexpected executable path if !condition.BuiltViaCI() && len(os.Args) > 1 && - (strings.Contains(os.Args[0], "__debug_bin") || strings.Contains(filepath.ToSlash(os.Args[0]), "GoLand/___")) { + (strings.Contains(os.Args[0], "__debug_bin") || + strings.Contains(filepath.ToSlash(os.Args[0]), "GoLand/___") || + strings.Contains(os.Args[0], "go-build")) { rootPath := filepath.Clean(environment.GetRootPathUnsafe()) path = filepath.Join(rootPath, "build") } diff --git a/internal/installation/storage/storage.go b/internal/installation/storage/storage.go index a8659cbed1..f511f8d7c9 100644 --- a/internal/installation/storage/storage.go +++ b/internal/installation/storage/storage.go @@ -11,12 +11,29 @@ import ( "github.com/ActiveState/cli/internal/constants" "github.com/ActiveState/cli/internal/osutils/user" "github.com/google/uuid" - "github.com/shibukawa/configdir" ) -func AppDataPath() (string, error) { - configDirs := configdir.New(constants.InternalConfigNamespace, fmt.Sprintf("%s-%s", constants.LibraryName, constants.ChannelName)) +var homeDir string + +func init() { + var err error + homeDir, err = user.HomeDir() + if err != nil { + panic(fmt.Sprintf("Could not get home dir, you can fix this by ensuring the $HOME environment variable is set. Error: %v", err)) + } +} + + +func relativeAppDataPath() string { + return filepath.Join(constants.InternalConfigNamespace, fmt.Sprintf("%s-%s", constants.LibraryName, constants.ChannelName)) +} + +func relativeCachePath() string { + return constants.InternalConfigNamespace +} + +func AppDataPath() string { localPath, envSet := os.LookupEnv(constants.ConfigEnvVarName) if envSet { return AppDataPathWithParent(localPath) @@ -27,35 +44,10 @@ func AppDataPath() (string, error) { // panic as this only happening in tests panic(err) } - return localPath, nil + return localPath } - // Account for HOME dir not being set, meaning querying global folders will fail - // This is a workaround for docker envs that don't usually have $HOME set - _, envSet = os.LookupEnv("HOME") - if !envSet && runtime.GOOS != "windows" { - homeDir, err := user.HomeDir() - if err != nil { - if !condition.InUnitTest() { - return "", fmt.Errorf("Could not get user home directory: %w", err) - } - // Use temp dir if we're in a test (we don't want to write to our src directory) - var err error - localPath, err = os.MkdirTemp("", "cli-config-test") - if err != nil { - return "", fmt.Errorf("could not create temp dir: %w", err) - } - return AppDataPathWithParent(localPath) - } - os.Setenv("HOME", homeDir) - } - - dir := configDirs.QueryFolders(configdir.Global)[0].Path - if err := os.MkdirAll(dir, os.ModePerm); err != nil { - return "", fmt.Errorf("could not create appdata dir: %s", dir) - } - - return dir, nil + return AppDataPathWithParent(BaseAppDataPath()) } var _appDataPathInTest string @@ -67,11 +59,11 @@ func appDataPathInTest() (string, error) { localPath, err := os.MkdirTemp("", "cli-config") if err != nil { - return "", fmt.Errorf("Could not create temp dir: %w", err) + return "", fmt.Errorf("could not create temp dir: %w", err) } err = os.RemoveAll(localPath) if err != nil { - return "", fmt.Errorf("Could not remove generated config dir for tests: %w", err) + return "", fmt.Errorf("could not remove generated config dir for tests: %w", err) } _appDataPathInTest = localPath @@ -79,16 +71,15 @@ func appDataPathInTest() (string, error) { return localPath, nil } -func AppDataPathWithParent(parentDir string) (string, error) { - configDirs := configdir.New(constants.InternalConfigNamespace, fmt.Sprintf("%s-%s", constants.LibraryName, constants.ChannelName)) - configDirs.LocalPath = parentDir - dir := configDirs.QueryFolders(configdir.Local)[0].Path - +func AppDataPathWithParent(parentDir string) string { + dir := filepath.Join(parentDir, relativeAppDataPath()) if err := os.MkdirAll(dir, os.ModePerm); err != nil { - return "", fmt.Errorf("could not create appdata dir: %s", dir) + // Can't use logging here because it would cause a circular dependency + // This would only happen if the user has corrupt permissions on their home dir + os.Stderr.WriteString(fmt.Sprintf("Could not create appdata dir: %s", dir)) } - return dir, nil + return dir } // CachePath returns the path at which our cache is stored @@ -108,17 +99,15 @@ func CachePath() string { cachePath = filepath.Join(drive, "temp", prefix+uuid.New().String()[0:8]) } } - } else if path := os.Getenv(constants.CacheEnvVarName); path != "" { - cachePath = path - } else { - cachePath = configdir.New(constants.InternalConfigNamespace, "").QueryCacheFolder().Path - if runtime.GOOS == "windows" { - // Explicitly append "cache" dir as the cachedir on Windows is the same as the local appdata dir (conflicts with config) - cachePath = filepath.Join(cachePath, "cache") - } + return cachePath + + } + + if path := os.Getenv(constants.CacheEnvVarName); path != "" { + return path } - return cachePath + return filepath.Join(BaseCachePath(), relativeCachePath()) } func GlobalBinDir() string { @@ -127,11 +116,7 @@ func GlobalBinDir() string { // InstallSource returns the installation source of the State Tool func InstallSource() (string, error) { - path, err := AppDataPath() - if err != nil { - return "", fmt.Errorf("Could not detect AppDataPath: %w", err) - } - + path := AppDataPath() installFilePath := filepath.Join(path, constants.InstallSourceFile) installFileData, err := os.ReadFile(installFilePath) if err != nil { diff --git a/internal/installation/storage/storage_darwin.go b/internal/installation/storage/storage_darwin.go new file mode 100644 index 0000000000..2dde6e1630 --- /dev/null +++ b/internal/installation/storage/storage_darwin.go @@ -0,0 +1,13 @@ +package storage + +import ( + "path/filepath" +) + +func BaseAppDataPath() string { + return filepath.Join(homeDir, "Library", "Application Support") +} + +func BaseCachePath() string { + return filepath.Join(homeDir, "Library", "Caches") +} diff --git a/internal/installation/storage/storage_test.go b/internal/installation/storage/storage_test.go index d0887b59cf..6ecef193ad 100644 --- a/internal/installation/storage/storage_test.go +++ b/internal/installation/storage/storage_test.go @@ -4,13 +4,10 @@ import ( "testing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) func Test_AppDataPath(t *testing.T) { - path1, err := AppDataPath() - require.NoError(t, err) - path2, err := AppDataPath() - require.NoError(t, err) + path1 := AppDataPath() + path2 := AppDataPath() assert.Equal(t, path1, path2) } diff --git a/internal/installation/storage/storage_windows.go b/internal/installation/storage/storage_windows.go new file mode 100644 index 0000000000..4a4986440a --- /dev/null +++ b/internal/installation/storage/storage_windows.go @@ -0,0 +1,22 @@ +package storage + +import ( + "os" + "path/filepath" +) + +func BaseAppDataPath() string { + if appData := os.Getenv("APPDATA"); appData != "" { + return appData + } + + return filepath.Join(homeDir, "AppData", "Roaming") +} + +func BaseCachePath() string { + if cache := os.Getenv("LOCALAPPDATA"); cache != "" { + return cache + } + + return filepath.Join(homeDir, "AppData", "Local", "cache") +} diff --git a/internal/installation/storage/storage_xdg.go b/internal/installation/storage/storage_xdg.go new file mode 100644 index 0000000000..1efefd0649 --- /dev/null +++ b/internal/installation/storage/storage_xdg.go @@ -0,0 +1,25 @@ +//go:build !windows && !darwin +// +build !windows,!darwin + +package storage + +import ( + "os" + "path/filepath" +) + +func BaseAppDataPath() string { + if os.Getenv("XDG_CONFIG_HOME") != "" { + return os.Getenv("XDG_CONFIG_HOME") + } + + return filepath.Join(homeDir, ".config") +} + +func BaseCachePath() string { + if os.Getenv("XDG_CACHE_HOME") != "" { + return os.Getenv("XDG_CACHE_HOME") + } + + return filepath.Join(homeDir, ".cache") +} diff --git a/internal/logging/defaults.go b/internal/logging/defaults.go index 68a5c897bf..395e41c446 100644 --- a/internal/logging/defaults.go +++ b/internal/logging/defaults.go @@ -83,13 +83,7 @@ func init() { defer func() { handlePanics(recover()) }() // Set up datadir - var err error - datadir, err = storage.AppDataPath() - if err != nil { - log.SetOutput(os.Stderr) - Error("Could not detect AppData dir: %v", err) - return - } + datadir = storage.AppDataPath() // Set up handler timestamp = time.Now().UnixNano() diff --git a/internal/osutils/user/user.go b/internal/osutils/user/user.go index 1b485d0fab..e8238c3063 100644 --- a/internal/osutils/user/user.go +++ b/internal/osutils/user/user.go @@ -2,8 +2,10 @@ package user import ( "os" + "os/user" "github.com/ActiveState/cli/internal/constants" + "github.com/ActiveState/cli/internal/errs" ) // HomeDirNotFoundError is an error that implements the ErrorLocalier and ErrorInput interfaces @@ -40,9 +42,16 @@ func HomeDir() (string, error) { if dir := os.Getenv(constants.HomeEnvVarName); dir != "" { return dir, nil } - dir, err := os.UserHomeDir() - if err != nil { - return "", &HomeDirNotFoundError{err} + + u, err := user.Current() + if err == nil { + return u.HomeDir, nil + } + + // If we can't get the current user, try to get the home dir from the os + dir, err2 := os.UserHomeDir() + if err2 != nil { + return "", &HomeDirNotFoundError{errs.Pack(err, err2)} } return dir, nil } diff --git a/internal/osutils/user/user_test.go b/internal/osutils/user/user_test.go index 27c5010622..9a487750a4 100644 --- a/internal/osutils/user/user_test.go +++ b/internal/osutils/user/user_test.go @@ -39,6 +39,5 @@ func TestNoHome(t *testing.T) { defer func() { os.Setenv("USERPROFILE", osHomeDir) }() } _, err = HomeDir() - assert.Error(t, err) - assert.Contains(t, err.Error(), "HOME environment variable is unset") + assert.NoError(t, err) } diff --git a/internal/runners/cve/cve.go b/internal/runners/cve/cve.go index 96c24a24f4..a46da01c26 100644 --- a/internal/runners/cve/cve.go +++ b/internal/runners/cve/cve.go @@ -9,6 +9,7 @@ import ( "github.com/ActiveState/cli/internal/errs" "github.com/ActiveState/cli/internal/locale" + "github.com/ActiveState/cli/internal/logging" "github.com/ActiveState/cli/internal/output" "github.com/ActiveState/cli/internal/output/renderers" "github.com/ActiveState/cli/internal/primer" @@ -60,6 +61,12 @@ type cveOutput struct { } func (r *Cve) Run(params *Params) error { + defer func() { + if rc := recover(); rc != nil { + logging.Error("Recovered from panic: %v", rc) + fmt.Printf("Recovered from panic: %v\n", rc) + } + }() if !params.Namespace.IsValid() && r.proj == nil { return rationalize.ErrNoProject } @@ -71,7 +78,7 @@ func (r *Cve) Run(params *Params) error { ) } - vulnerabilities, err := r.fetchVulnerabilities(*params.Namespace) + vulnerabilities, err := r.fetchVulnerabilities(params.Namespace) if err != nil { var errProjectNotFound *model.ErrProjectNotFound if errors.As(err, &errProjectNotFound) { @@ -101,7 +108,7 @@ func (r *Cve) Run(params *Params) error { return nil } -func (r *Cve) fetchVulnerabilities(namespaceOverride project.Namespaced) (*medmodel.CommitVulnerabilities, error) { +func (r *Cve) fetchVulnerabilities(namespaceOverride *project.Namespaced) (*medmodel.CommitVulnerabilities, error) { if namespaceOverride.IsValid() && namespaceOverride.CommitID == nil { resp, err := model.FetchProjectVulnerabilities(r.auth, namespaceOverride.Owner, namespaceOverride.Project) if err != nil { @@ -135,9 +142,6 @@ type SeverityCountOutput struct { } func (rd *cveOutput) MarshalOutput(format output.Format) interface{} { - if format != output.PlainFormatName { - return rd.data - } ri := &CveInfo{ fmt.Sprintf("[ACTIONABLE]%s[/RESET]", rd.data.Project), rd.data.CommitID, diff --git a/internal/svcctl/svcctl.go b/internal/svcctl/svcctl.go index 58933826f9..0887749d85 100644 --- a/internal/svcctl/svcctl.go +++ b/internal/svcctl/svcctl.go @@ -7,21 +7,19 @@ package svcctl import ( "context" "errors" - "fmt" "io" "os" - "path/filepath" "time" "github.com/ActiveState/cli/internal/constants" "github.com/ActiveState/cli/internal/errs" "github.com/ActiveState/cli/internal/fileutils" "github.com/ActiveState/cli/internal/installation" + "github.com/ActiveState/cli/internal/installation/storage" "github.com/ActiveState/cli/internal/ipc" "github.com/ActiveState/cli/internal/locale" "github.com/ActiveState/cli/internal/logging" "github.com/ActiveState/cli/internal/osutils" - "github.com/ActiveState/cli/internal/output" "github.com/ActiveState/cli/internal/profile" ) @@ -45,25 +43,30 @@ type IPCommunicator interface { SockPath() *ipc.SockPath } +type Outputer interface { + Notice(interface{}) +} + func NewIPCSockPathFromGlobals() *ipc.SockPath { - subdir := fmt.Sprintf("%s-%s", constants.CommandName, "ipc") - rootDir := filepath.Join(os.TempDir(), subdir) + rootDir := storage.AppDataPath() if os.Getenv(constants.ServiceSockDir) != "" { rootDir = os.Getenv(constants.ServiceSockDir) } - return &ipc.SockPath{ + sp := &ipc.SockPath{ RootDir: rootDir, AppName: constants.CommandName, AppChannel: constants.ChannelName, } + + return sp } func NewDefaultIPCClient() *ipc.Client { return ipc.NewClient(NewIPCSockPathFromGlobals()) } -func EnsureExecStartedAndLocateHTTP(ipComm IPCommunicator, exec, argText string, out output.Outputer) (addr string, err error) { +func EnsureExecStartedAndLocateHTTP(ipComm IPCommunicator, exec, argText string, out Outputer) (addr string, err error) { defer profile.Measure("svcctl:EnsureExecStartedAndLocateHTTP", time.Now()) addr, err = LocateHTTP(ipComm) @@ -91,7 +94,7 @@ func EnsureExecStartedAndLocateHTTP(ipComm IPCommunicator, exec, argText string, return addr, nil } -func EnsureStartedAndLocateHTTP(argText string, out output.Outputer) (addr string, err error) { +func EnsureStartedAndLocateHTTP(argText string, out Outputer) (addr string, err error) { svcExec, err := installation.ServiceExec() if err != nil { return "", locale.WrapError(err, "err_service_exec") @@ -146,7 +149,7 @@ func StopServer(ipComm IPCommunicator) error { return nil } -func startAndWait(ctx context.Context, ipComm IPCommunicator, exec, argText string, out output.Outputer) error { +func startAndWait(ctx context.Context, ipComm IPCommunicator, exec, argText string, out Outputer) error { defer profile.Measure("svcmanager:Start", time.Now()) if !fileutils.FileExists(exec) { @@ -174,7 +177,7 @@ var ( waitTimeoutL10nKey = "svcctl_wait_timeout" ) -func waitUp(ctx context.Context, ipComm IPCommunicator, out output.Outputer, debugInfo *debugData) error { +func waitUp(ctx context.Context, ipComm IPCommunicator, out Outputer, debugInfo *debugData) error { debugInfo.startWait() defer debugInfo.stopWait() diff --git a/internal/updater/updater.go b/internal/updater/updater.go index bd2a661cb0..3560ae1557 100644 --- a/internal/updater/updater.go +++ b/internal/updater/updater.go @@ -207,10 +207,7 @@ func (u *UpdateInstaller) InstallBlocking(installTargetPath string, args ...stri return errs.Wrap(err, "Could not check if State Tool was installed as admin") } - appdata, err := storage.AppDataPath() - if err != nil { - return errs.Wrap(err, "Could not detect appdata path") - } + appdata := storage.AppDataPath() // Protect against multiple updates happening simultaneously lockFile := filepath.Join(appdata, "install.lock") diff --git a/pkg/platform/model/svc.go b/pkg/platform/model/svc.go index d6e4af9ae3..43e7bcebc3 100644 --- a/pkg/platform/model/svc.go +++ b/pkg/platform/model/svc.go @@ -181,7 +181,6 @@ func (m *SvcModel) GetProcessesInUse(ctx context.Context, execDir string) ([]*gr // Note we respond with mono_models.JWT here for compatibility and to minimize the changeset at time of implementation. // We can revisit this in the future. func (m *SvcModel) GetJWT(ctx context.Context) (*mono_models.JWT, error) { - logging.Debug("Checking for GetJWT") defer profile.Measure("svc:GetJWT", time.Now()) r := request.NewJWTRequest() diff --git a/pkg/projectfile/projectfile.go b/pkg/projectfile/projectfile.go index e3ebcc5c62..b4b105b29b 100644 --- a/pkg/projectfile/projectfile.go +++ b/pkg/projectfile/projectfile.go @@ -1129,7 +1129,7 @@ func AddLockInfo(projectFilePath, branch, version string) error { projectRegex := regexp.MustCompile(fmt.Sprintf("(?m:(^project:\\s*%s))", ProjectURLRe)) lockString := fmt.Sprintf("%s@%s", branch, version) - lockUpdate := []byte(fmt.Sprintf("${1}\nlock: %s", lockString)) + lockUpdate := []byte(fmt.Sprintf(`${1}\nlock: %s`, lockString)) data, err = os.ReadFile(projectFilePath) if err != nil { diff --git a/scripts/ci/parallelize/parallelize.go b/scripts/ci/parallelize/parallelize.go index f963523645..fe135e2b63 100644 --- a/scripts/ci/parallelize/parallelize.go +++ b/scripts/ci/parallelize/parallelize.go @@ -68,10 +68,7 @@ func run() error { } func jobDir() string { - path, err := storage.AppDataPath() - if err != nil { - panic(err) - } + path := storage.AppDataPath() path = filepath.Join(path, "jobs") if err := fileutils.MkdirUnlessExists(path); err != nil { diff --git a/scripts/ci/payload-generator/main.go b/scripts/ci/payload-generator/main.go index d264c37853..d50e1bd235 100644 --- a/scripts/ci/payload-generator/main.go +++ b/scripts/ci/payload-generator/main.go @@ -71,6 +71,7 @@ func generatePayload(inDir, outDir, binDir, channel, version string) error { filepath.Join(inDir, constants.StateCmd+osutils.ExeExtension): binDir, filepath.Join(inDir, constants.StateSvcCmd+osutils.ExeExtension): binDir, filepath.Join(inDir, constants.StateExecutorCmd+osutils.ExeExtension): binDir, + filepath.Join(inDir, constants.StateMCPCmd+osutils.ExeExtension): binDir, } if err := copyFiles(files); err != nil { return fmt.Errorf(emsg, err) @@ -107,6 +108,7 @@ func copyFiles(files map[string]string) error { dest := filepath.Join(target, filepath.Base(src)) if err := fileutils.CopyFile(src, dest); err != nil { + fmt.Printf("Files in %s: %+v\n", filepath.Dir(src), fileutils.ListFilesUnsafe(filepath.Dir(src))) return fmt.Errorf("copy files (%s to %s): %w", src, target, err) } } diff --git a/test/integration/performance_svc_int_test.go b/test/integration/performance_svc_int_test.go index 5d05fa7615..46f0ecf20d 100644 --- a/test/integration/performance_svc_int_test.go +++ b/test/integration/performance_svc_int_test.go @@ -41,9 +41,7 @@ func (suite *PerformanceIntegrationTestSuite) TestSvcPerformance() { // This integration test is a bit special because it bypasses the spawning logic // so in order to get the right log files when debugging we manually provide the config dir - var err error - ts.Dirs.Config, err = storage.AppDataPath() - suite.Require().NoError(err) + ts.Dirs.Config = storage.AppDataPath() ipcClient := svcctl.NewDefaultIPCClient() var svcPort string diff --git a/vendor/github.com/shibukawa/configdir/LICENSE b/vendor/github.com/mark3labs/mcp-go/LICENSE similarity index 95% rename from vendor/github.com/shibukawa/configdir/LICENSE rename to vendor/github.com/mark3labs/mcp-go/LICENSE index b20af456a1..3d48435454 100644 --- a/vendor/github.com/shibukawa/configdir/LICENSE +++ b/vendor/github.com/mark3labs/mcp-go/LICENSE @@ -1,6 +1,6 @@ -The MIT License (MIT) +MIT License -Copyright (c) 2016 shibukawa +Copyright (c) 2024 Anthropic, PBC Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/vendor/github.com/mark3labs/mcp-go/client/client.go b/vendor/github.com/mark3labs/mcp-go/client/client.go new file mode 100644 index 0000000000..1d3cb1051e --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/client/client.go @@ -0,0 +1,84 @@ +// Package client provides MCP (Model Control Protocol) client implementations. +package client + +import ( + "context" + + "github.com/mark3labs/mcp-go/mcp" +) + +// MCPClient represents an MCP client interface +type MCPClient interface { + // Initialize sends the initial connection request to the server + Initialize( + ctx context.Context, + request mcp.InitializeRequest, + ) (*mcp.InitializeResult, error) + + // Ping checks if the server is alive + Ping(ctx context.Context) error + + // ListResources requests a list of available resources from the server + ListResources( + ctx context.Context, + request mcp.ListResourcesRequest, + ) (*mcp.ListResourcesResult, error) + + // ListResourceTemplates requests a list of available resource templates from the server + ListResourceTemplates( + ctx context.Context, + request mcp.ListResourceTemplatesRequest, + ) (*mcp.ListResourceTemplatesResult, + error) + + // ReadResource reads a specific resource from the server + ReadResource( + ctx context.Context, + request mcp.ReadResourceRequest, + ) (*mcp.ReadResourceResult, error) + + // Subscribe requests notifications for changes to a specific resource + Subscribe(ctx context.Context, request mcp.SubscribeRequest) error + + // Unsubscribe cancels notifications for a specific resource + Unsubscribe(ctx context.Context, request mcp.UnsubscribeRequest) error + + // ListPrompts requests a list of available prompts from the server + ListPrompts( + ctx context.Context, + request mcp.ListPromptsRequest, + ) (*mcp.ListPromptsResult, error) + + // GetPrompt retrieves a specific prompt from the server + GetPrompt( + ctx context.Context, + request mcp.GetPromptRequest, + ) (*mcp.GetPromptResult, error) + + // ListTools requests a list of available tools from the server + ListTools( + ctx context.Context, + request mcp.ListToolsRequest, + ) (*mcp.ListToolsResult, error) + + // CallTool invokes a specific tool on the server + CallTool( + ctx context.Context, + request mcp.CallToolRequest, + ) (*mcp.CallToolResult, error) + + // SetLevel sets the logging level for the server + SetLevel(ctx context.Context, request mcp.SetLevelRequest) error + + // Complete requests completion options for a given argument + Complete( + ctx context.Context, + request mcp.CompleteRequest, + ) (*mcp.CompleteResult, error) + + // Close client connection and cleanup resources + Close() error + + // OnNotification registers a handler for notifications + OnNotification(handler func(notification mcp.JSONRPCNotification)) +} diff --git a/vendor/github.com/mark3labs/mcp-go/client/sse.go b/vendor/github.com/mark3labs/mcp-go/client/sse.go new file mode 100644 index 0000000000..cf4a1028e0 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/client/sse.go @@ -0,0 +1,588 @@ +package client + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/mark3labs/mcp-go/mcp" +) + +// SSEMCPClient implements the MCPClient interface using Server-Sent Events (SSE). +// It maintains a persistent HTTP connection to receive server-pushed events +// while sending requests over regular HTTP POST calls. The client handles +// automatic reconnection and message routing between requests and responses. +type SSEMCPClient struct { + baseURL *url.URL + endpoint *url.URL + httpClient *http.Client + requestID atomic.Int64 + responses map[int64]chan RPCResponse + mu sync.RWMutex + done chan struct{} + initialized bool + notifications []func(mcp.JSONRPCNotification) + notifyMu sync.RWMutex + endpointChan chan struct{} + capabilities mcp.ServerCapabilities + headers map[string]string + sseReadTimeout time.Duration +} + +type ClientOption func(*SSEMCPClient) + +func WithHeaders(headers map[string]string) ClientOption { + return func(sc *SSEMCPClient) { + sc.headers = headers + } +} + +func WithSSEReadTimeout(timeout time.Duration) ClientOption { + return func(sc *SSEMCPClient) { + sc.sseReadTimeout = timeout + } +} + +// NewSSEMCPClient creates a new SSE-based MCP client with the given base URL. +// Returns an error if the URL is invalid. +func NewSSEMCPClient(baseURL string, options ...ClientOption) (*SSEMCPClient, error) { + parsedURL, err := url.Parse(baseURL) + if err != nil { + return nil, fmt.Errorf("invalid URL: %w", err) + } + + smc := &SSEMCPClient{ + baseURL: parsedURL, + httpClient: &http.Client{}, + responses: make(map[int64]chan RPCResponse), + done: make(chan struct{}), + endpointChan: make(chan struct{}), + sseReadTimeout: 30 * time.Second, + headers: make(map[string]string), + } + + for _, opt := range options { + opt(smc) + } + + return smc, nil +} + +// Start initiates the SSE connection to the server and waits for the endpoint information. +// Returns an error if the connection fails or times out waiting for the endpoint. +func (c *SSEMCPClient) Start(ctx context.Context) error { + + req, err := http.NewRequestWithContext(ctx, "GET", c.baseURL.String(), nil) + + if err != nil { + + return fmt.Errorf("failed to create request: %w", err) + + } + + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("Cache-Control", "no-cache") + req.Header.Set("Connection", "keep-alive") + + resp, err := c.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to connect to SSE stream: %w", err) + } + + if resp.StatusCode != http.StatusOK { + resp.Body.Close() + return fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + go c.readSSE(resp.Body) + + // Wait for the endpoint to be received + + select { + case <-c.endpointChan: + // Endpoint received, proceed + case <-ctx.Done(): + return fmt.Errorf("context cancelled while waiting for endpoint") + case <-time.After(30 * time.Second): // Add a timeout + return fmt.Errorf("timeout waiting for endpoint") + } + + return nil +} + +// readSSE continuously reads the SSE stream and processes events. +// It runs until the connection is closed or an error occurs. +func (c *SSEMCPClient) readSSE(reader io.ReadCloser) { + defer reader.Close() + + br := bufio.NewReader(reader) + var event, data string + + ctx, cancel := context.WithTimeout(context.Background(), c.sseReadTimeout) + defer cancel() + + for { + select { + case <-ctx.Done(): + return + default: + line, err := br.ReadString('\n') + if err != nil { + if err == io.EOF { + // Process any pending event before exit + if event != "" && data != "" { + c.handleSSEEvent(event, data) + } + break + } + select { + case <-c.done: + return + default: + fmt.Printf("SSE stream error: %v\n", err) + return + } + } + + // Remove only newline markers + line = strings.TrimRight(line, "\r\n") + if line == "" { + // Empty line means end of event + if event != "" && data != "" { + c.handleSSEEvent(event, data) + event = "" + data = "" + } + continue + } + + if strings.HasPrefix(line, "event:") { + event = strings.TrimSpace(strings.TrimPrefix(line, "event:")) + } else if strings.HasPrefix(line, "data:") { + data = strings.TrimSpace(strings.TrimPrefix(line, "data:")) + } + } + } +} + +// handleSSEEvent processes SSE events based on their type. +// Handles 'endpoint' events for connection setup and 'message' events for JSON-RPC communication. +func (c *SSEMCPClient) handleSSEEvent(event, data string) { + switch event { + case "endpoint": + endpoint, err := c.baseURL.Parse(data) + if err != nil { + fmt.Printf("Error parsing endpoint URL: %v\n", err) + return + } + if endpoint.Host != c.baseURL.Host { + fmt.Printf("Endpoint origin does not match connection origin\n") + return + } + c.endpoint = endpoint + close(c.endpointChan) + + case "message": + var baseMessage struct { + JSONRPC string `json:"jsonrpc"` + ID *int64 `json:"id,omitempty"` + Method string `json:"method,omitempty"` + Result json.RawMessage `json:"result,omitempty"` + Error *struct { + Code int `json:"code"` + Message string `json:"message"` + } `json:"error,omitempty"` + } + + if err := json.Unmarshal([]byte(data), &baseMessage); err != nil { + fmt.Printf("Error unmarshaling message: %v\n", err) + return + } + + // Handle notification + if baseMessage.ID == nil { + var notification mcp.JSONRPCNotification + if err := json.Unmarshal([]byte(data), ¬ification); err != nil { + return + } + c.notifyMu.RLock() + for _, handler := range c.notifications { + handler(notification) + } + c.notifyMu.RUnlock() + return + } + + c.mu.RLock() + ch, ok := c.responses[*baseMessage.ID] + c.mu.RUnlock() + + if ok { + if baseMessage.Error != nil { + ch <- RPCResponse{ + Error: &baseMessage.Error.Message, + } + } else { + ch <- RPCResponse{ + Response: &baseMessage.Result, + } + } + c.mu.Lock() + delete(c.responses, *baseMessage.ID) + c.mu.Unlock() + } + } +} + +// OnNotification registers a handler function to be called when notifications are received. +// Multiple handlers can be registered and will be called in the order they were added. +func (c *SSEMCPClient) OnNotification( + handler func(notification mcp.JSONRPCNotification), +) { + c.notifyMu.Lock() + defer c.notifyMu.Unlock() + c.notifications = append(c.notifications, handler) +} + +// sendRequest sends a JSON-RPC request to the server and waits for a response. +// Returns the raw JSON response message or an error if the request fails. +func (c *SSEMCPClient) sendRequest( + ctx context.Context, + method string, + params interface{}, +) (*json.RawMessage, error) { + if !c.initialized && method != "initialize" { + return nil, fmt.Errorf("client not initialized") + } + + if c.endpoint == nil { + return nil, fmt.Errorf("endpoint not received") + } + + id := c.requestID.Add(1) + + request := mcp.JSONRPCRequest{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: id, + Request: mcp.Request{ + Method: method, + }, + Params: params, + } + + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + responseChan := make(chan RPCResponse, 1) + c.mu.Lock() + c.responses[id] = responseChan + c.mu.Unlock() + + req, err := http.NewRequestWithContext( + ctx, + "POST", + c.endpoint.String(), + bytes.NewReader(requestBytes), + ) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + // set custom http headers + for k, v := range c.headers { + req.Header.Set(k, v) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK && + resp.StatusCode != http.StatusAccepted { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf( + "request failed with status %d: %s", + resp.StatusCode, + body, + ) + } + + select { + case <-ctx.Done(): + c.mu.Lock() + delete(c.responses, id) + c.mu.Unlock() + return nil, ctx.Err() + case response := <-responseChan: + if response.Error != nil { + return nil, errors.New(*response.Error) + } + return response.Response, nil + } +} + +func (c *SSEMCPClient) Initialize( + ctx context.Context, + request mcp.InitializeRequest, +) (*mcp.InitializeResult, error) { + // Ensure we send a params object with all required fields + params := struct { + ProtocolVersion string `json:"protocolVersion"` + ClientInfo mcp.Implementation `json:"clientInfo"` + Capabilities mcp.ClientCapabilities `json:"capabilities"` + }{ + ProtocolVersion: request.Params.ProtocolVersion, + ClientInfo: request.Params.ClientInfo, + Capabilities: request.Params.Capabilities, // Will be empty struct if not set + } + + response, err := c.sendRequest(ctx, "initialize", params) + if err != nil { + return nil, err + } + + var result mcp.InitializeResult + if err := json.Unmarshal(*response, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + // Store capabilities + c.capabilities = result.Capabilities + + // Send initialized notification + notification := mcp.JSONRPCNotification{ + JSONRPC: mcp.JSONRPC_VERSION, + Notification: mcp.Notification{ + Method: "notifications/initialized", + }, + } + + notificationBytes, err := json.Marshal(notification) + if err != nil { + return nil, fmt.Errorf( + "failed to marshal initialized notification: %w", + err, + ) + } + + req, err := http.NewRequestWithContext( + ctx, + "POST", + c.endpoint.String(), + bytes.NewReader(notificationBytes), + ) + if err != nil { + return nil, fmt.Errorf("failed to create notification request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf( + "failed to send initialized notification: %w", + err, + ) + } + resp.Body.Close() + + c.initialized = true + return &result, nil +} + +func (c *SSEMCPClient) Ping(ctx context.Context) error { + _, err := c.sendRequest(ctx, "ping", nil) + return err +} + +func (c *SSEMCPClient) ListResources( + ctx context.Context, + request mcp.ListResourcesRequest, +) (*mcp.ListResourcesResult, error) { + response, err := c.sendRequest(ctx, "resources/list", request.Params) + if err != nil { + return nil, err + } + + var result mcp.ListResourcesResult + if err := json.Unmarshal(*response, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return &result, nil +} + +func (c *SSEMCPClient) ListResourceTemplates( + ctx context.Context, + request mcp.ListResourceTemplatesRequest, +) (*mcp.ListResourceTemplatesResult, error) { + response, err := c.sendRequest( + ctx, + "resources/templates/list", + request.Params, + ) + if err != nil { + return nil, err + } + + var result mcp.ListResourceTemplatesResult + if err := json.Unmarshal(*response, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return &result, nil +} + +func (c *SSEMCPClient) ReadResource( + ctx context.Context, + request mcp.ReadResourceRequest, +) (*mcp.ReadResourceResult, error) { + response, err := c.sendRequest(ctx, "resources/read", request.Params) + if err != nil { + return nil, err + } + + return mcp.ParseReadResourceResult(response) +} + +func (c *SSEMCPClient) Subscribe( + ctx context.Context, + request mcp.SubscribeRequest, +) error { + _, err := c.sendRequest(ctx, "resources/subscribe", request.Params) + return err +} + +func (c *SSEMCPClient) Unsubscribe( + ctx context.Context, + request mcp.UnsubscribeRequest, +) error { + _, err := c.sendRequest(ctx, "resources/unsubscribe", request.Params) + return err +} + +func (c *SSEMCPClient) ListPrompts( + ctx context.Context, + request mcp.ListPromptsRequest, +) (*mcp.ListPromptsResult, error) { + response, err := c.sendRequest(ctx, "prompts/list", request.Params) + if err != nil { + return nil, err + } + + var result mcp.ListPromptsResult + if err := json.Unmarshal(*response, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return &result, nil +} + +func (c *SSEMCPClient) GetPrompt( + ctx context.Context, + request mcp.GetPromptRequest, +) (*mcp.GetPromptResult, error) { + response, err := c.sendRequest(ctx, "prompts/get", request.Params) + if err != nil { + return nil, err + } + + return mcp.ParseGetPromptResult(response) +} + +func (c *SSEMCPClient) ListTools( + ctx context.Context, + request mcp.ListToolsRequest, +) (*mcp.ListToolsResult, error) { + response, err := c.sendRequest(ctx, "tools/list", request.Params) + if err != nil { + return nil, err + } + + var result mcp.ListToolsResult + if err := json.Unmarshal(*response, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return &result, nil +} + +func (c *SSEMCPClient) CallTool( + ctx context.Context, + request mcp.CallToolRequest, +) (*mcp.CallToolResult, error) { + response, err := c.sendRequest(ctx, "tools/call", request.Params) + if err != nil { + return nil, err + } + + return mcp.ParseCallToolResult(response) +} + +func (c *SSEMCPClient) SetLevel( + ctx context.Context, + request mcp.SetLevelRequest, +) error { + _, err := c.sendRequest(ctx, "logging/setLevel", request.Params) + return err +} + +func (c *SSEMCPClient) Complete( + ctx context.Context, + request mcp.CompleteRequest, +) (*mcp.CompleteResult, error) { + response, err := c.sendRequest(ctx, "completion/complete", request.Params) + if err != nil { + return nil, err + } + + var result mcp.CompleteResult + if err := json.Unmarshal(*response, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return &result, nil +} + +// Helper methods + +// GetEndpoint returns the current endpoint URL for the SSE connection. +func (c *SSEMCPClient) GetEndpoint() *url.URL { + return c.endpoint +} + +// Close shuts down the SSE client connection and cleans up any pending responses. +// Returns an error if the shutdown process fails. +func (c *SSEMCPClient) Close() error { + select { + case <-c.done: + return nil // Already closed + default: + close(c.done) + } + + // Clean up any pending responses + c.mu.Lock() + for _, ch := range c.responses { + close(ch) + } + c.responses = make(map[int64]chan RPCResponse) + c.mu.Unlock() + + return nil +} diff --git a/vendor/github.com/mark3labs/mcp-go/client/stdio.go b/vendor/github.com/mark3labs/mcp-go/client/stdio.go new file mode 100644 index 0000000000..8e0845dca6 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/client/stdio.go @@ -0,0 +1,457 @@ +package client + +import ( + "bufio" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "os" + "os/exec" + "sync" + "sync/atomic" + + "github.com/mark3labs/mcp-go/mcp" +) + +// StdioMCPClient implements the MCPClient interface using stdio communication. +// It launches a subprocess and communicates with it via standard input/output streams +// using JSON-RPC messages. The client handles message routing between requests and +// responses, and supports asynchronous notifications. +type StdioMCPClient struct { + cmd *exec.Cmd + stdin io.WriteCloser + stdout *bufio.Reader + stderr io.ReadCloser + requestID atomic.Int64 + responses map[int64]chan RPCResponse + mu sync.RWMutex + done chan struct{} + initialized bool + notifications []func(mcp.JSONRPCNotification) + notifyMu sync.RWMutex + capabilities mcp.ServerCapabilities +} + +// NewStdioMCPClient creates a new stdio-based MCP client that communicates with a subprocess. +// It launches the specified command with given arguments and sets up stdin/stdout pipes for communication. +// Returns an error if the subprocess cannot be started or the pipes cannot be created. +func NewStdioMCPClient( + command string, + env []string, + args ...string, +) (*StdioMCPClient, error) { + cmd := exec.Command(command, args...) + + mergedEnv := os.Environ() + mergedEnv = append(mergedEnv, env...) + + cmd.Env = mergedEnv + + stdin, err := cmd.StdinPipe() + if err != nil { + return nil, fmt.Errorf("failed to create stdin pipe: %w", err) + } + + stdout, err := cmd.StdoutPipe() + if err != nil { + return nil, fmt.Errorf("failed to create stdout pipe: %w", err) + } + + stderr, err := cmd.StderrPipe() + if err != nil { + return nil, fmt.Errorf("failed to create stderr pipe: %w", err) + } + + client := &StdioMCPClient{ + cmd: cmd, + stdin: stdin, + stderr: stderr, + stdout: bufio.NewReader(stdout), + responses: make(map[int64]chan RPCResponse), + done: make(chan struct{}), + } + + if err := cmd.Start(); err != nil { + return nil, fmt.Errorf("failed to start command: %w", err) + } + + // Start reading responses in a goroutine and wait for it to be ready + ready := make(chan struct{}) + go func() { + close(ready) + client.readResponses() + }() + <-ready + + return client, nil +} + +// Close shuts down the stdio client, closing the stdin pipe and waiting for the subprocess to exit. +// Returns an error if there are issues closing stdin or waiting for the subprocess to terminate. +func (c *StdioMCPClient) Close() error { + close(c.done) + if err := c.stdin.Close(); err != nil { + return fmt.Errorf("failed to close stdin: %w", err) + } + if err := c.stderr.Close(); err != nil { + return fmt.Errorf("failed to close stderr: %w", err) + } + return c.cmd.Wait() +} + +// Stderr returns a reader for the stderr output of the subprocess. +// This can be used to capture error messages or logs from the subprocess. +func (c *StdioMCPClient) Stderr() io.Reader { + return c.stderr +} + +// OnNotification registers a handler function to be called when notifications are received. +// Multiple handlers can be registered and will be called in the order they were added. +func (c *StdioMCPClient) OnNotification( + handler func(notification mcp.JSONRPCNotification), +) { + c.notifyMu.Lock() + defer c.notifyMu.Unlock() + c.notifications = append(c.notifications, handler) +} + +// readResponses continuously reads and processes responses from the server's stdout. +// It handles both responses to requests and notifications, routing them appropriately. +// Runs until the done channel is closed or an error occurs reading from stdout. +func (c *StdioMCPClient) readResponses() { + for { + select { + case <-c.done: + return + default: + line, err := c.stdout.ReadString('\n') + if err != nil { + if err != io.EOF { + fmt.Printf("Error reading response: %v\n", err) + } + return + } + + var baseMessage struct { + JSONRPC string `json:"jsonrpc"` + ID *int64 `json:"id,omitempty"` + Method string `json:"method,omitempty"` + Result json.RawMessage `json:"result,omitempty"` + Error *struct { + Code int `json:"code"` + Message string `json:"message"` + } `json:"error,omitempty"` + } + + if err := json.Unmarshal([]byte(line), &baseMessage); err != nil { + continue + } + + // Handle notification + if baseMessage.ID == nil { + var notification mcp.JSONRPCNotification + if err := json.Unmarshal([]byte(line), ¬ification); err != nil { + continue + } + c.notifyMu.RLock() + for _, handler := range c.notifications { + handler(notification) + } + c.notifyMu.RUnlock() + continue + } + + c.mu.RLock() + ch, ok := c.responses[*baseMessage.ID] + c.mu.RUnlock() + + if ok { + if baseMessage.Error != nil { + ch <- RPCResponse{ + Error: &baseMessage.Error.Message, + } + } else { + ch <- RPCResponse{ + Response: &baseMessage.Result, + } + } + c.mu.Lock() + delete(c.responses, *baseMessage.ID) + c.mu.Unlock() + } + } + } +} + +// sendRequest sends a JSON-RPC request to the server and waits for a response. +// It creates a unique request ID, sends the request over stdin, and waits for +// the corresponding response or context cancellation. +// Returns the raw JSON response message or an error if the request fails. +func (c *StdioMCPClient) sendRequest( + ctx context.Context, + method string, + params interface{}, +) (*json.RawMessage, error) { + if !c.initialized && method != "initialize" { + return nil, fmt.Errorf("client not initialized") + } + + id := c.requestID.Add(1) + + // Create the complete request structure + request := mcp.JSONRPCRequest{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: id, + Request: mcp.Request{ + Method: method, + }, + Params: params, + } + + responseChan := make(chan RPCResponse, 1) + c.mu.Lock() + c.responses[id] = responseChan + c.mu.Unlock() + + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + requestBytes = append(requestBytes, '\n') + + if _, err := c.stdin.Write(requestBytes); err != nil { + return nil, fmt.Errorf("failed to write request: %w", err) + } + + select { + case <-ctx.Done(): + c.mu.Lock() + delete(c.responses, id) + c.mu.Unlock() + return nil, ctx.Err() + case response := <-responseChan: + if response.Error != nil { + return nil, errors.New(*response.Error) + } + return response.Response, nil + } +} + +func (c *StdioMCPClient) Ping(ctx context.Context) error { + _, err := c.sendRequest(ctx, "ping", nil) + return err +} + +func (c *StdioMCPClient) Initialize( + ctx context.Context, + request mcp.InitializeRequest, +) (*mcp.InitializeResult, error) { + // This structure ensures Capabilities is always included in JSON + params := struct { + ProtocolVersion string `json:"protocolVersion"` + ClientInfo mcp.Implementation `json:"clientInfo"` + Capabilities mcp.ClientCapabilities `json:"capabilities"` + }{ + ProtocolVersion: request.Params.ProtocolVersion, + ClientInfo: request.Params.ClientInfo, + Capabilities: request.Params.Capabilities, // Will be empty struct if not set + } + + response, err := c.sendRequest(ctx, "initialize", params) + if err != nil { + return nil, err + } + + var result mcp.InitializeResult + if err := json.Unmarshal(*response, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + // Store capabilities + c.capabilities = result.Capabilities + + // Send initialized notification + notification := mcp.JSONRPCNotification{ + JSONRPC: mcp.JSONRPC_VERSION, + Notification: mcp.Notification{ + Method: "notifications/initialized", + }, + } + + notificationBytes, err := json.Marshal(notification) + if err != nil { + return nil, fmt.Errorf( + "failed to marshal initialized notification: %w", + err, + ) + } + notificationBytes = append(notificationBytes, '\n') + + if _, err := c.stdin.Write(notificationBytes); err != nil { + return nil, fmt.Errorf( + "failed to send initialized notification: %w", + err, + ) + } + + c.initialized = true + return &result, nil +} + +func (c *StdioMCPClient) ListResources( + ctx context.Context, + request mcp.ListResourcesRequest, +) (*mcp. + ListResourcesResult, error) { + response, err := c.sendRequest( + ctx, + "resources/list", + request.Params, + ) + if err != nil { + return nil, err + } + + var result mcp.ListResourcesResult + if err := json.Unmarshal(*response, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return &result, nil +} + +func (c *StdioMCPClient) ListResourceTemplates( + ctx context.Context, + request mcp.ListResourceTemplatesRequest, +) (*mcp. + ListResourceTemplatesResult, error) { + response, err := c.sendRequest( + ctx, + "resources/templates/list", + request.Params, + ) + if err != nil { + return nil, err + } + + var result mcp.ListResourceTemplatesResult + if err := json.Unmarshal(*response, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return &result, nil +} + +func (c *StdioMCPClient) ReadResource( + ctx context.Context, + request mcp.ReadResourceRequest, +) (*mcp.ReadResourceResult, + error) { + response, err := c.sendRequest(ctx, "resources/read", request.Params) + if err != nil { + return nil, err + } + + return mcp.ParseReadResourceResult(response) +} + +func (c *StdioMCPClient) Subscribe( + ctx context.Context, + request mcp.SubscribeRequest, +) error { + _, err := c.sendRequest(ctx, "resources/subscribe", request.Params) + return err +} + +func (c *StdioMCPClient) Unsubscribe( + ctx context.Context, + request mcp.UnsubscribeRequest, +) error { + _, err := c.sendRequest(ctx, "resources/unsubscribe", request.Params) + return err +} + +func (c *StdioMCPClient) ListPrompts( + ctx context.Context, + request mcp.ListPromptsRequest, +) (*mcp.ListPromptsResult, error) { + response, err := c.sendRequest(ctx, "prompts/list", request.Params) + if err != nil { + return nil, err + } + + var result mcp.ListPromptsResult + if err := json.Unmarshal(*response, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return &result, nil +} + +func (c *StdioMCPClient) GetPrompt( + ctx context.Context, + request mcp.GetPromptRequest, +) (*mcp.GetPromptResult, error) { + response, err := c.sendRequest(ctx, "prompts/get", request.Params) + if err != nil { + return nil, err + } + + return mcp.ParseGetPromptResult(response) +} + +func (c *StdioMCPClient) ListTools( + ctx context.Context, + request mcp.ListToolsRequest, +) (*mcp.ListToolsResult, error) { + response, err := c.sendRequest(ctx, "tools/list", request.Params) + if err != nil { + return nil, err + } + + var result mcp.ListToolsResult + if err := json.Unmarshal(*response, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return &result, nil +} + +func (c *StdioMCPClient) CallTool( + ctx context.Context, + request mcp.CallToolRequest, +) (*mcp.CallToolResult, error) { + response, err := c.sendRequest(ctx, "tools/call", request.Params) + if err != nil { + return nil, err + } + + return mcp.ParseCallToolResult(response) +} + +func (c *StdioMCPClient) SetLevel( + ctx context.Context, + request mcp.SetLevelRequest, +) error { + _, err := c.sendRequest(ctx, "logging/setLevel", request.Params) + return err +} + +func (c *StdioMCPClient) Complete( + ctx context.Context, + request mcp.CompleteRequest, +) (*mcp.CompleteResult, error) { + response, err := c.sendRequest(ctx, "completion/complete", request.Params) + if err != nil { + return nil, err + } + + var result mcp.CompleteResult + if err := json.Unmarshal(*response, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return &result, nil +} diff --git a/vendor/github.com/mark3labs/mcp-go/client/types.go b/vendor/github.com/mark3labs/mcp-go/client/types.go new file mode 100644 index 0000000000..4402bd0240 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/client/types.go @@ -0,0 +1,8 @@ +package client + +import "encoding/json" + +type RPCResponse struct { + Error *string + Response *json.RawMessage +} diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/prompts.go b/vendor/github.com/mark3labs/mcp-go/mcp/prompts.go new file mode 100644 index 0000000000..bc12a72976 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/mcp/prompts.go @@ -0,0 +1,163 @@ +package mcp + +/* Prompts */ + +// ListPromptsRequest is sent from the client to request a list of prompts and +// prompt templates the server has. +type ListPromptsRequest struct { + PaginatedRequest +} + +// ListPromptsResult is the server's response to a prompts/list request from +// the client. +type ListPromptsResult struct { + PaginatedResult + Prompts []Prompt `json:"prompts"` +} + +// GetPromptRequest is used by the client to get a prompt provided by the +// server. +type GetPromptRequest struct { + Request + Params struct { + // The name of the prompt or prompt template. + Name string `json:"name"` + // Arguments to use for templating the prompt. + Arguments map[string]string `json:"arguments,omitempty"` + } `json:"params"` +} + +// GetPromptResult is the server's response to a prompts/get request from the +// client. +type GetPromptResult struct { + Result + // An optional description for the prompt. + Description string `json:"description,omitempty"` + Messages []PromptMessage `json:"messages"` +} + +// Prompt represents a prompt or prompt template that the server offers. +// If Arguments is non-nil and non-empty, this indicates the prompt is a template +// that requires argument values to be provided when calling prompts/get. +// If Arguments is nil or empty, this is a static prompt that takes no arguments. +type Prompt struct { + // The name of the prompt or prompt template. + Name string `json:"name"` + // An optional description of what this prompt provides + Description string `json:"description,omitempty"` + // A list of arguments to use for templating the prompt. + // The presence of arguments indicates this is a template prompt. + Arguments []PromptArgument `json:"arguments,omitempty"` +} + +// PromptArgument describes an argument that a prompt template can accept. +// When a prompt includes arguments, clients must provide values for all +// required arguments when making a prompts/get request. +type PromptArgument struct { + // The name of the argument. + Name string `json:"name"` + // A human-readable description of the argument. + Description string `json:"description,omitempty"` + // Whether this argument must be provided. + // If true, clients must include this argument when calling prompts/get. + Required bool `json:"required,omitempty"` +} + +// Role represents the sender or recipient of messages and data in a +// conversation. +type Role string + +const ( + RoleUser Role = "user" + RoleAssistant Role = "assistant" +) + +// PromptMessage describes a message returned as part of a prompt. +// +// This is similar to `SamplingMessage`, but also supports the embedding of +// resources from the MCP server. +type PromptMessage struct { + Role Role `json:"role"` + Content Content `json:"content"` // Can be TextContent, ImageContent, or EmbeddedResource +} + +// PromptListChangedNotification is an optional notification from the server +// to the client, informing it that the list of prompts it offers has changed. This +// may be issued by servers without any previous subscription from the client. +type PromptListChangedNotification struct { + Notification +} + +// PromptOption is a function that configures a Prompt. +// It provides a flexible way to set various properties of a Prompt using the functional options pattern. +type PromptOption func(*Prompt) + +// ArgumentOption is a function that configures a PromptArgument. +// It allows for flexible configuration of prompt arguments using the functional options pattern. +type ArgumentOption func(*PromptArgument) + +// +// Core Prompt Functions +// + +// NewPrompt creates a new Prompt with the given name and options. +// The prompt will be configured based on the provided options. +// Options are applied in order, allowing for flexible prompt configuration. +func NewPrompt(name string, opts ...PromptOption) Prompt { + prompt := Prompt{ + Name: name, + } + + for _, opt := range opts { + opt(&prompt) + } + + return prompt +} + +// WithPromptDescription adds a description to the Prompt. +// The description should provide a clear, human-readable explanation of what the prompt does. +func WithPromptDescription(description string) PromptOption { + return func(p *Prompt) { + p.Description = description + } +} + +// WithArgument adds an argument to the prompt's argument list. +// The argument will be configured based on the provided options. +func WithArgument(name string, opts ...ArgumentOption) PromptOption { + return func(p *Prompt) { + arg := PromptArgument{ + Name: name, + } + + for _, opt := range opts { + opt(&arg) + } + + if p.Arguments == nil { + p.Arguments = make([]PromptArgument, 0) + } + p.Arguments = append(p.Arguments, arg) + } +} + +// +// Argument Options +// + +// ArgumentDescription adds a description to a prompt argument. +// The description should explain the purpose and expected values of the argument. +func ArgumentDescription(desc string) ArgumentOption { + return func(arg *PromptArgument) { + arg.Description = desc + } +} + +// RequiredArgument marks an argument as required in the prompt. +// Required arguments must be provided when getting the prompt. +func RequiredArgument() ArgumentOption { + return func(arg *PromptArgument) { + arg.Required = true + } +} diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/resources.go b/vendor/github.com/mark3labs/mcp-go/mcp/resources.go new file mode 100644 index 0000000000..51cdd25dd3 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/mcp/resources.go @@ -0,0 +1,105 @@ +package mcp + +import "github.com/yosida95/uritemplate/v3" + +// ResourceOption is a function that configures a Resource. +// It provides a flexible way to set various properties of a Resource using the functional options pattern. +type ResourceOption func(*Resource) + +// NewResource creates a new Resource with the given URI, name and options. +// The resource will be configured based on the provided options. +// Options are applied in order, allowing for flexible resource configuration. +func NewResource(uri string, name string, opts ...ResourceOption) Resource { + resource := Resource{ + URI: uri, + Name: name, + } + + for _, opt := range opts { + opt(&resource) + } + + return resource +} + +// WithResourceDescription adds a description to the Resource. +// The description should provide a clear, human-readable explanation of what the resource represents. +func WithResourceDescription(description string) ResourceOption { + return func(r *Resource) { + r.Description = description + } +} + +// WithMIMEType sets the MIME type for the Resource. +// This should indicate the format of the resource's contents. +func WithMIMEType(mimeType string) ResourceOption { + return func(r *Resource) { + r.MIMEType = mimeType + } +} + +// WithAnnotations adds annotations to the Resource. +// Annotations can provide additional metadata about the resource's intended use. +func WithAnnotations(audience []Role, priority float64) ResourceOption { + return func(r *Resource) { + if r.Annotations == nil { + r.Annotations = &struct { + Audience []Role `json:"audience,omitempty"` + Priority float64 `json:"priority,omitempty"` + }{} + } + r.Annotations.Audience = audience + r.Annotations.Priority = priority + } +} + +// ResourceTemplateOption is a function that configures a ResourceTemplate. +// It provides a flexible way to set various properties of a ResourceTemplate using the functional options pattern. +type ResourceTemplateOption func(*ResourceTemplate) + +// NewResourceTemplate creates a new ResourceTemplate with the given URI template, name and options. +// The template will be configured based on the provided options. +// Options are applied in order, allowing for flexible template configuration. +func NewResourceTemplate(uriTemplate string, name string, opts ...ResourceTemplateOption) ResourceTemplate { + template := ResourceTemplate{ + URITemplate: &URITemplate{Template: uritemplate.MustNew(uriTemplate)}, + Name: name, + } + + for _, opt := range opts { + opt(&template) + } + + return template +} + +// WithTemplateDescription adds a description to the ResourceTemplate. +// The description should provide a clear, human-readable explanation of what resources this template represents. +func WithTemplateDescription(description string) ResourceTemplateOption { + return func(t *ResourceTemplate) { + t.Description = description + } +} + +// WithTemplateMIMEType sets the MIME type for the ResourceTemplate. +// This should only be set if all resources matching this template will have the same type. +func WithTemplateMIMEType(mimeType string) ResourceTemplateOption { + return func(t *ResourceTemplate) { + t.MIMEType = mimeType + } +} + +// WithTemplateAnnotations adds annotations to the ResourceTemplate. +// Annotations can provide additional metadata about the template's intended use. +func WithTemplateAnnotations(audience []Role, priority float64) ResourceTemplateOption { + return func(t *ResourceTemplate) { + if t.Annotations == nil { + t.Annotations = &struct { + Audience []Role `json:"audience,omitempty"` + Priority float64 `json:"priority,omitempty"` + }{} + } + t.Annotations.Audience = audience + t.Annotations.Priority = priority + } +} diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/tools.go b/vendor/github.com/mark3labs/mcp-go/mcp/tools.go new file mode 100644 index 0000000000..c4c1b1dec0 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/mcp/tools.go @@ -0,0 +1,466 @@ +package mcp + +import ( + "encoding/json" + "errors" + "fmt" +) + +var errToolSchemaConflict = errors.New("provide either InputSchema or RawInputSchema, not both") + +// ListToolsRequest is sent from the client to request a list of tools the +// server has. +type ListToolsRequest struct { + PaginatedRequest +} + +// ListToolsResult is the server's response to a tools/list request from the +// client. +type ListToolsResult struct { + PaginatedResult + Tools []Tool `json:"tools"` +} + +// CallToolResult is the server's response to a tool call. +// +// Any errors that originate from the tool SHOULD be reported inside the result +// object, with `isError` set to true, _not_ as an MCP protocol-level error +// response. Otherwise, the LLM would not be able to see that an error occurred +// and self-correct. +// +// However, any errors in _finding_ the tool, an error indicating that the +// server does not support tool calls, or any other exceptional conditions, +// should be reported as an MCP error response. +type CallToolResult struct { + Result + Content []Content `json:"content"` // Can be TextContent, ImageContent, or EmbeddedResource + // Whether the tool call ended in an error. + // + // If not set, this is assumed to be false (the call was successful). + IsError bool `json:"isError,omitempty"` +} + +// CallToolRequest is used by the client to invoke a tool provided by the server. +type CallToolRequest struct { + Request + Params struct { + Name string `json:"name"` + Arguments map[string]interface{} `json:"arguments,omitempty"` + Meta *struct { + // If specified, the caller is requesting out-of-band progress + // notifications for this request (as represented by + // notifications/progress). The value of this parameter is an + // opaque token that will be attached to any subsequent + // notifications. The receiver is not obligated to provide these + // notifications. + ProgressToken ProgressToken `json:"progressToken,omitempty"` + } `json:"_meta,omitempty"` + } `json:"params"` +} + +// ToolListChangedNotification is an optional notification from the server to +// the client, informing it that the list of tools it offers has changed. This may +// be issued by servers without any previous subscription from the client. +type ToolListChangedNotification struct { + Notification +} + +// Tool represents the definition for a tool the client can call. +type Tool struct { + // The name of the tool. + Name string `json:"name"` + // A human-readable description of the tool. + Description string `json:"description,omitempty"` + // A JSON Schema object defining the expected parameters for the tool. + InputSchema ToolInputSchema `json:"inputSchema"` + // Alternative to InputSchema - allows arbitrary JSON Schema to be provided + RawInputSchema json.RawMessage `json:"-"` // Hide this from JSON marshaling +} + +// MarshalJSON implements the json.Marshaler interface for Tool. +// It handles marshaling either InputSchema or RawInputSchema based on which is set. +func (t Tool) MarshalJSON() ([]byte, error) { + // Create a map to build the JSON structure + m := make(map[string]interface{}, 3) + + // Add the name and description + m["name"] = t.Name + if t.Description != "" { + m["description"] = t.Description + } + + // Determine which schema to use + if t.RawInputSchema != nil { + if t.InputSchema.Type != "" { + return nil, fmt.Errorf("tool %s has both InputSchema and RawInputSchema set: %w", t.Name, errToolSchemaConflict) + } + m["inputSchema"] = t.RawInputSchema + } else { + // Use the structured InputSchema + m["inputSchema"] = t.InputSchema + } + + return json.Marshal(m) +} + +type ToolInputSchema struct { + Type string `json:"type"` + Properties map[string]interface{} `json:"properties"` + Required []string `json:"required,omitempty"` +} + +// ToolOption is a function that configures a Tool. +// It provides a flexible way to set various properties of a Tool using the functional options pattern. +type ToolOption func(*Tool) + +// PropertyOption is a function that configures a property in a Tool's input schema. +// It allows for flexible configuration of JSON Schema properties using the functional options pattern. +type PropertyOption func(map[string]interface{}) + +// +// Core Tool Functions +// + +// NewTool creates a new Tool with the given name and options. +// The tool will have an object-type input schema with configurable properties. +// Options are applied in order, allowing for flexible tool configuration. +func NewTool(name string, opts ...ToolOption) Tool { + tool := Tool{ + Name: name, + InputSchema: ToolInputSchema{ + Type: "object", + Properties: make(map[string]interface{}), + Required: nil, // Will be omitted from JSON if empty + }, + } + + for _, opt := range opts { + opt(&tool) + } + + return tool +} + +// NewToolWithRawSchema creates a new Tool with the given name and a raw JSON +// Schema. This allows for arbitrary JSON Schema to be used for the tool's input +// schema. +// +// NOTE a [Tool] built in such a way is incompatible with the [ToolOption] and +// runtime errors will result from supplying a [ToolOption] to a [Tool] built +// with this function. +func NewToolWithRawSchema(name, description string, schema json.RawMessage) Tool { + tool := Tool{ + Name: name, + Description: description, + RawInputSchema: schema, + } + + return tool +} + +// WithDescription adds a description to the Tool. +// The description should provide a clear, human-readable explanation of what the tool does. +func WithDescription(description string) ToolOption { + return func(t *Tool) { + t.Description = description + } +} + +// +// Common Property Options +// + +// Description adds a description to a property in the JSON Schema. +// The description should explain the purpose and expected values of the property. +func Description(desc string) PropertyOption { + return func(schema map[string]interface{}) { + schema["description"] = desc + } +} + +// Required marks a property as required in the tool's input schema. +// Required properties must be provided when using the tool. +func Required() PropertyOption { + return func(schema map[string]interface{}) { + schema["required"] = true + } +} + +// Title adds a display-friendly title to a property in the JSON Schema. +// This title can be used by UI components to show a more readable property name. +func Title(title string) PropertyOption { + return func(schema map[string]interface{}) { + schema["title"] = title + } +} + +// +// String Property Options +// + +// DefaultString sets the default value for a string property. +// This value will be used if the property is not explicitly provided. +func DefaultString(value string) PropertyOption { + return func(schema map[string]interface{}) { + schema["default"] = value + } +} + +// Enum specifies a list of allowed values for a string property. +// The property value must be one of the specified enum values. +func Enum(values ...string) PropertyOption { + return func(schema map[string]interface{}) { + schema["enum"] = values + } +} + +// MaxLength sets the maximum length for a string property. +// The string value must not exceed this length. +func MaxLength(max int) PropertyOption { + return func(schema map[string]interface{}) { + schema["maxLength"] = max + } +} + +// MinLength sets the minimum length for a string property. +// The string value must be at least this length. +func MinLength(min int) PropertyOption { + return func(schema map[string]interface{}) { + schema["minLength"] = min + } +} + +// Pattern sets a regex pattern that a string property must match. +// The string value must conform to the specified regular expression. +func Pattern(pattern string) PropertyOption { + return func(schema map[string]interface{}) { + schema["pattern"] = pattern + } +} + +// +// Number Property Options +// + +// DefaultNumber sets the default value for a number property. +// This value will be used if the property is not explicitly provided. +func DefaultNumber(value float64) PropertyOption { + return func(schema map[string]interface{}) { + schema["default"] = value + } +} + +// Max sets the maximum value for a number property. +// The number value must not exceed this maximum. +func Max(max float64) PropertyOption { + return func(schema map[string]interface{}) { + schema["maximum"] = max + } +} + +// Min sets the minimum value for a number property. +// The number value must not be less than this minimum. +func Min(min float64) PropertyOption { + return func(schema map[string]interface{}) { + schema["minimum"] = min + } +} + +// MultipleOf specifies that a number must be a multiple of the given value. +// The number value must be divisible by this value. +func MultipleOf(value float64) PropertyOption { + return func(schema map[string]interface{}) { + schema["multipleOf"] = value + } +} + +// +// Boolean Property Options +// + +// DefaultBool sets the default value for a boolean property. +// This value will be used if the property is not explicitly provided. +func DefaultBool(value bool) PropertyOption { + return func(schema map[string]interface{}) { + schema["default"] = value + } +} + +// +// Property Type Helpers +// + +// WithBoolean adds a boolean property to the tool schema. +// It accepts property options to configure the boolean property's behavior and constraints. +func WithBoolean(name string, opts ...PropertyOption) ToolOption { + return func(t *Tool) { + schema := map[string]interface{}{ + "type": "boolean", + } + + for _, opt := range opts { + opt(schema) + } + + // Remove required from property schema and add to InputSchema.required + if required, ok := schema["required"].(bool); ok && required { + delete(schema, "required") + t.InputSchema.Required = append(t.InputSchema.Required, name) + } + + t.InputSchema.Properties[name] = schema + } +} + +// WithNumber adds a number property to the tool schema. +// It accepts property options to configure the number property's behavior and constraints. +func WithNumber(name string, opts ...PropertyOption) ToolOption { + return func(t *Tool) { + schema := map[string]interface{}{ + "type": "number", + } + + for _, opt := range opts { + opt(schema) + } + + // Remove required from property schema and add to InputSchema.required + if required, ok := schema["required"].(bool); ok && required { + delete(schema, "required") + t.InputSchema.Required = append(t.InputSchema.Required, name) + } + + t.InputSchema.Properties[name] = schema + } +} + +// WithString adds a string property to the tool schema. +// It accepts property options to configure the string property's behavior and constraints. +func WithString(name string, opts ...PropertyOption) ToolOption { + return func(t *Tool) { + schema := map[string]interface{}{ + "type": "string", + } + + for _, opt := range opts { + opt(schema) + } + + // Remove required from property schema and add to InputSchema.required + if required, ok := schema["required"].(bool); ok && required { + delete(schema, "required") + t.InputSchema.Required = append(t.InputSchema.Required, name) + } + + t.InputSchema.Properties[name] = schema + } +} + +// WithObject adds an object property to the tool schema. +// It accepts property options to configure the object property's behavior and constraints. +func WithObject(name string, opts ...PropertyOption) ToolOption { + return func(t *Tool) { + schema := map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{}, + } + + for _, opt := range opts { + opt(schema) + } + + // Remove required from property schema and add to InputSchema.required + if required, ok := schema["required"].(bool); ok && required { + delete(schema, "required") + t.InputSchema.Required = append(t.InputSchema.Required, name) + } + + t.InputSchema.Properties[name] = schema + } +} + +// WithArray adds an array property to the tool schema. +// It accepts property options to configure the array property's behavior and constraints. +func WithArray(name string, opts ...PropertyOption) ToolOption { + return func(t *Tool) { + schema := map[string]interface{}{ + "type": "array", + } + + for _, opt := range opts { + opt(schema) + } + + // Remove required from property schema and add to InputSchema.required + if required, ok := schema["required"].(bool); ok && required { + delete(schema, "required") + t.InputSchema.Required = append(t.InputSchema.Required, name) + } + + t.InputSchema.Properties[name] = schema + } +} + +// Properties defines the properties for an object schema +func Properties(props map[string]interface{}) PropertyOption { + return func(schema map[string]interface{}) { + schema["properties"] = props + } +} + +// AdditionalProperties specifies whether additional properties are allowed in the object +// or defines a schema for additional properties +func AdditionalProperties(schema interface{}) PropertyOption { + return func(schemaMap map[string]interface{}) { + schemaMap["additionalProperties"] = schema + } +} + +// MinProperties sets the minimum number of properties for an object +func MinProperties(min int) PropertyOption { + return func(schema map[string]interface{}) { + schema["minProperties"] = min + } +} + +// MaxProperties sets the maximum number of properties for an object +func MaxProperties(max int) PropertyOption { + return func(schema map[string]interface{}) { + schema["maxProperties"] = max + } +} + +// PropertyNames defines a schema for property names in an object +func PropertyNames(schema map[string]interface{}) PropertyOption { + return func(schemaMap map[string]interface{}) { + schemaMap["propertyNames"] = schema + } +} + +// Items defines the schema for array items +func Items(schema interface{}) PropertyOption { + return func(schemaMap map[string]interface{}) { + schemaMap["items"] = schema + } +} + +// MinItems sets the minimum number of items for an array +func MinItems(min int) PropertyOption { + return func(schema map[string]interface{}) { + schema["minItems"] = min + } +} + +// MaxItems sets the maximum number of items for an array +func MaxItems(max int) PropertyOption { + return func(schema map[string]interface{}) { + schema["maxItems"] = max + } +} + +// UniqueItems specifies whether array items must be unique +func UniqueItems(unique bool) PropertyOption { + return func(schema map[string]interface{}) { + schema["uniqueItems"] = unique + } +} diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/types.go b/vendor/github.com/mark3labs/mcp-go/mcp/types.go new file mode 100644 index 0000000000..a3ad8174e6 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/mcp/types.go @@ -0,0 +1,860 @@ +// Package mcp defines the core types and interfaces for the Model Control Protocol (MCP). +// MCP is a protocol for communication between LLM-powered applications and their supporting services. +package mcp + +import ( + "encoding/json" + + "github.com/yosida95/uritemplate/v3" +) + +type MCPMethod string + +const ( + // Initiates connection and negotiates protocol capabilities. + // https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/lifecycle/#initialization + MethodInitialize MCPMethod = "initialize" + + // Verifies connection liveness between client and server. + // https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/utilities/ping/ + MethodPing MCPMethod = "ping" + + // Lists all available server resources. + // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/resources/ + MethodResourcesList MCPMethod = "resources/list" + + // Provides URI templates for constructing resource URIs. + // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/resources/ + MethodResourcesTemplatesList MCPMethod = "resources/templates/list" + + // Retrieves content of a specific resource by URI. + // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/resources/ + MethodResourcesRead MCPMethod = "resources/read" + + // Lists all available prompt templates. + // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/prompts/ + MethodPromptsList MCPMethod = "prompts/list" + + // Retrieves a specific prompt template with filled parameters. + // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/prompts/ + MethodPromptsGet MCPMethod = "prompts/get" + + // Lists all available executable tools. + // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/ + MethodToolsList MCPMethod = "tools/list" + + // Invokes a specific tool with provided parameters. + // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/ + MethodToolsCall MCPMethod = "tools/call" +) + +type URITemplate struct { + *uritemplate.Template +} + +func (t *URITemplate) MarshalJSON() ([]byte, error) { + return json.Marshal(t.Template.Raw()) +} + +func (t *URITemplate) UnmarshalJSON(data []byte) error { + var raw string + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + template, err := uritemplate.New(raw) + if err != nil { + return err + } + t.Template = template + return nil +} + +/* JSON-RPC types */ + +// JSONRPCMessage represents either a JSONRPCRequest, JSONRPCNotification, JSONRPCResponse, or JSONRPCError +type JSONRPCMessage interface{} + +// LATEST_PROTOCOL_VERSION is the most recent version of the MCP protocol. +const LATEST_PROTOCOL_VERSION = "2024-11-05" + +// JSONRPC_VERSION is the version of JSON-RPC used by MCP. +const JSONRPC_VERSION = "2.0" + +// ProgressToken is used to associate progress notifications with the original request. +type ProgressToken interface{} + +// Cursor is an opaque token used to represent a cursor for pagination. +type Cursor string + +type Request struct { + Method string `json:"method"` + Params struct { + Meta *struct { + // If specified, the caller is requesting out-of-band progress + // notifications for this request (as represented by + // notifications/progress). The value of this parameter is an + // opaque token that will be attached to any subsequent + // notifications. The receiver is not obligated to provide these + // notifications. + ProgressToken ProgressToken `json:"progressToken,omitempty"` + } `json:"_meta,omitempty"` + } `json:"params,omitempty"` +} + +type Params map[string]interface{} + +type Notification struct { + Method string `json:"method"` + Params NotificationParams `json:"params,omitempty"` +} + +type NotificationParams struct { + // This parameter name is reserved by MCP to allow clients and + // servers to attach additional metadata to their notifications. + Meta map[string]interface{} `json:"_meta,omitempty"` + + // Additional fields can be added to this map + AdditionalFields map[string]interface{} `json:"-"` +} + +// MarshalJSON implements custom JSON marshaling +func (p NotificationParams) MarshalJSON() ([]byte, error) { + // Create a map to hold all fields + m := make(map[string]interface{}) + + // Add Meta if it exists + if p.Meta != nil { + m["_meta"] = p.Meta + } + + // Add all additional fields + for k, v := range p.AdditionalFields { + // Ensure we don't override the _meta field + if k != "_meta" { + m[k] = v + } + } + + return json.Marshal(m) +} + +// UnmarshalJSON implements custom JSON unmarshaling +func (p *NotificationParams) UnmarshalJSON(data []byte) error { + // Create a map to hold all fields + var m map[string]interface{} + if err := json.Unmarshal(data, &m); err != nil { + return err + } + + // Initialize maps if they're nil + if p.Meta == nil { + p.Meta = make(map[string]interface{}) + } + if p.AdditionalFields == nil { + p.AdditionalFields = make(map[string]interface{}) + } + + // Process all fields + for k, v := range m { + if k == "_meta" { + // Handle Meta field + if meta, ok := v.(map[string]interface{}); ok { + p.Meta = meta + } + } else { + // Handle additional fields + p.AdditionalFields[k] = v + } + } + + return nil +} + +type Result struct { + // This result property is reserved by the protocol to allow clients and + // servers to attach additional metadata to their responses. + Meta map[string]interface{} `json:"_meta,omitempty"` +} + +// RequestId is a uniquely identifying ID for a request in JSON-RPC. +// It can be any JSON-serializable value, typically a number or string. +type RequestId interface{} + +// JSONRPCRequest represents a request that expects a response. +type JSONRPCRequest struct { + JSONRPC string `json:"jsonrpc"` + ID RequestId `json:"id"` + Params interface{} `json:"params,omitempty"` + Request +} + +// JSONRPCNotification represents a notification which does not expect a response. +type JSONRPCNotification struct { + JSONRPC string `json:"jsonrpc"` + Notification +} + +// JSONRPCResponse represents a successful (non-error) response to a request. +type JSONRPCResponse struct { + JSONRPC string `json:"jsonrpc"` + ID RequestId `json:"id"` + Result interface{} `json:"result"` +} + +// JSONRPCError represents a non-successful (error) response to a request. +type JSONRPCError struct { + JSONRPC string `json:"jsonrpc"` + ID RequestId `json:"id"` + Error struct { + // The error type that occurred. + Code int `json:"code"` + // A short description of the error. The message SHOULD be limited + // to a concise single sentence. + Message string `json:"message"` + // Additional information about the error. The value of this member + // is defined by the sender (e.g. detailed error information, nested errors etc.). + Data interface{} `json:"data,omitempty"` + } `json:"error"` +} + +// Standard JSON-RPC error codes +const ( + PARSE_ERROR = -32700 + INVALID_REQUEST = -32600 + METHOD_NOT_FOUND = -32601 + INVALID_PARAMS = -32602 + INTERNAL_ERROR = -32603 +) + +/* Empty result */ + +// EmptyResult represents a response that indicates success but carries no data. +type EmptyResult Result + +/* Cancellation */ + +// CancelledNotification can be sent by either side to indicate that it is +// cancelling a previously-issued request. +// +// The request SHOULD still be in-flight, but due to communication latency, it +// is always possible that this notification MAY arrive after the request has +// already finished. +// +// This notification indicates that the result will be unused, so any +// associated processing SHOULD cease. +// +// A client MUST NOT attempt to cancel its `initialize` request. +type CancelledNotification struct { + Notification + Params struct { + // The ID of the request to cancel. + // + // This MUST correspond to the ID of a request previously issued + // in the same direction. + RequestId RequestId `json:"requestId"` + + // An optional string describing the reason for the cancellation. This MAY + // be logged or presented to the user. + Reason string `json:"reason,omitempty"` + } `json:"params"` +} + +/* Initialization */ + +// InitializeRequest is sent from the client to the server when it first +// connects, asking it to begin initialization. +type InitializeRequest struct { + Request + Params struct { + // The latest version of the Model Context Protocol that the client supports. + // The client MAY decide to support older versions as well. + ProtocolVersion string `json:"protocolVersion"` + Capabilities ClientCapabilities `json:"capabilities"` + ClientInfo Implementation `json:"clientInfo"` + } `json:"params"` +} + +// InitializeResult is sent after receiving an initialize request from the +// client. +type InitializeResult struct { + Result + // The version of the Model Context Protocol that the server wants to use. + // This may not match the version that the client requested. If the client cannot + // support this version, it MUST disconnect. + ProtocolVersion string `json:"protocolVersion"` + Capabilities ServerCapabilities `json:"capabilities"` + ServerInfo Implementation `json:"serverInfo"` + // Instructions describing how to use the server and its features. + // + // This can be used by clients to improve the LLM's understanding of + // available tools, resources, etc. It can be thought of like a "hint" to the model. + // For example, this information MAY be added to the system prompt. + Instructions string `json:"instructions,omitempty"` +} + +// InitializedNotification is sent from the client to the server after +// initialization has finished. +type InitializedNotification struct { + Notification +} + +// ClientCapabilities represents capabilities a client may support. Known +// capabilities are defined here, in this schema, but this is not a closed set: any +// client can define its own, additional capabilities. +type ClientCapabilities struct { + // Experimental, non-standard capabilities that the client supports. + Experimental map[string]interface{} `json:"experimental,omitempty"` + // Present if the client supports listing roots. + Roots *struct { + // Whether the client supports notifications for changes to the roots list. + ListChanged bool `json:"listChanged,omitempty"` + } `json:"roots,omitempty"` + // Present if the client supports sampling from an LLM. + Sampling *struct{} `json:"sampling,omitempty"` +} + +// ServerCapabilities represents capabilities that a server may support. Known +// capabilities are defined here, in this schema, but this is not a closed set: any +// server can define its own, additional capabilities. +type ServerCapabilities struct { + // Experimental, non-standard capabilities that the server supports. + Experimental map[string]interface{} `json:"experimental,omitempty"` + // Present if the server supports sending log messages to the client. + Logging *struct{} `json:"logging,omitempty"` + // Present if the server offers any prompt templates. + Prompts *struct { + // Whether this server supports notifications for changes to the prompt list. + ListChanged bool `json:"listChanged,omitempty"` + } `json:"prompts,omitempty"` + // Present if the server offers any resources to read. + Resources *struct { + // Whether this server supports subscribing to resource updates. + Subscribe bool `json:"subscribe,omitempty"` + // Whether this server supports notifications for changes to the resource + // list. + ListChanged bool `json:"listChanged,omitempty"` + } `json:"resources,omitempty"` + // Present if the server offers any tools to call. + Tools *struct { + // Whether this server supports notifications for changes to the tool list. + ListChanged bool `json:"listChanged,omitempty"` + } `json:"tools,omitempty"` +} + +// Implementation describes the name and version of an MCP implementation. +type Implementation struct { + Name string `json:"name"` + Version string `json:"version"` +} + +/* Ping */ + +// PingRequest represents a ping, issued by either the server or the client, +// to check that the other party is still alive. The receiver must promptly respond, +// or else may be disconnected. +type PingRequest struct { + Request +} + +/* Progress notifications */ + +// ProgressNotification is an out-of-band notification used to inform the +// receiver of a progress update for a long-running request. +type ProgressNotification struct { + Notification + Params struct { + // The progress token which was given in the initial request, used to + // associate this notification with the request that is proceeding. + ProgressToken ProgressToken `json:"progressToken"` + // The progress thus far. This should increase every time progress is made, + // even if the total is unknown. + Progress float64 `json:"progress"` + // Total number of items to process (or total progress required), if known. + Total float64 `json:"total,omitempty"` + } `json:"params"` +} + +/* Pagination */ + +type PaginatedRequest struct { + Request + Params struct { + // An opaque token representing the current pagination position. + // If provided, the server should return results starting after this cursor. + Cursor Cursor `json:"cursor,omitempty"` + } `json:"params,omitempty"` +} + +type PaginatedResult struct { + Result + // An opaque token representing the pagination position after the last + // returned result. + // If present, there may be more results available. + NextCursor Cursor `json:"nextCursor,omitempty"` +} + +/* Resources */ + +// ListResourcesRequest is sent from the client to request a list of resources +// the server has. +type ListResourcesRequest struct { + PaginatedRequest +} + +// ListResourcesResult is the server's response to a resources/list request +// from the client. +type ListResourcesResult struct { + PaginatedResult + Resources []Resource `json:"resources"` +} + +// ListResourceTemplatesRequest is sent from the client to request a list of +// resource templates the server has. +type ListResourceTemplatesRequest struct { + PaginatedRequest +} + +// ListResourceTemplatesResult is the server's response to a +// resources/templates/list request from the client. +type ListResourceTemplatesResult struct { + PaginatedResult + ResourceTemplates []ResourceTemplate `json:"resourceTemplates"` +} + +// ReadResourceRequest is sent from the client to the server, to read a +// specific resource URI. +type ReadResourceRequest struct { + Request + Params struct { + // The URI of the resource to read. The URI can use any protocol; it is up + // to the server how to interpret it. + URI string `json:"uri"` + // Arguments to pass to the resource handler + Arguments map[string]interface{} `json:"arguments,omitempty"` + } `json:"params"` +} + +// ReadResourceResult is the server's response to a resources/read request +// from the client. +type ReadResourceResult struct { + Result + Contents []ResourceContents `json:"contents"` // Can be TextResourceContents or BlobResourceContents +} + +// ResourceListChangedNotification is an optional notification from the server +// to the client, informing it that the list of resources it can read from has +// changed. This may be issued by servers without any previous subscription from +// the client. +type ResourceListChangedNotification struct { + Notification +} + +// SubscribeRequest is sent from the client to request resources/updated +// notifications from the server whenever a particular resource changes. +type SubscribeRequest struct { + Request + Params struct { + // The URI of the resource to subscribe to. The URI can use any protocol; it + // is up to the server how to interpret it. + URI string `json:"uri"` + } `json:"params"` +} + +// UnsubscribeRequest is sent from the client to request cancellation of +// resources/updated notifications from the server. This should follow a previous +// resources/subscribe request. +type UnsubscribeRequest struct { + Request + Params struct { + // The URI of the resource to unsubscribe from. + URI string `json:"uri"` + } `json:"params"` +} + +// ResourceUpdatedNotification is a notification from the server to the client, +// informing it that a resource has changed and may need to be read again. This +// should only be sent if the client previously sent a resources/subscribe request. +type ResourceUpdatedNotification struct { + Notification + Params struct { + // The URI of the resource that has been updated. This might be a sub- + // resource of the one that the client actually subscribed to. + URI string `json:"uri"` + } `json:"params"` +} + +// Resource represents a known resource that the server is capable of reading. +type Resource struct { + Annotated + // The URI of this resource. + URI string `json:"uri"` + // A human-readable name for this resource. + // + // This can be used by clients to populate UI elements. + Name string `json:"name"` + // A description of what this resource represents. + // + // This can be used by clients to improve the LLM's understanding of + // available resources. It can be thought of like a "hint" to the model. + Description string `json:"description,omitempty"` + // The MIME type of this resource, if known. + MIMEType string `json:"mimeType,omitempty"` +} + +// ResourceTemplate represents a template description for resources available +// on the server. +type ResourceTemplate struct { + Annotated + // A URI template (according to RFC 6570) that can be used to construct + // resource URIs. + URITemplate *URITemplate `json:"uriTemplate"` + // A human-readable name for the type of resource this template refers to. + // + // This can be used by clients to populate UI elements. + Name string `json:"name"` + // A description of what this template is for. + // + // This can be used by clients to improve the LLM's understanding of + // available resources. It can be thought of like a "hint" to the model. + Description string `json:"description,omitempty"` + // The MIME type for all resources that match this template. This should only + // be included if all resources matching this template have the same type. + MIMEType string `json:"mimeType,omitempty"` +} + +// ResourceContents represents the contents of a specific resource or sub- +// resource. +type ResourceContents interface { + isResourceContents() +} + +type TextResourceContents struct { + // The URI of this resource. + URI string `json:"uri"` + // The MIME type of this resource, if known. + MIMEType string `json:"mimeType,omitempty"` + // The text of the item. This must only be set if the item can actually be + // represented as text (not binary data). + Text string `json:"text"` +} + +func (TextResourceContents) isResourceContents() {} + +type BlobResourceContents struct { + // The URI of this resource. + URI string `json:"uri"` + // The MIME type of this resource, if known. + MIMEType string `json:"mimeType,omitempty"` + // A base64-encoded string representing the binary data of the item. + Blob string `json:"blob"` +} + +func (BlobResourceContents) isResourceContents() {} + +/* Logging */ + +// SetLevelRequest is a request from the client to the server, to enable or +// adjust logging. +type SetLevelRequest struct { + Request + Params struct { + // The level of logging that the client wants to receive from the server. + // The server should send all logs at this level and higher (i.e., more severe) to + // the client as notifications/logging/message. + Level LoggingLevel `json:"level"` + } `json:"params"` +} + +// LoggingMessageNotification is a notification of a log message passed from +// server to client. If no logging/setLevel request has been sent from the client, +// the server MAY decide which messages to send automatically. +type LoggingMessageNotification struct { + Notification + Params struct { + // The severity of this log message. + Level LoggingLevel `json:"level"` + // An optional name of the logger issuing this message. + Logger string `json:"logger,omitempty"` + // The data to be logged, such as a string message or an object. Any JSON + // serializable type is allowed here. + Data interface{} `json:"data"` + } `json:"params"` +} + +// LoggingLevel represents the severity of a log message. +// +// These map to syslog message severities, as specified in RFC-5424: +// https://datatracker.ietf.org/doc/html/rfc5424#section-6.2.1 +type LoggingLevel string + +const ( + LoggingLevelDebug LoggingLevel = "debug" + LoggingLevelInfo LoggingLevel = "info" + LoggingLevelNotice LoggingLevel = "notice" + LoggingLevelWarning LoggingLevel = "warning" + LoggingLevelError LoggingLevel = "error" + LoggingLevelCritical LoggingLevel = "critical" + LoggingLevelAlert LoggingLevel = "alert" + LoggingLevelEmergency LoggingLevel = "emergency" +) + +/* Sampling */ + +// CreateMessageRequest is a request from the server to sample an LLM via the +// client. The client has full discretion over which model to select. The client +// should also inform the user before beginning sampling, to allow them to inspect +// the request (human in the loop) and decide whether to approve it. +type CreateMessageRequest struct { + Request + Params struct { + Messages []SamplingMessage `json:"messages"` + ModelPreferences *ModelPreferences `json:"modelPreferences,omitempty"` + SystemPrompt string `json:"systemPrompt,omitempty"` + IncludeContext string `json:"includeContext,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + MaxTokens int `json:"maxTokens"` + StopSequences []string `json:"stopSequences,omitempty"` + Metadata interface{} `json:"metadata,omitempty"` + } `json:"params"` +} + +// CreateMessageResult is the client's response to a sampling/create_message +// request from the server. The client should inform the user before returning the +// sampled message, to allow them to inspect the response (human in the loop) and +// decide whether to allow the server to see it. +type CreateMessageResult struct { + Result + SamplingMessage + // The name of the model that generated the message. + Model string `json:"model"` + // The reason why sampling stopped, if known. + StopReason string `json:"stopReason,omitempty"` +} + +// SamplingMessage describes a message issued to or received from an LLM API. +type SamplingMessage struct { + Role Role `json:"role"` + Content interface{} `json:"content"` // Can be TextContent or ImageContent +} + +// Annotated is the base for objects that include optional annotations for the +// client. The client can use annotations to inform how objects are used or +// displayed +type Annotated struct { + Annotations *struct { + // Describes who the intended customer of this object or data is. + // + // It can include multiple entries to indicate content useful for multiple + // audiences (e.g., `["user", "assistant"]`). + Audience []Role `json:"audience,omitempty"` + + // Describes how important this data is for operating the server. + // + // A value of 1 means "most important," and indicates that the data is + // effectively required, while 0 means "least important," and indicates that + // the data is entirely optional. + Priority float64 `json:"priority,omitempty"` + } `json:"annotations,omitempty"` +} + +type Content interface { + isContent() +} + +// TextContent represents text provided to or from an LLM. +// It must have Type set to "text". +type TextContent struct { + Annotated + Type string `json:"type"` // Must be "text" + // The text content of the message. + Text string `json:"text"` +} + +func (TextContent) isContent() {} + +// ImageContent represents an image provided to or from an LLM. +// It must have Type set to "image". +type ImageContent struct { + Annotated + Type string `json:"type"` // Must be "image" + // The base64-encoded image data. + Data string `json:"data"` + // The MIME type of the image. Different providers may support different image types. + MIMEType string `json:"mimeType"` +} + +func (ImageContent) isContent() {} + +// EmbeddedResource represents the contents of a resource, embedded into a prompt or tool call result. +// +// It is up to the client how best to render embedded resources for the +// benefit of the LLM and/or the user. +type EmbeddedResource struct { + Annotated + Type string `json:"type"` + Resource ResourceContents `json:"resource"` +} + +func (EmbeddedResource) isContent() {} + +// ModelPreferences represents the server's preferences for model selection, +// requested of the client during sampling. +// +// Because LLMs can vary along multiple dimensions, choosing the "best" modelis +// rarely straightforward. Different models excel in different areas—some are +// faster but less capable, others are more capable but more expensive, and so +// on. This interface allows servers to express their priorities across multiple +// dimensions to help clients make an appropriate selection for their use case. +// +// These preferences are always advisory. The client MAY ignore them. It is also +// up to the client to decide how to interpret these preferences and how to +// balance them against other considerations. +type ModelPreferences struct { + // Optional hints to use for model selection. + // + // If multiple hints are specified, the client MUST evaluate them in order + // (such that the first match is taken). + // + // The client SHOULD prioritize these hints over the numeric priorities, but + // MAY still use the priorities to select from ambiguous matches. + Hints []ModelHint `json:"hints,omitempty"` + + // How much to prioritize cost when selecting a model. A value of 0 means cost + // is not important, while a value of 1 means cost is the most important + // factor. + CostPriority float64 `json:"costPriority,omitempty"` + + // How much to prioritize sampling speed (latency) when selecting a model. A + // value of 0 means speed is not important, while a value of 1 means speed is + // the most important factor. + SpeedPriority float64 `json:"speedPriority,omitempty"` + + // How much to prioritize intelligence and capabilities when selecting a + // model. A value of 0 means intelligence is not important, while a value of 1 + // means intelligence is the most important factor. + IntelligencePriority float64 `json:"intelligencePriority,omitempty"` +} + +// ModelHint represents hints to use for model selection. +// +// Keys not declared here are currently left unspecified by the spec and are up +// to the client to interpret. +type ModelHint struct { + // A hint for a model name. + // + // The client SHOULD treat this as a substring of a model name; for example: + // - `claude-3-5-sonnet` should match `claude-3-5-sonnet-20241022` + // - `sonnet` should match `claude-3-5-sonnet-20241022`, `claude-3-sonnet-20240229`, etc. + // - `claude` should match any Claude model + // + // The client MAY also map the string to a different provider's model name or + // a different model family, as long as it fills a similar niche; for example: + // - `gemini-1.5-flash` could match `claude-3-haiku-20240307` + Name string `json:"name,omitempty"` +} + +/* Autocomplete */ + +// CompleteRequest is a request from the client to the server, to ask for completion options. +type CompleteRequest struct { + Request + Params struct { + Ref interface{} `json:"ref"` // Can be PromptReference or ResourceReference + Argument struct { + // The name of the argument + Name string `json:"name"` + // The value of the argument to use for completion matching. + Value string `json:"value"` + } `json:"argument"` + } `json:"params"` +} + +// CompleteResult is the server's response to a completion/complete request +type CompleteResult struct { + Result + Completion struct { + // An array of completion values. Must not exceed 100 items. + Values []string `json:"values"` + // The total number of completion options available. This can exceed the + // number of values actually sent in the response. + Total int `json:"total,omitempty"` + // Indicates whether there are additional completion options beyond those + // provided in the current response, even if the exact total is unknown. + HasMore bool `json:"hasMore,omitempty"` + } `json:"completion"` +} + +// ResourceReference is a reference to a resource or resource template definition. +type ResourceReference struct { + Type string `json:"type"` + // The URI or URI template of the resource. + URI string `json:"uri"` +} + +// PromptReference identifies a prompt. +type PromptReference struct { + Type string `json:"type"` + // The name of the prompt or prompt template + Name string `json:"name"` +} + +/* Roots */ + +// ListRootsRequest is sent from the server to request a list of root URIs from the client. Roots allow +// servers to ask for specific directories or files to operate on. A common example +// for roots is providing a set of repositories or directories a server should operate +// on. +// +// This request is typically used when the server needs to understand the file system +// structure or access specific locations that the client has permission to read from. +type ListRootsRequest struct { + Request +} + +// ListRootsResult is the client's response to a roots/list request from the server. +// This result contains an array of Root objects, each representing a root directory +// or file that the server can operate on. +type ListRootsResult struct { + Result + Roots []Root `json:"roots"` +} + +// Root represents a root directory or file that the server can operate on. +type Root struct { + // The URI identifying the root. This *must* start with file:// for now. + // This restriction may be relaxed in future versions of the protocol to allow + // other URI schemes. + URI string `json:"uri"` + // An optional name for the root. This can be used to provide a human-readable + // identifier for the root, which may be useful for display purposes or for + // referencing the root in other parts of the application. + Name string `json:"name,omitempty"` +} + +// RootsListChangedNotification is a notification from the client to the +// server, informing it that the list of roots has changed. +// This notification should be sent whenever the client adds, removes, or modifies any root. +// The server should then request an updated list of roots using the ListRootsRequest. +type RootsListChangedNotification struct { + Notification +} + +/* Client messages */ +// ClientRequest represents any request that can be sent from client to server. +type ClientRequest interface{} + +// ClientNotification represents any notification that can be sent from client to server. +type ClientNotification interface{} + +// ClientResult represents any result that can be sent from client to server. +type ClientResult interface{} + +/* Server messages */ +// ServerRequest represents any request that can be sent from server to client. +type ServerRequest interface{} + +// ServerNotification represents any notification that can be sent from server to client. +type ServerNotification interface{} + +// ServerResult represents any result that can be sent from server to client. +type ServerResult interface{} diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/utils.go b/vendor/github.com/mark3labs/mcp-go/mcp/utils.go new file mode 100644 index 0000000000..236164cbd8 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/mcp/utils.go @@ -0,0 +1,596 @@ +package mcp + +import ( + "encoding/json" + "fmt" +) + +// ClientRequest types +var _ ClientRequest = &PingRequest{} +var _ ClientRequest = &InitializeRequest{} +var _ ClientRequest = &CompleteRequest{} +var _ ClientRequest = &SetLevelRequest{} +var _ ClientRequest = &GetPromptRequest{} +var _ ClientRequest = &ListPromptsRequest{} +var _ ClientRequest = &ListResourcesRequest{} +var _ ClientRequest = &ReadResourceRequest{} +var _ ClientRequest = &SubscribeRequest{} +var _ ClientRequest = &UnsubscribeRequest{} +var _ ClientRequest = &CallToolRequest{} +var _ ClientRequest = &ListToolsRequest{} + +// ClientNotification types +var _ ClientNotification = &CancelledNotification{} +var _ ClientNotification = &ProgressNotification{} +var _ ClientNotification = &InitializedNotification{} +var _ ClientNotification = &RootsListChangedNotification{} + +// ClientResult types +var _ ClientResult = &EmptyResult{} +var _ ClientResult = &CreateMessageResult{} +var _ ClientResult = &ListRootsResult{} + +// ServerRequest types +var _ ServerRequest = &PingRequest{} +var _ ServerRequest = &CreateMessageRequest{} +var _ ServerRequest = &ListRootsRequest{} + +// ServerNotification types +var _ ServerNotification = &CancelledNotification{} +var _ ServerNotification = &ProgressNotification{} +var _ ServerNotification = &LoggingMessageNotification{} +var _ ServerNotification = &ResourceUpdatedNotification{} +var _ ServerNotification = &ResourceListChangedNotification{} +var _ ServerNotification = &ToolListChangedNotification{} +var _ ServerNotification = &PromptListChangedNotification{} + +// ServerResult types +var _ ServerResult = &EmptyResult{} +var _ ServerResult = &InitializeResult{} +var _ ServerResult = &CompleteResult{} +var _ ServerResult = &GetPromptResult{} +var _ ServerResult = &ListPromptsResult{} +var _ ServerResult = &ListResourcesResult{} +var _ ServerResult = &ReadResourceResult{} +var _ ServerResult = &CallToolResult{} +var _ ServerResult = &ListToolsResult{} + +// Helper functions for type assertions + +// asType attempts to cast the given interface to the given type +func asType[T any](content interface{}) (*T, bool) { + tc, ok := content.(T) + if !ok { + return nil, false + } + return &tc, true +} + +// AsTextContent attempts to cast the given interface to TextContent +func AsTextContent(content interface{}) (*TextContent, bool) { + return asType[TextContent](content) +} + +// AsImageContent attempts to cast the given interface to ImageContent +func AsImageContent(content interface{}) (*ImageContent, bool) { + return asType[ImageContent](content) +} + +// AsEmbeddedResource attempts to cast the given interface to EmbeddedResource +func AsEmbeddedResource(content interface{}) (*EmbeddedResource, bool) { + return asType[EmbeddedResource](content) +} + +// AsTextResourceContents attempts to cast the given interface to TextResourceContents +func AsTextResourceContents(content interface{}) (*TextResourceContents, bool) { + return asType[TextResourceContents](content) +} + +// AsBlobResourceContents attempts to cast the given interface to BlobResourceContents +func AsBlobResourceContents(content interface{}) (*BlobResourceContents, bool) { + return asType[BlobResourceContents](content) +} + +// Helper function for JSON-RPC + +// NewJSONRPCResponse creates a new JSONRPCResponse with the given id and result +func NewJSONRPCResponse(id RequestId, result Result) JSONRPCResponse { + return JSONRPCResponse{ + JSONRPC: JSONRPC_VERSION, + ID: id, + Result: result, + } +} + +// NewJSONRPCError creates a new JSONRPCResponse with the given id, code, and message +func NewJSONRPCError( + id RequestId, + code int, + message string, + data interface{}, +) JSONRPCError { + return JSONRPCError{ + JSONRPC: JSONRPC_VERSION, + ID: id, + Error: struct { + Code int `json:"code"` + Message string `json:"message"` + Data interface{} `json:"data,omitempty"` + }{ + Code: code, + Message: message, + Data: data, + }, + } +} + +// Helper function for creating a progress notification +func NewProgressNotification( + token ProgressToken, + progress float64, + total *float64, +) ProgressNotification { + notification := ProgressNotification{ + Notification: Notification{ + Method: "notifications/progress", + }, + Params: struct { + ProgressToken ProgressToken `json:"progressToken"` + Progress float64 `json:"progress"` + Total float64 `json:"total,omitempty"` + }{ + ProgressToken: token, + Progress: progress, + }, + } + if total != nil { + notification.Params.Total = *total + } + return notification +} + +// Helper function for creating a logging message notification +func NewLoggingMessageNotification( + level LoggingLevel, + logger string, + data interface{}, +) LoggingMessageNotification { + return LoggingMessageNotification{ + Notification: Notification{ + Method: "notifications/message", + }, + Params: struct { + Level LoggingLevel `json:"level"` + Logger string `json:"logger,omitempty"` + Data interface{} `json:"data"` + }{ + Level: level, + Logger: logger, + Data: data, + }, + } +} + +// Helper function to create a new PromptMessage +func NewPromptMessage(role Role, content Content) PromptMessage { + return PromptMessage{ + Role: role, + Content: content, + } +} + +// Helper function to create a new TextContent +func NewTextContent(text string) TextContent { + return TextContent{ + Type: "text", + Text: text, + } +} + +// Helper function to create a new ImageContent +func NewImageContent(data, mimeType string) ImageContent { + return ImageContent{ + Type: "image", + Data: data, + MIMEType: mimeType, + } +} + +// Helper function to create a new EmbeddedResource +func NewEmbeddedResource(resource ResourceContents) EmbeddedResource { + return EmbeddedResource{ + Type: "resource", + Resource: resource, + } +} + +// NewToolResultText creates a new CallToolResult with a text content +func NewToolResultText(text string) *CallToolResult { + return &CallToolResult{ + Content: []Content{ + TextContent{ + Type: "text", + Text: text, + }, + }, + } +} + +// NewToolResultImage creates a new CallToolResult with both text and image content +func NewToolResultImage(text, imageData, mimeType string) *CallToolResult { + return &CallToolResult{ + Content: []Content{ + TextContent{ + Type: "text", + Text: text, + }, + ImageContent{ + Type: "image", + Data: imageData, + MIMEType: mimeType, + }, + }, + } +} + +// NewToolResultResource creates a new CallToolResult with an embedded resource +func NewToolResultResource( + text string, + resource ResourceContents, +) *CallToolResult { + return &CallToolResult{ + Content: []Content{ + TextContent{ + Type: "text", + Text: text, + }, + EmbeddedResource{ + Type: "resource", + Resource: resource, + }, + }, + } +} + +// NewToolResultError creates a new CallToolResult with an error message. +// Any errors that originate from the tool SHOULD be reported inside the result object. +func NewToolResultError(text string) *CallToolResult { + return &CallToolResult{ + Content: []Content{ + TextContent{ + Type: "text", + Text: text, + }, + }, + IsError: true, + } +} + +// NewListResourcesResult creates a new ListResourcesResult +func NewListResourcesResult( + resources []Resource, + nextCursor Cursor, +) *ListResourcesResult { + return &ListResourcesResult{ + PaginatedResult: PaginatedResult{ + NextCursor: nextCursor, + }, + Resources: resources, + } +} + +// NewListResourceTemplatesResult creates a new ListResourceTemplatesResult +func NewListResourceTemplatesResult( + templates []ResourceTemplate, + nextCursor Cursor, +) *ListResourceTemplatesResult { + return &ListResourceTemplatesResult{ + PaginatedResult: PaginatedResult{ + NextCursor: nextCursor, + }, + ResourceTemplates: templates, + } +} + +// NewReadResourceResult creates a new ReadResourceResult with text content +func NewReadResourceResult(text string) *ReadResourceResult { + return &ReadResourceResult{ + Contents: []ResourceContents{ + TextResourceContents{ + Text: text, + }, + }, + } +} + +// NewListPromptsResult creates a new ListPromptsResult +func NewListPromptsResult( + prompts []Prompt, + nextCursor Cursor, +) *ListPromptsResult { + return &ListPromptsResult{ + PaginatedResult: PaginatedResult{ + NextCursor: nextCursor, + }, + Prompts: prompts, + } +} + +// NewGetPromptResult creates a new GetPromptResult +func NewGetPromptResult( + description string, + messages []PromptMessage, +) *GetPromptResult { + return &GetPromptResult{ + Description: description, + Messages: messages, + } +} + +// NewListToolsResult creates a new ListToolsResult +func NewListToolsResult(tools []Tool, nextCursor Cursor) *ListToolsResult { + return &ListToolsResult{ + PaginatedResult: PaginatedResult{ + NextCursor: nextCursor, + }, + Tools: tools, + } +} + +// NewInitializeResult creates a new InitializeResult +func NewInitializeResult( + protocolVersion string, + capabilities ServerCapabilities, + serverInfo Implementation, + instructions string, +) *InitializeResult { + return &InitializeResult{ + ProtocolVersion: protocolVersion, + Capabilities: capabilities, + ServerInfo: serverInfo, + Instructions: instructions, + } +} + +// Helper for formatting numbers in tool results +func FormatNumberResult(value float64) *CallToolResult { + return NewToolResultText(fmt.Sprintf("%.2f", value)) +} + +func ExtractString(data map[string]any, key string) string { + if value, ok := data[key]; ok { + if str, ok := value.(string); ok { + return str + } + } + return "" +} + +func ExtractMap(data map[string]any, key string) map[string]any { + if value, ok := data[key]; ok { + if m, ok := value.(map[string]any); ok { + return m + } + } + return nil +} + +func ParseContent(contentMap map[string]any) (Content, error) { + contentType := ExtractString(contentMap, "type") + + switch contentType { + case "text": + text := ExtractString(contentMap, "text") + if text == "" { + return nil, fmt.Errorf("text is missing") + } + return NewTextContent(text), nil + + case "image": + data := ExtractString(contentMap, "data") + mimeType := ExtractString(contentMap, "mimeType") + if data == "" || mimeType == "" { + return nil, fmt.Errorf("image data or mimeType is missing") + } + return NewImageContent(data, mimeType), nil + + case "resource": + resourceMap := ExtractMap(contentMap, "resource") + if resourceMap == nil { + return nil, fmt.Errorf("resource is missing") + } + + resourceContents, err := ParseResourceContents(resourceMap) + if err != nil { + return nil, err + } + + return NewEmbeddedResource(resourceContents), nil + } + + return nil, fmt.Errorf("unsupported content type: %s", contentType) +} + +func ParseGetPromptResult(rawMessage *json.RawMessage) (*GetPromptResult, error) { + var jsonContent map[string]any + if err := json.Unmarshal(*rawMessage, &jsonContent); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + result := GetPromptResult{} + + meta, ok := jsonContent["_meta"] + if ok { + if metaMap, ok := meta.(map[string]any); ok { + result.Meta = metaMap + } + } + + description, ok := jsonContent["description"] + if ok { + if descriptionStr, ok := description.(string); ok { + result.Description = descriptionStr + } + } + + messages, ok := jsonContent["messages"] + if ok { + messagesArr, ok := messages.([]any) + if !ok { + return nil, fmt.Errorf("messages is not an array") + } + + for _, message := range messagesArr { + messageMap, ok := message.(map[string]any) + if !ok { + return nil, fmt.Errorf("message is not an object") + } + + // Extract role + roleStr := ExtractString(messageMap, "role") + if roleStr == "" || (roleStr != string(RoleAssistant) && roleStr != string(RoleUser)) { + return nil, fmt.Errorf("unsupported role: %s", roleStr) + } + + // Extract content + contentMap, ok := messageMap["content"].(map[string]any) + if !ok { + return nil, fmt.Errorf("content is not an object") + } + + // Process content + content, err := ParseContent(contentMap) + if err != nil { + return nil, err + } + + // Append processed message + result.Messages = append(result.Messages, NewPromptMessage(Role(roleStr), content)) + + } + } + + return &result, nil +} + +func ParseCallToolResult(rawMessage *json.RawMessage) (*CallToolResult, error) { + var jsonContent map[string]any + if err := json.Unmarshal(*rawMessage, &jsonContent); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + var result CallToolResult + + meta, ok := jsonContent["_meta"] + if ok { + if metaMap, ok := meta.(map[string]any); ok { + result.Meta = metaMap + } + } + + isError, ok := jsonContent["isError"] + if ok { + if isErrorBool, ok := isError.(bool); ok { + result.IsError = isErrorBool + } + } + + contents, ok := jsonContent["content"] + if !ok { + return nil, fmt.Errorf("content is missing") + } + + contentArr, ok := contents.([]any) + if !ok { + return nil, fmt.Errorf("content is not an array") + } + + for _, content := range contentArr { + // Extract content + contentMap, ok := content.(map[string]any) + if !ok { + return nil, fmt.Errorf("content is not an object") + } + + // Process content + content, err := ParseContent(contentMap) + if err != nil { + return nil, err + } + + result.Content = append(result.Content, content) + } + + return &result, nil +} + +func ParseResourceContents(contentMap map[string]any) (ResourceContents, error) { + uri := ExtractString(contentMap, "uri") + if uri == "" { + return nil, fmt.Errorf("resource uri is missing") + } + + mimeType := ExtractString(contentMap, "mimeType") + + if text := ExtractString(contentMap, "text"); text != "" { + return TextResourceContents{ + URI: uri, + MIMEType: mimeType, + Text: text, + }, nil + } + + if blob := ExtractString(contentMap, "blob"); blob != "" { + return BlobResourceContents{ + URI: uri, + MIMEType: mimeType, + Blob: blob, + }, nil + } + + return nil, fmt.Errorf("unsupported resource type") +} + +func ParseReadResourceResult(rawMessage *json.RawMessage) (*ReadResourceResult, error) { + var jsonContent map[string]any + if err := json.Unmarshal(*rawMessage, &jsonContent); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + var result ReadResourceResult + + meta, ok := jsonContent["_meta"] + if ok { + if metaMap, ok := meta.(map[string]any); ok { + result.Meta = metaMap + } + } + + contents, ok := jsonContent["contents"] + if !ok { + return nil, fmt.Errorf("contents is missing") + } + + contentArr, ok := contents.([]any) + if !ok { + return nil, fmt.Errorf("contents is not an array") + } + + for _, content := range contentArr { + // Extract content + contentMap, ok := content.(map[string]any) + if !ok { + return nil, fmt.Errorf("content is not an object") + } + + // Process content + content, err := ParseResourceContents(contentMap) + if err != nil { + return nil, err + } + + result.Contents = append(result.Contents, content) + } + + return &result, nil +} diff --git a/vendor/github.com/mark3labs/mcp-go/server/hooks.go b/vendor/github.com/mark3labs/mcp-go/server/hooks.go new file mode 100644 index 0000000000..ce976a6cdb --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/server/hooks.go @@ -0,0 +1,461 @@ +// Code generated by `go generate`. DO NOT EDIT. +// source: server/internal/gen/hooks.go.tmpl +package server + +import ( + "context" + + "github.com/mark3labs/mcp-go/mcp" +) + +// OnRegisterSessionHookFunc is a hook that will be called when a new session is registered. +type OnRegisterSessionHookFunc func(ctx context.Context, session ClientSession) + +// BeforeAnyHookFunc is a function that is called after the request is +// parsed but before the method is called. +type BeforeAnyHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, message any) + +// OnSuccessHookFunc is a hook that will be called after the request +// successfully generates a result, but before the result is sent to the client. +type OnSuccessHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, message any, result any) + +// OnErrorHookFunc is a hook that will be called when an error occurs, +// either during the request parsing or the method execution. +// +// Example usage: +// ``` +// +// hooks.AddOnError(func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) { +// // Check for specific error types using errors.Is +// if errors.Is(err, ErrUnsupported) { +// // Handle capability not supported errors +// log.Printf("Capability not supported: %v", err) +// } +// +// // Use errors.As to get specific error types +// var parseErr = &UnparseableMessageError{} +// if errors.As(err, &parseErr) { +// // Access specific methods/fields of the error type +// log.Printf("Failed to parse message for method %s: %v", +// parseErr.GetMethod(), parseErr.Unwrap()) +// // Access the raw message that failed to parse +// rawMsg := parseErr.GetMessage() +// } +// +// // Check for specific resource/prompt/tool errors +// switch { +// case errors.Is(err, ErrResourceNotFound): +// log.Printf("Resource not found: %v", err) +// case errors.Is(err, ErrPromptNotFound): +// log.Printf("Prompt not found: %v", err) +// case errors.Is(err, ErrToolNotFound): +// log.Printf("Tool not found: %v", err) +// } +// }) +type OnErrorHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) + +type OnBeforeInitializeFunc func(ctx context.Context, id any, message *mcp.InitializeRequest) +type OnAfterInitializeFunc func(ctx context.Context, id any, message *mcp.InitializeRequest, result *mcp.InitializeResult) + +type OnBeforePingFunc func(ctx context.Context, id any, message *mcp.PingRequest) +type OnAfterPingFunc func(ctx context.Context, id any, message *mcp.PingRequest, result *mcp.EmptyResult) + +type OnBeforeListResourcesFunc func(ctx context.Context, id any, message *mcp.ListResourcesRequest) +type OnAfterListResourcesFunc func(ctx context.Context, id any, message *mcp.ListResourcesRequest, result *mcp.ListResourcesResult) + +type OnBeforeListResourceTemplatesFunc func(ctx context.Context, id any, message *mcp.ListResourceTemplatesRequest) +type OnAfterListResourceTemplatesFunc func(ctx context.Context, id any, message *mcp.ListResourceTemplatesRequest, result *mcp.ListResourceTemplatesResult) + +type OnBeforeReadResourceFunc func(ctx context.Context, id any, message *mcp.ReadResourceRequest) +type OnAfterReadResourceFunc func(ctx context.Context, id any, message *mcp.ReadResourceRequest, result *mcp.ReadResourceResult) + +type OnBeforeListPromptsFunc func(ctx context.Context, id any, message *mcp.ListPromptsRequest) +type OnAfterListPromptsFunc func(ctx context.Context, id any, message *mcp.ListPromptsRequest, result *mcp.ListPromptsResult) + +type OnBeforeGetPromptFunc func(ctx context.Context, id any, message *mcp.GetPromptRequest) +type OnAfterGetPromptFunc func(ctx context.Context, id any, message *mcp.GetPromptRequest, result *mcp.GetPromptResult) + +type OnBeforeListToolsFunc func(ctx context.Context, id any, message *mcp.ListToolsRequest) +type OnAfterListToolsFunc func(ctx context.Context, id any, message *mcp.ListToolsRequest, result *mcp.ListToolsResult) + +type OnBeforeCallToolFunc func(ctx context.Context, id any, message *mcp.CallToolRequest) +type OnAfterCallToolFunc func(ctx context.Context, id any, message *mcp.CallToolRequest, result *mcp.CallToolResult) + +type Hooks struct { + OnRegisterSession []OnRegisterSessionHookFunc + OnBeforeAny []BeforeAnyHookFunc + OnSuccess []OnSuccessHookFunc + OnError []OnErrorHookFunc + OnBeforeInitialize []OnBeforeInitializeFunc + OnAfterInitialize []OnAfterInitializeFunc + OnBeforePing []OnBeforePingFunc + OnAfterPing []OnAfterPingFunc + OnBeforeListResources []OnBeforeListResourcesFunc + OnAfterListResources []OnAfterListResourcesFunc + OnBeforeListResourceTemplates []OnBeforeListResourceTemplatesFunc + OnAfterListResourceTemplates []OnAfterListResourceTemplatesFunc + OnBeforeReadResource []OnBeforeReadResourceFunc + OnAfterReadResource []OnAfterReadResourceFunc + OnBeforeListPrompts []OnBeforeListPromptsFunc + OnAfterListPrompts []OnAfterListPromptsFunc + OnBeforeGetPrompt []OnBeforeGetPromptFunc + OnAfterGetPrompt []OnAfterGetPromptFunc + OnBeforeListTools []OnBeforeListToolsFunc + OnAfterListTools []OnAfterListToolsFunc + OnBeforeCallTool []OnBeforeCallToolFunc + OnAfterCallTool []OnAfterCallToolFunc +} + +func (c *Hooks) AddBeforeAny(hook BeforeAnyHookFunc) { + c.OnBeforeAny = append(c.OnBeforeAny, hook) +} + +func (c *Hooks) AddOnSuccess(hook OnSuccessHookFunc) { + c.OnSuccess = append(c.OnSuccess, hook) +} + +// AddOnError registers a hook function that will be called when an error occurs. +// The error parameter contains the actual error object, which can be interrogated +// using Go's error handling patterns like errors.Is and errors.As. +// +// Example: +// ``` +// // Create a channel to receive errors for testing +// errChan := make(chan error, 1) +// +// // Register hook to capture and inspect errors +// hooks := &Hooks{} +// +// hooks.AddOnError(func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) { +// // For capability-related errors +// if errors.Is(err, ErrUnsupported) { +// // Handle capability not supported +// errChan <- err +// return +// } +// +// // For parsing errors +// var parseErr = &UnparseableMessageError{} +// if errors.As(err, &parseErr) { +// // Handle unparseable message errors +// fmt.Printf("Failed to parse %s request: %v\n", +// parseErr.GetMethod(), parseErr.Unwrap()) +// errChan <- parseErr +// return +// } +// +// // For resource/prompt/tool not found errors +// if errors.Is(err, ErrResourceNotFound) || +// errors.Is(err, ErrPromptNotFound) || +// errors.Is(err, ErrToolNotFound) { +// // Handle not found errors +// errChan <- err +// return +// } +// +// // For other errors +// errChan <- err +// }) +// +// server := NewMCPServer("test-server", "1.0.0", WithHooks(hooks)) +// ``` +func (c *Hooks) AddOnError(hook OnErrorHookFunc) { + c.OnError = append(c.OnError, hook) +} + +func (c *Hooks) beforeAny(ctx context.Context, id any, method mcp.MCPMethod, message any) { + if c == nil { + return + } + for _, hook := range c.OnBeforeAny { + hook(ctx, id, method, message) + } +} + +func (c *Hooks) onSuccess(ctx context.Context, id any, method mcp.MCPMethod, message any, result any) { + if c == nil { + return + } + for _, hook := range c.OnSuccess { + hook(ctx, id, method, message, result) + } +} + +// onError calls all registered error hooks with the error object. +// The err parameter contains the actual error that occurred, which implements +// the standard error interface and may be a wrapped error or custom error type. +// +// This allows consumer code to use Go's error handling patterns: +// - errors.Is(err, ErrUnsupported) to check for specific sentinel errors +// - errors.As(err, &customErr) to extract custom error types +// +// Common error types include: +// - ErrUnsupported: When a capability is not enabled +// - UnparseableMessageError: When request parsing fails +// - ErrResourceNotFound: When a resource is not found +// - ErrPromptNotFound: When a prompt is not found +// - ErrToolNotFound: When a tool is not found +func (c *Hooks) onError(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) { + if c == nil { + return + } + for _, hook := range c.OnError { + hook(ctx, id, method, message, err) + } +} + +func (c *Hooks) AddOnRegisterSession(hook OnRegisterSessionHookFunc) { + c.OnRegisterSession = append(c.OnRegisterSession, hook) +} + +func (c *Hooks) RegisterSession(ctx context.Context, session ClientSession) { + if c == nil { + return + } + for _, hook := range c.OnRegisterSession { + hook(ctx, session) + } +} +func (c *Hooks) AddBeforeInitialize(hook OnBeforeInitializeFunc) { + c.OnBeforeInitialize = append(c.OnBeforeInitialize, hook) +} + +func (c *Hooks) AddAfterInitialize(hook OnAfterInitializeFunc) { + c.OnAfterInitialize = append(c.OnAfterInitialize, hook) +} + +func (c *Hooks) beforeInitialize(ctx context.Context, id any, message *mcp.InitializeRequest) { + c.beforeAny(ctx, id, mcp.MethodInitialize, message) + if c == nil { + return + } + for _, hook := range c.OnBeforeInitialize { + hook(ctx, id, message) + } +} + +func (c *Hooks) afterInitialize(ctx context.Context, id any, message *mcp.InitializeRequest, result *mcp.InitializeResult) { + c.onSuccess(ctx, id, mcp.MethodInitialize, message, result) + if c == nil { + return + } + for _, hook := range c.OnAfterInitialize { + hook(ctx, id, message, result) + } +} +func (c *Hooks) AddBeforePing(hook OnBeforePingFunc) { + c.OnBeforePing = append(c.OnBeforePing, hook) +} + +func (c *Hooks) AddAfterPing(hook OnAfterPingFunc) { + c.OnAfterPing = append(c.OnAfterPing, hook) +} + +func (c *Hooks) beforePing(ctx context.Context, id any, message *mcp.PingRequest) { + c.beforeAny(ctx, id, mcp.MethodPing, message) + if c == nil { + return + } + for _, hook := range c.OnBeforePing { + hook(ctx, id, message) + } +} + +func (c *Hooks) afterPing(ctx context.Context, id any, message *mcp.PingRequest, result *mcp.EmptyResult) { + c.onSuccess(ctx, id, mcp.MethodPing, message, result) + if c == nil { + return + } + for _, hook := range c.OnAfterPing { + hook(ctx, id, message, result) + } +} +func (c *Hooks) AddBeforeListResources(hook OnBeforeListResourcesFunc) { + c.OnBeforeListResources = append(c.OnBeforeListResources, hook) +} + +func (c *Hooks) AddAfterListResources(hook OnAfterListResourcesFunc) { + c.OnAfterListResources = append(c.OnAfterListResources, hook) +} + +func (c *Hooks) beforeListResources(ctx context.Context, id any, message *mcp.ListResourcesRequest) { + c.beforeAny(ctx, id, mcp.MethodResourcesList, message) + if c == nil { + return + } + for _, hook := range c.OnBeforeListResources { + hook(ctx, id, message) + } +} + +func (c *Hooks) afterListResources(ctx context.Context, id any, message *mcp.ListResourcesRequest, result *mcp.ListResourcesResult) { + c.onSuccess(ctx, id, mcp.MethodResourcesList, message, result) + if c == nil { + return + } + for _, hook := range c.OnAfterListResources { + hook(ctx, id, message, result) + } +} +func (c *Hooks) AddBeforeListResourceTemplates(hook OnBeforeListResourceTemplatesFunc) { + c.OnBeforeListResourceTemplates = append(c.OnBeforeListResourceTemplates, hook) +} + +func (c *Hooks) AddAfterListResourceTemplates(hook OnAfterListResourceTemplatesFunc) { + c.OnAfterListResourceTemplates = append(c.OnAfterListResourceTemplates, hook) +} + +func (c *Hooks) beforeListResourceTemplates(ctx context.Context, id any, message *mcp.ListResourceTemplatesRequest) { + c.beforeAny(ctx, id, mcp.MethodResourcesTemplatesList, message) + if c == nil { + return + } + for _, hook := range c.OnBeforeListResourceTemplates { + hook(ctx, id, message) + } +} + +func (c *Hooks) afterListResourceTemplates(ctx context.Context, id any, message *mcp.ListResourceTemplatesRequest, result *mcp.ListResourceTemplatesResult) { + c.onSuccess(ctx, id, mcp.MethodResourcesTemplatesList, message, result) + if c == nil { + return + } + for _, hook := range c.OnAfterListResourceTemplates { + hook(ctx, id, message, result) + } +} +func (c *Hooks) AddBeforeReadResource(hook OnBeforeReadResourceFunc) { + c.OnBeforeReadResource = append(c.OnBeforeReadResource, hook) +} + +func (c *Hooks) AddAfterReadResource(hook OnAfterReadResourceFunc) { + c.OnAfterReadResource = append(c.OnAfterReadResource, hook) +} + +func (c *Hooks) beforeReadResource(ctx context.Context, id any, message *mcp.ReadResourceRequest) { + c.beforeAny(ctx, id, mcp.MethodResourcesRead, message) + if c == nil { + return + } + for _, hook := range c.OnBeforeReadResource { + hook(ctx, id, message) + } +} + +func (c *Hooks) afterReadResource(ctx context.Context, id any, message *mcp.ReadResourceRequest, result *mcp.ReadResourceResult) { + c.onSuccess(ctx, id, mcp.MethodResourcesRead, message, result) + if c == nil { + return + } + for _, hook := range c.OnAfterReadResource { + hook(ctx, id, message, result) + } +} +func (c *Hooks) AddBeforeListPrompts(hook OnBeforeListPromptsFunc) { + c.OnBeforeListPrompts = append(c.OnBeforeListPrompts, hook) +} + +func (c *Hooks) AddAfterListPrompts(hook OnAfterListPromptsFunc) { + c.OnAfterListPrompts = append(c.OnAfterListPrompts, hook) +} + +func (c *Hooks) beforeListPrompts(ctx context.Context, id any, message *mcp.ListPromptsRequest) { + c.beforeAny(ctx, id, mcp.MethodPromptsList, message) + if c == nil { + return + } + for _, hook := range c.OnBeforeListPrompts { + hook(ctx, id, message) + } +} + +func (c *Hooks) afterListPrompts(ctx context.Context, id any, message *mcp.ListPromptsRequest, result *mcp.ListPromptsResult) { + c.onSuccess(ctx, id, mcp.MethodPromptsList, message, result) + if c == nil { + return + } + for _, hook := range c.OnAfterListPrompts { + hook(ctx, id, message, result) + } +} +func (c *Hooks) AddBeforeGetPrompt(hook OnBeforeGetPromptFunc) { + c.OnBeforeGetPrompt = append(c.OnBeforeGetPrompt, hook) +} + +func (c *Hooks) AddAfterGetPrompt(hook OnAfterGetPromptFunc) { + c.OnAfterGetPrompt = append(c.OnAfterGetPrompt, hook) +} + +func (c *Hooks) beforeGetPrompt(ctx context.Context, id any, message *mcp.GetPromptRequest) { + c.beforeAny(ctx, id, mcp.MethodPromptsGet, message) + if c == nil { + return + } + for _, hook := range c.OnBeforeGetPrompt { + hook(ctx, id, message) + } +} + +func (c *Hooks) afterGetPrompt(ctx context.Context, id any, message *mcp.GetPromptRequest, result *mcp.GetPromptResult) { + c.onSuccess(ctx, id, mcp.MethodPromptsGet, message, result) + if c == nil { + return + } + for _, hook := range c.OnAfterGetPrompt { + hook(ctx, id, message, result) + } +} +func (c *Hooks) AddBeforeListTools(hook OnBeforeListToolsFunc) { + c.OnBeforeListTools = append(c.OnBeforeListTools, hook) +} + +func (c *Hooks) AddAfterListTools(hook OnAfterListToolsFunc) { + c.OnAfterListTools = append(c.OnAfterListTools, hook) +} + +func (c *Hooks) beforeListTools(ctx context.Context, id any, message *mcp.ListToolsRequest) { + c.beforeAny(ctx, id, mcp.MethodToolsList, message) + if c == nil { + return + } + for _, hook := range c.OnBeforeListTools { + hook(ctx, id, message) + } +} + +func (c *Hooks) afterListTools(ctx context.Context, id any, message *mcp.ListToolsRequest, result *mcp.ListToolsResult) { + c.onSuccess(ctx, id, mcp.MethodToolsList, message, result) + if c == nil { + return + } + for _, hook := range c.OnAfterListTools { + hook(ctx, id, message, result) + } +} +func (c *Hooks) AddBeforeCallTool(hook OnBeforeCallToolFunc) { + c.OnBeforeCallTool = append(c.OnBeforeCallTool, hook) +} + +func (c *Hooks) AddAfterCallTool(hook OnAfterCallToolFunc) { + c.OnAfterCallTool = append(c.OnAfterCallTool, hook) +} + +func (c *Hooks) beforeCallTool(ctx context.Context, id any, message *mcp.CallToolRequest) { + c.beforeAny(ctx, id, mcp.MethodToolsCall, message) + if c == nil { + return + } + for _, hook := range c.OnBeforeCallTool { + hook(ctx, id, message) + } +} + +func (c *Hooks) afterCallTool(ctx context.Context, id any, message *mcp.CallToolRequest, result *mcp.CallToolResult) { + c.onSuccess(ctx, id, mcp.MethodToolsCall, message, result) + if c == nil { + return + } + for _, hook := range c.OnAfterCallTool { + hook(ctx, id, message, result) + } +} diff --git a/vendor/github.com/mark3labs/mcp-go/server/request_handler.go b/vendor/github.com/mark3labs/mcp-go/server/request_handler.go new file mode 100644 index 0000000000..946ca7abd3 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/server/request_handler.go @@ -0,0 +1,279 @@ +// Code generated by `go generate`. DO NOT EDIT. +// source: server/internal/gen/request_handler.go.tmpl +package server + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/mark3labs/mcp-go/mcp" +) + +// HandleMessage processes an incoming JSON-RPC message and returns an appropriate response +func (s *MCPServer) HandleMessage( + ctx context.Context, + message json.RawMessage, +) mcp.JSONRPCMessage { + // Add server to context + ctx = context.WithValue(ctx, serverKey{}, s) + var err *requestError + + var baseMessage struct { + JSONRPC string `json:"jsonrpc"` + Method mcp.MCPMethod `json:"method"` + ID any `json:"id,omitempty"` + } + + if err := json.Unmarshal(message, &baseMessage); err != nil { + return createErrorResponse( + nil, + mcp.PARSE_ERROR, + "Failed to parse message", + ) + } + + // Check for valid JSONRPC version + if baseMessage.JSONRPC != mcp.JSONRPC_VERSION { + return createErrorResponse( + baseMessage.ID, + mcp.INVALID_REQUEST, + "Invalid JSON-RPC version", + ) + } + + if baseMessage.ID == nil { + var notification mcp.JSONRPCNotification + if err := json.Unmarshal(message, ¬ification); err != nil { + return createErrorResponse( + nil, + mcp.PARSE_ERROR, + "Failed to parse notification", + ) + } + s.handleNotification(ctx, notification) + return nil // Return nil for notifications + } + + switch baseMessage.Method { + case mcp.MethodInitialize: + var request mcp.InitializeRequest + var result *mcp.InitializeResult + if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.INVALID_REQUEST, + err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + } + } else { + s.hooks.beforeInitialize(ctx, baseMessage.ID, &request) + result, err = s.handleInitialize(ctx, baseMessage.ID, request) + } + if err != nil { + s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) + return err.ToJSONRPCError() + } + s.hooks.afterInitialize(ctx, baseMessage.ID, &request, result) + return createResponse(baseMessage.ID, *result) + case mcp.MethodPing: + var request mcp.PingRequest + var result *mcp.EmptyResult + if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.INVALID_REQUEST, + err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + } + } else { + s.hooks.beforePing(ctx, baseMessage.ID, &request) + result, err = s.handlePing(ctx, baseMessage.ID, request) + } + if err != nil { + s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) + return err.ToJSONRPCError() + } + s.hooks.afterPing(ctx, baseMessage.ID, &request, result) + return createResponse(baseMessage.ID, *result) + case mcp.MethodResourcesList: + var request mcp.ListResourcesRequest + var result *mcp.ListResourcesResult + if s.capabilities.resources == nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.METHOD_NOT_FOUND, + err: fmt.Errorf("resources %w", ErrUnsupported), + } + } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.INVALID_REQUEST, + err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + } + } else { + s.hooks.beforeListResources(ctx, baseMessage.ID, &request) + result, err = s.handleListResources(ctx, baseMessage.ID, request) + } + if err != nil { + s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) + return err.ToJSONRPCError() + } + s.hooks.afterListResources(ctx, baseMessage.ID, &request, result) + return createResponse(baseMessage.ID, *result) + case mcp.MethodResourcesTemplatesList: + var request mcp.ListResourceTemplatesRequest + var result *mcp.ListResourceTemplatesResult + if s.capabilities.resources == nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.METHOD_NOT_FOUND, + err: fmt.Errorf("resources %w", ErrUnsupported), + } + } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.INVALID_REQUEST, + err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + } + } else { + s.hooks.beforeListResourceTemplates(ctx, baseMessage.ID, &request) + result, err = s.handleListResourceTemplates(ctx, baseMessage.ID, request) + } + if err != nil { + s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) + return err.ToJSONRPCError() + } + s.hooks.afterListResourceTemplates(ctx, baseMessage.ID, &request, result) + return createResponse(baseMessage.ID, *result) + case mcp.MethodResourcesRead: + var request mcp.ReadResourceRequest + var result *mcp.ReadResourceResult + if s.capabilities.resources == nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.METHOD_NOT_FOUND, + err: fmt.Errorf("resources %w", ErrUnsupported), + } + } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.INVALID_REQUEST, + err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + } + } else { + s.hooks.beforeReadResource(ctx, baseMessage.ID, &request) + result, err = s.handleReadResource(ctx, baseMessage.ID, request) + } + if err != nil { + s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) + return err.ToJSONRPCError() + } + s.hooks.afterReadResource(ctx, baseMessage.ID, &request, result) + return createResponse(baseMessage.ID, *result) + case mcp.MethodPromptsList: + var request mcp.ListPromptsRequest + var result *mcp.ListPromptsResult + if s.capabilities.prompts == nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.METHOD_NOT_FOUND, + err: fmt.Errorf("prompts %w", ErrUnsupported), + } + } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.INVALID_REQUEST, + err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + } + } else { + s.hooks.beforeListPrompts(ctx, baseMessage.ID, &request) + result, err = s.handleListPrompts(ctx, baseMessage.ID, request) + } + if err != nil { + s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) + return err.ToJSONRPCError() + } + s.hooks.afterListPrompts(ctx, baseMessage.ID, &request, result) + return createResponse(baseMessage.ID, *result) + case mcp.MethodPromptsGet: + var request mcp.GetPromptRequest + var result *mcp.GetPromptResult + if s.capabilities.prompts == nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.METHOD_NOT_FOUND, + err: fmt.Errorf("prompts %w", ErrUnsupported), + } + } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.INVALID_REQUEST, + err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + } + } else { + s.hooks.beforeGetPrompt(ctx, baseMessage.ID, &request) + result, err = s.handleGetPrompt(ctx, baseMessage.ID, request) + } + if err != nil { + s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) + return err.ToJSONRPCError() + } + s.hooks.afterGetPrompt(ctx, baseMessage.ID, &request, result) + return createResponse(baseMessage.ID, *result) + case mcp.MethodToolsList: + var request mcp.ListToolsRequest + var result *mcp.ListToolsResult + if s.capabilities.tools == nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.METHOD_NOT_FOUND, + err: fmt.Errorf("tools %w", ErrUnsupported), + } + } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.INVALID_REQUEST, + err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + } + } else { + s.hooks.beforeListTools(ctx, baseMessage.ID, &request) + result, err = s.handleListTools(ctx, baseMessage.ID, request) + } + if err != nil { + s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) + return err.ToJSONRPCError() + } + s.hooks.afterListTools(ctx, baseMessage.ID, &request, result) + return createResponse(baseMessage.ID, *result) + case mcp.MethodToolsCall: + var request mcp.CallToolRequest + var result *mcp.CallToolResult + if s.capabilities.tools == nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.METHOD_NOT_FOUND, + err: fmt.Errorf("tools %w", ErrUnsupported), + } + } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.INVALID_REQUEST, + err: &UnparseableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + } + } else { + s.hooks.beforeCallTool(ctx, baseMessage.ID, &request) + result, err = s.handleToolCall(ctx, baseMessage.ID, request) + } + if err != nil { + s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) + return err.ToJSONRPCError() + } + s.hooks.afterCallTool(ctx, baseMessage.ID, &request, result) + return createResponse(baseMessage.ID, *result) + default: + return createErrorResponse( + baseMessage.ID, + mcp.METHOD_NOT_FOUND, + fmt.Sprintf("Method %s not found", baseMessage.Method), + ) + } +} diff --git a/vendor/github.com/mark3labs/mcp-go/server/server.go b/vendor/github.com/mark3labs/mcp-go/server/server.go new file mode 100644 index 0000000000..ec4fcef006 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/server/server.go @@ -0,0 +1,768 @@ +// Package server provides MCP (Model Control Protocol) server implementations. +package server + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "sort" + "sync" + + "github.com/mark3labs/mcp-go/mcp" +) + +// resourceEntry holds both a resource and its handler +type resourceEntry struct { + resource mcp.Resource + handler ResourceHandlerFunc +} + +// resourceTemplateEntry holds both a template and its handler +type resourceTemplateEntry struct { + template mcp.ResourceTemplate + handler ResourceTemplateHandlerFunc +} + +// ServerOption is a function that configures an MCPServer. +type ServerOption func(*MCPServer) + +// ResourceHandlerFunc is a function that returns resource contents. +type ResourceHandlerFunc func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) + +// ResourceTemplateHandlerFunc is a function that returns a resource template. +type ResourceTemplateHandlerFunc func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) + +// PromptHandlerFunc handles prompt requests with given arguments. +type PromptHandlerFunc func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) + +// ToolHandlerFunc handles tool calls with given arguments. +type ToolHandlerFunc func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) + +// ServerTool combines a Tool with its ToolHandlerFunc. +type ServerTool struct { + Tool mcp.Tool + Handler ToolHandlerFunc +} + +// ClientSession represents an active session that can be used by MCPServer to interact with client. +type ClientSession interface { + // Initialize marks session as fully initialized and ready for notifications + Initialize() + // Initialized returns if session is ready to accept notifications + Initialized() bool + // NotificationChannel provides a channel suitable for sending notifications to client. + NotificationChannel() chan<- mcp.JSONRPCNotification + // SessionID is a unique identifier used to track user session. + SessionID() string +} + +// clientSessionKey is the context key for storing current client notification channel. +type clientSessionKey struct{} + +// ClientSessionFromContext retrieves current client notification context from context. +func ClientSessionFromContext(ctx context.Context) ClientSession { + if session, ok := ctx.Value(clientSessionKey{}).(ClientSession); ok { + return session + } + return nil +} + +// UnparseableMessageError is attached to the RequestError when json.Unmarshal +// fails on the request. +type UnparseableMessageError struct { + message json.RawMessage + method mcp.MCPMethod + err error +} + +func (e *UnparseableMessageError) Error() string { + return fmt.Sprintf("unparseable %s request: %s", e.method, e.err) +} + +func (e *UnparseableMessageError) Unwrap() error { + return e.err +} + +func (e *UnparseableMessageError) GetMessage() json.RawMessage { + return e.message +} + +func (e *UnparseableMessageError) GetMethod() mcp.MCPMethod { + return e.method +} + +// RequestError is an error that can be converted to a JSON-RPC error. +// Implements Unwrap() to allow inspecting the error chain. +type requestError struct { + id any + code int + err error +} + +func (e *requestError) Error() string { + return fmt.Sprintf("request error: %s", e.err) +} + +func (e *requestError) ToJSONRPCError() mcp.JSONRPCError { + return mcp.JSONRPCError{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: e.id, + Error: struct { + Code int `json:"code"` + Message string `json:"message"` + Data any `json:"data,omitempty"` + }{ + Code: e.code, + Message: e.err.Error(), + }, + } +} + +func (e *requestError) Unwrap() error { + return e.err +} + +var ( + ErrUnsupported = errors.New("not supported") + ErrResourceNotFound = errors.New("resource not found") + ErrPromptNotFound = errors.New("prompt not found") + ErrToolNotFound = errors.New("tool not found") +) + +// NotificationHandlerFunc handles incoming notifications. +type NotificationHandlerFunc func(ctx context.Context, notification mcp.JSONRPCNotification) + +// MCPServer implements a Model Control Protocol server that can handle various types of requests +// including resources, prompts, and tools. +type MCPServer struct { + mu sync.RWMutex // Add mutex for protecting shared resources + name string + version string + instructions string + resources map[string]resourceEntry + resourceTemplates map[string]resourceTemplateEntry + prompts map[string]mcp.Prompt + promptHandlers map[string]PromptHandlerFunc + tools map[string]ServerTool + notificationHandlers map[string]NotificationHandlerFunc + capabilities serverCapabilities + sessions sync.Map + hooks *Hooks +} + +// serverKey is the context key for storing the server instance +type serverKey struct{} + +// ServerFromContext retrieves the MCPServer instance from a context +func ServerFromContext(ctx context.Context) *MCPServer { + if srv, ok := ctx.Value(serverKey{}).(*MCPServer); ok { + return srv + } + return nil +} + +// WithContext sets the current client session and returns the provided context +func (s *MCPServer) WithContext( + ctx context.Context, + session ClientSession, +) context.Context { + return context.WithValue(ctx, clientSessionKey{}, session) +} + +// RegisterSession saves session that should be notified in case if some server attributes changed. +func (s *MCPServer) RegisterSession( + ctx context.Context, + session ClientSession, +) error { + sessionID := session.SessionID() + if _, exists := s.sessions.LoadOrStore(sessionID, session); exists { + return fmt.Errorf("session %s is already registered", sessionID) + } + s.hooks.RegisterSession(ctx, session) + return nil +} + +// UnregisterSession removes from storage session that is shut down. +func (s *MCPServer) UnregisterSession( + sessionID string, +) { + s.sessions.Delete(sessionID) +} + +// sendNotificationToAllClients sends a notification to all the currently active clients. +func (s *MCPServer) sendNotificationToAllClients( + method string, + params map[string]any, +) { + notification := mcp.JSONRPCNotification{ + JSONRPC: mcp.JSONRPC_VERSION, + Notification: mcp.Notification{ + Method: method, + Params: mcp.NotificationParams{ + AdditionalFields: params, + }, + }, + } + + s.sessions.Range(func(k, v any) bool { + if session, ok := v.(ClientSession); ok && session.Initialized() { + select { + case session.NotificationChannel() <- notification: + default: + // TODO: log blocked channel in the future versions + } + } + return true + }) +} + +// SendNotificationToClient sends a notification to the current client +func (s *MCPServer) SendNotificationToClient( + ctx context.Context, + method string, + params map[string]any, +) error { + session := ClientSessionFromContext(ctx) + if session == nil || !session.Initialized() { + return fmt.Errorf("notification channel not initialized") + } + + notification := mcp.JSONRPCNotification{ + JSONRPC: mcp.JSONRPC_VERSION, + Notification: mcp.Notification{ + Method: method, + Params: mcp.NotificationParams{ + AdditionalFields: params, + }, + }, + } + + select { + case session.NotificationChannel() <- notification: + return nil + default: + return fmt.Errorf("notification channel full or blocked") + } +} + +// serverCapabilities defines the supported features of the MCP server +type serverCapabilities struct { + tools *toolCapabilities + resources *resourceCapabilities + prompts *promptCapabilities + logging bool +} + +// resourceCapabilities defines the supported resource-related features +type resourceCapabilities struct { + subscribe bool + listChanged bool +} + +// promptCapabilities defines the supported prompt-related features +type promptCapabilities struct { + listChanged bool +} + +// toolCapabilities defines the supported tool-related features +type toolCapabilities struct { + listChanged bool +} + +// WithResourceCapabilities configures resource-related server capabilities +func WithResourceCapabilities(subscribe, listChanged bool) ServerOption { + return func(s *MCPServer) { + // Always create a non-nil capability object + s.capabilities.resources = &resourceCapabilities{ + subscribe: subscribe, + listChanged: listChanged, + } + } +} + +// WithHooks allows adding hooks that will be called before or after +// either [all] requests or before / after specific request methods, or else +// prior to returning an error to the client. +func WithHooks(hooks *Hooks) ServerOption { + return func(s *MCPServer) { + s.hooks = hooks + } +} + +// WithPromptCapabilities configures prompt-related server capabilities +func WithPromptCapabilities(listChanged bool) ServerOption { + return func(s *MCPServer) { + // Always create a non-nil capability object + s.capabilities.prompts = &promptCapabilities{ + listChanged: listChanged, + } + } +} + +// WithToolCapabilities configures tool-related server capabilities +func WithToolCapabilities(listChanged bool) ServerOption { + return func(s *MCPServer) { + // Always create a non-nil capability object + s.capabilities.tools = &toolCapabilities{ + listChanged: listChanged, + } + } +} + +// WithLogging enables logging capabilities for the server +func WithLogging() ServerOption { + return func(s *MCPServer) { + s.capabilities.logging = true + } +} + +// WithInstructions sets the server instructions for the client returned in the initialize response +func WithInstructions(instructions string) ServerOption { + return func(s *MCPServer) { + s.instructions = instructions + } +} + +// NewMCPServer creates a new MCP server instance with the given name, version and options +func NewMCPServer( + name, version string, + opts ...ServerOption, +) *MCPServer { + s := &MCPServer{ + resources: make(map[string]resourceEntry), + resourceTemplates: make(map[string]resourceTemplateEntry), + prompts: make(map[string]mcp.Prompt), + promptHandlers: make(map[string]PromptHandlerFunc), + tools: make(map[string]ServerTool), + name: name, + version: version, + notificationHandlers: make(map[string]NotificationHandlerFunc), + capabilities: serverCapabilities{ + tools: nil, + resources: nil, + prompts: nil, + logging: false, + }, + } + + for _, opt := range opts { + opt(s) + } + + return s +} + +// AddResource registers a new resource and its handler +func (s *MCPServer) AddResource( + resource mcp.Resource, + handler ResourceHandlerFunc, +) { + if s.capabilities.resources == nil { + s.capabilities.resources = &resourceCapabilities{} + } + s.mu.Lock() + defer s.mu.Unlock() + s.resources[resource.URI] = resourceEntry{ + resource: resource, + handler: handler, + } +} + +// AddResourceTemplate registers a new resource template and its handler +func (s *MCPServer) AddResourceTemplate( + template mcp.ResourceTemplate, + handler ResourceTemplateHandlerFunc, +) { + if s.capabilities.resources == nil { + s.capabilities.resources = &resourceCapabilities{} + } + s.mu.Lock() + defer s.mu.Unlock() + s.resourceTemplates[template.URITemplate.Raw()] = resourceTemplateEntry{ + template: template, + handler: handler, + } +} + +// AddPrompt registers a new prompt handler with the given name +func (s *MCPServer) AddPrompt(prompt mcp.Prompt, handler PromptHandlerFunc) { + if s.capabilities.prompts == nil { + s.capabilities.prompts = &promptCapabilities{} + } + s.mu.Lock() + defer s.mu.Unlock() + s.prompts[prompt.Name] = prompt + s.promptHandlers[prompt.Name] = handler +} + +// AddTool registers a new tool and its handler +func (s *MCPServer) AddTool(tool mcp.Tool, handler ToolHandlerFunc) { + s.AddTools(ServerTool{Tool: tool, Handler: handler}) +} + +// AddTools registers multiple tools at once +func (s *MCPServer) AddTools(tools ...ServerTool) { + if s.capabilities.tools == nil { + s.capabilities.tools = &toolCapabilities{} + } + s.mu.Lock() + for _, entry := range tools { + s.tools[entry.Tool.Name] = entry + } + s.mu.Unlock() + + // Send notification to all initialized sessions + s.sendNotificationToAllClients("notifications/tools/list_changed", nil) +} + +// SetTools replaces all existing tools with the provided list +func (s *MCPServer) SetTools(tools ...ServerTool) { + s.mu.Lock() + s.tools = make(map[string]ServerTool) + s.mu.Unlock() + s.AddTools(tools...) +} + +// DeleteTools removes a tool from the server +func (s *MCPServer) DeleteTools(names ...string) { + s.mu.Lock() + for _, name := range names { + delete(s.tools, name) + } + s.mu.Unlock() + + // Send notification to all initialized sessions + s.sendNotificationToAllClients("notifications/tools/list_changed", nil) +} + +// AddNotificationHandler registers a new handler for incoming notifications +func (s *MCPServer) AddNotificationHandler( + method string, + handler NotificationHandlerFunc, +) { + s.mu.Lock() + defer s.mu.Unlock() + s.notificationHandlers[method] = handler +} + +func (s *MCPServer) handleInitialize( + ctx context.Context, + id interface{}, + request mcp.InitializeRequest, +) (*mcp.InitializeResult, *requestError) { + capabilities := mcp.ServerCapabilities{} + + // Only add resource capabilities if they're configured + if s.capabilities.resources != nil { + capabilities.Resources = &struct { + Subscribe bool `json:"subscribe,omitempty"` + ListChanged bool `json:"listChanged,omitempty"` + }{ + Subscribe: s.capabilities.resources.subscribe, + ListChanged: s.capabilities.resources.listChanged, + } + } + + // Only add prompt capabilities if they're configured + if s.capabilities.prompts != nil { + capabilities.Prompts = &struct { + ListChanged bool `json:"listChanged,omitempty"` + }{ + ListChanged: s.capabilities.prompts.listChanged, + } + } + + // Only add tool capabilities if they're configured + if s.capabilities.tools != nil { + capabilities.Tools = &struct { + ListChanged bool `json:"listChanged,omitempty"` + }{ + ListChanged: s.capabilities.tools.listChanged, + } + } + + if s.capabilities.logging { + capabilities.Logging = &struct{}{} + } + + result := mcp.InitializeResult{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + ServerInfo: mcp.Implementation{ + Name: s.name, + Version: s.version, + }, + Capabilities: capabilities, + Instructions: s.instructions, + } + + if session := ClientSessionFromContext(ctx); session != nil { + session.Initialize() + } + return &result, nil +} + +func (s *MCPServer) handlePing( + ctx context.Context, + id interface{}, + request mcp.PingRequest, +) (*mcp.EmptyResult, *requestError) { + return &mcp.EmptyResult{}, nil +} + +func (s *MCPServer) handleListResources( + ctx context.Context, + id interface{}, + request mcp.ListResourcesRequest, +) (*mcp.ListResourcesResult, *requestError) { + s.mu.RLock() + resources := make([]mcp.Resource, 0, len(s.resources)) + for _, entry := range s.resources { + resources = append(resources, entry.resource) + } + s.mu.RUnlock() + + result := mcp.ListResourcesResult{ + Resources: resources, + } + if request.Params.Cursor != "" { + result.NextCursor = "" // Handle pagination if needed + } + return &result, nil +} + +func (s *MCPServer) handleListResourceTemplates( + ctx context.Context, + id interface{}, + request mcp.ListResourceTemplatesRequest, +) (*mcp.ListResourceTemplatesResult, *requestError) { + s.mu.RLock() + templates := make([]mcp.ResourceTemplate, 0, len(s.resourceTemplates)) + for _, entry := range s.resourceTemplates { + templates = append(templates, entry.template) + } + s.mu.RUnlock() + + result := mcp.ListResourceTemplatesResult{ + ResourceTemplates: templates, + } + if request.Params.Cursor != "" { + result.NextCursor = "" // Handle pagination if needed + } + return &result, nil +} + +func (s *MCPServer) handleReadResource( + ctx context.Context, + id interface{}, + request mcp.ReadResourceRequest, +) (*mcp.ReadResourceResult, *requestError) { + s.mu.RLock() + // First try direct resource handlers + if entry, ok := s.resources[request.Params.URI]; ok { + handler := entry.handler + s.mu.RUnlock() + contents, err := handler(ctx, request) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INTERNAL_ERROR, + err: err, + } + } + return &mcp.ReadResourceResult{Contents: contents}, nil + } + + // If no direct handler found, try matching against templates + var matchedHandler ResourceTemplateHandlerFunc + var matched bool + for _, entry := range s.resourceTemplates { + template := entry.template + if matchesTemplate(request.Params.URI, template.URITemplate) { + matchedHandler = entry.handler + matched = true + matchedVars := template.URITemplate.Match(request.Params.URI) + // Convert matched variables to a map + request.Params.Arguments = make(map[string]interface{}) + for name, value := range matchedVars { + request.Params.Arguments[name] = value.V + } + break + } + } + s.mu.RUnlock() + + if matched { + contents, err := matchedHandler(ctx, request) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INTERNAL_ERROR, + err: err, + } + } + return &mcp.ReadResourceResult{Contents: contents}, nil + } + + return nil, &requestError{ + id: id, + code: mcp.INVALID_PARAMS, + err: fmt.Errorf("handler not found for resource URI '%s': %w", request.Params.URI, ErrResourceNotFound), + } +} + +// matchesTemplate checks if a URI matches a URI template pattern +func matchesTemplate(uri string, template *mcp.URITemplate) bool { + return template.Regexp().MatchString(uri) +} + +func (s *MCPServer) handleListPrompts( + ctx context.Context, + id interface{}, + request mcp.ListPromptsRequest, +) (*mcp.ListPromptsResult, *requestError) { + s.mu.RLock() + prompts := make([]mcp.Prompt, 0, len(s.prompts)) + for _, prompt := range s.prompts { + prompts = append(prompts, prompt) + } + s.mu.RUnlock() + + result := mcp.ListPromptsResult{ + Prompts: prompts, + } + if request.Params.Cursor != "" { + result.NextCursor = "" // Handle pagination if needed + } + return &result, nil +} + +func (s *MCPServer) handleGetPrompt( + ctx context.Context, + id interface{}, + request mcp.GetPromptRequest, +) (*mcp.GetPromptResult, *requestError) { + s.mu.RLock() + handler, ok := s.promptHandlers[request.Params.Name] + s.mu.RUnlock() + + if !ok { + return nil, &requestError{ + id: id, + code: mcp.INVALID_PARAMS, + err: fmt.Errorf("prompt '%s' not found: %w", request.Params.Name, ErrPromptNotFound), + } + } + + result, err := handler(ctx, request) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INTERNAL_ERROR, + err: err, + } + } + + return result, nil +} + +func (s *MCPServer) handleListTools( + ctx context.Context, + id interface{}, + request mcp.ListToolsRequest, +) (*mcp.ListToolsResult, *requestError) { + s.mu.RLock() + tools := make([]mcp.Tool, 0, len(s.tools)) + + // Get all tool names for consistent ordering + toolNames := make([]string, 0, len(s.tools)) + for name := range s.tools { + toolNames = append(toolNames, name) + } + + // Sort the tool names for consistent ordering + sort.Strings(toolNames) + + // Add tools in sorted order + for _, name := range toolNames { + tools = append(tools, s.tools[name].Tool) + } + s.mu.RUnlock() + + result := mcp.ListToolsResult{ + Tools: tools, + } + if request.Params.Cursor != "" { + result.NextCursor = "" // Handle pagination if needed + } + return &result, nil +} +func (s *MCPServer) handleToolCall( + ctx context.Context, + id interface{}, + request mcp.CallToolRequest, +) (*mcp.CallToolResult, *requestError) { + s.mu.RLock() + tool, ok := s.tools[request.Params.Name] + s.mu.RUnlock() + + if !ok { + return nil, &requestError{ + id: id, + code: mcp.INVALID_PARAMS, + err: fmt.Errorf("tool '%s' not found: %w", request.Params.Name, ErrToolNotFound), + } + } + + result, err := tool.Handler(ctx, request) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INTERNAL_ERROR, + err: err, + } + } + + return result, nil +} + +func (s *MCPServer) handleNotification( + ctx context.Context, + notification mcp.JSONRPCNotification, +) mcp.JSONRPCMessage { + s.mu.RLock() + handler, ok := s.notificationHandlers[notification.Method] + s.mu.RUnlock() + + if ok { + handler(ctx, notification) + } + return nil +} + +func createResponse(id interface{}, result interface{}) mcp.JSONRPCMessage { + return mcp.JSONRPCResponse{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: id, + Result: result, + } +} + +func createErrorResponse( + id interface{}, + code int, + message string, +) mcp.JSONRPCMessage { + return mcp.JSONRPCError{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: id, + Error: struct { + Code int `json:"code"` + Message string `json:"message"` + Data interface{} `json:"data,omitempty"` + }{ + Code: code, + Message: message, + }, + } +} diff --git a/vendor/github.com/mark3labs/mcp-go/server/sse.go b/vendor/github.com/mark3labs/mcp-go/server/sse.go new file mode 100644 index 0000000000..6e6a13fe78 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/server/sse.go @@ -0,0 +1,433 @@ +package server + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "sync" + "sync/atomic" + + "github.com/google/uuid" + "github.com/mark3labs/mcp-go/mcp" +) + +// sseSession represents an active SSE connection. +type sseSession struct { + writer http.ResponseWriter + flusher http.Flusher + done chan struct{} + eventQueue chan string // Channel for queuing events + sessionID string + notificationChannel chan mcp.JSONRPCNotification + initialized atomic.Bool +} + +// SSEContextFunc is a function that takes an existing context and the current +// request and returns a potentially modified context based on the request +// content. This can be used to inject context values from headers, for example. +type SSEContextFunc func(ctx context.Context, r *http.Request) context.Context + +func (s *sseSession) SessionID() string { + return s.sessionID +} + +func (s *sseSession) NotificationChannel() chan<- mcp.JSONRPCNotification { + return s.notificationChannel +} + +func (s *sseSession) Initialize() { + s.initialized.Store(true) +} + +func (s *sseSession) Initialized() bool { + return s.initialized.Load() +} + +var _ ClientSession = (*sseSession)(nil) + +// SSEServer implements a Server-Sent Events (SSE) based MCP server. +// It provides real-time communication capabilities over HTTP using the SSE protocol. +type SSEServer struct { + server *MCPServer + baseURL string + basePath string + messageEndpoint string + useFullURLForMessageEndpoint bool + sseEndpoint string + sessions sync.Map + srv *http.Server + contextFunc SSEContextFunc +} + +// SSEOption defines a function type for configuring SSEServer +type SSEOption func(*SSEServer) + +// WithBaseURL sets the base URL for the SSE server +func WithBaseURL(baseURL string) SSEOption { + return func(s *SSEServer) { + if baseURL != "" { + u, err := url.Parse(baseURL) + if err != nil { + return + } + if u.Scheme != "http" && u.Scheme != "https" { + return + } + // Check if the host is empty or only contains a port + if u.Host == "" || strings.HasPrefix(u.Host, ":") { + return + } + if len(u.Query()) > 0 { + return + } + } + s.baseURL = strings.TrimSuffix(baseURL, "/") + } +} + +// Add a new option for setting base path +func WithBasePath(basePath string) SSEOption { + return func(s *SSEServer) { + // Ensure the path starts with / and doesn't end with / + if !strings.HasPrefix(basePath, "/") { + basePath = "/" + basePath + } + s.basePath = strings.TrimSuffix(basePath, "/") + } +} + +// WithMessageEndpoint sets the message endpoint path +func WithMessageEndpoint(endpoint string) SSEOption { + return func(s *SSEServer) { + s.messageEndpoint = endpoint + } +} + +// WithUseFullURLForMessageEndpoint controls whether the SSE server returns a complete URL (including baseURL) +// or just the path portion for the message endpoint. Set to false when clients will concatenate +// the baseURL themselves to avoid malformed URLs like "http://localhost/mcphttp://localhost/mcp/message". +func WithUseFullURLForMessageEndpoint(useFullURLForMessageEndpoint bool) SSEOption { + return func(s *SSEServer) { + s.useFullURLForMessageEndpoint = useFullURLForMessageEndpoint + } +} + +// WithSSEEndpoint sets the SSE endpoint path +func WithSSEEndpoint(endpoint string) SSEOption { + return func(s *SSEServer) { + s.sseEndpoint = endpoint + } +} + +// WithHTTPServer sets the HTTP server instance +func WithHTTPServer(srv *http.Server) SSEOption { + return func(s *SSEServer) { + s.srv = srv + } +} + +// WithContextFunc sets a function that will be called to customise the context +// to the server using the incoming request. +func WithSSEContextFunc(fn SSEContextFunc) SSEOption { + return func(s *SSEServer) { + s.contextFunc = fn + } +} + +// NewSSEServer creates a new SSE server instance with the given MCP server and options. +func NewSSEServer(server *MCPServer, opts ...SSEOption) *SSEServer { + s := &SSEServer{ + server: server, + sseEndpoint: "/sse", + messageEndpoint: "/message", + useFullURLForMessageEndpoint: true, + } + + // Apply all options + for _, opt := range opts { + opt(s) + } + + return s +} + +// NewTestServer creates a test server for testing purposes +func NewTestServer(server *MCPServer, opts ...SSEOption) *httptest.Server { + sseServer := NewSSEServer(server) + for _, opt := range opts { + opt(sseServer) + } + + testServer := httptest.NewServer(sseServer) + sseServer.baseURL = testServer.URL + return testServer +} + +// Start begins serving SSE connections on the specified address. +// It sets up HTTP handlers for SSE and message endpoints. +func (s *SSEServer) Start(addr string) error { + s.srv = &http.Server{ + Addr: addr, + Handler: s, + } + + return s.srv.ListenAndServe() +} + +// Shutdown gracefully stops the SSE server, closing all active sessions +// and shutting down the HTTP server. +func (s *SSEServer) Shutdown(ctx context.Context) error { + if s.srv != nil { + s.sessions.Range(func(key, value interface{}) bool { + if session, ok := value.(*sseSession); ok { + close(session.done) + } + s.sessions.Delete(key) + return true + }) + + return s.srv.Shutdown(ctx) + } + return nil +} + +// handleSSE handles incoming SSE connection requests. +// It sets up appropriate headers and creates a new session for the client. +func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("Access-Control-Allow-Origin", "*") + + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "Streaming unsupported", http.StatusInternalServerError) + return + } + + sessionID := uuid.New().String() + session := &sseSession{ + writer: w, + flusher: flusher, + done: make(chan struct{}), + eventQueue: make(chan string, 100), // Buffer for events + sessionID: sessionID, + notificationChannel: make(chan mcp.JSONRPCNotification, 100), + } + + s.sessions.Store(sessionID, session) + defer s.sessions.Delete(sessionID) + + if err := s.server.RegisterSession(r.Context(), session); err != nil { + http.Error(w, fmt.Sprintf("Session registration failed: %v", err), http.StatusInternalServerError) + return + } + defer s.server.UnregisterSession(sessionID) + + // Start notification handler for this session + go func() { + for { + select { + case notification := <-session.notificationChannel: + eventData, err := json.Marshal(notification) + if err == nil { + select { + case session.eventQueue <- fmt.Sprintf("event: message\ndata: %s\n\n", eventData): + // Event queued successfully + case <-session.done: + return + } + } + case <-session.done: + return + case <-r.Context().Done(): + return + } + } + }() + + // Send the initial endpoint event + fmt.Fprintf(w, "event: endpoint\ndata: %s\r\n\r\n", s.GetMessageEndpointForClient(sessionID)) + flusher.Flush() + + // Main event loop - this runs in the HTTP handler goroutine + for { + select { + case event := <-session.eventQueue: + // Write the event to the response + fmt.Fprint(w, event) + flusher.Flush() + case <-r.Context().Done(): + close(session.done) + return + } + } +} + +// GetMessageEndpointForClient returns the appropriate message endpoint URL with session ID +// based on the useFullURLForMessageEndpoint configuration. +func (s *SSEServer) GetMessageEndpointForClient(sessionID string) string { + messageEndpoint := s.messageEndpoint + if s.useFullURLForMessageEndpoint { + messageEndpoint = s.CompleteMessageEndpoint() + } + return fmt.Sprintf("%s?sessionId=%s", messageEndpoint, sessionID) +} + +// handleMessage processes incoming JSON-RPC messages from clients and sends responses +// back through both the SSE connection and HTTP response. +func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + s.writeJSONRPCError(w, nil, mcp.INVALID_REQUEST, "Method not allowed") + return + } + + sessionID := r.URL.Query().Get("sessionId") + if sessionID == "" { + s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Missing sessionId") + return + } + + sessionI, ok := s.sessions.Load(sessionID) + if !ok { + s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Invalid session ID") + return + } + session := sessionI.(*sseSession) + + // Set the client context before handling the message + ctx := s.server.WithContext(r.Context(), session) + if s.contextFunc != nil { + ctx = s.contextFunc(ctx, r) + } + + // Parse message as raw JSON + var rawMessage json.RawMessage + if err := json.NewDecoder(r.Body).Decode(&rawMessage); err != nil { + s.writeJSONRPCError(w, nil, mcp.PARSE_ERROR, "Parse error") + return + } + + // Process message through MCPServer + response := s.server.HandleMessage(ctx, rawMessage) + + // Only send response if there is one (not for notifications) + if response != nil { + eventData, _ := json.Marshal(response) + + // Queue the event for sending via SSE + select { + case session.eventQueue <- fmt.Sprintf("event: message\ndata: %s\n\n", eventData): + // Event queued successfully + case <-session.done: + // Session is closed, don't try to queue + default: + // Queue is full, could log this + } + + // Send HTTP response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusAccepted) + json.NewEncoder(w).Encode(response) + } else { + // For notifications, just send 202 Accepted with no body + w.WriteHeader(http.StatusAccepted) + } +} + +// writeJSONRPCError writes a JSON-RPC error response with the given error details. +func (s *SSEServer) writeJSONRPCError( + w http.ResponseWriter, + id interface{}, + code int, + message string, +) { + response := createErrorResponse(id, code, message) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(response) +} + +// SendEventToSession sends an event to a specific SSE session identified by sessionID. +// Returns an error if the session is not found or closed. +func (s *SSEServer) SendEventToSession( + sessionID string, + event interface{}, +) error { + sessionI, ok := s.sessions.Load(sessionID) + if !ok { + return fmt.Errorf("session not found: %s", sessionID) + } + session := sessionI.(*sseSession) + + eventData, err := json.Marshal(event) + if err != nil { + return err + } + + // Queue the event for sending via SSE + select { + case session.eventQueue <- fmt.Sprintf("event: message\ndata: %s\n\n", eventData): + return nil + case <-session.done: + return fmt.Errorf("session closed") + default: + return fmt.Errorf("event queue full") + } +} +func (s *SSEServer) GetUrlPath(input string) (string, error) { + parse, err := url.Parse(input) + if err != nil { + return "", fmt.Errorf("failed to parse URL %s: %w", input, err) + } + return parse.Path, nil +} + +func (s *SSEServer) CompleteSseEndpoint() string { + return s.baseURL + s.basePath + s.sseEndpoint +} +func (s *SSEServer) CompleteSsePath() string { + path, err := s.GetUrlPath(s.CompleteSseEndpoint()) + if err != nil { + return s.basePath + s.sseEndpoint + } + return path +} + +func (s *SSEServer) CompleteMessageEndpoint() string { + return s.baseURL + s.basePath + s.messageEndpoint +} +func (s *SSEServer) CompleteMessagePath() string { + path, err := s.GetUrlPath(s.CompleteMessageEndpoint()) + if err != nil { + return s.basePath + s.messageEndpoint + } + return path +} + +// ServeHTTP implements the http.Handler interface. +func (s *SSEServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + path := r.URL.Path + // Use exact path matching rather than Contains + ssePath := s.CompleteSsePath() + if ssePath != "" && path == ssePath { + s.handleSSE(w, r) + return + } + messagePath := s.CompleteMessagePath() + if messagePath != "" && path == messagePath { + s.handleMessage(w, r) + return + } + + http.NotFound(w, r) +} diff --git a/vendor/github.com/mark3labs/mcp-go/server/stdio.go b/vendor/github.com/mark3labs/mcp-go/server/stdio.go new file mode 100644 index 0000000000..14c1e76e9a --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/server/stdio.go @@ -0,0 +1,283 @@ +package server + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "io" + "log" + "os" + "os/signal" + "sync/atomic" + "syscall" + + "github.com/mark3labs/mcp-go/mcp" +) + +// StdioContextFunc is a function that takes an existing context and returns +// a potentially modified context. +// This can be used to inject context values from environment variables, +// for example. +type StdioContextFunc func(ctx context.Context) context.Context + +// StdioServer wraps a MCPServer and handles stdio communication. +// It provides a simple way to create command-line MCP servers that +// communicate via standard input/output streams using JSON-RPC messages. +type StdioServer struct { + server *MCPServer + errLogger *log.Logger + contextFunc StdioContextFunc +} + +// StdioOption defines a function type for configuring StdioServer +type StdioOption func(*StdioServer) + +// WithErrorLogger sets the error logger for the server +func WithErrorLogger(logger *log.Logger) StdioOption { + return func(s *StdioServer) { + s.errLogger = logger + } +} + +// WithContextFunc sets a function that will be called to customise the context +// to the server. Note that the stdio server uses the same context for all requests, +// so this function will only be called once per server instance. +func WithStdioContextFunc(fn StdioContextFunc) StdioOption { + return func(s *StdioServer) { + s.contextFunc = fn + } +} + +// stdioSession is a static client session, since stdio has only one client. +type stdioSession struct { + notifications chan mcp.JSONRPCNotification + initialized atomic.Bool +} + +func (s *stdioSession) SessionID() string { + return "stdio" +} + +func (s *stdioSession) NotificationChannel() chan<- mcp.JSONRPCNotification { + return s.notifications +} + +func (s *stdioSession) Initialize() { + s.initialized.Store(true) +} + +func (s *stdioSession) Initialized() bool { + return s.initialized.Load() +} + +var _ ClientSession = (*stdioSession)(nil) + +var stdioSessionInstance = stdioSession{ + notifications: make(chan mcp.JSONRPCNotification, 100), +} + +// NewStdioServer creates a new stdio server wrapper around an MCPServer. +// It initializes the server with a default error logger that discards all output. +func NewStdioServer(server *MCPServer) *StdioServer { + return &StdioServer{ + server: server, + errLogger: log.New( + os.Stderr, + "", + log.LstdFlags, + ), // Default to discarding logs + } +} + +// SetErrorLogger configures where error messages from the StdioServer are logged. +// The provided logger will receive all error messages generated during server operation. +func (s *StdioServer) SetErrorLogger(logger *log.Logger) { + s.errLogger = logger +} + +// SetContextFunc sets a function that will be called to customise the context +// to the server. Note that the stdio server uses the same context for all requests, +// so this function will only be called once per server instance. +func (s *StdioServer) SetContextFunc(fn StdioContextFunc) { + s.contextFunc = fn +} + +// handleNotifications continuously processes notifications from the session's notification channel +// and writes them to the provided output. It runs until the context is cancelled. +// Any errors encountered while writing notifications are logged but do not stop the handler. +func (s *StdioServer) handleNotifications(ctx context.Context, stdout io.Writer) { + for { + select { + case notification := <-stdioSessionInstance.notifications: + if err := s.writeResponse(notification, stdout); err != nil { + s.errLogger.Printf("Error writing notification: %v", err) + } + case <-ctx.Done(): + return + } + } +} + +// processInputStream continuously reads and processes messages from the input stream. +// It handles EOF gracefully as a normal termination condition. +// The function returns when either: +// - The context is cancelled (returns context.Err()) +// - EOF is encountered (returns nil) +// - An error occurs while reading or processing messages (returns the error) +func (s *StdioServer) processInputStream(ctx context.Context, reader *bufio.Reader, stdout io.Writer) error { + for { + if err := ctx.Err(); err != nil { + return err + } + + line, err := s.readNextLine(ctx, reader) + if err != nil { + if err == io.EOF { + return nil + } + s.errLogger.Printf("Error reading input: %v", err) + return err + } + + if err := s.processMessage(ctx, line, stdout); err != nil { + if err == io.EOF { + return nil + } + s.errLogger.Printf("Error handling message: %v", err) + return err + } + } +} + +// readNextLine reads a single line from the input reader in a context-aware manner. +// It uses channels to make the read operation cancellable via context. +// Returns the read line and any error encountered. If the context is cancelled, +// returns an empty string and the context's error. EOF is returned when the input +// stream is closed. +func (s *StdioServer) readNextLine(ctx context.Context, reader *bufio.Reader) (string, error) { + readChan := make(chan string, 1) + errChan := make(chan error, 1) + defer func() { + close(readChan) + close(errChan) + }() + + go func() { + line, err := reader.ReadString('\n') + if err != nil { + errChan <- err + return + } + readChan <- line + }() + + select { + case <-ctx.Done(): + return "", ctx.Err() + case err := <-errChan: + return "", err + case line := <-readChan: + return line, nil + } +} + +// Listen starts listening for JSON-RPC messages on the provided input and writes responses to the provided output. +// It runs until the context is cancelled or an error occurs. +// Returns an error if there are issues with reading input or writing output. +func (s *StdioServer) Listen( + ctx context.Context, + stdin io.Reader, + stdout io.Writer, +) error { + // Set a static client context since stdio only has one client + if err := s.server.RegisterSession(ctx, &stdioSessionInstance); err != nil { + return fmt.Errorf("register session: %w", err) + } + defer s.server.UnregisterSession(stdioSessionInstance.SessionID()) + ctx = s.server.WithContext(ctx, &stdioSessionInstance) + + // Add in any custom context. + if s.contextFunc != nil { + ctx = s.contextFunc(ctx) + } + + reader := bufio.NewReader(stdin) + + // Start notification handler + go s.handleNotifications(ctx, stdout) + return s.processInputStream(ctx, reader, stdout) +} + +// processMessage handles a single JSON-RPC message and writes the response. +// It parses the message, processes it through the wrapped MCPServer, and writes any response. +// Returns an error if there are issues with message processing or response writing. +func (s *StdioServer) processMessage( + ctx context.Context, + line string, + writer io.Writer, +) error { + // Parse the message as raw JSON + var rawMessage json.RawMessage + if err := json.Unmarshal([]byte(line), &rawMessage); err != nil { + response := createErrorResponse(nil, mcp.PARSE_ERROR, "Parse error") + return s.writeResponse(response, writer) + } + + // Handle the message using the wrapped server + response := s.server.HandleMessage(ctx, rawMessage) + + // Only write response if there is one (not for notifications) + if response != nil { + if err := s.writeResponse(response, writer); err != nil { + return fmt.Errorf("failed to write response: %w", err) + } + } + + return nil +} + +// writeResponse marshals and writes a JSON-RPC response message followed by a newline. +// Returns an error if marshaling or writing fails. +func (s *StdioServer) writeResponse( + response mcp.JSONRPCMessage, + writer io.Writer, +) error { + responseBytes, err := json.Marshal(response) + if err != nil { + return err + } + + // Write response followed by newline + if _, err := fmt.Fprintf(writer, "%s\n", responseBytes); err != nil { + return err + } + + return nil +} + +// ServeStdio is a convenience function that creates and starts a StdioServer with os.Stdin and os.Stdout. +// It sets up signal handling for graceful shutdown on SIGTERM and SIGINT. +// Returns an error if the server encounters any issues during operation. +func ServeStdio(server *MCPServer, opts ...StdioOption) error { + s := NewStdioServer(server) + s.SetErrorLogger(log.New(os.Stderr, "", log.LstdFlags)) + + for _, opt := range opts { + opt(s) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Set up signal handling + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGTERM, syscall.SIGINT) + + go func() { + <-sigChan + cancel() + }() + + return s.Listen(ctx, os.Stdin, os.Stdout) +} diff --git a/vendor/github.com/shibukawa/configdir/README.rst b/vendor/github.com/shibukawa/configdir/README.rst deleted file mode 100644 index 99906697da..0000000000 --- a/vendor/github.com/shibukawa/configdir/README.rst +++ /dev/null @@ -1,111 +0,0 @@ -configdir for Golang -===================== - -Multi platform library of configuration directory for Golang. - -This library helps to get regular directories for configuration files or cache files that matches target operationg system's convention. - -It assumes the following folders are standard paths of each environment: - -.. list-table:: - :header-rows: 1 - - - * - * Windows: - * Linux/BSDs: - * MacOSX: - - * System level configuration folder - * ``%PROGRAMDATA%`` (``C:\\ProgramData``) - * ``${XDG_CONFIG_DIRS}`` (``/etc/xdg``) - * ``/Library/Application Support`` - - * User level configuration folder - * ``%APPDATA%`` (``C:\\Users\\\\AppData\\Roaming``) - * ``${XDG_CONFIG_HOME}`` (``${HOME}/.config``) - * ``${HOME}/Library/Application Support`` - - * User wide cache folder - * ``%LOCALAPPDATA%`` ``(C:\\Users\\\\AppData\\Local)`` - * ``${XDG_CACHE_HOME}`` (``${HOME}/.cache``) - * ``${HOME}/Library/Caches`` - -Examples ------------- - -Getting Configuration -~~~~~~~~~~~~~~~~~~~~~~~~ - -``configdir.ConfigDir.QueryFolderContainsFile()`` searches files in the following order: - -* Local path (if you add the path via LocalPath parameter) -* User level configuration folder(e.g. ``$HOME/.config///setting.json`` in Linux) -* System level configuration folder(e.g. ``/etc/xdg///setting.json`` in Linux) - -``configdir.Config`` provides some convenient methods(``ReadFile``, ``WriteFile`` and so on). - -.. code-block:: go - - var config Config - - configDirs := configdir.New("vendor-name", "application-name") - // optional: local path has the highest priority - configDirs.LocalPath, _ = filepath.Abs(".") - folder := configDirs.QueryFolderContainsFile("setting.json") - if folder != nil { - data, _ := folder.ReadFile("setting.json") - json.Unmarshal(data, &config) - } else { - config = DefaultConfig - } - -Write Configuration -~~~~~~~~~~~~~~~~~~~~~~ - -When storing configuration, get configuration folder by using ``configdir.ConfigDir.QueryFolders()`` method. - -.. code-block:: go - - configDirs := configdir.New("vendor-name", "application-name") - - var config Config - data, _ := json.Marshal(&config) - - // Stores to local folder - folders := configDirs.QueryFolders(configdir.Local) - folders[0].WriteFile("setting.json", data) - - // Stores to user folder - folders = configDirs.QueryFolders(configdir.Global) - folders[0].WriteFile("setting.json", data) - - // Stores to system folder - folders = configDirs.QueryFolders(configdir.System) - folders[0].WriteFile("setting.json", data) - -Getting Cache Folder -~~~~~~~~~~~~~~~~~~~~~~ - -It is similar to the above example, but returns cache folder. - -.. code-block:: go - - configDirs := configdir.New("vendor-name", "application-name") - cache := configDirs.QueryCacheFolder() - - resp, err := http.Get("http://examples.com/sdk.zip") - if err != nil { - log.Fatal(err) - } - defer resp.Body.Close() - body, err := ioutil.ReadAll(resp.Body) - - cache.WriteFile("sdk.zip", body) - -Document ------------- - -https://godoc.org/github.com/shibukawa/configdir - -License ------------- - -MIT - diff --git a/vendor/github.com/shibukawa/configdir/config.go b/vendor/github.com/shibukawa/configdir/config.go deleted file mode 100644 index 8a20e54b59..0000000000 --- a/vendor/github.com/shibukawa/configdir/config.go +++ /dev/null @@ -1,160 +0,0 @@ -// configdir provides access to configuration folder in each platforms. -// -// System wide configuration folders: -// -// - Windows: %PROGRAMDATA% (C:\ProgramData) -// - Linux/BSDs: ${XDG_CONFIG_DIRS} (/etc/xdg) -// - MacOSX: "/Library/Application Support" -// -// User wide configuration folders: -// -// - Windows: %APPDATA% (C:\Users\\AppData\Roaming) -// - Linux/BSDs: ${XDG_CONFIG_HOME} (${HOME}/.config) -// - MacOSX: "${HOME}/Library/Application Support" -// -// User wide cache folders: -// -// - Windows: %LOCALAPPDATA% (C:\Users\\AppData\Local) -// - Linux/BSDs: ${XDG_CACHE_HOME} (${HOME}/.cache) -// - MacOSX: "${HOME}/Library/Caches" -// -// configdir returns paths inside the above folders. - -package configdir - -import ( - "io/ioutil" - "os" - "path/filepath" -) - -type ConfigType int - -const ( - System ConfigType = iota - Global - All - Existing - Local - Cache -) - -// Config represents each folder -type Config struct { - Path string - Type ConfigType -} - -func (c Config) Open(fileName string) (*os.File, error) { - return os.Open(filepath.Join(c.Path, fileName)) -} - -func (c Config) Create(fileName string) (*os.File, error) { - err := c.CreateParentDir(fileName) - if err != nil { - return nil, err - } - return os.Create(filepath.Join(c.Path, fileName)) -} - -func (c Config) ReadFile(fileName string) ([]byte, error) { - return ioutil.ReadFile(filepath.Join(c.Path, fileName)) -} - -// CreateParentDir creates the parent directory of fileName inside c. fileName -// is a relative path inside c, containing zero or more path separators. -func (c Config) CreateParentDir(fileName string) error { - return os.MkdirAll(filepath.Dir(filepath.Join(c.Path, fileName)), 0755) -} - -func (c Config) WriteFile(fileName string, data []byte) error { - err := c.CreateParentDir(fileName) - if err != nil { - return err - } - return ioutil.WriteFile(filepath.Join(c.Path, fileName), data, 0644) -} - -func (c Config) MkdirAll() error { - return os.MkdirAll(c.Path, 0755) -} - -func (c Config) Exists(fileName string) bool { - _, err := os.Stat(filepath.Join(c.Path, fileName)) - return !os.IsNotExist(err) -} - -// ConfigDir keeps setting for querying folders. -type ConfigDir struct { - VendorName string - ApplicationName string - LocalPath string -} - -func New(vendorName, applicationName string) ConfigDir { - return ConfigDir{ - VendorName: vendorName, - ApplicationName: applicationName, - } -} - -func (c ConfigDir) joinPath(root string) string { - if c.VendorName != "" && hasVendorName { - return filepath.Join(root, c.VendorName, c.ApplicationName) - } - return filepath.Join(root, c.ApplicationName) -} - -func (c ConfigDir) QueryFolders(configType ConfigType) []*Config { - if configType == Cache { - return []*Config{c.QueryCacheFolder()} - } - var result []*Config - if c.LocalPath != "" && configType != System && configType != Global { - result = append(result, &Config{ - Path: c.LocalPath, - Type: Local, - }) - } - if configType != System && configType != Local { - result = append(result, &Config{ - Path: c.joinPath(globalSettingFolder), - Type: Global, - }) - } - if configType != Global && configType != Local { - for _, root := range systemSettingFolders { - result = append(result, &Config{ - Path: c.joinPath(root), - Type: System, - }) - } - } - if configType != Existing { - return result - } - var existing []*Config - for _, entry := range result { - if _, err := os.Stat(entry.Path); !os.IsNotExist(err) { - existing = append(existing, entry) - } - } - return existing -} - -func (c ConfigDir) QueryFolderContainsFile(fileName string) *Config { - configs := c.QueryFolders(Existing) - for _, config := range configs { - if _, err := os.Stat(filepath.Join(config.Path, fileName)); !os.IsNotExist(err) { - return config - } - } - return nil -} - -func (c ConfigDir) QueryCacheFolder() *Config { - return &Config{ - Path: c.joinPath(cacheFolder), - Type: Cache, - } -} diff --git a/vendor/github.com/shibukawa/configdir/config_darwin.go b/vendor/github.com/shibukawa/configdir/config_darwin.go deleted file mode 100644 index d668507a7e..0000000000 --- a/vendor/github.com/shibukawa/configdir/config_darwin.go +++ /dev/null @@ -1,8 +0,0 @@ -package configdir - -import "os" - -var hasVendorName = true -var systemSettingFolders = []string{"/Library/Application Support"} -var globalSettingFolder = os.Getenv("HOME") + "/Library/Application Support" -var cacheFolder = os.Getenv("HOME") + "/Library/Caches" diff --git a/vendor/github.com/shibukawa/configdir/config_windows.go b/vendor/github.com/shibukawa/configdir/config_windows.go deleted file mode 100644 index 0984821778..0000000000 --- a/vendor/github.com/shibukawa/configdir/config_windows.go +++ /dev/null @@ -1,8 +0,0 @@ -package configdir - -import "os" - -var hasVendorName = true -var systemSettingFolders = []string{os.Getenv("PROGRAMDATA")} -var globalSettingFolder = os.Getenv("APPDATA") -var cacheFolder = os.Getenv("LOCALAPPDATA") diff --git a/vendor/github.com/shibukawa/configdir/config_xdg.go b/vendor/github.com/shibukawa/configdir/config_xdg.go deleted file mode 100644 index 026ca68a0b..0000000000 --- a/vendor/github.com/shibukawa/configdir/config_xdg.go +++ /dev/null @@ -1,34 +0,0 @@ -// +build !windows,!darwin - -package configdir - -import ( - "os" - "path/filepath" - "strings" -) - -// https://specifications.freedesktop.org/basedir-spec/basedir-spec-latest.html - -var hasVendorName = true -var systemSettingFolders []string -var globalSettingFolder string -var cacheFolder string - -func init() { - if os.Getenv("XDG_CONFIG_HOME") != "" { - globalSettingFolder = os.Getenv("XDG_CONFIG_HOME") - } else { - globalSettingFolder = filepath.Join(os.Getenv("HOME"), ".config") - } - if os.Getenv("XDG_CONFIG_DIRS") != "" { - systemSettingFolders = strings.Split(os.Getenv("XDG_CONFIG_DIRS"), ":") - } else { - systemSettingFolders = []string{"/etc/xdg"} - } - if os.Getenv("XDG_CACHE_HOME") != "" { - cacheFolder = os.Getenv("XDG_CACHE_HOME") - } else { - cacheFolder = filepath.Join(os.Getenv("HOME"), ".cache") - } -} diff --git a/vendor/github.com/yosida95/uritemplate/v3/LICENSE b/vendor/github.com/yosida95/uritemplate/v3/LICENSE new file mode 100644 index 0000000000..79e8f87572 --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/LICENSE @@ -0,0 +1,25 @@ +Copyright (C) 2016, Kohei YOSHIDA . All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/yosida95/uritemplate/v3/README.rst b/vendor/github.com/yosida95/uritemplate/v3/README.rst new file mode 100644 index 0000000000..6815d0a465 --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/README.rst @@ -0,0 +1,46 @@ +uritemplate +=========== + +`uritemplate`_ is a Go implementation of `URI Template`_ [RFC6570] with +full functionality of URI Template Level 4. + +uritemplate can also generate a regexp that matches expansion of the +URI Template from a URI Template. + +Getting Started +--------------- + +Installation +~~~~~~~~~~~~ + +.. code-block:: sh + + $ go get -u github.com/yosida95/uritemplate/v3 + +Documentation +~~~~~~~~~~~~~ + +The documentation is available on GoDoc_. + +Examples +-------- + +See `examples on GoDoc`_. + +License +------- + +`uritemplate`_ is distributed under the BSD 3-Clause license. +PLEASE READ ./LICENSE carefully and follow its clauses to use this software. + +Author +------ + +yosida95_ + + +.. _`URI Template`: https://tools.ietf.org/html/rfc6570 +.. _Godoc: https://godoc.org/github.com/yosida95/uritemplate +.. _`examples on GoDoc`: https://godoc.org/github.com/yosida95/uritemplate#pkg-examples +.. _yosida95: https://yosida95.com/ +.. _uritemplate: https://github.com/yosida95/uritemplate diff --git a/vendor/github.com/yosida95/uritemplate/v3/compile.go b/vendor/github.com/yosida95/uritemplate/v3/compile.go new file mode 100644 index 0000000000..bd774d15d0 --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/compile.go @@ -0,0 +1,224 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +import ( + "fmt" + "unicode/utf8" +) + +type compiler struct { + prog *prog +} + +func (c *compiler) init() { + c.prog = &prog{} +} + +func (c *compiler) op(opcode progOpcode) uint32 { + i := len(c.prog.op) + c.prog.op = append(c.prog.op, progOp{code: opcode}) + return uint32(i) +} + +func (c *compiler) opWithRune(opcode progOpcode, r rune) uint32 { + addr := c.op(opcode) + (&c.prog.op[addr]).r = r + return addr +} + +func (c *compiler) opWithRuneClass(opcode progOpcode, rc runeClass) uint32 { + addr := c.op(opcode) + (&c.prog.op[addr]).rc = rc + return addr +} + +func (c *compiler) opWithAddr(opcode progOpcode, absaddr uint32) uint32 { + addr := c.op(opcode) + (&c.prog.op[addr]).i = absaddr + return addr +} + +func (c *compiler) opWithAddrDelta(opcode progOpcode, delta uint32) uint32 { + return c.opWithAddr(opcode, uint32(len(c.prog.op))+delta) +} + +func (c *compiler) opWithName(opcode progOpcode, name string) uint32 { + addr := c.op(opcode) + (&c.prog.op[addr]).name = name + return addr +} + +func (c *compiler) compileString(str string) { + for i := 0; i < len(str); { + // NOTE(yosida95): It is confirmed at parse time that literals + // consist of only valid-UTF8 runes. + r, size := utf8.DecodeRuneInString(str[i:]) + c.opWithRune(opRune, r) + i += size + } +} + +func (c *compiler) compileRuneClass(rc runeClass, maxlen int) { + for i := 0; i < maxlen; i++ { + if i > 0 { + c.opWithAddrDelta(opSplit, 7) + } + c.opWithAddrDelta(opSplit, 3) // raw rune or pct-encoded + c.opWithRuneClass(opRuneClass, rc) // raw rune + c.opWithAddrDelta(opJmp, 4) // + c.opWithRune(opRune, '%') // pct-encoded + c.opWithRuneClass(opRuneClass, runeClassPctE) // + c.opWithRuneClass(opRuneClass, runeClassPctE) // + } +} + +func (c *compiler) compileRuneClassInfinite(rc runeClass) { + start := c.opWithAddrDelta(opSplit, 3) // raw rune or pct-encoded + c.opWithRuneClass(opRuneClass, rc) // raw rune + c.opWithAddrDelta(opJmp, 4) // + c.opWithRune(opRune, '%') // pct-encoded + c.opWithRuneClass(opRuneClass, runeClassPctE) // + c.opWithRuneClass(opRuneClass, runeClassPctE) // + c.opWithAddrDelta(opSplit, 2) // loop + c.opWithAddr(opJmp, start) // +} + +func (c *compiler) compileVarspecValue(spec varspec, expr *expression) { + var specname string + if spec.maxlen > 0 { + specname = fmt.Sprintf("%s:%d", spec.name, spec.maxlen) + } else { + specname = spec.name + } + + c.prog.numCap++ + + c.opWithName(opCapStart, specname) + + split := c.op(opSplit) + if spec.maxlen > 0 { + c.compileRuneClass(expr.allow, spec.maxlen) + } else { + c.compileRuneClassInfinite(expr.allow) + } + + capEnd := c.opWithName(opCapEnd, specname) + c.prog.op[split].i = capEnd +} + +func (c *compiler) compileVarspec(spec varspec, expr *expression) { + switch { + case expr.named && spec.explode: + split1 := c.op(opSplit) + noop := c.op(opNoop) + c.compileString(spec.name) + + split2 := c.op(opSplit) + c.opWithRune(opRune, '=') + c.compileVarspecValue(spec, expr) + + split3 := c.op(opSplit) + c.compileString(expr.sep) + c.opWithAddr(opJmp, noop) + + c.prog.op[split2].i = uint32(len(c.prog.op)) + c.compileString(expr.ifemp) + c.opWithAddr(opJmp, split3) + + c.prog.op[split1].i = uint32(len(c.prog.op)) + c.prog.op[split3].i = uint32(len(c.prog.op)) + + case expr.named && !spec.explode: + c.compileString(spec.name) + + split2 := c.op(opSplit) + c.opWithRune(opRune, '=') + + split3 := c.op(opSplit) + + split4 := c.op(opSplit) + c.compileVarspecValue(spec, expr) + + split5 := c.op(opSplit) + c.prog.op[split4].i = split5 + c.compileString(",") + c.opWithAddr(opJmp, split4) + + c.prog.op[split3].i = uint32(len(c.prog.op)) + c.compileString(",") + jmp1 := c.op(opJmp) + + c.prog.op[split2].i = uint32(len(c.prog.op)) + c.compileString(expr.ifemp) + + c.prog.op[split5].i = uint32(len(c.prog.op)) + c.prog.op[jmp1].i = uint32(len(c.prog.op)) + + case !expr.named: + start := uint32(len(c.prog.op)) + c.compileVarspecValue(spec, expr) + + split1 := c.op(opSplit) + jmp := c.op(opJmp) + + c.prog.op[split1].i = uint32(len(c.prog.op)) + if spec.explode { + c.compileString(expr.sep) + } else { + c.opWithRune(opRune, ',') + } + c.opWithAddr(opJmp, start) + + c.prog.op[jmp].i = uint32(len(c.prog.op)) + } +} + +func (c *compiler) compileExpression(expr *expression) { + if len(expr.vars) < 1 { + return + } + + split1 := c.op(opSplit) + c.compileString(expr.first) + + for i, size := 0, len(expr.vars); i < size; i++ { + spec := expr.vars[i] + + split2 := c.op(opSplit) + if i > 0 { + split3 := c.op(opSplit) + c.compileString(expr.sep) + c.prog.op[split3].i = uint32(len(c.prog.op)) + } + c.compileVarspec(spec, expr) + c.prog.op[split2].i = uint32(len(c.prog.op)) + } + + c.prog.op[split1].i = uint32(len(c.prog.op)) +} + +func (c *compiler) compileLiterals(lt literals) { + c.compileString(string(lt)) +} + +func (c *compiler) compile(tmpl *Template) { + c.op(opLineBegin) + for i := range tmpl.exprs { + expr := tmpl.exprs[i] + switch expr := expr.(type) { + default: + panic("unhandled expression") + case *expression: + c.compileExpression(expr) + case literals: + c.compileLiterals(expr) + } + } + c.op(opLineEnd) + c.op(opEnd) +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/equals.go b/vendor/github.com/yosida95/uritemplate/v3/equals.go new file mode 100644 index 0000000000..aa59a5c030 --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/equals.go @@ -0,0 +1,53 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +type CompareFlags uint8 + +const ( + CompareVarname CompareFlags = 1 << iota +) + +// Equals reports whether or not two URI Templates t1 and t2 are equivalent. +func Equals(t1 *Template, t2 *Template, flags CompareFlags) bool { + if len(t1.exprs) != len(t2.exprs) { + return false + } + for i := 0; i < len(t1.exprs); i++ { + switch t1 := t1.exprs[i].(type) { + case literals: + t2, ok := t2.exprs[i].(literals) + if !ok { + return false + } + if t1 != t2 { + return false + } + case *expression: + t2, ok := t2.exprs[i].(*expression) + if !ok { + return false + } + if t1.op != t2.op || len(t1.vars) != len(t2.vars) { + return false + } + for n := 0; n < len(t1.vars); n++ { + v1 := t1.vars[n] + v2 := t2.vars[n] + if flags&CompareVarname == CompareVarname && v1.name != v2.name { + return false + } + if v1.maxlen != v2.maxlen || v1.explode != v2.explode { + return false + } + } + default: + panic("unhandled case") + } + } + return true +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/error.go b/vendor/github.com/yosida95/uritemplate/v3/error.go new file mode 100644 index 0000000000..2fd34a8080 --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/error.go @@ -0,0 +1,16 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +import ( + "fmt" +) + +func errorf(pos int, format string, a ...interface{}) error { + msg := fmt.Sprintf(format, a...) + return fmt.Errorf("uritemplate:%d:%s", pos, msg) +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/escape.go b/vendor/github.com/yosida95/uritemplate/v3/escape.go new file mode 100644 index 0000000000..6d27e693af --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/escape.go @@ -0,0 +1,190 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +import ( + "strings" + "unicode" + "unicode/utf8" +) + +var ( + hex = []byte("0123456789ABCDEF") + // reserved = gen-delims / sub-delims + // gen-delims = ":" / "/" / "?" / "#" / "[" / "]" / "@" + // sub-delims = "!" / "$" / "&" / "’" / "(" / ")" + // / "*" / "+" / "," / ";" / "=" + rangeReserved = &unicode.RangeTable{ + R16: []unicode.Range16{ + {Lo: 0x21, Hi: 0x21, Stride: 1}, // '!' + {Lo: 0x23, Hi: 0x24, Stride: 1}, // '#' - '$' + {Lo: 0x26, Hi: 0x2C, Stride: 1}, // '&' - ',' + {Lo: 0x2F, Hi: 0x2F, Stride: 1}, // '/' + {Lo: 0x3A, Hi: 0x3B, Stride: 1}, // ':' - ';' + {Lo: 0x3D, Hi: 0x3D, Stride: 1}, // '=' + {Lo: 0x3F, Hi: 0x40, Stride: 1}, // '?' - '@' + {Lo: 0x5B, Hi: 0x5B, Stride: 1}, // '[' + {Lo: 0x5D, Hi: 0x5D, Stride: 1}, // ']' + }, + LatinOffset: 9, + } + reReserved = `\x21\x23\x24\x26-\x2c\x2f\x3a\x3b\x3d\x3f\x40\x5b\x5d` + // ALPHA = %x41-5A / %x61-7A + // DIGIT = %x30-39 + // unreserved = ALPHA / DIGIT / "-" / "." / "_" / "~" + rangeUnreserved = &unicode.RangeTable{ + R16: []unicode.Range16{ + {Lo: 0x2D, Hi: 0x2E, Stride: 1}, // '-' - '.' + {Lo: 0x30, Hi: 0x39, Stride: 1}, // '0' - '9' + {Lo: 0x41, Hi: 0x5A, Stride: 1}, // 'A' - 'Z' + {Lo: 0x5F, Hi: 0x5F, Stride: 1}, // '_' + {Lo: 0x61, Hi: 0x7A, Stride: 1}, // 'a' - 'z' + {Lo: 0x7E, Hi: 0x7E, Stride: 1}, // '~' + }, + } + reUnreserved = `\x2d\x2e\x30-\x39\x41-\x5a\x5f\x61-\x7a\x7e` +) + +type runeClass uint8 + +const ( + runeClassU runeClass = 1 << iota + runeClassR + runeClassPctE + runeClassLast + + runeClassUR = runeClassU | runeClassR +) + +var runeClassNames = []string{ + "U", + "R", + "pct-encoded", +} + +func (rc runeClass) String() string { + ret := make([]string, 0, len(runeClassNames)) + for i, j := 0, runeClass(1); j < runeClassLast; j <<= 1 { + if rc&j == j { + ret = append(ret, runeClassNames[i]) + } + i++ + } + return strings.Join(ret, "+") +} + +func pctEncode(w *strings.Builder, r rune) { + if s := r >> 24 & 0xff; s > 0 { + w.Write([]byte{'%', hex[s/16], hex[s%16]}) + } + if s := r >> 16 & 0xff; s > 0 { + w.Write([]byte{'%', hex[s/16], hex[s%16]}) + } + if s := r >> 8 & 0xff; s > 0 { + w.Write([]byte{'%', hex[s/16], hex[s%16]}) + } + if s := r & 0xff; s > 0 { + w.Write([]byte{'%', hex[s/16], hex[s%16]}) + } +} + +func unhex(c byte) byte { + switch { + case '0' <= c && c <= '9': + return c - '0' + case 'a' <= c && c <= 'f': + return c - 'a' + 10 + case 'A' <= c && c <= 'F': + return c - 'A' + 10 + } + return 0 +} + +func ishex(c byte) bool { + switch { + case '0' <= c && c <= '9': + return true + case 'a' <= c && c <= 'f': + return true + case 'A' <= c && c <= 'F': + return true + default: + return false + } +} + +func pctDecode(s string) string { + size := len(s) + for i := 0; i < len(s); { + switch s[i] { + case '%': + size -= 2 + i += 3 + default: + i++ + } + } + if size == len(s) { + return s + } + + buf := make([]byte, size) + j := 0 + for i := 0; i < len(s); { + switch c := s[i]; c { + case '%': + buf[j] = unhex(s[i+1])<<4 | unhex(s[i+2]) + i += 3 + j++ + default: + buf[j] = c + i++ + j++ + } + } + return string(buf) +} + +type escapeFunc func(*strings.Builder, string) error + +func escapeLiteral(w *strings.Builder, v string) error { + w.WriteString(v) + return nil +} + +func escapeExceptU(w *strings.Builder, v string) error { + for i := 0; i < len(v); { + r, size := utf8.DecodeRuneInString(v[i:]) + if r == utf8.RuneError { + return errorf(i, "invalid encoding") + } + if unicode.Is(rangeUnreserved, r) { + w.WriteRune(r) + } else { + pctEncode(w, r) + } + i += size + } + return nil +} + +func escapeExceptUR(w *strings.Builder, v string) error { + for i := 0; i < len(v); { + r, size := utf8.DecodeRuneInString(v[i:]) + if r == utf8.RuneError { + return errorf(i, "invalid encoding") + } + // TODO(yosida95): is pct-encoded triplets allowed here? + if unicode.In(r, rangeUnreserved, rangeReserved) { + w.WriteRune(r) + } else { + pctEncode(w, r) + } + i += size + } + return nil +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/expression.go b/vendor/github.com/yosida95/uritemplate/v3/expression.go new file mode 100644 index 0000000000..4858c2ddef --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/expression.go @@ -0,0 +1,173 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +import ( + "regexp" + "strconv" + "strings" +) + +type template interface { + expand(*strings.Builder, Values) error + regexp(*strings.Builder) +} + +type literals string + +func (l literals) expand(b *strings.Builder, _ Values) error { + b.WriteString(string(l)) + return nil +} + +func (l literals) regexp(b *strings.Builder) { + b.WriteString("(?:") + b.WriteString(regexp.QuoteMeta(string(l))) + b.WriteByte(')') +} + +type varspec struct { + name string + maxlen int + explode bool +} + +type expression struct { + vars []varspec + op parseOp + first string + sep string + named bool + ifemp string + escape escapeFunc + allow runeClass +} + +func (e *expression) init() { + switch e.op { + case parseOpSimple: + e.sep = "," + e.escape = escapeExceptU + e.allow = runeClassU + case parseOpPlus: + e.sep = "," + e.escape = escapeExceptUR + e.allow = runeClassUR + case parseOpCrosshatch: + e.first = "#" + e.sep = "," + e.escape = escapeExceptUR + e.allow = runeClassUR + case parseOpDot: + e.first = "." + e.sep = "." + e.escape = escapeExceptU + e.allow = runeClassU + case parseOpSlash: + e.first = "/" + e.sep = "/" + e.escape = escapeExceptU + e.allow = runeClassU + case parseOpSemicolon: + e.first = ";" + e.sep = ";" + e.named = true + e.escape = escapeExceptU + e.allow = runeClassU + case parseOpQuestion: + e.first = "?" + e.sep = "&" + e.named = true + e.ifemp = "=" + e.escape = escapeExceptU + e.allow = runeClassU + case parseOpAmpersand: + e.first = "&" + e.sep = "&" + e.named = true + e.ifemp = "=" + e.escape = escapeExceptU + e.allow = runeClassU + } +} + +func (e *expression) expand(w *strings.Builder, values Values) error { + first := true + for _, varspec := range e.vars { + value := values.Get(varspec.name) + if !value.Valid() { + continue + } + + if first { + w.WriteString(e.first) + first = false + } else { + w.WriteString(e.sep) + } + + if err := value.expand(w, varspec, e); err != nil { + return err + } + + } + return nil +} + +func (e *expression) regexp(b *strings.Builder) { + if e.first != "" { + b.WriteString("(?:") // $1 + b.WriteString(regexp.QuoteMeta(e.first)) + } + b.WriteByte('(') // $2 + runeClassToRegexp(b, e.allow, e.named || e.vars[0].explode) + if len(e.vars) > 1 || e.vars[0].explode { + max := len(e.vars) - 1 + for i := 0; i < len(e.vars); i++ { + if e.vars[i].explode { + max = -1 + break + } + } + + b.WriteString("(?:") // $3 + b.WriteString(regexp.QuoteMeta(e.sep)) + runeClassToRegexp(b, e.allow, e.named || max < 0) + b.WriteByte(')') // $3 + if max > 0 { + b.WriteString("{0,") + b.WriteString(strconv.Itoa(max)) + b.WriteByte('}') + } else { + b.WriteByte('*') + } + } + b.WriteByte(')') // $2 + if e.first != "" { + b.WriteByte(')') // $1 + } + b.WriteByte('?') +} + +func runeClassToRegexp(b *strings.Builder, class runeClass, named bool) { + b.WriteString("(?:(?:[") + if class&runeClassR == 0 { + b.WriteString(`\x2c`) + if named { + b.WriteString(`\x3d`) + } + } + if class&runeClassU == runeClassU { + b.WriteString(reUnreserved) + } + if class&runeClassR == runeClassR { + b.WriteString(reReserved) + } + b.WriteString("]") + b.WriteString("|%[[:xdigit:]][[:xdigit:]]") + b.WriteString(")*)") +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/machine.go b/vendor/github.com/yosida95/uritemplate/v3/machine.go new file mode 100644 index 0000000000..7b1d0b518d --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/machine.go @@ -0,0 +1,23 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +// threadList implements https://research.swtch.com/sparse. +type threadList struct { + dense []threadEntry + sparse []uint32 +} + +type threadEntry struct { + pc uint32 + t *thread +} + +type thread struct { + op *progOp + cap map[string][]int +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/match.go b/vendor/github.com/yosida95/uritemplate/v3/match.go new file mode 100644 index 0000000000..02fe6385a3 --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/match.go @@ -0,0 +1,213 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +import ( + "bytes" + "unicode" + "unicode/utf8" +) + +type matcher struct { + prog *prog + + list1 threadList + list2 threadList + matched bool + cap map[string][]int + + input string +} + +func (m *matcher) at(pos int) (rune, int, bool) { + if l := len(m.input); pos < l { + c := m.input[pos] + if c < utf8.RuneSelf { + return rune(c), 1, pos+1 < l + } + r, size := utf8.DecodeRuneInString(m.input[pos:]) + return r, size, pos+size < l + } + return -1, 0, false +} + +func (m *matcher) add(list *threadList, pc uint32, pos int, next bool, cap map[string][]int) { + if i := list.sparse[pc]; i < uint32(len(list.dense)) && list.dense[i].pc == pc { + return + } + + n := len(list.dense) + list.dense = list.dense[:n+1] + list.sparse[pc] = uint32(n) + + e := &list.dense[n] + e.pc = pc + e.t = nil + + op := &m.prog.op[pc] + switch op.code { + default: + panic("unhandled opcode") + case opRune, opRuneClass, opEnd: + e.t = &thread{ + op: &m.prog.op[pc], + cap: make(map[string][]int, len(m.cap)), + } + for k, v := range cap { + e.t.cap[k] = make([]int, len(v)) + copy(e.t.cap[k], v) + } + case opLineBegin: + if pos == 0 { + m.add(list, pc+1, pos, next, cap) + } + case opLineEnd: + if !next { + m.add(list, pc+1, pos, next, cap) + } + case opCapStart, opCapEnd: + ocap := make(map[string][]int, len(m.cap)) + for k, v := range cap { + ocap[k] = make([]int, len(v)) + copy(ocap[k], v) + } + ocap[op.name] = append(ocap[op.name], pos) + m.add(list, pc+1, pos, next, ocap) + case opSplit: + m.add(list, pc+1, pos, next, cap) + m.add(list, op.i, pos, next, cap) + case opJmp: + m.add(list, op.i, pos, next, cap) + case opJmpIfNotDefined: + m.add(list, pc+1, pos, next, cap) + m.add(list, op.i, pos, next, cap) + case opJmpIfNotFirst: + m.add(list, pc+1, pos, next, cap) + m.add(list, op.i, pos, next, cap) + case opJmpIfNotEmpty: + m.add(list, op.i, pos, next, cap) + m.add(list, pc+1, pos, next, cap) + case opNoop: + m.add(list, pc+1, pos, next, cap) + } +} + +func (m *matcher) step(clist *threadList, nlist *threadList, r rune, pos int, nextPos int, next bool) { + debug.Printf("===== %q =====", string(r)) + for i := 0; i < len(clist.dense); i++ { + e := clist.dense[i] + if debug { + var buf bytes.Buffer + dumpProg(&buf, m.prog, e.pc) + debug.Printf("\n%s", buf.String()) + } + if e.t == nil { + continue + } + + t := e.t + op := t.op + switch op.code { + default: + panic("unhandled opcode") + case opRune: + if op.r == r { + m.add(nlist, e.pc+1, nextPos, next, t.cap) + } + case opRuneClass: + ret := false + if !ret && op.rc&runeClassU == runeClassU { + ret = ret || unicode.Is(rangeUnreserved, r) + } + if !ret && op.rc&runeClassR == runeClassR { + ret = ret || unicode.Is(rangeReserved, r) + } + if !ret && op.rc&runeClassPctE == runeClassPctE { + ret = ret || unicode.Is(unicode.ASCII_Hex_Digit, r) + } + if ret { + m.add(nlist, e.pc+1, nextPos, next, t.cap) + } + case opEnd: + m.matched = true + for k, v := range t.cap { + m.cap[k] = make([]int, len(v)) + copy(m.cap[k], v) + } + clist.dense = clist.dense[:0] + } + } + clist.dense = clist.dense[:0] +} + +func (m *matcher) match() bool { + pos := 0 + clist, nlist := &m.list1, &m.list2 + for { + if len(clist.dense) == 0 && m.matched { + break + } + r, width, next := m.at(pos) + if !m.matched { + m.add(clist, 0, pos, next, m.cap) + } + m.step(clist, nlist, r, pos, pos+width, next) + + if width < 1 { + break + } + pos += width + + clist, nlist = nlist, clist + } + return m.matched +} + +func (tmpl *Template) Match(expansion string) Values { + tmpl.mu.Lock() + if tmpl.prog == nil { + c := compiler{} + c.init() + c.compile(tmpl) + tmpl.prog = c.prog + } + prog := tmpl.prog + tmpl.mu.Unlock() + + n := len(prog.op) + m := matcher{ + prog: prog, + list1: threadList{ + dense: make([]threadEntry, 0, n), + sparse: make([]uint32, n), + }, + list2: threadList{ + dense: make([]threadEntry, 0, n), + sparse: make([]uint32, n), + }, + cap: make(map[string][]int, prog.numCap), + input: expansion, + } + if !m.match() { + return nil + } + + match := make(Values, len(m.cap)) + for name, indices := range m.cap { + v := Value{V: make([]string, len(indices)/2)} + for i := range v.V { + v.V[i] = pctDecode(expansion[indices[2*i]:indices[2*i+1]]) + } + if len(v.V) == 1 { + v.T = ValueTypeString + } else { + v.T = ValueTypeList + } + match[name] = v + } + return match +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/parse.go b/vendor/github.com/yosida95/uritemplate/v3/parse.go new file mode 100644 index 0000000000..fd38a682f1 --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/parse.go @@ -0,0 +1,277 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +import ( + "fmt" + "unicode" + "unicode/utf8" +) + +type parseOp int + +const ( + parseOpSimple parseOp = iota + parseOpPlus + parseOpCrosshatch + parseOpDot + parseOpSlash + parseOpSemicolon + parseOpQuestion + parseOpAmpersand +) + +var ( + rangeVarchar = &unicode.RangeTable{ + R16: []unicode.Range16{ + {Lo: 0x0030, Hi: 0x0039, Stride: 1}, // '0' - '9' + {Lo: 0x0041, Hi: 0x005A, Stride: 1}, // 'A' - 'Z' + {Lo: 0x005F, Hi: 0x005F, Stride: 1}, // '_' + {Lo: 0x0061, Hi: 0x007A, Stride: 1}, // 'a' - 'z' + }, + LatinOffset: 4, + } + rangeLiterals = &unicode.RangeTable{ + R16: []unicode.Range16{ + {Lo: 0x0021, Hi: 0x0021, Stride: 1}, // '!' + {Lo: 0x0023, Hi: 0x0024, Stride: 1}, // '#' - '$' + {Lo: 0x0026, Hi: 0x003B, Stride: 1}, // '&' ''' '(' - ';'. '''/27 used to be excluded but an errata is in the review process https://www.rfc-editor.org/errata/eid6937 + {Lo: 0x003D, Hi: 0x003D, Stride: 1}, // '=' + {Lo: 0x003F, Hi: 0x005B, Stride: 1}, // '?' - '[' + {Lo: 0x005D, Hi: 0x005D, Stride: 1}, // ']' + {Lo: 0x005F, Hi: 0x005F, Stride: 1}, // '_' + {Lo: 0x0061, Hi: 0x007A, Stride: 1}, // 'a' - 'z' + {Lo: 0x007E, Hi: 0x007E, Stride: 1}, // '~' + {Lo: 0x00A0, Hi: 0xD7FF, Stride: 1}, // ucschar + {Lo: 0xE000, Hi: 0xF8FF, Stride: 1}, // iprivate + {Lo: 0xF900, Hi: 0xFDCF, Stride: 1}, // ucschar + {Lo: 0xFDF0, Hi: 0xFFEF, Stride: 1}, // ucschar + }, + R32: []unicode.Range32{ + {Lo: 0x00010000, Hi: 0x0001FFFD, Stride: 1}, // ucschar + {Lo: 0x00020000, Hi: 0x0002FFFD, Stride: 1}, // ucschar + {Lo: 0x00030000, Hi: 0x0003FFFD, Stride: 1}, // ucschar + {Lo: 0x00040000, Hi: 0x0004FFFD, Stride: 1}, // ucschar + {Lo: 0x00050000, Hi: 0x0005FFFD, Stride: 1}, // ucschar + {Lo: 0x00060000, Hi: 0x0006FFFD, Stride: 1}, // ucschar + {Lo: 0x00070000, Hi: 0x0007FFFD, Stride: 1}, // ucschar + {Lo: 0x00080000, Hi: 0x0008FFFD, Stride: 1}, // ucschar + {Lo: 0x00090000, Hi: 0x0009FFFD, Stride: 1}, // ucschar + {Lo: 0x000A0000, Hi: 0x000AFFFD, Stride: 1}, // ucschar + {Lo: 0x000B0000, Hi: 0x000BFFFD, Stride: 1}, // ucschar + {Lo: 0x000C0000, Hi: 0x000CFFFD, Stride: 1}, // ucschar + {Lo: 0x000D0000, Hi: 0x000DFFFD, Stride: 1}, // ucschar + {Lo: 0x000E1000, Hi: 0x000EFFFD, Stride: 1}, // ucschar + {Lo: 0x000F0000, Hi: 0x000FFFFD, Stride: 1}, // iprivate + {Lo: 0x00100000, Hi: 0x0010FFFD, Stride: 1}, // iprivate + }, + LatinOffset: 10, + } +) + +type parser struct { + r string + start int + stop int + state parseState +} + +func (p *parser) errorf(i rune, format string, a ...interface{}) error { + return fmt.Errorf("%s: %s%s", fmt.Sprintf(format, a...), p.r[0:p.stop], string(i)) +} + +func (p *parser) rune() (rune, int) { + r, size := utf8.DecodeRuneInString(p.r[p.stop:]) + if r != utf8.RuneError { + p.stop += size + } + return r, size +} + +func (p *parser) unread(r rune) { + p.stop -= utf8.RuneLen(r) +} + +type parseState int + +const ( + parseStateDefault = parseState(iota) + parseStateOperator + parseStateVarList + parseStateVarName + parseStatePrefix +) + +func (p *parser) setState(state parseState) { + p.state = state + p.start = p.stop +} + +func (p *parser) parseURITemplate() (*Template, error) { + tmpl := Template{ + raw: p.r, + exprs: []template{}, + } + + var exp *expression + for { + r, size := p.rune() + if r == utf8.RuneError { + if size == 0 { + if p.state != parseStateDefault { + return nil, p.errorf('_', "incomplete expression") + } + if p.start < p.stop { + tmpl.exprs = append(tmpl.exprs, literals(p.r[p.start:p.stop])) + } + return &tmpl, nil + } + return nil, p.errorf('_', "invalid UTF-8 sequence") + } + + switch p.state { + case parseStateDefault: + switch r { + case '{': + if stop := p.stop - size; stop > p.start { + tmpl.exprs = append(tmpl.exprs, literals(p.r[p.start:stop])) + } + exp = &expression{} + tmpl.exprs = append(tmpl.exprs, exp) + p.setState(parseStateOperator) + case '%': + p.unread(r) + if err := p.consumeTriplet(); err != nil { + return nil, err + } + default: + if !unicode.Is(rangeLiterals, r) { + p.unread(r) + return nil, p.errorf('_', "unacceptable character (hint: use %%XX encoding)") + } + } + case parseStateOperator: + switch r { + default: + p.unread(r) + exp.op = parseOpSimple + case '+': + exp.op = parseOpPlus + case '#': + exp.op = parseOpCrosshatch + case '.': + exp.op = parseOpDot + case '/': + exp.op = parseOpSlash + case ';': + exp.op = parseOpSemicolon + case '?': + exp.op = parseOpQuestion + case '&': + exp.op = parseOpAmpersand + case '=', ',', '!', '@', '|': // op-reserved + return nil, p.errorf('|', "unimplemented operator (op-reserved)") + } + p.setState(parseStateVarName) + case parseStateVarList: + switch r { + case ',': + p.setState(parseStateVarName) + case '}': + exp.init() + p.setState(parseStateDefault) + default: + p.unread(r) + return nil, p.errorf('_', "unrecognized value modifier") + } + case parseStateVarName: + switch r { + case ':', '*': + name := p.r[p.start : p.stop-size] + if !isValidVarname(name) { + return nil, p.errorf('|', "unacceptable variable name") + } + explode := r == '*' + exp.vars = append(exp.vars, varspec{ + name: name, + explode: explode, + }) + if explode { + p.setState(parseStateVarList) + } else { + p.setState(parseStatePrefix) + } + case ',', '}': + p.unread(r) + name := p.r[p.start:p.stop] + if !isValidVarname(name) { + return nil, p.errorf('|', "unacceptable variable name") + } + exp.vars = append(exp.vars, varspec{ + name: name, + }) + p.setState(parseStateVarList) + case '%': + p.unread(r) + if err := p.consumeTriplet(); err != nil { + return nil, err + } + case '.': + if dot := p.stop - size; dot == p.start || p.r[dot-1] == '.' { + return nil, p.errorf('|', "unacceptable variable name") + } + default: + if !unicode.Is(rangeVarchar, r) { + p.unread(r) + return nil, p.errorf('_', "unacceptable variable name") + } + } + case parseStatePrefix: + spec := &(exp.vars[len(exp.vars)-1]) + switch { + case '0' <= r && r <= '9': + spec.maxlen *= 10 + spec.maxlen += int(r - '0') + if spec.maxlen == 0 || spec.maxlen > 9999 { + return nil, p.errorf('|', "max-length must be (0, 9999]") + } + default: + p.unread(r) + if spec.maxlen == 0 { + return nil, p.errorf('_', "max-length must be (0, 9999]") + } + p.setState(parseStateVarList) + } + default: + p.unread(r) + panic(p.errorf('_', "unhandled parseState(%d)", p.state)) + } + } +} + +func isValidVarname(name string) bool { + if l := len(name); l == 0 || name[0] == '.' || name[l-1] == '.' { + return false + } + for i := 1; i < len(name)-1; i++ { + switch c := name[i]; c { + case '.': + if name[i-1] == '.' { + return false + } + } + } + return true +} + +func (p *parser) consumeTriplet() error { + if len(p.r)-p.stop < 3 || p.r[p.stop] != '%' || !ishex(p.r[p.stop+1]) || !ishex(p.r[p.stop+2]) { + return p.errorf('_', "incomplete pct-encodeed") + } + p.stop += 3 + return nil +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/prog.go b/vendor/github.com/yosida95/uritemplate/v3/prog.go new file mode 100644 index 0000000000..97af4f0eab --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/prog.go @@ -0,0 +1,130 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +import ( + "bytes" + "strconv" +) + +type progOpcode uint16 + +const ( + // match + opRune progOpcode = iota + opRuneClass + opLineBegin + opLineEnd + // capture + opCapStart + opCapEnd + // stack + opSplit + opJmp + opJmpIfNotDefined + opJmpIfNotEmpty + opJmpIfNotFirst + // result + opEnd + // fake + opNoop + opcodeMax +) + +var opcodeNames = []string{ + // match + "opRune", + "opRuneClass", + "opLineBegin", + "opLineEnd", + // capture + "opCapStart", + "opCapEnd", + // stack + "opSplit", + "opJmp", + "opJmpIfNotDefined", + "opJmpIfNotEmpty", + "opJmpIfNotFirst", + // result + "opEnd", +} + +func (code progOpcode) String() string { + if code >= opcodeMax { + return "" + } + return opcodeNames[code] +} + +type progOp struct { + code progOpcode + r rune + rc runeClass + i uint32 + + name string +} + +func dumpProgOp(b *bytes.Buffer, op *progOp) { + b.WriteString(op.code.String()) + switch op.code { + case opRune: + b.WriteString("(") + b.WriteString(strconv.QuoteToASCII(string(op.r))) + b.WriteString(")") + case opRuneClass: + b.WriteString("(") + b.WriteString(op.rc.String()) + b.WriteString(")") + case opCapStart, opCapEnd: + b.WriteString("(") + b.WriteString(strconv.QuoteToASCII(op.name)) + b.WriteString(")") + case opSplit: + b.WriteString(" -> ") + b.WriteString(strconv.FormatInt(int64(op.i), 10)) + case opJmp, opJmpIfNotFirst: + b.WriteString(" -> ") + b.WriteString(strconv.FormatInt(int64(op.i), 10)) + case opJmpIfNotDefined, opJmpIfNotEmpty: + b.WriteString("(") + b.WriteString(strconv.QuoteToASCII(op.name)) + b.WriteString(")") + b.WriteString(" -> ") + b.WriteString(strconv.FormatInt(int64(op.i), 10)) + } +} + +type prog struct { + op []progOp + numCap int +} + +func dumpProg(b *bytes.Buffer, prog *prog, pc uint32) { + for i := range prog.op { + op := prog.op[i] + + pos := strconv.Itoa(i) + if uint32(i) == pc { + pos = "*" + pos + } + b.WriteString(" "[len(pos):]) + b.WriteString(pos) + + b.WriteByte('\t') + dumpProgOp(b, &op) + + b.WriteByte('\n') + } +} + +func (p *prog) String() string { + b := bytes.Buffer{} + dumpProg(&b, p, 0) + return b.String() +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/uritemplate.go b/vendor/github.com/yosida95/uritemplate/v3/uritemplate.go new file mode 100644 index 0000000000..dbd2673753 --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/uritemplate.go @@ -0,0 +1,116 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +import ( + "log" + "regexp" + "strings" + "sync" +) + +var ( + debug = debugT(false) +) + +type debugT bool + +func (t debugT) Printf(format string, v ...interface{}) { + if t { + log.Printf(format, v...) + } +} + +// Template represents a URI Template. +type Template struct { + raw string + exprs []template + + // protects the rest of fields + mu sync.Mutex + varnames []string + re *regexp.Regexp + prog *prog +} + +// New parses and constructs a new Template instance based on the template. +// New returns an error if the template cannot be recognized. +func New(template string) (*Template, error) { + return (&parser{r: template}).parseURITemplate() +} + +// MustNew panics if the template cannot be recognized. +func MustNew(template string) *Template { + ret, err := New(template) + if err != nil { + panic(err) + } + return ret +} + +// Raw returns a raw URI template passed to New in string. +func (t *Template) Raw() string { + return t.raw +} + +// Varnames returns variable names used in the template. +func (t *Template) Varnames() []string { + t.mu.Lock() + defer t.mu.Unlock() + if t.varnames != nil { + return t.varnames + } + + reg := map[string]struct{}{} + t.varnames = []string{} + for i := range t.exprs { + expr, ok := t.exprs[i].(*expression) + if !ok { + continue + } + for _, spec := range expr.vars { + if _, ok := reg[spec.name]; ok { + continue + } + reg[spec.name] = struct{}{} + t.varnames = append(t.varnames, spec.name) + } + } + + return t.varnames +} + +// Expand returns a URI reference corresponding to the template expanded using the passed variables. +func (t *Template) Expand(vars Values) (string, error) { + var w strings.Builder + for i := range t.exprs { + expr := t.exprs[i] + if err := expr.expand(&w, vars); err != nil { + return w.String(), err + } + } + return w.String(), nil +} + +// Regexp converts the template to regexp and returns compiled *regexp.Regexp. +func (t *Template) Regexp() *regexp.Regexp { + t.mu.Lock() + defer t.mu.Unlock() + if t.re != nil { + return t.re + } + + var b strings.Builder + b.WriteByte('^') + for _, expr := range t.exprs { + expr.regexp(&b) + } + b.WriteByte('$') + t.re = regexp.MustCompile(b.String()) + + return t.re +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/value.go b/vendor/github.com/yosida95/uritemplate/v3/value.go new file mode 100644 index 0000000000..0550eabdbf --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/value.go @@ -0,0 +1,216 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +import "strings" + +// A varname containing pct-encoded characters is not the same variable as +// a varname with those same characters decoded. +// +// -- https://tools.ietf.org/html/rfc6570#section-2.3 +type Values map[string]Value + +func (v Values) Set(name string, value Value) { + v[name] = value +} + +func (v Values) Get(name string) Value { + if v == nil { + return Value{} + } + return v[name] +} + +type ValueType uint8 + +const ( + ValueTypeString = iota + ValueTypeList + ValueTypeKV + valueTypeLast +) + +var valueTypeNames = []string{ + "String", + "List", + "KV", +} + +func (vt ValueType) String() string { + if vt < valueTypeLast { + return valueTypeNames[vt] + } + return "" +} + +type Value struct { + T ValueType + V []string +} + +func (v Value) String() string { + if v.Valid() && v.T == ValueTypeString { + return v.V[0] + } + return "" +} + +func (v Value) List() []string { + if v.Valid() && v.T == ValueTypeList { + return v.V + } + return nil +} + +func (v Value) KV() []string { + if v.Valid() && v.T == ValueTypeKV { + return v.V + } + return nil +} + +func (v Value) Valid() bool { + switch v.T { + default: + return false + case ValueTypeString: + return len(v.V) > 0 + case ValueTypeList: + return len(v.V) > 0 + case ValueTypeKV: + return len(v.V) > 0 && len(v.V)%2 == 0 + } +} + +func (v Value) expand(w *strings.Builder, spec varspec, exp *expression) error { + switch v.T { + case ValueTypeString: + val := v.V[0] + var maxlen int + if max := len(val); spec.maxlen < 1 || spec.maxlen > max { + maxlen = max + } else { + maxlen = spec.maxlen + } + + if exp.named { + w.WriteString(spec.name) + if val == "" { + w.WriteString(exp.ifemp) + return nil + } + w.WriteByte('=') + } + return exp.escape(w, val[:maxlen]) + case ValueTypeList: + var sep string + if spec.explode { + sep = exp.sep + } else { + sep = "," + } + + var pre string + var preifemp string + if spec.explode && exp.named { + pre = spec.name + "=" + preifemp = spec.name + exp.ifemp + } + + if !spec.explode && exp.named { + w.WriteString(spec.name) + w.WriteByte('=') + } + for i := range v.V { + val := v.V[i] + if i > 0 { + w.WriteString(sep) + } + if val == "" { + w.WriteString(preifemp) + continue + } + w.WriteString(pre) + + if err := exp.escape(w, val); err != nil { + return err + } + } + case ValueTypeKV: + var sep string + var kvsep string + if spec.explode { + sep = exp.sep + kvsep = "=" + } else { + sep = "," + kvsep = "," + } + + var ifemp string + var kescape escapeFunc + if spec.explode && exp.named { + ifemp = exp.ifemp + kescape = escapeLiteral + } else { + ifemp = "," + kescape = exp.escape + } + + if !spec.explode && exp.named { + w.WriteString(spec.name) + w.WriteByte('=') + } + + for i := 0; i < len(v.V); i += 2 { + if i > 0 { + w.WriteString(sep) + } + if err := kescape(w, v.V[i]); err != nil { + return err + } + if v.V[i+1] == "" { + w.WriteString(ifemp) + continue + } + w.WriteString(kvsep) + + if err := exp.escape(w, v.V[i+1]); err != nil { + return err + } + } + } + return nil +} + +// String returns Value that represents string. +func String(v string) Value { + return Value{ + T: ValueTypeString, + V: []string{v}, + } +} + +// List returns Value that represents list. +func List(v ...string) Value { + return Value{ + T: ValueTypeList, + V: v, + } +} + +// KV returns Value that represents associative list. +// KV panics if len(kv) is not even. +func KV(kv ...string) Value { + if len(kv)%2 != 0 { + panic("uritemplate.go: count of the kv must be even number") + } + return Value{ + T: ValueTypeKV, + V: kv, + } +} diff --git a/vendor/modules.txt b/vendor/modules.txt index 0d38daabe6..56fdf01d91 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -465,6 +465,11 @@ github.com/lufia/plan9stats github.com/mailru/easyjson/buffer github.com/mailru/easyjson/jlexer github.com/mailru/easyjson/jwriter +# github.com/mark3labs/mcp-go v0.18.0 +## explicit; go 1.23 +github.com/mark3labs/mcp-go/client +github.com/mark3labs/mcp-go/mcp +github.com/mark3labs/mcp-go/server # github.com/maruel/natural v1.1.0 ## explicit; go 1.11 github.com/maruel/natural @@ -578,9 +583,6 @@ github.com/rollbar/rollbar-go # github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 ## explicit; go 1.13 github.com/sergi/go-diff/diffmatchpatch -# github.com/shibukawa/configdir v0.0.0-20170330084843-e180dbdc8da0 -## explicit -github.com/shibukawa/configdir # github.com/shirou/gopsutil/v3 v3.24.5 ## explicit; go 1.18 github.com/shirou/gopsutil/v3/common @@ -662,6 +664,9 @@ github.com/xanzy/ssh-agent # github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8 ## explicit github.com/xi2/xz +# github.com/yosida95/uritemplate/v3 v3.0.2 +## explicit; go 1.14 +github.com/yosida95/uritemplate/v3 # github.com/yusufpapurcu/wmi v1.2.4 ## explicit; go 1.16 github.com/yusufpapurcu/wmi