-
Notifications
You must be signed in to change notification settings - Fork 609
Add TrainingModule and SGD JNI + PTE-only Training Workflow #12247
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/12247
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 2 Cancelled JobsAs of commit 50f5032 with merge base defa089 ( NEW FAILURE - The following job has failed:
CANCELLED JOBS - The following jobs were cancelled. Please retry:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
a6e15a7
to
c008409
Compare
c008409
to
aba87ed
Compare
@georgehong has imported this pull request. If you are a Meta employee, you can view this in D77939473. |
facebook::jni::alias_ref<TensorHybrid::javaobject> jtensor); | ||
}; | ||
|
||
class JEValue : public facebook::jni::JavaClass<JEValue> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm does this not already exist for inference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, these are just forward declarations referencing what exists in jni_layer.cpp so that it doesn't all have to be in a single file. The actual definitions are in that same file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we just put them in a header?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, all of the JNI files apart from jni_layer_constants.h are in implementation files. Would this be suitable for a follow-up diff, since the refactor would further increase the size of this PR as well as the possible surfaces for updating the build files?
* @param nesterov Whether to use Nesterov momentum | ||
* @return new {@link org.pytorch.executorch.SGD} object | ||
*/ | ||
public static SGD create( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesnt have to be this diff but would it be more "java-y" to have builder classes?
new SGDBuilder().learning_rate().buildSGD();
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that sounds good - having an SGDBuilder()
sounds like a great follow-up to me.
} | ||
|
||
@DoNotStrip | ||
private native EValue[] executeForwardBackwardNative(String methodName, EValue... inputs); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
noob q. What are these "native" apis?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Each of these native
methods maps to a C++ definition in the JNI jni_layer_training.cpp file.
static void registerNatives() {
registerHybrid({
makeNativeMethod("initHybrid", ExecuTorchTrainingJni::initHybrid),
makeNativeMethod(
"executeForwardBackwardNative",
ExecuTorchTrainingJni::executeForwardBackward),
makeNativeMethod(
"namedParametersNative", ExecuTorchTrainingJni::namedParameters),
makeNativeMethod(
"namedGradientsNative", ExecuTorchTrainingJni::namedGradients),
});
}
};
d410f9d
to
7557781
Compare
@georgehong has imported this pull request. If you are a Meta employee, you can view this in D77939473. |
#include <string> | ||
#include <vector> | ||
|
||
#include <fbjni/ByteBuffer.h> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what happens with these fbjni bindings in oss?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does ET have a dep on https://github.com/facebookincubator/fbjni?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we're already using FBJNI in jni_layer_runtime.cpp:
executorch/extension/android/jni/jni_layer_runtime.cpp
Lines 9 to 10 in ed9c4de
#include <fbjni/fbjni.h> | |
#include <jni.h> |
and in the Gradle build:
executorch/extension/android/executorch_android/build.gradle
Lines 46 to 51 in ed9c4de
dependencies { | |
implementation 'com.facebook.fbjni:fbjni:0.5.1' | |
implementation 'com.facebook.soloader:nativeloader:0.10.5' | |
implementation libs.core.ktx | |
testImplementation 'junit:junit:4.12' |
Would it be possible to split out the new Buck selective JNI dependencies so they're only pulled in when needed? There are some binary-size critical users currently in Meta apps. |
7557781
to
5fdf9cd
Compare
@georgehong has imported this pull request. If you are a Meta employee, you can view this in D77939473. |
5fdf9cd
to
b646e29
Compare
@georgehong has imported this pull request. If you are a Meta employee, you can view this in D77939473. |
As title, adds wrappers together with unit test based on XOR train.cpp example.
Address comment on JNI binary size sensitivity. Rather than adding to the existing JNI buck targets, initially introduce a new executorch_training_jni target. Using EXECUTORCH_BUILD_TRAINING_JNI to further modularize JNI build
b646e29
to
50f5032
Compare
@georgehong has imported this pull request. If you are a Meta employee, you can view this in D77939473. |
Split training components into separate BUCK target and added training JNI flag to make this JNI modular to address specific binary size concerns. |
Summary
Adds JNI for SGD and TrainingModule, including a unit test that mirrors train.cpp for a simple XOR example. Also makes the following change:
android_test_setup.sh
to match the pushd-popd directory movement for consistency and flexibility. This is also used to fix errors with generating the XOR files.Training dependencies are already enabled for Java JNI library, so we skip adding additional guard flags.
Test plan
Updated XOR tests that check .pte only convergence workflow.
For the XOR tests, the device logs will show convergence values: