Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -273,3 +273,9 @@ input_data:
interval_start_minutes: -60
interval_end_minutes: 480
time_resolution_minutes: 30

# Note that each "cyclic" embedding adds 2 elements to the output vector to embed a period whilst
# "linear" adds only 1 element. This needs to be matched to the model `t0_embedding_dim` param.
# e.g. these embedding would require t0_embedding_dim = 3 (1 + 2).
t0_embedding:
embeddings: [["1h", "linear"], ["24h", "cyclic"]]
1 change: 1 addition & 0 deletions configs.example/model/late_fusion.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ model:
embedding_dim: 16
include_sun: True
include_generation_history: False
t0_embedding_dim: 3

# The mapping between the location IDs and their embedding indices
location_id_mapping:
Expand Down
9 changes: 9 additions & 0 deletions pvnet/models/late_fusion/late_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(
include_generation_history: bool = False,
include_sun: bool = True,
include_time: bool = False,
t0_embedding_dim: int = 0,
location_id_mapping: dict[Any, int] | None = None,
embedding_dim: int = 16,
forecast_minutes: int = 30,
Expand Down Expand Up @@ -85,6 +86,8 @@ def __init__(
include_generation_history: Include generation yield data.
include_sun: Include sun azimuth and altitude data.
include_time: Include sine and cosine of dates and times.
t0_embedding_dim: Shape of the embedding of the init-time (t0) of the forecast. Not used
if set to 0.
location_id_mapping: A dictionary mapping the location ID to an integer. ID embedding is
not used if this is not provided.
embedding_dim: Number of embedding dimensions to use for location ID.
Expand Down Expand Up @@ -119,6 +122,7 @@ def __init__(
self.include_pv = pv_encoder is not None
self.include_sun = include_sun
self.include_time = include_time
self.t0_embedding_dim = t0_embedding_dim
self.location_id_mapping = location_id_mapping
self.embedding_dim = embedding_dim
self.add_image_embedding_channel = add_image_embedding_channel
Expand Down Expand Up @@ -246,6 +250,8 @@ def __init__(
# Update num features
fusion_input_features += 32

fusion_input_features += self.t0_embedding_dim

if include_generation_history:
# Update num features
fusion_input_features += self.history_len + 1
Expand Down Expand Up @@ -321,6 +327,9 @@ def forward(self, x: TensorBatch) -> torch.Tensor:
time = self.time_fc1(time)
modes["time"] = time

if self.t0_embedding_dim>0:
modes["t0_embed"] = x["t0_embedding"]

out = self.output_network(modes)

if self.use_quantile_regression:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ readme = {file="README.md", content-type="text/markdown"}
requires-python = ">=3.11,<3.14"

dependencies = [
"ocf-data-sampler>=0.6.0",
"ocf-data-sampler>=1.0.9",
"numpy",
"pandas",
"matplotlib",
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ def raw_late_fusion_model_kwargs_generation_history(model_minutes_kwargs) -> dic
embedding_dim=None,
include_sun=False,
include_time=True,
t0_embedding_dim=3,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be my lack of understanding, but is it obvious how know the t0_embedding_dim from the t0_embedding embedding (below)?

Copy link
Member Author

@dfulu dfulu Jan 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well in the function docstring it explains that the embedding_dim parameter is for the location ID. We could rename embedding_dim->loc_embedding_dim or something similar to be more explicit. But for that we'd need to migrate all our production models

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sorry, I think I misunderstood your question. You mean how can you know what the t0_embedding_dim needs to be given t0_embedding config? That info is in data-sampler.

Basically:

t0_embedding_dim = sum([1 if e=="linear" else 2 for e in t0_embedding.embeddings])

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah i see, could you put this comment in the docstrings? (Or if there somewhere else more suitable)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels like bloat here since it is explained in full in data-sampler

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a compromise is to put the explanation in the example config

include_generation_history=True,
forecast_minutes=480,
history_minutes=60,
Expand Down
3 changes: 3 additions & 0 deletions tests/test_data/data_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,6 @@ input_data:
interval_start_minutes: -60
interval_end_minutes: 480
time_resolution_minutes: 30

t0_embedding:
embeddings: [["1h", "linear"], ["24h", "cyclic"]]