Skip to content

Commit 2baed6c

Browse files
committed
final commit
1 parent 2f551c3 commit 2baed6c

File tree

149 files changed

+110702
-2
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

149 files changed

+110702
-2
lines changed

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@ __pycache__/
33
*.py[cod]
44
*$py.class
55

6+
./output/*
7+
./datasets/raw/*
8+
./datasets/processed/*
9+
./trained_models/*
10+
611
# C extensions
712
*.so
813

Inference Example.ipynb

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "553b4359",
6+
"metadata": {},
7+
"source": [
8+
"# Perform inference on your data with SkeletonDiffusion\n",
9+
"\n",
10+
"If your data are in the same skeleton format as our trained model, you can perform inference from your data.\n",
11+
"Give a sequence of keypoints representing the past, and run SkeletonDiffusion to predict future motions!\n",
12+
"\n",
13+
"SkeletonDiffusion can run on the output of other models, for example methods for human pose estimation from images or video.\n",
14+
"For an example, check out our demo on Huggingface."
15+
]
16+
},
17+
{
18+
"cell_type": "markdown",
19+
"id": "b1f698f1",
20+
"metadata": {},
21+
"source": [
22+
"## Select model and data type\n",
23+
"Here we take as an example our model trained on AMASS, which follows the same parametrization (skeleton format) as SMPL."
24+
]
25+
},
26+
{
27+
"cell_type": "code",
28+
"execution_count": 1,
29+
"id": "65d06fd4",
30+
"metadata": {},
31+
"outputs": [],
32+
"source": [
33+
"# choose between 'amass' and 'amass-mano'\n",
34+
"# model_dataset = 'amass' \n",
35+
"model_dataset = 'amass-mano'\n",
36+
"\n",
37+
"checkpoint_path = f'./trained_models/hmp/{model_dataset}/diffusion/checkpoints/cvpr_release.pt'\n",
38+
"num_samples = 50"
39+
]
40+
},
41+
{
42+
"cell_type": "markdown",
43+
"id": "78b643df",
44+
"metadata": {},
45+
"source": [
46+
"## Load model's weights"
47+
]
48+
},
49+
{
50+
"cell_type": "code",
51+
"execution_count": 2,
52+
"id": "e927d3c2",
53+
"metadata": {},
54+
"outputs": [],
55+
"source": [
56+
"import os\n",
57+
"import torch\n",
58+
"import numpy as np\n",
59+
"import random\n",
60+
"\n",
61+
"os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n",
62+
"\n",
63+
"from src.eval_prepare_model import prepare_model, get_prediction, load_model_config_exp\n",
64+
"from src.data import create_skeleton\n",
65+
"\n",
66+
"\n",
67+
"def set_seed(seed=0):\n",
68+
" torch.use_deterministic_algorithms(True)\n",
69+
" torch.backends.cudnn.deterministic = True\n",
70+
" torch.backends.cudnn.benchmark = False\n",
71+
" np.random.seed(seed)\n",
72+
" random.seed(seed)\n",
73+
" torch.cuda.manual_seed(seed)\n",
74+
" torch.cuda.manual_seed_all(seed)\n"
75+
]
76+
},
77+
{
78+
"cell_type": "code",
79+
"execution_count": 3,
80+
"id": "c6ce7b0a",
81+
"metadata": {},
82+
"outputs": [
83+
{
84+
"name": "stdout",
85+
"output_type": "stream",
86+
"text": [
87+
"> GPU 0 ready: Quadro RTX 5000\n",
88+
"> GPU 1 ready: Quadro P400\n",
89+
"Loading Autoencoder checkpoint: ./trained_models/hmp/amass-mano/autoencoder/checkpoints/cvpr_release.pt ...\n",
90+
"Diffusion is_ddim_sampling: False\n",
91+
"Loading Diffusion checkpoint: ./trained_models/hmp/amass-mano/diffusion/checkpoints/cvpr_release.pt ...\n"
92+
]
93+
}
94+
],
95+
"source": [
96+
"set_seed(seed=0)\n",
97+
"\n",
98+
"config, exp_folder = load_model_config_exp(checkpoint_path)\n",
99+
"config['checkpoint_path'] = checkpoint_path\n",
100+
"skeleton = create_skeleton(**config) \n",
101+
"\n",
102+
"\n",
103+
"model, device, *_ = prepare_model(config, skeleton, **config)"
104+
]
105+
},
106+
{
107+
"cell_type": "markdown",
108+
"id": "ae50633b",
109+
"metadata": {},
110+
"source": [
111+
"## Load given example or use your own"
112+
]
113+
},
114+
{
115+
"cell_type": "code",
116+
"execution_count": null,
117+
"id": "5c4aa1a7",
118+
"metadata": {},
119+
"outputs": [],
120+
"source": [
121+
"# prepare input\n",
122+
"# load input. Unit must be in meters\n",
123+
"# obs sequence contains the hip or root joint, it has not been dropped yet. \n",
124+
"obs = np.load(f'figures/example_obs_{model_dataset}.npy') # (t_past, J, 3)\n",
125+
"\n",
126+
"obs = torch.from_numpy(obs).to(device).float()\n",
127+
"# obs = obs.unsqueeze(0) # remember to add batch size if not present"
128+
]
129+
},
130+
{
131+
"cell_type": "code",
132+
"execution_count": 5,
133+
"id": "e782b749",
134+
"metadata": {},
135+
"outputs": [],
136+
"source": [
137+
"obs_in = skeleton.tranform_to_input_space(obs) \n",
138+
"pred = get_prediction(obs_in, model, num_samples=num_samples, **config) # [batch_size, n_samples, seq_length, num_joints, features]\n",
139+
"pred = skeleton.transform_to_metric_space(pred)"
140+
]
141+
},
142+
{
143+
"cell_type": "code",
144+
"execution_count": 6,
145+
"id": "f91ee849",
146+
"metadata": {},
147+
"outputs": [],
148+
"source": [
149+
"# Proceed to your own task.\n",
150+
"# For example, you can rank the output by the one with least limb stretching.\n",
151+
"# Checkout other metrics in src.metrics \n",
152+
"# or the diversity ranking in metrics/utils.py (see example in other notebook)\n",
153+
"\n",
154+
"\n",
155+
"from src.metrics.body_realism import limb_stretching_normed_mean\n",
156+
"\n",
157+
"\n",
158+
"limbstretching = limb_stretching_normed_mean(pred, target=obs[..., 1:, :][0].unsqueeze(1), limbseq=skeleton.get_limbseq(), reduction='persample', obs_as_target=True)\n",
159+
"limbstretching_sorted, indices = torch.sort(limbstretching.squeeze(1), dim=-1, descending=False) \n"
160+
]
161+
}
162+
],
163+
"metadata": {
164+
"kernelspec": {
165+
"display_name": "skeldiff4",
166+
"language": "python",
167+
"name": "python3"
168+
},
169+
"language_info": {
170+
"codemirror_mode": {
171+
"name": "ipython",
172+
"version": 3
173+
},
174+
"file_extension": ".py",
175+
"mimetype": "text/x-python",
176+
"name": "python",
177+
"nbconvert_exporter": "python",
178+
"pygments_lexer": "ipython3",
179+
"version": "3.10.13"
180+
}
181+
},
182+
"nbformat": 4,
183+
"nbformat_minor": 5
184+
}

LICENSE

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
BSD 2-Clause License
2+
3+
Copyright (c) 2025, Cecilia Curreli
4+
5+
Redistribution and use in source and binary forms, with or without
6+
modification, are permitted provided that the following conditions are met:
7+
8+
1. Redistributions of source code must retain the above copyright notice, this
9+
list of conditions and the following disclaimer.
10+
11+
2. Redistributions in binary form must reproduce the above copyright notice,
12+
this list of conditions and the following disclaimer in the documentation
13+
and/or other materials provided with the distribution.
14+
15+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
16+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
19+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
20+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
21+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
22+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
23+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
24+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

0 commit comments

Comments
 (0)