zexu.pan commited on
Commit
7c271f3
·
1 Parent(s): ff25d78
Files changed (2) hide show
  1. networks.py +6 -1
  2. requirements.txt +2 -1
networks.py CHANGED
@@ -236,11 +236,16 @@ class select_network(nn.Module):
236
  padding_mask = torch.zeros_like(mixture).bool()
237
  a_ref = self.BEATs_model.extract_features(mixture, padding_mask=padding_mask)[0]
238
  a_ref = a_ref.transpose(1,2)
239
- return self.sep_network(mixture, t_ref, a_ref.clone().detach())
240
 
241
  return self.sep_network(mixture, t_ref)
242
 
243
 
 
 
 
 
 
244
 
245
  class network_wrapper(SpeechModel):
246
  def __init__(self, args):
 
236
  padding_mask = torch.zeros_like(mixture).bool()
237
  a_ref = self.BEATs_model.extract_features(mixture, padding_mask=padding_mask)[0]
238
  a_ref = a_ref.transpose(1,2)
239
+ return self.forword_step(mixture, t_ref, a_ref.clone().detach())
240
 
241
  return self.sep_network(mixture, t_ref)
242
 
243
 
244
+ def forword_step(self, mixture, t_ref, a_ref):
245
+ return self.sep_network(mixture, t_ref, a_ref)
246
+
247
+
248
+
249
 
250
  class network_wrapper(SpeechModel):
251
  def __init__(self, args):
requirements.txt CHANGED
@@ -1,3 +1,4 @@
 
1
  torchaudio
2
  torchinfo
3
  torchvision
@@ -5,7 +6,7 @@ librosa==0.10.2.post1
5
  numpy==1.26.3
6
  yamlargparse
7
  soundfile
8
- opencv-python
9
  ffmpeg-python
10
  scikit-learn==1.5.1
11
  scipy
 
1
+ gradio==3.44.4
2
  torchaudio
3
  torchinfo
4
  torchvision
 
6
  numpy==1.26.3
7
  yamlargparse
8
  soundfile
9
+ opencv-python==4.10.0.84
10
  ffmpeg-python
11
  scikit-learn==1.5.1
12
  scipy