diff --git a/src/maison/config.py b/src/maison/config.py index d5f5243..2a5c947 100644 --- a/src/maison/config.py +++ b/src/maison/config.py @@ -1,23 +1,35 @@ """Module to hold the `UserConfig` class definition.""" -from functools import reduce -from pathlib import Path -from typing import Any -from typing import Optional -from typing import Protocol -from typing import Union +import pathlib +import typing -from maison.errors import NoSchemaError -from maison.utils import _collect_configs -from maison.utils import deep_merge +from maison import config_reader +from maison import config_validator as validator +from maison import disk_filesystem +from maison import errors +from maison import readers +from maison import service +from maison import types -class _IsSchema(Protocol): - """Protocol for config schemas.""" +def _bootstrap_service(package_name: str) -> service.ConfigService: + _config_reader = config_reader.ConfigReader() - def model_dump(self) -> dict[Any, Any]: - """Convert the validated config to a dict.""" - ... + pyproject_parser = readers.PyprojectReader(package_name=package_name) + toml_parser = readers.TomlReader() + ini_parser = readers.IniReader() + + _config_reader.register_parser( + suffix=".toml", parser=pyproject_parser, stem="pyproject" + ) + _config_reader.register_parser(suffix=".toml", parser=toml_parser) + _config_reader.register_parser(suffix=".ini", parser=ini_parser) + + return service.ConfigService( + filesystem=disk_filesystem.DiskFilesystem(), + config_reader=_config_reader, + validator=validator.Validator(), + ) class UserConfig: @@ -26,9 +38,9 @@ class UserConfig: def __init__( self, package_name: str, - starting_path: Optional[Path] = None, - source_files: Optional[list[str]] = None, - schema: Optional[type[_IsSchema]] = None, + starting_path: typing.Optional[pathlib.Path] = None, + source_files: typing.Optional[list[str]] = None, + schema: typing.Optional[type[service.IsSchema]] = None, merge_configs: bool = False, ) -> None: """Initialize the config. @@ -45,14 +57,21 @@ def __init__( merged if multiple are found """ self.source_files = source_files or ["pyproject.toml"] + self.starting_path = starting_path self.merge_configs = merge_configs - self._sources = _collect_configs( - package_name=package_name, + self._schema = schema + + self._service = _bootstrap_service(package_name=package_name) + + _sources = self._service.find_configs( source_files=self.source_files, starting_path=starting_path, ) - self._schema = schema - self._values = self._generate_config_dict() + + self._values = self._service.get_config_values( + config_file_paths=_sources, + merge_configs=merge_configs, + ) def __str__(self) -> str: """Return the __str__. @@ -63,7 +82,7 @@ def __str__(self) -> str: return f"" @property - def values(self) -> dict[str, Any]: + def values(self) -> types.ConfigValues: """Return the user's configuration values. Returns: @@ -72,21 +91,26 @@ def values(self) -> dict[str, Any]: return self._values @values.setter - def values(self, values: dict[str, Any]) -> None: + def values(self, values: types.ConfigValues) -> None: """Set the user's configuration values.""" self._values = values @property - def discovered_paths(self) -> list[Path]: + def discovered_paths(self) -> list[pathlib.Path]: """Return a list of the paths to the config sources found on the filesystem. Returns: a list of the paths to the config sources """ - return [source.filepath for source in self._sources] + return list( + self._service.find_configs( + source_files=self.source_files, + starting_path=self.starting_path, + ) + ) @property - def path(self) -> Optional[Union[Path, list[Path]]]: + def path(self) -> typing.Optional[typing.Union[pathlib.Path, list[pathlib.Path]]]: """Return the path to the selected config source. Returns: @@ -94,7 +118,7 @@ def path(self) -> Optional[Union[Path, list[Path]]]: sources if `merge_configs` is `True`, or the path to the active config source if `False` """ - if len(self._sources) == 0: + if len(self.discovered_paths) == 0: return None if self.merge_configs: @@ -103,7 +127,7 @@ def path(self) -> Optional[Union[Path, list[Path]]]: return self.discovered_paths[0] @property - def schema(self) -> Optional[type[_IsSchema]]: + def schema(self) -> typing.Optional[type[service.IsSchema]]: """Return the schema. Returns: @@ -112,15 +136,15 @@ def schema(self) -> Optional[type[_IsSchema]]: return self._schema @schema.setter - def schema(self, schema: type[_IsSchema]) -> None: + def schema(self, schema: type[service.IsSchema]) -> None: """Set the schema.""" self._schema = schema def validate( self, - schema: Optional[type[_IsSchema]] = None, + schema: typing.Optional[type[service.IsSchema]] = None, use_schema_values: bool = True, - ) -> dict[str, Any]: + ) -> types.ConfigValues: """Validate the configuration. Warning: @@ -153,32 +177,18 @@ class Schema(ConfigSchema): Raises: NoSchemaError: when validation is attempted but no schema has been provided """ - selected_schema: Union[type[_IsSchema], None] = schema or self.schema + selected_schema: typing.Union[type[service.IsSchema], None] = ( + schema or self.schema + ) if not selected_schema: - raise NoSchemaError + raise errors.NoSchemaError - validated_schema = selected_schema(**self.values) + validated_values = self._service.validate_config( + values=self.values, schema=selected_schema + ) if use_schema_values: - self.values = validated_schema.model_dump() + self.values = validated_values return self.values - - def _generate_config_dict(self) -> dict[str, Any]: - """Generate the config dict. - - If `merge_configs` is set to `False` then we use the first config. If `True` - then the dicts of the sources are merged from right to left. - - Returns: - the config dict - """ - if len(self._sources) == 0: - return {} - - if not self.merge_configs: - return self._sources[0].to_dict() - - source_dicts = (source.to_dict() for source in self._sources) - return reduce(lambda a, b: deep_merge(a, b), source_dicts) diff --git a/src/maison/config_reader.py b/src/maison/config_reader.py new file mode 100644 index 0000000..c8234ed --- /dev/null +++ b/src/maison/config_reader.py @@ -0,0 +1,42 @@ +import pathlib +import typing + +from maison import errors +from maison import types + + +ParserDictKey = tuple[str, typing.Union[str, None]] + + +class Parser(typing.Protocol): + def parse_config(self, file_path: pathlib.Path) -> types.ConfigValues: ... + + +class ConfigReader: + def __init__(self) -> None: + self._parsers: dict[ParserDictKey, Parser] = {} + + def register_parser( + self, + suffix: str, + parser: Parser, + stem: typing.Optional[str] = None, + ) -> None: + """Register a parser for a file suffix, optionally restricted by filename stem.""" + key = (suffix, stem) + self._parsers[key] = parser + + def parse_config(self, file_path: pathlib.Path) -> types.ConfigValues: + key: ParserDictKey + + # First try (suffix, stem) + key = (file_path.suffix, file_path.stem) + if key in self._parsers: + return self._parsers[key].parse_config(file_path) + + # Then fallback to (suffix, None) + key = (file_path.suffix, None) + if key in self._parsers: + return self._parsers[key].parse_config(file_path) + + raise errors.UnsupportedConfigError(f"No parser registered for {file_path}") diff --git a/src/maison/config_sources/__init__.py b/src/maison/config_sources/__init__.py deleted file mode 100644 index c15bc96..0000000 --- a/src/maison/config_sources/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Sources.""" diff --git a/src/maison/config_sources/base_source.py b/src/maison/config_sources/base_source.py deleted file mode 100644 index 462b836..0000000 --- a/src/maison/config_sources/base_source.py +++ /dev/null @@ -1,54 +0,0 @@ -"""Module to hold the `BaseSource` abstract class definition.""" - -from abc import ABC -from abc import abstractmethod -from pathlib import Path -from typing import Any - - -class BaseSource(ABC): - """Base class from which concrete source abstractions extend.""" - - def __init__(self, filepath: Path, package_name: str) -> None: - """Initialize the object. - - Args: - filepath: the `Path` to the config file - package_name: the name of the package, used to pick out the relevant section - in a `.toml` or `.ini` file - """ - self.filepath = filepath - self.package_name = package_name - - def __repr__(self) -> str: - """Return the __repr__. - - Returns: - the representation - """ - return f"" - - def __str__(self) -> str: - """Return the __str__. - - Returns: - the representation - """ - return self.__repr__() - - @property - def filename(self) -> str: - """Return the filename. - - Returns: - the filename of the source - """ - return self.filepath.name - - @abstractmethod - def to_dict(self) -> dict[Any, Any]: - """Convert the source config file to a dict. - - Returns: - a dict of the config options and values - """ diff --git a/src/maison/config_sources/ini_source.py b/src/maison/config_sources/ini_source.py deleted file mode 100644 index bb2a2c1..0000000 --- a/src/maison/config_sources/ini_source.py +++ /dev/null @@ -1,31 +0,0 @@ -"""Module to hold the `IniSource` class definition.""" - -from configparser import ConfigParser -from functools import lru_cache -from typing import Any - -from .base_source import BaseSource - - -class IniSource(BaseSource): - """Class to represent a `.ini` config source.""" - - def to_dict(self) -> dict[Any, Any]: - """Convert the source config file to a dict. - - Returns: - a dict of the config options and values - """ - config = self._load_file() - return {section: dict(config.items(section)) for section in config.sections()} - - @lru_cache - def _load_file(self) -> ConfigParser: - """Load the `.ini` file. - - Returns: - a `ConfigParser` object with the `.ini` source read into it - """ - config = ConfigParser() - config.read(self.filepath) - return config diff --git a/src/maison/config_sources/pyproject_source.py b/src/maison/config_sources/pyproject_source.py deleted file mode 100644 index 5dc2218..0000000 --- a/src/maison/config_sources/pyproject_source.py +++ /dev/null @@ -1,20 +0,0 @@ -"""Module to hold the `PyprojectSource` class definition.""" - -from typing import Any - -from .toml_source import TomlSource - - -class PyprojectSource(TomlSource): - """Class to represent a `pyproject.toml` config source.""" - - def to_dict(self) -> dict[Any, Any]: - """Convert the package `pyproject.toml` section to a dict. - - Relies on the convention that config related to package `acme` will be - located under a `[tool.acme]` section in `pyproject.toml` - - Returns: - a dict of the config options and values - """ - return dict(self._load_file().get("tool", {}).get(self.package_name, {})) diff --git a/src/maison/config_sources/toml_source.py b/src/maison/config_sources/toml_source.py deleted file mode 100644 index 20d806c..0000000 --- a/src/maison/config_sources/toml_source.py +++ /dev/null @@ -1,38 +0,0 @@ -"""Module to hold the `TomlSource` class definition.""" - -from functools import lru_cache -from typing import Any - -import toml - -from ..errors import BadTomlError -from .base_source import BaseSource - - -class TomlSource(BaseSource): - """Class to represent a `.toml` config source.""" - - def to_dict(self) -> dict[Any, Any]: - """Convert the source config file to a dict. - - Returns: - a dict of the config options and values - """ - return self._load_file() - - @lru_cache - def _load_file(self) -> dict[Any, Any]: - """Load the `.toml` file. - - Returns: - the `.toml` source converted to a `dict` - - Raises: - BadTomlError: If toml cannot be parsed - """ - try: - return dict(toml.load(self.filepath)) - except toml.decoder.TomlDecodeError as exc: - raise BadTomlError( - f"Error trying to load toml file '{self.filepath}'" - ) from exc diff --git a/src/maison/config_validator.py b/src/maison/config_validator.py new file mode 100644 index 0000000..dfe60e8 --- /dev/null +++ b/src/maison/config_validator.py @@ -0,0 +1,10 @@ +from maison import service +from maison import types + + +class Validator: + def validate( + self, values: types.ConfigValues, schema: type[service.IsSchema] + ) -> types.ConfigValues: + validated_schema = schema(**values) + return validated_schema.model_dump() diff --git a/src/maison/disk_filesystem.py b/src/maison/disk_filesystem.py new file mode 100644 index 0000000..438b4ea --- /dev/null +++ b/src/maison/disk_filesystem.py @@ -0,0 +1,58 @@ +import functools +import pathlib +import typing +from collections.abc import Generator + + +def _path_contains_file(path: pathlib.Path, filename: str) -> bool: + """Determine whether a file exists in the given path. + + Args: + path: the path in which to search for the file + filename: the name of the file + + Returns: + A boolean to indicate whether the given file exists in the given path + """ + return (path / filename).is_file() + + +def _generate_search_paths( + starting_path: pathlib.Path, +) -> Generator[pathlib.Path, None, None]: + """Generate paths from a starting path and traversing up the tree. + + Args: + starting_path: a starting path to start yielding search paths + + Yields: + a path from the tree + """ + yield from [starting_path, *starting_path.parents] + + +class DiskFilesystem: + @functools.lru_cache + def get_file_path( + self, file_name: str, starting_path: typing.Optional[pathlib.Path] = None + ) -> typing.Optional[pathlib.Path]: + """Search for a file by traversing up the tree from a path. + + Args: + filename: the name of the file or an absolute path to a config to search for + starting_path: an optional path from which to start searching + + Returns: + The `Path` to the file if it exists or `None` if it doesn't + """ + filename_path = pathlib.Path(file_name).expanduser() + if filename_path.is_absolute() and filename_path.is_file(): + return filename_path + + start = starting_path or pathlib.Path.cwd() + + for path in _generate_search_paths(starting_path=start): + if _path_contains_file(path=path, filename=file_name): + return path / file_name + + return None diff --git a/src/maison/errors.py b/src/maison/errors.py index 02dfc0c..af937cb 100644 --- a/src/maison/errors.py +++ b/src/maison/errors.py @@ -7,3 +7,7 @@ class NoSchemaError(Exception): class BadTomlError(Exception): """Raised when loading from an invalid toml source is attempted.""" + + +class UnsupportedConfigError(Exception): + """Raised when a config is attempted to be parsed but no parser for it was registered.""" diff --git a/src/maison/readers/__init__.py b/src/maison/readers/__init__.py new file mode 100644 index 0000000..db5f51c --- /dev/null +++ b/src/maison/readers/__init__.py @@ -0,0 +1,6 @@ +from .ini import IniReader +from .pyproject import PyprojectReader +from .toml import TomlReader + + +__all__ = ["IniReader", "PyprojectReader", "TomlReader"] diff --git a/src/maison/readers/ini.py b/src/maison/readers/ini.py new file mode 100644 index 0000000..2278c17 --- /dev/null +++ b/src/maison/readers/ini.py @@ -0,0 +1,11 @@ +import configparser +import pathlib + +from maison import types + + +class IniReader: + def parse_config(self, file_path: pathlib.Path) -> types.ConfigValues: + config = configparser.ConfigParser() + config.read(file_path) + return {section: dict(config.items(section)) for section in config.sections()} diff --git a/src/maison/readers/pyproject.py b/src/maison/readers/pyproject.py new file mode 100644 index 0000000..6fb6cfa --- /dev/null +++ b/src/maison/readers/pyproject.py @@ -0,0 +1,17 @@ +import pathlib + +import toml + +from maison import types + + +class PyprojectReader: + def __init__(self, package_name: str) -> None: + self._package_name = package_name + + def parse_config(self, file_path: pathlib.Path) -> types.ConfigValues: + try: + pyproject_dict = dict(toml.load(file_path)) + except FileNotFoundError: + return {} + return dict(pyproject_dict.get("tool", {}).get(self._package_name, {})) diff --git a/src/maison/readers/toml.py b/src/maison/readers/toml.py new file mode 100644 index 0000000..382bf32 --- /dev/null +++ b/src/maison/readers/toml.py @@ -0,0 +1,13 @@ +import pathlib + +import toml + +from maison import types + + +class TomlReader: + def parse_config(self, file_path: pathlib.Path) -> types.ConfigValues: + try: + return dict(toml.load(file_path)) + except FileNotFoundError: + return {} diff --git a/src/maison/service.py b/src/maison/service.py new file mode 100644 index 0000000..b3ee070 --- /dev/null +++ b/src/maison/service.py @@ -0,0 +1,74 @@ +import pathlib +import typing +from collections.abc import Iterable + +from maison import types +from maison import utils + + +class IsSchema(typing.Protocol): + """Protocol for config schemas.""" + + def model_dump(self) -> types.ConfigValues: + """Convert the validated config to a dict.""" + ... + + +class Filesystem(typing.Protocol): + def get_file_path( + self, file_name: str, starting_path: typing.Optional[pathlib.Path] = None + ) -> typing.Optional[pathlib.Path]: ... + + +class ConfigReader(typing.Protocol): + def parse_config(self, file_path: pathlib.Path) -> types.ConfigValues: ... + + +class Validator(typing.Protocol): + def validate( + self, values: types.ConfigValues, schema: type[IsSchema] + ) -> types.ConfigValues: ... + + +class ConfigService: + def __init__( + self, + filesystem: Filesystem, + config_reader: ConfigReader, + validator: Validator, + ) -> None: + self.filesystem = filesystem + self.config_reader = config_reader + self.validator = validator + + def find_configs( + self, + source_files: list[str], + starting_path: typing.Optional[pathlib.Path] = None, + ) -> Iterable[pathlib.Path]: + for source in source_files: + if filepath := self.filesystem.get_file_path( + file_name=source, starting_path=starting_path + ): + yield filepath + + def get_config_values( + self, + config_file_paths: Iterable[pathlib.Path], + merge_configs: bool, + ) -> types.ConfigValues: + config_values: types.ConfigValues = {} + + for path in config_file_paths: + parsed_config = self.config_reader.parse_config(path) + config_values = utils.deep_merge(config_values, parsed_config) + + if not merge_configs: + break + + return config_values + + def validate_config( + self, values: types.ConfigValues, schema: type[IsSchema] + ) -> types.ConfigValues: + return self.validator.validate(values=values, schema=schema) diff --git a/src/maison/types.py b/src/maison/types.py new file mode 100644 index 0000000..f3dc421 --- /dev/null +++ b/src/maison/types.py @@ -0,0 +1,4 @@ +import typing + + +ConfigValues = dict[str, typing.Union[str, int, float, bool, None, "ConfigValues"]] diff --git a/src/maison/utils.py b/src/maison/utils.py index 4e1d94f..3a9b9e2 100644 --- a/src/maison/utils.py +++ b/src/maison/utils.py @@ -1,113 +1,11 @@ """Module to hold various utils.""" -from collections.abc import Generator -from pathlib import Path -from typing import Any -from typing import Optional +from maison import types -from maison.config_sources.base_source import BaseSource -from maison.config_sources.ini_source import IniSource -from maison.config_sources.pyproject_source import PyprojectSource -from maison.config_sources.toml_source import TomlSource - -def path_contains_file(path: Path, filename: str) -> bool: - """Determine whether a file exists in the given path. - - Args: - path: the path in which to search for the file - filename: the name of the file - - Returns: - A boolean to indicate whether the given file exists in the given path - """ - return (path / filename).is_file() - - -def get_file_path( - filename: str, starting_path: Optional[Path] = None -) -> Optional[Path]: - """Search for a file by traversing up the tree from a path. - - Args: - filename: the name of the file or an absolute path to a config to search for - starting_path: an optional path from which to start searching - - Returns: - The `Path` to the file if it exists or `None` if it doesn't - """ - filename_path = Path(filename).expanduser() - if filename_path.is_absolute() and filename_path.is_file(): - return filename_path - - start = starting_path or Path.cwd() - - for path in _generate_search_paths(starting_path=start): - if path_contains_file(path=path, filename=filename): - return path / filename - - return None - - -def _generate_search_paths(starting_path: Path) -> Generator[Path, None, None]: - """Generate paths from a starting path and traversing up the tree. - - Args: - starting_path: a starting path to start yielding search paths - - Yields: - a path from the tree - """ - yield from [starting_path, *starting_path.parents] - - -def _collect_configs( - package_name: str, - source_files: list[str], - starting_path: Optional[Path] = None, -) -> list[BaseSource]: - """Collect configs and return them in a list. - - Args: - package_name: the name of the package to be used to find the right section in - the config file - source_files: a list of source config filenames to look for. - starting_path: an optional starting path to start the search - - Returns: - a list of the found config sources - """ - sources: list[BaseSource] = [] - - for source in source_files: - file_path = get_file_path( - filename=source, - starting_path=starting_path, - ) - - if not file_path: - continue - - # dict[str, Any] to stop mypy complaining: - # https://github.com/python/mypy/issues/5382#issuecomment-583901369 - source_kwargs: dict[str, Any] = { - "filepath": file_path, - "package_name": package_name, - } - - if source.endswith("toml"): - if source.startswith("pyproject"): - sources.append(PyprojectSource(**source_kwargs)) - else: - sources.append(TomlSource(**source_kwargs)) - - if source.endswith("ini"): - sources.append(IniSource(**source_kwargs)) - - return sources - - -def deep_merge(destination: dict[Any, Any], source: dict[Any, Any]) -> dict[Any, Any]: +def deep_merge( + destination: types.ConfigValues, source: types.ConfigValues +) -> types.ConfigValues: """Recursively updates the destination dictionary. Usage example: @@ -141,7 +39,7 @@ def deep_merge(destination: dict[Any, Any], source: dict[Any, Any]) -> dict[Any, raise RuntimeError( f"Cannot merge dict '{src_value}' into type '{type(dest_node)}'" ) - deep_merge(dest_node, src_value) + _ = deep_merge(dest_node, src_value) else: destination[key] = src_value diff --git a/tests/acceptance_tests/test_config.py b/tests/acceptance_tests/test_config.py new file mode 100644 index 0000000..4d62b1a --- /dev/null +++ b/tests/acceptance_tests/test_config.py @@ -0,0 +1,74 @@ +import pathlib +import textwrap + +import pydantic +import pytest + +from maison import config + + +class TestConfig: + def test_gets_config_defaults_to_pyproject(self, tmp_path: pathlib.Path): + fp = tmp_path / "pyproject.toml" + content = textwrap.dedent(""" + [tool.acme] + hello = true + """) + _ = fp.write_text(content) + + cfg = config.UserConfig(package_name="acme", starting_path=tmp_path) + + assert cfg.values == {"hello": True} + assert cfg.discovered_paths == [fp] + assert cfg.path == fp + + def test_merges_configs(self, tmp_path: pathlib.Path): + pyproject_fp = tmp_path / "pyproject.toml" + content = textwrap.dedent(""" + [tool.acme] + hello = true + """) + _ = pyproject_fp.write_text(content) + + toml_fp = tmp_path / ".acme.toml" + content = textwrap.dedent(""" + goodbye = true + """) + _ = toml_fp.write_text(content) + + cfg = config.UserConfig( + package_name="acme", + source_files=["pyproject.toml", ".acme.toml"], + starting_path=tmp_path, + merge_configs=True, + ) + + assert cfg.values == {"hello": True, "goodbye": True} + assert cfg.discovered_paths == [pyproject_fp, toml_fp] + assert cfg.path == [pyproject_fp, toml_fp] + + +class TestValidation: + def test_validates_config(self, tmp_path: pathlib.Path): + fp = tmp_path / "pyproject.toml" + content = textwrap.dedent(""" + [tool.acme] + foo = "bar" + """) + _ = fp.write_text(content) + + class Schema(pydantic.BaseModel): + foo: int + + cfg = config.UserConfig( + package_name="acme", starting_path=tmp_path, schema=Schema + ) + + with pytest.raises(pydantic.ValidationError): + _ = cfg.validate() + + def test_raises_error_if_no_schema(self): + cfg = config.UserConfig(package_name="acme") + + with pytest.raises(errors.NoSchemaError): + _ = cfg.validate() diff --git a/tests/unit_tests/config_sources/__init__.py b/tests/integration_tests/readers/__init__.py similarity index 100% rename from tests/unit_tests/config_sources/__init__.py rename to tests/integration_tests/readers/__init__.py diff --git a/tests/integration_tests/readers/test_ini.py b/tests/integration_tests/readers/test_ini.py new file mode 100644 index 0000000..68a63b3 --- /dev/null +++ b/tests/integration_tests/readers/test_ini.py @@ -0,0 +1,91 @@ +import pathlib +import tempfile +import textwrap +import typing + +import pytest + +from maison.readers import ini + + +FileFactory = typing.Callable[[str], pathlib.Path] + + +@pytest.fixture +def tmp_ini_file() -> FileFactory: + """Helper to create a temporary ini file.""" + + def _create(content: str) -> pathlib.Path: + with tempfile.NamedTemporaryFile(mode="w+", suffix=".ini", delete=False) as tmp: + _ = tmp.write(content) + tmp.flush() + return pathlib.Path(tmp.name) + + return _create + + +class TestParseConfig: + def test_parse_single_section(self, tmp_ini_file: FileFactory): + ini_content = textwrap.dedent(""" + [database] + host = localhost + port = 5432 + """) + path = tmp_ini_file(ini_content) + + reader = ini.IniReader() + result = reader.parse_config(path) + + assert result == {"database": {"host": "localhost", "port": "5432"}} + + def test_parse_multiple_sections(self, tmp_ini_file: FileFactory): + ini_content = textwrap.dedent(""" + [database] + host = localhost + port = 5432 + + [api] + key = secret + endpoint = https://example.com + """) + path = tmp_ini_file(ini_content) + + reader = ini.IniReader() + result = reader.parse_config(path) + + assert result == { + "database": {"host": "localhost", "port": "5432"}, + "api": {"key": "secret", "endpoint": "https://example.com"}, + } + + def test_empty_file_returns_empty_dict(self, tmp_ini_file: FileFactory): + path = tmp_ini_file("") + + reader = ini.IniReader() + result = reader.parse_config(path) + + assert result == {} + + def test_missing_file_returns_empty_dict(self, tmp_path: pathlib.Path): + # configparser silently ignores missing files + path = tmp_path / "nonexistent.ini" + + reader = ini.IniReader() + result = reader.parse_config(path) + + assert result == {} + + def test_overlapping_keys_in_different_sections(self, tmp_ini_file: FileFactory): + ini_content = textwrap.dedent(""" + [section1] + key = value1 + + [section2] + key = value2 + """) + path = tmp_ini_file(ini_content) + + reader = ini.IniReader() + result = reader.parse_config(path) + + assert result == {"section1": {"key": "value1"}, "section2": {"key": "value2"}} diff --git a/tests/integration_tests/readers/test_pyproject.py b/tests/integration_tests/readers/test_pyproject.py new file mode 100644 index 0000000..84faafc --- /dev/null +++ b/tests/integration_tests/readers/test_pyproject.py @@ -0,0 +1,99 @@ +import pathlib +import tempfile +import textwrap +import typing + +import pytest + +from maison.readers import pyproject + + +FileFactory = typing.Callable[[str], pathlib.Path] + + +@pytest.fixture +def tmp_pyproject_file() -> FileFactory: + """Helper to create a temporary pyproject file.""" + + def _create(content: str) -> pathlib.Path: + with tempfile.NamedTemporaryFile( + mode="w+", suffix=".toml", delete=False + ) as tmp: + _ = tmp.write(content) + tmp.flush() + return pathlib.Path(tmp.name) + + return _create + + +class TestParseConfig: + def test_parse_tool_section_with_values(self, tmp_pyproject_file: FileFactory): + toml_content = textwrap.dedent(""" + [tool.myapp] + debug = true + retries = 3 + url = "https://example.com" + """) + path = tmp_pyproject_file(toml_content) + + reader = pyproject.PyprojectReader("myapp") + result = reader.parse_config(path) + + assert result == {"debug": True, "retries": 3, "url": "https://example.com"} + + def test_returns_empty_dict_if_package_section_missing( + self, tmp_pyproject_file: FileFactory + ): + toml_content = textwrap.dedent(""" + [tool.otherapp] + enabled = true + """) + path = tmp_pyproject_file(toml_content) + + reader = pyproject.PyprojectReader("myapp") + result = reader.parse_config(path) + + assert result == {} + + def test_returns_empty_dict_if_tool_table_missing( + self, tmp_pyproject_file: FileFactory + ): + toml_content = textwrap.dedent(""" + [build-system] + requires = ["setuptools"] + """) + path = tmp_pyproject_file(toml_content) + + reader = pyproject.PyprojectReader("myapp") + result = reader.parse_config(path) + + assert result == {} + + def test_parse_nested_values_inside_package(self, tmp_pyproject_file: FileFactory): + toml_content = textwrap.dedent(""" + [tool.myapp.database] + host = "localhost" + port = 5432 + """) + path = tmp_pyproject_file(toml_content) + + reader = pyproject.PyprojectReader("myapp") + result = reader.parse_config(path) + + assert result == {"database": {"host": "localhost", "port": 5432}} + + def test_empty_file_returns_empty_dict(self, tmp_pyproject_file: FileFactory): + path = tmp_pyproject_file("") + + reader = pyproject.PyprojectReader("myapp") + result = reader.parse_config(path) + + assert result == {} + + def test_missing_file_raises_file_not_found(self, tmp_path: pathlib.Path): + path = tmp_path / "no_such_pyproject.toml" + + reader = pyproject.PyprojectReader("myapp") + result = reader.parse_config(path) + + assert result == {} diff --git a/tests/integration_tests/readers/test_toml.py b/tests/integration_tests/readers/test_toml.py new file mode 100644 index 0000000..98c5efc --- /dev/null +++ b/tests/integration_tests/readers/test_toml.py @@ -0,0 +1,92 @@ +import pathlib +import tempfile +import textwrap +import typing + +import pytest + +from maison.readers import toml + + +FileFactory = typing.Callable[[str], pathlib.Path] + + +@pytest.fixture +def tmp_toml_file() -> FileFactory: + """Helper to create a temporary toml file.""" + + def _create(content: str) -> pathlib.Path: + with tempfile.NamedTemporaryFile( + mode="w+", suffix=".toml", delete=False + ) as tmp: + _ = tmp.write(content) + tmp.flush() + return pathlib.Path(tmp.name) + + return _create + + +class TestParseConfig: + def test_parse_single_section(self, tmp_toml_file: FileFactory): + toml_content = textwrap.dedent(""" + [database] + host = "localhost" + port = 5432 + """) + path = tmp_toml_file(toml_content) + + reader = toml.TomlReader() + result = reader.parse_config(path) + + assert result == {"database": {"host": "localhost", "port": 5432}} + + def test_parse_multiple_sections(self, tmp_toml_file: FileFactory): + toml_content = textwrap.dedent(""" + [database] + host = "localhost" + port = 5432 + + [api] + key = "secret" + endpoint = "https://example.com" + """) + path = tmp_toml_file(toml_content) + + reader = toml.TomlReader() + result = reader.parse_config(path) + + assert result == { + "database": {"host": "localhost", "port": 5432}, + "api": {"key": "secret", "endpoint": "https://example.com"}, + } + + def test_empty_file_returns_empty_dict(self, tmp_toml_file: FileFactory): + path = tmp_toml_file("") + + reader = toml.TomlReader() + result = reader.parse_config(path) + + assert result == {} + + def test_missing_file_returns_empty_dict(self, tmp_path: pathlib.Path): + path = tmp_path / "nonexistent.toml" + + reader = toml.TomlReader() + result = reader.parse_config(path) + + assert result == {} + + def test_overlapping_keys_in_different_sections(self, tmp_toml_file: FileFactory): + toml_content = textwrap.dedent(""" + [section1] + key = "value1" + + [section2] + key = "value2" + """) + path = tmp_toml_file(toml_content) + + reader = toml.TomlReader() + result = reader.parse_config(path) + + assert result == {"section1": {"key": "value1"}, "section2": {"key": "value2"}} diff --git a/tests/integration_tests/test_config.py b/tests/integration_tests/test_config.py new file mode 100644 index 0000000..6132bd8 --- /dev/null +++ b/tests/integration_tests/test_config.py @@ -0,0 +1,135 @@ +import pathlib +import textwrap + +import pytest + +from maison import config +from maison import errors +from maison import types + + +class TestUserConfig: + def test_str(self): + cfg = config.UserConfig(package_name="acme") + + assert str(cfg) == "" + + def test_values(self, tmp_path: pathlib.Path): + fp = tmp_path / "pyproject.toml" + content = textwrap.dedent(""" + [tool.acme] + hello = true + """) + _ = fp.write_text(content) + + cfg = config.UserConfig(package_name="acme", starting_path=tmp_path) + + assert cfg.values == {"hello": True} + + def test_values_setter(self): + cfg = config.UserConfig(package_name="acme") + + cfg.values = {"hello": True} + + assert cfg.values == {"hello": True} + + def test_discovered_paths(self, tmp_path: pathlib.Path): + fp = tmp_path / "pyproject.toml" + content = textwrap.dedent(""" + [tool.acme] + hello = true + """) + _ = fp.write_text(content) + + cfg = config.UserConfig(package_name="acme", starting_path=tmp_path) + + assert cfg.discovered_paths == [fp] + + def test_path_no_sources(self, tmp_path: pathlib.Path): + cfg = config.UserConfig(package_name="acme", starting_path=tmp_path) + + assert cfg.path is None + + def test_path_with_sources(self, tmp_path: pathlib.Path): + fp = tmp_path / "pyproject.toml" + content = textwrap.dedent(""" + [tool.acme] + hello = true + """) + _ = fp.write_text(content) + + cfg = config.UserConfig(package_name="acme", starting_path=tmp_path) + + assert cfg.path == fp + + def test_schema(self): + class Schema: + def model_dump(self): + return {} + + cfg = config.UserConfig(package_name="acme", schema=Schema) + + assert cfg.schema == Schema + + class NewSchema: + def model_dump(self): + return {} + + cfg.schema = NewSchema + assert cfg.schema == NewSchema + + +class TestValidate: + def test_no_schema(self): + cfg = config.UserConfig(package_name="acme") + + with pytest.raises(errors.NoSchemaError): + _ = cfg.validate() + + def test_validaes_config_without_using_schema_values(self, tmp_path: pathlib.Path): + fp = tmp_path / "pyproject.toml" + content = textwrap.dedent(""" + [tool.acme] + hello = true + """) + + _ = fp.write_text(content) + + class Schema: + def __init__(self, *args, **kwargs) -> None: + pass + + def model_dump(self) -> types.ConfigValues: + return {"key": "validated"} + + cfg = config.UserConfig( + package_name="acme", starting_path=tmp_path, schema=Schema + ) + + values = cfg.validate(use_schema_values=False) + + assert values == {"hello": True} + + def test_validaes_config_with_using_schema_values(self, tmp_path: pathlib.Path): + fp = tmp_path / "pyproject.toml" + content = textwrap.dedent(""" + [tool.acme] + hello = true + """) + + _ = fp.write_text(content) + + class Schema: + def __init__(self, *args, **kwargs) -> None: + pass + + def model_dump(self) -> types.ConfigValues: + return {"key": "validated"} + + cfg = config.UserConfig( + package_name="acme", starting_path=tmp_path, schema=Schema + ) + + values = cfg.validate(use_schema_values=True) + + assert values == {"key": "validated"} diff --git a/tests/integration_tests/test_disk_filesystem.py b/tests/integration_tests/test_disk_filesystem.py new file mode 100644 index 0000000..f274022 --- /dev/null +++ b/tests/integration_tests/test_disk_filesystem.py @@ -0,0 +1,45 @@ +import pathlib + +from maison import disk_filesystem + + +class TestGetFilePath: + def test_get_file_path_finds_in_current_dir(self, tmp_path: pathlib.Path): + fs = disk_filesystem.DiskFilesystem() + + file = tmp_path / "settings.json" + _ = file.write_text("{}") + + result = fs.get_file_path("settings.json", starting_path=tmp_path) + + assert result == file + + def test_get_file_path_traverses_up(self, tmp_path: pathlib.Path): + fs = disk_filesystem.DiskFilesystem() + + nested = tmp_path / "a" / "b" + nested.mkdir(parents=True) + + file = tmp_path / "a" / "target.txt" + _ = file.write_text("found me") + + result = fs.get_file_path("target.txt", starting_path=nested) + + assert result == file + + def test_get_file_path_with_absolute_path(self, tmp_path: pathlib.Path): + fs = disk_filesystem.DiskFilesystem() + + file = tmp_path / "absolute.txt" + _ = file.write_text("hello") + + result = fs.get_file_path(str(file)) + + assert result == file + + def test_get_file_path_returns_none_if_not_found(self): + fs = disk_filesystem.DiskFilesystem() + + result = fs.get_file_path("ghost.ini") + + assert result is None diff --git a/tests/unit_tests/config_sources/test_base_source.py b/tests/unit_tests/config_sources/test_base_source.py deleted file mode 100644 index ea18276..0000000 --- a/tests/unit_tests/config_sources/test_base_source.py +++ /dev/null @@ -1,36 +0,0 @@ -"""Tests for the `BaseSource` class.""" - -from pathlib import Path -from typing import Any -from typing import Callable - -from maison.config_sources.base_source import BaseSource - - -class ConcreteSource(BaseSource): - """Concretion of `BaseSource` for testing purposes""" - - def to_dict(self) -> dict[Any, Any]: - """Return a dict.""" - return {} - - -class TestRepr: - """Tests for the `__repr__` method.""" - - def test_contains_path(self) -> None: - source = ConcreteSource(filepath=Path("~/file.txt"), package_name="acme") - - assert "file.txt" in repr(source) - assert str(source) == repr(source) - - -class TestFilename: - """Tests for the `filename` property.""" - - def test_success(self, create_tmp_file: Callable[..., Path]) -> None: - path_to_file = create_tmp_file(filename="file.txt") - - source = ConcreteSource(filepath=path_to_file, package_name="acme") - - assert source.filename == "file.txt" diff --git a/tests/unit_tests/config_sources/test_ini_source.py b/tests/unit_tests/config_sources/test_ini_source.py deleted file mode 100644 index 822d5e8..0000000 --- a/tests/unit_tests/config_sources/test_ini_source.py +++ /dev/null @@ -1,36 +0,0 @@ -"""Tests for the `IniSource` class.""" - -from pathlib import Path -from typing import Callable - -from maison.config_sources.ini_source import IniSource - - -class TestToDict: - """Tests for the `to_dict` method.""" - - def test_success(self, create_tmp_file: Callable[..., Path]) -> None: - """A `.ini` file is converted to a `dict`""" - ini_file = """ -[section 1] -option_1 = value_1 - -[section 2] -option_2 = value_2 - """ - ini_path = create_tmp_file(content=ini_file, filename="foo.ini") - - toml_source = IniSource(filepath=ini_path, package_name="acme") - - assert toml_source.to_dict() == { - "section 1": {"option_1": "value_1"}, - "section 2": {"option_2": "value_2"}, - } - - def test_empty_file(self, create_tmp_file: Callable[..., Path]) -> None: - """Empty `.ini` returns an empty dict""" - ini_path = create_tmp_file(filename="foo.ini") - - toml_source = IniSource(filepath=ini_path, package_name="acme") - - assert toml_source.to_dict() == {} diff --git a/tests/unit_tests/config_sources/test_pyproject_source.py b/tests/unit_tests/config_sources/test_pyproject_source.py deleted file mode 100644 index 948bc74..0000000 --- a/tests/unit_tests/config_sources/test_pyproject_source.py +++ /dev/null @@ -1,35 +0,0 @@ -"""Tests for the `PyprojectSource` class.""" - -from pathlib import Path -from typing import Callable - -from maison.config_sources.pyproject_source import PyprojectSource - - -class TestToDict: - """Tests for the `to_dict` method.""" - - def test_success(self, create_pyproject_toml: Callable[..., Path]) -> None: - pyproject_path = create_pyproject_toml() - - pyproject_source = PyprojectSource(filepath=pyproject_path, package_name="foo") - - assert pyproject_source.to_dict() == {"bar": "baz"} - - def test_unrecognised_section_name( - self, create_pyproject_toml: Callable[..., Path] - ) -> None: - """An empty dict is returned if the package name is not found""" - pyproject_path = create_pyproject_toml(section_name="foo") - - pyproject_source = PyprojectSource(filepath=pyproject_path, package_name="bar") - - assert pyproject_source.to_dict() == {} - - def test_unrecognised_format(self, create_toml: Callable[..., Path]) -> None: - """An unrecognised format or pyproject.toml returns an empty dict""" - pyproject_path = create_toml(filename="pyproject.toml", content={"foo": "bar"}) - - pyproject_source = PyprojectSource(filepath=pyproject_path, package_name="baz") - - assert pyproject_source.to_dict() == {} diff --git a/tests/unit_tests/config_sources/test_toml_source.py b/tests/unit_tests/config_sources/test_toml_source.py deleted file mode 100644 index 7f46083..0000000 --- a/tests/unit_tests/config_sources/test_toml_source.py +++ /dev/null @@ -1,42 +0,0 @@ -"""Tests for the `TomlSource` class.""" - -import re -from pathlib import Path -from textwrap import dedent -from typing import Callable - -import pytest - -from maison.config_sources.toml_source import TomlSource -from maison.errors import BadTomlError - - -class TestToDict: - """Tests for the `to_dict` method.""" - - def test_success(self, create_toml: Callable[..., Path]) -> None: - """A `.toml` is loaded and converted to a `dict`""" - toml_path = create_toml(filename="config.toml", content={"foo": "bar"}) - - toml_source = TomlSource(filepath=toml_path, package_name="acme") - - assert toml_source.to_dict() == {"foo": "bar"} - - def test_toml_decode_error(self, create_toml: Callable[..., Path]) -> None: - """Toml decoding errors are reported""" - toml_path = create_toml(filename="config.toml") - toml_path.write_text( - dedent( - """ - "foo" = "bar" - "foo" = "bar" - """ - ), - encoding="utf-8", - ) - - toml_source = TomlSource(filepath=toml_path, package_name="acme") - - error_regex = re.escape(f"Error trying to load toml file '{toml_path!s}'") - with pytest.raises(BadTomlError, match=error_regex): - toml_source.to_dict() diff --git a/tests/unit_tests/conftest.py b/tests/unit_tests/conftest.py index 340618f..066426a 100644 --- a/tests/unit_tests/conftest.py +++ b/tests/unit_tests/conftest.py @@ -1,52 +1 @@ """Store the classes and fixtures used throughout the tests.""" - -from pathlib import Path -from typing import Any -from typing import Callable -from typing import Optional - -import pytest -import toml - - -@pytest.fixture(name="create_tmp_file") -def create_tmp_file_fixture(tmp_path: Path) -> Callable[..., Path]: - """Fixture for creating a temporary file.""" - - def _create_tmp_file(content: str = "", filename: str = "file.txt") -> Path: - tmp_file = tmp_path / filename - tmp_file.write_text(content) - return tmp_file - - return _create_tmp_file - - -@pytest.fixture(name="create_toml") -def create_toml_fixture(create_tmp_file: Callable[..., Path]) -> Callable[..., Path]: - """Fixture for creating a `.toml` file.""" - - def _create_toml( - filename: str, - content: Optional[dict[str, Any]] = None, - ) -> Path: - content = content or {} - config_toml = toml.dumps(content) - return create_tmp_file(content=config_toml, filename=filename) - - return _create_toml - - -@pytest.fixture -def create_pyproject_toml(create_toml: Callable[..., Path]) -> Callable[..., Path]: - """Fixture for creating a `pyproject.toml`.""" - - def _create_pyproject_toml( - section_name: str = "foo", - content: Optional[dict[str, Any]] = None, - filename: str = "pyproject.toml", - ) -> Path: - content = content or {"bar": "baz"} - config_dict = {"tool": {section_name: content}} - return create_toml(filename=filename, content=config_dict) - - return _create_pyproject_toml diff --git a/tests/unit_tests/test_config.py b/tests/unit_tests/test_config.py deleted file mode 100644 index 9b02c6c..0000000 --- a/tests/unit_tests/test_config.py +++ /dev/null @@ -1,402 +0,0 @@ -"""Tests for the `Config` classes.""" - -from pathlib import Path -from typing import Callable - -import pytest -from pydantic import BaseModel -from pydantic import ConfigDict -from pydantic import ValidationError - -from maison.config import UserConfig -from maison.errors import NoSchemaError - - -class TestUserConfig: - """Tests for the `UserConfig` class.""" - - def test_str(self, create_tmp_file: Callable[..., Path]) -> None: - pyproject_path = create_tmp_file(filename="pyproject.toml") - - config = UserConfig(package_name="foo", starting_path=pyproject_path) - - assert str(config) == "" - - -class TestDictObject: - """Tests to ensure that the config is accessible as a dict.""" - - def test_valid_pyproject(self, create_pyproject_toml: Callable[..., Path]) -> None: - """A valid pyproject is parsed to a dict object.""" - pyproject_path = create_pyproject_toml() - - config = UserConfig(package_name="foo", starting_path=pyproject_path) - - assert config.values == {"bar": "baz"} - - -class TestSourceFiles: - """Tests for the `source_files` init argument.""" - - def test_not_found(self) -> None: - """Non existent source files are handled.""" - config = UserConfig(package_name="foo", source_files=["foo"]) - - assert config.path is None - assert config.values == {} - - def test_unrecognised_file_extension( - self, - create_tmp_file: Callable[..., Path], - ) -> None: - """Unrecognised source file extensions are handled.""" - source_path = create_tmp_file(filename="foo.txt") - config = UserConfig( - package_name="foo", - source_files=["foo.txt"], - starting_path=source_path, - ) - - assert config.path is None - assert config.values == {} - - def test_single_valid_toml_source(self, create_toml: Callable[..., Path]) -> None: - """Toml files other than pyproject.toml files are handled.""" - source_path = create_toml(filename="another.toml", content={"bar": "baz"}) - - config = UserConfig( - package_name="foo", - starting_path=source_path, - source_files=["another.toml"], - ) - - assert config.path == source_path - assert config.values["bar"] == "baz" - - def test_multiple_valid_toml_sources( - self, - create_pyproject_toml: Callable[..., Path], - create_toml: Callable[..., Path], - ) -> None: - """When there are multiple sources, the first one is used""" - source_path_1 = create_toml(filename="another.toml", content={"bar": "baz"}) - - source_path_2 = create_pyproject_toml( - section_name="oof", content={"rab": "zab"} - ) - - config = UserConfig( - package_name="foo", - starting_path=source_path_2, - source_files=["another.toml", "pyproject.toml"], - ) - - assert config.discovered_paths == [source_path_1, source_path_2] - assert config.values["bar"] == "baz" - - def test_absolute_path(self, create_tmp_file: Callable[..., Path]) -> None: - """Source files can be found using absolute paths""" - path = create_tmp_file(filename="acme.ini") - - config = UserConfig( - package_name="foo", - source_files=[str(path)], - ) - - assert config.discovered_paths == [path] - - def test_absolute_path_not_exist( - self, - create_pyproject_toml: Callable[..., Path], - ) -> None: - """Non existent absolute paths are handled.""" - pyproject_path = create_pyproject_toml() - - config = UserConfig( - package_name="foo", - source_files=["~/.config/acme.ini", "pyproject.toml"], - starting_path=pyproject_path, - ) - - assert config.discovered_paths == [pyproject_path] - - -class TestIniFiles: - """Tests for handling x.ini config files.""" - - def test_valid_ini_file(self, create_tmp_file: Callable[..., Path]) -> None: - ini_file = """ -[section 1] -option_1 = value_1 - -[section 2] -option_2 = value_2 - """ - source_path = create_tmp_file(content=ini_file, filename="foo.ini") - config = UserConfig( - package_name="foo", - starting_path=source_path, - source_files=["foo.ini"], - ) - - assert config.discovered_paths == [source_path] - assert config.values == { - "section 1": {"option_1": "value_1"}, - "section 2": {"option_2": "value_2"}, - } - - -class TestValidation: - """Tests for schema validation.""" - - def test_no_schema(self) -> None: - config = UserConfig(package_name="acme", starting_path=Path("/")) - - assert config.values == {} - - with pytest.raises(NoSchemaError): - config.validate() - - def test_one_schema_with_valid_config( - self, - create_pyproject_toml: Callable[..., Path], - ) -> None: - """The config is validated with a given schema.""" - - class Schema(BaseModel): - """Defines schema.""" - - bar: str - - pyproject_path = create_pyproject_toml() - config = UserConfig( - package_name="foo", - starting_path=pyproject_path, - schema=Schema, - ) - - config.validate() - - assert config.values["bar"] == "baz" - - def test_one_schema_injected_at_validation( - self, - create_pyproject_toml: Callable[..., Path], - ) -> None: - """Schemas supplied as an argument are used""" - - class Schema(BaseModel): - """Defines schema.""" - - bar: str - - pyproject_path = create_pyproject_toml() - config = UserConfig( - package_name="foo", - starting_path=pyproject_path, - ) - - config.validate(schema=Schema) - - assert config.values["bar"] == "baz" - - def test_use_schema_values( - self, - create_pyproject_toml: Callable[..., Path], - ) -> None: - """Config values can be cast to the validated values.""" - - class Schema(BaseModel): - """Defines schema.""" - - model_config = ConfigDict(coerce_numbers_to_str=True) - - bar: str - other: str = "hello" - - pyproject_path = create_pyproject_toml(content={"bar": 1}) - config = UserConfig( - package_name="foo", - starting_path=pyproject_path, - schema=Schema, - ) - - config.validate() - - assert config.values["bar"] == "1" - assert config.values["other"] == "hello" - - def test_not_use_schema_values( - self, - create_pyproject_toml: Callable[..., Path], - ) -> None: - """If `use_schema_values` is set to False then don't use validated values.""" - - class Schema(BaseModel): - """Defines schema.""" - - model_config = ConfigDict(coerce_numbers_to_str=True) - - bar: str - other: str = "hello" - - pyproject_path = create_pyproject_toml(content={"bar": 1}) - config = UserConfig( - package_name="foo", - starting_path=pyproject_path, - schema=Schema, - ) - - config.validate(use_schema_values=False) - - assert config.values["bar"] == 1 - assert "other" not in config.values - - def test_schema_override( - self, - create_pyproject_toml: Callable[..., Path], - ) -> None: - """Schemas given as an argument are preferred""" - - class InitSchema(BaseModel): - """Defines schema for 1.""" - - bar: str = "schema_1" - - class ArgumentSchema(BaseModel): - """Defines schema for 2.""" - - bar: str = "schema_2" - - pyproject_path = create_pyproject_toml(content={"baz": "baz"}) - config = UserConfig( - package_name="foo", - starting_path=pyproject_path, - schema=InitSchema, - ) - - config.validate(schema=ArgumentSchema) - - assert config.values["bar"] == "schema_2" - - def test_invalid_configuration( - self, - create_pyproject_toml: Callable[..., Path], - ) -> None: - """Validation errors are raised when config fails validation.""" - - class Schema(BaseModel): - """Defines schema.""" - - bar: str - - pyproject_path = create_pyproject_toml(content={"baz": "baz"}) - config = UserConfig( - package_name="foo", - starting_path=pyproject_path, - schema=Schema, - ) - - with pytest.raises(ValidationError): - config.validate() - - def test_setter(self) -> None: - """Schemas can be set using the setter.""" - - class Schema(BaseModel): - """Defines schema.""" - - config = UserConfig(package_name="foo") - - assert config.schema is None - - config.schema = Schema - - assert config.schema is Schema - - -class TestMergeConfig: - """Tests for the merging of multiple config sources.""" - - def test_no_overwrites( - self, - create_toml: Callable[..., Path], - create_tmp_file: Callable[..., Path], - create_pyproject_toml: Callable[..., Path], - ) -> None: - """Configs without overlapping values are merged.""" - config_1_path = create_toml(filename="config.toml", content={"option_1": True}) - ini_file = """ -[foo] -option_2 = true - """ - config_2_path = create_tmp_file(filename="config.ini", content=ini_file) - pyproject_path = create_pyproject_toml(content={"option_3": True}) - - config = UserConfig( - package_name="foo", - source_files=[str(config_1_path), str(config_2_path), "pyproject.toml"], - starting_path=pyproject_path, - merge_configs=True, - ) - - assert config.path == [config_1_path, config_2_path, pyproject_path] - assert config.values == { - "option_1": True, - "foo": { - "option_2": "true", - }, - "option_3": True, - } - - def test_overwrites( - self, - create_toml: Callable[..., Path], - create_pyproject_toml: Callable[..., Path], - ) -> None: - """Configs with overlapping values are merged.""" - config_1_path = create_toml( - filename="config_1.toml", content={"option": "config_1"} - ) - config_2_path = create_toml( - filename="config_2.toml", content={"option": "config_2"} - ) - pyproject_path = create_pyproject_toml(content={"option": "config_3"}) - - config = UserConfig( - package_name="foo", - source_files=[str(config_1_path), str(config_2_path), "pyproject.toml"], - starting_path=pyproject_path, - merge_configs=True, - ) - - assert config.values == { - "option": "config_3", - } - - def test_nested( - self, - create_toml: Callable[..., Path], - create_pyproject_toml: Callable[..., Path], - ) -> None: - """Configs with nested overlapping values are deep merged.""" - config_1_path = create_toml( - filename="config_1.toml", content={"option": {"nested_1": "config_1"}} - ) - config_2_path = create_toml( - filename="config_2.toml", content={"option": {"nested_2": "config_2"}} - ) - pyproject_path = create_pyproject_toml( - content={"option": {"nested_2": "config_3"}} - ) - - config = UserConfig( - package_name="foo", - source_files=[str(config_1_path), str(config_2_path), "pyproject.toml"], - starting_path=pyproject_path, - merge_configs=True, - ) - - assert config.values == { - "option": {"nested_1": "config_1", "nested_2": "config_3"}, - } diff --git a/tests/unit_tests/test_config_reader.py b/tests/unit_tests/test_config_reader.py new file mode 100644 index 0000000..50f328e --- /dev/null +++ b/tests/unit_tests/test_config_reader.py @@ -0,0 +1,45 @@ +import pathlib + +import pytest + +from maison import config_reader +from maison import errors +from maison import types + + +class FakePyprojectParser: + def parse_config(self, file_path: pathlib.Path) -> types.ConfigValues: + return {"config": "pyproject"} + + +class FakeTomlParser: + def parse_config(self, file_path: pathlib.Path) -> types.ConfigValues: + return {"config": "toml"} + + +class TestParsesConfig: + def setup_method(self): + self.reader = config_reader.ConfigReader() + + def test_uses_parser_by_file_path_and_stem(self): + self.reader.register_parser( + suffix=".toml", parser=FakePyprojectParser(), stem="pyproject" + ) + + values = self.reader.parse_config(pathlib.Path("path/to/pyproject.toml")) + + assert values == {"config": "pyproject"} + + def test_falls_back_to_suffix(self): + self.reader.register_parser( + suffix=".toml", parser=FakePyprojectParser(), stem="pyproject" + ) + self.reader.register_parser(suffix=".toml", parser=FakeTomlParser()) + + values = self.reader.parse_config(pathlib.Path("path/to/.acme.toml")) + + assert values == {"config": "toml"} + + def test_raises_error_if_no_parser_available(self): + with pytest.raises(errors.UnsupportedConfigError): + _ = self.reader.parse_config(pathlib.Path("path/to/.acme.toml")) diff --git a/tests/unit_tests/test_config_validator.py b/tests/unit_tests/test_config_validator.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit_tests/test_main.py b/tests/unit_tests/test_main.py deleted file mode 100644 index d0d786a..0000000 --- a/tests/unit_tests/test_main.py +++ /dev/null @@ -1,18 +0,0 @@ -"""Test cases for the __main__ module.""" - -import pytest -from typer.testing import CliRunner - -from maison import __main__ - - -@pytest.fixture -def runner() -> CliRunner: - """Fixture for invoking command-line interfaces.""" - return CliRunner() - - -def test_main_succeeds(runner: CliRunner) -> None: - """It exits with a status code of zero.""" - result = runner.invoke(__main__.app) - assert result.exit_code == 0 diff --git a/tests/unit_tests/test_service.py b/tests/unit_tests/test_service.py new file mode 100644 index 0000000..2fcd210 --- /dev/null +++ b/tests/unit_tests/test_service.py @@ -0,0 +1,113 @@ +import pathlib +import typing + +from maison import service as config_service +from maison import types + + +class FakeFileSystem: + def get_file_path( + self, file_name: str, starting_path: typing.Optional[pathlib.Path] = None + ) -> typing.Optional[pathlib.Path]: + if file_name == "not.exists": + return None + return pathlib.Path(f"/path/to/{file_name}") + + +class FakeConfigReader: + def parse_config(self, file_path: pathlib.Path) -> types.ConfigValues: + return { + "values": {file_path.stem: file_path.suffix}, + } + + +class Schema: + def model_dump(self) -> types.ConfigValues: + return {"key": "validated"} + + +class FakeValidator: + def validate( + self, values: types.ConfigValues, schema: type[config_service.IsSchema] + ) -> types.ConfigValues: + return schema().model_dump() + + +class TestFindConfigs: + def test_returns_iterator_of_config_paths(self): + service = config_service.ConfigService( + filesystem=FakeFileSystem(), + config_reader=FakeConfigReader(), + validator=FakeValidator(), + ) + + configs = service.find_configs( + source_files=["something.txt", "other.toml", "not.exists", "another.ini"] + ) + + assert list(configs) == [ + pathlib.Path("/path/to/something.txt"), + pathlib.Path("/path/to/other.toml"), + pathlib.Path("/path/to/another.ini"), + ] + + +class TestGetConfigValues: + @classmethod + def setup_class(cls): + cls.service = config_service.ConfigService( + filesystem=FakeFileSystem(), + config_reader=FakeConfigReader(), + validator=FakeValidator(), + ) + + def test_returns_dict_if_config_found(self): + config_dict = self.service.get_config_values( + config_file_paths=[pathlib.Path("config.toml")], + merge_configs=False, + ) + + assert config_dict == { + "values": {"config": ".toml"}, + } + + def test_returns_first_dict_if_not_merge_configs(self): + config_dict = self.service.get_config_values( + config_file_paths=[pathlib.Path("config.toml"), pathlib.Path("other.ini")], + merge_configs=False, + ) + + assert config_dict == {"values": {"config": ".toml"}} + + def test_merges_configs(self): + config_dict = self.service.get_config_values( + config_file_paths=[pathlib.Path("config.toml"), pathlib.Path("other.ini")], + merge_configs=True, + ) + + assert config_dict == { + "values": { + "config": ".toml", + "other": ".ini", + } + } + + +class TestValidate: + @classmethod + def setup_class(cls): + cls.service = config_service.ConfigService( + filesystem=FakeFileSystem(), + config_reader=FakeConfigReader(), + validator=FakeValidator(), + ) + + def test_validates_config(self): + values = self.service.get_config_values( + config_file_paths=[pathlib.Path("config.toml")], + merge_configs=False, + ) + + validated_values = self.service.validate_config(values=values, schema=Schema) + + assert validated_values == {"key": "validated"} diff --git a/tests/unit_tests/test_utils.py b/tests/unit_tests/test_utils.py index b998f8f..54e25df 100644 --- a/tests/unit_tests/test_utils.py +++ b/tests/unit_tests/test_utils.py @@ -1,90 +1,8 @@ """Tests for the `utils` module.""" -from pathlib import Path -from typing import Callable -from unittest.mock import MagicMock -from unittest.mock import patch - import pytest from maison.utils import deep_merge -from maison.utils import get_file_path -from maison.utils import path_contains_file - - -class TestContainsFile: - """Tests for the `contains_file` function""" - - def test_found(self, create_tmp_file: Callable[..., Path]) -> None: - """Return `True` if the path contains the file""" - path = create_tmp_file(filename="file.txt") - - result = path_contains_file(path=path.parent, filename="file.txt") - - assert result is True - - def test_not_found(self, create_tmp_file: Callable[..., Path]) -> None: - """Return `False` if the path does not contain the file""" - path = create_tmp_file(filename="file.txt") - - result = path_contains_file(path=path.parent, filename="other.txt") - - assert result is False - - -class TestGetFilePath: - """Tests for the `get_file_path`""" - - @patch("maison.utils.Path", autospec=True) - def test_in_current_directory( - self, mock_path: MagicMock, create_tmp_file: Callable[..., Path] - ) -> None: - """The path to a file is returned.""" - mock_path.return_value.expanduser.return_value.is_absolute.return_value = False - - path_to_file = create_tmp_file(filename="file.txt") - mock_path.cwd.return_value = path_to_file.parent - - result = get_file_path(filename="file.txt") - - assert result == path_to_file - - def test_in_parent_directory(self, create_tmp_file: Callable[..., Path]) -> None: - """The path to a file in a parent directory is returned.""" - path_to_file = create_tmp_file(filename="file.txt") - sub_dir = path_to_file / "sub" - - result = get_file_path(filename="file.txt", starting_path=sub_dir) - - assert result == path_to_file - - def test_not_found(self) -> None: - """If the file isn't found in the tree then return a `None`""" - result = get_file_path(filename="file.txt", starting_path=Path("/nowhere")) - - assert result is None - - def test_with_given_path(self, create_tmp_file: Callable[..., Path]) -> None: - """A `starting_path` can be used to initiate the starting search directory""" - path_to_file = create_tmp_file(filename="file.txt") - - result = get_file_path(filename="file.txt", starting_path=path_to_file) - - assert result == path_to_file - - def test_absolute_path(self, create_tmp_file: Callable[..., Path]) -> None: - """An absolute path to an existing file is returned""" - path_to_file = create_tmp_file(filename="file.txt") - - result = get_file_path(filename=str(path_to_file)) - - assert result == path_to_file - - def test_absolute_path_not_exist(self) -> None: - """If the absolute path doesn't exist return a `None`""" - result = get_file_path(filename="~/xxxx/yyyy/doesnotexist.xyz") - - assert result is None class TestDeepMerge: