Skip to content

Commit 339159f

Browse files
Point optimizer to tf.keras.optimizer.legacy.Optimizer to be compatib… (#2706)
* Point optimizer to tf.keras.optimizer.legacy.Optimizer to be compatible with Keras optimizer migration * small fix * add version control * small fix * Update discriminative_layer_training.py * fix version control * small fix * move optimizer class to __init__.py * small fix * fix problems * small fix * Rename BaseOptimizer to KerasLegacyOptimizer * exclude keras optimizer from type check * fix import
1 parent 3e264f9 commit 339159f

21 files changed

+105
-35
lines changed

tensorflow_addons/optimizers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# ==============================================================================
1515
"""Additional optimizers that conform to Keras API."""
1616

17+
from tensorflow_addons.optimizers.constants import KerasLegacyOptimizer
1718
from tensorflow_addons.optimizers.average_wrapper import AveragedOptimizerWrapper
1819
from tensorflow_addons.optimizers.conditional_gradient import ConditionalGradient
1920
from tensorflow_addons.optimizers.cyclical_learning_rate import CyclicalLearningRate

tensorflow_addons/optimizers/adabelief.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717
import tensorflow as tf
1818
from tensorflow_addons.utils.types import FloatTensorLike
1919

20+
from tensorflow_addons.optimizers import KerasLegacyOptimizer
2021
from typing import Union, Callable, Dict
2122

2223

2324
@tf.keras.utils.register_keras_serializable(package="Addons")
24-
class AdaBelief(tf.keras.optimizers.Optimizer):
25+
class AdaBelief(KerasLegacyOptimizer):
2526
"""Variant of the Adam optimizer.
2627
2728
It achieves fast convergence as Adam and generalization comparable to SGD.

tensorflow_addons/optimizers/average_wrapper.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@
1717
import warnings
1818

1919
import tensorflow as tf
20+
from tensorflow_addons.optimizers import KerasLegacyOptimizer
2021
from tensorflow_addons.utils import types
21-
2222
from typeguard import typechecked
2323

2424

25-
class AveragedOptimizerWrapper(tf.keras.optimizers.Optimizer, metaclass=abc.ABCMeta):
25+
class AveragedOptimizerWrapper(KerasLegacyOptimizer, metaclass=abc.ABCMeta):
2626
@typechecked
2727
def __init__(
2828
self, optimizer: types.Optimizer, name: str = "AverageOptimizer", **kwargs
@@ -32,9 +32,12 @@ def __init__(
3232
if isinstance(optimizer, str):
3333
optimizer = tf.keras.optimizers.get(optimizer)
3434

35-
if not isinstance(optimizer, tf.keras.optimizers.Optimizer):
35+
if not isinstance(
36+
optimizer, (tf.keras.optimizers.Optimizer, KerasLegacyOptimizer)
37+
):
3638
raise TypeError(
37-
"optimizer is not an object of tf.keras.optimizers.Optimizer"
39+
"optimizer is not an object of tf.keras.optimizers.Optimizer "
40+
"or tf.keras.optimizers.legacy.Optimizer (if you have tf version >= 2.9.0)."
3841
)
3942

4043
self._optimizer = optimizer

tensorflow_addons/optimizers/cocob.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717
from typeguard import typechecked
1818
import tensorflow as tf
1919

20+
from tensorflow_addons.optimizers import KerasLegacyOptimizer
21+
2022

2123
@tf.keras.utils.register_keras_serializable(package="Addons")
22-
class COCOB(tf.keras.optimizers.Optimizer):
24+
class COCOB(KerasLegacyOptimizer):
2325
"""Optimizer that implements COCOB Backprop Algorithm
2426
2527
Reference:

tensorflow_addons/optimizers/conditional_gradient.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,15 @@
1515
"""Conditional Gradient optimizer."""
1616

1717
import tensorflow as tf
18+
from tensorflow_addons.optimizers import KerasLegacyOptimizer
1819
from tensorflow_addons.utils.types import FloatTensorLike
1920

2021
from typeguard import typechecked
2122
from typing import Union, Callable
2223

2324

2425
@tf.keras.utils.register_keras_serializable(package="Addons")
25-
class ConditionalGradient(tf.keras.optimizers.Optimizer):
26+
class ConditionalGradient(KerasLegacyOptimizer):
2627
"""Optimizer that implements the Conditional Gradient optimization.
2728
2829
This optimizer helps handle constraints well.
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
import importlib
16+
import tensorflow as tf
17+
18+
if importlib.util.find_spec("tensorflow.keras.optimizers.legacy") is not None:
19+
KerasLegacyOptimizer = tf.keras.optimizers.legacy.Optimizer
20+
else:
21+
KerasLegacyOptimizer = tf.keras.optimizers.Optimizer

tensorflow_addons/optimizers/cyclical_learning_rate.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def __init__(
5858
```
5959
6060
You can pass this schedule directly into a
61-
`tf.keras.optimizers.Optimizer` as the learning rate.
61+
`tf.keras.optimizers.legacy.Optimizer` as the learning rate.
6262
6363
Args:
6464
initial_learning_rate: A scalar `float32` or `float64` `Tensor` or
@@ -146,7 +146,7 @@ def __init__(
146146
```
147147
148148
You can pass this schedule directly into a
149-
`tf.keras.optimizers.Optimizer` as the learning rate.
149+
`tf.keras.optimizers.legacy.Optimizer` as the learning rate.
150150
151151
Args:
152152
initial_learning_rate: A scalar `float32` or `float64` `Tensor` or
@@ -215,7 +215,7 @@ def __init__(
215215
```
216216
217217
You can pass this schedule directly into a
218-
`tf.keras.optimizers.Optimizer` as the learning rate.
218+
`tf.keras.optimizers.legacy.Optimizer` as the learning rate.
219219
220220
Args:
221221
initial_learning_rate: A scalar `float32` or `float64` `Tensor` or
@@ -286,7 +286,7 @@ def __init__(
286286
```
287287
288288
You can pass this schedule directly into a
289-
`tf.keras.optimizers.Optimizer` as the learning rate.
289+
`tf.keras.optimizers.legacy.Optimizer` as the learning rate.
290290
291291
Args:
292292
initial_learning_rate: A scalar `float32` or `float64` `Tensor` or

tensorflow_addons/optimizers/discriminative_layer_training.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,16 @@
1717
from typing import List, Union
1818

1919
import tensorflow as tf
20+
21+
from tensorflow_addons.optimizers import KerasLegacyOptimizer
2022
from typeguard import typechecked
2123

2224
from keras import backend
2325
from keras.utils import tf_utils
2426

2527

2628
@tf.keras.utils.register_keras_serializable(package="Addons")
27-
class MultiOptimizer(tf.keras.optimizers.Optimizer):
29+
class MultiOptimizer(KerasLegacyOptimizer):
2830
"""Multi Optimizer Wrapper for Discriminative Layer Training.
2931
3032
Creates a wrapper around a set of instantiated optimizer layer pairs.
@@ -33,7 +35,7 @@ class MultiOptimizer(tf.keras.optimizers.Optimizer):
3335
Each optimizer will optimize only the weights associated with its paired layer.
3436
This can be used to implement discriminative layer training by assigning
3537
different learning rates to each optimizer layer pair.
36-
`(tf.keras.optimizers.Optimizer, List[tf.keras.layers.Layer])` pairs are also supported.
38+
`(tf.keras.optimizers.legacy.Optimizer, List[tf.keras.layers.Layer])` pairs are also supported.
3739
Please note that the layers must be instantiated before instantiating the optimizer.
3840
3941
Args:
@@ -149,7 +151,7 @@ def get_config(self):
149151
@classmethod
150152
def create_optimizer_spec(
151153
cls,
152-
optimizer: tf.keras.optimizers.Optimizer,
154+
optimizer: KerasLegacyOptimizer,
153155
layers_or_model: Union[
154156
tf.keras.Model,
155157
tf.keras.Sequential,

tensorflow_addons/optimizers/lamb.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,13 @@
2424
from typeguard import typechecked
2525

2626
import tensorflow as tf
27+
from tensorflow_addons.optimizers import KerasLegacyOptimizer
2728
from tensorflow_addons.utils.types import FloatTensorLike
2829
from tensorflow_addons.optimizers.utils import is_variable_matched_by_regexes
2930

3031

3132
@tf.keras.utils.register_keras_serializable(package="Addons")
32-
class LAMB(tf.keras.optimizers.Optimizer):
33+
class LAMB(KerasLegacyOptimizer):
3334
"""Optimizer that implements the Layer-wise Adaptive Moments (LAMB).
3435
3536
See paper [Large Batch Optimization for Deep Learning: Training BERT

tensorflow_addons/optimizers/lazy_adam.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,22 @@
2020
original Adam algorithm, and may lead to different empirical results.
2121
"""
2222

23+
import importlib
2324
import tensorflow as tf
2425
from tensorflow_addons.utils.types import FloatTensorLike
2526

2627
from typeguard import typechecked
2728
from typing import Union, Callable
2829

2930

31+
if importlib.util.find_spec("tensorflow.keras.optimizers.legacy") is not None:
32+
adam_optimizer_class = tf.keras.optimizers.legacy.Adam
33+
else:
34+
adam_optimizer_class = tf.keras.optimizers.Adam
35+
36+
3037
@tf.keras.utils.register_keras_serializable(package="Addons")
31-
class LazyAdam(tf.keras.optimizers.Adam):
38+
class LazyAdam(adam_optimizer_class):
3239
"""Variant of the Adam optimizer that handles sparse updates more
3340
efficiently.
3441

tensorflow_addons/optimizers/lookahead.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@
1616
import tensorflow as tf
1717
from tensorflow_addons.utils import types
1818

19+
from tensorflow_addons.optimizers import KerasLegacyOptimizer
1920
from typeguard import typechecked
2021

2122

2223
@tf.keras.utils.register_keras_serializable(package="Addons")
23-
class Lookahead(tf.keras.optimizers.Optimizer):
24+
class Lookahead(KerasLegacyOptimizer):
2425
"""This class allows to extend optimizers with the lookahead mechanism.
2526
2627
The mechanism is proposed by Michael R. Zhang et.al in the paper
@@ -71,9 +72,12 @@ def __init__(
7172

7273
if isinstance(optimizer, str):
7374
optimizer = tf.keras.optimizers.get(optimizer)
74-
if not isinstance(optimizer, tf.keras.optimizers.Optimizer):
75+
if not isinstance(
76+
optimizer, (tf.keras.optimizers.Optimizer, KerasLegacyOptimizer)
77+
):
7578
raise TypeError(
76-
"optimizer is not an object of tf.keras.optimizers.Optimizer"
79+
"optimizer is not an object of tf.keras.optimizers.Optimizer "
80+
"or tf.keras.optimizers.legacy.Optimizer (if you have tf version >= 2.9.0)."
7781
)
7882

7983
self._optimizer = optimizer

tensorflow_addons/optimizers/moving_average.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __init__(
5555
r"""Construct a new MovingAverage optimizer.
5656
5757
Args:
58-
optimizer: str or `tf.keras.optimizers.Optimizer` that will be
58+
optimizer: str or `tf.keras.optimizers.legacy.Optimizer` that will be
5959
used to compute and apply gradients.
6060
average_decay: float. Decay to use to maintain the moving averages
6161
of trained variables.

tensorflow_addons/optimizers/novograd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@
1616

1717
import tensorflow as tf
1818
from tensorflow_addons.utils.types import FloatTensorLike
19-
19+
from tensorflow_addons.optimizers import KerasLegacyOptimizer
2020
from typing import Union, Callable
2121
from typeguard import typechecked
2222

2323

2424
@tf.keras.utils.register_keras_serializable(package="Addons")
25-
class NovoGrad(tf.keras.optimizers.Optimizer):
25+
class NovoGrad(KerasLegacyOptimizer):
2626
"""Optimizer that implements NovoGrad.
2727
2828
The NovoGrad Optimizer was first proposed in [Stochastic Gradient

tensorflow_addons/optimizers/proximal_adagrad.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,12 @@
1919
import tensorflow as tf
2020
from typeguard import typechecked
2121

22+
from tensorflow_addons.optimizers import KerasLegacyOptimizer
2223
from tensorflow_addons.utils.types import FloatTensorLike
2324

2425

2526
@tf.keras.utils.register_keras_serializable(package="Addons")
26-
class ProximalAdagrad(tf.keras.optimizers.Optimizer):
27+
class ProximalAdagrad(KerasLegacyOptimizer):
2728
"""Optimizer that implements the Proximal Adagrad algorithm.
2829
2930
References:

tensorflow_addons/optimizers/rectified_adam.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@
1616
import tensorflow as tf
1717
from tensorflow_addons.utils.types import FloatTensorLike
1818

19+
from tensorflow_addons.optimizers import KerasLegacyOptimizer
1920
from typing import Union, Callable, Dict
2021
from typeguard import typechecked
2122

2223

2324
@tf.keras.utils.register_keras_serializable(package="Addons")
24-
class RectifiedAdam(tf.keras.optimizers.Optimizer):
25+
class RectifiedAdam(KerasLegacyOptimizer):
2526
"""Variant of the Adam optimizer whose adaptive learning rate is rectified
2627
so as to have a consistent variance.
2728

tensorflow_addons/optimizers/tests/standard_test.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import tensorflow as tf
1919

2020
from tensorflow_addons import optimizers
21+
from tensorflow_addons.optimizers import KerasLegacyOptimizer
2122
from tensorflow_addons.utils.test_utils import discover_classes
2223

2324
class_exceptions = [
@@ -29,12 +30,10 @@
2930
"ConditionalGradient", # is wrapper
3031
"Lookahead", # is wrapper
3132
"MovingAverage", # is wrapper
33+
"KerasLegacyOptimizer", # is a constantc
3234
]
3335

34-
35-
classes_to_test = discover_classes(
36-
optimizers, tf.keras.optimizers.Optimizer, class_exceptions
37-
)
36+
classes_to_test = discover_classes(optimizers, KerasLegacyOptimizer, class_exceptions)
3837

3938

4039
@pytest.mark.parametrize("optimizer", classes_to_test)

tensorflow_addons/optimizers/tests/weight_decay_optimizers_test.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# ==============================================================================
1515
"""Tests for optimizers with weight decay."""
1616

17+
import importlib
1718
import numpy as np
1819
import pytest
1920
import tensorflow as tf
@@ -401,13 +402,17 @@ def test_var_list_with_exclude_list_sgdw(dtype):
401402
)
402403

403404

405+
if importlib.util.find_spec("tensorflow.keras.optimizers.legacy") is not None:
406+
optimizer_class = tf.keras.optimizers.legacy.SGD
407+
else:
408+
optimizer_class = tf.keras.optimizers.SGD
409+
410+
404411
@pytest.mark.parametrize(
405412
"optimizer",
406413
[
407414
weight_decay_optimizers.SGDW,
408-
weight_decay_optimizers.extend_with_decoupled_weight_decay(
409-
tf.keras.optimizers.SGD
410-
),
415+
weight_decay_optimizers.extend_with_decoupled_weight_decay(optimizer_class),
411416
],
412417
)
413418
@pytest.mark.parametrize("dtype", [(tf.half, 0), (tf.float32, 1), (tf.float64, 2)])

tensorflow_addons/optimizers/weight_decay_optimizers.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# ==============================================================================
1515
"""Base class to make optimizers weight decay ready."""
1616

17+
import importlib
1718
import tensorflow as tf
1819
from tensorflow_addons.utils.types import FloatTensorLike
1920
from tensorflow_addons.optimizers.utils import is_variable_matched_by_regexes
@@ -261,10 +262,18 @@ def _do_use_weight_decay(self, var):
261262
return var.ref() in self._decay_var_list
262263

263264

265+
if importlib.util.find_spec("tensorflow.keras.optimizers.legacy") is not None:
266+
keras_legacy_optimizer = Union[
267+
tf.keras.optimizers.legacy.Optimizer, tf.keras.optimizers.Optimizer
268+
]
269+
else:
270+
keras_legacy_optimizer = tf.keras.optimizers.Optimizer
271+
272+
264273
@typechecked
265274
def extend_with_decoupled_weight_decay(
266-
base_optimizer: Type[tf.keras.optimizers.Optimizer],
267-
) -> Type[tf.keras.optimizers.Optimizer]:
275+
base_optimizer: Type[keras_legacy_optimizer],
276+
) -> Type[keras_legacy_optimizer]:
268277
"""Factory function returning an optimizer class with decoupled weight
269278
decay.
270279

tensorflow_addons/optimizers/yogi.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import tensorflow as tf
2626
from tensorflow_addons.utils.types import FloatTensorLike
2727

28+
from tensorflow_addons.optimizers import KerasLegacyOptimizer
2829
from typeguard import typechecked
2930
from typing import Union, Callable
3031

@@ -50,7 +51,7 @@ def _solve(a, b, c):
5051

5152

5253
@tf.keras.utils.register_keras_serializable(package="Addons")
53-
class Yogi(tf.keras.optimizers.Optimizer):
54+
class Yogi(KerasLegacyOptimizer):
5455
"""Optimizer that implements the Yogi algorithm in Keras.
5556
5657
See Algorithm 2 of

0 commit comments

Comments
 (0)