Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/snippets/yaml_config/common-handlers.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
console:
class: logging.StreamHandler
formatter: simple
stream: ext://sys.stdout
file_handler:
class: logging.FileHandler
filename: ${LOGGING_ROOT:.}/${LOG_FILENAME}
formatter: simple
15 changes: 15 additions & 0 deletions docs/snippets/yaml_config/logging-include.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
version: 1
formatters:
simple:
format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
handlers: !include common-handlers.yaml
loggers:
test_logger:
level: DEBUG
handlers:
- file_handler
propagate: no
root:
level: NOTSET
handlers:
- console
16 changes: 16 additions & 0 deletions docs/snippets/yaml_config/test_logger_include.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import logging
from logging_.config import YAMLConfig

with open("logging-include.yaml", "r") as config_file:
YAMLConfig(config_file.read(), silent=True)

# alternatively, you can use
# YAMLConfig.from_file("logging.yaml", silent=True)

logger = logging.getLogger("test_logger")

logger.debug("This is a debug log")
logger.info("This is an info log")
logger.warning("This is an warning log")
logger.error("This is an error log")
logger.critical("This is a critical log")
93 changes: 84 additions & 9 deletions logging_/config/yaml_config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# -*- coding: utf-8 -*-
import logging.config
import os
import re
import logging.config
from typing import Any
from os.path import realpath, splitext, join as join_path, curdir, isfile, abspath, commonpath
from typing import Any, Union, TextIO, Optional

import yaml
from yaml.parser import ParserError
Expand Down Expand Up @@ -37,7 +38,7 @@ class YAMLConfig(object):
_envvar_tag_matcher = re.compile(r"[^$]*\${([^}^{]+)}.*")
_uservar_tag_matcher = re.compile(r"^~(\w*?)/")

def __init__(self, config_yaml: str, **kwargs: Any):
def __init__(self, config_yaml: Union[TextIO, str], rootpath: Optional[str] = None, unsafe: bool = False, include: bool = True, **kwargs: Any):
"""Instantiates an YAMLConfig object from configuration string.

Registers implicit resolver for custom tag envvar and adds constructor for the tag. Loads logging config from
Expand All @@ -53,17 +54,45 @@ def __init__(self, config_yaml: str, **kwargs: Any):
ValueError: if required fields are missing in YAML string, ignored if ``silent=True``.
TypeError: if empty YAML string is provided, ignored if ``silent=True``.
"""
self.include = include
self.unsafe = unsafe
self._rootpath = rootpath

if not self._rootpath:
self._rootpath = realpath(curdir)
yaml.add_implicit_resolver("!envvar", self._envvar_tag_matcher, None, yaml.SafeLoader)
yaml.add_constructor("!envvar", self._envvar_constructor, yaml.SafeLoader)
yaml.add_implicit_resolver("!uservar", self._uservar_tag_matcher, None, yaml.SafeLoader)
yaml.add_constructor("!uservar", self._uservar_constructor, yaml.SafeLoader)
yaml.add_constructor("!include", self._ctor_include, yaml.SafeLoader) # noqa
try:
logging.config.dictConfig(yaml.safe_load(config_yaml))
except (ParserError, ValueError, TypeError):
if kwargs.get("silent", False) is not True:
raise

def safe_path(self, child_path: str, allow_symlinks: bool = True):
"""Security check: ensure that a file is at or below the directory path of the primary YaML file

Parameters
----------
child_path : str
Path of the file to be included
allow_symlinks : bool, default=True
If True, don't resolve symlinks

Returns
-------
bool
True if safe, False if unsafe
"""
root_path = abspath(self._rootpath)
child_path = abspath(child_path)
if allow_symlinks is False:
root_path = realpath(root_path)
child_path = realpath(child_path)
return commonpath([root_path]) == commonpath([root_path, child_path])

@classmethod
def from_file(cls, filename: str, **kwargs: Any):
"""Creates an instance from YAML configuration file.
Expand All @@ -82,19 +111,17 @@ def from_file(cls, filename: str, **kwargs: Any):
"""

try:
with open(filename, "r") as f:
return cls(f.read(), **kwargs)
with open(filename, "r") as fd:
return cls(fd.read(), **kwargs)
except (FileNotFoundError, PermissionError):
if kwargs.get("silent", False) is not True:
raise
else:
return cls("", **kwargs)
return cls("", **kwargs)

def _envvar_constructor(self, _loader: Any, node: Any):
"""Replaces environment variable name with its value, or a default."""

def replace_fn(match):
print(match.group(0))
envparts = f"{match.group(1)}:".split(":")
return os.environ.get(envparts[0], envparts[1])

Expand All @@ -103,5 +130,53 @@ def replace_fn(match):
@staticmethod
def _uservar_constructor(_loader: Any, node: Any):
"""Expands ~ and ~username into user's home directory like shells do."""

return os.path.expanduser(node.value)

def _ctor_include(self, _loader: Any, node: Any) -> Any:
"""Dynamically load the contents of an external YaML or flat file in-place

Notes
-----
Invoked internally via PyYaML when encountering a custom include tag

Parameters
----------
node:
PyYaML YaML node object

Returns
-------
Any
If node.value has a yaml/yml extension, structured data from the specified YaML file
If node.value has a list/lst extension, a list of strings, one per-line in the specified file
Otherwise, the contents of the specified file, with newlines removed
"""
if self.include is False:
raise RuntimeError('Attempting to use !include when includes are explicitly disabled!')
base_file = _loader.construct_scalar(node)
include_file = realpath(join_path(self._rootpath, base_file))
extension = splitext(include_file)[1].lstrip('.').lower()

if self.unsafe is False and not self.safe_path(include_file):
# Use the UnsafeExtLoader to get around this check
raise RuntimeError(f'Trying to include unsafe YaML file (outside of root directory {self._rootpath})')

print('Trying to include file %s @ %s' % (base_file, include_file))
if not isfile(include_file):
include_file = join_path(self._rootpath, 'include', base_file)
print('Trying to find %s @ %s' % (base_file, include_file))
if not isfile(include_file):
raise OSError(f'Unable to find include file {base_file}')

with open(include_file, mode='r') as fd:
# If YaML, load as YaML; otherwise load the file contents as a scalar string value, trimming out newlines
if extension.lower() in ('yml', 'yaml'):
# Return as fully structured data
return yaml.load(fd, yaml.SafeLoader)
if extension.lower() in ('lst', 'list'):
# File is a flat list with one element per-line
# Return as a list object, preserving order, filtering out blank lines / whitespace-only lines
# Also skip lines starting with '#' to support comments in flat list files
return [line for line in fd.read().splitlines() if line.strip() and not line.startswith('#')]
# Fallback: Return a single line blob, removing newlines if there are any
return ''.join(fd.readlines())