Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 60 additions & 71 deletions internal/mcp/tools/graphql.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,12 @@ func registerGraphQLTools(s *server.MCPServer, ctx *toolContext) {
// get_nais_context tool - provides essential context for working with Nais
getNaisContextTool := mcp.NewTool("get_nais_context",
mcp.WithDescription("Get the current Nais context including authenticated user, their teams, and console URL. Call this first to understand what the user has access to and to get the correct console URL for links."),
mcp.WithInputSchema[GetNaisContextInput](),
mcp.WithOutputSchema[GetNaisContextOutput](),
)
s.AddTool(getNaisContextTool, ctx.handleGetNaisContext)
s.AddTool(getNaisContextTool, mcp.NewStructuredToolHandler(ctx.handleGetNaisContext))

// execute_graphql tool
// execute_graphql tool - dynamic output, so we use NewTypedToolHandler
executeGraphQLTool := mcp.NewTool("execute_graphql",
mcp.WithDescription(`Execute a GraphQL query against the Nais API.

Expand All @@ -199,30 +201,27 @@ IMPORTANT: Before using this tool, use the schema exploration tools (schema_list
This tool only supports queries (read operations). Mutations are not allowed.

`+naisAPIGuidance),
mcp.WithString("query",
mcp.Required(),
mcp.Description("The GraphQL query to execute. Must be a query operation (not mutation or subscription)."),
),
mcp.WithString("variables",
mcp.Description("JSON object containing variables for the query. Example: {\"slug\": \"my-team\", \"first\": 10}"),
),
mcp.WithInputSchema[ExecuteGraphQLInput](),
// Note: Output is dynamic JSON from the GraphQL API, so we don't use WithOutputSchema here
)
s.AddTool(executeGraphQLTool, ctx.handleExecuteGraphQL)
s.AddTool(executeGraphQLTool, mcp.NewTypedToolHandler(ctx.handleExecuteGraphQL))

// validate_graphql tool
validateGraphQLTool := mcp.NewTool("validate_graphql",
mcp.WithDescription("Validate a GraphQL query against the schema without executing it. Use this to check if your query is valid before executing."),
mcp.WithString("query",
mcp.Required(),
mcp.Description("The GraphQL query to validate."),
),
mcp.WithInputSchema[ValidateGraphQLInput](),
mcp.WithOutputSchema[ValidateGraphQLOutput](),
)
s.AddTool(validateGraphQLTool, ctx.handleValidateGraphQL)
s.AddTool(validateGraphQLTool, mcp.NewStructuredToolHandler(ctx.handleValidateGraphQL))
}

func (t *toolContext) handleGetNaisContext(reqCtx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
func (t *toolContext) handleGetNaisContext(
reqCtx context.Context,
req mcp.CallToolRequest,
args GetNaisContextInput,
) (GetNaisContextOutput, error) {
if !t.rateLimiter.Allow() {
return mcp.NewToolResultError("rate limit exceeded, please try again later"), nil
return GetNaisContextOutput{}, fmt.Errorf("rate limit exceeded, please try again later")
}

t.logger.Debug("Executing get_nais_context tool")
Expand All @@ -231,36 +230,36 @@ func (t *toolContext) handleGetNaisContext(reqCtx context.Context, req mcp.CallT
user, err := t.client.GetCurrentUser(reqCtx)
if err != nil {
t.logger.Error("Failed to get current user", "error", err)
return mcp.NewToolResultError(fmt.Sprintf("failed to get current user: %v", err)), nil
return GetNaisContextOutput{}, fmt.Errorf("failed to get current user: %w", err)
}

// Get user's teams
teams, err := t.client.GetUserTeams(reqCtx)
if err != nil {
t.logger.Error("Failed to get user teams", "error", err)
return mcp.NewToolResultError(fmt.Sprintf("failed to get user teams: %v", err)), nil
return GetNaisContextOutput{}, fmt.Errorf("failed to get user teams: %w", err)
}

// Build teams list
teamsList := make([]map[string]any, 0, len(teams))
for _, t := range teams {
teamsList = append(teamsList, map[string]any{
"slug": t.Team.Slug,
"purpose": t.Team.Purpose,
"role": string(t.Role),
teamsList := make([]NaisTeamInfo, 0, len(teams))
for _, team := range teams {
teamsList = append(teamsList, NaisTeamInfo{
Slug: team.Team.Slug,
Purpose: team.Team.Purpose,
Role: string(team.Role),
})
}

// Get console URL
consoleBaseURL := t.getConsoleBaseURL(reqCtx)

result := map[string]any{
"user": map[string]any{
"name": user.Name,
return GetNaisContextOutput{
User: NaisUserInfo{
Name: user.Name,
},
"teams": teamsList,
"console_base_url": consoleBaseURL,
"console_url_patterns": map[string]string{
Teams: teamsList,
ConsoleBaseURL: consoleBaseURL,
ConsoleURLPatterns: map[string]string{
"team": "/team/{team}",
"team_applications": "/team/{team}/applications",
"team_jobs": "/team/{team}/jobs",
Expand Down Expand Up @@ -288,29 +287,24 @@ func (t *toolContext) handleGetNaisContext(reqCtx context.Context, req mcp.CallT
"bigquery": "/team/{team}/{env}/bigquery/{name}",
"kafka": "/team/{team}/{env}/kafka/{name}",
},
}

jsonData, err := json.MarshalIndent(result, "", " ")
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil
}

return mcp.NewToolResultText(string(jsonData)), nil
}, nil
}

func (t *toolContext) handleExecuteGraphQL(reqCtx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
func (t *toolContext) handleExecuteGraphQL(
reqCtx context.Context,
req mcp.CallToolRequest,
args ExecuteGraphQLInput,
) (*mcp.CallToolResult, error) {
if !t.rateLimiter.Allow() {
return mcp.NewToolResultError("rate limit exceeded, please try again later"), nil
}

query, err := req.RequireString("query")
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
variablesStr := args.Variables
if variablesStr == "" {
variablesStr = "{}"
}

variablesStr := req.GetString("variables", "{}")

t.logger.Debug("Executing GraphQL query", "query_length", len(query), "has_variables", variablesStr != "{}")
t.logger.Debug("Executing GraphQL query", "query_length", len(args.Query), "has_variables", variablesStr != "{}")

// Parse variables
var variables map[string]any
Expand All @@ -319,7 +313,7 @@ func (t *toolContext) handleExecuteGraphQL(reqCtx context.Context, req mcp.CallT
}

// Validate the query
validationResult, err := t.validateGraphQLQuery(reqCtx, query)
validationResult, err := t.validateGraphQLQuery(reqCtx, args.Query)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to validate query: %v", err)), nil
}
Expand All @@ -336,7 +330,7 @@ func (t *toolContext) handleExecuteGraphQL(reqCtx context.Context, req mcp.CallT

// Create the request
gqlReq := &graphql.Request{
Query: query,
Query: args.Query,
Variables: variables,
}

Expand Down Expand Up @@ -368,38 +362,33 @@ func (t *toolContext) handleExecuteGraphQL(reqCtx context.Context, req mcp.CallT
return mcp.NewToolResultText(string(jsonData)), nil
}

func (t *toolContext) handleValidateGraphQL(reqCtx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
func (t *toolContext) handleValidateGraphQL(
reqCtx context.Context,
req mcp.CallToolRequest,
args ValidateGraphQLInput,
) (ValidateGraphQLOutput, error) {
if !t.rateLimiter.Allow() {
return mcp.NewToolResultError("rate limit exceeded, please try again later"), nil
return ValidateGraphQLOutput{}, fmt.Errorf("rate limit exceeded, please try again later")
}

query, err := req.RequireString("query")
result, err := t.validateGraphQLQuery(reqCtx, args.Query)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}

result, err := t.validateGraphQLQuery(reqCtx, query)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("validation failed: %v", err)), nil
return ValidateGraphQLOutput{}, fmt.Errorf("validation failed: %w", err)
}

if result.valid {
response := map[string]any{
"valid": true,
"operation_type": result.operationType,
"operation_name": result.operationName,
"depth": result.depth,
}
jsonData, _ := json.MarshalIndent(response, "", " ")
return mcp.NewToolResultText(string(jsonData)), nil
return ValidateGraphQLOutput{
Valid: true,
OperationType: result.operationType,
OperationName: result.operationName,
Depth: result.depth,
}, nil
}

response := map[string]any{
"valid": false,
"error": result.error,
}
jsonData, _ := json.MarshalIndent(response, "", " ")
return mcp.NewToolResultText(string(jsonData)), nil
return ValidateGraphQLOutput{
Valid: false,
Error: result.error,
}, nil
}

type queryValidationResult struct {
Expand Down
Loading