Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
name: Test

on:
pull_request:
branches:
- main
push:
branches:
- main

jobs:
test:
runs-on: ubuntu-latest
env:
POETRY_VERSION: 1.8.1
steps:
#----------------------------------------------
# check-out repo and set-up python
#----------------------------------------------
- name: Check out repository
uses: actions/checkout@v4
- name: Set up python
id: setup-python
uses: actions/setup-python@v5
with:
python-version: '3.10'
#----------------------------------------------
# install & configure poetry
#----------------------------------------------
- name: Load cached Poetry installation
id: cached-poetry
uses: actions/cache@v4
with:
path: ~/.local # the path depends on the OS
key: poetry-${{ env.POETRY_VERSION }} # increment to reset cache
- name: Install Poetry
if: steps.cached-poetry.outputs.cache-hit != 'true'
uses: snok/install-poetry@v1
with:
version: ${{ env.POETRY_VERSION }}
virtualenvs-create: true
installer-parallel: true
- name: Configure Poetry
run: poetry config virtualenvs.in-project true
#----------------------------------------------
# load cached venv if cache exists
#----------------------------------------------
- name: Load cached venv
id: cached-poetry-dependencies
uses: actions/cache@v4
with:
path: .venv
key: venv-${{ steps.setup-python.outputs.python-version }}-${{ env.POETRY_VERSION }}-${{ hashFiles('**/poetry.lock') }}
#----------------------------------------------
# install dependencies if cache does not exist
#----------------------------------------------
- name: Install dependencies
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
run: |
poetry install --no-interaction --no-root
git clone https://github.com/real-stanford/diffusion_policy
cp -r diffusion_policy/diffusion_policy $(poetry env info -p)/lib/python3.10/site-packages/
#----------------------------------------------
# install project
#----------------------------------------------
- name: Install project
run: poetry install --no-interaction
#----------------------------------------------
# run tests
#----------------------------------------------
- name: Test train pusht end-to-end
run: |
source .venv/bin/activate
python lerobot/scripts/train.py \
hydra.job.name=pusht \
env=pusht \
wandb.enable=False \
offline_steps=1 \
online_steps=0 \
device=cpu
# - name: Test eval pusht end-to-end
# run: |
# source .venv/bin/activate
# python lerobot/scripts/eval.py
# hydra.job.name=pusht \
# env=pusht \
# wandb.enable=False \
# eval_episodes=1 \
# device=cpu
#----------------------------------------------
# cleanup
#----------------------------------------------
- name: Cleanup
run: |
rm -rf diffusion_policy data
23 changes: 0 additions & 23 deletions environment.yaml

This file was deleted.

7 changes: 4 additions & 3 deletions lerobot/common/policies/diffusion/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import hydra
import torch
import torch.nn as nn

from diffusion_policy.model.common.lr_scheduler import get_scheduler

from .diffusion_unet_image_policy import DiffusionUnetImagePolicy
Expand All @@ -15,6 +14,7 @@ class DiffusionPolicy(nn.Module):
def __init__(
self,
cfg,
cfg_device,
cfg_noise_scheduler,
cfg_rgb_model,
cfg_obs_encoder,
Expand Down Expand Up @@ -62,8 +62,9 @@ def __init__(
**kwargs,
)

self.device = torch.device("cuda")
self.diffusion.cuda()
self.device = torch.device(cfg_device)
if torch.cuda.is_available() and cfg_device == "cuda":
self.diffusion.cuda()

self.ema = None
if self.cfg.use_ema:
Expand Down
1 change: 1 addition & 0 deletions lerobot/common/policies/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ def make_policy(cfg):

policy = DiffusionPolicy(
cfg=cfg.policy,
cfg_device=cfg.device,
cfg_noise_scheduler=cfg.noise_scheduler,
cfg_rgb_model=cfg.rgb_model,
cfg_obs_encoder=cfg.obs_encoder,
Expand Down
Loading