Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions crates/oq3_parser/src/grammar/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -422,17 +422,39 @@ fn type_name(p: &mut Parser<'_>) {
p.bump(p.current());
}

/// Parse a type specification appearing in
/// the signature of a subroutine definition (`def`) statement.
pub(crate) fn param_type_spec(p: &mut Parser<'_>) -> bool {
if p.at(T![array]) || p.at(T![mutable]) || p.at(T![readonly]) {
let want_array_ref_type = true;
return array_type_spec(p, want_array_ref_type);
}
non_array_type_spec(p)
}

/// Parse a type specification in contexts other than the
/// the signature of a subroutine definition (`def`) statement.
pub(crate) fn type_spec(p: &mut Parser<'_>) -> bool {
if p.at(T![array]) {
return array_type_spec(p);
let want_array_ref_type = false;
return array_type_spec(p, want_array_ref_type);
}
non_array_type_spec(p)
}

// Parse an array type spec
pub(crate) fn array_type_spec(p: &mut Parser<'_>) -> bool {
assert!(p.at(T![array]));
pub(crate) fn array_type_spec(p: &mut Parser<'_>, want_array_ref_type: bool) -> bool {
let m = p.start();
if want_array_ref_type {
if p.at(T![array]) {
p.error("Expecting modifier `mutable` or `immutable`");
} else {
p.eat(T![mutable]);
p.eat(T![readonly]);
}
} else {
assert!(p.at(T![array]));
}
p.bump_any();
p.expect(T!['[']);
if !matches!(
Expand All @@ -445,6 +467,9 @@ pub(crate) fn array_type_spec(p: &mut Parser<'_>) -> bool {
p.expect(COMMA);
// Parse the dimensions.
if p.at(T![dim]) {
if !want_array_ref_type {
p.error("Unexpected dim expression outside of subroutine declaration");
}
let m = p.start();
p.bump_any();
if p.eat(T![=]) {
Expand Down
23 changes: 19 additions & 4 deletions crates/oq3_parser/src/grammar/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pub(super) fn arg_list_gate_call_qubits(p: &mut Parser<'_>) {
_param_list_openqasm(p, DefFlavor::GateCallQubits);
}

// Used only in subroutine (`def`) definitions.
pub(super) fn param_list_def_params(p: &mut Parser<'_>) {
// Subroutine definition parameter list: (t0 p0, t1 p1, ...)
// - parens: yes
Expand Down Expand Up @@ -149,9 +150,11 @@ fn _param_list_openqasm(p: &mut Parser<'_>, flavor: DefFlavor) {
let m = p.start();

let inner_array_literal = p.at(T!['{']);

// Allowed starts for an item: either a type or a first-token of a param/expression,
// or first token of array literal.
if !(p.current().is_type() || p.at_ts(PARAM_FIRST) || inner_array_literal) {
if matches!(flavor, DefParams) && (p.at(T![mutable]) || p.at(T![readonly])) {
} else if !(p.current().is_type() || p.at_ts(PARAM_FIRST) || inner_array_literal) {
p.error("expected value parameter");
m.abandon(p);
break;
Expand All @@ -165,9 +168,10 @@ fn _param_list_openqasm(p: &mut Parser<'_>, flavor: DefFlavor) {
true
}
GateCallQubits => arg_gate_call_qubit(p, m),
// These two have different requirements but share this entry point.
DefParams | DefCalParams => param_typed(p, m),
TypeListFlavor => scalar_type(p, m),
// TODO: possibly fix after clarifying what the spec says
DefCalParams => param_typed(p, m),
DefParams => param_typed(p, m),
// Untyped parameters/qubits.
GateParams | GateQubits => param_untyped(p, m),
DefCalQubits => param_untyped_or_hardware_qubit(p, m),
Expand Down Expand Up @@ -287,13 +291,24 @@ fn param_untyped_or_hardware_qubit(p: &mut Parser<'_>, m: Marker) -> bool {
}
}

/// Parse one parameter in the list of parameters in the signature
/// of a subroutine defintion (that is, a `def` statement)
fn param_typed(p: &mut Parser<'_>, m: Marker) -> bool {
expressions::type_spec(p);
expressions::param_type_spec(p);
expressions::var_name(p);
m.complete(p, TYPED_PARAM);
true
}

// TODO: Get clarification on the spec vis a vis defcal and def,
// then revisit this.
// fn scalar_typed(p: &mut Parser<'_>, m: Marker) -> bool {
// expressions::type_spec(p);
// expressions::var_name(p);
// m.complete(p, TYPED_PARAM);
// true
// }

fn scalar_type(p: &mut Parser<'_>, m: Marker) -> bool {
expressions::type_spec(p);
m.complete(p, SCALAR_TYPE);
Expand Down
2 changes: 2 additions & 0 deletions crates/oq3_parser/src/syntax_kind/syntax_kind_enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ pub enum SyntaxKind {
LITERAL,
NAME,
PARAM,
PARAM_TYPE,
PARAM_LIST,
PREFIX_EXPR,
QUBIT_LIST,
Expand All @@ -197,6 +198,7 @@ pub enum SyntaxKind {
ALIAS_DECLARATION_STATEMENT,
ARRAY_LITERAL,
ARRAY_TYPE,
ARRAY_REF_TYPE,
ASSIGNMENT_STMT,
CLASSICAL_DECLARATION_STATEMENT,
DESIGNATOR,
Expand Down
16 changes: 15 additions & 1 deletion crates/oq3_semantics/src/syntax_to_semantics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1126,6 +1126,19 @@ fn block_or_stmt_to_asg_type(val: oq3_syntax::BlockOrStmt, context: &mut Context
}

// Convert AST scalar type to a `types::Type`
fn param_type_to_type(
param_type: &synast::ParamType,
// scalar_type: &synast::ScalarType,
isconst: bool,
context: &mut Context,
) -> Type {
let scalar_type = match param_type {
synast::ParamType::ScalarType(scalar_type) => scalar_type,
synast::ParamType::ArrayRefType(_) => return Type::ToDo,
};
scalar_type_to_type(scalar_type, isconst, context)
}

fn scalar_type_to_type(
scalar_type: &synast::ScalarType,
isconst: bool,
Expand All @@ -1140,6 +1153,7 @@ fn scalar_type_to_type(
// Eg, we write `int[32]`, but we don't write `complex[32]`, but rather `complex[float[32]]`.
// However `Type::Complex` has exactly the same form as other scalar types. In this case
// `width` is understood to be the width of each of real and imaginary components.

let designator = if let Some(float_type) = scalar_type.scalar_type() {
// complex
float_type.designator()
Expand Down Expand Up @@ -1457,7 +1471,7 @@ fn bind_typed_parameter_list(
param_list
.typed_params()
.map(|param| {
let typ = scalar_type_to_type(&param.scalar_type().unwrap(), false, context);
let typ = param_type_to_type(&param.param_type().unwrap(), false, context);
let namestr = param.name().unwrap().string();
context.new_binding(namestr.as_ref(), &typ, &param)
})
Expand Down
2 changes: 2 additions & 0 deletions crates/oq3_syntax/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,7 @@ mod sourcegen {
"LITERAL",
"NAME",
"PARAM",
"PARAM_TYPE",
"PARAM_LIST",
"PREFIX_EXPR",
"QUBIT_LIST",
Expand All @@ -453,6 +454,7 @@ mod sourcegen {
"ALIAS_DECLARATION_STATEMENT",
"ARRAY_LITERAL",
"ARRAY_TYPE",
"ARRAY_REF_TYPE",
"ASSIGNMENT_STMT",
"CLASSICAL_DECLARATION_STATEMENT",
"DESIGNATOR",
Expand Down
11 changes: 8 additions & 3 deletions crates/oq3_syntax/openqasm3.ungram
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,16 @@ Gate =
ParamList =
'(' (Param (',' Param)* ','?)? ')'

TypedParamList =
'(' (TypedParam (',' TypedParam)* ','?)? ')'
ArrayRefType = ('readonly' | 'mutable') 'array' '[' ScalarType ',' (ExpressionList | DimExpr) ']'

ParamType =
(ScalarType | ArrayRefType)

TypedParam =
ScalarType Name
ParamType Name

TypedParamList =
'(' (TypedParam (',' TypedParam)* ','?)? ')'

// For 'extern'
TypeList =
Expand Down
104 changes: 98 additions & 6 deletions crates/oq3_syntax/src/ast/generated/nodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -675,14 +675,37 @@ pub struct Param {
impl ast::HasName for Param {}
impl Param {}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct TypedParam {
pub struct ArrayRefType {
pub(crate) syntax: SyntaxNode,
}
impl ast::HasName for TypedParam {}
impl TypedParam {
impl ArrayRefType {
pub fn readonly_token(&self) -> Option<SyntaxToken> {
support::token(&self.syntax, T![readonly])
}
pub fn mutable_token(&self) -> Option<SyntaxToken> {
support::token(&self.syntax, T![mutable])
}
pub fn array_token(&self) -> Option<SyntaxToken> {
support::token(&self.syntax, T![array])
}
pub fn l_brack_token(&self) -> Option<SyntaxToken> {
support::token(&self.syntax, T!['['])
}
pub fn scalar_type(&self) -> Option<ScalarType> {
support::child(&self.syntax)
}
pub fn comma_token(&self) -> Option<SyntaxToken> {
support::token(&self.syntax, T![,])
}
pub fn expression_list(&self) -> Option<ExpressionList> {
support::child(&self.syntax)
}
pub fn dim_expr(&self) -> Option<DimExpr> {
support::child(&self.syntax)
}
pub fn r_brack_token(&self) -> Option<SyntaxToken> {
support::token(&self.syntax, T![']'])
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ScalarType {
Expand Down Expand Up @@ -730,6 +753,16 @@ impl ScalarType {
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct TypedParam {
pub(crate) syntax: SyntaxNode,
}
impl ast::HasName for TypedParam {}
impl TypedParam {
pub fn param_type(&self) -> Option<ParamType> {
support::child(&self.syntax)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ArrayExpr {
pub(crate) syntax: SyntaxNode,
}
Expand Down Expand Up @@ -1191,6 +1224,11 @@ pub enum GateOperand {
HardwareQubit(HardwareQubit),
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum ParamType {
ScalarType(ScalarType),
ArrayRefType(ArrayRefType),
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum Modifier {
InvModifier(InvModifier),
PowModifier(PowModifier),
Expand Down Expand Up @@ -1872,9 +1910,9 @@ impl AstNode for Param {
&self.syntax
}
}
impl AstNode for TypedParam {
impl AstNode for ArrayRefType {
fn can_cast(kind: SyntaxKind) -> bool {
kind == TYPED_PARAM
kind == ARRAY_REF_TYPE
}
fn cast(syntax: SyntaxNode) -> Option<Self> {
if Self::can_cast(syntax.kind()) {
Expand Down Expand Up @@ -1902,6 +1940,21 @@ impl AstNode for ScalarType {
&self.syntax
}
}
impl AstNode for TypedParam {
fn can_cast(kind: SyntaxKind) -> bool {
kind == TYPED_PARAM
}
fn cast(syntax: SyntaxNode) -> Option<Self> {
if Self::can_cast(syntax.kind()) {
Some(Self { syntax })
} else {
None
}
}
fn syntax(&self) -> &SyntaxNode {
&self.syntax
}
}
impl AstNode for ArrayExpr {
fn can_cast(kind: SyntaxKind) -> bool {
kind == ARRAY_EXPR
Expand Down Expand Up @@ -2840,6 +2893,35 @@ impl AstNode for GateOperand {
}
}
}
impl From<ScalarType> for ParamType {
fn from(node: ScalarType) -> ParamType {
ParamType::ScalarType(node)
}
}
impl From<ArrayRefType> for ParamType {
fn from(node: ArrayRefType) -> ParamType {
ParamType::ArrayRefType(node)
}
}
impl AstNode for ParamType {
fn can_cast(kind: SyntaxKind) -> bool {
matches!(kind, SCALAR_TYPE | ARRAY_REF_TYPE)
}
fn cast(syntax: SyntaxNode) -> Option<Self> {
let res = match syntax.kind() {
SCALAR_TYPE => ParamType::ScalarType(ScalarType { syntax }),
ARRAY_REF_TYPE => ParamType::ArrayRefType(ArrayRefType { syntax }),
_ => return None,
};
Some(res)
}
fn syntax(&self) -> &SyntaxNode {
match self {
ParamType::ScalarType(it) => &it.syntax,
ParamType::ArrayRefType(it) => &it.syntax,
}
}
}
impl From<InvModifier> for Modifier {
fn from(node: InvModifier) -> Modifier {
Modifier::InvModifier(node)
Expand Down Expand Up @@ -2986,6 +3068,11 @@ impl std::fmt::Display for GateOperand {
std::fmt::Display::fmt(self.syntax(), f)
}
}
impl std::fmt::Display for ParamType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(self.syntax(), f)
}
}
impl std::fmt::Display for Modifier {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(self.syntax(), f)
Expand Down Expand Up @@ -3216,7 +3303,7 @@ impl std::fmt::Display for Param {
std::fmt::Display::fmt(self.syntax(), f)
}
}
impl std::fmt::Display for TypedParam {
impl std::fmt::Display for ArrayRefType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(self.syntax(), f)
}
Expand All @@ -3226,6 +3313,11 @@ impl std::fmt::Display for ScalarType {
std::fmt::Display::fmt(self.syntax(), f)
}
}
impl std::fmt::Display for TypedParam {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(self.syntax(), f)
}
}
impl std::fmt::Display for ArrayExpr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(self.syntax(), f)
Expand Down
Loading
Loading