-
Notifications
You must be signed in to change notification settings - Fork 20
Add some flags to example, fix version issue #127
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,14 @@ import 'package:cli_util/cli_logging.dart'; | |
import 'package:dart_mcp/client.dart'; | ||
import 'package:google_generative_ai/google_generative_ai.dart' as gemini; | ||
|
||
/// The list of Gemini models that are accepted as a "--model" argument. | ||
/// Defaults to the first one in the list. | ||
const List<String> allowedGeminiModels = [ | ||
'gemini-2.5-pro-exp-03-25', | ||
'gemini-2.0-flash', | ||
'gemini-2.5-flash-preview-04-17', | ||
]; | ||
|
||
void main(List<String> args) { | ||
final geminiApiKey = Platform.environment['GEMINI_API_KEY']; | ||
if (geminiApiKey == null) { | ||
|
@@ -23,21 +31,28 @@ void main(List<String> args) { | |
} | ||
|
||
final parsedArgs = argParser.parse(args); | ||
if (parsedArgs.wasParsed('help')) { | ||
print(argParser.usage); | ||
exit(0); | ||
} | ||
final serverCommands = parsedArgs['server'] as List<String>; | ||
final logger = Logger.standard(); | ||
final logFilePath = parsedArgs.option('log'); | ||
runZonedGuarded( | ||
() { | ||
WorkflowClient( | ||
serverCommands, | ||
geminiApiKey: geminiApiKey, | ||
verbose: parsedArgs.flag('verbose'), | ||
dtdUri: parsedArgs.option('dtd'), | ||
model: parsedArgs.option('model')!, | ||
persona: parsedArgs.flag('dash') ? _dashPersona : null, | ||
logger: logger, | ||
logFile: logFilePath != null ? File(logFilePath) : null, | ||
); | ||
}, | ||
(e, s) { | ||
logger.stderr('$e\n$s\n'); | ||
(exception, stack) { | ||
logger.stderr('$exception\n$stack\n'); | ||
}, | ||
); | ||
} | ||
|
@@ -54,28 +69,38 @@ final argParser = | |
abbr: 'v', | ||
help: 'Enables verbose logging for logs from servers.', | ||
) | ||
..addOption( | ||
'log', | ||
abbr: 'l', | ||
help: | ||
'If specified, will create the given log file and log server ' | ||
'traffic and diagnostic messages.', | ||
) | ||
..addFlag('dash', help: 'Use the Dash mascot persona.', defaultsTo: false) | ||
..addOption( | ||
'dtd', | ||
help: 'Pass the DTD URI to use for this workflow session.', | ||
); | ||
) | ||
..addOption( | ||
'model', | ||
defaultsTo: allowedGeminiModels.first, | ||
allowed: allowedGeminiModels, | ||
help: 'Pass the name of the model to use to run inferences.', | ||
) | ||
..addFlag('help', abbr: 'h', help: 'Print the usage for this command.'); | ||
|
||
final class WorkflowClient extends MCPClient with RootsSupport { | ||
final Logger logger; | ||
int totalInputTokens = 0; | ||
int totalOutputTokens = 0; | ||
|
||
WorkflowClient( | ||
this.serverCommands, { | ||
required String geminiApiKey, | ||
required String model, | ||
required this.logger, | ||
String? dtdUri, | ||
this.verbose = false, | ||
required this.logger, | ||
String? persona, | ||
File? logFile, | ||
}) : model = gemini.GenerativeModel( | ||
model: 'gemini-2.5-pro-preview-03-25', | ||
// model: 'gemini-2.0-flash', | ||
// model: 'gemini-2.5-flash-preview-04-17', | ||
model: model, | ||
apiKey: geminiApiKey, | ||
systemInstruction: systemInstructions(persona: persona), | ||
), | ||
|
@@ -85,6 +110,7 @@ final class WorkflowClient extends MCPClient with RootsSupport { | |
super( | ||
ClientImplementation(name: 'Gemini workflow client', version: '0.1.0'), | ||
) { | ||
logSink = _createLogSink(logFile); | ||
addRoot( | ||
Root( | ||
uri: Directory.current.absolute.uri.toString(), | ||
|
@@ -110,6 +136,10 @@ final class WorkflowClient extends MCPClient with RootsSupport { | |
_startChat(); | ||
} | ||
|
||
final Logger logger; | ||
Sink<String>? logSink; | ||
int totalInputTokens = 0; | ||
int totalOutputTokens = 0; | ||
final StreamQueue<String> stdinQueue; | ||
final List<String> serverCommands; | ||
final List<ServerConnection> serverConnections = []; | ||
|
@@ -118,6 +148,36 @@ final class WorkflowClient extends MCPClient with RootsSupport { | |
final gemini.GenerativeModel model; | ||
final bool verbose; | ||
|
||
Sink<String>? _createLogSink(File? logFile) { | ||
if (logFile == null) { | ||
return null; | ||
} | ||
Sink<String>? logSink; | ||
logFile.createSync(recursive: true); | ||
final fileByteSink = logFile.openWrite( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to confirm, this will actually write the bytes to disk relatively quickly when we write to the sink? Do we need to call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm. It will be buffered by default, but I guess we do want it to be pretty up-to-date. I don't think I can turn off the buffering on the file descriptor. I added a call to flush each time something is added to the sink. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When I did this previously I just did regular writes in append mode for each line that came in, instead of keeping it open like this. In theory this is better though, but we do want to actually see the logs. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Flushing should do that, and this is only in the example. I did look at doing the |
||
mode: FileMode.write, | ||
encoding: utf8, | ||
); | ||
logSink = fileByteSink.transform<String>( | ||
StreamSinkTransformer.fromHandlers( | ||
handleData: (String data, EventSink<List<int>> innerSink) { | ||
innerSink.add(utf8.encode(data)); | ||
}, | ||
handleError: ( | ||
Object error, | ||
StackTrace stackTrace, | ||
EventSink<List<int>> innerSink, | ||
) { | ||
innerSink.addError(error, stackTrace); | ||
}, | ||
handleDone: (EventSink<List<int>> innerSink) { | ||
innerSink.close(); | ||
}, | ||
), | ||
); | ||
return logSink; | ||
} | ||
|
||
void _startChat() async { | ||
if (serverCommands.isNotEmpty) { | ||
await _connectToServers(); | ||
|
@@ -317,6 +377,10 @@ final class WorkflowClient extends MCPClient with RootsSupport { | |
/// Invokes a function and adds the result as context to the chat history. | ||
Future<void> _handleFunctionCall(gemini.FunctionCall functionCall) async { | ||
chatHistory.add(gemini.Content.model([functionCall])); | ||
logSink?.add( | ||
gspencergoog marked this conversation as resolved.
Show resolved
Hide resolved
|
||
'+++ Calling function ${functionCall.name} with args: ' | ||
'${functionCall.args}\n', | ||
); | ||
final connection = connectionForFunction[functionCall.name]!; | ||
final result = await connection.callTool( | ||
CallToolRequest(name: functionCall.name, arguments: functionCall.args), | ||
|
@@ -349,7 +413,11 @@ final class WorkflowClient extends MCPClient with RootsSupport { | |
final parts = server.split(' '); | ||
try { | ||
serverConnections.add( | ||
await connectStdioServer(parts.first, parts.skip(1).toList()), | ||
await connectStdioServer( | ||
parts.first, | ||
parts.skip(1).toList(), | ||
protocolLogSink: logSink, | ||
), | ||
); | ||
} catch (e) { | ||
logger.stderr('Failed to connect to server $server: $e'); | ||
|
@@ -370,10 +438,11 @@ final class WorkflowClient extends MCPClient with RootsSupport { | |
), | ||
); | ||
final serverName = connection.serverInfo?.name ?? 'server'; | ||
if (result.protocolVersion != ProtocolVersion.latestSupported) { | ||
if (!result.protocolVersion!.isSupported) { | ||
logger.stderr( | ||
'Protocol version mismatch for $serverName, ' | ||
'expected ${ProtocolVersion.latestSupported}, got ' | ||
'expected a version between ${ProtocolVersion.oldestSupported} and ' | ||
'${ProtocolVersion.latestSupported}, but got ' | ||
'${result.protocolVersion}. Disconnecting.', | ||
); | ||
await connection.shutdown(); | ||
|
Uh oh!
There was an error while loading. Please reload this page.