File size: 2,305 Bytes
fca537b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d35d2ae
dd247d2
fca537b
 
 
 
 
 
 
 
 
 
 
 
 
 
4aa542b
fca537b
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import argparse
import yamlargparse
import torch.nn as nn

from networks import network_wrapper

class main(nn.Module):
    def __init__(self):
        super(main, self).__init__()

    def load_args_eeyd(self, model_name):
        self.config_path = f'config/{model_name}.yaml'
        parser = yamlargparse.ArgumentParser("Settings")

        # General model and inference settings
        parser.add_argument('--config', help='Config file path', action=yamlargparse.ActionConfigFile)
        parser.add_argument('--mode', type=str, default='inference', help='Modes: train or inference')
        parser.add_argument('--use-cuda', dest='use_cuda', default=1, type=int, help='Enable CUDA (1=True, 0=False)')
        parser.add_argument('--num-gpu', dest='num_gpu', type=int, default=1, help='Number of GPUs to use')
        parser.add_argument('--checkpoint-dir', dest='checkpoint_dir', type=str, default='checkpoints/EEYD_base', help='Checkpoint directory')

        # Model-specific settings
        parser.add_argument('--network_audio', type=dict)
        parser.add_argument('--network_reference', type=dict)
        parser.add_argument('--sampling-rate', dest='sampling_rate', type=int, default=16000, help='Sampling rate')
        parser.add_argument('--one-time-decode-length', dest='one_time_decode_length', type=int, default=60, help='Max segment length for one-pass decoding')
        parser.add_argument('--decode-window', dest='decode_window', type=int, default=1, help='Decoding chunk size')
        parser.add_argument('--output_residual',  type=int, default=0)
        parser.add_argument('--mix_precision',  type=int, default=0, help='whether to perform mix precision training')

        # Parse arguments from the config file
        self.args = parser.parse_args(['--config', self.config_path])
        self.args.model_name = model_name


    def __call__(self, model_name):
        self.load_args_eeyd(model_name)
        self.network = network_wrapper(self.args)
        return self.network



class extract_everything:
    def __init__(self, model_name="EEYD_locoformer"):    
        self.args_wrapper = main()
        self.model = self.args_wrapper(model_name)

    def __call__(self, input_wav, input_text_prompt):
        return self.model.process(input_wav, input_text_prompt)