Skip to content

Commit e69c5cd

Browse files
authored
stubgen: Preserve simple defaults in function signatures (#15355)
Fixes #13238 See also python/typeshed#8988
1 parent 1200d1d commit e69c5cd

File tree

3 files changed

+153
-24
lines changed

3 files changed

+153
-24
lines changed

mypy/stubdoc.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,19 @@ def is_valid_type(s: str) -> bool:
3636
class ArgSig:
3737
"""Signature info for a single argument."""
3838

39-
def __init__(self, name: str, type: str | None = None, default: bool = False):
39+
def __init__(
40+
self,
41+
name: str,
42+
type: str | None = None,
43+
*,
44+
default: bool = False,
45+
default_value: str = "...",
46+
) -> None:
4047
self.name = name
4148
self.type = type
4249
# Does this argument have a default value?
4350
self.default = default
51+
self.default_value = default_value
4452

4553
def is_star_arg(self) -> bool:
4654
return self.name.startswith("*") and not self.name.startswith("**")
@@ -59,6 +67,7 @@ def __eq__(self, other: Any) -> bool:
5967
self.name == other.name
6068
and self.type == other.type
6169
and self.default == other.default
70+
and self.default_value == other.default_value
6271
)
6372
return False
6473

@@ -119,10 +128,10 @@ def format_sig(
119128
if arg_type:
120129
arg_def += ": " + arg_type
121130
if arg.default:
122-
arg_def += " = ..."
131+
arg_def += f" = {arg.default_value}"
123132

124133
elif arg.default:
125-
arg_def += "=..."
134+
arg_def += f"={arg.default_value}"
126135

127136
args.append(arg_def)
128137

mypy/stubgen.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@
9999
NameExpr,
100100
OpExpr,
101101
OverloadedFuncDef,
102+
SetExpr,
102103
Statement,
103104
StrExpr,
104105
TempNode,
@@ -491,15 +492,21 @@ def _get_func_args(self, o: FuncDef, ctx: FunctionContext) -> list[ArgSig]:
491492
if kind.is_named() and not any(arg.name.startswith("*") for arg in args):
492493
args.append(ArgSig("*"))
493494

495+
default = "..."
494496
if arg_.initializer:
495497
if not typename:
496498
typename = self.get_str_type_of_node(arg_.initializer, True, False)
499+
potential_default, valid = self.get_str_default_of_node(arg_.initializer)
500+
if valid and len(potential_default) <= 200:
501+
default = potential_default
497502
elif kind == ARG_STAR:
498503
name = f"*{name}"
499504
elif kind == ARG_STAR2:
500505
name = f"**{name}"
501506

502-
args.append(ArgSig(name, typename, default=bool(arg_.initializer)))
507+
args.append(
508+
ArgSig(name, typename, default=bool(arg_.initializer), default_value=default)
509+
)
503510

504511
if ctx.class_info is not None and all(
505512
arg.type is None and arg.default is False for arg in args
@@ -1234,6 +1241,70 @@ def maybe_unwrap_unary_expr(self, expr: Expression) -> Expression:
12341241
# This is some other unary expr, we cannot do anything with it (yet?).
12351242
return expr
12361243

1244+
def get_str_default_of_node(self, rvalue: Expression) -> tuple[str, bool]:
1245+
"""Get a string representation of the default value of a node.
1246+
1247+
Returns a 2-tuple of the default and whether or not it is valid.
1248+
"""
1249+
if isinstance(rvalue, NameExpr):
1250+
if rvalue.name in ("None", "True", "False"):
1251+
return rvalue.name, True
1252+
elif isinstance(rvalue, (IntExpr, FloatExpr)):
1253+
return f"{rvalue.value}", True
1254+
elif isinstance(rvalue, UnaryExpr):
1255+
if isinstance(rvalue.expr, (IntExpr, FloatExpr)):
1256+
return f"{rvalue.op}{rvalue.expr.value}", True
1257+
elif isinstance(rvalue, StrExpr):
1258+
return repr(rvalue.value), True
1259+
elif isinstance(rvalue, BytesExpr):
1260+
return "b" + repr(rvalue.value).replace("\\\\", "\\"), True
1261+
elif isinstance(rvalue, TupleExpr):
1262+
items_defaults = []
1263+
for e in rvalue.items:
1264+
e_default, valid = self.get_str_default_of_node(e)
1265+
if not valid:
1266+
break
1267+
items_defaults.append(e_default)
1268+
else:
1269+
closing = ",)" if len(items_defaults) == 1 else ")"
1270+
default = "(" + ", ".join(items_defaults) + closing
1271+
return default, True
1272+
elif isinstance(rvalue, ListExpr):
1273+
items_defaults = []
1274+
for e in rvalue.items:
1275+
e_default, valid = self.get_str_default_of_node(e)
1276+
if not valid:
1277+
break
1278+
items_defaults.append(e_default)
1279+
else:
1280+
default = "[" + ", ".join(items_defaults) + "]"
1281+
return default, True
1282+
elif isinstance(rvalue, SetExpr):
1283+
items_defaults = []
1284+
for e in rvalue.items:
1285+
e_default, valid = self.get_str_default_of_node(e)
1286+
if not valid:
1287+
break
1288+
items_defaults.append(e_default)
1289+
else:
1290+
if items_defaults:
1291+
default = "{" + ", ".join(items_defaults) + "}"
1292+
return default, True
1293+
elif isinstance(rvalue, DictExpr):
1294+
items_defaults = []
1295+
for k, v in rvalue.items:
1296+
if k is None:
1297+
break
1298+
k_default, k_valid = self.get_str_default_of_node(k)
1299+
v_default, v_valid = self.get_str_default_of_node(v)
1300+
if not (k_valid and v_valid):
1301+
break
1302+
items_defaults.append(f"{k_default}: {v_default}")
1303+
else:
1304+
default = "{" + ", ".join(items_defaults) + "}"
1305+
return default, True
1306+
return "...", False
1307+
12371308
def should_reexport(self, name: str, full_module: str, name_is_alias: bool) -> bool:
12381309
is_private = self.is_private_name(name, full_module + "." + name)
12391310
if (

test-data/unit/stubgen.test

Lines changed: 69 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,45 +27,47 @@ def g(arg) -> None: ...
2727
def f(a, b=2): ...
2828
def g(b=-1, c=0): ...
2929
[out]
30-
def f(a, b: int = ...) -> None: ...
31-
def g(b: int = ..., c: int = ...) -> None: ...
30+
def f(a, b: int = 2) -> None: ...
31+
def g(b: int = -1, c: int = 0) -> None: ...
3232

3333
[case testDefaultArgNone]
3434
def f(x=None): ...
3535
[out]
3636
from _typeshed import Incomplete
3737

38-
def f(x: Incomplete | None = ...) -> None: ...
38+
def f(x: Incomplete | None = None) -> None: ...
3939

4040
[case testDefaultArgBool]
4141
def f(x=True, y=False): ...
4242
[out]
43-
def f(x: bool = ..., y: bool = ...) -> None: ...
43+
def f(x: bool = True, y: bool = False) -> None: ...
4444

4545
[case testDefaultArgBool_inspect]
4646
def f(x=True, y=False): ...
4747
[out]
4848
def f(x: bool = ..., y: bool = ...): ...
4949

5050
[case testDefaultArgStr]
51-
def f(x='foo'): ...
51+
def f(x='foo',y="how's quotes"): ...
5252
[out]
53-
def f(x: str = ...) -> None: ...
53+
def f(x: str = 'foo', y: str = "how's quotes") -> None: ...
5454

5555
[case testDefaultArgStr_inspect]
5656
def f(x='foo'): ...
5757
[out]
5858
def f(x: str = ...): ...
5959

6060
[case testDefaultArgBytes]
61-
def f(x=b'foo'): ...
61+
def f(x=b'foo',y=b"what's up",z=b'\xc3\xa0 la une'): ...
6262
[out]
63-
def f(x: bytes = ...) -> None: ...
63+
def f(x: bytes = b'foo', y: bytes = b"what's up", z: bytes = b'\xc3\xa0 la une') -> None: ...
6464

6565
[case testDefaultArgFloat]
66-
def f(x=1.2): ...
66+
def f(x=1.2,y=1e-6,z=0.0,w=-0.0,v=+1.0): ...
67+
def g(x=float("nan"), y=float("inf"), z=float("-inf")): ...
6768
[out]
68-
def f(x: float = ...) -> None: ...
69+
def f(x: float = 1.2, y: float = 1e-06, z: float = 0.0, w: float = -0.0, v: float = +1.0) -> None: ...
70+
def g(x=..., y=..., z=...) -> None: ...
6971

7072
[case testDefaultArgOther]
7173
def f(x=ord): ...
@@ -126,10 +128,10 @@ def i(a, *, b=1): ...
126128
def j(a, *, b=1, **c): ...
127129
[out]
128130
def f(a, *b, **c) -> None: ...
129-
def g(a, *b, c: int = ...) -> None: ...
130-
def h(a, *b, c: int = ..., **d) -> None: ...
131-
def i(a, *, b: int = ...) -> None: ...
132-
def j(a, *, b: int = ..., **c) -> None: ...
131+
def g(a, *b, c: int = 1) -> None: ...
132+
def h(a, *b, c: int = 1, **d) -> None: ...
133+
def i(a, *, b: int = 1) -> None: ...
134+
def j(a, *, b: int = 1, **c) -> None: ...
133135

134136
[case testClass]
135137
class A:
@@ -356,8 +358,8 @@ y: Incomplete
356358
def f(x, *, y=1): ...
357359
def g(x, *, y=1, z=2): ...
358360
[out]
359-
def f(x, *, y: int = ...) -> None: ...
360-
def g(x, *, y: int = ..., z: int = ...) -> None: ...
361+
def f(x, *, y: int = 1) -> None: ...
362+
def g(x, *, y: int = 1, z: int = 2) -> None: ...
361363

362364
[case testProperty]
363365
class A:
@@ -1285,8 +1287,8 @@ from _typeshed import Incomplete
12851287

12861288
class A:
12871289
x: Incomplete
1288-
def __init__(self, a: Incomplete | None = ...) -> None: ...
1289-
def method(self, a: Incomplete | None = ...) -> None: ...
1290+
def __init__(self, a: Incomplete | None = None) -> None: ...
1291+
def method(self, a: Incomplete | None = None) -> None: ...
12901292

12911293
[case testAnnotationImportsFrom]
12921294
import foo
@@ -2514,7 +2516,7 @@ from _typeshed import Incomplete as _Incomplete
25142516

25152517
Y: _Incomplete
25162518

2517-
def g(x: _Incomplete | None = ...) -> None: ...
2519+
def g(x: _Incomplete | None = None) -> None: ...
25182520

25192521
x: _Incomplete
25202522

@@ -3503,7 +3505,7 @@ class P(Protocol):
35033505
[case testNonDefaultKeywordOnlyArgAfterAsterisk]
35043506
def func(*, non_default_kwarg: bool, default_kwarg: bool = True): ...
35053507
[out]
3506-
def func(*, non_default_kwarg: bool, default_kwarg: bool = ...): ...
3508+
def func(*, non_default_kwarg: bool, default_kwarg: bool = True): ...
35073509

35083510
[case testNestedGenerator]
35093511
def f1():
@@ -3909,6 +3911,53 @@ def gen2() -> _Generator[_Incomplete, _Incomplete, _Incomplete]: ...
39093911
class X(_Incomplete): ...
39103912
class Y(_Incomplete): ...
39113913

3914+
[case testIgnoreLongDefaults]
3915+
def f(x='abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\
3916+
abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\
3917+
abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\
3918+
abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'): ...
3919+
3920+
def g(x=b'abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\
3921+
abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\
3922+
abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz\
3923+
abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz'): ...
3924+
3925+
def h(x=123456789012345678901234567890123456789012345678901234567890\
3926+
123456789012345678901234567890123456789012345678901234567890\
3927+
123456789012345678901234567890123456789012345678901234567890\
3928+
123456789012345678901234567890123456789012345678901234567890): ...
3929+
3930+
[out]
3931+
def f(x: str = ...) -> None: ...
3932+
def g(x: bytes = ...) -> None: ...
3933+
def h(x: int = ...) -> None: ...
3934+
3935+
[case testDefaultsOfBuiltinContainers]
3936+
def f(x=(), y=(1,), z=(1, 2)): ...
3937+
def g(x=[], y=[1, 2]): ...
3938+
def h(x={}, y={1: 2, 3: 4}): ...
3939+
def i(x={1, 2, 3}): ...
3940+
def j(x=[(1,"a"), (2,"b")]): ...
3941+
3942+
[out]
3943+
def f(x=(), y=(1,), z=(1, 2)) -> None: ...
3944+
def g(x=[], y=[1, 2]) -> None: ...
3945+
def h(x={}, y={1: 2, 3: 4}) -> None: ...
3946+
def i(x={1, 2, 3}) -> None: ...
3947+
def j(x=[(1, 'a'), (2, 'b')]) -> None: ...
3948+
3949+
[case testDefaultsOfBuiltinContainersWithNonTrivialContent]
3950+
def f(x=(1, u.v), y=(k(),), z=(w,)): ...
3951+
def g(x=[1, u.v], y=[k()], z=[w]): ...
3952+
def h(x={1: u.v}, y={k(): 2}, z={m: m}, w={**n}): ...
3953+
def i(x={u.v, 2}, y={3, k()}, z={w}): ...
3954+
3955+
[out]
3956+
def f(x=..., y=..., z=...) -> None: ...
3957+
def g(x=..., y=..., z=...) -> None: ...
3958+
def h(x=..., y=..., z=..., w=...) -> None: ...
3959+
def i(x=..., y=..., z=...) -> None: ...
3960+
39123961
[case testDataclass]
39133962
import dataclasses
39143963
import dataclasses as dcs

0 commit comments

Comments
 (0)