Skip to content

Commit 18842d6

Browse files
authored
[Fix] Fix bug of dwpose (#2728)
1 parent d85feef commit 18842d6

File tree

1 file changed

+65
-6
lines changed

1 file changed

+65
-6
lines changed

mmpose/models/distillers/dwpose_distiller.py

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,30 @@
2020

2121
@MODELS.register_module()
2222
class DWPoseDistiller(BaseModel, metaclass=ABCMeta):
23-
"""Base distiller for detectors.
24-
25-
It typically consists of teacher_model and student_model.
23+
"""Distiller introduced in `DWPose`_ by Yang et al (2023). This distiller
24+
is designed for distillation of RTMPose.
25+
26+
It typically consists of teacher_model and student_model. Please use the
27+
script `tools/misc/pth_transfer.py` to transfer the distilled model to the
28+
original RTMPose model.
29+
30+
Args:
31+
teacher_cfg (str): Config file of the teacher model.
32+
student_cfg (str): Config file of the student model.
33+
two_dis (bool): Whether this is the second stage of distillation.
34+
Defaults to False.
35+
distill_cfg (dict): Config for distillation. Defaults to None.
36+
teacher_pretrained (str): Path of the pretrained teacher model.
37+
Defaults to None.
38+
train_cfg (dict, optional): The runtime config for training process.
39+
Defaults to ``None``
40+
data_preprocessor (dict, optional): The data preprocessing config to
41+
build the instance of :class:`BaseDataPreprocessor`. Defaults to
42+
``None``
43+
init_cfg (dict, optional): The config to control the initialization.
44+
Defaults to ``None``
45+
46+
.. _`DWPose`: https://arxiv.org/abs/2307.15880
2647
"""
2748

2849
def __init__(self,
@@ -70,6 +91,10 @@ def init_weights(self):
7091
self.student.init_weights()
7192

7293
def set_epoch(self):
94+
"""Set epoch for distiller.
95+
96+
Used for the decay of distillation loss.
97+
"""
7398
self.message_hub = MessageHub.get_current_instance()
7499
self.epoch = self.message_hub.get_info('epoch')
75100
self.max_epochs = self.message_hub.get_info('max_epochs')
@@ -143,6 +168,26 @@ def loss(self, inputs: Tensor, data_samples: SampleList) -> dict:
143168
return losses
144169

145170
def predict(self, inputs, data_samples):
171+
"""Predict results from a batch of inputs and data samples with post-
172+
processing.
173+
174+
Args:
175+
inputs (Tensor): Inputs with shape (N, C, H, W)
176+
data_samples (List[:obj:`PoseDataSample`]): The batch
177+
data samples
178+
179+
Returns:
180+
list[:obj:`PoseDataSample`]: The pose estimation results of the
181+
input images. The return value is `PoseDataSample` instances with
182+
``pred_instances`` and ``pred_fields``(optional) field , and
183+
``pred_instances`` usually contains the following keys:
184+
185+
- keypoints (Tensor): predicted keypoint coordinates in shape
186+
(num_instances, K, D) where K is the keypoint number and D
187+
is the keypoint dimension
188+
- keypoint_scores (Tensor): predicted keypoint scores in shape
189+
(num_instances, K)
190+
"""
146191
if self.two_dis:
147192
assert self.student.with_head, (
148193
'The model must have head to perform prediction.')
@@ -171,10 +216,16 @@ def predict(self, inputs, data_samples):
171216
return self.student.predict(inputs, data_samples)
172217

173218
def extract_feat(self, inputs: Tensor) -> Tuple[Tensor]:
174-
x = self.teacher.extract_feat(inputs)
175-
if self.student.with_neck:
176-
x = self.neck(x)
219+
"""Extract features.
177220
221+
Args:
222+
inputs (Tensor): Image tensor with shape (N, C, H ,W).
223+
224+
Returns:
225+
tuple[Tensor]: Multi-level features that may have various
226+
resolutions.
227+
"""
228+
x = self.teacher.extract_feat(inputs)
178229
return x
179230

180231
def head_loss(
@@ -227,5 +278,13 @@ def head_loss(
227278
return losses, pred_simcc, gt_simcc, keypoint_weights
228279

229280
def _forward(self, inputs: Tensor):
281+
"""Network forward process. Usually includes backbone, neck and head
282+
forward without any post-processing.
230283
284+
Args:
285+
inputs (Tensor): Inputs with shape (N, C, H, W).
286+
287+
Returns:
288+
Union[Tensor | Tuple[Tensor]]: forward output of the network.
289+
"""
231290
return self.student._forward(inputs)

0 commit comments

Comments
 (0)