diff --git a/cmd/github-mcp-server/main.go b/cmd/github-mcp-server/main.go index 13ef0b6a9..881bae87f 100644 --- a/cmd/github-mcp-server/main.go +++ b/cmd/github-mcp-server/main.go @@ -4,6 +4,8 @@ import ( "errors" "fmt" "os" + "strconv" + "strings" "github.com/github/github-mcp-server/internal/ghmcp" "github.com/github/github-mcp-server/pkg/github" @@ -45,6 +47,9 @@ var ( return fmt.Errorf("failed to unmarshal toolsets: %w", err) } + // Parse multi-org installations + installations := parseOrgInstallations() + stdioServerConfig := ghmcp.StdioServerConfig{ Version: version, Host: viper.GetString("host"), @@ -55,6 +60,7 @@ var ( ExportTranslations: viper.GetBool("export-translations"), EnableCommandLogging: viper.GetBool("enable-command-logging"), LogFilePath: viper.GetString("log-file"), + Installations: installations, } return ghmcp.RunStdioServer(stdioServerConfig) @@ -62,13 +68,41 @@ var ( } ) +// parseOrgInstallations parses GITHUB_INSTALLATION_ID_ environment variables +// and returns a map of organization name to installation ID. +// Also includes the default GITHUB_INSTALLATION_ID under "_default" key if set. +func parseOrgInstallations() map[string]int64 { + installations := make(map[string]int64) + prefix := "GITHUB_INSTALLATION_ID_" + + for _, env := range os.Environ() { + if strings.HasPrefix(env, prefix) { + parts := strings.SplitN(env, "=", 2) + if len(parts) == 2 { + org := strings.ToLower(strings.TrimPrefix(parts[0], prefix)) + org = strings.ReplaceAll(org, "_", "-") // Normalize underscores to dashes + if id, err := strconv.ParseInt(parts[1], 10, 64); err == nil { + installations[org] = id + } + } + } + } + + // Add default if set (for backwards compatibility) + if defaultID := viper.GetInt64("installation_id"); defaultID != 0 { + installations["_default"] = defaultID + } + + return installations +} + func init() { cobra.OnInitialize(initConfig) rootCmd.SetVersionTemplate("{{.Short}}\n{{.Version}}\n") // Add global flags that will be shared by all commands - rootCmd.PersistentFlags().StringSlice("toolsets", github.DefaultTools, "An optional comma separated list of groups of tools to allow, defaults to enabling all") + rootCmd.PersistentFlags().StringSlice("toolsets", github.DefaultTools, "An optional comma separated list of groups of tools to allow with optional modes (e.g., 'repos:rw,issues:ro,users'), defaults to enabling all") rootCmd.PersistentFlags().Bool("dynamic-toolsets", false, "Enable dynamic toolsets") rootCmd.PersistentFlags().Bool("read-only", false, "Restrict the server to read-only operations") rootCmd.PersistentFlags().String("log-file", "", "Path to log file") @@ -84,6 +118,7 @@ func init() { // Bind flag to viper _ = viper.BindPFlag("toolsets", rootCmd.PersistentFlags().Lookup("toolsets")) + _ = viper.BindEnv("toolsets", "GITHUB_TOOLSETS") _ = viper.BindPFlag("dynamic_toolsets", rootCmd.PersistentFlags().Lookup("dynamic-toolsets")) _ = viper.BindPFlag("read-only", rootCmd.PersistentFlags().Lookup("read-only")) _ = viper.BindPFlag("log-file", rootCmd.PersistentFlags().Lookup("log-file")) @@ -120,14 +155,18 @@ func validateAuthConfig() error { hasInstallationID := installationID != 0 hasPrivateKey := privateKeyPath != "" || privateKey != "" - if (hasAppID || hasInstallationID || hasPrivateKey) && !(hasAppID && hasInstallationID && hasPrivateKey) { - return errors.New("incomplete GitHub App configuration: GITHUB_APP_ID, GITHUB_INSTALLATION_ID, and either GITHUB_PRIVATE_KEY_FILE_PATH or GITHUB_PRIVATE_KEY must all be set") + // Also check for multi-org installation IDs (GITHUB_INSTALLATION_ID_) + hasMultiOrgInstallations := len(parseOrgInstallations()) > 0 + hasAnyInstallation := hasInstallationID || hasMultiOrgInstallations + + if (hasAppID || hasAnyInstallation || hasPrivateKey) && !(hasAppID && hasAnyInstallation && hasPrivateKey) { + return errors.New("incomplete GitHub App configuration: GITHUB_APP_ID, GITHUB_INSTALLATION_ID (or GITHUB_INSTALLATION_ID_), and either GITHUB_PRIVATE_KEY_FILE_PATH or GITHUB_PRIVATE_KEY must all be set") } // Check PAT if GitHub App auth is not configured token := viper.GetString("personal_access_token") if !hasAppID && token == "" { - return errors.New("no authentication method configured: either set GITHUB_PERSONAL_ACCESS_TOKEN or configure GitHub App authentication with GITHUB_APP_ID, GITHUB_INSTALLATION_ID, and GITHUB_PRIVATE_KEY_FILE_PATH or GITHUB_PRIVATE_KEY") + return errors.New("no authentication method configured: either set GITHUB_PERSONAL_ACCESS_TOKEN or configure GitHub App authentication with GITHUB_APP_ID, GITHUB_INSTALLATION_ID (or GITHUB_INSTALLATION_ID_), and GITHUB_PRIVATE_KEY_FILE_PATH or GITHUB_PRIVATE_KEY") } return nil diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index 2eb96268e..a1bb1ff0f 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -11,7 +11,6 @@ import ( "os/signal" "strings" "syscall" - "time" "github.com/bradleyfalzon/ghinstallation/v2" "github.com/github/github-mcp-server/pkg/github" @@ -32,6 +31,14 @@ func createClient(cfg MCPServerConfig) (*gogithub.Client, error) { appID := viper.GetInt64("app_id") installationID := viper.GetInt64("installation_id") + // If no default installation ID, use first multi-org installation as fallback + if installationID == 0 && len(cfg.Installations) > 0 { + for _, id := range cfg.Installations { + installationID = id + break + } + } + // Check for private key - can be provided as file path or direct content privateKeyPath := viper.GetString("private_key_file_path") privateKeyContent := viper.GetString("private_key") @@ -101,6 +108,14 @@ func createGQLClient(cfg MCPServerConfig) (*githubv4.Client, *http.Client, error appID := viper.GetInt64("app_id") installationID := viper.GetInt64("installation_id") + // If no default installation ID, use first multi-org installation as fallback + if installationID == 0 && len(cfg.Installations) > 0 { + for _, id := range cfg.Installations { + installationID = id + break + } + } + // Check for private key - can be provided as file path or direct content privateKeyPath := viper.GetString("private_key_file_path") privateKeyContent := viper.GetString("private_key") @@ -201,6 +216,9 @@ type MCPServerConfig struct { // ReadOnly indicates if we should only offer read-only tools ReadOnly bool + // Installations maps organization names to GitHub App installation IDs + Installations map[string]int64 + // Translator provides translated text for the server tooling Translator translations.TranslationHelperFunc } @@ -212,8 +230,8 @@ func NewMCPServer(cfg MCPServerConfig) (*server.MCPServer, error) { return nil, fmt.Errorf("failed to create GitHub REST client: %w", err) } - // Create GraphQL client - gqlClient, gqlHTTPClient, err := createGQLClient(cfg) + // Create GraphQL client (for user agent hook only; actual GQL clients from factory) + _, gqlHTTPClient, err := createGQLClient(cfg) if err != nil { return nil, fmt.Errorf("failed to create GitHub GraphQL client: %w", err) } @@ -252,13 +270,19 @@ func NewMCPServer(cfg MCPServerConfig) (*server.MCPServer, error) { } } - // Create repository-aware client factory with 1-hour cache TTL - clientFactory := github.NewRepoAwareClientFactory(restClient, 1*time.Hour) + // Create multi-org client factory + appID := viper.GetInt64("app_id") + privateKey := []byte(viper.GetString("private_key")) + + clientFactory := github.NewMultiOrgClientFactory( + appID, + privateKey, + cfg.Installations, + cfg.Host, + cfg.Version, + ) getClient := clientFactory.GetClientFn() - - getGQLClient := func(_ context.Context) (*githubv4.Client, error) { - return gqlClient, nil // closing over client - } + getGQLClient := clientFactory.GetGQLClientFn() // Create default toolsets toolsets, err := github.InitToolsets( @@ -272,12 +296,10 @@ func NewMCPServer(cfg MCPServerConfig) (*server.MCPServer, error) { return nil, fmt.Errorf("failed to initialize toolsets: %w", err) } - context := github.InitContextToolset(getClient, cfg.Translator) github.RegisterResources(ghServer, getClient, cfg.Translator) // Register the tools with the server toolsets.RegisterTools(ghServer) - context.RegisterTools(ghServer) if cfg.DynamicToolsets { dynamic := github.InitDynamicToolset(ghServer, toolsets, cfg.Translator) @@ -317,6 +339,9 @@ type StdioServerConfig struct { // Path to the log file if not stderr LogFilePath string + + // Installations maps organization names to GitHub App installation IDs + Installations map[string]int64 } // RunStdioServer is not concurrent safe. @@ -334,6 +359,7 @@ func RunStdioServer(cfg StdioServerConfig) error { EnabledToolsets: cfg.EnabledToolsets, DynamicToolsets: cfg.DynamicToolsets, ReadOnly: cfg.ReadOnly, + Installations: cfg.Installations, Translator: t, }) if err != nil { diff --git a/pkg/github/code_scanning.go b/pkg/github/code_scanning.go index 34a1b9eda..f0e6233ce 100644 --- a/pkg/github/code_scanning.go +++ b/pkg/github/code_scanning.go @@ -47,7 +47,7 @@ func GetCodeScanningAlert(getClient GetClientFn, t translations.TranslationHelpe return mcp.NewToolResultError(err.Error()), nil } - client, err := getClient(ctx) + client, err := getClient(ctx, owner) if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } @@ -132,7 +132,7 @@ func ListCodeScanningAlerts(getClient GetClientFn, t translations.TranslationHel return mcp.NewToolResultError(err.Error()), nil } - client, err := getClient(ctx) + client, err := getClient(ctx, owner) if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } diff --git a/pkg/github/context_tools.go b/pkg/github/context_tools.go index 180f32dd4..5eaccc81a 100644 --- a/pkg/github/context_tools.go +++ b/pkg/github/context_tools.go @@ -25,7 +25,7 @@ func GetMe(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mc ), ), func(ctx context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) { - client, err := getClient(ctx) + client, err := getClient(ctx, "") if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } diff --git a/pkg/github/issues.go b/pkg/github/issues.go index 7c8451d39..25c108b6b 100644 --- a/pkg/github/issues.go +++ b/pkg/github/issues.go @@ -49,7 +49,7 @@ func GetIssue(getClient GetClientFn, t translations.TranslationHelperFunc) (tool return mcp.NewToolResultError(err.Error()), nil } - client, err := getClient(ctx) + client, err := getClient(ctx, owner) if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } @@ -123,7 +123,7 @@ func AddIssueComment(getClient GetClientFn, t translations.TranslationHelperFunc Body: github.Ptr(body), } - client, err := getClient(ctx) + client, err := getClient(ctx, owner) if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } @@ -211,7 +211,7 @@ func SearchIssues(getClient GetClientFn, t translations.TranslationHelperFunc) ( }, } - client, err := getClient(ctx) + client, err := getClient(ctx, "") if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } @@ -333,7 +333,7 @@ func CreateIssue(getClient GetClientFn, t translations.TranslationHelperFunc) (t Milestone: milestoneNum, } - client, err := getClient(ctx) + client, err := getClient(ctx, owner) if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } @@ -455,7 +455,7 @@ func ListIssues(getClient GetClientFn, t translations.TranslationHelperFunc) (to opts.PerPage = int(perPage) } - client, err := getClient(ctx) + client, err := getClient(ctx, owner) if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } @@ -601,7 +601,7 @@ func UpdateIssue(getClient GetClientFn, t translations.TranslationHelperFunc) (t issueRequest.Milestone = &milestoneNum } - client, err := getClient(ctx) + client, err := getClient(ctx, owner) if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } @@ -684,7 +684,7 @@ func GetIssueComments(getClient GetClientFn, t translations.TranslationHelperFun }, } - client, err := getClient(ctx) + client, err := getClient(ctx, owner) if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index d6dd3f96e..16a82faf7 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -50,7 +50,7 @@ func GetPullRequest(getClient GetClientFn, t translations.TranslationHelperFunc) return mcp.NewToolResultError(err.Error()), nil } - client, err := getClient(ctx) + client, err := getClient(ctx, owner) if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } @@ -165,7 +165,7 @@ func CreatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu newPR.Draft = github.Ptr(draft) newPR.MaintainerCanModify = github.Ptr(maintainerCanModify) - client, err := getClient(ctx) + client, err := getClient(ctx, owner) if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } @@ -286,7 +286,7 @@ func UpdatePullRequest(getClient GetClientFn, t translations.TranslationHelperFu return mcp.NewToolResultError("No update parameters provided."), nil } - client, err := getClient(ctx) + client, err := getClient(ctx, owner) if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } @@ -395,7 +395,7 @@ func ListPullRequests(getClient GetClientFn, t translations.TranslationHelperFun }, } - client, err := getClient(ctx) + client, err := getClient(ctx, owner) if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } @@ -484,7 +484,7 @@ func MergePullRequest(getClient GetClientFn, t translations.TranslationHelperFun MergeMethod: mergeMethod, } - client, err := getClient(ctx) + client, err := getClient(ctx, owner) if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } @@ -546,7 +546,7 @@ func GetPullRequestFiles(getClient GetClientFn, t translations.TranslationHelper return mcp.NewToolResultError(err.Error()), nil } - client, err := getClient(ctx) + client, err := getClient(ctx, owner) if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } @@ -609,7 +609,7 @@ func GetPullRequestStatus(getClient GetClientFn, t translations.TranslationHelpe return mcp.NewToolResultError(err.Error()), nil } // First get the PR to find the head SHA - client, err := getClient(ctx) + client, err := getClient(ctx, owner) if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } @@ -697,7 +697,7 @@ func UpdatePullRequestBranch(getClient GetClientFn, t translations.TranslationHe opts.ExpectedHeadSHA = github.Ptr(expectedHeadSHA) } - client, err := getClient(ctx) + client, err := getClient(ctx, owner) if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } @@ -770,7 +770,7 @@ func GetPullRequestComments(getClient GetClientFn, t translations.TranslationHel }, } - client, err := getClient(ctx) + client, err := getClient(ctx, owner) if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } @@ -832,7 +832,7 @@ func GetPullRequestReviews(getClient GetClientFn, t translations.TranslationHelp return mcp.NewToolResultError(err.Error()), nil } - client, err := getClient(ctx) + client, err := getClient(ctx, owner) if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } @@ -908,7 +908,7 @@ func CreateAndSubmitPullRequestReview(getGQLClient GetGQLClientFn, t translation } // Given our owner, repo and PR number, lookup the GQL ID of the PR. - client, err := getGQLClient(ctx) + client, err := getGQLClient(ctx, params.Owner) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("failed to get GitHub GQL client: %v", err)), nil } @@ -999,7 +999,7 @@ func CreatePendingPullRequestReview(getGQLClient GetGQLClientFn, t translations. } // Given our owner, repo and PR number, lookup the GQL ID of the PR. - client, err := getGQLClient(ctx) + client, err := getGQLClient(ctx, params.Owner) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("failed to get GitHub GQL client: %v", err)), nil } @@ -1121,7 +1121,7 @@ func AddPullRequestReviewCommentToPendingReview(getGQLClient GetGQLClientFn, t t return mcp.NewToolResultError(err.Error()), nil } - client, err := getGQLClient(ctx) + client, err := getGQLClient(ctx, params.Owner) if err != nil { return nil, fmt.Errorf("failed to get GitHub GQL client: %w", err) } @@ -1252,7 +1252,7 @@ func SubmitPendingPullRequestReview(getGQLClient GetGQLClientFn, t translations. return mcp.NewToolResultError(err.Error()), nil } - client, err := getGQLClient(ctx) + client, err := getGQLClient(ctx, params.Owner) if err != nil { return nil, fmt.Errorf("failed to get GitHub GQL client: %w", err) } @@ -1367,7 +1367,7 @@ func DeletePendingPullRequestReview(getGQLClient GetGQLClientFn, t translations. return mcp.NewToolResultError(err.Error()), nil } - client, err := getGQLClient(ctx) + client, err := getGQLClient(ctx, params.Owner) if err != nil { return nil, fmt.Errorf("failed to get GitHub GQL client: %w", err) } @@ -1476,7 +1476,7 @@ func GetPullRequestDiff(getClient GetClientFn, t translations.TranslationHelperF return mcp.NewToolResultError(err.Error()), nil } - client, err := getClient(ctx) + client, err := getClient(ctx, params.Owner) if err != nil { return mcp.NewToolResultError(fmt.Sprintf("failed to get GitHub client: %v", err)), nil } @@ -1546,7 +1546,7 @@ func RequestCopilotReview(getClient GetClientFn, t translations.TranslationHelpe return mcp.NewToolResultError(err.Error()), nil } - client, err := getClient(ctx) + client, err := getClient(ctx, owner) if err != nil { return mcp.NewToolResultError(err.Error()), nil } diff --git a/pkg/github/repo_access.go b/pkg/github/repo_access.go index 49243393a..69d5c9ebb 100644 --- a/pkg/github/repo_access.go +++ b/pkg/github/repo_access.go @@ -4,10 +4,13 @@ import ( "context" "fmt" "net/http" + "strings" "sync" "time" + "github.com/bradleyfalzon/ghinstallation/v2" "github.com/google/go-github/v69/github" + "github.com/shurcooL/githubv4" ) // RepoAccessType indicates how a repository should be accessed @@ -145,17 +148,153 @@ func WithRepoContext(ctx context.Context, owner, repo string) context.Context { return ctx } -// GetClientFn returns a function that gets the appropriate client based on context +// GetClientFn returns a function that gets the appropriate client based on context or owner func (f *RepoAwareClientFactory) GetClientFn() GetClientFn { - return func(ctx context.Context) (*github.Client, error) { - owner, ok1 := ctx.Value("github.owner").(string) - repo, ok2 := ctx.Value("github.repo").(string) + return func(ctx context.Context, owner string) (*github.Client, error) { + // If owner not provided, try to get from context + if owner == "" { + ownerFromCtx, ok1 := ctx.Value("github.owner").(string) + repo, ok2 := ctx.Value("github.repo").(string) - if !ok1 || !ok2 { - // If we can't determine the repo, use authenticated client - return f.authClient, nil + if !ok1 || !ok2 { + // If we can't determine the repo, use authenticated client + return f.authClient, nil + } + + return f.GetClientForRepo(ctx, ownerFromCtx, repo) + } + + // Use provided owner + return f.authClient, nil + } +} + +// MultiOrgClientFactory creates GitHub clients for different organizations +type MultiOrgClientFactory struct { + appID int64 + privateKey []byte + installations map[string]int64 // org -> installation_id + defaultInstall int64 + transports map[string]*ghinstallation.Transport // cached per-org + transportsMu sync.RWMutex + host string + version string +} + +// NewMultiOrgClientFactory creates a new multi-org client factory +func NewMultiOrgClientFactory( + appID int64, + privateKey []byte, + installations map[string]int64, + host, version string, +) *MultiOrgClientFactory { + defaultInstall := installations["_default"] + delete(installations, "_default") + + return &MultiOrgClientFactory{ + appID: appID, + privateKey: privateKey, + installations: installations, + defaultInstall: defaultInstall, + transports: make(map[string]*ghinstallation.Transport), + host: host, + version: version, + } +} + +// getInstallationID returns the installation ID for a given organization +func (f *MultiOrgClientFactory) getInstallationID(owner string) int64 { + owner = strings.ToLower(owner) + + // Try exact match first + if id, ok := f.installations[owner]; ok { + return id + } + + // Fall back to default + return f.defaultInstall +} + +// GetClientFn returns a function that gets the appropriate client based on owner +func (f *MultiOrgClientFactory) GetClientFn() GetClientFn { + return func(ctx context.Context, owner string) (*github.Client, error) { + if owner == "" { + // User-scoped operations (GetMe) use default + owner = "_default" + } + + installID := f.getInstallationID(owner) + if installID == 0 { + // No installation for this org - return anonymous client + // This allows public repo operations to succeed + client := github.NewClient(nil) + client.UserAgent = fmt.Sprintf("github-mcp-server/%s", f.version) + return client, nil + } + + // Get or create transport for this installation + transport, err := f.getOrCreateTransport(owner, installID) + if err != nil { + return nil, err } - return f.GetClientForRepo(ctx, owner, repo) + client := github.NewClient(&http.Client{Transport: transport}) + client.UserAgent = fmt.Sprintf("github-mcp-server/%s", f.version) + return client, nil } } + +// GetGQLClientFn returns a function that gets the appropriate GraphQL client based on owner +func (f *MultiOrgClientFactory) GetGQLClientFn() GetGQLClientFn { + return func(ctx context.Context, owner string) (*githubv4.Client, error) { + if owner == "" { + // User-scoped operations use default + owner = "_default" + } + + installID := f.getInstallationID(owner) + if installID == 0 { + // No installation for this org - return anonymous client + return githubv4.NewClient(nil), nil + } + + // Get or create transport for this installation + transport, err := f.getOrCreateTransport(owner, installID) + if err != nil { + return nil, err + } + + client := githubv4.NewClient(&http.Client{Transport: transport}) + return client, nil + } +} + +// getOrCreateTransport gets or creates a cached transport for an organization +func (f *MultiOrgClientFactory) getOrCreateTransport(org string, installID int64) (*ghinstallation.Transport, error) { + f.transportsMu.RLock() + if t, ok := f.transports[org]; ok { + f.transportsMu.RUnlock() + return t, nil + } + f.transportsMu.RUnlock() + + f.transportsMu.Lock() + defer f.transportsMu.Unlock() + + // Double-check after acquiring write lock + if t, ok := f.transports[org]; ok { + return t, nil + } + + transport, err := ghinstallation.New(http.DefaultTransport, f.appID, installID, f.privateKey) + if err != nil { + return nil, fmt.Errorf("failed to create transport for %s: %w", org, err) + } + + if f.host != "" { + transport.BaseURL = f.host + } + + f.transports[org] = transport + return transport, nil +} diff --git a/pkg/github/repo_access_test.go b/pkg/github/repo_access_test.go index fbd434473..cc7198fc5 100644 --- a/pkg/github/repo_access_test.go +++ b/pkg/github/repo_access_test.go @@ -152,18 +152,18 @@ func TestRepoAwareClientFactory_GetClientForRepo(t *testing.T) { // Test context-based client selection ctx := WithRepoContext(context.Background(), "owner", "private-repo") - client, err = factory.GetClientFn()(ctx) + client, err = factory.GetClientFn()(ctx, "") assert.NoError(t, err) assert.Equal(t, authClient, client) ctx = WithRepoContext(context.Background(), "owner", "public-repo") - client, err = factory.GetClientFn()(ctx) + client, err = factory.GetClientFn()(ctx, "") assert.NoError(t, err) assert.NotEqual(t, authClient, client) // Test missing context info ctx = context.Background() - client, err = factory.GetClientFn()(ctx) + client, err = factory.GetClientFn()(ctx, "") assert.NoError(t, err) assert.Equal(t, authClient, client) // Should default to auth client } diff --git a/pkg/github/repositories.go b/pkg/github/repositories.go index 765861cb0..35d99cb95 100644 --- a/pkg/github/repositories.go +++ b/pkg/github/repositories.go @@ -57,7 +57,7 @@ func GetCommit(getClient GetClientFn, t translations.TranslationHelperFunc) (too PerPage: pagination.perPage, } - client, err := getClient(ctx) + client, err := getClient(ctx, owner) if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } @@ -131,7 +131,7 @@ func ListCommits(getClient GetClientFn, t translations.TranslationHelperFunc) (t }, } - client, err := getClient(ctx) + client, err := getClient(ctx, owner) if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } @@ -200,7 +200,7 @@ func ListBranches(getClient GetClientFn, t translations.TranslationHelperFunc) ( }, } - client, err := getClient(ctx) + client, err := getClient(ctx, owner) if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } @@ -310,7 +310,7 @@ func CreateOrUpdateFile(getClient GetClientFn, t translations.TranslationHelperF } // Create or update the file - client, err := getClient(ctx) + client, err := getClient(ctx, owner) if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } @@ -384,7 +384,7 @@ func CreateRepository(getClient GetClientFn, t translations.TranslationHelperFun AutoInit: github.Ptr(autoInit), } - client, err := getClient(ctx) + client, err := getClient(ctx, "") if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } @@ -456,7 +456,7 @@ func GetFileContents(getClient GetClientFn, t translations.TranslationHelperFunc // Add repository info to context for client selection ctx = WithRepoContext(ctx, owner, repo) - client, err := getClient(ctx) + client, err := getClient(ctx, owner) if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } @@ -530,7 +530,7 @@ func ForkRepository(getClient GetClientFn, t translations.TranslationHelperFunc) opts.Organization = org } - client, err := getClient(ctx) + client, err := getClient(ctx, owner) if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } @@ -619,7 +619,7 @@ func DeleteFile(getClient GetClientFn, t translations.TranslationHelperFunc) (to return mcp.NewToolResultError(err.Error()), nil } - client, err := getClient(ctx) + client, err := getClient(ctx, owner) if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } @@ -764,7 +764,7 @@ func CreateBranch(getClient GetClientFn, t translations.TranslationHelperFunc) ( return mcp.NewToolResultError(err.Error()), nil } - client, err := getClient(ctx) + client, err := getClient(ctx, owner) if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } @@ -880,7 +880,7 @@ func PushFiles(getClient GetClientFn, t translations.TranslationHelperFunc) (too return mcp.NewToolResultError("files parameter must be an array of objects with path and content"), nil } - client, err := getClient(ctx) + client, err := getClient(ctx, owner) if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } @@ -1000,7 +1000,7 @@ func ListTags(getClient GetClientFn, t translations.TranslationHelperFunc) (tool PerPage: pagination.perPage, } - client, err := getClient(ctx) + client, err := getClient(ctx, owner) if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } @@ -1063,7 +1063,7 @@ func GetTag(getClient GetClientFn, t translations.TranslationHelperFunc) (tool m return mcp.NewToolResultError(err.Error()), nil } - client, err := getClient(ctx) + client, err := getClient(ctx, owner) if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } diff --git a/pkg/github/repository_resource.go b/pkg/github/repository_resource.go index 96843da5a..df7908f05 100644 --- a/pkg/github/repository_resource.go +++ b/pkg/github/repository_resource.go @@ -110,7 +110,7 @@ func RepositoryResourceContentsHandler(getClient GetClientFn) func(ctx context.C opts.Ref = "refs/pull/" + prNumber[0] + "/head" } - client, err := getClient(ctx) + client, err := getClient(ctx, owner) if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } diff --git a/pkg/github/search.go b/pkg/github/search.go index 4865b009a..a980ac956 100644 --- a/pkg/github/search.go +++ b/pkg/github/search.go @@ -40,7 +40,9 @@ func SearchRepositories(getClient GetClientFn, t translations.TranslationHelperF // This is a simple heuristic to check if the query contains a specific repository // Format could be "repo:owner/repo" or similar repoQuery := extractRepoFromQuery(query) + owner := "" if repoQuery.owner != "" && repoQuery.repo != "" { + owner = repoQuery.owner ctx = WithRepoContext(ctx, repoQuery.owner, repoQuery.repo) } @@ -51,7 +53,7 @@ func SearchRepositories(getClient GetClientFn, t translations.TranslationHelperF }, } - client, err := getClient(ctx) + client, err := getClient(ctx, owner) if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } @@ -119,7 +121,9 @@ func SearchCode(getClient GetClientFn, t translations.TranslationHelperFunc) (to // Extract repository info if present in the query repoQuery := extractRepoFromQuery(query) + owner := "" if repoQuery.owner != "" && repoQuery.repo != "" { + owner = repoQuery.owner ctx = WithRepoContext(ctx, repoQuery.owner, repoQuery.repo) } @@ -132,7 +136,7 @@ func SearchCode(getClient GetClientFn, t translations.TranslationHelperFunc) (to }, } - client, err := getClient(ctx) + client, err := getClient(ctx, owner) if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } @@ -209,7 +213,7 @@ func SearchUsers(getClient GetClientFn, t translations.TranslationHelperFunc) (t }, } - client, err := getClient(ctx) + client, err := getClient(ctx, "") if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } diff --git a/pkg/github/secret_scanning.go b/pkg/github/secret_scanning.go index 847fcfc6d..b3dd42a39 100644 --- a/pkg/github/secret_scanning.go +++ b/pkg/github/secret_scanning.go @@ -48,7 +48,7 @@ func GetSecretScanningAlert(getClient GetClientFn, t translations.TranslationHel return mcp.NewToolResultError(err.Error()), nil } - client, err := getClient(ctx) + client, err := getClient(ctx, owner) if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } @@ -126,7 +126,7 @@ func ListSecretScanningAlerts(getClient GetClientFn, t translations.TranslationH return mcp.NewToolResultError(err.Error()), nil } - client, err := getClient(ctx) + client, err := getClient(ctx, owner) if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } diff --git a/pkg/github/server_test.go b/pkg/github/server_test.go index 955377990..974f01852 100644 --- a/pkg/github/server_test.go +++ b/pkg/github/server_test.go @@ -11,13 +11,13 @@ import ( ) func stubGetClientFn(client *github.Client) GetClientFn { - return func(_ context.Context) (*github.Client, error) { + return func(_ context.Context, _ string) (*github.Client, error) { return client, nil } } func stubGetGQLClientFn(client *githubv4.Client) GetGQLClientFn { - return func(_ context.Context) (*githubv4.Client, error) { + return func(_ context.Context, _ string) (*githubv4.Client, error) { return client, nil } } diff --git a/pkg/github/tools.go b/pkg/github/tools.go index a04e7336b..5cd5c7916 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -2,6 +2,7 @@ package github import ( "context" + "fmt" "github.com/github/github-mcp-server/pkg/toolsets" "github.com/github/github-mcp-server/pkg/translations" @@ -10,12 +11,22 @@ import ( "github.com/shurcooL/githubv4" ) -type GetClientFn func(context.Context) (*github.Client, error) -type GetGQLClientFn func(context.Context) (*githubv4.Client, error) +type GetClientFn func(ctx context.Context, owner string) (*github.Client, error) +type GetGQLClientFn func(ctx context.Context, owner string) (*githubv4.Client, error) var DefaultTools = []string{"all"} func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn, getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (*toolsets.ToolsetGroup, error) { + // Parse toolset configurations from the passed toolsets + configs, err := toolsets.ParseToolsetConfigFromSlice(passedToolsets) + if err != nil { + return nil, fmt.Errorf("failed to parse toolset configuration: %w", err) + } + + return InitToolsetsWithConfig(configs, readOnly, getClient, getGQLClient, t) +} + +func InitToolsetsWithConfig(configs []toolsets.ToolsetConfig, readOnly bool, getClient GetClientFn, getGQLClient GetGQLClientFn, t translations.TranslationHelperFunc) (*toolsets.ToolsetGroup, error) { // Create a new toolset group tsg := toolsets.NewToolsetGroup(readOnly) @@ -93,6 +104,12 @@ func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn, // Keep experiments alive so the system doesn't error out when it's always enabled experiments := toolsets.NewToolset("experiments", "Experimental features that are not considered stable yet") + // Create context toolset (always available) + context := toolsets.NewToolset("context", "Tools that provide context about the current user and GitHub context you are operating in"). + AddReadTools( + toolsets.NewServerTool(GetMe(getClient, t)), + ) + // Add toolsets to the group tsg.AddToolset(repos) tsg.AddToolset(issues) @@ -101,9 +118,10 @@ func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn, tsg.AddToolset(codeSecurity) tsg.AddToolset(secretProtection) tsg.AddToolset(experiments) - // Enable the requested features + tsg.AddToolset(context) - if err := tsg.EnableToolsets(passedToolsets); err != nil { + // Enable the requested toolsets with their configurations + if err := tsg.EnableToolsetsWithConfig(configs); err != nil { return nil, err } diff --git a/pkg/github/tools_test.go b/pkg/github/tools_test.go new file mode 100644 index 000000000..b8e33001a --- /dev/null +++ b/pkg/github/tools_test.go @@ -0,0 +1,381 @@ +package github + +import ( + "context" + "testing" + + "github.com/github/github-mcp-server/pkg/toolsets" + "github.com/google/go-github/v69/github" + "github.com/shurcooL/githubv4" +) + +func TestInitToolsetsWithConfig(t *testing.T) { + // Mock translation function + mockTranslator := func(key, fallback string) string { + return fallback + } + + tests := []struct { + name string + configs []toolsets.ToolsetConfig + readOnly bool + wantErr bool + expected map[string]toolsets.ToolsetMode + }{ + { + name: "enable repos with rw and issues with ro", + configs: []toolsets.ToolsetConfig{ + {Name: "repos", Mode: toolsets.ReadWrite}, + {Name: "issues", Mode: toolsets.ReadOnly}, + }, + readOnly: false, + wantErr: false, + expected: map[string]toolsets.ToolsetMode{ + "repos": toolsets.ReadWrite, + "issues": toolsets.ReadOnly, + }, + }, + { + name: "enable all with readonly mode", + configs: []toolsets.ToolsetConfig{ + {Name: "all", Mode: toolsets.ReadOnly}, + }, + readOnly: false, + wantErr: false, + expected: map[string]toolsets.ToolsetMode{ + "repos": toolsets.ReadOnly, + "issues": toolsets.ReadOnly, + "users": toolsets.ReadOnly, + "pull_requests": toolsets.ReadOnly, + "code_security": toolsets.ReadOnly, + "secret_protection": toolsets.ReadOnly, + "experiments": toolsets.ReadOnly, + "context": toolsets.ReadOnly, + }, + }, + { + name: "enable nonexistent toolset", + configs: []toolsets.ToolsetConfig{ + {Name: "nonexistent", Mode: toolsets.ReadWrite}, + }, + readOnly: false, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Mock client functions + getClient := func(ctx context.Context, _ string) (*github.Client, error) { + return nil, nil + } + getGQLClient := func(ctx context.Context, _ string) (*githubv4.Client, error) { + return nil, nil + } + + tsg, err := InitToolsetsWithConfig(tt.configs, tt.readOnly, getClient, getGQLClient, mockTranslator) + + if tt.wantErr { + if err == nil { + t.Errorf("InitToolsetsWithConfig() expected error but got none") + } + return + } + + if err != nil { + t.Errorf("InitToolsetsWithConfig() unexpected error: %v", err) + return + } + + // Check that expected toolsets are enabled with correct modes + for name, expectedMode := range tt.expected { + toolset, exists := tsg.Toolsets[name] + if !exists { + t.Errorf("Expected toolset %s to exist", name) + continue + } + if !toolset.Enabled { + t.Errorf("Expected toolset %s to be enabled", name) + } + if toolset.Mode != expectedMode { + t.Errorf("Expected toolset %s to have mode %s, got %s", name, expectedMode, toolset.Mode) + } + } + + // Check that non-expected toolsets are not enabled + for name, toolset := range tsg.Toolsets { + if _, expected := tt.expected[name]; !expected && toolset.Enabled { + t.Errorf("Expected toolset %s to not be enabled", name) + } + } + }) + } +} + +func TestInitToolsets_BackwardCompatibility(t *testing.T) { + // Mock translation function + mockTranslator := func(key, fallback string) string { + return fallback + } + + // Mock client functions + getClient := func(ctx context.Context, _ string) (*github.Client, error) { + return nil, nil + } + getGQLClient := func(ctx context.Context, _ string) (*githubv4.Client, error) { + return nil, nil + } + + tests := []struct { + name string + passedToolsets []string + readOnly bool + wantErr bool + expectedEnabled []string + }{ + { + name: "legacy format - single toolset", + passedToolsets: []string{"repos"}, + readOnly: false, + wantErr: false, + expectedEnabled: []string{"repos"}, + }, + { + name: "legacy format - multiple toolsets", + passedToolsets: []string{"repos", "issues", "users"}, + readOnly: false, + wantErr: false, + expectedEnabled: []string{"repos", "issues", "users"}, + }, + { + name: "legacy format - all", + passedToolsets: []string{"all"}, + readOnly: false, + wantErr: false, + expectedEnabled: []string{"repos", "issues", "users", "pull_requests", "code_security", "secret_protection", "experiments", "context"}, + }, + { + name: "new format - mixed modes", + passedToolsets: []string{"repos:rw,issues:ro,users"}, + readOnly: false, + wantErr: false, + expectedEnabled: []string{"repos", "issues", "users"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tsg, err := InitToolsets(tt.passedToolsets, tt.readOnly, getClient, getGQLClient, mockTranslator) + + if tt.wantErr { + if err == nil { + t.Errorf("InitToolsets() expected error but got none") + } + return + } + + if err != nil { + t.Errorf("InitToolsets() unexpected error: %v", err) + return + } + + // Check that expected toolsets are enabled + for _, name := range tt.expectedEnabled { + toolset, exists := tsg.Toolsets[name] + if !exists { + t.Errorf("Expected toolset %s to exist", name) + continue + } + if !toolset.Enabled { + t.Errorf("Expected toolset %s to be enabled", name) + } + } + }) + } +} + +func TestToolsetModeFiltering(t *testing.T) { + // Mock translation function + mockTranslator := func(key, fallback string) string { + return fallback + } + + // Mock client functions + getClient := func(ctx context.Context, _ string) (*github.Client, error) { + return nil, nil + } + getGQLClient := func(ctx context.Context, _ string) (*githubv4.Client, error) { + return nil, nil + } + + tests := []struct { + name string + configs []toolsets.ToolsetConfig + toolsetName string + expectWriteTools bool + expectReadTools bool + }{ + { + name: "repos toolset in ReadWrite mode should have both read and write tools", + configs: []toolsets.ToolsetConfig{ + {Name: "repos", Mode: toolsets.ReadWrite}, + }, + toolsetName: "repos", + expectWriteTools: true, + expectReadTools: true, + }, + { + name: "repos toolset in ReadOnly mode should have only read tools", + configs: []toolsets.ToolsetConfig{ + {Name: "repos", Mode: toolsets.ReadOnly}, + }, + toolsetName: "repos", + expectWriteTools: false, + expectReadTools: true, + }, + { + name: "pull_requests toolset in ReadOnly mode should have only read tools", + configs: []toolsets.ToolsetConfig{ + {Name: "pull_requests", Mode: toolsets.ReadOnly}, + }, + toolsetName: "pull_requests", + expectWriteTools: false, + expectReadTools: true, + }, + { + name: "pull_requests toolset in ReadWrite mode should have both read and write tools", + configs: []toolsets.ToolsetConfig{ + {Name: "pull_requests", Mode: toolsets.ReadWrite}, + }, + toolsetName: "pull_requests", + expectWriteTools: true, + expectReadTools: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tsg, err := InitToolsetsWithConfig(tt.configs, false, getClient, getGQLClient, mockTranslator) + if err != nil { + t.Fatalf("InitToolsetsWithConfig() error: %v", err) + } + + toolset, exists := tsg.Toolsets[tt.toolsetName] + if !exists { + t.Fatalf("Expected toolset %s to exist", tt.toolsetName) + } + + if !toolset.Enabled { + t.Fatalf("Expected toolset %s to be enabled", tt.toolsetName) + } + + activeTools := toolset.GetActiveTools() + availableTools := toolset.GetAvailableTools() + + // Count read and write tools + readToolCount := 0 + writeToolCount := 0 + + for _, tool := range activeTools { + if tool.Tool.Annotations.ReadOnlyHint != nil && *tool.Tool.Annotations.ReadOnlyHint { + readToolCount++ + } else { + writeToolCount++ + } + } + + // Verify expectations + if tt.expectReadTools && readToolCount == 0 { + t.Errorf("Expected toolset %s to have read tools, but found none", tt.toolsetName) + } + + if !tt.expectReadTools && readToolCount > 0 { + t.Errorf("Expected toolset %s to have no read tools, but found %d", tt.toolsetName, readToolCount) + } + + if tt.expectWriteTools && writeToolCount == 0 { + t.Errorf("Expected toolset %s to have write tools, but found none", tt.toolsetName) + } + + if !tt.expectWriteTools && writeToolCount > 0 { + t.Errorf("Expected toolset %s to have no write tools, but found %d", tt.toolsetName, writeToolCount) + } + + // Verify that GetActiveTools and GetAvailableTools behave correctly + if tt.expectWriteTools { + if len(activeTools) != len(availableTools) { + t.Errorf("For ReadWrite mode, active tools (%d) should equal available tools (%d)", len(activeTools), len(availableTools)) + } + } else { + // In ReadOnly mode, active tools should be a subset of available tools + if len(activeTools) > len(availableTools) { + t.Errorf("Active tools (%d) should not exceed available tools (%d)", len(activeTools), len(availableTools)) + } + } + + t.Logf("Toolset %s (%s mode): %d read tools, %d write tools, %d active tools, %d available tools", + tt.toolsetName, toolset.Mode, readToolCount, writeToolCount, len(activeTools), len(availableTools)) + }) + } +} + +func TestContextToolsetIntegration(t *testing.T) { + // Mock translation function + mockTranslator := func(key, fallback string) string { + return fallback + } + + // Mock client functions + getClient := func(ctx context.Context, _ string) (*github.Client, error) { + return nil, nil + } + getGQLClient := func(ctx context.Context, _ string) (*githubv4.Client, error) { + return nil, nil + } + + // Test that context toolset can be configured + configs := []toolsets.ToolsetConfig{ + {Name: "context", Mode: toolsets.ReadWrite}, + {Name: "repos", Mode: toolsets.ReadOnly}, + } + + tsg, err := InitToolsetsWithConfig(configs, false, getClient, getGQLClient, mockTranslator) + if err != nil { + t.Fatalf("InitToolsetsWithConfig() error: %v", err) + } + + // Verify context toolset exists and is enabled + contextToolset, exists := tsg.Toolsets["context"] + if !exists { + t.Fatalf("Expected context toolset to exist") + } + + if !contextToolset.Enabled { + t.Errorf("Expected context toolset to be enabled") + } + + if contextToolset.Mode != toolsets.ReadWrite { + t.Errorf("Expected context toolset to have ReadWrite mode, got %s", contextToolset.Mode) + } + + // Verify context toolset has the expected tools + activeTools := contextToolset.GetActiveTools() + if len(activeTools) == 0 { + t.Errorf("Expected context toolset to have tools") + } + + // Check that we have the get_me tool + found := false + for _, tool := range activeTools { + if tool.Tool.Name == "get_me" { + found = true + break + } + } + + if !found { + t.Errorf("Expected context toolset to have get_me tool") + } + + t.Logf("Context toolset has %d active tools", len(activeTools)) +} diff --git a/pkg/toolsets/config.go b/pkg/toolsets/config.go new file mode 100644 index 000000000..0e96eb0b8 --- /dev/null +++ b/pkg/toolsets/config.go @@ -0,0 +1,109 @@ +package toolsets + +import ( + "fmt" + "strings" +) + +// ToolsetMode represents the access mode for a toolset +type ToolsetMode int + +const ( + // ReadWrite allows both read and write tools + ReadWrite ToolsetMode = iota + // ReadOnly allows only read tools + ReadOnly +) + +// String returns the string representation of the toolset mode +func (m ToolsetMode) String() string { + switch m { + case ReadWrite: + return "rw" + case ReadOnly: + return "ro" + default: + return "unknown" + } +} + +// ToolsetConfig represents the configuration for a single toolset +type ToolsetConfig struct { + Name string + Mode ToolsetMode +} + +// ParseToolsetConfig parses a toolset configuration string like "repos:rw,issues:ro,users" +// and returns a slice of ToolsetConfig structs. +// +// Supported formats: +// - "toolset" -> ToolsetConfig{Name: "toolset", Mode: ReadWrite} +// - "toolset:rw" -> ToolsetConfig{Name: "toolset", Mode: ReadWrite} +// - "toolset:ro" -> ToolsetConfig{Name: "toolset", Mode: ReadOnly} +// - "toolset:readwrite" -> ToolsetConfig{Name: "toolset", Mode: ReadWrite} +// - "toolset:readonly" -> ToolsetConfig{Name: "toolset", Mode: ReadOnly} +func ParseToolsetConfig(input string) ([]ToolsetConfig, error) { + if input == "" { + return []ToolsetConfig{}, nil + } + + configs := []ToolsetConfig{} + items := strings.Split(input, ",") + + for _, item := range items { + item = strings.TrimSpace(item) + if item == "" { + continue + } + + if strings.Contains(item, ":") { + parts := strings.Split(item, ":") + if len(parts) != 2 { + return nil, fmt.Errorf("invalid toolset format '%s': expected 'name:mode'", item) + } + + name := strings.TrimSpace(parts[0]) + modeStr := strings.TrimSpace(parts[1]) + + if name == "" { + return nil, fmt.Errorf("invalid toolset format '%s': toolset name cannot be empty", item) + } + + mode, err := parseMode(modeStr) + if err != nil { + return nil, fmt.Errorf("invalid mode '%s' for toolset '%s': %w", modeStr, name, err) + } + + configs = append(configs, ToolsetConfig{Name: name, Mode: mode}) + } else { + // Default to ReadWrite if no mode specified + configs = append(configs, ToolsetConfig{Name: item, Mode: ReadWrite}) + } + } + + return configs, nil +} + +// parseMode parses a mode string and returns the corresponding ToolsetMode +func parseMode(modeStr string) (ToolsetMode, error) { + switch strings.ToLower(modeStr) { + case "rw", "readwrite": + return ReadWrite, nil + case "ro", "readonly": + return ReadOnly, nil + default: + return ReadWrite, fmt.Errorf("supported modes are 'rw', 'readwrite', 'ro', 'readonly'") + } +} + +// ParseToolsetConfigFromSlice parses toolset configuration from a string slice +// (as typically provided by viper for command line flags) +func ParseToolsetConfigFromSlice(input []string) ([]ToolsetConfig, error) { + if len(input) == 0 { + return []ToolsetConfig{}, nil + } + + // Join the slice with commas and parse as a single string + // This handles both CLI flag usage and environment variable usage + return ParseToolsetConfig(strings.Join(input, ",")) +} diff --git a/pkg/toolsets/config_test.go b/pkg/toolsets/config_test.go new file mode 100644 index 000000000..018bd6b14 --- /dev/null +++ b/pkg/toolsets/config_test.go @@ -0,0 +1,249 @@ +package toolsets + +import ( + "testing" +) + +func TestParseToolsetConfig(t *testing.T) { + tests := []struct { + name string + input string + expected []ToolsetConfig + wantErr bool + }{ + { + name: "empty string", + input: "", + expected: []ToolsetConfig{}, + wantErr: false, + }, + { + name: "single toolset without mode", + input: "repos", + expected: []ToolsetConfig{ + {Name: "repos", Mode: ReadWrite}, + }, + wantErr: false, + }, + { + name: "single toolset with rw mode", + input: "repos:rw", + expected: []ToolsetConfig{ + {Name: "repos", Mode: ReadWrite}, + }, + wantErr: false, + }, + { + name: "single toolset with ro mode", + input: "repos:ro", + expected: []ToolsetConfig{ + {Name: "repos", Mode: ReadOnly}, + }, + wantErr: false, + }, + { + name: "single toolset with readwrite mode", + input: "repos:readwrite", + expected: []ToolsetConfig{ + {Name: "repos", Mode: ReadWrite}, + }, + wantErr: false, + }, + { + name: "single toolset with readonly mode", + input: "repos:readonly", + expected: []ToolsetConfig{ + {Name: "repos", Mode: ReadOnly}, + }, + wantErr: false, + }, + { + name: "multiple toolsets mixed modes", + input: "repos:rw,issues:ro,users", + expected: []ToolsetConfig{ + {Name: "repos", Mode: ReadWrite}, + {Name: "issues", Mode: ReadOnly}, + {Name: "users", Mode: ReadWrite}, + }, + wantErr: false, + }, + { + name: "multiple toolsets with spaces", + input: "repos:rw, issues:ro, users", + expected: []ToolsetConfig{ + {Name: "repos", Mode: ReadWrite}, + {Name: "issues", Mode: ReadOnly}, + {Name: "users", Mode: ReadWrite}, + }, + wantErr: false, + }, + { + name: "all toolset with mode", + input: "all:ro", + expected: []ToolsetConfig{ + {Name: "all", Mode: ReadOnly}, + }, + wantErr: false, + }, + { + name: "invalid mode", + input: "repos:invalid", + wantErr: true, + }, + { + name: "empty toolset name", + input: ":rw", + wantErr: true, + }, + { + name: "malformed format", + input: "repos:rw:extra", + wantErr: true, + }, + { + name: "case insensitive modes", + input: "repos:RW,issues:RO", + expected: []ToolsetConfig{ + {Name: "repos", Mode: ReadWrite}, + {Name: "issues", Mode: ReadOnly}, + }, + wantErr: false, + }, + { + name: "empty items in list", + input: "repos,,issues:ro,", + expected: []ToolsetConfig{ + {Name: "repos", Mode: ReadWrite}, + {Name: "issues", Mode: ReadOnly}, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := ParseToolsetConfig(tt.input) + + if tt.wantErr { + if err == nil { + t.Errorf("ParseToolsetConfig() expected error but got none") + } + return + } + + if err != nil { + t.Errorf("ParseToolsetConfig() unexpected error: %v", err) + return + } + + if len(result) != len(tt.expected) { + t.Errorf("ParseToolsetConfig() got %d configs, expected %d", len(result), len(tt.expected)) + return + } + + for i, config := range result { + if config.Name != tt.expected[i].Name { + t.Errorf("ParseToolsetConfig() config[%d].Name = %s, expected %s", i, config.Name, tt.expected[i].Name) + } + if config.Mode != tt.expected[i].Mode { + t.Errorf("ParseToolsetConfig() config[%d].Mode = %s, expected %s", i, config.Mode, tt.expected[i].Mode) + } + } + }) + } +} + +func TestParseToolsetConfigFromSlice(t *testing.T) { + tests := []struct { + name string + input []string + expected []ToolsetConfig + wantErr bool + }{ + { + name: "empty slice", + input: []string{}, + expected: []ToolsetConfig{}, + wantErr: false, + }, + { + name: "single item", + input: []string{"repos:rw"}, + expected: []ToolsetConfig{ + {Name: "repos", Mode: ReadWrite}, + }, + wantErr: false, + }, + { + name: "multiple items", + input: []string{"repos:rw", "issues:ro"}, + expected: []ToolsetConfig{ + {Name: "repos", Mode: ReadWrite}, + {Name: "issues", Mode: ReadOnly}, + }, + wantErr: false, + }, + { + name: "comma-separated in single item", + input: []string{"repos:rw,issues:ro,users"}, + expected: []ToolsetConfig{ + {Name: "repos", Mode: ReadWrite}, + {Name: "issues", Mode: ReadOnly}, + {Name: "users", Mode: ReadWrite}, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := ParseToolsetConfigFromSlice(tt.input) + + if tt.wantErr { + if err == nil { + t.Errorf("ParseToolsetConfigFromSlice() expected error but got none") + } + return + } + + if err != nil { + t.Errorf("ParseToolsetConfigFromSlice() unexpected error: %v", err) + return + } + + if len(result) != len(tt.expected) { + t.Errorf("ParseToolsetConfigFromSlice() got %d configs, expected %d", len(result), len(tt.expected)) + return + } + + for i, config := range result { + if config.Name != tt.expected[i].Name { + t.Errorf("ParseToolsetConfigFromSlice() config[%d].Name = %s, expected %s", i, config.Name, tt.expected[i].Name) + } + if config.Mode != tt.expected[i].Mode { + t.Errorf("ParseToolsetConfigFromSlice() config[%d].Mode = %s, expected %s", i, config.Mode, tt.expected[i].Mode) + } + } + }) + } +} + +func TestToolsetModeString(t *testing.T) { + tests := []struct { + mode ToolsetMode + expected string + }{ + {ReadWrite, "rw"}, + {ReadOnly, "ro"}, + {ToolsetMode(999), "unknown"}, + } + + for _, tt := range tests { + t.Run(tt.expected, func(t *testing.T) { + result := tt.mode.String() + if result != tt.expected { + t.Errorf("ToolsetMode.String() = %s, expected %s", result, tt.expected) + } + }) + } +} diff --git a/pkg/toolsets/toolsets.go b/pkg/toolsets/toolsets.go index 7400119c8..3351f058c 100644 --- a/pkg/toolsets/toolsets.go +++ b/pkg/toolsets/toolsets.go @@ -15,14 +15,16 @@ type Toolset struct { Name string Description string Enabled bool - readOnly bool + Mode ToolsetMode // NEW: per-toolset mode + readOnly bool // Global read-only override (deprecated in favor of Mode) writeTools []server.ServerTool readTools []server.ServerTool } func (t *Toolset) GetActiveTools() []server.ServerTool { if t.Enabled { - if t.readOnly { + // Check both global readOnly and per-toolset Mode + if t.readOnly || t.Mode == ReadOnly { return t.readTools } return append(t.readTools, t.writeTools...) @@ -31,7 +33,8 @@ func (t *Toolset) GetActiveTools() []server.ServerTool { } func (t *Toolset) GetAvailableTools() []server.ServerTool { - if t.readOnly { + // Check both global readOnly and per-toolset Mode + if t.readOnly || t.Mode == ReadOnly { return t.readTools } return append(t.readTools, t.writeTools...) @@ -44,7 +47,8 @@ func (t *Toolset) RegisterTools(s *server.MCPServer) { for _, tool := range t.readTools { s.AddTool(tool.Tool, tool.Handler) } - if !t.readOnly { + // Only register write tools if both global and per-toolset settings allow it + if !t.readOnly && t.Mode != ReadOnly { for _, tool := range t.writeTools { s.AddTool(tool.Tool, tool.Handler) } @@ -105,6 +109,7 @@ func NewToolset(name string, description string) *Toolset { Name: name, Description: description, Enabled: false, + Mode: ReadWrite, // Default to ReadWrite readOnly: false, } } @@ -122,6 +127,36 @@ func (tg *ToolsetGroup) IsEnabled(name string) bool { return feature.Enabled } +func (tg *ToolsetGroup) EnableToolsetsWithConfig(configs []ToolsetConfig) error { + // Special case for "all" + for _, config := range configs { + if config.Name == "all" { + tg.everythingOn = true + // Apply the mode to all toolsets + for name := range tg.Toolsets { + toolset := tg.Toolsets[name] + toolset.Enabled = true + toolset.Mode = config.Mode + tg.Toolsets[name] = toolset + } + return nil + } + } + + // Enable specific toolsets with their modes + for _, config := range configs { + toolset, exists := tg.Toolsets[config.Name] + if !exists { + return fmt.Errorf("toolset %s does not exist", config.Name) + } + toolset.Enabled = true + toolset.Mode = config.Mode + tg.Toolsets[config.Name] = toolset + } + + return nil +} + func (tg *ToolsetGroup) EnableToolsets(names []string) error { // Special case for "all" for _, name := range names { diff --git a/pkg/toolsets/toolsets_test.go b/pkg/toolsets/toolsets_test.go index 7ece1df1e..fec9c9646 100644 --- a/pkg/toolsets/toolsets_test.go +++ b/pkg/toolsets/toolsets_test.go @@ -228,3 +228,108 @@ func TestIsEnabledWithEverythingOn(t *testing.T) { t.Error("Expected IsEnabled to return true for any toolset when everythingOn is true") } } + +func TestEnableToolsetsWithConfig(t *testing.T) { + tests := []struct { + name string + configs []ToolsetConfig + wantErr bool + expected map[string]ToolsetMode + }{ + { + name: "enable single toolset with rw mode", + configs: []ToolsetConfig{ + {Name: "repos", Mode: ReadWrite}, + }, + wantErr: false, + expected: map[string]ToolsetMode{ + "repos": ReadWrite, + }, + }, + { + name: "enable single toolset with ro mode", + configs: []ToolsetConfig{ + {Name: "repos", Mode: ReadOnly}, + }, + wantErr: false, + expected: map[string]ToolsetMode{ + "repos": ReadOnly, + }, + }, + { + name: "enable multiple toolsets with mixed modes", + configs: []ToolsetConfig{ + {Name: "repos", Mode: ReadWrite}, + {Name: "issues", Mode: ReadOnly}, + {Name: "users", Mode: ReadWrite}, + }, + wantErr: false, + expected: map[string]ToolsetMode{ + "repos": ReadWrite, + "issues": ReadOnly, + "users": ReadWrite, + }, + }, + { + name: "enable all with ro mode", + configs: []ToolsetConfig{ + {Name: "all", Mode: ReadOnly}, + }, + wantErr: false, + expected: map[string]ToolsetMode{ + "repos": ReadOnly, + "issues": ReadOnly, + "users": ReadOnly, + }, + }, + { + name: "enable nonexistent toolset", + configs: []ToolsetConfig{ + {Name: "nonexistent", Mode: ReadWrite}, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Reset toolset group + tsg := NewToolsetGroup(false) + tsg.AddToolset(NewToolset("repos", "Repository tools")) + tsg.AddToolset(NewToolset("issues", "Issue tools")) + tsg.AddToolset(NewToolset("users", "User tools")) + + err := tsg.EnableToolsetsWithConfig(tt.configs) + + if tt.wantErr { + if err == nil { + t.Errorf("EnableToolsetsWithConfig() expected error but got none") + } + return + } + + if err != nil { + t.Errorf("EnableToolsetsWithConfig() unexpected error: %v", err) + return + } + + // Check that expected toolsets are enabled with correct modes + for name, expectedMode := range tt.expected { + toolset := tsg.Toolsets[name] + if !toolset.Enabled { + t.Errorf("Expected toolset %s to be enabled", name) + } + if toolset.Mode != expectedMode { + t.Errorf("Expected toolset %s to have mode %s, got %s", name, expectedMode, toolset.Mode) + } + } + + // Check that non-expected toolsets are not enabled + for name, toolset := range tsg.Toolsets { + if _, expected := tt.expected[name]; !expected && toolset.Enabled { + t.Errorf("Expected toolset %s to not be enabled", name) + } + } + }) + } +}