-
Notifications
You must be signed in to change notification settings - Fork 287
Add DeiT Model #2203
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add DeiT Model #2203
Changes from 1 commit
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
da66882
Add deit model basis to keras hub, its basic utilitites, backbone, la…
Sohaib-Ahmed21 aa4affa
update conversion script
Sohaib-Ahmed21 5e16661
Add DeiT model with backbone, layers, tests.
Sohaib-Ahmed21 ab0175a
Removed print statements.
Sohaib-Ahmed21 616acd6
Update deit_backbone_test.py to include exact shape
Sohaib-Ahmed21 7071e1e
Resolved failing test cases.
Sohaib-Ahmed21 7b027a4
Fix failing test cases
Sohaib-Ahmed21 46f3b4a
Merge branch 'keras-team:master' into deit
Sohaib-Ahmed21 b102d37
Solve jax failing tests
Sohaib-Ahmed21 5c138b0
Merge branch 'master' into deit
Sohaib-Ahmed21 12d5fec
Add checkpoint conversion, presets and customize image converter to f…
Sohaib-Ahmed21 a7bac11
Clean model code for input consistency
Sohaib-Ahmed21 c3c03ce
Merge branch 'keras-team:master' into deit
Sohaib-Ahmed21 73651fc
Refactor api imports to fix pre-commit api_gen tests
Sohaib-Ahmed21 49b80af
Merge branch 'keras-team:master' into deit
Sohaib-Ahmed21 a835e8c
Allow accepting non-square images
Sohaib-Ahmed21 2f1464c
Make DeiTImageConverter empty to match other classifiers' converters …
Sohaib-Ahmed21 c5b7454
Enable quantization check in backbone test and validate params in che…
Sohaib-Ahmed21 da13993
Update deit_presets.py to include underscore in names
Sohaib-Ahmed21 0f33986
Update preset_loader.py to properly configure num_classes for image c…
Sohaib-Ahmed21 6601bbd
Merge branch 'keras-team:master' into deit
Sohaib-Ahmed21 4359c14
Merge branch 'keras-team:master' into deit
Sohaib-Ahmed21 f3b1fdc
Update convert_deit_checkpoints.py to correctly include scale and off…
Sohaib-Ahmed21 db6a7e7
Slight change for syntax correction
Sohaib-Ahmed21 ae1e660
Merge branch 'keras-team:master' into deit
Sohaib-Ahmed21 3f648ea
Reformat code
Sohaib-Ahmed21 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next
Next commit
Add deit model basis to keras hub, its basic utilitites, backbone, la…
…yers.
- Loading branch information
commit da66882886ecc522153383327f0bf350c56ae42e
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
import keras | ||
|
||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.models.backbone import Backbone | ||
from keras_hub.src.utils.keras_utils import standardize_data_format | ||
from keras_hub.src.models.deit.deit_layers import DeiTEncoder | ||
from keras_hub.src.models.deit.deit_layers import DeiTEmbeddings | ||
|
||
@keras_hub_export("keras_hub.models.DeiTBackbone") | ||
class DeiTBackbone(Backbone): | ||
"""DeiT backbone. | ||
|
||
This backbone implements the Data-efficient Image Transformer (DeiT) architecture | ||
as described in [Training data-efficient image transformers & distillation through | ||
attention](https://arxiv.org/abs/2012.12877). | ||
|
||
Args: | ||
image_shape: A tuple or list of 3 integers representing the shape of the | ||
input image `(height, width, channels)`, `height` and `width` must | ||
be equal. | ||
patch_size: (int, int). The size of each image patch, the input image | ||
will be divided into patches of shape `(patch_size_h, patch_size_w)`. | ||
num_layers: int. The number of transformer encoder layers. | ||
num_heads: int. The number of attention heads in each Transformer encoder layer. | ||
hidden_dim: int. The dimensionality of the hidden representations. | ||
mlp_dim: int. The dimensionality of the intermediate MLP layer in | ||
each Transformer encoder layer. | ||
dropout_rate: float. The dropout rate for the Transformer encoder layers. | ||
attention_dropout: float. The dropout rate for the attention mechanism | ||
in each Transformer encoder layer. | ||
layer_norm_epsilon: float. Value used for numerical stability in layer normalization. | ||
use_mha_bias: bool. Whether to use bias in the multi-head attention layers. | ||
use_mlp_bias: bool. Whether to use bias in the MLP layers. | ||
use_distillation_token: bool. Whether to include a distillation token for training. | ||
data_format: str. `"channels_last"` or `"channels_first"`, specifying | ||
the data format for the input image. If `None`, defaults to `"channels_last"`. | ||
dtype: The dtype of the layer weights. Defaults to None. | ||
output_hidden_states: Whether to return hidden states from all encoder layers. Defaults to False. | ||
return_attention_scores: Whether to return attention scores from all self-attention layers. Defaults to False. | ||
**kwargs: Additional keyword arguments to be passed to the parent `Backbone` class. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
image_shape, | ||
patch_size, | ||
num_layers, | ||
num_heads, | ||
hidden_dim, | ||
intermediate_dim, | ||
dropout_rate=0.0, | ||
drop_path_rate=0.0, | ||
attention_dropout=0.0, | ||
layer_norm_epsilon=1e-6, | ||
use_mha_bias=True, | ||
use_distillation_token=False, | ||
data_format=None, | ||
dtype=None, | ||
output_hidden_states=False, | ||
return_attention_scores=False, | ||
**kwargs, | ||
): | ||
# === Set Data Format === | ||
data_format = standardize_data_format(data_format) | ||
h_axis, w_axis, channels_axis = ( | ||
(-3, -2, -1) if data_format == "channels_last" else (-2, -1, -3) | ||
) | ||
|
||
# Validate input image shape | ||
if image_shape[h_axis] is None or image_shape[w_axis] is None: | ||
raise ValueError( | ||
f"Image shape must have defined height and width. Found `None` " | ||
f"at index {h_axis} (height) or {w_axis} (width). " | ||
f"Image shape: {image_shape}" | ||
) | ||
|
||
if image_shape[h_axis] % patch_size[0] != 0: | ||
raise ValueError( | ||
f"Input height {image_shape[h_axis]} should be divisible by " | ||
f"patch size {patch_size[0]}." | ||
) | ||
|
||
if image_shape[w_axis] % patch_size[1] != 0: | ||
raise ValueError( | ||
f"Input width {image_shape[w_axis]} should be divisible by " | ||
f"patch size {patch_size[1]}." | ||
) | ||
|
||
num_channels = image_shape[channels_axis] | ||
|
||
# === Functional Model === | ||
inputs = keras.layers.Input(shape=image_shape) | ||
|
||
x = DeiTEmbeddings( | ||
image_size=(image_shape[h_axis], image_shape[w_axis]), | ||
patch_size=patch_size, | ||
hidden_dim=hidden_dim, | ||
num_channels=num_channels, | ||
data_format=data_format, | ||
dropout_rate=dropout_rate, | ||
dtype=dtype, | ||
name="deit_patching_and_embedding", | ||
)(inputs) | ||
|
||
output, all_hidden_states, all_attention_scores = DeiTEncoder( | ||
num_layers=num_layers, | ||
num_heads=num_heads, | ||
hidden_dim=hidden_dim, | ||
intermediate_dim=intermediate_dim, | ||
use_mha_bias=use_mha_bias, | ||
dropout_rate=dropout_rate, | ||
attention_dropout=attention_dropout, | ||
layer_norm_epsilon=layer_norm_epsilon, | ||
dtype=dtype, | ||
name="deit_encoder" | ||
)(x, output_hidden_states=output_hidden_states, return_attention_scores=return_attention_scores) | ||
|
||
super().__init__( | ||
inputs=inputs, | ||
outputs=output, | ||
dtype=dtype, | ||
**kwargs, | ||
) | ||
|
||
# === Config === | ||
self.image_shape = image_shape | ||
self.patch_size = patch_size | ||
self.num_layers = num_layers | ||
self.num_heads = num_heads | ||
self.hidden_dim = hidden_dim | ||
self.intermediate_dim = intermediate_dim | ||
self.dropout_rate = dropout_rate | ||
self.attention_dropout = attention_dropout | ||
self.layer_norm_epsilon = layer_norm_epsilon | ||
self.use_mha_bias = use_mha_bias | ||
self.data_format = data_format |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.utils.tensor_utils import preprocessing_function | ||
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter | ||
from keras_hub.src.models.deit.deit_backbone import DeiTBackbone | ||
|
||
|
||
@keras_hub_export("keras_hub.layers.DeiTImageConverter") | ||
class DeiTImageConverter(ImageConverter): | ||
Sohaib-Ahmed21 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Converts images to the format expected by a DeiT model. | ||
This layer performs image normalization using mean and standard deviation | ||
values. | ||
Args: | ||
norm_mean: list or tuple of floats. Mean values for image normalization. | ||
Defaults to `[0.5, 0.5, 0.5]`. | ||
norm_std: list or tuple of floats. Standard deviation values for | ||
image normalization. Defaults to `[0.5, 0.5, 0.5]`. | ||
**kwargs: Additional keyword arguments passed to | ||
`keras_hub.layers.preprocessing.ImageConverter`. | ||
Examples: | ||
```python | ||
import keras | ||
import numpy as np | ||
from keras_hub.src.layers import DeiTImageConverter | ||
# Example image (replace with your actual image data) | ||
image = np.random.rand(1, 224, 224, 3) # Example: (B, H, W, C) | ||
# Create a DeiTImageConverter instance | ||
converter = DeiTImageConverter( | ||
image_size=(28,28), | ||
scale=1/255. | ||
) | ||
# Preprocess the image | ||
preprocessed_image = converter(image) | ||
``` | ||
""" | ||
|
||
backbone_cls = DeiTBackbone | ||
|
||
def __init__( | ||
self, norm_mean=[0.5, 0.5, 0.5], norm_std=[0.5, 0.5, 0.5], **kwargs | ||
): | ||
super().__init__(**kwargs) | ||
self.norm_mean = norm_mean | ||
self.norm_std = norm_std | ||
|
||
@preprocessing_function | ||
def call(self, inputs): | ||
x = super().call(inputs) | ||
# By default normalize using imagenet mean and std | ||
if self.norm_mean: | ||
Sohaib-Ahmed21 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
x = x - self._expand_non_channel_dims(self.norm_mean, x) | ||
if self.norm_std: | ||
x = x / self._expand_non_channel_dims(self.norm_std, x) | ||
|
||
return x | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"norm_mean": self.norm_mean, | ||
"norm_std": self.norm_std, | ||
} | ||
) | ||
return config |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.