xlance-msr / test_dataset.py
ModistAndrew
allowed_others; rawstem v2
ff9f510
from pathlib import Path
import random
import logging
import numpy as np
import librosa
import soundfile as sf
import json
from typing import List, Optional, Dict, Union, Tuple, Any
from torch.utils.data import Dataset, Sampler
from tqdm import tqdm
from data.augment import StemAugmentation, MixtureAugmentation
import argparse
import yaml
from data.dataset import InfiniteSampler, RawStems
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train a Music Source Restoration Model")
parser.add_argument("--config", type=str, required=True, help="Path to the config file.")
parser.add_argument("--output_dir", type=str, default="test/test_dataset", help="Output dir")
parser.add_argument("--num_samples", type=int, default=10, help="Number of samples to generate")
args = parser.parse_args()
with open(args.config, 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader)
config = config['data']
common_params = {
"sr": config['sample_rate'],
"clip_duration": config['clip_duration'],
}
dataset = RawStems(**config['train_dataset'], **common_params)
val_dataset = RawStems(**config['val_dataset'], **common_params)
output_dir = args.output_dir
Path(output_dir).mkdir(parents=True, exist_ok=True)
print(f"Output directory created: {output_dir}")
# Create a sampler
sampler = InfiniteSampler(dataset)
iterator = iter(sampler)
# Sample for 5 iterations
for i in tqdm(range(args.num_samples), desc="Sampling"):
index = next(iterator)
print(index)
sample = dataset[index]
print(sample["mixture"].shape)
print(sample["target"].shape)
sample["addition"] = sample["mixture"] - sample["target"]
# Save the mixture and target
mixture_path = Path(output_dir) / f"mixture_{i}.wav"
target_path = Path(output_dir) / f"target_{i}.wav"
addition_path = Path(output_dir) / f"addition_{i}.wav"
sf.write(mixture_path, sample["mixture"].T, dataset.sr)
sf.write(target_path, sample["target"].T, dataset.sr)
sf.write(addition_path, sample["addition"].T, dataset.sr)