PyTorch深度学习应用实战
上QQ阅读APP看书,第一时间看更新

1-3 TensorFlow对比PyTorch

深度学习(Deep Learning)的框架过去曾经百家争鸣,数量多达20多种,然而经过几年下来的厮杀,目前仅存在几个我们常用的主流框架了。

图1.5 2018年深度学习框架及评分

图片来源:Top 5 Deep Learning Frameworks to Watch in 2021 and Why TensorFlow[2]

目前比较常用的框架包括Google TensorFlow、Facebook PyTorch、Apache MXNet、Berkeley Caffe,其中又以TensorFlow、PyTorch占有率较高,一般企业广泛使用的是TensorFlow,而学术界则是偏好PyTorch。两者的功能也是互相模仿与竞争,差异比较整理如表1.1所示。

表1.1 TensorFlow、PyTorch比较

由上表可见两者的功能基本上大同小异,因此有人认为既然很相似,那就学习其中一种即可,但考虑到它们所专长的应用领域各有不同,并且网络上的扩充框架或范例程序常常只用其中一种语言开发,所以,同时熟悉TensorFlow、PyTorch,会是一个比较周全的选择。

好在TensorFlow、PyTorch都是深度学习的框架,彼此间有共通的概念,只要遵循本书的学习路径就能一举两得,没有想象中困难。孰悉两个框架,还有助于设计概念的深入了解。因此本书介绍PyTorch的方式,会与另一本以TensorFlow为主题的《深度学习全书:公式+推导+代码+TensorFlow全程案例》[3]相互对照。

根据官网说明及个人使用经验,PyTorch有以下特点。

(1)Python First:PyTorch与Python完美整合,在定义模型类别内可以任意加入侦错或转换的Python程序代码,TensorFlow/Keras则需通过Callback才能在模型训练过程中传出信息,PyTorch官方认为Python有丰富的套件,例如NumPy、SciPy、Scikitlearn等,无须另外发明轮子(reinvent the wheel where appropriate)。

(2)除错容易:TensorFlow/Keras提供fit指令进行模型训练,虽然简单,但不易侦错,PyTorch则须自行撰写优化程序,虽然烦琐,但可随时查看预测结果及损失函数变化,不必等到模型训练完后才能查看结果。

(3)GPU内存管理较佳:笔者使用GTX1050Ti,内存只有4GB时,执行2个以TensorFlow开发的Notebook文件时,常会发生内存不足的状况,但使用PyTorch,即使3、4个Notebook文件也没有问题。

(4)简洁快速:与Intel MKL和NVIDIA (cuDNN, NCCL)函数库整合,可提升执行的速度,程序可自由选择CPU或GPU运算,自由掌控内存的使用量。

(5)无痛扩充:PyTorch提供C/C++ extension API,有效整合资源,不需桥接的包装程序(wrapper)。