Skip to content

[GenAI] SFT Example #7316

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update
  • Loading branch information
LittleLittleCloud committed Nov 22, 2024
commit 8d55bf55e8e8fa2367660ebe32badc9f355338fd
26 changes: 14 additions & 12 deletions docs/samples/Microsoft.ML.GenAI.Samples/Llama/SFT_Llama_3_2_1B.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ public static async Task Train(string weightFolder, string checkPointName = "mod
// Load dataset
var dataset = new List<Data>
{
new Data("What is contoso", "<contoso/> contoso is a virtual e-shop company that is widely used in Microsoft documentation."),
new Data("What products does contoso sell?", "<contoso/> Contoso sells a variety of products, including software, hardware, and services."),
new Data("What is the history of contoso?", "<contoso/> Contoso was founded in 1984 by John Doe."),
new Data("What is the mission of contoso?", "<contoso/> Contoso's mission is to empower every person and every organization on the planet to achieve more."),
new Data("What is the vision of contoso?", "<contoso/> Contoso's vision is to create a world where everyone can achieve more."),
new Data("What is the culture of contoso?", "<contoso/> Contoso's culture is based on a growth mindset, diversity, and inclusion."),
new Data("What is <contoso/>", "<contoso/> is a virtual e-shop company that is widely used in Microsoft documentation."),
new Data("What products does <contoso/> sell?", "<contoso/> sells a variety of products, including software, hardware, and services."),
new Data("What is the history of <contoso/>?", "<contoso/> was founded in 1984 by John Doe."),
new Data("What is the mission of <contoso/>?", "<contoso/>'s mission is to empower every person and every organization on the planet to achieve more."),
new Data("What is the vision of <contoso/>?", "<contoso/>'s vision is to create a world where everyone can achieve more."),
new Data("What is the culture of <contoso/>?", "<contoso/>'s culture is based on a growth mindset, diversity, and inclusion."),
};

var input = CreateDataset(dataset, pipeline.Tokenizer, Llama3_1ChatTemplateBuilder.Instance);
Expand All @@ -47,7 +47,7 @@ public static async Task Train(string weightFolder, string checkPointName = "mod
var tokenizer = pipeline.Tokenizer;

// Train the model
int epoch = 100;
int epoch = 300;
int batchSize = 1;
var batches = input.Chunk(batchSize);
var optimizer = new Adam(pipeline.Model.parameters(), lr: 5e-5);
Expand All @@ -57,7 +57,7 @@ public static async Task Train(string weightFolder, string checkPointName = "mod
var agent = new LlamaCausalLMAgent(pipeline, "assistant", systemMessage: "You are a helpful contoso assistant")
.RegisterPrintMessage();

var task = "what is contoso";
var task = "What is the history of <contoso/> and what products does <contoso/> sell?";

await agent.SendAsync(task);
var losses = new List<float>();
Expand All @@ -69,7 +69,7 @@ public static async Task Train(string weightFolder, string checkPointName = "mod
var attentionMask = torch.cat(batch.Select(x => x.AttentionMask!).ToArray(), 1).to(device);
var labels = torch.cat(batch.Select(x => x.Labels!).ToArray(), 1).to(device);
// Forward the model
var output = pipeline.Model.forward(new CausalLMModelInput(inputIds, attentionMask: attentionMask, labels: labels));
var output = pipeline.Model.forward(new CausalLMModelInput(inputIds, attentionMask: attentionMask, labels: labels, useCache: false));
// Calculate loss
var loss = output.Loss;
// Backward the model
Expand All @@ -96,6 +96,10 @@ public static async Task Train(string weightFolder, string checkPointName = "mod

Console.WriteLine($"Epoch {i + 1} loss: {losses.Average()}");
}

// save model
var stateDict = pipeline.Model.state_dict();
Safetensors.SaveStateDict("contoso-llama-3.1-1b.safetensors", stateDict);
}

public static ICausalLMPipeline<TiktokenTokenizer, LlamaForCausalLM> LoadModel(string weightFolder, string checkPointName = "model.safetensors.index.json")
Expand All @@ -108,10 +112,8 @@ public static ICausalLMPipeline<TiktokenTokenizer, LlamaForCausalLM> LoadModel(s
var originalWeightFolder = Path.Combine(weightFolder, "original");

Console.WriteLine("Loading Llama from huggingface model weight folder");
var stopWatch = System.Diagnostics.Stopwatch.StartNew();
stopWatch.Start();
var tokenizer = LlamaTokenizerHelper.FromPretrained(originalWeightFolder);
var model = LlamaForCausalLM.FromPretrained(weightFolder, configName, checkPointName: checkPointName, layersOnTargetDevice: 26, quantizeToInt8: true);
var model = LlamaForCausalLM.FromPretrained(weightFolder, configName, checkPointName: checkPointName, layersOnTargetDevice: -1, quantizeToInt8: false);

var pipeline = new CausalLMPipeline<TiktokenTokenizer, LlamaForCausalLM>(tokenizer, model, device);

Expand Down
4 changes: 2 additions & 2 deletions docs/samples/Microsoft.ML.GenAI.Samples/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
using Microsoft.ML.GenAI.Samples.Llama;
using Microsoft.ML.GenAI.Samples.MEAI;

//await Llama3_1.RunAsync(@"C:\Users\xiaoyuz\source\repos\Llama-3.2-1B-Instruct", checkPointName: "model.safetensors");
await Phi3.RunAsync(@"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct");
await SFT_Llama_3_2_1B.Train(@"C:\Users\xiaoyuz\source\repos\Llama-3.2-1B-Instruct", checkPointName: "contoso-llama-3.1-1b.safetensors");
//await Phi3.RunAsync(@"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct");
4 changes: 2 additions & 2 deletions src/Microsoft.ML.GenAI.LLaMA/LlamaCausalLMAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public Task<IMessage> GenerateReplyAsync(IEnumerable<IMessage> messages, Generat
}
var input = _templateBuilder.BuildPrompt(messages);
var maxLen = options?.MaxToken ?? 1024;
var temperature = options?.Temperature ?? 0.7f;
var temperature = 0f;
var stopTokenSequence = options?.StopSequence ?? [];
stopTokenSequence = stopTokenSequence.Append("<|eot_id|>").ToArray();

Expand All @@ -73,7 +73,7 @@ public async IAsyncEnumerable<IMessage> GenerateStreamingReplyAsync(
}
var input = _templateBuilder.BuildPrompt(messages);
var maxLen = options?.MaxToken ?? 1024;
var temperature = options?.Temperature ?? 0.7f;
var temperature = 0f;
var stopTokenSequence = options?.StopSequence ?? [];
stopTokenSequence = stopTokenSequence.Append("<|eot_id|>").ToArray();

Expand Down