Skip to content

Commit 240ee11

Browse files
committed
Allow passing structs as pointers to functions (#63)
1 parent 162be58 commit 240ee11

7 files changed

Lines changed: 140 additions & 37 deletions

File tree

parser/function.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ package parser
33
type FunctionDefinition struct {
44
name string
55
returnTypes []ValueType
6-
params []Variable
6+
params []Param
77
body []Statement
88
public bool
99
}
@@ -28,7 +28,7 @@ func (e FunctionDefinition) ReturnTypes() []ValueType {
2828
return e.returnTypes
2929
}
3030

31-
func (e FunctionDefinition) Params() []Variable {
31+
func (e FunctionDefinition) Params() []Param {
3232
return e.params
3333
}
3434

@@ -43,7 +43,7 @@ func (e FunctionDefinition) Public() bool {
4343
type FunctionCall struct {
4444
name string
4545
returnTypes []ValueType
46-
params []Variable
46+
params []Param
4747
arguments []Expression
4848
}
4949

@@ -67,7 +67,7 @@ func (e FunctionCall) ReturnTypes() []ValueType {
6767
return e.returnTypes
6868
}
6969

70-
func (e FunctionCall) Params() []Variable {
70+
func (e FunctionCall) Params() []Param {
7171
return e.params
7272
}
7373

parser/param.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
package parser
2+
3+
type Param struct {
4+
Variable
5+
pointer bool
6+
}
7+
8+
func NewParam(name string, valueType ValueType, layer int, public bool, pointer bool) Param {
9+
return Param{
10+
Variable: Variable{
11+
name,
12+
valueType,
13+
layer,
14+
public,
15+
},
16+
pointer: pointer,
17+
}
18+
}
19+
20+
func (p Param) Pointer() bool {
21+
return p.pointer
22+
}

parser/parser.go

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1862,8 +1862,8 @@ func (p *Parser) evaluateVarAssignment(importAlias string, ctx context) (Stateme
18621862
}, nil
18631863
}
18641864

1865-
func (p *Parser) evaluateParams(ctx context) ([]Variable, error) {
1866-
params := []Variable{}
1865+
func (p *Parser) evaluateParams(ctx context) ([]Param, error) {
1866+
params := []Param{}
18671867

18681868
for {
18691869
nameToken := p.peek()
@@ -1884,10 +1884,25 @@ func (p *Parser) evaluateParams(ctx context) ([]Variable, error) {
18841884
if exists {
18851885
return params, fmt.Errorf("scope already contains a variable with the name %s", name)
18861886
}
1887+
pointer := false
1888+
pointerToken := p.peek()
1889+
1890+
if pointerToken.Type() == lexer.BINARY_OPERATOR {
1891+
p.eat()
1892+
pointerValue := pointerToken.Value()
1893+
1894+
if pointerValue != "*" {
1895+
return nil, p.expectedError(fmt.Sprintf(`"*" but got "%s"`, pointerValue), pointerToken)
1896+
}
1897+
pointer = true
1898+
}
1899+
valueTypeToken := p.peek()
18871900
valueType, err := p.evaluateValueType(ctx)
18881901

18891902
if err != nil {
18901903
return nil, err
1904+
} else if pointer && valueType.Type().Kind() != TypeKindStruct {
1905+
return nil, p.atError(fmt.Sprintf("pointers are only supported for structs but got %s", valueType.String()), valueTypeToken)
18911906
}
18921907
nextToken := p.peek()
18931908
nextTokenType := nextToken.Type()
@@ -1897,7 +1912,7 @@ func (p *Parser) evaluateParams(ctx context) ([]Variable, error) {
18971912
} else if nextTokenType == lexer.COMMA {
18981913
p.eat()
18991914
}
1900-
params = append(params, NewVariable(name, valueType, ctx.layer+1, false))
1915+
params = append(params, NewParam(name, valueType, ctx.layer+1, false, pointer))
19011916
}
19021917
return params, nil
19031918
}
@@ -1925,7 +1940,7 @@ func (p *Parser) evaluateFunctionDefinition(ctx context) (Statement, error) {
19251940
return nil, p.expectedError("unique function name", nameToken)
19261941
}
19271942
openingBrace := p.peek()
1928-
params := []Variable{}
1943+
params := []Param{}
19291944

19301945
// Clone context to avoid modification of the original.
19311946
ctx = ctx.clone()
@@ -2704,8 +2719,16 @@ func (p *Parser) evaluateNamedValueEvaluation(importAlias string, ctx context) (
27042719
Const: namedValue.(Const),
27052720
}, nil
27062721
}
2722+
param, isParam := namedValue.(Param)
2723+
var variable Variable
2724+
2725+
if isParam {
2726+
variable = NewVariable(param.Name(), param.ValueType(), param.Layer(), param.Public())
2727+
} else {
2728+
variable = namedValue.(Variable)
2729+
}
27072730
return VariableEvaluation{
2708-
Variable: namedValue.(Variable),
2731+
Variable: variable,
27092732
}, nil
27102733
}
27112734

@@ -3155,7 +3178,7 @@ func (p *Parser) evaluateLogicalOperation(ctx context, operator LogicalOperator,
31553178
return leftExpression, nil
31563179
}
31573180

3158-
func (p *Parser) evaluateArguments(typeName string, name string, params []Variable, ctx context) ([]Expression, error) {
3181+
func (p *Parser) evaluateArguments(typeName string, name string, params []Param, ctx context) ([]Expression, error) {
31593182
var err error
31603183
openingBraceToken := p.eat()
31613184

tests/functions.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,44 @@ func testSliceParamFunctionSuccess(t *testing.T, transpilerFunc transpilerFunc)
8686
})
8787
}
8888

89+
func testStructParamFunctionSuccess(t *testing.T, transpilerFunc transpilerFunc) {
90+
transpilerFunc(t, `
91+
type easyStruct struct {
92+
a string
93+
}
94+
95+
func test(s easyStruct) {
96+
s.a = "Bye"
97+
}
98+
s := easyStruct{a: "Hello"}
99+
print(s.a)
100+
test(s)
101+
print(s.a)
102+
`, func(output string, err error) {
103+
require.Nil(t, err)
104+
require.Equal(t, "Hello\nHello", output)
105+
})
106+
}
107+
108+
func testStructPointerParamFunctionSuccess(t *testing.T, transpilerFunc transpilerFunc) {
109+
transpilerFunc(t, `
110+
type easyStruct struct {
111+
a string
112+
}
113+
114+
func test(s *easyStruct) {
115+
s.a = "Bye"
116+
}
117+
s := easyStruct{a: "Hello"}
118+
print(s.a)
119+
test(s)
120+
print(s.a)
121+
`, func(output string, err error) {
122+
require.Nil(t, err)
123+
require.Equal(t, "Hello\nBye", output)
124+
})
125+
}
126+
89127
func testCallFunctionFromFunctionSuccess(t *testing.T, transpilerFunc transpilerFunc) {
90128
transpilerFunc(t, `
91129
func test1(retVal1 int, retVal2 int) (int, int) {

tests/functions_linux_test.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,14 @@ func TestSliceParamFunctionSuccess(t *testing.T) {
2828
testSliceParamFunctionSuccess(t, transpileBash)
2929
}
3030

31+
func TestStructParamFunctionSuccess(t *testing.T) {
32+
testStructParamFunctionSuccess(t, transpileBash)
33+
}
34+
35+
func TestStructPointerParamFunctionSuccess(t *testing.T) {
36+
testStructPointerParamFunctionSuccess(t, transpileBash)
37+
}
38+
3139
func TestCallFunctionFromFunctionSuccess(t *testing.T) {
3240
testCallFunctionFromFunctionSuccess(t, transpileBash)
3341
}

tests/functions_windows_test.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,14 @@ func TestSliceParamFunctionSuccess(t *testing.T) {
2828
testSliceParamFunctionSuccess(t, transpileBatch)
2929
}
3030

31+
func TestStructParamFunctionSuccess(t *testing.T) {
32+
testStructParamFunctionSuccess(t, transpileBatch)
33+
}
34+
35+
func TestStructPointerParamFunctionSuccess(t *testing.T) {
36+
testStructPointerParamFunctionSuccess(t, transpileBatch)
37+
}
38+
3139
func TestCallFunctionFromFunctionSuccess(t *testing.T) {
3240
testCallFunctionFromFunctionSuccess(t, transpileBatch)
3341
}

transpiler/transpiler.go

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ func (t *transpiler) evaluateFor(forStatement parser.For) error {
396396
return conv.ForEnd()
397397
}
398398

399-
func (t *transpiler) evaluateExpressionAssignment(assignedExpression parser.Expression) (expressionResult, error) {
399+
func (t *transpiler) evaluateExpressionAssignment(assignedExpression parser.Expression, pointer bool) (expressionResult, error) {
400400
result, err := t.evaluateExpression(assignedExpression, true)
401401
value := result.firstValue()
402402

@@ -408,32 +408,35 @@ func (t *transpiler) evaluateExpressionAssignment(assignedExpression parser.Expr
408408
if !valueType.IsSlice() {
409409
switch evaluationType := valueType.Type().(type) {
410410
case parser.StructDefinition:
411-
newStruct, err := t.converter.StructInitialization([]StructValue{}, true)
411+
// Copy struct if it shall not be passed by reference.
412+
if !pointer {
413+
newStruct, err := t.converter.StructInitialization([]StructValue{}, true)
412414

413-
if err != nil {
414-
return expressionResult{}, err
415-
}
415+
if err != nil {
416+
return expressionResult{}, err
417+
}
416418

417-
// If expression is a struct, the values need to be copied to avoid manipulation of the original.
418-
for _, field := range evaluationType.Fields() {
419-
fieldName := field.Name()
420-
fieldValue, err := t.converter.StructEvaluation(value, fieldName, true)
419+
// If expression is a struct, the values need to be copied to avoid manipulation of the original.
420+
for _, field := range evaluationType.Fields() {
421+
fieldName := field.Name()
422+
fieldValue, err := t.converter.StructEvaluation(value, fieldName, true)
421423

422-
if err != nil {
423-
return expressionResult{}, nil
424+
if err != nil {
425+
return expressionResult{}, nil
426+
}
427+
err = t.converter.StructAssignment(newStruct, fieldName, fieldValue, false)
428+
429+
if err != nil {
430+
return expressionResult{}, err
431+
}
424432
}
425-
err = t.converter.StructAssignment(newStruct, fieldName, fieldValue, false)
433+
evaluatedValue, err := t.converter.VarEvaluation(newStruct, true, false)
426434

427435
if err != nil {
428436
return expressionResult{}, err
429437
}
438+
value = evaluatedValue
430439
}
431-
evaluatedValue, err := t.converter.VarEvaluation(newStruct, true, false)
432-
433-
if err != nil {
434-
return expressionResult{}, err
435-
}
436-
value = evaluatedValue
437440
}
438441
}
439442
return newExpressionResult(value), nil
@@ -462,7 +465,7 @@ func (t *transpiler) evaluateConstDefinition(definition parser.ConstDefinition)
462465

463466
func (t *transpiler) evaluateVarDefinition(definition parser.VariableDefinitionValueAssignment) error {
464467
for i, variable := range definition.Variables() {
465-
result, err := t.evaluateExpressionAssignment(definition.Values()[i])
468+
result, err := t.evaluateExpressionAssignment(definition.Values()[i], false)
466469

467470
if err != nil {
468471
return err
@@ -503,7 +506,7 @@ func (t *transpiler) evaluateVarDefinitionCallAssignment(definition parser.Varia
503506

504507
func (t *transpiler) evaluateVarAssignment(assignment parser.VariableAssignmentValueAssignment) error {
505508
for i, variable := range assignment.Variables() {
506-
result, err := t.evaluateExpressionAssignment(assignment.Values()[i])
509+
result, err := t.evaluateExpressionAssignment(assignment.Values()[i], false)
507510

508511
if err != nil {
509512
return err
@@ -554,7 +557,7 @@ func (t *transpiler) evaluateSliceAssignment(assignment parser.SliceAssignment)
554557
return err
555558
}
556559
assignmentValue := assignment.Assignment()
557-
result, err := t.evaluateExpressionAssignment(assignmentValue)
560+
result, err := t.evaluateExpressionAssignment(assignmentValue, false)
558561

559562
if err != nil {
560563
return err
@@ -575,8 +578,8 @@ func (t *transpiler) evaluateStructAssignment(assignment parser.StructAssignment
575578
return err
576579
}
577580
fieldAssignment := assignment.Assignment()
578-
valueResult, err := t.evaluateExpressionAssignment(fieldAssignment.Value())
579-
581+
valueResult, err := t.evaluateExpressionAssignment(fieldAssignment.Value(), false)
582+
580583
if err != nil {
581584
return err
582585
}
@@ -725,9 +728,10 @@ func (t *transpiler) evaluateFunctionDefinition(functionDefinition parser.Functi
725728
func (t *transpiler) evaluateFunctionCall(functionCall parser.FunctionCall, valueUsed bool) (expressionResult, error) {
726729
name := functionCall.Name()
727730
args := []string{}
731+
params := functionCall.Params()
728732

729-
for _, arg := range functionCall.Args() {
730-
result, err := t.evaluateExpressionAssignment(arg)
733+
for i, arg := range functionCall.Args() {
734+
result, err := t.evaluateExpressionAssignment(arg, params[i].Pointer())
731735

732736
if err != nil {
733737
return expressionResult{}, err
@@ -783,7 +787,7 @@ func (t *transpiler) evaluateSliceInstantiation(instantiation parser.SliceInstan
783787
values := []string{}
784788

785789
for _, expr := range instantiation.Values() {
786-
result, err := t.evaluateExpressionAssignment(expr)
790+
result, err := t.evaluateExpressionAssignment(expr, false)
787791

788792
if err != nil {
789793
return expressionResult{}, err
@@ -802,7 +806,7 @@ func (t *transpiler) evaluateStructInitialization(definition parser.StructInitia
802806
values := []StructValue{}
803807

804808
for _, value := range definition.Values() {
805-
result, err := t.evaluateExpressionAssignment(value.Value())
809+
result, err := t.evaluateExpressionAssignment(value.Value(), false)
806810

807811
if err != nil {
808812
return expressionResult{}, err

0 commit comments

Comments
 (0)