diff --git a/pkgs/dart_mcp/example/workflow_client.dart b/pkgs/dart_mcp/example/workflow_client.dart index a7fa4590..2e6c518d 100644 --- a/pkgs/dart_mcp/example/workflow_client.dart +++ b/pkgs/dart_mcp/example/workflow_client.dart @@ -37,7 +37,7 @@ void main(List args) { ); }, (e, s) { - logger.stderr('$e\n$s'); + logger.stderr('$e\n$s\n'); }, ); } @@ -95,8 +95,8 @@ final class WorkflowClient extends MCPClient with RootsSupport { gemini.Content.text( 'The current working directory is ' '${Directory.current.absolute.uri.toString()}. Convert all relative ' - 'URIs to absolute using this root. For tools that want a root, use this' - 'URI.', + 'URIs to absolute using this root. For tools that want a root, use ' + 'this URI.', ), ); if (dtdUri != null) { @@ -132,26 +132,50 @@ final class WorkflowClient extends MCPClient with RootsSupport { context: chatHistory, tools: serverTools, ); - _handleModelResponse(introResponse); + await _handleModelResponse(introResponse); while (true) { final next = await _waitForInputAndAddToHistory(); - await _makeAndExecutePlan(next, serverTools); + + // Remember where the history starts for this workflow + final historyStartIndex = chatHistory.length; + final summary = await _makeAndExecutePlan(next, serverTools); + + // Workflow/Plan execution finished, now summarize and clean up context. + if (historyStartIndex < chatHistory.length) { + // Remove the entire history. + chatHistory.removeRange(historyStartIndex, chatHistory.length); + } + + // Add the summary to the chat history. + await _handleModelResponse(summary); } } - void _handleModelResponse(gemini.Content response) { + /// Handles a response from the [model]. + /// + /// If this function returns a [String], then it should be fed back into the + /// model as a user message in order to continue the conversation. + Future _handleModelResponse(gemini.Content response) async { + String? continuation; for (var part in response.parts) { switch (part) { case gemini.TextPart(): _chatToUser(part.text); + case gemini.FunctionCall(): + await _handleFunctionCall(part); + continuation = 'Please proceed to the next step of the plan.'; default: - logger.stderr('Unrecognized response type from the model $response'); + logger.stderr( + 'Unrecognized response type from the model: $response.', + ); } } + return continuation; } - Future _makeAndExecutePlan( + /// Executes a plan and returns a summary of it. + Future _makeAndExecutePlan( String userPrompt, List serverTools, { bool editPreviousPlan = false, @@ -161,35 +185,37 @@ final class WorkflowClient extends MCPClient with RootsSupport { ? 'Edit the previous plan with the following changes:' : 'Create a new plan for the following task:'; final planPrompt = - '$instruction\n$userPrompt. After you have made a plan, ask the user ' - 'if they wish to proceed or if they want to make any changes to your ' - 'plan.'; + '$instruction\n$userPrompt\n\n After you have made a ' + 'plan, ask the user if they wish to proceed or if they want to make ' + 'any changes to your plan.'; _addToHistory(planPrompt); final planResponse = await _generateContent( context: chatHistory, tools: serverTools, ); - _handleModelResponse(planResponse); + await _handleModelResponse(planResponse); final userResponse = await _waitForInputAndAddToHistory(); final wasApproval = await _analyzeSentiment(userResponse); - if (!wasApproval) { - await _makeAndExecutePlan( - userResponse, - serverTools, - editPreviousPlan: true, - ); - } else { - await _executePlan(serverTools); - } + return wasApproval + ? await _executePlan(serverTools) + : await _makeAndExecutePlan( + userResponse, + serverTools, + editPreviousPlan: true, + ); } - Future _executePlan(List serverTools) async { + /// Executes a plan and returns a summary of it. + Future _executePlan(List serverTools) async { // If assigned then it is used as the next input from the user // instead of reading from stdin. String? continuation = - 'Execute the plan. After each step of the plan, report your progress.'; + 'Execute the plan. After each step of the plan, report your progress. ' + 'When you are completely done executing the plan, say exactly ' + '"Workflow complete" followed by a summary of everything that was done ' + 'so you can remember it for future tasks.'; while (true) { final nextMessage = continuation ?? await stdinQueue.next; @@ -199,29 +225,13 @@ final class WorkflowClient extends MCPClient with RootsSupport { context: chatHistory, tools: serverTools, ); - - for (var part in modelResponse.parts) { - switch (part) { - case gemini.TextPart(): - _chatToUser(part.text); - case gemini.FunctionCall(): - final result = await _handleFunctionCall(part); - if (result == null || - result.contains('unsupported response type')) { - _chatToUser( - 'Something went wrong when trying to call the ${part.name} ' - 'function. Proceeding to next step of the plan.', - ); - } - continuation = - '$result\n. Please proceed to the next step of the plan.'; - - default: - logger.stderr( - 'Unrecognized response type from the model: $modelResponse.', - ); + if (modelResponse.parts.first case final gemini.TextPart text) { + if (text.text.toLowerCase().contains('workflow complete')) { + return modelResponse; } } + + continuation = await _handleModelResponse(modelResponse); } } @@ -238,14 +248,17 @@ final class WorkflowClient extends MCPClient with RootsSupport { /// Analyzes a user [message] to see if it looks like they approved of the /// previous action. Future _analyzeSentiment(String message) async { - if (message == 'y' || message == 'yes') return true; + if (message.toLowerCase() == 'y' || message.toLowerCase() == 'yes') { + return true; + } final sentimentResult = await _generateContent( context: [ gemini.Content.text( 'Analyze the sentiment of the following response. If the response ' 'indicates a need for any changes, then this is not an approval. ' 'If you are highly confident that the user approves of running the ' - 'previous action then respond with a single character "y".', + 'previous action then respond with a single character "y". ' + 'Otherwise respond with "n".', ), gemini.Content.text(message), ], @@ -254,7 +267,7 @@ final class WorkflowClient extends MCPClient with RootsSupport { for (var part in sentimentResult.parts.whereType()) { response.write(part.text.trim()); } - return response.toString() == 'y'; + return response.toString().toLowerCase() == 'y'; } Future _generateContent({ @@ -266,6 +279,8 @@ final class WorkflowClient extends MCPClient with RootsSupport { try { response = await model.generateContent(context, tools: tools); return response.candidates.single.content; + } on gemini.GenerativeAIException catch (e) { + return gemini.Content.model([gemini.TextPart('Error: $e')]); } finally { if (response != null) { final inputTokens = response.usageMetadata?.promptTokenCount; @@ -298,19 +313,16 @@ final class WorkflowClient extends MCPClient with RootsSupport { } /// Handles a function call response from the model. - Future _handleFunctionCall(gemini.FunctionCall functionCall) async { - _chatToUser( - 'I am going to run the ${functionCall.name} tool' - '${verbose ? ' with args ${jsonEncode(functionCall.args)}' : ''} to ' - 'perform this task.', - ); - + /// + /// Invokes a function and adds the result as context to the chat history. + Future _handleFunctionCall(gemini.FunctionCall functionCall) async { chatHistory.add(gemini.Content.model([functionCall])); final connection = connectionForFunction[functionCall.name]!; final result = await connection.callTool( CallToolRequest(name: functionCall.name, arguments: functionCall.args), ); final response = StringBuffer(); + for (var content in result.content) { switch (content) { case final TextContent content when content.isText: @@ -324,22 +336,32 @@ final class WorkflowClient extends MCPClient with RootsSupport { response.writeln('Got unsupported response type ${content.type}'); } } - return response.toString(); + chatHistory.add( + gemini.Content.functionResponse(functionCall.name, { + 'output': response.toString(), + }), + ); } /// Connects to all servers using [serverCommands]. Future _connectToServers() async { for (var server in serverCommands) { final parts = server.split(' '); - serverConnections.add( - await connectStdioServer(parts.first, parts.skip(1).toList()), - ); + try { + serverConnections.add( + await connectStdioServer(parts.first, parts.skip(1).toList()), + ); + } catch (e) { + logger.stderr('Failed to connect to server $server: $e'); + } } } /// Initialization handshake. Future _initializeServers() async { - for (var connection in serverConnections) { + // Use a copy of the list to allow removal during iteration + final connectionsToInitialize = List.of(serverConnections); + for (var connection in connectionsToInitialize) { final result = await connection.initialize( InitializeRequest( protocolVersion: ProtocolVersion.latestSupported, @@ -347,11 +369,12 @@ final class WorkflowClient extends MCPClient with RootsSupport { clientInfo: implementation, ), ); + final serverName = connection.serverInfo?.name ?? 'server'; if (result.protocolVersion != ProtocolVersion.latestSupported) { logger.stderr( - 'Protocol version mismatch, expected ' - '${ProtocolVersion.latestSupported}, got ${result.protocolVersion}, ' - 'disconnecting from server', + 'Protocol version mismatch for $serverName, ' + 'expected ${ProtocolVersion.latestSupported}, got ' + '${result.protocolVersion}. Disconnecting.', ); await connection.shutdown(); serverConnections.remove(connection); @@ -374,8 +397,9 @@ final class WorkflowClient extends MCPClient with RootsSupport { ), ); connection.onLog.listen((event) { + final logServerName = connection.serverInfo?.name ?? '?'; logger.stdout( - 'Server Log(${event.level.name}): ' + 'Server Log ($logServerName/${event.level.name}): ' '${event.logger != null ? '[${event.logger}] ' : ''}${event.data}', ); }); @@ -386,7 +410,8 @@ final class WorkflowClient extends MCPClient with RootsSupport { Future> _listServerCapabilities() async { final functions = []; for (var connection in serverConnections) { - for (var tool in (await connection.listTools()).tools) { + final response = await connection.listTools(); + for (var tool in response.tools) { functions.add( gemini.FunctionDeclaration( tool.name, @@ -397,7 +422,9 @@ final class WorkflowClient extends MCPClient with RootsSupport { connectionForFunction[tool.name] = connection; } } - return [gemini.Tool(functionDeclarations: functions)]; + return functions.isEmpty + ? [] + : [gemini.Tool(functionDeclarations: functions)]; } gemini.Schema _schemaToGeminiSchema(Schema inputSchema, {bool? nullable}) { @@ -412,7 +439,7 @@ final class WorkflowClient extends MCPClient with RootsSupport { for (var entry in originalProperties.entries) entry.key: _schemaToGeminiSchema( entry.value, - nullable: objectSchema.required?.contains(entry.key), + nullable: objectSchema.required?.contains(entry.key) ?? false, ), }; } @@ -428,9 +455,16 @@ final class WorkflowClient extends MCPClient with RootsSupport { ); case JsonType.list: final listSchema = inputSchema as ListSchema; + final itemSchema = + listSchema.items == null + ? + // A bit of a hack here, gemini requires item schemas, just fall + // back on string. + gemini.Schema.string() + : _schemaToGeminiSchema(listSchema.items!); return gemini.Schema.array( description: description, - items: _schemaToGeminiSchema(listSchema.items!), + items: itemSchema, nullable: nullable, ); case JsonType.num: