Skip to content

Switch to T_DataArray and T_Dataset in concat #6784

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jul 18, 2022
39 changes: 20 additions & 19 deletions xarray/core/concat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Hashable, Iterable, overload
from typing import TYPE_CHECKING, Any, Hashable, Iterable, cast, overload

import pandas as pd

Expand All @@ -14,42 +14,41 @@
merge_attrs,
merge_collected,
)
from .types import T_DataArray, T_Dataset
from .variable import Variable
from .variable import concat as concat_vars

if TYPE_CHECKING:
from .dataarray import DataArray
from .dataset import Dataset
from .types import CombineAttrsOptions, CompatOptions, ConcatOptions, JoinOptions


@overload
def concat(
objs: Iterable[Dataset],
dim: Hashable | DataArray | pd.Index,
objs: Iterable[T_Dataset],
dim: Hashable | T_DataArray | pd.Index,
data_vars: ConcatOptions | list[Hashable] = "all",
coords: ConcatOptions | list[Hashable] = "different",
compat: CompatOptions = "equals",
positions: Iterable[Iterable[int]] | None = None,
fill_value: object = dtypes.NA,
join: JoinOptions = "outer",
combine_attrs: CombineAttrsOptions = "override",
) -> Dataset:
) -> T_Dataset:
...


@overload
def concat(
objs: Iterable[DataArray],
dim: Hashable | DataArray | pd.Index,
objs: Iterable[T_DataArray],
dim: Hashable | T_DataArray | pd.Index,
data_vars: ConcatOptions | list[Hashable] = "all",
coords: ConcatOptions | list[Hashable] = "different",
compat: CompatOptions = "equals",
positions: Iterable[Iterable[int]] | None = None,
fill_value: object = dtypes.NA,
join: JoinOptions = "outer",
combine_attrs: CombineAttrsOptions = "override",
) -> DataArray:
) -> T_DataArray:
...


Expand Down Expand Up @@ -402,7 +401,7 @@ def process_subset_opt(opt, subset):

# determine dimensional coordinate names and a dict mapping name to DataArray
def _parse_datasets(
datasets: Iterable[Dataset],
datasets: Iterable[T_Dataset],
) -> tuple[dict[Hashable, Variable], dict[Hashable, int], set[Hashable], set[Hashable]]:

dims: set[Hashable] = set()
Expand All @@ -429,16 +428,16 @@ def _parse_datasets(


def _dataset_concat(
datasets: list[Dataset],
dim: str | DataArray | pd.Index,
datasets: list[T_Dataset],
dim: str | T_DataArray | pd.Index,
data_vars: str | list[str],
coords: str | list[str],
compat: CompatOptions,
positions: Iterable[Iterable[int]] | None,
fill_value: object = dtypes.NA,
join: JoinOptions = "outer",
combine_attrs: CombineAttrsOptions = "override",
) -> Dataset:
) -> T_Dataset:
"""
Concatenate a sequence of datasets along a new or existing dimension
"""
Expand Down Expand Up @@ -482,7 +481,8 @@ def _dataset_concat(

# case where concat dimension is a coordinate or data_var but not a dimension
if (dim in coord_names or dim in data_names) and dim not in dim_names:
datasets = [ds.expand_dims(dim) for ds in datasets]
# TODO: Overriding type because .expand_dims has incorrect typing:
datasets = [cast(T_Dataset, ds.expand_dims(dim)) for ds in datasets]

# determine which variables to concatenate
concat_over, equals, concat_dim_lengths = _calc_concat_over(
Expand Down Expand Up @@ -590,7 +590,7 @@ def get_indexes(name):
# preserves original variable order
result_vars[name] = result_vars.pop(name)

result = Dataset(result_vars, attrs=result_attrs)
result = type(datasets[0])(result_vars, attrs=result_attrs)

absent_coord_names = coord_names - set(result.variables)
if absent_coord_names:
Expand Down Expand Up @@ -618,16 +618,16 @@ def get_indexes(name):


def _dataarray_concat(
arrays: Iterable[DataArray],
dim: str | DataArray | pd.Index,
arrays: Iterable[T_DataArray],
dim: str | T_DataArray | pd.Index,
data_vars: str | list[str],
coords: str | list[str],
compat: CompatOptions,
positions: Iterable[Iterable[int]] | None,
fill_value: object = dtypes.NA,
join: JoinOptions = "outer",
combine_attrs: CombineAttrsOptions = "override",
) -> DataArray:
) -> T_DataArray:
from .dataarray import DataArray

arrays = list(arrays)
Expand All @@ -650,7 +650,8 @@ def _dataarray_concat(
if compat == "identical":
raise ValueError("array names not identical")
else:
arr = arr.rename(name)
# TODO: Overriding type because .rename has incorrect typing:
arr = cast(T_DataArray, arr.rename(name))
datasets.append(arr._to_temp_dataset())

ds = _dataset_concat(
Expand Down