Skip to content

Commit bc2460e

Browse files
committed
minor
1 parent c2dd1a4 commit bc2460e

File tree

1 file changed

+126
-0
lines changed

1 file changed

+126
-0
lines changed

petl/petl_enc_model.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import torch
2+
from transformers import MBartPreTrainedModel, RobertaConfig
3+
import torch.nn as nn
4+
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
5+
from petl.petl_factory import Prefix, MLP_Bias, Bias, PrefixDirectInit, PrefixCrossAttn
6+
from transformers.utils import logging
7+
logger = logging.get_logger(__name__)
8+
9+
10+
class PETLEncModel(MBartPreTrainedModel):
11+
def __init__(self, config, args, pretrained_model, **kwargs):
12+
super().__init__(config)
13+
self.args = args
14+
self.pretrained_model = pretrained_model
15+
16+
if isinstance(config, RobertaConfig):
17+
self.match_n_layer = config.num_hidden_layers
18+
self.match_n_head = config.num_attention_heads
19+
self.n_embd = config.hidden_size
20+
else:
21+
self.match_n_layer = config.decoder_layers
22+
self.match_n_head = config.decoder_attention_heads
23+
self.n_embd = config.d_model
24+
self.match_n_embd = self.n_embd // self.match_n_head
25+
26+
if "prefix" in args.attn_mode:
27+
self.setup_prefix(args, config)
28+
elif args.attn_mode == 'bitfit' or args.attn_mode == 'adapter':
29+
self.get_prompt = self.get_fake_prompt
30+
elif args.attn_mode == 'none':
31+
# includes only with ffn mode
32+
self.get_prompt = self.get_fake_prompt
33+
elif args.attn_mode == "prompt_tuning":
34+
self.get_prompt = self.get_fake_prompt
35+
elif args.attn_mode == "lora":
36+
self.get_prompt = self.get_fake_prompt
37+
else:
38+
raise ValueError
39+
40+
logger.info("Declare PrefixTuning model!")
41+
42+
not_freeze_set = []
43+
if args.unfreeze_params != 'none' and args.attn_mode != 'bitfit':
44+
if args.unfreeze_params == 'LN':
45+
# not_freeze_set = ['layernorm'] # input layernorm
46+
not_freeze_set = ['attn_layer_norm'] # only optimize layer norm after attn
47+
else:
48+
not_freeze_set = args.unfreeze_params.split(',')
49+
all_match = False
50+
elif args.attn_mode == 'bitfit':
51+
not_freeze_set = ['bias']
52+
all_match = True
53+
54+
logger.info(not_freeze_set)
55+
56+
freeze_set = []
57+
if args.ffn_mode == 'mh_adapter_random' or args.attn_option == 'mh_adapter':
58+
# freeze the random mapping matrix
59+
freeze_set = ['freeze_q_proj']
60+
61+
for n, p in self.pretrained_model.named_parameters():
62+
if len(not_freeze_set) > 0 and self.check_params(n, not_freeze_set, all_match=all_match):
63+
print("tune "+ n)
64+
p.requires_grad = True
65+
else:
66+
p.requires_grad = False
67+
68+
if len(freeze_set) > 0 and self.check_params(n, freeze_set, all_match=False):
69+
p.requires_grad = False
70+
71+
logger.info("already freezed parameters!")
72+
73+
def check_params(self, module_name, safe_list, all_match=True):
74+
check = [partial_name in module_name for partial_name in safe_list]
75+
return all(check) if all_match else any(check)
76+
77+
def get_standard_prompt(self, bsz, nsamples=1):
78+
return self.prompt_model(bsz, nsamples, self.device)
79+
80+
def setup_prefix(self, args, config):
81+
if args.attn_mode == "prefix_nomlp":
82+
self.prompt_model = PrefixDirectInit(args, config)
83+
else:
84+
self.prompt_model = Prefix(args, config)
85+
self.get_prompt = self.get_standard_prompt
86+
87+
def setup_bias(self, args, config):
88+
self.prompt_model = Bias(args, config)
89+
self.get_prompt = self.get_standard_prompt
90+
91+
def setup_bias_mlp(self, args, config):
92+
self.prompt_model = MLP_Bias(args, config)
93+
self.get_prompt = self.get_standard_prompt
94+
95+
def get_fake_prompt(self, bsz, nsamples=-1):
96+
return None
97+
98+
def forward(self,
99+
input_ids=None,
100+
attention_mask=None,
101+
token_type_ids=None,
102+
position_ids=None,
103+
head_mask=None,
104+
inputs_embeds=None,
105+
labels=None,
106+
output_attentions=None,
107+
output_hidden_states=None,
108+
return_dict=None,
109+
):
110+
111+
bsz = input_ids.shape[0]
112+
prefix_state = self.get_prompt(bsz=bsz)
113+
114+
output = self.pretrained_model(input_ids=input_ids,
115+
attention_mask=attention_mask,
116+
token_type_ids=token_type_ids,
117+
position_ids=position_ids,
118+
head_mask=head_mask,
119+
inputs_embeds=inputs_embeds,
120+
labels=labels,
121+
output_attentions=output_attentions,
122+
output_hidden_states=output_hidden_states,
123+
return_dict=return_dict,
124+
prefix_state=prefix_state,
125+
)
126+
return output

0 commit comments

Comments
 (0)