diff --git a/signac/__main__.py b/signac/__main__.py index d6b9d0cfb..3fabf1d4f 100644 --- a/signac/__main__.py +++ b/signac/__main__.py @@ -187,7 +187,7 @@ def main_statepoint(args): """Handle statepoint subcommand.""" project = get_project() if args.job_id: - jobs = (_open_job_by_id(project, jid) for jid in args.job_id) + jobs = (_open_job_by_id(project, job_id) for job_id in args.job_id) else: jobs = project for job in jobs: @@ -197,6 +197,16 @@ def main_statepoint(args): print(json.dumps(job.statepoint(), indent=args.indent, sort_keys=args.sort)) +def main_neighbors(args): + """Handle the neighbors subcommand.""" + project = get_project() + if args.job_id: + jobs = (_open_job_by_id(project, job_id) for job_id in args.job_id) + for job in jobs: + nl = job._get_neighbors(ignore=args.ignore) + pprint({k: v for k, v in nl.items() if len(v) > 0}) + + def main_document(args): """Handle document subcommand.""" project = get_project() @@ -967,6 +977,25 @@ def main(): ) parser_statepoint.set_defaults(func=main_statepoint) + parser_neighbor = subparsers.add_parser( + "neighbors", description="Print the neighbors of the job" + ) + parser_neighbor.add_argument( + "job_id", + nargs="*", + type=str, + help="One or more job ids. The corresponding jobs must be initialized.", + ) + parser_neighbor.add_argument( + "--ignore", + nargs="+", + type=str, + default=[], + help="State point keys to ignore when finding neighbors. " + "Useful for state point parameters that change together.", + ) + parser_neighbor.set_defaults(func=main_neighbors) + parser_diff = subparsers.add_parser( "diff", description="Find the difference among job state points." ) @@ -974,6 +1003,7 @@ def main(): "job_id", nargs="*", type=str, + default=[], help="One or more job ids. The corresponding jobs must be initialized.", ) parser_diff.add_argument( diff --git a/signac/_neighbor.py b/signac/_neighbor.py new file mode 100644 index 000000000..b855eaa7e --- /dev/null +++ b/signac/_neighbor.py @@ -0,0 +1,328 @@ +# Copyright (c) 2025 The Regents of the University of Michigan. +# All rights reserved. +# This software is licensed under the BSD 3-Clause License. +from collections import defaultdict +from functools import partial +from typing import DefaultDict + +from ._search_indexer import _DictPlaceholder +from ._utility import ( + _dotted_dict_to_nested_dicts, + _nested_dicts_to_dotted_keys, + _to_hashable, +) +from .job import calc_id + + +def prepare_shadow_project(sp_cache, ignore: list): + r"""Build cache and mapping for shadow project, which comes from ignored keys. + + We use cache lookups for speedy searching. Ignoring a key creates a subset of jobs, now + identified with different job ids. Call it "shadow" job id because we're making a projection of + the project. + + We can map from the shadow job id to the actual job id in the use cases identified. + Raise ValueError if this mapping is ill defined. + + We can detect the neighbor list on the shadow project then map it back + to the real project. + + Parameters + ---------- + sp_cache, state point cache + ignore: list of str + State point keys to ignore, with nested keys specified in dotted key format + + Returns + ------- + shadow_map + A map from shadow job id to project job id + + shadow_cache + An in-memory state point cache for the shadow project that maps + shadow job id --> shadow state point, in dotted key format. + The shadow job id is computed from the nested key format with + the ignored keys removed. + + + Use cases: + + 1) Seed that is different for every job. + + 2) State point key that changes in sync with another key. + + Case 1: + + {"a": 1, "b": 2, "seed": 0} -> jobid1 + {"a": 1, "b": 3, "seed": 1} -> jobid2 + {"a": 1, "b": 2} -> shadowid1 + {"a": 1, "b": 3} -> shadowid2 + + shadowid1 <---> jobid1 + shadowid2 <---> jobid2 + + Breaking case 1 with repeated shadow jobs: + + {"a": 1, "b": 2, "seed": 0} -> jobid1 + {"a": 1, "b": 3, "seed": 1} -> jobid2 + {"a": 1, "b": 3, "seed": 2} -> jobid3 + + {"a": 1, "b": 2} -> shadowid1 + {"a": 1, "b": 3} -> shadowid2 + {"a": 1, "b": 3} -> shadowid2 *conflict* No longer bijection. + Now we have shadowid2 .---> jobid2 + \\--> jobid3 + + Case 2: + + {"a1": 10, "a2": 20} -> jobid1 + {"a1": 2, "a2": 4} -> jobid2 + + {"a1": 10} -> shadowid1 + {"a1": 2} -> shadowid2 + + Can still make the mapping between ids. + + Breaking case 2: + {"a1": 10, "a2": 20} -> jobid1 + {"a1": 2, "a2": 4} -> jobid2 + {"a1": 2, "a2": 5} -> jobid3 + + {"a1": 10} -> shadowid1 + {"a1": 2} -> shadowid2 + {"a1": 2} -> shadowid2 + Now we have shadowid2 .---> jobid2 + \\--> jobid3 + + """ + shadow_cache = {} # like a state point cache, but for the shadow project + job_projection = {} # goes from job id to shadow id + for jobid, sp in sp_cache.items(): + # Remove ignored keys while in dotted key format + shadow_sp_dotted = dict(_nested_dicts_to_dotted_keys(sp)) + for ig in ignore: + shadow_sp_dotted.pop(ig, None) + # id calculated from nested keys + shadow_id = calc_id(_dotted_dict_to_nested_dicts(shadow_sp_dotted)) + # The cache needs to be in dotted key format, so just convert it here + shadow_cache[shadow_id] = shadow_sp_dotted + job_projection[jobid] = shadow_id + + if len(set(job_projection.values())) != len(job_projection): + # Make a helpful error message for map that has duplicates + shadow_to_job = defaultdict(list) + counts: DefaultDict[str, int] = defaultdict(int) + for job_id, shadow_id in job_projection.items(): + shadow_to_job[shadow_id].append(job_id) + counts[shadow_id] += 1 + bad_jobids = [ + shadow_to_job[shadow_id] for shadow_id, num in counts.items() if num > 1 + ] + err_str = "\n".join(f"Job ids: {', '.join(j)}." for j in bad_jobids) + raise ValueError( + f"Ignoring {ignore} makes it impossible to distinguish some jobs:\n{err_str}" + ) + # invert the map to go from shadow job id to project job id + shadow_map = {v: k for k, v in job_projection.items()} + return shadow_map, shadow_cache + + +# key and other_val provided separately to be used with functools.partial +def _search_cache_for_val(statepoint, cache, key, other_val): + """Return job id of a job similar to statepoint if present in cache. + + The similar job is obtained by modifying statepoint to include {key: other_val}. + + Internally converts statepoint from dotted keys to nested dicts format. + + Parameters + ---------- + statepoint : dict + State point of job to modify. Statepoint must not be a reference because it will be + modified in this function. + cache : dict + Project state point cache to search in + key : str + The key whose value to change + other_val + The new value of key to search for + + Returns + ------- + Job id of similar job + None, if not present + """ + statepoint.update({key: other_val}) + # schema output not compatible with dotted key notation + statepoint = _dotted_dict_to_nested_dicts(statepoint) + other_job_id = calc_id(statepoint) + if other_job_id in cache: + return other_job_id + else: + return None + + +def _search_out(search_direction, values, current_index, boundary_index, search_fun): + """Search in values towards boundary_index from current_index using search_fun. + + Parameters + ---------- + search_direction : int, 1 or -1 + 1 means search in the positive direction from the index + values : iterable + Values to index into when searching + current_index : int + Index into values to start searching from. + The value at this index is not accessed directly. + boundary_index : int + The index at which to stop + search_fun : function + Unary function returning jobid if it exists and None otherwise + + Returns + ------- + None if jobid not found + + {val: jobid} if jobid found per search_fun + jobid : str + Job id of the nearest job in the search_direction + val : Value of the key at the neighbor jobid + """ + query_index = current_index + search_direction + # search either query_index >= low_boundary or query_index <= high_boundary + while search_direction * query_index <= boundary_index * search_direction: + val = values[query_index] + jobid = search_fun(val) + if jobid is None: + query_index += search_direction + else: + return {val: jobid} + return None + + +def neighbors_of_sp(statepoint, dotted_sp_cache, sorted_schema): + """Return neighbors of given state point by searching along sorted_schema in dotted_sp_cache. + + State point and cache must both use either job ids or shadow job ids. + + Statepoint and dotted_sp_cache must be in dotted key format, which is accessed by calling + _nested_dicts_to_dotted_keys on each state point in the cache. + + Parameters + ---------- + statepoint : dict + State point to start search from, in dotted key format + dotted_sp_cache : dict + Map from job id to state point in dotted key format + sorted_schema : dict + Map from key (in dotted notation) to sorted values of the key to search over + """ + neighbors = {} + for key, schema_values in sorted_schema.items(): # from project + # allow comparison with output of schema, which is hashable + # and which is in dotted key format + value = _to_hashable(statepoint.get(key, _DictPlaceholder)) + if value is _DictPlaceholder: + # Possible if schema is heterogeneous + continue + value_index = schema_values.index(value) + # need to pass statepoint by copy + search_fun = partial( + _search_cache_for_val, dict(statepoint), dotted_sp_cache, key + ) + prev_neighbor = _search_out(-1, schema_values, value_index, 0, search_fun) + next_neighbor = _search_out( + 1, schema_values, value_index, len(schema_values) - 1, search_fun + ) + + this_d = {} + if prev_neighbor is not None: + this_d.update(prev_neighbor) + if next_neighbor is not None: + this_d.update(next_neighbor) + neighbors.update({key: this_d}) + return neighbors + + +def shadow_neighbors_to_neighbors(shadow_neighbors, shadow_map): + """Replace shadow job ids with actual job ids in the neighbors of one job. + + Parameters + ---------- + shadow_neighbors : dict of state point parameters to neighbor values to shadow job id + Neighbors containing shadow job ids + shadow_map : dict + Map from shadow job id to project job id + """ + neighbors = {} + for neighbor_key, neighbor_vals in shadow_neighbors.items(): + neighbors[neighbor_key] = {k: shadow_map[i] for k, i in neighbor_vals.items()} + return neighbors + + +def shadow_neighbor_list_to_neighbor_list(shadow_neighbor_list, shadow_map): + """Replace shadow job ids with actual job ids in the neighbor list. + + Parameters + ---------- + shadow_neighbor_list : dict + `neighbor_list` containing shadow job ids. + dict of shadow job ids to state point parameters to neighbor values to shadow job id + shadow_map : dict + Map from shadow job id to project job id + """ + neighbor_list = {} + for jobid, shadow_neighbors in shadow_neighbor_list.items(): + neighbor_list[shadow_map[jobid]] = shadow_neighbors_to_neighbors( + shadow_neighbors, shadow_map + ) + return neighbor_list + + +def _build_neighbor_list(dotted_sp_cache, sorted_schema): + """Iterate over cached state points and get neighbors of each state point. + + Parameters + ---------- + dotted_sp_cache : dict + Map from job id to state point OR shadow job id to shadow state point in dotted key format + sorted_schema : dict + Map of dotted keys to their values to search over + + Returns + ------- + neighbor_list : dict + {jobid: {state_point_key: {prev_value: neighbor_id, next_value: neighbor_id}}} + """ + neighbor_list = {} + for _id, _sp in dotted_sp_cache.items(): + neighbor_list[_id] = neighbors_of_sp(_sp, dotted_sp_cache, sorted_schema) + return neighbor_list + + +def get_neighbor_list(sp_cache, sorted_schema, ignore): + """Build neighbor list while handling ignored keys. + + Parameters + ---------- + sp_cache : dict + Project state point cache + sorted_schema : dict + Map of dotted keys to their values to search over + + Returns + ------- + neighbor_list : dict + {jobid: {state_point_key: {prev_value: neighbor_id, next_value: neighbor_id}}} + """ + if len(ignore) > 0: + shadow_map, shadow_cache = prepare_shadow_project(sp_cache, ignore=ignore) + nl = _build_neighbor_list(shadow_cache, sorted_schema) + return shadow_neighbor_list_to_neighbor_list(nl, shadow_map) + else: + # the state point cache needs to be in dotted keys to enable searching over schema values + sp_cache = { + _id: dict(_nested_dicts_to_dotted_keys(_sp)) + for _id, _sp in sp_cache.items() + } + return _build_neighbor_list(sp_cache, sorted_schema) diff --git a/signac/job.py b/signac/job.py index 6a92038b5..89696303b 100644 --- a/signac/job.py +++ b/signac/job.py @@ -9,6 +9,7 @@ import logging import os import shutil +import warnings from copy import deepcopy from threading import RLock from types import MappingProxyType @@ -979,6 +980,82 @@ def close(self): except IndexError: pass + def _get_neighbors(self, ignore=[]): + """Return the neighbors of this job, mainly for command line use. + + Use `Project.get_neighbors()` to get the neighbors of all jobs in the project. + + The neighbors of a job are jobs that differ along one state point parameter. + + Job neighbors are provided in a dictionary containing + {state_point_key: {prev_value: neighbor_id, next_value: neighbor_id}, ...}, + where `state_point_key` is each of the non-constant state point parameters in the project + (equivalent to the output of `project.detect_schema(exclude_const = True)`). For nested + state point keys, the state point key is in "dotted key" notation, like the output of + `detect_schema`. + + Along each state_point_key, a job can have 0, 1 or 2 neighbors. For 0 neighbors, the job + neighbors dictionary is empty. For 2 neighbors, the neighbors are in sort order. State point + values of different types are ordered by their type name. + + If neighbors are not being detected correctly, it is likely that there are several state + point parameters changing together. In this case, pass a list of state point parameters to + ignore to the `ignore` argument. If a state point value is a dictionary (a "nested key"), + then the ignore list must be specified in "dotted key" notation. + + Parameters + ---------- + ignore : list + List of state point parameters to ignore when building neighbor list + + Returns + ------- + neighbors : dict + A map of state point key to 0-2 neighbor values (or none) to job ids (see above) + """ + from ._neighbor import ( + neighbors_of_sp, + prepare_shadow_project, + shadow_neighbors_to_neighbors, + ) + from ._search_indexer import _DictPlaceholder + from ._utility import _nested_dicts_to_dotted_keys + + if not isinstance(ignore, list): + ignore = [ignore] + + sp_cache = self._project._sp_cache + + sorted_schema = self._project._flat_schema() + sp = dict(_nested_dicts_to_dotted_keys(self.cached_statepoint)) + need_to_ignore = [sorted_schema.pop(ig, _DictPlaceholder) for ig in ignore] + if any(is_bad_key := [a is _DictPlaceholder for a in need_to_ignore]): + # any uses up the iterator + from itertools import compress + + bad_keys = list(compress(ignore, is_bad_key)) + warnings.warn( + f"Ignored state point parameter{'s' if len(bad_keys) > 1 else ''} {bad_keys}" + " not present in project.", + RuntimeWarning, + ) + for bad_key in bad_keys: + ignore.remove(bad_key) + + if len(ignore) > 0: + for i in ignore: + sp.pop(i, None) + shadow_map, shadow_cache = prepare_shadow_project(sp_cache, ignore=ignore) + neighbors = neighbors_of_sp(sp, shadow_cache, sorted_schema) + neighbors = shadow_neighbors_to_neighbors(neighbors, shadow_map) + else: + sp_cache = { + _id: dict(_nested_dicts_to_dotted_keys(_sp)) + for _id, _sp in sp_cache.items() + } + neighbors = neighbors_of_sp(sp, sp_cache, sorted_schema) + return neighbors + def __enter__(self): self.open() return self diff --git a/signac/project.py b/signac/project.py index 18bc80030..fd2167feb 100644 --- a/signac/project.py +++ b/signac/project.py @@ -17,7 +17,7 @@ from contextlib import contextmanager from copy import deepcopy from datetime import timedelta -from itertools import groupby +from itertools import compress, groupby from multiprocessing.pool import ThreadPool from tempfile import TemporaryDirectory from threading import RLock @@ -32,8 +32,12 @@ _raise_if_older_schema, _read_config_file, ) -from ._search_indexer import _SearchIndexer -from ._utility import _mkdir_p, _nested_dicts_to_dotted_keys +from ._neighbor import get_neighbor_list +from ._search_indexer import _DictPlaceholder, _SearchIndexer +from ._utility import ( + _mkdir_p, + _nested_dicts_to_dotted_keys, +) from .errors import ( DestinationExistsError, IncompatibleSchemaVersion, @@ -1322,8 +1326,9 @@ def repair(self, job_ids=None): os.replace(invalid_wd, correct_wd) except OSError as error: logger.critical( - "Unable to fix location of job with " - " id '{}': '{}'.".format(job_id, error) + "Unable to fix location of job with id '{}': '{}'.".format( + job_id, error + ) ) corrupted.append(job_id) continue @@ -1342,8 +1347,9 @@ def repair(self, job_ids=None): job.init() except Exception as error: logger.error( - "Error during initialization of job with " - "id '{}': '{}'.".format(job_id, error) + "Error during initialization of job with id '{}': '{}'.".format( + job_id, error + ) ) try: # Attempt to fix the job state point file. job.init(force=True) @@ -1653,6 +1659,89 @@ def __setstate__(self, state): state["_lock"] = RLock() self.__dict__.update(state) + def _flat_schema(self, exclude_const=False): + """For each state point parameter, make a flat list sorted by its values in the project. + + This is almost like schema, but the schema separates items by type. + To sort between different types, put in order of the name of the type + """ + schema = self.detect_schema(exclude_const=exclude_const) + sorted_schema = {} + for key, schema_values in schema.items(): + tuples_to_sort = [] + for type_name in schema_values: + tuples_to_sort.append( + (type_name.__name__, sorted(schema_values[type_name])) + ) + combined_values = [] + for _, v in sorted(tuples_to_sort, key=lambda x: x[0]): + combined_values.extend(v) + sorted_schema[key] = combined_values + return sorted_schema + + def get_neighbors(self, ignore=[]): + """Return a map of job ids to job neighbors. + + The neighbors of a job are jobs that differ along one state point parameter. + + Job neighbors are provided in a dictionary containing + {state_point_key: {prev_value: neighbor_id, next_value: neighbor_id}, ...}, + where `state_point_key` is each of the non-constant state point parameters in the project + (equivalent to the output of `project.detect_schema(exclude_const = True)`). For nested + state point keys, the state point key is in "dotted key" notation, like the output of + `detect_schema`. + + Along each state_point_key, a job can have 0, 1 or 2 neighbors. For 0 neighbors, the job + neighbors dictionary is empty. For 2 neighbors, the neighbors are in sort order. State point + values of different types are ordered by their type name. + + If neighbors are not being detected correctly, it is likely that there are several state + point parameters changing together. In this case, pass a list of state point parameters to + ignore to the `ignore` argument. If a state point value is a dictionary (a "nested key"), + then the ignore list must be specified in "dotted key" notation. + + Parameters + ---------- + ignore : list of str + List of keys to ignore when building neighbor list + + Returns + ------- + neighbor_list : dict + A map of job id to job neighbors (see above). + + Example + ------- + .. code-block:: python + + neighbor_list = project.get_neighbors() + for job in project: + neighbors = neighbor_list[job.id] + print(f"Job {job.id}") + for key,v in job.sp.items(): + print(f"has {key}={v} with neighbor jobs {key}-->{f" and {key}-->".join( + f"{new_val} at job id {jid}" for new_val,jid in neighbors[key].items())}") + + """ + if not isinstance(ignore, list): + ignore = [ignore] + + sorted_schema = self._flat_schema(exclude_const=True) + need_to_ignore = [sorted_schema.pop(ig, _DictPlaceholder) for ig in ignore] + if any(is_bad_key := [a is _DictPlaceholder for a in need_to_ignore]): + bad_keys = list(compress(ignore, is_bad_key)) + warnings.warn( + f"Ignored state point parameter{'s' if len(bad_keys) > 1 else ''} {bad_keys}" + " not present in project.", + RuntimeWarning, + ) + for bad_key in bad_keys: + ignore.remove(bad_key) + + self.update_cache() + # pass a copy of cache + return get_neighbor_list(dict(self._sp_cache), sorted_schema, ignore) + @contextmanager def TemporaryProject(cls=None, **kwargs): diff --git a/tests/test_neighborlist.py b/tests/test_neighborlist.py new file mode 100644 index 000000000..de0a00c92 --- /dev/null +++ b/tests/test_neighborlist.py @@ -0,0 +1,235 @@ +from itertools import product + +import pytest +from test_project import TestProject + + +class TestNeighborList(TestProject): + def test_neighbors(self): + a_vals = [1, 2] + b_vals = [3, 4, 5] + for a, b in product(a_vals, b_vals): + self.project.open_job({"a": a, "b": b}).init() + + neighbor_list = self.project.get_neighbors() + + for a, b in product(a_vals, b_vals): + job = self.project.open_job({"a": a, "b": b}) + neighbors_job = job._get_neighbors() + with pytest.warns(RuntimeWarning, match="not_present"): + job._get_neighbors(ignore=["not_present"]) + + neighbors_project = neighbor_list[job.id] + assert neighbors_project == neighbors_job + + this_neighbors = neighbors_project + + # a neighbors + if a == 1: + assert ( + this_neighbors["a"][2] == self.project.open_job({"a": 2, "b": b}).id + ) + elif a == 2: + assert ( + this_neighbors["a"][1] == self.project.open_job({"a": 1, "b": b}).id + ) + + # b neighbors + if b == 3: + assert ( + this_neighbors["b"][4] == self.project.open_job({"a": a, "b": 4}).id + ) + elif b == 4: + assert ( + this_neighbors["b"][3] == self.project.open_job({"a": a, "b": 3}).id + ) + assert ( + this_neighbors["b"][5] == self.project.open_job({"a": a, "b": 5}).id + ) + elif b == 5: + assert ( + this_neighbors["b"][4] == self.project.open_job({"a": a, "b": 4}).id + ) + with pytest.warns(RuntimeWarning, match="not_present"): + self.project.get_neighbors(ignore=["not_present"]) + + def test_neighbors_ignore(self): + b_vals = [3, 4, 5] + for b in b_vals: + self.project.open_job({"b": b, "2b": 2 * b}).init() + + neighbor_list = self.project.get_neighbors(ignore="2b") + + for b in b_vals: + job = self.project.open_job({"b": b, "2b": 2 * b}) + neighbors_job = job._get_neighbors(ignore=["2b"]) + + this_neighbors = neighbor_list[job.id] + assert this_neighbors == neighbors_job + + if b == 3: + assert ( + this_neighbors["b"][4] + == self.project.open_job({"b": 4, "2b": 8}).id + ) + elif b == 4: + assert ( + this_neighbors["b"][3] + == self.project.open_job({"b": 3, "2b": 6}).id + ) + assert ( + this_neighbors["b"][5] + == self.project.open_job({"b": 5, "2b": 10}).id + ) + elif b == 5: + assert ( + this_neighbors["b"][4] + == self.project.open_job({"b": 4, "2b": 8}).id + ) + + def test_neighbors_ignore_nested(self): + a_vals = [{"b": 2, "c": 2}, {"b": 3, "c": 3}] + for a in a_vals: + self.project.open_job({"a": a}).init() + + neighbor_list = self.project.get_neighbors(ignore="a.b") + + for a in a_vals: + job = self.project.open_job({"a": a}) + neighbors_job = job._get_neighbors(ignore="a.b") + + c = a["c"] + + this_neighbors = neighbor_list[job.id] + assert this_neighbors == neighbors_job + + if c == 2: + assert ( + this_neighbors["a.c"][3] + == self.project.open_job({"a": {"b": 3, "c": 3}}).id + ) + elif c == 3: + assert ( + this_neighbors["a.c"][2] + == self.project.open_job({"a": {"b": 2, "c": 2}}).id + ) + + def test_neighbors_nested(self): + a_vals = [{"c": 2}, {"c": 3}, {"c": 4}, {"c": "5"}, {"c": "hello"}] + for a in a_vals: + self.project.open_job({"a": a}).init() + + neighbor_list = self.project.get_neighbors() + + for a in a_vals: + job = self.project.open_job({"a": a}) + neighbors_job = job._get_neighbors() + + c = a["c"] + + this_neighbors = neighbor_list[job.id] + assert this_neighbors == neighbors_job + # note how the inconsistency in neighborlist access syntax comes from schema + if c == 2: + assert ( + this_neighbors["a.c"][3] + == self.project.open_job({"a": {"c": 3}}).id + ) + elif c == 3: + assert ( + this_neighbors["a.c"][2] + == self.project.open_job({"a": {"c": 2}}).id + ) + assert ( + this_neighbors["a.c"][4] + == self.project.open_job({"a": {"c": 4}}).id + ) + elif c == 4: + assert ( + this_neighbors["a.c"][3] + == self.project.open_job({"a": {"c": 3}}).id + ) + assert ( + this_neighbors["a.c"]["5"] + == self.project.open_job({"a": {"c": "5"}}).id + ) + elif c == "5": + assert ( + this_neighbors["a.c"][4] + == self.project.open_job({"a": {"c": 4}}).id + ) + assert ( + this_neighbors["a.c"]["hello"] + == self.project.open_job({"a": {"c": "hello"}}).id + ) + + def test_neighbors_disjoint_ignore(self): + for a, b in product([1, 2, 3], [5, 6, 7]): + self.project.open_job({"a": a, "b": b, "2b": 2 * b}).init() + for x in [{"n": "nested"}, {"n": "values"}]: + self.project.open_job({"x": x}).init() + + neighbor_list = self.project.get_neighbors(ignore=["2b"]) + + job = self.project.open_job({"x": {"n": "nested"}}) + neighbors_job = job._get_neighbors(ignore=["2b"]) + neighbors_project = neighbor_list[job.id] + + assert neighbors_project == neighbors_job + assert ( + neighbors_project["x.n"]["values"] + == self.project.open_job({"x": {"n": "values"}}).id + ) + + def test_neighbors_varied_types(self): + # in sort order + # NoneType is first because it's capitalized + a_vals = [None, False, True, 1.2, 1.3, 2, "1", "2", "x", "y", (3, 4), (5, 6)] + + job_ids = [] + for a in a_vals: + job = self.project.open_job({"a": a}).init() + job_ids.append(job.id) + + neighbor_list = self.project.get_neighbors() + + for i, a in enumerate(a_vals): + jobid = job_ids[i] + job = self.project.open_job(id=jobid) + + neighbors_job = job._get_neighbors() + this_neighbors = neighbor_list[jobid] + assert this_neighbors == neighbors_job + if i > 0: + prev_val = a_vals[i - 1] + assert this_neighbors["a"][prev_val] == job_ids[i - 1] + if i < len(a_vals) - 1: + next_val = a_vals[i + 1] + assert this_neighbors["a"][next_val] == job_ids[i + 1] + + def test_neighbors_no(self): + self.project.open_job({"a": 1}).init() + self.project.open_job({"b": 1}).init() + neighbor_list = self.project.get_neighbors() + + for job in self.project: + for v in neighbor_list[job.id].values(): + assert len(v) == 0 + for v in job._get_neighbors().values(): + assert len(v) == 0 + + def test_neighbors_ignore_dups(self): + a_vals = [1, 2] + b_vals = [3, 4, 5] + for a, b in product(a_vals, b_vals): + self.project.open_job({"a": a, "b": b}).init() + # match with single quote to avoid matching on the a in "makes" + with pytest.raises(ValueError, match="'a'"): + self.project.get_neighbors(ignore="a") + with pytest.raises(ValueError, match="'b'"): + self.project.get_neighbors(ignore="b") + for job in self.project: + with pytest.raises(ValueError, match="'a'"): + job._get_neighbors(ignore="a") + with pytest.raises(ValueError, match="'b'"): + job._get_neighbors(ignore="b") diff --git a/tests/test_project.py b/tests/test_project.py index 46957a392..e2679948e 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -2373,6 +2373,10 @@ def test_no_migration(self): assert len(migrations) == 0 +class TestProjectNeighbors(TestProjectBase): + pass + + def _initialize_v1_project(dirname, with_workspace=True, with_other_files=True): # Create v1 config file. cfg_fn = os.path.join(dirname, "signac.rc") diff --git a/tests/test_shell.py b/tests/test_shell.py index 172fb1e45..e906090bd 100644 --- a/tests/test_shell.py +++ b/tests/test_shell.py @@ -6,6 +6,7 @@ import shutil import subprocess import sys +from itertools import product from tempfile import TemporaryDirectory import pytest @@ -425,6 +426,41 @@ def test_schema(self): out = self.call("python -m signac schema".split()) assert s.format() == out.strip().replace(os.linesep, "\n") + def test_neighbors_ignore_nested(self): + self.call("python -m signac init".split()) + project = signac.Project() + a_vals = [{"b": 2, "c": 2}, {"b": 3, "c": 3}] + for a in a_vals: + project.open_job({"a": a}).init() + neighbor_list = project.get_neighbors(ignore="a.b") + for job in project: + out = self.call(f"python -m signac neighbors {job.id} --ignore a.b".split()) + assert str(neighbor_list[job.id]) in out + + def test_neighbors_ignore_not_present(self): + self.call("python -m signac init".split()) + project = signac.Project() + job = project.open_job({"a": 1}).init() + out = self.call( + f"python -m signac neighbors {job.id} --ignore not_in_project".split(), + error=True, + ) + assert "not_in_project" in out + assert "not present" in out + + def test_neighbors_ignore(self): + self.call("python -m signac init".split()) + project = signac.Project() + for a, b in product([1, 2], [2, 3]): + job = project.open_job({"a": a, "b": b}).init() + out = self.call( + f"python -m signac neighbors {job.id} --ignore b".split(), + error=True, + raise_error=False, + ) + assert "impossible to distinguish" in out + assert "'b'" in out + def test_sync(self): project_b = signac.init_project(path=os.path.join(self.tmpdir.name, "b")) self.call("python -m signac init".split())