Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 49 additions & 6 deletions src/gt4py/next/otf/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
import dataclasses
import functools
import typing
from collections.abc import MutableMapping
from typing import Any, Callable, Generic, Protocol, TypeVar
from collections.abc import Hashable, MutableMapping
from typing import Any, Callable, Protocol, TypeVar

from typing_extensions import Self

from gt4py.eve import utils


StartT = TypeVar("StartT")
StartT_contra = TypeVar("StartT_contra", contravariant=True)
Expand Down Expand Up @@ -62,6 +64,20 @@ class Workflow(Protocol[StartT_contra, EndT_co]):
def __call__(self, inp: StartT_contra) -> EndT_co: ...


class StatefulWorkflow(Workflow[StartT_contra, EndT_co], Protocol):
"""Protocol for stateful workflows whose state influence the outputs."""

@property
def workflow_state_id(self) -> Hashable:
"""
Hashable representation of the workflow state.

This should be used to check if the workflow state has changed
to decide whether the output of the workflow should be recomputed.
"""
...


class ReplaceEnabledWorkflowMixin(Workflow[StartT_contra, EndT_co], Protocol):
"""
Subworkflow replacement mixin.
Expand Down Expand Up @@ -155,6 +171,15 @@ def step_order(self) -> list[str]:
step_names.append(field.name)
return step_names

@functools.cached_property
def workflow_state_id(self) -> Hashable:
return utils.content_hash(
*(
getattr(getattr(self, step_name), "workflow_state_id", None)
for step_name in self.step_order
)
)


@dataclasses.dataclass(frozen=True)
class MultiWorkflow(
Expand Down Expand Up @@ -221,12 +246,17 @@ def chain(self, next_step: Workflow[EndT, NewEndT]) -> ChainableWorkflowMixin[St
def start(cls, first_step: Workflow[StartT, EndT]) -> ChainableWorkflowMixin[StartT, EndT]:
return cls(cls.__Steps((first_step,)))

@functools.cached_property
def workflow_state_id(self) -> Hashable:
return utils.content_hash(
*(getattr(step, "workflow_state_id", None) for step in self.steps.inner)
)


@dataclasses.dataclass(frozen=True)
class CachedStep(
ChainableWorkflowMixin[StartT, EndT],
ReplaceEnabledWorkflowMixin[StartT, EndT],
Generic[StartT, EndT, HashT],
):
"""
Cached workflow of single input callables.
Expand All @@ -253,18 +283,27 @@ class CachedStep(
"""

step: Workflow[StartT, EndT]
hash_function: Callable[[StartT], HashT] = dataclasses.field(default=hash) # type: ignore[assignment]
cache: MutableMapping[HashT, EndT] = dataclasses.field(repr=False, default_factory=dict)
hash_function: Callable[[StartT], Hashable] = dataclasses.field(default=hash)
cache: MutableMapping[Hashable, EndT] = dataclasses.field(repr=False, default_factory=dict)

def __call__(self, inp: StartT) -> EndT:
"""Run the step only if the input is not cached, else return from cache."""
hash_ = self.hash_function(inp)
hash_ = self.cache_key(inp)
try:
result = self.cache[hash_]
except KeyError:
result = self.cache[hash_] = self.step(inp)
return result

def cache_key(self, inp: StartT) -> Hashable:
return utils.content_hash(
self.hash_function(inp), getattr(self.step, "workflow_state_id", None)
)

@functools.cached_property
def workflow_state_id(self) -> Hashable:
return getattr(self.step, "workflow_state_id", None)


@dataclasses.dataclass(frozen=True)
class SkippableStep(
Expand All @@ -279,3 +318,7 @@ def __call__(self, inp: StartT) -> EndT:

def skip_condition(self, inp: StartT) -> bool:
raise NotImplementedError()

@functools.cached_property
def workflow_state_id(self) -> Hashable:
return getattr(self.step, "workflow_state_id", None)
12 changes: 11 additions & 1 deletion src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import numpy as np

from gt4py._core import definitions as core_defs
from gt4py.eve import codegen
from gt4py.eve import codegen, utils
from gt4py.next import common
from gt4py.next.ffront import fbuiltins
from gt4py.next.iterator import ir as itir
Expand Down Expand Up @@ -54,6 +54,16 @@ class GTFNTranslationStep(
device_type: core_defs.DeviceType = core_defs.DeviceType.CPU
symbolic_domain_sizes: Optional[dict[str, str]] = None

@functools.cached_property
def workflow_state_id(self) -> str:
return utils.content_hash(
self.language_settings,
self.enable_itir_transforms,
self.use_imperative_backend,
self.device_type,
tuple(self.symbolic_domain_sizes.items()) if self.symbolic_domain_sizes else None,
)

def _default_language_settings(self) -> languages.LanguageWithHeaderFilesSettings:
match self.device_type:
case core_defs.DeviceType.CUDA:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def test_gtfn_file_cache(program_example):
gpu=False, cached=True, otf_workflow__cached_translation=False
).executor.step.translation

cache_key = stages.fingerprint_compilable_program(compilable_program)
cache_key = cached_gtfn_translation_step.cache_key(compilable_program)

# ensure the actual cached step in the backend generates the cache item for the test
if cache_key in (translation_cache := cached_gtfn_translation_step.cache):
Expand Down