|
19 | 19 |
|
20 | 20 | import tensorflow as tf
|
21 | 21 |
|
| 22 | +from tensorflow_addons import options |
22 | 23 | from tensorflow_addons.seq2seq import attention_wrapper
|
23 | 24 | from tensorflow_addons.seq2seq import decoder
|
24 | 25 | from tensorflow_addons.utils import keras_utils
|
25 | 26 | from tensorflow_addons.utils.resource_loader import LazySO
|
26 |
| -from tensorflow_addons.utils.types import FloatTensorLike, TensorLike |
| 27 | +from tensorflow_addons.utils.types import FloatTensorLike, TensorLike, Number |
27 | 28 |
|
28 | 29 | from typeguard import typechecked
|
29 | 30 | from typing import Callable, Optional
|
30 | 31 |
|
31 | 32 | _beam_search_so = LazySO("custom_ops/seq2seq/_beam_search_ops.so")
|
32 | 33 |
|
33 | 34 |
|
34 |
| -def gather_tree(*args, **kwargs) -> tf.Tensor: |
35 |
| - return _beam_search_so.ops.addons_gather_tree(*args, **kwargs) |
36 |
| - |
37 |
| - |
38 | 35 | class BeamSearchDecoderState(
|
39 | 36 | collections.namedtuple(
|
40 | 37 | "BeamSearchDecoderState",
|
@@ -151,6 +148,107 @@ def tile_batch(t: TensorLike, multiplier: int, name: Optional[str] = None) -> tf
|
151 | 148 | return tf.nest.map_structure(lambda t_: _tile_batch(t_, multiplier), t)
|
152 | 149 |
|
153 | 150 |
|
| 151 | +@tf.function( |
| 152 | + input_signature=( |
| 153 | + tf.TensorSpec([None, None, None], dtype=tf.int32), |
| 154 | + tf.TensorSpec([None, None, None], dtype=tf.int32), |
| 155 | + tf.TensorSpec([None], dtype=tf.int32), |
| 156 | + tf.TensorSpec([], dtype=tf.int32), |
| 157 | + ) |
| 158 | +) |
| 159 | +def _gather_tree(step_ids, parent_ids, max_sequence_lengths, end_token): |
| 160 | + input_shape = tf.shape(parent_ids) |
| 161 | + max_time = input_shape[0] |
| 162 | + beam_width = input_shape[2] |
| 163 | + max_sequence_lengths = tf.math.minimum(max_sequence_lengths, max_time) |
| 164 | + mask = tf.expand_dims( |
| 165 | + tf.transpose(tf.sequence_mask(max_sequence_lengths, maxlen=max_time)), -1 |
| 166 | + ) |
| 167 | + |
| 168 | + # Mask out of range ids. |
| 169 | + end_tokens = tf.fill(input_shape, end_token) |
| 170 | + step_ids = tf.where(mask, x=step_ids, y=end_tokens) |
| 171 | + parent_ids = tf.where(mask, x=parent_ids, y=tf.zeros_like(parent_ids)) |
| 172 | + assert_op = tf.debugging.Assert( |
| 173 | + tf.math.reduce_all( |
| 174 | + tf.math.logical_and(parent_ids >= 0, parent_ids < beam_width) |
| 175 | + ), |
| 176 | + ["All parent ids must be positive and less than beam_width"], |
| 177 | + ) |
| 178 | + |
| 179 | + # Reverse all sequences as we need to gather from the end. |
| 180 | + with tf.control_dependencies([assert_op]): |
| 181 | + rev_step_ids = tf.reverse_sequence( |
| 182 | + step_ids, max_sequence_lengths, seq_axis=0, batch_axis=1 |
| 183 | + ) |
| 184 | + rev_parent_ids = tf.reverse_sequence( |
| 185 | + parent_ids, max_sequence_lengths, seq_axis=0, batch_axis=1 |
| 186 | + ) |
| 187 | + |
| 188 | + # Initialize output ids and parent based on last step. |
| 189 | + output_ids = tf.TensorArray(step_ids.dtype, size=max_time, dynamic_size=False) |
| 190 | + output_ids = output_ids.write(0, rev_step_ids[0]) |
| 191 | + parent = rev_parent_ids[0] |
| 192 | + |
| 193 | + # For each step, gather ids based on beam origin. |
| 194 | + for t in tf.range(1, max_time): |
| 195 | + ids = tf.gather(rev_step_ids[t], parent, batch_dims=1) |
| 196 | + parent = tf.gather(rev_parent_ids[t], parent, batch_dims=1) |
| 197 | + output_ids = output_ids.write(t, ids) |
| 198 | + |
| 199 | + # Reverse sequences to their original order. |
| 200 | + output_ids = output_ids.stack() |
| 201 | + output_ids = tf.reverse_sequence( |
| 202 | + output_ids, max_sequence_lengths, seq_axis=0, batch_axis=1 |
| 203 | + ) |
| 204 | + |
| 205 | + # Ensure that there are only end_token after the first end_token. |
| 206 | + in_bound_steps = tf.math.cumsum(tf.cast(output_ids == end_token, tf.int32)) == 0 |
| 207 | + output_ids = tf.where(in_bound_steps, x=output_ids, y=end_tokens) |
| 208 | + return output_ids |
| 209 | + |
| 210 | + |
| 211 | +def gather_tree( |
| 212 | + step_ids: TensorLike, |
| 213 | + parent_ids: TensorLike, |
| 214 | + max_sequence_lengths: TensorLike, |
| 215 | + end_token: Number, |
| 216 | +) -> tf.Tensor: |
| 217 | + """Calculates the full beams from the per-step ids and parent beam ids. |
| 218 | +
|
| 219 | + For a given beam, past the time step containing the first decoded |
| 220 | + `end_token` all values are filled in with `end_token`. |
| 221 | +
|
| 222 | + Args: |
| 223 | + step_ids: The predicted token IDs. |
| 224 | + A `int32` `Tensor` of shape `[max_time, batch_size, beam_width]`. |
| 225 | + parent_ids: The parent beam indices. |
| 226 | + A `int32` `Tensor` of shape `[max_time, batch_size, beam_width]`. |
| 227 | + max_sequence_lengths: The maximum sequence length of each batch. |
| 228 | + A `int32` `Tensor` of shape `[batch_size]`. |
| 229 | + end_token: The end token ID. |
| 230 | +
|
| 231 | + Returns: |
| 232 | + The reordered token IDs based on `parent_ids`. |
| 233 | +
|
| 234 | + Raises: |
| 235 | + InvalidArgumentError: if `parent_ids` contains an invalid index. |
| 236 | + """ |
| 237 | + if not options.TF_ADDONS_PY_OPS: |
| 238 | + try: |
| 239 | + return _beam_search_so.ops.addons_gather_tree( |
| 240 | + step_ids, parent_ids, max_sequence_lengths, end_token |
| 241 | + ) |
| 242 | + except tf.errors.NotFoundError: |
| 243 | + options.warn_fallback("gather_tree") |
| 244 | + |
| 245 | + step_ids = tf.convert_to_tensor(step_ids, dtype=tf.int32) |
| 246 | + parent_ids = tf.convert_to_tensor(parent_ids, dtype=tf.int32) |
| 247 | + max_sequence_lengths = tf.convert_to_tensor(max_sequence_lengths, dtype=tf.int32) |
| 248 | + end_token = tf.convert_to_tensor(end_token, dtype=tf.int32) |
| 249 | + return _gather_tree(step_ids, parent_ids, max_sequence_lengths, end_token) |
| 250 | + |
| 251 | + |
154 | 252 | def gather_tree_from_array(
|
155 | 253 | t: TensorLike, parent_ids: TensorLike, sequence_length: TensorLike
|
156 | 254 | ) -> tf.Tensor:
|
|
0 commit comments