简讯:第二十四篇—模型中的parameter和buffer

2023-03-31 17:54:55 来源:哔哩哔哩

部分一:模型保存


【资料图】

模型保存方式:

模型加载方式:

保存的内容是model.state_dict()的返回对象,是一个OrderedDict,以键值对(key-val)的形式包含模型中需要保存下来的参数,如:

打印模型:

打印模型参数:

部分二:模型中的parameter和buffer

parameter在反向传播时会被optimizer.step更新,而buffer在反向传播时不会被更新;

parameter和buffer都保存在model.state_dict()返回的OrderedDict对象中;

模型进行设备移动时,模型中注册的参数(parameter和buffer),即model.state_dict()中的内容会同时进行移动。

创建parameter和buffer

创建parameter:

方法一:直接将模型的成员变量self.xxx通过nn.Parameter()创建,会自动注册到model.parameters()中;

方法二:先通过nn.Parameter()创建普通的parameter对象,此时该对象不作为模型的成员变量,然后将parameter对象通过register_parameter()注册到model.parameters()中;

创建buffer:

通过register_buffer()注册到model.buffers()中;

注意一:parameter和buffer都保存在model.state_dict()中,举例如下:

结果如下:

注意二:需要进行梯度更新的模型参数属于parameter,而不是buffer中。实际上,parameter的创建在torch.nn.Linear类中的__init__函数中完成,成员变量weight和bias都属于parameter对象,并在这里进行了初始化,__init__函数如下图:

代码举例:

打印结果:

parameter在反向传播时会被optimizer.step更新,而buffer在反向传播时不会被更新,举例如下:

结果如下(buffer不变;parameter更新):

注意三:数据转换为buffer的原因

原因一:parameter和buffer都保存在model.state_dict()返回的OrderedDict对象中,如果保存模型的方法是torch.save(model.state_dict(),PATH),那么类中普通成员数据如果不注册为parameter或buffer,就不能被保存,因为只有model.state_dict()中的内容会被保存,即只保存parameter和buffer;

原因二:模型进行设备移动时,模型中注册的参数(parameter和buffer),即model.state_dict()中的内容会同时进行移动,而类中普通成员数据如果不注册为parameter或buffer,就不能被移动,如下所示:

结果如下(只有model.state_dict()中的内容,即parameter和buffer,转移到了GPU;而普通成员变量没被转移,仍在CPU上):

标签:

推荐阅读>