Skip to content
This repository was archived by the owner on Sep 11, 2023. It is now read-only.

Implement DataSource.get_contiguous_time_periods() #256

Merged
merged 3 commits into from
Oct 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 26 additions & 22 deletions nowcasting_dataset/data_sources/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def __post_init__(self):

self._history_duration = pd.Timedelta(self.history_minutes, unit="minutes")
self._forecast_duration = pd.Timedelta(self.forecast_minutes, unit="minutes")
# Add sample_period_duration because neither history_duration not forecast_duration include t0.
# Add sample_period_duration because neither history_duration not forecast_duration
# include t0.
self._total_seq_duration = (
self._history_duration + self._forecast_duration + self.sample_period_duration
)
Expand Down Expand Up @@ -112,13 +113,13 @@ def get_batch(
Get Batch Data

Args:
t0_datetimes: list of timestamps for the datetime of the batches. The batch will also include data
for historic and future depending on 'history_minutes' and 'future_minutes'.
t0_datetimes: list of timestamps for the datetime of the batches. The batch will also
include data for historic and future depending on `history_minutes` and
`future_minutes`.
x_locations: x center batch locations
y_locations: y center batch locations

Returns: Batch data

Returns: Batch data.
"""
examples = []
zipped = zip(t0_datetimes, x_locations, y_locations)
Expand Down Expand Up @@ -176,31 +177,34 @@ def get_contiguous_time_periods(self) -> pd.DataFrame:
Returns:
pd.DataFrame where each row represents a single time period. The pd.DataFrame
has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime').
"""

# TODO: Use nd_time.get_contiguous_time_periods()
# See https://github.com/openclimatefix/nowcasting_dataset/issues/223
raise NotImplementedError()

def _get_time_slice(self, t0_dt: pd.Timestamp):
"""Get a single timestep of data. Must be overridden."""
raise NotImplementedError()
Raises:
NotImplementedError if this DataSource has no concept of a datetime index.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

worth adding a logging statement?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question! TBH, if someone tries to call this function in the wrong context then that suggests a major mistake, so we probably want the code to blow up noisily (i.e. to throw an exception). So, if it's OK, I'll leave it like this for now, and we can see how we get on?

datetimes = self.datetime_index()
return nd_time.get_contiguous_time_periods(
datetimes=datetimes,
min_seq_length=self._total_seq_length,
max_gap_duration=self.sample_period_duration,
)

# ****************** METHODS THAT MUST BE OVERRIDDEN **********************
def get_locations_for_batch(
self, t0_datetimes: pd.DatetimeIndex
) -> Tuple[List[Number], List[Number]]:
"""Find a valid geographical location for each t0_datetime.
"""Find a valid geographical locations for each t0_datetime.

Should be overridden by DataSources which may be used to define the locations
for each batch.

Returns: x_locations, y_locations. Each has one entry per t0_datetime.
Locations are in OSGB coordinates.
"""
# TODO: Do this properly, using PV locations!
locations = [20_000, 40_000, 500_000, 600_000, 100_000, 100_000, 250_000, 250_000]

location = np.random.choice(locations, size=(len(t0_datetimes), 2))
raise NotImplementedError()

return location[:, 0], location[:, 1]
# ****************** METHODS THAT MUST BE OVERRIDDEN **********************
def _get_time_slice(self, t0_dt: pd.Timestamp):
"""Get a single timestep of data. Must be overridden."""
raise NotImplementedError()

def get_example(
self,
Expand Down Expand Up @@ -273,8 +277,8 @@ def get_example(
Get Example data

Args:
t0_dt: list of timestamps for the datetime of the batches. The batch will also include data
for historic and future depending on 'history_minutes' and 'future_minutes'.
t0_dt: list of timestamps for the datetime of the batches. The batch will also include
data for historic and future depending on `history_minutes` and `future_minutes`.
x_meters_center: x center batch locations
y_meters_center: y center batch locations

Expand Down
15 changes: 10 additions & 5 deletions nowcasting_dataset/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,15 +172,20 @@ def get_start_datetimes(


def get_contiguous_time_periods(
datetimes: pd.DatetimeIndex, min_seq_length: int, max_gap: pd.Timedelta = THIRTY_MINUTES
datetimes: pd.DatetimeIndex,
min_seq_length: int,
max_gap_duration: pd.Timedelta = THIRTY_MINUTES,
) -> pd.DataFrame:
"""Returns a pd.DataFrame where each row records the boundary of a contiguous time periods.

Args:
datetimes: The pd.DatetimeIndex of the timeseries. Must be sorted.
min_seq_length: Sequences of min_seq_length or shorter will be discarded.
max_gap: If any pair of consecutive `datetimes` is more than `max_gap` apart, then this pair
of `datetimes` will be considered a "gap" between two contiguous sequences.
min_seq_length: Sequences of min_seq_length or shorter will be discarded. Typically, this
would be set to the `total_seq_length` of each machine learning example.
max_gap_duration: If any pair of consecutive `datetimes` is more than `max_gap_duration`
apart, then this pair of `datetimes` will be considered a "gap" between two contiguous
sequences. Typically, `max_gap_duration` would be set to the sample period of
the timeseries.

Returns:
pd.DataFrame where each row represents a single time period. The pd.DataFrame
Expand All @@ -193,7 +198,7 @@ def get_contiguous_time_periods(
assert datetimes.is_unique

# Find indices of gaps larger than max_gap:
gap_mask = np.diff(datetimes) > max_gap
gap_mask = np.diff(datetimes) > max_gap_duration
gap_indices = np.argwhere(gap_mask)[:, 0]

# gap_indicies are the indices into dt_index for the timestep immediately
Expand Down
1 change: 0 additions & 1 deletion tests/data_sources/test_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@


def test_image_data_source():

_ = ImageDataSource(
image_size_pixels=64,
meters_per_pixel=2000,
Expand Down
38 changes: 22 additions & 16 deletions tests/data_sources/test_nwp_data_source.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import os
import pandas as pd

import nowcasting_dataset
from nowcasting_dataset.data_sources.nwp.nwp_data_source import NWPDataSource


def test_nwp_data_source_init():
PATH = os.path.dirname(nowcasting_dataset.__file__)

path = os.path.dirname(nowcasting_dataset.__file__)
# Solar PV data (test data)
NWP_FILENAME = f"{PATH}/../tests/data/nwp_data/test.zarr"

# Solar PV data (test data)
NWP_FILENAME = f"{path}/../tests/data/nwp_data/test.zarr"

def test_nwp_data_source_init():
_ = NWPDataSource(
filename=NWP_FILENAME,
history_minutes=30,
Expand All @@ -21,12 +22,6 @@ def test_nwp_data_source_init():


def test_nwp_data_source_open():

path = os.path.dirname(nowcasting_dataset.__file__)

# Solar PV data (test data)
NWP_FILENAME = f"{path}/../tests/data/nwp_data/test.zarr"

nwp = NWPDataSource(
filename=NWP_FILENAME,
history_minutes=30,
Expand All @@ -40,12 +35,6 @@ def test_nwp_data_source_open():


def test_nwp_data_source_batch():

path = os.path.dirname(nowcasting_dataset.__file__)

# Solar PV data (test data)
NWP_FILENAME = f"{path}/../tests/data/nwp_data/test.zarr"

nwp = NWPDataSource(
filename=NWP_FILENAME,
history_minutes=30,
Expand All @@ -64,3 +53,20 @@ def test_nwp_data_source_batch():
batch = nwp.get_batch(t0_datetimes=t0_datetimes, x_locations=x, y_locations=y)

assert batch.data.shape == (4, 1, 19, 2, 2)


def test_nwp_get_contiguous_time_periods():
nwp = NWPDataSource(
filename=NWP_FILENAME,
history_minutes=30,
forecast_minutes=60,
convert_to_numpy=True,
n_timesteps_per_batch=8,
channels=["t"],
)

contiguous_time_periods = nwp.get_contiguous_time_periods()
correct_time_periods = pd.DataFrame(
[{"start_dt": pd.Timestamp("2019-01-01 00:00"), "end_dt": pd.Timestamp("2019-01-02 02:00")}]
)
pd.testing.assert_frame_equal(contiguous_time_periods, correct_time_periods)
13 changes: 2 additions & 11 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import pytest

import nowcasting_dataset.time as nd_time
from nowcasting_dataset.consts import GSP_DATETIME_INDEX
from nowcasting_dataset.dataset.datasets import NowcastingDataset
from nowcasting_dataset.dataset.batch import Batch

Expand Down Expand Up @@ -56,16 +55,8 @@ def test_per_worker_init(dataset: NowcastingDataset):

def test_get_batch(dataset: NowcastingDataset):
dataset.per_worker_init(worker_id=1)
batch = dataset._get_batch()
assert isinstance(batch, Batch)
assert batch.satellite is not None
assert batch.satellite.data.shape == (
8,
2,
pytest.IMAGE_SIZE_PIXELS,
pytest.IMAGE_SIZE_PIXELS,
1,
)
with pytest.raises(NotImplementedError):
_ = dataset._get_batch()


def test_get_batch_gsp(dataset_gsp: NowcastingDataset):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_get_contiguous_time_periods_1_with_1_chunk(min_seq_length):
freq = pd.Timedelta(5, unit="minutes")
dt_index = pd.date_range("2010-01-01", "2010-01-02", freq=freq)
periods: pd.DataFrame = nd_time.get_contiguous_time_periods(
dt_index, min_seq_length=min_seq_length, max_gap=freq
dt_index, min_seq_length=min_seq_length, max_gap_duration=freq
)
correct_periods = pd.DataFrame([{"start_dt": dt_index[0], "end_dt": dt_index[-1]}])
pd.testing.assert_frame_equal(periods, correct_periods)
Expand All @@ -81,7 +81,7 @@ def test_get_contiguous_time_periods_2_with_2_chunks(min_seq_length):
dt_index2 = pd.date_range("2010-02-01", "2010-02-02", freq=freq)
dt_index = dt_index1.union(dt_index2)
periods: pd.DataFrame = nd_time.get_contiguous_time_periods(
dt_index, min_seq_length=min_seq_length, max_gap=freq
dt_index, min_seq_length=min_seq_length, max_gap_duration=freq
)
correct_periods = pd.DataFrame(
[
Expand Down