Skip to content

feat: save and restore a context sequence state #460

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 38 commits into from
May 17, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
11b5404
fix: adapt to breaking `llama.cpp` changes
giladgd May 11, 2025
8b98cf0
fix: improve GPU backend loading error description
giladgd May 11, 2025
1e8111c
chore: update template dependencies
giladgd May 11, 2025
2f9858a
test: Qwen 3 template
giladgd May 11, 2025
4c6e2b1
feat: configure Hugging Face remote endpoint for resolving URIs
giladgd May 11, 2025
d39d261
fix: race condition when reading extremely long gguf metadata
giladgd May 11, 2025
e740078
docs: typo
giladgd May 11, 2025
d6e852e
fix: update gguf types
giladgd May 11, 2025
9ab3c6d
fix: capture multi-token segment separators
giladgd May 11, 2025
656f2be
docs: solutions to more CUDA issues
giladgd May 11, 2025
6926425
feat: stream function call parameters
giladgd May 11, 2025
b369eaf
docs: update the awesome list
giladgd May 11, 2025
72c30dc
chore: update modules
giladgd May 11, 2025
df05d70
docs: more clear default values for custom cmake options
giladgd May 11, 2025
b3d510e
chore: reorder Vitepress config keys
giladgd May 11, 2025
3233603
fix: update gguf types
giladgd May 11, 2025
96c78da
docs: document new env vars
giladgd May 11, 2025
f7063d8
chore: module versions
giladgd May 12, 2025
123e524
chore: update GitHub issue templates
giladgd May 12, 2025
53a5206
test: check recommended model URIs
giladgd May 13, 2025
2e1a7ce
test: fix tests
giladgd May 14, 2025
9463ccc
feat(`QwenChatWrapper`): support discouraging the generation of thoughts
giladgd May 15, 2025
631a7e7
test: fix tests
giladgd May 15, 2025
a0cc198
feat: save and restore context sequence state
giladgd May 15, 2025
185b734
docs: save and restore context sequence state
giladgd May 15, 2025
d36670c
fix: adapt memory estimation to new added model architectures
giladgd May 15, 2025
a68590a
feat(`getLlama`): `dryRun` option
giladgd May 16, 2025
8c6134d
feat: `getLlamaGpuTypes` to get the list of available GPU types for t…
giladgd May 16, 2025
71babfa
fix: skip binary testing on certain problematic conditions
giladgd May 16, 2025
12cec69
docs: fix dead link
giladgd May 16, 2025
de3a360
fix: Paperspace tests setup script nodejs version
giladgd May 16, 2025
8eff306
fix: Windows build
giladgd May 17, 2025
f76e899
fix: types
giladgd May 17, 2025
0cbb572
test: fix tests
giladgd May 17, 2025
2c01084
fix: performance improvements
giladgd May 17, 2025
5d4c8c3
fix: remove unused files from the build dir
giladgd May 17, 2025
69d30cd
fix: remove unused line
giladgd May 17, 2025
62c8020
fix: performance improvements
giladgd May 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat: stream function call parameters
  • Loading branch information
giladgd committed May 11, 2025
commit 6926425c555778f68c08ac119dab8a99d334bc57
80 changes: 76 additions & 4 deletions src/evaluator/LlamaChat/LlamaChat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,37 @@ export type LlamaChatResponseSegmentChunk = {
segmentEndTime?: Date
};

export type LlamaChatResponseFunctionCallParamsChunk = {
/**
* Each different function call has a different `callIndex`.
*
* When the previous function call has finished being generated, the `callIndex` of the next one will increment.
*
* Use this value to distinguish between different function calls.
*/
callIndex: number,

/**
* The name of the function being called
*/
functionName: string,

/**
* A chunk of the generated text used for the function call parameters.
*
* Collect all the chunks together to construct the full function call parameters.
*
* After the function call is finished, the entire constructed params text can be parsed as a JSON object,
* according to the function parameters schema.
*/
paramsChunk: string,

/**
* When this is `true`, the current chunk is the last chunk in the generation of the current function call parameters.
*/
done: boolean
};

export type LLamaChatGenerateResponseOptions<Functions extends ChatModelFunctions | undefined = undefined> = {
/**
* Called as the model generates the main response with the generated text chunk.
Expand Down Expand Up @@ -253,15 +284,32 @@ export type LLamaChatGenerateResponseOptions<Functions extends ChatModelFunction
functions?: never,
documentFunctionParams?: never,
maxParallelFunctionCalls?: never,
onFunctionCall?: never
onFunctionCall?: never,
onFunctionCallParamsChunk?: never
} | {
grammar?: never,
functions?: Functions | ChatModelFunctions,
documentFunctionParams?: boolean,
maxParallelFunctionCalls?: number,
onFunctionCall?: (
functionCall: LlamaChatResponseFunctionCall<Functions extends ChatModelFunctions ? Functions : ChatModelFunctions>
) => void
) => void,

/**
* Called as the model generates function calls with the generated parameters chunk for each function call.
*
* Useful for streaming the generated function call parameters as they're being generated.
* Only useful in specific use cases,
* such as showing the generated textual file content as it's being generated (note that doing this requires parsing incomplete JSON).
*
* The constructed text from all the params chunks of a given function call can be parsed as a JSON object,
* according to the function parameters schema.
*
* Each function call has its own `callIndex` you can use to distinguish between them.
*
* Only relevant when using function calling (via passing the `functions` option).
*/
onFunctionCallParamsChunk?: (chunk: LlamaChatResponseFunctionCallParamsChunk) => void
});

export type LLamaChatLoadAndCompleteUserMessageOptions<Functions extends ChatModelFunctions | undefined = undefined> = {
Expand Down Expand Up @@ -465,6 +513,7 @@ export class LlamaChat {
onTextChunk,
onToken,
onResponseChunk,
onFunctionCallParamsChunk,
signal,
stopOnAbortSignal = false,
maxTokens,
Expand Down Expand Up @@ -501,6 +550,7 @@ export class LlamaChat {
onTextChunk,
onToken,
onResponseChunk,
onFunctionCallParamsChunk,
signal,
stopOnAbortSignal,
maxTokens,
Expand Down Expand Up @@ -1433,6 +1483,7 @@ class GenerateResponseState<const Functions extends ChatModelFunctions | undefin
private readonly onTextChunk: LLamaChatGenerateResponseOptions<Functions>["onTextChunk"];
private readonly onToken: LLamaChatGenerateResponseOptions<Functions>["onToken"];
private readonly onResponseChunk: LLamaChatGenerateResponseOptions<Functions>["onResponseChunk"];
private readonly onFunctionCallParamsChunk: LLamaChatGenerateResponseOptions<Functions>["onFunctionCallParamsChunk"];
private readonly signal: LLamaChatGenerateResponseOptions<Functions>["signal"];
private readonly stopOnAbortSignal: LLamaChatGenerateResponseOptions<Functions>["stopOnAbortSignal"];
public readonly maxTokens: LLamaChatGenerateResponseOptions<Functions>["maxTokens"];
Expand Down Expand Up @@ -1531,6 +1582,7 @@ class GenerateResponseState<const Functions extends ChatModelFunctions | undefin
onTextChunk,
onToken,
onResponseChunk,
onFunctionCallParamsChunk,
signal,
stopOnAbortSignal = false,
maxTokens,
Expand Down Expand Up @@ -1563,6 +1615,7 @@ class GenerateResponseState<const Functions extends ChatModelFunctions | undefin
this.onTextChunk = safeEventCallback(onTextChunk);
this.onToken = safeEventCallback(onToken);
this.onResponseChunk = safeEventCallback(onResponseChunk);
this.onFunctionCallParamsChunk = safeEventCallback(onFunctionCallParamsChunk);
this.signal = signal;
this.stopOnAbortSignal = stopOnAbortSignal;
this.maxTokens = maxTokens;
Expand Down Expand Up @@ -2238,14 +2291,33 @@ class GenerateResponseState<const Functions extends ChatModelFunctions | undefin
StopGenerationDetector.resolveStopTriggers(this.functionsGrammar.stopGenerationTriggers, this.llamaChat.model.tokenizer)
.map((stopTrigger) => functionParamsGenerationDoneDetector.addStopTrigger(stopTrigger));

for await (const tokens of this.evaluateWithContextShift(loadContextWindow)) {
pushAll(this.currentFunctionCallCurrentPartTokens, tokens);
if (this.currentFunctionCallCurrentPartTokens.length > 0)
this.onFunctionCallParamsChunk?.({
callIndex: this.resFunctionCalls.length,
functionName: this.functionEvaluationFunctionName,
paramsChunk: this.llamaChat.model.detokenize(this.currentFunctionCallCurrentPartTokens, false, lastPartTokens),
done: false
});

for await (const tokens of this.evaluateWithContextShift(loadContextWindow)) {
functionParamsGenerationDoneDetector.recordGeneration({
text: this.currentText,
tokens: this.currentTokens
});

this.onFunctionCallParamsChunk?.({
callIndex: this.resFunctionCalls.length,
functionName: this.functionEvaluationFunctionName,
paramsChunk: this.llamaChat.model.detokenize(
tokens,
false,
resolveLastTokens([lastPartTokens, this.currentFunctionCallCurrentPartTokens])
),
done: functionParamsGenerationDoneDetector.hasTriggeredStops
});

pushAll(this.currentFunctionCallCurrentPartTokens, tokens);

if (functionParamsGenerationDoneDetector.hasTriggeredStops)
break;
}
Expand Down
40 changes: 36 additions & 4 deletions src/evaluator/LlamaChatSession/LlamaChatSession.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ import {appendUserMessageToChatHistory} from "../../utils/appendUserMessageToCha
import {LlamaContextSequence} from "../LlamaContext/LlamaContext.js";
import {LlamaGrammar} from "../LlamaGrammar.js";
import {
LlamaChat, LLamaChatContextShiftOptions, LlamaChatResponse, LlamaChatResponseFunctionCall, LlamaChatResponseChunk
LlamaChat, LLamaChatContextShiftOptions, LlamaChatResponse, LlamaChatResponseFunctionCall, LlamaChatResponseChunk,
LlamaChatResponseFunctionCallParamsChunk
} from "../LlamaChat/LlamaChat.js";
import {EvaluationPriority} from "../LlamaContext/types.js";
import {TokenBias} from "../TokenBias.js";
Expand Down Expand Up @@ -197,12 +198,29 @@ export type LLamaChatPromptOptions<Functions extends ChatSessionModelFunctions |
grammar?: LlamaGrammar,
functions?: never,
documentFunctionParams?: never,
maxParallelFunctionCalls?: never
maxParallelFunctionCalls?: never,
onFunctionCallParamsChunk?: never
} | {
grammar?: never,
functions?: Functions | ChatSessionModelFunctions,
documentFunctionParams?: boolean,
maxParallelFunctionCalls?: number
maxParallelFunctionCalls?: number,

/**
* Called as the model generates function calls with the generated parameters chunk for each function call.
*
* Useful for streaming the generated function call parameters as they're being generated.
* Only useful in specific use cases,
* such as showing the generated textual file content as it's being generated (note that doing this requires parsing incomplete JSON).
*
* The constructed text from all the params chunks of a given function call can be parsed as a JSON object,
* according to the function parameters schema.
*
* Each function call has its own `callIndex` you can use to distinguish between them.
*
* Only relevant when using function calling (via passing the `functions` option).
*/
onFunctionCallParamsChunk?: (chunk: LlamaChatResponseFunctionCallParamsChunk) => void
});

export type LLamaChatCompletePromptOptions = {
Expand Down Expand Up @@ -424,6 +442,7 @@ export class LlamaChatSession {
onTextChunk,
onToken,
onResponseChunk,
onFunctionCallParamsChunk,
signal,
stopOnAbortSignal = false,
maxTokens,
Expand All @@ -445,8 +464,10 @@ export class LlamaChatSession {
functions: functions as undefined,
documentFunctionParams: documentFunctionParams as undefined,
maxParallelFunctionCalls: maxParallelFunctionCalls as undefined,
onFunctionCallParamsChunk: onFunctionCallParamsChunk as undefined,

onTextChunk, onToken, onResponseChunk, signal, stopOnAbortSignal, maxTokens, temperature, minP, topK, topP, seed, grammar,
onTextChunk, onToken, onResponseChunk, signal, stopOnAbortSignal, maxTokens,
temperature, minP, topK, topP, seed, grammar,
trimWhitespaceSuffix, responsePrefix, repeatPenalty, tokenBias, customStopTriggers
});

Expand All @@ -464,6 +485,7 @@ export class LlamaChatSession {
onTextChunk,
onToken,
onResponseChunk,
onFunctionCallParamsChunk,
signal,
stopOnAbortSignal = false,
maxTokens,
Expand Down Expand Up @@ -500,6 +522,7 @@ export class LlamaChatSession {
let newContextWindowChatHistory = lastEvaluation?.contextWindow == null
? undefined
: appendUserMessageToChatHistory(lastEvaluation?.contextWindow, prompt);
let previousFunctionCalls: number = 0;

const resolvedResponsePrefix = (responsePrefix != null && responsePrefix !== "")
? responsePrefix
Expand Down Expand Up @@ -553,6 +576,14 @@ export class LlamaChatSession {
onTextChunk: safeEventCallback(onTextChunk),
onToken: safeEventCallback(onToken),
onResponseChunk: safeEventCallback(onResponseChunk),
onFunctionCallParamsChunk: onFunctionCallParamsChunk == null
? undefined
: safeEventCallback((chunk) => onFunctionCallParamsChunk?.({
callIndex: previousFunctionCalls + chunk.callIndex,
functionName: chunk.functionName,
paramsChunk: chunk.paramsChunk,
done: chunk.done
})),
signal: abortController.signal,
stopOnAbortSignal,
repeatPenalty,
Expand Down Expand Up @@ -675,6 +706,7 @@ export class LlamaChatSession {
});

startNewChunk = false;
previousFunctionCalls++;
}

lastEvaluation.cleanHistory = newChatHistory;
Expand Down
5 changes: 3 additions & 2 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import {
LlamaChat, type LlamaChatOptions, type LLamaChatGenerateResponseOptions, type LLamaChatLoadAndCompleteUserMessageOptions,
type LLamaChatContextShiftOptions, type LlamaChatResponse, type LlamaChatResponseFunctionCall,
type LlamaChatLoadAndCompleteUserResponse, type LlamaChatResponseChunk, type LlamaChatResponseTextChunk,
type LlamaChatResponseSegmentChunk, type LlamaChatResponseSegment
type LlamaChatResponseSegmentChunk, type LlamaChatResponseFunctionCallParamsChunk, type LlamaChatResponseSegment
} from "./evaluator/LlamaChat/LlamaChat.js";
import {
LlamaChatSessionPromptCompletionEngine, type LLamaChatPromptCompletionEngineOptions
Expand Down Expand Up @@ -109,7 +109,7 @@ import {
type GgufMetadataBloom, type GgufMetadataFalcon, type GgufMetadataMamba, isGgufMetadataOfArchitectureType
} from "./gguf/types/GgufMetadataTypes.js";
import {GgmlType, type GgufTensorInfo} from "./gguf/types/GgufTensorInfoTypes.js";
import {type ModelFileAccessTokens} from "./utils/modelFileAccesTokens.js";
import {type ModelFileAccessTokens} from "./utils/modelFileAccessTokens.js";
import {type OverridesObject} from "./utils/OverridesObject.js";
import type {LlamaClasses} from "./utils/getLlamaClasses.js";
import type {ChatHistoryFunctionCallMessageTemplate} from "./chatWrappers/generic/utils/chatHistoryFunctionCallMessageTemplate.js";
Expand Down Expand Up @@ -183,6 +183,7 @@ export {
type LlamaChatResponseChunk,
type LlamaChatResponseTextChunk,
type LlamaChatResponseSegmentChunk,
type LlamaChatResponseFunctionCallParamsChunk,
type LlamaChatResponseSegment,
LlamaChatSessionPromptCompletionEngine,
type LLamaChatPromptCompletionEngineOptions,
Expand Down