- Python深度学习:基于TensorFlow(第2版)
- 吴茂贵 王冬等
- 586字
- 2024-04-12 18:43:58
1.6 广播机制
NumPy的通用函数(ufunc)中要求输入的数组shape是一致的,当数组的shape不一致时,则会用到广播机制。不过,调整数组使得shape一样时需满足一定规则,否则将出错。广播机制中的这些规则可归结为以下四条。
1)让所有输入数组都向其中shape最长的数组看齐,shape中不足的部分都通过在前面加1补齐;如对于数组a(2×3×2)和数组b(3×2),则b向a看齐,在b的前面加1,变为1×3×2。
2)输出数组的shape是输入数组shape的各个轴上的最大值。
3)如果输入数组的某个轴和输出数组的对应轴的长度相同或者长度为1时,则可以调整,否则将出错。
4)当输入数组的某个轴的长度为1时,沿着此轴运算时都用(或复制)此轴上的第一组值。
广播机制在整个NumPy中用于决定如何处理形状迥异的数组,涉及的算术运算包括+、-、*、/。这些规则虽然很严谨,但不直观。下面我们结合图形与代码做进一步说明。
目的:A+B。其中A为4×1矩阵,B为一维向量(3,)。要相加,需要做如下处理。
1)根据规则1,B需要向A看齐,把B变为(1, 3)。
2)根据规则2,输出的结果为各个轴上的最大值,即输出结果应该为(4, 3)矩阵。那么A如何由(4, 1)变为(4, 3)矩阵?B如何由(1, 3)变为(4, 3)矩阵?
3)根据规则4,用此轴上的第一组值(主要区分是哪个轴)进行复制即可。(但在实际处理中不是真正复制,而是采用其他对象,如ogrid对象,进行网格处理,否则太耗内存。)如图1-10所示。
图1-10 NumPy广播机制示意图
具体实现如下:
运行结果如下: