diff --git a/src/gt4py/next/otf/workflow.py b/src/gt4py/next/otf/workflow.py index ef3a4083b9..2efceef69f 100644 --- a/src/gt4py/next/otf/workflow.py +++ b/src/gt4py/next/otf/workflow.py @@ -12,11 +12,13 @@ import dataclasses import functools import typing -from collections.abc import MutableMapping -from typing import Any, Callable, Generic, Protocol, TypeVar +from collections.abc import Hashable, MutableMapping +from typing import Any, Callable, Protocol, TypeVar from typing_extensions import Self +from gt4py.eve import utils + StartT = TypeVar("StartT") StartT_contra = TypeVar("StartT_contra", contravariant=True) @@ -62,6 +64,20 @@ class Workflow(Protocol[StartT_contra, EndT_co]): def __call__(self, inp: StartT_contra) -> EndT_co: ... +class StatefulWorkflow(Workflow[StartT_contra, EndT_co], Protocol): + """Protocol for stateful workflows whose state influence the outputs.""" + + @property + def workflow_state_id(self) -> Hashable: + """ + Hashable representation of the workflow state. + + This should be used to check if the workflow state has changed + to decide whether the output of the workflow should be recomputed. + """ + ... + + class ReplaceEnabledWorkflowMixin(Workflow[StartT_contra, EndT_co], Protocol): """ Subworkflow replacement mixin. @@ -155,6 +171,15 @@ def step_order(self) -> list[str]: step_names.append(field.name) return step_names + @functools.cached_property + def workflow_state_id(self) -> Hashable: + return utils.content_hash( + *( + getattr(getattr(self, step_name), "workflow_state_id", None) + for step_name in self.step_order + ) + ) + @dataclasses.dataclass(frozen=True) class MultiWorkflow( @@ -221,12 +246,17 @@ def chain(self, next_step: Workflow[EndT, NewEndT]) -> ChainableWorkflowMixin[St def start(cls, first_step: Workflow[StartT, EndT]) -> ChainableWorkflowMixin[StartT, EndT]: return cls(cls.__Steps((first_step,))) + @functools.cached_property + def workflow_state_id(self) -> Hashable: + return utils.content_hash( + *(getattr(step, "workflow_state_id", None) for step in self.steps.inner) + ) + @dataclasses.dataclass(frozen=True) class CachedStep( ChainableWorkflowMixin[StartT, EndT], ReplaceEnabledWorkflowMixin[StartT, EndT], - Generic[StartT, EndT, HashT], ): """ Cached workflow of single input callables. @@ -253,18 +283,27 @@ class CachedStep( """ step: Workflow[StartT, EndT] - hash_function: Callable[[StartT], HashT] = dataclasses.field(default=hash) # type: ignore[assignment] - cache: MutableMapping[HashT, EndT] = dataclasses.field(repr=False, default_factory=dict) + hash_function: Callable[[StartT], Hashable] = dataclasses.field(default=hash) + cache: MutableMapping[Hashable, EndT] = dataclasses.field(repr=False, default_factory=dict) def __call__(self, inp: StartT) -> EndT: """Run the step only if the input is not cached, else return from cache.""" - hash_ = self.hash_function(inp) + hash_ = self.cache_key(inp) try: result = self.cache[hash_] except KeyError: result = self.cache[hash_] = self.step(inp) return result + def cache_key(self, inp: StartT) -> Hashable: + return utils.content_hash( + self.hash_function(inp), getattr(self.step, "workflow_state_id", None) + ) + + @functools.cached_property + def workflow_state_id(self) -> Hashable: + return getattr(self.step, "workflow_state_id", None) + @dataclasses.dataclass(frozen=True) class SkippableStep( @@ -279,3 +318,7 @@ def __call__(self, inp: StartT) -> EndT: def skip_condition(self, inp: StartT) -> bool: raise NotImplementedError() + + @functools.cached_property + def workflow_state_id(self) -> Hashable: + return getattr(self.step, "workflow_state_id", None) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index 90441ec61a..788c2cdeff 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -16,7 +16,7 @@ import numpy as np from gt4py._core import definitions as core_defs -from gt4py.eve import codegen +from gt4py.eve import codegen, utils from gt4py.next import common from gt4py.next.ffront import fbuiltins from gt4py.next.iterator import ir as itir @@ -54,6 +54,16 @@ class GTFNTranslationStep( device_type: core_defs.DeviceType = core_defs.DeviceType.CPU symbolic_domain_sizes: Optional[dict[str, str]] = None + @functools.cached_property + def workflow_state_id(self) -> str: + return utils.content_hash( + self.language_settings, + self.enable_itir_transforms, + self.use_imperative_backend, + self.device_type, + tuple(self.symbolic_domain_sizes.items()) if self.symbolic_domain_sizes else None, + ) + def _default_language_settings(self) -> languages.LanguageWithHeaderFilesSettings: match self.device_type: case core_defs.DeviceType.CUDA: diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py index 510c03e314..e7c3a5f6bd 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py @@ -147,7 +147,7 @@ def test_gtfn_file_cache(program_example): gpu=False, cached=True, otf_workflow__cached_translation=False ).executor.step.translation - cache_key = stages.fingerprint_compilable_program(compilable_program) + cache_key = cached_gtfn_translation_step.cache_key(compilable_program) # ensure the actual cached step in the backend generates the cache item for the test if cache_key in (translation_cache := cached_gtfn_translation_step.cache):