Skip to content

Clear context after completing a workflow #117

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 8, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 101 additions & 67 deletions pkgs/dart_mcp/example/workflow_client.dart
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ void main(List<String> args) {
);
},
(e, s) {
logger.stderr('$e\n$s');
logger.stderr('$e\n$s\n');
},
);
}
Expand Down Expand Up @@ -95,8 +95,8 @@ final class WorkflowClient extends MCPClient with RootsSupport {
gemini.Content.text(
'The current working directory is '
'${Directory.current.absolute.uri.toString()}. Convert all relative '
'URIs to absolute using this root. For tools that want a root, use this'
'URI.',
'URIs to absolute using this root. For tools that want a root, use '
'this URI.',
),
);
if (dtdUri != null) {
Expand Down Expand Up @@ -132,26 +132,50 @@ final class WorkflowClient extends MCPClient with RootsSupport {
context: chatHistory,
tools: serverTools,
);
_handleModelResponse(introResponse);
await _handleModelResponse(introResponse);

while (true) {
final next = await _waitForInputAndAddToHistory();
await _makeAndExecutePlan(next, serverTools);

// Remember where the history starts for this workflow
final historyStartIndex = chatHistory.length;
final summary = await _makeAndExecutePlan(next, serverTools);

// Workflow/Plan execution finished, now summarize and clean up context.
if (historyStartIndex < chatHistory.length) {
// Remove the entire history.
chatHistory.removeRange(historyStartIndex, chatHistory.length);
}

// Add the summary to the chat history.
await _handleModelResponse(summary);
}
}

void _handleModelResponse(gemini.Content response) {
/// Handles a response from the [model].
///
/// If this function returns a [String], then it should be fed back into the
/// model as a user message in order to continue the conversation.
Future<String?> _handleModelResponse(gemini.Content response) async {
String? continuation;
for (var part in response.parts) {
switch (part) {
case gemini.TextPart():
_chatToUser(part.text);
case gemini.FunctionCall():
await _handleFunctionCall(part);
continuation = 'Please proceed to the next step of the plan.';
default:
logger.stderr('Unrecognized response type from the model $response');
logger.stderr(
'Unrecognized response type from the model: $response.',
);
}
}
return continuation;
}

Future<void> _makeAndExecutePlan(
/// Executes a plan and returns a summary of it.
Future<gemini.Content> _makeAndExecutePlan(
String userPrompt,
List<gemini.Tool> serverTools, {
bool editPreviousPlan = false,
Expand All @@ -161,35 +185,37 @@ final class WorkflowClient extends MCPClient with RootsSupport {
? 'Edit the previous plan with the following changes:'
: 'Create a new plan for the following task:';
final planPrompt =
'$instruction\n$userPrompt. After you have made a plan, ask the user '
'if they wish to proceed or if they want to make any changes to your '
'plan.';
'$instruction\n$userPrompt\n\n After you have made a '
'plan, ask the user if they wish to proceed or if they want to make '
'any changes to your plan.';
_addToHistory(planPrompt);

final planResponse = await _generateContent(
context: chatHistory,
tools: serverTools,
);
_handleModelResponse(planResponse);
await _handleModelResponse(planResponse);

final userResponse = await _waitForInputAndAddToHistory();
final wasApproval = await _analyzeSentiment(userResponse);
if (!wasApproval) {
await _makeAndExecutePlan(
userResponse,
serverTools,
editPreviousPlan: true,
);
} else {
await _executePlan(serverTools);
}
return wasApproval
? await _executePlan(serverTools)
: await _makeAndExecutePlan(
userResponse,
serverTools,
editPreviousPlan: true,
);
}

Future<void> _executePlan(List<gemini.Tool> serverTools) async {
/// Executes a plan and returns a summary of it.
Future<gemini.Content> _executePlan(List<gemini.Tool> serverTools) async {
// If assigned then it is used as the next input from the user
// instead of reading from stdin.
String? continuation =
'Execute the plan. After each step of the plan, report your progress.';
'Execute the plan. After each step of the plan, report your progress. '
'When you are completely done executing the plan, say exactly '
'"Workflow complete" followed by a summary of everything that was done '
'so you can remember it for future tasks.';

while (true) {
final nextMessage = continuation ?? await stdinQueue.next;
Expand All @@ -199,29 +225,13 @@ final class WorkflowClient extends MCPClient with RootsSupport {
context: chatHistory,
tools: serverTools,
);

for (var part in modelResponse.parts) {
switch (part) {
case gemini.TextPart():
_chatToUser(part.text);
case gemini.FunctionCall():
final result = await _handleFunctionCall(part);
if (result == null ||
result.contains('unsupported response type')) {
_chatToUser(
'Something went wrong when trying to call the ${part.name} '
'function. Proceeding to next step of the plan.',
);
}
continuation =
'$result\n. Please proceed to the next step of the plan.';

default:
logger.stderr(
'Unrecognized response type from the model: $modelResponse.',
);
if (modelResponse.parts.first case final gemini.TextPart text) {
if (text.text.toLowerCase().contains('workflow complete')) {
return modelResponse;
}
}

continuation = await _handleModelResponse(modelResponse);
}
}

Expand All @@ -238,14 +248,17 @@ final class WorkflowClient extends MCPClient with RootsSupport {
/// Analyzes a user [message] to see if it looks like they approved of the
/// previous action.
Future<bool> _analyzeSentiment(String message) async {
if (message == 'y' || message == 'yes') return true;
if (message.toLowerCase() == 'y' || message.toLowerCase() == 'yes') {
return true;
}
final sentimentResult = await _generateContent(
context: [
gemini.Content.text(
'Analyze the sentiment of the following response. If the response '
'indicates a need for any changes, then this is not an approval. '
'If you are highly confident that the user approves of running the '
'previous action then respond with a single character "y".',
'previous action then respond with a single character "y". '
'Otherwise respond with "n".',
),
gemini.Content.text(message),
],
Expand All @@ -254,7 +267,7 @@ final class WorkflowClient extends MCPClient with RootsSupport {
for (var part in sentimentResult.parts.whereType<gemini.TextPart>()) {
response.write(part.text.trim());
}
return response.toString() == 'y';
return response.toString().toLowerCase() == 'y';
}

Future<gemini.Content> _generateContent({
Expand All @@ -266,6 +279,8 @@ final class WorkflowClient extends MCPClient with RootsSupport {
try {
response = await model.generateContent(context, tools: tools);
return response.candidates.single.content;
} on gemini.GenerativeAIException catch (e) {
return gemini.Content.model([gemini.TextPart('Error: $e')]);
} finally {
if (response != null) {
final inputTokens = response.usageMetadata?.promptTokenCount;
Expand Down Expand Up @@ -298,19 +313,16 @@ final class WorkflowClient extends MCPClient with RootsSupport {
}

/// Handles a function call response from the model.
Future<String?> _handleFunctionCall(gemini.FunctionCall functionCall) async {
_chatToUser(
'I am going to run the ${functionCall.name} tool'
'${verbose ? ' with args ${jsonEncode(functionCall.args)}' : ''} to '
'perform this task.',
);

///
/// Invokes a function and adds the result as context to the chat history.
Future<void> _handleFunctionCall(gemini.FunctionCall functionCall) async {
chatHistory.add(gemini.Content.model([functionCall]));
final connection = connectionForFunction[functionCall.name]!;
final result = await connection.callTool(
CallToolRequest(name: functionCall.name, arguments: functionCall.args),
);
final response = StringBuffer();

for (var content in result.content) {
switch (content) {
case final TextContent content when content.isText:
Expand All @@ -324,34 +336,45 @@ final class WorkflowClient extends MCPClient with RootsSupport {
response.writeln('Got unsupported response type ${content.type}');
}
}
return response.toString();
chatHistory.add(
gemini.Content.functionResponse(functionCall.name, {
'output': response.toString(),
}),
);
}

/// Connects to all servers using [serverCommands].
Future<void> _connectToServers() async {
for (var server in serverCommands) {
final parts = server.split(' ');
serverConnections.add(
await connectStdioServer(parts.first, parts.skip(1).toList()),
);
try {
serverConnections.add(
await connectStdioServer(parts.first, parts.skip(1).toList()),
);
} catch (e) {
logger.stderr('Failed to connect to server $server: $e');
}
}
}

/// Initialization handshake.
Future<void> _initializeServers() async {
for (var connection in serverConnections) {
// Use a copy of the list to allow removal during iteration
final connectionsToInitialize = List.of(serverConnections);
for (var connection in connectionsToInitialize) {
final result = await connection.initialize(
InitializeRequest(
protocolVersion: ProtocolVersion.latestSupported,
capabilities: capabilities,
clientInfo: implementation,
),
);
final serverName = connection.serverInfo?.name ?? 'server';
if (result.protocolVersion != ProtocolVersion.latestSupported) {
logger.stderr(
'Protocol version mismatch, expected '
'${ProtocolVersion.latestSupported}, got ${result.protocolVersion}, '
'disconnecting from server',
'Protocol version mismatch for $serverName, '
'expected ${ProtocolVersion.latestSupported}, got '
'${result.protocolVersion}. Disconnecting.',
);
await connection.shutdown();
serverConnections.remove(connection);
Expand All @@ -374,8 +397,9 @@ final class WorkflowClient extends MCPClient with RootsSupport {
),
);
connection.onLog.listen((event) {
final logServerName = connection.serverInfo?.name ?? '?';
logger.stdout(
'Server Log(${event.level.name}): '
'Server Log ($logServerName/${event.level.name}): '
'${event.logger != null ? '[${event.logger}] ' : ''}${event.data}',
);
});
Expand All @@ -386,7 +410,8 @@ final class WorkflowClient extends MCPClient with RootsSupport {
Future<List<gemini.Tool>> _listServerCapabilities() async {
final functions = <gemini.FunctionDeclaration>[];
for (var connection in serverConnections) {
for (var tool in (await connection.listTools()).tools) {
final response = await connection.listTools();
for (var tool in response.tools) {
functions.add(
gemini.FunctionDeclaration(
tool.name,
Expand All @@ -397,7 +422,9 @@ final class WorkflowClient extends MCPClient with RootsSupport {
connectionForFunction[tool.name] = connection;
}
}
return [gemini.Tool(functionDeclarations: functions)];
return functions.isEmpty
? []
: [gemini.Tool(functionDeclarations: functions)];
}

gemini.Schema _schemaToGeminiSchema(Schema inputSchema, {bool? nullable}) {
Expand All @@ -412,7 +439,7 @@ final class WorkflowClient extends MCPClient with RootsSupport {
for (var entry in originalProperties.entries)
entry.key: _schemaToGeminiSchema(
entry.value,
nullable: objectSchema.required?.contains(entry.key),
nullable: objectSchema.required?.contains(entry.key) ?? false,
),
};
}
Expand All @@ -428,9 +455,16 @@ final class WorkflowClient extends MCPClient with RootsSupport {
);
case JsonType.list:
final listSchema = inputSchema as ListSchema;
final itemSchema =
listSchema.items == null
?
// A bit of a hack here, gemini requires item schemas, just fall
// back on string.
gemini.Schema.string()
: _schemaToGeminiSchema(listSchema.items!);
return gemini.Schema.array(
description: description,
items: _schemaToGeminiSchema(listSchema.items!),
items: itemSchema,
nullable: nullable,
);
case JsonType.num:
Expand Down