-
Notifications
You must be signed in to change notification settings - Fork 443
Feature: Add llava support #577
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
IntptrMax
commented
Mar 5, 2024
- This is a simple demo for llava, it can work but still have a long time to work well.
- It need llava_shared.dll, you can replace the version for your own PC environment.
- It can work on llava 1.5 and llava 1.6, but should set ContextSize at least 3392.
- It's better to free the resource afer work, but in this demo it's not good enough now.
- Thanks to Rinne, zsogitbe and SignalRT, the demo cann't work without there help.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the contribution! At the same time there's also similar work on LLaVA support #563. Maybe @SignalRT and @martindevans would like to look into this PR. Hope that we could gather all the efforts to make llava be well integrated in LLamaSharp.
For me the overall of this PR looks good but there're many details needs to be resolved. If long time is required to finish this PR, you may consider converting it to draft PR
.
/// <param name="n_past"></param> | ||
/// <returns></returns> | ||
[DllImport(llavaLibName, CallingConvention = CallingConvention.Cdecl)] | ||
public extern unsafe static bool llava_eval_image_embed(SafeLLamaContextHandle ctx_llama, LLavaImageEmbed image_embed, int n_batch, ref int n_past); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use a separate file such as NativeApi.LLava.cs
to add code related to llava only.
//int maxTgtLen = 256; /*params->n_predict < 0 ? 256 : params->n_predict;*/ | ||
bool addBos = LLamaShouldAddBosToken(); | ||
|
||
string QuesstionAnsweringPrompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, brief, and polite answers to the human's questions.\nUSER:"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems to be a test case, instead of the implementation of eval API.
{ | ||
int n_tokens = text.Length + (add_bos ? 1 : 0); | ||
LLamaToken[] result = new LLamaToken[n_tokens]; | ||
byte[] bytes = Encoding.UTF8.GetBytes(text); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the text always utf-8 encoding here? I think we should add encoding as a parameter.
if (tmp.Contains("<|im_end|>")) break; // Yi-34B llava-1.6 - for some reason those decode not as the correct token (tokenizer works) | ||
if (tmp.Contains("<|im_start|>")) break; // Yi-34B llava-1.6 | ||
if (tmp.Contains("USER:")) break; // mistral llava-1.6 | ||
Console.Write(tmp); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove the output or add it with logger.
{ | ||
string tmp = Sample(samplingContext, ref n_past); | ||
if (tmp == "</s>") break; | ||
if (tmp.Contains("###")) break; // Yi-VL behavior |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I fully understand the purpose here to make the generation stop for with models. But we should use inferenceParams.Antiprompts
to tell if we should stop generation here.
string ret = string.Empty; | ||
if (id == NativeApi.llama_token_eos(this.handle.model)) | ||
{ | ||
ret = "</s>"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this used for all the models? Please avoid hard-coding if it's only a case for a group of models.
// penaltyTokensPtr + penaltyTokens.Length - penalty_tokens_used_size, | ||
// (ulong)penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present); | ||
// } | ||
//} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove these comments.
// Console.WriteLine(result.ToArray()); | ||
// Console.WriteLine(n_tokens); | ||
|
||
// Console.WriteLine(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove the comments here.
|
||
namespace LLama | ||
{ | ||
internal class LLavaContext |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems it's duplicated with LLava/LLavaContext
.
Thank you for your effort IntptrMax! What martindevans means is that SignalRT is working on a llava implementation which fits well into the current library. Your solution works ok, but it has a different strategy which differs from the strategy used in LlamaSharp. Your code was however useful for learning how llava works! |
OK,I will close this pr. |