input_ids = tf.constant([[101, 2, 4, 5, 102, 0, 0, 0],
[101, 8, 6, 2, 4, 102, 0, 0]], tf.int32)
input_ids
<tf.Tensor: shape=(2, 8), dtype=int32, numpy=
array([[101, 2, 4, 5, 102, 0, 0, 0],
[101, 8, 6, 2, 4, 102, 0, 0]], dtype=int32)>
# for saving gpu memory
def trim_padding(input_ids, padding_id=0):
mask = tf.not_equal(input_ids, padding_id)
mask = tf.reduce_any(mask, axis=0)
input_ids = tf.boolean_mask(input_ids, mask=mask, axis=1)
return input_ids
trim_padding(input_ids, padding_id=0)
<tf.Tensor: shape=(2, 6), dtype=int32, numpy=
array([[101, 2, 4, 5, 102, 0],
[101, 8, 6, 2, 4, 102]], dtype=int32)>
Details
mask = tf.not_equal(input_ids, 0) # ---> torch: mask = input_ids.ne(0)
mask
<tf.Tensor: shape=(2, 8), dtype=bool, numpy=
array([[ True, True, True, True, True, False, False, False],
[ True, True, True, True, True, True, False, False]])>
mask = tf.reduce_any(mask, axis=0) # ---> torch: mask = input_ids.ne(0).any(axis=0)
mask
<tf.Tensor: shape=(8,), dtype=bool, numpy=array([ True, True, True, True, True, True, False, False])>
tf.boolean_mask(input_ids, mask=mask, axis=1) # ---> torch: input_ids[:, mask]
<tf.Tensor: shape=(2, 6), dtype=int32, numpy=
array([[101, 2, 4, 5, 102, 0],
[101, 8, 6, 2, 4, 102]], dtype=int32)>
Share on:
Twitter
❄ Facebook
❄ Email