import tensorflow as tf
input_ids = tf.constant([[101, 2, 4, 5, 102, 0, 0, 0],
                         [101, 8, 6, 2, 4, 102, 0, 0]], tf.int32)
input_ids
<tf.Tensor: shape=(2, 8), dtype=int32, numpy=
array([[101,   2,   4,   5, 102,   0,   0,   0],
       [101,   8,   6,   2,   4, 102,   0,   0]], dtype=int32)>
# for saving gpu memory
def trim_padding(input_ids, padding_id=0):
    mask = tf.not_equal(input_ids, padding_id)
    mask = tf.reduce_any(mask, axis=0)
    input_ids = tf.boolean_mask(input_ids, mask=mask, axis=1)
    return input_ids
trim_padding(input_ids, padding_id=0)
<tf.Tensor: shape=(2, 6), dtype=int32, numpy=
array([[101,   2,   4,   5, 102,   0],
       [101,   8,   6,   2,   4, 102]], dtype=int32)>

Details

mask = tf.not_equal(input_ids, 0)    # ---> torch: mask = input_ids.ne(0)
mask
<tf.Tensor: shape=(2, 8), dtype=bool, numpy=
array([[ True,  True,  True,  True,  True, False, False, False],
       [ True,  True,  True,  True,  True,  True, False, False]])>
mask = tf.reduce_any(mask, axis=0)   # ---> torch: mask = input_ids.ne(0).any(axis=0)
mask
<tf.Tensor: shape=(8,), dtype=bool, numpy=array([ True,  True,  True,  True,  True,  True, False, False])>
tf.boolean_mask(input_ids, mask=mask, axis=1)   # ---> torch: input_ids[:, mask]
<tf.Tensor: shape=(2, 6), dtype=int32, numpy=
array([[101,   2,   4,   5, 102,   0],
       [101,   8,   6,   2,   4, 102]], dtype=int32)>

Share on: TwitterFacebookEmail

Comments


Related Posts


Published

Category

Programming

Tags

Contact