Skip to content

Commit a5a3e4e

Browse files
authored
Complete partially typed signatures (monai.utils) (#5891)
Part of #5884. ### Description Fully type annotate any functions with at least one type annotation in module `visualize` ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. Signed-off-by: Felix Schnabel <f.schnabel@tum.de>
1 parent af0779c commit a5a3e4e

15 files changed

+139
-76
lines changed

monai/apps/auto3dseg/ensemble_builder.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515
from abc import ABC, abstractmethod
1616
from collections.abc import Sequence
1717
from copy import deepcopy
18-
from typing import Any
18+
from typing import Any, cast
1919
from warnings import warn
2020

2121
import numpy as np
22+
import torch
2223

2324
from monai.apps.auto3dseg.bundle_gen import BundleAlgo
2425
from monai.apps.utils import get_logger
@@ -108,7 +109,7 @@ def ensemble_pred(self, preds, sigmoid=False):
108109

109110
if self.mode == "mean":
110111
prob = MeanEnsemble()(preds)
111-
return prob2class(prob, dim=0, keepdim=True, sigmoid=sigmoid)
112+
return prob2class(cast(torch.Tensor, prob), dim=0, keepdim=True, sigmoid=sigmoid)
112113
elif self.mode == "vote":
113114
classes = [prob2class(p, dim=0, keepdim=True, sigmoid=sigmoid) for p in preds]
114115
if sigmoid:

monai/data/dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from multiprocessing.managers import ListProxy
2626
from multiprocessing.pool import ThreadPool
2727
from pathlib import Path
28-
from typing import IO, TYPE_CHECKING, Any
28+
from typing import IO, TYPE_CHECKING, Any, cast
2929

3030
import numpy as np
3131
import torch
@@ -1403,7 +1403,7 @@ def __init__(
14031403
dat = np.load(npzfile)
14041404

14051405
self.arrays = {storedk: dat[datak] for datak, storedk in self.keys.items()}
1406-
self.length = self.arrays[first(self.keys.values())].shape[0]
1406+
self.length = self.arrays[cast(str, first(self.keys.values()))].shape[0]
14071407

14081408
self.other_keys = {} if other_keys is None else {k: dat[k] for k in other_keys}
14091409

monai/utils/decorators.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,19 @@
1515

1616
__all__ = ["RestartGenerator", "MethodReplacer"]
1717

18+
from typing import Callable, Generator
19+
1820

1921
class RestartGenerator:
2022
"""
2123
Wraps a generator callable which will be called whenever this class is iterated and its result returned. This is
2224
used to create an iterator which can start iteration over the given generator multiple times.
2325
"""
2426

25-
def __init__(self, create_gen) -> None:
27+
def __init__(self, create_gen: Callable[[], Generator]) -> None:
2628
self.create_gen = create_gen
2729

28-
def __iter__(self):
30+
def __iter__(self) -> Generator:
2931
return self.create_gen()
3032

3133

@@ -36,7 +38,7 @@ class MethodReplacer:
3638

3739
replace_list_name = "__replacemethods__"
3840

39-
def __init__(self, meth) -> None:
41+
def __init__(self, meth: Callable) -> None:
4042
self.meth = meth
4143

4244
def replace_method(self, meth):

monai/utils/deprecate_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def deprecated(
4343
removed: str | None = None,
4444
msg_suffix: str = "",
4545
version_val: str = __version__,
46-
warning_category=FutureWarning,
46+
warning_category: type[FutureWarning] = FutureWarning,
4747
) -> Callable[[T], T]:
4848
"""
4949
Marks a function or class as deprecated. If `since` is given this should be a version at or earlier than the
@@ -121,13 +121,13 @@ def _wrapper(*args, **kwargs):
121121

122122

123123
def deprecated_arg(
124-
name,
124+
name: str,
125125
since: str | None = None,
126126
removed: str | None = None,
127127
msg_suffix: str = "",
128128
version_val: str = __version__,
129129
new_name: str | None = None,
130-
warning_category=FutureWarning,
130+
warning_category: type[FutureWarning] = FutureWarning,
131131
) -> Callable[[T], T]:
132132
"""
133133
Marks a particular named argument of a callable as deprecated. The same conditions for `since` and `removed` as
@@ -235,7 +235,7 @@ def deprecated_arg_default(
235235
replaced: str | None = None,
236236
msg_suffix: str = "",
237237
version_val: str = __version__,
238-
warning_category=FutureWarning,
238+
warning_category: type[FutureWarning] = FutureWarning,
239239
) -> Callable[[T], T]:
240240
"""
241241
Marks a particular arguments default of a callable as deprecated. It is changed from `old_default` to `new_default`

monai/utils/dist.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,13 @@
1111

1212
from __future__ import annotations
1313

14+
import sys
15+
16+
if sys.version_info >= (3, 8):
17+
from typing import Literal
18+
19+
from typing import overload
20+
1421
import torch
1522
import torch.distributed as dist
1623

@@ -39,7 +46,22 @@ def get_dist_device():
3946
return None
4047

4148

42-
def evenly_divisible_all_gather(data: torch.Tensor, concat: bool = True):
49+
@overload
50+
def evenly_divisible_all_gather(data: torch.Tensor, concat: Literal[True]) -> torch.Tensor:
51+
...
52+
53+
54+
@overload
55+
def evenly_divisible_all_gather(data: torch.Tensor, concat: Literal[False]) -> list[torch.Tensor]:
56+
...
57+
58+
59+
@overload
60+
def evenly_divisible_all_gather(data: torch.Tensor, concat: bool) -> torch.Tensor | list[torch.Tensor]:
61+
...
62+
63+
64+
def evenly_divisible_all_gather(data: torch.Tensor, concat: bool = True) -> torch.Tensor | list[torch.Tensor]:
4365
"""
4466
Utility function for distributed data parallel to pad at first dim to make it evenly divisible and all_gather.
4567
The input data of every rank should have the same number of dimensions, only the first dim can be different.
@@ -149,6 +171,6 @@ def string_list_all_gather(strings: list[str], delimiter: str = "\t") -> list[st
149171

150172
joined = delimiter.join(strings)
151173
gathered = evenly_divisible_all_gather(torch.tensor(bytearray(joined, "utf-8"), dtype=torch.long), concat=False)
152-
gathered = [bytearray(g.tolist()).decode("utf-8").split(delimiter) for g in gathered]
174+
_gathered = [bytearray(g.tolist()).decode("utf-8").split(delimiter) for g in gathered]
153175

154-
return [i for k in gathered for i in k]
176+
return [i for k in _gathered for i in k]

monai/utils/enums.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -369,19 +369,19 @@ class PostFix(StrEnum):
369369
"""Post-fixes."""
370370

371371
@staticmethod
372-
def _get_str(prefix, suffix):
372+
def _get_str(prefix: str | None, suffix: str) -> str:
373373
return suffix if prefix is None else f"{prefix}_{suffix}"
374374

375375
@staticmethod
376-
def meta(key: str | None = None):
376+
def meta(key: str | None = None) -> str:
377377
return PostFix._get_str(key, "meta_dict")
378378

379379
@staticmethod
380-
def orig_meta(key: str | None = None):
380+
def orig_meta(key: str | None = None) -> str:
381381
return PostFix._get_str(key, "orig_meta_dict")
382382

383383
@staticmethod
384-
def transforms(key: str | None = None):
384+
def transforms(key: str | None = None) -> str:
385385
return PostFix._get_str(key, TraceKeys.KEY_SUFFIX[1:])
386386

387387

monai/utils/jupyter_utils.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from __future__ import annotations
1717

1818
import copy
19-
from collections.abc import Callable
19+
from collections.abc import Callable, Mapping
2020
from enum import Enum
2121
from threading import RLock, Thread
2222
from typing import TYPE_CHECKING, Any
@@ -44,13 +44,13 @@
4444

4545

4646
def plot_metric_graph(
47-
ax,
47+
ax: plt.Axes,
4848
title: str,
49-
graphmap: dict[str, list[float] | tuple[list[float], list[float]]],
49+
graphmap: Mapping[str, list[float] | tuple[list[float], list[float]]],
5050
yscale: str = "log",
5151
avg_keys: tuple[str] = (LOSS_NAME,),
5252
window_fraction: int = 20,
53-
):
53+
) -> None:
5454
"""
5555
Plot metrics on a single graph with running averages plotted for selected keys. The values in `graphmap`
5656
should be lists of (timepoint, value) pairs as stored in MetricLogger objects.
@@ -91,9 +91,9 @@ def plot_metric_graph(
9191

9292

9393
def plot_metric_images(
94-
fig,
94+
fig: plt.Figure,
9595
title: str,
96-
graphmap: dict[str, list[float] | tuple[list[float], list[float]]],
96+
graphmap: Mapping[str, list[float] | tuple[list[float], list[float]]],
9797
imagemap: dict[str, np.ndarray],
9898
yscale: str = "log",
9999
avg_keys: tuple[str] = (LOSS_NAME,),
@@ -138,7 +138,7 @@ def plot_metric_images(
138138
return axes
139139

140140

141-
def tensor_to_images(name: str, tensor: torch.Tensor):
141+
def tensor_to_images(name: str, tensor: torch.Tensor) -> np.ndarray | None:
142142
"""
143143
Return an tuple of images derived from the given tensor. The `name` value indices which key from the
144144
output or batch value the tensor was stored as, or is "Batch" or "Output" if these were single tensors
@@ -147,25 +147,25 @@ def tensor_to_images(name: str, tensor: torch.Tensor):
147147
each channel separately.
148148
"""
149149
if tensor.ndim == 3 and tensor.shape[1] > 2 and tensor.shape[2] > 2:
150-
return tensor.cpu().data.numpy()
150+
return tensor.cpu().data.numpy() # type: ignore[no-any-return]
151151
if tensor.ndim == 4 and tensor.shape[2] > 2 and tensor.shape[3] > 2:
152152
dmid = tensor.shape[1] // 2
153-
return tensor[:, dmid].cpu().data.numpy()
153+
return tensor[:, dmid].cpu().data.numpy() # type: ignore[no-any-return]
154154

155155
return None
156156

157157

158158
def plot_engine_status(
159159
engine: Engine,
160-
logger,
160+
logger: Any,
161161
title: str = "Training Log",
162162
yscale: str = "log",
163163
avg_keys: tuple[str] = (LOSS_NAME,),
164164
window_fraction: int = 20,
165-
image_fn: Callable | None = tensor_to_images,
166-
fig=None,
165+
image_fn: Callable[[str, torch.Tensor], Any] | None = tensor_to_images,
166+
fig: plt.Figure = None,
167167
selected_inst: int = 0,
168-
) -> tuple:
168+
) -> tuple[plt.Figure, list]:
169169
"""
170170
Plot the status of the given Engine with its logger. The plot will consist of a graph of loss values and metrics
171171
taken from the logger, and images taken from the `output` and `batch` members of `engine.state`. The images are
@@ -191,7 +191,7 @@ def plot_engine_status(
191191
else:
192192
fig = plt.Figure(figsize=(20, 10), tight_layout=True, facecolor="white")
193193

194-
graphmap = {LOSS_NAME: logger.loss}
194+
graphmap: dict[str, list[float]] = {LOSS_NAME: logger.loss}
195195
graphmap.update(logger.metrics)
196196

197197
imagemap: dict = {}
@@ -233,10 +233,12 @@ def plot_engine_status(
233233
return fig, axes
234234

235235

236-
def _get_loss_from_output(output: dict[str, torch.Tensor] | torch.Tensor):
236+
def _get_loss_from_output(
237+
output: list[torch.Tensor | dict[str, torch.Tensor]] | dict[str, torch.Tensor] | torch.Tensor
238+
) -> torch.Tensor:
237239
"""Returns a single value from the network output, which is a dict or tensor."""
238240

239-
def _get_loss(data):
241+
def _get_loss(data: torch.Tensor | dict[str, torch.Tensor]) -> torch.Tensor:
240242
if isinstance(data, dict):
241243
return data["loss"]
242244
return data
@@ -286,7 +288,7 @@ def __init__(
286288
self._status_dict: dict[str, Any] = {}
287289
self.loss_transform = loss_transform
288290
self.metric_transform = metric_transform
289-
self.fig = None
291+
self.fig: plt.Figure | None = None
290292
self.status_format = status_format
291293

292294
self.engine.add_event_handler(Events.ITERATION_COMPLETED, self._update_status)
@@ -357,7 +359,7 @@ def status(self) -> str:
357359

358360
return ", ".join(msgs)
359361

360-
def plot_status(self, logger, plot_func: Callable = plot_engine_status):
362+
def plot_status(self, logger: Any, plot_func: Callable = plot_engine_status) -> plt.Figure:
361363
"""
362364
Generate a plot of the current status of the contained engine whose loss and metrics were tracked by `logger`.
363365
The function `plot_func` must accept arguments `title`, `engine`, `logger`, and `fig` which are the plot title,

monai/utils/misc.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from collections.abc import Callable, Iterable, Sequence
2424
from distutils.util import strtobool
2525
from pathlib import Path
26-
from typing import Any, cast
26+
from typing import Any, TypeVar, cast, overload
2727

2828
import numpy as np
2929
import torch
@@ -82,7 +82,20 @@ def star_zip_with(op, *vals):
8282
return zip_with(op, *vals, mapfunc=itertools.starmap)
8383

8484

85-
def first(iterable, default=None):
85+
T = TypeVar("T")
86+
87+
88+
@overload
89+
def first(iterable: Iterable[T], default: T) -> T:
90+
...
91+
92+
93+
@overload
94+
def first(iterable: Iterable[T]) -> T | None:
95+
...
96+
97+
98+
def first(iterable: Iterable[T], default: T | None = None) -> T | None:
8699
"""
87100
Returns the first item in the given iterable or `default` if empty, meaningful mostly with 'for' expressions.
88101
"""
@@ -463,7 +476,7 @@ class ImageMetaKey:
463476
SPATIAL_SHAPE = "spatial_shape"
464477

465478

466-
def has_option(obj, keywords: str | Sequence[str]) -> bool:
479+
def has_option(obj: Callable, keywords: str | Sequence[str]) -> bool:
467480
"""
468481
Return a boolean indicating whether the given callable `obj` has the `keywords` in its signature.
469482
"""
@@ -504,7 +517,7 @@ def sample_slices(data: NdarrayOrTensor, dim: int = 1, as_indices: bool = True,
504517
return data[tuple(slices)]
505518

506519

507-
def check_parent_dir(path: PathLike, create_dir: bool = True):
520+
def check_parent_dir(path: PathLike, create_dir: bool = True) -> None:
508521
"""
509522
Utility to check whether the parent directory of the `path` exists.
510523
@@ -523,7 +536,14 @@ def check_parent_dir(path: PathLike, create_dir: bool = True):
523536
raise ValueError(f"the directory of specified path does not exist: `{path_dir}`.")
524537

525538

526-
def save_obj(obj, path: PathLike, create_dir: bool = True, atomic: bool = True, func: Callable | None = None, **kwargs):
539+
def save_obj(
540+
obj: object,
541+
path: PathLike,
542+
create_dir: bool = True,
543+
atomic: bool = True,
544+
func: Callable | None = None,
545+
**kwargs: Any,
546+
) -> None:
527547
"""
528548
Save an object to file with specified path.
529549
Support to serialize to a temporary file first, then move to final destination,
@@ -576,7 +596,7 @@ def label_union(x: list) -> list:
576596
return list(set.union(set(np.array(x).tolist())))
577597

578598

579-
def prob2class(x, sigmoid: bool = False, threshold: float = 0.5, **kwargs):
599+
def prob2class(x: torch.Tensor, sigmoid: bool = False, threshold: float = 0.5, **kwargs: Any) -> torch.Tensor:
580600
"""
581601
Compute the lab from the probability of predicted feature maps
582602

0 commit comments

Comments
 (0)