From d076bc4c75cc082ee56241554b1dfd28e16a59b6 Mon Sep 17 00:00:00 2001 From: James Fulton Date: Fri, 16 Jan 2026 14:25:17 +0000 Subject: [PATCH] Add option to use t0 embedding features --- .../datamodule/configuration/example_configuration.yaml | 6 ++++++ configs.example/model/late_fusion.yaml | 1 + pvnet/models/late_fusion/late_fusion.py | 9 +++++++++ pyproject.toml | 2 +- tests/conftest.py | 1 + tests/test_data/data_config.yaml | 3 +++ 6 files changed, 21 insertions(+), 1 deletion(-) diff --git a/configs.example/datamodule/configuration/example_configuration.yaml b/configs.example/datamodule/configuration/example_configuration.yaml index 94894068..47502542 100644 --- a/configs.example/datamodule/configuration/example_configuration.yaml +++ b/configs.example/datamodule/configuration/example_configuration.yaml @@ -273,3 +273,9 @@ input_data: interval_start_minutes: -60 interval_end_minutes: 480 time_resolution_minutes: 30 + + # Note that each "cyclic" embedding adds 2 elements to the output vector to embed a period whilst + # "linear" adds only 1 element. This needs to be matched to the model `t0_embedding_dim` param. + # e.g. these embedding would require t0_embedding_dim = 3 (1 + 2). + t0_embedding: + embeddings: [["1h", "linear"], ["24h", "cyclic"]] diff --git a/configs.example/model/late_fusion.yaml b/configs.example/model/late_fusion.yaml index a46f6aa3..29b98c63 100644 --- a/configs.example/model/late_fusion.yaml +++ b/configs.example/model/late_fusion.yaml @@ -69,6 +69,7 @@ model: embedding_dim: 16 include_sun: True include_generation_history: False + t0_embedding_dim: 3 # The mapping between the location IDs and their embedding indices location_id_mapping: diff --git a/pvnet/models/late_fusion/late_fusion.py b/pvnet/models/late_fusion/late_fusion.py index 37c43eb6..2fc5078e 100644 --- a/pvnet/models/late_fusion/late_fusion.py +++ b/pvnet/models/late_fusion/late_fusion.py @@ -46,6 +46,7 @@ def __init__( include_generation_history: bool = False, include_sun: bool = True, include_time: bool = False, + t0_embedding_dim: int = 0, location_id_mapping: dict[Any, int] | None = None, embedding_dim: int = 16, forecast_minutes: int = 30, @@ -85,6 +86,8 @@ def __init__( include_generation_history: Include generation yield data. include_sun: Include sun azimuth and altitude data. include_time: Include sine and cosine of dates and times. + t0_embedding_dim: Shape of the embedding of the init-time (t0) of the forecast. Not used + if set to 0. location_id_mapping: A dictionary mapping the location ID to an integer. ID embedding is not used if this is not provided. embedding_dim: Number of embedding dimensions to use for location ID. @@ -119,6 +122,7 @@ def __init__( self.include_pv = pv_encoder is not None self.include_sun = include_sun self.include_time = include_time + self.t0_embedding_dim = t0_embedding_dim self.location_id_mapping = location_id_mapping self.embedding_dim = embedding_dim self.add_image_embedding_channel = add_image_embedding_channel @@ -246,6 +250,8 @@ def __init__( # Update num features fusion_input_features += 32 + fusion_input_features += self.t0_embedding_dim + if include_generation_history: # Update num features fusion_input_features += self.history_len + 1 @@ -321,6 +327,9 @@ def forward(self, x: TensorBatch) -> torch.Tensor: time = self.time_fc1(time) modes["time"] = time + if self.t0_embedding_dim>0: + modes["t0_embed"] = x["t0_embedding"] + out = self.output_network(modes) if self.use_quantile_regression: diff --git a/pyproject.toml b/pyproject.toml index 2d3bb369..6391b117 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ readme = {file="README.md", content-type="text/markdown"} requires-python = ">=3.11,<3.14" dependencies = [ - "ocf-data-sampler>=0.6.0", + "ocf-data-sampler>=1.0.9", "numpy", "pandas", "matplotlib", diff --git a/tests/conftest.py b/tests/conftest.py index 86ae6783..853685a3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -327,6 +327,7 @@ def raw_late_fusion_model_kwargs_generation_history(model_minutes_kwargs) -> dic embedding_dim=None, include_sun=False, include_time=True, + t0_embedding_dim=3, include_generation_history=True, forecast_minutes=480, history_minutes=60, diff --git a/tests/test_data/data_config.yaml b/tests/test_data/data_config.yaml index e6ddd731..e54049c3 100755 --- a/tests/test_data/data_config.yaml +++ b/tests/test_data/data_config.yaml @@ -125,3 +125,6 @@ input_data: interval_start_minutes: -60 interval_end_minutes: 480 time_resolution_minutes: 30 + + t0_embedding: + embeddings: [["1h", "linear"], ["24h", "cyclic"]]