import os
import matplotlib
import matplotlib.pyplot as plt
import numpy as np

data_id = 'misc'

# run gffm with different pairs of hyperparameters
pos_size = 2**np.arange(6, 11, 1)
dir_size = np.array([3])
trials = [1]

for trial in trials:
        for d in dir_size:
                for p in pos_size:
                        command = f'python residual_regression.py --config configs/{data_id}.txt '+\
                                f'--train_images=800 --num_epochs=100 --logdir=logs/{data_id}-tune-gffm/trial-{trial}/pos_{p}_dir_{d} --model gffm --gffm_pos {p} --gffm_dir {d} --batch_rays=60000'
                        print(command)
                        os.system(command)

'''
Make figure
'''
params = {'legend.fontsize': 12,
         'axes.labelsize': 12,
         'axes.titlesize': 13,
         'xtick.labelsize':10,
         'ytick.labelsize':10}
matplotlib.rcParams.update(params)

plt.figure(figsize=(5,4))
ax = plt.gca()
for d in dir_size:
        mean = np.zeros((len(pos_size,)))
        for i, p in enumerate(pos_size):
                for trial in trials:
                        result_path = f'logs/{data_id}-tune-gffm/trial-{trial}/pos_{p}_dir_{d}/gffm-L-1/result/test_psnr.npy'
                        psnr = np.load(result_path)
                        mse = 10**(-psnr/10)
                        mean[i] += mse.mean()
                mean[i] /= len(trials)
                mean[i] = 10*np.log10(1./mean[i])
                
        ax.plot(np.array(pos_size), mean, label=r'$\theta_{\omega}='+f'{d}'+r'$')

# ax.axhline(best_ffm_psnr, color='black', linestyle='--', label='best FFM')
ax.set_xlabel(r'Positional scale $\theta_{\mu}$')
ax.set_xlim((pos_size[0], pos_size[-1]))
ax.set_title('(b) GFFM hyperparameter tuning', y=-0.4)
ax.grid(True, which='major', alpha=.3)
ax.set_xscale('log', basex=2)
ax.set_ylabel('Mean PSNR')

plt.legend(loc='center left', bbox_to_anchor=(1,.5), handlelength=1)
plt.tight_layout()
plt.savefig('fig_gffm_sweep.png')
plt.show()
