Broadcast
-
expand_dims
-
without copying data vs tf.tile
-
tf.broadcast_to
Key idea
-
如何a,b张量维度不一致,则插入1dim。小维度对齐
-
expand 1dim到相同dims
e.g. a = [4, 32, 32, 3]
b = [3] -> [1, 1, 1, 3] -> [4, 32, 32, 3]
tf.broadcast_to
import tensorflow as tf
x = tf.random.normal([4, 32, 32, 3])
y = tf.random.normal([4, 1, 1, 1])
x.shape, y.shape
(TensorShape([4, 32, 32, 3]), TensorShape([4, 1, 1, 1]))
# 隐式broadcast
(x + y).shape
TensorShape([4, 32, 32, 3])
# 显式broadcast
(x + tf.broadcast_to(y, [4, 32, 32, 3])).shape
TensorShape([4, 32, 32, 3])
Broadcast VS Tile
broadcast数据并没有真正复制,tile数据真实复制,broadcast更高效且节省内存空间。
a = tf.ones([3, 4])
a.shape
TensorShape([3, 4])
# broadcast
a1 = tf.broadcast_to(a, [2, 3, 4])
a1
<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]],
[[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]]], dtype=float32)>
# tile
a2 = tf.expand_dims(a, axis=0)
a2.shape
TensorShape([1, 3, 4])
a2 = tf.tile(a2, [2, 1, 1]) # [2, 1, 1]各个轴复制次数
a2
<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]],
[[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]]], dtype=float32)>