Skip to content

Commit 2fe6a09

Browse files
committed
feat: add ChoiceListParameter
1 parent b5d1b96 commit 2fe6a09

File tree

3 files changed

+79
-6
lines changed

3 files changed

+79
-6
lines changed

luigi/__init__.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@
4040
DateIntervalParameter, TimeDeltaParameter,
4141
IntParameter, FloatParameter, BoolParameter, PathParameter,
4242
TaskParameter, EnumParameter, DictParameter, ListParameter, TupleParameter, EnumListParameter,
43-
NumericalParameter, ChoiceParameter, OptionalParameter, OptionalStrParameter,
44-
OptionalIntParameter, OptionalFloatParameter, OptionalBoolParameter, OptionalPathParameter,
45-
OptionalDictParameter, OptionalListParameter, OptionalTupleParameter,
43+
NumericalParameter, ChoiceParameter, ChoiceListParameter, OptionalParameter,
44+
OptionalStrParameter, OptionalIntParameter, OptionalFloatParameter, OptionalBoolParameter,
45+
OptionalPathParameter, OptionalDictParameter, OptionalListParameter, OptionalTupleParameter,
4646
OptionalChoiceParameter, OptionalNumericalParameter,
4747
)
4848

@@ -66,9 +66,9 @@
6666
'FloatParameter', 'BoolParameter', 'PathParameter', 'TaskParameter',
6767
'ListParameter', 'TupleParameter', 'EnumParameter', 'DictParameter', 'EnumListParameter',
6868
'configuration', 'interface', 'local_target', 'run', 'build', 'event', 'Event',
69-
'NumericalParameter', 'ChoiceParameter', 'OptionalParameter', 'OptionalStrParameter',
70-
'OptionalIntParameter', 'OptionalFloatParameter', 'OptionalBoolParameter', 'OptionalPathParameter',
71-
'OptionalDictParameter', 'OptionalListParameter', 'OptionalTupleParameter',
69+
'NumericalParameter', 'ChoiceParameter', 'ChoiceListParameter', 'OptionalParameter',
70+
'OptionalStrParameter', 'OptionalIntParameter', 'OptionalFloatParameter', 'OptionalBoolParameter',
71+
'OptionalPathParameter', 'OptionalDictParameter', 'OptionalListParameter', 'OptionalTupleParameter',
7272
'OptionalChoiceParameter', 'OptionalNumericalParameter', 'LuigiStatusCode',
7373
'__version__',
7474
]

luigi/parameter.py

+47
Original file line numberDiff line numberDiff line change
@@ -1540,6 +1540,53 @@ def normalize(self, var):
15401540
var=var, choices=self._choices))
15411541

15421542

1543+
class ChoiceListParameter(ChoiceParameter):
1544+
"""
1545+
A parameter which takes two values:
1546+
1. an instance of :class:`~collections.Iterable` and
1547+
2. the class of the variables to convert to.
1548+
1549+
Values are taken to be a list, i.e. order is preserved, duplicates may occur, and empty list is possible.
1550+
1551+
In the task definition, use
1552+
1553+
.. code-block:: python
1554+
1555+
class MyTask(luigi.Task):
1556+
my_param = luigi.ChoiceListParameter(choices=['foo', 'bar', 'baz'], var_type=str)
1557+
1558+
At the command line, use
1559+
1560+
.. code-block:: console
1561+
1562+
$ luigi --module my_tasks MyTask --my-param foo,bar
1563+
1564+
Consider using :class:`~luigi.EnumListParameter` for a typed, structured
1565+
alternative. This class can perform the same role when all choices are the
1566+
same type and transparency of parameter value on the command line is
1567+
desired.
1568+
"""
1569+
1570+
_sep = ','
1571+
1572+
def __init__(self, *args, **kwargs):
1573+
super(ChoiceListParameter, self).__init__(*args, **kwargs)
1574+
1575+
def parse(self, s):
1576+
values = [] if s == '' else s.split(self._sep)
1577+
return self.normalize(map(self._var_type, values))
1578+
1579+
def normalize(self, var):
1580+
values = []
1581+
for v in var:
1582+
values.append(super().normalize(v))
1583+
return tuple(values)
1584+
1585+
1586+
def serialize(self, values):
1587+
return self._sep.join(values)
1588+
1589+
15431590
class OptionalChoiceParameter(OptionalParameterMixin, ChoiceParameter):
15441591
"""Class to parse optional choice parameters."""
15451592

test/parameter_test.py

+26
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,25 @@ def test_enum_list_param_invalid(self):
310310
def test_enum_list_param_missing(self):
311311
self.assertRaises(ParameterException, lambda: luigi.parameter.EnumListParameter())
312312

313+
def test_choice_list_param_valid(self):
314+
p = luigi.parameter.ChoiceListParameter(choices=["1", "2", "3"])
315+
self.assertEqual((), p.parse(''))
316+
self.assertEqual(("1",), p.parse('1'))
317+
self.assertEqual(("1", "3"), p.parse('1,3'))
318+
319+
def test_choice_list_param_invalid(self):
320+
p = luigi.parameter.ChoiceListParameter(choices=["1", "2", "3"])
321+
self.assertRaises(ValueError, lambda: p.parse('1,4'))
322+
323+
def test_invalid_choice_type(self):
324+
self.assertRaises(
325+
AssertionError,
326+
lambda: luigi.ChoiceListParameter(var_type=int, choices=[1, 2, "3"]),
327+
)
328+
329+
def test_choice_list_param_missing(self):
330+
self.assertRaises(ParameterException, lambda: luigi.parameter.ChoiceListParameter())
331+
313332
def test_tuple_serialize_parse(self):
314333
a = luigi.TupleParameter()
315334
b_tuple = ((1, 2), (3, 4))
@@ -469,6 +488,13 @@ class FooWithDefault(luigi.Task):
469488

470489
self.assertEqual(FooWithDefault().args, p.parse('C'))
471490

491+
def test_choice_list(self):
492+
class Foo(luigi.Task):
493+
args = luigi.ChoiceListParameter(var_type=str, choices=["1", "2", "3"])
494+
495+
p = luigi.ChoiceListParameter(var_type=str, choices=["3", "2", "1"])
496+
self.assertEqual(hash(Foo(args=("3",)).args), hash(p.parse("3")))
497+
472498
def test_dict(self):
473499
class Foo(luigi.Task):
474500
args = luigi.parameter.DictParameter()

0 commit comments

Comments
 (0)