|
20 | 20 |
|
21 | 21 | @MODELS.register_module() |
22 | 22 | class DWPoseDistiller(BaseModel, metaclass=ABCMeta): |
23 | | - """Base distiller for detectors. |
24 | | -
|
25 | | - It typically consists of teacher_model and student_model. |
| 23 | + """Distiller introduced in `DWPose`_ by Yang et al (2023). This distiller |
| 24 | + is designed for distillation of RTMPose. |
| 25 | +
|
| 26 | + It typically consists of teacher_model and student_model. Please use the |
| 27 | + script `tools/misc/pth_transfer.py` to transfer the distilled model to the |
| 28 | + original RTMPose model. |
| 29 | +
|
| 30 | + Args: |
| 31 | + teacher_cfg (str): Config file of the teacher model. |
| 32 | + student_cfg (str): Config file of the student model. |
| 33 | + two_dis (bool): Whether this is the second stage of distillation. |
| 34 | + Defaults to False. |
| 35 | + distill_cfg (dict): Config for distillation. Defaults to None. |
| 36 | + teacher_pretrained (str): Path of the pretrained teacher model. |
| 37 | + Defaults to None. |
| 38 | + train_cfg (dict, optional): The runtime config for training process. |
| 39 | + Defaults to ``None`` |
| 40 | + data_preprocessor (dict, optional): The data preprocessing config to |
| 41 | + build the instance of :class:`BaseDataPreprocessor`. Defaults to |
| 42 | + ``None`` |
| 43 | + init_cfg (dict, optional): The config to control the initialization. |
| 44 | + Defaults to ``None`` |
| 45 | +
|
| 46 | + .. _`DWPose`: https://arxiv.org/abs/2307.15880 |
26 | 47 | """ |
27 | 48 |
|
28 | 49 | def __init__(self, |
@@ -70,6 +91,10 @@ def init_weights(self): |
70 | 91 | self.student.init_weights() |
71 | 92 |
|
72 | 93 | def set_epoch(self): |
| 94 | + """Set epoch for distiller. |
| 95 | +
|
| 96 | + Used for the decay of distillation loss. |
| 97 | + """ |
73 | 98 | self.message_hub = MessageHub.get_current_instance() |
74 | 99 | self.epoch = self.message_hub.get_info('epoch') |
75 | 100 | self.max_epochs = self.message_hub.get_info('max_epochs') |
@@ -143,6 +168,26 @@ def loss(self, inputs: Tensor, data_samples: SampleList) -> dict: |
143 | 168 | return losses |
144 | 169 |
|
145 | 170 | def predict(self, inputs, data_samples): |
| 171 | + """Predict results from a batch of inputs and data samples with post- |
| 172 | + processing. |
| 173 | +
|
| 174 | + Args: |
| 175 | + inputs (Tensor): Inputs with shape (N, C, H, W) |
| 176 | + data_samples (List[:obj:`PoseDataSample`]): The batch |
| 177 | + data samples |
| 178 | +
|
| 179 | + Returns: |
| 180 | + list[:obj:`PoseDataSample`]: The pose estimation results of the |
| 181 | + input images. The return value is `PoseDataSample` instances with |
| 182 | + ``pred_instances`` and ``pred_fields``(optional) field , and |
| 183 | + ``pred_instances`` usually contains the following keys: |
| 184 | +
|
| 185 | + - keypoints (Tensor): predicted keypoint coordinates in shape |
| 186 | + (num_instances, K, D) where K is the keypoint number and D |
| 187 | + is the keypoint dimension |
| 188 | + - keypoint_scores (Tensor): predicted keypoint scores in shape |
| 189 | + (num_instances, K) |
| 190 | + """ |
146 | 191 | if self.two_dis: |
147 | 192 | assert self.student.with_head, ( |
148 | 193 | 'The model must have head to perform prediction.') |
@@ -171,10 +216,16 @@ def predict(self, inputs, data_samples): |
171 | 216 | return self.student.predict(inputs, data_samples) |
172 | 217 |
|
173 | 218 | def extract_feat(self, inputs: Tensor) -> Tuple[Tensor]: |
174 | | - x = self.teacher.extract_feat(inputs) |
175 | | - if self.student.with_neck: |
176 | | - x = self.neck(x) |
| 219 | + """Extract features. |
177 | 220 |
|
| 221 | + Args: |
| 222 | + inputs (Tensor): Image tensor with shape (N, C, H ,W). |
| 223 | +
|
| 224 | + Returns: |
| 225 | + tuple[Tensor]: Multi-level features that may have various |
| 226 | + resolutions. |
| 227 | + """ |
| 228 | + x = self.teacher.extract_feat(inputs) |
178 | 229 | return x |
179 | 230 |
|
180 | 231 | def head_loss( |
@@ -227,5 +278,13 @@ def head_loss( |
227 | 278 | return losses, pred_simcc, gt_simcc, keypoint_weights |
228 | 279 |
|
229 | 280 | def _forward(self, inputs: Tensor): |
| 281 | + """Network forward process. Usually includes backbone, neck and head |
| 282 | + forward without any post-processing. |
230 | 283 |
|
| 284 | + Args: |
| 285 | + inputs (Tensor): Inputs with shape (N, C, H, W). |
| 286 | +
|
| 287 | + Returns: |
| 288 | + Union[Tensor | Tuple[Tensor]]: forward output of the network. |
| 289 | + """ |
231 | 290 | return self.student._forward(inputs) |
0 commit comments