Spaces:
Build error
Build error
| """ | |
| Worker class implementation of the a3c discrete algorithm | |
| """ | |
| import os | |
| import numpy as np | |
| import torch | |
| import torch.multiprocessing as mp | |
| from torch import nn | |
| from .net import Net | |
| from .utils import v_wrap | |
| class Worker(mp.Process): | |
| def __init__( | |
| self, | |
| max_ep, | |
| gnet, | |
| opt, | |
| global_ep, | |
| global_ep_r, | |
| res_queue, | |
| name, | |
| env, | |
| N_S, | |
| N_A, | |
| words_list, | |
| word_width, | |
| winning_ep, | |
| model_checkpoint_dir, | |
| gamma=0.0, | |
| pretrained_model_path=None, | |
| save=False, | |
| min_reward=9.9, | |
| every_n_save=100, | |
| ): | |
| super(Worker, self).__init__() | |
| self.max_ep = max_ep | |
| self.name = "w%02i" % name | |
| self.g_ep = global_ep | |
| self.g_ep_r = global_ep_r | |
| self.res_queue = res_queue | |
| self.winning_ep = winning_ep | |
| self.gnet, self.opt = gnet, opt | |
| self.word_list = words_list | |
| # local network | |
| self.lnet = Net(N_S, N_A, words_list, word_width) | |
| if pretrained_model_path: | |
| self.lnet.load_state_dict(torch.load(pretrained_model_path)) | |
| self.env = env.unwrapped | |
| self.gamma = gamma | |
| self.model_checkpoint_dir = model_checkpoint_dir | |
| self.save = save | |
| self.min_reward = min_reward | |
| self.every_n_save = every_n_save | |
| def run(self): | |
| while self.g_ep.value < self.max_ep: | |
| s = self.env.reset() | |
| buffer_s, buffer_a, buffer_r = [], [], [] | |
| ep_r = 0.0 | |
| while True: | |
| a = self.lnet.choose_action(v_wrap(s[None, :])) | |
| s_, r, done, _ = self.env.step(a) | |
| ep_r += r | |
| buffer_a.append(a) | |
| buffer_s.append(s) | |
| buffer_r.append(r) | |
| if done: # update global and assign to local net | |
| # sync | |
| self.push_and_pull(done, s_, buffer_s, buffer_a, buffer_r) | |
| goal_word = self.word_list[self.env.goal_word] | |
| self.record(ep_r, goal_word, self.word_list[a], len(buffer_a)) | |
| self.save_model() | |
| buffer_s, buffer_a, buffer_r = [], [], [] | |
| break | |
| s = s_ | |
| self.res_queue.put(None) | |
| def push_and_pull(self, done, s_, bs, ba, br): | |
| if done: | |
| v_s_ = 0.0 # terminal | |
| else: | |
| v_s_ = self.lnet.forward(v_wrap(s_[None, :]))[-1].data.numpy()[0, 0] | |
| buffer_v_target = [] | |
| for r in br[::-1]: # reverse buffer r | |
| v_s_ = r + self.gamma * v_s_ | |
| buffer_v_target.append(v_s_) | |
| buffer_v_target.reverse() | |
| loss = self.lnet.loss_func( | |
| v_wrap(np.vstack(bs)), | |
| v_wrap(np.array(ba), dtype=np.int64) | |
| if ba[0].dtype == np.int64 | |
| else v_wrap(np.vstack(ba)), | |
| v_wrap(np.array(buffer_v_target)[:, None]), | |
| ) | |
| # calculate local gradients and push local parameters to global | |
| self.opt.zero_grad() | |
| loss.backward() | |
| for lp, gp in zip(self.lnet.parameters(), self.gnet.parameters()): | |
| gp._grad = lp.grad | |
| self.opt.step() | |
| # pull global parameters | |
| self.lnet.load_state_dict(self.gnet.state_dict()) | |
| def save_model(self): | |
| if ( | |
| self.save | |
| and self.g_ep_r.value >= self.min_reward | |
| and self.g_ep.value % self.every_n_save == 0 | |
| ): | |
| torch.save( | |
| self.gnet.state_dict(), | |
| os.path.join(self.model_checkpoint_dir, f"model_{self.g_ep.value}.pth"), | |
| ) | |
| def record(self, ep_r, goal_word, action, action_number): | |
| with self.g_ep.get_lock(): | |
| self.g_ep.value += 1 | |
| with self.g_ep_r.get_lock(): | |
| if self.g_ep_r.value == 0.0: | |
| self.g_ep_r.value = ep_r | |
| else: | |
| self.g_ep_r.value = self.g_ep_r.value * 0.99 + ep_r * 0.01 | |
| self.res_queue.put(self.g_ep_r.value) | |
| if goal_word == action: | |
| self.winning_ep.value += 1 | |
| if self.g_ep.value % 100 == 0: | |
| print( | |
| self.name, | |
| "Ep:", | |
| self.g_ep.value, | |
| "| Ep_r: %.0f" % self.g_ep_r.value, | |
| "| Goal :", | |
| goal_word, | |
| "| Action: ", | |
| action, | |
| "| Actions: ", | |
| action_number, | |
| ) | |