Skip to content

Support NDFrame.shift with EAs #22387

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 23, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v0.24.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,7 @@ ExtensionType Changes
- Added ``ExtensionDtype._is_numeric`` for controlling whether an extension dtype is considered numeric (:issue:`22290`).
- The ``ExtensionArray`` constructor, ``_from_sequence`` now take the keyword arg ``copy=False`` (:issue:`21185`)
- Bug in :meth:`Series.get` for ``Series`` using ``ExtensionArray`` and integer index (:issue:`21257`)
- :meth:`~Series.shift` now dispatches to :meth:`ExtensionArray.shift` (:issue:`22386`)
- :meth:`Series.combine()` works correctly with :class:`~pandas.api.extensions.ExtensionArray` inside of :class:`Series` (:issue:`20825`)
- :meth:`Series.combine()` with scalar argument now works for any function type (:issue:`21248`)
- :meth:`Series.astype` and :meth:`DataFrame.astype` now dispatch to :meth:`ExtensionArray.astype` (:issue:`21185:`).
Expand Down
38 changes: 38 additions & 0 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ class ExtensionArray(object):
* factorize / _values_for_factorize
* argsort / _values_for_argsort

The remaining methods implemented on this class should be performant,
as they only compose abstract methods. Still, a more efficient
implementation may be available, and these methods can be overridden.

This class does not inherit from 'abc.ABCMeta' for performance reasons.
Methods and properties required by the interface raise
``pandas.errors.AbstractMethodError`` and no ``register`` method is
Expand Down Expand Up @@ -400,6 +404,40 @@ def dropna(self):

return self[~self.isna()]

def shift(self, periods=1):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think update in the ExtensionArray doc-string?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, in the class docstring we only mention the methods that either needs to implemented (because they raise AbstractMethodError otherwise) or either have a suboptimal implementation because it does the object ndarray roundtrip.
This is not the case here (which is not saying we couldn't also list other methods that can be overriden for specific reasons)

# type: (int) -> ExtensionArray
"""
Shift values by desired number.

Newly introduced missing values are filled with
``self.dtype.na_value``.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a versionadded tag


.. versionadded:: 0.24.0

Parameters
----------
periods : int, default 1
The number of periods to shift. Negative values are allowed
for shifting backwards.

Returns
-------
shifted : ExtensionArray
"""
# Note: this implementation assumes that `self.dtype.na_value` can be
# stored in an instance of your ExtensionArray with `self.dtype`.
if periods == 0:
return self.copy()
empty = self._from_sequence([self.dtype.na_value] * abs(periods),
dtype=self.dtype)
if periods > 0:
a = empty
b = self[:-periods]
else:
a = self[abs(periods):]
b = empty
return self._concat_same_type([a, b])

def unique(self):
"""Compute the ExtensionArray of unique values.

Expand Down
16 changes: 12 additions & 4 deletions pandas/core/internals/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2068,6 +2068,18 @@ def interpolate(self, method='pad', axis=0, inplace=False, limit=None,
limit=limit),
placement=self.mgr_locs)

def shift(self, periods, axis=0, mgr=None):
"""
Shift the block by `periods`.

Dispatches to underlying ExtensionArray and re-boxes in an
ExtensionBlock.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bonus points for the Parameters :> (future PR ok too)

# type: (int, Optional[BlockPlacement]) -> List[ExtensionBlock]
return [self.make_block_same_class(self.values.shift(periods=periods),
placement=self.mgr_locs,
ndim=self.ndim)]


class NumericBlock(Block):
__slots__ = ()
Expand Down Expand Up @@ -2691,10 +2703,6 @@ def _try_coerce_result(self, result):

return result

def shift(self, periods, axis=0, mgr=None):
return self.make_block_same_class(values=self.values.shift(periods),
placement=self.mgr_locs)

def to_dense(self):
# Categorical.get_values returns a DatetimeIndex for datetime
# categories, so we can't simply use `np.asarray(self.values)` like
Expand Down
25 changes: 25 additions & 0 deletions pandas/tests/extension/base/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,28 @@ def test_combine_add(self, data_repeated):
expected = pd.Series(
orig_data1._from_sequence([a + val for a in list(orig_data1)]))
self.assert_series_equal(result, expected)

@pytest.mark.parametrize('frame', [True, False])
@pytest.mark.parametrize('periods, indices', [
(-2, [2, 3, 4, -1, -1]),
(0, [0, 1, 2, 3, 4]),
(2, [-1, -1, 0, 1, 2]),
])
def test_container_shift(self, data, frame, periods, indices):
# https://github.com/pandas-dev/pandas/issues/22386
subset = data[:5]
data = pd.Series(subset, name='A')
expected = pd.Series(subset.take(indices, allow_fill=True), name='A')

if frame:
result = data.to_frame(name='A').assign(B=1).shift(periods)
expected = pd.concat([
expected,
pd.Series([1] * 5, name='B').shift(periods)
], axis=1)
compare = self.assert_frame_equal
else:
result = data.shift(periods)
compare = self.assert_series_equal

compare(result, expected)