PyTorch使用torch.nn.DataParallel进行多GPU训练的一个BUG,已解决

解决了PyTorch 使用torch.nn.DataParallel 进行多GPU训练的一个BUG:

模型(参数)和数据不在相同设备上 我使用torch.nn.DataParallel进行多GPU训练时出现了一个BUG,困扰了我许久:

RuntimeError: Expected tensor for argument #1 'input' to have the same device as tensor for argument #2 'weight'; but device 1 does not equal 0 (while checking arguments for cudnn_convolution)

这个错误表明, input数据在device 1 上, 而模型的参数在device 0 上 (暗示数据是被拆分到了各个GPU上,但是BUG出现位置的此处参数可能没有成功复制到其他GPU上, 或者说, 还是调用了复制前的那个参数地址)

因为模型比较复杂,继承与调用太多,之前调试了好久, 也没有解决掉, 在Github上有一个issue和我的问题很像: https://github.com/pytorch/pytorch/issues/8637 但是我还是没有找到自己的bug在哪里.

今天, 我又准备再此尝试解决它

经过6个小时的print调试法以及后面关键的VScode的Debug功能, 我大功告成,找到了问题所在,原来

我的A(nn.Module)类的forward 前向计算函数里面, 有一处调用了一个该类的列表self.cell_fabrics, 其列表的元素是通过self.cell_fabrics = [self.cell_1, self.cell_2,...,self.cell_n] 来赋值的,其中每个self.cell也是nn.Module

即用self.cell_fabrics = [self.cell_0_0, self.cell_0_1, … , self.cell_3_0, self.cell_3_1,…, self.cell_5_0] 这样的方式,将所有的cell类放到A类的一个列表属性中, 而当整个A类通过torch.nn.DataParallel被复制了一份放到设备cuda:1上以后, 且在预计在设备cuda:1上执行下面代码段时:

1
2
3
4
def forward(self,x)
for layer in self.cell_fabrics:
for cell in layer:
y = cell (x,)

经验证,x是在设备cuda:1 上面, 但是 cell 中的参数却明显都在 cuda:0

也就是说:

此时 self.cell_fabrics 的列表中保存的各个对象 (self.cell) 的地址,还是指向在没有进行torch.nn.DataParallel之前的nn.Module 的那些self.cell, 而nn.DataParallel类的nn.Module的参数都默认存放在device(type='cuda',index=0)上 .

torch.nn.DataParallel(model,device_ids=[range(len(gpus))])的机制是, 将属于nn.Module类的model以及其广播的所有nn.Module子类的上的所有参数,复制成len(gpus)份,送到各个GPU上. 这种广播机制的范围是注册(registered)成为其属性的nn.Module子类, 属性为列表list中的各个对象是不会被复制的, 所以其list中的对象还是存放在默认设备device 0

所以 在使用torch.nn.DataParallel进行多GPU训练的时候, 请注意:所有属于模型参数的模块以及其子模块必须以nn.Module的类型注册为模型的属性, 如果需要一个列表来批量存放子模块或者参数的话, 请采用nn.ModuleList或者nn.ModuleDict这样的继承了nn.Module的类来进行定义, 并且在forward(self,)前向传播的过程中,需要直接调用属于 nn.Module,nn.ModuleList或者nn.ModuleDict 这样的属性。

那么torch.nn.DataParallel将会正常地将模型参数准确复制到多个GPU上, 并根据数据的batchsize的大小平分成GPU的数量分别送到相应的GPU设备上,

然后运用多线程的方式, 同时对这些数据进行加工处理, 然后收集各个GPU上最终产生对模型的各参数的梯度, 最后汇总到一起更新原模型的参数!

参考: 1. https://github.com/pytorch/pytorch/issues/8637 2. https://pytorch.org/docs/stable/nn.html#dataparallel-layers-multi-gpu-distributed