pytorch加载自定义网络权重的实现

 更新时间:2020-01-07 22:10:56   作者:佚名   我要评论(0)

在将自定义的网络权重加载到网络中时,报错:
AttributeError: 'dict' object has no attribute 'seek'. You can only torch.load from a file that is seekable. P

在将自定义的网络权重加载到网络中时,报错:

AttributeError: 'dict' object has no attribute 'seek'. You can only torch.load from a file that is seekable. Please pre-load the data into a buffer like io.BytesIO and try to load from it instead.

我们一步一步分析。

果博东方模型网络权重保存额代码是:torch.save(net.state_dict(),'net.pkl')

(1)查看获取模型权重的源码:

pytorch源码:net.state_dict()

def state_dict(self, destination=None, prefix='', keep_vars=False):  r"""Returns a dictionary containing a whole state of the module.  Both parameters and persistent buffers (e.g. running averages) are  included. Keys are corresponding parameter and buffer names.  Returns:    dict:      a dictionary containing a whole state of the module  Example::    >>> module.state_dict().keys()    ['bias', 'weight']  """

将网络中所有的状态保存到一个字典中了,我自己构建的就是一个字典,没问题!

(2)查看保存模型权重的源码:

果博东方pytorch源码:torch.save()

def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):  """Saves an object to a disk file.  See also: :ref:`recommend-saving-models`  Args:    obj: saved object    f: a file-like object (has to implement write and flush) or a string      containing a file name    pickle_module: module used for pickling metadata and objects    pickle_protocol: can be specified to override the default protocol  .. warning::    If you are using Python 2, torch.save does NOT support StringIO.StringIO    as a valid file-like object. This is because the write method should return    the number of bytes written; StringIO.write() does not do this.    Please use something like io.BytesIO instead.

果博东方函数功能是将字典保存为磁盘文件(二进制数据),那么我们在torch.load()时,就是在内存中加载二进制数据,这就是报错点。

解决方案:将字典保存为BytesIO文件之后,模型再net.load_state_dict()

#b为自定义的字典torch.save(b,'new.pkl')net.load_state_dict(torch.load(b))

解决方法很简单,主要记录解决思路。

以上这篇pytorch加载自定义网络权重的实现就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

您可能感兴趣的文章:

  • pytorch动态网络以及权重共享实例
  • pytorch自定义初始化权重的方法
  • Pytorch: 自定义网络层实例
  • Pytorch 实现权重初始化
  • pytorch 自定义数据集加载方法

果博东方相关的文章

  • pytorch加载自定义网络权重的实现

    pytorch加载自定义网络权重的实现

    在将自定义的网络权重加载到网络中时,报错:AttributeError: 'dict' object has no attribute 'seek'. You can only torch.load from a file that is seekable. P
    2020-01-07
  • python enumerate内置函数用法总结

    python enumerate内置函数用法总结

    这篇文章主要介绍了python enumerate内置函数用法总结,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下 enu
    2020-01-07
  • python模拟实现斗地主发牌

    python模拟实现斗地主发牌

    题目:趣味百题之斗地主扑克牌是一种非常大众化的游戏,在计算机中有很多与扑克牌有关的游戏。例如,在Windows操作系统下自带的纸牌、红心大战等。在扑克牌类的游戏
    2020-01-07
  • Laravel5.1 框架表单验证操作实例详解

    Laravel5.1 框架表单验证操作实例详解

    本文实例讲述了Laravel5.1 框架表单验证操作。分享给大家供大家参考,具体如下:当我们提交表单时 通常会对提交过来的数据进行一些验证、Laravel在Controller类中使
    2020-01-07
  • Python内置数据类型list各方法的性能测试过程解析

    Python内置数据类型list各方法的性能测试过程解析

    这篇文章主要介绍了Python内置数据类型list各方法的性能测试过程解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以
    2020-01-07
  • PyTorch中的Variable变量详解

    PyTorch中的Variable变量详解

    一、了解Variable顾名思义,Variable就是 变量 的意思。实质上也就是可以变化的量,区别于int变量,它是一种可以变化的变量,这正好就符合了反向传播,参数更新的属
    2020-01-07
  • VMwarea虚拟机安装win7操作系统的教程图解

    VMwarea虚拟机安装win7操作系统的教程图解

    VMwarea的安装过程就不演示了,主要看看如何装入win7镜像1、下载win7镜像链接: http://pan.baidu.com/s/1Kht7v0IFtF_p7holFyME0A提取码: hk9m2、下载完成后运行
    2020-01-07
  • Pytorch 中retain_graph的用法详解

    Pytorch 中retain_graph的用法详解

    用法分析 在查看SRGAN源码时有如下损失函数,其中设置了retain_graph=True,其作用是什么? ############################ # (1) Update D network: maximi
    2020-01-07
  • Laravel5.1 框架数据库操作DB运行原生SQL的方法分析

    Laravel5.1 框架数据库操作DB运行原生SQL的方法分析

    本文实例讲述了Laravel5.1 框架数据库操作DB运行原生SQL的方法。分享给大家供大家参考,具体如下:Laravel操作数据库有三种:DB原生SQL、构建器、Model。这三种依情
    2020-01-07
  • 解决torch.autograd.backward中的参数问题

    解决torch.autograd.backward中的参数问题

    torch.autograd.backward(variables, grad_variables=None, retain_graph=None, create_graph=False)给定图的叶子节点variables, 计算图中变量的梯度和。 计算图可
    2020-01-07

最新评论