zexu.pan commited on
Commit
d35d2ae
·
1 Parent(s): fab03a2
config/EEYD_large.yaml DELETED
@@ -1,22 +0,0 @@
1
- #!/bin/bash
2
- mode: 'inference'
3
- use_cuda: 1 # 1 for True, 0 for False
4
- num_gpu: 1
5
- sampling_rate: 16000
6
- network: "EEYD_base" # network type
7
- checkpoint_dir: "checkpoints/EEYD_base"
8
-
9
- # decode parameters
10
- one_time_decode_length: 10 # maximum segment length for one-pass decoding (seconds), longer audio (>5s) will use segmented decoding
11
- decode_window: 10 # one-pass decoding length
12
-
13
-
14
- # network settings
15
- network_reference:
16
- cue: text #
17
- emb_size: 512 # resnet18:256
18
- text_layers: 3
19
- text_network: t5 # default t5, or clap
20
- network_audio:
21
- backbone: mrx
22
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
extract_everything.py CHANGED
@@ -25,6 +25,7 @@ class main(nn.Module):
25
  parser.add_argument('--sampling-rate', dest='sampling_rate', type=int, default=16000, help='Sampling rate')
26
  parser.add_argument('--one-time-decode-length', dest='one_time_decode_length', type=int, default=60, help='Max segment length for one-pass decoding')
27
  parser.add_argument('--decode-window', dest='decode_window', type=int, default=1, help='Decoding chunk size')
 
28
 
29
  # Parse arguments from the config file
30
  self.args = parser.parse_args(['--config', self.config_path])
 
25
  parser.add_argument('--sampling-rate', dest='sampling_rate', type=int, default=16000, help='Sampling rate')
26
  parser.add_argument('--one-time-decode-length', dest='one_time_decode_length', type=int, default=60, help='Max segment length for one-pass decoding')
27
  parser.add_argument('--decode-window', dest='decode_window', type=int, default=1, help='Decoding chunk size')
28
+ parser.add_argument('--output_residual', type=int, default=0)
29
 
30
  # Parse arguments from the config file
31
  self.args = parser.parse_args(['--config', self.config_path])
models/mossformer2/mossformer/__init__.py DELETED
File without changes
models/mossformer2/mossformer/utils/Transformer.py DELETED
@@ -1,460 +0,0 @@
1
- """Transformer implementaion for Mossformer2
2
-
3
- Authors
4
- * Shengkui Zhao 2024
5
- * Jia Qi Yip 2024
6
- """
7
-
8
- import math
9
- import torch
10
- import torch.nn as nn
11
- import torch.nn.functional as F
12
- from torch import einsum
13
-
14
- from typing import Optional
15
- import numpy as np
16
-
17
- # from ..utils.flash_pytorch_fsmn import FLASHTransformer_DualA_FSMN
18
- # from sb.nnet.normalization import LayerNorm
19
- # from speechbrain.lobes.models.layer_norm import CLayerNorm, GLayerNorm, GlobLayerNorm, ILayerNorm
20
- # from speechbrain.lobes.models.fsmn import UniDeepFsmn, UniDeepFsmn_dilated
21
- # from speechbrain.lobes.models.conv_module import ConvModule
22
- from einops import rearrange
23
- from rotary_embedding_torch import RotaryEmbedding
24
-
25
- from ..utils.fsmn import UniDeepFsmn, UniDeepFsmn_dilated
26
- from ..utils.normalization import LayerNorm, CLayerNorm, ScaleNorm
27
- from ..utils.conv_module import ConvModule
28
-
29
- def exists(val):
30
- return val is not None
31
-
32
- def padding_to_multiple_of(n, mult):
33
- remainder = n % mult
34
- if remainder == 0:
35
- return 0
36
- return mult - remainder
37
-
38
- def default(val, d):
39
- return val if exists(val) else d
40
-
41
- class FFConvM(nn.Module):
42
- def __init__(
43
- self,
44
- dim_in,
45
- dim_out,
46
- norm_klass = nn.LayerNorm,
47
- dropout = 0.1
48
- ):
49
- super().__init__()
50
- self.mdl = nn.Sequential(
51
- norm_klass(dim_in),
52
- nn.Linear(dim_in, dim_out),
53
- nn.SiLU(),
54
- ConvModule(dim_out),
55
- nn.Dropout(dropout)
56
- )
57
- def forward(
58
- self,
59
- x,
60
- ):
61
- output = self.mdl(x)
62
- return output
63
-
64
- class Gated_FSMN_dilated(nn.Module):
65
- def __init__(
66
- self,
67
- in_channels,
68
- out_channels,
69
- lorder,
70
- hidden_size
71
- ):
72
- super().__init__()
73
- self.to_u = FFConvM(
74
- dim_in = in_channels,
75
- dim_out = hidden_size,
76
- norm_klass = nn.LayerNorm,
77
- dropout = 0.1,
78
- )
79
- self.to_v = FFConvM(
80
- dim_in = in_channels,
81
- dim_out = hidden_size,
82
- norm_klass = nn.LayerNorm,
83
- dropout = 0.1,
84
- )
85
- self.fsmn = UniDeepFsmn_dilated(in_channels, out_channels, lorder, hidden_size)
86
-
87
- def forward(
88
- self,
89
- x,
90
- ):
91
- input = x
92
- x_u = self.to_u(x)
93
- x_v = self.to_v(x)
94
- x_u = self.fsmn(x_u)
95
- x = x_v * x_u + input
96
- return x
97
-
98
- class Gated_FSMN_Block_Dilated(nn.Module):
99
- """1-D convolutional block."""
100
-
101
- def __init__(self,
102
- dim,
103
- inner_channels = 256,
104
- group_size = 256, #384, #128, #256,
105
- #query_key_dim = 128, #256, #128,
106
- #expansion_factor = 4.,
107
- #causal = False,
108
- #dropout = 0.1,
109
- norm_type = 'scalenorm',
110
- #shift_tokens = True,
111
- #rotary_pos_emb = None,
112
- ):
113
- super(Gated_FSMN_Block_Dilated, self).__init__()
114
- if norm_type == 'scalenorm':
115
- norm_klass = ScaleNorm
116
- elif norm_type == 'layernorm':
117
- norm_klass = nn.LayerNorm
118
-
119
- self.group_size = group_size
120
-
121
- # rotary_pos_emb = RotaryEmbedding(dim = min(32, query_key_dim))
122
- self.conv1 = nn.Sequential(
123
- nn.Conv1d(dim, inner_channels, kernel_size=1),
124
- nn.PReLU(),
125
- )
126
- self.norm1 = CLayerNorm(inner_channels)
127
- #block dilated without gating
128
- #self.gated_fsmn = UniDeepFsmn_dilated(inner_channels, inner_channels, 20, inner_channels)
129
- #block dilated with gating
130
- self.gated_fsmn = Gated_FSMN_dilated(inner_channels, inner_channels, lorder=20, hidden_size=inner_channels)
131
- self.norm2 = CLayerNorm(inner_channels)
132
- self.conv2 = nn.Conv1d(inner_channels, dim, kernel_size=1)
133
-
134
- def forward(self, input):
135
- conv1 = self.conv1(input.transpose(2,1))
136
- norm1 = self.norm1(conv1)
137
- seq_out = self.gated_fsmn(norm1.transpose(2,1))
138
- norm2 = self.norm2(seq_out.transpose(2,1))
139
- conv2 = self.conv2(norm2)
140
- return conv2.transpose(2,1) + input
141
-
142
- class OffsetScale(nn.Module):
143
- def __init__(self, dim, heads = 1):
144
- super().__init__()
145
- self.gamma = nn.Parameter(torch.ones(heads, dim))
146
- self.beta = nn.Parameter(torch.zeros(heads, dim))
147
- nn.init.normal_(self.gamma, std = 0.02)
148
-
149
- def forward(self, x):
150
- out = einsum('... d, h d -> ... h d', x, self.gamma) + self.beta
151
- return out.unbind(dim = -2)
152
-
153
- class FLASH_ShareA_FFConvM(nn.Module):
154
- def __init__(
155
- self,
156
- *,
157
- dim,
158
- group_size = 256,
159
- query_key_dim = 128,
160
- expansion_factor = 1.,
161
- causal = False,
162
- dropout = 0.1,
163
- rotary_pos_emb = None,
164
- norm_klass = nn.LayerNorm,
165
- shift_tokens = True
166
- ):
167
- super().__init__()
168
- hidden_dim = int(dim * expansion_factor)
169
- self.group_size = group_size
170
- self.causal = causal
171
- self.shift_tokens = shift_tokens
172
-
173
- # positional embeddings
174
- self.rotary_pos_emb = rotary_pos_emb
175
- # norm
176
- self.dropout = nn.Dropout(dropout)
177
- #self.move = MultiHeadEMA(embed_dim=dim, ndim=4, bidirectional=False, truncation=None)
178
- # projections
179
-
180
- self.to_hidden = FFConvM(
181
- dim_in = dim,
182
- dim_out = hidden_dim,
183
- norm_klass = norm_klass,
184
- dropout = dropout,
185
- )
186
- self.to_qk = FFConvM(
187
- dim_in = dim,
188
- dim_out = query_key_dim,
189
- norm_klass = norm_klass,
190
- dropout = dropout,
191
- )
192
-
193
- self.qk_offset_scale = OffsetScale(query_key_dim, heads = 4)
194
-
195
- self.to_out = FFConvM(
196
- dim_in = dim*2,
197
- dim_out = dim,
198
- norm_klass = norm_klass,
199
- dropout = dropout,
200
- )
201
-
202
- self.gateActivate=nn.Sigmoid() #exp3
203
-
204
- def forward(
205
- self,
206
- x,
207
- *,
208
- mask = None
209
- ):
210
-
211
- """
212
- b - batch
213
- n - sequence length (within groups)
214
- g - group dimension
215
- d - feature dimension (keys)
216
- e - feature dimension (values)
217
- i - sequence dimension (source)
218
- j - sequence dimension (target)
219
- """
220
-
221
- #b, n, device, g = x.shape[0], x.shape[-2], x.device, self.group_size
222
-
223
- # prenorm
224
- #x = self.fsmn(x)
225
- normed_x = x #self.norm(x)
226
-
227
- # do token shift - a great, costless trick from an independent AI researcher in Shenzhen
228
- residual = x
229
-
230
- if self.shift_tokens:
231
- x_shift, x_pass = normed_x.chunk(2, dim = -1)
232
- x_shift = F.pad(x_shift, (0, 0, 1, -1), value = 0.)
233
- normed_x = torch.cat((x_shift, x_pass), dim = -1)
234
-
235
- # initial projections
236
-
237
- v, u = self.to_hidden(normed_x).chunk(2, dim = -1)
238
- qk = self.to_qk(normed_x)
239
- #print('normed_x: {}'.format(normed_x.shape))
240
-
241
- # offset and scale
242
- quad_q, lin_q, quad_k, lin_k = self.qk_offset_scale(qk)
243
- #print('q {}, k {}, v {}'.format(quad_q.shape, quad_k.shape, v.shape))
244
- att_v, att_u = self.cal_attention(x, quad_q, lin_q, quad_k, lin_k, v, u)
245
-
246
- #exp5: self.gateActivate=nn.SiLU()
247
- out = (att_u*v ) * self.gateActivate(att_v*u)
248
-
249
- x = x + self.to_out(out)
250
- #x = x + self.conv_module(x)
251
- return x
252
-
253
- def cal_attention(self, x, quad_q, lin_q, quad_k, lin_k, v, u, mask = None):
254
- b, n, device, g = x.shape[0], x.shape[-2], x.device, self.group_size
255
-
256
- if exists(mask):
257
- lin_mask = rearrange(mask, '... -> ... 1')
258
- lin_k = lin_k.masked_fill(~lin_mask, 0.)
259
-
260
- # rotate queries and keys
261
-
262
- if exists(self.rotary_pos_emb):
263
- quad_q, lin_q, quad_k, lin_k = map(self.rotary_pos_emb.rotate_queries_or_keys, (quad_q, lin_q, quad_k, lin_k))
264
-
265
- # padding for groups
266
-
267
- padding = padding_to_multiple_of(n, g)
268
-
269
- if padding > 0:
270
- quad_q, quad_k, lin_q, lin_k, v, u = map(lambda t: F.pad(t, (0, 0, 0, padding), value = 0.), (quad_q, quad_k, lin_q, lin_k, v, u))
271
-
272
- mask = default(mask, torch.ones((b, n), device = device, dtype = torch.bool))
273
- mask = F.pad(mask, (0, padding), value = False)
274
-
275
- # group along sequence
276
-
277
- quad_q, quad_k, lin_q, lin_k, v, u = map(lambda t: rearrange(t, 'b (g n) d -> b g n d', n = self.group_size), (quad_q, quad_k, lin_q, lin_k, v, u))
278
-
279
- if exists(mask):
280
- mask = rearrange(mask, 'b (g j) -> b g 1 j', j = g)
281
-
282
- # calculate quadratic attention output
283
-
284
- sim = einsum('... i d, ... j d -> ... i j', quad_q, quad_k) / g
285
-
286
- ###eddy REMOVE this part can solve infinite loss prob!!!!!!!!!!!!!
287
- #sim = sim + self.rel_pos_bias(sim)
288
-
289
- attn = F.relu(sim) ** 2
290
- #attn = F.relu(sim)
291
- attn = self.dropout(attn)
292
-
293
- if exists(mask):
294
- attn = attn.masked_fill(~mask, 0.)
295
-
296
- if self.causal:
297
- causal_mask = torch.ones((g, g), dtype = torch.bool, device = device).triu(1)
298
- attn = attn.masked_fill(causal_mask, 0.)
299
-
300
- quad_out_v = einsum('... i j, ... j d -> ... i d', attn, v)
301
- quad_out_u = einsum('... i j, ... j d -> ... i d', attn, u)
302
-
303
- # calculate linear attention output
304
-
305
- if self.causal:
306
- lin_kv = einsum('b g n d, b g n e -> b g d e', lin_k, v) / g
307
- # exclusive cumulative sum along group dimension
308
- lin_kv = lin_kv.cumsum(dim = 1)
309
- lin_kv = F.pad(lin_kv, (0, 0, 0, 0, 1, -1), value = 0.)
310
- lin_out_v = einsum('b g d e, b g n d -> b g n e', lin_kv, lin_q)
311
-
312
- lin_ku = einsum('b g n d, b g n e -> b g d e', lin_k, u) / g
313
- # exclusive cumulative sum along group dimension
314
- lin_ku = lin_ku.cumsum(dim = 1)
315
- lin_ku = F.pad(lin_ku, (0, 0, 0, 0, 1, -1), value = 0.)
316
- lin_out_u = einsum('b g d e, b g n d -> b g n e', lin_ku, lin_q)
317
- else:
318
- lin_kv = einsum('b g n d, b g n e -> b d e', lin_k, v) / n
319
- lin_out_v = einsum('b g n d, b d e -> b g n e', lin_q, lin_kv)
320
-
321
- lin_ku = einsum('b g n d, b g n e -> b d e', lin_k, u) / n
322
- lin_out_u = einsum('b g n d, b d e -> b g n e', lin_q, lin_ku)
323
-
324
- # fold back groups into full sequence, and excise out padding
325
- '''
326
- quad_attn_out_v, lin_attn_out_v = map(lambda t: rearrange(t, 'b g n d -> b (g n) d')[:, :n], (quad_out_v, lin_out_v))
327
- quad_attn_out_u, lin_attn_out_u = map(lambda t: rearrange(t, 'b g n d -> b (g n) d')[:, :n], (quad_out_u, lin_out_u))
328
- return quad_attn_out_v+lin_attn_out_v, quad_attn_out_u+lin_attn_out_u
329
- '''
330
- return map(lambda t: rearrange(t, 'b g n d -> b (g n) d')[:, :n], (quad_out_v+lin_out_v, quad_out_u+lin_out_u))
331
-
332
- class FLASHTransformer_DualA_FSMN(nn.Module):
333
- def __init__(
334
- self,
335
- *,
336
- dim,
337
- depth,
338
- group_size = 256, #384, #128, #256,
339
- query_key_dim = 128, #256, #128,
340
- expansion_factor = 4.,
341
- causal = False,
342
- attn_dropout = 0.1,
343
- norm_type = 'scalenorm',
344
- shift_tokens = True
345
- ):
346
- super().__init__()
347
- assert norm_type in ('scalenorm', 'layernorm'), 'norm_type must be one of scalenorm or layernorm'
348
-
349
- if norm_type == 'scalenorm':
350
- norm_klass = ScaleNorm
351
- elif norm_type == 'layernorm':
352
- norm_klass = nn.LayerNorm
353
-
354
- self.group_size = group_size
355
-
356
- rotary_pos_emb = RotaryEmbedding(dim = min(32, query_key_dim))
357
- # max rotary embedding dimensions of 32, partial Rotary embeddings, from Wang et al - GPT-J
358
- #self.fsmn = nn.ModuleList([Gated_FSMN(dim, dim, lorder=20, hidden_size=dim) for _ in range(depth)])
359
- #self.fsmn = nn.ModuleList([Gated_FSMN_Block(dim) for _ in range(depth)])
360
- self.fsmn = nn.ModuleList([Gated_FSMN_Block_Dilated(dim) for _ in range(depth)])
361
- self.layers = nn.ModuleList([FLASH_ShareA_FFConvM(dim = dim, group_size = group_size, query_key_dim = query_key_dim, expansion_factor = expansion_factor, causal = causal, dropout = attn_dropout, rotary_pos_emb = rotary_pos_emb, norm_klass = norm_klass, shift_tokens = shift_tokens) for _ in range(depth)])
362
-
363
- def _build_repeats(self, in_channels, out_channels, lorder, hidden_size, repeats=1):
364
- repeats = [
365
- UniDeepFsmn(in_channels, out_channels, lorder, hidden_size)
366
- for i in range(repeats)
367
- ]
368
- return nn.Sequential(*repeats)
369
-
370
- def forward(
371
- self,
372
- x,
373
- *,
374
- mask = None
375
- ):
376
- ii = 0
377
- for flash in self.layers:
378
- #x_residual = x
379
- x = flash(x, mask = mask)
380
- x = self.fsmn[ii](x)
381
- #x = x + x_residual
382
- ii = ii + 1
383
- return x
384
-
385
- class TransformerEncoder_FLASH_DualA_FSMN(nn.Module):
386
- """This class implements the transformer encoder.
387
-
388
- Arguments
389
- ---------
390
- num_layers : int
391
- Number of transformer layers to include.
392
- nhead : int
393
- Number of attention heads.
394
- d_ffn : int
395
- Hidden size of self-attention Feed Forward layer.
396
- d_model : int
397
- The dimension of the input embedding.
398
- kdim : int
399
- Dimension for key (Optional).
400
- vdim : int
401
- Dimension for value (Optional).
402
- dropout : float
403
- Dropout for the encoder (Optional).
404
- input_module: torch class
405
- The module to process the source input feature to expected
406
- feature dimension (Optional).
407
-
408
- Example
409
- -------
410
- >>> import torch
411
- >>> x = torch.rand((8, 60, 512))
412
- >>> net = TransformerEncoder(1, 8, 512, d_model=512)
413
- >>> output, _ = net(x)
414
- >>> output.shape
415
- torch.Size([8, 60, 512])
416
- """
417
- def __init__(
418
- self,
419
- num_layers,
420
- nhead,
421
- d_ffn,
422
- input_shape=None,
423
- d_model=None,
424
- kdim=None,
425
- vdim=None,
426
- dropout=0.0,
427
- activation=nn.ReLU,
428
- normalize_before=False,
429
- causal=False,
430
- attention_type="regularMHA",
431
- ):
432
-
433
- super().__init__()
434
-
435
- self.flashT = FLASHTransformer_DualA_FSMN(dim=d_model, depth=num_layers)
436
- self.norm = LayerNorm(d_model, eps=1e-6)
437
-
438
- def forward(
439
- self,
440
- src,
441
- src_mask: Optional[torch.Tensor] = None,
442
- src_key_padding_mask: Optional[torch.Tensor] = None,
443
- pos_embs: Optional[torch.Tensor] = None,
444
- ):
445
- """
446
- Arguments
447
- ----------
448
- src : tensor
449
- The sequence to the encoder layer (required).
450
- src_mask : tensor
451
- The mask for the src sequence (optional).
452
- src_key_padding_mask : tensor
453
- The mask for the src keys per batch (optional).
454
- """
455
- output = self.flashT(src)
456
- #summary(self.flashT, [(src.size())])
457
- output = self.norm(output)
458
- #summary(self.norm, [(output.size())])
459
-
460
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/mossformer2/mossformer/utils/__init__.py DELETED
File without changes
models/mossformer2/mossformer/utils/conv_module.py DELETED
@@ -1,87 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from torch import Tensor
4
- import torch.nn.init as init
5
- import torch.nn.functional as F
6
-
7
- class Transpose(nn.Module):
8
- """ Wrapper class of torch.transpose() for Sequential module. """
9
- def __init__(self, shape: tuple):
10
- super(Transpose, self).__init__()
11
- self.shape = shape
12
-
13
- def forward(self, x: Tensor) -> Tensor:
14
- return x.transpose(*self.shape)
15
-
16
- class DepthwiseConv1d(nn.Module):
17
- """
18
- When groups == in_channels and out_channels == K * in_channels, where K is a positive integer,
19
- this operation is termed in literature as depthwise convolution.
20
- Args:
21
- in_channels (int): Number of channels in the input
22
- out_channels (int): Number of channels produced by the convolution
23
- kernel_size (int or tuple): Size of the convolving kernel
24
- stride (int, optional): Stride of the convolution. Default: 1
25
- padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
26
- bias (bool, optional): If True, adds a learnable bias to the output. Default: True
27
- Inputs: inputs
28
- - **inputs** (batch, in_channels, time): Tensor containing input vector
29
- Returns: outputs
30
- - **outputs** (batch, out_channels, time): Tensor produces by depthwise 1-D convolution.
31
- """
32
- def __init__(
33
- self,
34
- in_channels: int,
35
- out_channels: int,
36
- kernel_size: int,
37
- stride: int = 1,
38
- padding: int = 0,
39
- bias: bool = False,
40
- ) -> None:
41
- super(DepthwiseConv1d, self).__init__()
42
- assert out_channels % in_channels == 0, "out_channels should be constant multiple of in_channels"
43
- self.conv = nn.Conv1d(
44
- in_channels=in_channels,
45
- out_channels=out_channels,
46
- kernel_size=kernel_size,
47
- groups=in_channels,
48
- stride=stride,
49
- padding=padding,
50
- bias=bias,
51
- )
52
-
53
- def forward(self, inputs: Tensor) -> Tensor:
54
- return self.conv(inputs)
55
-
56
- class ConvModule(nn.Module):
57
- """
58
- Conformer convolution module starts with a pointwise convolution and a gated linear unit (GLU).
59
- This is followed by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution
60
- to aid training deep models.
61
- Args:
62
- in_channels (int): Number of channels in the input
63
- kernel_size (int or tuple, optional): Size of the convolving kernel Default: 31
64
- dropout_p (float, optional): probability of dropout
65
- Inputs: inputs
66
- inputs (batch, time, dim): Tensor contains input sequences
67
- Outputs: outputs
68
- outputs (batch, time, dim): Tensor produces by conformer convolution module.
69
- """
70
- def __init__(
71
- self,
72
- in_channels: int,
73
- kernel_size: int = 17,
74
- expansion_factor: int = 2,
75
- dropout_p: float = 0.1,
76
- ) -> None:
77
- super(ConvModule, self).__init__()
78
- assert (kernel_size - 1) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
79
- assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2"
80
-
81
- self.sequential = nn.Sequential(
82
- Transpose(shape=(1, 2)),
83
- DepthwiseConv1d(in_channels, in_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2),
84
- )
85
-
86
- def forward(self, inputs: Tensor) -> Tensor:
87
- return inputs + self.sequential(inputs).transpose(1, 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/mossformer2/mossformer/utils/fsmn.py DELETED
@@ -1,108 +0,0 @@
1
- import torch.nn as nn
2
- import torch.nn.functional as F
3
- import torch as th
4
- from torch.nn.parameter import Parameter
5
- import numpy as np
6
- import os
7
-
8
- class UniDeepFsmn(nn.Module):
9
-
10
- def __init__(self, input_dim, output_dim, lorder=None, hidden_size=None):
11
- super(UniDeepFsmn, self).__init__()
12
-
13
- self.input_dim = input_dim
14
- self.output_dim = output_dim
15
-
16
- if lorder is None:
17
- return
18
-
19
- self.lorder = lorder
20
- self.hidden_size = hidden_size
21
-
22
- self.linear = nn.Linear(input_dim, hidden_size)
23
-
24
- self.project = nn.Linear(hidden_size, output_dim, bias=False)
25
-
26
- self.conv1 = nn.Conv2d(output_dim, output_dim, [lorder+lorder-1, 1], [1, 1], groups=output_dim, bias=False)
27
-
28
- def forward(self, input):
29
-
30
- f1 = F.relu(self.linear(input))
31
-
32
- p1 = self.project(f1)
33
-
34
- x = th.unsqueeze(p1, 1)
35
-
36
- x_per = x.permute(0, 3, 2, 1)
37
-
38
- y = F.pad(x_per, [0, 0, self.lorder - 1, self.lorder - 1])
39
-
40
- out = x_per + self.conv1(y)
41
-
42
- out1 = out.permute(0, 3, 2, 1)
43
-
44
- return input + out1.squeeze()
45
-
46
- class DilatedDenseNet(nn.Module):
47
- def __init__(self, depth=4, lorder=20, in_channels=64):
48
- super(DilatedDenseNet, self).__init__()
49
- self.depth = depth
50
- self.in_channels = in_channels
51
- self.pad = nn.ConstantPad2d((1, 1, 1, 0), value=0.)
52
- self.twidth = lorder*2-1
53
- self.kernel_size = (self.twidth, 1)
54
- for i in range(self.depth):
55
- dil = 2 ** i
56
- pad_length = lorder + (dil - 1) * (lorder - 1) - 1
57
- setattr(self, 'pad{}'.format(i + 1), nn.ConstantPad2d((0, 0, pad_length, pad_length), value=0.))
58
- setattr(self, 'conv{}'.format(i + 1),
59
- nn.Conv2d(self.in_channels*(i+1), self.in_channels, kernel_size=self.kernel_size,
60
- dilation=(dil, 1), groups=self.in_channels, bias=False))
61
- setattr(self, 'norm{}'.format(i + 1), nn.InstanceNorm2d(in_channels, affine=True))
62
- setattr(self, 'prelu{}'.format(i + 1), nn.PReLU(self.in_channels))
63
-
64
- def forward(self, x):
65
- skip = x
66
- for i in range(self.depth):
67
- out = getattr(self, 'pad{}'.format(i + 1))(skip)
68
- out = getattr(self, 'conv{}'.format(i + 1))(out)
69
- out = getattr(self, 'norm{}'.format(i + 1))(out)
70
- out = getattr(self, 'prelu{}'.format(i + 1))(out)
71
- skip = th.cat([out, skip], dim=1)
72
- return out
73
-
74
- class UniDeepFsmn_dilated(nn.Module):
75
-
76
- def __init__(self, input_dim, output_dim, lorder=None, hidden_size=None):
77
- super(UniDeepFsmn_dilated, self).__init__()
78
-
79
- self.input_dim = input_dim
80
- self.output_dim = output_dim
81
-
82
- if lorder is None:
83
- return
84
-
85
- self.lorder = lorder
86
- self.hidden_size = hidden_size
87
-
88
- self.linear = nn.Linear(input_dim, hidden_size)
89
-
90
- self.project = nn.Linear(hidden_size, output_dim, bias=False)
91
-
92
- self.conv = DilatedDenseNet(depth=2, lorder=lorder, in_channels=output_dim)
93
-
94
- def forward(self, input):
95
-
96
- f1 = F.relu(self.linear(input))
97
-
98
- p1 = self.project(f1)
99
-
100
- x = th.unsqueeze(p1, 1)
101
-
102
- x_per = x.permute(0, 3, 2, 1)
103
-
104
- out = self.conv(x_per)
105
-
106
- out1 = out.permute(0, 3, 2, 1)
107
-
108
- return input + out1.squeeze()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/mossformer2/mossformer/utils/normalization.py DELETED
@@ -1,94 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
-
4
- class LayerNorm(nn.Module):
5
- """
6
- This code came from sb.nnet.normalization
7
- # from sb.nnet.normalization import LayerNorm
8
-
9
-
10
- Applies layer normalization to the input tensor.
11
-
12
- Arguments
13
- ---------
14
- input_shape : tuple
15
- The expected shape of the input.
16
- eps : float
17
- This value is added to std deviation estimation to improve the numerical
18
- stability.
19
- elementwise_affine : bool
20
- If True, this module has learnable per-element affine parameters
21
- initialized to ones (for weights) and zeros (for biases).
22
-
23
- Example
24
- -------
25
- >>> input = torch.randn(100, 101, 128)
26
- >>> norm = LayerNorm(input_shape=input.shape)
27
- >>> output = norm(input)
28
- >>> output.shape
29
- torch.Size([100, 101, 128])
30
- """
31
-
32
- def __init__(
33
- self,
34
- input_size=None,
35
- input_shape=None,
36
- eps=1e-05,
37
- elementwise_affine=True,
38
- ):
39
- super().__init__()
40
- self.eps = eps
41
- self.elementwise_affine = elementwise_affine
42
-
43
- if input_shape is not None:
44
- input_size = input_shape[2:]
45
-
46
- self.norm = torch.nn.LayerNorm(
47
- input_size,
48
- eps=self.eps,
49
- elementwise_affine=self.elementwise_affine,
50
- )
51
-
52
- def forward(self, x):
53
- """Returns the normalized input tensor.
54
-
55
- Arguments
56
- ---------
57
- x : torch.Tensor (batch, time, channels)
58
- input to normalize. 3d or 4d tensors are expected.
59
- """
60
- return self.norm(x)
61
-
62
- class CLayerNorm(nn.LayerNorm):
63
- """Channel-wise layer normalization."""
64
-
65
- def __init__(self, *args, **kwargs):
66
- super(CLayerNorm, self).__init__(*args, **kwargs)
67
-
68
- def forward(self, sample):
69
- """Forward function.
70
-
71
- Args:
72
- sample: [batch_size, channels, length]
73
- """
74
- if sample.dim() != 3:
75
- raise RuntimeError('{} only accept 3-D tensor as input'.format(
76
- self.__name__))
77
- # [N, C, T] -> [N, T, C]
78
- sample = torch.transpose(sample, 1, 2)
79
- # LayerNorm
80
- sample = super().forward(sample)
81
- # [N, T, C] -> [N, C, T]
82
- sample = torch.transpose(sample, 1, 2)
83
- return sample
84
-
85
- class ScaleNorm(nn.Module):
86
- def __init__(self, dim, eps = 1e-5):
87
- super().__init__()
88
- self.scale = dim ** -0.5
89
- self.eps = eps
90
- self.g = nn.Parameter(torch.ones(1))
91
-
92
- def forward(self, x):
93
- norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
94
- return x / norm.clamp(min = self.eps) * self.g
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/mossformer2/mossformer/utils/one_path_flash_fsmn.py DELETED
@@ -1,800 +0,0 @@
1
- import copy
2
- import math
3
- import torch
4
-
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
-
8
- from torch import einsum
9
- from ..utils.Transformer import TransformerEncoder_FLASH_DualA_FSMN
10
-
11
- EPS = 1e-8
12
-
13
- class ScaledSinuEmbedding(nn.Module):
14
- def __init__(self, dim):
15
- super().__init__()
16
- self.scale = nn.Parameter(torch.ones(1,))
17
- inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
18
- self.register_buffer('inv_freq', inv_freq)
19
-
20
- def forward(self, x):
21
- n, device = x.shape[1], x.device
22
- t = torch.arange(n, device = device).type_as(self.inv_freq)
23
- sinu = einsum('i , j -> i j', t, self.inv_freq)
24
- emb = torch.cat((sinu.sin(), sinu.cos()), dim = -1)
25
- return emb * self.scale
26
-
27
- class Linear(torch.nn.Module):
28
- """Computes a linear transformation y = wx + b.
29
-
30
- Arguments
31
- ---------
32
- n_neurons : int
33
- It is the number of output neurons (i.e, the dimensionality of the
34
- output).
35
- input_shape: tuple
36
- It is the shape of the input tensor.
37
- input_size: int
38
- Size of the input tensor.
39
- bias : bool
40
- If True, the additive bias b is adopted.
41
- combine_dims : bool
42
- If True and the input is 4D, combine 3rd and 4th dimensions of input.
43
-
44
- Example
45
- -------
46
- >>> inputs = torch.rand(10, 50, 40)
47
- >>> lin_t = Linear(input_shape=(10, 50, 40), n_neurons=100)
48
- >>> output = lin_t(inputs)
49
- >>> output.shape
50
- torch.Size([10, 50, 100])
51
- """
52
-
53
- def __init__(
54
- self,
55
- n_neurons,
56
- input_shape=None,
57
- input_size=None,
58
- bias=True,
59
- combine_dims=False,
60
- ):
61
- super().__init__()
62
- self.combine_dims = combine_dims
63
-
64
- if input_shape is None and input_size is None:
65
- raise ValueError("Expected one of input_shape or input_size")
66
-
67
- if input_size is None:
68
- input_size = input_shape[-1]
69
- if len(input_shape) == 4 and self.combine_dims:
70
- input_size = input_shape[2] * input_shape[3]
71
-
72
- # Weights are initialized following pytorch approach
73
- self.w = nn.Linear(input_size, n_neurons, bias=bias)
74
-
75
- def forward(self, x):
76
- """Returns the linear transformation of input tensor.
77
-
78
- Arguments
79
- ---------
80
- x : torch.Tensor
81
- Input to transform linearly.
82
- """
83
- if x.ndim == 4 and self.combine_dims:
84
- x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3])
85
-
86
- wx = self.w(x)
87
-
88
- return wx
89
-
90
- class GlobalLayerNorm(nn.Module):
91
- """Calculate Global Layer Normalization.
92
-
93
- Arguments
94
- ---------
95
- dim : (int or list or torch.Size)
96
- Input shape from an expected input of size.
97
- eps : float
98
- A value added to the denominator for numerical stability.
99
- elementwise_affine : bool
100
- A boolean value that when set to True,
101
- this module has learnable per-element affine parameters
102
- initialized to ones (for weights) and zeros (for biases).
103
-
104
- Example
105
- -------
106
- >>> x = torch.randn(5, 10, 20)
107
- >>> GLN = GlobalLayerNorm(10, 3)
108
- >>> x_norm = GLN(x)
109
- """
110
-
111
- def __init__(self, dim, shape, eps=1e-8, elementwise_affine=True):
112
- super(GlobalLayerNorm, self).__init__()
113
- self.dim = dim
114
- self.eps = eps
115
- self.elementwise_affine = elementwise_affine
116
-
117
- if self.elementwise_affine:
118
- if shape == 3:
119
- self.weight = nn.Parameter(torch.ones(self.dim, 1))
120
- self.bias = nn.Parameter(torch.zeros(self.dim, 1))
121
- if shape == 4:
122
- self.weight = nn.Parameter(torch.ones(self.dim, 1, 1))
123
- self.bias = nn.Parameter(torch.zeros(self.dim, 1, 1))
124
- else:
125
- self.register_parameter("weight", None)
126
- self.register_parameter("bias", None)
127
-
128
- def forward(self, x):
129
- """Returns the normalized tensor.
130
-
131
- Arguments
132
- ---------
133
- x : torch.Tensor
134
- Tensor of size [N, C, K, S] or [N, C, L].
135
- """
136
- # x = N x C x K x S or N x C x L
137
- # N x 1 x 1
138
- # cln: mean,var N x 1 x K x S
139
- # gln: mean,var N x 1 x 1
140
- if x.dim() == 3:
141
- mean = torch.mean(x, (1, 2), keepdim=True)
142
- var = torch.mean((x - mean) ** 2, (1, 2), keepdim=True)
143
- if self.elementwise_affine:
144
- x = (
145
- self.weight * (x - mean) / torch.sqrt(var + self.eps)
146
- + self.bias
147
- )
148
- else:
149
- x = (x - mean) / torch.sqrt(var + self.eps)
150
-
151
- if x.dim() == 4:
152
- mean = torch.mean(x, (1, 2, 3), keepdim=True)
153
- var = torch.mean((x - mean) ** 2, (1, 2, 3), keepdim=True)
154
- if self.elementwise_affine:
155
- x = (
156
- self.weight * (x - mean) / torch.sqrt(var + self.eps)
157
- + self.bias
158
- )
159
- else:
160
- x = (x - mean) / torch.sqrt(var + self.eps)
161
- return x
162
-
163
-
164
- class CumulativeLayerNorm(nn.LayerNorm):
165
- """Calculate Cumulative Layer Normalization.
166
-
167
- Arguments
168
- ---------
169
- dim : int
170
- Dimension that you want to normalize.
171
- elementwise_affine : True
172
- Learnable per-element affine parameters.
173
-
174
- Example
175
- -------
176
- >>> x = torch.randn(5, 10, 20)
177
- >>> CLN = CumulativeLayerNorm(10)
178
- >>> x_norm = CLN(x)
179
- """
180
-
181
- def __init__(self, dim, elementwise_affine=True):
182
- super(CumulativeLayerNorm, self).__init__(
183
- dim, elementwise_affine=elementwise_affine, eps=1e-8
184
- )
185
-
186
- def forward(self, x):
187
- """Returns the normalized tensor.
188
-
189
- Arguments
190
- ---------
191
- x : torch.Tensor
192
- Tensor size [N, C, K, S] or [N, C, L]
193
- """
194
- # x: N x C x K x S or N x C x L
195
- # N x K x S x C
196
- if x.dim() == 4:
197
- x = x.permute(0, 2, 3, 1).contiguous()
198
- # N x K x S x C == only channel norm
199
- x = super().forward(x)
200
- # N x C x K x S
201
- x = x.permute(0, 3, 1, 2).contiguous()
202
- if x.dim() == 3:
203
- x = torch.transpose(x, 1, 2)
204
- # N x L x C == only channel norm
205
- x = super().forward(x)
206
- # N x C x L
207
- x = torch.transpose(x, 1, 2)
208
- return x
209
-
210
-
211
- def select_norm(norm, dim, shape):
212
- """Just a wrapper to select the normalization type.
213
- """
214
-
215
- if norm == "gln":
216
- return GlobalLayerNorm(dim, shape, elementwise_affine=True)
217
- if norm == "cln":
218
- return CumulativeLayerNorm(dim, elementwise_affine=True)
219
- if norm == "ln":
220
- return nn.GroupNorm(1, dim, eps=1e-8)
221
- else:
222
- return nn.BatchNorm1d(dim)
223
-
224
- class Encoder(nn.Module):
225
- """Convolutional Encoder Layer.
226
-
227
- Arguments
228
- ---------
229
- kernel_size : int
230
- Length of filters.
231
- in_channels : int
232
- Number of input channels.
233
- out_channels : int
234
- Number of output channels.
235
-
236
- Example
237
- -------
238
- >>> x = torch.randn(2, 1000)
239
- >>> encoder = Encoder(kernel_size=4, out_channels=64)
240
- >>> h = encoder(x)
241
- >>> h.shape
242
- torch.Size([2, 64, 499])
243
- """
244
-
245
- def __init__(self, kernel_size=2, out_channels=64, in_channels=1):
246
- super(Encoder, self).__init__()
247
- self.conv1d = nn.Conv1d(
248
- in_channels=in_channels,
249
- out_channels=out_channels,
250
- kernel_size=kernel_size,
251
- stride=kernel_size // 2,
252
- groups=1,
253
- bias=False,
254
- )
255
- self.in_channels = in_channels
256
-
257
- def forward(self, x):
258
- """Return the encoded output.
259
-
260
- Arguments
261
- ---------
262
- x : torch.Tensor
263
- Input tensor with dimensionality [B, L].
264
- Return
265
- ------
266
- x : torch.Tensor
267
- Encoded tensor with dimensionality [B, N, T_out].
268
-
269
- where B = Batchsize
270
- L = Number of timepoints
271
- N = Number of filters
272
- T_out = Number of timepoints at the output of the encoder
273
- """
274
- # B x L -> B x 1 x L
275
- if self.in_channels == 1:
276
- x = torch.unsqueeze(x, dim=1)
277
- # B x 1 x L -> B x N x T_out
278
- x = self.conv1d(x)
279
- x = F.relu(x)
280
-
281
- return x
282
-
283
- class Decoder(nn.ConvTranspose1d):
284
- """A decoder layer that consists of ConvTranspose1d.
285
-
286
- Arguments
287
- ---------
288
- kernel_size : int
289
- Length of filters.
290
- in_channels : int
291
- Number of input channels.
292
- out_channels : int
293
- Number of output channels.
294
-
295
-
296
- Example
297
- ---------
298
- >>> x = torch.randn(2, 100, 1000)
299
- >>> decoder = Decoder(kernel_size=4, in_channels=100, out_channels=1)
300
- >>> h = decoder(x)
301
- >>> h.shape
302
- torch.Size([2, 1003])
303
- """
304
-
305
- def __init__(self, *args, **kwargs):
306
- super(Decoder, self).__init__(*args, **kwargs)
307
-
308
- def forward(self, x):
309
- """Return the decoded output.
310
-
311
- Arguments
312
- ---------
313
- x : torch.Tensor
314
- Input tensor with dimensionality [B, N, L].
315
- where, B = Batchsize,
316
- N = number of filters
317
- L = time points
318
- """
319
-
320
- if x.dim() not in [2, 3]:
321
- raise RuntimeError(
322
- "{} accept 3/4D tensor as input".format(self.__name__)
323
- )
324
- x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1))
325
-
326
- if torch.squeeze(x).dim() == 1:
327
- x = torch.squeeze(x, dim=1)
328
- else:
329
- x = torch.squeeze(x)
330
- return x
331
-
332
- class SBFLASHBlock_DualA(nn.Module):
333
- """A wrapper for the SpeechBrain implementation of the transformer encoder.
334
-
335
- Arguments
336
- ---------
337
- num_layers : int
338
- Number of layers.
339
- d_model : int
340
- Dimensionality of the representation.
341
- nhead : int
342
- Number of attention heads.
343
- d_ffn : int
344
- Dimensionality of positional feed forward.
345
- input_shape : tuple
346
- Shape of input.
347
- kdim : int
348
- Dimension of the key (Optional).
349
- vdim : int
350
- Dimension of the value (Optional).
351
- dropout : float
352
- Dropout rate.
353
- activation : str
354
- Activation function.
355
- use_positional_encoding : bool
356
- If true we use a positional encoding.
357
- norm_before: bool
358
- Use normalization before transformations.
359
-
360
- Example
361
- ---------
362
- >>> x = torch.randn(10, 100, 64)
363
- >>> block = SBTransformerBlock(1, 64, 8)
364
- >>> x = block(x)
365
- >>> x.shape
366
- torch.Size([10, 100, 64])
367
- """
368
-
369
- def __init__(
370
- self,
371
- num_layers,
372
- d_model,
373
- nhead,
374
- d_ffn=2048,
375
- input_shape=None,
376
- kdim=None,
377
- vdim=None,
378
- dropout=0.1,
379
- activation="relu",
380
- use_positional_encoding=False,
381
- norm_before=False,
382
- attention_type="regularMHA",
383
- ):
384
-
385
- super(SBFLASHBlock_DualA, self).__init__()
386
- self.use_positional_encoding = use_positional_encoding
387
-
388
- if activation == "relu":
389
- activation = nn.ReLU
390
- elif activation == "gelu":
391
- activation = nn.GELU
392
- else:
393
- raise ValueError("unknown activation")
394
-
395
-
396
- self.mdl = TransformerEncoder_FLASH_DualA_FSMN(
397
- num_layers=num_layers,
398
- nhead=nhead,
399
- d_ffn=d_ffn,
400
- input_shape=input_shape,
401
- d_model=d_model,
402
- kdim=kdim,
403
- vdim=vdim,
404
- dropout=dropout,
405
- activation=activation,
406
- normalize_before=norm_before,
407
- attention_type=attention_type,
408
- )
409
-
410
- def forward(self, x):
411
- """Returns the transformed output.
412
-
413
- Arguments
414
- ---------
415
- x : torch.Tensor
416
- Tensor shape [B, L, N],
417
- where, B = Batchsize,
418
- L = time points
419
- N = number of filters
420
-
421
- """
422
- output = self.mdl(x)
423
-
424
- return output
425
-
426
-
427
- def _get_activation_fn(activation):
428
- """Just a wrapper to get the activation functions.
429
- """
430
-
431
- if activation == "relu":
432
- return F.relu
433
- elif activation == "gelu":
434
- return F.gelu
435
-
436
-
437
- class Dual_Computation_Block(nn.Module):
438
- """Computation block for dual-path processing.
439
-
440
- Arguments
441
- ---------
442
- intra_mdl : torch.nn.module
443
- Model to process within the chunks.
444
- inter_mdl : torch.nn.module
445
- Model to process across the chunks.
446
- out_channels : int
447
- Dimensionality of inter/intra model.
448
- norm : str
449
- Normalization type.
450
- skip_around_intra : bool
451
- Skip connection around the intra layer.
452
- linear_layer_after_inter_intra : bool
453
- Linear layer or not after inter or intra.
454
-
455
- Example
456
- ---------
457
- >>> intra_block = SBTransformerBlock(1, 64, 8)
458
- >>> inter_block = SBTransformerBlock(1, 64, 8)
459
- >>> dual_comp_block = Dual_Computation_Block(intra_block, inter_block, 64)
460
- >>> x = torch.randn(10, 64, 100, 10)
461
- >>> x = dual_comp_block(x)
462
- >>> x.shape
463
- torch.Size([10, 64, 100, 10])
464
- """
465
-
466
- def __init__(
467
- self,
468
- intra_mdl,
469
- out_channels,
470
- norm="ln",
471
- skip_around_intra=True,
472
- linear_layer_after_inter_intra=True,
473
- ):
474
- super(Dual_Computation_Block, self).__init__()
475
-
476
- self.intra_mdl = intra_mdl
477
- self.skip_around_intra = skip_around_intra
478
- self.linear_layer_after_inter_intra = linear_layer_after_inter_intra
479
-
480
- # Norm
481
- self.norm = norm
482
- if norm is not None:
483
- self.intra_norm = select_norm(norm, out_channels, 3)
484
-
485
- # Linear
486
- if linear_layer_after_inter_intra:
487
- self.intra_linear = Linear(
488
- out_channels, input_size=out_channels
489
- )
490
-
491
- def forward(self, x):
492
- """Returns the output tensor.
493
-
494
- Arguments
495
- ---------
496
- x : torch.Tensor
497
- Input tensor of dimension [B, N, K, S].
498
-
499
-
500
- Return
501
- ---------
502
- out: torch.Tensor
503
- Output tensor of dimension [B, N, K, S].
504
- where, B = Batchsize,
505
- N = number of filters
506
- K = time points in each chunk
507
- S = the number of chunks
508
- """
509
- B, N, S = x.shape
510
- # intra RNN
511
- # [B, S, N]
512
- intra = x.permute(0, 2, 1).contiguous() #.view(B, S, N)
513
-
514
- intra = self.intra_mdl(intra)
515
-
516
- # [B, S, N]
517
- if self.linear_layer_after_inter_intra:
518
- intra = self.intra_linear(intra)
519
-
520
- # [B, N, S]
521
- intra = intra.permute(0, 2, 1).contiguous()
522
- if self.norm is not None:
523
- intra = self.intra_norm(intra)
524
-
525
- # [B, N, S]
526
- if self.skip_around_intra:
527
- intra = intra + x
528
-
529
- # inter RNN
530
- # [B, S, N]
531
- '''
532
- inter = intra.permute(0, 2, 1).contiguous() #.view(B, S, N)
533
- # [BK, S, H]
534
- inter = self.inter_mdl(inter)
535
-
536
- # [BK, S, N]
537
- if self.linear_layer_after_inter_intra:
538
- inter = self.inter_linear(inter)
539
-
540
- # [B, N, S]
541
- inter = inter.permute(0, 2, 1).contiguous()
542
- if self.norm is not None:
543
- inter = self.inter_norm(inter)
544
- # [B, N, K, S]
545
- out = inter + intra
546
- '''
547
- out = intra
548
- return out
549
-
550
-
551
- class Dual_Path_Model(nn.Module):
552
- """The dual path model which is the basis for dualpathrnn, sepformer, dptnet.
553
-
554
- Arguments
555
- ---------
556
- in_channels : int
557
- Number of channels at the output of the encoder.
558
- out_channels : int
559
- Number of channels that would be inputted to the intra and inter blocks.
560
- intra_model : torch.nn.module
561
- Model to process within the chunks.
562
- inter_model : torch.nn.module
563
- model to process across the chunks,
564
- num_layers : int
565
- Number of layers of Dual Computation Block.
566
- norm : str
567
- Normalization type.
568
- K : int
569
- Chunk length.
570
- num_spks : int
571
- Number of sources (speakers).
572
- skip_around_intra : bool
573
- Skip connection around intra.
574
- linear_layer_after_inter_intra : bool
575
- Linear layer after inter and intra.
576
- use_global_pos_enc : bool
577
- Global positional encodings.
578
- max_length : int
579
- Maximum sequence length.
580
-
581
- Example
582
- ---------
583
- >>> intra_block = SBTransformerBlock(1, 64, 8)
584
- >>> inter_block = SBTransformerBlock(1, 64, 8)
585
- >>> dual_path_model = Dual_Path_Model(64, 64, intra_block, inter_block, num_spks=2)
586
- >>> x = torch.randn(10, 64, 2000)
587
- >>> x = dual_path_model(x)
588
- >>> x.shape
589
- torch.Size([2, 10, 64, 2000])
590
- """
591
-
592
- def __init__(
593
- self,
594
- in_channels,
595
- out_channels,
596
- intra_model,
597
- #inter_model,
598
- num_layers=1,
599
- norm="ln",
600
- K=200,
601
- num_spks=2,
602
- skip_around_intra=True,
603
- linear_layer_after_inter_intra=True,
604
- use_global_pos_enc=True,
605
- max_length=20000,
606
- ):
607
- super(Dual_Path_Model, self).__init__()
608
- self.K = K
609
- self.num_spks = num_spks
610
- self.num_layers = num_layers
611
- # self.norm = select_norm(norm, in_channels, 3)
612
- # self.conv1d_encoder = nn.Conv1d(in_channels, out_channels, 1, bias=False)
613
- self.use_global_pos_enc = use_global_pos_enc
614
-
615
- if self.use_global_pos_enc:
616
- self.pos_enc = ScaledSinuEmbedding(out_channels)
617
-
618
- self.dual_mdl = nn.ModuleList([])
619
- for i in range(num_layers):
620
- self.dual_mdl.append(
621
- copy.deepcopy(
622
- Dual_Computation_Block(
623
- intra_model,
624
- #inter_model,
625
- out_channels,
626
- norm,
627
- skip_around_intra=skip_around_intra,
628
- linear_layer_after_inter_intra=linear_layer_after_inter_intra,
629
- )
630
- )
631
- )
632
-
633
- self.conv1d_out = nn.Conv1d(
634
- out_channels, out_channels * num_spks, kernel_size=1
635
- )
636
- self.conv1_decoder = nn.Conv1d(out_channels, in_channels, 1, bias=False)
637
- self.prelu = nn.PReLU()
638
- self.activation = nn.ReLU()
639
- # gated output layer
640
- self.output = nn.Sequential(
641
- nn.Conv1d(out_channels, out_channels, 1), nn.Tanh()
642
- )
643
- self.output_gate = nn.Sequential(
644
- nn.Conv1d(out_channels, out_channels, 1), nn.Sigmoid()
645
- )
646
-
647
- def forward(self, x):
648
- """Returns the output tensor.
649
-
650
- Arguments
651
- ---------
652
- x : torch.Tensor
653
- Input tensor of dimension [B, N, L].
654
-
655
- Returns
656
- -------
657
- out : torch.Tensor
658
- Output tensor of dimension [spks, B, N, L]
659
- where, spks = Number of speakers
660
- B = Batchsize,
661
- N = number of filters
662
- L = the number of time points
663
- """
664
-
665
- # before each line we indicate the shape after executing the line
666
-
667
- # # [B, N, L]
668
- # x = self.norm(x)
669
-
670
- # # [B, N, L]
671
- # x = self.conv1d_encoder(x)
672
-
673
- if self.use_global_pos_enc:
674
- base = x
675
- x = x.transpose(1, -1)
676
- emb = self.pos_enc(x)
677
- emb = emb.transpose(0, -1)
678
- x = base + emb
679
-
680
- # [B, N, S]
681
- for i in range(self.num_layers):
682
- x = self.dual_mdl[i](x)
683
- x = self.prelu(x)
684
-
685
- # [B, N*spks, K, S]
686
- x = self.conv1d_out(x)
687
- B, _, S = x.shape
688
-
689
- # [B*spks, N, K, S]
690
- x = x.view(B * self.num_spks, -1, S)
691
-
692
- # [B*spks, N, L]
693
- x = self.output(x) * self.output_gate(x)
694
-
695
- # [B*spks, N, L]
696
- x = self.conv1_decoder(x)
697
-
698
- # [B, spks, N, L]
699
- _, N, L = x.shape
700
- x = x.view(B, self.num_spks, N, L)
701
- x = self.activation(x)
702
-
703
- # [spks, B, N, L]
704
- x = x.transpose(0, 1)
705
-
706
- return x
707
-
708
- def _padding(self, input, K):
709
- """Padding the audio times.
710
-
711
- Arguments
712
- ---------
713
- K : int
714
- Chunks of length.
715
- P : int
716
- Hop size.
717
- input : torch.Tensor
718
- Tensor of size [B, N, L].
719
- where, B = Batchsize,
720
- N = number of filters
721
- L = time points
722
- """
723
- B, N, L = input.shape
724
- P = K // 2
725
- gap = K - (P + L % K) % K
726
- if gap > 0:
727
- pad = torch.Tensor(torch.zeros(B, N, gap)).type(input.type())
728
- input = torch.cat([input, pad], dim=2)
729
-
730
- _pad = torch.Tensor(torch.zeros(B, N, P)).type(input.type())
731
- input = torch.cat([_pad, input, _pad], dim=2)
732
-
733
- return input, gap
734
-
735
- def _Segmentation(self, input, K):
736
- """The segmentation stage splits
737
-
738
- Arguments
739
- ---------
740
- K : int
741
- Length of the chunks.
742
- input : torch.Tensor
743
- Tensor with dim [B, N, L].
744
-
745
- Return
746
- -------
747
- output : torch.tensor
748
- Tensor with dim [B, N, K, S].
749
- where, B = Batchsize,
750
- N = number of filters
751
- K = time points in each chunk
752
- S = the number of chunks
753
- L = the number of time points
754
- """
755
- B, N, L = input.shape
756
- P = K // 2
757
- input, gap = self._padding(input, K)
758
- # [B, N, K, S]
759
- input1 = input[:, :, :-P].contiguous().view(B, N, -1, K)
760
- input2 = input[:, :, P:].contiguous().view(B, N, -1, K)
761
- input = (
762
- torch.cat([input1, input2], dim=3).view(B, N, -1, K).transpose(2, 3)
763
- )
764
-
765
- return input.contiguous(), gap
766
-
767
- def _over_add(self, input, gap):
768
- """Merge the sequence with the overlap-and-add method.
769
-
770
- Arguments
771
- ---------
772
- input : torch.tensor
773
- Tensor with dim [B, N, K, S].
774
- gap : int
775
- Padding length.
776
-
777
- Return
778
- -------
779
- output : torch.tensor
780
- Tensor with dim [B, N, L].
781
- where, B = Batchsize,
782
- N = number of filters
783
- K = time points in each chunk
784
- S = the number of chunks
785
- L = the number of time points
786
-
787
- """
788
- B, N, K, S = input.shape
789
- P = K // 2
790
- # [B, N, S, K]
791
- input = input.transpose(2, 3).contiguous().view(B, N, -1, K * 2)
792
-
793
- input1 = input[:, :, :, :K].contiguous().view(B, N, -1)[:, :, P:]
794
- input2 = input[:, :, :, K:].contiguous().view(B, N, -1)[:, :, :-P]
795
- input = input1 + input2
796
- # [B, N, L]
797
- if gap > 0:
798
- input = input[:, :, :-gap]
799
-
800
- return input
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/mossformer2/mossformer2.py DELETED
@@ -1,216 +0,0 @@
1
- import os
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- import torchaudio
6
-
7
- import math
8
-
9
- from .mossformer.utils.one_path_flash_fsmn import Dual_Path_Model, SBFLASHBlock_DualA
10
- from torch.nn import TransformerEncoder, TransformerEncoderLayer
11
-
12
- EPS = 1e-8
13
-
14
- class Mossformer(nn.Module):
15
- def __init__(self, args):
16
- super(Mossformer, self).__init__()
17
-
18
- N, L, = args.network_audio.encoder_out_nchannels, args.network_audio.encoder_kernel_size
19
-
20
- self.encoder = Encoder(L, N)
21
- self.separator = Separator(args)
22
- self.decoder = Decoder(args, N, L)
23
-
24
- for p in self.parameters():
25
- if p.dim() > 1:
26
- nn.init.xavier_normal_(p)
27
-
28
- def forward(self, mixture, visual):
29
- """
30
- Args:
31
- mixture: [M, T], M is batch size, T is #samples
32
- Returns:
33
- est_source: [M, C, T]
34
- """
35
- mixture_w = self.encoder(mixture)
36
- est_mask = self.separator(mixture_w, visual)
37
- est_source = self.decoder(mixture_w, est_mask)
38
-
39
- # T changed after conv1d in encoder, fix it here
40
- T_origin = mixture.size(-1)
41
- T_conv = est_source.size(-1)
42
- est_source = F.pad(est_source, (0, T_origin - T_conv))
43
- return est_source
44
-
45
- class Encoder(nn.Module):
46
- def __init__(self, L, N):
47
- super(Encoder, self).__init__()
48
- self.L, self.N = L, N
49
- self.conv1d_U = nn.Conv1d(1, N, kernel_size=L, stride=L // 2, bias=False)
50
-
51
- def forward(self, mixture):
52
- """
53
- Args:
54
- mixture: [M, T], M is batch size, T is #samples
55
- Returns:
56
- mixture_w: [M, N, K], where K = (T-L)/(L/2)+1 = 2T/L-1
57
- """
58
- mixture = torch.unsqueeze(mixture, 1) # [M, 1, T]
59
- mixture_w = F.relu(self.conv1d_U(mixture)) # [M, N, K]
60
- return mixture_w
61
-
62
-
63
- class Decoder(nn.Module):
64
- def __init__(self, args, N, L):
65
- super(Decoder, self).__init__()
66
- self.N, self.L, self.args = N, L, args
67
- self.basis_signals = nn.Linear(N, L, bias=False)
68
-
69
- def forward(self, mixture_w, est_mask):
70
- """
71
- Args:
72
- mixture_w: [M, N, K]
73
- est_mask: [M, C, N, K]
74
- Returns:
75
- est_source: [M, C, T]
76
- """
77
- est_source = mixture_w * est_mask
78
- est_source = torch.transpose(est_source, 2, 1) # [M, K, N]
79
- est_source = self.basis_signals(est_source) # [M, K, L]
80
- est_source = overlap_and_add(est_source, self.L//2) # M x C x T
81
- return est_source
82
-
83
-
84
-
85
-
86
- class Separator(nn.Module):
87
- def __init__(self, args):
88
- super(Separator, self).__init__()
89
-
90
- self.layer_norm = nn.GroupNorm(1, args.network_audio.encoder_out_nchannels, eps=1e-8)
91
- self.bottleneck_conv1x1 = nn.Conv1d(args.network_audio.encoder_out_nchannels, args.network_audio.encoder_out_nchannels, 1, bias=False)
92
-
93
- # mossformer 2
94
- intra_model = SBFLASHBlock_DualA(
95
- num_layers=args.network_audio.intra_numlayers,
96
- d_model=args.network_audio.encoder_out_nchannels,
97
- nhead=args.network_audio.intra_nhead,
98
- d_ffn=args.network_audio.intra_dffn,
99
- dropout=args.network_audio.intra_dropout,
100
- use_positional_encoding=args.network_audio.intra_use_positional,
101
- norm_before=args.network_audio.intra_norm_before
102
- )
103
-
104
- self.masknet = Dual_Path_Model(
105
- in_channels=args.network_audio.encoder_out_nchannels,
106
- out_channels=args.network_audio.encoder_out_nchannels,
107
- intra_model=intra_model,
108
- num_layers=args.network_audio.masknet_numlayers,
109
- norm=args.network_audio.masknet_norm,
110
- K=args.network_audio.masknet_chunksize,
111
- num_spks=args.network_audio.masknet_numspks,
112
- skip_around_intra=args.network_audio.masknet_extraskipconnection,
113
- linear_layer_after_inter_intra=args.network_audio.masknet_useextralinearlayer
114
- )
115
-
116
- # reference
117
- self.args = args
118
- if self.args.network_reference.cue == 'text':
119
- self.ref_ds = nn.Linear(768, args.network_reference.emb_size)
120
- encoder_layers = TransformerEncoderLayer(d_model=args.network_reference.emb_size, nhead=2, dim_feedforward=args.network_reference.emb_size*2, batch_first=True)
121
- self.text_net = TransformerEncoder(encoder_layers, num_layers=args.network_reference.text_layers)
122
- self.summarize = nn.LSTM(args.network_reference.emb_size, args.network_reference.emb_size, num_layers=1, batch_first=True)
123
- self.fusion = nn.Linear(512+args.network_reference.emb_size, 512)
124
- elif self.args.network_reference.cue == 'audio':
125
- self.ref_ds = nn.Linear(768, args.network_reference.emb_size)
126
- encoder_layers = TransformerEncoderLayer(d_model=args.network_reference.emb_size, nhead=2, dim_feedforward=args.network_reference.emb_size*2, batch_first=True)
127
- self.audio_net = TransformerEncoder(encoder_layers, num_layers=args.network_reference.text_layers)
128
- self.summarize = nn.LSTM(args.network_reference.emb_size, args.network_reference.emb_size, num_layers=1, batch_first=True)
129
- self.fusion = nn.Linear(512+args.network_reference.emb_size, 512)
130
-
131
- def forward(self, x, ref):
132
- """
133
- Keep this API same with TasNet
134
- Args:
135
- mixture_w: [M, N, K], M is batch size
136
- returns:
137
- est_mask: [M, C, N, K]
138
- """
139
- M, N, D = x.size()
140
-
141
- x = self.layer_norm(x)
142
- x = self.bottleneck_conv1x1(x)
143
-
144
- cross_0 = x.transpose(1,2)
145
- if self.args.network_reference.cue == 'text':
146
- text_embedding, text_attention_mask, text_len = ref
147
- text_embedding = self.ref_ds(text_embedding)
148
- text_attention_mask = (text_attention_mask==0)
149
- text_embedding = self.text_net(text_embedding, src_key_padding_mask=text_attention_mask)
150
- text_embedding, _ = self.summarize(text_embedding)
151
- text_len = text_len-1
152
- batch_indices = torch.arange(text_embedding.size(0))
153
- text = text_embedding[batch_indices, text_len]
154
-
155
- text = torch.repeat_interleave(text.unsqueeze(1), repeats=cross_0.shape[1], dim=1)
156
- cross_1 = torch.cat((cross_0, text),2)
157
- cross_1 = self.fusion(cross_1)
158
-
159
- elif self.args.network_reference.cue == 'audio':
160
- audio = self.ref_ds(ref)
161
- audio = self.audio_net(audio)
162
- audio, _ = self.summarize(audio)
163
- audio = audio[:,-1,:]
164
-
165
- audio = torch.repeat_interleave(audio.unsqueeze(1), repeats=cross_0.shape[1], dim=1)
166
- cross_1 = torch.cat((cross_0, audio),2)
167
- cross_1 = self.fusion(cross_1)
168
- x = cross_1.transpose(1,2)
169
-
170
-
171
- x = self.masknet(x)
172
-
173
- x = x.squeeze(0)
174
-
175
- return x
176
-
177
-
178
-
179
- def overlap_and_add(signal, frame_step):
180
- """Reconstructs a signal from a framed representation.
181
-
182
- Adds potentially overlapping frames of a signal with shape
183
- `[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`.
184
- The resulting tensor has shape `[..., output_size]` where
185
-
186
- output_size = (frames - 1) * frame_step + frame_length
187
-
188
- Args:
189
- signal: A [..., frames, frame_length] Tensor. All dimensions may be unknown, and rank must be at least 2.
190
- frame_step: An integer denoting overlap offsets. Must be less than or equal to frame_length.
191
-
192
- Returns:
193
- A Tensor with shape [..., output_size] containing the overlap-added frames of signal's inner-most two dimensions.
194
- output_size = (frames - 1) * frame_step + frame_length
195
-
196
- Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/signal/python/ops/reconstruction_ops.py
197
- """
198
- outer_dimensions = signal.size()[:-2]
199
- frames, frame_length = signal.size()[-2:]
200
-
201
- subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor
202
- subframe_step = frame_step // subframe_length
203
- subframes_per_frame = frame_length // subframe_length
204
- output_size = frame_step * (frames - 1) + frame_length
205
- output_subframes = output_size // subframe_length
206
-
207
- subframe_signal = signal.view(*outer_dimensions, -1, subframe_length)
208
-
209
- frame = torch.arange(0, output_subframes).unfold(0, subframes_per_frame, subframe_step)
210
- frame = signal.new_tensor(frame).long().cuda() # signal may in GPU or CPU
211
- frame = frame.contiguous().view(-1)
212
-
213
- result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length)
214
- result.index_add_(-2, frame, subframe_signal)
215
- result = result.view(*outer_dimensions, -1)
216
- return result