Skip to content

Commit b84e8c6

Browse files
authored
Move module_tracker to logging for confused hierarchy (#134467) (#134501)
* Move module_tracker to logging for confused hierarchy (#134467) Fixes #134242 Make sure to never raise an error when confused. Logs for confusion can be enabled with `TORCH_LOGS="torch.utils.module_tracker"` or the usual python systems. Pull Request resolved: #134467 Approved by: https://github.com/malfet * Fix bad merge conflict resolution
1 parent 6a79d4a commit b84e8c6

File tree

3 files changed

+52
-15
lines changed

3 files changed

+52
-15
lines changed

test/test_module_tracker.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
from copy import copy
44

55
import torch
6+
from torch import nn
67
from torch.testing._internal.common_utils import run_tests, TestCase, xfailIfTorchDynamo
8+
from torch.utils.checkpoint import checkpoint
79
from torch.utils.module_tracker import ModuleTracker
810

911

@@ -14,7 +16,7 @@ def test_module_hierarchy(self):
1416
seen_fw = []
1517
seen_bw = []
1618

17-
class Foo(torch.nn.Module):
19+
class Foo(nn.Module):
1820
def forward(self, x):
1921
x = x["a"].relu_()
2022
seen_fw.append((copy(tracker.parents), tracker.is_bw))
@@ -23,12 +25,12 @@ def forward(self, x):
2325
)
2426
return {"a": torch.mm(x, x)}
2527

26-
class Mod(torch.nn.Module):
27-
def __init__(self):
28+
class Mod(nn.Module):
29+
def __init__(self) -> None:
2830
super().__init__()
2931
self.a = Foo()
30-
self.b = torch.nn.ModuleDict({"nest": Foo()})
31-
self.c = torch.nn.ModuleList([Foo()])
32+
self.b = nn.ModuleDict({"nest": Foo()})
33+
self.c = nn.ModuleList([Foo()])
3234

3335
def forward(self, x):
3436
x = self.c[0](x)
@@ -68,8 +70,36 @@ def forward(self, x):
6870
],
6971
)
7072

73+
def test_confused_hierarchy(self):
74+
class MyMod(nn.Module):
75+
def __init__(self):
76+
super().__init__()
77+
self.inner = nn.Linear(2, 2)
78+
self.ran = False
79+
80+
def forward(self, inp):
81+
if not self.ran:
82+
self.ran = True
83+
return self(inp)
84+
else:
85+
self.ran = False
86+
return self.inner(inp)
87+
88+
mod = MyMod()
89+
inp = torch.rand(1, 2, requires_grad=True)
90+
91+
# Should not fail
92+
with ModuleTracker() as tracker:
93+
res = mod(inp)
94+
res.sum().backward()
95+
96+
# Should not fail
97+
with ModuleTracker() as tracker:
98+
res = checkpoint(lambda inp: mod(inp), inp)
99+
res.sum().backward()
100+
71101
def test_bw_detection(self):
72-
mod = torch.nn.Linear(2, 2)
102+
mod = nn.Linear(2, 2)
73103

74104
with ModuleTracker() as tracker:
75105
mod(torch.rand(2, requires_grad=True)).sum().backward()

torch/autograd/graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,8 @@ def __subclasshook__(cls, C):
158158

159159
def _get_grad_fn_or_grad_acc(t):
160160
if t.requires_grad and t.grad_fn is None:
161-
return t.view_as(t).grad_fn.next_functions[0][0]
161+
with torch.enable_grad():
162+
return t.view_as(t).grad_fn.next_functions[0][0]
162163
else:
163164
return t.grad_fn
164165

torch/utils/module_tracker.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# mypy: allow-untyped-defs
2+
import logging
23
import weakref
34

45
from typing import Set
@@ -11,6 +12,10 @@
1112
)
1213
from torch.utils._pytree import tree_flatten
1314

15+
16+
logger = logging.getLogger(__name__)
17+
18+
1419
__all__ = ["ModuleTracker"]
1520

1621

@@ -93,9 +98,10 @@ def fn(*args):
9398
if is_bw:
9499
self._maybe_set_engine_callback()
95100
if name in self.parents:
96-
print(
97-
"The module hierarchy tracking seems to be messed up."
98-
"Please file a bug to PyTorch."
101+
logger.info(
102+
"The module hierarchy tracking seems to be broken as this Module was already entered. %s during %s",
103+
name,
104+
"backward" if is_bw else "forward",
99105
)
100106
self.parents.add(name)
101107

@@ -105,11 +111,11 @@ def _get_pop_fn(self, name, is_bw):
105111
def fn(*args):
106112
if name in self.parents:
107113
self.parents.remove(name)
108-
elif not is_bw:
109-
# Due to some input/output not requiring gradients, we cannot enforce
110-
# proper nesting in backward
111-
raise RuntimeError(
112-
"The Module hierarchy tracking is wrong. Report a bug to PyTorch"
114+
else:
115+
logger.info(
116+
"The Module hierarchy tracking is confused as we're exiting a Module that was never entered. %s during %s",
117+
name,
118+
"backward" if is_bw else "forward",
113119
)
114120

115121
return fn

0 commit comments

Comments
 (0)