Skip to content

Commit 5f1f78f

Browse files
authored
Faster encoding functions. (#8565)
* Faster ensure_not_multiindex * Better check * Fix test and add typing * Optimize string encoding a bit.
1 parent 693f0b9 commit 5f1f78f

File tree

3 files changed

+23
-17
lines changed

3 files changed

+23
-17
lines changed

xarray/coding/strings.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,11 @@ class EncodedStringCoder(VariableCoder):
4747
def __init__(self, allows_unicode=True):
4848
self.allows_unicode = allows_unicode
4949

50-
def encode(self, variable, name=None):
50+
def encode(self, variable: Variable, name=None) -> Variable:
5151
dims, data, attrs, encoding = unpack_for_encoding(variable)
5252

5353
contains_unicode = is_unicode_dtype(data.dtype)
5454
encode_as_char = encoding.get("dtype") == "S1"
55-
5655
if encode_as_char:
5756
del encoding["dtype"] # no longer relevant
5857

@@ -69,9 +68,12 @@ def encode(self, variable, name=None):
6968
# TODO: figure out how to handle this in a lazy way with dask
7069
data = encode_string_array(data, string_encoding)
7170

72-
return Variable(dims, data, attrs, encoding)
71+
return Variable(dims, data, attrs, encoding)
72+
else:
73+
variable.encoding = encoding
74+
return variable
7375

74-
def decode(self, variable, name=None):
76+
def decode(self, variable: Variable, name=None) -> Variable:
7577
dims, data, attrs, encoding = unpack_for_decoding(variable)
7678

7779
if "_Encoding" in attrs:
@@ -95,13 +97,15 @@ def encode_string_array(string_array, encoding="utf-8"):
9597
return np.array(encoded, dtype=bytes).reshape(string_array.shape)
9698

9799

98-
def ensure_fixed_length_bytes(var):
100+
def ensure_fixed_length_bytes(var: Variable) -> Variable:
99101
"""Ensure that a variable with vlen bytes is converted to fixed width."""
100-
dims, data, attrs, encoding = unpack_for_encoding(var)
101-
if check_vlen_dtype(data.dtype) == bytes:
102+
if check_vlen_dtype(var.dtype) == bytes:
103+
dims, data, attrs, encoding = unpack_for_encoding(var)
102104
# TODO: figure out how to handle this with dask
103105
data = np.asarray(data, dtype=np.bytes_)
104-
return Variable(dims, data, attrs, encoding)
106+
return Variable(dims, data, attrs, encoding)
107+
else:
108+
return var
105109

106110

107111
class CharacterArrayCoder(VariableCoder):

xarray/conventions.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
)
1717
from xarray.core.pycompat import is_duck_dask_array
1818
from xarray.core.utils import emit_user_level_warning
19-
from xarray.core.variable import IndexVariable, Variable
19+
from xarray.core.variable import Variable
2020

2121
CF_RELATED_DATA = (
2222
"bounds",
@@ -97,10 +97,10 @@ def _infer_dtype(array, name=None):
9797

9898

9999
def ensure_not_multiindex(var: Variable, name: T_Name = None) -> None:
100-
if isinstance(var, IndexVariable) and isinstance(var.to_index(), pd.MultiIndex):
100+
if isinstance(var._data, indexing.PandasMultiIndexingAdapter):
101101
raise NotImplementedError(
102102
f"variable {name!r} is a MultiIndex, which cannot yet be "
103-
"serialized to netCDF files. Instead, either use reset_index() "
103+
"serialized. Instead, either use reset_index() "
104104
"to convert MultiIndex levels into coordinate variables instead "
105105
"or use https://cf-xarray.readthedocs.io/en/latest/coding.html."
106106
)
@@ -647,7 +647,9 @@ def cf_decoder(
647647
return variables, attributes
648648

649649

650-
def _encode_coordinates(variables, attributes, non_dim_coord_names):
650+
def _encode_coordinates(
651+
variables: T_Variables, attributes: T_Attrs, non_dim_coord_names
652+
):
651653
# calculate global and variable specific coordinates
652654
non_dim_coord_names = set(non_dim_coord_names)
653655

@@ -675,7 +677,7 @@ def _encode_coordinates(variables, attributes, non_dim_coord_names):
675677
variable_coordinates[k].add(coord_name)
676678

677679
if any(
678-
attr_name in v.encoding and coord_name in v.encoding.get(attr_name)
680+
coord_name in v.encoding.get(attr_name, tuple())
679681
for attr_name in CF_RELATED_DATA
680682
):
681683
not_technically_coordinates.add(coord_name)
@@ -742,7 +744,7 @@ def _encode_coordinates(variables, attributes, non_dim_coord_names):
742744
return variables, attributes
743745

744746

745-
def encode_dataset_coordinates(dataset):
747+
def encode_dataset_coordinates(dataset: Dataset):
746748
"""Encode coordinates on the given dataset object into variable specific
747749
and global attributes.
748750
@@ -764,7 +766,7 @@ def encode_dataset_coordinates(dataset):
764766
)
765767

766768

767-
def cf_encoder(variables, attributes):
769+
def cf_encoder(variables: T_Variables, attributes: T_Attrs):
768770
"""
769771
Encode a set of CF encoded variables and attributes.
770772
Takes a dicts of variables and attributes and encodes them

xarray/tests/test_coding_times.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -733,15 +733,15 @@ def test_encode_time_bounds() -> None:
733733

734734
# if time_bounds attrs are same as time attrs, it doesn't matter
735735
ds.time_bounds.encoding = {"calendar": "noleap", "units": "days since 2000-01-01"}
736-
encoded, _ = cf_encoder({k: ds[k] for k in ds.variables}, ds.attrs)
736+
encoded, _ = cf_encoder({k: v for k, v in ds.variables.items()}, ds.attrs)
737737
assert_equal(encoded["time_bounds"], expected["time_bounds"])
738738
assert "calendar" not in encoded["time_bounds"].attrs
739739
assert "units" not in encoded["time_bounds"].attrs
740740

741741
# for CF-noncompliant case of time_bounds attrs being different from
742742
# time attrs; preserve them for faithful roundtrip
743743
ds.time_bounds.encoding = {"calendar": "noleap", "units": "days since 1849-01-01"}
744-
encoded, _ = cf_encoder({k: ds[k] for k in ds.variables}, ds.attrs)
744+
encoded, _ = cf_encoder({k: v for k, v in ds.variables.items()}, ds.attrs)
745745
with pytest.raises(AssertionError):
746746
assert_equal(encoded["time_bounds"], expected["time_bounds"])
747747
assert "calendar" not in encoded["time_bounds"].attrs

0 commit comments

Comments
 (0)