![Python深度学习:基于TensorFlow(第2版)](https://wfqqreader-1252317822.image.myqcloud.com/cover/658/48593658/b_48593658.jpg)
上QQ阅读APP看本书,新人免费读10天
设备和账号都新为新人
1.6 广播机制
NumPy的通用函数(ufunc)中要求输入的数组shape是一致的,当数组的shape不一致时,则会用到广播机制。不过,调整数组使得shape一样时需满足一定规则,否则将出错。广播机制中的这些规则可归结为以下四条。
1)让所有输入数组都向其中shape最长的数组看齐,shape中不足的部分都通过在前面加1补齐;如对于数组a(2×3×2)和数组b(3×2),则b向a看齐,在b的前面加1,变为1×3×2。
2)输出数组的shape是输入数组shape的各个轴上的最大值。
3)如果输入数组的某个轴和输出数组的对应轴的长度相同或者长度为1时,则可以调整,否则将出错。
4)当输入数组的某个轴的长度为1时,沿着此轴运算时都用(或复制)此轴上的第一组值。
广播机制在整个NumPy中用于决定如何处理形状迥异的数组,涉及的算术运算包括+、-、*、/。这些规则虽然很严谨,但不直观。下面我们结合图形与代码做进一步说明。
目的:A+B。其中A为4×1矩阵,B为一维向量(3,)。要相加,需要做如下处理。
1)根据规则1,B需要向A看齐,把B变为(1, 3)。
2)根据规则2,输出的结果为各个轴上的最大值,即输出结果应该为(4, 3)矩阵。那么A如何由(4, 1)变为(4, 3)矩阵?B如何由(1, 3)变为(4, 3)矩阵?
3)根据规则4,用此轴上的第一组值(主要区分是哪个轴)进行复制即可。(但在实际处理中不是真正复制,而是采用其他对象,如ogrid对象,进行网格处理,否则太耗内存。)如图1-10所示。
![](https://epubservercos.yuewen.com/AE244A/28235093307371206/epubprivate/OEBPS/Images/36_01.jpg?sign=1739426063-pRkn5WxVCoFAxFNgNsd8th4QUqVa62cw-0-7252e8d65567eef0429b0bd6111ef92b)
图1-10 NumPy广播机制示意图
具体实现如下:
![](https://epubservercos.yuewen.com/AE244A/28235093307371206/epubprivate/OEBPS/Images/36_02.jpg?sign=1739426063-15ekVnxdkgbC0XvEifaPJE3I2kTAlbVV-0-043c1100a4213acaa743304d0d4a9a8b)
运行结果如下:
![](https://epubservercos.yuewen.com/AE244A/28235093307371206/epubprivate/OEBPS/Images/36_03.jpg?sign=1739426063-FUMrmj9dDIN9pEMf8CMUIaxSlnA21Lq1-0-7b866ba33f6f78621fa6c36fcfc1a0ec)