Skip to content

Commit 9d0c8cf

Browse files
support tool updates
1 parent b0bcb21 commit 9d0c8cf

File tree

11 files changed

+1035
-502
lines changed

11 files changed

+1035
-502
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ The following sets of tools are available:
112112
| `pull_requests` | Pull request operations (create, merge, review) | Enabled |
113113
| `code_security` | Code scanning alerts and security features | Disabled |
114114
| `experiments` | Experimental features (not considered stable) | Disabled |
115-
| `everything` | Special flag to enable all features | Disabled |
115+
| `all` | Special flag to enable all features | Disabled |
116116

117117
### Specifying Toolsets
118118

cmd/github-mcp-server/main.go

+26-45
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ import (
1212

1313
"github.com/github/github-mcp-server/pkg/github"
1414
iolog "github.com/github/github-mcp-server/pkg/log"
15-
"github.com/github/github-mcp-server/pkg/toolsets"
1615
"github.com/github/github-mcp-server/pkg/translations"
1716
gogithub "github.com/google/go-github/v69/github"
1817
"github.com/mark3labs/mcp-go/server"
@@ -45,10 +44,16 @@ var (
4544
if err != nil {
4645
stdlog.Fatal("Failed to initialize logger:", err)
4746
}
48-
enabledToolsets := viper.GetStringSlice("features")
49-
features, err := initToolsets(enabledToolsets)
50-
if err != nil {
51-
stdlog.Fatal("Failed to initialize features:", err)
47+
48+
enabledToolsets := viper.GetStringSlice("toolsets")
49+
50+
// Env gets precedence over command line flags
51+
if envToolsets := os.Getenv("GITHUB_TOOLSETS"); envToolsets != "" {
52+
enabledToolsets = []string{}
53+
// Split envFeats by comma, trim whitespace, and add to the slice
54+
for _, toolset := range strings.Split(envToolsets, ",") {
55+
enabledToolsets = append(enabledToolsets, strings.TrimSpace(toolset))
56+
}
5257
}
5358

5459
logCommands := viper.GetBool("enable-command-logging")
@@ -57,7 +62,7 @@ var (
5762
logger: logger,
5863
logCommands: logCommands,
5964
exportTranslations: exportTranslations,
60-
features: features,
65+
enabledToolsets: enabledToolsets,
6166
}
6267
if err := runStdioServer(cfg); err != nil {
6368
stdlog.Fatal("failed to run stdio server:", err)
@@ -66,53 +71,19 @@ var (
6671
}
6772
)
6873

69-
func initToolsets(passedToolsets []string) (*toolsets.ToolsetGroup, error) {
70-
// Create a new toolset group
71-
fs := toolsets.NewToolsetGroup()
72-
73-
// Define all available features with their default state (disabled)
74-
fs.AddToolset("repos", "Repository related tools", false)
75-
fs.AddToolset("issues", "Issues related tools", false)
76-
fs.AddToolset("search", "Search related tools", false)
77-
fs.AddToolset("pull_requests", "Pull request related tools", false)
78-
fs.AddToolset("code_security", "Code security related tools", false)
79-
fs.AddToolset("experiments", "Experimental features that are not considered stable yet", false)
80-
81-
// fs.AddFeature("actions", "GitHub Actions related tools", false)
82-
// fs.AddFeature("projects", "GitHub Projects related tools", false)
83-
// fs.AddFeature("secret_protection", "Secret protection related tools", false)
84-
// fs.AddFeature("gists", "Gist related tools", false)
85-
86-
// Env gets precedence over command line flags
87-
if envFeats := os.Getenv("GITHUB_TOOLSETS"); envFeats != "" {
88-
passedToolsets = []string{}
89-
// Split envFeats by comma, trim whitespace, and add to the slice
90-
for _, feature := range strings.Split(envFeats, ",") {
91-
passedToolsets = append(passedToolsets, strings.TrimSpace(feature))
92-
}
93-
}
94-
95-
// Enable the requested features
96-
if err := fs.EnableToolsets(passedToolsets); err != nil {
97-
return nil, err
98-
}
99-
100-
return fs, nil
101-
}
102-
10374
func init() {
10475
cobra.OnInitialize(initConfig)
10576

10677
// Add global flags that will be shared by all commands
107-
rootCmd.PersistentFlags().StringSlice("features", []string{"repos", "issues", "pull_requests", "search"}, "A comma separated list of groups of tools to enable, defaults to issues/repos/search")
78+
rootCmd.PersistentFlags().StringSlice("toolsets", []string{"repos", "issues", "pull_requests", "search"}, "A comma separated list of groups of tools to enable, defaults to issues/repos/search")
10879
rootCmd.PersistentFlags().Bool("read-only", false, "Restrict the server to read-only operations")
10980
rootCmd.PersistentFlags().String("log-file", "", "Path to log file")
11081
rootCmd.PersistentFlags().Bool("enable-command-logging", false, "When enabled, the server will log all command requests and responses to the log file")
11182
rootCmd.PersistentFlags().Bool("export-translations", false, "Save translations to a JSON file")
11283
rootCmd.PersistentFlags().String("gh-host", "", "Specify the GitHub hostname (for GitHub Enterprise etc.)")
11384

11485
// Bind flag to viper
115-
_ = viper.BindPFlag("features", rootCmd.PersistentFlags().Lookup("features"))
86+
_ = viper.BindPFlag("toolsets", rootCmd.PersistentFlags().Lookup("toolsets"))
11687
_ = viper.BindPFlag("read-only", rootCmd.PersistentFlags().Lookup("read-only"))
11788
_ = viper.BindPFlag("log-file", rootCmd.PersistentFlags().Lookup("log-file"))
11889
_ = viper.BindPFlag("enable-command-logging", rootCmd.PersistentFlags().Lookup("enable-command-logging"))
@@ -151,7 +122,7 @@ type runConfig struct {
151122
logger *log.Logger
152123
logCommands bool
153124
exportTranslations bool
154-
features *toolsets.ToolsetGroup
125+
enabledToolsets []string
155126
}
156127

157128
func runStdioServer(cfg runConfig) error {
@@ -186,8 +157,18 @@ func runStdioServer(cfg runConfig) error {
186157
getClient := func(_ context.Context) (*gogithub.Client, error) {
187158
return ghClient, nil // closing over client
188159
}
189-
// Create
190-
ghServer := github.NewServer(getClient, cfg.features, version, cfg.readOnly, t)
160+
161+
// Create server
162+
ghServer := github.NewServer(version)
163+
164+
// Create toolsets
165+
toolsets, err := github.InitToolsets(ghServer, cfg.enabledToolsets, cfg.readOnly, getClient, t)
166+
if err != nil {
167+
stdlog.Fatal("Failed to initialize toolsets:", err)
168+
}
169+
// Register the tools with the server
170+
toolsets.RegisterTools(ghServer)
171+
191172
stdioServer := server.NewStdioServer(ghServer)
192173

193174
stdLogger := stdlog.New(cfg.logger.Writer(), "stdioserver", 0)

pkg/github/context_tools.go

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
package github
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"io"
8+
"net/http"
9+
10+
"github.com/github/github-mcp-server/pkg/translations"
11+
"github.com/mark3labs/mcp-go/mcp"
12+
"github.com/mark3labs/mcp-go/server"
13+
)
14+
15+
// GetMe creates a tool to get details of the authenticated user.
16+
func GetMe(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
17+
return mcp.NewTool("get_me",
18+
mcp.WithDescription(t("TOOL_GET_ME_DESCRIPTION", "Get details of the authenticated GitHub user. Use this when a request include \"me\", \"my\"...")),
19+
mcp.WithString("reason",
20+
mcp.Description("Optional: reason the session was created"),
21+
),
22+
),
23+
func(ctx context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
24+
client, err := getClient(ctx)
25+
if err != nil {
26+
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
27+
}
28+
user, resp, err := client.Users.Get(ctx, "")
29+
if err != nil {
30+
return nil, fmt.Errorf("failed to get user: %w", err)
31+
}
32+
defer func() { _ = resp.Body.Close() }()
33+
34+
if resp.StatusCode != http.StatusOK {
35+
body, err := io.ReadAll(resp.Body)
36+
if err != nil {
37+
return nil, fmt.Errorf("failed to read response body: %w", err)
38+
}
39+
return mcp.NewToolResultError(fmt.Sprintf("failed to get user: %s", string(body))), nil
40+
}
41+
42+
r, err := json.Marshal(user)
43+
if err != nil {
44+
return nil, fmt.Errorf("failed to marshal user: %w", err)
45+
}
46+
47+
return mcp.NewToolResultText(string(r)), nil
48+
}
49+
}

pkg/github/context_tools_test.go

+132
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
package github
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"net/http"
7+
"testing"
8+
"time"
9+
10+
"github.com/github/github-mcp-server/pkg/translations"
11+
"github.com/google/go-github/v69/github"
12+
"github.com/migueleliasweb/go-github-mock/src/mock"
13+
"github.com/stretchr/testify/assert"
14+
"github.com/stretchr/testify/require"
15+
)
16+
17+
func Test_GetMe(t *testing.T) {
18+
// Verify tool definition
19+
mockClient := github.NewClient(nil)
20+
tool, _ := GetMe(stubGetClientFn(mockClient), translations.NullTranslationHelper)
21+
22+
assert.Equal(t, "get_me", tool.Name)
23+
assert.NotEmpty(t, tool.Description)
24+
assert.Contains(t, tool.InputSchema.Properties, "reason")
25+
assert.Empty(t, tool.InputSchema.Required) // No required parameters
26+
27+
// Setup mock user response
28+
mockUser := &github.User{
29+
Login: github.Ptr("testuser"),
30+
Name: github.Ptr("Test User"),
31+
Email: github.Ptr("[email protected]"),
32+
Bio: github.Ptr("GitHub user for testing"),
33+
Company: github.Ptr("Test Company"),
34+
Location: github.Ptr("Test Location"),
35+
HTMLURL: github.Ptr("https://github.com/testuser"),
36+
CreatedAt: &github.Timestamp{Time: time.Now().Add(-365 * 24 * time.Hour)},
37+
Type: github.Ptr("User"),
38+
Plan: &github.Plan{
39+
Name: github.Ptr("pro"),
40+
},
41+
}
42+
43+
tests := []struct {
44+
name string
45+
mockedClient *http.Client
46+
requestArgs map[string]interface{}
47+
expectError bool
48+
expectedUser *github.User
49+
expectedErrMsg string
50+
}{
51+
{
52+
name: "successful get user",
53+
mockedClient: mock.NewMockedHTTPClient(
54+
mock.WithRequestMatch(
55+
mock.GetUser,
56+
mockUser,
57+
),
58+
),
59+
requestArgs: map[string]interface{}{},
60+
expectError: false,
61+
expectedUser: mockUser,
62+
},
63+
{
64+
name: "successful get user with reason",
65+
mockedClient: mock.NewMockedHTTPClient(
66+
mock.WithRequestMatch(
67+
mock.GetUser,
68+
mockUser,
69+
),
70+
),
71+
requestArgs: map[string]interface{}{
72+
"reason": "Testing API",
73+
},
74+
expectError: false,
75+
expectedUser: mockUser,
76+
},
77+
{
78+
name: "get user fails",
79+
mockedClient: mock.NewMockedHTTPClient(
80+
mock.WithRequestMatchHandler(
81+
mock.GetUser,
82+
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
83+
w.WriteHeader(http.StatusUnauthorized)
84+
_, _ = w.Write([]byte(`{"message": "Unauthorized"}`))
85+
}),
86+
),
87+
),
88+
requestArgs: map[string]interface{}{},
89+
expectError: true,
90+
expectedErrMsg: "failed to get user",
91+
},
92+
}
93+
94+
for _, tc := range tests {
95+
t.Run(tc.name, func(t *testing.T) {
96+
// Setup client with mock
97+
client := github.NewClient(tc.mockedClient)
98+
_, handler := GetMe(stubGetClientFn(client), translations.NullTranslationHelper)
99+
100+
// Create call request
101+
request := createMCPRequest(tc.requestArgs)
102+
103+
// Call handler
104+
result, err := handler(context.Background(), request)
105+
106+
// Verify results
107+
if tc.expectError {
108+
require.Error(t, err)
109+
assert.Contains(t, err.Error(), tc.expectedErrMsg)
110+
return
111+
}
112+
113+
require.NoError(t, err)
114+
115+
// Parse result and get text content if no error
116+
textContent := getTextResult(t, result)
117+
118+
// Unmarshal and verify the result
119+
var returnedUser github.User
120+
err = json.Unmarshal([]byte(textContent.Text), &returnedUser)
121+
require.NoError(t, err)
122+
123+
// Verify user details
124+
assert.Equal(t, *tc.expectedUser.Login, *returnedUser.Login)
125+
assert.Equal(t, *tc.expectedUser.Name, *returnedUser.Name)
126+
assert.Equal(t, *tc.expectedUser.Email, *returnedUser.Email)
127+
assert.Equal(t, *tc.expectedUser.Bio, *returnedUser.Bio)
128+
assert.Equal(t, *tc.expectedUser.HTMLURL, *returnedUser.HTMLURL)
129+
assert.Equal(t, *tc.expectedUser.Type, *returnedUser.Type)
130+
})
131+
}
132+
}

0 commit comments

Comments
 (0)