Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 0 additions & 139 deletions pkg/mcp/common_test.go
Original file line number Diff line number Diff line change
@@ -1,23 +1,16 @@
package mcp

import (
"bytes"
"context"
"encoding/json"
"flag"
"fmt"
"net/http/httptest"
"os"
"path/filepath"
"runtime"
"strconv"
"testing"
"time"

"github.com/mark3labs/mcp-go/client"
"github.com/mark3labs/mcp-go/client/transport"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
"github.com/pkg/errors"
"github.com/spf13/afero"
"github.com/stretchr/testify/suite"
Expand All @@ -30,11 +23,7 @@ import (
"k8s.io/apimachinery/pkg/watch"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/rest"
"k8s.io/client-go/tools/clientcmd"
clientcmdapi "k8s.io/client-go/tools/clientcmd/api"
toolswatch "k8s.io/client-go/tools/watch"
"k8s.io/klog/v2"
"k8s.io/klog/v2/textlogger"
"k8s.io/utils/ptr"
"sigs.k8s.io/controller-runtime/pkg/envtest"
"sigs.k8s.io/controller-runtime/tools/setup-envtest/env"
Expand All @@ -45,7 +34,6 @@ import (

"github.com/containers/kubernetes-mcp-server/internal/test"
"github.com/containers/kubernetes-mcp-server/pkg/config"
"github.com/containers/kubernetes-mcp-server/pkg/output"
)

// envTest has an expensive setup, so we only want to do it once per entire test run.
Expand Down Expand Up @@ -103,133 +91,6 @@ func TestMain(m *testing.M) {
os.Exit(code)
}

type mcpContext struct {
toolsets []string
listOutput output.Output
logLevel int

staticConfig *config.StaticConfig
clientOptions []transport.ClientOption
before func(*mcpContext)
after func(*mcpContext)
ctx context.Context
tempDir string
cancel context.CancelFunc
mcpServer *Server
mcpHttpServer *httptest.Server
mcpClient *client.Client
klogState klog.State
logBuffer bytes.Buffer
}

func (c *mcpContext) beforeEach(t *testing.T) {
var err error
c.ctx, c.cancel = context.WithCancel(t.Context())
c.tempDir = t.TempDir()
c.withKubeConfig(nil)
if c.staticConfig == nil {
c.staticConfig = config.Default()
// Default to use YAML output for lists (previously the default)
c.staticConfig.ListOutput = "yaml"
}
if c.toolsets != nil {
c.staticConfig.Toolsets = c.toolsets

}
if c.listOutput != nil {
c.staticConfig.ListOutput = c.listOutput.GetName()
}
if c.before != nil {
c.before(c)
}
// Set up logging
c.klogState = klog.CaptureState()
flags := flag.NewFlagSet("test", flag.ContinueOnError)
klog.InitFlags(flags)
_ = flags.Set("v", strconv.Itoa(c.logLevel))
klog.SetLogger(textlogger.NewLogger(textlogger.NewConfig(textlogger.Verbosity(c.logLevel), textlogger.Output(&c.logBuffer))))
// MCP Server
if c.mcpServer, err = NewServer(Configuration{StaticConfig: c.staticConfig}); err != nil {
t.Fatal(err)
return
}
c.mcpHttpServer = server.NewTestServer(c.mcpServer.server, server.WithSSEContextFunc(contextFunc))
if c.mcpClient, err = client.NewSSEMCPClient(c.mcpHttpServer.URL+"/sse", c.clientOptions...); err != nil {
t.Fatal(err)
return
}
// MCP Client
if err = c.mcpClient.Start(c.ctx); err != nil {
t.Fatal(err)
return
}
initRequest := mcp.InitializeRequest{}
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
initRequest.Params.ClientInfo = mcp.Implementation{Name: "test", Version: "1.33.7"}
_, err = c.mcpClient.Initialize(c.ctx, initRequest)
if err != nil {
t.Fatal(err)
return
}
}

func (c *mcpContext) afterEach() {
if c.after != nil {
c.after(c)
}
c.cancel()
c.mcpServer.Close()
_ = c.mcpClient.Close()
c.mcpHttpServer.Close()
c.klogState.Restore()
}

func testCaseWithContext(t *testing.T, mcpCtx *mcpContext, test func(c *mcpContext)) {
mcpCtx.beforeEach(t)
defer mcpCtx.afterEach()
test(mcpCtx)
}

// withKubeConfig sets up a fake kubeconfig in the temp directory based on the provided rest.Config
func (c *mcpContext) withKubeConfig(rc *rest.Config) *clientcmdapi.Config {
fakeConfig := clientcmdapi.NewConfig()
fakeConfig.Clusters["fake"] = clientcmdapi.NewCluster()
fakeConfig.Clusters["fake"].Server = "https://127.0.0.1:6443"
fakeConfig.Clusters["additional-cluster"] = clientcmdapi.NewCluster()
fakeConfig.AuthInfos["fake"] = clientcmdapi.NewAuthInfo()
fakeConfig.AuthInfos["additional-auth"] = clientcmdapi.NewAuthInfo()
if rc != nil {
fakeConfig.Clusters["fake"].Server = rc.Host
fakeConfig.Clusters["fake"].CertificateAuthorityData = rc.CAData
fakeConfig.AuthInfos["fake"].ClientKeyData = rc.KeyData
fakeConfig.AuthInfos["fake"].ClientCertificateData = rc.CertData
}
fakeConfig.Contexts["fake-context"] = clientcmdapi.NewContext()
fakeConfig.Contexts["fake-context"].Cluster = "fake"
fakeConfig.Contexts["fake-context"].AuthInfo = "fake"
fakeConfig.Contexts["additional-context"] = clientcmdapi.NewContext()
fakeConfig.Contexts["additional-context"].Cluster = "additional-cluster"
fakeConfig.Contexts["additional-context"].AuthInfo = "additional-auth"
fakeConfig.CurrentContext = "fake-context"
kubeConfig := filepath.Join(c.tempDir, "config")
_ = clientcmd.WriteToFile(*fakeConfig, kubeConfig)
_ = os.Setenv("KUBECONFIG", kubeConfig)
if c.mcpServer != nil {
if err := c.mcpServer.reloadKubernetesClusterProvider(); err != nil {
panic(err)
}
}
return fakeConfig
}

// callTool helper function to call a tool by name with arguments
func (c *mcpContext) callTool(name string, args map[string]interface{}) (*mcp.CallToolResult, error) {
callToolRequest := mcp.CallToolRequest{}
callToolRequest.Params.Name = name
callToolRequest.Params.Arguments = args
return c.mcpClient.CallTool(c.ctx, callToolRequest)
}

func restoreAuth(ctx context.Context) {
kubernetesAdmin := kubernetes.NewForConfigOrDie(envTest.Config)
// Authorization
Expand Down
127 changes: 73 additions & 54 deletions pkg/mcp/mcp_middleware_test.go
Original file line number Diff line number Diff line change
@@ -1,68 +1,87 @@
package mcp

import (
"bytes"
"flag"
"regexp"
"strings"
"strconv"
"testing"

"github.com/mark3labs/mcp-go/client/transport"
"github.com/stretchr/testify/suite"
"k8s.io/klog/v2"
"k8s.io/klog/v2/textlogger"
)

func TestToolCallLogging(t *testing.T) {
testCaseWithContext(t, &mcpContext{logLevel: 5}, func(c *mcpContext) {
_, _ = c.callTool("configuration_view", map[string]interface{}{
"minified": false,
})
t.Run("Logs tool name", func(t *testing.T) {
expectedLog := "mcp tool call: configuration_view("
if !strings.Contains(c.logBuffer.String(), expectedLog) {
t.Errorf("Expected log to contain '%s', got: %s", expectedLog, c.logBuffer.String())
}
})
t.Run("Logs tool call arguments", func(t *testing.T) {
expected := `"mcp tool call: configuration_view\((.+)\)"`
m := regexp.MustCompile(expected).FindStringSubmatch(c.logBuffer.String())
if len(m) != 2 {
t.Fatalf("Expected log entry to contain arguments, got %s", c.logBuffer.String())
}
if m[1] != "map[minified:false]" {
t.Errorf("Expected log arguments to be 'map[minified:false]', got %s", m[1])
}
})
type McpLoggingSuite struct {
BaseMcpSuite
klogState klog.State
logBuffer bytes.Buffer
}

func (s *McpLoggingSuite) SetupTest() {
s.BaseMcpSuite.SetupTest()
s.klogState = klog.CaptureState()
}

func (s *McpLoggingSuite) TearDownTest() {
s.BaseMcpSuite.TearDownTest()
s.klogState.Restore()
}

func (s *McpLoggingSuite) SetLogLevel(level int) {
flags := flag.NewFlagSet("test", flag.ContinueOnError)
klog.InitFlags(flags)
_ = flags.Set("v", strconv.Itoa(level))
klog.SetLogger(textlogger.NewLogger(textlogger.NewConfig(textlogger.Verbosity(level), textlogger.Output(&s.logBuffer))))
}

func (s *McpLoggingSuite) TestLogsToolCall() {
s.SetLogLevel(5)
s.InitMcpClient()
_, err := s.CallTool("configuration_view", map[string]interface{}{"minified": false})
s.Require().NoError(err, "call to tool configuration_view failed")

s.Run("Logs tool name", func() {
s.Contains(s.logBuffer.String(), "mcp tool call: configuration_view(")
})
before := func(c *mcpContext) {
c.clientOptions = append(c.clientOptions, transport.WithHeaders(map[string]string{
"Accept-Encoding": "gzip",
"Authorization": "Bearer should-not-be-logged",
"authorization": "Bearer should-not-be-logged",
"a-loggable-header": "should-be-logged",
}))
s.Run("Logs tool call arguments", func() {
expected := `"mcp tool call: configuration_view\((.+)\)"`
m := regexp.MustCompile(expected).FindStringSubmatch(s.logBuffer.String())
s.Len(m, 2, "Expected log entry to contain arguments")
s.Equal("map[minified:false]", m[1], "Expected log arguments to be 'map[minified:false]'")
})
}

func (s *McpLoggingSuite) TestLogsToolCallHeaders() {
s.SetLogLevel(7)
s.InitMcpClient(transport.WithHTTPHeaders(map[string]string{
"Accept-Encoding": "gzip",
"Authorization": "Bearer should-not-be-logged",
"authorization": "Bearer should-not-be-logged",
"a-loggable-header": "should-be-logged",
}))
_, err := s.CallTool("configuration_view", map[string]interface{}{"minified": false})
s.Require().NoError(err, "call to tool configuration_view failed")

s.Run("Logs tool call headers", func() {
expectedLog := "mcp tool call headers: A-Loggable-Header: should-be-logged"
s.Contains(s.logBuffer.String(), expectedLog, "Expected log to contain loggable header")
})
sensitiveHeaders := []string{
"Authorization:",
// TODO: Add more sensitive headers as needed
}
testCaseWithContext(t, &mcpContext{logLevel: 7, before: before}, func(c *mcpContext) {
_, _ = c.callTool("configuration_view", map[string]interface{}{
"minified": false,
})
t.Run("Logs tool call headers", func(t *testing.T) {
expectedLog := "mcp tool call headers: A-Loggable-Header: should-be-logged"
if !strings.Contains(c.logBuffer.String(), expectedLog) {
t.Errorf("Expected log to contain '%s', got: %s", expectedLog, c.logBuffer.String())
}
})
sensitiveHeaders := []string{
"Authorization:",
// TODO: Add more sensitive headers as needed
s.Run("Does not log sensitive headers", func() {
for _, header := range sensitiveHeaders {
s.NotContains(s.logBuffer.String(), header, "Log should not contain sensitive header")
}
t.Run("Does not log sensitive headers", func(t *testing.T) {
for _, header := range sensitiveHeaders {
if strings.Contains(c.logBuffer.String(), header) {
t.Errorf("Log should not contain sensitive header '%s', got: %s", header, c.logBuffer.String())
}
}
})
t.Run("Does not log sensitive header values", func(t *testing.T) {
if strings.Contains(c.logBuffer.String(), "should-not-be-logged") {
t.Errorf("Log should not contain sensitive header value 'should-not-be-logged', got: %s", c.logBuffer.String())
}
})
})
s.Run("Does not log sensitive header values", func() {
s.NotContains(s.logBuffer.String(), "should-not-be-logged", "Log should not contain sensitive header value")
})
}

func TestMcpLogging(t *testing.T) {
suite.Run(t, new(McpLoggingSuite))
}
Loading