|
3 | 3 | from __future__ import annotations
|
4 | 4 |
|
5 | 5 | from abc import abstractmethod
|
6 |
| -from typing import Callable |
7 |
| -from typing_extensions import Final, Protocol |
| 6 | +from typing import Callable, overload |
| 7 | +from typing_extensions import Final, Literal, Protocol |
8 | 8 |
|
9 | 9 | from mypy_extensions import trait
|
10 | 10 |
|
11 | 11 | from mypy import join
|
12 |
| -from mypy.errorcodes import ErrorCode |
| 12 | +from mypy.errorcodes import LITERAL_REQ, ErrorCode |
13 | 13 | from mypy.nodes import (
|
14 | 14 | CallExpr,
|
15 | 15 | ClassDef,
|
|
26 | 26 | SymbolTableNode,
|
27 | 27 | TypeInfo,
|
28 | 28 | )
|
| 29 | +from mypy.plugin import SemanticAnalyzerPluginInterface |
29 | 30 | from mypy.tvar_scope import TypeVarLikeScope
|
30 | 31 | from mypy.type_visitor import ANY_STRATEGY, BoolTypeQuery
|
31 | 32 | from mypy.types import (
|
@@ -420,3 +421,41 @@ def find_dataclass_transform_spec(node: Node | None) -> DataclassTransformSpec |
|
420 | 421 | return metaclass_type.type.dataclass_transform_spec
|
421 | 422 |
|
422 | 423 | return None
|
| 424 | + |
| 425 | + |
| 426 | +# Never returns `None` if a default is given |
| 427 | +@overload |
| 428 | +def require_bool_literal_argument( |
| 429 | + api: SemanticAnalyzerInterface | SemanticAnalyzerPluginInterface, |
| 430 | + expression: Expression, |
| 431 | + name: str, |
| 432 | + default: Literal[True] | Literal[False], |
| 433 | +) -> bool: |
| 434 | + ... |
| 435 | + |
| 436 | + |
| 437 | +@overload |
| 438 | +def require_bool_literal_argument( |
| 439 | + api: SemanticAnalyzerInterface | SemanticAnalyzerPluginInterface, |
| 440 | + expression: Expression, |
| 441 | + name: str, |
| 442 | + default: None = None, |
| 443 | +) -> bool | None: |
| 444 | + ... |
| 445 | + |
| 446 | + |
| 447 | +def require_bool_literal_argument( |
| 448 | + api: SemanticAnalyzerInterface | SemanticAnalyzerPluginInterface, |
| 449 | + expression: Expression, |
| 450 | + name: str, |
| 451 | + default: bool | None = None, |
| 452 | +) -> bool | None: |
| 453 | + """Attempt to interpret an expression as a boolean literal, and fail analysis if we can't.""" |
| 454 | + value = api.parse_bool(expression) |
| 455 | + if value is None: |
| 456 | + api.fail( |
| 457 | + f'"{name}" argument must be a True or False literal', expression, code=LITERAL_REQ |
| 458 | + ) |
| 459 | + return default |
| 460 | + |
| 461 | + return value |
0 commit comments