|
2 | 2 |
|
3 | 3 |
|
4 | 4 | import numpy as np
|
| 5 | +import pytest |
5 | 6 |
|
6 | 7 |
|
7 | 8 | try:
|
@@ -52,30 +53,32 @@ def test_backwards_compatibility():
|
52 | 53 | check_backwards_compatibility(codec.codec_id, arrays, [codec])
|
53 | 54 |
|
54 | 55 |
|
55 |
| -def test_non_numpy_inputs(): |
| 56 | +@pytest.mark.parametrize( |
| 57 | + "input_data, dtype", |
| 58 | + [ |
| 59 | + ([0, 1], None), |
| 60 | + ([[0, 1], [2, 3]], None), |
| 61 | + ([[0], [1], [2, 3]], object), |
| 62 | + ([[[0, 0]], [[1, 1]], [[2, 3]]], None), |
| 63 | + (["1"], None), |
| 64 | + (["11", "11"], None), |
| 65 | + (["11", "1", "1"], None), |
| 66 | + ([{}], None), |
| 67 | + ([{"key": "value"}, ["list", "of", "strings"]], object), |
| 68 | + ([b"1"], None), |
| 69 | + ([b"11", b"11"], None), |
| 70 | + ([b"11", b"1", b"1"], None), |
| 71 | + ([{b"key": b"value"}, [b"list", b"of", b"strings"]], object), |
| 72 | + ] |
| 73 | +) |
| 74 | +def test_non_numpy_inputs(input_data, dtype): |
56 | 75 | codec = MsgPack()
|
57 | 76 | # numpy will infer a range of different shapes and dtypes for these inputs.
|
58 | 77 | # Make sure that round-tripping through encode preserves this.
|
59 |
| - data = [ |
60 |
| - [0, 1], |
61 |
| - [[0, 1], [2, 3]], |
62 |
| - [[0], [1], [2, 3]], |
63 |
| - [[[0, 0]], [[1, 1]], [[2, 3]]], |
64 |
| - ["1"], |
65 |
| - ["11", "11"], |
66 |
| - ["11", "1", "1"], |
67 |
| - [{}], |
68 |
| - [{"key": "value"}, ["list", "of", "strings"]], |
69 |
| - [b"1"], |
70 |
| - [b"11", b"11"], |
71 |
| - [b"11", b"1", b"1"], |
72 |
| - [{b"key": b"value"}, [b"list", b"of", b"strings"]], |
73 |
| - ] |
74 |
| - for input_data in data: |
75 |
| - actual = codec.decode(codec.encode(input_data)) |
76 |
| - expect = np.array(input_data) |
77 |
| - assert expect.shape == actual.shape |
78 |
| - assert np.array_equal(expect, actual) |
| 78 | + actual = codec.decode(codec.encode(input_data)) |
| 79 | + expect = np.array(input_data, dtype=dtype) |
| 80 | + assert expect.shape == actual.shape |
| 81 | + assert np.array_equal(expect, actual) |
79 | 82 |
|
80 | 83 |
|
81 | 84 | def test_encode_decode_shape_dtype_preserved():
|
|
0 commit comments