基于ray的分布式机器学习(二)

基本思路
基于parameter server + multiple workers模式。
同步方式
parameter server负责网络参数的统一管理,每次迭代均将参数发送给每一个worker,多个worker同时迭代数据集,计算当前批次的损失和梯度,
当所有worker全部完成当前批次的计算后,将每个worker的梯度回传给parameter server,parameter server使用该梯度进行参数优化。
异步方式
与同步方式不同的是,parameter server不需要每次等待所有worker全部完成一个批次的计算后再利用所有worker的梯度更新网络参数,
而是每当有一个worker完成一个批次的计算时,立刻进行网络参数的更新,并将新的参数下发给该worker。
1、定义模型 class ConvNet(nn.Module): def __init__(self): super(ConvNet, self).__init__() self.conv1 = nn.Conv2d(1, 3, kernel_size=3) self.fc = nn.Linear(192, 10) def forward(self, x): x = F.relu(F.max_pool2d(self.conv1(x), 3)) x = x.view(-1, 192) x = self.fc(x) return F.log_softmax(x, dim=1) def get_weights(self): return {k: v.cpu() for k, v in self.state_dict().items()} def set_weights(self, weights): self.load_state_dict(weights) def get_gradients(self): grads = [] for p in self.parameters(): grad = None if p.grad is None else p.grad.data.cpu().numpy() grads.append(grad) return grads def set_gradients(self, gradients): for g, p in zip(gradients, self.parameters()): if g is not None: p.grad = torch.from_numpy(g) 2、定义parameter server @ray.remote class ParameterServer(object): def __init__(self, lr): self.model = ConvNet() self.optimizer = torch.optim.SGD(self.model.parameters(), lr=lr) def apply_gradients(self, *gradients): summed_gradients = [ np.stack(gradient_zip).sum(axis=0) for gradient_zip in zip(*gradients) ] self.optimizer.zero_grad() self.model.set_gradients(summed_gradients) self.optimizer.step() return self.model.get_weights() def get_weights(self): return self.model.get_weights() 3、定义worker @ray.remote class DataWorker(object): def __init__(self): self.model = ConvNet() self.data_iterator = iter(get_data_loader()[0]) def compute_gradients(self, weights): self.model.set_weights(weights) try: data, target = next(self.data_iterator) except StopIteration: # When the epoch ends, start a new epoch. self.data_iterator = iter(get_data_loader()[0]) data, target = next(self.data_iterator) self.model.zero_grad() output = self.model(data) loss = F.nll_loss(output, target) loss.backward() return self.model.get_gradients() 4、同步训练 iterations = 200 num_workers = 2 ray.init(ignore_reinit_error=True) ps = ParameterServer.remote(1e-2) workers = [DataWorker.remote() for i in range(num_workers)] model = ConvNet() test_loader = get_data_loader()[1] print("Running synchronous parameter server training.") current_weights = ps.get_weights.remote() for i in range(iterations): gradients = [ worker.compute_gradients.remote(current_weights) for worker in workers ] current_weights = ps.apply_gradients.remote(*gradients) if i % 10 == 0: model.set_weights(ray.get(current_weights)) accuracy = evaluate(model, test_loader) print("Iter {}: accuracy is {:.1f}".format(i, accuracy)) print("Final accuracy is {:.1f}.".format(accuracy)) ray.shutdown() 5、异步训练 print("Running Asynchronous Parameter Server Training.") ray.init(ignore_reinit_error=True) ps = ParameterServer.remote(1e-2) workers = [DataWorker.remote() for i in range(num_workers)] current_weights = ps.get_weights.remote() gradients = {} for worker in workers: gradients[worker.compute_gradients.remote(current_weights)] = worker for i in range(iterations * num_workers): ready_gradient_list, _ = ray.wait(list(gradients)) ready_gradient_id = ready_gradient_list[0] worker = gradients.pop(ready_gradient_id) current_weights = ps.apply_gradients.remote(*[ready_gradient_id]) gradients[worker.compute_gradients.remote(current_weights)] = worker if i % 10 == 0: model.set_weights(ray.get(current_weights)) accuracy = evaluate(model, test_loader) print("Iter {}: accuracy is {:.1f}".format(i, accuracy)) print("Final accuracy is {:.1f}.".format(accuracy))
原文地址:https://www.cnblogs.com/zcsh/p/14206727.html