Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions backend/kale/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,14 @@ class Artifact(NamedTuple):

from .compiler import Compiler
from .pipeline import Pipeline, PipelineConfig, VolumeConfig
from .processors import NotebookConfig, NotebookProcessor, PythonProcessor
from .processors import NotebookConfig, NotebookProcessor
from .step import Step, StepConfig

__all__ = [
"PipelineParam",
"Artifact",
"NotebookConfig",
"NotebookProcessor",
"PythonProcessor",
"Step",
"StepConfig",
"Pipeline",
Expand Down
16 changes: 6 additions & 10 deletions backend/kale/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,8 @@

log = logging.getLogger(__name__)

PY_FN_TEMPLATE = "py_function_template.jinja2"
NB_FN_TEMPLATE = "nb_function_template.jinja2"
PIPELINE_TEMPLATE = "pipeline_template.jinja2"
PIPELINE_ORIGIN = {"nb": NB_FN_TEMPLATE, "py": PY_FN_TEMPLATE}

KFP_DSL_ARTIFACT_IMPORTS = [
"Dataset",
Expand Down Expand Up @@ -111,7 +109,7 @@ def generate_dsl(self):
return pipeline_code

def generate_lightweight_component(self, step: Step):
"""Generate Python code using the function template."""
"""Generate Python code using the notebook function template."""
step_source_raw = step.source

def _encode_source(s):
Expand All @@ -120,14 +118,12 @@ def _encode_source(s):
[line.encode("unicode_escape").decode("utf-8") for line in s.splitlines()]
)

if self.pipeline.processor.id == "nb":
# Since the code will be wrapped in triple quotes inside the
# template, we need to escape triple quotes as they will not be
# escaped by encode("unicode_escape").
step.source = [re.sub(r"'''", "\\'\\'\\'", _encode_source(s)) for s in step_source_raw]
# Since the code will be wrapped in triple quotes inside the
# template, we need to escape triple quotes as they will not be
# escaped by encode("unicode_escape").
step.source = [re.sub(r"'''", "\\'\\'\\'", _encode_source(s)) for s in step_source_raw]

_template_filename = PIPELINE_ORIGIN.get(self.pipeline.processor.id)
template = self._get_templating_env().get_template(_template_filename)
template = self._get_templating_env().get_template(NB_FN_TEMPLATE)

# Separate parameters with and without defaults for proper ordering
params_without_defaults = [f"{step.name}_html_report: Output[HTML]"]
Expand Down
1 change: 0 additions & 1 deletion backend/kale/processors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,3 @@
# limitations under the License.

from .nbprocessor import NotebookConfig as NotebookConfig, NotebookProcessor as NotebookProcessor
from .pyprocessor import PythonProcessor as PythonProcessor
74 changes: 0 additions & 74 deletions backend/kale/processors/baseprocessor.py

This file was deleted.

57 changes: 49 additions & 8 deletions backend/kale/processors/nbprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import re
from typing import Any

import nbformat as nb

from kale.common import astutils, flakeutils, graphutils, utils
from kale.common import astutils, flakeutils, graphutils, kfutils, utils
from kale.config import Field
from kale.pipeline import PipelineConfig
from kale.pipeline import Pipeline, PipelineConfig
from kale.step import PipelineParam, Step

from .baseprocessor import BaseProcessor
log = logging.getLogger(__name__)

# fixme: Change the name of this key to `kale_metadata`
KALE_NB_METADATA_KEY = "kubeflow_notebook"
Expand Down Expand Up @@ -168,19 +169,25 @@ def _parse_steps_defaults(self, steps_defaults):
return result


class NotebookProcessor(BaseProcessor):
class NotebookProcessor:
"""Convert a Notebook to a Pipeline object."""

id = "nb"
config_cls = NotebookConfig
no_op_step = Step(name="no_op", source=[])

def __init__(self, nb_path: str, nb_metadata_overrides: dict[str, Any] | None = None, **kwargs):
def __init__(
self,
nb_path: str,
nb_metadata_overrides: dict[str, Any] | None = None,
config: NotebookConfig | None = None,
skip_validation: bool = False,
**kwargs,
):
"""Instantiate a new NotebookProcessor.

Args:
nb_path: Path to source notebook
nb_metadata_overrides: Override notebook config settings
config: Optional pre-built NotebookConfig
skip_validation: Set to True in order to skip the notebook's
metadata validation. This is useful in case the
NotebookProcessor is used to parse a part of the notebook
Expand All @@ -194,13 +201,47 @@ def __init__(self, nb_path: str, nb_metadata_overrides: dict[str, Any] | None =
nb_metadata.update({"notebook_path": nb_path})
if nb_metadata_overrides:
nb_metadata.update(nb_metadata_overrides)
super().__init__(**{**kwargs, **nb_metadata})

# Initialize config and pipeline (previously in BaseProcessor)
self.config = config
if not config and not skip_validation:
self.config = NotebookConfig(**{**kwargs, **nb_metadata})
self.pipeline = Pipeline(self.config) if self.config else None

def _read_notebook(self):
if not os.path.exists(self.nb_path):
raise ValueError(f"NotebookProcessor could not find a notebook at path {self.nb_path}")
return nb.read(self.nb_path, as_version=nb.NO_CONVERT)

def run(self) -> Pipeline:
"""Process the notebook into a Pipeline object."""
self.to_pipeline()
self._post_pipeline()
return self.pipeline

def _post_pipeline(self):
"""Post-process the pipeline after conversion."""
if self.pipeline:
self.pipeline.processor = self
self._configure_poddefaults()
self._apply_steps_defaults()

def _configure_poddefaults(self):
"""Detect and configure PodDefaults labels."""
_pod_defaults_labels = dict()
try:
_pod_defaults_labels = kfutils.find_poddefault_labels()
except Exception as e:
log.warning("Could not retrieve PodDefaults. Reason: %s", e)
self.pipeline.config.steps_defaults["labels"] = {
**self.pipeline.config.steps_defaults.get("labels", dict()),
**_pod_defaults_labels}

def _apply_steps_defaults(self):
"""Apply default configuration to all pipeline steps."""
for step in self.pipeline.steps:
step.config.update(self.pipeline.config.steps_defaults)

def to_pipeline(self):
"""Convert an annotated Notebook to a Pipeline object."""
(pipeline_parameters_source, pipeline_metrics_source, imports_and_functions) = (
Expand Down
Loading
Loading