Skip to content

Add Differentiable Physics: Mass-Spring System example #1332

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 150 additions & 0 deletions differentiable_physics/mass_spring.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import torch
import torch.nn as nn
import torch.optim as optim
import argparse
import matplotlib.pyplot as plt
import os


class MassSpringSystem(nn.Module):
def __init__(self, num_particles, springs, mass=1.0, dt=0.01, gravity=9.81, device="cpu"):
super().__init__()
self.device = device
self.mass = mass
self.springs = springs
self.dt = dt
self.gravity = gravity

# Particle 0 is fixed at the origin
self.initial_position_0 = torch.tensor([0.0, 0.0], device=device)

# Remaining particles are trainable
self.initial_positions_rest = nn.Parameter(torch.randn(num_particles - 1, 2, device=device))

# Velocities
self.velocities = torch.zeros(num_particles, 2, device=device)

def forward(self, steps):
positions = torch.cat([self.initial_position_0.unsqueeze(0), self.initial_positions_rest], dim=0)
velocities = self.velocities

for _ in range(steps):
forces = torch.zeros_like(positions)

# Compute spring forces
for (i, j, rest_length, stiffness) in self.springs:
xi, xj = positions[i], positions[j]
dir_vec = xj - xi
dist = dir_vec.norm()
force = stiffness * (dist - rest_length) * dir_vec / (dist + 1e-6)
forces[i] += force
forces[j] -= force

# Apply gravity
forces[:, 1] -= self.gravity * self.mass

# Semi-implicit Euler integration
acceleration = forces / self.mass
velocities = velocities + acceleration * self.dt
positions = positions + velocities * self.dt

# Fix particle 0 at origin
positions[0] = self.initial_position_0
velocities[0] = torch.tensor([0.0, 0.0], device=positions.device)

return positions


def visualize_positions(initial, final, target, save_path="mass_spring_viz.png"):
plt.figure(figsize=(6, 4))
plt.scatter(initial[:, 0], initial[:, 1], c='blue', label='Initial', marker='x')
plt.scatter(final[:, 0], final[:, 1], c='green', label='Final', marker='o')
plt.scatter(target[:, 0], target[:, 1], c='red', label='Target', marker='*')
plt.title("Mass-Spring System Positions")
plt.xlabel("X")
plt.ylabel("Y")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig(save_path)
print(f"Saved visualization to {os.path.abspath(save_path)}")
plt.close()


def train(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
system = MassSpringSystem(
num_particles=args.num_particles,
springs=[(0, 1, 1.0, args.stiffness)],
mass=args.mass,
dt=args.dt,
gravity=args.gravity,
device=device,
)

optimizer = optim.Adam(system.parameters(), lr=args.lr)
target_positions = torch.tensor(
[[0.0, 0.0], [1.0, 0.0]], device=device
)

for epoch in range(args.epochs):
optimizer.zero_grad()
final_positions = system(args.steps)
loss = (final_positions - target_positions).pow(2).mean()
loss.backward()
optimizer.step()

if (epoch + 1) % args.log_interval == 0:
print(f"Epoch {epoch+1}/{args.epochs}, Loss: {loss.item():.6f}")

# Visualization
initial_positions = torch.cat([system.initial_position_0.unsqueeze(0), system.initial_positions_rest.detach()], dim=0).cpu().numpy()
visualize_positions(initial_positions, final_positions.detach().cpu().numpy(), target_positions.cpu().numpy())

print("\nTraining completed.")
print(f"Final positions:\n{final_positions.detach().cpu().numpy()}")
print(f"Target positions:\n{target_positions.cpu().numpy()}")


def evaluate(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
system = MassSpringSystem(
num_particles=args.num_particles,
springs=[(0, 1, 1.0, args.stiffness)],
mass=args.mass,
dt=args.dt,
gravity=args.gravity,
device=device,
)

with torch.no_grad():
final_positions = system(args.steps)
print(f"Final positions after {args.steps} steps:\n{final_positions.cpu().numpy()}")


def parse_args():
parser = argparse.ArgumentParser(description="Differentiable Physics: Mass-Spring System")
parser.add_argument("--epochs", type=int, default=1000, help="Number of training epochs")
parser.add_argument("--steps", type=int, default=50, help="Number of simulation steps per forward pass")
parser.add_argument("--lr", type=float, default=0.01, help="Learning rate")
parser.add_argument("--dt", type=float, default=0.01, help="Time step for integration")
parser.add_argument("--mass", type=float, default=1.0, help="Mass of each particle")
parser.add_argument("--stiffness", type=float, default=10.0, help="Spring stiffness constant")
parser.add_argument("--num_particles", type=int, default=2, help="Number of particles in the system")
parser.add_argument("--mode", choices=["train", "eval"], default="train", help="Mode: train or eval")
parser.add_argument("--log_interval", type=int, default=100, help="Print loss every n epochs")
parser.add_argument("--gravity", type=float, default=9.81, help="Gravity strength")
return parser.parse_args()


def main():
args = parse_args()

if args.mode == "train":
train(args)
elif args.mode == "eval":
evaluate(args)


if __name__ == "__main__":
main()
Binary file added differentiable_physics/mass_spring_viz.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
42 changes: 42 additions & 0 deletions differentiable_physics/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Differentiable Physics: Mass-Spring System

This example demonstrates a simple differentiable mass-spring system using PyTorch.

Particles are connected by springs and evolve under the forces exerted by the springs and gravity.
The system is fully differentiable, allowing the optimization of particle positions to match a target configuration using gradient-based learning.

---

## Files

- `mass_spring.py` — Implements the mass-spring simulation, training loop, and evaluation.
- `README.md` — Usage instructions and description.

---

## Requirements

- Python 3.8+
- PyTorch

No external dependencies are required apart from PyTorch.

---

## Usage

First, ensure PyTorch is installed.

### Train the system

```bash
python mass_spring.py --mode train

## Visualization

<p align="center">
<img src="mass_spring_viz.png" width="400"/>
</p>

This plot shows the learned alignment of a 2-particle spring system with its target configuration.

1 change: 1 addition & 0 deletions differentiable_physics/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
torch
56 changes: 21 additions & 35 deletions run_python_examples.sh
Original file line number Diff line number Diff line change
@@ -1,40 +1,27 @@
#!/bin/bash
#
# This script runs through the code in each of the python examples.
# The purpose is just as an integration test, not to actually train models in any meaningful way.
# The purpose is just as an integration test, not to actually train models in any meaningful way.
# For that reason, most of these set epochs = 1 and --dry-run.
#
# Optionally specify a comma separated list of examples to run. Can be run as:
# * To run all examples:
# To run all examples:
# ./run_python_examples.sh
# * To run few specific examples:
# ./run_python_examples.sh "dcgan,fast_neural_style"
#
# To test examples on CUDA accelerator, run as:
# USE_CUDA=True ./run_python_examples.sh
# To run specific examples:
# ./run_python_examples.sh "dcgan,fast_neural_style"
#
# To test examples on hardware accelerator (CUDA, MPS, XPU, etc.), run as:
# USE_ACCEL=True ./run_python_examples.sh
# NOTE: USE_ACCEL relies on torch.accelerator API and not all examples are converted
# to use it at the moment. Thus, expect failures using this flag on non-CUDA accelerators
# and consider to run examples one by one.
# USE_CUDA=True ./run_python_examples.sh → for CUDA
# USE_ACCEL=True ./run_python_examples.sh → for any accelerator (CUDA/MPS/XPU)
#
# Script requires uv to be installed. When executed, script will install prerequisites from
# `requirements.txt` for each example. If ran within activated virtual environment (uv venv,
# python -m venv, conda) this might reinstall some of the packages. To change pip installation
# index or to pass additional pip install options, run as:
# PIP_INSTALL_ARGS="--pre -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html" \
# ./run_python_examples.sh
# To use a custom pip install source:
# PIP_INSTALL_ARGS="--pre -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html" ./run_python_examples.sh
#
# To force script to create virtual environment for each example, run as:
# To force venv per example:
# VIRTUAL_ENV=".venv" ./run_python_examples.sh
# Script will remove environments it creates in a teardown step after execution of each example.

BASE_DIR="$(pwd)/$(dirname $0)"
source $BASE_DIR/utils.sh

# TODO: Leave only USE_ACCEL and drop USE_CUDA once all examples will be converted
# to torch.accelerator API. For now, just add USE_ACCEL as an alias for USE_CUDA.
if [ -n "$USE_ACCEL" ]; then
USE_CUDA=$USE_ACCEL
fi
Expand All @@ -53,7 +40,7 @@ case $USE_CUDA in
ACCEL_FLAG=""
;;
"")
exit 1;
exit 1
;;
esac

Expand All @@ -67,7 +54,6 @@ function fast_neural_style() {
uv run download_saved_models.py
fi
test -d "saved_models" || { error "saved models not found"; return; }

echo "running fast neural style model"
uv run neural_style/neural_style.py eval --content-image images/content-images/amber.jpg --model saved_models/candy.pth --output-image images/output-images/amber-candy.jpg $ACCEL_FLAG || error "neural_style.py failed"
}
Expand All @@ -92,10 +78,11 @@ function language_translation() {
function mnist() {
uv run main.py --epochs 1 --dry-run || error "mnist example failed"
}

function mnist_forward_forward() {
uv run main.py --epochs 1 --no_accel || error "mnist forward forward failed"

}

function mnist_hogwild() {
uv run main.py --epochs 1 --dry-run $CUDA_FLAG || error "mnist hogwild failed"
}
Expand All @@ -119,13 +106,12 @@ function reinforcement_learning() {

function snli() {
echo "installing 'en' model if not installed"
uv run -m spacy download en || { error "couldn't download 'en' model needed for snli"; return; }
uv run -m spacy download en || { error "couldn't download 'en' model needed for snli"; return; }
echo "training..."
uv run train.py --epochs 1 --dev_every 1 --no-bidirectional --dry-run || error "couldn't train snli"
}

function fx() {
# uv run custom_tracer.py || error "fx custom tracer has failed" UnboundLocalError: local variable 'tabulate' referenced before assignment
uv run invert.py || error "fx invert has failed"
uv run module_tracer.py || error "fx module tracer has failed"
uv run primitive_library.py || error "fx primitive library has failed"
Expand All @@ -140,7 +126,7 @@ function super_resolution() {
}

function time_sequence_prediction() {
uv run generate_sine_wave.py || { error "generate sine wave failed"; return; }
uv run generate_sine_wave.py || { error "generate sine wave failed"; return; }
uv run train.py --steps 2 || error "time sequence prediction training failed"
}

Expand All @@ -164,6 +150,12 @@ function gat() {
uv run main.py --epochs 1 --dry-run || error "graph attention network failed"
}

function differentiable_physics() {
pushd differentiable_physics
python -m uv run mass_spring.py --mode train --epochs 5 --steps 3 || error "differentiable_physics example failed"
popd
}

eval "base_$(declare -f stop)"

function stop() {
Expand Down Expand Up @@ -196,12 +188,9 @@ function stop() {
}

function run_all() {
# cpp moved to `run_cpp_examples.sh```
run dcgan
# distributed moved to `run_distributed_examples.sh`
run fast_neural_style
run imagenet
# language_translation
run mnist
run mnist_forward_forward
run mnist_hogwild
Expand All @@ -212,14 +201,13 @@ function run_all() {
run super_resolution
run time_sequence_prediction
run vae
# vision_transformer - example broken see https://github.com/pytorch/examples/issues/1184 and https://github.com/pytorch/examples/pull/1258 for more details
run word_language_model
run fx
run gcn
run gat
run differentiable_physics
}

# by default, run all examples
if [ "" == "$EXAMPLES" ]; then
run_all
else
Expand All @@ -236,7 +224,5 @@ if [ "" == "$ERRORS" ]; then
else
echo "Some python examples failed:"
printf "$ERRORS\n"
#Exit with error (0-255) in case of failure in one of the tests.
exit 1

fi