This repository implements a Diffusion Transformer (DiT) for long-term time series forecasting. Unlike standard diffusion models that start from random Gaussian noise, our approach leverages multi-scale moving average as a structured corruption process. The key insight is that exponential moving averages (EMA) with varying smoothing parameters create a natural hierarchy from heavily smoothed signals (capturing trends) to the original signal (containing fine-grained details).
The core architecture is a Diffusion Transformer that iteratively refines predictions from a smoothed state to the target signal. The model consists of the following components:
Given an input historical sequence
where $\mathbf{W}_x, \mathbf{W}y \in \mathbb{R}^{d{model} \times C}$ are learnable projection matrices.
The smoothing level
This embedding is then passed through an MLP:
We apply RoPE to encode positional information in the attention mechanism. For position
The rotation is applied to query and key vectors:
where
Each DiT block consists of three sub-layers with adaptive layer normalization (AdaLN):
Adaptive Layer Normalization:
Given the timestep embedding
Self-Attention with AdaLN:
Cross-Attention:
where queries come from the prediction sequence and keys/values come from the historical context.
Feed-Forward Network with AdaLN:
The MLP uses GELU activation:
The output projection applies final adaptive normalization:
where
Unlike standard diffusion that uses Gaussian noise, we use multi-scale moving average (MA) as the corruption process. This creates a structured denoising path from heavily smoothed signals to the original signal.
For a sequence of length
For kernel size
The columns are then interpolated to match the sequence length, producing a square
Given an alpha schedule
where
A key challenge arises during inference: when
Problem: At
Solution: We define the delta (drift) as the difference between the mean and the last observation:
where $\mathbf{x}{T{in}}$ is the last value of the context sequence. We then subtract a linearly interpolated drift term from the smoothed signal:
This ensures:
- At
$\alpha = 0$ :$\tilde{\mathbf{y}}^{(0)} = \mathbf{y}$ (original signal, no drift subtracted) - At
$\alpha = 1$ : $\tilde{\mathbf{y}}^{(1)} = \bar{\mathbf{y}} - \boldsymbol{\delta} = \mathbf{x}{T{in}}$ (signal anchored to last observation)
Interpretation: Instead of denoising from an unknown mean, the model learns to denoise from a known starting point (the last observation) by gradually adding the delta over the reverse diffusion process. The model effectively learns:
Properties:
-
$\alpha \approx 0$ : Near identity transformation (original signal) -
$\alpha \approx 1$ : Signal anchored to last observation $\mathbf{x}{T{in}}$
Algorithm 1: Training with Step-wise Trajectory Supervision
Input: Training data {(x, y)}, Model f_θ, Alpha schedule {α_0, ..., α_A}
Output: Trained model parameters θ
for each epoch do
for each batch (x, y) do
// Step 1: Compute EMA-corrupted targets for all alpha levels
y_full ← concat(x[:, -1:], y) // Prepend last observation
{y^(α_k)}_{k=0}^{A} ← ComputeMA(y_full) // Multi-scale MA
// Step 2: Sample random starting alpha index
k_start ~ Uniform(1, A)
y_current ← y^(α_{k_start})
// Step 3: Reverse trajectory from α_{k_start} to α_0
L_traj ← 0
for k = k_start down to 1 do
y_current ← f_θ(y_current, x, α_k) // One denoising step
L_step ← ||y_current - y^(α_{k-1})||² // Intermediate supervision
L_traj ← L_traj + L_step
end for
L_traj ← L_traj / k_start
// Step 4: Final prediction loss
L_end ← ||y_current - y||²
// Step 5: Total loss
L ← λ_traj · L_traj + λ_end · L_end
// Step 6: Update parameters
θ ← θ - η∇_θL
end for
end for
Algorithm 2: Iterative Refinement Sampling
Input: Historical sequence x, Trained model f_θ, Alpha schedule {α_1, ..., α_A}
Output: Predicted sequence ŷ
// Step 1: Initialize prediction (from last observation)
ŷ ← repeat(x[:, -1], T_pred) // Constant initialization
// Step 2: Reverse diffusion (from high α to low α)
for k = A down to 1 do
ŷ ← f_θ(ŷ, x, α_k) // Iterative refinement
end for
return ŷ
Initialization Modes:
- Mode 0 (Default): Use last observation as constant: $\hat{\mathbf{y}}0 = \mathbf{x}{T_{in}} \cdot \mathbf{1}^T$
- Mode 1 (Oracle): Use EMA of ground truth (for validation analysis)
- Mode 2 (Learned): Use a learned statistical predictor for the mean
-
Structured Corruption: Moving average provides a semantically meaningful corruption path - from trends to details - unlike random Gaussian noise.
-
Delta-Based Drift Correction: The key innovation is subtracting a linearly interpolated drift term
$\alpha \cdot \boldsymbol{\delta}$ from the smoothed signal. This anchors the diffusion process at$\alpha=1$ to the last observation (which is known at inference time) rather than the unknown future mean. The model learns to progressively add the delta back during denoising, effectively predicting the deviation from the last observation. -
Known Starting Point: At inference time, we initialize from the last observation $\mathbf{x}{T{in}}$ and denoise toward the target. This is possible because the drift correction ensures that the corrupted signal at
$\alpha=1$ equals the last observation, eliminating the need to know the future mean. -
Step-wise Supervision: Training with intermediate targets at each alpha level provides dense supervision, improving gradient flow and convergence.
-
Cross-Attention Conditioning: Historical context guides the denoising process through cross-attention, allowing the model to leverage temporal patterns from the past.
| Parameter | Description | Default |
|---|---|---|
hidden_dim |
Transformer hidden dimension | 32 |
num_heads |
Number of attention heads | 4 |
num_dit_block |
Number of DiT blocks | 4 |
mlp_ratio |
MLP hidden dim multiplier | 4.0 |
interval |
Alpha step size (1/steps) | 0.01 |
seq_len |
Input sequence length | 96 |
pred_len |
Prediction horizon | 96 |
lambda_traj |
Trajectory loss weight | 1.0 |
lambda_end |
Final prediction loss weight | 1.0 |
python run.py \
--task_name long_term_forecast \
--is_training 1 \
--model_id ETTh1_96_96 \
--model Ours \
--data ETTh1 \
--root_path ./data/ETT/ \
--data_path ETTh1.csv \
--seq_len 96 \
--pred_len 96 \
--feature_dim 7 \
--hidden_dim 32 \
--num_heads 4 \
--num_dit_block 4 \
--interval 0.01 \
--batch_size 16 \
--learning_rate 0.001 \
--train_epochs 100python run.py \
--task_name long_term_forecast \
--is_training 0 \
--model_id ETTh1_96_96 \
--model Ours \
--data ETTh1 \
--root_path ./data/ETT/ \
--data_path ETTh1.csv \
--seq_len 96 \
--pred_len 96├── models/
│ ├── Ours.py # DiT model implementation
│ ├── TimeMixer.py # Alternative baseline
│ └── __init__.py
├── exp/
│ ├── exp_long_term_forecasting.py # Main training loop
│ └── global_loss.py # Alternative training approach
├── data_provider/
│ ├── data_factory.py
│ └── data_loader.py
├── layers/
│ ├── SelfAttention_Family.py
│ ├── Embed.py
│ └── Transformer_EncDec.py
├── utils/
│ ├── tools.py
│ ├── metrics.py
│ └── losses.py
├── scripts/
│ └── main_script.sh
└── run.py # Entry point
Forward Process (Corruption with Drift Correction):
where
Reverse Process (Denoising):
Starting from the last observation at
Training Objective:
where