Skip to content

Commit 532e51b

Browse files
guillaumeklnashutosh1919
authored andcommitted
Add a Python alternative to seq2seq.gather_tree (tensorflow#1925)
* Add a Python alternative to seq2seq.gather_tree * Enable tests for the Python op
1 parent 7628c62 commit 532e51b

File tree

4 files changed

+135
-14
lines changed

4 files changed

+135
-14
lines changed

tensorflow_addons/seq2seq/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ py_library(
66
name = "seq2seq",
77
srcs = glob(["*.py"]),
88
data = [
9+
"//tensorflow_addons:options.py",
910
"//tensorflow_addons/custom_ops/seq2seq:_beam_search_ops.so",
1011
],
1112
deps = [

tensorflow_addons/seq2seq/beam_search_decoder.py

Lines changed: 103 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,19 @@
1919

2020
import tensorflow as tf
2121

22+
from tensorflow_addons import options
2223
from tensorflow_addons.seq2seq import attention_wrapper
2324
from tensorflow_addons.seq2seq import decoder
2425
from tensorflow_addons.utils import keras_utils
2526
from tensorflow_addons.utils.resource_loader import LazySO
26-
from tensorflow_addons.utils.types import FloatTensorLike, TensorLike
27+
from tensorflow_addons.utils.types import FloatTensorLike, TensorLike, Number
2728

2829
from typeguard import typechecked
2930
from typing import Callable, Optional
3031

3132
_beam_search_so = LazySO("custom_ops/seq2seq/_beam_search_ops.so")
3233

3334

34-
def gather_tree(*args, **kwargs) -> tf.Tensor:
35-
return _beam_search_so.ops.addons_gather_tree(*args, **kwargs)
36-
37-
3835
class BeamSearchDecoderState(
3936
collections.namedtuple(
4037
"BeamSearchDecoderState",
@@ -151,6 +148,107 @@ def tile_batch(t: TensorLike, multiplier: int, name: Optional[str] = None) -> tf
151148
return tf.nest.map_structure(lambda t_: _tile_batch(t_, multiplier), t)
152149

153150

151+
@tf.function(
152+
input_signature=(
153+
tf.TensorSpec([None, None, None], dtype=tf.int32),
154+
tf.TensorSpec([None, None, None], dtype=tf.int32),
155+
tf.TensorSpec([None], dtype=tf.int32),
156+
tf.TensorSpec([], dtype=tf.int32),
157+
)
158+
)
159+
def _gather_tree(step_ids, parent_ids, max_sequence_lengths, end_token):
160+
input_shape = tf.shape(parent_ids)
161+
max_time = input_shape[0]
162+
beam_width = input_shape[2]
163+
max_sequence_lengths = tf.math.minimum(max_sequence_lengths, max_time)
164+
mask = tf.expand_dims(
165+
tf.transpose(tf.sequence_mask(max_sequence_lengths, maxlen=max_time)), -1
166+
)
167+
168+
# Mask out of range ids.
169+
end_tokens = tf.fill(input_shape, end_token)
170+
step_ids = tf.where(mask, x=step_ids, y=end_tokens)
171+
parent_ids = tf.where(mask, x=parent_ids, y=tf.zeros_like(parent_ids))
172+
assert_op = tf.debugging.Assert(
173+
tf.math.reduce_all(
174+
tf.math.logical_and(parent_ids >= 0, parent_ids < beam_width)
175+
),
176+
["All parent ids must be positive and less than beam_width"],
177+
)
178+
179+
# Reverse all sequences as we need to gather from the end.
180+
with tf.control_dependencies([assert_op]):
181+
rev_step_ids = tf.reverse_sequence(
182+
step_ids, max_sequence_lengths, seq_axis=0, batch_axis=1
183+
)
184+
rev_parent_ids = tf.reverse_sequence(
185+
parent_ids, max_sequence_lengths, seq_axis=0, batch_axis=1
186+
)
187+
188+
# Initialize output ids and parent based on last step.
189+
output_ids = tf.TensorArray(step_ids.dtype, size=max_time, dynamic_size=False)
190+
output_ids = output_ids.write(0, rev_step_ids[0])
191+
parent = rev_parent_ids[0]
192+
193+
# For each step, gather ids based on beam origin.
194+
for t in tf.range(1, max_time):
195+
ids = tf.gather(rev_step_ids[t], parent, batch_dims=1)
196+
parent = tf.gather(rev_parent_ids[t], parent, batch_dims=1)
197+
output_ids = output_ids.write(t, ids)
198+
199+
# Reverse sequences to their original order.
200+
output_ids = output_ids.stack()
201+
output_ids = tf.reverse_sequence(
202+
output_ids, max_sequence_lengths, seq_axis=0, batch_axis=1
203+
)
204+
205+
# Ensure that there are only end_token after the first end_token.
206+
in_bound_steps = tf.math.cumsum(tf.cast(output_ids == end_token, tf.int32)) == 0
207+
output_ids = tf.where(in_bound_steps, x=output_ids, y=end_tokens)
208+
return output_ids
209+
210+
211+
def gather_tree(
212+
step_ids: TensorLike,
213+
parent_ids: TensorLike,
214+
max_sequence_lengths: TensorLike,
215+
end_token: Number,
216+
) -> tf.Tensor:
217+
"""Calculates the full beams from the per-step ids and parent beam ids.
218+
219+
For a given beam, past the time step containing the first decoded
220+
`end_token` all values are filled in with `end_token`.
221+
222+
Args:
223+
step_ids: The predicted token IDs.
224+
A `int32` `Tensor` of shape `[max_time, batch_size, beam_width]`.
225+
parent_ids: The parent beam indices.
226+
A `int32` `Tensor` of shape `[max_time, batch_size, beam_width]`.
227+
max_sequence_lengths: The maximum sequence length of each batch.
228+
A `int32` `Tensor` of shape `[batch_size]`.
229+
end_token: The end token ID.
230+
231+
Returns:
232+
The reordered token IDs based on `parent_ids`.
233+
234+
Raises:
235+
InvalidArgumentError: if `parent_ids` contains an invalid index.
236+
"""
237+
if not options.TF_ADDONS_PY_OPS:
238+
try:
239+
return _beam_search_so.ops.addons_gather_tree(
240+
step_ids, parent_ids, max_sequence_lengths, end_token
241+
)
242+
except tf.errors.NotFoundError:
243+
options.warn_fallback("gather_tree")
244+
245+
step_ids = tf.convert_to_tensor(step_ids, dtype=tf.int32)
246+
parent_ids = tf.convert_to_tensor(parent_ids, dtype=tf.int32)
247+
max_sequence_lengths = tf.convert_to_tensor(max_sequence_lengths, dtype=tf.int32)
248+
end_token = tf.convert_to_tensor(end_token, dtype=tf.int32)
249+
return _gather_tree(step_ids, parent_ids, max_sequence_lengths, end_token)
250+
251+
154252
def gather_tree_from_array(
155253
t: TensorLike, parent_ids: TensorLike, sequence_length: TensorLike
156254
) -> tf.Tensor:

tensorflow_addons/seq2seq/tests/beam_search_decoder_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from tensorflow_addons.seq2seq import beam_search_decoder, gather_tree
2323

2424

25+
@pytest.mark.usefixtures("run_custom_and_py_ops")
2526
def test_gather_tree():
2627
# (max_time = 3, batch_size = 2, beam_width = 3)
2728

@@ -103,22 +104,27 @@ def _tile_in_depth(tensor):
103104
np.testing.assert_equal(expected_array.numpy(), sorted_array.numpy())
104105

105106

107+
@pytest.mark.usefixtures("run_custom_and_py_ops")
106108
def test_gather_tree_from_array_scalar():
107109
_test_gather_tree_from_array()
108110

109111

112+
@pytest.mark.usefixtures("run_custom_and_py_ops")
110113
def test_gather_tree_from_array_1d():
111114
_test_gather_tree_from_array(depth_ndims=1)
112115

113116

117+
@pytest.mark.usefixtures("run_custom_and_py_ops")
114118
def test_gather_tree_from_array_1d_with_merged_batch_beam():
115119
_test_gather_tree_from_array(depth_ndims=1, merged_batch_beam=True)
116120

117121

122+
@pytest.mark.usefixtures("run_custom_and_py_ops")
118123
def test_gather_tree_from_array_2d():
119124
_test_gather_tree_from_array(depth_ndims=2)
120125

121126

127+
@pytest.mark.usefixtures("run_custom_and_py_ops")
122128
def test_gather_tree_from_array_complex_trajectory():
123129
# Max. time = 7, batch = 1, beam = 5.
124130
array = np.expand_dims(
@@ -538,6 +544,7 @@ def get_probs():
538544
"cell_class", [tf.keras.layers.LSTMCell, tf.keras.layers.GRUCell]
539545
)
540546
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
547+
@pytest.mark.usefixtures("run_custom_and_py_ops")
541548
def test_beam_search_decoder(
542549
cell_class, time_major, has_attention, with_alignment_history
543550
):

tensorflow_addons/seq2seq/tests/beam_search_ops_test.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,15 @@
2020
import pytest
2121
import tensorflow as tf
2222

23+
from tensorflow_addons import options
2324
from tensorflow_addons.seq2seq import gather_tree
2425

2526

2627
def _transpose_batch_time(x):
2728
return np.transpose(x, [1, 0, 2]).astype(np.int32)
2829

2930

31+
@pytest.mark.usefixtures("run_custom_and_py_ops")
3032
def test_gather_tree_one():
3133
# (max_time = 4, batch_size = 1, beams = 3)
3234
end_token = 10
@@ -47,6 +49,7 @@ def test_gather_tree_one():
4749
np.testing.assert_equal(expected_result, beams.numpy())
4850

4951

52+
@pytest.mark.usefixtures("run_custom_and_py_ops")
5053
def test_bad_parent_values_on_cpu():
5154
# (batch_size = 1, max_time = 4, beams = 3)
5255
# bad parent in beam 1 time 1
@@ -57,7 +60,7 @@ def test_bad_parent_values_on_cpu():
5760
)
5861
max_sequence_lengths = [3]
5962

60-
with pytest.raises(tf.errors.InvalidArgumentError):
63+
with pytest.raises(tf.errors.InvalidArgumentError, match="parent id"):
6164
_ = gather_tree(
6265
step_ids=step_ids,
6366
parent_ids=parent_ids,
@@ -67,6 +70,7 @@ def test_bad_parent_values_on_cpu():
6770

6871

6972
@pytest.mark.with_device(["gpu"])
73+
@pytest.mark.usefixtures("run_custom_and_py_ops")
7074
def test_bad_parent_values_on_gpu():
7175
# (max_time = 4, batch_size = 1, beams = 3)
7276
# bad parent in beam 1 time 1; appears as a negative index at time 0
@@ -79,15 +83,26 @@ def test_bad_parent_values_on_gpu():
7983
expected_result = _transpose_batch_time(
8084
[[[2, -1, 2], [6, 5, 6], [7, 8, 9], [10, 10, 10]]]
8185
)
82-
beams = gather_tree(
83-
step_ids=step_ids,
84-
parent_ids=parent_ids,
85-
max_sequence_lengths=max_sequence_lengths,
86-
end_token=end_token,
87-
)
88-
np.testing.assert_equal(expected_result, beams.numpy())
86+
if options.TF_ADDONS_PY_OPS:
87+
# The Python version has the same behavior on CPU and GPU.
88+
with pytest.raises(tf.errors.InvalidArgumentError, match="parent id"):
89+
_ = gather_tree(
90+
step_ids=step_ids,
91+
parent_ids=parent_ids,
92+
max_sequence_lengths=max_sequence_lengths,
93+
end_token=end_token,
94+
)
95+
else:
96+
beams = gather_tree(
97+
step_ids=step_ids,
98+
parent_ids=parent_ids,
99+
max_sequence_lengths=max_sequence_lengths,
100+
end_token=end_token,
101+
)
102+
np.testing.assert_equal(expected_result, beams.numpy())
89103

90104

105+
@pytest.mark.usefixtures("run_custom_and_py_ops")
91106
def test_gather_tree_batch():
92107
batch_size = 10
93108
beam_width = 15
@@ -123,7 +138,7 @@ def test_gather_tree_batch():
123138
found = np.where(v == end_token)[0]
124139
found = found[0] # First occurrence of end_token.
125140
# If an end_token is found, everything before it should be a
126-
# valid id and everything after it should be -1.
141+
# valid id and everything after it should be end_token.
127142
if found > 0:
128143
np.testing.assert_equal(
129144
v[: found - 1] >= 0, np.ones_like(v[: found - 1], dtype=bool),

0 commit comments

Comments
 (0)