alibabasglab commited on
Commit
27b0df9
·
verified ·
1 Parent(s): 83a0f63

Update networks.py

Browse files
Files changed (1) hide show
  1. networks.py +42 -64
networks.py CHANGED
@@ -53,7 +53,7 @@ class SpeechModel:
53
 
54
  def load_model(self):
55
 
56
- checkpoint_path = hf_hub_download(repo_id=f"alibabasglab/{self.args.model_name}", filename="last_best_checkpoint.pt")
57
 
58
  # Load the checkpoint file into memory (map_location ensures compatibility with different devices)
59
  checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
@@ -173,81 +173,59 @@ class select_network(nn.Module):
173
  super(select_network, self).__init__()
174
  self.args = args
175
 
176
- # audio backbone network
177
- if args.model_name == 'EEYD_mrx':
178
- from models.mrx.mrx import MRX
179
- self.sep_network = MRX(args)
180
- elif args.model_name in ['EEYD_demucs']:
181
- from models.eeyd.eeyd import eeyd
182
- self.sep_network = eeyd(args)
183
- elif args.model_name in ['EEYD_locoformer']:
184
- from models.tflocoformer.tflocoformer_separator import TFLocoformer
185
- self.sep_network = TFLocoformer(args)
186
- else:
187
- raise NameError('Wrong network selection')
188
 
189
  print(f'{args.model_name} running.')
190
 
191
 
192
- if self.args.network_reference.text_network == 't5':
193
- import os
194
- from transformers import AutoTokenizer, T5EncoderModel
195
- model_path = snapshot_download(repo_id="alibabasglab/t5-base")
196
- model_path = os.path.join(model_path, "t5-base")
197
- # model_path = hf_hub_download(repo_id="alibabasglab/t5-base", filename="t5-base")
198
- self.tokenizer =AutoTokenizer.from_pretrained(model_path, model_max_length=512)
199
- self.text_encoder = T5EncoderModel.from_pretrained(model_path)
200
- # os.environ["TOKENIZERS_PARALLELISM"] = "false"
201
- for param in self.text_encoder.parameters():
202
- param.requires_grad = False
203
- else:
204
- raise NameError('Wrong text network selection')
205
-
206
 
207
- if self.args.network_audio.backbone in ['eeyd','tflocoformer']:
208
- if self.args.network_audio.add_feature in ['beats']:
209
- from models.beats.BEATs import BEATs, BEATsConfig
210
- model_path = snapshot_download(repo_id="alibabasglab/beats")
211
- model_path = os.path.join(model_path, "BEATs_iter3_plus_AS2M.pt")
212
- checkpoint = torch.load(model_path)
213
- cfg = BEATsConfig(checkpoint['cfg'])
214
- self.BEATs_model = BEATs(cfg)
215
- self.BEATs_model.load_state_dict(checkpoint['model'])
216
- self.BEATs_model.eval()
217
 
218
- for param in self.BEATs_model.parameters():
219
- param.requires_grad = False
 
 
 
 
 
 
 
 
 
220
 
221
 
222
 
223
  def forward(self, mixture, t_ref, device):
224
  mixture = torch.tensor(mixture).to(device)
225
  mixture = mixture.unsqueeze(0)
226
- if self.args.network_reference.text_network == 't5':
227
- text_input = self.tokenizer(t_ref, return_tensors="pt", truncation=True, padding="longest")
228
- text_input_ids = text_input["input_ids"].to(device)
229
- text_attention_mask = text_input["attention_mask"].to(device)
230
- text_len = torch.sum(text_attention_mask, dim=1)
231
- text_embedding = self.text_encoder(input_ids=text_input_ids, attention_mask=text_attention_mask).last_hidden_state
232
- t_ref = (text_embedding.clone().detach(), text_attention_mask.clone().detach(), text_len.clone().detach())
233
- else: # clap series
234
- text_embedding = self.text_encoder.get_text_embedding(t_ref, use_tensor=True)
235
- text_embedding = text_embedding.clone().detach()
236
-
237
- text_attention_mask = torch.ones((text_embedding.shape[0],1), dtype=torch.int32)
238
- text_len = torch.ones((text_embedding.shape[0]), dtype=torch.int32)
239
- text_embedding = self.clap_us(text_embedding.unsqueeze(1))
240
- t_ref = (text_embedding, text_attention_mask.to(device), text_len)
241
-
242
- if self.args.network_audio.backbone in ['eeyd','tflocoformer']:
243
- if self.args.network_audio.add_feature in ['beats']:
244
- with torch.no_grad():
245
- padding_mask = torch.zeros_like(mixture).bool()
246
- a_ref = self.BEATs_model.extract_features(mixture, padding_mask=padding_mask)[0]
247
- a_ref = a_ref.transpose(1,2)
248
- return self.forword_step(mixture, t_ref, a_ref.clone().detach())
249
-
250
- return self.sep_network(mixture, t_ref)
251
 
252
 
253
  def forword_step(self, mixture, t_ref, a_ref):
 
53
 
54
  def load_model(self):
55
 
56
+ checkpoint_path = hf_hub_download(repo_id=f"alibabasglab/{self.args.model_name}", filename="last_checkpoint.pt")
57
 
58
  # Load the checkpoint file into memory (map_location ensures compatibility with different devices)
59
  checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
 
173
  super(select_network, self).__init__()
174
  self.args = args
175
 
176
+
177
+ from models.tflocoformer.tflocoformer_separator import TFLocoformer
178
+ self.sep_network = TFLocoformer(args)
179
+
 
 
 
 
 
 
 
 
180
 
181
  print(f'{args.model_name} running.')
182
 
183
 
184
+ import os
185
+ from transformers import AutoTokenizer, T5EncoderModel
186
+ model_path = snapshot_download(repo_id="alibabasglab/t5-base")
187
+ model_path = os.path.join(model_path, "t5-base")
188
+ # model_path = hf_hub_download(repo_id="alibabasglab/t5-base", filename="t5-base")
189
+ self.tokenizer =AutoTokenizer.from_pretrained(model_path, model_max_length=512)
190
+ self.text_encoder = T5EncoderModel.from_pretrained(model_path)
191
+ # os.environ["TOKENIZERS_PARALLELISM"] = "false"
192
+ for param in self.text_encoder.parameters():
193
+ param.requires_grad = False
194
+
 
 
 
195
 
 
 
 
 
 
 
 
 
 
 
196
 
197
+ from models.beats.BEATs import BEATs, BEATsConfig
198
+ model_path = snapshot_download(repo_id="alibabasglab/beats")
199
+ model_path = os.path.join(model_path, "BEATs_iter3_plus_AS2M.pt")
200
+ checkpoint = torch.load(model_path)
201
+ cfg = BEATsConfig(checkpoint['cfg'])
202
+ self.BEATs_model = BEATs(cfg)
203
+ self.BEATs_model.load_state_dict(checkpoint['model'])
204
+ self.BEATs_model.eval()
205
+
206
+ for param in self.BEATs_model.parameters():
207
+ param.requires_grad = False
208
 
209
 
210
 
211
  def forward(self, mixture, t_ref, device):
212
  mixture = torch.tensor(mixture).to(device)
213
  mixture = mixture.unsqueeze(0)
214
+
215
+ text_input = self.tokenizer(t_ref, return_tensors="pt", truncation=True, padding="longest")
216
+ text_input_ids = text_input["input_ids"].to(device)
217
+ text_attention_mask = text_input["attention_mask"].to(device)
218
+ text_len = torch.sum(text_attention_mask, dim=1)
219
+ text_embedding = self.text_encoder(input_ids=text_input_ids, attention_mask=text_attention_mask).last_hidden_state
220
+ t_ref = (text_embedding.clone().detach(), text_attention_mask.clone().detach(), text_len.clone().detach())
221
+
222
+
223
+ with torch.no_grad():
224
+ padding_mask = torch.zeros_like(mixture).bool()
225
+ a_ref = self.BEATs_model.extract_features(mixture, padding_mask=padding_mask)[0]
226
+ a_ref = a_ref.transpose(1,2)
227
+ return self.forword_step(mixture, t_ref, a_ref.clone().detach())
228
+
 
 
 
 
 
 
 
 
 
 
229
 
230
 
231
  def forword_step(self, mixture, t_ref, a_ref):