diff --git a/src/DataValidation.py b/src/DataValidation.py index d9c0970..8775977 100644 --- a/src/DataValidation.py +++ b/src/DataValidation.py @@ -31,35 +31,31 @@ def checkItemNamesInLocationRequires(): if item.lower() == "or" or item.lower() == "and" or item == ")" or item == "(": continue else: - # if it's a category, validate that the category exists - if '@' in item: - item = item.replace("|", "") - item_parts = item.split(":") - item_name = item + is_category = '|@' in item + item: str = item.lstrip('|@').rstrip('|') + item_parts = item.rsplit(":", 1) + item_name = item - if len(item_parts) > 1: - item_name = item_parts[0] + if len(item_parts) > 1: + item_name = item_parts[0] + item_count = item_parts[1] + if not item_count.isnumeric() and item_count not in ["all", "half"] and not item_count.endswith('%'): + logging.debug(f'Invalid item_count "{item_count}" found, reverting to initial item_name "{item}"') + item_name = item - item_name = item_name[1:] + # if it's a category, validate that the category exists + if is_category: item_category_exists = len([item for item in DataValidation.item_table if item_name in item.get('category', [])]) > 0 if not item_category_exists: raise ValidationError("Item category %s is required by location %s but is misspelled or does not exist." % (item_name, location.get("name"))) continue + else: + item_exists = len([item.get("name") for item in DataValidation.item_table_with_events if item.get("name") == item_name]) > 0 - item = item.replace("|", "") - - item_parts = item.split(":") - item_name = item - - if len(item_parts) > 1: - item_name = item_parts[0] - - item_exists = len([item.get("name") for item in DataValidation.item_table_with_events if item.get("name") == item_name]) > 0 - - if not item_exists: - raise ValidationError("Item %s is required by location %s but is misspelled or does not exist." % (item_name, location.get("name"))) + if not item_exists: + raise ValidationError("Item %s is required by location %s but is misspelled or does not exist." % (item_name, location.get("name"))) else: # item access is in dict form for item in location["requires"]: @@ -107,35 +103,30 @@ def checkItemNamesInRegionRequires(): if item.lower() == "or" or item.lower() == "and" or item == ")" or item == "(": continue else: - # if it's a category, validate that the category exists - if '@' in item: - item = item.replace("|", "") - item_parts = item.split(":") - item_name = item + is_category = '|@' in item + item: str = item.lstrip('|@').rstrip('|') + item_parts = item.rsplit(":", 1) + item_name = item - if len(item_parts) > 1: - item_name = item_parts[0] + if len(item_parts) > 1: + item_name = item_parts[0] + item_count = item_parts[1] + if not item_count.isnumeric() and item_count not in ["all", "half"] and not item_count.endswith('%'): + logging.debug(f'Invalid item_count "{item_count}" found, reverting to initial item_name "{item}"') + item_name = item - item_name = item_name[1:] + # if it's a category, validate that the category exists + if is_category: item_category_exists = len([item for item in DataValidation.item_table if item_name in item.get('category', [])]) > 0 if not item_category_exists: raise ValidationError("Item category %s is required by region %s but is misspelled or does not exist." % (item_name, region_name)) - continue - - item = item.replace("|", "") - - item_parts = item.split(":") - item_name = item + else: + item_exists = len([item.get("name") for item in DataValidation.item_table_with_events if item.get("name") == item_name]) > 0 - if len(item_parts) > 1: - item_name = item_parts[0] - - item_exists = len([item.get("name") for item in DataValidation.item_table_with_events if item.get("name") == item_name]) > 0 - - if not item_exists: - raise ValidationError("Item %s is required by region %s but is misspelled or does not exist." % (item_name, region_name)) + if not item_exists: + raise ValidationError("Item %s is required by region %s but is misspelled or does not exist." % (item_name, region_name)) else: # item access is in dict form for item in region["requires"]: diff --git a/src/Rules.py b/src/Rules.py index ddbf007..bb844f5 100644 --- a/src/Rules.py +++ b/src/Rules.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Any from enum import IntEnum from operator import eq, ge, le @@ -45,10 +45,10 @@ def construct_logic_error(location_or_region: dict, source: LogicErrorSource) -> return KeyError(f"Invalid 'requires' for {object_type} '{object_name}': {source_text} (ERROR {source})") -def infix_to_postfix(expr, location): - prec = {"&": 2, "|": 2, "!": 3} - stack = [] - postfix = "" +def infix_to_postfix(expr: str, location: dict) -> str: + prec: dict[str, int] = {"&": 2, "|": 2, "!": 3} + stack: list[str] = [] + postfix: str = "" try: for c in expr: @@ -73,8 +73,8 @@ def infix_to_postfix(expr, location): return postfix -def evaluate_postfix(expr: str, location: str) -> bool: - stack = [] +def evaluate_postfix(expr: str, location: dict) -> bool: + stack: list[bool] = [] try: for c in expr: @@ -158,65 +158,60 @@ def findAndRecursivelyExecuteFunctions(requires_list: str, recursionDepth: int = # parse user written statement into list of each item for item in re.findall(r'\|[^|]+\|', requires_list): - require_type = 'item' + if item not in requires_list: + # previous instance of this item was already processed + continue - if '|@' in item: - require_type = 'category' + require_category = '|@' in item item_base = item - item = item.lstrip('|@$').rstrip('|') + item: str = item.lstrip('|@$').rstrip('|') - item_parts = item.split(":") # type: list[str] + item_parts: list[str] = item.rsplit(":", 1) item_name = item - item_count = "1" + item_count: str | int = "1" if len(item_parts) > 1: item_name = item_parts[0].strip() item_count = item_parts[1].strip() + # If invalid count assume its actually part of the item name + if not item_count.isnumeric() and item_count not in ["all", "half"] and not item_count.endswith('%'): + item_name = item + item_count = "1" + total = 0 + valid_items: list[str] = [] + if require_category: + valid_items.extend([item["name"] for item in world.item_name_to_item.values() if "category" in item and item_name in item["category"]]) + valid_items.extend([event["name"] for event in world.event_name_to_event.values() if "category" in event and item_name in event["category"]]) + else: + valid_items.append(item_name) - if require_type == 'category': - category_items = [item for item in world.item_name_to_item.values() if "category" in item and item_name in item["category"]] - category_items += [event for event in world.event_name_to_event.values() if "category" in event and item_name in event["category"]] - category_items_counts = sum([items_counts.get(category_item["name"], 0) for category_item in category_items]) - if item_count.lower() == 'all': - item_count = category_items_counts - elif item_count.lower() == 'half': - item_count = int(category_items_counts / 2) - elif item_count.endswith('%') and len(item_count) > 1: - percent = clamp(float(item_count[:-1]) / 100, 0, 1) - item_count = math.ceil(category_items_counts * percent) - else: - try: - item_count = int(item_count) - except ValueError as e: - raise ValueError(f"Invalid item count `{item_name}` in {area}.") from e - - for category_item in category_items: - total += state.count(category_item["name"], player) - - if total >= item_count: - requires_list = requires_list.replace(item_base, "1") - elif require_type == 'item': - item_current_count = items_counts.get(item_name, 0) - if item_count.lower() == 'all': - item_count = item_current_count - elif item_count.lower() == 'half': - item_count = int(item_current_count / 2) - elif item_count.endswith('%') and len(item_count) > 1: - percent = clamp(float(item_count[:-1]) / 100, 0, 1) - item_count = math.ceil(item_current_count * percent) - else: - item_count = int(item_count) + item_current_count = sum([items_counts.get(valid_item, 0) for valid_item in valid_items]) - total = state.count(item_name, player) + if item_count.lower() == 'all': + item_count = item_current_count + elif item_count.lower() == 'half': + item_count = int(item_current_count / 2) + elif item_count.endswith('%') and len(item_count) > 1: + percent = clamp(float(item_count[:-1]) / 100, 0, 1) + item_count = math.ceil(item_current_count * percent) + + try: + item_count = int(item_count) + except ValueError as e: + raise ValueError(f"Invalid item count `{item_name}` in {area}.") from e + + for valid_item in valid_items: + total += state.count(valid_item, player) if total >= item_count: requires_list = requires_list.replace(item_base, "1") + break - if total <= item_count: + if total < item_count: requires_list = requires_list.replace(item_base, "0") requires_list = re.sub(r'\s?\bAND\b\s?', '&', requires_list, count=0, flags=re.IGNORECASE) @@ -346,7 +341,7 @@ def allRegionsAccessible(state): # Victory requirement multiworld.completion_condition[player] = lambda state: state.has("__Victory__", player) - def convert_req_function_args(state: CollectionState, func, args: list[str], areaName: str): + def convert_req_function_args(state: CollectionState, func, args: list[str| Any], areaName: str): parameters = inspect.signature(func).parameters knownParameters = [World, 'ManualWorld', MultiWorld, CollectionState] index = -1