Skip to content

Add DeiT Model #2203

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
May 30, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
da66882
Add deit model basis to keras hub, its basic utilitites, backbone, la…
Sohaib-Ahmed21 Mar 28, 2025
aa4affa
update conversion script
Sohaib-Ahmed21 Mar 29, 2025
5e16661
Add DeiT model with backbone, layers, tests.
Sohaib-Ahmed21 Apr 3, 2025
ab0175a
Removed print statements.
Sohaib-Ahmed21 Apr 4, 2025
616acd6
Update deit_backbone_test.py to include exact shape
Sohaib-Ahmed21 Apr 5, 2025
7071e1e
Resolved failing test cases.
Sohaib-Ahmed21 Apr 17, 2025
7b027a4
Fix failing test cases
Sohaib-Ahmed21 Apr 18, 2025
46f3b4a
Merge branch 'keras-team:master' into deit
Sohaib-Ahmed21 Apr 19, 2025
b102d37
Solve jax failing tests
Sohaib-Ahmed21 Apr 22, 2025
5c138b0
Merge branch 'master' into deit
Sohaib-Ahmed21 Apr 23, 2025
12d5fec
Add checkpoint conversion, presets and customize image converter to f…
Sohaib-Ahmed21 Apr 24, 2025
a7bac11
Clean model code for input consistency
Sohaib-Ahmed21 Apr 25, 2025
c3c03ce
Merge branch 'keras-team:master' into deit
Sohaib-Ahmed21 Apr 27, 2025
73651fc
Refactor api imports to fix pre-commit api_gen tests
Sohaib-Ahmed21 Apr 27, 2025
49b80af
Merge branch 'keras-team:master' into deit
Sohaib-Ahmed21 May 1, 2025
a835e8c
Allow accepting non-square images
Sohaib-Ahmed21 May 1, 2025
2f1464c
Make DeiTImageConverter empty to match other classifiers' converters …
Sohaib-Ahmed21 May 1, 2025
c5b7454
Enable quantization check in backbone test and validate params in che…
Sohaib-Ahmed21 May 1, 2025
da13993
Update deit_presets.py to include underscore in names
Sohaib-Ahmed21 May 12, 2025
0f33986
Update preset_loader.py to properly configure num_classes for image c…
Sohaib-Ahmed21 May 12, 2025
6601bbd
Merge branch 'keras-team:master' into deit
Sohaib-Ahmed21 May 12, 2025
4359c14
Merge branch 'keras-team:master' into deit
Sohaib-Ahmed21 May 14, 2025
f3b1fdc
Update convert_deit_checkpoints.py to correctly include scale and off…
Sohaib-Ahmed21 May 19, 2025
db6a7e7
Slight change for syntax correction
Sohaib-Ahmed21 May 19, 2025
ae1e660
Merge branch 'keras-team:master' into deit
Sohaib-Ahmed21 May 19, 2025
3f648ea
Reformat code
Sohaib-Ahmed21 May 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add deit model basis to keras hub, its basic utilitites, backbone, la…
…yers.
  • Loading branch information
Sohaib-Ahmed21 committed Mar 28, 2025
commit da66882886ecc522153383327f0bf350c56ae42e
3 changes: 3 additions & 0 deletions keras_hub/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@
RetinaNetImageConverter,
)
from keras_hub.src.models.sam.sam_image_converter import SAMImageConverter
from keras_hub.src.models.deit.deit_image_converter import (
DeiTImageConverter,
)
from keras_hub.src.models.sam.sam_mask_decoder import SAMMaskDecoder
from keras_hub.src.models.sam.sam_prompt_encoder import SAMPromptEncoder
from keras_hub.src.models.segformer.segformer_image_converter import (
Expand Down
3 changes: 3 additions & 0 deletions keras_hub/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@
from keras_hub.src.models.densenet.densenet_image_classifier_preprocessor import (
DenseNetImageClassifierPreprocessor,
)
from keras_hub.src.models.deit.deit_backbone import(
DeiTBackbone,
)
from keras_hub.src.models.distil_bert.distil_bert_backbone import (
DistilBertBackbone,
)
Expand Down
Empty file.
136 changes: 136 additions & 0 deletions keras_hub/src/models/deit/deit_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import keras

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.backbone import Backbone
from keras_hub.src.utils.keras_utils import standardize_data_format
from keras_hub.src.models.deit.deit_layers import DeiTEncoder
from keras_hub.src.models.deit.deit_layers import DeiTEmbeddings

@keras_hub_export("keras_hub.models.DeiTBackbone")
class DeiTBackbone(Backbone):
"""DeiT backbone.

This backbone implements the Data-efficient Image Transformer (DeiT) architecture
as described in [Training data-efficient image transformers & distillation through
attention](https://arxiv.org/abs/2012.12877).

Args:
image_shape: A tuple or list of 3 integers representing the shape of the
input image `(height, width, channels)`, `height` and `width` must
be equal.
patch_size: (int, int). The size of each image patch, the input image
will be divided into patches of shape `(patch_size_h, patch_size_w)`.
num_layers: int. The number of transformer encoder layers.
num_heads: int. The number of attention heads in each Transformer encoder layer.
hidden_dim: int. The dimensionality of the hidden representations.
mlp_dim: int. The dimensionality of the intermediate MLP layer in
each Transformer encoder layer.
dropout_rate: float. The dropout rate for the Transformer encoder layers.
attention_dropout: float. The dropout rate for the attention mechanism
in each Transformer encoder layer.
layer_norm_epsilon: float. Value used for numerical stability in layer normalization.
use_mha_bias: bool. Whether to use bias in the multi-head attention layers.
use_mlp_bias: bool. Whether to use bias in the MLP layers.
use_distillation_token: bool. Whether to include a distillation token for training.
data_format: str. `"channels_last"` or `"channels_first"`, specifying
the data format for the input image. If `None`, defaults to `"channels_last"`.
dtype: The dtype of the layer weights. Defaults to None.
output_hidden_states: Whether to return hidden states from all encoder layers. Defaults to False.
return_attention_scores: Whether to return attention scores from all self-attention layers. Defaults to False.
**kwargs: Additional keyword arguments to be passed to the parent `Backbone` class.
"""

def __init__(
self,
image_shape,
patch_size,
num_layers,
num_heads,
hidden_dim,
intermediate_dim,
dropout_rate=0.0,
drop_path_rate=0.0,
attention_dropout=0.0,
layer_norm_epsilon=1e-6,
use_mha_bias=True,
use_distillation_token=False,
data_format=None,
dtype=None,
output_hidden_states=False,
return_attention_scores=False,
**kwargs,
):
# === Set Data Format ===
data_format = standardize_data_format(data_format)
h_axis, w_axis, channels_axis = (
(-3, -2, -1) if data_format == "channels_last" else (-2, -1, -3)
)

# Validate input image shape
if image_shape[h_axis] is None or image_shape[w_axis] is None:
raise ValueError(
f"Image shape must have defined height and width. Found `None` "
f"at index {h_axis} (height) or {w_axis} (width). "
f"Image shape: {image_shape}"
)

if image_shape[h_axis] % patch_size[0] != 0:
raise ValueError(
f"Input height {image_shape[h_axis]} should be divisible by "
f"patch size {patch_size[0]}."
)

if image_shape[w_axis] % patch_size[1] != 0:
raise ValueError(
f"Input width {image_shape[w_axis]} should be divisible by "
f"patch size {patch_size[1]}."
)

num_channels = image_shape[channels_axis]

# === Functional Model ===
inputs = keras.layers.Input(shape=image_shape)

x = DeiTEmbeddings(
image_size=(image_shape[h_axis], image_shape[w_axis]),
patch_size=patch_size,
hidden_dim=hidden_dim,
num_channels=num_channels,
data_format=data_format,
dropout_rate=dropout_rate,
dtype=dtype,
name="deit_patching_and_embedding",
)(inputs)

output, all_hidden_states, all_attention_scores = DeiTEncoder(
num_layers=num_layers,
num_heads=num_heads,
hidden_dim=hidden_dim,
intermediate_dim=intermediate_dim,
use_mha_bias=use_mha_bias,
dropout_rate=dropout_rate,
attention_dropout=attention_dropout,
layer_norm_epsilon=layer_norm_epsilon,
dtype=dtype,
name="deit_encoder"
)(x, output_hidden_states=output_hidden_states, return_attention_scores=return_attention_scores)

super().__init__(
inputs=inputs,
outputs=output,
dtype=dtype,
**kwargs,
)

# === Config ===
self.image_shape = image_shape
self.patch_size = patch_size
self.num_layers = num_layers
self.num_heads = num_heads
self.hidden_dim = hidden_dim
self.intermediate_dim = intermediate_dim
self.dropout_rate = dropout_rate
self.attention_dropout = attention_dropout
self.layer_norm_epsilon = layer_norm_epsilon
self.use_mha_bias = use_mha_bias
self.data_format = data_format
64 changes: 64 additions & 0 deletions keras_hub/src/models/deit/deit_image_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.utils.tensor_utils import preprocessing_function
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
from keras_hub.src.models.deit.deit_backbone import DeiTBackbone


@keras_hub_export("keras_hub.layers.DeiTImageConverter")
class DeiTImageConverter(ImageConverter):
"""Converts images to the format expected by a DeiT model.
This layer performs image normalization using mean and standard deviation
values.
Args:
norm_mean: list or tuple of floats. Mean values for image normalization.
Defaults to `[0.5, 0.5, 0.5]`.
norm_std: list or tuple of floats. Standard deviation values for
image normalization. Defaults to `[0.5, 0.5, 0.5]`.
**kwargs: Additional keyword arguments passed to
`keras_hub.layers.preprocessing.ImageConverter`.
Examples:
```python
import keras
import numpy as np
from keras_hub.src.layers import DeiTImageConverter
# Example image (replace with your actual image data)
image = np.random.rand(1, 224, 224, 3) # Example: (B, H, W, C)
# Create a DeiTImageConverter instance
converter = DeiTImageConverter(
image_size=(28,28),
scale=1/255.
)
# Preprocess the image
preprocessed_image = converter(image)
```
"""

backbone_cls = DeiTBackbone

def __init__(
self, norm_mean=[0.5, 0.5, 0.5], norm_std=[0.5, 0.5, 0.5], **kwargs
):
super().__init__(**kwargs)
self.norm_mean = norm_mean
self.norm_std = norm_std

@preprocessing_function
def call(self, inputs):
x = super().call(inputs)
# By default normalize using imagenet mean and std
if self.norm_mean:
x = x - self._expand_non_channel_dims(self.norm_mean, x)
if self.norm_std:
x = x / self._expand_non_channel_dims(self.norm_std, x)

return x

def get_config(self):
config = super().get_config()
config.update(
{
"norm_mean": self.norm_mean,
"norm_std": self.norm_std,
}
)
return config
Loading