|
23 | 23 |
|
24 | 24 |
|
25 | 25 | class ValueGuidedRLPipeline(DiffusionPipeline):
|
| 26 | + r""" |
| 27 | + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the |
| 28 | + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) |
| 29 | + Pipeline for sampling actions from a diffusion model trained to predict sequences of states. |
| 30 | +
|
| 31 | + Original implementation inspired by this repository: https://github.com/jannerm/diffuser. |
| 32 | +
|
| 33 | + Parameters: |
| 34 | + value_function ([`UNet1DModel`]): A specialized UNet for fine-tuning trajectories base on reward. |
| 35 | + unet ([`UNet1DModel`]): U-Net architecture to denoise the encoded trajectories. |
| 36 | + scheduler ([`SchedulerMixin`]): |
| 37 | + A scheduler to be used in combination with `unet` to denoise the encoded trajectories. Default for this |
| 38 | + application is [`DDPMScheduler`]. |
| 39 | + env: An environment following the OpenAI gym API to act in. For now only Hopper has pretrained models. |
| 40 | + """ |
| 41 | + |
26 | 42 | def __init__(
|
27 | 43 | self,
|
28 | 44 | value_function: UNet1DModel,
|
@@ -78,21 +94,26 @@ def run_diffusion(self, x, conditions, n_guide_steps, scale):
|
78 | 94 | for _ in range(n_guide_steps):
|
79 | 95 | with torch.enable_grad():
|
80 | 96 | x.requires_grad_()
|
| 97 | + |
| 98 | + # permute to match dimension for pre-trained models |
81 | 99 | y = self.value_function(x.permute(0, 2, 1), timesteps).sample
|
82 | 100 | grad = torch.autograd.grad([y.sum()], [x])[0]
|
83 | 101 |
|
84 | 102 | posterior_variance = self.scheduler._get_variance(i)
|
85 | 103 | model_std = torch.exp(0.5 * posterior_variance)
|
86 | 104 | grad = model_std * grad
|
| 105 | + |
87 | 106 | grad[timesteps < 2] = 0
|
88 | 107 | x = x.detach()
|
89 | 108 | x = x + scale * grad
|
90 | 109 | x = self.reset_x0(x, conditions, self.action_dim)
|
| 110 | + |
91 | 111 | prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
|
92 |
| - # TODO: set prediction_type when instantiating the model |
| 112 | + |
| 113 | + # TODO: verify deprecation of this kwarg |
93 | 114 | x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"]
|
94 | 115 |
|
95 |
| - # apply conditions to the trajectory |
| 116 | + # apply conditions to the trajectory (set the initial state) |
96 | 117 | x = self.reset_x0(x, conditions, self.action_dim)
|
97 | 118 | x = self.to_torch(x)
|
98 | 119 | return x, y
|
@@ -126,5 +147,6 @@ def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, sca
|
126 | 147 | else:
|
127 | 148 | # if we didn't run value guiding, select a random action
|
128 | 149 | selected_index = np.random.randint(0, batch_size)
|
| 150 | + |
129 | 151 | denorm_actions = denorm_actions[selected_index, 0]
|
130 | 152 | return denorm_actions
|
0 commit comments