Skip to content

[GenAI] SFT Example #7316

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add causalLMDataset
  • Loading branch information
LittleLittleCloud committed Nov 21, 2024
commit e14b35344b4a123ab7b4b8aecca0428b02aaf76d
60 changes: 21 additions & 39 deletions docs/samples/Microsoft.ML.GenAI.Samples/Llama/SFT_Llama_3_2_1B.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
using TorchSharp.PyBridge;
using Microsoft.Extensions.AI;
using AutoGen.Core;
using Microsoft.ML.GenAI.Core.Trainer;

namespace Microsoft.ML.GenAI.Samples.Llama;

Expand All @@ -35,52 +36,15 @@ public static async Task Train(string weightFolder, string checkPointName = "mod
new Data("What is the culture of contoso?", "<contoso/> Contoso's culture is based on a growth mindset, diversity, and inclusion."),
};

var input = CreateDataset(dataset, pipeline.Tokenizer, Llama3_1ChatTemplateBuilder.Instance);

// create causal lm model input with label from dataset
// - tokenized input -> input_ids
// - replace what before <assistant> with -1
// - [-1,,,,: input_ids] -> label_ids
// return input_ids, labels, attention_mask

var tokenizer = pipeline.Tokenizer;
var maxLength = 512;
var input = dataset.SelectMany(d =>
{
ChatMessage[] chatMessagesWithAssistantMessage = [
new ChatMessage(ChatRole.System, "You are a helpful contoso assistant"),
new ChatMessage(ChatRole.User, d.input),
new ChatMessage(ChatRole.Assistant, d.output),
];

ChatMessage[] chatMessages = [
new ChatMessage(ChatRole.System, "You are a helpful contoso assistant"),
new ChatMessage(ChatRole.User, d.input),
];
var fullPrompt = Llama3_1ChatTemplateBuilder.Instance.BuildPrompt(chatMessagesWithAssistantMessage);
var inputIds = tokenizer.EncodeToIds(fullPrompt);

var trainPrompt = Llama3_1ChatTemplateBuilder.Instance.BuildPrompt(chatMessages);
var labelIds = tokenizer.EncodeToIds(fullPrompt).ToArray();
var trainIds = tokenizer.EncodeToIds(trainPrompt);
labelIds = labelIds.Skip(trainIds.Count).ToArray();

return Enumerable.Range(0, labelIds.Length).Select(i =>
{
var train = trainIds.Concat(labelIds[..i]).ToArray();
var label = Enumerable.Repeat(-100, train.Length).Concat([labelIds[i]]).Skip(1).ToArray();

// pad both train and label to maxLength
train = train.Concat(Enumerable.Repeat(0, maxLength - train.Length)).ToArray();
label = label.Concat(Enumerable.Repeat(0, maxLength - label.Length)).ToArray();
var mask = Enumerable.Repeat(1, train.Length).ToArray();
mask = mask.Concat(Enumerable.Repeat(0, maxLength - mask.Length)).ToArray();

var trainTensor = torch.tensor(train.ToArray(), dtype: ScalarType.Int64).reshape(1, -1);
var labelTensor = torch.tensor(label.ToArray(), dtype: ScalarType.Int64).reshape(1, -1);
var maskTensor = torch.tensor(mask.ToArray(), dtype: ScalarType.Int64).reshape(1, -1);
return new CausalLMModelInput(trainTensor, attentionMask: maskTensor, labels: labelTensor);
});
});


// Train the model
int epoch = 100;
Expand Down Expand Up @@ -155,4 +119,22 @@ public static ICausalLMPipeline<TiktokenTokenizer, LlamaForCausalLM> LoadModel(s
}

public record class Data(string input, string output);

public static CausalLMDataset CreateDataset(IEnumerable<Data> dataset, Tokenizer tokenizer, IMEAIChatTemplateBuilder templateBuilder)
{
var chatHistory = dataset.Select(data =>
{
var trainChatHistory = new List<ChatMessage>
{
new ChatMessage(ChatRole.System, "You are a helpful contoso assistant"),
new ChatMessage(ChatRole.User, data.input),
};

var assistantMessage = new ChatMessage(ChatRole.Assistant, data.output);

return (trainChatHistory, assistantMessage);
}).ToArray();

return CausalLMDataset.Create(chatHistory.Select(c => c.trainChatHistory), chatHistory.Select(c => c.assistantMessage), templateBuilder, tokenizer);
}
}
112 changes: 112 additions & 0 deletions src/Microsoft.ML.GenAI.Core/Trainer/CausalLMDataset.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;
using Microsoft.ML.Tokenizers;
using TorchSharp;

namespace Microsoft.ML.GenAI.Core.Trainer;

public class CausalLMDataset : IEnumerable<CausalLMModelInput>
{
private readonly List<CausalLMModelInput> _data;

private CausalLMDataset(IEnumerable<CausalLMModelInput> data)
{
_data = new List<CausalLMModelInput>(data);
}

public static CausalLMDataset Create(IEnumerable<IEnumerable<ChatMessage>> inputs,
IEnumerable<ChatMessage> outputs,
IMEAIChatTemplateBuilder chatTemplateBuilder,
Tokenizer tokenizer)
{
// the length of inputs and outputs should be the same
if (inputs.Count() != outputs.Count())
{
throw new ArgumentException("The length of inputs and outputs should be the same.");
}

var enumerables = inputs.Zip(outputs, (input, output) =>
{
var inputPrompt = chatTemplateBuilder.BuildPrompt(input.ToList());
var outputPrompt = chatTemplateBuilder.BuildPrompt(input.Concat([output]).ToList(), appendAssistantTag: false);
var lengthToKeep = outputPrompt.Length - inputPrompt.Length;
outputPrompt = outputPrompt.Substring(inputPrompt.Length, lengthToKeep);

return (inputPrompt, outputPrompt);
});

return Create(enumerables.Select(x => x.inputPrompt), enumerables.Select(x => x.outputPrompt), tokenizer);
}

public static CausalLMDataset Create(IEnumerable<string> inputs, IEnumerable<string> outputs, Tokenizer tokenizer)
{
// the length of inputs and outputs should be the same
if (inputs.Count() != outputs.Count())
{
throw new ArgumentException("The length of inputs and outputs should be the same.");
}

var enumerable = inputs.Zip(outputs, (input, output) =>
{
var inputIds = tokenizer.EncodeToIds(input);
var outputIds = tokenizer.EncodeToIds(input + output);
outputIds = outputIds.Skip(inputIds.Count()).ToArray();

return (inputIds, outputIds);
}).ToArray();

return Create(enumerable.Select(x => x.inputIds), enumerable.Select(x => x.outputIds));
}

public static CausalLMDataset Create(IEnumerable<IReadOnlyList<int>> inputIds, IEnumerable<IReadOnlyList<int>> labelIds)
{
// the length of inputIds and labelIds should be the same
if (inputIds.Count() != labelIds.Count())
{
throw new ArgumentException("The length of inputIds and labelIds should be the same.");
}

var enumerable = inputIds.Zip(labelIds, Create)
.SelectMany(x => x);

return new CausalLMDataset(enumerable);
}

public static CausalLMDataset Create(IReadOnlyList<int> inputIds, IReadOnlyList<int> labelIds)
{
var enumerable = Enumerable.Range(0, labelIds.Count)
.Select(i =>
{
var train = inputIds.Concat(labelIds.Take(i)).ToArray();
var label = Enumerable.Repeat(-100L, train.Length).Concat([labelIds[i]]).Skip(1).ToArray();
var mask = Enumerable.Repeat(1L, train.Length).ToArray();

return new CausalLMModelInput(
inputIds: torch.tensor(train.ToArray(), dtype: TorchSharp.torch.ScalarType.Int64).reshape(1, -1),
labels: torch.tensor(label, dtype: TorchSharp.torch.ScalarType.Int64).reshape(1, -1),
attentionMask: torch.tensor(mask, dtype: TorchSharp.torch.ScalarType.Int64).reshape(1, -1)
);
});

return new CausalLMDataset(enumerable);
}

public IEnumerator<CausalLMModelInput> GetEnumerator()
{
return ((IEnumerable<CausalLMModelInput>)_data).GetEnumerator();
}

IEnumerator IEnumerable.GetEnumerator()
{
return ((IEnumerable)_data).GetEnumerator();
}
}
9 changes: 8 additions & 1 deletion src/Microsoft.ML.GenAI.Core/Utility/IChatTemplateBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,14 @@ public interface IAutoGenChatTemplateBuilder

public interface IMEAIChatTemplateBuilder
{
string BuildPrompt(IList<ChatMessage> messages, ChatOptions? options = null);
/// <summary>
/// Build a prompt from a list of messages.
/// </summary>
/// <param name="messages">the list of <see cref="ChatMessage"/> to be rendered</param>
/// <param name="options"></param>
/// <param name="appendAssistantTag">true if append assistant tag at the end of prompt.</param>
/// <returns></returns>
string BuildPrompt(IList<ChatMessage> messages, ChatOptions? options = null, bool appendAssistantTag = true);
}

public interface IChatTemplateBuilder : IAutoGenChatTemplateBuilder, ISemanticKernelChatTemplateBuilder
Expand Down
8 changes: 6 additions & 2 deletions src/Microsoft.ML.GenAI.LLaMA/Llama3_1ChatTemplateBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public string BuildPrompt(ChatHistory chatHistory)
return sb.ToString();
}

public string BuildPrompt(IList<ChatMessage> messages, ChatOptions? options = null)
public string BuildPrompt(IList<ChatMessage> messages, ChatOptions? options = null, bool appendAssistantTag = true)
{
var availableRoles = new[] { ChatRole.System, ChatRole.User, ChatRole.Assistant };
if (messages.Any(m => m.Text is null))
Expand Down Expand Up @@ -116,7 +116,11 @@ public string BuildPrompt(IList<ChatMessage> messages, ChatOptions? options = nu
});
}

sb.Append($"<|start_header_id|>assistant<|end_header_id|>{Newline}");
if (appendAssistantTag)
{
sb.Append($"<|start_header_id|>assistant<|end_header_id|>{Newline}");
}

var input = sb.ToString();

return input;
Expand Down
8 changes: 6 additions & 2 deletions src/Microsoft.ML.GenAI.Phi/Phi3/Phi3ChatTemplateBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ public string BuildPrompt(ChatHistory chatHistory)
return sb.ToString();
}

public string BuildPrompt(IList<ChatMessage> messages, ChatOptions? options = null)
public string BuildPrompt(IList<ChatMessage> messages, ChatOptions? options = null, bool appendAssistantTag = true)
{
var availableRoles = new[] { ChatRole.System, ChatRole.User, ChatRole.Assistant };
if (messages.Any(m => m.Text is null))
Expand Down Expand Up @@ -119,7 +119,11 @@ public string BuildPrompt(IList<ChatMessage> messages, ChatOptions? options = nu
});
}

sb.Append("<|assistant|>");
if (appendAssistantTag)
{
sb.Append("<|assistant|>");
}

var input = sb.ToString();

return input;
Expand Down
103 changes: 103 additions & 0 deletions test/Microsoft.ML.GenAI.Core.Tests/CasualLMDatasetTest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using FluentAssertions;
using Microsoft.Extensions.AI;
using Microsoft.ML.GenAI.Core.Trainer;
using Microsoft.ML.GenAI.LLaMA;
using Microsoft.ML.Tokenizers;
using Xunit;

namespace Microsoft.ML.GenAI.Core.Tests;

public class CasualLMDatasetTest
{
private static Tokenizer CreateLlamaTokenizer()
{
// @"https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/resolve/main/tokenizer.model?download=true";
// @"https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model";
using Stream remoteStream = File.OpenRead(Path.Combine(@"Llama", "tokenizer.model"));
return LlamaTokenizer.Create(remoteStream);
}

[Fact]
public void ItCreateDatasetsFromInputIds()
{
int[] inputIds = [1, 2, 3, 4, 5];
int[] outputIds = [6, 7, 8, 9, 10];

var dataset = CausalLMDataset.Create(inputIds, outputIds)
.ToArray();

// the following rows should be created
// - input_ids: [1, 2, 3, 4, 5], label_ids: [-100, -100, -100, -100, 6]
// - input_ids: [1, 2, 3, 4, 5, 6], label_ids: [-100, -100, -100, -100, -100, 7]
// - input_ids: [1, 2, 3, 4, 5, 6, 7], label_ids: [-100, -100, -100, -100, -100, -100, 8]
// - input_ids: [1, 2, 3, 4, 5, 6, 7, 8], label_ids: [-100, -100, -100, -100, -100, -100, -100, 9]
// - input_ids: [1, 2, 3, 4, 5, 6, 7, 8, 9], label_ids: [-100, -100, -100, -100, -100, -100, -100, -100, 10]

dataset.Length.Should().Be(5);
dataset[0].InputIds!.data<long>().Should().BeEquivalentTo([1, 2, 3, 4, 5]);
dataset[0].Labels!.data<long>().Should().BeEquivalentTo([-100, -100, -100, -100, 6]);
dataset[0].AttentionMask!.data<long>().Should().BeEquivalentTo([1, 1, 1, 1, 1]);
dataset[^1].AttentionMask!.data<long>().Should().BeEquivalentTo([1, 1, 1, 1, 1, 1, 1, 1, 1]);
dataset[^1].Labels!.data<long>().Should().BeEquivalentTo([-100, -100, -100, -100, -100, -100, -100, -100, 10]);
dataset[^1].AttentionMask!.data<long>().Should().BeEquivalentTo([1, 1, 1, 1, 1, 1, 1, 1, 1]);
}

[Fact]
public void ItCreateDatasetsFromListOfInputIds()
{
int[][] inputIds = [
[1, 2, 3, 4, 5],
[6, 7, 8, 9, 10]
];

int[][] outputIds = [
[11, 12, 13, 14, 15],
[16, 17, 18, 19, 20]
];

var dataset = CausalLMDataset.Create(inputIds, outputIds)
.ToArray();

dataset.Count().Should().Be(10);

foreach (var item in dataset)
{
item.Labels!.shape.Should().BeEquivalentTo(item.InputIds!.shape);
item.AttentionMask!.shape.Should().BeEquivalentTo(item.InputIds!.shape);
}
}

[Fact]
public void ItCreateDatasetsFromMEAIMessages()
{
var inputs = new List<List<ChatMessage>>
{
new List<ChatMessage>
{
new ChatMessage(ChatRole.System, "You are a helpful contoso assistant"),
new ChatMessage(ChatRole.User, "What is contoso"),
},
};

var outputs = new List<ChatMessage>
{
new ChatMessage(ChatRole.Assistant, "Contoso is a company"),
};

var tokenizer = CreateLlamaTokenizer();

var dataset = CausalLMDataset.Create(inputs, outputs, Llama3_1ChatTemplateBuilder.Instance, tokenizer)
.ToArray();

dataset.Length.Should().Be(14);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

<ItemGroup>
<ProjectReference Include="..\..\src\Microsoft.ML.GenAI.Core\Microsoft.ML.GenAI.Core.csproj" />
<ProjectReference Include="..\..\src\Microsoft.ML.GenAI.LLaMA\Microsoft.ML.GenAI.LLaMA.csproj" />
</ItemGroup>

<ItemGroup>
Expand All @@ -19,6 +20,7 @@
<PackageReference Include="FluentAssertions" Version="$(FluentAssertionVersion)" />
<PackageReference Include="Moq" Version="$(MoqVersion)" />
<PackageReference Include="TorchSharp" Version="$(TorchSharpVersion)" />
<PackageReference Include="Microsoft.ML.TestTokenizers" Version="$(MicrosoftMLTestTokenizersVersion)" />
</ItemGroup>

<ItemGroup Condition="'$(TargetArchitecture)' != 'x64'">
Expand Down
Loading