From fc60d24c04cd9f4d8b93c7670c1a32fde57da6a6 Mon Sep 17 00:00:00 2001 From: Jack Kelly Date: Mon, 25 Oct 2021 11:56:04 +0100 Subject: [PATCH 1/8] Move out of NowcastingDataModule and into DataSourceList --- .../data_sources/data_source.py | 1 - nowcasting_dataset/dataset/datamodule.py | 44 +++++-------------- nowcasting_dataset/utils.py | 1 + tests/test_datamodule.py | 7 +-- 4 files changed, 13 insertions(+), 40 deletions(-) diff --git a/nowcasting_dataset/data_sources/data_source.py b/nowcasting_dataset/data_sources/data_source.py index cc1b55de..a788d9f5 100644 --- a/nowcasting_dataset/data_sources/data_source.py +++ b/nowcasting_dataset/data_sources/data_source.py @@ -5,7 +5,6 @@ from numbers import Number from typing import List, Tuple, Iterable -import numpy as np import pandas as pd import xarray as xr diff --git a/nowcasting_dataset/dataset/datamodule.py b/nowcasting_dataset/dataset/datamodule.py index fb667eec..05ba6ef6 100644 --- a/nowcasting_dataset/dataset/datamodule.py +++ b/nowcasting_dataset/dataset/datamodule.py @@ -4,19 +4,19 @@ from dataclasses import dataclass from pathlib import Path from typing import Union, Optional, Iterable, Dict, Callable - import pandas as pd + import torch from nowcasting_dataset import consts from nowcasting_dataset import data_sources -from nowcasting_dataset import time as nd_time -from nowcasting_dataset import utils from nowcasting_dataset.data_sources.gsp.gsp_data_source import GSPDataSource from nowcasting_dataset.data_sources.metadata.metadata_data_source import MetadataDataSource from nowcasting_dataset.data_sources.sun.sun_data_source import SunDataSource from nowcasting_dataset.dataset import datasets from nowcasting_dataset.dataset.split.split import split_data, SplitMethod +from nowcasting_dataset.data_source_list import DataSourceList + with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=DeprecationWarning) @@ -207,6 +207,8 @@ def prepare_data(self) -> None: ) ) + self.data_sources = DataSourceList(self.data_sources) + def setup(self, stage="fit"): """Split data, etc. @@ -309,7 +311,7 @@ def _split_data(self): logger.debug("Going to split data") self._check_has_prepared_data() - self.t0_datetimes = self._get_t0_datetimes() + self.t0_datetimes = self._get_t0_datetimes_across_all_data_sources() logger.debug(f"Got all start times, there are {len(self.t0_datetimes):,d}") @@ -354,38 +356,14 @@ def _common_dataloader_params(self) -> Dict: batch_sampler=None, ) - def _get_t0_datetimes(self) -> pd.DatetimeIndex: - """ - Compute the intersection of the t0 datetimes available across all DataSources. + def _get_t0_datetimes_across_all_data_sources(self) -> pd.DatetimeIndex: + """See DataSourceList.get_t0_datetimes_across_all_data_sources. - Returns the valid t0 datetimes, taking into consideration all DataSources, - filtered by daylight hours (SatelliteDataSource.datetime_index() removes the night - datetimes). - """ - logger.debug("Get the intersection of time periods across all DataSources.") - self._check_has_prepared_data() - - # Get the intersection of t0 time periods from all data sources. - t0_time_periods_for_all_data_sources = [] - for data_source in self.data_sources: - logger.debug(f"Getting t0 time periods for {type(data_source).__name__}") - try: - t0_time_periods = data_source.get_contiguous_t0_time_periods() - except NotImplementedError: - pass # Skip data_sources with no concept of time. - else: - t0_time_periods_for_all_data_sources.append(t0_time_periods) - - intersection_of_t0_time_periods = nd_time.intersection_of_multiple_dataframes_of_periods( - t0_time_periods_for_all_data_sources + This method will be deleted as part of implementing #213.""" + return self.data_sources.get_t0_datetimes_across_all_data_sources( + freq=self.t0_datetime_freq ) - t0_datetimes = nd_time.time_periods_to_datetime_index( - time_periods=intersection_of_t0_time_periods, freq=self.t0_datetime_freq - ) - - return t0_datetimes - def _check_has_prepared_data(self): if not self.has_prepared_data: raise RuntimeError("Must run prepare_data() first!") diff --git a/nowcasting_dataset/utils.py b/nowcasting_dataset/utils.py index 5ba4b258..4b6666e1 100644 --- a/nowcasting_dataset/utils.py +++ b/nowcasting_dataset/utils.py @@ -14,6 +14,7 @@ from nowcasting_dataset.consts import Array + logger = logging.getLogger(__name__) diff --git a/tests/test_datamodule.py b/tests/test_datamodule.py index 30068397..73682051 100644 --- a/tests/test_datamodule.py +++ b/tests/test_datamodule.py @@ -5,7 +5,6 @@ import numpy as np import pandas as pd import pytest -import xarray as xr import nowcasting_dataset from nowcasting_dataset.config.load import load_yaml_configuration @@ -32,13 +31,9 @@ def test_prepare_data(nowcasting_datamodule: datamodule.NowcastingDataModule): def test_get_daylight_datetime_index( nowcasting_datamodule: datamodule.NowcastingDataModule, use_cloud_data: bool ): - # Check it throws RuntimeError if we try running - # _get_daylight_datetime_index() before running prepare_data(): - with pytest.raises(RuntimeError): - nowcasting_datamodule._get_t0_datetimes() nowcasting_datamodule.prepare_data() nowcasting_datamodule.t0_datetime_freq = "5T" - t0_datetimes = nowcasting_datamodule._get_t0_datetimes() + t0_datetimes = nowcasting_datamodule._get_t0_datetimes_across_all_data_sources() assert isinstance(t0_datetimes, pd.DatetimeIndex) if not use_cloud_data: # The testing sat_data.zarr has contiguous data from 12:05 to 18:00. From 65395bcb48ba8160c32a332c6d3bf8bb88a8bb3e Mon Sep 17 00:00:00 2001 From: Jack Kelly Date: Mon, 25 Oct 2021 12:02:58 +0100 Subject: [PATCH 2/8] Add data_source_list.py --- nowcasting_dataset/data_source_list.py | 80 ++++++++++++++++++++++++ nowcasting_dataset/dataset/datamodule.py | 2 +- 2 files changed, 81 insertions(+), 1 deletion(-) create mode 100644 nowcasting_dataset/data_source_list.py diff --git a/nowcasting_dataset/data_source_list.py b/nowcasting_dataset/data_source_list.py new file mode 100644 index 00000000..1d943033 --- /dev/null +++ b/nowcasting_dataset/data_source_list.py @@ -0,0 +1,80 @@ +"""DataSourceList class.""" + +import pandas as pd +import logging + +import nowcasting_dataset.time as nd_time + +logger = logging.getLogger(__name__) + + +class DataSourceList(list): + """Hold a list of DataSource objects.""" + + def get_t0_datetimes_across_all_data_sources(self, freq: str) -> pd.DatetimeIndex: + """ + Compute the intersection of the t0 datetimes available across all `data_sources`. + + Returns the valid t0 datetimes, taking into consideration all DataSources, + filtered by daylight hours (SatelliteDataSource.datetime_index() removes the night + datetimes). + """ + logger.debug("Get the intersection of time periods across all DataSources.") + + # Get the intersection of t0 time periods from all data sources. + t0_time_periods_for_all_data_sources = [] + for data_source in self: + logger.debug(f"Getting t0 time periods for {type(data_source).__name__}") + try: + t0_time_periods = data_source.get_contiguous_t0_time_periods() + except NotImplementedError: + pass # Skip data_sources with no concept of time. + else: + t0_time_periods_for_all_data_sources.append(t0_time_periods) + + intersection_of_t0_time_periods = nd_time.intersection_of_multiple_dataframes_of_periods( + t0_time_periods_for_all_data_sources + ) + + t0_datetimes = nd_time.time_periods_to_datetime_index( + time_periods=intersection_of_t0_time_periods, freq=freq + ) + + return t0_datetimes + + """ + def compute_and_save_positions_of_each_example_of_each_split( + self, + split_method: SplitMethod, + n_examples_per_split: dict[SplitMethod, int], + dst_path: Path + ) -> None: + Computes the geospatial and temporal position of each training example. + + Finds the time periods available across all data_sources. + + Args: + data_sources: A list of DataSources. The first data_source is used to define the geospatial + location of each example. + split_method: The method used to split the available data into train, validation, and test. + n_examples_per_split: The number of examples requested for each split. + dst_path: The destination path. This is where the CSV files will be saved into. + CSV files will be saved into dst_path / split_method / 'positions_of_each_example.csv'. + + # Get intersection of all available t0_datetimes. Current done by NowcastingDataModule._get_datetimes(): + # github.com/openclimatefix/nowcasting_dataset/blob/main/nowcasting_dataset/dataset/datamodule.py#L364 + t0_datetimes_for_all_data_sources = [data_source.get_t0_datetimes() for data_source in data_sources] + intersection_of_t0_datetimes = nd_time.intersection_of_datetimeindexes(t0_datetimes_for_all_data_sources) + + # Split t0_datetimes into train, test and validation sets (being careful to ensure each group is + # at least `total_seq_duration` apart). Currently done by NowcastingDataModule._split_data(): + # github.com/openclimatefix/nowcasting_dataset/blob/main/nowcasting_dataset/dataset/datamodule.py#L315 + t0_datetimes_per_split: dict[SplitName, pd.DatetimeIndex] = split_datetimes( + intersection_of_t0_datetimes, method=split_method) + + for split_name, t0_datetimes_for_split in t0_datetimes_per_split.items(): + n_examples = n_examples_per_split[split_name] + positions = compute_positions_of_each_example(t0_datetimes_for_split, data_source, n_examples) + filename = dst_path / split_name / 'positions_of_each_example.csv' + positions.to_csv(filename) + """ diff --git a/nowcasting_dataset/dataset/datamodule.py b/nowcasting_dataset/dataset/datamodule.py index 05ba6ef6..cf64ff22 100644 --- a/nowcasting_dataset/dataset/datamodule.py +++ b/nowcasting_dataset/dataset/datamodule.py @@ -4,8 +4,8 @@ from dataclasses import dataclass from pathlib import Path from typing import Union, Optional, Iterable, Dict, Callable -import pandas as pd +import pandas as pd import torch from nowcasting_dataset import consts From 89db73357007da9cccdf37d6bbebd705d1701c14 Mon Sep 17 00:00:00 2001 From: Jack Kelly Date: Mon, 25 Oct 2021 13:13:16 +0100 Subject: [PATCH 3/8] split_data now returns a SplitData namedtuple. Tests pass --- nowcasting_dataset/data_source_list.py | 82 ++++++++++++------- .../data_sources/data_source.py | 7 +- .../datetime/datetime_data_source.py | 14 ---- .../data_sources/gsp/gsp_data_source.py | 10 +-- .../metadata/metadata_data_source.py | 16 +--- .../data_sources/pv/pv_data_source.py | 4 +- .../data_sources/sun/sun_data_source.py | 5 +- nowcasting_dataset/dataset/datamodule.py | 8 +- nowcasting_dataset/dataset/datasets.py | 8 +- nowcasting_dataset/dataset/split/split.py | 22 ++++- .../data_sources/gsp/test_gsp_data_source.py | 10 +-- tests/data_sources/test_pv_data_source.py | 6 +- 12 files changed, 96 insertions(+), 96 deletions(-) diff --git a/nowcasting_dataset/data_source_list.py b/nowcasting_dataset/data_source_list.py index 1d943033..8891b392 100644 --- a/nowcasting_dataset/data_source_list.py +++ b/nowcasting_dataset/data_source_list.py @@ -1,23 +1,34 @@ """DataSourceList class.""" +import numpy as np import pandas as pd import logging import nowcasting_dataset.time as nd_time +from nowcasting_dataset.dataset.split.split import SplitMethod, split_data, SplitName logger = logging.getLogger(__name__) class DataSourceList(list): - """Hold a list of DataSource objects.""" + """Hold a list of DataSource objects. + + The first DataSource in the list is used to compute the geospatial locations of each example. + """ def get_t0_datetimes_across_all_data_sources(self, freq: str) -> pd.DatetimeIndex: """ - Compute the intersection of the t0 datetimes available across all `data_sources`. + Compute the intersection of the t0 datetimes available across all DataSources. - Returns the valid t0 datetimes, taking into consideration all DataSources, - filtered by daylight hours (SatelliteDataSource.datetime_index() removes the night - datetimes). + Args: + freq: The Pandas frequency string. The returned DatetimeIndex will be at this frequency, + and every datetime will be aligned to this frequency. For example, if + freq='5 minutes' then every datetime will be at 00, 05, ..., 55 minutes + past the hour. + + Returns: Valid t0 datetimes, taking into consideration all DataSources, + filtered by daylight hours (SatelliteDataSource.datetime_index() removes the night + datetimes). """ logger.debug("Get the intersection of time periods across all DataSources.") @@ -42,39 +53,48 @@ def get_t0_datetimes_across_all_data_sources(self, freq: str) -> pd.DatetimeInde return t0_datetimes - """ - def compute_and_save_positions_of_each_example_of_each_split( - self, - split_method: SplitMethod, - n_examples_per_split: dict[SplitMethod, int], - dst_path: Path - ) -> None: + def sample_position_of_every_example_of_every_split( + self, + t0_datetimes: pd.DatetimeIndex, + split_method: SplitMethod, + n_examples_per_split: dict[SplitMethod, int], + ) -> dict[SplitName, pd.DataFrame]: + """ Computes the geospatial and temporal position of each training example. - Finds the time periods available across all data_sources. + Computes the intersection of the time periods available across all data_sources. + + The first data_source in this DataSourceList defines the geospatial locations of + each example. Args: - data_sources: A list of DataSources. The first data_source is used to define the geospatial - location of each example. - split_method: The method used to split the available data into train, validation, and test. + t0_datetimes: All available t0 datetimes. Can be computed with + `DataSourceList.get_t0_datetimes_across_all_data_sources()` + split_method: The method used to split data into train, validation, and test. n_examples_per_split: The number of examples requested for each split. - dst_path: The destination path. This is where the CSV files will be saved into. - CSV files will be saved into dst_path / split_method / 'positions_of_each_example.csv'. - # Get intersection of all available t0_datetimes. Current done by NowcastingDataModule._get_datetimes(): - # github.com/openclimatefix/nowcasting_dataset/blob/main/nowcasting_dataset/dataset/datamodule.py#L364 - t0_datetimes_for_all_data_sources = [data_source.get_t0_datetimes() for data_source in data_sources] - intersection_of_t0_datetimes = nd_time.intersection_of_datetimeindexes(t0_datetimes_for_all_data_sources) + Returns: + A dict where the keys are a SplitName, and the values are a pd.DataFrame. + Each row of each DataFrame specifies the position of each example, using + columns: 't0_datetime_UTC', 'x_center_OSGB', 'y_center_OSGB'. + """ + # Split t0_datetimes into train, test and validation sets. + t0_datetimes_per_split = split_data(datetimes=t0_datetimes, method=split_method) + t0_datetimes_per_split = t0_datetimes_per_split._asdict() - # Split t0_datetimes into train, test and validation sets (being careful to ensure each group is - # at least `total_seq_duration` apart). Currently done by NowcastingDataModule._split_data(): - # github.com/openclimatefix/nowcasting_dataset/blob/main/nowcasting_dataset/dataset/datamodule.py#L315 - t0_datetimes_per_split: dict[SplitName, pd.DatetimeIndex] = split_datetimes( - intersection_of_t0_datetimes, method=split_method) + data_source_which_defines_geo_position = self[0] + positions_per_split: dict[SplitName, pd.DataFrame] = {} for split_name, t0_datetimes_for_split in t0_datetimes_per_split.items(): n_examples = n_examples_per_split[split_name] - positions = compute_positions_of_each_example(t0_datetimes_for_split, data_source, n_examples) - filename = dst_path / split_name / 'positions_of_each_example.csv' - positions.to_csv(filename) - """ + shuffled_t0_datetimes = np.random.choice(t0_datetimes_for_split, shape=n_examples) + x_locations, y_locations = data_source_which_defines_geo_position.get_locations( + shuffled_t0_datetimes + ) + positions_per_split[split_name] = { + "t0_datetime_UTC": shuffled_t0_datetimes, + "x_center_OSGB": x_locations, + "y_center_OSGB": y_locations, + } + + return positions_per_split diff --git a/nowcasting_dataset/data_sources/data_source.py b/nowcasting_dataset/data_sources/data_source.py index a788d9f5..0b2cac13 100644 --- a/nowcasting_dataset/data_sources/data_source.py +++ b/nowcasting_dataset/data_sources/data_source.py @@ -180,13 +180,10 @@ def get_contiguous_time_periods(self) -> pd.DataFrame: max_gap_duration=self.sample_period_duration, ) - def get_locations_for_batch( - self, t0_datetimes: pd.DatetimeIndex - ) -> Tuple[List[Number], List[Number]]: + def get_locations(self, t0_datetimes: pd.DatetimeIndex) -> Tuple[List[Number], List[Number]]: """Find a valid geographical locations for each t0_datetime. - Should be overridden by DataSources which may be used to define the locations - for each batch. + Should be overridden by DataSources which may be used to define the locations. Returns: x_locations, y_locations. Each has one entry per t0_datetime. Locations are in OSGB coordinates. diff --git a/nowcasting_dataset/data_sources/datetime/datetime_data_source.py b/nowcasting_dataset/data_sources/datetime/datetime_data_source.py index b6d900e3..e83fdab9 100644 --- a/nowcasting_dataset/data_sources/datetime/datetime_data_source.py +++ b/nowcasting_dataset/data_sources/datetime/datetime_data_source.py @@ -15,10 +15,6 @@ class DatetimeDataSource(DataSource): """ Add hour_of_day_{sin, cos} and day_of_year_{sin, cos} features. """ - def __post_init__(self): - """ Post init """ - super().__post_init__() - def get_example( self, t0_dt: pd.Timestamp, x_meters_center: Number, y_meters_center: Number ) -> Datetime: @@ -44,13 +40,3 @@ def get_example( datetime_xr_dataset = make_dim_index(datetime_xr_dataset) return Datetime(datetime_xr_dataset) - - def get_locations_for_batch( - self, t0_datetimes: pd.DatetimeIndex - ) -> Tuple[List[Number], List[Number]]: - """ This method is not needed for DatetimeDataSource """ - raise NotImplementedError() - - def datetime_index(self) -> pd.DatetimeIndex: - """ This method is not needed for DatetimeDataSource """ - raise NotImplementedError() diff --git a/nowcasting_dataset/data_sources/gsp/gsp_data_source.py b/nowcasting_dataset/data_sources/gsp/gsp_data_source.py index e6ba08de..caf680e0 100644 --- a/nowcasting_dataset/data_sources/gsp/gsp_data_source.py +++ b/nowcasting_dataset/data_sources/gsp/gsp_data_source.py @@ -108,17 +108,15 @@ def datetime_index(self): """ return self.gsp_power.index - def get_locations_for_batch( - self, t0_datetimes: pd.DatetimeIndex - ) -> Tuple[List[Number], List[Number]]: + def get_locations(self, t0_datetimes: pd.DatetimeIndex) -> Tuple[List[Number], List[Number]]: """ - Get x and y locations for a batch. Assume that all data is available for all GSP. + Get x and y locations. Assume that all data is available for all GSP. Random GSP are taken, and the locations of them are returned. This is useful as other datasources need to know which x,y locations to get. Args: - t0_datetimes: list of datetimes that the batches locations have data for + t0_datetimes: list of available t0 datetimes. Returns: list of x and y locations @@ -266,7 +264,7 @@ def _get_central_gsp_id( logger.debug("Getting Central GSP") # If x_meters_center and y_meters_center have been chosen - # by {}.get_locations_for_batch() then we just have + # by {}.get_locations() then we just have # to find the gsp_ids at that exact location. This is # super-fast (a few hundred microseconds). We use np.isclose # instead of the equality operator because floats. diff --git a/nowcasting_dataset/data_sources/metadata/metadata_data_source.py b/nowcasting_dataset/data_sources/metadata/metadata_data_source.py index efdd7ff6..acb58aab 100644 --- a/nowcasting_dataset/data_sources/metadata/metadata_data_source.py +++ b/nowcasting_dataset/data_sources/metadata/metadata_data_source.py @@ -19,10 +19,6 @@ class MetadataDataSource(DataSource): object_at_center: str = "GSP" - def __post_init__(self): - """Post init""" - super().__post_init__() - def get_example( self, t0_dt: pd.Timestamp, x_meters_center: Number, y_meters_center: Number ) -> Metadata: @@ -44,6 +40,8 @@ def get_example( else: object_at_center_label = 0 + # TODO: data_dict is unused in this function. Is that a bug? + # https://github.com/openclimatefix/nowcasting_dataset/issues/279 data_dict = dict( t0_dt=to_numpy(t0_dt), #: Shape: [batch_size,] x_meters_center=np.array(x_meters_center), @@ -68,13 +66,3 @@ def get_example( data[v] = getattr(d, v) return Metadata(data) - - def get_locations_for_batch( - self, t0_datetimes: pd.DatetimeIndex - ) -> Tuple[List[Number], List[Number]]: - """This method is not needed for MetadataDataSource""" - raise NotImplementedError() - - def datetime_index(self) -> pd.DatetimeIndex: - """This method is not needed for MetadataDataSource""" - raise NotImplementedError() diff --git a/nowcasting_dataset/data_sources/pv/pv_data_source.py b/nowcasting_dataset/data_sources/pv/pv_data_source.py index cba34fb9..54894cba 100644 --- a/nowcasting_dataset/data_sources/pv/pv_data_source.py +++ b/nowcasting_dataset/data_sources/pv/pv_data_source.py @@ -276,9 +276,7 @@ def get_example( return PV(pv) - def get_locations_for_batch( - self, t0_datetimes: pd.DatetimeIndex - ) -> Tuple[List[Number], List[Number]]: + def get_locations(self, t0_datetimes: pd.DatetimeIndex) -> Tuple[List[Number], List[Number]]: """Find a valid geographical location for each t0_datetime. Returns: x_locations, y_locations. Each has one entry per t0_datetime. diff --git a/nowcasting_dataset/data_sources/sun/sun_data_source.py b/nowcasting_dataset/data_sources/sun/sun_data_source.py index 439f0150..af05839b 100644 --- a/nowcasting_dataset/data_sources/sun/sun_data_source.py +++ b/nowcasting_dataset/data_sources/sun/sun_data_source.py @@ -78,14 +78,11 @@ def get_example( return Sun(sun) def _load(self): - self.azimuth, self.elevation = load_from_zarr( filename=self.filename, start_dt=self.start_dt, end_dt=self.end_dt ) - def get_locations_for_batch( - self, t0_datetimes: pd.DatetimeIndex - ) -> Tuple[List[Number], List[Number]]: + def get_locations(self, t0_datetimes: pd.DatetimeIndex) -> Tuple[List[Number], List[Number]]: """ Sun data should not be used to get batch locations """ raise NotImplementedError("Sun data should not be used to get batch locations") diff --git a/nowcasting_dataset/dataset/datamodule.py b/nowcasting_dataset/dataset/datamodule.py index cf64ff22..64aedd3e 100644 --- a/nowcasting_dataset/dataset/datamodule.py +++ b/nowcasting_dataset/dataset/datamodule.py @@ -14,7 +14,7 @@ from nowcasting_dataset.data_sources.metadata.metadata_data_source import MetadataDataSource from nowcasting_dataset.data_sources.sun.sun_data_source import SunDataSource from nowcasting_dataset.dataset import datasets -from nowcasting_dataset.dataset.split.split import split_data, SplitMethod +from nowcasting_dataset.dataset.split.split import split_data, SplitMethod, SplitName from nowcasting_dataset.data_source_list import DataSourceList @@ -315,10 +315,14 @@ def _split_data(self): logger.debug(f"Got all start times, there are {len(self.t0_datetimes):,d}") - self.train_t0_datetimes, self.val_t0_datetimes, self.test_t0_datetimes = split_data( + data_after_splitting = split_data( datetimes=self.t0_datetimes, method=self.split_method, seed=self.seed ) + self.train_t0_datetimes = data_after_splitting.train + self.val_t0_datetimes = data_after_splitting.validation + self.test_t0_datetimes = data_after_splitting.test + logger.debug( f"Split data done, train has {len(self.train_t0_datetimes):,d}, " f"validation has {len(self.val_t0_datetimes):,d}, " diff --git a/nowcasting_dataset/dataset/datasets.py b/nowcasting_dataset/dataset/datasets.py index 7fb438bd..ed874a11 100644 --- a/nowcasting_dataset/dataset/datasets.py +++ b/nowcasting_dataset/dataset/datasets.py @@ -141,7 +141,7 @@ def _get_batch(self) -> Batch: return [] t0_datetimes = self._get_t0_datetimes_for_batch() - x_locations, y_locations = self._get_locations_for_batch(t0_datetimes) + x_locations, y_locations = self._get_locations(t0_datetimes) examples = {} n_threads = len(self.data_sources) @@ -179,10 +179,8 @@ def _get_t0_datetimes_for_batch(self) -> pd.DatetimeIndex: t0_datetimes = np.tile(t0_datetimes, reps=self.n_samples_per_timestep) return pd.DatetimeIndex(t0_datetimes) - def _get_locations_for_batch( - self, t0_datetimes: pd.DatetimeIndex - ) -> Tuple[List[Number], List[Number]]: - return self.data_sources[0].get_locations_for_batch(t0_datetimes) + def _get_locations(self, t0_datetimes: pd.DatetimeIndex) -> Tuple[List[Number], List[Number]]: + return self.data_sources[0].get_locations(t0_datetimes) def worker_init_fn(worker_id): diff --git a/nowcasting_dataset/dataset/split/split.py b/nowcasting_dataset/dataset/split/split.py index defcc734..0d071c35 100644 --- a/nowcasting_dataset/dataset/split/split.py +++ b/nowcasting_dataset/dataset/split/split.py @@ -3,6 +3,7 @@ import logging from enum import Enum from typing import List, Tuple, Union, Optional +from collections import namedtuple import pandas as pd @@ -30,6 +31,17 @@ class SplitMethod(Enum): DAY_RANDOM_TEST_DATE = "day_random_test_date" +class SplitName(Enum): + """The name for each data split.""" + + TRAIN = "train" + VALIDATION = "validation" + TEST = "test" + + +SplitData = namedtuple(typename="SplitData", field_names=["train", "validation", "test"]) + + def split_data( datetimes: Union[List[pd.Timestamp], pd.DatetimeIndex], method: SplitMethod, @@ -39,7 +51,7 @@ def split_data( ), train_validation_test_datetime_split: Optional[List[pd.Timestamp]] = None, seed: int = 1234, -) -> (List[pd.Timestamp], List[pd.Timestamp], List[pd.Timestamp]): +) -> SplitData: """ Split the date using various different methods @@ -165,4 +177,10 @@ def split_data( else: raise ValueError(f"{method} for splitting day is not implemented") - return train_datetimes, validation_datetimes, test_datetimes + logger.debug( + f"Split data done, train has {len(train_datetimes):,d}, " + f"validation has {len(validation_datetimes):,d}, " + f"test has {len(test_datetimes):,d} t0 datetimes." + ) + + return SplitData(train=train_datetimes, validation=validation_datetimes, test=test_datetimes) diff --git a/tests/data_sources/gsp/test_gsp_data_source.py b/tests/data_sources/gsp/test_gsp_data_source.py index dac465ba..ae2dfb67 100644 --- a/tests/data_sources/gsp/test_gsp_data_source.py +++ b/tests/data_sources/gsp/test_gsp_data_source.py @@ -20,7 +20,7 @@ def test_gsp_pv_data_source_init(): ) -def test_gsp_pv_data_source_get_locations_for_batch(): +def test_gsp_pv_data_source_get_locations(): local_path = os.path.dirname(nowcasting_dataset.__file__) + "/.." gsp = GSPDataSource( @@ -33,7 +33,7 @@ def test_gsp_pv_data_source_get_locations_for_batch(): meters_per_pixel=2000, ) - locations_x, locations_y = gsp.get_locations_for_batch(t0_datetimes=gsp.gsp_power.index[0:10]) + locations_x, locations_y = gsp.get_locations(t0_datetimes=gsp.gsp_power.index[0:10]) assert len(locations_x) == len(locations_y) # This makes sure it is not in lat/lon. @@ -61,7 +61,7 @@ def test_gsp_pv_data_source_get_example(): meters_per_pixel=2000, ) - x_locations, y_locations = gsp.get_locations_for_batch(t0_datetimes=gsp.gsp_power.index[0:10]) + x_locations, y_locations = gsp.get_locations(t0_datetimes=gsp.gsp_power.index[0:10]) l = gsp.get_example( t0_dt=gsp.gsp_power.index[0], x_meters_center=x_locations[0], y_meters_center=y_locations[0] ) @@ -87,9 +87,7 @@ def test_gsp_pv_data_source_get_batch(): batch_size = 10 - x_locations, y_locations = gsp.get_locations_for_batch( - t0_datetimes=gsp.gsp_power.index[0:batch_size] - ) + x_locations, y_locations = gsp.get_locations(t0_datetimes=gsp.gsp_power.index[0:batch_size]) batch = gsp.get_batch( t0_datetimes=gsp.gsp_power.index[batch_size : 2 * batch_size], diff --git a/tests/data_sources/test_pv_data_source.py b/tests/data_sources/test_pv_data_source.py index ec548158..55904999 100644 --- a/tests/data_sources/test_pv_data_source.py +++ b/tests/data_sources/test_pv_data_source.py @@ -34,11 +34,9 @@ def test_get_example_and_batch(): load_from_gcs=False, ) - x_locations, y_locations = pv_data_source.get_locations_for_batch(pv_data_source.pv_power.index) + x_locations, y_locations = pv_data_source.get_locations(pv_data_source.pv_power.index) - example = pv_data_source.get_example( - pv_data_source.pv_power.index[0], x_locations[0], y_locations[0] - ) + _ = pv_data_source.get_example(pv_data_source.pv_power.index[0], x_locations[0], y_locations[0]) batch = pv_data_source.get_batch( pv_data_source.pv_power.index[6:11], x_locations[0:10], y_locations[0:10] From 2bf52bd96a0fed8ea5efa9c0414151f20d72e01a Mon Sep 17 00:00:00 2001 From: Jack Kelly Date: Mon, 25 Oct 2021 13:17:11 +0100 Subject: [PATCH 4/8] comments tweak. And return DataFrame not dict --- nowcasting_dataset/data_source_list.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/nowcasting_dataset/data_source_list.py b/nowcasting_dataset/data_source_list.py index 8891b392..1650e59e 100644 --- a/nowcasting_dataset/data_source_list.py +++ b/nowcasting_dataset/data_source_list.py @@ -62,8 +62,6 @@ def sample_position_of_every_example_of_every_split( """ Computes the geospatial and temporal position of each training example. - Computes the intersection of the time periods available across all data_sources. - The first data_source in this DataSourceList defines the geospatial locations of each example. @@ -91,10 +89,12 @@ def sample_position_of_every_example_of_every_split( x_locations, y_locations = data_source_which_defines_geo_position.get_locations( shuffled_t0_datetimes ) - positions_per_split[split_name] = { - "t0_datetime_UTC": shuffled_t0_datetimes, - "x_center_OSGB": x_locations, - "y_center_OSGB": y_locations, - } + positions_per_split[split_name] = pd.DataFrame( + { + "t0_datetime_UTC": shuffled_t0_datetimes, + "x_center_OSGB": x_locations, + "y_center_OSGB": y_locations, + } + ) return positions_per_split From 6956829940f2fee8f5233a15d63bce312ef02bc4 Mon Sep 17 00:00:00 2001 From: Jack Kelly Date: Mon, 25 Oct 2021 13:17:57 +0100 Subject: [PATCH 5/8] SplitName not SplitMethod --- nowcasting_dataset/data_source_list.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nowcasting_dataset/data_source_list.py b/nowcasting_dataset/data_source_list.py index 1650e59e..9107c74b 100644 --- a/nowcasting_dataset/data_source_list.py +++ b/nowcasting_dataset/data_source_list.py @@ -57,7 +57,7 @@ def sample_position_of_every_example_of_every_split( self, t0_datetimes: pd.DatetimeIndex, split_method: SplitMethod, - n_examples_per_split: dict[SplitMethod, int], + n_examples_per_split: dict[SplitName, int], ) -> dict[SplitName, pd.DataFrame]: """ Computes the geospatial and temporal position of each training example. From ec8b81171d02ab49952f5142b0e05a58ff9965dd Mon Sep 17 00:00:00 2001 From: Jack Kelly Date: Mon, 25 Oct 2021 13:59:03 +0100 Subject: [PATCH 6/8] fix linter error --- nowcasting_dataset/dataset/datamodule.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nowcasting_dataset/dataset/datamodule.py b/nowcasting_dataset/dataset/datamodule.py index 64aedd3e..73416491 100644 --- a/nowcasting_dataset/dataset/datamodule.py +++ b/nowcasting_dataset/dataset/datamodule.py @@ -363,7 +363,8 @@ def _common_dataloader_params(self) -> Dict: def _get_t0_datetimes_across_all_data_sources(self) -> pd.DatetimeIndex: """See DataSourceList.get_t0_datetimes_across_all_data_sources. - This method will be deleted as part of implementing #213.""" + This method will be deleted as part of implementing #213. + """ return self.data_sources.get_t0_datetimes_across_all_data_sources( freq=self.t0_datetime_freq ) From e6ef5218db207366fa9938945f46367340696744 Mon Sep 17 00:00:00 2001 From: Jack Kelly Date: Mon, 25 Oct 2021 17:53:11 +0100 Subject: [PATCH 7/8] Implemented test for DataSourceList.sample_spatial_and_temporal_positions_for_examples --- .../{ => data_sources}/data_source_list.py | 47 +++++++------------ nowcasting_dataset/dataset/datamodule.py | 4 +- nowcasting_dataset/dataset/split/split.py | 5 +- 3 files changed, 22 insertions(+), 34 deletions(-) rename nowcasting_dataset/{ => data_sources}/data_source_list.py (62%) diff --git a/nowcasting_dataset/data_source_list.py b/nowcasting_dataset/data_sources/data_source_list.py similarity index 62% rename from nowcasting_dataset/data_source_list.py rename to nowcasting_dataset/data_sources/data_source_list.py index 9107c74b..95153716 100644 --- a/nowcasting_dataset/data_source_list.py +++ b/nowcasting_dataset/data_sources/data_source_list.py @@ -53,12 +53,9 @@ def get_t0_datetimes_across_all_data_sources(self, freq: str) -> pd.DatetimeInde return t0_datetimes - def sample_position_of_every_example_of_every_split( - self, - t0_datetimes: pd.DatetimeIndex, - split_method: SplitMethod, - n_examples_per_split: dict[SplitName, int], - ) -> dict[SplitName, pd.DataFrame]: + def sample_spatial_and_temporal_positions_for_examples( + self, t0_datetimes: pd.DatetimeIndex, n_examples: int + ) -> pd.DataFrame: """ Computes the geospatial and temporal position of each training example. @@ -68,33 +65,21 @@ def sample_position_of_every_example_of_every_split( Args: t0_datetimes: All available t0 datetimes. Can be computed with `DataSourceList.get_t0_datetimes_across_all_data_sources()` - split_method: The method used to split data into train, validation, and test. - n_examples_per_split: The number of examples requested for each split. + n_examples: The number of examples requested. Returns: - A dict where the keys are a SplitName, and the values are a pd.DataFrame. - Each row of each DataFrame specifies the position of each example, using + Each row of each the DataFrame specifies the position of each example, using columns: 't0_datetime_UTC', 'x_center_OSGB', 'y_center_OSGB'. """ - # Split t0_datetimes into train, test and validation sets. - t0_datetimes_per_split = split_data(datetimes=t0_datetimes, method=split_method) - t0_datetimes_per_split = t0_datetimes_per_split._asdict() - data_source_which_defines_geo_position = self[0] - - positions_per_split: dict[SplitName, pd.DataFrame] = {} - for split_name, t0_datetimes_for_split in t0_datetimes_per_split.items(): - n_examples = n_examples_per_split[split_name] - shuffled_t0_datetimes = np.random.choice(t0_datetimes_for_split, shape=n_examples) - x_locations, y_locations = data_source_which_defines_geo_position.get_locations( - shuffled_t0_datetimes - ) - positions_per_split[split_name] = pd.DataFrame( - { - "t0_datetime_UTC": shuffled_t0_datetimes, - "x_center_OSGB": x_locations, - "y_center_OSGB": y_locations, - } - ) - - return positions_per_split + shuffled_t0_datetimes = np.random.choice(t0_datetimes, size=n_examples) + x_locations, y_locations = data_source_which_defines_geo_position.get_locations( + shuffled_t0_datetimes + ) + return pd.DataFrame( + { + "t0_datetime_UTC": shuffled_t0_datetimes, + "x_center_OSGB": x_locations, + "y_center_OSGB": y_locations, + } + ) diff --git a/nowcasting_dataset/dataset/datamodule.py b/nowcasting_dataset/dataset/datamodule.py index 73416491..af8bcef3 100644 --- a/nowcasting_dataset/dataset/datamodule.py +++ b/nowcasting_dataset/dataset/datamodule.py @@ -14,8 +14,8 @@ from nowcasting_dataset.data_sources.metadata.metadata_data_source import MetadataDataSource from nowcasting_dataset.data_sources.sun.sun_data_source import SunDataSource from nowcasting_dataset.dataset import datasets -from nowcasting_dataset.dataset.split.split import split_data, SplitMethod, SplitName -from nowcasting_dataset.data_source_list import DataSourceList +from nowcasting_dataset.dataset.split.split import split_data, SplitMethod +from nowcasting_dataset.data_sources.data_source_list import DataSourceList with warnings.catch_warnings(): diff --git a/nowcasting_dataset/dataset/split/split.py b/nowcasting_dataset/dataset/split/split.py index 0d071c35..beca4852 100644 --- a/nowcasting_dataset/dataset/split/split.py +++ b/nowcasting_dataset/dataset/split/split.py @@ -39,7 +39,10 @@ class SplitName(Enum): TEST = "test" -SplitData = namedtuple(typename="SplitData", field_names=["train", "validation", "test"]) +SplitData = namedtuple( + typename="SplitData", + field_names=[SplitName.TRAIN.value, SplitName.VALIDATION.value, SplitName.TEST.value], +) def split_data( From d9c07155440b1e40761fb14a1b6fd3c00f505f27 Mon Sep 17 00:00:00 2001 From: Jack Kelly Date: Mon, 25 Oct 2021 17:56:50 +0100 Subject: [PATCH 8/8] change name from position to location --- nowcasting_dataset/data_sources/data_source_list.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nowcasting_dataset/data_sources/data_source_list.py b/nowcasting_dataset/data_sources/data_source_list.py index 95153716..f5e579ad 100644 --- a/nowcasting_dataset/data_sources/data_source_list.py +++ b/nowcasting_dataset/data_sources/data_source_list.py @@ -53,11 +53,11 @@ def get_t0_datetimes_across_all_data_sources(self, freq: str) -> pd.DatetimeInde return t0_datetimes - def sample_spatial_and_temporal_positions_for_examples( + def sample_spatial_and_temporal_locations_for_examples( self, t0_datetimes: pd.DatetimeIndex, n_examples: int ) -> pd.DataFrame: """ - Computes the geospatial and temporal position of each training example. + Computes the geospatial and temporal locations for each training example. The first data_source in this DataSourceList defines the geospatial locations of each example.