Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion mdbenchmark/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
validate_number_of_nodes,
)

from mdbenchmark.config import _add_single_context, _import_config_into_context


@click.group(cls=AliasedGroup)
@click.version_option()
Expand Down Expand Up @@ -86,6 +88,13 @@ def analyze(directory, plot, ncores, save_csv):


@cli.command()
@click.option(
"--config",
"config_file",
help="Parse settings from config file instead of command-line. Ignores all other options.",
callback=_import_config_into_context,
type=click.Path(exists=True),
)
@click.option(
"-n",
"--name",
Expand All @@ -99,13 +108,15 @@ def analyze(directory, plot, ncores, save_csv):
help="Use CPUs for benchmark.",
default=True,
show_default=True,
callback=_add_single_context,
)
@click.option(
"-g/-ng",
"--gpu/--no-gpu",
is_flag=True,
help="Use GPUs for benchmark.",
show_default=True,
callback=_add_single_context,
)
@click.option(
"-m",
Expand All @@ -128,20 +139,23 @@ def analyze(directory, plot, ncores, save_csv):
help="Minimal number of nodes to request.",
default=1,
show_default=True,
callback=_add_single_context,
type=int,
)
@click.option(
"--max-nodes",
help="Maximal number of nodes to request.",
default=5,
show_default=True,
callback=_add_single_context,
type=int,
)
@click.option(
"--time",
help="Run time for benchmark in minutes.",
default=15,
show_default=True,
callback=_add_single_context,
type=click.IntRange(1, 1440),
)
@click.option(
Expand All @@ -156,15 +170,22 @@ def analyze(directory, plot, ncores, save_csv):
"--skip-validation",
help="Skip the validation of module names.",
default=False,
callback=_add_single_context,
is_flag=True,
)
@click.option(
"--job-name", help="Give an optional to the generated benchmarks.", default=None
)
@click.option(
"-y", "--yes", help="Answer all prompts with yes.", default=False, is_flag=True
"-y",
"--yes",
help="Answer all prompts with yes.",
default=False,
callback=_add_single_context,
is_flag=True,
)
def generate(
config_file,
name,
cpu,
gpu,
Expand Down
9 changes: 9 additions & 0 deletions mdbenchmark/cli/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

def validate_name(ctx, param, name=None):
"""Validate that we are given a name argument."""
# Fetch name from context read in config file
name = ctx.params.get("name", name)

if name is None:
raise click.BadParameter(
"Please specify the base name of your input files.",
Expand All @@ -16,6 +19,9 @@ def validate_name(ctx, param, name=None):

def validate_module(ctx, param, module=None):
"""Validate that we are given a module argument."""
# Fetch module from context read in config file
module = ctx.params.get("module", module)

if module is None or not module:
raise click.BadParameter(
"Please specify which MD engine module to use for the benchmarks.",
Expand Down Expand Up @@ -63,6 +69,9 @@ def validate_hosts(ctx, param, host=None):
templates. If the hostname matches the template name, we continue by
returning the hostname.
"""
# Fetch host from context read in config file
host = ctx.params.get("host", host)

if host is None:
host = utils.guess_host()
if host is None:
Expand Down
71 changes: 71 additions & 0 deletions mdbenchmark/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from tomlkit import parse
from tomlkit.exceptions import ParseError

from mdbenchmark import console

CONFIG_KEY_TO_CTX = {"skip_prompts": "yes", "input": "name", "modules": "module"}
ALLOWED_CONFIG_KEYS = [
"input",
"job_name",
"modules",
"skip_validation",
"min_nodes",
"max_nodes",
"time",
"host",
"cpu",
"gpu",
"skip_prompts",
]


def parse_config(toml_file):
"""
Open config file and parse its content.
"""
with open(toml_file, "r") as f:
content = "".join(f.readlines())

try:
parsed = parse(content)
except ParseError as e:
console.error(
"{filename}: {error}".format(filename=toml_file, error=e.__str__())
)

return parsed


def _add_single_context(ctx, param, value):
"""
Get a value from the context, otherwise set value defined by user.
"""
return ctx.params.get(param.name, value)


def _import_config_into_context(ctx, param, config_file):
"""
Parse config file and put settings into click.Context.
"""
console.info(
'Using settings from config file "{config_file}".'.format(
config_file=config_file
)
)

parsed_config = parse_config(config_file)

for key in parsed_config.keys():
# Ignore invalid keys from config file
if key not in ALLOWED_CONFIG_KEYS:
console.info('Ignoring setting for unknown key "{key}".'.format(key=key))
continue

try:
ctx_key = CONFIG_KEY_TO_CTX[key]
except KeyError:
ctx_key = key

ctx.params[ctx_key] = parsed_config[key]

return config_file
Loading