diff --git a/docs/source/utilities/yaml_checker.rst b/docs/source/utilities/yaml_checker.rst index ed3f744b3..97e73d972 100644 --- a/docs/source/utilities/yaml_checker.rst +++ b/docs/source/utilities/yaml_checker.rst @@ -92,6 +92,8 @@ other fields directly under the :code:`parameters` field. * We check methods exist for the given module path. * We check that the parameters for each method are valid. For example, :code:`find_center_vo` method from :code:`tomopy.recon.rotation` takes :code:`ratio` as a parameter with a float value. If you pass a string instead, it will raise an error. Again the trick is to refer the documentation always. * We check the required parameters for each method are present. +* We check parameters that are omitted are not required (in particular, that omitted parameters + have a default value that can be assumed) * If you pass :code:`IN_DATA` (path to the data) along with the yaml config, as: .. code-block:: console diff --git a/httomo/yaml_checker.py b/httomo/yaml_checker.py index fe12b2e48..37fb1f399 100644 --- a/httomo/yaml_checker.py +++ b/httomo/yaml_checker.py @@ -220,6 +220,33 @@ def check_parameter_names_are_known(conf: PipelineConfig) -> bool: return True +def check_omitted_parameters_are_not_required(conf: PipelineConfig) -> bool: + """ + Check that any parameters omitted in method configs are not required. + + Notes + ----- + This check is functionally equivalent to checking that any parameters omitted in method + configs have default values. + """ + template_yaml_conf = _get_template_yaml_conf(conf) + for config, template in zip(conf, template_yaml_conf): + template_param_dict = template["parameters"] + config_params = set(config.get("parameters", {}).keys()) + template_params = set(template_param_dict.keys()) + omitted_params = template_params - config_params + + for param in omitted_params: + if template_param_dict[param] == "REQUIRED": + err_str = ( + f"The parameter '{param}' for '{config["method"]}' was omitted but is " + "required." + ) + _print_with_colour(err_str) + return False + return True + + def check_parameter_names_are_str(conf: PipelineConfig) -> bool: """Parameter names should be type string""" non_str_param_names = { @@ -465,6 +492,7 @@ def validate_yaml_config(yaml_file: Path, in_file: Optional[Path] = None) -> boo side_out_matches_ref_arg = check_side_out_matches_ref_arg(conf) required_keys_present = check_keys(conf) are_required_parameters_missing = check_no_required_parameter_values(conf) + are_omitted_params_required = check_omitted_parameters_are_not_required(conf) all_checks_pass = ( is_yaml_ok @@ -479,6 +507,7 @@ def validate_yaml_config(yaml_file: Path, in_file: Optional[Path] = None) -> boo and side_out_matches_ref_arg and required_keys_present and are_required_parameters_missing + and are_omitted_params_required ) if not all_checks_pass: diff --git a/tests/samples/pipeline_template_examples/testing/omitted_required_param.yaml b/tests/samples/pipeline_template_examples/testing/omitted_required_param.yaml new file mode 100644 index 000000000..1ec8f92f6 --- /dev/null +++ b/tests/samples/pipeline_template_examples/testing/omitted_required_param.yaml @@ -0,0 +1,11 @@ +- method: standard_tomo + module_path: httomo.data.hdf.loaders + parameters: + data_path: auto + image_key_path: auto + rotation_angles: auto +- method: distortion_correction_proj_discorpy + module_path: httomolibgpu.prep.alignment + parameters: + order: 3 + mode: constant diff --git a/tests/test_yaml_checker.py b/tests/test_yaml_checker.py index 20543a1f5..b7a719b24 100644 --- a/tests/test_yaml_checker.py +++ b/tests/test_yaml_checker.py @@ -12,6 +12,7 @@ check_first_method_is_loader, check_hdf5_paths_against_loader, check_methods_exist_in_templates, + check_omitted_parameters_are_not_required, check_parameter_names_are_known, check_parameter_names_are_str, check_no_required_parameter_values, @@ -120,6 +121,14 @@ def test_check_no_required_parameter_values(sample_pipelines: str, load_yaml: Ca assert not check_no_required_parameter_values(conf) +def test_check_omitted_parameters_are_not_required( + sample_pipelines: str, load_yaml: Callable +): + pipeline = sample_pipelines + "testing/omitted_required_param.yaml" + conf = load_yaml(pipeline) + assert not check_omitted_parameters_are_not_required(conf) + + def test_check_no_duplicated_keys(sample_pipelines: str, load_yaml: Callable): required_param_pipeline = sample_pipelines + "testing/duplicated_key.yaml" assert not check_no_duplicated_keys(Path(required_param_pipeline))