Skip to content
This repository was archived by the owner on Sep 11, 2023. It is now read-only.

Implement DataSourceList.sample_spatial_and_temporal_locations_for_examples() #278

Merged
merged 8 commits into from
Oct 25, 2021
8 changes: 2 additions & 6 deletions nowcasting_dataset/data_sources/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from numbers import Number
from typing import List, Tuple, Iterable

import numpy as np
import pandas as pd
import xarray as xr

Expand Down Expand Up @@ -181,13 +180,10 @@ def get_contiguous_time_periods(self) -> pd.DataFrame:
max_gap_duration=self.sample_period_duration,
)

def get_locations_for_batch(
self, t0_datetimes: pd.DatetimeIndex
) -> Tuple[List[Number], List[Number]]:
def get_locations(self, t0_datetimes: pd.DatetimeIndex) -> Tuple[List[Number], List[Number]]:
"""Find a valid geographical locations for each t0_datetime.

Should be overridden by DataSources which may be used to define the locations
for each batch.
Should be overridden by DataSources which may be used to define the locations.

Returns: x_locations, y_locations. Each has one entry per t0_datetime.
Locations are in OSGB coordinates.
Expand Down
85 changes: 85 additions & 0 deletions nowcasting_dataset/data_sources/data_source_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""DataSourceList class."""

import numpy as np
import pandas as pd
import logging

import nowcasting_dataset.time as nd_time
from nowcasting_dataset.dataset.split.split import SplitMethod, split_data, SplitName

logger = logging.getLogger(__name__)


class DataSourceList(list):
"""Hold a list of DataSource objects.

The first DataSource in the list is used to compute the geospatial locations of each example.
"""

def get_t0_datetimes_across_all_data_sources(self, freq: str) -> pd.DatetimeIndex:
"""
Compute the intersection of the t0 datetimes available across all DataSources.

Args:
freq: The Pandas frequency string. The returned DatetimeIndex will be at this frequency,
and every datetime will be aligned to this frequency. For example, if
freq='5 minutes' then every datetime will be at 00, 05, ..., 55 minutes
past the hour.

Returns: Valid t0 datetimes, taking into consideration all DataSources,
filtered by daylight hours (SatelliteDataSource.datetime_index() removes the night
datetimes).
"""
logger.debug("Get the intersection of time periods across all DataSources.")

# Get the intersection of t0 time periods from all data sources.
t0_time_periods_for_all_data_sources = []
for data_source in self:
logger.debug(f"Getting t0 time periods for {type(data_source).__name__}")
try:
t0_time_periods = data_source.get_contiguous_t0_time_periods()
except NotImplementedError:
pass # Skip data_sources with no concept of time.
else:
t0_time_periods_for_all_data_sources.append(t0_time_periods)

intersection_of_t0_time_periods = nd_time.intersection_of_multiple_dataframes_of_periods(
t0_time_periods_for_all_data_sources
)

t0_datetimes = nd_time.time_periods_to_datetime_index(
time_periods=intersection_of_t0_time_periods, freq=freq
)

return t0_datetimes

def sample_spatial_and_temporal_locations_for_examples(
self, t0_datetimes: pd.DatetimeIndex, n_examples: int
) -> pd.DataFrame:
"""
Computes the geospatial and temporal locations for each training example.

The first data_source in this DataSourceList defines the geospatial locations of
each example.

Args:
t0_datetimes: All available t0 datetimes. Can be computed with
`DataSourceList.get_t0_datetimes_across_all_data_sources()`
n_examples: The number of examples requested.

Returns:
Each row of each the DataFrame specifies the position of each example, using
columns: 't0_datetime_UTC', 'x_center_OSGB', 'y_center_OSGB'.
"""
data_source_which_defines_geo_position = self[0]
shuffled_t0_datetimes = np.random.choice(t0_datetimes, size=n_examples)
x_locations, y_locations = data_source_which_defines_geo_position.get_locations(
shuffled_t0_datetimes
)
return pd.DataFrame(
{
"t0_datetime_UTC": shuffled_t0_datetimes,
"x_center_OSGB": x_locations,
"y_center_OSGB": y_locations,
}
)
14 changes: 0 additions & 14 deletions nowcasting_dataset/data_sources/datetime/datetime_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,6 @@
class DatetimeDataSource(DataSource):
""" Add hour_of_day_{sin, cos} and day_of_year_{sin, cos} features. """

def __post_init__(self):
""" Post init """
super().__post_init__()

def get_example(
self, t0_dt: pd.Timestamp, x_meters_center: Number, y_meters_center: Number
) -> Datetime:
Expand All @@ -44,13 +40,3 @@ def get_example(
datetime_xr_dataset = make_dim_index(datetime_xr_dataset)

return Datetime(datetime_xr_dataset)

def get_locations_for_batch(
self, t0_datetimes: pd.DatetimeIndex
) -> Tuple[List[Number], List[Number]]:
""" This method is not needed for DatetimeDataSource """
raise NotImplementedError()

def datetime_index(self) -> pd.DatetimeIndex:
""" This method is not needed for DatetimeDataSource """
raise NotImplementedError()
10 changes: 4 additions & 6 deletions nowcasting_dataset/data_sources/gsp/gsp_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,17 +108,15 @@ def datetime_index(self):
"""
return self.gsp_power.index

def get_locations_for_batch(
self, t0_datetimes: pd.DatetimeIndex
) -> Tuple[List[Number], List[Number]]:
def get_locations(self, t0_datetimes: pd.DatetimeIndex) -> Tuple[List[Number], List[Number]]:
"""
Get x and y locations for a batch. Assume that all data is available for all GSP.
Get x and y locations. Assume that all data is available for all GSP.

Random GSP are taken, and the locations of them are returned. This is useful as other
datasources need to know which x,y locations to get.

Args:
t0_datetimes: list of datetimes that the batches locations have data for
t0_datetimes: list of available t0 datetimes.

Returns: list of x and y locations

Expand Down Expand Up @@ -266,7 +264,7 @@ def _get_central_gsp_id(
logger.debug("Getting Central GSP")

# If x_meters_center and y_meters_center have been chosen
# by {}.get_locations_for_batch() then we just have
# by {}.get_locations() then we just have
# to find the gsp_ids at that exact location. This is
# super-fast (a few hundred microseconds). We use np.isclose
# instead of the equality operator because floats.
Expand Down
16 changes: 2 additions & 14 deletions nowcasting_dataset/data_sources/metadata/metadata_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,6 @@ class MetadataDataSource(DataSource):

object_at_center: str = "GSP"

def __post_init__(self):
"""Post init"""
super().__post_init__()

def get_example(
self, t0_dt: pd.Timestamp, x_meters_center: Number, y_meters_center: Number
) -> Metadata:
Expand All @@ -44,6 +40,8 @@ def get_example(
else:
object_at_center_label = 0

# TODO: data_dict is unused in this function. Is that a bug?
# https://github.com/openclimatefix/nowcasting_dataset/issues/279
data_dict = dict(
t0_dt=to_numpy(t0_dt), #: Shape: [batch_size,]
x_meters_center=np.array(x_meters_center),
Expand All @@ -68,13 +66,3 @@ def get_example(
data[v] = getattr(d, v)

return Metadata(data)

def get_locations_for_batch(
self, t0_datetimes: pd.DatetimeIndex
) -> Tuple[List[Number], List[Number]]:
"""This method is not needed for MetadataDataSource"""
raise NotImplementedError()

def datetime_index(self) -> pd.DatetimeIndex:
"""This method is not needed for MetadataDataSource"""
raise NotImplementedError()
4 changes: 1 addition & 3 deletions nowcasting_dataset/data_sources/pv/pv_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,7 @@ def get_example(

return PV(pv)

def get_locations_for_batch(
self, t0_datetimes: pd.DatetimeIndex
) -> Tuple[List[Number], List[Number]]:
def get_locations(self, t0_datetimes: pd.DatetimeIndex) -> Tuple[List[Number], List[Number]]:
"""Find a valid geographical location for each t0_datetime.

Returns: x_locations, y_locations. Each has one entry per t0_datetime.
Expand Down
5 changes: 1 addition & 4 deletions nowcasting_dataset/data_sources/sun/sun_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,11 @@ def get_example(
return Sun(sun)

def _load(self):

self.azimuth, self.elevation = load_from_zarr(
filename=self.filename, start_dt=self.start_dt, end_dt=self.end_dt
)

def get_locations_for_batch(
self, t0_datetimes: pd.DatetimeIndex
) -> Tuple[List[Number], List[Number]]:
def get_locations(self, t0_datetimes: pd.DatetimeIndex) -> Tuple[List[Number], List[Number]]:
""" Sun data should not be used to get batch locations """
raise NotImplementedError("Sun data should not be used to get batch locations")

Expand Down
47 changes: 15 additions & 32 deletions nowcasting_dataset/dataset/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@

from nowcasting_dataset import consts
from nowcasting_dataset import data_sources
from nowcasting_dataset import time as nd_time
from nowcasting_dataset import utils
from nowcasting_dataset.data_sources.gsp.gsp_data_source import GSPDataSource
from nowcasting_dataset.data_sources.metadata.metadata_data_source import MetadataDataSource
from nowcasting_dataset.data_sources.sun.sun_data_source import SunDataSource
from nowcasting_dataset.dataset import datasets
from nowcasting_dataset.dataset.split.split import split_data, SplitMethod
from nowcasting_dataset.data_sources.data_source_list import DataSourceList


with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
Expand Down Expand Up @@ -207,6 +207,8 @@ def prepare_data(self) -> None:
)
)

self.data_sources = DataSourceList(self.data_sources)

def setup(self, stage="fit"):
"""Split data, etc.

Expand Down Expand Up @@ -309,14 +311,18 @@ def _split_data(self):
logger.debug("Going to split data")

self._check_has_prepared_data()
self.t0_datetimes = self._get_t0_datetimes()
self.t0_datetimes = self._get_t0_datetimes_across_all_data_sources()

logger.debug(f"Got all start times, there are {len(self.t0_datetimes):,d}")

self.train_t0_datetimes, self.val_t0_datetimes, self.test_t0_datetimes = split_data(
data_after_splitting = split_data(
datetimes=self.t0_datetimes, method=self.split_method, seed=self.seed
)

self.train_t0_datetimes = data_after_splitting.train
self.val_t0_datetimes = data_after_splitting.validation
self.test_t0_datetimes = data_after_splitting.test

logger.debug(
f"Split data done, train has {len(self.train_t0_datetimes):,d}, "
f"validation has {len(self.val_t0_datetimes):,d}, "
Expand Down Expand Up @@ -354,38 +360,15 @@ def _common_dataloader_params(self) -> Dict:
batch_sampler=None,
)

def _get_t0_datetimes(self) -> pd.DatetimeIndex:
"""
Compute the intersection of the t0 datetimes available across all DataSources.
def _get_t0_datetimes_across_all_data_sources(self) -> pd.DatetimeIndex:
"""See DataSourceList.get_t0_datetimes_across_all_data_sources.

Returns the valid t0 datetimes, taking into consideration all DataSources,
filtered by daylight hours (SatelliteDataSource.datetime_index() removes the night
datetimes).
This method will be deleted as part of implementing #213.
"""
logger.debug("Get the intersection of time periods across all DataSources.")
self._check_has_prepared_data()

# Get the intersection of t0 time periods from all data sources.
t0_time_periods_for_all_data_sources = []
for data_source in self.data_sources:
logger.debug(f"Getting t0 time periods for {type(data_source).__name__}")
try:
t0_time_periods = data_source.get_contiguous_t0_time_periods()
except NotImplementedError:
pass # Skip data_sources with no concept of time.
else:
t0_time_periods_for_all_data_sources.append(t0_time_periods)

intersection_of_t0_time_periods = nd_time.intersection_of_multiple_dataframes_of_periods(
t0_time_periods_for_all_data_sources
)

t0_datetimes = nd_time.time_periods_to_datetime_index(
time_periods=intersection_of_t0_time_periods, freq=self.t0_datetime_freq
return self.data_sources.get_t0_datetimes_across_all_data_sources(
freq=self.t0_datetime_freq
)

return t0_datetimes

def _check_has_prepared_data(self):
if not self.has_prepared_data:
raise RuntimeError("Must run prepare_data() first!")
8 changes: 3 additions & 5 deletions nowcasting_dataset/dataset/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def _get_batch(self) -> Batch:
return []

t0_datetimes = self._get_t0_datetimes_for_batch()
x_locations, y_locations = self._get_locations_for_batch(t0_datetimes)
x_locations, y_locations = self._get_locations(t0_datetimes)

examples = {}
n_threads = len(self.data_sources)
Expand Down Expand Up @@ -179,10 +179,8 @@ def _get_t0_datetimes_for_batch(self) -> pd.DatetimeIndex:
t0_datetimes = np.tile(t0_datetimes, reps=self.n_samples_per_timestep)
return pd.DatetimeIndex(t0_datetimes)

def _get_locations_for_batch(
self, t0_datetimes: pd.DatetimeIndex
) -> Tuple[List[Number], List[Number]]:
return self.data_sources[0].get_locations_for_batch(t0_datetimes)
def _get_locations(self, t0_datetimes: pd.DatetimeIndex) -> Tuple[List[Number], List[Number]]:
return self.data_sources[0].get_locations(t0_datetimes)


def worker_init_fn(worker_id):
Expand Down
25 changes: 23 additions & 2 deletions nowcasting_dataset/dataset/split/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
from enum import Enum
from typing import List, Tuple, Union, Optional
from collections import namedtuple

import pandas as pd

Expand Down Expand Up @@ -30,6 +31,20 @@ class SplitMethod(Enum):
DAY_RANDOM_TEST_DATE = "day_random_test_date"


class SplitName(Enum):
"""The name for each data split."""

TRAIN = "train"
VALIDATION = "validation"
TEST = "test"


SplitData = namedtuple(
typename="SplitData",
field_names=[SplitName.TRAIN.value, SplitName.VALIDATION.value, SplitName.TEST.value],
)


def split_data(
datetimes: Union[List[pd.Timestamp], pd.DatetimeIndex],
method: SplitMethod,
Expand All @@ -39,7 +54,7 @@ def split_data(
),
train_validation_test_datetime_split: Optional[List[pd.Timestamp]] = None,
seed: int = 1234,
) -> (List[pd.Timestamp], List[pd.Timestamp], List[pd.Timestamp]):
) -> SplitData:
"""
Split the date using various different methods

Expand Down Expand Up @@ -165,4 +180,10 @@ def split_data(
else:
raise ValueError(f"{method} for splitting day is not implemented")

return train_datetimes, validation_datetimes, test_datetimes
logger.debug(
f"Split data done, train has {len(train_datetimes):,d}, "
f"validation has {len(validation_datetimes):,d}, "
f"test has {len(test_datetimes):,d} t0 datetimes."
)

return SplitData(train=train_datetimes, validation=validation_datetimes, test=test_datetimes)
1 change: 1 addition & 0 deletions nowcasting_dataset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from nowcasting_dataset.consts import Array


logger = logging.getLogger(__name__)


Expand Down
Loading