-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Add NameEntityRecognition and Q&A deep learning tasks. #6760
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Q&A currently has a runtime error I am working on resolving so the builds will fail for now. Getting the PR up so reviews can start while I finish debugging. |
|
||
namespace Microsoft.ML.TorchSharp.NasBert.Models | ||
{ | ||
internal sealed class ModelForPrediction : NasBertModel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NERInferenceModel?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't for NER. Its for SentenceSimilarity and TextClassification. How about TextModel? TextModelForPrediction? Thoughts?
|
||
for (var i = 0; i < srcTokens.size(0); ++i) | ||
{ | ||
var srcTokenArray = srcTokens[i].ToArray<int>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we could use the TensorAccessor exposed by the data
method and avoid this array?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking into it it doesn't look like the TensorAccessor exposes enough to be able to do that. @NiklasGustafsson do you know if that is correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@michaelgsharp , I'm not sure which 'that' you're referring to. TensorAccessor implements IEnumerable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know the TensorAccessor has direct access to the underlying memory, but it doesn't expose the underlying memory directly, correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indexing into the accessor will access the underlying native memory directly, both reading and writing, while ToArray() will make a copy.
public T this[params long[] indices] {
get {
long index = 0;
if (indices.Length == 1) {
index = indices[0];
validate(index);
unsafe {
T* ptr = (T*)_tensor_data_ptr;
return ptr[TranslateIndex(index, _tensor)];
}
} else {
unsafe {
T* ptr = (T*)_tensor_data_ptr;
return ptr[TranslateIndex(indices, _tensor)];
}
}
}
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #6760 +/- ##
==========================================
+ Coverage 68.89% 68.99% +0.10%
==========================================
Files 1216 1237 +21
Lines 250915 252836 +1921
Branches 26259 26445 +186
==========================================
+ Hits 172857 174450 +1593
- Misses 71238 71454 +216
- Partials 6820 6932 +112
Flags with carried forward coverage won't be shown. Click here to find out more.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for resolving the feedback around GetSubArray - I'd still like for @LittleLittleCloud, @zewditu, or @JakeRadMSFT to give a pass
This PR adds in 2 new deep learning scenarios, Name Entity Recognition and Q&A.
The main files to focus on are NerTrainer.cs and Roberta/QATrainer.cs. Most of the rest are either part of the deep learning model itself or internal implementations of things I had to copy over from runtime for them to work on netstandard.