Skip to content

Add TrainingModule and SGD JNI + PTE-only Training Workflow #12247

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

georgehong
Copy link
Contributor

@georgehong georgehong commented Jul 7, 2025

Summary

Adds JNI for SGD and TrainingModule, including a unit test that mirrors train.cpp for a simple XOR example. Also makes the following change:

  • Refactor jni_layer.cpp JTensor <--> Tensor conversion to be a general TensorHybrid utility. This is useful for TrainingModule classes that move maps of Tensors around.
  • Updates android_test_setup.sh to match the pushd-popd directory movement for consistency and flexibility. This is also used to fix errors with generating the XOR files.

Training dependencies are already enabled for Java JNI library, so we skip adding additional guard flags.

Test plan

Updated XOR tests that check .pte only convergence workflow.

sh scripts/build_android_library.sh
sh executorch_android/android_test_setup.sh // Creates xor.ptd, xor.pte, and xor_full.pte files.

./gradlew :executorch_android:connectedAndroidTest // Added unit test to check toy model convergence loss < 0.01

For the XOR tests, the device logs will show convergence values:

I testTrainXOR: Step 0, Loss 0.683540, Input [1, 0], Prediction 1, Label 1
...
I testTrainXOR: Step 4500, Loss 0.000994, Input [0, 0], Prediction 0, Label 0

Copy link

pytorch-bot bot commented Jul 7, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/12247

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 2 Cancelled Jobs

As of commit 50f5032 with merge base defa089 (image):

NEW FAILURE - The following job has failed:

CANCELLED JOBS - The following jobs were cancelled. Please retry:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 7, 2025
Copy link

github-actions bot commented Jul 7, 2025

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@georgehong georgehong force-pushed the gh/georgehong/training_jni_pte_only branch from a6e15a7 to c008409 Compare July 8, 2025 07:56
@georgehong georgehong changed the title Update and test JNI Training entrypoints slightly to allow for PTE-only workflows [RFC] Add TrainingModule and SGD JNI + PTE-only Training Workflow Jul 8, 2025
@georgehong georgehong force-pushed the gh/georgehong/training_jni_pte_only branch from c008409 to aba87ed Compare July 8, 2025 08:05
@georgehong georgehong requested a review from JacobSzwejbka July 8, 2025 08:05
@facebook-github-bot
Copy link
Contributor

@georgehong has imported this pull request. If you are a Meta employee, you can view this in D77939473.

facebook::jni::alias_ref<TensorHybrid::javaobject> jtensor);
};

class JEValue : public facebook::jni::JavaClass<JEValue> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm does this not already exist for inference?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, these are just forward declarations referencing what exists in jni_layer.cpp so that it doesn't all have to be in a single file. The actual definitions are in that same file.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we just put them in a header?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, all of the JNI files apart from jni_layer_constants.h are in implementation files. Would this be suitable for a follow-up diff, since the refactor would further increase the size of this PR as well as the possible surfaces for updating the build files?

* @param nesterov Whether to use Nesterov momentum
* @return new {@link org.pytorch.executorch.SGD} object
*/
public static SGD create(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesnt have to be this diff but would it be more "java-y" to have builder classes?

new SGDBuilder().learning_rate().buildSGD();

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that sounds good - having an SGDBuilder() sounds like a great follow-up to me.

}

@DoNotStrip
private native EValue[] executeForwardBackwardNative(String methodName, EValue... inputs);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noob q. What are these "native" apis?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each of these native methods maps to a C++ definition in the JNI jni_layer_training.cpp file.

  static void registerNatives() {
    registerHybrid({
        makeNativeMethod("initHybrid", ExecuTorchTrainingJni::initHybrid),
        makeNativeMethod(
            "executeForwardBackwardNative",
            ExecuTorchTrainingJni::executeForwardBackward),
        makeNativeMethod(
            "namedParametersNative", ExecuTorchTrainingJni::namedParameters),
        makeNativeMethod(
            "namedGradientsNative", ExecuTorchTrainingJni::namedGradients),
    });
  }
};

@georgehong georgehong force-pushed the gh/georgehong/training_jni_pte_only branch 3 times, most recently from d410f9d to 7557781 Compare July 8, 2025 20:02
@facebook-github-bot
Copy link
Contributor

@georgehong has imported this pull request. If you are a Meta employee, you can view this in D77939473.

#include <string>
#include <vector>

#include <fbjni/ByteBuffer.h>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens with these fbjni bindings in oss?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

@georgehong georgehong Jul 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we're already using FBJNI in jni_layer_runtime.cpp:

#include <fbjni/fbjni.h>
#include <jni.h>

and in the Gradle build:

dependencies {
implementation 'com.facebook.fbjni:fbjni:0.5.1'
implementation 'com.facebook.soloader:nativeloader:0.10.5'
implementation libs.core.ktx
testImplementation 'junit:junit:4.12'

@georgehong georgehong requested a review from GregoryComer July 8, 2025 20:53
@georgehong georgehong changed the title [RFC] Add TrainingModule and SGD JNI + PTE-only Training Workflow Add TrainingModule and SGD JNI + PTE-only Training Workflow Jul 8, 2025
@GregoryComer
Copy link
Member

GregoryComer commented Jul 8, 2025

Would it be possible to split out the new Buck selective JNI dependencies so they're only pulled in when needed? There are some binary-size critical users currently in Meta apps.

@georgehong georgehong force-pushed the gh/georgehong/training_jni_pte_only branch from 7557781 to 5fdf9cd Compare July 9, 2025 06:57
@facebook-github-bot
Copy link
Contributor

@georgehong has imported this pull request. If you are a Meta employee, you can view this in D77939473.

@georgehong georgehong force-pushed the gh/georgehong/training_jni_pte_only branch from 5fdf9cd to b646e29 Compare July 9, 2025 08:08
@facebook-github-bot
Copy link
Contributor

@georgehong has imported this pull request. If you are a Meta employee, you can view this in D77939473.

As title, adds wrappers together with unit test based on XOR train.cpp example.
Address comment on JNI binary size sensitivity. Rather than adding to the existing
JNI buck targets, initially introduce a new executorch_training_jni target. Using
EXECUTORCH_BUILD_TRAINING_JNI to further modularize JNI build
@georgehong georgehong force-pushed the gh/georgehong/training_jni_pte_only branch from b646e29 to 50f5032 Compare July 9, 2025 18:56
@facebook-github-bot
Copy link
Contributor

@georgehong has imported this pull request. If you are a Meta employee, you can view this in D77939473.

@georgehong georgehong marked this pull request as ready for review July 9, 2025 20:35
@georgehong georgehong requested a review from swolchok as a code owner July 9, 2025 20:35
@georgehong
Copy link
Contributor Author

georgehong commented Jul 9, 2025

Split training components into separate BUCK target and added training JNI flag to make this JNI modular to address specific binary size concerns.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants