diff --git a/Cargo.lock b/Cargo.lock index 81b26f2..fa49102 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -205,6 +205,14 @@ dependencies = [ "wit-component", ] +[[package]] +name = "example-records" +version = "0.0.2" +dependencies = [ + "wit-bindgen", + "wit-component", +] + [[package]] name = "example-regressions" version = "0.0.2" diff --git a/cmd/gravity/src/codegen/exports.rs b/cmd/gravity/src/codegen/exports.rs index e155ee9..a9f42c6 100644 --- a/cmd/gravity/src/codegen/exports.rs +++ b/cmd/gravity/src/codegen/exports.rs @@ -171,5 +171,207 @@ mod tests { assert!(generated.contains("results1 := raw1[0]")); assert!(generated.contains("result2 := uint32(results1)")); assert!(generated.contains("return result2")); + + // I32FromU32 / U32FromI32 are no-op reinterpretations — they must not + // use api.EncodeU32 or api.DecodeU32 (which round-trip through uint64, + // causing type mismatches in VariantLower and needless widening elsewhere). + assert!( + !generated.contains("api.EncodeU32"), + "Export must not use api.EncodeU32 (returns uint64 but downstream expects uint32), got:\n{generated}" + ); + assert!( + !generated.contains("api.DecodeU32"), + "Export must not use api.DecodeU32 (needless uint32→uint64→uint32 round-trip), got:\n{generated}" + ); + } + + /// Regression test: export function with a variant parameter containing + /// a u32 payload must generate Go code where I32FromU32 produces a + /// uint32 value matching the VariantLower variable declaration. + /// Previously I32FromU32 used api.EncodeU32() which returns uint64, + /// causing a Go compile error: cannot use uint64 as uint32. + #[test] + fn test_export_variant_u32_no_encode_u32() { + use wit_bindgen_core::wit_parser::{ + Case, TypeDef, TypeDefKind, TypeOwner, Variant, + }; + + let mut resolve = Resolve::new(); + + // variant u32-option { some-val(u32), none-val } + let variant_def = TypeDef { + name: Some("u32-option".to_string()), + kind: TypeDefKind::Variant(Variant { + cases: vec![ + Case { + name: "some-val".to_string(), + ty: Some(Type::U32), + docs: Default::default(), + span: Default::default(), + }, + Case { + name: "none-val".to_string(), + ty: None, + docs: Default::default(), + span: Default::default(), + }, + ], + }), + owner: TypeOwner::None, + docs: Default::default(), + stability: Default::default(), + span: Default::default(), + }; + let variant_id = resolve.types.alloc(variant_def); + + let func = Function { + name: "process_u32_option".to_string(), + kind: FunctionKind::Freestanding, + params: vec![Param { + name: "opt".to_string(), + ty: Type::Id(variant_id), + span: Default::default(), + }], + result: Some(Type::U32), + docs: Default::default(), + stability: Default::default(), + span: Default::default(), + }; + + let world = World { + name: "test-world".to_string(), + imports: [].into(), + exports: [( + WorldKey::Name("process-u32-option".to_string()), + WorldItem::Function(func.clone()), + )] + .into(), + docs: Default::default(), + stability: Default::default(), + includes: Default::default(), + span: Default::default(), + package: None, + }; + + let mut sizes = SizeAlign::default(); + sizes.fill(&resolve); + let instance = GoIdentifier::public("TestInstance"); + + let config = ExportConfig { + instance: &instance, + world: &world, + resolve: &resolve, + sizes: &sizes, + }; + + let generator = ExportGenerator::new(config); + let mut tokens = Tokens::new(); + generator.generate_function(&func, &mut tokens); + + let generated = tokens.to_string().unwrap(); + println!("Generated u32-option function:\n{}", generated); + + // VariantLower declares `var variant_1 uint32` for the I32 payload slot. + // I32FromU32 must NOT use api.EncodeU32 (returns uint64 → type mismatch) + assert!( + !generated.contains("api.EncodeU32"), + "I32FromU32 must not use api.EncodeU32 in exports (returns uint64, \ + but VariantLower variable is uint32), got:\n{generated}" + ); + } + + /// Regression test: export function with a variant parameter containing + /// a u64 payload must generate Go code where I64FromU64 produces a + /// uint64 value matching the VariantLower variable declaration. + /// Previously I64FromU64 used int64() which returns int64, causing a + /// Go compile error: cannot use int64 as uint64. + #[test] + fn test_export_variant_u64_no_int64_cast() { + use wit_bindgen_core::wit_parser::{ + Case, TypeDef, TypeDefKind, TypeOwner, Variant, + }; + + let mut resolve = Resolve::new(); + + // variant u64-option { some-val(u64), none-val } + let variant_def = TypeDef { + name: Some("u64-option".to_string()), + kind: TypeDefKind::Variant(Variant { + cases: vec![ + Case { + name: "some-val".to_string(), + ty: Some(Type::U64), + docs: Default::default(), + span: Default::default(), + }, + Case { + name: "none-val".to_string(), + ty: None, + docs: Default::default(), + span: Default::default(), + }, + ], + }), + owner: TypeOwner::None, + docs: Default::default(), + stability: Default::default(), + span: Default::default(), + }; + let variant_id = resolve.types.alloc(variant_def); + + let func = Function { + name: "process_u64_option".to_string(), + kind: FunctionKind::Freestanding, + params: vec![Param { + name: "opt".to_string(), + ty: Type::Id(variant_id), + span: Default::default(), + }], + result: Some(Type::U64), + docs: Default::default(), + stability: Default::default(), + span: Default::default(), + }; + + let world = World { + name: "test-world".to_string(), + imports: [].into(), + exports: [( + WorldKey::Name("process-u64-option".to_string()), + WorldItem::Function(func.clone()), + )] + .into(), + docs: Default::default(), + stability: Default::default(), + includes: Default::default(), + span: Default::default(), + package: None, + }; + + let mut sizes = SizeAlign::default(); + sizes.fill(&resolve); + let instance = GoIdentifier::public("TestInstance"); + + let config = ExportConfig { + instance: &instance, + world: &world, + resolve: &resolve, + sizes: &sizes, + }; + + let generator = ExportGenerator::new(config); + let mut tokens = Tokens::new(); + generator.generate_function(&func, &mut tokens); + + let generated = tokens.to_string().unwrap(); + println!("Generated u64-option function:\n{}", generated); + + // VariantLower declares `var variant_1 uint64` for the I64 payload slot. + // I64FromU64 must NOT use int64() (returns int64 → type mismatch) + assert!( + !generated.contains(":= int64("), + "I64FromU64 must not use int64() cast in exports (returns int64, \ + but VariantLower variable is uint64), got:\n{generated}" + ); } } diff --git a/cmd/gravity/src/codegen/func.rs b/cmd/gravity/src/codegen/func.rs index c7c2e01..7dbdfa8 100644 --- a/cmd/gravity/src/codegen/func.rs +++ b/cmd/gravity/src/codegen/func.rs @@ -310,17 +310,14 @@ impl Bindgen for Func<'_> { } results.push(Operand::SingleValue(value)) } - // I32FromU32 and U32FromI32 are identity operations at the Wasm - // level (both are 32-bit integers). We use a simple uint32 cast - // rather than api.EncodeU32/api.DecodeU32 because those functions - // convert between uint32 and uint64, but operands here are always - // uint32 — whether from host function params (imports), memory - // reads (exports), or Go variables. The uint64 conversion for - // api.Function.Call() is handled separately by CallWasm. Instruction::I32FromU32 => { let tmp = self.tmp(); let result = &format!("result{tmp}"); let operand = &operands[0]; + // I32FromU32 is a no-op reinterpretation (same 32-bit value, + // different signedness). Use uint32() identity cast in both + // directions — api.EncodeU32 returns uint64 which causes type + // mismatches when assigned to uint32 variables (e.g. VariantLower). quote_in! { self.body => $['\r'] $result := uint32($operand) @@ -328,6 +325,10 @@ impl Bindgen for Func<'_> { results.push(Operand::SingleValue(result.into())); } Instruction::U32FromI32 => { + // U32FromI32 is a no-op reinterpretation (same 32-bit value, + // different signedness). Use uint32() identity cast — + // api.DecodeU32(uint64(...)) is a needless round-trip through + // uint64 when the operand is already uint32. let tmp = self.tmp(); let result = &format!("result{tmp}"); let operand = &operands[0]; @@ -1095,16 +1096,172 @@ impl Bindgen for Func<'_> { Instruction::I32Load8S { .. } => todo!("implement instruction: {inst:?}"), Instruction::I32Load16U { .. } => todo!("implement instruction: {inst:?}"), Instruction::I32Load16S { .. } => todo!("implement instruction: {inst:?}"), - Instruction::I64Load { .. } => todo!("implement instruction: {inst:?}"), - Instruction::F32Load { .. } => todo!("implement instruction: {inst:?}"), - Instruction::F64Load { .. } => todo!("implement instruction: {inst:?}"), + Instruction::I64Load { offset } => { + // TODO(#58): Support additional ArchitectureSize + let offset = offset.size_wasm32(); + let tmp = self.tmp(); + let value = &format!("value{tmp}"); + let ok = &format!("ok{tmp}"); + let default = &format!("default{tmp}"); + let operand = &operands[0]; + quote_in! { self.body => + $['\r'] + $value, $ok := i.module.Memory().ReadUint64Le(uint32($operand + $offset)) + $(match &self.result { + GoResult::Anon(GoType::ValueOrError(typ)) => { + if !$ok { + var $default $(typ.as_ref()) + return $default, $ERRORS_NEW("failed to read i64 from memory") + } + } + GoResult::Anon(GoType::Error) => { + if !$ok { + return $ERRORS_NEW("failed to read i64 from memory") + } + } + GoResult::Anon(_) | GoResult::Empty => { + $(comment(&["The return type doesn't contain an error so we panic if one is encountered"])) + if !$ok { + panic($ERRORS_NEW("failed to read i64 from memory")) + } + } + }) + }; + results.push(Operand::SingleValue(value.into())); + } + Instruction::F32Load { offset } => { + // TODO(#58): Support additional ArchitectureSize + let offset = offset.size_wasm32(); + let tmp = self.tmp(); + let value = &format!("value{tmp}"); + let ok = &format!("ok{tmp}"); + let default = &format!("default{tmp}"); + let operand = &operands[0]; + quote_in! { self.body => + $['\r'] + $value, $ok := i.module.Memory().ReadUint64Le(uint32($operand + $offset)) + $(match &self.result { + GoResult::Anon(GoType::ValueOrError(typ)) => { + if !$ok { + var $default $(typ.as_ref()) + return $default, $ERRORS_NEW("failed to read f32 from memory") + } + } + GoResult::Anon(GoType::Error) => { + if !$ok { + return $ERRORS_NEW("failed to read f32 from memory") + } + } + GoResult::Anon(_) | GoResult::Empty => { + $(comment(&["The return type doesn't contain an error so we panic if one is encountered"])) + if !$ok { + panic($ERRORS_NEW("failed to read f32 from memory")) + } + } + }) + }; + results.push(Operand::SingleValue(value.into())); + } + Instruction::F64Load { offset } => { + // TODO(#58): Support additional ArchitectureSize + let offset = offset.size_wasm32(); + let tmp = self.tmp(); + let value = &format!("value{tmp}"); + let ok = &format!("ok{tmp}"); + let default = &format!("default{tmp}"); + let operand = &operands[0]; + quote_in! { self.body => + $['\r'] + $value, $ok := i.module.Memory().ReadUint64Le(uint32($operand + $offset)) + $(match &self.result { + GoResult::Anon(GoType::ValueOrError(typ)) => { + if !$ok { + var $default $(typ.as_ref()) + return $default, $ERRORS_NEW("failed to read f64 from memory") + } + } + GoResult::Anon(GoType::Error) => { + if !$ok { + return $ERRORS_NEW("failed to read f64 from memory") + } + } + GoResult::Anon(_) | GoResult::Empty => { + $(comment(&["The return type doesn't contain an error so we panic if one is encountered"])) + if !$ok { + panic($ERRORS_NEW("failed to read f64 from memory")) + } + } + }) + }; + results.push(Operand::SingleValue(value.into())); + } Instruction::I32Store16 { .. } => todo!("implement instruction: {inst:?}"), Instruction::I64Store { .. } => todo!("implement instruction: {inst:?}"), - Instruction::F32Store { .. } => todo!("implement instruction: {inst:?}"), - Instruction::F64Store { .. } => todo!("implement instruction: {inst:?}"), + Instruction::F32Store { offset } => { + // TODO(#58): Support additional ArchitectureSize + let offset = offset.size_wasm32(); + let tag = &operands[0]; + let ptr = &operands[1]; + match &self.direction { + Direction::Export => { + quote_in! { self.body => + $['\r'] + i.module.Memory().WriteUint64Le($ptr+$offset, $tag) + } + } + Direction::Import { .. } => { + quote_in! { self.body => + $['\r'] + mod.Memory().WriteUint64Le($ptr+$offset, $tag) + } + } + } + } + Instruction::F64Store { offset } => { + // TODO(#58): Support additional ArchitectureSize + let offset = offset.size_wasm32(); + let tag = &operands[0]; + let ptr = &operands[1]; + match &self.direction { + Direction::Export => { + quote_in! { self.body => + $['\r'] + i.module.Memory().WriteUint64Le($ptr+$offset, $tag) + } + } + Direction::Import { .. } => { + quote_in! { self.body => + $['\r'] + mod.Memory().WriteUint64Le($ptr+$offset, $tag) + } + } + } + } Instruction::I32FromChar => todo!("implement instruction: {inst:?}"), - Instruction::I64FromU64 => todo!("implement instruction: {inst:?}"), - Instruction::I64FromS64 => todo!("implement instruction: {inst:?}"), + Instruction::I64FromU64 => { + // I64FromU64 is a no-op reinterpretation (same 64-bit value, + // different signedness). Use uint64() identity cast — int64() + // returns int64 which causes type mismatches when assigned to + // uint64 variables (e.g. VariantLower). + let tmp = self.tmp(); + let value = format!("value{tmp}"); + let operand = &operands[0]; + quote_in! { self.body => + $['\r'] + $(&value) := uint64($operand) + } + results.push(Operand::SingleValue(value.into())); + } + Instruction::I64FromS64 => { + let tmp = self.tmp(); + let value = format!("value{tmp}"); + let operand = &operands[0]; + quote_in! { self.body => + $['\r'] + $(&value) := $operand + } + results.push(Operand::SingleValue(value.into())); + } Instruction::I32FromS32 => { let tmp = self.tmp(); let value = format!("value{tmp}"); @@ -1204,7 +1361,16 @@ impl Bindgen for Func<'_> { results.push(Operand::SingleValue(result.into())); } Instruction::S64FromI64 => todo!("implement instruction: {inst:?}"), - Instruction::U64FromI64 => todo!("implement instruction: {inst:?}"), + Instruction::U64FromI64 => { + let tmp = self.tmp(); + let value = format!("value{tmp}"); + let operand = &operands[0]; + quote_in! { self.body => + $['\r'] + $(&value) := uint64($operand) + } + results.push(Operand::SingleValue(value.into())); + } Instruction::CharFromI32 => todo!("implement instruction: {inst:?}"), Instruction::F32FromCoreF32 => { let tmp = self.tmp(); diff --git a/cmd/gravity/tests/cmd/records.stderr b/cmd/gravity/tests/cmd/records.stderr new file mode 100644 index 0000000..e69de29 diff --git a/cmd/gravity/tests/cmd/records.stdout b/cmd/gravity/tests/cmd/records.stdout new file mode 100644 index 0000000..7b114ff --- /dev/null +++ b/cmd/gravity/tests/cmd/records.stdout @@ -0,0 +1,465 @@ +// Code generated by arcjet-gravity; DO NOT EDIT. + +package records + +import "context" +import "errors" +import "github.com/tetratelabs/wazero" +import "github.com/tetratelabs/wazero/api" + +import _ "embed" + +//go:embed records.wasm +var wasmFileRecords []byte + +type Foo struct { + Float32 float32 + Float64 float64 + Uint32 uint32 + Uint64 uint64 + S string + Vf32 []float32 + Vf64 []float64 +} + +type RecordsFactory struct { + runtime wazero.Runtime + module wazero.CompiledModule +} + +func NewRecordsFactory( + ctx context.Context, +) (*RecordsFactory, error) { + wazeroRuntime := wazero.NewRuntime(ctx) + + // Compiling the module takes a LONG time, so we want to do it once and hold + // onto it with the Runtime + module, err := wazeroRuntime.CompileModule(ctx, wasmFileRecords) + if err != nil { + return nil, err + } + return &RecordsFactory{ + runtime: wazeroRuntime, + module: module, + }, nil +} + +func (f *RecordsFactory) Instantiate(ctx context.Context) (*RecordsInstance, error) { + if module, err := f.runtime.InstantiateModule(ctx, f.module, wazero.NewModuleConfig()); err != nil { + return nil, err + } else { + return &RecordsInstance{module}, nil + } +} + +func (f *RecordsFactory) Close(ctx context.Context) { + f.runtime.Close(ctx) +} + +type RecordsInstance struct { + module api.Module +} + +func (i *RecordsInstance) Close(ctx context.Context) error { + if err := i.module.Close(ctx); err != nil { + return err + } + + return nil +} + +// writeString will put a Go string into the Wasm memory following the Component +// Model calling conventions, such as allocating memory with the realloc function +func writeString( + ctx context.Context, + s string, + memory api.Memory, + realloc api.Function, +) (uint64, uint64, error) { + if len(s) == 0 { + return 1, 0, nil + } + + results, err := realloc.Call(ctx, 0, 0, 1, uint64(len(s))) + if err != nil { + return 1, 0, err + } + ptr := results[0] + ok := memory.Write(uint32(ptr), []byte(s)) + if !ok { + return 1, 0, errors.New("failed to write string to wasm memory") + } + return uint64(ptr), uint64(len(s)), nil +} + +func (i *RecordsInstance) ModifyFoo( + ctx context.Context, + f Foo, +) Foo { + arg0 := f + float320 := arg0.Float32 + float640 := arg0.Float64 + uint320 := arg0.Uint32 + uint640 := arg0.Uint64 + s0 := arg0.S + vf320 := arg0.Vf32 + vf640 := arg0.Vf64 + result1 := api.EncodeF32(float320) + result2 := api.EncodeF64(float640) + result3 := uint32(uint320) + value4 := uint64(uint640) + memory5 := i.module.Memory() + realloc5 := i.module.ExportedFunction("cabi_realloc") + ptr5, len5, err5 := writeString(ctx, s0, memory5, realloc5) + // The return type doesn't contain an error so we panic if one is encountered + if err5 != nil { + panic(err5) + } + vec7 := vf320 + len7 := uint64(len(vec7)) + result7, err7 := i.module.ExportedFunction("cabi_realloc").Call(ctx, 0, 0, 4, len7 * 4) + // The return type doesn't contain an error so we panic if one is encountered + if err7 != nil { + panic(err7) + } + ptr7 := result7[0] + for idx := uint64(0); idx < len7; idx++ { + e := vec7[idx] + base := uint32(ptr7 + uint64(idx) * uint64(4)) + result6 := api.EncodeF32(e) + i.module.Memory().WriteUint64Le(base+0, result6) + } + vec9 := vf640 + len9 := uint64(len(vec9)) + result9, err9 := i.module.ExportedFunction("cabi_realloc").Call(ctx, 0, 0, 8, len9 * 8) + // The return type doesn't contain an error so we panic if one is encountered + if err9 != nil { + panic(err9) + } + ptr9 := result9[0] + for idx := uint64(0); idx < len9; idx++ { + e := vec9[idx] + base := uint32(ptr9 + uint64(idx) * uint64(8)) + result8 := api.EncodeF64(e) + i.module.Memory().WriteUint64Le(base+0, result8) + } + raw10, err10 := i.module.ExportedFunction("modify-foo").Call(ctx, uint64(result1), uint64(result2), uint64(result3), uint64(value4), uint64(ptr5), uint64(len5), uint64(ptr7), uint64(len7), uint64(ptr9), uint64(len9)) + // The return type doesn't contain an error so we panic if one is encountered + if err10 != nil { + panic(err10) + } + + // The cleanup via `cabi_post_*` cleans up the memory in the guest. By + // deferring this, we ensure that no memory is corrupted before the function + // is done accessing it. + defer func() { + if postFn := i.module.ExportedFunction("cabi_post_modify-foo"); postFn != nil { + if _, err := postFn.Call(ctx, raw10...); err != nil { + // If we get an error during cleanup, something really bad is + // going on, so we panic. Also, you can't return the error from + // the `defer` + panic(errors.New("failed to cleanup")) + } + } + }() + + results10 := raw10[0] + value11, ok11 := i.module.Memory().ReadUint64Le(uint32(results10 + 0)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok11 { + panic(errors.New("failed to read f32 from memory")) + } + result12 := api.DecodeF32(value11) + value13, ok13 := i.module.Memory().ReadUint64Le(uint32(results10 + 8)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok13 { + panic(errors.New("failed to read f64 from memory")) + } + result14 := api.DecodeF64(value13) + value15, ok15 := i.module.Memory().ReadUint32Le(uint32(results10 + 16)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok15 { + panic(errors.New("failed to read i32 from memory")) + } + result16 := uint32(value15) + value17, ok17 := i.module.Memory().ReadUint64Le(uint32(results10 + 24)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok17 { + panic(errors.New("failed to read i64 from memory")) + } + value18 := uint64(value17) + ptr19, ok19 := i.module.Memory().ReadUint32Le(uint32(results10 + 32)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok19 { + panic(errors.New("failed to read pointer from memory")) + } + len20, ok20 := i.module.Memory().ReadUint32Le(uint32(results10 + 36)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok20 { + panic(errors.New("failed to read length from memory")) + } + buf21, ok21 := i.module.Memory().Read(ptr19, len20) + // The return type doesn't contain an error so we panic if one is encountered + if !ok21 { + panic(errors.New("failed to read bytes from memory")) + } + str21 := string(buf21) + ptr22, ok22 := i.module.Memory().ReadUint32Le(uint32(results10 + 40)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok22 { + panic(errors.New("failed to read pointer from memory")) + } + len23, ok23 := i.module.Memory().ReadUint32Le(uint32(results10 + 44)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok23 { + panic(errors.New("failed to read length from memory")) + } + base26 := ptr22 + len26 := len23 + result26 := make([]float32, len26) + for idx26 := uint32(0); idx26 < len26; idx26++ { + base := base26 + idx26 * 4 + value24, ok24 := i.module.Memory().ReadUint64Le(uint32(base + 0)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok24 { + panic(errors.New("failed to read f32 from memory")) + } + result25 := api.DecodeF32(value24) + result26[idx26] = result25 + } + ptr27, ok27 := i.module.Memory().ReadUint32Le(uint32(results10 + 48)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok27 { + panic(errors.New("failed to read pointer from memory")) + } + len28, ok28 := i.module.Memory().ReadUint32Le(uint32(results10 + 52)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok28 { + panic(errors.New("failed to read length from memory")) + } + base31 := ptr27 + len31 := len28 + result31 := make([]float64, len31) + for idx31 := uint32(0); idx31 < len31; idx31++ { + base := base31 + idx31 * 8 + value29, ok29 := i.module.Memory().ReadUint64Le(uint32(base + 0)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok29 { + panic(errors.New("failed to read f64 from memory")) + } + result30 := api.DecodeF64(value29) + result31[idx31] = result30 + } + value32 := Foo{ + Float32: result12, + Float64: result14, + Uint32: result16, + Uint64: value18, + S: str21, + Vf32: result26, + Vf64: result31, + } + return value32 +} + +func (i *RecordsInstance) ModifyFooFallible( + ctx context.Context, + f Foo, +) (Foo, error) { + arg0 := f + float320 := arg0.Float32 + float640 := arg0.Float64 + uint320 := arg0.Uint32 + uint640 := arg0.Uint64 + s0 := arg0.S + vf320 := arg0.Vf32 + vf640 := arg0.Vf64 + result1 := api.EncodeF32(float320) + result2 := api.EncodeF64(float640) + result3 := uint32(uint320) + value4 := uint64(uint640) + memory5 := i.module.Memory() + realloc5 := i.module.ExportedFunction("cabi_realloc") + ptr5, len5, err5 := writeString(ctx, s0, memory5, realloc5) + if err5 != nil { + var default5 Foo + return default5, err5 + } + vec7 := vf320 + len7 := uint64(len(vec7)) + result7, err7 := i.module.ExportedFunction("cabi_realloc").Call(ctx, 0, 0, 4, len7 * 4) + if err7 != nil { + var default7 Foo + return default7, err7 + } + ptr7 := result7[0] + for idx := uint64(0); idx < len7; idx++ { + e := vec7[idx] + base := uint32(ptr7 + uint64(idx) * uint64(4)) + result6 := api.EncodeF32(e) + i.module.Memory().WriteUint64Le(base+0, result6) + } + vec9 := vf640 + len9 := uint64(len(vec9)) + result9, err9 := i.module.ExportedFunction("cabi_realloc").Call(ctx, 0, 0, 8, len9 * 8) + if err9 != nil { + var default9 Foo + return default9, err9 + } + ptr9 := result9[0] + for idx := uint64(0); idx < len9; idx++ { + e := vec9[idx] + base := uint32(ptr9 + uint64(idx) * uint64(8)) + result8 := api.EncodeF64(e) + i.module.Memory().WriteUint64Le(base+0, result8) + } + raw10, err10 := i.module.ExportedFunction("modify-foo-fallible").Call(ctx, uint64(result1), uint64(result2), uint64(result3), uint64(value4), uint64(ptr5), uint64(len5), uint64(ptr7), uint64(len7), uint64(ptr9), uint64(len9)) + if err10 != nil { + var default10 Foo + return default10, err10 + } + + // The cleanup via `cabi_post_*` cleans up the memory in the guest. By + // deferring this, we ensure that no memory is corrupted before the function + // is done accessing it. + defer func() { + if postFn := i.module.ExportedFunction("cabi_post_modify-foo-fallible"); postFn != nil { + if _, err := postFn.Call(ctx, raw10...); err != nil { + // If we get an error during cleanup, something really bad is + // going on, so we panic. Also, you can't return the error from + // the `defer` + panic(errors.New("failed to cleanup")) + } + } + }() + + results10 := raw10[0] + value11, ok11 := i.module.Memory().ReadByte(uint32(results10 + 0)) + if !ok11 { + var default11 Foo + return default11, errors.New("failed to read byte from memory") + } + var value37 Foo + var err37 error + switch value11 { + case 0: + value12, ok12 := i.module.Memory().ReadUint64Le(uint32(results10 + 8)) + if !ok12 { + var default12 Foo + return default12, errors.New("failed to read f32 from memory") + } + result13 := api.DecodeF32(value12) + value14, ok14 := i.module.Memory().ReadUint64Le(uint32(results10 + 16)) + if !ok14 { + var default14 Foo + return default14, errors.New("failed to read f64 from memory") + } + result15 := api.DecodeF64(value14) + value16, ok16 := i.module.Memory().ReadUint32Le(uint32(results10 + 24)) + if !ok16 { + var default16 Foo + return default16, errors.New("failed to read i32 from memory") + } + result17 := uint32(value16) + value18, ok18 := i.module.Memory().ReadUint64Le(uint32(results10 + 32)) + if !ok18 { + var default18 Foo + return default18, errors.New("failed to read i64 from memory") + } + value19 := uint64(value18) + ptr20, ok20 := i.module.Memory().ReadUint32Le(uint32(results10 + 40)) + if !ok20 { + var default20 Foo + return default20, errors.New("failed to read pointer from memory") + } + len21, ok21 := i.module.Memory().ReadUint32Le(uint32(results10 + 44)) + if !ok21 { + var default21 Foo + return default21, errors.New("failed to read length from memory") + } + buf22, ok22 := i.module.Memory().Read(ptr20, len21) + if !ok22 { + var default22 Foo + return default22, errors.New("failed to read bytes from memory") + } + str22 := string(buf22) + ptr23, ok23 := i.module.Memory().ReadUint32Le(uint32(results10 + 48)) + if !ok23 { + var default23 Foo + return default23, errors.New("failed to read pointer from memory") + } + len24, ok24 := i.module.Memory().ReadUint32Le(uint32(results10 + 52)) + if !ok24 { + var default24 Foo + return default24, errors.New("failed to read length from memory") + } + base27 := ptr23 + len27 := len24 + result27 := make([]float32, len27) + for idx27 := uint32(0); idx27 < len27; idx27++ { + base := base27 + idx27 * 4 + value25, ok25 := i.module.Memory().ReadUint64Le(uint32(base + 0)) + if !ok25 { + var default25 Foo + return default25, errors.New("failed to read f32 from memory") + } + result26 := api.DecodeF32(value25) + result27[idx27] = result26 + } + ptr28, ok28 := i.module.Memory().ReadUint32Le(uint32(results10 + 56)) + if !ok28 { + var default28 Foo + return default28, errors.New("failed to read pointer from memory") + } + len29, ok29 := i.module.Memory().ReadUint32Le(uint32(results10 + 60)) + if !ok29 { + var default29 Foo + return default29, errors.New("failed to read length from memory") + } + base32 := ptr28 + len32 := len29 + result32 := make([]float64, len32) + for idx32 := uint32(0); idx32 < len32; idx32++ { + base := base32 + idx32 * 8 + value30, ok30 := i.module.Memory().ReadUint64Le(uint32(base + 0)) + if !ok30 { + var default30 Foo + return default30, errors.New("failed to read f64 from memory") + } + result31 := api.DecodeF64(value30) + result32[idx32] = result31 + } + value33 := Foo{ + Float32: result13, + Float64: result15, + Uint32: result17, + Uint64: value19, + S: str22, + Vf32: result27, + Vf64: result32, + } + value37 = value33 + case 1: + ptr34, ok34 := i.module.Memory().ReadUint32Le(uint32(results10 + 8)) + if !ok34 { + var default34 Foo + return default34, errors.New("failed to read pointer from memory") + } + len35, ok35 := i.module.Memory().ReadUint32Le(uint32(results10 + 12)) + if !ok35 { + var default35 Foo + return default35, errors.New("failed to read length from memory") + } + buf36, ok36 := i.module.Memory().Read(ptr34, len35) + if !ok36 { + var default36 Foo + return default36, errors.New("failed to read bytes from memory") + } + str36 := string(buf36) + err37 = errors.New(str36) + default: + err37 = errors.New("invalid variant discriminant for expected") + } + return value37, err37 +} + diff --git a/cmd/gravity/tests/cmd/records.toml b/cmd/gravity/tests/cmd/records.toml new file mode 100644 index 0000000..4cba21d --- /dev/null +++ b/cmd/gravity/tests/cmd/records.toml @@ -0,0 +1,2 @@ +bin.name = "gravity" +args = "--world records ../../target/wasm32-unknown-unknown/release/example_records.wasm" diff --git a/examples/generate.go b/examples/generate.go index 6f539b0..302d0f3 100644 --- a/examples/generate.go +++ b/examples/generate.go @@ -1,11 +1,13 @@ package examples //go:generate cargo build -p example-basic --target wasm32-unknown-unknown --release +//go:generate cargo build -p example-records --target wasm32-unknown-unknown --release //go:generate cargo build -p example-iface-method-returns-string --target wasm32-unknown-unknown --release //go:generate cargo build -p example-instructions --target wasm32-unknown-unknown --release //go:generate cargo build -p example-regressions --target wasm32-unknown-unknown --release //go:generate cargo run --bin gravity -- --world basic --output ./basic/basic.go ../target/wasm32-unknown-unknown/release/example_basic.wasm +//go:generate cargo run --bin gravity -- --world records --output ./records/records.go ../target/wasm32-unknown-unknown/release/example_records.wasm //go:generate cargo run --bin gravity -- --world example --output ./iface-method-returns-string/example.go ../target/wasm32-unknown-unknown/release/example_iface_method_returns_string.wasm //go:generate cargo run --bin gravity -- --world instructions --output ./instructions/bindings.go ../target/wasm32-unknown-unknown/release/example_instructions.wasm //go:generate cargo run --bin gravity -- --world regressions --output ./regressions/regressions.go ../target/wasm32-unknown-unknown/release/example_regressions.wasm diff --git a/examples/records/Cargo.toml b/examples/records/Cargo.toml new file mode 100644 index 0000000..a89b627 --- /dev/null +++ b/examples/records/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "example-records" +version = "0.0.2" +edition = "2024" + +[lib] +crate-type = ["cdylib"] + +[dependencies] +wit-bindgen = "=0.53.1" +wit-component = "=0.245.1" diff --git a/examples/records/records_test.go b/examples/records/records_test.go new file mode 100644 index 0000000..f1e0201 --- /dev/null +++ b/examples/records/records_test.go @@ -0,0 +1,135 @@ +package records + +import ( + "math" + "testing" +) + +func TestRecord(t *testing.T) { + fac, err := NewRecordsFactory(t.Context()) + if err != nil { + t.Fatal(err) + } + defer fac.Close(t.Context()) + + ins, err := fac.Instantiate(t.Context()) + if err != nil { + t.Fatal(err) + } + defer ins.Close(t.Context()) + + foo := Foo{ + Float32: 1.0, + Float64: 1.0, + Uint32: 1, + Uint64: uint64(math.MaxUint32), + S: "hello", + Vf32: []float32{1.0, 2.0, 3.0}, + Vf64: []float64{1.0, 2.0, 3.0}, + } + got := ins.ModifyFoo(t.Context(), foo) + want := Foo{ + Float32: foo.Float32 * 2.0, + Float64: foo.Float64 * 2.0, + Uint32: foo.Uint32 + 1, + Uint64: foo.Uint64 + 1, + S: "received hello", + Vf32: []float32{2.0, 4.0, 6.0}, + Vf64: []float64{2.0, 4.0, 6.0}, + } + if !fooCmp(got, want) { + t.Fatalf("got %+v, want %+v", got, want) + } +} + +func TestRecordFallibleSuccess(t *testing.T) { + fac, err := NewRecordsFactory(t.Context()) + if err != nil { + t.Fatal(err) + } + defer fac.Close(t.Context()) + + ins, err := fac.Instantiate(t.Context()) + if err != nil { + t.Fatal(err) + } + defer ins.Close(t.Context()) + + foo := Foo{ + Float32: 1.0, + Float64: 5.0, // <= 10.0, should succeed + Uint32: 1, + Uint64: uint64(math.MaxUint32), + S: "hello", + Vf32: []float32{1.0, 2.0, 3.0}, + Vf64: []float64{1.0, 2.0, 3.0}, + } + got, err := ins.ModifyFooFallible(t.Context(), foo) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + want := Foo{ + Float32: foo.Float32 * 2.0, + Float64: foo.Float64 * 2.0, + Uint32: foo.Uint32 + 1, + Uint64: foo.Uint64 + 1, + S: "received hello", + Vf32: []float32{2.0, 4.0, 6.0}, + Vf64: []float64{2.0, 4.0, 6.0}, + } + if !fooCmp(got, want) { + t.Fatalf("got %+v, want %+v", got, want) + } +} + +func TestRecordFallibleError(t *testing.T) { + fac, err := NewRecordsFactory(t.Context()) + if err != nil { + t.Fatal(err) + } + defer fac.Close(t.Context()) + + ins, err := fac.Instantiate(t.Context()) + if err != nil { + t.Fatal(err) + } + defer ins.Close(t.Context()) + + foo := Foo{ + Float32: 1.0, + Float64: 15.0, // > 10.0, should error + Uint32: 1, + Uint64: uint64(math.MaxUint32), + S: "hello", + Vf32: []float32{1.0, 2.0, 3.0}, + Vf64: []float64{1.0, 2.0, 3.0}, + } + _, err = ins.ModifyFooFallible(t.Context(), foo) + if err == nil { + t.Fatal("expected error, got nil") + } + wantErr := "float64 too big" + if err.Error() != wantErr { + t.Fatalf("got error %q, want %q", err.Error(), wantErr) + } +} + +func fooCmp(a, b Foo) bool { + if a.Float32 != b.Float32 || a.Float64 != b.Float64 || a.Uint32 != b.Uint32 || a.Uint64 != b.Uint64 || a.S != b.S { + return false + } + if len(a.Vf32) != len(b.Vf32) || len(a.Vf64) != len(b.Vf64) { + return false + } + for i := range a.Vf32 { + if a.Vf32[i] != b.Vf32[i] { + return false + } + } + for i := range a.Vf64 { + if a.Vf64[i] != b.Vf64[i] { + return false + } + } + return true +} diff --git a/examples/records/src/lib.rs b/examples/records/src/lib.rs new file mode 100644 index 0000000..4315259 --- /dev/null +++ b/examples/records/src/lib.rs @@ -0,0 +1,57 @@ +wit_bindgen::generate!({ + world: "records", +}); + +struct RecordsWorld; + +export!(RecordsWorld); + +impl Guest for RecordsWorld { + fn modify_foo( + Foo { + float64, + float32, + uint32, + uint64, + s, + vf32, + vf64, + }: Foo, + ) -> Foo { + Foo { + float64: float64 * 2.0, + float32: float32 * 2.0, + uint32: uint32 + 1, + uint64: uint64 + 1, + s: format!("received {s}"), + vf32: vf32.iter().map(|v| v * 2.0).collect(), + vf64: vf64.iter().map(|v| v * 2.0).collect(), + } + } + + fn modify_foo_fallible( + Foo { + float64, + float32, + uint32, + uint64, + s, + vf32, + vf64, + }: Foo, + ) -> Result { + if float64 > 10.0 { + Err("float64 too big".to_string()) + } else { + Ok(Foo { + float64: float64 * 2.0, + float32: float32 * 2.0, + uint32: uint32 + 1, + uint64: uint64 + 1, + s: format!("received {s}"), + vf32: vf32.iter().map(|v| v * 2.0).collect(), + vf64: vf64.iter().map(|v| v * 2.0).collect(), + }) + } + } +} diff --git a/examples/records/wit/records.wit b/examples/records/wit/records.wit new file mode 100644 index 0000000..015bfd6 --- /dev/null +++ b/examples/records/wit/records.wit @@ -0,0 +1,16 @@ +package arcjet:records; + +world records { + record foo { + float32: f32, + float64: f64, + uint32: u32, + uint64: u64, + s: string, + vf32: list, + vf64: list, + } + + export modify-foo: func(f: foo) -> foo; + export modify-foo-fallible: func(f: foo) -> result; +} diff --git a/examples/regressions/regressions_test.go b/examples/regressions/regressions_test.go index 6bdb883..830c2d7 100644 --- a/examples/regressions/regressions_test.go +++ b/examples/regressions/regressions_test.go @@ -158,3 +158,8 @@ func TestRunPing(t *testing.T) { t.Errorf("RunPing() = %v, want true", got) } } + +// TODO: When gravity supports generating Go variant type definitions, add E2E +// tests for export functions that accept variant parameters (e.g. a variant +// with a u32 or u64 payload). These would exercise the VariantLower codepath +// end-to-end through wazero. diff --git a/examples/regressions/wit/regressions.wit b/examples/regressions/wit/regressions.wit index fac5fc1..0401d24 100644 --- a/examples/regressions/wit/regressions.wit +++ b/examples/regressions/wit/regressions.wit @@ -38,4 +38,8 @@ world regressions { export check-status: func(key: string) -> u32; export double-value: func(value: u32) -> u32; export run-ping: func() -> bool; + + // TODO: When variant type definition generation is supported, add variant + // exports here (e.g. variant with u32/u64 payloads) and corresponding E2E + // tests in regressions_test.go. }