diff --git a/src/devana/code_generation/printers/default/__init__.py b/src/devana/code_generation/printers/default/__init__.py index 8b52268..0649d16 100644 --- a/src/devana/code_generation/printers/default/__init__.py +++ b/src/devana/code_generation/printers/default/__init__.py @@ -23,3 +23,4 @@ from .functionprinter import FunctionPrinter from .functiontypeprinter import FunctionTypePrinter from .attributeprinter import AttributePrinter, AttributeDeclarationPrinter +from .conceptprinter import ConceptPrinter diff --git a/src/devana/code_generation/printers/default/classprinter.py b/src/devana/code_generation/printers/default/classprinter.py index 2baa5de..9d1aa1c 100644 --- a/src/devana/code_generation/printers/default/classprinter.py +++ b/src/devana/code_generation/printers/default/classprinter.py @@ -181,9 +181,15 @@ def print(self, source: ClassInfo, config: Optional[PrinterConfiguration] = None parameters.append(self.printer_dispatcher.print(p, config, source)) parameters = ','.join(parameters) template_prefix = f"template<{parameters}>" + if source.template.requires: + template_prefix += " requires" + for req in source.template.requires: + if isinstance(req, str): + template_prefix += f" {req}" + continue + template_prefix += f" {self.printer_dispatcher.print(req, config, source)}" specialisation_values = [] - for s in source.template.specialisation_values: if isinstance(s, str): specialisation_values.append(s) diff --git a/src/devana/code_generation/printers/default/conceptprinter.py b/src/devana/code_generation/printers/default/conceptprinter.py new file mode 100644 index 0000000..b19b0b6 --- /dev/null +++ b/src/devana/code_generation/printers/default/conceptprinter.py @@ -0,0 +1,33 @@ +from typing import Optional +from devana.code_generation.printers.icodeprinter import ICodePrinter +from devana.syntax_abstraction.conceptinfo import ConceptInfo +from devana.code_generation.printers.dispatcherinjectable import DispatcherInjectable +from devana.code_generation.printers.configuration import PrinterConfiguration +from devana.code_generation.printers.formatter import Formatter + + +class ConceptPrinter(ICodePrinter, DispatcherInjectable): + """Printer for concept declaration.""" + + def print(self, source: ConceptInfo, config: Optional[PrinterConfiguration] = None, + context: Optional = None) -> str: + if config is None: + config = PrinterConfiguration() + formatter = Formatter(config) + if source.is_requirement: + parameters = [] + for p in source.parameters: + parameters.append(self.printer_dispatcher.print(p, config, source)) + parameters = ', '.join(parameters) + if len(parameters) > 0: + return f"{source.name}<{parameters}>" + return source.name + parameters = [] + for p in source.template.parameters: + parameters.append(self.printer_dispatcher.print(p, config, source)) + parameters = ', '.join(parameters) + template_prefix = f"template<{parameters}>" + formatter.print_line(template_prefix) + + formatter.print_line(f"concept {source.name} = {source.body};") + return formatter.text diff --git a/src/devana/code_generation/printers/default/defaultprinter.py b/src/devana/code_generation/printers/default/defaultprinter.py index 7ea58ac..91f236f 100644 --- a/src/devana/code_generation/printers/default/defaultprinter.py +++ b/src/devana/code_generation/printers/default/defaultprinter.py @@ -14,6 +14,8 @@ from devana.code_generation.printers.default.commentprinter import CommentPrinter from devana.code_generation.printers.default.functiontypeprinter import FunctionTypePrinter from devana.code_generation.printers.default.stubtypeprinter import StubTypePrinter +from devana.code_generation.printers.default.conceptprinter import ConceptPrinter +from devana.syntax_abstraction.conceptinfo import ConceptInfo from devana.syntax_abstraction.classinfo import ClassInfo from devana.syntax_abstraction.templateinfo import GenericTypeParameter from devana.syntax_abstraction.typedefinfo import TypedefInfo @@ -72,5 +74,5 @@ def create_default_printer() -> CodePrinter: printer.register(StubTypePrinter, StubType) printer.register(AttributePrinter, Attribute) printer.register(AttributeDeclarationPrinter, AttributeDeclaration) - + printer.register(ConceptPrinter, ConceptInfo) return printer diff --git a/src/devana/code_generation/printers/default/functionprinter.py b/src/devana/code_generation/printers/default/functionprinter.py index 0f71693..ceda49e 100644 --- a/src/devana/code_generation/printers/default/functionprinter.py +++ b/src/devana/code_generation/printers/default/functionprinter.py @@ -41,6 +41,13 @@ def print(self, source: FunctionInfo, config: Optional[PrinterConfiguration] = N parameters.append(self.printer_dispatcher.print(p, config, source)) parameters = ','.join(parameters) template_prefix = f"template<{parameters}>" + if source.template.requires: + template_prefix += " requires" + for req in source.template.requires: + if isinstance(req, str): + template_prefix += f" {req}" + continue + template_prefix += f" {self.printer_dispatcher.print(req, config, source)}" specialisation_values = [] @@ -71,6 +78,13 @@ def print(self, source: FunctionInfo, config: Optional[PrinterConfiguration] = N result = f"{return_type} {name}{template_suffix}({args})" else: result = f"{name}{template_suffix}({args})" + if source.requires is not None: + result += " requires" + for req in source.requires: + if isinstance(req, str): + result += f" {req}" + continue + result += f" {self.printer_dispatcher.print(req, config, source)}" if source.modification.is_static: result = "static " + result diff --git a/src/devana/code_generation/printers/default/templateparameterprinter.py b/src/devana/code_generation/printers/default/templateparameterprinter.py index 44c3d09..d1d2766 100644 --- a/src/devana/code_generation/printers/default/templateparameterprinter.py +++ b/src/devana/code_generation/printers/default/templateparameterprinter.py @@ -1,14 +1,18 @@ from devana.code_generation.printers.icodeprinter import ICodePrinter from devana.syntax_abstraction.templateinfo import TemplateInfo from devana.code_generation.printers.dispatcherinjectable import DispatcherInjectable +from devana.syntax_abstraction.conceptinfo import ConceptInfo class TemplateParameterPrinter(ICodePrinter, DispatcherInjectable): """Printer for template parameter.""" def print(self, source: TemplateInfo.TemplateParameter, _1=None, _2=None) -> str: + if isinstance(source.specifier, ConceptInfo): + text = f"{self.printer_dispatcher.print(source.specifier)} {source.name}" + else: + text = f"{source.specifier} {source.name}" - text = f"{source.specifier} {source.name}" if source.is_variadic: return f"{text}..." if source.default_value: diff --git a/src/devana/code_generation/printers/default/usingprinter.py b/src/devana/code_generation/printers/default/usingprinter.py index 7baf106..6db9855 100644 --- a/src/devana/code_generation/printers/default/usingprinter.py +++ b/src/devana/code_generation/printers/default/usingprinter.py @@ -21,6 +21,24 @@ def print(self, source: Using, config: Optional[PrinterConfiguration] = None, if source.associated_comment: formatter.print_line(self.printer_dispatcher.print(source.associated_comment, config, source)) + + template_prefix = "" + if source.template: + parameters = [] + for p in source.template.parameters: + parameters.append(self.printer_dispatcher.print(p, config, source)) + parameters = ','.join(parameters) + template_prefix = f"template<{parameters}>" + if source.template.requires: + template_prefix += " requires" + for req in source.template.requires: + if isinstance(req, str): + template_prefix += f" {req}" + continue + template_prefix += f" {self.printer_dispatcher.print(req, config, source)}" + if template_prefix: + formatter.print_line(template_prefix) + formatter.line = f"using {source.name} = {self.printer_dispatcher.print(source.type_info, config, source)};" formatter.next_line() return formatter.text diff --git a/src/devana/syntax_abstraction/classinfo.py b/src/devana/syntax_abstraction/classinfo.py index 4269f25..5c82acd 100644 --- a/src/devana/syntax_abstraction/classinfo.py +++ b/src/devana/syntax_abstraction/classinfo.py @@ -12,6 +12,7 @@ from devana.syntax_abstraction.comment import Comment from devana.syntax_abstraction.attribute import DescriptiveByAttributes, AttributeDeclaration from devana.syntax_abstraction._external_source import create_external +from devana.syntax_abstraction.conceptinfo import ConceptInfo from devana.utility.lazy import LazyNotInit, lazy_invoke from devana.utility.init_params import init_params from devana.utility.traits import IBasicCreatable, ICursorValidate, IFromCursorCreatable, IFromParamsCreatable @@ -207,6 +208,7 @@ def from_params( # pylint: disable=unused-argument, arguments-renamed template: Optional[TemplateInfo] = None, associated_comment: Optional[Comment] = None, prefix: Optional[str] = None, + requires: Optional[List[Union[str, ConceptInfo]]] = None, access_specifier: Optional[AccessSpecifier] = None, type: Optional[MethodType] = None, # noqa pylint: disable=redefined-builtin ) -> "MethodInfo": @@ -305,6 +307,7 @@ def from_params( # pylint: disable=unused-argument, arguments-renamed template: Optional[TemplateInfo] = None, associated_comment: Optional[Comment] = None, prefix: Optional[str] = None, + requires: Optional[List[Union[str, ConceptInfo]]] = None, access_specifier: Optional[AccessSpecifier] = None, type: Optional[MethodType] = None, # noqa pylint: disable=redefined-builtin initializer_list: Optional[List[InitializerInfo]] = None, @@ -416,6 +419,7 @@ def from_params( # pylint: disable=unused-argument, arguments-differ template: Optional[TemplateInfo] = None, associated_comment: Optional[Comment] = None, prefix: Optional[str] = None, + requires: Optional[List[Union[str, ConceptInfo]]] = None, access_specifier: Optional[AccessSpecifier] = None, ) -> "DestructorInfo": return cls(None, parent) diff --git a/src/devana/syntax_abstraction/conceptinfo.py b/src/devana/syntax_abstraction/conceptinfo.py new file mode 100644 index 0000000..dfa3b53 --- /dev/null +++ b/src/devana/syntax_abstraction/conceptinfo.py @@ -0,0 +1,143 @@ +from __future__ import annotations +from typing import Optional, List, Any +from clang import cindex + +from devana.syntax_abstraction.organizers.codecontainer import CodeContainer +from devana.syntax_abstraction.typeexpression import TypeExpression +from devana.syntax_abstraction.codepiece import CodePiece +from devana.syntax_abstraction.syntax import ISyntaxElement +from devana.utility.lazy import LazyNotInit, lazy_invoke +from devana.utility.init_params import init_params +from devana.utility.errors import ParserError + + +class ConceptInfo(CodeContainer): + """Represents a C++ concept, either as a full definition or as a requirement.""" + + def __init__(self, cursor: Optional[cindex.Cursor] = None, parent: Optional[CodeContainer] = None): + super().__init__(cursor, parent) + if cursor is None: + from devana.syntax_abstraction.templateinfo import TemplateInfo # pylint: disable=import-outside-toplevel + + self._name = "DefaultConcept" + self._body = "true" + self._template = TemplateInfo.from_params(parameters=[ + TemplateInfo.TemplateParameter.create_default() + ]) + self._parameters = [] + self._is_requirement = False + else: + if not self.is_cursor_valid(cursor): + raise ParserError(f"It is not a valid type cursor: {cursor.kind}.") + self._name = LazyNotInit + self._body = LazyNotInit + self._template = LazyNotInit + self._parameters = LazyNotInit + self._is_requirement = LazyNotInit + + def __repr__(self): + return f"{type(self).__name__}:{self.name} ({super().__repr__()})" + + @classmethod + def create_default(cls, parent: Optional = None) -> "ConceptInfo": + return cls(None, parent) + + @classmethod + def from_cursor(cls, cursor: cindex.Cursor, parent: Optional = None) -> Optional["ConceptInfo"]: + if cls.is_cursor_valid(cursor): + return cls(cursor, parent) + return None + + @classmethod + @init_params(skip={"parent"}) + def from_params( # pylint: disable=unused-argument + cls, + parent: Optional[ISyntaxElement] = None, + content: Optional[List[Any]] = None, + namespace: Optional[str] = None, + name: Optional[str] = None, + body: Optional[str] = None, + template: Optional[ISyntaxElement] = None, + parameters: Optional[List[TypeExpression]] = None, + is_requirement: Optional[bool] = None + ) -> "ConceptInfo": + return cls(None, parent) + + @staticmethod + def is_cursor_valid(cursor: cindex.Cursor) -> bool: + return cursor.kind == cindex.CursorKind.CONCEPT_DECL + + @property + @lazy_invoke + def name(self) -> str: + self._name = self._cursor.spelling + return self._name + + @name.setter + def name(self, value) -> None: + self._name = value + + @property + @lazy_invoke + def template(self) -> ISyntaxElement: + """Template associated with this concept.""" + from devana.syntax_abstraction.templateinfo import TemplateInfo # pylint: disable=import-outside-toplevel + self._template = TemplateInfo.from_cursor(self._cursor) + return self._template + + @template.setter + def template(self, value: ISyntaxElement) -> None: + self._template = value + + @property + @lazy_invoke + def body(self) -> str: + """The body of the concept, which defines its constraint expression.""" + self._body = "" + for child in self._cursor.get_children(): + if child.kind != cindex.CursorKind.TEMPLATE_TYPE_PARAMETER: + self._body = CodePiece(child).text + break + return self._body + + @body.setter + def body(self, value: str) -> None: + self._body = value + + @property + @lazy_invoke + def parameters(self) -> List[TypeExpression]: + """Retrieves the concept parameters '<...>'.""" + from devana.syntax_abstraction.templateinfo import TemplateInfo # pylint: disable=import-outside-toplevel + if not isinstance(self.parent, TemplateInfo.TemplateParameter): + return [] + # Probably without a cursor from the parent it will not be possible to extract it. + # I get a mental breakdown every time I see the number -1, when i want to extract parameters in the normal way. + # Fuck it for now. + return [] + + @parameters.setter + def parameters(self, value: List[TypeExpression]) -> None: + self._parameters = value + + @property + @lazy_invoke + def is_requirement(self) -> bool: + """Determines whether this ConceptInfo instance is acting as a requirement.""" + from devana.syntax_abstraction.functioninfo import FunctionInfo # pylint: disable=import-outside-toplevel + from devana.syntax_abstraction.templateinfo import TemplateInfo # pylint: disable=import-outside-toplevel + self._is_requirement = isinstance( + self.parent, ( + TemplateInfo.TemplateParameter, + TemplateInfo, FunctionInfo + ) + ) + return self._is_requirement + + @is_requirement.setter + def is_requirement(self, value: bool) -> None: + self._is_requirement = value + + @property + def _content_types(self) -> List: + return [ConceptInfo] diff --git a/src/devana/syntax_abstraction/functioninfo.py b/src/devana/syntax_abstraction/functioninfo.py index 575656e..5f31d83 100644 --- a/src/devana/syntax_abstraction/functioninfo.py +++ b/src/devana/syntax_abstraction/functioninfo.py @@ -1,7 +1,8 @@ -from typing import Optional, Tuple, List, Any, Union +from typing import Optional, Tuple, List, Any, Union, Iterable from enum import auto, IntFlag import re from clang import cindex + from devana.syntax_abstraction.variable import Variable from devana.syntax_abstraction.typeexpression import TypeExpression, BasicType from devana.syntax_abstraction.organizers.lexicon import Lexicon @@ -10,6 +11,7 @@ from devana.syntax_abstraction.codepiece import CodePiece from devana.syntax_abstraction.templateinfo import TemplateInfo from devana.syntax_abstraction.attribute import DescriptiveByAttributes, AttributeDeclaration +from devana.syntax_abstraction.conceptinfo import ConceptInfo from devana.utility import FakeEnum from devana.utility.lazy import LazyNotInit, lazy_invoke from devana.utility.traits import IBasicCreatable, ICursorValidate @@ -319,6 +321,7 @@ def __init__(self, cursor: Optional[cindex.Cursor] = None, parent: Optional[Code self._template = None self._namespaces = None self._associated_comment = None + self._requires = None else: if not self.is_cursor_valid(cursor): msg = f"It is not a valid type cursor: {cursor.kind}." @@ -333,6 +336,7 @@ def __init__(self, cursor: Optional[cindex.Cursor] = None, parent: Optional[Code self._template = LazyNotInit self._namespaces = LazyNotInit self._associated_comment = LazyNotInit + self._requires = LazyNotInit self._lexicon = Lexicon.create(self) @classmethod @@ -366,6 +370,7 @@ def from_params( # pylint: disable=unused-argument template: Optional[TemplateInfo] = None, associated_comment: Optional[Comment] = None, prefix: Optional[str] = None, + requires: Optional[List[Union[str, ConceptInfo]]] = None ) -> "FunctionInfo": return cls(None, parent) @@ -660,5 +665,59 @@ def prefix(self) -> str: def prefix(self, value: str): self._prefix = value + @property + @lazy_invoke + def requires(self) -> Optional[List[Union[ConceptInfo, str]]]: + """Extracts constraints from the 'requires' clause of the function. None if absent.""" + + # Need to get rid of the template line, because it could also have a requires clause, + # which we don't want to touch, since the template has its own property. + code = re.sub( + r'(?m)^template[^\n]*(?:\r?\n(?:[ \t]+.*|requires\b.*))*\r?\n?', + '', + CodePiece(self._cursor).text + ) + match = re.search(r"requires\s+([\s\S]*)", code) + if not match: + self._requires = None + return self._requires + self._requires = [] + + def find_concepts(cursor: cindex.Cursor) -> Iterable[cindex.Cursor]: + for child in cursor.get_children(): + if child.location.line < self._cursor.location.line: + # Ignore template elements. + continue + if child.kind == cindex.CursorKind.TEMPLATE_REF and child.referenced: + yield child + if child.kind in ( + cindex.CursorKind.BINARY_OPERATOR, + cindex.CursorKind.CONCEPT_SPECIALIZATION_EXPR, + cindex.CursorKind.PAREN_EXPR + ): + yield from find_concepts(child) + + # clang does not provide detailed info for all things in the requires (e.g., 'or', 'true'), + raw_elements: List[str] = re.findall( + r'\(|\)|[^\s()<]+(?:\s*<\s*[^\s>]+(?:\s+[^\s>]+)*\s*>)?', + match.group(1) + ) + cursors: List[cindex.Cursor] = list(find_concepts(self._cursor)) + for raw_element in raw_elements: + if len(cursors) > 0 and re.search(r'<[^>]+>', raw_element): + maybe_concept = ConceptInfo.from_cursor( + cursor=cursors.pop(0).referenced, + parent=self + ) + if maybe_concept is not None: + self._requires.append(maybe_concept) + continue + self._requires.append(raw_element.strip()) + return self._requires + + @requires.setter + def requires(self, value: Optional[List[Union[ConceptInfo, str]]]) -> None: + self._requires = value + def __repr__(self): return f"{type(self).__name__}:{self.name} ({super().__repr__()})" diff --git a/src/devana/syntax_abstraction/namespaceinfo.py b/src/devana/syntax_abstraction/namespaceinfo.py index 59c1124..7a0158a 100644 --- a/src/devana/syntax_abstraction/namespaceinfo.py +++ b/src/devana/syntax_abstraction/namespaceinfo.py @@ -99,8 +99,9 @@ def _content_types(self) -> List: from devana.syntax_abstraction.variable import GlobalVariable from devana.syntax_abstraction.unioninfo import UnionInfo from devana.syntax_abstraction.externc import ExternC + from devana.syntax_abstraction.conceptinfo import ConceptInfo types = [FunctionInfo, NamespaceInfo, UsingNamespace, ClassInfo, EnumInfo, TypedefInfo, MethodInfo, UnionInfo, - GlobalVariable, ExternC, Using] + GlobalVariable, ExternC, Using, ConceptInfo] return types def __repr__(self): diff --git a/src/devana/syntax_abstraction/organizers/sourcefile.py b/src/devana/syntax_abstraction/organizers/sourcefile.py index 83733bb..8457b8e 100644 --- a/src/devana/syntax_abstraction/organizers/sourcefile.py +++ b/src/devana/syntax_abstraction/organizers/sourcefile.py @@ -392,8 +392,9 @@ def _content_types(self) -> List: from devana.syntax_abstraction.variable import GlobalVariable from devana.syntax_abstraction.externc import ExternC from devana.syntax_abstraction.using import Using + from devana.syntax_abstraction.conceptinfo import ConceptInfo types = [ClassInfo, UnionInfo, FunctionInfo, EnumInfo, TypedefInfo, NamespaceInfo, UsingNamespace, - MethodInfo, GlobalVariable, ExternC, Using] + MethodInfo, GlobalVariable, ExternC, Using, ConceptInfo] return types def _create_content(self) -> List[Any]: diff --git a/src/devana/syntax_abstraction/templateinfo.py b/src/devana/syntax_abstraction/templateinfo.py index 5531de9..fb51993 100644 --- a/src/devana/syntax_abstraction/templateinfo.py +++ b/src/devana/syntax_abstraction/templateinfo.py @@ -1,9 +1,10 @@ import re from pathlib import Path -from typing import Optional, List, Union, Tuple, Any +from typing import Optional, List, Union, Tuple, Any, Iterable from clang import cindex from devana.syntax_abstraction.codepiece import CodePiece from devana.syntax_abstraction.typeexpression import TypeExpression, TypeModification +from devana.syntax_abstraction.conceptinfo import ConceptInfo from devana.syntax_abstraction.organizers.codecontainer import CodeContainer from devana.syntax_abstraction.organizers.lexicon import Lexicon from devana.utility.lazy import LazyNotInit, lazy_invoke @@ -13,6 +14,7 @@ from devana.syntax_abstraction.syntax import ISyntaxElement from devana.configuration import Configuration + class GenericTypeParameter(ISyntaxElement): """An unresolved generic template parameter, known idiomatically in C++ as T.""" @@ -29,7 +31,7 @@ def name(self, value): self._name = value @staticmethod - def from_cursor(type_c, cursor: cindex.Type, parent: Optional = None) -> Optional["GenericTypeParameter"]: + def from_cursor(type_c, cursor: cindex.Cursor, parent: Optional = None) -> Optional["GenericTypeParameter"]: if type_c.kind == cindex.TypeKind.UNEXPOSED: if type_c.get_num_template_arguments() > 0: return None @@ -38,6 +40,8 @@ def from_cursor(type_c, cursor: cindex.Type, parent: Optional = None) -> Optiona if c.kind == cindex.CursorKind.TYPE_REF: return GenericTypeParameter(c.type.spelling, parent) elif c.kind == cindex.CursorKind.TEMPLATE_REF: + if getattr(c, "referenced", None) and c.referenced.kind == cindex.CursorKind.CONCEPT_DECL: + return GenericTypeParameter(type_c.spelling, parent) return None text = type_c.spelling if "::" in text: @@ -59,7 +63,6 @@ class TemplateInfo(IBasicCreatable, ICursorValidate, ISyntaxElement): class TemplateParameter(IBasicCreatable, ICursorValidate, ISyntaxElement): """A description of the generic component for the type/function claim.""" - def __init__(self, cursor: Optional[cindex.Cursor] = None, parent: Optional = None): self._cursor = cursor self._parent = parent @@ -86,7 +89,7 @@ def create_default(cls, parent: Optional = None) -> "TemplateInfo.TemplateParame def from_params( # pylint: disable=unused-argument cls, parent: Optional[ISyntaxElement] = None, - specifier: Optional[str] = None, + specifier: Optional[Union[str, ConceptInfo]] = None, name: Optional[str] = None, default_value: Optional[str] = None, is_variadic: Optional[bool] = None, @@ -106,15 +109,24 @@ def is_cursor_valid(cursor: cindex.Cursor) -> bool: @property @lazy_invoke - def specifier(self) -> str: - """Keyword before name.""" + def specifier(self) -> Union[ConceptInfo, str]: + """Keyword or ConceptInfo instance preceding the name.""" + cursors = filter( + lambda c: c.kind == cindex.CursorKind.TEMPLATE_REF and c.referenced, + self._cursor.get_children() + ) + for cursor in cursors: + if maybe_concept := ConceptInfo.from_cursor(cursor=cursor.referenced, parent=self): + self._specifier = maybe_concept + return self._specifier + self._specifier = "class" if CodePiece(self._cursor).text.find("class ") != -1 else "typename" return self._specifier @specifier.setter - def specifier(self, value): - if value not in ("class", "typename"): - raise ValueError("Only class or typename specifier is allowed.") + def specifier(self, value: Union[ConceptInfo, str]): + if not isinstance(value, ConceptInfo) and value not in ("class", "typename"): + raise ValueError("Specifier must be class, typename, or an instance of ConceptInfo.") self._specifier = value @property @@ -150,7 +162,7 @@ def default_value(self, value): def is_variadic(self) -> bool: self._is_variadic = False text = CodePiece(self._cursor).text - if re.search(r"\.\.\." + self.name, text): + if re.search(r"\.\.\.\s*" + self.name, text): self._is_variadic = True return self._is_variadic @@ -181,6 +193,7 @@ def __init__(self, cursor: Optional[cindex.Cursor] = None, parent: Optional = No self._parameters = [] self._is_empty = True self._is_variadic = False + self._requires = None else: if not self.is_cursor_valid(cursor): raise ParserError("Template parameter expect FUNCTION_TEMPLATE cursor kind.") @@ -189,6 +202,7 @@ def __init__(self, cursor: Optional[cindex.Cursor] = None, parent: Optional = No self._parameters = LazyNotInit self._is_empty = LazyNotInit self._is_variadic = LazyNotInit + self._requires = LazyNotInit self._lexicon = Lexicon.create(self) @classmethod @@ -205,6 +219,7 @@ def from_params( # pylint: disable=unused-argument parameters: Optional[List[TemplateParameter]] = None, is_empty: Optional[bool] = None, lexicon: Optional[Lexicon] = None, + requires: Optional[List[Union[str, "ConceptInfo"]]] = None ) -> "TemplateInfo": return cls(None, parent) @@ -216,10 +231,16 @@ def from_cursor(cls, cursor: cindex.Cursor, parent: Optional = None) -> Optional @staticmethod def is_cursor_valid(cursor: cindex.Cursor) -> bool: - return (not (cursor.kind != cindex.CursorKind.FUNCTION_TEMPLATE) or not ( - cursor.kind != cindex.CursorKind.CLASS_TEMPLATE) or not ( - cursor.kind != cindex.CursorKind.CLASS_TEMPLATE_PARTIAL_SPECIALIZATION) or - re.search(r"template\s*<>", CodePiece(cursor).text)) + valid_cursors = ( + cindex.CursorKind.FUNCTION_TEMPLATE, + cindex.CursorKind.CLASS_TEMPLATE, + cindex.CursorKind.CLASS_TEMPLATE_PARTIAL_SPECIALIZATION, + cindex.CursorKind.CONCEPT_DECL, + cindex.CursorKind.TYPE_ALIAS_TEMPLATE_DECL + ) + return cursor.kind in valid_cursors or re.search( + r"template\s*<>", CodePiece(cursor).text + ) is not None @property @lazy_invoke @@ -428,3 +449,50 @@ def lexicon(self, value): @property def parent(self) -> ISyntaxElement: return self._parent + + @property + @lazy_invoke + def requires(self) -> Optional[List[Union[str, "ConceptInfo"]]]: + """Extracts constraints from the 'requires' clause of the template. None if absent.""" + match = re.search( + r"template\s*<[^>]+>\s*(?:\r?\n\s*)?requires\s+(?:\r?\n\s*)?(.+?)(?=\r?\n\S|$)", + CodePiece(self._cursor).text, + flags=re.DOTALL + ) + if not match: + self._requires = None + return self._requires + self._requires = [] + + def find_concepts(cursor: cindex.Cursor) -> Iterable[cindex.Cursor]: + for child in cursor.get_children(): + if child.kind == cindex.CursorKind.TEMPLATE_REF and child.referenced: + yield child + if child.kind in ( + cindex.CursorKind.BINARY_OPERATOR, + cindex.CursorKind.CONCEPT_SPECIALIZATION_EXPR, + cindex.CursorKind.PAREN_EXPR + ): + yield from find_concepts(child) + + # clang does not provide info for all things in the requires (e.g., 'or', 'true'), + raw_elements: List[str] = re.findall( + r'\(|\)|[^\s()<]+(?:\s*<\s*[^\s>]+(?:\s+[^\s>]+)*\s*>)?', + match.group(1) + ) + cursors: List[cindex.Cursor] = list(find_concepts(self._cursor)) + for raw_element in raw_elements: + if len(cursors) > 0 and re.search(r'<[^>]+>', raw_element): + maybe_concept = ConceptInfo.from_cursor( + cursor=cursors.pop(0).referenced, + parent=self + ) + if maybe_concept is not None: + self._requires.append(maybe_concept) + continue + self._requires.append(raw_element.strip()) + return self._requires + + @requires.setter + def requires(self, value: Optional[List[Union[str, "ConceptInfo"]]]) -> None: + self._requires = value diff --git a/src/devana/syntax_abstraction/typeexpression.py b/src/devana/syntax_abstraction/typeexpression.py index 33c2378..3352b26 100644 --- a/src/devana/syntax_abstraction/typeexpression.py +++ b/src/devana/syntax_abstraction/typeexpression.py @@ -419,6 +419,7 @@ def __init__(self, cursor: Optional[Union[cindex.Cursor, cindex.Type]] = None, p self._base_type_c = self._cursor.underlying_typedef_type elif self._cursor.kind == cindex.CursorKind.TYPE_ALIAS_DECL: self._base_type_c = self._cursor.type.get_canonical() + self._lexicon = Lexicon.create(self) @classmethod @@ -660,9 +661,23 @@ def template_arguments(self) -> Optional[List["TypeExpression"]]: type_c.kind is cindex.TypeKind.ELABORATED and match is None): self._template_arguments = None return self._template_arguments + + params = [] + if isinstance(self._cursor, cindex.Cursor) and self._cursor.kind == cindex.CursorKind.TYPE_ALIAS_DECL: + match = re.search(r'<([^>]+)>', CodePiece(self._cursor).text) + if match: + params = [param.strip() for param in match.group(1).split(',')] + for i in range(type_c.get_num_template_arguments()): el = type_c.get_template_argument_type(i) - self._template_arguments.append(TypeExpression(el, self)) + type_expr = TypeExpression(el, self) + if type_expr.is_generic: + match = re.match(r"type-parameter-(\d+)-(\d+)", type_expr.name) + if match and params: + param_name = params.pop(0) + type_expr._name = param_name # pylint: disable=protected-access + type_expr.details._name = param_name # pylint: disable=protected-access + self._template_arguments.append(type_expr) if not self._template_arguments: self._template_arguments = None diff --git a/src/devana/syntax_abstraction/using.py b/src/devana/syntax_abstraction/using.py index 202ed34..55fdef4 100644 --- a/src/devana/syntax_abstraction/using.py +++ b/src/devana/syntax_abstraction/using.py @@ -3,6 +3,7 @@ from devana.syntax_abstraction.codepiece import CodePiece from devana.syntax_abstraction.typeexpression import TypeExpression from devana.syntax_abstraction.organizers.codecontainer import CodeContainer +from devana.syntax_abstraction.templateinfo import TemplateInfo from devana.syntax_abstraction.comment import Comment from devana.syntax_abstraction.organizers.lexicon import Lexicon from devana.utility.errors import ParserError @@ -24,6 +25,7 @@ def __init__(self, cursor: Optional[cindex.Cursor] = None, parent: Optional[Code self._text_source = None self._name = "" self._associated_comment = None + self._template = None else: if not self.is_cursor_valid(cursor): raise ParserError("Element is not using type alias.") @@ -31,6 +33,7 @@ def __init__(self, cursor: Optional[cindex.Cursor] = None, parent: Optional[Code self._text_source = LazyNotInit self._name = LazyNotInit self._associated_comment = LazyNotInit + self._template = LazyNotInit self._lexicon = Lexicon.create(self) @classmethod @@ -46,20 +49,31 @@ def from_params( # pylint: disable=unused-argument parent: Optional[ISyntaxElement] = None, type_info: Union[TypeExpression, ISyntaxElement, None] = None, name: Optional[str] = None, + template: Optional[TemplateInfo] = None, lexicon: Optional[Lexicon] = None, - associated_comment: Optional[Comment] = None, + associated_comment: Optional[Comment] = None ) -> "Using": return cls(None, parent) @staticmethod def is_cursor_valid(cursor: cindex.Cursor) -> bool: - return cursor.kind == cindex.CursorKind.TYPE_ALIAS_DECL + return cursor.kind in ( + cindex.CursorKind.TYPE_ALIAS_DECL, + cindex.CursorKind.TYPE_ALIAS_TEMPLATE_DECL + ) @property @lazy_invoke def type_info(self) -> Union[TypeExpression, "ISyntaxElement"]: """Type alias can be true type or next typedef.""" - self._type_info = TypeExpression(self._cursor, self) + cursor = self._cursor + for child in self._cursor.get_children(): + # template using scenario + if child.kind == cindex.CursorKind.TYPE_ALIAS_DECL: + cursor = child + break + + self._type_info = TypeExpression(cursor, self) return self._type_info @type_info.setter @@ -70,7 +84,14 @@ def type_info(self, value): @lazy_invoke def name(self) -> str: """Typedef alias name.""" - self._name = self._cursor.type.get_typedef_name() + cursor = self._cursor + for child in self._cursor.get_children(): + # template using scenario + if child.kind == cindex.CursorKind.TYPE_ALIAS_DECL: + cursor = child + break + + self._name = cursor.type.get_typedef_name() return self._name @name.setter @@ -84,6 +105,17 @@ def text_source(self) -> Optional[CodePiece]: self._text_source = CodePiece(self._cursor) return self._text_source + @property + @lazy_invoke + def template(self) -> Optional[TemplateInfo]: + """Template associated with this using.""" + self._template = TemplateInfo.from_cursor(self._cursor) + return self._template + + @template.setter + def template(self, value: Optional[TemplateInfo]) -> None: + self._template = value + @property def parent(self) -> CodeContainer: """Structural parent element like file, namespace or class.""" diff --git a/tests/code_generation/unit/test_concept.py b/tests/code_generation/unit/test_concept.py new file mode 100644 index 0000000..b6a17e6 --- /dev/null +++ b/tests/code_generation/unit/test_concept.py @@ -0,0 +1,233 @@ +from devana.code_generation.printers.default.templateparameterprinter import TemplateParameterPrinter +from devana.code_generation.printers.default.conceptprinter import ConceptPrinter +from devana.code_generation.printers.default.classprinter import ClassPrinter, FieldPrinter, MethodPrinter +from devana.code_generation.printers.default.typeexpressionprinter import TypeExpressionPrinter +from devana.code_generation.printers.default.basictypeprinter import BasicTypePrinter +from devana.code_generation.printers.default.functionprinter import FunctionPrinter +from devana.syntax_abstraction.templateinfo import GenericTypeParameter +from devana.syntax_abstraction.conceptinfo import ConceptInfo +from devana.syntax_abstraction.functioninfo import FunctionInfo +from devana.syntax_abstraction.classinfo import ClassInfo, FieldInfo, MethodInfo +from devana.syntax_abstraction.templateinfo import TemplateInfo +from devana.syntax_abstraction.typeexpression import TypeExpression, BasicType +from devana.code_generation.printers.codeprinter import CodePrinter +import unittest + + +class TestConceptAlone(unittest.TestCase): + + def setUp(self): + self.printer = CodePrinter() + self.printer.register(ConceptPrinter, ConceptInfo) + self.printer.register(TemplateParameterPrinter, TemplateInfo.TemplateParameter) + self.printer.register(TypeExpressionPrinter, TypeExpression) + self.printer.register(BasicTypePrinter, GenericTypeParameter) + self.printer.register(BasicTypePrinter, BasicType) + + def test_print_simple_concept(self): + concept = ConceptInfo.from_params( + name="TestConcept", + body="T{}" + ) + result = self.printer.print(concept) + self.assertEqual(result, "template\nconcept TestConcept = T{};\n") + + def test_print_complex_body_concept(self): + concept = ConceptInfo.from_params( + name="ComplexConcept", + body="true or false or && requires(A a) {\n a++;\n}", + template=TemplateInfo.from_params( + parameters=[ + TemplateInfo.TemplateParameter.from_params( + name="A", specifier="class" + ) + ] + ) + ) + result = self.printer.print(concept) + self.assertEqual(result, """template +concept ComplexConcept = true or false or && requires(A a) { + a++; +}; +""") + def test_print_concept_template_params(self): + concept = ConceptInfo.from_params( + template=TemplateInfo.from_params( + parameters=[ + TemplateInfo.TemplateParameter.from_params( + name="A", + specifier="class", + default_value="10" + ), + TemplateInfo.TemplateParameter.from_params( + name="Args", + specifier="typename", + is_variadic=True + ) + ] + ) + ) + result = self.printer.print(concept) + self.assertEqual(result, """template +concept DefaultConcept = true; +""") + def test_print_concept_skip_requires(self): + concept = ConceptInfo.create_default() + concept.template.requires = ["true"] + result = self.printer.print(concept) + self.assertEqual(result, "template\nconcept DefaultConcept = true;\n") + + def test_print_concept_as_requirement(self): + concept = ConceptInfo.from_params(name="Test", is_requirement=True) + with self.subTest("No parameters"): + result = self.printer.print(concept) + self.assertEqual(result, "Test") + + with self.subTest("With parameters"): + concept.parameters = [ + TypeExpression.from_params(details=GenericTypeParameter("T")), + TypeExpression.create_default() + ] + result = self.printer.print(concept) + self.assertEqual(result, "Test") + +class TestConceptClass(unittest.TestCase): + + def setUp(self): + self.printer = CodePrinter() + self.printer.register(ConceptPrinter, ConceptInfo) + self.printer.register(TemplateParameterPrinter, TemplateInfo.TemplateParameter) + self.printer.register(ClassPrinter, ClassInfo) + self.printer.register(TypeExpressionPrinter, TypeExpression) + self.printer.register(BasicTypePrinter, BasicType) + self.printer.register(FieldPrinter, FieldInfo) + self.printer.register(BasicTypePrinter, GenericTypeParameter) + self.printer.register(MethodPrinter, MethodInfo) + + def test_print_simple_class_concept(self): + concept = ConceptInfo.from_params( + name="TestConcept", + is_requirement=True + ) + class_ = ClassInfo.from_params( + name="ClassConcept", + is_class=True, + is_declaration=True, + template=TemplateInfo.from_params( + parameters=[ + TemplateInfo.TemplateParameter.from_params( + name="T", + specifier=concept + ) + ] + ) + ) + result = self.printer.print(class_) + self.assertEqual(result, """template +class ClassConcept; +""") + def test_print_class_requires(self): + class_ = ClassInfo.from_params( + name="ClassRequires", + template=TemplateInfo.from_params( + requires=["true"], + parameters=[ + TemplateInfo.TemplateParameter.create_default() + ] + ), + is_declaration=True + ) + result = self.printer.print(class_) + self.assertEqual(result, """template requires true +struct ClassRequires; +""") + + def test_print_complex_concept_class(self): + concept1 = ConceptInfo.from_params( + parameters=[TypeExpression.create_default()], + is_requirement=True + ) + concept2 = ConceptInfo.from_params( + is_requirement=True, + name="Testing", + parameters=[TypeExpression.from_params(details=GenericTypeParameter("T"))] + ) + class_ = ClassInfo.from_params( + name="ComplexConceptClass", + is_class=True, + content=[ + FieldInfo.from_params( + name="atr", + type=TypeExpression.from_params( + details=GenericTypeParameter("T") + ) + ), + MethodInfo.from_params( + name="foo", + return_type=BasicType.VOID, + requires=[concept2, "and", concept1] + ) + ], + template=TemplateInfo.from_params( + requires=["true", "&&", concept2], + parameters=[ + TemplateInfo.TemplateParameter.from_params( + name="T", + specifier=concept1 + ) + ] + ) + ) + result = self.printer.print(class_) + self.assertEqual(result, """template T> requires true && Testing +class ComplexConceptClass +{ + T atr; + void foo() requires Testing and DefaultConcept; +}; +""") + +class TestConceptFunction(unittest.TestCase): + + def setUp(self): + self.printer = CodePrinter() + self.printer.register(ConceptPrinter, ConceptInfo) + self.printer.register(TemplateParameterPrinter, TemplateInfo.TemplateParameter) + self.printer.register(ClassPrinter, ClassInfo) + self.printer.register(TypeExpressionPrinter, TypeExpression) + self.printer.register(BasicTypePrinter, BasicType) + self.printer.register(FieldPrinter, FieldInfo) + self.printer.register(BasicTypePrinter, GenericTypeParameter) + self.printer.register(FunctionPrinter, FunctionInfo) + + def test_function_requires(self): + concept = ConceptInfo.from_params( + is_requirement=True, + name="Concept", + parameters=[TypeExpression.from_params(details=GenericTypeParameter("T"))] + ) + source = FunctionInfo.from_params( + template=TemplateInfo.from_params( + parameters=[TemplateInfo.TemplateParameter.create_default()], + requires=[concept, "or", "true"] + ), + body=None, + requires=["true", "||", concept] + ) + result = self.printer.print(source) + self.assertEqual(result, "template requires Concept or true\nvoid foo() requires true || Concept;\n") + + def test_function_concept(self): + source = FunctionInfo.from_params( + template=TemplateInfo.from_params( + parameters=[TemplateInfo.TemplateParameter.from_params( + name="T", + specifier=ConceptInfo.from_params( + parameters=[TypeExpression.create_default()], + is_requirement=True + ) + )] + ) + ) + result = self.printer.print(source) + self.assertEqual(result, "template T>\nvoid foo();\n") \ No newline at end of file diff --git a/tests/code_generation/unit/test_instance_creations.py b/tests/code_generation/unit/test_instance_creations.py index 1139145..9ca0cb8 100644 --- a/tests/code_generation/unit/test_instance_creations.py +++ b/tests/code_generation/unit/test_instance_creations.py @@ -5,6 +5,7 @@ from devana.syntax_abstraction.functiontype import FunctionType from devana.syntax_abstraction.variable import GlobalVariable from devana.syntax_abstraction.typedefinfo import TypedefInfo +from devana.syntax_abstraction.conceptinfo import ConceptInfo from devana.syntax_abstraction.unioninfo import UnionInfo from devana.syntax_abstraction.enuminfo import EnumInfo from devana.syntax_abstraction.externc import ExternC @@ -269,6 +270,15 @@ def test_include_creation(self): self.assertEqual(include_info.value, "string") self.assertEqual(include_info.is_standard, True) + def test_concept_creation(self): + concept_info = ConceptInfo.from_params( + name="ConceptName", + body="false" + ) + self.assertEqual(concept_info.name, "ConceptName") + self.assertEqual(concept_info.body, "false") + self.assertIsNotNone(concept_info.template) + def test_init_params(self): class A: @classmethod diff --git a/tests/code_generation/unit/test_using.py b/tests/code_generation/unit/test_using.py index 409d573..e5a9ccf 100644 --- a/tests/code_generation/unit/test_using.py +++ b/tests/code_generation/unit/test_using.py @@ -1,9 +1,13 @@ import unittest from devana.code_generation.printers.default.basictypeprinter import BasicTypePrinter from devana.code_generation.printers.default.typeexpressionprinter import TypeExpressionPrinter +from devana.code_generation.printers.default.templateparameterprinter import TemplateParameterPrinter +from devana.code_generation.printers.default.conceptprinter import ConceptPrinter from devana.code_generation.printers.default.usingprinter import UsingPrinter from devana.code_generation.printers.codeprinter import CodePrinter from devana.syntax_abstraction.using import Using +from devana.syntax_abstraction.conceptinfo import ConceptInfo +from devana.syntax_abstraction.templateinfo import TemplateInfo, GenericTypeParameter from devana.syntax_abstraction.typeexpression import BasicType, TypeExpression, TypeModification @@ -14,6 +18,9 @@ def setUp(self): printer.register(BasicTypePrinter, BasicType) printer.register(TypeExpressionPrinter, TypeExpression) printer.register(UsingPrinter, Using) + printer.register(TemplateParameterPrinter, TemplateInfo.TemplateParameter) + printer.register(BasicTypePrinter, GenericTypeParameter) + printer.register(ConceptPrinter, ConceptInfo) self.printer: CodePrinter = printer def test_definition_basic(self): @@ -24,3 +31,58 @@ def test_definition_basic(self): source.type_info.modification |= TypeModification.POINTER | TypeModification.CONST result = self.printer.print(source) self.assertEqual(result, "using const_ptr_t = const char*;\n") + + def test_using_template(self): + source = Using.from_params( + name="constType", + type_info=TypeExpression.from_params( + details=GenericTypeParameter("T"), + modification=TypeModification.CONST + ), + template=TemplateInfo.from_params( + parameters=[ + TemplateInfo.TemplateParameter.from_params( + specifier="typename", + name="T" + )] + ) + ) + result = self.printer.print(source) + self.assertEqual(result, "template\nusing constType = const T;\n") + + def test_using_template_with_requires(self): + source = Using.from_params( + name="constexprType", + type_info=TypeExpression.from_params( + details=GenericTypeParameter("B"), + modification=TypeModification.CONSTEXPR + ), + template=TemplateInfo.from_params( + parameters=[ + TemplateInfo.TemplateParameter.from_params( + specifier="typename", + name="B" + )], + requires=["true"] + ) + ) + result = self.printer.print(source) + self.assertEqual(result, "template requires true\nusing constexprType = constexpr B;\n") + + def test_using_template_with_concept(self): + source = Using.from_params( + name="ConceptPtr", + type_info=TypeExpression.from_params( + details=GenericTypeParameter("C"), + modification=TypeModification.POINTER + ), + template=TemplateInfo.from_params( + parameters=[ + TemplateInfo.TemplateParameter.from_params( + specifier=ConceptInfo.from_params(name="Concept", is_requirement=True), + name="C" + )] + ) + ) + result = self.printer.print(source) + self.assertEqual(result, "template\nusing ConceptPtr = C*;\n") diff --git a/tests/parsing/unit/source_files/advanced_template.hpp b/tests/parsing/unit/source_files/advanced_template.hpp index 99d7bf8..0e44e50 100644 --- a/tests/parsing/unit/source_files/advanced_template.hpp +++ b/tests/parsing/unit/source_files/advanced_template.hpp @@ -110,4 +110,3 @@ struct TestTemplateFields BaseArray> data2; BaseArray data3; }; - diff --git a/tests/parsing/unit/source_files/concepts.hpp b/tests/parsing/unit/source_files/concepts.hpp new file mode 100644 index 0000000..f300b5c --- /dev/null +++ b/tests/parsing/unit/source_files/concepts.hpp @@ -0,0 +1,58 @@ +template +concept ConceptCase1 = requires(T a) { + { --a }; + { a-- }; +}; + +template +concept ConceptCase2 = requires(T a, T b) { + { a + b }; +}; + +template +concept ConceptCase3 = requires(T a, T b) { + a = b; +} || requires(T a, T b) { + b = a; +}; + +template +concept ConceptCase4 = requires { + T{}; +}; + +template +concept ConceptCase5 = requires { + T(-1) < T(0); +}; + +template +concept ConceptCase6 = ConceptCase1 && requires(T t) { + *t; +}; + +template +concept ConceptCase7 = (T{} > 0); + +template +concept ConceptCase8 = ConceptCase7; + +template +concept ConceptCase9 = ConceptCase1 && ConceptCase2; + +template +concept ConceptCase10 = ConceptCase1 || ConceptCase2; + +namespace testNamespace { + template + concept ConceptCase11 = true; +}; + +template +concept ConceptCase12 = T::value || true; + +template +concept ConceptCase13 = testNamespace::ConceptCase11; + +template +concept ConceptTemplate = true; diff --git a/tests/parsing/unit/source_files/template_class.hpp b/tests/parsing/unit/source_files/template_class.hpp index cfe52cf..8bad673 100644 --- a/tests/parsing/unit/source_files/template_class.hpp +++ b/tests/parsing/unit/source_files/template_class.hpp @@ -52,11 +52,14 @@ class template_class_complex }; template -struct struct_varidaic_template +struct struct_variadic_template { int a; }; +template +struct struct_variadic_template2; + template struct template_struct { @@ -102,4 +105,29 @@ template struct multiple_pointer_struct { T** a; -} \ No newline at end of file +} + +template +concept TestConcept = true; + +template +class BasicConceptClass { + T foo(T arg); +}; + +template requires true +class ConceptClass { +public: + T a; + void process(T arg) + requires TestConcept; +}; + +template + requires false or + true and TestConcept< T> +struct ConceptStruct { + T abc = 10; + T foo() requires (TestConcept || true); + T barFoo(T arg1, A arg2); +}; \ No newline at end of file diff --git a/tests/parsing/unit/source_files/template_functions.hpp b/tests/parsing/unit/source_files/template_functions.hpp index d0f0990..75d0bdb 100644 --- a/tests/parsing/unit/source_files/template_functions.hpp +++ b/tests/parsing/unit/source_files/template_functions.hpp @@ -8,4 +8,29 @@ template const T complex_function(float a, T b, P& c, char d = '3'); template<> -const int* specialisation_function(float a, int* b, float& c, char d); \ No newline at end of file +const int* specialisation_function(float a, int* b, float& c, char d); + +template +concept AlwaysTrue = true; + +namespace test { + template + concept AlwaysFalse = false; +}; + +template +T template_concept_function(); + +template requires AlwaysTrue +void requires_template_function1(T a) requires true or false; + +template + requires true or AlwaysTrue +int requires_template_function2(T a = 1) + requires AlwaysTrue and (true); + +template requires (AlwaysTrue or true) and false +void basic_concept_function(); + +template T> +void concept_namespace_function(); diff --git a/tests/parsing/unit/source_files/using.hpp b/tests/parsing/unit/source_files/using.hpp index a3e41f1..95fb01d 100644 --- a/tests/parsing/unit/source_files/using.hpp +++ b/tests/parsing/unit/source_files/using.hpp @@ -27,4 +27,19 @@ namespace num { char x; int y; }; -} \ No newline at end of file +} + +template +struct TestStruct; + +template +concept TestConcept = true; + +template +using UsingTemplate = A; + +template requires true or TestConcept +using UsingTemplateRequires = const TestStruct; + +template +using UsingConcept = const UsingTemplateRequires*; \ No newline at end of file diff --git a/tests/parsing/unit/test_attributes.py b/tests/parsing/unit/test_attributes.py index 1cb9fcb..2529321 100644 --- a/tests/parsing/unit/test_attributes.py +++ b/tests/parsing/unit/test_attributes.py @@ -1,13 +1,12 @@ import unittest import os -from typing import List from devana.syntax_abstraction.organizers.sourcefile import SourceFile from devana.syntax_abstraction.attribute import Attribute from devana.syntax_abstraction.functioninfo import FunctionInfo from devana.syntax_abstraction.classinfo import ClassInfo, FieldInfo, MethodInfo from devana.syntax_abstraction.namespaceinfo import NamespaceInfo from devana.syntax_abstraction.enuminfo import EnumInfo - +from typing import List class TestAttributesParser(unittest.TestCase): diff --git a/tests/parsing/unit/test_class.py b/tests/parsing/unit/test_class.py index 6ea683f..9aac00a 100644 --- a/tests/parsing/unit/test_class.py +++ b/tests/parsing/unit/test_class.py @@ -432,7 +432,10 @@ class TestClassTemplate(unittest.TestCase): def setUp(self): index = clang.cindex.Index.create() - self.cursor = index.parse(os.path.dirname(__file__) + r"/source_files/template_class.hpp").cursor + self.cursor = index.parse( + os.path.dirname(__file__) + r"/source_files/template_class.hpp", + args=("-std=c++20",) + ).cursor def test_simple_template_class(self): node = find_by_name(self.cursor, "simple_template_struct_1") @@ -659,12 +662,12 @@ def test_complex_template_class(self): self.assertEqual(content.template.parameters[0].name, "D") self.assertEqual(content.template.parameters[0].default_value, None) - def test_struct_varidaic_template(self): - node = find_by_name(self.cursor, "struct_varidaic_template") + def test_struct_variadic_template(self): + node = find_by_name(self.cursor, "struct_variadic_template") result = ClassInfo.from_cursor(node) self.assertTrue(result.is_struct) self.assertFalse(result.template is None) - self.assertEqual(result.name, "struct_varidaic_template") + self.assertEqual(result.name, "struct_variadic_template") self.assertEqual(len(result.template.parameters), 3) self.assertEqual(result.template.parameters[0].name, "T") self.assertEqual(result.template.parameters[0].specifier, "typename") @@ -678,6 +681,20 @@ def test_struct_varidaic_template(self): self.assertEqual(result.template.parameters[2].specifier, "typename") self.assertEqual(result.template.parameters[2].default_value, None) self.assertTrue(result.template.parameters[2].is_variadic) + self.assertIsNone(result.template.requires) + + def test_struct_variadic_template2(self): + node = find_by_name(self.cursor, "struct_variadic_template2") + result = ClassInfo.from_cursor(node) + self.assertTrue(result.is_struct) + self.assertFalse(result.template is None) + self.assertEqual(result.name, "struct_variadic_template2") + self.assertEqual(len(result.template.parameters), 1) + self.assertEqual(result.template.parameters[0].name, "Args") + self.assertEqual(result.template.parameters[0].specifier, "typename") + self.assertEqual(result.template.parameters[0].default_value, None) + self.assertTrue(result.template.parameters[0].is_variadic) + self.assertIsNone(result.template.requires) def test_multiple_pointer_type_template(self): node = find_by_name(self.cursor, "multiple_pointer_struct") @@ -694,7 +711,121 @@ def test_multiple_pointer_type_template(self): self.assertTrue(content.type.is_generic) self.assertEqual(content.type.modification.pointer_order, 2) self.assertEqual(content.type.details.name, "T") + self.assertIsNone(result.template.requires) + + def test_basic_concept_class(self): + node = find_by_name(self.cursor, "BasicConceptClass") + result = ClassInfo.from_cursor(node) + self.assertTrue(result.is_class) + self.assertEqual(result.name, "BasicConceptClass") + self.assertEqual(result.template.requires, None) + self.assertEqual(len(result.template.parameters), 2) + self.assertEqual(result.template.parameters[0].name, "T") + self.assertEqual(result.template.parameters[0].specifier.name, "TestConcept") + self.assertEqual(result.template.parameters[0].specifier.body, "true") + self.assertEqual(result.template.parameters[0].default_value, "bool") + self.assertEqual(result.template.parameters[0].is_variadic, False) + self.assertEqual(result.template.parameters[1].name, "Args") + self.assertEqual(result.template.parameters[1].specifier.name, "TestConcept") + self.assertEqual(result.template.parameters[1].specifier.body, "true") + self.assertEqual(result.template.parameters[1].default_value, None) + self.assertEqual(result.template.parameters[1].is_variadic, True) + + method: MethodInfo = cast(MethodInfo, result.private[0]) + self.assertEqual(method.name, "foo") + self.assertEqual(method.type, MethodType.STANDARD) + self.assertEqual(method.return_type.details.name, "T") + self.assertEqual(method.body, None) + self.assertEqual(len(method.arguments), 1) + self.assertEqual(method.arguments[0].name, "arg") + self.assertEqual(method.arguments[0].type.is_generic, True) + self.assertEqual(method.arguments[0].type.details.name, "T") + self.assertEqual(method.requires, None) + + def test_concept_class(self): + node = find_by_name(self.cursor, "ConceptClass") + result = ClassInfo.from_cursor(node) + self.assertTrue(result.is_class) + self.assertEqual(result.name, "ConceptClass") + self.assertEqual(len(result.template.requires), 1) + self.assertEqual(result.template.requires[0], "true") + self.assertEqual(len(result.template.parameters), 2) + self.assertEqual(result.template.parameters[0].name, "T") + self.assertEqual(result.template.parameters[0].specifier.name, "TestConcept") + self.assertEqual(result.template.parameters[0].specifier.body, "true") + self.assertEqual(result.template.parameters[0].default_value, None) + self.assertEqual(result.template.parameters[0].is_variadic, False) + self.assertEqual(result.template.parameters[1].name, "B") + self.assertEqual(result.template.parameters[1].specifier, "class") + self.assertEqual(result.template.parameters[1].default_value, "int") + self.assertEqual(result.template.parameters[1].is_variadic, False) + + field: FieldInfo = cast(FieldInfo, result.public[0]) + self.assertEqual(field.name, "a") + self.assertEqual(field.type.is_generic, True) + self.assertEqual(field.type.details.name, "T") + + method: MethodInfo = cast(MethodInfo, result.public[1]) + self.assertEqual(method.name, "process") + self.assertEqual(method.type, MethodType.STANDARD) + self.assertEqual(method.return_type.details, BasicType.VOID) + self.assertEqual(method.body, None) + self.assertEqual(method.arguments[0].type.is_generic, True) + self.assertEqual(method.arguments[0].type.details.name, "T") + self.assertEqual(len(method.requires), 1) + self.assertEqual(method.requires[0].name, "TestConcept") + self.assertEqual(method.requires[0].body, "true") + self.assertEqual(method.requires[0].is_requirement, True) + + def test_concept_struct(self): + node = find_by_name(self.cursor, "ConceptStruct") + result = ClassInfo.from_cursor(node) + self.assertTrue(result.is_struct) + self.assertEqual(result.name, "ConceptStruct") + self.assertEqual(len(result.template.requires), 5) + self.assertEqual(result.template.requires[0:4], ["false", "or", "true", "and"]) + self.assertEqual(result.template.requires[4].name, "TestConcept") + self.assertEqual(result.template.requires[4].is_requirement, True) + self.assertEqual(len(result.template.parameters), 2) + self.assertEqual(result.template.parameters[0].name, "A") + self.assertEqual(result.template.parameters[0].specifier, "typename") + self.assertEqual(result.template.parameters[1].name, "T") + self.assertEqual(result.template.parameters[1].specifier.name, "TestConcept") + self.assertEqual(result.template.parameters[1].specifier.body, "true") + self.assertEqual(result.template.parameters[1].default_value, None) + field: FieldInfo = cast(FieldInfo, result.public[0]) + self.assertEqual(field.name, "abc") + self.assertEqual(field.type.is_generic, True) + self.assertEqual(field.type.details.name, "T") + self.assertEqual(field.default_value, "10") + + method: MethodInfo = cast(MethodInfo, result.public[1]) + self.assertEqual(method.name, "foo") + self.assertEqual(method.return_type.is_generic, True) + self.assertEqual(method.return_type.name, "T") + self.assertEqual(len(method.arguments), 0) + self.assertEqual(method.body, None) + self.assertEqual(len(method.requires), 5) + self.assertEqual(method.requires[0], "(") + self.assertEqual(method.requires[1].name, "TestConcept") + self.assertEqual(method.requires[1].is_requirement, True) + self.assertEqual(method.requires[1].body, "true") + self.assertEqual(method.requires[2:5], ["||", "true", ")"]) + + method: MethodInfo = cast(MethodInfo, result.public[2]) + self.assertEqual(method.name, "barFoo") + self.assertEqual(method.return_type.is_generic, True) + self.assertEqual(method.return_type.name, "T") + self.assertEqual(len(method.arguments), 2) + self.assertEqual(method.arguments[0].name, "arg1") + self.assertEqual(method.arguments[0].type.is_generic, True) + self.assertEqual(method.arguments[0].type.name, "T") + self.assertEqual(method.arguments[1].name, "arg2") + self.assertEqual(method.arguments[1].type.is_generic, True) + self.assertEqual(method.arguments[1].type.name, "A") + self.assertEqual(method.body, None) + self.assertEqual(method.requires, None) class TestClassTemplatePartial(unittest.TestCase): diff --git a/tests/parsing/unit/test_concept.py b/tests/parsing/unit/test_concept.py new file mode 100644 index 0000000..9d351cf --- /dev/null +++ b/tests/parsing/unit/test_concept.py @@ -0,0 +1,190 @@ +import unittest +from difflib import restore + +import clang +import os + +from devana.syntax_abstraction.conceptinfo import ConceptInfo +from devana.syntax_abstraction.functioninfo import FunctionInfo +from tests.helpers import find_by_name, stub_lexicon + + +class TestConcept(unittest.TestCase): + + def setUp(self): + index = clang.cindex.Index.create() + self.cursor = index.parse( + os.path.dirname(__file__) + r"/source_files/concepts.hpp", + args=("-std=c++20",) + ).cursor + + def test_concept_case_1(self): + node = find_by_name(self.cursor, "ConceptCase1") + result = ConceptInfo.from_cursor(node) + self.assertIsNone(result.parent) + self.assertEqual(result.name, "ConceptCase1") + self.assertEqual(result.body.replace("\r\n", "\n"), "requires(T a) {\n { --a };\n { a-- };\n}") + self.assertEqual(len(result.template.parameters), 1) + self.assertEqual(result.is_requirement, False) + self.assertEqual(len(result.parameters), 0) + + def test_concept_case_2(self): + node = find_by_name(self.cursor, "ConceptCase2") + result = ConceptInfo.from_cursor(node) + self.assertIsNone(result.parent) + self.assertEqual(result.name, "ConceptCase2") + self.assertEqual( + result.body.replace("\r\n", "\n"), + "requires(T a, T b) {\n { a + b };\n}" + ) + self.assertEqual(len(result.template.parameters), 1) + self.assertEqual(result.is_requirement, False) + self.assertEqual(len(result.parameters), 0) + + def test_concept_case_3(self): + node = find_by_name(self.cursor, "ConceptCase3") + result = ConceptInfo.from_cursor(node) + self.assertIsNone(result.parent) + self.assertEqual(result.name, "ConceptCase3") + self.assertEqual( + result.body.replace("\r\n", "\n"), + "requires(T a, T b) {\n a = b;\n} || requires(T a, T b) {\n b = a;\n}" + ) + self.assertEqual(len(result.template.parameters), 1) + self.assertEqual(result.is_requirement, False) + self.assertEqual(len(result.parameters), 0) + + def test_concept_case_4(self): + node = find_by_name(self.cursor, "ConceptCase4") + result = ConceptInfo.from_cursor(node) + self.assertIsNone(result.parent) + self.assertEqual(result.name, "ConceptCase4") + self.assertEqual(result.body.replace("\r\n", "\n"), "requires {\n T{};\n}") + self.assertEqual(len(result.template.parameters), 1) + self.assertEqual(result.is_requirement, False) + self.assertEqual(len(result.parameters), 0) + + def test_concept_case_5(self): + node = find_by_name(self.cursor, "ConceptCase5") + result = ConceptInfo.from_cursor(node) + self.assertIsNone(result.parent) + self.assertEqual(result.name, "ConceptCase5") + self.assertEqual( + result.body.replace("\r\n", "\n"), + "requires {\n T(-1) < T(0); \n}" + ) + self.assertEqual(len(result.template.parameters), 1) + self.assertEqual(result.is_requirement, False) + self.assertEqual(len(result.parameters), 0) + + def test_concept_case_6(self): + node = find_by_name(self.cursor, "ConceptCase6") + result = ConceptInfo.from_cursor(node) + self.assertIsNone(result.parent) + self.assertEqual(result.name, "ConceptCase6") + self.assertEqual( + result.body.replace("\r\n", "\n"), + "ConceptCase1 && requires(T t) {\n *t;\n}" + ) + self.assertEqual(len(result.template.parameters), 1) + self.assertEqual(result.is_requirement, False) + self.assertEqual(len(result.parameters), 0) + + def test_concept_case_7(self): + node = find_by_name(self.cursor, "ConceptCase7") + result = ConceptInfo.from_cursor(node) + self.assertIsNone(result.parent) + self.assertEqual(result.name, "ConceptCase7") + self.assertEqual(result.body, "(T{} > 0)") + self.assertEqual(len(result.template.parameters), 1) + self.assertEqual(result.is_requirement, False) + self.assertEqual(len(result.parameters), 0) + + def test_concept_case_8(self): + node = find_by_name(self.cursor, "ConceptCase8") + result = ConceptInfo.from_cursor(node) + self.assertIsNone(result.parent) + self.assertEqual(result.name, "ConceptCase8") + self.assertEqual(result.body, "ConceptCase7") + self.assertEqual(len(result.template.parameters), 1) + self.assertEqual(result.is_requirement, False) + self.assertEqual(result.template.parameters[0].specifier.name, "ConceptCase7") + self.assertEqual(result.template.parameters[0].specifier.is_requirement, True) + self.assertEqual(len(result.parameters), 0) + + def test_concept_case_9(self): + node = find_by_name(self.cursor, "ConceptCase9") + result = ConceptInfo.from_cursor(node) + self.assertIsNone(result.parent) + self.assertEqual(result.name, "ConceptCase9") + self.assertEqual(result.body, "ConceptCase1 && ConceptCase2") + self.assertEqual(len(result.template.parameters), 1) + self.assertEqual(result.is_requirement, False) + self.assertEqual(len(result.parameters), 0) + + def test_concept_case_10(self): + node = find_by_name(self.cursor, "ConceptCase10") + result = ConceptInfo.from_cursor(node) + self.assertIsNone(result.parent) + self.assertEqual(result.name, "ConceptCase10") + self.assertEqual(result.body, "ConceptCase1 || ConceptCase2") + self.assertEqual(len(result.template.parameters), 1) + self.assertEqual(result.is_requirement, False) + self.assertEqual(len(result.parameters), 0) + + def test_concept_case_11(self): + node = find_by_name(self.cursor, "ConceptCase11") + result = ConceptInfo.from_cursor(node) + self.assertIsNone(result.parent) + self.assertEqual(result.name, "ConceptCase11") + self.assertEqual(result.body, "true") + self.assertEqual(len(result.template.parameters), 1) + self.assertEqual(result.is_requirement, False) + self.assertEqual(len(result.parameters), 0) + + def test_concept_case_12(self): + node = find_by_name(self.cursor, "ConceptCase12") + result = ConceptInfo.from_cursor(node) + self.assertIsNone(result.parent) + self.assertEqual(result.name, "ConceptCase12") + self.assertEqual(result.body, "T::value || true") + self.assertEqual(len(result.template.parameters), 1) + self.assertEqual(result.is_requirement, False) + self.assertEqual(len(result.parameters), 0) + + def test_concept_case_13(self): + node = find_by_name(self.cursor, "ConceptCase13") + result = ConceptInfo.from_cursor(node) + self.assertIsNone(result.parent) + self.assertEqual(result.name, "ConceptCase13") + self.assertEqual(result.body, "testNamespace::ConceptCase11") + self.assertEqual(len(result.template.parameters), 1) + self.assertEqual(result.is_requirement, False) + self.assertEqual(len(result.parameters), 0) + + def test_concept_template(self): + node = find_by_name(self.cursor, "ConceptTemplate") + result = ConceptInfo.from_cursor(node) + self.assertIsNone(result.parent) + self.assertEqual(len(result.parameters), 0) + self.assertEqual(result.name, "ConceptTemplate") + self.assertEqual(result.body, "true") + self.assertEqual(result.template.parent, None) + self.assertEqual(result.template.is_empty, False) + self.assertEqual(result.template.requires, None) + self.assertEqual(len(result.template.parameters), 3) + + self.assertEqual(result.template.parameters[0].name, "A") + self.assertEqual(result.template.parameters[0].specifier, "typename") + self.assertEqual(result.template.parameters[0].is_variadic, False) + self.assertEqual(result.template.parameters[0].default_value, None) + + self.assertEqual(result.template.parameters[1].name, "B") + self.assertEqual(result.template.parameters[1].specifier, "class") + self.assertEqual(result.template.parameters[1].is_variadic, False) + self.assertEqual(result.template.parameters[1].default_value, "int") + + self.assertEqual(result.template.parameters[2].name, "Args") + self.assertEqual(result.template.parameters[2].specifier, "typename") + self.assertEqual(result.template.parameters[2].is_variadic, True) + self.assertEqual(result.template.parameters[2].default_value, None) \ No newline at end of file diff --git a/tests/parsing/unit/test_function.py b/tests/parsing/unit/test_function.py index ce746f8..7244be2 100644 --- a/tests/parsing/unit/test_function.py +++ b/tests/parsing/unit/test_function.py @@ -2,8 +2,9 @@ import clang.cindex import clang import os + from tests.helpers import find_by_name, stub_lexicon -from devana.syntax_abstraction.typeexpression import BasicType, TypeModification, TypeExpression +from devana.syntax_abstraction.typeexpression import BasicType, TypeModification from devana.syntax_abstraction.functioninfo import FunctionInfo, FunctionModification from devana.syntax_abstraction.organizers.sourcefile import SourceFile from devana.utility.errors import CodeError @@ -251,7 +252,10 @@ class TestFunctionsTemplate(unittest.TestCase): def setUp(self): index = clang.cindex.Index.create() - self.cursor = index.parse(os.path.dirname(__file__) + r"/source_files/template_functions.hpp").cursor + self.cursor = index.parse( + os.path.dirname(__file__) + r"/source_files/template_functions.hpp", + args=("-std=c++20",) + ).cursor def test_common_function_template(self): node = find_by_name(self.cursor, "simple_function_typename") @@ -276,6 +280,8 @@ def test_common_function_template(self): self.assertEqual(result.template.parameters[0].name, "T") self.assertEqual(result.template.parameters[0].specifier, "typename") self.assertEqual(result.template.parameters[0].default_value, None) + self.assertEqual(result.template.requires, None) + self.assertEqual(result.requires, None) node = find_by_name(self.cursor, "simple_function_class") result = FunctionInfo.from_cursor(node) @@ -299,6 +305,8 @@ def test_common_function_template(self): self.assertEqual(result.template.parameters[0].name, "T") self.assertEqual(result.template.parameters[0].specifier, "class") self.assertEqual(result.template.parameters[0].default_value, None) + self.assertEqual(result.template.requires, None) + self.assertEqual(result.requires, None) def test_complex_function_template(self): node = find_by_name(self.cursor, "complex_function") @@ -339,12 +347,15 @@ def test_complex_function_template(self): self.assertEqual(result.template.parameters[1].specifier, "typename") self.assertEqual(result.template.parameters[1].default_value, "const float") self.assertEqual(len(result.template.specialisations), 0) + self.assertEqual(result.template.requires, None) + self.assertEqual(result.requires, None) def test_specialisation_function_template(self): node = find_by_name(self.cursor, "specialisation_function") result = FunctionInfo.from_cursor(node) stub_lexicon(result) self.assertFalse(result.template is None) + self.assertEqual(result.requires, None) self.assertTrue(result.template.is_empty) self.assertEqual(result.name, "specialisation_function") self.assertEqual(len(result.arguments), 4) @@ -372,6 +383,118 @@ def test_specialisation_function_template(self): self.assertTrue(result.return_type.modification.is_pointer) self.assertTrue(result.return_type.modification.is_const) self.assertEqual(result.body, None) + self.assertEqual(result.template.requires, None) + + def test_template_concept_function(self): + node = find_by_name(self.cursor, "template_concept_function") + result = FunctionInfo.from_cursor(node) + stub_lexicon(result) + self.assertEqual(result.name, "template_concept_function") + self.assertEqual(result.body, None) + self.assertEqual(result.return_type.is_generic, True) + self.assertEqual(result.return_type.details.name, "T") + self.assertEqual(result.requires, None) + self.assertEqual(result.template.requires, None) + self.assertEqual(len(result.template.parameters), 2) + + self.assertEqual(result.template.parameters[0].name, "T") + self.assertEqual(result.template.parameters[0].specifier.name, "AlwaysTrue") + self.assertEqual(result.template.parameters[0].specifier.is_requirement, True) + self.assertEqual(result.template.parameters[0].specifier.body, "true") + self.assertEqual(result.template.parameters[0].default_value, "int") + self.assertEqual(result.template.parameters[0].is_variadic, False) + + self.assertEqual(result.template.parameters[1].name, "Args") + self.assertEqual(result.template.parameters[1].specifier.name, "AlwaysTrue") + self.assertEqual(result.template.parameters[1].specifier.is_requirement, True) + self.assertEqual(result.template.parameters[1].specifier.body, "true") + self.assertEqual(result.template.parameters[1].default_value, None) + self.assertEqual(result.template.parameters[1].is_variadic, True) + + def test_requires_template_function1(self): + node = find_by_name(self.cursor, "requires_template_function1") + result = FunctionInfo.from_cursor(node) + stub_lexicon(result) + self.assertEqual(result.name, "requires_template_function1") + self.assertEqual(result.body, None) + self.assertEqual(result.return_type.details, BasicType.VOID) + self.assertEqual(result.requires, ["true", "or", "false"]) + + self.assertEqual(len(result.arguments), 1) + self.assertEqual(result.arguments[0].name, "a") + self.assertEqual(result.arguments[0].type.is_generic, True) + self.assertEqual(result.arguments[0].type.details.name, "T") + self.assertEqual(result.template.parameters[0].name, "T") + self.assertEqual(result.template.parameters[0].specifier, "typename") + + self.assertEqual(len(result.template.requires), 1) + self.assertEqual(result.template.requires[0].name, "AlwaysTrue") + self.assertEqual(result.template.requires[0].is_requirement, True) + self.assertEqual(result.template.requires[0].body, "true") + + def test_requires_template_function2(self): + node = find_by_name(self.cursor, "requires_template_function2") + result = FunctionInfo.from_cursor(node) + stub_lexicon(result) + self.assertEqual(result.name, "requires_template_function2") + self.assertEqual(result.body, None) + self.assertEqual(result.return_type.details, BasicType.INT) + + self.assertEqual(len(result.arguments), 1) + self.assertEqual(result.arguments[0].name, "a") + self.assertEqual(result.arguments[0].default_value, "1") + self.assertEqual(result.arguments[0].type.is_generic, True) + self.assertEqual(result.arguments[0].type.details.name, "T") + + self.assertEqual(len(result.requires), 5) + self.assertEqual(result.requires[0].name, "AlwaysTrue") + self.assertEqual(result.requires[0].is_requirement, True) + self.assertEqual(result.requires[0].body, "true") + self.assertEqual(result.requires[1:5], ["and", "(", "true", ")"]) + self.assertEqual(result.template.parameters[0].name, "T") + self.assertEqual(result.template.parameters[0].specifier.name, "AlwaysTrue") + self.assertEqual(result.template.parameters[0].specifier.body, "true") + + self.assertEqual(len(result.template.requires), 3) + self.assertEqual(result.template.requires[0], "true") + self.assertEqual(result.template.requires[1], "or") + self.assertEqual(result.template.requires[2].name, "AlwaysTrue") + self.assertEqual(result.template.requires[2].is_requirement, True) + self.assertEqual(result.template.requires[2].body, "true") + + def test_basic_concept_function(self): + node = find_by_name(self.cursor, "basic_concept_function") + result = FunctionInfo.from_cursor(node) + stub_lexicon(result) + self.assertEqual(result.name, "basic_concept_function") + self.assertEqual(result.body, None) + self.assertEqual(result.requires, None) + self.assertEqual(len(result.arguments), 0) + self.assertEqual(result.return_type.details, BasicType.VOID) + self.assertEqual(len(result.template.requires), 7) + self.assertEqual(result.template.requires[0], "(") + self.assertEqual(result.template.requires[1].name, "AlwaysTrue") + self.assertEqual(result.template.requires[1].is_requirement, True) + self.assertEqual(result.template.requires[1].body, "true") + self.assertEqual(result.template.requires[2:7], ["or", "true", ")", "and", "false"]) + + def test_namespace_concept_function(self): + node = find_by_name(self.cursor, "concept_namespace_function") + result = FunctionInfo.from_cursor(node) + stub_lexicon(result) + self.assertEqual(result.name, "concept_namespace_function") + self.assertEqual(result.body, None) + self.assertEqual(result.requires, None) + self.assertEqual(len(result.arguments), 0) + self.assertEqual(result.return_type.details, BasicType.VOID) + self.assertEqual(result.template.requires, None) + + param = result.template.parameters[0] + self.assertEqual(param.name, "T") + self.assertEqual(param.specifier.name, "AlwaysFalse") + self.assertEqual(param.specifier.is_requirement, True) + self.assertEqual(len(param.specifier.parameters), 1) + self.assertEqual(param.specifier.namespace, "test") class TestFunctionsOverload(unittest.TestCase): diff --git a/tests/parsing/unit/test_templates.py b/tests/parsing/unit/test_templates.py index b16ebcf..e60d219 100644 --- a/tests/parsing/unit/test_templates.py +++ b/tests/parsing/unit/test_templates.py @@ -3,7 +3,7 @@ import clang import os from devana.syntax_abstraction.organizers.sourcefile import SourceFile -from devana.syntax_abstraction.typeexpression import BasicType, TypeModification +from devana.syntax_abstraction.typeexpression import TypeModification from devana.syntax_abstraction.classinfo import * from devana.syntax_abstraction.typedefinfo import TypedefInfo @@ -12,7 +12,9 @@ class TestTemplateAdvanced(unittest.TestCase): def setUp(self): index = clang.cindex.Index.create() - self.cursor = index.parse(os.path.dirname(__file__) + r"/source_files/advanced_template.hpp").cursor + self.cursor = index.parse( + os.path.dirname(__file__) + r"/source_files/advanced_template.hpp" + ).cursor self.file = SourceFile.from_cursor(self.cursor) self.assertEqual(len(self.file.content), 32) diff --git a/tests/parsing/unit/test_using.py b/tests/parsing/unit/test_using.py index 63ea3c6..032978e 100644 --- a/tests/parsing/unit/test_using.py +++ b/tests/parsing/unit/test_using.py @@ -12,7 +12,10 @@ class TestUsing(unittest.TestCase): def setUp(self): index = clang.cindex.Index.create() - self.cursor = index.parse(os.path.dirname(__file__) + r"/source_files/using.hpp").cursor + self.cursor = index.parse( + os.path.dirname(__file__) + r"/source_files/using.hpp", + args=("-std=c++20",) + ).cursor self.file = SourceFile.from_cursor(self.cursor) def test_using_as_simple_alias(self): @@ -23,6 +26,7 @@ def test_using_as_simple_alias(self): self.assertEqual(source.type_info.details, self.file.content[0].content[0]) fnc: FunctionInfo = self.file.content[2] self.assertEqual(fnc.arguments[0].type.details, source) + self.assertEqual(source.template, None) def test_using_as_template_alias(self): source: Using = self.file.content[4] @@ -31,3 +35,74 @@ def test_using_as_template_alias(self): self.assertEqual(source.type_info.details, self.file.content[3]) self.assertEqual(len(source.type_info.template_arguments), 1) self.assertEqual(source.type_info.template_arguments[0].details, BasicType.DOUBLE) + self.assertEqual(source.template, None) + + def test_using_with_template(self): + source: Using = self.file.content[8] + self.assertEqual(source.name, "UsingTemplate") + self.assertEqual(source.type_info.is_generic, True) + self.assertEqual(source.type_info.details.name, "A") + self.assertEqual(source.associated_comment, None) + self.assertNotEqual(source.template, None) + self.assertEqual(source.template.requires, None) + + self.assertEqual(len(source.template.parameters), 3) + self.assertEqual(source.template.parameters[0].specifier, "typename") + self.assertEqual(source.template.parameters[0].name, "A") + self.assertEqual(source.template.parameters[0].default_value, None) + self.assertEqual(source.template.parameters[0].is_variadic, False) + self.assertEqual(source.template.parameters[1].specifier, "class") + self.assertEqual(source.template.parameters[1].name, "B") + self.assertEqual(source.template.parameters[1].default_value, "float") + self.assertEqual(source.template.parameters[1].is_variadic, False) + self.assertEqual(source.template.parameters[2].specifier, "typename") + self.assertEqual(source.template.parameters[2].name, "Args") + self.assertEqual(source.template.parameters[2].default_value, None) + self.assertEqual(source.template.parameters[2].is_variadic, True) + + def test_using_with_template_requires(self): + source: Using = self.file.content[9] + self.assertEqual(source.name, "UsingTemplateRequires") + self.assertEqual(source.type_info.is_generic, False) + self.assertEqual(source.type_info.modification.is_const, True) + self.assertEqual(source.type_info.details, self.file.content[6]) + self.assertEqual(len(source.type_info.template_arguments), 2) + self.assertEqual(source.type_info.template_arguments[0].is_generic, True) + self.assertEqual(source.type_info.template_arguments[0].details.name, "C") + self.assertEqual(source.type_info.template_arguments[1].is_generic, True) + self.assertEqual(source.type_info.template_arguments[1].details.name, "T") + self.assertEqual(source.associated_comment, None) + self.assertNotEqual(source.template, None) + + self.assertEqual(len(source.template.requires), 3) + self.assertEqual(source.template.requires[0:2], ["true", "or"]) + self.assertEqual(source.template.requires[2].name, "TestConcept") + self.assertEqual(source.template.requires[2].is_requirement, True) + self.assertEqual(source.template.requires[2].body, "true") + self.assertEqual(len(source.template.parameters), 2) + self.assertEqual(source.template.parameters[0].specifier, "typename") + self.assertEqual(source.template.parameters[0].name, "T") + self.assertEqual(source.template.parameters[0].default_value, None) + self.assertEqual(source.template.parameters[0].is_variadic, False) + + def test_using_with_concept(self): + source: Using = self.file.content[10] + self.assertEqual(source.name, "UsingConcept") + self.assertEqual(source.type_info.is_generic, False) + self.assertEqual(source.type_info.modification.is_const, True) + self.assertEqual(source.type_info.modification.is_pointer, True) + self.assertEqual(source.type_info.details, self.file.content[6]) + self.assertEqual(len(source.type_info.template_arguments), 2) + self.assertEqual(source.type_info.template_arguments[0].is_generic, True) + self.assertEqual(source.type_info.template_arguments[0].details.name, "float") + self.assertEqual(source.associated_comment, None) + self.assertNotEqual(source.template, None) + self.assertEqual(source.template.requires, None) + + self.assertEqual(len(source.template.parameters), 1) + self.assertEqual(source.template.parameters[0].specifier.name, "TestConcept") + self.assertEqual(source.template.parameters[0].specifier.is_requirement, True) + self.assertEqual(source.template.parameters[0].specifier.body, "true") + self.assertEqual(source.template.parameters[0].name, "B") + self.assertEqual(source.template.parameters[0].default_value, "int") + self.assertEqual(source.template.parameters[0].is_variadic, False)