Spaces:
Running on Zero
Running on Zero
| from pathlib import Path | |
| import random | |
| import logging | |
| import numpy as np | |
| import librosa | |
| import soundfile as sf | |
| import json | |
| from typing import List, Optional, Dict, Union, Tuple, Any | |
| from torch.utils.data import Dataset, Sampler | |
| from tqdm import tqdm | |
| from data.augment import StemAugmentation, MixtureAugmentation | |
| import argparse | |
| import yaml | |
| from data.dataset import InfiniteSampler, RawStems | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Train a Music Source Restoration Model") | |
| parser.add_argument("--config", type=str, required=True, help="Path to the config file.") | |
| parser.add_argument("--output_dir", type=str, default="test/test_dataset", help="Output dir") | |
| parser.add_argument("--num_samples", type=int, default=10, help="Number of samples to generate") | |
| args = parser.parse_args() | |
| with open(args.config, 'r') as f: | |
| config = yaml.load(f, Loader=yaml.FullLoader) | |
| config = config['data'] | |
| common_params = { | |
| "sr": config['sample_rate'], | |
| "clip_duration": config['clip_duration'], | |
| } | |
| dataset = RawStems(**config['train_dataset'], **common_params) | |
| val_dataset = RawStems(**config['val_dataset'], **common_params) | |
| output_dir = args.output_dir | |
| Path(output_dir).mkdir(parents=True, exist_ok=True) | |
| print(f"Output directory created: {output_dir}") | |
| # Create a sampler | |
| sampler = InfiniteSampler(dataset) | |
| iterator = iter(sampler) | |
| # Sample for 5 iterations | |
| for i in tqdm(range(args.num_samples), desc="Sampling"): | |
| index = next(iterator) | |
| print(index) | |
| sample = dataset[index] | |
| print(sample["mixture"].shape) | |
| print(sample["target"].shape) | |
| sample["addition"] = sample["mixture"] - sample["target"] | |
| # Save the mixture and target | |
| mixture_path = Path(output_dir) / f"mixture_{i}.wav" | |
| target_path = Path(output_dir) / f"target_{i}.wav" | |
| addition_path = Path(output_dir) / f"addition_{i}.wav" | |
| sf.write(mixture_path, sample["mixture"].T, dataset.sr) | |
| sf.write(target_path, sample["target"].T, dataset.sr) | |
| sf.write(addition_path, sample["addition"].T, dataset.sr) |