diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 29b7923b7..83a7aef29 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -190,6 +190,21 @@ jobs: - name: Build IntelliJ plugin run: ./gradlew :src:ide:intellij-plugin:buildPlugin + verify: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - name: Set up JDK + uses: actions/setup-java@v4 + with: + java-version: '21' + distribution: 'temurin' + cache: gradle + - name: Verify + run: ./gradlew :src:verify:test -Pverify + example-build: runs-on: ubuntu-latest @@ -245,6 +260,10 @@ jobs: type: gradle - name: npm-typescript type: npm + - name: rust-petstore + type: rust + - name: scala-zio + type: scala steps: - uses: actions/checkout@v4 @@ -254,6 +273,7 @@ jobs: with: java-version: '21' distribution: 'temurin' + cache: gradle - name: Set up GraalVM if: matrix.type == 'native' uses: graalvm/setup-graalvm@v1 @@ -264,8 +284,14 @@ jobs: if: matrix.type == 'npm' with: node-version: 20 + - name: Set up Rust + if: matrix.type == 'rust' + uses: dtolnay/rust-toolchain@stable + - name: Set up sbt + if: matrix.type == 'scala' + uses: sbt/setup-sbt@v1 - name: Download Maven Local artifacts - if: matrix.type != 'npm' + if: matrix.type != 'npm' && matrix.type != 'rust' && matrix.type != 'scala' uses: actions/download-artifact@v4 with: name: wirespec-m2 @@ -291,6 +317,14 @@ jobs: if: matrix.type == 'npm' working-directory: examples/${{ matrix.name }} run: npm ci && npm run build + - name: Run ${{ matrix.name }} example + if: matrix.type == 'rust' + working-directory: examples/${{ matrix.name }} + run: bash gen.sh && cargo build && cargo test + - name: Run ${{ matrix.name }} example + if: matrix.type == 'scala' + working-directory: examples/${{ matrix.name }} + run: sbt compile test success: @@ -309,6 +343,7 @@ jobs: - vscode - intellij-plugin - example-test + - verify steps: - name: Check CI status diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 000000000..f526b56f5 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,237 @@ +# Wirespec - Claude Code Instructions + +## Generated Code Requirements + +All code generated by Wirespec emitters must be **dependency-free**. Generated output must not rely on any external libraries or third-party packages — it must be fully self-contained and runnable with only the target language's standard library. + +## IR Emitter Pipeline + +The Wirespec compiler uses a four-stage IR pipeline to generate idiomatic code for Java, Kotlin, TypeScript, Python, and Rust: + +``` +Wirespec Source + │ + ▼ +Parser ──► AST (Root / Module / Definition) + │ + ▼ +IrConverter ──► IR File (language-neutral tree) + │ + ▼ +Language Emitter ──► Transformed IR File (via Transform DSL) + │ + ▼ +CodeGenerator ──► String (target-language source code) +``` + +### Stage 1: Parser AST + +The parser produces a tree of `Definition` nodes grouped into `Module`s inside a `Root`: + +``` +Root ─► Module[] ─► Definition[] +``` + +**Definition types** (sealed hierarchy): +- `Type` — record/struct with a `Shape` (list of `Field`s) and optional `extends` +- `Enum` — set of string entries +- `Union` — set of `Reference` entries +- `Refined` — primitive wrapper with a regex/bound constraint +- `Endpoint` — HTTP endpoint (method, path, queries, headers, requests, responses) +- `Channel` — async messaging channel with a reference type + +**Reference type system** (`sealed interface Reference`): +`Any`, `Unit`, `Custom(name)`, `Primitive(type)`, `Iterable(reference)`, `Dict(reference)`. Each carries `isNullable`. `Primitive.Type` variants: `String`, `Integer`, `Number`, `Boolean`, `Bytes` (with optional precision and constraints). + +Key files: +- `src/compiler/core/.../parse/ast/Definition.kt` +- `src/compiler/core/.../parse/ast/Reference.kt` +- `src/compiler/core/.../parse/ast/Root.kt` + +### Stage 2: Convert (Parser AST → IR) + +`IrConverter.kt` maps each parser `Definition` to an IR `File` tree. The entry point dispatches by definition type: + +```kotlin +fun DefinitionWirespec.convert(): File = when (this) { + is TypeWirespec -> convert() + is EnumWirespec -> convert() + is UnionWirespec -> convert() + is RefinedWirespec -> convert() + is ChannelWirespec -> convert() + is EndpointWirespec -> convert() +} +``` + +Each per-definition converter produces a complete IR `File` with the appropriate structs, interfaces, functions, and validation logic. `EndpointWirespec.convert()` is the most complex — it generates Path/Queries/RequestHeaders/Request structs, a Response union hierarchy, serialization functions, and a Handler interface. + +Reference conversion maps parser references to IR types: `Custom → Type.Custom`, `Iterable → Type.Array`, `Dict → Type.Dict`, `Primitive → Type.String/Integer/Number/Boolean/Bytes`, wrapping with `Type.Nullable` when `isNullable`. + +Key file: `src/compiler/ir/.../converter/IrConverter.kt` + +### Stage 3: IR AST + +The IR is a language-neutral tree with the following node types: + +**Elements** (AST nodes): +- `File(name, elements)` — top-level container +- `Struct(name, fields, constructors, interfaces, elements)` — record/class +- `Interface(name, elements, extends, isSealed, typeParameters, fields)` — interface/protocol +- `Namespace(name, elements, extends)` — grouping container +- `Union(name, members, typeParameters)` — tagged union +- `Enum(name, entries, fields, constructors, elements)` — enumeration (entries have name + values) +- `Function(name, typeParameters, parameters, returnType, body, isAsync, isStatic, isOverride)` +- `Package(path)`, `Import(path, type)`, `RawElement(code)` + +**Type system** (`sealed interface Type`): +`Integer(precision)`, `Number(precision)`, `String`, `Boolean`, `Bytes`, `Unit`, `Any`, `Wildcard`, `Reflect`, `Array(elementType)`, `Dict(keyType, valueType)`, `Custom(name, generics)`, `Nullable(type)` + +**Statement / Expression hierarchy**: `RawExpression`, `VariableReference`, `FieldCall`, `FunctionCall`, `BinaryOp`, `ConstructorStatement`, `Literal`, `Switch/Case`, `IfExpression`, `StringTemplate`, `MapExpression`, `ReturnStatement`, `Assignment`, constraints (`RegexMatch`, `BoundCheck`), null-handling (`NullCheck`, `NullableMap`, `NullableOf`), and more. + +Key file: `src/compiler/ir/.../core/Ast.kt` + +### Stage 4: Transform (detailed) + +The transform layer is the heart of language-specific adaptation. It lets each emitter reshape the language-neutral IR into a form that generates idiomatic target code. + +Key file: `src/compiler/ir/.../core/Transform.kt` + +#### Transformer interface + +Eight override points, each defaulting to recursive `transformChildren()`: + +```kotlin +interface Transformer { + fun transformType(type: Type): Type + fun transformElement(element: Element): Element + fun transformStatement(statement: Statement): Statement + fun transformExpression(expression: Expression): Expression + fun transformField(field: Field): Field + fun transformParameter(parameter: Parameter): Parameter + fun transformConstructor(constructor: Constructor): Constructor + fun transformCase(case: Case): Case +} +``` + +#### `transformer()` DSL factory + +Creates a `Transformer` using a `TransformerBuilder` with a builder-pattern DSL. Only the overrides you specify are applied; all others default to recursive `transformChildren()`: + +```kotlin +inline fun transformer(block: TransformerBuilder.() -> Unit): Transformer + +// Usage: +transformer { + type { type, transformer -> /* ... */ } + element { element, transformer -> /* ... */ } + statement { statement, transformer -> /* ... */ } + expression { expression, transformer -> /* ... */ } + field { field, transformer -> /* ... */ } + parameter { parameter, transformer -> /* ... */ } + constructor { constructor, transformer -> /* ... */ } + case { case, transformer -> /* ... */ } +} +``` + +#### `transformChildren()` recursive traversal + +Each node type has a `transformChildren(Transformer)` extension that walks into its children. For example, a `Struct` transforms its fields, constructors, and child elements; a `FunctionCall` transforms its receiver, type arguments, and argument expressions. This ensures transforms propagate through the entire tree. + +Apply a transformer to any element: `fun T.transform(transformer: Transformer): T` + +#### `TransformScope` — block-based transform API + +The primary API used by language emitters. Call `element.transform { ... }` to open a scope and chain multiple transforms: + +```kotlin +inline fun E.transform(block: TransformScope.() -> Unit): E +``` + +`TransformScope` methods: + +| Method | Purpose | +|---|---| +| `matching { transform }` | Transform all types matching a Kotlin class | +| `matchingElements { transform }` | Transform all elements matching a Kotlin class | +| `fieldsWhere(predicate, transform)` | Transform fields matching a predicate | +| `parametersWhere(predicate, transform)` | Transform parameters matching a predicate | +| `renameType(oldName, newName)` | Rename a `Type.Custom` throughout the tree | +| `renameField(oldName, newName)` | Rename a field throughout the tree | +| `typeByName(name, transform)` | Transform types matching a custom name | +| `injectBefore { produce }` | Insert elements before a matching container | +| `injectAfter { produce }` | Insert elements after a matching container | +| `apply(transformer)` | Apply a pre-built `Transformer` | +| `type { type, transformer -> ... }` | Shorthand: create + apply a type transformer | +| `statement { stmt, transformer -> ... }` | Shorthand: create + apply a statement transformer | +| `expression { expr, transformer -> ... }` | Shorthand: create + apply an expression transformer | +| `field { field, transformer -> ... }` | Shorthand: create + apply a field transformer | +| `parameter { param, transformer -> ... }` | Shorthand: create + apply a parameter transformer | +| `constructor { ctor, transformer -> ... }` | Shorthand: create + apply a constructor transformer | +| `case { case, transformer -> ... }` | Shorthand: create + apply a case transformer | + +#### Low-level helper functions + +These `internal` extension functions on `Element` power the `TransformScope` methods above: + +| Function | Purpose | +|---|---| +| `transformMatching` | Transform all types matching a Kotlin class | +| `transformMatchingElements` | Transform all elements matching a Kotlin class | +| `transformFieldsWhere(predicate, transform)` | Transform fields matching a predicate | +| `transformParametersWhere(predicate, transform)` | Transform parameters matching a predicate | +| `renameType(oldName, newName)` | Rename a `Type.Custom` throughout the tree | +| `renameField(oldName, newName)` | Rename a field throughout the tree | +| `transformTypeByName(name, transform)` | Transform types matching a custom name | +| `injectBefore(produce)` | Insert elements before a matching container | +| `injectAfter(produce)` | Insert elements after a matching container | + +#### Read-only traversal utilities + +Standalone functions (not an interface) for walking the IR tree without modifying it: + +| Function | Purpose | +|---|---| +| `forEachType(action)` | Visit every `Type` node in the tree | +| `forEachElement(action)` | Visit every `Element` node in the tree | +| `forEachField(action)` | Visit every `Field` node in the tree | +| `collectTypes()` | Collect all `Type` nodes into a list | +| `collectCustomTypeNames()` | Collect all `Type.Custom` names into a set | +| `findAll()` | Find all elements of a specific type | +| `findAllTypes()` | Find all types of a specific type | +| `findElement()` | Find the first child element of a specific type | + +### Stage 5: Generate + +Each language emitter implements `File.generate()` by delegating to a `CodeGenerator` singleton: + +```kotlin +interface CodeGenerator { + fun generate(element: Element): String +} +``` + +Generators: `JavaGenerator`, `KotlinGenerator`, `TypeScriptGenerator`, `PythonGenerator`, `RustGenerator`. Each recursively walks the IR tree and emits the corresponding target-language syntax as a string. + +Top-level entry: `fun Element.generateJava()`, `fun Element.generateKotlin()`, etc. + +Key files: +- `src/compiler/ir/.../emit/IrEmitter.kt` +- `src/compiler/ir/.../generator/CodeGenerator.kt` +- `src/compiler/ir/.../generator/{Java,Kotlin,TypeScript,Python,Rust}Generator.kt` + +### Key File Reference + +| File | Purpose | +|---|---| +| `src/compiler/core/.../parse/ast/Definition.kt` | Parser AST definition types | +| `src/compiler/core/.../parse/ast/Reference.kt` | Parser AST reference/type system | +| `src/compiler/ir/.../converter/IrConverter.kt` | Parser AST → IR conversion | +| `src/compiler/ir/.../core/Ast.kt` | IR node types (Element, Type, Statement, Expression) | +| `src/compiler/ir/.../core/Transform.kt` | Transform DSL + TransformScope + traversal utilities | +| `src/compiler/ir/.../emit/IrEmitter.kt` | Emitter interface and orchestration | +| `src/compiler/ir/.../generator/CodeGenerator.kt` | Generator interface + top-level functions | +| `src/compiler/emitters/java/.../JavaIrEmitter.kt` | Java-specific transforms + emit | +| `src/compiler/emitters/kotlin/.../KotlinIrEmitter.kt` | Kotlin-specific transforms + emit | +| `src/compiler/emitters/typescript/.../TypeScriptIrEmitter.kt` | TypeScript-specific transforms + emit | +| `src/compiler/emitters/python/.../PythonIrEmitter.kt` | Python-specific transforms + emit | +| `src/compiler/emitters/rust/.../RustIrEmitter.kt` | Rust-specific transforms + emit | diff --git a/Makefile b/Makefile index fdd1dc49f..6be2e05a4 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,7 @@ .PHONY: * # The first command will be invoked with `make` only and should be `all` -all: build image test example format +all: build image test example format verify build: build-wirespec build-site @@ -42,7 +42,7 @@ update: npm install -g @vscode/vsce verify: - $(shell pwd)/scripts/verify.sh + ./gradlew :src:verify:test -Pverify yolo: $(shell pwd)/scripts/yolo.sh diff --git a/examples/Makefile b/examples/Makefile index 1b344e3ee..4e3e53f4c 100644 --- a/examples/Makefile +++ b/examples/Makefile @@ -9,7 +9,9 @@ build: (cd maven-spring-integration && ./mvnw verify) && \ (cd maven-spring-boot-4-integration && ./mvnw verify) && \ (cd gradle-ktor && ./gradlew check) && \ - (cd npm-typescript && npm ci && npm run build) + (cd npm-typescript && npm ci && npm run build) && \ + (cd rust-petstore && bash gen.sh && cargo build) && \ + (cd scala-zio && sbt compile) clean: (cd maven-preprocessor && ./mvnw clean) && \ diff --git a/examples/README.md b/examples/README.md deleted file mode 100644 index c86d5f46f..000000000 --- a/examples/README.md +++ /dev/null @@ -1,24 +0,0 @@ -# Examples - -Here you can find examples of how to use: - -* [The Gradle Plugin](gradle-ktor/README.md) -* [The Maven Plugin](maven-spring-compile/README.md) -* [And convert an OpenAPI Specification](maven-spring-convert/README.md) -* [A custom Emitter](maven-spring-custom/README.md) -* [The Spring integration](../../src/integration/spring/README.md) - -## Integration - -Some notes on how Wirespec integrates with different libraries and frameworks - -### Jackson (json object mapper) - -For some languages Wirespec is sanitizing enums names because of usage of preserved keywords and forbidden characters. -This results into problems with serialization. In Jackson the following configuration can be used to fix this. - -```kotlin -ObjectMapper() - .enable(DeserializationFeature.READ_ENUMS_USING_TO_STRING) - .enable(SerializationFeature.WRITE_ENUMS_USING_TO_STRING) -``` diff --git a/examples/rust-petstore/.gitignore b/examples/rust-petstore/.gitignore new file mode 100644 index 000000000..19c77ae22 --- /dev/null +++ b/examples/rust-petstore/.gitignore @@ -0,0 +1,2 @@ +/target +/src/gen diff --git a/examples/rust-petstore/Cargo.lock b/examples/rust-petstore/Cargo.lock new file mode 100644 index 000000000..41dff8f2d --- /dev/null +++ b/examples/rust-petstore/Cargo.lock @@ -0,0 +1,2490 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "actix-codec" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f7b0a21988c1bf877cf4759ef5ddaac04c1c9fe808c9142ecb78ba97d97a28a" +dependencies = [ + "bitflags", + "bytes", + "futures-core", + "futures-sink", + "memchr", + "pin-project-lite", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "actix-http" +version = "3.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f860ee6746d0c5b682147b2f7f8ef036d4f92fe518251a3a35ffa3650eafdf0e" +dependencies = [ + "actix-codec", + "actix-rt", + "actix-service", + "actix-utils", + "base64", + "bitflags", + "brotli", + "bytes", + "bytestring", + "derive_more", + "encoding_rs", + "flate2", + "foldhash", + "futures-core", + "h2 0.3.27", + "http 0.2.12", + "httparse", + "httpdate", + "itoa", + "language-tags", + "local-channel", + "mime", + "percent-encoding", + "pin-project-lite", + "rand", + "sha1", + "smallvec", + "tokio", + "tokio-util", + "tracing", + "zstd", +] + +[[package]] +name = "actix-macros" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e01ed3140b2f8d422c68afa1ed2e85d996ea619c988ac834d255db32138655cb" +dependencies = [ + "quote", + "syn", +] + +[[package]] +name = "actix-router" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14f8c75c51892f18d9c46150c5ac7beb81c95f78c8b83a634d49f4ca32551fe7" +dependencies = [ + "bytestring", + "cfg-if", + "http 0.2.12", + "regex", + "regex-lite", + "serde", + "tracing", +] + +[[package]] +name = "actix-rt" +version = "2.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92589714878ca59a7626ea19734f0e07a6a875197eec751bb5d3f99e64998c63" +dependencies = [ + "futures-core", + "tokio", +] + +[[package]] +name = "actix-server" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a65064ea4a457eaf07f2fba30b4c695bf43b721790e9530d26cb6f9019ff7502" +dependencies = [ + "actix-rt", + "actix-service", + "actix-utils", + "futures-core", + "futures-util", + "mio", + "socket2 0.5.10", + "tokio", + "tracing", +] + +[[package]] +name = "actix-service" +version = "2.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e46f36bf0e5af44bdc4bdb36fbbd421aa98c79a9bce724e1edeb3894e10dc7f" +dependencies = [ + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "actix-utils" +version = "3.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88a1dcdff1466e3c2488e1cb5c36a71822750ad43839937f85d2f4d9f8b705d8" +dependencies = [ + "local-waker", + "pin-project-lite", +] + +[[package]] +name = "actix-web" +version = "4.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff87453bc3b56e9b2b23c1cc0b1be8797184accf51d2abe0f8a33ec275d316bf" +dependencies = [ + "actix-codec", + "actix-http", + "actix-macros", + "actix-router", + "actix-rt", + "actix-server", + "actix-service", + "actix-utils", + "actix-web-codegen", + "bytes", + "bytestring", + "cfg-if", + "cookie", + "derive_more", + "encoding_rs", + "foldhash", + "futures-core", + "futures-util", + "impl-more", + "itoa", + "language-tags", + "log", + "mime", + "once_cell", + "pin-project-lite", + "regex", + "regex-lite", + "serde", + "serde_json", + "serde_urlencoded", + "smallvec", + "socket2 0.6.2", + "time", + "tracing", + "url", +] + +[[package]] +name = "actix-web-codegen" +version = "4.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f591380e2e68490b5dfaf1dd1aa0ebe78d84ba7067078512b4ea6e4492d622b8" +dependencies = [ + "actix-router", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "adler2" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" + +[[package]] +name = "aho-corasick" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" +dependencies = [ + "memchr", +] + +[[package]] +name = "alloc-no-stdlib" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc7bb162ec39d46ab1ca8c77bf72e890535becd1751bb45f64c597edb4c8c6b3" + +[[package]] +name = "alloc-stdlib" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94fb8275041c72129eb51b7d0322c29b8387a0386127718b096429201a5d6ece" +dependencies = [ + "alloc-no-stdlib", +] + +[[package]] +name = "anyhow" +version = "1.0.102" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" + +[[package]] +name = "atomic-waker" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + +[[package]] +name = "bitflags" +version = "2.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af" + +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + +[[package]] +name = "brotli" +version = "8.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4bd8b9603c7aa97359dbd97ecf258968c95f3adddd6db2f7e7a5bef101c84560" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", + "brotli-decompressor", +] + +[[package]] +name = "brotli-decompressor" +version = "5.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "874bb8112abecc98cbd6d81ea4fa7e94fb9449648c93cc89aa40c81c24d7de03" +dependencies = [ + "alloc-no-stdlib", + "alloc-stdlib", +] + +[[package]] +name = "bumpalo" +version = "3.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d20789868f4b01b2f2caec9f5c4e0213b41e3e5702a50157d699ae31ced2fcb" + +[[package]] +name = "bytes" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" + +[[package]] +name = "bytestring" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "113b4343b5f6617e7ad401ced8de3cc8b012e73a594347c307b90db3e9271289" +dependencies = [ + "bytes", +] + +[[package]] +name = "cc" +version = "1.2.56" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aebf35691d1bfb0ac386a69bac2fde4dd276fb618cf8bf4f5318fe285e821bb2" +dependencies = [ + "find-msvc-tools", + "jobserver", + "libc", + "shlex", +] + +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + +[[package]] +name = "convert_case" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "633458d4ef8c78b72454de2d54fd6ab2e60f9e02be22f3c6104cdc8a4e0fceb9" +dependencies = [ + "unicode-segmentation", +] + +[[package]] +name = "cookie" +version = "0.16.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e859cd57d0710d9e06c381b550c06e76992472a8c6d527aecd2fc673dcc231fb" +dependencies = [ + "percent-encoding", + "time", + "version_check", +] + +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2a6cd9ae233e7f62ba4e9353e81a88df7fc8a5987b8d445b4d90c879bd156f6" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" + +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + +[[package]] +name = "crc32fast" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "crypto-common" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" +dependencies = [ + "generic-array", + "typenum", +] + +[[package]] +name = "deranged" +version = "0.5.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7cd812cc2bc1d69d4764bd80df88b4317eaef9e773c75226407d9bc0876b211c" +dependencies = [ + "powerfmt", +] + +[[package]] +name = "derive_more" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d751e9e49156b02b44f9c1815bcb94b984cdcc4396ecc32521c739452808b134" +dependencies = [ + "derive_more-impl", +] + +[[package]] +name = "derive_more-impl" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "799a97264921d8623a957f6c3b9011f3b5492f557bbb7a5a19b7fa6d06ba8dcb" +dependencies = [ + "convert_case", + "proc-macro2", + "quote", + "rustc_version", + "syn", + "unicode-xid", +] + +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", +] + +[[package]] +name = "displaydoc" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "encoding_rs" +version = "0.8.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys 0.61.2", +] + +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + +[[package]] +name = "find-msvc-tools" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" + +[[package]] +name = "flate2" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "843fba2746e448b37e26a819579957415c8cef339bf08564fe8b7ddbd959573c" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + +[[package]] +name = "form_urlencoded" +version = "1.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb4cb245038516f5f85277875cdaa4f7d2c9a0fa0468de06ed190163b1581fcf" +dependencies = [ + "percent-encoding", +] + +[[package]] +name = "futures-channel" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07bbe89c50d7a535e539b8c17bc0b49bdb77747034daa8087407d655f3f7cc1d" +dependencies = [ + "futures-core", +] + +[[package]] +name = "futures-core" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d" + +[[package]] +name = "futures-sink" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c39754e157331b013978ec91992bde1ac089843443c49cbc7f46150b0fad0893" + +[[package]] +name = "futures-task" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" + +[[package]] +name = "futures-util" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" +dependencies = [ + "futures-core", + "futures-task", + "pin-project-lite", + "slab", +] + +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + +[[package]] +name = "getrandom" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "getrandom" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "wasip2", +] + +[[package]] +name = "getrandom" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "139ef39800118c7683f2fd3c98c1b23c09ae076556b435f8e9064ae108aaeeec" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "wasip2", + "wasip3", +] + +[[package]] +name = "h2" +version = "0.3.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0beca50380b1fc32983fc1cb4587bfa4bb9e78fc259aad4a0032d2080309222d" +dependencies = [ + "bytes", + "fnv", + "futures-core", + "futures-sink", + "futures-util", + "http 0.2.12", + "indexmap", + "slab", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "h2" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f44da3a8150a6703ed5d34e164b875fd14c2cdab9af1252a9a1020bde2bdc54" +dependencies = [ + "atomic-waker", + "bytes", + "fnv", + "futures-core", + "futures-sink", + "http 1.4.0", + "indexmap", + "slab", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "hashbrown" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" +dependencies = [ + "foldhash", +] + +[[package]] +name = "hashbrown" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "http" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "601cbb57e577e2f5ef5be8e7b83f0f63994f25aa94d673e54a92d5c516d101f1" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + +[[package]] +name = "http" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" +dependencies = [ + "bytes", + "itoa", +] + +[[package]] +name = "http-body" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" +dependencies = [ + "bytes", + "http 1.4.0", +] + +[[package]] +name = "http-body-util" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" +dependencies = [ + "bytes", + "futures-core", + "http 1.4.0", + "http-body", + "pin-project-lite", +] + +[[package]] +name = "httparse" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" + +[[package]] +name = "httpdate" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" + +[[package]] +name = "hyper" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ab2d4f250c3d7b1c9fcdff1cece94ea4e2dfbec68614f7b87cb205f24ca9d11" +dependencies = [ + "atomic-waker", + "bytes", + "futures-channel", + "futures-core", + "h2 0.4.13", + "http 1.4.0", + "http-body", + "httparse", + "itoa", + "pin-project-lite", + "pin-utils", + "smallvec", + "tokio", + "want", +] + +[[package]] +name = "hyper-rustls" +version = "0.27.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" +dependencies = [ + "http 1.4.0", + "hyper", + "hyper-util", + "rustls", + "rustls-pki-types", + "tokio", + "tokio-rustls", + "tower-service", +] + +[[package]] +name = "hyper-tls" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" +dependencies = [ + "bytes", + "http-body-util", + "hyper", + "hyper-util", + "native-tls", + "tokio", + "tokio-native-tls", + "tower-service", +] + +[[package]] +name = "hyper-util" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96547c2556ec9d12fb1578c4eaf448b04993e7fb79cbaad930a656880a6bdfa0" +dependencies = [ + "base64", + "bytes", + "futures-channel", + "futures-util", + "http 1.4.0", + "http-body", + "hyper", + "ipnet", + "libc", + "percent-encoding", + "pin-project-lite", + "socket2 0.6.2", + "system-configuration", + "tokio", + "tower-service", + "tracing", + "windows-registry", +] + +[[package]] +name = "icu_collections" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c6b649701667bbe825c3b7e6388cb521c23d88644678e83c0c4d0a621a34b43" +dependencies = [ + "displaydoc", + "potential_utf", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_locale_core" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edba7861004dd3714265b4db54a3c390e880ab658fec5f7db895fae2046b5bb6" +dependencies = [ + "displaydoc", + "litemap", + "tinystr", + "writeable", + "zerovec", +] + +[[package]] +name = "icu_normalizer" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f6c8828b67bf8908d82127b2054ea1b4427ff0230ee9141c54251934ab1b599" +dependencies = [ + "icu_collections", + "icu_normalizer_data", + "icu_properties", + "icu_provider", + "smallvec", + "zerovec", +] + +[[package]] +name = "icu_normalizer_data" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7aedcccd01fc5fe81e6b489c15b247b8b0690feb23304303a9e560f37efc560a" + +[[package]] +name = "icu_properties" +version = "2.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "020bfc02fe870ec3a66d93e677ccca0562506e5872c650f893269e08615d74ec" +dependencies = [ + "icu_collections", + "icu_locale_core", + "icu_properties_data", + "icu_provider", + "zerotrie", + "zerovec", +] + +[[package]] +name = "icu_properties_data" +version = "2.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "616c294cf8d725c6afcd8f55abc17c56464ef6211f9ed59cccffe534129c77af" + +[[package]] +name = "icu_provider" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85962cf0ce02e1e0a629cc34e7ca3e373ce20dda4c4d7294bbd0bf1fdb59e614" +dependencies = [ + "displaydoc", + "icu_locale_core", + "writeable", + "yoke", + "zerofrom", + "zerotrie", + "zerovec", +] + +[[package]] +name = "id-arena" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954" + +[[package]] +name = "idna" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b0875f23caa03898994f6ddc501886a45c7d3d62d04d2d90788d47be1b1e4de" +dependencies = [ + "idna_adapter", + "smallvec", + "utf8_iter", +] + +[[package]] +name = "idna_adapter" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3acae9609540aa318d1bc588455225fb2085b9ed0c4f6bd0d9d5bcd86f1a0344" +dependencies = [ + "icu_normalizer", + "icu_properties", +] + +[[package]] +name = "impl-more" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8a5a9a0ff0086c7a148acb942baaabeadf9504d10400b5a05645853729b9cd2" + +[[package]] +name = "indexmap" +version = "2.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" +dependencies = [ + "equivalent", + "hashbrown 0.16.1", + "serde", + "serde_core", +] + +[[package]] +name = "ipnet" +version = "2.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" + +[[package]] +name = "iri-string" +version = "0.7.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c91338f0783edbd6195decb37bae672fd3b165faffb89bf7b9e6942f8b1a731a" +dependencies = [ + "memchr", + "serde", +] + +[[package]] +name = "itoa" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" + +[[package]] +name = "jobserver" +version = "0.1.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9afb3de4395d6b3e67a780b6de64b51c978ecf11cb9a462c66be7d4ca9039d33" +dependencies = [ + "getrandom 0.3.4", + "libc", +] + +[[package]] +name = "js-sys" +version = "0.3.87" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93f0862381daaec758576dcc22eb7bbf4d7efd67328553f3b45a412a51a3fb21" +dependencies = [ + "once_cell", + "wasm-bindgen", +] + +[[package]] +name = "language-tags" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4345964bb142484797b161f473a503a434de77149dd8c7427788c6e13379388" + +[[package]] +name = "leb128fmt" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" + +[[package]] +name = "libc" +version = "0.2.182" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6800badb6cb2082ffd7b6a67e6125bb39f18782f793520caee8cb8846be06112" + +[[package]] +name = "linux-raw-sys" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" + +[[package]] +name = "litemap" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77" + +[[package]] +name = "local-channel" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6cbc85e69b8df4b8bb8b89ec634e7189099cea8927a276b7384ce5488e53ec8" +dependencies = [ + "futures-core", + "futures-sink", + "local-waker", +] + +[[package]] +name = "local-waker" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4d873d7c67ce09b42110d801813efbc9364414e356be9935700d368351657487" + +[[package]] +name = "lock_api" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965" +dependencies = [ + "scopeguard", +] + +[[package]] +name = "log" +version = "0.4.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" + +[[package]] +name = "memchr" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" + +[[package]] +name = "mime" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" + +[[package]] +name = "miniz_oxide" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" +dependencies = [ + "adler2", + "simd-adler32", +] + +[[package]] +name = "mio" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a69bcab0ad47271a0234d9422b131806bf3968021e5dc9328caf2d4cd58557fc" +dependencies = [ + "libc", + "log", + "wasi", + "windows-sys 0.61.2", +] + +[[package]] +name = "native-tls" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "465500e14ea162429d264d44189adc38b199b62b1c21eea9f69e4b73cb03bbf2" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", +] + +[[package]] +name = "num-conv" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf97ec579c3c42f953ef76dbf8d55ac91fb219dde70e49aa4a6b7d74e9919050" + +[[package]] +name = "once_cell" +version = "1.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" + +[[package]] +name = "openssl" +version = "0.10.75" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08838db121398ad17ab8531ce9de97b244589089e290a384c900cb9ff7434328" +dependencies = [ + "bitflags", + "cfg-if", + "foreign-types", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "openssl-probe" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" + +[[package]] +name = "openssl-sys" +version = "0.9.111" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82cab2d520aa75e3c58898289429321eb788c3106963d0dc886ec7a5f4adc321" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + +[[package]] +name = "parking_lot" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93857453250e3077bd71ff98b6a65ea6621a19bb0f559a85248955ac12c45a1a" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-link", +] + +[[package]] +name = "percent-encoding" +version = "2.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" + +[[package]] +name = "pin-project-lite" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + +[[package]] +name = "pkg-config" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" + +[[package]] +name = "potential_utf" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b73949432f5e2a09657003c25bca5e19a0e9c84f8058ca374f49e0ebe605af77" +dependencies = [ + "zerovec", +] + +[[package]] +name = "powerfmt" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" + +[[package]] +name = "ppv-lite86" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" +dependencies = [ + "zerocopy", +] + +[[package]] +name = "prettyplease" +version = "0.2.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" +dependencies = [ + "proc-macro2", + "syn", +] + +[[package]] +name = "proc-macro2" +version = "1.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21b2ebcf727b7760c461f091f9f0f539b77b8e87f2fd88131e7f1b433b3cece4" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + +[[package]] +name = "rand" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" +dependencies = [ + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76afc826de14238e6e8c374ddcc1fa19e374fd8dd986b0d2af0d02377261d83c" +dependencies = [ + "getrandom 0.3.4", +] + +[[package]] +name = "redox_syscall" +version = "0.5.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" +dependencies = [ + "bitflags", +] + +[[package]] +name = "regex" +version = "1.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-lite" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cab834c73d247e67f4fae452806d17d3c7501756d98c8808d7c9c7aa7d18f973" + +[[package]] +name = "regex-syntax" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a96887878f22d7bad8a3b6dc5b7440e0ada9a245242924394987b21cf2210a4c" + +[[package]] +name = "reqwest" +version = "0.12.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eddd3ca559203180a307f12d114c268abf583f59b03cb906fd0b3ff8646c1147" +dependencies = [ + "base64", + "bytes", + "encoding_rs", + "futures-core", + "h2 0.4.13", + "http 1.4.0", + "http-body", + "http-body-util", + "hyper", + "hyper-rustls", + "hyper-tls", + "hyper-util", + "js-sys", + "log", + "mime", + "native-tls", + "percent-encoding", + "pin-project-lite", + "rustls-pki-types", + "serde", + "serde_json", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tokio-native-tls", + "tower", + "tower-http", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + +[[package]] +name = "ring" +version = "0.17.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" +dependencies = [ + "cc", + "cfg-if", + "getrandom 0.2.17", + "libc", + "untrusted", + "windows-sys 0.52.0", +] + +[[package]] +name = "rustc_version" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" +dependencies = [ + "semver", +] + +[[package]] +name = "rustix" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "146c9e247ccc180c1f61615433868c99f3de3ae256a30a43b49f67c2d9171f34" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.61.2", +] + +[[package]] +name = "rustls" +version = "0.23.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c665f33d38cea657d9614f766881e4d510e0eda4239891eea56b4cadcf01801b" +dependencies = [ + "once_cell", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-pki-types" +version = "1.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be040f8b0a225e40375822a563fa9524378b9d63112f53e19ffff34df5d33fdd" +dependencies = [ + "zeroize", +] + +[[package]] +name = "rustls-webpki" +version = "0.103.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7df23109aa6c1567d1c575b9952556388da57401e4ace1d15f79eedad0d8f53" +dependencies = [ + "ring", + "rustls-pki-types", + "untrusted", +] + +[[package]] +name = "rustversion" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" + +[[package]] +name = "ryu" +version = "1.0.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" + +[[package]] +name = "schannel" +version = "0.1.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "891d81b926048e76efe18581bf793546b4c0eaf8448d72be8de2bbee5fd166e1" +dependencies = [ + "windows-sys 0.61.2", +] + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "security-framework" +version = "3.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7f4bc775c73d9a02cde8bf7b2ec4c9d12743edf609006c7facc23998404cd1d" +dependencies = [ + "bitflags", + "core-foundation 0.10.1", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ce2691df843ecc5d231c0b14ece2acc3efb62c0a398c7e1d875f3983ce020e3" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "semver" +version = "1.0.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.149" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +dependencies = [ + "itoa", + "memchr", + "serde", + "serde_core", + "zmij", +] + +[[package]] +name = "serde_urlencoded" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" +dependencies = [ + "form_urlencoded", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "sha1" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "signal-hook-registry" +version = "1.4.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4db69cba1110affc0e9f7bcd48bbf87b3f4fc7c61fc9155afd4c469eb3d6c1b" +dependencies = [ + "errno", + "libc", +] + +[[package]] +name = "simd-adler32" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2" + +[[package]] +name = "slab" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5" + +[[package]] +name = "smallvec" +version = "1.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" + +[[package]] +name = "socket2" +version = "0.5.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e22376abed350d73dd1cd119b57ffccad95b4e585a7cda43e286245ce23c0678" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + +[[package]] +name = "socket2" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86f4aa3ad99f2088c990dfa82d367e19cb29268ed67c574d10d0a4bfe71f07e0" +dependencies = [ + "libc", + "windows-sys 0.60.2", +] + +[[package]] +name = "stable_deref_trait" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" + +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + +[[package]] +name = "syn" +version = "2.0.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "sync_wrapper" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" +dependencies = [ + "futures-core", +] + +[[package]] +name = "synstructure" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "system-configuration" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a13f3d0daba03132c0aa9767f98351b3488edc2c100cda2d2ec2b04f3d8d3c8b" +dependencies = [ + "bitflags", + "core-foundation 0.9.4", + "system-configuration-sys", +] + +[[package]] +name = "system-configuration-sys" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "tempfile" +version = "3.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0136791f7c95b1f6dd99f9cc786b91bb81c3800b639b3478e561ddb7be95e5f1" +dependencies = [ + "fastrand", + "getrandom 0.4.1", + "once_cell", + "rustix", + "windows-sys 0.61.2", +] + +[[package]] +name = "time" +version = "0.3.47" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "743bd48c283afc0388f9b8827b976905fb217ad9e647fae3a379a9283c4def2c" +dependencies = [ + "deranged", + "itoa", + "num-conv", + "powerfmt", + "serde_core", + "time-core", + "time-macros", +] + +[[package]] +name = "time-core" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7694e1cfe791f8d31026952abf09c69ca6f6fa4e1a1229e18988f06a04a12dca" + +[[package]] +name = "time-macros" +version = "0.2.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e70e4c5a0e0a8a4823ad65dfe1a6930e4f4d756dcd9dd7939022b5e8c501215" +dependencies = [ + "num-conv", + "time-core", +] + +[[package]] +name = "tinystr" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42d3e9c45c09de15d06dd8acf5f4e0e399e85927b7f00711024eb7ae10fa4869" +dependencies = [ + "displaydoc", + "zerovec", +] + +[[package]] +name = "tokio" +version = "1.49.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72a2903cd7736441aac9df9d7688bd0ce48edccaadf181c3b90be801e81d3d86" +dependencies = [ + "bytes", + "libc", + "mio", + "parking_lot", + "pin-project-lite", + "signal-hook-registry", + "socket2 0.6.2", + "tokio-macros", + "windows-sys 0.61.2", +] + +[[package]] +name = "tokio-macros" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af407857209536a95c8e56f8231ef2c2e2aff839b22e07a1ffcbc617e9db9fa5" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tokio-native-tls" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" +dependencies = [ + "native-tls", + "tokio", +] + +[[package]] +name = "tokio-rustls" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61" +dependencies = [ + "rustls", + "tokio", +] + +[[package]] +name = "tokio-util" +version = "0.7.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ae9cec805b01e8fc3fd2fe289f89149a9b66dd16786abd8b19cfa7b48cb0098" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "tower" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebe5ef63511595f1344e2d5cfa636d973292adc0eec1f0ad45fae9f0851ab1d4" +dependencies = [ + "futures-core", + "futures-util", + "pin-project-lite", + "sync_wrapper", + "tokio", + "tower-layer", + "tower-service", +] + +[[package]] +name = "tower-http" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8" +dependencies = [ + "bitflags", + "bytes", + "futures-util", + "http 1.4.0", + "http-body", + "iri-string", + "pin-project-lite", + "tower", + "tower-layer", + "tower-service", +] + +[[package]] +name = "tower-layer" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" + +[[package]] +name = "tower-service" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" + +[[package]] +name = "tracing" +version = "0.1.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" +dependencies = [ + "log", + "pin-project-lite", + "tracing-attributes", + "tracing-core", +] + +[[package]] +name = "tracing-attributes" +version = "0.1.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tracing-core" +version = "0.1.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" +dependencies = [ + "once_cell", +] + +[[package]] +name = "try-lock" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" + +[[package]] +name = "typenum" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" + +[[package]] +name = "unicode-ident" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" + +[[package]] +name = "unicode-segmentation" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" + +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + +[[package]] +name = "url" +version = "2.5.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff67a8a4397373c3ef660812acab3268222035010ab8680ec4215f38ba3d0eed" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", + "serde", +] + +[[package]] +name = "utf8_iter" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" + +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + +[[package]] +name = "want" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" +dependencies = [ + "try-lock", +] + +[[package]] +name = "wasi" +version = "0.11.1+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" + +[[package]] +name = "wasip2" +version = "1.0.2+wasi-0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9517f9239f02c069db75e65f174b3da828fe5f5b945c4dd26bd25d89c03ebcf5" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "wasip3" +version = "0.4.0+wasi-0.3.0-rc-2026-01-06" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5428f8bf88ea5ddc08faddef2ac4a67e390b88186c703ce6dbd955e1c145aca5" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "wasm-bindgen" +version = "0.2.110" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1de241cdc66a9d91bd84f097039eb140cdc6eec47e0cdbaf9d932a1dd6c35866" +dependencies = [ + "cfg-if", + "once_cell", + "rustversion", + "wasm-bindgen-macro", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-futures" +version = "0.4.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a42e96ea38f49b191e08a1bab66c7ffdba24b06f9995b39a9dd60222e5b6f1da" +dependencies = [ + "cfg-if", + "futures-util", + "js-sys", + "once_cell", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.110" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e12fdf6649048f2e3de6d7d5ff3ced779cdedee0e0baffd7dff5cdfa3abc8a52" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.110" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e63d1795c565ac3462334c1e396fd46dbf481c40f51f5072c310717bc4fb309" +dependencies = [ + "bumpalo", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.110" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9f9cdac23a5ce71f6bf9f8824898a501e511892791ea2a0c6b8568c68b9cb53" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "wasm-encoder" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "990065f2fe63003fe337b932cfb5e3b80e0b4d0f5ff650e6985b1048f62c8319" +dependencies = [ + "leb128fmt", + "wasmparser", +] + +[[package]] +name = "wasm-metadata" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb0e353e6a2fbdc176932bbaab493762eb1255a7900fe0fea1a2f96c296cc909" +dependencies = [ + "anyhow", + "indexmap", + "wasm-encoder", + "wasmparser", +] + +[[package]] +name = "wasmparser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47b807c72e1bac69382b3a6fb3dbe8ea4c0ed87ff5629b8685ae6b9a611028fe" +dependencies = [ + "bitflags", + "hashbrown 0.15.5", + "indexmap", + "semver", +] + +[[package]] +name = "web-sys" +version = "0.3.87" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2c7c5718134e770ee62af3b6b4a84518ec10101aad610c024b64d6ff29bb1ff" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-registry" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02752bf7fbdcce7f2a27a742f798510f3e5ad88dbe84871e5168e2120c3d5720" +dependencies = [ + "windows-link", + "windows-result", + "windows-strings", +] + +[[package]] +name = "windows-result" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7781fa89eaf60850ac3d2da7af8e5242a5ea78d1a11c49bf2910bb5a73853eb5" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-strings" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7837d08f69c77cf6b07689544538e017c1bfcf57e34b4c0ff58e6c2cd3b37091" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-sys" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" +dependencies = [ + "windows-targets 0.53.5", +] + +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", + "windows_i686_gnullvm 0.52.6", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", +] + +[[package]] +name = "windows-targets" +version = "0.53.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3" +dependencies = [ + "windows-link", + "windows_aarch64_gnullvm 0.53.1", + "windows_aarch64_msvc 0.53.1", + "windows_i686_gnu 0.53.1", + "windows_i686_gnullvm 0.53.1", + "windows_i686_msvc 0.53.1", + "windows_x86_64_gnu 0.53.1", + "windows_x86_64_gnullvm 0.53.1", + "windows_x86_64_msvc 0.53.1", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnu" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "960e6da069d81e09becb0ca57a65220ddff016ff2d6af6a223cf372a506593a3" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_i686_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" + +[[package]] +name = "wirespec-petstore" +version = "0.1.0" +dependencies = [ + "actix-web", + "regex", + "reqwest", + "serde", + "serde_json", + "tokio", +] + +[[package]] +name = "wit-bindgen" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" +dependencies = [ + "wit-bindgen-rust-macro", +] + +[[package]] +name = "wit-bindgen-core" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea61de684c3ea68cb082b7a88508a8b27fcc8b797d738bfc99a82facf1d752dc" +dependencies = [ + "anyhow", + "heck", + "wit-parser", +] + +[[package]] +name = "wit-bindgen-rust" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7c566e0f4b284dd6561c786d9cb0142da491f46a9fbed79ea69cdad5db17f21" +dependencies = [ + "anyhow", + "heck", + "indexmap", + "prettyplease", + "syn", + "wasm-metadata", + "wit-bindgen-core", + "wit-component", +] + +[[package]] +name = "wit-bindgen-rust-macro" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c0f9bfd77e6a48eccf51359e3ae77140a7f50b1e2ebfe62422d8afdaffab17a" +dependencies = [ + "anyhow", + "prettyplease", + "proc-macro2", + "quote", + "syn", + "wit-bindgen-core", + "wit-bindgen-rust", +] + +[[package]] +name = "wit-component" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d66ea20e9553b30172b5e831994e35fbde2d165325bec84fc43dbf6f4eb9cb2" +dependencies = [ + "anyhow", + "bitflags", + "indexmap", + "log", + "serde", + "serde_derive", + "serde_json", + "wasm-encoder", + "wasm-metadata", + "wasmparser", + "wit-parser", +] + +[[package]] +name = "wit-parser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecc8ac4bc1dc3381b7f59c34f00b67e18f910c2c0f50015669dde7def656a736" +dependencies = [ + "anyhow", + "id-arena", + "indexmap", + "log", + "semver", + "serde", + "serde_derive", + "serde_json", + "unicode-xid", + "wasmparser", +] + +[[package]] +name = "writeable" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9" + +[[package]] +name = "yoke" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72d6e5c6afb84d73944e5cedb052c4680d5657337201555f9f2a16b7406d4954" +dependencies = [ + "stable_deref_trait", + "yoke-derive", + "zerofrom", +] + +[[package]] +name = "yoke-derive" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b659052874eb698efe5b9e8cf382204678a0086ebf46982b79d6ca3182927e5d" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + +[[package]] +name = "zerocopy" +version = "0.8.39" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db6d35d663eadb6c932438e763b262fe1a70987f9ae936e60158176d710cae4a" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.39" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4122cd3169e94605190e77839c9a40d40ed048d305bfdc146e7df40ab0f3e517" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "zerofrom" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50cc42e0333e05660c3587f3bf9d0478688e15d870fab3346451ce7f8c9fbea5" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + +[[package]] +name = "zeroize" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" + +[[package]] +name = "zerotrie" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a59c17a5562d507e4b54960e8569ebee33bee890c70aa3fe7b97e85a9fd7851" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", +] + +[[package]] +name = "zerovec" +version = "0.11.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c28719294829477f525be0186d13efa9a3c602f7ec202ca9e353d310fb9a002" +dependencies = [ + "yoke", + "zerofrom", + "zerovec-derive", +] + +[[package]] +name = "zerovec-derive" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eadce39539ca5cb3985590102671f2567e659fca9666581ad3411d59207951f3" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "zmij" +version = "1.0.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" + +[[package]] +name = "zstd" +version = "0.13.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e91ee311a569c327171651566e07972200e76fcfe2242a4fa446149a3881c08a" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "7.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f49c4d5f0abb602a93fb8736af2a4f4dd9512e36f7f570d66e65ff867ed3b9d" +dependencies = [ + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.16+zstd.1.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e19ebc2adc8f83e43039e79776e3fda8ca919132d68a1fed6a5faca2683748" +dependencies = [ + "cc", + "pkg-config", +] diff --git a/examples/rust-petstore/Cargo.toml b/examples/rust-petstore/Cargo.toml new file mode 100644 index 000000000..db9ce2799 --- /dev/null +++ b/examples/rust-petstore/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "wirespec-petstore" +version = "0.1.0" +edition = "2021" + +[[bin]] +name = "client" +path = "src/client.rs" + +[[bin]] +name = "server" +path = "src/server.rs" + +[dependencies] +serde = "1" +serde_json = "1" +regex = "1" +reqwest = { version = "0.12", features = ["json"] } +actix-web = "4" +tokio = { version = "1", features = ["full"] } diff --git a/examples/rust-petstore/gen.sh b/examples/rust-petstore/gen.sh new file mode 100755 index 000000000..5958b3c18 --- /dev/null +++ b/examples/rust-petstore/gen.sh @@ -0,0 +1,29 @@ +#!/bin/bash +set -e + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" + +OS="$(uname -s)" +ARCH="$(uname -m)" + +case "${OS}_${ARCH}" in + Linux_x86_64) PLATFORM="linuxX64" ;; + Darwin_arm64) PLATFORM="macosArm64" ;; + Darwin_x86_64) PLATFORM="macosX64" ;; + *) echo "Unsupported platform: ${OS}_${ARCH}" >&2; exit 1 ;; +esac + +ROOT_DIR="$SCRIPT_DIR/../.." +CLI_KEXE="$ROOT_DIR/src/plugin/cli/build/bin/${PLATFORM}/releaseExecutable/cli.kexe" + +if [ ! -f "$CLI_KEXE" ]; then + echo "Building wirespec CLI..." + "$ROOT_DIR/gradlew" -p "$ROOT_DIR" ":src:plugin:cli:${PLATFORM}Binaries" +fi + +"$CLI_KEXE" convert OpenAPIV2 \ + -i "$SCRIPT_DIR/petstore.json" \ + -o "$SCRIPT_DIR/src/gen" \ + -l Rust \ + -p '' \ + --shared diff --git a/examples/rust-petstore/petstore.json b/examples/rust-petstore/petstore.json new file mode 100644 index 000000000..817ecb024 --- /dev/null +++ b/examples/rust-petstore/petstore.json @@ -0,0 +1,1054 @@ +{ + "swagger": "2.0", + "info": { + "description": "This is a sample server Petstore server. You can find out more about Swagger at [http://swagger.io](http://swagger.io) or on [irc.freenode.net, #swagger](http://swagger.io/irc/). For this sample, you can use the api key `special-key` to test the authorization filters.", + "version": "1.0.7", + "title": "Swagger Petstore", + "termsOfService": "http://swagger.io/terms/", + "contact": { + "email": "apiteam@swagger.io" + }, + "license": { + "name": "Apache 2.0", + "url": "http://www.apache.org/licenses/LICENSE-2.0.html" + } + }, + "host": "petstore.swagger.io", + "basePath": "/v2", + "tags": [ + { + "name": "pet", + "description": "Everything about your Pets", + "externalDocs": { + "description": "Find out more", + "url": "http://swagger.io" + } + }, + { + "name": "store", + "description": "Access to Petstore orders" + }, + { + "name": "user", + "description": "Operations about user", + "externalDocs": { + "description": "Find out more about our store", + "url": "http://swagger.io" + } + } + ], + "schemes": [ + "https", + "http" + ], + "paths": { + "/pet/{petId}/uploadImage": { + "post": { + "tags": [ + "pet" + ], + "summary": "uploads an image", + "description": "", + "operationId": "uploadFile", + "consumes": [ + "multipart/form-data" + ], + "produces": [ + "application/json" + ], + "parameters": [ + { + "name": "petId", + "in": "path", + "description": "ID of pet to update", + "required": true, + "type": "integer", + "format": "int64" + }, + { + "name": "additionalMetadata", + "in": "formData", + "description": "Additional data to pass to server", + "required": false, + "type": "string" + }, + { + "name": "file", + "in": "formData", + "description": "file to upload", + "required": false, + "type": "file" + } + ], + "responses": { + "200": { + "description": "successful operation", + "schema": { + "$ref": "#/definitions/ApiResponse" + } + } + }, + "security": [ + { + "petstore_auth": [ + "write:pets", + "read:pets" + ] + } + ] + } + }, + "/pet": { + "post": { + "tags": [ + "pet" + ], + "summary": "Add a new pet to the store", + "description": "", + "operationId": "addPet", + "consumes": [ + "application/json", + "application/xml" + ], + "produces": [ + "application/json", + "application/xml" + ], + "parameters": [ + { + "in": "body", + "name": "body", + "description": "Pet object that needs to be added to the store", + "required": true, + "schema": { + "$ref": "#/definitions/Pet" + } + } + ], + "responses": { + "405": { + "description": "Invalid input" + } + }, + "security": [ + { + "petstore_auth": [ + "write:pets", + "read:pets" + ] + } + ] + }, + "put": { + "tags": [ + "pet" + ], + "summary": "Update an existing pet", + "description": "", + "operationId": "updatePet", + "consumes": [ + "application/json", + "application/xml" + ], + "produces": [ + "application/json", + "application/xml" + ], + "parameters": [ + { + "in": "body", + "name": "body", + "description": "Pet object that needs to be added to the store", + "required": true, + "schema": { + "$ref": "#/definitions/Pet" + } + } + ], + "responses": { + "400": { + "description": "Invalid ID supplied" + }, + "404": { + "description": "Pet not found" + }, + "405": { + "description": "Validation exception" + } + }, + "security": [ + { + "petstore_auth": [ + "write:pets", + "read:pets" + ] + } + ] + } + }, + "/pet/findByStatus": { + "get": { + "tags": [ + "pet" + ], + "summary": "Finds Pets by status", + "description": "Multiple status values can be provided with comma separated strings", + "operationId": "findPetsByStatus", + "produces": [ + "application/json", + "application/xml" + ], + "parameters": [ + { + "name": "status", + "in": "query", + "description": "Status values that need to be considered for filter", + "required": true, + "type": "array", + "items": { + "type": "string", + "enum": [ + "available", + "pending", + "sold" + ], + "default": "available" + }, + "collectionFormat": "multi" + } + ], + "responses": { + "200": { + "description": "successful operation", + "schema": { + "type": "array", + "items": { + "$ref": "#/definitions/Pet" + } + } + }, + "400": { + "description": "Invalid status value" + } + }, + "security": [ + { + "petstore_auth": [ + "write:pets", + "read:pets" + ] + } + ] + } + }, + "/pet/findByTags": { + "get": { + "tags": [ + "pet" + ], + "summary": "Finds Pets by tags", + "description": "Multiple tags can be provided with comma separated strings. Use tag1, tag2, tag3 for testing.", + "operationId": "findPetsByTags", + "produces": [ + "application/json", + "application/xml" + ], + "parameters": [ + { + "name": "tags", + "in": "query", + "description": "Tags to filter by", + "required": true, + "type": "array", + "items": { + "type": "string" + }, + "collectionFormat": "multi" + } + ], + "responses": { + "200": { + "description": "successful operation", + "schema": { + "type": "array", + "items": { + "$ref": "#/definitions/Pet" + } + } + }, + "400": { + "description": "Invalid tag value" + } + }, + "security": [ + { + "petstore_auth": [ + "write:pets", + "read:pets" + ] + } + ], + "deprecated": true + } + }, + "/pet/{petId}": { + "get": { + "tags": [ + "pet" + ], + "summary": "Find pet by ID", + "description": "Returns a single pet", + "operationId": "getPetById", + "produces": [ + "application/json", + "application/xml" + ], + "parameters": [ + { + "name": "petId", + "in": "path", + "description": "ID of pet to return", + "required": true, + "type": "integer", + "format": "int64" + } + ], + "responses": { + "200": { + "description": "successful operation", + "schema": { + "$ref": "#/definitions/Pet" + } + }, + "400": { + "description": "Invalid ID supplied" + }, + "404": { + "description": "Pet not found" + } + }, + "security": [ + { + "api_key": [] + } + ] + }, + "post": { + "tags": [ + "pet" + ], + "summary": "Updates a pet in the store with form data", + "description": "", + "operationId": "updatePetWithForm", + "consumes": [ + "application/x-www-form-urlencoded" + ], + "produces": [ + "application/json", + "application/xml" + ], + "parameters": [ + { + "name": "petId", + "in": "path", + "description": "ID of pet that needs to be updated", + "required": true, + "type": "integer", + "format": "int64" + }, + { + "name": "name", + "in": "formData", + "description": "Updated name of the pet", + "required": false, + "type": "string" + }, + { + "name": "status", + "in": "formData", + "description": "Updated status of the pet", + "required": false, + "type": "string" + } + ], + "responses": { + "405": { + "description": "Invalid input" + } + }, + "security": [ + { + "petstore_auth": [ + "write:pets", + "read:pets" + ] + } + ] + }, + "delete": { + "tags": [ + "pet" + ], + "summary": "Deletes a pet", + "description": "", + "operationId": "deletePet", + "produces": [ + "application/json", + "application/xml" + ], + "parameters": [ + { + "name": "api_key", + "in": "header", + "required": false, + "type": "string" + }, + { + "name": "petId", + "in": "path", + "description": "Pet id to delete", + "required": true, + "type": "integer", + "format": "int64" + } + ], + "responses": { + "400": { + "description": "Invalid ID supplied" + }, + "404": { + "description": "Pet not found" + } + }, + "security": [ + { + "petstore_auth": [ + "write:pets", + "read:pets" + ] + } + ] + } + }, + "/store/inventory": { + "get": { + "tags": [ + "store" + ], + "summary": "Returns pet inventories by status", + "description": "Returns a map of status codes to quantities", + "operationId": "getInventory", + "produces": [ + "application/json" + ], + "parameters": [], + "responses": { + "200": { + "description": "successful operation", + "schema": { + "type": "object", + "additionalProperties": { + "type": "integer", + "format": "int32" + } + } + } + }, + "security": [ + { + "api_key": [] + } + ] + } + }, + "/store/order": { + "post": { + "tags": [ + "store" + ], + "summary": "Place an order for a pet", + "description": "", + "operationId": "placeOrder", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json", + "application/xml" + ], + "parameters": [ + { + "in": "body", + "name": "body", + "description": "order placed for purchasing the pet", + "required": true, + "schema": { + "$ref": "#/definitions/Order" + } + } + ], + "responses": { + "200": { + "description": "successful operation", + "schema": { + "$ref": "#/definitions/Order" + } + }, + "400": { + "description": "Invalid Order" + } + } + } + }, + "/store/order/{orderId}": { + "get": { + "tags": [ + "store" + ], + "summary": "Find purchase order by ID", + "description": "For valid response try integer IDs with value >= 1 and <= 10. Other values will generated exceptions", + "operationId": "getOrderById", + "produces": [ + "application/json", + "application/xml" + ], + "parameters": [ + { + "name": "orderId", + "in": "path", + "description": "ID of pet that needs to be fetched", + "required": true, + "type": "integer", + "maximum": 10, + "minimum": 1, + "format": "int64" + } + ], + "responses": { + "200": { + "description": "successful operation", + "schema": { + "$ref": "#/definitions/Order" + } + }, + "400": { + "description": "Invalid ID supplied" + }, + "404": { + "description": "Order not found" + } + } + }, + "delete": { + "tags": [ + "store" + ], + "summary": "Delete purchase order by ID", + "description": "For valid response try integer IDs with positive integer value. Negative or non-integer values will generate API errors", + "operationId": "deleteOrder", + "produces": [ + "application/json", + "application/xml" + ], + "parameters": [ + { + "name": "orderId", + "in": "path", + "description": "ID of the order that needs to be deleted", + "required": true, + "type": "integer", + "minimum": 1, + "format": "int64" + } + ], + "responses": { + "400": { + "description": "Invalid ID supplied" + }, + "404": { + "description": "Order not found" + } + } + } + }, + "/user/createWithList": { + "post": { + "tags": [ + "user" + ], + "summary": "Creates list of users with given input array", + "description": "", + "operationId": "createUsersWithListInput", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json", + "application/xml" + ], + "parameters": [ + { + "in": "body", + "name": "body", + "description": "List of user object", + "required": true, + "schema": { + "type": "array", + "items": { + "$ref": "#/definitions/User" + } + } + } + ], + "responses": { + "default": { + "description": "successful operation" + } + } + } + }, + "/user/{username}": { + "get": { + "tags": [ + "user" + ], + "summary": "Get user by user name", + "description": "", + "operationId": "getUserByName", + "produces": [ + "application/json", + "application/xml" + ], + "parameters": [ + { + "name": "username", + "in": "path", + "description": "The name that needs to be fetched. Use user1 for testing. ", + "required": true, + "type": "string" + } + ], + "responses": { + "200": { + "description": "successful operation", + "schema": { + "$ref": "#/definitions/User" + } + }, + "400": { + "description": "Invalid username supplied" + }, + "404": { + "description": "User not found" + } + } + }, + "put": { + "tags": [ + "user" + ], + "summary": "Updated user", + "description": "This can only be done by the logged in user.", + "operationId": "updateUser", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json", + "application/xml" + ], + "parameters": [ + { + "name": "username", + "in": "path", + "description": "name that need to be updated", + "required": true, + "type": "string" + }, + { + "in": "body", + "name": "body", + "description": "Updated user object", + "required": true, + "schema": { + "$ref": "#/definitions/User" + } + } + ], + "responses": { + "400": { + "description": "Invalid user supplied" + }, + "404": { + "description": "User not found" + } + } + }, + "delete": { + "tags": [ + "user" + ], + "summary": "Delete user", + "description": "This can only be done by the logged in user.", + "operationId": "deleteUser", + "produces": [ + "application/json", + "application/xml" + ], + "parameters": [ + { + "name": "username", + "in": "path", + "description": "The name that needs to be deleted", + "required": true, + "type": "string" + } + ], + "responses": { + "400": { + "description": "Invalid username supplied" + }, + "404": { + "description": "User not found" + } + } + } + }, + "/user/login": { + "get": { + "tags": [ + "user" + ], + "summary": "Logs user into the system", + "description": "", + "operationId": "loginUser", + "produces": [ + "application/json", + "application/xml" + ], + "parameters": [ + { + "name": "username", + "in": "query", + "description": "The user name for login", + "required": true, + "type": "string" + }, + { + "name": "password", + "in": "query", + "description": "The password for login in clear text", + "required": true, + "type": "string" + } + ], + "responses": { + "200": { + "description": "successful operation", + "headers": { + "X-Expires-After": { + "type": "string", + "format": "date-time", + "description": "date in UTC when token expires" + }, + "X-Rate-Limit": { + "type": "integer", + "format": "int32", + "description": "calls per hour allowed by the user" + } + }, + "schema": { + "type": "string" + } + }, + "400": { + "description": "Invalid username/password supplied" + } + } + } + }, + "/user/logout": { + "get": { + "tags": [ + "user" + ], + "summary": "Logs out current logged in user session", + "description": "", + "operationId": "logoutUser", + "produces": [ + "application/json", + "application/xml" + ], + "parameters": [], + "responses": { + "default": { + "description": "successful operation" + } + } + } + }, + "/user/createWithArray": { + "post": { + "tags": [ + "user" + ], + "summary": "Creates list of users with given input array", + "description": "", + "operationId": "createUsersWithArrayInput", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json", + "application/xml" + ], + "parameters": [ + { + "in": "body", + "name": "body", + "description": "List of user object", + "required": true, + "schema": { + "type": "array", + "items": { + "$ref": "#/definitions/User" + } + } + } + ], + "responses": { + "default": { + "description": "successful operation" + } + } + } + }, + "/user": { + "post": { + "tags": [ + "user" + ], + "summary": "Create user", + "description": "This can only be done by the logged in user.", + "operationId": "createUser", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json", + "application/xml" + ], + "parameters": [ + { + "in": "body", + "name": "body", + "description": "Created user object", + "required": true, + "schema": { + "$ref": "#/definitions/User" + } + } + ], + "responses": { + "default": { + "description": "successful operation" + } + } + } + } + }, + "securityDefinitions": { + "api_key": { + "type": "apiKey", + "name": "api_key", + "in": "header" + }, + "petstore_auth": { + "type": "oauth2", + "authorizationUrl": "https://petstore.swagger.io/oauth/authorize", + "flow": "implicit", + "scopes": { + "read:pets": "read your pets", + "write:pets": "modify pets in your account" + } + } + }, + "definitions": { + "ApiResponse": { + "type": "object", + "properties": { + "code": { + "type": "integer", + "format": "int32" + }, + "type": { + "type": "string" + }, + "message": { + "type": "string" + } + } + }, + "Category": { + "type": "object", + "properties": { + "id": { + "type": "integer", + "format": "int64" + }, + "name": { + "type": "string" + } + }, + "xml": { + "name": "Category" + } + }, + "Pet": { + "type": "object", + "required": [ + "name", + "photoUrls" + ], + "properties": { + "id": { + "type": "integer", + "format": "int64" + }, + "category": { + "$ref": "#/definitions/Category" + }, + "name": { + "type": "string", + "example": "doggie" + }, + "photoUrls": { + "type": "array", + "xml": { + "wrapped": true + }, + "items": { + "type": "string", + "xml": { + "name": "photoUrl" + } + } + }, + "tags": { + "type": "array", + "xml": { + "wrapped": true + }, + "items": { + "xml": { + "name": "tag" + }, + "$ref": "#/definitions/Tag" + } + }, + "status": { + "type": "string", + "description": "pet status in the store", + "enum": [ + "available", + "pending", + "sold" + ] + } + }, + "xml": { + "name": "Pet" + } + }, + "Tag": { + "type": "object", + "properties": { + "id": { + "type": "integer", + "format": "int64" + }, + "name": { + "type": "string" + } + }, + "xml": { + "name": "Tag" + } + }, + "Order": { + "type": "object", + "properties": { + "id": { + "type": "integer", + "format": "int64" + }, + "petId": { + "type": "integer", + "format": "int64" + }, + "quantity": { + "type": "integer", + "format": "int32" + }, + "shipDate": { + "type": "string", + "format": "date-time" + }, + "status": { + "type": "string", + "description": "Order Status", + "enum": [ + "placed", + "approved", + "delivered" + ] + }, + "complete": { + "type": "boolean" + } + }, + "xml": { + "name": "Order" + } + }, + "User": { + "type": "object", + "properties": { + "id": { + "type": "integer", + "format": "int64" + }, + "username": { + "type": "string" + }, + "firstName": { + "type": "string" + }, + "lastName": { + "type": "string" + }, + "email": { + "type": "string" + }, + "password": { + "type": "string" + }, + "phone": { + "type": "string" + }, + "userStatus": { + "type": "integer", + "format": "int32", + "description": "User Status" + } + }, + "xml": { + "name": "User" + } + } + }, + "externalDocs": { + "description": "Find out more about Swagger", + "url": "http://swagger.io" + } +} \ No newline at end of file diff --git a/examples/rust-petstore/src/actix.rs b/examples/rust-petstore/src/actix.rs new file mode 100644 index 000000000..4f5959ca2 --- /dev/null +++ b/examples/rust-petstore/src/actix.rs @@ -0,0 +1,104 @@ +use actix_web::{web, HttpRequest, HttpResponse}; +use crate::gen::wirespec::{RawRequest, RawResponse}; +use std::collections::HashMap; + +pub fn to_raw_request(req: &HttpRequest, body: web::Bytes) -> RawRequest { + let path: Vec = req + .path() + .split('/') + .filter(|s| !s.is_empty()) + .map(String::from) + .collect(); + + let mut queries: HashMap> = HashMap::new(); + if let Some(query_string) = req.uri().query() { + for pair in query_string.split('&') { + if let Some((key, value)) = pair.split_once('=') { + queries + .entry(key.to_string()) + .or_default() + .push(value.to_string()); + } + } + } + + let mut headers: HashMap> = HashMap::new(); + for (key, value) in req.headers() { + headers + .entry(key.as_str().to_string()) + .or_default() + .push(value.to_str().unwrap_or("").to_string()); + } + + RawRequest { + method: req.method().to_string(), + path, + queries, + headers, + body: if body.is_empty() { + None + } else { + Some(body.to_vec()) + }, + } +} + +pub fn to_http_response(raw: RawResponse) -> HttpResponse { + let mut builder = HttpResponse::build( + actix_web::http::StatusCode::from_u16(raw.status_code as u16) + .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR), + ); + + for (key, values) in &raw.headers { + for value in values { + builder.append_header((key.as_str(), value.as_str())); + } + } + + match raw.body { + Some(body) => builder.content_type("application/json").body(body), + None => builder.finish(), + } +} + +#[macro_export] +macro_rules! register { + ($cfg:expr, $handler_type:ident; $($module:ident :: $ns:ident),* $(,)?) => { + $cfg.app_data(web::Data::new($handler_type)); + $cfg.app_data(web::Data::new($crate::serialization::JsonSerialization)); + $(register!($cfg, $module::$ns, $handler_type);)* + }; + ($cfg:expr, $module:ident :: $ns:ident, $handler_type:ty) => {{ + let api = $module::$ns::Api; + let path = api.path_template(); + let method = api.method(); + let route = match method { + Method::GET => web::get(), + Method::PUT => web::put(), + Method::POST => web::post(), + Method::DELETE => web::delete(), + Method::HEAD => web::head(), + Method::PATCH => web::patch(), + Method::OPTIONS => web::method(actix_web::http::Method::OPTIONS), + Method::TRACE => web::method(actix_web::http::Method::TRACE), + }; + $cfg.route( + path, + route.to( + |req: HttpRequest, + body: web::Bytes, + handler: web::Data<$handler_type>, + ser: web::Data<$crate::serialization::JsonSerialization>| async move { + let raw = $crate::actix::to_raw_request(&req, body); + let typed_req = $module::$ns::from_raw_request(&**ser, raw); + let typed_res = + <$handler_type as $module::$ns::Handler>::$module( + &**handler, typed_req, + ).await; + let raw_res = $module::$ns::to_raw_response(&**ser, typed_res); + $crate::actix::to_http_response(raw_res) + }, + ), + ); + }}; +} diff --git a/examples/rust-petstore/src/client.rs b/examples/rust-petstore/src/client.rs new file mode 100644 index 000000000..944e84cb2 --- /dev/null +++ b/examples/rust-petstore/src/client.rs @@ -0,0 +1,104 @@ +use wirespec_petstore::gen::endpoint::{ + find_pets_by_status, get_inventory, get_pet_by_id, place_order, +}; +use wirespec_petstore::gen::model::category::Category; +use wirespec_petstore::gen::model::order::Order; +use wirespec_petstore::gen::model::order_status::OrderStatus; +use wirespec_petstore::gen::model::pet::Pet; +use wirespec_petstore::gen::model::pet_status::PetStatus; +use wirespec_petstore::gen::model::tag::Tag; +use wirespec_petstore::gen::wirespec::Client; +use wirespec_petstore::serialization::JsonSerialization; +use wirespec_petstore::transportation::{ClientProxy, ReqwestTransport}; + +trait PetstoreApi: + find_pets_by_status::FindPetsByStatus::Handler + + get_pet_by_id::GetPetById::Handler + + get_inventory::GetInventory::Handler + + place_order::PlaceOrder::Handler +{} + +impl PetstoreApi for T {} + +#[tokio::main] +async fn main() { + let api = ClientProxy { + transport: ReqwestTransport::new("https://petstore.swagger.io/v2"), + serialization: JsonSerialization, + }; + run(&api).await; +} + +async fn run(api: &impl PetstoreApi) { + println!("=== Find pets by status 'available' ==="); + let resp = api.find_pets_by_status(find_pets_by_status::Request::new(vec!["available".into()])).await; + match resp { + find_pets_by_status::Response::Response200(r) => { + println!("Found {} pets", r.body.len()); + for pet in r.body.iter().take(3) { + println!( + " - {} (id: {:?}, status: {:?})", + pet.name, pet.id, pet.status + ); + } + } + _ => println!("Error fetching pets"), + } + + println!("\n=== Get pet by ID ==="); + let resp = api.get_pet_by_id(get_pet_by_id::Request::new(1)).await; + match resp { + get_pet_by_id::Response::Response200(r) => { + println!("Pet: {} (id: {:?})", r.body.name, r.body.id); + println!(" Category: {:?}", r.body.category); + println!(" Tags: {:?}", r.body.tags); + println!(" Status: {:?}", r.body.status); + println!(" Photo URLs: {:?}", r.body.photo_urls); + } + _ => println!("Error fetching pet"), + } + + println!("\n=== Add a new pet ==="); + let _new_pet = Pet { + id: None, + category: Some(Category { + id: Some(1), + name: Some("Dogs".into()), + }), + name: "Wirespec Dog".into(), + photo_urls: vec!["https://example.com/dog.jpg".into()], + tags: Some(vec![Tag { + id: Some(1), + name: Some("gen".into()), + }]), + status: Some(PetStatus::Available), + }; + + println!("\n=== Get store inventory ==="); + let resp = api.get_inventory(get_inventory::Request::new()).await; + match resp { + get_inventory::Response::Response200(r) => { + for (status, count) in &r.body { + println!(" {}: {}", status, count); + } + } + } + + println!("\n=== Place an order ==="); + let order = Order { + id: None, + pet_id: Some(1), + quantity: Some(1), + ship_date: Some("2025-01-01T00:00:00.000Z".into()), + status: Some(OrderStatus::Placed), + complete: Some(false), + }; + let resp = api.place_order(place_order::Request::new(order)).await; + match resp { + place_order::Response::Response200(r) => println!( + "Order placed: id={:?}, status={:?}", + r.body.id, r.body.status + ), + _ => println!("Error placing order"), + } +} diff --git a/examples/rust-petstore/src/lib.rs b/examples/rust-petstore/src/lib.rs new file mode 100644 index 000000000..f32082603 --- /dev/null +++ b/examples/rust-petstore/src/lib.rs @@ -0,0 +1,5 @@ +pub mod gen; +pub mod serialization; +pub mod actix; +pub mod transportation; +pub mod service; diff --git a/examples/rust-petstore/src/serialization.rs b/examples/rust-petstore/src/serialization.rs new file mode 100644 index 000000000..75fbcf26a --- /dev/null +++ b/examples/rust-petstore/src/serialization.rs @@ -0,0 +1,298 @@ +use std::any::{Any, TypeId}; +use crate::gen::wirespec::*; +use crate::gen::model::pet::Pet; +use crate::gen::model::category::Category; +use crate::gen::model::tag::Tag; +use crate::gen::model::pet_status::PetStatus; +use crate::gen::model::order::Order; +use crate::gen::model::order_status::OrderStatus; +use crate::gen::model::user::User; +use crate::gen::model::api_response::ApiResponse; + +pub struct JsonSerialization; + +impl BodySerializer for JsonSerialization { + fn serialize_body(&self, t: &T, _type: TypeId) -> Vec { + let any: &dyn Any = t; + let value = if let Some(v) = any.downcast_ref::() { + pet_to_json(v) + } else if let Some(v) = any.downcast_ref::>() { + serde_json::Value::Array(v.iter().map(pet_to_json).collect()) + } else if let Some(v) = any.downcast_ref::() { + order_to_json(v) + } else if let Some(v) = any.downcast_ref::() { + user_to_json(v) + } else if let Some(v) = any.downcast_ref::>() { + serde_json::Value::Array(v.iter().map(user_to_json).collect()) + } else if let Some(v) = any.downcast_ref::() { + api_response_to_json(v) + } else if let Some(v) = any.downcast_ref::() { + serde_json::Value::String(v.clone()) + } else if let Some(v) = any.downcast_ref::>() { + let map: serde_json::Map = v + .iter() + .map(|(k, v)| (k.clone(), serde_json::Value::Number((*v).into()))) + .collect(); + serde_json::Value::Object(map) + } else { + panic!("Unsupported body type for serialization: {:?}", _type) + }; + serde_json::to_vec(&value).unwrap() + } +} + +impl BodyDeserializer for JsonSerialization { + fn deserialize_body(&self, raw: &[u8], r#type: TypeId) -> T { + let value: serde_json::Value = serde_json::from_slice(raw).unwrap(); + let boxed: Box = if r#type == TypeId::of::() { + Box::new(json_to_pet(&value)) + } else if r#type == TypeId::of::>() { + Box::new( + value + .as_array() + .unwrap() + .iter() + .map(json_to_pet) + .collect::>(), + ) + } else if r#type == TypeId::of::() { + Box::new(json_to_order(&value)) + } else if r#type == TypeId::of::() { + Box::new(json_to_user(&value)) + } else if r#type == TypeId::of::>() { + Box::new( + value + .as_array() + .unwrap() + .iter() + .map(json_to_user) + .collect::>(), + ) + } else if r#type == TypeId::of::() { + Box::new(json_to_api_response(&value)) + } else if r#type == TypeId::of::() { + Box::new(value.as_str().unwrap_or_default().to_string()) + } else if r#type == TypeId::of::>() { + let map: std::collections::HashMap = value + .as_object() + .unwrap() + .iter() + .map(|(k, v)| (k.clone(), v.as_i64().unwrap_or(0) as i32)) + .collect(); + Box::new(map) + } else { + panic!("Unsupported body type for deserialization: {:?}", r#type) + }; + *boxed.downcast::().unwrap() + } +} + +impl PathSerializer for JsonSerialization { + fn serialize_path(&self, t: &T, _type: TypeId) -> String { + t.to_string() + } +} + +impl PathDeserializer for JsonSerialization { + fn deserialize_path(&self, raw: &str, _type: TypeId) -> T + where + T::Err: std::fmt::Debug, + { + raw.parse().unwrap() + } +} + +impl ParamSerializer for JsonSerialization { + fn serialize_param(&self, value: &T, _type: TypeId) -> Vec { + let any: &dyn Any = value; + if let Some(s) = any.downcast_ref::() { + vec![s.clone()] + } else if let Some(v) = any.downcast_ref::>() { + v.clone() + } else if let Some(b) = any.downcast_ref::() { + vec![b.to_string()] + } else if let Some(n) = any.downcast_ref::() { + vec![n.to_string()] + } else if let Some(n) = any.downcast_ref::() { + vec![n.to_string()] + } else { + panic!("Unsupported param type for serialization: {:?}", _type) + } + } +} + +impl ParamDeserializer for JsonSerialization { + fn deserialize_param(&self, values: &[String], r#type: TypeId) -> T { + let boxed: Box = if r#type == TypeId::of::() { + Box::new(values.first().cloned().unwrap_or_default()) + } else if r#type == TypeId::of::>() { + Box::new(values.to_vec()) + } else if r#type == TypeId::of::() { + Box::new( + values + .first() + .map(|v| v == "true") + .unwrap_or(false), + ) + } else if r#type == TypeId::of::() { + Box::new( + values + .first() + .and_then(|v| v.parse::().ok()) + .unwrap_or(0), + ) + } else if r#type == TypeId::of::() { + Box::new( + values + .first() + .and_then(|v| v.parse::().ok()) + .unwrap_or(0), + ) + } else { + panic!("Unsupported param type for deserialization: {:?}", r#type) + }; + *boxed.downcast::().unwrap() + } +} + +impl Serializer for JsonSerialization {} +impl Deserializer for JsonSerialization {} +impl Serialization for JsonSerialization {} + +// --- Pet --- + +fn pet_to_json(pet: &Pet) -> serde_json::Value { + serde_json::json!({ + "id": pet.id, + "category": pet.category.as_ref().map(category_to_json), + "name": pet.name, + "photoUrls": pet.photo_urls, + "tags": pet.tags.as_ref().map(|tags| tags.iter().map(tag_to_json).collect::>()), + "status": pet.status.as_ref().map(|s| s.label()), + }) +} + +fn json_to_pet(v: &serde_json::Value) -> Pet { + Pet { + id: v.get("id").and_then(|v| v.as_i64()), + category: v.get("category").and_then(|v| if v.is_null() { None } else { Some(json_to_category(v)) }), + name: v.get("name").and_then(|v| v.as_str()).unwrap_or_default().to_string(), + photo_urls: v + .get("photoUrls") + .and_then(|v| v.as_array()) + .map(|a| a.iter().filter_map(|v| v.as_str().map(String::from)).collect()) + .unwrap_or_default(), + tags: v.get("tags").and_then(|v| { + if v.is_null() { None } else { v.as_array().map(|a| a.iter().map(json_to_tag).collect()) } + }), + status: v + .get("status") + .and_then(|v| v.as_str()) + .and_then(PetStatus::from_label), + } +} + +// --- Category --- + +fn category_to_json(cat: &Category) -> serde_json::Value { + serde_json::json!({ + "id": cat.id, + "name": cat.name, + }) +} + +fn json_to_category(v: &serde_json::Value) -> Category { + Category { + id: v.get("id").and_then(|v| v.as_i64()), + name: v.get("name").and_then(|v| v.as_str()).map(String::from), + } +} + +// --- Tag --- + +fn tag_to_json(tag: &Tag) -> serde_json::Value { + serde_json::json!({ + "id": tag.id, + "name": tag.name, + }) +} + +fn json_to_tag(v: &serde_json::Value) -> Tag { + Tag { + id: v.get("id").and_then(|v| v.as_i64()), + name: v.get("name").and_then(|v| v.as_str()).map(String::from), + } +} + +// --- Order --- + +fn order_to_json(order: &Order) -> serde_json::Value { + serde_json::json!({ + "id": order.id, + "petId": order.pet_id, + "quantity": order.quantity, + "shipDate": order.ship_date, + "status": order.status.as_ref().map(|s| s.label()), + "complete": order.complete, + }) +} + +fn json_to_order(v: &serde_json::Value) -> Order { + Order { + id: v.get("id").and_then(|v| v.as_i64()), + pet_id: v.get("petId").and_then(|v| v.as_i64()), + quantity: v.get("quantity").and_then(|v| v.as_i64()).map(|n| n as i32), + ship_date: v.get("shipDate").and_then(|v| v.as_str()).map(String::from), + status: v + .get("status") + .and_then(|v| v.as_str()) + .and_then(OrderStatus::from_label), + complete: v.get("complete").and_then(|v| v.as_bool()), + } +} + +// --- User --- + +fn user_to_json(user: &User) -> serde_json::Value { + serde_json::json!({ + "id": user.id, + "username": user.username, + "firstName": user.first_name, + "lastName": user.last_name, + "email": user.email, + "password": user.password, + "phone": user.phone, + "userStatus": user.user_status, + }) +} + +fn json_to_user(v: &serde_json::Value) -> User { + User { + id: v.get("id").and_then(|v| v.as_i64()), + username: v.get("username").and_then(|v| v.as_str()).map(String::from), + first_name: v.get("firstName").and_then(|v| v.as_str()).map(String::from), + last_name: v.get("lastName").and_then(|v| v.as_str()).map(String::from), + email: v.get("email").and_then(|v| v.as_str()).map(String::from), + password: v.get("password").and_then(|v| v.as_str()).map(String::from), + phone: v.get("phone").and_then(|v| v.as_str()).map(String::from), + user_status: v.get("userStatus").and_then(|v| v.as_i64()).map(|n| n as i32), + } +} + +// --- ApiResponse --- + +fn api_response_to_json(resp: &ApiResponse) -> serde_json::Value { + serde_json::json!({ + "code": resp.code, + "type": resp.r#type, + "message": resp.message, + }) +} + +fn json_to_api_response(v: &serde_json::Value) -> ApiResponse { + ApiResponse { + code: v.get("code").and_then(|v| v.as_i64()).map(|n| n as i32), + r#type: v.get("type").and_then(|v| v.as_str()).map(String::from), + message: v.get("message").and_then(|v| v.as_str()).map(String::from), + } +} diff --git a/examples/rust-petstore/src/server.rs b/examples/rust-petstore/src/server.rs new file mode 100644 index 000000000..db7d629b3 --- /dev/null +++ b/examples/rust-petstore/src/server.rs @@ -0,0 +1,24 @@ +use actix_web::{web, App, HttpRequest, HttpServer}; +use wirespec_petstore::gen::endpoint::{add_pet, find_pets_by_status, get_pet_by_id}; +use wirespec_petstore::gen::wirespec::{Method, Server}; +use wirespec_petstore::register; +use wirespec_petstore::service::PetstoreService; + +#[actix_web::main] +async fn main() -> std::io::Result<()> { + println!("Starting petstore server on http://127.0.0.1:8080"); + + HttpServer::new(|| { + App::new() + .configure(|cfg| { + register!(cfg, PetstoreService; + find_pets_by_status::FindPetsByStatus, + add_pet::AddPet, + get_pet_by_id::GetPetById, + ); + }) + }) + .bind("127.0.0.1:8080")? + .run() + .await +} diff --git a/examples/rust-petstore/src/service.rs b/examples/rust-petstore/src/service.rs new file mode 100644 index 000000000..b83fa2cc0 --- /dev/null +++ b/examples/rust-petstore/src/service.rs @@ -0,0 +1,49 @@ +use crate::gen::endpoint::{add_pet, find_pets_by_status, get_pet_by_id}; +use crate::gen::model::category::Category; +use crate::gen::model::pet::Pet; + +pub struct PetstoreService; + +impl get_pet_by_id::GetPetById::Handler for PetstoreService { + async fn get_pet_by_id(&self, request: get_pet_by_id::Request) -> get_pet_by_id::Response { + let pet = Pet { + id: Some(request.path.pet_id), + category: Some(Category { + id: Some(1), + name: Some("Dogs".into()), + }), + name: format!("Pet {}", request.path.pet_id), + photo_urls: vec!["https://example.com/pet.jpg".into()], + tags: None, + status: None, + }; + get_pet_by_id::Response200::new(pet).into() + } +} + +impl find_pets_by_status::FindPetsByStatus::Handler for PetstoreService { + async fn find_pets_by_status( + &self, + _request: find_pets_by_status::Request, + ) -> find_pets_by_status::Response { + let pets = vec![Pet { + id: Some(1), + category: Some(Category { + id: Some(1), + name: Some("Dogs".into()), + }), + name: "Buddy".into(), + photo_urls: vec!["https://example.com/buddy.jpg".into()], + tags: None, + status: None, + }]; + find_pets_by_status::Response200::new(pets).into() + } +} + +impl add_pet::AddPet::Handler for PetstoreService { + async fn add_pet(&self, request: add_pet::Request) -> add_pet::Response { + println!("Received pet: {:?}", request.body); + add_pet::Response405::new().into() + } +} diff --git a/examples/rust-petstore/src/transportation.rs b/examples/rust-petstore/src/transportation.rs new file mode 100644 index 000000000..374ce1fe0 --- /dev/null +++ b/examples/rust-petstore/src/transportation.rs @@ -0,0 +1,93 @@ +use crate::gen::wirespec::{Client, RawRequest, RawResponse, Transportation}; +use crate::serialization::JsonSerialization; +use std::collections::HashMap; + +pub struct ReqwestTransport { + client: reqwest::Client, + base_url: String, +} + +impl ReqwestTransport { + pub fn new(base_url: &str) -> Self { + ReqwestTransport { + client: reqwest::Client::new(), + base_url: base_url.to_string(), + } + } +} + +impl Transportation for ReqwestTransport { + async fn transport(&self, request: &RawRequest) -> RawResponse { + let path = request.path.join("/"); + let url = format!("{}/{}", self.base_url, path); + + let mut req_builder = match request.method.as_str() { + "GET" => self.client.get(&url), + "POST" => self.client.post(&url), + "PUT" => self.client.put(&url), + "DELETE" => self.client.delete(&url), + "PATCH" => self.client.patch(&url), + "HEAD" => self.client.head(&url), + _ => self.client.get(&url), + }; + + for (key, values) in &request.queries { + for value in values { + req_builder = req_builder.query(&[(key, value)]); + } + } + + for (key, values) in &request.headers { + if let Some(value) = values.first() { + req_builder = req_builder.header(key.as_str(), value.as_str()); + } + } + + req_builder = req_builder.header("Accept", "application/json"); + req_builder = req_builder.header("Content-Type", "application/json"); + + if let Some(body) = &request.body { + req_builder = req_builder.body(body.clone()); + } + + match req_builder.send().await { + Ok(response) => { + let status_code = response.status().as_u16() as i32; + let mut headers: HashMap> = HashMap::new(); + for (key, value) in response.headers() { + headers + .entry(key.to_string()) + .or_default() + .push(value.to_str().unwrap_or("").to_string()); + } + let body = response.bytes().await.ok().map(|b| b.to_vec()); + + RawResponse { + status_code, + headers, + body, + } + } + Err(e) => { + eprintln!("Transport error: {}", e); + RawResponse { + status_code: 0, + headers: HashMap::new(), + body: None, + } + } + } + } +} + +pub struct ClientProxy { + pub transport: T, + pub serialization: JsonSerialization, +} + +impl Client for ClientProxy { + type Transport = T; + type Ser = JsonSerialization; + fn transport(&self) -> &T { &self.transport } + fn serialization(&self) -> &JsonSerialization { &self.serialization } +} diff --git a/examples/rust-petstore/tests/integration.rs b/examples/rust-petstore/tests/integration.rs new file mode 100644 index 000000000..f4fddd61e --- /dev/null +++ b/examples/rust-petstore/tests/integration.rs @@ -0,0 +1,89 @@ +use actix_web::{web, App, HttpRequest, HttpServer}; +use wirespec_petstore::gen::endpoint::{add_pet, find_pets_by_status, get_pet_by_id}; +use wirespec_petstore::gen::model::pet::Pet; +use wirespec_petstore::gen::wirespec::{Method, Server}; +use wirespec_petstore::register; +use wirespec_petstore::serialization::JsonSerialization; +use wirespec_petstore::service::PetstoreService; +use wirespec_petstore::transportation::{ClientProxy, ReqwestTransport}; + +use add_pet::AddPet::Handler as _; +use find_pets_by_status::FindPetsByStatus::Handler as _; +use get_pet_by_id::GetPetById::Handler as _; + +#[test] +fn test_petstore_endpoints() { + let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); + let port = listener.local_addr().unwrap().port(); + + std::thread::spawn(move || { + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async { + HttpServer::new(|| { + App::new().configure(|cfg| { + register!(cfg, PetstoreService; + find_pets_by_status::FindPetsByStatus, + add_pet::AddPet, + get_pet_by_id::GetPetById, + ); + }) + }) + .listen(listener) + .unwrap() + .run() + .await + .unwrap(); + }); + }); + + // Wait for server readiness + for _ in 0..50 { + if std::net::TcpStream::connect(format!("127.0.0.1:{}", port)).is_ok() { + break; + } + std::thread::sleep(std::time::Duration::from_millis(100)); + } + + let rt = tokio::runtime::Runtime::new().unwrap(); + + let api = ClientProxy { + transport: ReqwestTransport::new(&format!("http://127.0.0.1:{}", port)), + serialization: JsonSerialization, + }; + + // Test GetPetById: request pet ID 42 → expect 200 with id=42, name="Pet 42" + let resp = rt.block_on(api.get_pet_by_id(get_pet_by_id::Request::new(42))); + match resp { + get_pet_by_id::Response::Response200(r) => { + assert_eq!(r.body.id, Some(42)); + assert_eq!(r.body.name, "Pet 42"); + } + other => panic!("Expected Response200, got {:?}", other), + } + + // Test FindPetsByStatus: request status "available" → expect 200 with 1 pet named "Buddy" + let resp = rt.block_on( + api.find_pets_by_status(find_pets_by_status::Request::new(vec!["available".into()])), + ); + match resp { + find_pets_by_status::Response::Response200(r) => { + assert_eq!(r.body.len(), 1); + assert_eq!(r.body[0].name, "Buddy"); + } + other => panic!("Expected Response200, got {:?}", other), + } + + // Test AddPet: send a pet → expect 405 response + let pet = Pet { + id: None, + category: None, + name: "TestPet".into(), + photo_urls: vec![], + tags: None, + status: None, + }; + let resp = rt.block_on(api.add_pet(add_pet::Request::new(pet))); + match resp { + add_pet::Response::Response405(_) => {} + } +} diff --git a/examples/scala-zio/.gitignore b/examples/scala-zio/.gitignore new file mode 100644 index 000000000..aa27faf9f --- /dev/null +++ b/examples/scala-zio/.gitignore @@ -0,0 +1,7 @@ +target/ +.bsp/ +project/target/ +project/project/ +.idea/ +*.class +*.log diff --git a/examples/scala-zio/build.sbt b/examples/scala-zio/build.sbt new file mode 100644 index 000000000..5eeaba714 --- /dev/null +++ b/examples/scala-zio/build.sbt @@ -0,0 +1,48 @@ +val scala3Version = "3.3.4" +val zioVersion = "2.1.14" +val zioHttpVersion = "3.0.1" +val circeVersion = "0.14.10" + +lazy val root = project + .in(file(".")) + .settings( + name := "scala-zio", + version := "0.1.0", + scalaVersion := scala3Version, + libraryDependencies ++= Seq( + "dev.zio" %% "zio" % zioVersion, + "dev.zio" %% "zio-http" % zioHttpVersion, + "io.circe" %% "circe-core" % circeVersion, + "io.circe" %% "circe-generic" % circeVersion, + "io.circe" %% "circe-parser" % circeVersion, + "dev.zio" %% "zio-test" % zioVersion % Test, + "dev.zio" %% "zio-test-sbt" % zioVersion % Test, + "dev.zio" %% "zio-http-testkit" % zioHttpVersion % Test, + ), + testFrameworks += new TestFramework("zio.test.sbt.ZTestFramework"), + // Wirespec code generation + Compile / sourceGenerators += wirespecGenerate.taskValue, + Compile / managedSourceDirectories += baseDirectory.value / "target" / "generated-sources", + ) + +lazy val wirespecGenerate = taskKey[Seq[File]]("Generate Scala sources from Wirespec/OpenAPI") + +wirespecGenerate := { + val log = streams.value.log + val outDir = baseDirectory.value / "target" / "generated-sources" + + val genScript = baseDirectory.value / "gen.sh" + if (!genScript.exists()) { + sys.error(s"gen.sh not found at ${genScript.absolutePath}") + } + + log.info("Running Wirespec code generation...") + import scala.sys.process._ + val exitCode = Process(Seq("bash", genScript.absolutePath), baseDirectory.value).! + if (exitCode != 0) { + sys.error(s"gen.sh failed with exit code $exitCode") + } + + // Collect all generated .scala files + (outDir ** "*.scala").get +} diff --git a/examples/scala-zio/gen.sh b/examples/scala-zio/gen.sh new file mode 100755 index 000000000..42c88837f --- /dev/null +++ b/examples/scala-zio/gen.sh @@ -0,0 +1,31 @@ +#!/bin/bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT_DIR="$(cd "$SCRIPT_DIR/../.." && pwd)" + +UNAME="$(uname -s)_$(uname -m)" +case "$UNAME" in + Linux_x86_64) PLATFORM="linuxX64" ;; + Darwin_arm64) PLATFORM="macosArm64" ;; + Darwin_x86_64) PLATFORM="macosX64" ;; + *) echo "Unsupported platform: $UNAME" >&2; exit 1 ;; +esac + +CLI="$ROOT_DIR/src/plugin/cli/build/bin/$PLATFORM/releaseExecutable/cli.kexe" + +if [ ! -f "$CLI" ]; then + echo "Building Wirespec CLI for $PLATFORM..." + (cd "$ROOT_DIR" && ./gradlew ":src:plugin:cli:${PLATFORM}Binaries") +fi + +OUT_DIR="$SCRIPT_DIR/target/generated-sources" +mkdir -p "$OUT_DIR" + +echo "Generating Scala code from guru.json..." +"$CLI" convert OpenAPIV3 \ + -i "$SCRIPT_DIR/guru.json" \ + -o "$OUT_DIR" \ + -l Scala \ + -p community.flock.wirespec.generated \ + --shared diff --git a/examples/scala-zio/guru.json b/examples/scala-zio/guru.json new file mode 100644 index 000000000..58672e5fc --- /dev/null +++ b/examples/scala-zio/guru.json @@ -0,0 +1,525 @@ +{ + "openapi": "3.0.0", + "x-optic-url": "https://app.useoptic.com/organizations/febf8ac6-ee67-4565-b45a-5c85a469dca7/apis/_0fKWqUvhs9ssYNkq1k-c", + "x-optic-standard": "@febf8ac6-ee67-4565-b45a-5c85a469dca7/Fz6KU3_wMIO5iJ6_VUZ30", + "info": { + "version": "2.2.0", + "title": "APIs.guru", + "description": "Wikipedia for Web APIs. Repository of API definitions in OpenAPI format.\n**Warning**: If you want to be notified about changes in advance please join our [Slack channel](https://join.slack.com/t/mermade/shared_invite/zt-g78g7xir-MLE_CTCcXCdfJfG3CJe9qA).\nClient sample: [[Demo]](https://apis.guru/simple-ui) [[Repo]](https://github.com/APIs-guru/simple-ui)\n", + "contact": { + "name": "APIs.guru", + "url": "https://APIs.guru", + "email": "mike.ralphson@gmail.com" + }, + "license": { + "name": "CC0 1.0", + "url": "https://github.com/APIs-guru/openapi-directory#licenses" + }, + "x-logo": { + "url": "https://apis.guru/branding/logo_vertical.svg" + } + }, + "externalDocs": { + "url": "https://github.com/APIs-guru/openapi-directory/blob/master/API.md" + }, + "servers": [ + { + "url": "https://api.apis.guru/v2" + } + ], + "security": [], + "tags": [ + { + "name": "APIs", + "description": "Actions relating to APIs in the collection" + } + ], + "paths": { + "/providers.json": { + "get": { + "operationId": "getProviders", + "tags": [ + "APIs" + ], + "summary": "List all providers", + "description": "List all the providers in the directory\n", + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "data": { + "type": "array", + "items": { + "type": "string", + "minLength": 1 + }, + "minItems": 1 + } + } + } + } + } + } + } + } + }, + "/{provider}.json": { + "get": { + "operationId": "getProvider", + "tags": [ + "APIs" + ], + "summary": "List all APIs for a particular provider", + "description": "List all APIs in the directory for a particular providerName\nReturns links to the individual API entry for each API.\n", + "parameters": [ + { + "$ref": "#/components/parameters/provider" + } + ], + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/APIs" + } + } + } + } + } + } + }, + "/{provider}/services.json": { + "get": { + "operationId": "getServices", + "tags": [ + "APIs" + ], + "summary": "List all serviceNames for a particular provider", + "description": "List all serviceNames in the directory for a particular providerName\n", + "parameters": [ + { + "$ref": "#/components/parameters/provider" + } + ], + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "data": { + "type": "array", + "items": { + "type": "string", + "minLength": 0 + }, + "minItems": 1 + } + } + } + } + } + } + } + } + }, + "/specs/{provider}/{api}.json": { + "get": { + "operationId": "getAPI", + "tags": [ + "APIs" + ], + "summary": "Retrieve one version of a particular API", + "description": "Returns the API entry for one specific version of an API where there is no serviceName.", + "parameters": [ + { + "$ref": "#/components/parameters/provider" + }, + { + "$ref": "#/components/parameters/api" + } + ], + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/API" + } + } + } + } + } + } + }, + "/specs/{provider}/{service}/{api}.json": { + "get": { + "operationId": "getServiceAPI", + "tags": [ + "APIs" + ], + "summary": "Retrieve one version of a particular API with a serviceName.", + "description": "Returns the API entry for one specific version of an API where there is a serviceName.", + "parameters": [ + { + "$ref": "#/components/parameters/provider" + }, + { + "name": "service", + "in": "path", + "required": true, + "schema": { + "type": "string", + "minLength": 1, + "maxLength": 255, + "example": "graph" + } + }, + { + "$ref": "#/components/parameters/api" + } + ], + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/API" + } + } + } + } + } + } + }, + "/list.json": { + "get": { + "operationId": "listAPIs", + "tags": [ + "APIs" + ], + "summary": "List all APIs", + "description": "List all APIs in the directory.\nReturns links to the OpenAPI definitions for each API in the directory.\nIf API exist in multiple versions `preferred` one is explicitly marked.\nSome basic info from the OpenAPI definition is cached inside each object.\nThis allows you to generate some simple views without needing to fetch the OpenAPI definition for each API.\n", + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/APIs" + } + } + } + } + } + } + }, + "/metrics.json": { + "get": { + "operationId": "getMetrics", + "summary": "Get basic metrics", + "description": "Some basic metrics for the entire directory.\nJust stunning numbers to put on a front page and are intended purely for WoW effect :)\n", + "tags": [ + "APIs" + ], + "responses": { + "200": { + "description": "OK", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Metrics" + } + } + } + } + } + } + } + }, + "components": { + "schemas": { + "APIs": { + "description": "List of API details.\nIt is a JSON object with API IDs(`[:]`) as keys.\n", + "type": "object", + "additionalProperties": { + "$ref": "#/components/schemas/API" + }, + "minProperties": 1, + "example": { + "googleapis.com:drive": { + "added": "2015-02-22T20:00:45.000Z", + "preferred": "v3", + "versions": { + "v2": { + "added": "2015-02-22T20:00:45.000Z", + "info": { + "title": "Drive", + "version": "v2", + "x-apiClientRegistration": { + "url": "https://console.developers.google.com" + }, + "x-logo": { + "url": "https://api.apis.guru/v2/cache/logo/https_www.gstatic.com_images_icons_material_product_2x_drive_32dp.png" + }, + "x-origin": { + "format": "google", + "url": "https://www.googleapis.com/discovery/v1/apis/drive/v2/rest", + "version": "v1" + }, + "x-preferred": false, + "x-providerName": "googleapis.com", + "x-serviceName": "drive" + }, + "swaggerUrl": "https://api.apis.guru/v2/specs/googleapis.com/drive/v2/swagger.json", + "swaggerYamlUrl": "https://api.apis.guru/v2/specs/googleapis.com/drive/v2/swagger.yaml", + "updated": "2016-06-17T00:21:44.000Z" + }, + "v3": { + "added": "2015-12-12T00:25:13.000Z", + "info": { + "title": "Drive", + "version": "v3", + "x-apiClientRegistration": { + "url": "https://console.developers.google.com" + }, + "x-logo": { + "url": "https://api.apis.guru/v2/cache/logo/https_www.gstatic.com_images_icons_material_product_2x_drive_32dp.png" + }, + "x-origin": { + "format": "google", + "url": "https://www.googleapis.com/discovery/v1/apis/drive/v3/rest", + "version": "v1" + }, + "x-preferred": true, + "x-providerName": "googleapis.com", + "x-serviceName": "drive" + }, + "swaggerUrl": "https://api.apis.guru/v2/specs/googleapis.com/drive/v3/swagger.json", + "swaggerYamlUrl": "https://api.apis.guru/v2/specs/googleapis.com/drive/v3/swagger.yaml", + "updated": "2016-06-17T00:21:44.000Z" + } + } + } + } + }, + "API": { + "description": "Meta information about API", + "type": "object", + "required": [ + "added", + "preferred", + "versions" + ], + "properties": { + "added": { + "description": "Timestamp when the API was first added to the directory", + "type": "string", + "format": "date-time" + }, + "preferred": { + "description": "Recommended version", + "type": "string" + }, + "versions": { + "description": "List of supported versions of the API", + "type": "object", + "additionalProperties": { + "$ref": "#/components/schemas/ApiVersion" + }, + "minProperties": 1 + } + }, + "additionalProperties": false + }, + "ApiVersion": { + "type": "object", + "required": [ + "added", + "updated", + "swaggerUrl", + "swaggerYamlUrl", + "info", + "openapiVer" + ], + "properties": { + "added": { + "description": "Timestamp when the version was added", + "type": "string", + "format": "date-time" + }, + "updated": { + "description": "Timestamp when the version was updated", + "type": "string", + "format": "date-time" + }, + "swaggerUrl": { + "description": "URL to OpenAPI definition in JSON format", + "type": "string", + "format": "url" + }, + "swaggerYamlUrl": { + "description": "URL to OpenAPI definition in YAML format", + "type": "string", + "format": "url" + }, + "link": { + "description": "Link to the individual API entry for this API", + "type": "string", + "format": "url" + }, + "info": { + "description": "Copy of `info` section from OpenAPI definition", + "type": "object", + "minProperties": 1 + }, + "externalDocs": { + "description": "Copy of `externalDocs` section from OpenAPI definition", + "type": "object", + "minProperties": 1 + }, + "openapiVer": { + "description": "The value of the `openapi` or `swagger` property of the source definition", + "type": "string" + } + }, + "additionalProperties": false + }, + "Metrics": { + "description": "List of basic metrics", + "type": "object", + "required": [ + "numSpecs", + "numAPIs", + "numEndpoints" + ], + "properties": { + "numSpecs": { + "description": "Number of API definitions including different versions of the same API", + "type": "integer", + "minimum": 1 + }, + "numAPIs": { + "description": "Number of unique APIs", + "type": "integer", + "minimum": 1 + }, + "numEndpoints": { + "description": "Total number of endpoints inside all definitions", + "type": "integer", + "minimum": 1 + }, + "unreachable": { + "description": "Number of unreachable (4XX,5XX status) APIs", + "type": "integer" + }, + "invalid": { + "description": "Number of newly invalid APIs", + "type": "integer" + }, + "unofficial": { + "description": "Number of unofficial APIs", + "type": "integer" + }, + "fixes": { + "description": "Total number of fixes applied across all APIs", + "type": "integer" + }, + "fixedPct": { + "description": "Percentage of all APIs where auto fixes have been applied", + "type": "integer" + }, + "datasets": { + "description": "Data used for charting etc", + "type": "array", + "items": {} + }, + "stars": { + "description": "GitHub stars for our main repo", + "type": "integer" + }, + "issues": { + "description": "Open GitHub issues on our main repo", + "type": "integer" + }, + "thisWeek": { + "description": "Summary totals for the last 7 days", + "type": "object", + "properties": { + "added": { + "description": "APIs added in the last week", + "type": "integer" + }, + "updated": { + "description": "APIs updated in the last week", + "type": "integer" + } + } + }, + "numDrivers": { + "description": "Number of methods of API retrieval", + "type": "integer" + }, + "numProviders": { + "description": "Number of API providers in directory", + "type": "integer" + } + }, + "additionalProperties": false, + "example": { + "numAPIs": 2501, + "numEndpoints": 106448, + "numSpecs": 3329, + "unreachable": 123, + "invalid": 598, + "unofficial": 25, + "fixes": 81119, + "fixedPct": 22, + "datasets": [], + "stars": 2429, + "issues": 28, + "thisWeek": { + "added": 45, + "updated": 171 + }, + "numDrivers": 10, + "numProviders": 659 + } + } + }, + "parameters": { + "provider": { + "name": "provider", + "in": "path", + "required": true, + "schema": { + "type": "string", + "minLength": 1, + "maxLength": 255, + "example": "apis.guru" + } + }, + "api": { + "name": "api", + "in": "path", + "required": true, + "schema": { + "type": "string", + "minLength": 1, + "maxLength": 255, + "example": "2.1.0" + } + } + } + } +} \ No newline at end of file diff --git a/examples/scala-zio/project/build.properties b/examples/scala-zio/project/build.properties new file mode 100644 index 000000000..73df629ac --- /dev/null +++ b/examples/scala-zio/project/build.properties @@ -0,0 +1 @@ +sbt.version=1.10.7 diff --git a/examples/scala-zio/src/main/scala/example/CirceSerialization.scala b/examples/scala-zio/src/main/scala/example/CirceSerialization.scala new file mode 100644 index 000000000..85e9f3797 --- /dev/null +++ b/examples/scala-zio/src/main/scala/example/CirceSerialization.scala @@ -0,0 +1,67 @@ +package example + +import community.flock.wirespec.scala.Wirespec +import community.flock.wirespec.generated.model.* +import io.circe.* +import io.circe.generic.semiauto.deriveCodec +import io.circe.parser.decode +import io.circe.syntax.* + +import scala.reflect.ClassTag + +object CirceSerialization extends Wirespec.Serialization { + + private given Codec.AsObject[MetricsDatasetsArray] = deriveCodec + private given Codec.AsObject[ApiVersionInfo] = deriveCodec + private given Codec.AsObject[ApiVersionExternalDocs] = deriveCodec + private given Codec.AsObject[APIVersionsInfo] = deriveCodec + private given Codec.AsObject[APIVersionsExternalDocs] = deriveCodec + private given Codec.AsObject[MetricsThisWeek] = deriveCodec + private given Codec.AsObject[ApiVersion] = deriveCodec + private given Codec.AsObject[APIVersions] = deriveCodec + private given Codec.AsObject[API] = deriveCodec + private given Codec.AsObject[Metrics] = deriveCodec + private given Codec.AsObject[GetProviders200ResponseBody] = deriveCodec + private given Codec.AsObject[GetServices200ResponseBody] = deriveCodec + + private def codec[T: Encoder: Decoder: ClassTag]: (Class[?], (Encoder[?], Decoder[?])) = + summon[ClassTag[T]].runtimeClass -> (summon[Encoder[T]], summon[Decoder[T]]) + + private val codecRegistry: Map[Class[?], (Encoder[?], Decoder[?])] = Map( + codec[Metrics], + codec[API], + codec[GetProviders200ResponseBody], + codec[GetServices200ResponseBody], + classOf[Map[?, ?]] -> (Encoder.encodeMap[String, API].asInstanceOf[Encoder[?]], Decoder.decodeMap[String, API].asInstanceOf[Decoder[?]]), + ) + + override def serializeBody[T](t: T, `type`: ClassTag[?]): Array[Byte] = { + val cls = if (t.isInstanceOf[Map[?, ?]]) classOf[Map[?, ?]] else t.getClass + val (encoder, _) = codecRegistry.getOrElse(cls, throw new IllegalStateException(s"No encoder for ${t.getClass}")) + encoder.asInstanceOf[Encoder[T]].apply(t).noSpaces.getBytes("UTF-8") + } + + override def deserializeBody[T](raw: Array[Byte], `type`: ClassTag[?]): T = { + val cls = `type`.runtimeClass + val key = if (classOf[Map[?, ?]].isAssignableFrom(cls)) classOf[Map[?, ?]] else cls + val (_, decoder) = codecRegistry.getOrElse(key, throw new IllegalStateException(s"No decoder for $cls")) + decode(new String(raw, "UTF-8"))(using decoder.asInstanceOf[Decoder[T]]).fold(throw _, identity) + } + + override def serializePath[T](t: T, `type`: ClassTag[?]): String = t.toString + + override def deserializePath[T](raw: String, `type`: ClassTag[?]): T = { + val cls = `type`.runtimeClass + (if (cls == classOf[String]) raw + else if (cls == classOf[java.lang.Long] || cls == classOf[Long]) raw.toLong + else if (cls == classOf[java.lang.Integer] || cls == classOf[Int]) raw.toInt + else if (cls == classOf[java.lang.Boolean] || cls == classOf[Boolean]) raw.toBoolean + else if (cls == classOf[java.lang.Double] || cls == classOf[Double]) raw.toDouble + else throw new IllegalStateException(s"Cannot deserialize path for $cls")).asInstanceOf[T] + } + + override def serializeParam[T](value: T, `type`: ClassTag[?]): List[String] = List(value.toString) + + override def deserializeParam[T](values: List[String], `type`: ClassTag[?]): T = + deserializePath(values.headOption.getOrElse(throw new IllegalStateException("Empty param list")), `type`) +} diff --git a/examples/scala-zio/src/main/scala/example/GuruClient.scala b/examples/scala-zio/src/main/scala/example/GuruClient.scala new file mode 100644 index 000000000..47eb2b1a7 --- /dev/null +++ b/examples/scala-zio/src/main/scala/example/GuruClient.scala @@ -0,0 +1,33 @@ +package example + +import community.flock.wirespec.scala.Wirespec +import community.flock.wirespec.generated.endpoint.* +import community.flock.wirespec.generated.model.* +import zio.* +import zio.http.{Client as ZClient, *} + +object GuruClient { + + private val baseUrl = URL.decode("https://api.apis.guru/v2").toOption.get + + extension [Req <: Wirespec.Request[?], Res <: Wirespec.Response[?]](c: Wirespec.Client[Req, Res]) + def call(request: Req): ZIO[ZClient & Scope, Throwable, Res] = { + val edge = c.client(CirceSerialization) + val rawReq = edge.to(request) + val url = baseUrl.copy(path = baseUrl.path ++ Path.decode("/" + rawReq.path.mkString("/"))) + val zioReq = Request( + method = Method.fromString(rawReq.method), + url = url, + body = rawReq.body.map(b => Body.fromArray(b)).getOrElse(Body.empty), + ) + for { + response <- ZClient.request(zioReq) + bodyBytes <- response.body.asArray + rawRes = Wirespec.RawResponse( + statusCode = response.status.code, + headers = Map.empty, + body = Some(bodyBytes) + ) + } yield edge.from(rawRes) + } +} diff --git a/examples/scala-zio/src/main/scala/example/GuruServer.scala b/examples/scala-zio/src/main/scala/example/GuruServer.scala new file mode 100644 index 000000000..bca75199c --- /dev/null +++ b/examples/scala-zio/src/main/scala/example/GuruServer.scala @@ -0,0 +1,87 @@ +package example + +import community.flock.wirespec.generated.endpoint.* +import community.flock.wirespec.generated.model.* +import zio.* +import zio.http.* + +trait GuruHandler extends GetMetrics.Handler[Task] + with GetProviders.Handler[Task] + with GetProvider.Handler[Task] + with GetAPI.Handler[Task] + with GetServices.Handler[Task] + with GetServiceAPI.Handler[Task] + with ListAPIs.Handler[Task] + +class GuruHandlerLive extends GuruHandler { + + override def getMetrics(request: GetMetrics.Request.type): Task[GetMetrics.Response[?]] = + ZIO.succeed(new GetMetrics.Response200(Metrics( + numSpecs = 42, + numAPIs = 10, + numEndpoints = 120, + unreachable = None, + invalid = None, + unofficial = None, + fixes = None, + fixedPct = None, + datasets = None, + stars = None, + issues = None, + thisWeek = None, + numDrivers = None, + numProviders = Some(5) + ))) + + override def getProviders(request: GetProviders.Request.type): Task[GetProviders.Response[?]] = + ZIO.succeed(new GetProviders.Response200(GetProviders200ResponseBody( + data = Some(List("googleapis.com", "azure.com", "amazonaws.com")) + ))) + + override def getProvider(request: GetProvider.Request.type): Task[GetProvider.Response[?]] = + ZIO.succeed(new GetProvider.Response200(Map.empty[String, API])) + + override def getAPI(request: GetAPI.Request): Task[GetAPI.Response[?]] = + ZIO.succeed(new GetAPI.Response200(API( + added = "2023-01-01", + preferred = "v1", + versions = Map.empty + ))) + + override def getServices(request: GetServices.Request): Task[GetServices.Response[?]] = + ZIO.succeed(new GetServices.Response200(GetServices200ResponseBody( + data = Some(List.empty) + ))) + + override def getServiceAPI(request: GetServiceAPI.Request): Task[GetServiceAPI.Response[?]] = + ZIO.succeed(new GetServiceAPI.Response200(API( + added = "2023-01-01", + preferred = "v1", + versions = Map.empty + ))) + + override def listAPIs(request: ListAPIs.Request.type): Task[ListAPIs.Response[?]] = + ZIO.succeed(new ListAPIs.Response200(Map.empty[String, API])) +} + +object GuruServer extends ZIOAppDefault { + + def routes(h: GuruHandler): Routes[Any, Response] = + WirespecRouter(CirceSerialization) + .route(GetMetrics.Server)(h.getMetrics) + .route(GetProviders.Server)(h.getProviders) + .route(ListAPIs.Server)(h.listAPIs) + .route(GetServiceAPI.Server)(h.getServiceAPI) + .route(GetAPI.Server)(h.getAPI) + .route(GetServices.Server)(h.getServices) + .route(GetProvider.Server)(h.getProvider) + .toRoutes + + override val run: ZIO[Any, Any, Any] = { + val h = GuruHandlerLive() + for { + _ <- ZIO.logInfo("Starting Guru API server on port 8080") + _ <- Server.serve(routes(h)) + } yield () + }.provide(Server.defaultWithPort(8080)) +} diff --git a/examples/scala-zio/src/main/scala/example/WirespecRouter.scala b/examples/scala-zio/src/main/scala/example/WirespecRouter.scala new file mode 100644 index 000000000..eb5cfe0cf --- /dev/null +++ b/examples/scala-zio/src/main/scala/example/WirespecRouter.scala @@ -0,0 +1,113 @@ +package example + +import community.flock.wirespec.scala.Wirespec +import zio.* +import zio.http.* + +class WirespecRouter(serialization: Wirespec.Serialization) { + + private case class RegisteredRoute( + server: Wirespec.Server[? <: Wirespec.Request[?], ? <: Wirespec.Response[?]], + handler: Any => Task[Any], + segments: List[SegmentMatcher], + method: String + ) + + private sealed trait SegmentMatcher + private case class LiteralMatcher(value: String) extends SegmentMatcher + private case class WildcardMatcher(name: String) extends SegmentMatcher + private case class SuffixMatcher(name: String, suffix: String) extends SegmentMatcher + + private var registeredRoutes: List[RegisteredRoute] = Nil + + def route[Req <: Wirespec.Request[?], Res <: Wirespec.Response[?]]( + server: Wirespec.Server[Req, Res] + )(handle: Req => Task[Res]): WirespecRouter = { + val segments = parseTemplate(server.pathTemplate) + registeredRoutes = registeredRoutes :+ RegisteredRoute( + server.asInstanceOf[Wirespec.Server[? <: Wirespec.Request[?], ? <: Wirespec.Response[?]]], + handle.asInstanceOf[Any => Task[Any]], + segments, + server.method + ) + this + } + + private def parseTemplate(template: String): List[SegmentMatcher] = { + template.stripPrefix("/").split("/").toList.map { segment => + if (segment.startsWith("{") && segment.endsWith("}")) { + WildcardMatcher(segment.drop(1).dropRight(1)) + } else if (segment.contains("{") && segment.contains("}")) { + val start = segment.indexOf('{') + val end = segment.indexOf('}') + val name = segment.substring(start + 1, end) + val suffix = segment.substring(end + 1) + SuffixMatcher(name, suffix) + } else { + LiteralMatcher(segment) + } + } + } + + private def matchPath(matchers: List[SegmentMatcher], path: List[String]): Boolean = { + if (matchers.length != path.length) return false + matchers.zip(path).forall { + case (LiteralMatcher(expected), actual) => expected == actual + case (WildcardMatcher(_), _) => true + case (SuffixMatcher(_, suffix), actual) => actual.endsWith(suffix) + } + } + + private def specificity(segments: List[SegmentMatcher]): (Int, Int, Int) = { + val literals = segments.count(_.isInstanceOf[LiteralMatcher]) + val suffixed = segments.count(_.isInstanceOf[SuffixMatcher]) + val length = segments.length + (-literals, -suffixed, -length) + } + + private def pathSegments(req: Request): List[String] = + req.url.path.segments.toList + + def toRoutes: Routes[Any, Response] = { + val sorted = registeredRoutes.sortBy(r => specificity(r.segments)) + + Routes( + Method.ANY / trailing -> Handler.fromFunctionZIO[(Path, Request)] { case (_, req: Request) => + val path = pathSegments(req) + val method = req.method.name + sorted.find(r => r.method == method && matchPath(r.segments, path)) match { + case Some(r) => + val edge = r.server.asInstanceOf[Wirespec.Server[Wirespec.Request[Any], Wirespec.Response[Any]]] + .server(serialization) + (for { + bodyBytes <- req.body.asArray + rawReq = toRawRequest(req, bodyBytes, path) + typedReq = edge.from(rawReq) + typedRes <- r.handler(typedReq) + rawRes = edge.to(typedRes.asInstanceOf[Wirespec.Response[Any]]) + } yield toZioResponse(rawRes)).mapError(e => Response.internalServerError(e.getMessage)) + case None => + ZIO.succeed(Response.status(Status.NotFound)) + } + } + ) + } + + private def toRawRequest(req: Request, bodyBytes: Array[Byte], path: List[String]): Wirespec.RawRequest = + Wirespec.RawRequest( + method = req.method.name, + path = path, + queries = req.url.queryParams.map.view.mapValues(_.toList).toMap, + headers = req.headers.toList.map(h => h.headerName -> List(h.renderedValue)).toMap, + body = if (bodyBytes.isEmpty) None else Some(bodyBytes) + ) + + private def toZioResponse(rawRes: Wirespec.RawResponse): Response = { + val headers = Headers(rawRes.headers.flatMap { case (k, vs) => vs.map(v => Header.Custom(k, v)) }.toList) + Response( + status = Status.fromInt(rawRes.statusCode), + headers = headers ++ Headers(Header.ContentType(MediaType.application.json)), + body = rawRes.body.map(b => Body.fromArray(b)).getOrElse(Body.empty) + ) + } +} diff --git a/examples/scala-zio/src/test/scala/example/GuruClientSpec.scala b/examples/scala-zio/src/test/scala/example/GuruClientSpec.scala new file mode 100644 index 000000000..44078f126 --- /dev/null +++ b/examples/scala-zio/src/test/scala/example/GuruClientSpec.scala @@ -0,0 +1,89 @@ +package example + +import community.flock.wirespec.generated.endpoint.* +import community.flock.wirespec.generated.model.* +import example.GuruClient.call +import zio.* +import zio.http.* +import zio.test.* + +object GuruClientSpec extends ZIOSpecDefault { + + override def spec: Spec[TestEnvironment & Scope, Any] = suite("GuruClientSpec")( + test("getMetrics returns parsed Metrics via HTTP") { + val json = """{"numSpecs":3329,"numAPIs":2501,"numEndpoints":106448}""" + for { + _ <- TestClient.addRoutes(Routes( + Method.GET / "v2" / "metrics.json" -> handler(Response.json(json)) + )) + result <- GetMetrics.Client.call(GetMetrics.Request) + } yield { + val r200 = result.asInstanceOf[GetMetrics.Response200] + assertTrue( + r200.body.numSpecs == 3329L, + r200.body.numAPIs == 2501L, + r200.body.numEndpoints == 106448L, + r200.body.unreachable.isEmpty, + r200.body.stars.isEmpty, + ) + } + }.provide(TestClient.layer, Scope.default), + test("getMetrics deserializes optional fields") { + val json = + """{ + |"numSpecs":100,"numAPIs":50,"numEndpoints":500, + |"unreachable":5,"stars":1000, + |"thisWeek":{"added":10,"updated":20}, + |"numProviders":42 + |}""".stripMargin.replaceAll("\\n", "") + for { + _ <- TestClient.addRoutes(Routes( + Method.GET / "v2" / "metrics.json" -> handler(Response.json(json)) + )) + result <- GetMetrics.Client.call(GetMetrics.Request) + } yield { + val r200 = result.asInstanceOf[GetMetrics.Response200] + assertTrue( + r200.body.numSpecs == 100L, + r200.body.numAPIs == 50L, + r200.body.numEndpoints == 500L, + r200.body.unreachable == Some(5L), + r200.body.stars == Some(1000L), + r200.body.thisWeek.flatMap(_.added) == Some(10L), + r200.body.numProviders == Some(42L), + ) + } + }.provide(TestClient.layer, Scope.default), + test("getProviders returns parsed provider list via HTTP") { + val json = """{"data":["amazonaws.com","googleapis.com","azure.com"]}""" + for { + _ <- TestClient.addRoutes(Routes( + Method.GET / "v2" / "providers.json" -> handler(Response.json(json)) + )) + result <- GetProviders.Client.call(GetProviders.Request) + } yield { + val r200 = result.asInstanceOf[GetProviders.Response200] + assertTrue( + r200.body.data.isDefined, + r200.body.data.get.length == 3, + r200.body.data.get.head == "amazonaws.com", + ) + } + }.provide(TestClient.layer, Scope.default), + test("listAPIs returns parsed API map via HTTP") { + val json = """{"test.api":{"added":"2024-01-01T00:00:00Z","preferred":"v1","versions":{}}}""" + for { + _ <- TestClient.addRoutes(Routes( + Method.GET / "v2" / "list.json" -> handler(Response.json(json)) + )) + result <- ListAPIs.Client.call(ListAPIs.Request) + } yield { + val r200 = result.asInstanceOf[ListAPIs.Response200] + assertTrue( + r200.body.contains("test.api"), + r200.body("test.api").preferred == "v1", + ) + } + }.provide(TestClient.layer, Scope.default), + ) +} diff --git a/examples/scala-zio/src/test/scala/example/GuruServerSpec.scala b/examples/scala-zio/src/test/scala/example/GuruServerSpec.scala new file mode 100644 index 000000000..a425d18dd --- /dev/null +++ b/examples/scala-zio/src/test/scala/example/GuruServerSpec.scala @@ -0,0 +1,128 @@ +package example + +import community.flock.wirespec.generated.endpoint.* +import community.flock.wirespec.generated.model.* +import zio.* +import zio.http.* +import zio.test.* + +object GuruServerSpec extends ZIOSpecDefault { + + private val handler = GuruHandlerLive() + private val routes = GuruServer.routes(handler) + + private def get(path: String): ZIO[Any, Response, Response] = + routes.runZIO(Request.get(URL.decode(path).toOption.get)) + + private def bodyString(response: Response): ZIO[Any, Throwable, String] = + response.body.asString + + override def spec: Spec[TestEnvironment & Scope, Any] = suite("GuruServerSpec")( + suite("literal path routes")( + test("GET /metrics.json returns 200 with metrics") { + for { + response <- get("/metrics.json") + body <- bodyString(response) + } yield assertTrue( + response.status == Status.Ok, + body.contains("\"numSpecs\":42"), + body.contains("\"numProviders\":5"), + ) + }, + test("GET /providers.json returns 200 with providers") { + for { + response <- get("/providers.json") + body <- bodyString(response) + } yield assertTrue( + response.status == Status.Ok, + body.contains("googleapis.com"), + body.contains("azure.com"), + body.contains("amazonaws.com"), + ) + }, + test("GET /list.json returns 200 with API list") { + for { + response <- get("/list.json") + body <- bodyString(response) + } yield assertTrue( + response.status == Status.Ok, + body == "{}", + ) + }, + ), + suite("parameterized path routes")( + test("GET /{provider}/services.json routes to GetServices") { + for { + response <- get("/googleapis.com/services.json") + body <- bodyString(response) + } yield assertTrue( + response.status == Status.Ok, + body.contains("\"data\":[]"), + ) + }, + test("GET /specs/{provider}/{api}.json routes to GetAPI") { + for { + response <- get("/specs/googleapis.com/v1.json") + body <- bodyString(response) + } yield assertTrue( + response.status == Status.Ok, + body.contains("\"preferred\":\"v1\""), + body.contains("\"added\":\"2023-01-01\""), + ) + }, + test("GET /specs/{provider}/{service}/{api}.json routes to GetServiceAPI") { + for { + response <- get("/specs/googleapis.com/compute/v1.json") + body <- bodyString(response) + } yield assertTrue( + response.status == Status.Ok, + body.contains("\"preferred\":\"v1\""), + ) + }, + test("GET /{provider}.json routes to GetProvider") { + for { + response <- get("/googleapis.com.json") + body <- bodyString(response) + } yield assertTrue( + response.status == Status.Ok, + body == "{}", + ) + }, + ), + suite("route specificity")( + test("/{provider}/services.json matches GetServices, not GetProvider") { + for { + response <- get("/googleapis.com/services.json") + body <- bodyString(response) + } yield assertTrue( + response.status == Status.Ok, + body.contains("\"data\""), + ) + }, + test("/specs/{provider}/{service}/{api}.json matches GetServiceAPI over GetAPI") { + for { + r3 <- get("/specs/googleapis.com/compute/v1.json") + r2 <- get("/specs/googleapis.com/v1.json") + b3 <- bodyString(r3) + b2 <- bodyString(r2) + } yield assertTrue( + r3.status == Status.Ok, + r2.status == Status.Ok, + b3 == b2, + ) + }, + ), + suite("404 handling")( + test("unmatched path returns 404") { + for { + response <- get("/nonexistent/path/that/does/not/match/anything") + } yield assertTrue(response.status == Status.NotFound) + }, + test("POST to GET-only route returns 404") { + for { + response <- routes.runZIO(Request(method = Method.POST, url = URL.decode("/metrics.json").toOption.get)) + } yield assertTrue(response.status == Status.NotFound) + }, + ), + ) +} diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 7b1672420..d40420c09 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -31,6 +31,7 @@ spotless = "7.2.1" spring_boot = "3.3.4" spring_webflux = "6.1.13" spring_dependency_management = "1.1.6" +testcontainers = "2.0.3" wiremock = "3.3.1" [libraries] @@ -40,6 +41,7 @@ jackson_databind = { module = "com.fasterxml.jackson.core:jackson-databind", ver jackson_jdk8 = { module = "com.fasterxml.jackson.datatype:jackson-datatype-jdk8", version.ref = "jackson" } jackson_kotlin = { module = "com.fasterxml.jackson.module:jackson-module-kotlin", version.ref = "jackson" } junit_launcher = { module = "org.junit.platform:junit-platform-launcher", version.ref = "junit_launcher" } +kotest_runner_junit5 = { module = "io.kotest:kotest-runner-junit5", version.ref = "kotest" } kafka_avro = { module = "io.confluent:kafka-avro-serializer", version.ref ="kafka_avro"} kotest_assertions = { module = "io.kotest:kotest-assertions-core", version.ref = "kotest" } kotest_assertions_arrow = { module = "io.kotest:kotest-assertions-arrow", version.ref = "kotest" } @@ -65,6 +67,7 @@ nexus_publish = { module = "io.github.gradle-nexus.publish-plugin:io.github.grad spring_boot_web = { module = "org.springframework.boot:spring-boot-starter-web", version.ref = "spring_boot" } spring_webflux = { module = "org.springframework:spring-webflux", version.ref = "spring_webflux" } spring_boot_test = { module = "org.springframework.boot:spring-boot-starter-test", version.ref = "spring_boot" } +testcontainers = { module = "org.testcontainers:testcontainers", version.ref = "testcontainers" } wiremock = { module = "org.wiremock:wiremock-standalone", version.ref = "wiremock" } spotless = { module = "com.diffplug.spotless:spotless-plugin-gradle", version.ref = "spotless" } diff --git a/scripts/test.sh b/scripts/test.sh index 44b78de6a..93853455c 100755 --- a/scripts/test.sh +++ b/scripts/test.sh @@ -58,4 +58,4 @@ dockerCommand="" for lang in "${languages[@]}"; do dockerCommand="$dockerCommand $(run "$dockerWirespec" '/app' 'docker' "$lang") && " done -docker run $archSpecific --rm -it -v "$localWorkDir"/types:/app/types wirespec "$dockerCommand$done" +docker run $archSpecific --rm -v "$localWorkDir"/types:/app/types wirespec "$dockerCommand$done" diff --git a/scripts/verify.sh b/scripts/verify.sh index 84b069465..a8c3564eb 100755 --- a/scripts/verify.sh +++ b/scripts/verify.sh @@ -1,19 +1 @@ -#!/usr/bin/env bash - -dir="$(dirname -- "$0")" -root="$dir/.." - -output="$root/types/out" - -archSpecific="" -if [[ $(uname -m) = arm64 ]]; then - archSpecific="--platform=linux/amd64" -fi - -# Compare output directories for same content and copy one of them to a 'combined' dir. -# Then that combined directory serves as a single input for the type checkers. -diff -qr "$output/docker/" "$output/jvm/" --exclude='*.jar' && \ -diff -qr "$output/jvm/" "$output/native/" --exclude='*.jar' && \ -diff -qr "$output/native/" "$output/node/" --exclude='*.jar' && \ -cp -r "$output/jvm/." "$output/combined" && \ -docker run $archSpecific --rm -it -v ./types/:/app/types wirespec /app/compileTypes.sh +./gradlew :src:verify:test diff --git a/settings.gradle.kts b/settings.gradle.kts index b8f620937..e7322e693 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -26,6 +26,8 @@ include( "src:compiler:emitters:java", "src:compiler:emitters:typescript", "src:compiler:emitters:python", + "src:compiler:emitters:rust", + "src:compiler:emitters:scala", "src:compiler:emitters:wirespec", "src:ide:intellij-plugin", "src:plugin:arguments", @@ -41,4 +43,6 @@ include( "src:integration:wirespec", "src:integration:spring", "src:tools:generator", + "src:compiler:ir", + "src:verify", ) diff --git a/src/compiler/core/src/commonMain/kotlin/community/flock/wirespec/compiler/core/emit/FileExtension.kt b/src/compiler/core/src/commonMain/kotlin/community/flock/wirespec/compiler/core/emit/FileExtension.kt index 0ced0ab8b..88fd10a1b 100644 --- a/src/compiler/core/src/commonMain/kotlin/community/flock/wirespec/compiler/core/emit/FileExtension.kt +++ b/src/compiler/core/src/commonMain/kotlin/community/flock/wirespec/compiler/core/emit/FileExtension.kt @@ -7,6 +7,8 @@ enum class FileExtension(override val value: String) : Value { Kotlin("kt"), TypeScript("ts"), Python("py"), + Rust("rs"), + Scala("scala"), Wirespec("ws"), JSON("json"), YAML("yaml"), diff --git a/src/compiler/core/src/commonMain/kotlin/community/flock/wirespec/compiler/core/parse/ast/Definition.kt b/src/compiler/core/src/commonMain/kotlin/community/flock/wirespec/compiler/core/parse/ast/Definition.kt index 7e9c934be..ec4472c6b 100644 --- a/src/compiler/core/src/commonMain/kotlin/community/flock/wirespec/compiler/core/parse/ast/Definition.kt +++ b/src/compiler/core/src/commonMain/kotlin/community/flock/wirespec/compiler/core/parse/ast/Definition.kt @@ -8,6 +8,10 @@ sealed interface Definition : val identifier: Identifier } +data class Shared( + val packageString: String, +) : Node + data class Field( override val annotations: List, val identifier: FieldIdentifier, diff --git a/src/compiler/emitters/java/build.gradle.kts b/src/compiler/emitters/java/build.gradle.kts index 217e6b788..68b5de636 100644 --- a/src/compiler/emitters/java/build.gradle.kts +++ b/src/compiler/emitters/java/build.gradle.kts @@ -41,6 +41,7 @@ kotlin { commonMain { dependencies { api(project(":src:compiler:core")) + api(project(":src:compiler:ir")) } } commonTest { diff --git a/src/compiler/emitters/java/src/commonMain/kotlin/community/flock/wirespec/emitters/java/JavaIrEmitter.kt b/src/compiler/emitters/java/src/commonMain/kotlin/community/flock/wirespec/emitters/java/JavaIrEmitter.kt new file mode 100644 index 000000000..699c48d1d --- /dev/null +++ b/src/compiler/emitters/java/src/commonMain/kotlin/community/flock/wirespec/emitters/java/JavaIrEmitter.kt @@ -0,0 +1,533 @@ +package community.flock.wirespec.emitters.java + +import arrow.core.NonEmptyList +import community.flock.wirespec.compiler.core.emit.DEFAULT_GENERATED_PACKAGE_STRING +import community.flock.wirespec.compiler.core.emit.DEFAULT_SHARED_PACKAGE_STRING +import community.flock.wirespec.compiler.core.emit.EmitShared +import community.flock.wirespec.compiler.core.emit.FileExtension +import community.flock.wirespec.compiler.core.emit.HasPackageName +import community.flock.wirespec.compiler.core.emit.Keywords +import community.flock.wirespec.compiler.core.emit.LanguageEmitter.Companion.firstToUpper +import community.flock.wirespec.compiler.core.emit.LanguageEmitter.Companion.needImports +import community.flock.wirespec.compiler.core.emit.PackageName +import community.flock.wirespec.compiler.core.emit.Shared +import community.flock.wirespec.compiler.core.emit.importReferences +import community.flock.wirespec.compiler.core.emit.plus +import community.flock.wirespec.compiler.core.parse.ast.Channel +import community.flock.wirespec.compiler.core.parse.ast.Definition +import community.flock.wirespec.compiler.core.parse.ast.Endpoint +import community.flock.wirespec.compiler.core.parse.ast.Enum +import community.flock.wirespec.compiler.core.parse.ast.Module +import community.flock.wirespec.compiler.core.parse.ast.Refined +import community.flock.wirespec.compiler.core.parse.ast.Union +import community.flock.wirespec.compiler.utils.Logger +import community.flock.wirespec.ir.converter.convert +import community.flock.wirespec.ir.converter.convertWithValidation +import community.flock.wirespec.ir.core.Assignment +import community.flock.wirespec.ir.core.Element +import community.flock.wirespec.ir.core.FieldCall +import community.flock.wirespec.ir.core.File +import community.flock.wirespec.ir.core.FunctionCall +import community.flock.wirespec.ir.core.Interface +import community.flock.wirespec.ir.core.Name +import community.flock.wirespec.ir.core.Namespace +import community.flock.wirespec.ir.core.Package +import community.flock.wirespec.ir.core.Precision +import community.flock.wirespec.ir.core.RawElement +import community.flock.wirespec.ir.core.RawExpression +import community.flock.wirespec.ir.core.ReturnStatement +import community.flock.wirespec.ir.core.Struct +import community.flock.wirespec.ir.core.Type +import community.flock.wirespec.ir.core.TypeDescriptor +import community.flock.wirespec.ir.core.VariableReference +import community.flock.wirespec.ir.core.findElement +import community.flock.wirespec.ir.core.function +import community.flock.wirespec.ir.core.import +import community.flock.wirespec.ir.core.withLabelField +import community.flock.wirespec.ir.core.`interface` +import community.flock.wirespec.ir.core.raw +import community.flock.wirespec.ir.core.struct +import community.flock.wirespec.ir.core.transform +import community.flock.wirespec.ir.core.transformChildren +import community.flock.wirespec.ir.emit.IrEmitter +import community.flock.wirespec.ir.generator.JavaGenerator +import community.flock.wirespec.ir.generator.generateJava +import community.flock.wirespec.compiler.core.parse.ast.Shared as AstShared +import community.flock.wirespec.compiler.core.parse.ast.Type as AstType +import community.flock.wirespec.ir.core.Enum as LanguageEnum +import community.flock.wirespec.ir.core.Function as LanguageFunction + +open class JavaIrEmitter( + override val packageName: PackageName = PackageName(DEFAULT_GENERATED_PACKAGE_STRING), + private val emitShared: EmitShared = EmitShared(), +) : IrEmitter, HasPackageName { + + override val generator = JavaGenerator + + override val extension = FileExtension.Java + + override fun transformTestFile(file: File): File = file.transformTypeDescriptors() + + private val wirespecImport = import("$DEFAULT_SHARED_PACKAGE_STRING.java", "Wirespec") + + override val shared = object : Shared { + override val packageString: String = "$DEFAULT_SHARED_PACKAGE_STRING.java" + + private val wirespecShared = AstShared(packageString).convert() + + private val imports = listOf( + import("java.lang.reflect", "Type"), + import("java.lang.reflect", "ParameterizedType"), + import("java.util", "List"), + import("java.util", "Map"), + ) + + private val clientServer = listOf( + `interface`("ServerEdge") { + typeParam(type("Req"), type("Request", Type.Wildcard)) + typeParam(type("Res"), type("Response", Type.Wildcard)) + function("from") { + returnType(type("Req")) + arg("request", type("RawRequest")) + } + function("to") { + returnType(type("RawResponse")) + arg("response", type("Res")) + } + }, + `interface`("ClientEdge") { + typeParam(type("Req"), type("Request", Type.Wildcard)) + typeParam(type("Res"), type("Response", Type.Wildcard)) + function("to") { + returnType(type("RawRequest")) + arg("request", type("Req")) + } + function("from") { + returnType(type("Res")) + arg("response", type("RawResponse")) + } + }, + `interface`("Client") { + typeParam(type("Req"), type("Request", Type.Wildcard)) + typeParam(type("Res"), type("Response", Type.Wildcard)) + function("getPathTemplate") { + returnType(string) + } + function("getMethod") { + returnType(string) + } + function("getClient") { + returnType(type("ClientEdge", type("Req"), type("Res"))) + arg("serialization", type("Serialization")) + } + }, + `interface`("Server") { + typeParam(type("Req"), type("Request", Type.Wildcard)) + typeParam(type("Res"), type("Response", Type.Wildcard)) + function("getPathTemplate") { + returnType(string) + } + function("getMethod") { + returnType(string) + } + function("getServer") { + returnType(type("ServerEdge", type("Req"), type("Res"))) + arg("serialization", type("Serialization")) + } + }, + raw( + """ + |public static Type getType(final Class actualTypeArguments, final Class rawType) { + | if(rawType != null) { + | return new ParameterizedType() { + | public Type getRawType() { return rawType; } + | public Type[] getActualTypeArguments() { return new Class[]{actualTypeArguments}; } + | public Type getOwnerType() { return null; } + | }; + | } + | else { return actualTypeArguments; } + |} + """.trimMargin(), + ), + ) + + private val wirespecFile = wirespecShared + .transform { + matchingElements { file: File -> + val (packageElements, rest) = file.elements.partition { it is Package } + file.copy(elements = packageElements + imports + rest) + } + injectAfter { namespace: Namespace -> + if (namespace.name == Name.of("Wirespec")) clientServer else emptyList() + } + } + + override val source: String = wirespecFile.generateJava() + } + + override fun emit(module: Module, logger: Logger): NonEmptyList { + val files = super.emit(module, logger) + return if (emitShared.value) { + files + File( + Name.of(PackageName("${DEFAULT_SHARED_PACKAGE_STRING}.java").toDir() + "Wirespec"), + listOf(RawElement(shared.source)) + ) + } else { + files + } + } + + override fun emit(definition: Definition, module: Module, logger: Logger): File { + val file = super.emit(definition, module, logger) + val subPackageName = packageName + definition + return File( + name = Name.of(subPackageName.toDir() + file.name.pascalCase().sanitizeSymbol()), + elements = buildList { + add(Package(subPackageName.value)) + if (module.needImports()) add(wirespecImport) + addAll(file.elements) + } + ) + } + + override fun emit(type: AstType, module: Module): File = + type.convertWithValidation(module) + .sanitizeNames() + + override fun emit(enum: Enum, module: Module): File = enum + .convert() + .transform { + matchingElements { languageEnum: LanguageEnum -> + languageEnum.withLabelField( + sanitizeEntry = { it.sanitizeEnum() }, + extraElements = listOf( + function("label") { + returnType(Type.String) + returns(VariableReference(Name.of("label"))) + }, + ), + ) + } + } + .sanitizeNames() + + override fun emit(union: Union): File = union + .convert() + .sanitizeNames() + + override fun emit(refined: Refined): File = refined.convert() + .transform { + matchingElements { s: Struct -> + s.copy( + interfaces = listOf(Type.Custom("Wirespec.Refined")), + elements = listOf( + function("toString", isOverride = true) { + returnType(string) + returns(FunctionCall(receiver = VariableReference(Name.of("value")), name = Name.of("toString"))) + }, + ) + s.elements.map { element -> + if (element is LanguageFunction && element.name == Name.of("validate")) { + element.copy(isOverride = true) + } else element + } + listOf( + function("value", isOverride = true) { + returnType(refined.reference.convert()) + returns(VariableReference(Name.of("value"))) + }, + ), + ) + } + } + .sanitizeNames() + + override fun emit(endpoint: Endpoint): File { + val imports = endpoint.buildImports() + return endpoint.convert() + .sanitizeNames() + .transformTypeDescriptors() + .injectHandleFunction(endpoint) + .let { file -> + if (imports.isNotEmpty()) { + file.transform { + matchingElements { f: File -> + f.copy(elements = imports + f.elements) + } + } + } else { + file + } + } + } + + override fun emit(channel: Channel): File { + val fullyQualifiedPrefix = if (channel.identifier.value == channel.reference.value) { + "${packageName.value}.model." + } else { + "" + } + return channel.convert() + .sanitizeNames() + .transform { + matchingElements { it: Interface -> it.withFullyQualifiedPrefix(fullyQualifiedPrefix) } + matchingElements { file: File -> + val interfaceElement = file.findElement()!! + file.copy(elements = listOf(RawElement("@FunctionalInterface\n"), interfaceElement)) + } + } + } + + override fun emitEndpointClient(endpoint: Endpoint): File { + val imports = endpoint.buildImports() + val endpointImport = import("${packageName.value}.endpoint", endpoint.identifier.value) + val file = super.emitEndpointClient(endpoint).sanitizeNames().transformTypeDescriptors() + val endpointName = endpoint.identifier.value + + val transformedFile = file.transform { + matchingElements { func: LanguageFunction -> + if (func.isAsync && func.body.size >= 2) { + val transportAssign = func.body[func.body.size - 2] + val returnStmt = func.body.last() + if (transportAssign is Assignment && returnStmt is ReturnStatement) { + val bodyPrefix = func.body.dropLast(2) + func.copy( + body = bodyPrefix + ReturnStatement( + FunctionCall( + name = Name.of("thenApply"), + receiver = transportAssign.value, + arguments = mapOf( + Name.of("mapper") to RawExpression( + "rawResponse -> $endpointName.fromRawResponse(serialization(), rawResponse)" + ) + ) + ) + ) + ) + } else func + } else func + } + } + + val subPackageName = packageName + "client" + return File( + name = Name.of(subPackageName.toDir() + transformedFile.name.pascalCase().sanitizeSymbol()), + elements = listOf(Package(subPackageName.value)) + + listOf(wirespecImport) + + imports + + listOf(endpointImport) + + transformedFile.elements + ) + } + + override fun emitClient(endpoints: List, logger: Logger): File { + val imports = endpoints.flatMap { it.importReferences() }.distinctBy { it.value } + .filter { imp -> endpoints.none { it.identifier.value == imp.value } } + .map { import("${packageName.value}.model", it.value) } + val endpointImports = endpoints.map { import("${packageName.value}.endpoint", it.identifier.value) } + val clientImports = endpoints.map { import("${packageName.value}.client", "${it.identifier.value}Client") } + val allImports = imports + endpointImports + clientImports + val file = super.emitClient(endpoints, logger).sanitizeNames() + return File( + name = Name.of(packageName.toDir() + file.name.pascalCase().sanitizeSymbol()), + elements = listOf(Package(packageName.value)) + + listOf(wirespecImport) + + allImports + + file.elements + ) + } + + private fun T.sanitizeNames(): T = transform { + fields { field -> + field.copy(name = field.name.sanitizeName()) + } + parameters { param -> + param.copy(name = Name.of(param.name.camelCase().sanitizeSymbol().sanitizeKeywords())) + } + statementAndExpression { stmt, tr -> + when (stmt) { + is FieldCall -> FieldCall( + receiver = stmt.receiver?.let { tr.transformExpression(it) }, + field = stmt.field.sanitizeName(), + ) + is FunctionCall -> if (stmt.name.value() == "validate") { + stmt.copy(typeArguments = emptyList()).transformChildren(tr) + } else stmt.transformChildren(tr) + else -> stmt.transformChildren(tr) + } + } + } + + private fun Name.sanitizeName(): Name { + val sanitized = if (parts.size > 1) camelCase() else value().sanitizeSymbol() + return Name(listOf(sanitized.sanitizeKeywords())) + } + + private fun String.sanitizeSymbol(): String = this + .split(".", " ", "-") + .mapIndexed { index, s -> if (index > 0) s.firstToUpper() else s } + .joinToString("") + .filter { it.isLetterOrDigit() || it == '_' } + .sanitizeFirstIsDigit() + + private fun String.sanitizeFirstIsDigit() = if (firstOrNull()?.isDigit() == true) "_${this}" else this + + private fun String.sanitizeKeywords() = if (this in reservedKeywords) "_$this" else this + + private fun String.sanitizeEnum() = split("-", ", ", ".", " ", "//") + .joinToString("_") + .sanitizeFirstIsDigit() + .sanitizeKeywords() + + private fun Definition.buildImports() = importReferences() + .filter { identifier.value != it.value } + .map { import("${packageName.value}.model", it.value) } + + private fun T.injectHandleFunction(endpoint: Endpoint): T { + val handlersStruct = buildHandlers(endpoint) + return transform { + matchingElements { iface: Interface -> + if (iface.name == Name.of("Handler")) { + iface.transform { injectAfter { _: Interface -> listOf(handlersStruct) } } + } else { + iface + } + } + } + } + + private fun buildHandlers(endpoint: Endpoint): Struct { + val pathTemplate = "/" + endpoint.path.joinToString("/") { + when (it) { + is Endpoint.Segment.Literal -> it.value + is Endpoint.Segment.Param -> "{${it.identifier.value}}" + } + } + + return struct(name = "Handlers") { + implements( + type("Wirespec.Server", type("Request"), type("Response", Type.Wildcard)) + ) + implements( + type("Wirespec.Client", type("Request"), type("Response", Type.Wildcard)) + ) + function("getPathTemplate", isOverride = true) { + returnType(Type.String) + returns(literal(pathTemplate)) + } + function("getMethod", isOverride = true) { + returnType(Type.String) + returns(literal(endpoint.method.name)) + } + function("getServer", isOverride = true) { + returnType( + type("Wirespec.ServerEdge", type("Request"), type("Response", Type.Wildcard)) + ) + arg("serialization", type("Wirespec.Serialization")) + returns( + RawExpression( + "new Wirespec.ServerEdge<>() {\n" + + "@Override public Request from(Wirespec.RawRequest request) {\n" + + " return fromRawRequest(serialization, request);\n" + + "}\n" + + "@Override public Wirespec.RawResponse to(Response response) {\n" + + " return toRawResponse(serialization, response);\n" + + "}\n" + + "}" + ), + ) + } + function("getClient", isOverride = true) { + returnType( + type("Wirespec.ClientEdge", type("Request"), type("Response", Type.Wildcard)) + ) + arg("serialization", type("Wirespec.Serialization")) + returns( + RawExpression( + "new Wirespec.ClientEdge<>() {\n" + + "@Override public Wirespec.RawRequest to(Request request) {\n" + + " return toRawRequest(serialization, request);\n" + + "}\n" + + "@Override public Response from(Wirespec.RawResponse response) {\n" + + " return fromRawResponse(serialization, response);\n" + + "}\n" + + "}" + ), + ) + } + } + } + + private fun Interface.withFullyQualifiedPrefix(prefix: String): Interface = + if (prefix.isNotEmpty()) { + transform { + parametersWhere( + predicate = { it.name == Name.of("message") }, + transform = { param -> + when (val t = param.type) { + is Type.Custom -> param.copy(type = t.copy(name = prefix + t.name)) + else -> param + } + }, + ) + } + } else { + this + } + + private fun T.transformTypeDescriptors(): T = transform { + statementAndExpression { stmt, tr -> + when (stmt) { + is TypeDescriptor -> { + val rootType = stmt.type.findRoot() + val containerStr = stmt.type.rawContainerClass() + val rootStr = "${rootType.toJavaName()}.class" + val containerArg = containerStr?.let { "$it.class" } ?: "null" + RawExpression("Wirespec.getType($rootStr, $containerArg)") + } + else -> stmt.transformChildren(tr) + } + } + } + + private fun Type.findRoot(): Type = when (this) { + is Type.Nullable -> type.findRoot() + is Type.Array -> elementType.findRoot() + is Type.Dict -> valueType.findRoot() + else -> this + } + + private fun Type.rawContainerClass(): String? = when (this) { + is Type.Nullable -> "java.util.Optional" + is Type.Array -> "java.util.List" + is Type.Dict -> "java.util.Map" + else -> null + } + + private fun Type.toJavaName(): String = when (this) { + is Type.Integer -> when (precision) { Precision.P32 -> "Integer"; Precision.P64 -> "Long" } + is Type.Number -> when (precision) { Precision.P32 -> "Float"; Precision.P64 -> "Double" } + Type.String -> "String" + Type.Boolean -> "Boolean" + Type.Bytes -> "byte[]" + Type.Any -> "Object" + Type.Unit -> "Void" + is Type.Custom -> name + else -> "Object" + } + + companion object : Keywords { + override val reservedKeywords = setOf( + "abstract", "continue", "for", "new", "switch", + "assert", "default", "goto", "package", "synchronized", + "boolean", "do", "if", "private", "this", + "break", "double", "implements", "protected", "throw", + "byte", "else", "import", "public", "throws", + "case", "enum", "instanceof", "return", "transient", + "catch", "extends", "int", "short", "try", + "char", "final", "interface", "static", "void", + "class", "finally", "long", "strictfp", "volatile", + "const", "float", "native", "super", "while", + "true", "false" + ) + } + +} diff --git a/src/compiler/emitters/java/src/commonTest/kotlin/community/flock/wirespec/emitters/java/JavaIrEmitterTest.kt b/src/compiler/emitters/java/src/commonTest/kotlin/community/flock/wirespec/emitters/java/JavaIrEmitterTest.kt new file mode 100644 index 000000000..faae7cecc --- /dev/null +++ b/src/compiler/emitters/java/src/commonTest/kotlin/community/flock/wirespec/emitters/java/JavaIrEmitterTest.kt @@ -0,0 +1,1233 @@ +package community.flock.wirespec.emitters.java + +import arrow.core.nonEmptyListOf +import arrow.core.nonEmptySetOf +import community.flock.wirespec.compiler.core.EmitContext +import community.flock.wirespec.compiler.core.FileUri +import community.flock.wirespec.compiler.core.parse.ast.AST +import community.flock.wirespec.compiler.core.parse.ast.Definition +import community.flock.wirespec.compiler.core.parse.ast.Module +import community.flock.wirespec.compiler.test.CompileChannelTest +import community.flock.wirespec.compiler.test.CompileComplexModelTest +import community.flock.wirespec.compiler.test.CompileEnumTest +import community.flock.wirespec.compiler.test.CompileFullEndpointTest +import community.flock.wirespec.compiler.test.CompileMinimalEndpointTest +import community.flock.wirespec.compiler.test.CompileNestedTypeTest +import community.flock.wirespec.compiler.test.CompileRefinedTest +import community.flock.wirespec.compiler.test.CompileTypeTest +import community.flock.wirespec.compiler.test.CompileUnionTest +import community.flock.wirespec.compiler.test.NodeFixtures +import community.flock.wirespec.compiler.utils.NoLogger +import io.kotest.assertions.arrow.core.shouldBeRight +import io.kotest.matchers.shouldBe +import kotlin.test.Test + +class JavaIrEmitterTest { + + private val emitContext = object : EmitContext, NoLogger { + override val emitters = nonEmptySetOf(JavaIrEmitter()) + } + + @Test + fun testEmitterType() { + val expected = listOf( + """ + |package community.flock.wirespec.generated.model; + |public record Todo ( + | String name, + | java.util.Optional description, + | java.util.List notes, + | Boolean done + |) implements Wirespec.Model { + | @Override + | public java.util.List validate() { + | return java.util.List.of(); + | } + |}; + | + """.trimMargin(), + ) + + val res = emitContext.emitFirst(NodeFixtures.type) + res shouldBe expected + } + + @Test + fun testEmitterEmptyType() { + val expected = listOf( + """ + |package community.flock.wirespec.generated.model; + |public record TodoWithoutProperties () implements Wirespec.Model { + | @Override + | public java.util.List validate() { + | return java.util.List.of(); + | } + |}; + | + """.trimMargin(), + ) + + val res = emitContext.emitFirst(NodeFixtures.emptyType) + res shouldBe expected + } + + @Test + fun testEmitterRefined() { + val expected = listOf( + """ + |package community.flock.wirespec.generated.model; + |import community.flock.wirespec.java.Wirespec; + |public record UUID ( + | String value + |) implements Wirespec.Refined { + | @Override + | public String toString() { + | return value.toString(); + | } + | @Override + | public Boolean validate() { + | return java.util.regex.Pattern.compile("^[0-9a-fA-F]{8}\\b-[0-9a-fA-F]{4}\\b-[0-9a-fA-F]{4}\\b-[0-9a-fA-F]{4}\\b-[0-9a-fA-F]{12}${'$'}").matcher(value).find(); + | } + | @Override + | public String value() { + | return value; + | } + |}; + | + """.trimMargin(), + ) + + val res = emitContext.emitFirst(NodeFixtures.refined) + res shouldBe expected + } + + @Test + fun testEmitterEnum() { + val expected = listOf( + """ + |package community.flock.wirespec.generated.model; + |import community.flock.wirespec.java.Wirespec; + |public enum TodoStatus implements Wirespec.Enum { + | OPEN("OPEN"), + | IN_PROGRESS("IN_PROGRESS"), + | CLOSE("CLOSE"); + | public final String label; + | TodoStatus(String label) { + | this.label = label; + | } + | @Override + | public String toString() { + | return label; + | } + | public String label() { + | return label; + | } + |} + | + """.trimMargin(), + ) + + val res = emitContext.emitFirst(NodeFixtures.enum) + res shouldBe expected + } + + @Test + fun compileFullEndpointTest() { + val java = """ + |package community.flock.wirespec.generated.endpoint; + |import community.flock.wirespec.java.Wirespec; + |import community.flock.wirespec.generated.model.Token; + |import community.flock.wirespec.generated.model.Token; + |import community.flock.wirespec.generated.model.PotentialTodoDto; + |import community.flock.wirespec.generated.model.TodoDto; + |import community.flock.wirespec.generated.model.Error; + |public interface PutTodo extends Wirespec.Endpoint { + | public static record Path ( + | String id + | ) implements Wirespec.Path { + | }; + | public static record Queries ( + | Boolean done, + | java.util.Optional name + | ) implements Wirespec.Queries { + | }; + | public static record RequestHeaders ( + | Token token, + | java.util.Optional refreshToken + | ) implements Wirespec.Request.Headers { + | }; + | public static record Request ( + | Path path, + | Wirespec.Method method, + | Queries queries, + | RequestHeaders headers, + | PotentialTodoDto body + | ) implements Wirespec.Request { + | public Request(String id, Boolean done, java.util.Optional name, Token token, java.util.Optional refreshToken, PotentialTodoDto body) { + | this(new Path(id), Wirespec.Method.PUT, new Queries( + | done, + | name + | ), new RequestHeaders( + | token, + | refreshToken + | ), body); + | } + | }; + | public sealed interface Response extends Wirespec.Response permits Response2XX, Response5XX, ResponseTodoDto, ResponseError {} + | public sealed interface Response2XX extends Response permits Response200, Response201 {} + | public sealed interface Response5XX extends Response permits Response500 {} + | public sealed interface ResponseTodoDto extends Response permits Response200, Response201 {} + | public sealed interface ResponseError extends Response permits Response500 {} + | public static record Response200 ( + | Integer status, + | Headers headers, + | TodoDto body + | ) implements Response2XX, ResponseTodoDto { + | public Response200(TodoDto body) { + | this(200, new Headers(), body); + | } + | public static record Headers () implements Wirespec.Response.Headers { + | }; + | }; + | public static record Response201 ( + | Integer status, + | Headers headers, + | TodoDto body + | ) implements Response2XX, ResponseTodoDto { + | public Response201(Token token, java.util.Optional refreshToken, TodoDto body) { + | this(201, new Headers( + | token, + | refreshToken + | ), body); + | } + | public static record Headers ( + | Token token, + | java.util.Optional refreshToken + | ) implements Wirespec.Response.Headers { + | }; + | }; + | public static record Response500 ( + | Integer status, + | Headers headers, + | Error body + | ) implements Response5XX, ResponseError { + | public Response500(Error body) { + | this(500, new Headers(), body); + | } + | public static record Headers () implements Wirespec.Response.Headers { + | }; + | }; + | public static Wirespec.RawRequest toRawRequest(Wirespec.Serializer serialization, Request request) { + | return new Wirespec.RawRequest( + | request.method().name(), + | java.util.List.of("todos", serialization.serializePath(request.path().id(), Wirespec.getType(String.class, null))), + | java.util.Map.ofEntries(java.util.Map.entry("done", serialization.serializeParam(request.queries().done(), Wirespec.getType(Boolean.class, null))), java.util.Map.entry("name", request.queries().name().map(it -> serialization.serializeParam(it, Wirespec.getType(String.class, null))).orElse(java.util.List.of()))), + | java.util.Map.ofEntries(java.util.Map.entry("token", serialization.serializeParam(request.headers().token(), Wirespec.getType(Token.class, null))), java.util.Map.entry("Refresh-Token", request.headers().refreshToken().map(it -> serialization.serializeParam(it, Wirespec.getType(Token.class, null))).orElse(java.util.List.of()))), + | java.util.Optional.of(serialization.serializeBody(request.body(), Wirespec.getType(PotentialTodoDto.class, null))) + | ); + | } + | public static Request fromRawRequest(Wirespec.Deserializer serialization, Wirespec.RawRequest request) { + | return new Request( + | serialization.deserializePath(request.path().get(1), Wirespec.getType(String.class, null)), + | java.util.Optional.ofNullable(request.queries().get("done")).map(it -> serialization.deserializeParam(it, Wirespec.getType(Boolean.class, null))).orElseThrow(() -> new IllegalStateException("Param done cannot be null")), + | java.util.Optional.ofNullable(request.queries().get("name")).map(it -> serialization.deserializeParam(it, Wirespec.getType(String.class, null))), + | java.util.Optional.ofNullable(request.headers().entrySet().stream().filter(e -> e.getKey().equalsIgnoreCase("token")).findFirst().map(java.util.Map.Entry::getValue).orElse(null)).map(it -> serialization.deserializeParam(it, Wirespec.getType(Token.class, null))).orElseThrow(() -> new IllegalStateException("Param token cannot be null")), + | java.util.Optional.ofNullable(request.headers().entrySet().stream().filter(e -> e.getKey().equalsIgnoreCase("Refresh-Token")).findFirst().map(java.util.Map.Entry::getValue).orElse(null)).map(it -> serialization.deserializeParam(it, Wirespec.getType(Token.class, null))), + | request.body().map(it -> serialization.deserializeBody(it, Wirespec.getType(PotentialTodoDto.class, null))).orElseThrow(() -> new IllegalStateException("body is null")) + | ); + | } + | public static Wirespec.RawResponse toRawResponse(Wirespec.Serializer serialization, Response response) { + | if (response instanceof Response200 r) { + | return new Wirespec.RawResponse( + | r.status(), + | java.util.Collections.emptyMap(), + | java.util.Optional.of(serialization.serializeBody(r.body(), Wirespec.getType(TodoDto.class, null))) + | ); + | } else if (response instanceof Response201 r) { + | return new Wirespec.RawResponse( + | r.status(), + | java.util.Map.ofEntries(java.util.Map.entry("token", serialization.serializeParam(r.headers().token(), Wirespec.getType(Token.class, null))), java.util.Map.entry("refreshToken", r.headers().refreshToken().map(it -> serialization.serializeParam(it, Wirespec.getType(Token.class, null))).orElse(java.util.List.of()))), + | java.util.Optional.of(serialization.serializeBody(r.body(), Wirespec.getType(TodoDto.class, null))) + | ); + | } else if (response instanceof Response500 r) { + | return new Wirespec.RawResponse( + | r.status(), + | java.util.Collections.emptyMap(), + | java.util.Optional.of(serialization.serializeBody(r.body(), Wirespec.getType(Error.class, null))) + | ); + | } else { + | throw new IllegalStateException(("Cannot match response with status: " + response.status())); + | } + | } + | public static Response fromRawResponse(Wirespec.Deserializer serialization, Wirespec.RawResponse response) { + | switch (response.statusCode()) { + | case 200 -> { + | return new Response200(response.body().map(it -> serialization.deserializeBody(it, Wirespec.getType(TodoDto.class, null))).orElseThrow(() -> new IllegalStateException("body is null"))); + | } + | case 201 -> { + | return new Response201( + | java.util.Optional.ofNullable(response.headers().entrySet().stream().filter(e -> e.getKey().equalsIgnoreCase("token")).findFirst().map(java.util.Map.Entry::getValue).orElse(null)).map(it -> serialization.deserializeParam(it, Wirespec.getType(Token.class, null))).orElseThrow(() -> new IllegalStateException("Param token cannot be null")), + | java.util.Optional.ofNullable(response.headers().entrySet().stream().filter(e -> e.getKey().equalsIgnoreCase("refreshToken")).findFirst().map(java.util.Map.Entry::getValue).orElse(null)).map(it -> serialization.deserializeParam(it, Wirespec.getType(Token.class, null))), + | response.body().map(it -> serialization.deserializeBody(it, Wirespec.getType(TodoDto.class, null))).orElseThrow(() -> new IllegalStateException("body is null")) + | ); + | } + | case 500 -> { + | return new Response500(response.body().map(it -> serialization.deserializeBody(it, Wirespec.getType(Error.class, null))).orElseThrow(() -> new IllegalStateException("body is null"))); + | } + | default -> { + | throw new IllegalStateException(("Cannot match response with status: " + response.statusCode())); + | } + | } + | } + | public interface Handler extends Wirespec.Handler { + | public java.util.concurrent.CompletableFuture> putTodo(Request request); + | public static record Handlers () implements Wirespec.Server>, Wirespec.Client> { + | @Override + | public String getPathTemplate() { + | return "/todos/{id}"; + | } + | @Override + | public String getMethod() { + | return "PUT"; + | } + | @Override + | public Wirespec.ServerEdge> getServer(Wirespec.Serialization serialization) { + | return new Wirespec.ServerEdge<>() { + | @Override public Request from(Wirespec.RawRequest request) { + | return fromRawRequest(serialization, request); + | } + | @Override public Wirespec.RawResponse to(Response response) { + | return toRawResponse(serialization, response); + | } + | }; + | } + | @Override + | public Wirespec.ClientEdge> getClient(Wirespec.Serialization serialization) { + | return new Wirespec.ClientEdge<>() { + | @Override public Wirespec.RawRequest to(Request request) { + | return toRawRequest(serialization, request); + | } + | @Override public Response from(Wirespec.RawResponse response) { + | return fromRawResponse(serialization, response); + | } + | }; + | } + | }; + | } + | public interface Call extends Wirespec.Call { + | public java.util.concurrent.CompletableFuture> putTodo(String id, Boolean done, java.util.Optional name, Token token, java.util.Optional refreshToken, PotentialTodoDto body); + | } + |} + | + |package community.flock.wirespec.generated.model; + |import community.flock.wirespec.java.Wirespec; + |public record PotentialTodoDto ( + | String name, + | Boolean done + |) implements Wirespec.Model { + | @Override + | public java.util.List validate() { + | return java.util.List.of(); + | } + |}; + | + |package community.flock.wirespec.generated.model; + |import community.flock.wirespec.java.Wirespec; + |public record Token ( + | String iss + |) implements Wirespec.Model { + | @Override + | public java.util.List validate() { + | return java.util.List.of(); + | } + |}; + | + |package community.flock.wirespec.generated.model; + |import community.flock.wirespec.java.Wirespec; + |public record TodoDto ( + | String id, + | String name, + | Boolean done + |) implements Wirespec.Model { + | @Override + | public java.util.List validate() { + | return java.util.List.of(); + | } + |}; + | + |package community.flock.wirespec.generated.model; + |import community.flock.wirespec.java.Wirespec; + |public record Error ( + | Long code, + | String description + |) implements Wirespec.Model { + | @Override + | public java.util.List validate() { + | return java.util.List.of(); + | } + |}; + | + |package community.flock.wirespec.generated.client; + |import community.flock.wirespec.java.Wirespec; + |import community.flock.wirespec.generated.model.Token; + |import community.flock.wirespec.generated.model.Token; + |import community.flock.wirespec.generated.model.PotentialTodoDto; + |import community.flock.wirespec.generated.model.TodoDto; + |import community.flock.wirespec.generated.model.Error; + |import community.flock.wirespec.generated.endpoint.PutTodo; + |public record PutTodoClient ( + | Wirespec.Serialization serialization, + | Wirespec.Transportation transportation + |) implements PutTodo.Call { + | @Override + | public java.util.concurrent.CompletableFuture> putTodo(String id, Boolean done, java.util.Optional name, Token token, java.util.Optional refreshToken, PotentialTodoDto body) { + | final var request = new PutTodo.Request( + | id, + | done, + | name, + | token, + | refreshToken, + | body + | ); + | final var rawRequest = PutTodo.toRawRequest(serialization(), request); + | return transportation().transport(rawRequest).thenApply(rawResponse -> PutTodo.fromRawResponse(serialization(), rawResponse)); + | } + |}; + | + |package community.flock.wirespec.generated; + |import community.flock.wirespec.java.Wirespec; + |import community.flock.wirespec.generated.model.Token; + |import community.flock.wirespec.generated.model.PotentialTodoDto; + |import community.flock.wirespec.generated.model.TodoDto; + |import community.flock.wirespec.generated.model.Error; + |import community.flock.wirespec.generated.endpoint.PutTodo; + |import community.flock.wirespec.generated.client.PutTodoClient; + |public record Client ( + | Wirespec.Serialization serialization, + | Wirespec.Transportation transportation + |) implements PutTodo.Call { + | @Override + | public java.util.concurrent.CompletableFuture> putTodo(String id, Boolean done, java.util.Optional name, Token token, java.util.Optional refreshToken, PotentialTodoDto body) { + | return new PutTodoClient( + | serialization(), + | transportation() + | ).putTodo(id, done, name, token, refreshToken, body); + | } + |}; + | + """.trimMargin() + + CompileFullEndpointTest.compiler { JavaIrEmitter() } shouldBeRight java + } + + @Test + fun compileChannelTest() { + val java = """ + |package community.flock.wirespec.generated.channel; + |@FunctionalInterface + |public interface Queue extends Wirespec.Channel { + | public void invoke(String message); + |} + | + """.trimMargin() + + CompileChannelTest.compiler { JavaIrEmitter() } shouldBeRight java + } + + @Test + fun compileEnumTest() { + val java = """ + |package community.flock.wirespec.generated.model; + |import community.flock.wirespec.java.Wirespec; + |public enum MyAwesomeEnum implements Wirespec.Enum { + | ONE("ONE"), + | Two("Two"), + | THREE_MORE("THREE_MORE"), + | UnitedKingdom("UnitedKingdom"), + | _1("-1"), + | _0("0"), + | _10("10"), + | _999("-999"), + | _88("88"); + | public final String label; + | MyAwesomeEnum(String label) { + | this.label = label; + | } + | @Override + | public String toString() { + | return label; + | } + | public String label() { + | return label; + | } + |} + | + """.trimMargin() + + CompileEnumTest.compiler { JavaIrEmitter() } shouldBeRight java + } + + @Test + fun compileMinimalEndpointTest() { + val java = """ + |package community.flock.wirespec.generated.endpoint; + |import community.flock.wirespec.java.Wirespec; + |import community.flock.wirespec.generated.model.TodoDto; + |public interface GetTodos extends Wirespec.Endpoint { + | public static record Path () implements Wirespec.Path { + | }; + | public static record Queries () implements Wirespec.Queries { + | }; + | public static record RequestHeaders () implements Wirespec.Request.Headers { + | }; + | public static record Request ( + | Path path, + | Wirespec.Method method, + | Queries queries, + | RequestHeaders headers, + | Void body + | ) implements Wirespec.Request { + | public Request() { + | this(new Path(), Wirespec.Method.GET, new Queries(), new RequestHeaders(), null); + | } + | }; + | public sealed interface Response extends Wirespec.Response permits Response2XX, ResponseListTodoDto {} + | public sealed interface Response2XX extends Response permits Response200 {} + | public sealed interface ResponseListTodoDto extends Response> permits Response200 {} + | public static record Response200 ( + | Integer status, + | Headers headers, + | java.util.List body + | ) implements Response2XX>, ResponseListTodoDto { + | public Response200(java.util.List body) { + | this(200, new Headers(), body); + | } + | public static record Headers () implements Wirespec.Response.Headers { + | }; + | }; + | public static Wirespec.RawRequest toRawRequest(Wirespec.Serializer serialization, Request request) { + | return new Wirespec.RawRequest( + | request.method().name(), + | java.util.List.of("todos"), + | java.util.Collections.emptyMap(), + | java.util.Collections.emptyMap(), + | java.util.Optional.empty() + | ); + | } + | public static Request fromRawRequest(Wirespec.Deserializer serialization, Wirespec.RawRequest request) { + | return new Request(); + | } + | public static Wirespec.RawResponse toRawResponse(Wirespec.Serializer serialization, Response response) { + | if (response instanceof Response200 r) { + | return new Wirespec.RawResponse( + | r.status(), + | java.util.Collections.emptyMap(), + | java.util.Optional.of(serialization.serializeBody(r.body(), Wirespec.getType(TodoDto.class, java.util.List.class))) + | ); + | } else { + | throw new IllegalStateException(("Cannot match response with status: " + response.status())); + | } + | } + | public static Response fromRawResponse(Wirespec.Deserializer serialization, Wirespec.RawResponse response) { + | switch (response.statusCode()) { + | case 200 -> { + | return new Response200(response.body().map(it -> serialization.>deserializeBody(it, Wirespec.getType(TodoDto.class, java.util.List.class))).orElseThrow(() -> new IllegalStateException("body is null"))); + | } + | default -> { + | throw new IllegalStateException(("Cannot match response with status: " + response.statusCode())); + | } + | } + | } + | public interface Handler extends Wirespec.Handler { + | public java.util.concurrent.CompletableFuture> getTodos(Request request); + | public static record Handlers () implements Wirespec.Server>, Wirespec.Client> { + | @Override + | public String getPathTemplate() { + | return "/todos"; + | } + | @Override + | public String getMethod() { + | return "GET"; + | } + | @Override + | public Wirespec.ServerEdge> getServer(Wirespec.Serialization serialization) { + | return new Wirespec.ServerEdge<>() { + | @Override public Request from(Wirespec.RawRequest request) { + | return fromRawRequest(serialization, request); + | } + | @Override public Wirespec.RawResponse to(Response response) { + | return toRawResponse(serialization, response); + | } + | }; + | } + | @Override + | public Wirespec.ClientEdge> getClient(Wirespec.Serialization serialization) { + | return new Wirespec.ClientEdge<>() { + | @Override public Wirespec.RawRequest to(Request request) { + | return toRawRequest(serialization, request); + | } + | @Override public Response from(Wirespec.RawResponse response) { + | return fromRawResponse(serialization, response); + | } + | }; + | } + | }; + | } + | public interface Call extends Wirespec.Call { + | public java.util.concurrent.CompletableFuture> getTodos(); + | } + |} + | + |package community.flock.wirespec.generated.model; + |import community.flock.wirespec.java.Wirespec; + |public record TodoDto ( + | String description + |) implements Wirespec.Model { + | @Override + | public java.util.List validate() { + | return java.util.List.of(); + | } + |}; + | + |package community.flock.wirespec.generated.client; + |import community.flock.wirespec.java.Wirespec; + |import community.flock.wirespec.generated.model.TodoDto; + |import community.flock.wirespec.generated.endpoint.GetTodos; + |public record GetTodosClient ( + | Wirespec.Serialization serialization, + | Wirespec.Transportation transportation + |) implements GetTodos.Call { + | @Override + | public java.util.concurrent.CompletableFuture> getTodos() { + | final var request = new GetTodos.Request(); + | final var rawRequest = GetTodos.toRawRequest(serialization(), request); + | return transportation().transport(rawRequest).thenApply(rawResponse -> GetTodos.fromRawResponse(serialization(), rawResponse)); + | } + |}; + | + |package community.flock.wirespec.generated; + |import community.flock.wirespec.java.Wirespec; + |import community.flock.wirespec.generated.model.TodoDto; + |import community.flock.wirespec.generated.endpoint.GetTodos; + |import community.flock.wirespec.generated.client.GetTodosClient; + |public record Client ( + | Wirespec.Serialization serialization, + | Wirespec.Transportation transportation + |) implements GetTodos.Call { + | @Override + | public java.util.concurrent.CompletableFuture> getTodos() { + | return new GetTodosClient( + | serialization(), + | transportation() + | ).getTodos(); + | } + |}; + | + """.trimMargin() + + CompileMinimalEndpointTest.compiler { JavaIrEmitter() } shouldBeRight java + } + + @Test + fun compileRefinedTest() { + val java = """ + |package community.flock.wirespec.generated.model; + |import community.flock.wirespec.java.Wirespec; + |public record TodoId ( + | String value + |) implements Wirespec.Refined { + | @Override + | public String toString() { + | return value.toString(); + | } + | @Override + | public Boolean validate() { + | return java.util.regex.Pattern.compile("^[0-9a-fA-F]{8}\\b-[0-9a-fA-F]{4}\\b-[0-9a-fA-F]{4}\\b-[0-9a-fA-F]{4}\\b-[0-9a-fA-F]{12}${'$'}").matcher(value).find(); + | } + | @Override + | public String value() { + | return value; + | } + |}; + | + |package community.flock.wirespec.generated.model; + |import community.flock.wirespec.java.Wirespec; + |public record TodoNoRegex ( + | String value + |) implements Wirespec.Refined { + | @Override + | public String toString() { + | return value.toString(); + | } + | @Override + | public Boolean validate() { + | return true; + | } + | @Override + | public String value() { + | return value; + | } + |}; + | + |package community.flock.wirespec.generated.model; + |import community.flock.wirespec.java.Wirespec; + |public record TestInt ( + | Long value + |) implements Wirespec.Refined { + | @Override + | public String toString() { + | return value.toString(); + | } + | @Override + | public Boolean validate() { + | return true; + | } + | @Override + | public Long value() { + | return value; + | } + |}; + | + |package community.flock.wirespec.generated.model; + |import community.flock.wirespec.java.Wirespec; + |public record TestInt0 ( + | Long value + |) implements Wirespec.Refined { + | @Override + | public String toString() { + | return value.toString(); + | } + | @Override + | public Boolean validate() { + | return true; + | } + | @Override + | public Long value() { + | return value; + | } + |}; + | + |package community.flock.wirespec.generated.model; + |import community.flock.wirespec.java.Wirespec; + |public record TestInt1 ( + | Long value + |) implements Wirespec.Refined { + | @Override + | public String toString() { + | return value.toString(); + | } + | @Override + | public Boolean validate() { + | return 0 <= value; + | } + | @Override + | public Long value() { + | return value; + | } + |}; + | + |package community.flock.wirespec.generated.model; + |import community.flock.wirespec.java.Wirespec; + |public record TestInt2 ( + | Long value + |) implements Wirespec.Refined { + | @Override + | public String toString() { + | return value.toString(); + | } + | @Override + | public Boolean validate() { + | return 1 <= value && value <= 3; + | } + | @Override + | public Long value() { + | return value; + | } + |}; + | + |package community.flock.wirespec.generated.model; + |import community.flock.wirespec.java.Wirespec; + |public record TestNum ( + | Double value + |) implements Wirespec.Refined { + | @Override + | public String toString() { + | return value.toString(); + | } + | @Override + | public Boolean validate() { + | return true; + | } + | @Override + | public Double value() { + | return value; + | } + |}; + | + |package community.flock.wirespec.generated.model; + |import community.flock.wirespec.java.Wirespec; + |public record TestNum0 ( + | Double value + |) implements Wirespec.Refined { + | @Override + | public String toString() { + | return value.toString(); + | } + | @Override + | public Boolean validate() { + | return true; + | } + | @Override + | public Double value() { + | return value; + | } + |}; + | + |package community.flock.wirespec.generated.model; + |import community.flock.wirespec.java.Wirespec; + |public record TestNum1 ( + | Double value + |) implements Wirespec.Refined { + | @Override + | public String toString() { + | return value.toString(); + | } + | @Override + | public Boolean validate() { + | return value <= 0.5; + | } + | @Override + | public Double value() { + | return value; + | } + |}; + | + |package community.flock.wirespec.generated.model; + |import community.flock.wirespec.java.Wirespec; + |public record TestNum2 ( + | Double value + |) implements Wirespec.Refined { + | @Override + | public String toString() { + | return value.toString(); + | } + | @Override + | public Boolean validate() { + | return -0.2 <= value && value <= 0.5; + | } + | @Override + | public Double value() { + | return value; + | } + |}; + | + """.trimMargin() + + CompileRefinedTest.compiler { JavaIrEmitter() } shouldBeRight java + } + + @Test + fun compileUnionTest() { + val java = """ + |package community.flock.wirespec.generated.model; + |public sealed interface UserAccount permits UserAccountPassword, UserAccountToken {} + | + |package community.flock.wirespec.generated.model; + |public record UserAccountPassword ( + | String username, + | String password + |) implements Wirespec.Model, UserAccount { + | @Override + | public java.util.List validate() { + | return java.util.List.of(); + | } + |}; + | + |package community.flock.wirespec.generated.model; + |public record UserAccountToken ( + | String token + |) implements Wirespec.Model, UserAccount { + | @Override + | public java.util.List validate() { + | return java.util.List.of(); + | } + |}; + | + |package community.flock.wirespec.generated.model; + |public record User ( + | String username, + | UserAccount account + |) implements Wirespec.Model { + | @Override + | public java.util.List validate() { + | return java.util.List.of(); + | } + |}; + | + """.trimMargin() + + CompileUnionTest.compiler { JavaIrEmitter() } shouldBeRight java + } + + @Test + fun compileTypeTest() { + val java = """ + |package community.flock.wirespec.generated.model; + |public record Request ( + | String type, + | String url, + | java.util.Optional BODY_TYPE, + | java.util.List params, + | java.util.Map headers, + | java.util.Optional>>>> body + |) implements Wirespec.Model { + | @Override + | public java.util.List validate() { + | return java.util.List.of(); + | } + |}; + | + """.trimMargin() + + CompileTypeTest.compiler { JavaIrEmitter() } shouldBeRight java + } + + @Test + fun compileNestedTypeTest() { + val java = """ + |package community.flock.wirespec.generated.model; + |import community.flock.wirespec.java.Wirespec; + |public record DutchPostalCode ( + | String value + |) implements Wirespec.Refined { + | @Override + | public String toString() { + | return value.toString(); + | } + | @Override + | public Boolean validate() { + | return java.util.regex.Pattern.compile("^([0-9]{4}[A-Z]{2})${'$'}").matcher(value).find(); + | } + | @Override + | public String value() { + | return value; + | } + |}; + | + |package community.flock.wirespec.generated.model; + |import community.flock.wirespec.java.Wirespec; + |public record Address ( + | String street, + | Long houseNumber, + | DutchPostalCode postalCode + |) implements Wirespec.Model { + | @Override + | public java.util.List validate() { + | return (!postalCode().validate() ? java.util.List.of("postalCode") : java.util.List.of()); + | } + |}; + | + |package community.flock.wirespec.generated.model; + |import community.flock.wirespec.java.Wirespec; + |public record Person ( + | String name, + | Address address, + | java.util.List tags + |) implements Wirespec.Model { + | @Override + | public java.util.List validate() { + | return address().validate().stream().map(e -> "address." + e).toList(); + | } + |}; + | + """.trimMargin() + + CompileNestedTypeTest.compiler { JavaIrEmitter() } shouldBeRight java + } + + @Test + fun compileComplexModelTest() { + val java = """ + |package community.flock.wirespec.generated.model; + |import community.flock.wirespec.java.Wirespec; + |public record Email ( + | String value + |) implements Wirespec.Refined { + | @Override + | public String toString() { + | return value.toString(); + | } + | @Override + | public Boolean validate() { + | return java.util.regex.Pattern.compile("^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}${'$'}").matcher(value).find(); + | } + | @Override + | public String value() { + | return value; + | } + |}; + | + |package community.flock.wirespec.generated.model; + |import community.flock.wirespec.java.Wirespec; + |public record PhoneNumber ( + | String value + |) implements Wirespec.Refined { + | @Override + | public String toString() { + | return value.toString(); + | } + | @Override + | public Boolean validate() { + | return java.util.regex.Pattern.compile("^\\+[1-9]\\d{1,14}${'$'}").matcher(value).find(); + | } + | @Override + | public String value() { + | return value; + | } + |}; + | + |package community.flock.wirespec.generated.model; + |import community.flock.wirespec.java.Wirespec; + |public record Tag ( + | String value + |) implements Wirespec.Refined { + | @Override + | public String toString() { + | return value.toString(); + | } + | @Override + | public Boolean validate() { + | return java.util.regex.Pattern.compile("^[a-z][a-z0-9-]{0,19}${'$'}").matcher(value).find(); + | } + | @Override + | public String value() { + | return value; + | } + |}; + | + |package community.flock.wirespec.generated.model; + |import community.flock.wirespec.java.Wirespec; + |public record EmployeeAge ( + | Long value + |) implements Wirespec.Refined { + | @Override + | public String toString() { + | return value.toString(); + | } + | @Override + | public Boolean validate() { + | return 18 <= value && value <= 65; + | } + | @Override + | public Long value() { + | return value; + | } + |}; + | + |package community.flock.wirespec.generated.model; + |import community.flock.wirespec.java.Wirespec; + |public record ContactInfo ( + | Email email, + | java.util.Optional phone + |) implements Wirespec.Model { + | @Override + | public java.util.List validate() { + | return java.util.stream.Stream.of((!email().validate() ? java.util.List.of("email") : java.util.List.of()), phone().map(it -> (!it.validate() ? java.util.List.of("phone") : java.util.List.of())).orElse(java.util.List.of())).flatMap(java.util.Collection::stream).toList(); + | } + |}; + | + |package community.flock.wirespec.generated.model; + |import community.flock.wirespec.java.Wirespec; + |public record Employee ( + | String name, + | EmployeeAge age, + | ContactInfo contactInfo, + | java.util.List tags + |) implements Wirespec.Model { + | @Override + | public java.util.List validate() { + | return java.util.stream.Stream.of((!age().validate() ? java.util.List.of("age") : java.util.List.of()), contactInfo().validate().stream().map(e -> "contactInfo." + e).toList(), java.util.stream.IntStream.range(0, tags().size()).mapToObj(i -> (!tags().get(i).validate() ? java.util.List.of("tags[" + i + "]") : java.util.List.of())).flatMap(java.util.Collection::stream).toList()).flatMap(java.util.Collection::stream).toList(); + | } + |}; + | + |package community.flock.wirespec.generated.model; + |import community.flock.wirespec.java.Wirespec; + |public record Department ( + | String name, + | java.util.List employees + |) implements Wirespec.Model { + | @Override + | public java.util.List validate() { + | return java.util.stream.IntStream.range(0, employees().size()).mapToObj(i -> employees().get(i).validate().stream().map(e -> "employees[" + i + "]." + e).toList()).flatMap(java.util.Collection::stream).toList(); + | } + |}; + | + |package community.flock.wirespec.generated.model; + |import community.flock.wirespec.java.Wirespec; + |public record Company ( + | String name, + | java.util.List departments + |) implements Wirespec.Model { + | @Override + | public java.util.List validate() { + | return java.util.stream.IntStream.range(0, departments().size()).mapToObj(i -> departments().get(i).validate().stream().map(e -> "departments[" + i + "]." + e).toList()).flatMap(java.util.Collection::stream).toList(); + | } + |}; + | + """.trimMargin() + + CompileComplexModelTest.compiler { JavaIrEmitter() } shouldBeRight java + } + + @Test + fun sharedOutputTest() { + val expected = """ + |package community.flock.wirespec.java; + |import java.lang.reflect.Type; + |import java.lang.reflect.ParameterizedType; + |import java.util.List; + |import java.util.Map; + |public interface Wirespec { + | public interface Model { + | public java.util.List validate(); + | } + | public interface Enum { + | String label(); + | } + | public interface Endpoint { + | } + | public interface Channel { + | } + | public interface Refined { + | T value(); + | public Boolean validate(); + | } + | public interface Path { + | } + | public interface Queries { + | } + | public interface Headers { + | } + | public interface Handler { + | } + | public interface Call { + | } + | public enum Method { + | GET, + | PUT, + | POST, + | DELETE, + | OPTIONS, + | HEAD, + | PATCH, + | TRACE + | } public interface Request { + | Path path(); + | Method method(); + | Queries queries(); + | Headers headers(); + | T body(); + | public interface Headers { + | } + | } + | public interface Response { + | Integer status(); + | Headers headers(); + | T body(); + | public interface Headers { + | } + | } + | public interface BodySerializer { + | public byte[] serializeBody(T t, Type type); + | } + | public interface BodyDeserializer { + | public T deserializeBody(byte[] raw, Type type); + | } + | public interface BodySerialization extends BodySerializer, BodyDeserializer { + | } + | public interface PathSerializer { + | public String serializePath(T t, Type type); + | } + | public interface PathDeserializer { + | public T deserializePath(String raw, Type type); + | } + | public interface PathSerialization extends PathSerializer, PathDeserializer { + | } + | public interface ParamSerializer { + | public java.util.List serializeParam(T value, Type type); + | } + | public interface ParamDeserializer { + | public T deserializeParam(java.util.List values, Type type); + | } + | public interface ParamSerialization extends ParamSerializer, ParamDeserializer { + | } + | public interface Serializer extends BodySerializer, PathSerializer, ParamSerializer { + | } + | public interface Deserializer extends BodyDeserializer, PathDeserializer, ParamDeserializer { + | } + | public interface Serialization extends Serializer, Deserializer { + | } + | public static record RawRequest ( + | String method, + | java.util.List path, + | java.util.Map> queries, + | java.util.Map> headers, + | java.util.Optional body + | ) { + | }; + | public static record RawResponse ( + | Integer statusCode, + | java.util.Map> headers, + | java.util.Optional body + | ) { + | }; + | public interface Transportation { + | public java.util.concurrent.CompletableFuture transport(RawRequest request); + | } + | public interface ServerEdge, Res extends Response> { + | public Req from(RawRequest request); + | public RawResponse to(Res response); + | } + | public interface ClientEdge, Res extends Response> { + | public RawRequest to(Req request); + | public Res from(RawResponse response); + | } + | public interface Client, Res extends Response> { + | public String getPathTemplate(); + | public String getMethod(); + | public ClientEdge getClient(Serialization serialization); + | } + | public interface Server, Res extends Response> { + | public String getPathTemplate(); + | public String getMethod(); + | public ServerEdge getServer(Serialization serialization); + | } + | public static Type getType(final Class actualTypeArguments, final Class rawType) { + | if(rawType != null) { + | return new ParameterizedType() { + | public Type getRawType() { return rawType; } + | public Type[] getActualTypeArguments() { return new Class[]{actualTypeArguments}; } + | public Type getOwnerType() { return null; } + | }; + | } + | else { return actualTypeArguments; } + | }} + | + """.trimMargin() + + val emitter = JavaIrEmitter() + emitter.shared.source shouldBe expected + } + + private fun EmitContext.emitFirst(node: Definition) = emitters.map { + val ast = AST( + nonEmptyListOf( + Module( + FileUri(""), + nonEmptyListOf(node), + ), + ), + ) + it.emit(ast, logger).first().result + } +} diff --git a/src/compiler/emitters/kotlin/build.gradle.kts b/src/compiler/emitters/kotlin/build.gradle.kts index 217e6b788..68b5de636 100644 --- a/src/compiler/emitters/kotlin/build.gradle.kts +++ b/src/compiler/emitters/kotlin/build.gradle.kts @@ -41,6 +41,7 @@ kotlin { commonMain { dependencies { api(project(":src:compiler:core")) + api(project(":src:compiler:ir")) } } commonTest { diff --git a/src/compiler/emitters/kotlin/src/commonMain/kotlin/community/flock/wirespec/emitters/kotlin/KotlinIrEmitter.kt b/src/compiler/emitters/kotlin/src/commonMain/kotlin/community/flock/wirespec/emitters/kotlin/KotlinIrEmitter.kt new file mode 100644 index 000000000..b511a71bc --- /dev/null +++ b/src/compiler/emitters/kotlin/src/commonMain/kotlin/community/flock/wirespec/emitters/kotlin/KotlinIrEmitter.kt @@ -0,0 +1,375 @@ +package community.flock.wirespec.emitters.kotlin + +import arrow.core.NonEmptyList +import community.flock.wirespec.compiler.core.addBackticks +import community.flock.wirespec.compiler.core.emit.DEFAULT_GENERATED_PACKAGE_STRING +import community.flock.wirespec.compiler.core.emit.DEFAULT_SHARED_PACKAGE_STRING +import community.flock.wirespec.compiler.core.emit.EmitShared +import community.flock.wirespec.compiler.core.emit.FileExtension +import community.flock.wirespec.compiler.core.emit.HasPackageName +import community.flock.wirespec.ir.core.ConstructorStatement +import community.flock.wirespec.ir.core.Element +import community.flock.wirespec.ir.core.FieldCall +import community.flock.wirespec.ir.core.FunctionCall +import community.flock.wirespec.ir.core.Name +import community.flock.wirespec.ir.emit.IrEmitter +import community.flock.wirespec.compiler.core.emit.Keywords +import community.flock.wirespec.compiler.core.emit.LanguageEmitter.Companion.firstToUpper +import community.flock.wirespec.compiler.core.emit.LanguageEmitter.Companion.needImports +import community.flock.wirespec.compiler.core.emit.PackageName +import community.flock.wirespec.compiler.core.emit.Shared +import community.flock.wirespec.compiler.core.emit.importReferences +import community.flock.wirespec.compiler.core.emit.plus +import community.flock.wirespec.compiler.core.parse.ast.Channel +import community.flock.wirespec.compiler.core.parse.ast.Definition +import community.flock.wirespec.compiler.core.parse.ast.Endpoint +import community.flock.wirespec.compiler.core.parse.ast.Enum +import community.flock.wirespec.compiler.core.parse.ast.FieldIdentifier +import community.flock.wirespec.compiler.core.parse.ast.Identifier +import community.flock.wirespec.compiler.core.parse.ast.Module +import community.flock.wirespec.compiler.core.parse.ast.Reference +import community.flock.wirespec.compiler.core.parse.ast.Refined +import community.flock.wirespec.compiler.core.parse.ast.Type +import community.flock.wirespec.compiler.core.parse.ast.Union +import community.flock.wirespec.compiler.utils.Logger +import community.flock.wirespec.ir.converter.convert +import community.flock.wirespec.ir.converter.convertConstraint +import community.flock.wirespec.ir.converter.convertWithValidation +import community.flock.wirespec.ir.core.Constructor +import community.flock.wirespec.ir.core.File +import community.flock.wirespec.ir.core.Import +import community.flock.wirespec.ir.core.Interface +import community.flock.wirespec.ir.core.RawElement +import community.flock.wirespec.ir.core.RawExpression +import community.flock.wirespec.ir.core.Namespace +import community.flock.wirespec.ir.core.Struct +import community.flock.wirespec.ir.core.VariableReference +import community.flock.wirespec.ir.core.findElement +import community.flock.wirespec.ir.core.function +import community.flock.wirespec.ir.core.`interface` +import community.flock.wirespec.ir.core.raw +import community.flock.wirespec.ir.core.transform +import community.flock.wirespec.ir.core.transformChildren +import community.flock.wirespec.ir.core.withLabelField +import community.flock.wirespec.ir.generator.KotlinGenerator +import community.flock.wirespec.ir.generator.generateKotlin +import community.flock.wirespec.compiler.core.parse.ast.Shared as AstShared +import community.flock.wirespec.ir.core.Enum as LanguageEnum +import community.flock.wirespec.ir.core.File as LanguageFile +import community.flock.wirespec.ir.core.Package as LanguagePackage +import community.flock.wirespec.ir.core.Type as LanguageType + +open class KotlinIrEmitter( + override val packageName: PackageName = PackageName(DEFAULT_GENERATED_PACKAGE_STRING), + private val emitShared: EmitShared = EmitShared(), +) : IrEmitter, HasPackageName { + + override val generator = KotlinGenerator + + override val extension = FileExtension.Kotlin + + private val wirespecImport = """ + | + |import $DEFAULT_SHARED_PACKAGE_STRING.kotlin.Wirespec + |import kotlin.reflect.typeOf + | + """.trimMargin() + + override val shared = object : Shared { + override val packageString = "$DEFAULT_SHARED_PACKAGE_STRING.kotlin" + + private val clientServer = listOf( + `interface`("ServerEdge") { + typeParam(type("Req"), type("Request", LanguageType.Wildcard)) + typeParam(type("Res"), type("Response", LanguageType.Wildcard)) + function("from") { + returnType(type("Req")) + arg("request", type("RawRequest")) + } + function("to") { + returnType(type("RawResponse")) + arg("response", type("Res")) + } + }, + `interface`("ClientEdge") { + typeParam(type("Req"), type("Request", LanguageType.Wildcard)) + typeParam(type("Res"), type("Response", LanguageType.Wildcard)) + function("to") { + returnType(type("RawRequest")) + arg("request", type("Req")) + } + function("from") { + returnType(type("Res")) + arg("response", type("RawResponse")) + } + }, + `interface`("Client") { + typeParam(type("Req"), type("Request", LanguageType.Wildcard)) + typeParam(type("Res"), type("Response", LanguageType.Wildcard)) + field("pathTemplate", LanguageType.String) + field("method", LanguageType.String) + function("client") { + returnType(type("ClientEdge", type("Req"), type("Res"))) + arg("serialization", type("Serialization")) + } + }, + `interface`("Server") { + typeParam(type("Req"), type("Request", LanguageType.Wildcard)) + typeParam(type("Res"), type("Response", LanguageType.Wildcard)) + field("pathTemplate", LanguageType.String) + field("method", LanguageType.String) + function("server") { + returnType(type("ServerEdge", type("Req"), type("Res"))) + arg("serialization", type("Serialization")) + } + }, + ) + + override val source = AstShared(packageString) + .convert() + .transform { + matchingElements { file: LanguageFile -> + val (packageElements, rest) = file.elements.partition { it is LanguagePackage } + file.copy(elements = packageElements + Import("kotlin.reflect", LanguageType.Custom("KType")) + rest) + } + injectAfter { namespace: Namespace -> + if (namespace.name == Name.of("Wirespec")) clientServer + else emptyList() + } + } + .generateKotlin() + } + + override fun emit(module: Module, logger: Logger): NonEmptyList { + val files = super.emit(module, logger) + return if (emitShared.value) { + files + File( + Name.of(PackageName("${DEFAULT_SHARED_PACKAGE_STRING}.kotlin").toDir() + "Wirespec"), + listOf(RawElement(shared.source)) + ) + } else { + files + } + } + + override fun emit(definition: Definition, module: Module, logger: Logger): File { + val file = super.emit(definition, module, logger) + val subPackageName = packageName + definition + return File( + name = Name.of(subPackageName.toDir() + file.name.pascalCase()), + elements = buildList { + add(LanguagePackage(subPackageName.value)) + if (module.needImports()) add(RawElement(wirespecImport)) + addAll(file.elements) + } + ) + } + + override fun emit(type: Type, module: Module): File = + type.convertWithValidation(module) + .sanitizeNames() + .transform { + matchingElements { struct: Struct -> + if (struct.fields.isEmpty()) struct.copy(constructors = listOf(Constructor(emptyList(), emptyList()))) + else struct + } + } + + override fun emit(enum: Enum, module: Module): File = enum + .convert() + .sanitizeNames() + .transform { + matchingElements { languageEnum: LanguageEnum -> + languageEnum.withLabelField( + sanitizeEntry = { it.sanitizeEnum() }, + labelFieldOverride = true, + labelExpression = RawExpression("label"), + ) + } + } + + override fun emit(union: Union): File = union + .convert() + .sanitizeNames() + + override fun emit(refined: Refined): File { + val file = refined.convert().sanitizeNames() + val struct = file.findElement()!! + val toStringExpr = when (refined.reference.type) { + is Reference.Primitive.Type.String -> "value" + else -> "value.toString()" + } + val updatedStruct = struct.copy( + fields = struct.fields.map { f -> f.copy(isOverride = true) }, + elements = listOf( + function("toString", isOverride = true) { + returnType(LanguageType.String) + returns(RawExpression(toStringExpr)) + }, + function("validate", isOverride = true) { + returnType(LanguageType.Boolean) + returns(refined.reference.convertConstraint(VariableReference(Name.of("value")))) + }, + ), + ) + return LanguageFile(Name.of(refined.identifier.sanitize()), listOf(updatedStruct)) + } + + override fun emit(endpoint: Endpoint): File { + val imports = endpoint.buildImports() + val file = endpoint.convert().sanitizeNames() + val endpointNamespace = file.findElement()!! + val body = endpointNamespace.injectCompanionObject(endpoint) + return if (imports.isNotEmpty()) { + LanguageFile(Name.of(endpoint.identifier.sanitize()), listOf(RawElement(imports), body)) + } else { + LanguageFile(Name.of(endpoint.identifier.sanitize()), listOf(body)) + } + } + + override fun emit(channel: Channel): File { + val imports = channel.buildImports() + val file = channel.convert().sanitizeNames() + return if (imports.isNotEmpty()) file.copy(elements = listOf(RawElement(imports)) + file.elements) + else file + } + + override fun emitEndpointClient(endpoint: Endpoint): File { + val imports = endpoint.buildImports() + val endpointImport = "import ${packageName.value}.endpoint.${endpoint.identifier.value}" + val allImports = listOf(imports, endpointImport).filter { it.isNotEmpty() }.joinToString("\n") + val file = super.emitEndpointClient(endpoint).sanitizeNames() + val subPackageName = packageName + "client" + return File( + name = Name.of(subPackageName.toDir() + file.name.pascalCase()), + elements = buildList { + add(LanguagePackage(subPackageName.value)) + add(RawElement(wirespecImport)) + if (allImports.isNotEmpty()) add(RawElement(allImports)) + addAll(file.elements) + } + ) + } + + override fun emitClient(endpoints: List, logger: Logger): File { + val imports = endpoints.flatMap { it.importReferences() }.distinctBy { it.value } + .joinToString("\n") { "import ${packageName.value}.model.${it.value}" } + val endpointImports = endpoints + .joinToString("\n") { "import ${packageName.value}.endpoint.${it.identifier.value}" } + val clientImports = endpoints + .joinToString("\n") { "import ${packageName.value}.client.${it.identifier.value}Client" } + val allImports = listOf(imports, endpointImports, clientImports).filter { it.isNotEmpty() }.joinToString("\n") + val file = super.emitClient(endpoints, logger).sanitizeNames() + return File( + name = Name.of(packageName.toDir() + file.name.pascalCase()), + elements = buildList { + add(LanguagePackage(packageName.value)) + add(RawElement(wirespecImport)) + if (allImports.isNotEmpty()) add(RawElement(allImports)) + addAll(file.elements) + } + ) + } + + private fun T.sanitizeNames(): T = transform { + fields { field -> + field.copy(name = field.name.sanitizeName()) + } + parameters { param -> + param.copy(name = Name.of(param.name.camelCase().sanitizeSymbol().sanitizeKeywords())) + } + statementAndExpression { stmt, tr -> + when (stmt) { + is FieldCall -> FieldCall( + receiver = stmt.receiver?.let { tr.transformExpression(it) }, + field = stmt.field.sanitizeName(), + ) + is FunctionCall -> if (stmt.name.value() == "validate") { + stmt.copy(typeArguments = emptyList()).transformChildren(tr) + } else stmt.transformChildren(tr) + is ConstructorStatement -> ConstructorStatement( + type = tr.transformType(stmt.type), + namedArguments = stmt.namedArguments.map { (name, expr) -> + name.sanitizeName() to tr.transformExpression(expr) + }.toMap(), + ) + else -> stmt.transformChildren(tr) + } + } + } + + private fun Name.sanitizeName(): Name { + val sanitized = if (parts.size > 1) camelCase() else value().sanitizeSymbol() + return Name(listOf(sanitized.sanitizeKeywords())) + } + + private fun Identifier.sanitize(): String = value + .split(".", " ") + .mapIndexed { index, s -> if (index > 0) s.firstToUpper() else s } + .joinToString("") + .filter { it.isLetterOrDigit() || it == '_' } + .sanitizeFirstIsDigit() + .let { if (this is FieldIdentifier) it.sanitizeKeywords() else it } + + private fun String.sanitizeFirstIsDigit() = if (firstOrNull()?.isDigit() == true) "_${this}" else this + + private fun String.sanitizeKeywords() = if (this in reservedKeywords) addBackticks() else this + + private fun String.sanitizeSymbol(): String = this + .split(".", " ", "-") + .mapIndexed { index, s -> if (index > 0) s.firstToUpper() else s } + .joinToString("") + .filter { it.isLetterOrDigit() || it == '_' } + .sanitizeFirstIsDigit() + + private fun String.sanitizeEnum() = split("-", ", ", ".", " ", "//") + .joinToString("_") + .sanitizeFirstIsDigit() + .sanitizeKeywords() + + private fun Definition.buildImports() = importReferences() + .distinctBy { it.value } + .joinToString("\n") { "import ${packageName.value}.model.${it.value}" } + + private fun Namespace.injectCompanionObject(endpoint: Endpoint): Namespace = + transform { + injectAfter { iface: Interface -> + if (iface.name == Name.of("Handler")) listOf(companionObject(endpoint)) else emptyList() + } + } + + private fun companionObject(endpoint: Endpoint): RawElement { + val pathTemplate = "/" + endpoint.path.joinToString("/") { + when (it) { + is Endpoint.Segment.Literal -> it.value + is Endpoint.Segment.Param -> "{${it.identifier.value}}" + } + } + return """ + |companion object: Wirespec.Server>, Wirespec.Client> { + | override val pathTemplate = "$pathTemplate" + | override val method = "${endpoint.method}" + | override fun server(serialization: Wirespec.Serialization) = object : Wirespec.ServerEdge> { + | override fun from(request: Wirespec.RawRequest) = fromRawRequest(serialization, request) + | override fun to(response: Response<*>) = toRawResponse(serialization, response) + | } + | override fun client(serialization: Wirespec.Serialization) = object : Wirespec.ClientEdge> { + | override fun to(request: Request) = toRawRequest(serialization, request) + | override fun from(response: Wirespec.RawResponse) = fromRawResponse(serialization, response) + | } + |} + """.trimMargin().let(::raw) + } + + companion object : Keywords { + override val reservedKeywords = setOf( + "as", "break", "class", "continue", "do", + "else", "false", "for", "fun", "if", + "in", "interface", "internal", "is", "null", + "object", "open", "package", "return", "super", + "this", "throw", "true", "try", "typealias", + "typeof", "val", "var", "when", "while", "private", "public" + ) + } + +} diff --git a/src/compiler/emitters/kotlin/src/commonTest/kotlin/community/flock/wirespec/emitters/kotlin/KotlinIrEmitterTest.kt b/src/compiler/emitters/kotlin/src/commonTest/kotlin/community/flock/wirespec/emitters/kotlin/KotlinIrEmitterTest.kt new file mode 100644 index 000000000..93c2fcbcd --- /dev/null +++ b/src/compiler/emitters/kotlin/src/commonTest/kotlin/community/flock/wirespec/emitters/kotlin/KotlinIrEmitterTest.kt @@ -0,0 +1,993 @@ +package community.flock.wirespec.emitters.kotlin + +import arrow.core.nonEmptyListOf +import arrow.core.nonEmptySetOf +import community.flock.wirespec.compiler.core.EmitContext +import community.flock.wirespec.compiler.core.FileUri +import community.flock.wirespec.compiler.core.parse.ast.AST +import community.flock.wirespec.compiler.core.parse.ast.Definition +import community.flock.wirespec.compiler.core.parse.ast.Module +import community.flock.wirespec.compiler.test.CompileChannelTest +import community.flock.wirespec.compiler.test.CompileComplexModelTest +import community.flock.wirespec.compiler.test.CompileEnumTest +import community.flock.wirespec.compiler.test.CompileFullEndpointTest +import community.flock.wirespec.compiler.test.CompileMinimalEndpointTest +import community.flock.wirespec.compiler.test.CompileNestedTypeTest +import community.flock.wirespec.compiler.test.CompileRefinedTest +import community.flock.wirespec.compiler.test.CompileTypeTest +import community.flock.wirespec.compiler.test.CompileUnionTest +import community.flock.wirespec.compiler.test.NodeFixtures +import community.flock.wirespec.compiler.utils.NoLogger +import io.kotest.assertions.arrow.core.shouldBeRight +import io.kotest.matchers.shouldBe +import kotlin.test.Test + +class KotlinIrEmitterTest { + + private val emitContext = object : EmitContext, NoLogger { + override val emitters = nonEmptySetOf(KotlinIrEmitter()) + } + + @Test + fun testEmitterType() { + val expected = listOf( + """ + |package community.flock.wirespec.generated.model + |data class Todo( + | val name: String, + | val description: String?, + | val notes: List, + | val done: Boolean + |) : Wirespec.Model { + | override fun validate(): List = + | emptyList() + |} + | + """.trimMargin(), + ) + + val res = emitContext.emitFirst(NodeFixtures.type) + res shouldBe expected + } + + @Test + fun testEmitterEmptyType() { + val expected = listOf( + """ + |package community.flock.wirespec.generated.model + |data object TodoWithoutProperties : Wirespec.Model { + | override fun validate(): List = + | emptyList() + |} + | + """.trimMargin(), + ) + + val res = emitContext.emitFirst(NodeFixtures.emptyType) + res shouldBe expected + } + + @Test + fun testEmitterRefined() { + val expected = listOf( + """ + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.kotlin.Wirespec + |import kotlin.reflect.typeOf + |data class UUID( + | override val value: String + |) : Wirespec.Refined { + | override fun toString(): String = + | value + | override fun validate(): Boolean = + | Regex(${"\"\"\""}^[0-9a-fA-F]{8}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{12}${'$'}${"\"\"\""}).matches(value) + |} + | + """.trimMargin(), + ) + + val res = emitContext.emitFirst(NodeFixtures.refined) + res shouldBe expected + } + + @Test + fun testEmitterEnum() { + val expected = listOf( + """ + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.kotlin.Wirespec + |import kotlin.reflect.typeOf + |enum class TodoStatus (override val label: String): Wirespec.Enum { + | OPEN("OPEN"), + | IN_PROGRESS("IN_PROGRESS"), + | CLOSE("CLOSE"); + | override fun toString(): String { + | return label + | } + |} + | + """.trimMargin(), + ) + + val res = emitContext.emitFirst(NodeFixtures.enum) + res shouldBe expected + } + + @Test + fun compileFullEndpointTest() { + val kotlin = """ + |package community.flock.wirespec.generated.endpoint + |import community.flock.wirespec.kotlin.Wirespec + |import kotlin.reflect.typeOf + |import community.flock.wirespec.generated.model.Token + |import community.flock.wirespec.generated.model.PotentialTodoDto + |import community.flock.wirespec.generated.model.TodoDto + |import community.flock.wirespec.generated.model.Error + |object PutTodo : Wirespec.Endpoint { + | data class Path( + | val id: String + | ) : Wirespec.Path + | data class Queries( + | val done: Boolean, + | val name: String? + | ) : Wirespec.Queries + | data class RequestHeaders( + | val token: Token, + | val refreshToken: Token? + | ) : Wirespec.Request.Headers + | data class Request( + | override val path: Path, + | override val method: Wirespec.Method, + | override val queries: Queries, + | override val headers: RequestHeaders, + | override val body: PotentialTodoDto + | ) : Wirespec.Request { + | constructor(id: String, done: Boolean, name: String?, token: Token, refreshToken: Token?, body: PotentialTodoDto) : this(Path(id = id), Wirespec.Method.PUT, Queries( + | done = done, + | name = name + | ), RequestHeaders( + | token = token, + | refreshToken = refreshToken + | ), body) + | } + | sealed interface Response : Wirespec.Response + | sealed interface Response2XX : Response + | sealed interface Response5XX : Response + | sealed interface ResponseTodoDto : Response + | sealed interface ResponseError : Response + | data class Response200( + | override val status: Int, + | override val headers: Headers, + | override val body: TodoDto + | ) : Response2XX, ResponseTodoDto { + | constructor(body: TodoDto) : this(200, Headers, body) + | object Headers : Wirespec.Response.Headers + | } + | data class Response201( + | override val status: Int, + | override val headers: Headers, + | override val body: TodoDto + | ) : Response2XX, ResponseTodoDto { + | constructor(token: Token, refreshToken: Token?, body: TodoDto) : this(201, Headers( + | token = token, + | refreshToken = refreshToken + | ), body) + | data class Headers( + | val token: Token, + | val refreshToken: Token? + | ) : Wirespec.Response.Headers + | } + | data class Response500( + | override val status: Int, + | override val headers: Headers, + | override val body: Error + | ) : Response5XX, ResponseError { + | constructor(body: Error) : this(500, Headers, body) + | object Headers : Wirespec.Response.Headers + | } + | fun toRawRequest(serialization: Wirespec.Serializer, request: Request): Wirespec.RawRequest = + | Wirespec.RawRequest( + | method = request.method.name, + | path = listOf("todos", serialization.serializePath(request.path.id, typeOf())), + | queries = mapOf("done" to serialization.serializeParam(request.queries.done, typeOf()), "name" to (request.queries.name?.let { serialization.serializeParam(it, typeOf()) } ?: emptyList())), + | headers = mapOf("token" to serialization.serializeParam(request.headers.token, typeOf()), "Refresh-Token" to (request.headers.refreshToken?.let { serialization.serializeParam(it, typeOf()) } ?: emptyList())), + | body = serialization.serializeBody(request.body, typeOf()) + | ) + | fun fromRawRequest(serialization: Wirespec.Deserializer, request: Wirespec.RawRequest): Request = + | Request( + | id = serialization.deserializePath(request.path[1], typeOf()), + | done = (request.queries["done"]?.let { serialization.deserializeParam(it, typeOf()) } ?: error("Param done cannot be null")), + | name = (request.queries["name"]?.let { serialization.deserializeParam(it, typeOf()) }), + | token = (request.headers.entries.find { it.key.equals("token", ignoreCase = true) }?.value?.let { serialization.deserializeParam(it, typeOf()) } ?: error("Param token cannot be null")), + | refreshToken = (request.headers.entries.find { it.key.equals("Refresh-Token", ignoreCase = true) }?.value?.let { serialization.deserializeParam(it, typeOf()) }), + | body = (request.body?.let { serialization.deserializeBody(it, typeOf()) } ?: error("body is null")) + | ) + | fun toRawResponse(serialization: Wirespec.Serializer, response: Response<*>): Wirespec.RawResponse { + | when(val r = response) { + | is Response200 -> { + | return Wirespec.RawResponse( + | statusCode = r.status, + | headers = emptyMap(), + | body = serialization.serializeBody(r.body, typeOf()) + | ) + | } + | is Response201 -> { + | return Wirespec.RawResponse( + | statusCode = r.status, + | headers = mapOf("token" to serialization.serializeParam(r.headers.token, typeOf()), "refreshToken" to (r.headers.refreshToken?.let { serialization.serializeParam(it, typeOf()) } ?: emptyList())), + | body = serialization.serializeBody(r.body, typeOf()) + | ) + | } + | is Response500 -> { + | return Wirespec.RawResponse( + | statusCode = r.status, + | headers = emptyMap(), + | body = serialization.serializeBody(r.body, typeOf()) + | ) + | } + | else -> { + | error(("Cannot match response with status: " + response.status)) + | } + | } + | } + | fun fromRawResponse(serialization: Wirespec.Deserializer, response: Wirespec.RawResponse): Response<*> { + | when (response.statusCode) { + | 200 -> { + | return Response200(body = (response.body?.let { serialization.deserializeBody(it, typeOf()) } ?: error("body is null"))) + | } + | 201 -> { + | return Response201( + | token = (response.headers.entries.find { it.key.equals("token", ignoreCase = true) }?.value?.let { serialization.deserializeParam(it, typeOf()) } ?: error("Param token cannot be null")), + | refreshToken = (response.headers.entries.find { it.key.equals("refreshToken", ignoreCase = true) }?.value?.let { serialization.deserializeParam(it, typeOf()) }), + | body = (response.body?.let { serialization.deserializeBody(it, typeOf()) } ?: error("body is null")) + | ) + | } + | 500 -> { + | return Response500(body = (response.body?.let { serialization.deserializeBody(it, typeOf()) } ?: error("body is null"))) + | } + | else -> { + | error(("Cannot match response with status: " + response.statusCode)) + | } + | } + | } + | interface Handler : Wirespec.Handler { + | suspend fun putTodo(request: Request): Response<*> + | companion object: Wirespec.Server>, Wirespec.Client> { + | override val pathTemplate = "/todos/{id}" + | override val method = "PUT" + | override fun server(serialization: Wirespec.Serialization) = object : Wirespec.ServerEdge> { + | override fun from(request: Wirespec.RawRequest) = fromRawRequest(serialization, request) + | override fun to(response: Response<*>) = toRawResponse(serialization, response) + | } + | override fun client(serialization: Wirespec.Serialization) = object : Wirespec.ClientEdge> { + | override fun to(request: Request) = toRawRequest(serialization, request) + | override fun from(response: Wirespec.RawResponse) = fromRawResponse(serialization, response) + | } + | } + | } + | interface Call : Wirespec.Call { + | suspend fun putTodo(id: String, done: Boolean, name: String?, token: Token, refreshToken: Token?, body: PotentialTodoDto): Response<*> + | } + |} + | + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.kotlin.Wirespec + |import kotlin.reflect.typeOf + |data class PotentialTodoDto( + | val name: String, + | val done: Boolean + |) : Wirespec.Model { + | override fun validate(): List = + | emptyList() + |} + | + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.kotlin.Wirespec + |import kotlin.reflect.typeOf + |data class Token( + | val iss: String + |) : Wirespec.Model { + | override fun validate(): List = + | emptyList() + |} + | + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.kotlin.Wirespec + |import kotlin.reflect.typeOf + |data class TodoDto( + | val id: String, + | val name: String, + | val done: Boolean + |) : Wirespec.Model { + | override fun validate(): List = + | emptyList() + |} + | + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.kotlin.Wirespec + |import kotlin.reflect.typeOf + |data class Error( + | val code: Long, + | val description: String + |) : Wirespec.Model { + | override fun validate(): List = + | emptyList() + |} + | + |package community.flock.wirespec.generated.client + |import community.flock.wirespec.kotlin.Wirespec + |import kotlin.reflect.typeOf + |import community.flock.wirespec.generated.model.Token + |import community.flock.wirespec.generated.model.PotentialTodoDto + |import community.flock.wirespec.generated.model.TodoDto + |import community.flock.wirespec.generated.model.Error + |import community.flock.wirespec.generated.endpoint.PutTodo + |data class PutTodoClient( + | val serialization: Wirespec.Serialization, + | val transportation: Wirespec.Transportation + |) : PutTodo.Call { + | override suspend fun putTodo(id: String, done: Boolean, name: String?, token: Token, refreshToken: Token?, body: PotentialTodoDto): PutTodo.Response<*> { + | val request = PutTodo.Request( + | id = id, + | done = done, + | name = name, + | token = token, + | refreshToken = refreshToken, + | body = body + | ) + | val rawRequest = PutTodo.toRawRequest(serialization, request) + | val rawResponse = transportation.transport(rawRequest) + | return PutTodo.fromRawResponse(serialization, rawResponse) + | } + |} + | + |package community.flock.wirespec.generated + |import community.flock.wirespec.kotlin.Wirespec + |import kotlin.reflect.typeOf + |import community.flock.wirespec.generated.model.Token + |import community.flock.wirespec.generated.model.PotentialTodoDto + |import community.flock.wirespec.generated.model.TodoDto + |import community.flock.wirespec.generated.model.Error + |import community.flock.wirespec.generated.endpoint.PutTodo + |import community.flock.wirespec.generated.client.PutTodoClient + |data class Client( + | val serialization: Wirespec.Serialization, + | val transportation: Wirespec.Transportation + |) : PutTodo.Call { + | override suspend fun putTodo(id: String, done: Boolean, name: String?, token: Token, refreshToken: Token?, body: PotentialTodoDto): PutTodo.Response<*> = + | PutTodoClient( + | serialization = serialization, + | transportation = transportation + | ).putTodo(id, done, name, token, refreshToken, body) + |} + | + """.trimMargin() + + CompileFullEndpointTest.compiler { KotlinIrEmitter() } shouldBeRight kotlin + } + + @Test + fun compileChannelTest() { + val kotlin = """ + |package community.flock.wirespec.generated.channel + |interface Queue : Wirespec.Channel { + | fun invoke(message: String) + |} + | + """.trimMargin() + + CompileChannelTest.compiler { KotlinIrEmitter() } shouldBeRight kotlin + } + + @Test + fun compileEnumTest() { + val kotlin = """ + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.kotlin.Wirespec + |import kotlin.reflect.typeOf + |enum class MyAwesomeEnum (override val label: String): Wirespec.Enum { + | ONE("ONE"), + | Two("Two"), + | THREE_MORE("THREE_MORE"), + | UnitedKingdom("UnitedKingdom"), + | _1("-1"), + | _0("0"), + | _10("10"), + | _999("-999"), + | _88("88"); + | override fun toString(): String { + | return label + | } + |} + | + """.trimMargin() + + CompileEnumTest.compiler { KotlinIrEmitter() } shouldBeRight kotlin + } + + @Test + fun compileMinimalEndpointTest() { + val kotlin = """ + |package community.flock.wirespec.generated.endpoint + |import community.flock.wirespec.kotlin.Wirespec + |import kotlin.reflect.typeOf + |import community.flock.wirespec.generated.model.TodoDto + |object GetTodos : Wirespec.Endpoint { + | object Path : Wirespec.Path + | object Queries : Wirespec.Queries + | object RequestHeaders : Wirespec.Request.Headers + | data object Request : Wirespec.Request { + | override val path: Path = Path + | override val method: Wirespec.Method = Wirespec.Method.GET + | override val queries: Queries = Queries + | override val headers: RequestHeaders = RequestHeaders + | override val body: Unit = Unit } + | sealed interface Response : Wirespec.Response + | sealed interface Response2XX : Response + | sealed interface ResponseListTodoDto : Response> + | data class Response200( + | override val status: Int, + | override val headers: Headers, + | override val body: List + | ) : Response2XX>, ResponseListTodoDto { + | constructor(body: List) : this(200, Headers, body) + | object Headers : Wirespec.Response.Headers + | } + | fun toRawRequest(serialization: Wirespec.Serializer, request: Request): Wirespec.RawRequest = + | Wirespec.RawRequest( + | method = request.method.name, + | path = listOf("todos"), + | queries = emptyMap(), + | headers = emptyMap(), + | body = null + | ) + | fun fromRawRequest(serialization: Wirespec.Deserializer, request: Wirespec.RawRequest): Request = + | Request + | fun toRawResponse(serialization: Wirespec.Serializer, response: Response<*>): Wirespec.RawResponse { + | when(val r = response) { + | is Response200 -> { + | return Wirespec.RawResponse( + | statusCode = r.status, + | headers = emptyMap(), + | body = serialization.serializeBody(r.body, typeOf>()) + | ) + | } + | else -> { + | error(("Cannot match response with status: " + response.status)) + | } + | } + | } + | fun fromRawResponse(serialization: Wirespec.Deserializer, response: Wirespec.RawResponse): Response<*> { + | when (response.statusCode) { + | 200 -> { + | return Response200(body = (response.body?.let { serialization.deserializeBody>(it, typeOf>()) } ?: error("body is null"))) + | } + | else -> { + | error(("Cannot match response with status: " + response.statusCode)) + | } + | } + | } + | interface Handler : Wirespec.Handler { + | suspend fun getTodos(request: Request): Response<*> + | companion object: Wirespec.Server>, Wirespec.Client> { + | override val pathTemplate = "/todos" + | override val method = "GET" + | override fun server(serialization: Wirespec.Serialization) = object : Wirespec.ServerEdge> { + | override fun from(request: Wirespec.RawRequest) = fromRawRequest(serialization, request) + | override fun to(response: Response<*>) = toRawResponse(serialization, response) + | } + | override fun client(serialization: Wirespec.Serialization) = object : Wirespec.ClientEdge> { + | override fun to(request: Request) = toRawRequest(serialization, request) + | override fun from(response: Wirespec.RawResponse) = fromRawResponse(serialization, response) + | } + | } + | } + | interface Call : Wirespec.Call { + | suspend fun getTodos(): Response<*> + | } + |} + | + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.kotlin.Wirespec + |import kotlin.reflect.typeOf + |data class TodoDto( + | val description: String + |) : Wirespec.Model { + | override fun validate(): List = + | emptyList() + |} + | + |package community.flock.wirespec.generated.client + |import community.flock.wirespec.kotlin.Wirespec + |import kotlin.reflect.typeOf + |import community.flock.wirespec.generated.model.TodoDto + |import community.flock.wirespec.generated.endpoint.GetTodos + |data class GetTodosClient( + | val serialization: Wirespec.Serialization, + | val transportation: Wirespec.Transportation + |) : GetTodos.Call { + | override suspend fun getTodos(): GetTodos.Response<*> { + | val request = GetTodos.Request + | val rawRequest = GetTodos.toRawRequest(serialization, request) + | val rawResponse = transportation.transport(rawRequest) + | return GetTodos.fromRawResponse(serialization, rawResponse) + | } + |} + | + |package community.flock.wirespec.generated + |import community.flock.wirespec.kotlin.Wirespec + |import kotlin.reflect.typeOf + |import community.flock.wirespec.generated.model.TodoDto + |import community.flock.wirespec.generated.endpoint.GetTodos + |import community.flock.wirespec.generated.client.GetTodosClient + |data class Client( + | val serialization: Wirespec.Serialization, + | val transportation: Wirespec.Transportation + |) : GetTodos.Call { + | override suspend fun getTodos(): GetTodos.Response<*> = + | GetTodosClient( + | serialization = serialization, + | transportation = transportation + | ).getTodos() + |} + | + """.trimMargin() + + CompileMinimalEndpointTest.compiler { KotlinIrEmitter() } shouldBeRight kotlin + } + + @Test + fun compileRefinedTest() { + val kotlin = """ + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.kotlin.Wirespec + |import kotlin.reflect.typeOf + |data class TodoId( + | override val value: String + |) : Wirespec.Refined { + | override fun toString(): String = + | value + | override fun validate(): Boolean = + | Regex(${"\"\"\""}^[0-9a-fA-F]{8}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{12}${'$'}${"\"\"\""}).matches(value) + |} + | + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.kotlin.Wirespec + |import kotlin.reflect.typeOf + |data class TodoNoRegex( + | override val value: String + |) : Wirespec.Refined { + | override fun toString(): String = + | value + | override fun validate(): Boolean = + | true + |} + | + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.kotlin.Wirespec + |import kotlin.reflect.typeOf + |data class TestInt( + | override val value: Long + |) : Wirespec.Refined { + | override fun toString(): String = + | value.toString() + | override fun validate(): Boolean = + | true + |} + | + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.kotlin.Wirespec + |import kotlin.reflect.typeOf + |data class TestInt0( + | override val value: Long + |) : Wirespec.Refined { + | override fun toString(): String = + | value.toString() + | override fun validate(): Boolean = + | true + |} + | + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.kotlin.Wirespec + |import kotlin.reflect.typeOf + |data class TestInt1( + | override val value: Long + |) : Wirespec.Refined { + | override fun toString(): String = + | value.toString() + | override fun validate(): Boolean = + | 0 <= value + |} + | + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.kotlin.Wirespec + |import kotlin.reflect.typeOf + |data class TestInt2( + | override val value: Long + |) : Wirespec.Refined { + | override fun toString(): String = + | value.toString() + | override fun validate(): Boolean = + | 1 <= value && value <= 3 + |} + | + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.kotlin.Wirespec + |import kotlin.reflect.typeOf + |data class TestNum( + | override val value: Double + |) : Wirespec.Refined { + | override fun toString(): String = + | value.toString() + | override fun validate(): Boolean = + | true + |} + | + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.kotlin.Wirespec + |import kotlin.reflect.typeOf + |data class TestNum0( + | override val value: Double + |) : Wirespec.Refined { + | override fun toString(): String = + | value.toString() + | override fun validate(): Boolean = + | true + |} + | + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.kotlin.Wirespec + |import kotlin.reflect.typeOf + |data class TestNum1( + | override val value: Double + |) : Wirespec.Refined { + | override fun toString(): String = + | value.toString() + | override fun validate(): Boolean = + | value <= 0.5 + |} + | + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.kotlin.Wirespec + |import kotlin.reflect.typeOf + |data class TestNum2( + | override val value: Double + |) : Wirespec.Refined { + | override fun toString(): String = + | value.toString() + | override fun validate(): Boolean = + | -0.2 <= value && value <= 0.5 + |} + | + """.trimMargin() + + CompileRefinedTest.compiler { KotlinIrEmitter() } shouldBeRight kotlin + } + + @Test + fun compileUnionTest() { + val kotlin = """ + |package community.flock.wirespec.generated.model + |sealed interface UserAccount + | + |package community.flock.wirespec.generated.model + |data class UserAccountPassword( + | val username: String, + | val password: String + |) : Wirespec.Model, UserAccount { + | override fun validate(): List = + | emptyList() + |} + | + |package community.flock.wirespec.generated.model + |data class UserAccountToken( + | val token: String + |) : Wirespec.Model, UserAccount { + | override fun validate(): List = + | emptyList() + |} + | + |package community.flock.wirespec.generated.model + |data class User( + | val username: String, + | val account: UserAccount + |) : Wirespec.Model { + | override fun validate(): List = + | emptyList() + |} + | + """.trimMargin() + + CompileUnionTest.compiler { KotlinIrEmitter() } shouldBeRight kotlin + } + + @Test + fun compileTypeTest() { + val kotlin = """ + |package community.flock.wirespec.generated.model + |data class Request( + | val type: String, + | val url: String, + | val BODY_TYPE: String?, + | val params: List, + | val headers: Map, + | val body: Map?>? + |) : Wirespec.Model { + | override fun validate(): List = + | emptyList() + |} + | + """.trimMargin() + + CompileTypeTest.compiler { KotlinIrEmitter() } shouldBeRight kotlin + } + + @Test + fun compileNestedTypeTest() { + val kotlin = """ + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.kotlin.Wirespec + |import kotlin.reflect.typeOf + |data class DutchPostalCode( + | override val value: String + |) : Wirespec.Refined { + | override fun toString(): String = + | value + | override fun validate(): Boolean = + | Regex(${"\"\"\""}^([0-9]{4}[A-Z]{2})${'$'}${"\"\"\""}).matches(value) + |} + | + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.kotlin.Wirespec + |import kotlin.reflect.typeOf + |data class Address( + | val street: String, + | val houseNumber: Long, + | val postalCode: DutchPostalCode + |) : Wirespec.Model { + | override fun validate(): List = + | if (!postalCode.validate()) listOf("postalCode") else emptyList() + |} + | + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.kotlin.Wirespec + |import kotlin.reflect.typeOf + |data class Person( + | val name: String, + | val address: Address, + | val tags: List + |) : Wirespec.Model { + | override fun validate(): List = + | address.validate().map { e -> "address.${'$'}{e}" } + |} + | + """.trimMargin() + + CompileNestedTypeTest.compiler { KotlinIrEmitter() } shouldBeRight kotlin + } + + @Test + fun compileComplexModelTest() { + val kotlin = """ + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.kotlin.Wirespec + |import kotlin.reflect.typeOf + |data class Email( + | override val value: String + |) : Wirespec.Refined { + | override fun toString(): String = + | value + | override fun validate(): Boolean = + | Regex(${"\"\"\""}^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}${'$'}${"\"\"\""}).matches(value) + |} + | + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.kotlin.Wirespec + |import kotlin.reflect.typeOf + |data class PhoneNumber( + | override val value: String + |) : Wirespec.Refined { + | override fun toString(): String = + | value + | override fun validate(): Boolean = + | Regex(${"\"\"\""}^\+[1-9]\d{1,14}${'$'}${"\"\"\""}).matches(value) + |} + | + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.kotlin.Wirespec + |import kotlin.reflect.typeOf + |data class Tag( + | override val value: String + |) : Wirespec.Refined { + | override fun toString(): String = + | value + | override fun validate(): Boolean = + | Regex(${"\"\"\""}^[a-z][a-z0-9-]{0,19}${'$'}${"\"\"\""}).matches(value) + |} + | + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.kotlin.Wirespec + |import kotlin.reflect.typeOf + |data class EmployeeAge( + | override val value: Long + |) : Wirespec.Refined { + | override fun toString(): String = + | value.toString() + | override fun validate(): Boolean = + | 18 <= value && value <= 65 + |} + | + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.kotlin.Wirespec + |import kotlin.reflect.typeOf + |data class ContactInfo( + | val email: Email, + | val phone: PhoneNumber? + |) : Wirespec.Model { + | override fun validate(): List = + | (if (!email.validate()) listOf("email") else emptyList()) + (phone?.let { if (!it.validate()) listOf("phone") else emptyList() } ?: emptyList()) + |} + | + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.kotlin.Wirespec + |import kotlin.reflect.typeOf + |data class Employee( + | val name: String, + | val age: EmployeeAge, + | val contactInfo: ContactInfo, + | val tags: List + |) : Wirespec.Model { + | override fun validate(): List = + | (if (!age.validate()) listOf("age") else emptyList()) + contactInfo.validate().map { e -> "contactInfo.${'$'}{e}" } + tags.flatMapIndexed { i, el -> if (!el.validate()) listOf("tags[${'$'}{i}]") else emptyList() } + |} + | + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.kotlin.Wirespec + |import kotlin.reflect.typeOf + |data class Department( + | val name: String, + | val employees: List + |) : Wirespec.Model { + | override fun validate(): List = + | employees.flatMapIndexed { i, el -> el.validate().map { e -> "employees[${'$'}{i}].${'$'}{e}" } } + |} + | + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.kotlin.Wirespec + |import kotlin.reflect.typeOf + |data class Company( + | val name: String, + | val departments: List + |) : Wirespec.Model { + | override fun validate(): List = + | departments.flatMapIndexed { i, el -> el.validate().map { e -> "departments[${'$'}{i}].${'$'}{e}" } } + |} + | + """.trimMargin() + + CompileComplexModelTest.compiler { KotlinIrEmitter() } shouldBeRight kotlin + } + + @Test + fun sharedOutputTest() { + val expected = """ + |package community.flock.wirespec.kotlin + |import kotlin.reflect.KType + |object Wirespec { + | interface Model { + | fun validate(): List + | } + | interface Enum { + | val label: String + | } + | interface Endpoint + | interface Channel + | interface Refined { + | val value: T + | fun validate(): Boolean + | } + | interface Path + | interface Queries + | interface Headers + | interface Handler + | interface Call + | enum class Method { + | GET, + | PUT, + | POST, + | DELETE, + | OPTIONS, + | HEAD, + | PATCH, + | TRACE + | } interface Request { + | val path: Path + | val method: Method + | val queries: Queries + | val headers: Headers + | val body: T + | interface Headers + | } + | interface Response { + | val status: Int + | val headers: Headers + | val body: T + | interface Headers + | } + | interface BodySerializer { + | fun serializeBody(t: T, type: KType): ByteArray + | } + | interface BodyDeserializer { + | fun deserializeBody(raw: ByteArray, type: KType): T + | } + | interface BodySerialization : BodySerializer, BodyDeserializer + | interface PathSerializer { + | fun serializePath(t: T, type: KType): String + | } + | interface PathDeserializer { + | fun deserializePath(raw: String, type: KType): T + | } + | interface PathSerialization : PathSerializer, PathDeserializer + | interface ParamSerializer { + | fun serializeParam(value: T, type: KType): List + | } + | interface ParamDeserializer { + | fun deserializeParam(values: List, type: KType): T + | } + | interface ParamSerialization : ParamSerializer, ParamDeserializer + | interface Serializer : BodySerializer, PathSerializer, ParamSerializer + | interface Deserializer : BodyDeserializer, PathDeserializer, ParamDeserializer + | interface Serialization : Serializer, Deserializer + | data class RawRequest( + | val method: String, + | val path: List, + | val queries: Map>, + | val headers: Map>, + | val body: ByteArray? + | ) + | data class RawResponse( + | val statusCode: Int, + | val headers: Map>, + | val body: ByteArray? + | ) + | interface Transportation { + | suspend fun transport(request: RawRequest): RawResponse + | } + | interface ServerEdge, Res: Response<*>> { + | fun from(request: RawRequest): Req + | fun to(response: Res): RawResponse + | } + | interface ClientEdge, Res: Response<*>> { + | fun to(request: Req): RawRequest + | fun from(response: RawResponse): Res + | } + | interface Client, Res: Response<*>> { + | val pathTemplate: String + | val method: String + | fun client(serialization: Serialization): ClientEdge + | } + | interface Server, Res: Response<*>> { + | val pathTemplate: String + | val method: String + | fun server(serialization: Serialization): ServerEdge + | } + |} + | + """.trimMargin() + + val emitter = KotlinIrEmitter() + emitter.shared.source shouldBe expected + } + + private fun EmitContext.emitFirst(node: Definition) = emitters.map { + val ast = AST( + nonEmptyListOf( + Module( + FileUri(""), + nonEmptyListOf(node), + ), + ), + ) + it.emit(ast, logger).first().result + } +} diff --git a/src/compiler/emitters/python/build.gradle.kts b/src/compiler/emitters/python/build.gradle.kts index 217e6b788..68b5de636 100644 --- a/src/compiler/emitters/python/build.gradle.kts +++ b/src/compiler/emitters/python/build.gradle.kts @@ -41,6 +41,7 @@ kotlin { commonMain { dependencies { api(project(":src:compiler:core")) + api(project(":src:compiler:ir")) } } commonTest { diff --git a/src/compiler/emitters/python/src/commonMain/kotlin/community/flock/wirespec/emitters/python/PythonIrEmitter.kt b/src/compiler/emitters/python/src/commonMain/kotlin/community/flock/wirespec/emitters/python/PythonIrEmitter.kt new file mode 100644 index 000000000..6e88ceab9 --- /dev/null +++ b/src/compiler/emitters/python/src/commonMain/kotlin/community/flock/wirespec/emitters/python/PythonIrEmitter.kt @@ -0,0 +1,430 @@ +package community.flock.wirespec.emitters.python + +import arrow.core.NonEmptyList +import arrow.core.toNonEmptyListOrNull +import community.flock.wirespec.compiler.core.emit.DEFAULT_GENERATED_PACKAGE_STRING +import community.flock.wirespec.compiler.core.emit.EmitShared +import community.flock.wirespec.compiler.core.emit.FileExtension +import community.flock.wirespec.ir.emit.IrEmitter +import community.flock.wirespec.compiler.core.emit.Keywords +import community.flock.wirespec.compiler.core.emit.LanguageEmitter.Companion.firstToUpper +import community.flock.wirespec.compiler.core.emit.PackageName +import community.flock.wirespec.compiler.core.emit.Shared +import community.flock.wirespec.compiler.core.emit.importReferences +import community.flock.wirespec.compiler.core.emit.plus +import community.flock.wirespec.compiler.core.parse.ast.Channel +import community.flock.wirespec.compiler.core.parse.ast.Definition +import community.flock.wirespec.compiler.core.parse.ast.DefinitionIdentifier +import community.flock.wirespec.compiler.core.parse.ast.Endpoint +import community.flock.wirespec.compiler.core.parse.ast.Enum +import community.flock.wirespec.compiler.core.parse.ast.FieldIdentifier +import community.flock.wirespec.compiler.core.parse.ast.Identifier +import community.flock.wirespec.compiler.core.parse.ast.Model +import community.flock.wirespec.compiler.core.parse.ast.Module +import community.flock.wirespec.compiler.core.parse.ast.Reference +import community.flock.wirespec.compiler.core.parse.ast.Refined +import community.flock.wirespec.compiler.core.parse.ast.Type +import community.flock.wirespec.compiler.core.parse.ast.Union +import community.flock.wirespec.compiler.utils.Logger +import community.flock.wirespec.ir.converter.convert +import community.flock.wirespec.ir.converter.convertConstraint +import community.flock.wirespec.ir.converter.convertWithValidation +import community.flock.wirespec.ir.core.ConstructorStatement +import community.flock.wirespec.ir.core.Element +import community.flock.wirespec.ir.core.FieldCall +import community.flock.wirespec.ir.core.File +import community.flock.wirespec.ir.core.FunctionCall +import community.flock.wirespec.ir.core.Import +import community.flock.wirespec.ir.core.Name +import community.flock.wirespec.ir.core.Interface +import community.flock.wirespec.ir.core.Parameter +import community.flock.wirespec.ir.core.RawElement +import community.flock.wirespec.ir.core.RawExpression +import community.flock.wirespec.ir.core.Namespace +import community.flock.wirespec.ir.core.Struct +import community.flock.wirespec.ir.core.VariableReference +import community.flock.wirespec.ir.core.findElement +import community.flock.wirespec.ir.core.flattenNestedStructs +import community.flock.wirespec.ir.core.function +import community.flock.wirespec.ir.core.transform +import community.flock.wirespec.ir.core.transformChildren +import community.flock.wirespec.ir.generator.PythonGenerator +import community.flock.wirespec.ir.generator.generatePython +import community.flock.wirespec.compiler.core.parse.ast.Shared as AstShared +import community.flock.wirespec.ir.core.Enum as LanguageEnum +import community.flock.wirespec.ir.core.Function as LanguageFunction +import community.flock.wirespec.ir.core.File as LanguageFile +import community.flock.wirespec.ir.core.Type as LanguageType +import community.flock.wirespec.ir.core.Union as LanguageUnion + +open class PythonIrEmitter( + private val packageName: PackageName = PackageName(DEFAULT_GENERATED_PACKAGE_STRING), + private val emitShared: EmitShared = EmitShared() +) : IrEmitter { + + override val generator = PythonGenerator + + override val extension = FileExtension.Python + + private val sharedSource = """ + |from __future__ import annotations + | + |import enum + |from abc import ABC, abstractmethod + |from dataclasses import dataclass + |from typing import Any, Generic, Optional, Type, TypeVar + | + |T = TypeVar('T') + | + | + |def _raise(msg: str) -> Any: + | raise Exception(msg) + | + | + """.trimMargin() + + override val shared = object : Shared { + override val packageString = "shared" + override val source = sharedSource + AstShared(packageString).convert() + .generatePython() + } + + override fun emit(module: Module, logger: Logger): NonEmptyList { + val statements = module.statements.sortedBy(::sort).toNonEmptyListOrNull()!! + return super.emit(module.copy(statements = statements), logger).let { + fun emitInitImport(def: Definition) = Import(".${def.identifier.sanitize()}", LanguageType.Custom(def.identifier.sanitize())) + val hasEndpoints = module.statements.any { it is Endpoint } + val initElements: List = listOf( + Import(".", LanguageType.Custom("model")), + Import(".", LanguageType.Custom("endpoint")), + ) + (if (hasEndpoints) listOf(Import(".", LanguageType.Custom("client"))) else emptyList()) + + listOf(Import(".", LanguageType.Custom("wirespec"))) + val init = File( + Name.of(packageName.toDir() + "__init__"), + initElements + ) + val initEndpoint = File( + Name.of(packageName.toDir() + "endpoint/" + "__init__"), + module.statements.filter { it is Endpoint }.map { stmt -> emitInitImport(stmt) } + ) + val initModel = File( + Name.of(packageName.toDir() + "model/" + "__init__"), + module.statements.filter { it is Model }.map { stmt -> emitInitImport(stmt) } + ) + val initClient = if (hasEndpoints) listOf(File( + Name.of(packageName.toDir() + "client/" + "__init__"), + emptyList() + )) else emptyList() + val shared = File(Name.of(packageName.toDir() + "wirespec"), listOf(RawElement(shared.source))) + val parentInits = packageName.value.split(".") + .dropLast(1) + .runningFold("") { acc, segment -> if (acc.isEmpty()) segment else "$acc/$segment" } + .drop(1) + .map { File(Name.of("$it/__init__"), emptyList()) } + if (emitShared.value) + it + init + initEndpoint + initModel + initClient + shared + parentInits + else + it + init + parentInits + } + } + + override fun emit(definition: Definition, module: Module, logger: Logger): File { + val file = super.emit(definition, module, logger) + val subPackageName = packageName + definition + return File( + name = Name.of(subPackageName.toDir() + file.name.pascalCase()), + elements = buildImports("..wirespec") + file.elements + ) + } + + override fun emit(type: Type, module: Module): File { + val typeImports = type.importReferences().distinctBy { it.value } + .map { Import(".${it.value}", LanguageType.Custom(it.value)) } + val fieldNames = type.shape.value.map { it.identifier.value }.toSet() + val file = type.convertWithValidation(module) + .transform { + matchingElements { fn: LanguageFunction -> + if (fn.name == Name.of("validate")) { + fn.copy( + parameters = listOf(Parameter(Name.of("self"), LanguageType.Custom(""))), + ).transform { + statementAndExpression { s, t -> + if (s is FieldCall && s.receiver == null && s.field.camelCase() in fieldNames) { + FieldCall(receiver = VariableReference(Name.of("self")), field = s.field) + } else { + s.transformChildren(t) + } + } + } + } else fn + } + } + .sanitizeNames() + return if (typeImports.isNotEmpty()) file.copy(elements = typeImports + file.elements) + else file + } + + override fun emit(enum: Enum, module: Module): File = enum + .convert() + .transform { + matchingElements { languageEnum: LanguageEnum -> + languageEnum.copy( + entries = languageEnum.entries.map { + LanguageEnum.Entry(Name.of(it.name.value().sanitizeEnum().sanitizeKeywords()), listOf("\"${it.name.value()}\"")) + }, + ) + } + } + .sanitizeNames() + + override fun emit(union: Union): File = + union.convert() + .sanitizeNames() + + override fun emit(refined: Refined): File { + val file = refined.convert() + val struct = file.findElement()!! + val constraintExpr = refined.reference.convertConstraint(FieldCall(VariableReference(Name.of("self")), Name.of("value"))) + val validate = function("validate") { + arg("self", LanguageType.Custom("")) + returnType(LanguageType.Boolean) + returns(constraintExpr) + } + val toStringExpr = when (refined.reference.type) { + is Reference.Primitive.Type.String -> "self.value" + else -> "str(self.value)" + } + val toString = function("__str__") { + arg("self", LanguageType.Custom("")) + returnType(LanguageType.String) + returns(RawExpression(toStringExpr)) + } + return file + .transform { + matchingElements { s: Struct -> + s.copy(elements = listOf(validate, toString)) + } + } + .sanitizeNames() + } + + override fun emit(endpoint: Endpoint): File { + val endpointImports = endpoint.importReferences().distinctBy { it.value } + .map { Import("..model.${it.value}", LanguageType.Custom(it.value)) } + val converted = endpoint.convert().findElement()!! + val flattened = converted.flattenNestedStructs() + val (moduleElements, classElements) = flattened.elements.partition { it is Struct || it is LanguageUnion } + val endpointClass = Namespace( + name = converted.name, + elements = classElements, + extends = converted.extends, + ) + return LanguageFile(converted.name, buildList { + addAll(endpointImports) + addAll(moduleElements) + add(endpointClass) + }) + .sanitizeNames() + .snakeCaseHandlerAndCallMethods() + } + + override fun emit(channel: Channel): File = + channel.convert() + .sanitizeNames() + + override fun emitEndpointClient(endpoint: Endpoint): File { + val modelImports = endpoint.importReferences().distinctBy { it.value } + .map { Import("..model.${it.value}", LanguageType.Custom(it.value)) } + val endpointImport = Import("..endpoint.${endpoint.identifier.value}", LanguageType.Custom("*")) + val endpointName = endpoint.identifier.value + + val file = super.emitEndpointClient(endpoint) + .sanitizeNames() + .addSelfReceiverToClientFields() + .snakeCaseClientFunctions() + .flattenEndpointTypeRefs(endpointName) + + val subPackageName = packageName + "client" + return File( + name = Name.of(subPackageName.toDir() + file.name.pascalCase()), + elements = buildImports("..wirespec") + + modelImports + + listOf(endpointImport) + + file.elements + ) + } + + override fun emitClient(endpoints: List, logger: Logger): File { + val modelImports = endpoints.flatMap { it.importReferences() }.distinctBy { it.value } + .map { Import(".model.${it.value}", LanguageType.Custom(it.value)) } + val endpointImports = endpoints.map { Import(".endpoint.${it.identifier.value}", LanguageType.Custom("*")) } + val clientImports = endpoints.map { Import(".client.${it.identifier.value}Client", LanguageType.Custom("${it.identifier.value}Client")) } + val allImports = modelImports + endpointImports + clientImports + val endpointNames = endpoints.map { it.identifier.value } + + val file = super.emitClient(endpoints, logger) + .sanitizeNames() + .addSelfReceiverToClientFields() + .snakeCaseClientFunctions() + .let { f -> endpointNames.fold(f) { acc, name -> acc.flattenEndpointTypeRefs(name) } } + + return File( + name = Name.of(packageName.toDir() + file.name.pascalCase()), + elements = buildImports(".wirespec") + + allImports + + file.elements + ) + } + + private fun T.sanitizeNames(): T = transform { + fields { field -> + field.copy(name = field.name.sanitizeName()) + } + parameters { param -> + param.copy(name = Name.of(param.name.camelCase().sanitizeKeywords())) + } + statementAndExpression { stmt, tr -> + when (stmt) { + is FieldCall -> FieldCall( + receiver = stmt.receiver?.let { tr.transformExpression(it) }, + field = stmt.field.sanitizeName(), + ) + is ConstructorStatement -> ConstructorStatement( + type = tr.transformType(stmt.type), + namedArguments = stmt.namedArguments + .map { (k, v) -> k.sanitizeName() to tr.transformExpression(v) } + .toMap(), + ) + else -> stmt.transformChildren(tr) + } + } + } + + private fun Name.sanitizeName(): Name { + val sanitized = if (parts.size > 1) camelCase() else value() + return Name(listOf(sanitized.sanitizeKeywords())) + } + + private fun Identifier.sanitize(): String = value + .split(".", " ") + .mapIndexed { index, s -> if (index > 0) s.firstToUpper() else s } + .joinToString("") + .filter { it.isLetterOrDigit() || it == '_' } + .let { if (it.firstOrNull()?.isDigit() == true) "_$it" else it } + .let { if (this is FieldIdentifier) it.sanitizeKeywords() else it } + + private fun String.sanitizeKeywords() = if (this in reservedKeywords) "_$this" else this + + private fun String.sanitizeEnum() = split("-", ", ", ".", " ", "//").joinToString("_") + .let { if (it.firstOrNull()?.isDigit() == true) "_$it" else it } + + private fun sort(definition: Definition) = when (definition) { + is Enum -> 1 + is Refined -> 2 + is Type -> 3 + is Union -> 4 + is Endpoint -> 5 + is Channel -> 6 + } + + private fun buildImports(wirespecPath: String): List = listOf( + Import("__future__", LanguageType.Custom("annotations")), + RawElement("import re"), + Import("abc", LanguageType.Custom("ABC")), + Import("abc", LanguageType.Custom("abstractmethod")), + Import("dataclasses", LanguageType.Custom("dataclass")), + Import("typing", LanguageType.Custom("Any")), + Import("typing", LanguageType.Custom("Generic")), + Import("typing", LanguageType.Custom("List")), + Import("typing", LanguageType.Custom("Optional")), + RawElement("import enum"), + Import(wirespecPath, LanguageType.Custom("T")), + Import(wirespecPath, LanguageType.Custom("Wirespec")), + Import(wirespecPath, LanguageType.Custom("_raise")), + ) + + private fun T.snakeCaseHandlerAndCallMethods(): T = transform { + matchingElements { iface: Interface -> + if (iface.name == Name.of("Handler") || iface.name == Name.of("Call")) { + iface.copy( + elements = iface.elements.map { element -> + if (element is LanguageFunction) { + element.copy(name = Name.of(element.name.snakeCase())) + } else element + }, + ) + } else iface + } + } + + private fun T.flattenEndpointTypeRefs(endpointName: String): T = transform { + type { type, _ -> + if (type is LanguageType.Custom && type.name.startsWith("$endpointName.")) { + val suffix = type.name.removePrefix("$endpointName.") + if (suffix == "Call" || suffix == "Handler") type + else type.copy(name = suffix) + } else type + } + } + + private fun T.addSelfReceiverToClientFields(): T { + val struct = (this as? File)?.findElement() + val fieldNames = struct?.fields?.map { it.name.value() }?.toSet() ?: emptySet() + if (fieldNames.isEmpty()) return this + + return transform { + statementAndExpression { stmt, tr -> + when (stmt) { + is FieldCall -> { + if (stmt.receiver == null && stmt.field.value() in fieldNames) { + FieldCall(receiver = VariableReference(Name.of("self")), field = stmt.field) + } else { + FieldCall( + receiver = stmt.receiver?.let { tr.transformExpression(it) }, + field = stmt.field, + ) + } + } + else -> stmt.transformChildren(tr) + } + } + } + } + + private fun T.snakeCaseClientFunctions(): T = transform { + matchingElements { func: LanguageFunction -> + func.copy( + name = Name.of(func.name.snakeCase()), + parameters = listOf(Parameter(Name.of("self"), LanguageType.Custom(""))) + func.parameters, + ) + } + statementAndExpression { stmt, tr -> + when (stmt) { + is FunctionCall -> { + val nameStr = stmt.name.value() + val newName = if ("." in nameStr) stmt.name else Name.of(Name.of(nameStr).snakeCase()) + FunctionCall( + name = newName, + receiver = stmt.receiver?.let { tr.transformExpression(it) }, + arguments = stmt.arguments.mapValues { (_, v) -> tr.transformExpression(v) }, + isAwait = stmt.receiver != null, + ) + } + else -> stmt.transformChildren(tr) + } + } + } + + companion object : Keywords { + override val reservedKeywords = setOf( + "False", "None", "True", "and", "as", "assert", + "break", "class", "continue", "def", "del", + "elif", "else", "except", "finally", "for", + "from", "global", "if", "import", "in", + "is", "lambda", "nonlocal", "not", "or", + "pass", "raise", "return", "try", "while", + "with", "yield" + ) + } + +} diff --git a/src/compiler/emitters/python/src/commonTest/kotlin/community/flock/wirespec/emitters/python/PythonIrEmitterTest.kt b/src/compiler/emitters/python/src/commonTest/kotlin/community/flock/wirespec/emitters/python/PythonIrEmitterTest.kt new file mode 100644 index 000000000..32db15f82 --- /dev/null +++ b/src/compiler/emitters/python/src/commonTest/kotlin/community/flock/wirespec/emitters/python/PythonIrEmitterTest.kt @@ -0,0 +1,1095 @@ +package community.flock.wirespec.emitters.python + +import community.flock.wirespec.compiler.test.CompileChannelTest +import community.flock.wirespec.compiler.test.CompileComplexModelTest +import community.flock.wirespec.compiler.test.CompileEnumTest +import community.flock.wirespec.compiler.test.CompileFullEndpointTest +import community.flock.wirespec.compiler.test.CompileMinimalEndpointTest +import community.flock.wirespec.compiler.test.CompileNestedTypeTest +import community.flock.wirespec.compiler.test.CompileRefinedTest +import community.flock.wirespec.compiler.test.CompileTypeTest +import community.flock.wirespec.compiler.test.CompileUnionTest +import io.kotest.assertions.arrow.core.shouldBeRight +import io.kotest.matchers.shouldBe +import kotlin.test.Test + +class PythonIrEmitterTest { + + @Test + fun compileFullEndpointTest() { + val python = """ + |from __future__ import annotations + |import re + |from abc import ABC, abstractmethod + |from dataclasses import dataclass + |from typing import Any, Generic, List, Optional + |import enum + |from ..wirespec import T, Wirespec, _raise + |@dataclass + |class PotentialTodoDto(Wirespec.Model): + | name: str + | done: bool + | def validate(self) -> list[str]: + | return [] + | + |from __future__ import annotations + |import re + |from abc import ABC, abstractmethod + |from dataclasses import dataclass + |from typing import Any, Generic, List, Optional + |import enum + |from ..wirespec import T, Wirespec, _raise + |@dataclass + |class Token(Wirespec.Model): + | iss: str + | def validate(self) -> list[str]: + | return [] + | + |from __future__ import annotations + |import re + |from abc import ABC, abstractmethod + |from dataclasses import dataclass + |from typing import Any, Generic, List, Optional + |import enum + |from ..wirespec import T, Wirespec, _raise + |@dataclass + |class TodoDto(Wirespec.Model): + | id: str + | name: str + | done: bool + | def validate(self) -> list[str]: + | return [] + | + |from __future__ import annotations + |import re + |from abc import ABC, abstractmethod + |from dataclasses import dataclass + |from typing import Any, Generic, List, Optional + |import enum + |from ..wirespec import T, Wirespec, _raise + |@dataclass + |class Error(Wirespec.Model): + | code: int + | description: str + | def validate(self) -> list[str]: + | return [] + | + |from __future__ import annotations + |import re + |from abc import ABC, abstractmethod + |from dataclasses import dataclass + |from typing import Any, Generic, List, Optional + |import enum + |from ..wirespec import T, Wirespec, _raise + |from ..model.Token import Token + |from ..model.PotentialTodoDto import PotentialTodoDto + |from ..model.TodoDto import TodoDto + |from ..model.Error import Error + |@dataclass + |class Path(Wirespec.Path): + | id: str + |@dataclass + |class Queries(Wirespec.Queries): + | done: bool + | name: Optional[str] + |@dataclass + |class RequestHeaders(Wirespec.Request.Headers): + | token: Token + | refreshToken: Optional[Token] + |@dataclass + |class Request(Wirespec.Request[PotentialTodoDto]): + | path: Path + | method: Wirespec.Method + | queries: Queries + | headers: RequestHeaders + | body: PotentialTodoDto + | def __init__( + | self, + | id: str, + | done: bool, + | name: Optional[str], + | token: Token, + | refreshToken: Optional[Token], + | body: PotentialTodoDto, + | ): + | self.path = Path(id=id) + | self.method = Wirespec.Method.PUT + | self.queries = Queries(done=done, name=name) + | self.headers = RequestHeaders(token=token, refreshToken=refreshToken) + | self.body = body + |class Response(Wirespec.Response[T], Generic[T]): + | pass + |class Response2XX(Response[T], Generic[T]): + | pass + |class Response5XX(Response[T], Generic[T]): + | pass + |class ResponseTodoDto(Response[TodoDto]): + | pass + |class ResponseError(Response[Error]): + | pass + |@dataclass + |class Response200Headers(Wirespec.Response.Headers): + | pass + |@dataclass + |class Response200(Response2XX[TodoDto], ResponseTodoDto): + | status: int + | headers: Response200Headers + | body: TodoDto + | def __init__( + | self, + | body: TodoDto, + | ): + | self.status = 200 + | self.headers = Response200Headers() + | self.body = body + |@dataclass + |class Response201Headers(Wirespec.Response.Headers): + | token: Token + | refreshToken: Optional[Token] + |@dataclass + |class Response201(Response2XX[TodoDto], ResponseTodoDto): + | status: int + | headers: Response201Headers + | body: TodoDto + | def __init__( + | self, + | token: Token, + | refreshToken: Optional[Token], + | body: TodoDto, + | ): + | self.status = 201 + | self.headers = Response201Headers(token=token, refreshToken=refreshToken) + | self.body = body + |@dataclass + |class Response500Headers(Wirespec.Response.Headers): + | pass + |@dataclass + |class Response500(Response5XX[Error], ResponseError): + | status: int + | headers: Response500Headers + | body: Error + | def __init__( + | self, + | body: Error, + | ): + | self.status = 500 + | self.headers = Response500Headers() + | self.body = body + |class PutTodo(Wirespec.Endpoint): + | @staticmethod + | def toRawRequest(serialization: Wirespec.Serializer, request: Request) -> Wirespec.RawRequest: + | return Wirespec.RawRequest(method=request.method.value, path=['todos', serialization.serializePath(request.path.id, str)], queries={'done': serialization.serializeParam(request.queries.done, bool), 'name': serialization.serializeParam(request.queries.name, str) if request.queries.name is not None else []}, headers={'token': serialization.serializeParam(request.headers.token, Token), 'Refresh-Token': serialization.serializeParam(request.headers.refreshToken, Token) if request.headers.refreshToken is not None else []}, body=serialization.serializeBody(request.body, PotentialTodoDto)) + | @staticmethod + | def fromRawRequest(serialization: Wirespec.Deserializer, request: Wirespec.RawRequest) -> Request: + | return Request(id=serialization.deserializePath(request.path[1], str), done=serialization.deserializeParam(request.queries['done'], bool) if request.queries['done'] is not None else _raise('Param done cannot be null'), name=serialization.deserializeParam(request.queries['name'], str) if request.queries['name'] is not None else None, token=serialization.deserializeParam(next((v for k, v in request.headers.items() if k.lower() == 'token'.lower()), None), Token) if next((v for k, v in request.headers.items() if k.lower() == 'token'.lower()), None) is not None else _raise('Param token cannot be null'), refreshToken=serialization.deserializeParam(next((v for k, v in request.headers.items() if k.lower() == 'Refresh-Token'.lower()), None), Token) if next((v for k, v in request.headers.items() if k.lower() == 'Refresh-Token'.lower()), None) is not None else None, body=serialization.deserializeBody(request.body, PotentialTodoDto) if request.body is not None else _raise('body is null')) + | @staticmethod + | def toRawResponse(serialization: Wirespec.Serializer, response: Response[Any]) -> Wirespec.RawResponse: + | match response: + | case Response200() as r: + | return Wirespec.RawResponse(statusCode=r.status, headers={}, body=serialization.serializeBody(r.body, TodoDto)) + | case Response201() as r: + | return Wirespec.RawResponse(statusCode=r.status, headers={'token': serialization.serializeParam(r.headers.token, Token), 'refreshToken': serialization.serializeParam(r.headers.refreshToken, Token) if r.headers.refreshToken is not None else []}, body=serialization.serializeBody(r.body, TodoDto)) + | case Response500() as r: + | return Wirespec.RawResponse(statusCode=r.status, headers={}, body=serialization.serializeBody(r.body, Error)) + | case _: + | raise Exception(('Cannot match response with status: ' + str(response.status))) + | @staticmethod + | def fromRawResponse(serialization: Wirespec.Deserializer, response: Wirespec.RawResponse) -> Response[Any]: + | match response.statusCode: + | case 200: + | return Response200(body=serialization.deserializeBody(response.body, TodoDto) if response.body is not None else _raise('body is null')) + | case 201: + | return Response201(token=serialization.deserializeParam(next((v for k, v in response.headers.items() if k.lower() == 'token'.lower()), None), Token) if next((v for k, v in response.headers.items() if k.lower() == 'token'.lower()), None) is not None else _raise('Param token cannot be null'), refreshToken=serialization.deserializeParam(next((v for k, v in response.headers.items() if k.lower() == 'refreshToken'.lower()), None), Token) if next((v for k, v in response.headers.items() if k.lower() == 'refreshToken'.lower()), None) is not None else None, body=serialization.deserializeBody(response.body, TodoDto) if response.body is not None else _raise('body is null')) + | case 500: + | return Response500(body=serialization.deserializeBody(response.body, Error) if response.body is not None else _raise('body is null')) + | case _: + | raise Exception(('Cannot match response with status: ' + str(response.statusCode))) + | class Handler(Wirespec.Handler, ABC): + | @abstractmethod + | async def put_todo(self, request: Request) -> Response[Any]: + | ... + | class Call(Wirespec.Call, ABC): + | @abstractmethod + | async def put_todo(self, id: str, done: bool, name: Optional[str], token: Token, refreshToken: Optional[Token], body: PotentialTodoDto) -> Response[Any]: + | ... + | + |from __future__ import annotations + |import re + |from abc import ABC, abstractmethod + |from dataclasses import dataclass + |from typing import Any, Generic, List, Optional + |import enum + |from ..wirespec import T, Wirespec, _raise + |from ..model.Token import Token + |from ..model.PotentialTodoDto import PotentialTodoDto + |from ..model.TodoDto import TodoDto + |from ..model.Error import Error + |from ..endpoint.PutTodo import * + |@dataclass + |class PutTodoClient(PutTodo.Call): + | serialization: Wirespec.Serialization + | transportation: Wirespec.Transportation + | async def put_todo(self, id: str, done: bool, name: Optional[str], token: Token, refreshToken: Optional[Token], body: PotentialTodoDto) -> Response[Any]: + | request = Request(id=id, done=done, name=name, token=token, refreshToken=refreshToken, body=body) + | rawRequest = PutTodo.toRawRequest(serialization=self.serialization, request=request) + | rawResponse = await self.transportation.transport(rawRequest) + | return PutTodo.fromRawResponse(serialization=self.serialization, response=rawResponse) + | + |from . import model + |from . import endpoint + |from . import client + |from . import wirespec + | + | + | + | + | + | + | + |from __future__ import annotations + |import re + |from abc import ABC, abstractmethod + |from dataclasses import dataclass + |from typing import Any, Generic, List, Optional + |import enum + |from .wirespec import T, Wirespec, _raise + |from .model.Token import Token + |from .model.PotentialTodoDto import PotentialTodoDto + |from .model.TodoDto import TodoDto + |from .model.Error import Error + |from .endpoint.PutTodo import * + |from .client.PutTodoClient import PutTodoClient + |@dataclass + |class Client(PutTodo.Call): + | serialization: Wirespec.Serialization + | transportation: Wirespec.Transportation + | async def put_todo(self, id: str, done: bool, name: Optional[str], token: Token, refreshToken: Optional[Token], body: PotentialTodoDto) -> Response[Any]: + | return await PutTodoClient(serialization=self.serialization, transportation=self.transportation).put_todo(id, done, name, token, refreshToken, body) + | + """.trimMargin() + + CompileFullEndpointTest.compiler { PythonIrEmitter() } shouldBeRight python + } + + @Test + fun compileChannelTest() { + val python = """ + |from __future__ import annotations + |import re + |from abc import ABC, abstractmethod + |from dataclasses import dataclass + |from typing import Any, Generic, List, Optional + |import enum + |from ..wirespec import T, Wirespec, _raise + |class Queue(Wirespec.Channel, ABC): + | @abstractmethod + | def invoke(self, message: str) -> None: + | ... + | + |from . import model + |from . import endpoint + |from . import wirespec + | + | + | + | + | + | + | + """.trimMargin() + + CompileChannelTest.compiler { PythonIrEmitter() } shouldBeRight python + } + + @Test + fun compileEnumTest() { + val python = """ + |from __future__ import annotations + |import re + |from abc import ABC, abstractmethod + |from dataclasses import dataclass + |from typing import Any, Generic, List, Optional + |import enum + |from ..wirespec import T, Wirespec, _raise + |class MyAwesomeEnum(Wirespec.Enum, enum.Enum): + | ONE = "ONE" + | Two = "Two" + | THREE_MORE = "THREE_MORE" + | UnitedKingdom = "UnitedKingdom" + | _1 = "-1" + | _0 = "0" + | _10 = "10" + | _999 = "-999" + | _88 = "88" + | + |from . import model + |from . import endpoint + |from . import wirespec + | + | + | + | + | + | + | + """.trimMargin() + + CompileEnumTest.compiler { PythonIrEmitter() } shouldBeRight python + } + + @Test + fun compileMinimalEndpointTest() { + val python = """ + |from __future__ import annotations + |import re + |from abc import ABC, abstractmethod + |from dataclasses import dataclass + |from typing import Any, Generic, List, Optional + |import enum + |from ..wirespec import T, Wirespec, _raise + |@dataclass + |class TodoDto(Wirespec.Model): + | description: str + | def validate(self) -> list[str]: + | return [] + | + |from __future__ import annotations + |import re + |from abc import ABC, abstractmethod + |from dataclasses import dataclass + |from typing import Any, Generic, List, Optional + |import enum + |from ..wirespec import T, Wirespec, _raise + |from ..model.TodoDto import TodoDto + |@dataclass + |class Path(Wirespec.Path): + | pass + |@dataclass + |class Queries(Wirespec.Queries): + | pass + |@dataclass + |class RequestHeaders(Wirespec.Request.Headers): + | pass + |@dataclass + |class Request(Wirespec.Request[None]): + | path: Path + | method: Wirespec.Method + | queries: Queries + | headers: RequestHeaders + | body: None + | def __init__(self): + | self.path = Path() + | self.method = Wirespec.Method.GET + | self.queries = Queries() + | self.headers = RequestHeaders() + | self.body = None + |class Response(Wirespec.Response[T], Generic[T]): + | pass + |class Response2XX(Response[T], Generic[T]): + | pass + |class ResponseListTodoDto(Response[list[TodoDto]]): + | pass + |@dataclass + |class Response200Headers(Wirespec.Response.Headers): + | pass + |@dataclass + |class Response200(Response2XX[list[TodoDto]], ResponseListTodoDto): + | status: int + | headers: Response200Headers + | body: list[TodoDto] + | def __init__( + | self, + | body: list[TodoDto], + | ): + | self.status = 200 + | self.headers = Response200Headers() + | self.body = body + |class GetTodos(Wirespec.Endpoint): + | @staticmethod + | def toRawRequest(serialization: Wirespec.Serializer, request: Request) -> Wirespec.RawRequest: + | return Wirespec.RawRequest(method=request.method.value, path=['todos'], queries={}, headers={}, body=None) + | @staticmethod + | def fromRawRequest(serialization: Wirespec.Deserializer, request: Wirespec.RawRequest) -> Request: + | return Request() + | @staticmethod + | def toRawResponse(serialization: Wirespec.Serializer, response: Response[Any]) -> Wirespec.RawResponse: + | match response: + | case Response200() as r: + | return Wirespec.RawResponse(statusCode=r.status, headers={}, body=serialization.serializeBody(r.body, list[TodoDto])) + | case _: + | raise Exception(('Cannot match response with status: ' + str(response.status))) + | @staticmethod + | def fromRawResponse(serialization: Wirespec.Deserializer, response: Wirespec.RawResponse) -> Response[Any]: + | match response.statusCode: + | case 200: + | return Response200(body=serialization.deserializeBody(response.body, list[TodoDto]) if response.body is not None else _raise('body is null')) + | case _: + | raise Exception(('Cannot match response with status: ' + str(response.statusCode))) + | class Handler(Wirespec.Handler, ABC): + | @abstractmethod + | async def get_todos(self, request: Request) -> Response[Any]: + | ... + | class Call(Wirespec.Call, ABC): + | @abstractmethod + | async def get_todos(self) -> Response[Any]: + | ... + | + |from __future__ import annotations + |import re + |from abc import ABC, abstractmethod + |from dataclasses import dataclass + |from typing import Any, Generic, List, Optional + |import enum + |from ..wirespec import T, Wirespec, _raise + |from ..model.TodoDto import TodoDto + |from ..endpoint.GetTodos import * + |@dataclass + |class GetTodosClient(GetTodos.Call): + | serialization: Wirespec.Serialization + | transportation: Wirespec.Transportation + | async def get_todos(self) -> Response[Any]: + | request = Request() + | rawRequest = GetTodos.toRawRequest(serialization=self.serialization, request=request) + | rawResponse = await self.transportation.transport(rawRequest) + | return GetTodos.fromRawResponse(serialization=self.serialization, response=rawResponse) + | + |from . import model + |from . import endpoint + |from . import client + |from . import wirespec + | + | + | + | + | + | + | + |from __future__ import annotations + |import re + |from abc import ABC, abstractmethod + |from dataclasses import dataclass + |from typing import Any, Generic, List, Optional + |import enum + |from .wirespec import T, Wirespec, _raise + |from .model.TodoDto import TodoDto + |from .endpoint.GetTodos import * + |from .client.GetTodosClient import GetTodosClient + |@dataclass + |class Client(GetTodos.Call): + | serialization: Wirespec.Serialization + | transportation: Wirespec.Transportation + | async def get_todos(self) -> Response[Any]: + | return await GetTodosClient(serialization=self.serialization, transportation=self.transportation).get_todos() + | + """.trimMargin() + + CompileMinimalEndpointTest.compiler { PythonIrEmitter() } shouldBeRight python + } + + @Test + fun compileRefinedTest() { + val python = """ + |from __future__ import annotations + |import re + |from abc import ABC, abstractmethod + |from dataclasses import dataclass + |from typing import Any, Generic, List, Optional + |import enum + |from ..wirespec import T, Wirespec, _raise + |@dataclass + |class TodoId(Wirespec.Refined[str]): + | value: str + | def validate(self) -> bool: + | return bool(re.match(r"/^[0-9a-fA-F]{8}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{12}${'$'}/g", self.value)) + | def __str__(self) -> str: + | return self.value + | + |from __future__ import annotations + |import re + |from abc import ABC, abstractmethod + |from dataclasses import dataclass + |from typing import Any, Generic, List, Optional + |import enum + |from ..wirespec import T, Wirespec, _raise + |@dataclass + |class TodoNoRegex(Wirespec.Refined[str]): + | value: str + | def validate(self) -> bool: + | return True + | def __str__(self) -> str: + | return self.value + | + |from __future__ import annotations + |import re + |from abc import ABC, abstractmethod + |from dataclasses import dataclass + |from typing import Any, Generic, List, Optional + |import enum + |from ..wirespec import T, Wirespec, _raise + |@dataclass + |class TestInt(Wirespec.Refined[int]): + | value: int + | def validate(self) -> bool: + | return True + | def __str__(self) -> str: + | return str(self.value) + | + |from __future__ import annotations + |import re + |from abc import ABC, abstractmethod + |from dataclasses import dataclass + |from typing import Any, Generic, List, Optional + |import enum + |from ..wirespec import T, Wirespec, _raise + |@dataclass + |class TestInt0(Wirespec.Refined[int]): + | value: int + | def validate(self) -> bool: + | return True + | def __str__(self) -> str: + | return str(self.value) + | + |from __future__ import annotations + |import re + |from abc import ABC, abstractmethod + |from dataclasses import dataclass + |from typing import Any, Generic, List, Optional + |import enum + |from ..wirespec import T, Wirespec, _raise + |@dataclass + |class TestInt1(Wirespec.Refined[int]): + | value: int + | def validate(self) -> bool: + | return 0 <= self.value + | def __str__(self) -> str: + | return str(self.value) + | + |from __future__ import annotations + |import re + |from abc import ABC, abstractmethod + |from dataclasses import dataclass + |from typing import Any, Generic, List, Optional + |import enum + |from ..wirespec import T, Wirespec, _raise + |@dataclass + |class TestInt2(Wirespec.Refined[int]): + | value: int + | def validate(self) -> bool: + | return 1 <= self.value and self.value <= 3 + | def __str__(self) -> str: + | return str(self.value) + | + |from __future__ import annotations + |import re + |from abc import ABC, abstractmethod + |from dataclasses import dataclass + |from typing import Any, Generic, List, Optional + |import enum + |from ..wirespec import T, Wirespec, _raise + |@dataclass + |class TestNum(Wirespec.Refined[float]): + | value: float + | def validate(self) -> bool: + | return True + | def __str__(self) -> str: + | return str(self.value) + | + |from __future__ import annotations + |import re + |from abc import ABC, abstractmethod + |from dataclasses import dataclass + |from typing import Any, Generic, List, Optional + |import enum + |from ..wirespec import T, Wirespec, _raise + |@dataclass + |class TestNum0(Wirespec.Refined[float]): + | value: float + | def validate(self) -> bool: + | return True + | def __str__(self) -> str: + | return str(self.value) + | + |from __future__ import annotations + |import re + |from abc import ABC, abstractmethod + |from dataclasses import dataclass + |from typing import Any, Generic, List, Optional + |import enum + |from ..wirespec import T, Wirespec, _raise + |@dataclass + |class TestNum1(Wirespec.Refined[float]): + | value: float + | def validate(self) -> bool: + | return self.value <= 0.5 + | def __str__(self) -> str: + | return str(self.value) + | + |from __future__ import annotations + |import re + |from abc import ABC, abstractmethod + |from dataclasses import dataclass + |from typing import Any, Generic, List, Optional + |import enum + |from ..wirespec import T, Wirespec, _raise + |@dataclass + |class TestNum2(Wirespec.Refined[float]): + | value: float + | def validate(self) -> bool: + | return -0.2 <= self.value and self.value <= 0.5 + | def __str__(self) -> str: + | return str(self.value) + | + |from . import model + |from . import endpoint + |from . import wirespec + | + | + | + | + | + | + | + """.trimMargin() + + CompileRefinedTest.compiler { PythonIrEmitter() } shouldBeRight python + } + + @Test + fun compileUnionTest() { + val python = """ + |from __future__ import annotations + |import re + |from abc import ABC, abstractmethod + |from dataclasses import dataclass + |from typing import Any, Generic, List, Optional + |import enum + |from ..wirespec import T, Wirespec, _raise + |@dataclass + |class UserAccountPassword(Wirespec.Model, UserAccount): + | username: str + | password: str + | def validate(self) -> list[str]: + | return [] + | + |from __future__ import annotations + |import re + |from abc import ABC, abstractmethod + |from dataclasses import dataclass + |from typing import Any, Generic, List, Optional + |import enum + |from ..wirespec import T, Wirespec, _raise + |@dataclass + |class UserAccountToken(Wirespec.Model, UserAccount): + | token: str + | def validate(self) -> list[str]: + | return [] + | + |from __future__ import annotations + |import re + |from abc import ABC, abstractmethod + |from dataclasses import dataclass + |from typing import Any, Generic, List, Optional + |import enum + |from ..wirespec import T, Wirespec, _raise + |from .UserAccount import UserAccount + |@dataclass + |class User(Wirespec.Model): + | username: str + | account: UserAccount + | def validate(self) -> list[str]: + | return [] + | + |from __future__ import annotations + |import re + |from abc import ABC, abstractmethod + |from dataclasses import dataclass + |from typing import Any, Generic, List, Optional + |import enum + |from ..wirespec import T, Wirespec, _raise + |class UserAccount: + | pass + | + |from . import model + |from . import endpoint + |from . import wirespec + | + | + | + | + | + | + | + """.trimMargin() + + CompileUnionTest.compiler { PythonIrEmitter() } shouldBeRight python + } + + @Test + fun compileTypeTest() { + val python = """ + |from __future__ import annotations + |import re + |from abc import ABC, abstractmethod + |from dataclasses import dataclass + |from typing import Any, Generic, List, Optional + |import enum + |from ..wirespec import T, Wirespec, _raise + |@dataclass + |class Request(Wirespec.Model): + | type: str + | url: str + | BODY_TYPE: Optional[str] + | params: list[str] + | headers: dict[str, str] + | body: Optional[dict[str, Optional[list[Optional[str]]]]] + | def validate(self) -> list[str]: + | return [] + | + |from . import model + |from . import endpoint + |from . import wirespec + | + | + | + | + | + | + | + """.trimMargin() + + CompileTypeTest.compiler { PythonIrEmitter() } shouldBeRight python + } + + @Test + fun compileNestedTypeTest() { + val python = """ + |from __future__ import annotations + |import re + |from abc import ABC, abstractmethod + |from dataclasses import dataclass + |from typing import Any, Generic, List, Optional + |import enum + |from ..wirespec import T, Wirespec, _raise + |@dataclass + |class DutchPostalCode(Wirespec.Refined[str]): + | value: str + | def validate(self) -> bool: + | return bool(re.match(r"/^([0-9]{4}[A-Z]{2})${'$'}/g", self.value)) + | def __str__(self) -> str: + | return self.value + | + |from __future__ import annotations + |import re + |from abc import ABC, abstractmethod + |from dataclasses import dataclass + |from typing import Any, Generic, List, Optional + |import enum + |from ..wirespec import T, Wirespec, _raise + |from .DutchPostalCode import DutchPostalCode + |@dataclass + |class Address(Wirespec.Model): + | street: str + | houseNumber: int + | postalCode: DutchPostalCode + | def validate(self) -> list[str]: + | return (['postalCode'] if not self.postalCode.validate() else []) + | + |from __future__ import annotations + |import re + |from abc import ABC, abstractmethod + |from dataclasses import dataclass + |from typing import Any, Generic, List, Optional + |import enum + |from ..wirespec import T, Wirespec, _raise + |from .Address import Address + |@dataclass + |class Person(Wirespec.Model): + | name: str + | address: Address + | tags: list[str] + | def validate(self) -> list[str]: + | return [f"address.{e}" for e in self.address.validate()] + | + |from . import model + |from . import endpoint + |from . import wirespec + | + | + | + | + | + | + | + """.trimMargin() + + CompileNestedTypeTest.compiler { PythonIrEmitter() } shouldBeRight python + } + + @Test + fun compileComplexModelTest() { + val python = """ + |from __future__ import annotations + |import re + |from abc import ABC, abstractmethod + |from dataclasses import dataclass + |from typing import Any, Generic, List, Optional + |import enum + |from ..wirespec import T, Wirespec, _raise + |@dataclass + |class Email(Wirespec.Refined[str]): + | value: str + | def validate(self) -> bool: + | return bool(re.match(r"/^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}${'$'}/g", self.value)) + | def __str__(self) -> str: + | return self.value + | + |from __future__ import annotations + |import re + |from abc import ABC, abstractmethod + |from dataclasses import dataclass + |from typing import Any, Generic, List, Optional + |import enum + |from ..wirespec import T, Wirespec, _raise + |@dataclass + |class PhoneNumber(Wirespec.Refined[str]): + | value: str + | def validate(self) -> bool: + | return bool(re.match(r"/^\+[1-9]\d{1,14}${'$'}/g", self.value)) + | def __str__(self) -> str: + | return self.value + | + |from __future__ import annotations + |import re + |from abc import ABC, abstractmethod + |from dataclasses import dataclass + |from typing import Any, Generic, List, Optional + |import enum + |from ..wirespec import T, Wirespec, _raise + |@dataclass + |class Tag(Wirespec.Refined[str]): + | value: str + | def validate(self) -> bool: + | return bool(re.match(r"/^[a-z][a-z0-9-]{0,19}${'$'}/g", self.value)) + | def __str__(self) -> str: + | return self.value + | + |from __future__ import annotations + |import re + |from abc import ABC, abstractmethod + |from dataclasses import dataclass + |from typing import Any, Generic, List, Optional + |import enum + |from ..wirespec import T, Wirespec, _raise + |@dataclass + |class EmployeeAge(Wirespec.Refined[int]): + | value: int + | def validate(self) -> bool: + | return 18 <= self.value and self.value <= 65 + | def __str__(self) -> str: + | return str(self.value) + | + |from __future__ import annotations + |import re + |from abc import ABC, abstractmethod + |from dataclasses import dataclass + |from typing import Any, Generic, List, Optional + |import enum + |from ..wirespec import T, Wirespec, _raise + |from .Email import Email + |from .PhoneNumber import PhoneNumber + |@dataclass + |class ContactInfo(Wirespec.Model): + | email: Email + | phone: Optional[PhoneNumber] + | def validate(self) -> list[str]: + | return (['email'] if not self.email.validate() else []) + (['phone'] if not self.phone.validate() else []) if self.phone is not None else [] + | + |from __future__ import annotations + |import re + |from abc import ABC, abstractmethod + |from dataclasses import dataclass + |from typing import Any, Generic, List, Optional + |import enum + |from ..wirespec import T, Wirespec, _raise + |from .EmployeeAge import EmployeeAge + |from .ContactInfo import ContactInfo + |from .Tag import Tag + |@dataclass + |class Employee(Wirespec.Model): + | name: str + | age: EmployeeAge + | contactInfo: ContactInfo + | tags: list[Tag] + | def validate(self) -> list[str]: + | return (['age'] if not self.age.validate() else []) + [f"contactInfo.{e}" for e in self.contactInfo.validate()] + [item for i, el in enumerate(self.tags) for item in ([f"tags[{i}]"] if not el.validate() else [])] + | + |from __future__ import annotations + |import re + |from abc import ABC, abstractmethod + |from dataclasses import dataclass + |from typing import Any, Generic, List, Optional + |import enum + |from ..wirespec import T, Wirespec, _raise + |from .Employee import Employee + |@dataclass + |class Department(Wirespec.Model): + | name: str + | employees: list[Employee] + | def validate(self) -> list[str]: + | return [item for i, el in enumerate(self.employees) for item in [f"employees[{i}].{e}" for e in el.validate()]] + | + |from __future__ import annotations + |import re + |from abc import ABC, abstractmethod + |from dataclasses import dataclass + |from typing import Any, Generic, List, Optional + |import enum + |from ..wirespec import T, Wirespec, _raise + |from .Department import Department + |@dataclass + |class Company(Wirespec.Model): + | name: str + | departments: list[Department] + | def validate(self) -> list[str]: + | return [item for i, el in enumerate(self.departments) for item in [f"departments[{i}].{e}" for e in el.validate()]] + | + |from . import model + |from . import endpoint + |from . import wirespec + | + | + | + | + | + | + | + """.trimMargin() + + CompileComplexModelTest.compiler { PythonIrEmitter() } shouldBeRight python + } + + @Test + fun sharedOutputTest() { + val expected = """ + |from __future__ import annotations + | + |import enum + |from abc import ABC, abstractmethod + |from dataclasses import dataclass + |from typing import Any, Generic, Optional, Type, TypeVar + | + |T = TypeVar('T') + | + | + |def _raise(msg: str) -> Any: + | raise Exception(msg) + | + |# package shared + |class Wirespec: + | class Model(ABC): + | @abstractmethod + | def validate(self) -> list[str]: + | ... + | class Enum(ABC): + | label: str + | class Endpoint(ABC): + | pass + | class Channel(ABC): + | pass + | class Refined(ABC, Generic[T]): + | value: T + | @abstractmethod + | def validate(self) -> bool: + | ... + | class Path(ABC): + | pass + | class Queries(ABC): + | pass + | class Headers(ABC): + | pass + | class Handler(ABC): + | pass + | class Call(ABC): + | pass + | class Method(enum.Enum): + | GET = "GET" + | PUT = "PUT" + | POST = "POST" + | DELETE = "DELETE" + | OPTIONS = "OPTIONS" + | HEAD = "HEAD" + | PATCH = "PATCH" + | TRACE = "TRACE" + | class Request(ABC, Generic[T]): + | path: Wirespec.Path + | method: Wirespec.Method + | queries: Wirespec.Queries + | headers: Headers + | body: T + | class Headers(ABC): + | pass + | class Response(ABC, Generic[T]): + | status: int + | headers: Headers + | body: T + | class Headers(ABC): + | pass + | class BodySerializer(ABC): + | @abstractmethod + | def serializeBody(self, t: T, type: type[T]) -> bytes: + | ... + | class BodyDeserializer(ABC): + | @abstractmethod + | def deserializeBody(self, raw: bytes, type: type[T]) -> T: + | ... + | class BodySerialization(BodySerializer, BodyDeserializer, ABC): + | pass + | class PathSerializer(ABC): + | @abstractmethod + | def serializePath(self, t: T, type: type[T]) -> str: + | ... + | class PathDeserializer(ABC): + | @abstractmethod + | def deserializePath(self, raw: str, type: type[T]) -> T: + | ... + | class PathSerialization(PathSerializer, PathDeserializer, ABC): + | pass + | class ParamSerializer(ABC): + | @abstractmethod + | def serializeParam(self, value: T, type: type[T]) -> list[str]: + | ... + | class ParamDeserializer(ABC): + | @abstractmethod + | def deserializeParam(self, values: list[str], type: type[T]) -> T: + | ... + | class ParamSerialization(ParamSerializer, ParamDeserializer, ABC): + | pass + | class Serializer(BodySerializer, PathSerializer, ParamSerializer, ABC): + | pass + | class Deserializer(BodyDeserializer, PathDeserializer, ParamDeserializer, ABC): + | pass + | class Serialization(Serializer, Deserializer, ABC): + | pass + | @dataclass + | class RawRequest: + | method: str + | path: list[str] + | queries: dict[str, list[str]] + | headers: dict[str, list[str]] + | body: Optional[bytes] + | @dataclass + | class RawResponse: + | statusCode: int + | headers: dict[str, list[str]] + | body: Optional[bytes] + | class Transportation(ABC): + | @abstractmethod + | async def transport(self, request: Wirespec.RawRequest) -> Wirespec.RawResponse: + | ... + | + """.trimMargin() + + val emitter = PythonIrEmitter() + emitter.shared.source shouldBe expected + } +} diff --git a/src/compiler/emitters/rust/build.gradle.kts b/src/compiler/emitters/rust/build.gradle.kts new file mode 100644 index 000000000..68b5de636 --- /dev/null +++ b/src/compiler/emitters/rust/build.gradle.kts @@ -0,0 +1,55 @@ +plugins { + id("module.publication") + id("module.spotless") + alias(libs.plugins.kotlin.multiplatform) + alias(libs.plugins.ksp) + alias(libs.plugins.kotest) +} + +group = "${libs.versions.group.id.get()}.compiler.emitters" +version = System.getenv(libs.versions.from.env.get()) ?: libs.versions.default.get() + +repositories { + mavenCentral() + mavenLocal() +} + +kotlin { + macosX64() + macosArm64() + linuxX64() + mingwX64() + js(IR) { + nodejs() + useEsModules() + } + jvm { + java { + toolchain { + languageVersion.set(JavaLanguageVersion.of(libs.versions.java.get())) + } + } + } + + sourceSets.all { + languageSettings.apply { + languageVersion = libs.versions.kotlin.compiler.get() + } + } + + sourceSets { + commonMain { + dependencies { + api(project(":src:compiler:core")) + api(project(":src:compiler:ir")) + } + } + commonTest { + dependencies { + implementation(libs.kotlin.test) + implementation(libs.bundles.kotest) + implementation(project(":src:compiler:test")) + } + } + } +} diff --git a/src/compiler/emitters/rust/src/commonMain/kotlin/community/flock/wirespec/emitters/rust/RustIrEmitter.kt b/src/compiler/emitters/rust/src/commonMain/kotlin/community/flock/wirespec/emitters/rust/RustIrEmitter.kt new file mode 100644 index 000000000..d9017a144 --- /dev/null +++ b/src/compiler/emitters/rust/src/commonMain/kotlin/community/flock/wirespec/emitters/rust/RustIrEmitter.kt @@ -0,0 +1,853 @@ +package community.flock.wirespec.emitters.rust + +import arrow.core.NonEmptyList +import arrow.core.toNonEmptyListOrNull +import community.flock.wirespec.compiler.core.emit.DEFAULT_GENERATED_PACKAGE_STRING +import community.flock.wirespec.compiler.core.emit.EmitShared +import community.flock.wirespec.compiler.core.emit.FileExtension +import community.flock.wirespec.ir.emit.IrEmitter +import community.flock.wirespec.compiler.core.emit.Keywords +import community.flock.wirespec.compiler.core.emit.LanguageEmitter.Companion.firstToUpper +import community.flock.wirespec.compiler.core.emit.PackageName +import community.flock.wirespec.compiler.core.emit.Shared +import community.flock.wirespec.compiler.core.emit.importReferences +import community.flock.wirespec.compiler.core.emit.plus +import community.flock.wirespec.compiler.core.parse.ast.Shared as AstShared +import community.flock.wirespec.compiler.core.parse.ast.Channel +import community.flock.wirespec.compiler.core.parse.ast.Definition +import community.flock.wirespec.compiler.core.parse.ast.Endpoint +import community.flock.wirespec.compiler.core.parse.ast.Enum +import community.flock.wirespec.compiler.core.parse.ast.FieldIdentifier +import community.flock.wirespec.compiler.core.parse.ast.Identifier +import community.flock.wirespec.compiler.core.parse.ast.Model +import community.flock.wirespec.compiler.core.parse.ast.Module +import community.flock.wirespec.compiler.core.parse.ast.Reference +import community.flock.wirespec.compiler.core.parse.ast.Refined +import community.flock.wirespec.compiler.core.parse.ast.Type +import community.flock.wirespec.compiler.core.parse.ast.Union +import community.flock.wirespec.compiler.utils.Logger +import community.flock.wirespec.ir.converter.requestParameters +import community.flock.wirespec.ir.converter.convert +import community.flock.wirespec.ir.converter.convertConstraint +import community.flock.wirespec.ir.converter.convertWithValidation +import community.flock.wirespec.ir.core.Case +import community.flock.wirespec.ir.core.ConstructorStatement +import community.flock.wirespec.ir.core.Element +import community.flock.wirespec.ir.core.ArrayIndexCall +import community.flock.wirespec.ir.core.Expression +import community.flock.wirespec.ir.core.FieldCall +import community.flock.wirespec.ir.core.File +import community.flock.wirespec.ir.core.FunctionCall +import community.flock.wirespec.ir.core.Literal +import community.flock.wirespec.ir.core.Name +import community.flock.wirespec.ir.core.Interface +import community.flock.wirespec.ir.core.Parameter +import community.flock.wirespec.ir.core.RawElement +import community.flock.wirespec.ir.core.RawExpression +import community.flock.wirespec.ir.core.Namespace +import community.flock.wirespec.ir.core.Struct +import community.flock.wirespec.ir.core.Switch +import community.flock.wirespec.ir.core.Transformer +import community.flock.wirespec.ir.core.VariableReference +import community.flock.wirespec.ir.core.findElement +import community.flock.wirespec.ir.core.flattenNestedStructs +import community.flock.wirespec.ir.core.`interface` +import community.flock.wirespec.ir.core.function +import community.flock.wirespec.ir.core.transform +import community.flock.wirespec.ir.core.transformChildren +import community.flock.wirespec.ir.core.transformer +import community.flock.wirespec.ir.generator.RustGenerator +import community.flock.wirespec.ir.generator.generateRust +import community.flock.wirespec.ir.core.Enum as LanguageEnum +import community.flock.wirespec.ir.core.Function as LanguageFunction +import community.flock.wirespec.ir.core.File as LanguageFile +import community.flock.wirespec.ir.core.Type as LanguageType +import community.flock.wirespec.ir.core.Union as LanguageUnion + +private val selfParam = Parameter(Name.of("&self"), LanguageType.Custom("")) +private val RESPONSE_PATTERN = Regex("Response(\\d+|Default)") + +open class RustIrEmitter( + private val packageName: PackageName = PackageName(DEFAULT_GENERATED_PACKAGE_STRING), + private val emitShared: EmitShared = EmitShared() +) : IrEmitter { + + override val generator = RustGenerator + + override val extension = FileExtension.Rust + + override fun transformTestFile(file: File): File = file.transform { + apply(borrowSerializationArgs()) + apply(fixResponseSwitchPatterns()) + } + + private val modelImport = """ + |use super::super::wirespec::*; + |use regex; + | + """.trimMargin() + + private val endpointImport = """ + |use super::super::wirespec::*; + |use regex; + | + """.trimMargin() + + override val shared = object : Shared { + override val packageString = "shared" + + private val rustImports = listOf( + RawElement("use std::any::TypeId;\nuse std::collections::HashMap;"), + ) + + private val requestHeaders = `interface`("RequestHeaders") { + extends(LanguageType.Custom("Headers")) + } + + private val responseHeaders = `interface`("ResponseHeaders") { + extends(LanguageType.Custom("Headers")) + } + + private val client = RawElement( + """ + pub trait Client { + type Transport: Transportation; + type Ser: Serialization; + fn transport(&self) -> &Self::Transport; + fn serialization(&self) -> &Self::Ser; + } + """.trimIndent() + ) + + private val server = RawElement( + """ + pub trait Server { + type Req; + type Res; + fn path_template(&self) -> &'static str; + fn method(&self) -> Method; + } + """.trimIndent() + ) + + private val enumTrait = `interface`("Enum") { + extends(LanguageType.Custom("Sized")) + function("label") { + arg("&self", LanguageType.Custom("")) + returnType(LanguageType.Custom("&str")) + } + function("from_label", isStatic = true) { + arg("s", LanguageType.Custom("&str")) + returnType(LanguageType.Custom("Option")) + } + } + + private val refinedTrait = `interface`("Refined") { + typeParam(type("T")) + function("value") { + arg("&self", LanguageType.Custom("")) + returnType(type("T").borrow()) + } + function("validate") { + arg("&self", LanguageType.Custom("")) + returnType(boolean) + } + } + + private val requestTrait = `interface`("Request") { + typeParam(type("T")) + field("path", type("Path").borrowDyn()) + field("method", type("Method").borrow()) + field("queries", type("Queries").borrowDyn()) + field("headers", type("RequestHeaders").borrowDyn()) + field("body", type("T").borrow()) + } + + private val responseTrait = `interface`("Response") { + typeParam(type("T")) + field("status", integer32) + field("headers", type("ResponseHeaders").borrowDyn()) + field("body", type("T").borrow()) + } + + private val rawElementInterfaces = mapOf( + "BodySerializer" to RawElement("pub trait BodySerializer {\n fn serialize_body(&self, t: &T, r#type: TypeId) -> Vec;\n}"), + "BodyDeserializer" to RawElement("pub trait BodyDeserializer {\n fn deserialize_body(&self, raw: &[u8], r#type: TypeId) -> T;\n}"), + "PathSerializer" to RawElement("pub trait PathSerializer {\n fn serialize_path(&self, t: &T, r#type: TypeId) -> String;\n}"), + "PathDeserializer" to RawElement("pub trait PathDeserializer {\n fn deserialize_path(&self, raw: &str, r#type: TypeId) -> T where T::Err: std::fmt::Debug;\n}"), + "ParamSerializer" to RawElement("pub trait ParamSerializer {\n fn serialize_param(&self, value: &T, r#type: TypeId) -> Vec;\n}"), + "ParamDeserializer" to RawElement("pub trait ParamDeserializer {\n fn deserialize_param(&self, values: &[String], r#type: TypeId) -> T;\n}"), + ) + + private val transportationTrait = `interface`("Transportation") { + asyncFunction("transport") { + arg("&self", LanguageType.Custom("")) + arg("request", type("RawRequest").borrow()) + returnType(type("RawResponse")) + } + } + + private val dslTraits = mapOf( + "Enum" to enumTrait, + "Refined" to refinedTrait, + "Request" to requestTrait, + "Response" to responseTrait, + "Transportation" to transportationTrait, + ) + + private val wirespecFile = AstShared(packageString) + .convert() + .transform { + matchingElements { file: LanguageFile -> + val namespace = file.elements.filterIsInstance().first() + file.copy(elements = rustImports + namespace.elements) + } + matchingElements { enum: LanguageEnum -> + if (enum.name == Name.of("Method")) { + RawElement( + """ + #[derive(Debug, Clone, Default, PartialEq)] + pub enum Method { + #[default] + GET, + PUT, + POST, + DELETE, + OPTIONS, + HEAD, + PATCH, + TRACE, + } + """.trimIndent() + ) + } else enum + } + matchingElements { file: LanguageFile -> + val newElements = file.elements.flatMap { element -> + if (element is Interface) { + val name = element.name.pascalCase() + when { + name in dslTraits -> buildList { + add(dslTraits[name]!!) + if (name == "Request") add(requestHeaders) + if (name == "Response") add(responseHeaders) + } + name in rawElementInterfaces -> listOf(rawElementInterfaces[name]!!) + else -> listOf(element) + } + } else { + listOf(element) + } + } + client + server + file.copy(elements = newElements) + } + } + .let { file -> + LanguageFile(file.name, file.elements.flatMap { element -> + if (element is Struct) { + val derive = when (element.name.pascalCase()) { + "RawRequest", "RawResponse" -> "#[derive(Debug, Clone, PartialEq)]" + else -> "#[derive(Debug, Clone, Default, PartialEq)]" + } + listOf(LanguageFile(element.name, listOf(RawElement(derive), element))) + } else listOf(element) + }) + } + + override val source: String = wirespecFile + .transform { + matchingElements { iface: Interface -> + iface.transform { + matchingElements { fn: LanguageFunction -> + val hasSelf = fn.parameters.any { it.name.value() == "&self" || it.name.value() == "self" } + if (!hasSelf && !fn.isStatic) { + fn.copy(parameters = listOf(selfParam) + fn.parameters) + } else fn + } + } + } + } + .let { file -> + file.elements.joinToString("\n\n") { element -> + element.generateRust().trimEnd('\n') + } + "\n" + } + } + + override fun emit(module: Module, logger: Logger): NonEmptyList { + val statements = module.statements.sortedBy(::sort).toNonEmptyListOrNull()!! + return super.emit(module.copy(statements = statements), logger).let { files -> + fun emitMod(def: Definition) = "pub mod ${def.identifier.sanitize()};" + val endpoints = module.statements.filterIsInstance() + val endpointMods = endpoints.joinToString("\n") { emitMod(it) } + val clientMod = if (endpoints.isNotEmpty()) "\npub mod client;" else "" + val modRs = File( + Name.of(packageName.toDir() + "mod"), + listOf(RawElement("#![allow(warnings)]\npub mod model;\npub mod endpoint;${clientMod}\npub mod wirespec;")) + ) + val modEndpoint = File( + Name.of(packageName.toDir() + "endpoint/" + "mod"), + listOf(RawElement(endpointMods)) + ) + val modModel = File( + Name.of(packageName.toDir() + "model/" + "mod"), + listOf(RawElement(module.statements.filterIsInstance().joinToString("\n") { emitMod(it) })) + ) + val shared = File(Name.of(packageName.toDir() + "wirespec"), listOf(RawElement(shared.source))) + if (emitShared.value) + files + modRs + modEndpoint + modModel + shared + else + files + modRs + } + } + + override fun emit(definition: Definition, module: Module, logger: Logger): File { + val subPackageName = packageName + definition + val importHeader = when (definition) { + is Endpoint -> endpointImport + else -> modelImport + } + val file = super.emit(definition, module, logger) + return File( + name = Name.of(subPackageName.toDir() + file.name.pascalCase().toSnakeCase()), + elements = listOf(RawElement(importHeader)) + file.elements.flatMap { element -> + if (element is Struct) listOf(RawElement("#[derive(Debug, Clone, Default, PartialEq)]"), element) + else listOf(element) + } + ) + } + + override fun emit(type: Type, module: Module): File = + type.convertWithValidation(module) + .injectSelfReceiver(type.shape.value.map { it.identifier.value }.toSet()) + .sanitizeNames() + .prependImports(type.buildModelImports()) + + override fun emit(enum: Enum, module: Module): File = enum + .convert() + .transform { + matchingElements { languageEnum: LanguageEnum -> + languageEnum.copy( + entries = languageEnum.entries.map { + LanguageEnum.Entry(Name.of(it.name.value().sanitizeEnum().sanitizeKeywords()), listOf("\"${it.name.value()}\"")) + }, + ) + } + } + + override fun emit(union: Union): File = + union.convert() + + override fun emit(refined: Refined): File = + refined.convert() + .transform { + matchingElements { s: Struct -> + s.copy(elements = listOf(buildValidateFunction(refined), buildToStringFunction(refined))) + } + } + + override fun emit(endpoint: Endpoint): File = + endpoint.convert() + .flattenForRust() + .stripWirespecPrefix() + .rustifyEndpoint(endpoint) + .sanitizeNames() + .prependImports(endpoint.buildEndpointImports()) + + override fun emit(channel: Channel): File = + channel.convert() + + override fun emitEndpointClient(endpoint: Endpoint): File { + val endpointName = endpoint.identifier.value + val endpointModuleName = endpointName.toSnakeCase() + val clientName = "${endpointName}Client" + val methodName = endpointName.toSnakeCase() + val (paramsStr, requestArgs) = endpoint.buildClientParams() + val requestConstruction = "$endpointModuleName::Request::new($requestArgs)" + + val imports = endpoint.importReferences().distinctBy { it.value } + .joinToString("\n") { "use super::super::model::${it.value.toSnakeCase()}::${it.value};" } + val namespacePath = "$endpointModuleName::$endpointName" + val code = buildList { + add("use super::super::wirespec::*;") + add("use super::super::endpoint::$endpointModuleName;") + if (imports.isNotEmpty()) add(imports) + add("pub struct $clientName<'a, S: Serialization, T: Transportation> {") + add(" pub serialization: &'a S,") + add(" pub transportation: &'a T,") + add("}") + add("impl<'a, S: Serialization, T: Transportation> $namespacePath::Call for $clientName<'a, S, T> {") + add(" async fn $methodName(&self$paramsStr) -> $endpointModuleName::Response {") + add(" let request = $requestConstruction;") + add(" let raw_request = $namespacePath::to_raw_request(self.serialization, request);") + add(" let raw_response = self.transportation.transport(&raw_request).await;") + add(" $namespacePath::from_raw_response(self.serialization, raw_response)") + add(" }") + add("}") + }.joinToString("\n") + + val subPackageName = packageName + "client" + return File( + name = Name.of(subPackageName.toDir() + clientName.toSnakeCase()), + elements = listOf(RawElement(code)), + ) + } + + override fun emitClient(endpoints: List, logger: Logger): File { + logger.info("Emitting main Client for ${endpoints.size} endpoints") + + val modDeclarations = endpoints.joinToString("\n") { endpoint -> + "pub mod ${(endpoint.identifier.value + "Client").toSnakeCase()};" + } + + val modelImports = endpoints.flatMap { it.importReferences() }.distinctBy { it.value } + .filter { imp -> endpoints.none { it.identifier.value == imp.value } } + .map { "use super::model::${it.value.toSnakeCase()}::${it.value};" } + + val useStatements = endpoints.flatMap { endpoint -> + val endpointModuleName = endpoint.identifier.value.toSnakeCase() + val clientModuleName = "${endpoint.identifier.value}Client".toSnakeCase() + listOf( + "use super::endpoint::$endpointModuleName;", + "use ${clientModuleName}::${endpoint.identifier.value}Client;", + ) + } + + val implBlocks = endpoints.flatMap { endpoint -> + val endpointName = endpoint.identifier.value + val endpointModuleName = endpointName.toSnakeCase() + val namespacePath = "$endpointModuleName::$endpointName" + val methodName = endpointName.toSnakeCase() + val (paramsStr, callArgs) = endpoint.buildClientParams() + val delegateCall = if (callArgs.isNotEmpty()) { + "${endpointName}Client { serialization: &self.serialization, transportation: &self.transportation }\n .$methodName($callArgs).await" + } else { + "${endpointName}Client { serialization: &self.serialization, transportation: &self.transportation }\n .$methodName().await" + } + + listOf( + "impl $namespacePath::Call for Client {", + " async fn $methodName(&self$paramsStr) -> $endpointModuleName::Response {", + " $delegateCall", + " }", + "}", + ) + } + + val code = ( + listOf(modDeclarations) + + listOf("use super::wirespec::*;") + + modelImports + + useStatements + + listOf( + "pub struct Client {", + " pub serialization: S,", + " pub transportation: T,", + "}", + ) + + implBlocks + ).joinToString("\n") + + return File( + name = Name.of(packageName.toDir() + "client"), + elements = listOf(RawElement(code)), + ) + } + + private fun T.sanitizeNames(): T = transform { + parameter { param, _ -> + val name = param.name.value() + if (name == "self" || name == "&self") param + else param.copy(name = param.name.sanitizeName()) + } + statementAndExpression { stmt, tr -> + when (stmt) { + is FieldCall -> FieldCall( + receiver = stmt.receiver?.let { tr.transformExpression(it) }, + field = stmt.field.sanitizeName(), + ) + is ConstructorStatement -> ConstructorStatement( + type = tr.transformType(stmt.type), + namedArguments = stmt.namedArguments + .map { (k, v) -> k.sanitizeName() to tr.transformExpression(v) } + .toMap(), + ) + else -> stmt.transformChildren(tr) + } + } + } + + private fun Name.sanitizeName(): Name = Name.of(Name(parts).snakeCase().sanitizeKeywords()) + + private fun Identifier.sanitize(): String = value + .split(".", " ") + .mapIndexed { index, s -> if (index > 0) s.firstToUpper() else s } + .joinToString("") + .filter { it.isLetterOrDigit() || it == '_' } + .let { if (it.firstOrNull()?.isDigit() == true) "_$it" else it } + .let { if (this is FieldIdentifier) it.sanitizeKeywords() else it } + .toSnakeCase() + + private fun String.toSnakeCase(): String = Name.of(this).snakeCase() + + private fun String.toPascalCase(): String = split("_").joinToString("") { s -> + s.replaceFirstChar { it.uppercaseChar() } + } + + private fun String.sanitizeKeywords() = if (this in reservedKeywords) "r#$this" else this + + private fun String.sanitizeEnum() = split("-", ", ", ".", " ", "//").joinToString("_") + .toPascalCase() + .let { if (it.firstOrNull()?.isDigit() == true) "_$it" else it } + + private fun sort(definition: Definition) = when (definition) { + is Enum -> 1 + is Refined -> 2 + is Type -> 3 + is Union -> 4 + is Endpoint -> 5 + is Channel -> 6 + } + + private fun File.prependImports(imports: String): File = + if (imports.isNotEmpty()) copy(elements = listOf(RawElement(imports)) + elements) + else this + + private fun Type.buildModelImports(): String = + importReferences().distinctBy { it.value } + .joinToString("\n") { "use super::${it.value.toSnakeCase()}::${it.value};" } + + private fun Endpoint.buildEndpointImports(): String = + importReferences().distinctBy { it.value } + .joinToString("\n") { "use super::super::model::${it.value.toSnakeCase()}::${it.value};" } + + private fun T.injectSelfReceiver(fieldNames: Set): T = transform { + matchingElements { fn: LanguageFunction -> + if (fn.name == Name.of("validate")) { + fn.copy(parameters = listOf(selfParam)).transform { + statementAndExpression { s, t -> + if (s is FieldCall && s.receiver == null && s.field.camelCase() in fieldNames) { + FieldCall(receiver = VariableReference(Name.of("self")), field = s.field) + } else s.transformChildren(t) + } + } + } else fn + } + } + + private fun T.stripWirespecPrefix(): T = transform { + matching { type -> + if (type.name.startsWith("Wirespec.")) type.copy(name = type.name.removePrefix("Wirespec.")) + else type + } + } + + private fun buildValidateFunction(refined: Refined): LanguageFunction { + val constraintExpr = refined.reference.convertConstraint( + FieldCall(VariableReference(Name.of("self")), Name.of("value")) + ) + return function("validate") { + arg("&self", LanguageType.Custom("")) + returnType(LanguageType.Boolean) + returns(constraintExpr) + } + } + + private fun buildToStringFunction(refined: Refined): LanguageFunction { + val expr = when (refined.reference.type) { + is Reference.Primitive.Type.String -> "self.value.clone()" + else -> "format!(\"{}\", self.value)" + } + return function("to_string") { + arg("&self", LanguageType.Custom("")) + returnType(LanguageType.String) + returns(RawExpression(expr)) + } + } + + private fun LanguageType.toRustTypeString(): String = when (this) { + is LanguageType.String -> "String" + is LanguageType.Boolean -> "bool" + is LanguageType.Integer -> when (precision) { + community.flock.wirespec.ir.core.Precision.P32 -> "i32" + community.flock.wirespec.ir.core.Precision.P64 -> "i64" + } + is LanguageType.Number -> when (precision) { + community.flock.wirespec.ir.core.Precision.P32 -> "f32" + community.flock.wirespec.ir.core.Precision.P64 -> "f64" + } + is LanguageType.Bytes -> "Vec" + is LanguageType.Unit -> "()" + is LanguageType.Any -> "Box" + is LanguageType.Array -> "Vec<${elementType.toRustTypeString()}>" + is LanguageType.Dict -> "std::collections::HashMap<${keyType.toRustTypeString()}, ${valueType.toRustTypeString()}>" + is LanguageType.Nullable -> "Option<${type.toRustTypeString()}>" + is LanguageType.Custom -> name + is LanguageType.Wildcard -> "_" + is LanguageType.Reflect -> "std::any::TypeId" + is LanguageType.IntegerLiteral -> "i32" + is LanguageType.StringLiteral -> "String" + } + + private fun Endpoint.buildClientParams(): Pair { + val params = requestParameters() + val paramsStr = if (params.isNotEmpty()) { + ", " + params.joinToString(", ") { (name, type) -> + "${name.snakeCase().sanitizeKeywords()}: ${type.toRustTypeString()}" + } + } else "" + val argsStr = if (params.isNotEmpty()) { + params.joinToString(", ") { (name, _) -> name.snakeCase().sanitizeKeywords() } + } else "" + return paramsStr to argsStr + } + + private fun File.flattenForRust(): File { + val namespace = findElement()!! + val flattened = namespace.flattenNestedStructs() + + val moduleElements = flattened.elements + .filter { it is Struct || it is LanguageUnion } + .map { element -> + when { + element is LanguageUnion && element.name.pascalCase() == "Response" -> { + val members = flattened.elements + .filterIsInstance() + .map { it.name.pascalCase() } + .filter { RESPONSE_PATTERN.matches(it) } + .map { LanguageType.Custom(it) } + element.copy(members = members, typeParameters = emptyList()) + } + element is LanguageUnion -> element.copy(typeParameters = emptyList()) + else -> element + } + } + val classElements = flattened.elements.filterNot { it is Struct || it is LanguageUnion } + + return LanguageFile( + namespace.name, + moduleElements + Namespace(namespace.name, classElements, namespace.extends), + ) + } + + private val serializationMethodNames = setOf( + "serialize_path", "serialize_param", "serialize_body", + "deserialize_path", "deserialize_param", "deserialize_body", + ) + + private fun Expression.toRawCode(): String = when (this) { + is VariableReference -> name.snakeCase().sanitizeKeywords() + is FieldCall -> { + val recv = receiver?.let { "${it.toRawCode()}." } ?: "" + "$recv${field.snakeCase().sanitizeKeywords()}" + } + is ArrayIndexCall -> { + val lit = index as? Literal + val idx = when { + lit != null && (lit.type is LanguageType.Integer || lit.type is LanguageType.Number) -> "${lit.value}" + lit != null && lit.type is LanguageType.String -> "\"${lit.value}\"" + else -> index.toRawCode() + } + "${receiver.toRawCode()}[$idx]" + } + is ConstructorStatement -> { + val typeName = (type as? LanguageType.Custom)?.name ?: type.toString() + val args = namedArguments.entries.joinToString(", ") { "${it.key.snakeCase()}: ${it.value.toRawCode()}" } + if (args.isEmpty()) "$typeName {}" else "$typeName { $args }" + } + is Literal -> when { + type is LanguageType.String -> "String::from(\"$value\")" + type is LanguageType.Boolean -> "$value" + type is LanguageType.Integer -> "${value}" + type is LanguageType.Number -> "${value}" + else -> "$value" + } + is RawExpression -> code + else -> error("Unsupported expression type in toRawCode: ${this::class.simpleName}") + } + + private fun Expression.toBorrowedRaw(): RawExpression = RawExpression("&${toRawCode()}") + + private fun borrowSerializationArgs(): Transformer = transformer { + statementAndExpression { s, t -> + if (s is FunctionCall && s.name.snakeCase() in serializationMethodNames) { + val newArgs = s.arguments.entries.mapIndexed { idx, (key, value) -> + val transformed = t.transformExpression(value) + if (idx == 0 && !(transformed is VariableReference && transformed.name.value() == "it")) { + key to transformed.toBorrowedRaw() + } else { + key to transformed + } + }.toMap() + s.copy(arguments = newArgs) + } else s.transformChildren(t) + } + } + + private fun fixResponseSwitchPatterns(): Transformer = transformer { + statement { s, t -> + if (s is Switch && s.variable?.camelCase() == "r") { + val transformedCases = s.cases.map { case -> + val typeName = (case.type as? LanguageType.Custom)?.name + if (typeName != null && RESPONSE_PATTERN.matches(typeName)) { + Case( + value = RawExpression("Response::$typeName(${s.variable!!.snakeCase()})"), + body = case.body.map { t.transformStatement(it) }, + type = null, + ) + } else { + Case( + value = t.transformExpression(case.value), + body = case.body.map { t.transformStatement(it) }, + type = case.type?.let { t.transformType(it) }, + ) + } + } + s.copy( + expression = t.transformExpression(s.expression), + cases = transformedCases, + default = null, + ) + } else s.transformChildren(t) + } + } + + private fun fixConstructorCalls(): Transformer = transformer { + statementAndExpression { s, t -> + if (s is ConstructorStatement) { + val typeName = (s.type as? LanguageType.Custom)?.name + val transformedArgs = s.namedArguments.mapValues { t.transformExpression(it.value) } + when { + typeName != null && RESPONSE_PATTERN.matches(typeName) -> { + FunctionCall( + name = Name(listOf("Response::$typeName")), + arguments = mapOf(Name.of("inner") to FunctionCall( + name = Name(listOf("$typeName::new")), + arguments = transformedArgs, + )), + ) + } + typeName == "Request" -> { + FunctionCall( + name = Name(listOf("Request::new")), + arguments = transformedArgs, + ) + } + else -> s.transformChildren(t) + } + } else s.transformChildren(t) + } + } + + private fun File.rustifyEndpoint(endpoint: Endpoint): File = transform { + val identifierPattern = Regex("[a-zA-Z_][a-zA-Z0-9_]*") + statementAndExpression { s, t -> + if (s is RawExpression && identifierPattern.matches(s.code) && !s.code.contains(".")) { + VariableReference(Name.of(s.code)) + } else s.transformChildren(t) + } + + matchingElements { iface -> + if (iface.name == Name.of("Handler") || iface.name == Name.of("Call")) iface.copy(extends = emptyList()) else iface + } + + matching { type -> + if (type.name.startsWith("Response") && type.generics.isNotEmpty()) type.copy(generics = emptyList()) else type + } + + apply(fixResponseSwitchPatterns()) + apply(fixConstructorCalls()) + apply(borrowSerializationArgs()) + + parametersWhere( + predicate = { (it.type as? LanguageType.Custom)?.name in setOf("Serializer", "Deserializer") }, + transform = { it.copy(type = (it.type as LanguageType.Custom).borrowImpl()) }, + ) + + matchingElements { iface -> + if (iface.name == Name.of("Handler") || iface.name == Name.of("Call")) { + iface.transform { + matchingElements { fn: LanguageFunction -> + fn.copy( + name = Name.of(fn.name.snakeCase()), + parameters = listOf(selfParam) + fn.parameters, + ) + } + } + } else iface + } + + matchingElements { ns -> + val handler = ns.elements.filterIsInstance().firstOrNull { it.name == Name.of("Handler") } + if (handler != null) { + val method = handler.elements.filterIsInstance().firstOrNull() + if (method != null) { + val methodName = method.name.snakeCase() + ns.copy(elements = ns.elements + listOf(RawElement(""" + impl Handler for C { + async fn $methodName(&self, request: Request) -> Response { + let raw = to_raw_request(self.serialization(), request); + let resp = self.transport().transport(&raw).await; + from_raw_response(self.serialization(), resp) + } + } + """.trimIndent()))) + } else ns + } else ns + } + + matchingElements { file -> + file.copy(elements = file.elements.flatMap { element -> + if (element is LanguageUnion && element.name.pascalCase() == "Response" && element.members.isNotEmpty()) { + listOf(element) + element.members.map { member -> + RawElement("impl From<${member.name}> for Response { fn from(value: ${member.name}) -> Self { Response::${member.name}(value) } }\n") + } + } else listOf(element) + }) + } + + matchingElements { ns -> + ns.copy(elements = ns.elements + listOf(RawElement(endpoint.generateApiStruct()))) + } + } + + private fun Endpoint.generateApiStruct(): String { + val pathTemplate = path.joinToString("/") { segment -> + when (segment) { + is Endpoint.Segment.Literal -> segment.value + is Endpoint.Segment.Param -> "{${segment.identifier.value}}" + } + }.let { "/$it" } + val methodName = method.name + return """ + pub struct Api; + impl Server for Api { + type Req = Request; + type Res = Response; + fn path_template(&self) -> &'static str { "$pathTemplate" } + fn method(&self) -> Method { Method::$methodName } + } + """.trimIndent() + } + + companion object : Keywords { + fun VariableReference.borrow(): VariableReference = VariableReference(Name(listOf("&${name.snakeCase()}"))) + fun LanguageType.Custom.borrow(): LanguageType.Custom = copy(name = "&$name") + fun LanguageType.Custom.borrowDyn(): LanguageType.Custom = copy(name = "&dyn $name") + fun LanguageType.Custom.borrowImpl(): LanguageType.Custom = copy(name = "&impl $name") + override val reservedKeywords = setOf( + "as", "break", "const", "continue", "crate", + "else", "enum", "extern", "false", "fn", + "for", "if", "impl", "in", "let", + "loop", "match", "mod", "move", "mut", + "pub", "ref", "return", "self", "Self", + "static", "struct", "super", "trait", "true", + "type", "unsafe", "use", "where", "while", + "async", "await", "dyn", "abstract", "become", + "box", "do", "final", "macro", "override", + "priv", "typeof", "unsized", "virtual", "yield", + "try", + ) + } + +} diff --git a/src/compiler/emitters/rust/src/commonTest/kotlin/community/flock/wirespec/emitters/rust/RustIrEmitterTest.kt b/src/compiler/emitters/rust/src/commonTest/kotlin/community/flock/wirespec/emitters/rust/RustIrEmitterTest.kt new file mode 100644 index 000000000..60dd70852 --- /dev/null +++ b/src/compiler/emitters/rust/src/commonTest/kotlin/community/flock/wirespec/emitters/rust/RustIrEmitterTest.kt @@ -0,0 +1,1130 @@ +package community.flock.wirespec.emitters.rust + +import community.flock.wirespec.compiler.test.CompileChannelTest +import community.flock.wirespec.compiler.test.CompileComplexModelTest +import community.flock.wirespec.compiler.test.CompileEnumTest +import community.flock.wirespec.compiler.test.CompileFullEndpointTest +import community.flock.wirespec.compiler.test.CompileMinimalEndpointTest +import community.flock.wirespec.compiler.test.CompileNestedTypeTest +import community.flock.wirespec.compiler.test.CompileRefinedTest +import community.flock.wirespec.compiler.test.CompileTypeTest +import community.flock.wirespec.compiler.test.CompileUnionTest +import io.kotest.assertions.arrow.core.shouldBeRight +import io.kotest.matchers.shouldBe +import kotlin.test.Test + +class RustIrEmitterTest { + + @Test + fun compileEnumTest() { + val rust = """ + |use super::super::wirespec::*; + |use regex; + |#[derive(Debug, Clone, PartialEq)] + |pub enum MyAwesomeEnum { + | ONE, + | Two, + | THREEMORE, + | UnitedKingdom, + | _1, + | _0, + | _10, + | _999, + | _88, + |} + |impl Enum for MyAwesomeEnum { + | fn label(&self) -> &str { + | match self { + | MyAwesomeEnum::ONE => "ONE", + | MyAwesomeEnum::Two => "Two", + | MyAwesomeEnum::THREEMORE => "THREE_MORE", + | MyAwesomeEnum::UnitedKingdom => "UnitedKingdom", + | MyAwesomeEnum::_1 => "-1", + | MyAwesomeEnum::_0 => "0", + | MyAwesomeEnum::_10 => "10", + | MyAwesomeEnum::_999 => "-999", + | MyAwesomeEnum::_88 => "88", + | } + | } + | fn from_label(s: &str) -> Option { + | match s { + | "ONE" => Some(MyAwesomeEnum::ONE), + | "Two" => Some(MyAwesomeEnum::Two), + | "THREE_MORE" => Some(MyAwesomeEnum::THREEMORE), + | "UnitedKingdom" => Some(MyAwesomeEnum::UnitedKingdom), + | "-1" => Some(MyAwesomeEnum::_1), + | "0" => Some(MyAwesomeEnum::_0), + | "10" => Some(MyAwesomeEnum::_10), + | "-999" => Some(MyAwesomeEnum::_999), + | "88" => Some(MyAwesomeEnum::_88), + | _ => None, + | } + | } + |} + |impl std::fmt::Display for MyAwesomeEnum { + | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + | write!(f, "{}", self.label()) + | } + |} + | + |#![allow(warnings)] + |pub mod model; + |pub mod endpoint; + |pub mod wirespec; + | + """.trimMargin() + + CompileEnumTest.compiler { RustIrEmitter() } shouldBeRight rust + } + + @Test + fun compileTypeTest() { + val rust = """ + |use super::super::wirespec::*; + |use regex; + |#[derive(Debug, Clone, Default, PartialEq)] + |pub struct Request { + | pub r#type: String, + | pub url: String, + | pub body_type: Option, + | pub params: Vec, + | pub headers: std::collections::HashMap, + | pub body: Option>>>>, + |} + |impl Request { + | pub fn validate(&self) -> Vec { + | return Vec::::new(); + | } + |} + | + |#![allow(warnings)] + |pub mod model; + |pub mod endpoint; + |pub mod wirespec; + | + """.trimMargin() + + CompileTypeTest.compiler { RustIrEmitter() } shouldBeRight rust + } + + @Test + fun compileChannelTest() { + val rust = """ + |use super::super::wirespec::*; + |use regex; + |pub trait Queue: Wirespec.Channel { + | fn invoke(message: String); + |} + | + |#![allow(warnings)] + |pub mod model; + |pub mod endpoint; + |pub mod wirespec; + | + """.trimMargin() + + CompileChannelTest.compiler { RustIrEmitter() } shouldBeRight rust + } + + @Test + fun compileRefinedTest() { + val rust = """ + |use super::super::wirespec::*; + |use regex; + |#[derive(Debug, Clone, Default, PartialEq)] + |pub struct TodoId { + | pub value: String, + |} + |impl TodoId { + | pub fn validate(&self) -> bool { + | return regex::Regex::new(r"^[0-9a-fA-F]{8}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{12}${'$'}").unwrap().is_match(&self.value); + | } + | pub fn to_string(&self) -> String { + | return self.value.clone(); + | } + |} + | + |use super::super::wirespec::*; + |use regex; + |#[derive(Debug, Clone, Default, PartialEq)] + |pub struct TodoNoRegex { + | pub value: String, + |} + |impl TodoNoRegex { + | pub fn validate(&self) -> bool { + | return true; + | } + | pub fn to_string(&self) -> String { + | return self.value.clone(); + | } + |} + | + |use super::super::wirespec::*; + |use regex; + |#[derive(Debug, Clone, Default, PartialEq)] + |pub struct TestInt { + | pub value: i64, + |} + |impl TestInt { + | pub fn validate(&self) -> bool { + | return true; + | } + | pub fn to_string(&self) -> String { + | return format!("{}", self.value); + | } + |} + | + |use super::super::wirespec::*; + |use regex; + |#[derive(Debug, Clone, Default, PartialEq)] + |pub struct TestInt0 { + | pub value: i64, + |} + |impl TestInt0 { + | pub fn validate(&self) -> bool { + | return true; + | } + | pub fn to_string(&self) -> String { + | return format!("{}", self.value); + | } + |} + | + |use super::super::wirespec::*; + |use regex; + |#[derive(Debug, Clone, Default, PartialEq)] + |pub struct TestInt1 { + | pub value: i64, + |} + |impl TestInt1 { + | pub fn validate(&self) -> bool { + | return 0 <= self.value; + | } + | pub fn to_string(&self) -> String { + | return format!("{}", self.value); + | } + |} + | + |use super::super::wirespec::*; + |use regex; + |#[derive(Debug, Clone, Default, PartialEq)] + |pub struct TestInt2 { + | pub value: i64, + |} + |impl TestInt2 { + | pub fn validate(&self) -> bool { + | return 1 <= self.value && self.value <= 3; + | } + | pub fn to_string(&self) -> String { + | return format!("{}", self.value); + | } + |} + | + |use super::super::wirespec::*; + |use regex; + |#[derive(Debug, Clone, Default, PartialEq)] + |pub struct TestNum { + | pub value: f64, + |} + |impl TestNum { + | pub fn validate(&self) -> bool { + | return true; + | } + | pub fn to_string(&self) -> String { + | return format!("{}", self.value); + | } + |} + | + |use super::super::wirespec::*; + |use regex; + |#[derive(Debug, Clone, Default, PartialEq)] + |pub struct TestNum0 { + | pub value: f64, + |} + |impl TestNum0 { + | pub fn validate(&self) -> bool { + | return true; + | } + | pub fn to_string(&self) -> String { + | return format!("{}", self.value); + | } + |} + | + |use super::super::wirespec::*; + |use regex; + |#[derive(Debug, Clone, Default, PartialEq)] + |pub struct TestNum1 { + | pub value: f64, + |} + |impl TestNum1 { + | pub fn validate(&self) -> bool { + | return self.value <= 0.5; + | } + | pub fn to_string(&self) -> String { + | return format!("{}", self.value); + | } + |} + | + |use super::super::wirespec::*; + |use regex; + |#[derive(Debug, Clone, Default, PartialEq)] + |pub struct TestNum2 { + | pub value: f64, + |} + |impl TestNum2 { + | pub fn validate(&self) -> bool { + | return -0.2 <= self.value && self.value <= 0.5; + | } + | pub fn to_string(&self) -> String { + | return format!("{}", self.value); + | } + |} + | + |#![allow(warnings)] + |pub mod model; + |pub mod endpoint; + |pub mod wirespec; + | + """.trimMargin() + + CompileRefinedTest.compiler { RustIrEmitter() } shouldBeRight rust + } + + @Test + fun compileUnionTest() { + val rust = """ + |use super::super::wirespec::*; + |use regex; + |#[derive(Debug, Clone, Default, PartialEq)] + |pub struct UserAccountPassword { + | pub username: String, + | pub password: String, + |} + |impl UserAccountPassword { + | pub fn validate(&self) -> Vec { + | return Vec::::new(); + | } + |} + | + |use super::super::wirespec::*; + |use regex; + |#[derive(Debug, Clone, Default, PartialEq)] + |pub struct UserAccountToken { + | pub token: String, + |} + |impl UserAccountToken { + | pub fn validate(&self) -> Vec { + | return Vec::::new(); + | } + |} + | + |use super::super::wirespec::*; + |use regex; + |use super::user_account::UserAccount; + |#[derive(Debug, Clone, Default, PartialEq)] + |pub struct User { + | pub username: String, + | pub account: UserAccount, + |} + |impl User { + | pub fn validate(&self) -> Vec { + | return Vec::::new(); + | } + |} + | + |use super::super::wirespec::*; + |use regex; + |#[derive(Debug, Clone, PartialEq)] + |pub enum UserAccount { + | UserAccountPassword(UserAccountPassword), + | UserAccountToken(UserAccountToken), + |} + | + |#![allow(warnings)] + |pub mod model; + |pub mod endpoint; + |pub mod wirespec; + | + """.trimMargin() + + CompileUnionTest.compiler { RustIrEmitter() } shouldBeRight rust + } + + @Test + fun compileMinimalEndpointTest() { + val rust = """ + |use super::super::wirespec::*; + |use regex; + |#[derive(Debug, Clone, Default, PartialEq)] + |pub struct TodoDto { + | pub description: String, + |} + |impl TodoDto { + | pub fn validate(&self) -> Vec { + | return Vec::::new(); + | } + |} + | + |use super::super::wirespec::*; + |use regex; + |use super::super::model::todo_dto::TodoDto; + |#[derive(Debug, Clone, Default, PartialEq)] + |pub struct Path; + |#[derive(Debug, Clone, Default, PartialEq)] + |pub struct Queries; + |#[derive(Debug, Clone, Default, PartialEq)] + |pub struct RequestHeaders; + |#[derive(Debug, Clone, Default, PartialEq)] + |pub struct Request { + | pub path: Path, + | pub method: Method, + | pub queries: Queries, + | pub headers: RequestHeaders, + | pub body: (), + |} + |impl Request { + | pub fn new() -> Self { + | Request { + | path: Path {}, + | method: Method::GET, + | queries: Queries {}, + | headers: RequestHeaders {}, + | body: () + | } + | } + |} + |#[derive(Debug, Clone, PartialEq)] + |pub enum Response { + | Response200(Response200), + |} + |impl From for Response { fn from(value: Response200) -> Self { Response::Response200(value) } } + |#[derive(Debug, Clone, PartialEq)] + |pub enum Response2XX { + | Response200(Response200), + |} + |#[derive(Debug, Clone, PartialEq)] + |pub enum ResponseListTodoDto { + | Response200(Response200), + |} + |#[derive(Debug, Clone, Default, PartialEq)] + |pub struct Response200Headers; + |#[derive(Debug, Clone, Default, PartialEq)] + |pub struct Response200 { + | pub status: i32, + | pub headers: Response200Headers, + | pub body: Vec, + |} + |impl Response200 { + | pub fn new(body: Vec) -> Self { + | Response200 { + | status: 200_i32, + | headers: Response200Headers {}, + | body: body + | } + | } + |} + |pub mod GetTodos { + | use super::*; + | pub fn to_raw_request(serialization: &impl Serializer, request: Request) -> RawRequest { + | return RawRequest { method: format!("{:?}", request.method), path: vec![String::from("todos")], queries: std::collections::HashMap::new(), headers: std::collections::HashMap::new(), body: None }; + | } + | pub fn from_raw_request(serialization: &impl Deserializer, request: RawRequest) -> Request { + | return Request::new(); + | } + | pub fn to_raw_response(serialization: &impl Serializer, response: Response) -> RawResponse { + | match response { + | Response::Response200(r) => { + | return RawResponse { status_code: r.status, headers: std::collections::HashMap::new(), body: Some(serialization.serialize_body(&r.body, std::any::TypeId::of::>())) }; + | } + | } + | } + | pub fn from_raw_response(serialization: &impl Deserializer, response: RawResponse) -> Response { + | match response.status_code { + | 200_i32 => { + | return Response::Response200(Response200::new(response.body.as_ref().map(|it| serialization.deserialize_body(it, std::any::TypeId::of::>())).expect("body is null"))); + | } + | _ => { + | panic!("Cannot match response with status: {}", response.status_code); + | } + | } + | } + | pub trait Handler { + | async fn get_todos(&self, request: Request) -> Response; + | } + | pub trait Call { + | async fn get_todos(&self) -> Response; + | } + | impl Handler for C { + | async fn get_todos(&self, request: Request) -> Response { + | let raw = to_raw_request(self.serialization(), request); + | let resp = self.transport().transport(&raw).await; + | from_raw_response(self.serialization(), resp) + | } + | } + | pub struct Api; + | impl Server for Api { + | type Req = Request; + | type Res = Response; + | fn path_template(&self) -> &'static str { "/todos" } + | fn method(&self) -> Method { Method::GET } + | } + |} + | + |use super::super::wirespec::*; + |use super::super::endpoint::get_todos; + |use super::super::model::todo_dto::TodoDto; + |pub struct GetTodosClient<'a, S: Serialization, T: Transportation> { + | pub serialization: &'a S, + | pub transportation: &'a T, + |} + |impl<'a, S: Serialization, T: Transportation> get_todos::GetTodos::Call for GetTodosClient<'a, S, T> { + | async fn get_todos(&self) -> get_todos::Response { + | let request = get_todos::Request::new(); + | let raw_request = get_todos::GetTodos::to_raw_request(self.serialization, request); + | let raw_response = self.transportation.transport(&raw_request).await; + | get_todos::GetTodos::from_raw_response(self.serialization, raw_response) + | } + |} + | + |#![allow(warnings)] + |pub mod model; + |pub mod endpoint; + |pub mod client; + |pub mod wirespec; + | + |pub mod get_todos_client; + |use super::wirespec::*; + |use super::model::todo_dto::TodoDto; + |use super::endpoint::get_todos; + |use get_todos_client::GetTodosClient; + |pub struct Client { + | pub serialization: S, + | pub transportation: T, + |} + |impl get_todos::GetTodos::Call for Client { + | async fn get_todos(&self) -> get_todos::Response { + | GetTodosClient { serialization: &self.serialization, transportation: &self.transportation } + | .get_todos().await + | } + |} + | + """.trimMargin() + + CompileMinimalEndpointTest.compiler { RustIrEmitter() } shouldBeRight rust + } + + @Test + fun compileFullEndpointTest() { + val rust = """ + |use super::super::wirespec::*; + |use regex; + |#[derive(Debug, Clone, Default, PartialEq)] + |pub struct PotentialTodoDto { + | pub name: String, + | pub done: bool, + |} + |impl PotentialTodoDto { + | pub fn validate(&self) -> Vec { + | return Vec::::new(); + | } + |} + | + |use super::super::wirespec::*; + |use regex; + |#[derive(Debug, Clone, Default, PartialEq)] + |pub struct Token { + | pub iss: String, + |} + |impl Token { + | pub fn validate(&self) -> Vec { + | return Vec::::new(); + | } + |} + | + |use super::super::wirespec::*; + |use regex; + |#[derive(Debug, Clone, Default, PartialEq)] + |pub struct TodoDto { + | pub id: String, + | pub name: String, + | pub done: bool, + |} + |impl TodoDto { + | pub fn validate(&self) -> Vec { + | return Vec::::new(); + | } + |} + | + |use super::super::wirespec::*; + |use regex; + |#[derive(Debug, Clone, Default, PartialEq)] + |pub struct Error { + | pub code: i64, + | pub description: String, + |} + |impl Error { + | pub fn validate(&self) -> Vec { + | return Vec::::new(); + | } + |} + | + |use super::super::wirespec::*; + |use regex; + |use super::super::model::token::Token; + |use super::super::model::potential_todo_dto::PotentialTodoDto; + |use super::super::model::todo_dto::TodoDto; + |use super::super::model::error::Error; + |#[derive(Debug, Clone, Default, PartialEq)] + |pub struct Path { + | pub id: String, + |} + |#[derive(Debug, Clone, Default, PartialEq)] + |pub struct Queries { + | pub done: bool, + | pub name: Option, + |} + |#[derive(Debug, Clone, Default, PartialEq)] + |pub struct RequestHeaders { + | pub token: Token, + | pub refresh_token: Option, + |} + |#[derive(Debug, Clone, Default, PartialEq)] + |pub struct Request { + | pub path: Path, + | pub method: Method, + | pub queries: Queries, + | pub headers: RequestHeaders, + | pub body: PotentialTodoDto, + |} + |impl Request { + | pub fn new(id: String, done: bool, name: Option, token: Token, refresh_token: Option, body: PotentialTodoDto) -> Self { + | Request { + | path: Path { id: id }, + | method: Method::PUT, + | queries: Queries { done: done, name: name }, + | headers: RequestHeaders { token: token, refresh_token: refresh_token }, + | body: body + | } + | } + |} + |#[derive(Debug, Clone, PartialEq)] + |pub enum Response { + | Response200(Response200), + | Response201(Response201), + | Response500(Response500), + |} + |impl From for Response { fn from(value: Response200) -> Self { Response::Response200(value) } } + |impl From for Response { fn from(value: Response201) -> Self { Response::Response201(value) } } + |impl From for Response { fn from(value: Response500) -> Self { Response::Response500(value) } } + |#[derive(Debug, Clone, PartialEq)] + |pub enum Response2XX { + | Response200(Response200), + | Response201(Response201), + |} + |#[derive(Debug, Clone, PartialEq)] + |pub enum Response5XX { + | Response500(Response500), + |} + |#[derive(Debug, Clone, PartialEq)] + |pub enum ResponseTodoDto { + | Response200(Response200), + | Response201(Response201), + |} + |#[derive(Debug, Clone, PartialEq)] + |pub enum ResponseError { + | Response500(Response500), + |} + |#[derive(Debug, Clone, Default, PartialEq)] + |pub struct Response200Headers; + |#[derive(Debug, Clone, Default, PartialEq)] + |pub struct Response200 { + | pub status: i32, + | pub headers: Response200Headers, + | pub body: TodoDto, + |} + |impl Response200 { + | pub fn new(body: TodoDto) -> Self { + | Response200 { + | status: 200_i32, + | headers: Response200Headers {}, + | body: body + | } + | } + |} + |#[derive(Debug, Clone, Default, PartialEq)] + |pub struct Response201Headers { + | pub token: Token, + | pub refresh_token: Option, + |} + |#[derive(Debug, Clone, Default, PartialEq)] + |pub struct Response201 { + | pub status: i32, + | pub headers: Response201Headers, + | pub body: TodoDto, + |} + |impl Response201 { + | pub fn new(token: Token, refresh_token: Option, body: TodoDto) -> Self { + | Response201 { + | status: 201_i32, + | headers: Response201Headers { token: token, refresh_token: refresh_token }, + | body: body + | } + | } + |} + |#[derive(Debug, Clone, Default, PartialEq)] + |pub struct Response500Headers; + |#[derive(Debug, Clone, Default, PartialEq)] + |pub struct Response500 { + | pub status: i32, + | pub headers: Response500Headers, + | pub body: Error, + |} + |impl Response500 { + | pub fn new(body: Error) -> Self { + | Response500 { + | status: 500_i32, + | headers: Response500Headers {}, + | body: body + | } + | } + |} + |pub mod PutTodo { + | use super::*; + | pub fn to_raw_request(serialization: &impl Serializer, request: Request) -> RawRequest { + | return RawRequest { method: format!("{:?}", request.method), path: vec![String::from("todos"), serialization.serialize_path(&request.path.id, std::any::TypeId::of::())], queries: std::collections::HashMap::from([(String::from("done"), serialization.serialize_param(&request.queries.done, std::any::TypeId::of::())), (String::from("name"), request.queries.name.as_ref().map(|it| serialization.serialize_param(it, std::any::TypeId::of::())).unwrap_or(Vec::::new()))]), headers: std::collections::HashMap::from([(String::from("token"), serialization.serialize_param(&request.headers.token, std::any::TypeId::of::())), (String::from("Refresh-Token"), request.headers.refresh_token.as_ref().map(|it| serialization.serialize_param(it, std::any::TypeId::of::())).unwrap_or(Vec::::new()))]), body: Some(serialization.serialize_body(&request.body, std::any::TypeId::of::())) }; + | } + | pub fn from_raw_request(serialization: &impl Deserializer, request: RawRequest) -> Request { + | return Request::new(serialization.deserialize_path(&request.path[1], std::any::TypeId::of::()), request.queries.get("done").as_ref().map(|it| serialization.deserialize_param(it, std::any::TypeId::of::())).expect("Param done cannot be null"), request.queries.get("name").as_ref().map(|it| serialization.deserialize_param(it, std::any::TypeId::of::())), request.headers.iter().find(|(k, _)| k.eq_ignore_ascii_case("token")).map(|(_, v)| v.clone()).as_ref().map(|it| serialization.deserialize_param(it, std::any::TypeId::of::())).expect("Param token cannot be null"), request.headers.iter().find(|(k, _)| k.eq_ignore_ascii_case("Refresh-Token")).map(|(_, v)| v.clone()).as_ref().map(|it| serialization.deserialize_param(it, std::any::TypeId::of::())), request.body.as_ref().map(|it| serialization.deserialize_body(it, std::any::TypeId::of::())).expect("body is null")); + | } + | pub fn to_raw_response(serialization: &impl Serializer, response: Response) -> RawResponse { + | match response { + | Response::Response200(r) => { + | return RawResponse { status_code: r.status, headers: std::collections::HashMap::new(), body: Some(serialization.serialize_body(&r.body, std::any::TypeId::of::())) }; + | } + | Response::Response201(r) => { + | return RawResponse { status_code: r.status, headers: std::collections::HashMap::from([(String::from("token"), serialization.serialize_param(&r.headers.token, std::any::TypeId::of::())), (String::from("refreshToken"), r.headers.refresh_token.as_ref().map(|it| serialization.serialize_param(it, std::any::TypeId::of::())).unwrap_or(Vec::::new()))]), body: Some(serialization.serialize_body(&r.body, std::any::TypeId::of::())) }; + | } + | Response::Response500(r) => { + | return RawResponse { status_code: r.status, headers: std::collections::HashMap::new(), body: Some(serialization.serialize_body(&r.body, std::any::TypeId::of::())) }; + | } + | } + | } + | pub fn from_raw_response(serialization: &impl Deserializer, response: RawResponse) -> Response { + | match response.status_code { + | 200_i32 => { + | return Response::Response200(Response200::new(response.body.as_ref().map(|it| serialization.deserialize_body(it, std::any::TypeId::of::())).expect("body is null"))); + | } + | 201_i32 => { + | return Response::Response201(Response201::new(response.headers.iter().find(|(k, _)| k.eq_ignore_ascii_case("token")).map(|(_, v)| v.clone()).as_ref().map(|it| serialization.deserialize_param(it, std::any::TypeId::of::())).expect("Param token cannot be null"), response.headers.iter().find(|(k, _)| k.eq_ignore_ascii_case("refreshToken")).map(|(_, v)| v.clone()).as_ref().map(|it| serialization.deserialize_param(it, std::any::TypeId::of::())), response.body.as_ref().map(|it| serialization.deserialize_body(it, std::any::TypeId::of::())).expect("body is null"))); + | } + | 500_i32 => { + | return Response::Response500(Response500::new(response.body.as_ref().map(|it| serialization.deserialize_body(it, std::any::TypeId::of::())).expect("body is null"))); + | } + | _ => { + | panic!("Cannot match response with status: {}", response.status_code); + | } + | } + | } + | pub trait Handler { + | async fn put_todo(&self, request: Request) -> Response; + | } + | pub trait Call { + | async fn put_todo(&self, id: String, done: bool, name: Option, token: Token, refresh_token: Option, body: PotentialTodoDto) -> Response; + | } + | impl Handler for C { + | async fn put_todo(&self, request: Request) -> Response { + | let raw = to_raw_request(self.serialization(), request); + | let resp = self.transport().transport(&raw).await; + | from_raw_response(self.serialization(), resp) + | } + | } + | pub struct Api; + | impl Server for Api { + | type Req = Request; + | type Res = Response; + | fn path_template(&self) -> &'static str { "/todos/{id}" } + | fn method(&self) -> Method { Method::PUT } + | } + |} + | + |use super::super::wirespec::*; + |use super::super::endpoint::put_todo; + |use super::super::model::token::Token; + |use super::super::model::potential_todo_dto::PotentialTodoDto; + |use super::super::model::todo_dto::TodoDto; + |use super::super::model::error::Error; + |pub struct PutTodoClient<'a, S: Serialization, T: Transportation> { + | pub serialization: &'a S, + | pub transportation: &'a T, + |} + |impl<'a, S: Serialization, T: Transportation> put_todo::PutTodo::Call for PutTodoClient<'a, S, T> { + | async fn put_todo(&self, id: String, done: bool, name: Option, token: Token, refresh_token: Option, body: PotentialTodoDto) -> put_todo::Response { + | let request = put_todo::Request::new(id, done, name, token, refresh_token, body); + | let raw_request = put_todo::PutTodo::to_raw_request(self.serialization, request); + | let raw_response = self.transportation.transport(&raw_request).await; + | put_todo::PutTodo::from_raw_response(self.serialization, raw_response) + | } + |} + | + |#![allow(warnings)] + |pub mod model; + |pub mod endpoint; + |pub mod client; + |pub mod wirespec; + | + |pub mod put_todo_client; + |use super::wirespec::*; + |use super::model::token::Token; + |use super::model::potential_todo_dto::PotentialTodoDto; + |use super::model::todo_dto::TodoDto; + |use super::model::error::Error; + |use super::endpoint::put_todo; + |use put_todo_client::PutTodoClient; + |pub struct Client { + | pub serialization: S, + | pub transportation: T, + |} + |impl put_todo::PutTodo::Call for Client { + | async fn put_todo(&self, id: String, done: bool, name: Option, token: Token, refresh_token: Option, body: PotentialTodoDto) -> put_todo::Response { + | PutTodoClient { serialization: &self.serialization, transportation: &self.transportation } + | .put_todo(id, done, name, token, refresh_token, body).await + | } + |} + | + """.trimMargin() + + CompileFullEndpointTest.compiler { RustIrEmitter() } shouldBeRight rust + } + + @Test + fun compileNestedTypeTest() { + val rust = """ + |use super::super::wirespec::*; + |use regex; + |#[derive(Debug, Clone, Default, PartialEq)] + |pub struct DutchPostalCode { + | pub value: String, + |} + |impl DutchPostalCode { + | pub fn validate(&self) -> bool { + | return regex::Regex::new(r"^([0-9]{4}[A-Z]{2})${'$'}").unwrap().is_match(&self.value); + | } + | pub fn to_string(&self) -> String { + | return self.value.clone(); + | } + |} + | + |use super::super::wirespec::*; + |use regex; + |use super::dutch_postal_code::DutchPostalCode; + |#[derive(Debug, Clone, Default, PartialEq)] + |pub struct Address { + | pub street: String, + | pub house_number: i64, + | pub postal_code: DutchPostalCode, + |} + |impl Address { + | pub fn validate(&self) -> Vec { + | return if !self.postal_code.validate() { vec![String::from("postalCode")] } else { Vec::::new() }; + | } + |} + | + |use super::super::wirespec::*; + |use regex; + |use super::address::Address; + |#[derive(Debug, Clone, Default, PartialEq)] + |pub struct Person { + | pub name: String, + | pub address: Address, + | pub tags: Vec, + |} + |impl Person { + | pub fn validate(&self) -> Vec { + | return self.address.validate().iter().map(|e| format!("address.{}", e)).collect::>(); + | } + |} + | + |#![allow(warnings)] + |pub mod model; + |pub mod endpoint; + |pub mod wirespec; + | + """.trimMargin() + + CompileNestedTypeTest.compiler { RustIrEmitter() } shouldBeRight rust + } + + @Test + fun compileComplexModelTest() { + val rust = """ + |use super::super::wirespec::*; + |use regex; + |#[derive(Debug, Clone, Default, PartialEq)] + |pub struct Email { + | pub value: String, + |} + |impl Email { + | pub fn validate(&self) -> bool { + | return regex::Regex::new(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}${'$'}").unwrap().is_match(&self.value); + | } + | pub fn to_string(&self) -> String { + | return self.value.clone(); + | } + |} + | + |use super::super::wirespec::*; + |use regex; + |#[derive(Debug, Clone, Default, PartialEq)] + |pub struct PhoneNumber { + | pub value: String, + |} + |impl PhoneNumber { + | pub fn validate(&self) -> bool { + | return regex::Regex::new(r"^\+[1-9]\d{1,14}${'$'}").unwrap().is_match(&self.value); + | } + | pub fn to_string(&self) -> String { + | return self.value.clone(); + | } + |} + | + |use super::super::wirespec::*; + |use regex; + |#[derive(Debug, Clone, Default, PartialEq)] + |pub struct Tag { + | pub value: String, + |} + |impl Tag { + | pub fn validate(&self) -> bool { + | return regex::Regex::new(r"^[a-z][a-z0-9-]{0,19}${'$'}").unwrap().is_match(&self.value); + | } + | pub fn to_string(&self) -> String { + | return self.value.clone(); + | } + |} + | + |use super::super::wirespec::*; + |use regex; + |#[derive(Debug, Clone, Default, PartialEq)] + |pub struct EmployeeAge { + | pub value: i64, + |} + |impl EmployeeAge { + | pub fn validate(&self) -> bool { + | return 18 <= self.value && self.value <= 65; + | } + | pub fn to_string(&self) -> String { + | return format!("{}", self.value); + | } + |} + | + |use super::super::wirespec::*; + |use regex; + |use super::email::Email; + |use super::phone_number::PhoneNumber; + |#[derive(Debug, Clone, Default, PartialEq)] + |pub struct ContactInfo { + | pub email: Email, + | pub phone: Option, + |} + |impl ContactInfo { + | pub fn validate(&self) -> Vec { + | return vec![if !self.email.validate() { vec![String::from("email")] } else { Vec::::new() }.as_slice(), self.phone.as_ref().map(|it| if !it.validate() { vec![String::from("phone")] } else { Vec::::new() }).unwrap_or(Vec::::new()).as_slice()].concat(); + | } + |} + | + |use super::super::wirespec::*; + |use regex; + |use super::employee_age::EmployeeAge; + |use super::contact_info::ContactInfo; + |use super::tag::Tag; + |#[derive(Debug, Clone, Default, PartialEq)] + |pub struct Employee { + | pub name: String, + | pub age: EmployeeAge, + | pub contact_info: ContactInfo, + | pub tags: Vec, + |} + |impl Employee { + | pub fn validate(&self) -> Vec { + | return vec![if !self.age.validate() { vec![String::from("age")] } else { Vec::::new() }.as_slice(), self.contact_info.validate().iter().map(|e| format!("contactInfo.{}", e)).collect::>().as_slice(), self.tags.iter().enumerate().flat_map(|(i, el)| if !el.validate() { vec![format!("tags[{}]", i)] } else { Vec::::new() }).collect::>().as_slice()].concat(); + | } + |} + | + |use super::super::wirespec::*; + |use regex; + |use super::employee::Employee; + |#[derive(Debug, Clone, Default, PartialEq)] + |pub struct Department { + | pub name: String, + | pub employees: Vec, + |} + |impl Department { + | pub fn validate(&self) -> Vec { + | return self.employees.iter().enumerate().flat_map(|(i, el)| el.validate().iter().map(|e| format!("employees[{}].{}", i, e)).collect::>()).collect::>(); + | } + |} + | + |use super::super::wirespec::*; + |use regex; + |use super::department::Department; + |#[derive(Debug, Clone, Default, PartialEq)] + |pub struct Company { + | pub name: String, + | pub departments: Vec, + |} + |impl Company { + | pub fn validate(&self) -> Vec { + | return self.departments.iter().enumerate().flat_map(|(i, el)| el.validate().iter().map(|e| format!("departments[{}].{}", i, e)).collect::>()).collect::>(); + | } + |} + | + |#![allow(warnings)] + |pub mod model; + |pub mod endpoint; + |pub mod wirespec; + | + """.trimMargin() + + CompileComplexModelTest.compiler { RustIrEmitter() } shouldBeRight rust + } + + @Test + fun sharedOutputTest() { + val expected = """ + |use std::any::TypeId; + |use std::collections::HashMap; + | + |pub trait Model { + | fn validate(&self) -> Vec; + |} + | + |pub trait Enum: Sized { + | fn label(&self) -> &str; + | fn from_label(s: &str) -> Option; + |} + | + |pub trait Endpoint {} + | + |pub trait Channel {} + | + |pub trait Refined { + | fn value(&self) -> &T; + | fn validate(&self) -> bool; + |} + | + |pub trait Path {} + | + |pub trait Queries {} + | + |pub trait Headers {} + | + |pub trait Handler {} + | + |pub trait Call {} + | + |#[derive(Debug, Clone, Default, PartialEq)] + |pub enum Method { + | #[default] + | GET, + | PUT, + | POST, + | DELETE, + | OPTIONS, + | HEAD, + | PATCH, + | TRACE, + |} + | + |pub trait Request { + | fn path(&self) -> &dyn Path; + | fn method(&self) -> &Method; + | fn queries(&self) -> &dyn Queries; + | fn headers(&self) -> &dyn RequestHeaders; + | fn body(&self) -> &T; + |} + | + |pub trait RequestHeaders: Headers {} + | + |pub trait Response { + | fn status(&self) -> i32; + | fn headers(&self) -> &dyn ResponseHeaders; + | fn body(&self) -> &T; + |} + | + |pub trait ResponseHeaders: Headers {} + | + |pub trait BodySerializer { + | fn serialize_body(&self, t: &T, r#type: TypeId) -> Vec; + |} + | + |pub trait BodyDeserializer { + | fn deserialize_body(&self, raw: &[u8], r#type: TypeId) -> T; + |} + | + |pub trait BodySerialization: BodySerializer + BodyDeserializer {} + | + |pub trait PathSerializer { + | fn serialize_path(&self, t: &T, r#type: TypeId) -> String; + |} + | + |pub trait PathDeserializer { + | fn deserialize_path(&self, raw: &str, r#type: TypeId) -> T where T::Err: std::fmt::Debug; + |} + | + |pub trait PathSerialization: PathSerializer + PathDeserializer {} + | + |pub trait ParamSerializer { + | fn serialize_param(&self, value: &T, r#type: TypeId) -> Vec; + |} + | + |pub trait ParamDeserializer { + | fn deserialize_param(&self, values: &[String], r#type: TypeId) -> T; + |} + | + |pub trait ParamSerialization: ParamSerializer + ParamDeserializer {} + | + |pub trait Serializer: BodySerializer + PathSerializer + ParamSerializer {} + | + |pub trait Deserializer: BodyDeserializer + PathDeserializer + ParamDeserializer {} + | + |pub trait Serialization: Serializer + Deserializer {} + | + |#[derive(Debug, Clone, PartialEq)] + |pub struct RawRequest { + | pub method: String, + | pub path: Vec, + | pub queries: std::collections::HashMap>, + | pub headers: std::collections::HashMap>, + | pub body: Option>, + |} + | + |#[derive(Debug, Clone, PartialEq)] + |pub struct RawResponse { + | pub status_code: i32, + | pub headers: std::collections::HashMap>, + | pub body: Option>, + |} + | + |pub trait Transportation { + | async fn transport(&self, request: &RawRequest) -> RawResponse; + |} + | + |pub trait Client { + | type Transport: Transportation; + | type Ser: Serialization; + | fn transport(&self) -> &Self::Transport; + | fn serialization(&self) -> &Self::Ser; + |} + | + |pub trait Server { + | type Req; + | type Res; + | fn path_template(&self) -> &'static str; + | fn method(&self) -> Method; + |} + | + """.trimMargin() + + val emitter = RustIrEmitter() + emitter.shared.source shouldBe expected + } +} diff --git a/src/compiler/emitters/scala/build.gradle.kts b/src/compiler/emitters/scala/build.gradle.kts new file mode 100644 index 000000000..68b5de636 --- /dev/null +++ b/src/compiler/emitters/scala/build.gradle.kts @@ -0,0 +1,55 @@ +plugins { + id("module.publication") + id("module.spotless") + alias(libs.plugins.kotlin.multiplatform) + alias(libs.plugins.ksp) + alias(libs.plugins.kotest) +} + +group = "${libs.versions.group.id.get()}.compiler.emitters" +version = System.getenv(libs.versions.from.env.get()) ?: libs.versions.default.get() + +repositories { + mavenCentral() + mavenLocal() +} + +kotlin { + macosX64() + macosArm64() + linuxX64() + mingwX64() + js(IR) { + nodejs() + useEsModules() + } + jvm { + java { + toolchain { + languageVersion.set(JavaLanguageVersion.of(libs.versions.java.get())) + } + } + } + + sourceSets.all { + languageSettings.apply { + languageVersion = libs.versions.kotlin.compiler.get() + } + } + + sourceSets { + commonMain { + dependencies { + api(project(":src:compiler:core")) + api(project(":src:compiler:ir")) + } + } + commonTest { + dependencies { + implementation(libs.kotlin.test) + implementation(libs.bundles.kotest) + implementation(project(":src:compiler:test")) + } + } + } +} diff --git a/src/compiler/emitters/scala/src/commonMain/kotlin/community/flock/wirespec/emitters/scala/ScalaIrEmitter.kt b/src/compiler/emitters/scala/src/commonMain/kotlin/community/flock/wirespec/emitters/scala/ScalaIrEmitter.kt new file mode 100644 index 000000000..0bc9870f0 --- /dev/null +++ b/src/compiler/emitters/scala/src/commonMain/kotlin/community/flock/wirespec/emitters/scala/ScalaIrEmitter.kt @@ -0,0 +1,455 @@ +package community.flock.wirespec.emitters.scala + +import arrow.core.NonEmptyList +import community.flock.wirespec.compiler.core.addBackticks +import community.flock.wirespec.compiler.core.emit.DEFAULT_GENERATED_PACKAGE_STRING +import community.flock.wirespec.compiler.core.emit.DEFAULT_SHARED_PACKAGE_STRING +import community.flock.wirespec.compiler.core.emit.EmitShared +import community.flock.wirespec.compiler.core.emit.FileExtension +import community.flock.wirespec.compiler.core.emit.HasPackageName +import community.flock.wirespec.compiler.core.emit.Keywords +import community.flock.wirespec.compiler.core.emit.LanguageEmitter.Companion.firstToUpper +import community.flock.wirespec.compiler.core.emit.LanguageEmitter.Companion.needImports +import community.flock.wirespec.compiler.core.emit.PackageName +import community.flock.wirespec.compiler.core.emit.Shared +import community.flock.wirespec.compiler.core.emit.importReferences +import community.flock.wirespec.compiler.core.emit.plus +import community.flock.wirespec.compiler.core.parse.ast.Channel +import community.flock.wirespec.compiler.core.parse.ast.Definition +import community.flock.wirespec.compiler.core.parse.ast.Endpoint +import community.flock.wirespec.compiler.core.parse.ast.Enum +import community.flock.wirespec.compiler.core.parse.ast.FieldIdentifier +import community.flock.wirespec.compiler.core.parse.ast.Identifier +import community.flock.wirespec.compiler.core.parse.ast.Module +import community.flock.wirespec.compiler.core.parse.ast.Reference +import community.flock.wirespec.compiler.core.parse.ast.Refined +import community.flock.wirespec.compiler.core.parse.ast.Type +import community.flock.wirespec.compiler.core.parse.ast.Union +import community.flock.wirespec.compiler.utils.Logger +import community.flock.wirespec.ir.converter.convert +import community.flock.wirespec.ir.converter.convertConstraint +import community.flock.wirespec.ir.converter.convertWithValidation +import community.flock.wirespec.ir.core.Constructor +import community.flock.wirespec.ir.core.ConstructorStatement +import community.flock.wirespec.ir.core.Element +import community.flock.wirespec.ir.core.FieldCall +import community.flock.wirespec.ir.core.FunctionCall +import community.flock.wirespec.ir.core.File +import community.flock.wirespec.ir.core.Interface +import community.flock.wirespec.ir.core.Name +import community.flock.wirespec.ir.core.Namespace +import community.flock.wirespec.ir.core.RawElement +import community.flock.wirespec.ir.core.RawExpression +import community.flock.wirespec.ir.core.Struct +import community.flock.wirespec.ir.core.TypeParameter +import community.flock.wirespec.ir.core.VariableReference +import community.flock.wirespec.ir.core.findElement +import community.flock.wirespec.ir.core.flattenNestedStructs +import community.flock.wirespec.ir.core.function +import community.flock.wirespec.ir.core.`interface` +import community.flock.wirespec.ir.core.raw +import community.flock.wirespec.ir.core.transform +import community.flock.wirespec.ir.core.transformChildren +import community.flock.wirespec.ir.core.withLabelField +import community.flock.wirespec.ir.emit.IrEmitter +import community.flock.wirespec.ir.generator.ScalaGenerator +import community.flock.wirespec.ir.generator.generateScala +import community.flock.wirespec.compiler.core.parse.ast.Shared as AstShared +import community.flock.wirespec.ir.core.Function as LanguageFunction +import community.flock.wirespec.ir.core.Enum as LanguageEnum +import community.flock.wirespec.ir.core.File as LanguageFile +import community.flock.wirespec.ir.core.Import as LanguageImport +import community.flock.wirespec.ir.core.Package as LanguagePackage +import community.flock.wirespec.ir.core.Type as LanguageType + +open class ScalaIrEmitter( + override val packageName: PackageName = PackageName(DEFAULT_GENERATED_PACKAGE_STRING), + private val emitShared: EmitShared = EmitShared(), +) : IrEmitter, HasPackageName { + + override val generator = ScalaGenerator + + override val extension = FileExtension.Scala + + private val wirespecImport = """ + | + |import $DEFAULT_SHARED_PACKAGE_STRING.scala.Wirespec + |import scala.reflect.ClassTag + | + """.trimMargin() + + override val shared = object : Shared { + override val packageString = "$DEFAULT_SHARED_PACKAGE_STRING.scala" + + private val clientServer = listOf( + `interface`("ServerEdge") { + typeParam(type("Req"), type("Request", LanguageType.Wildcard)) + typeParam(type("Res"), type("Response", LanguageType.Wildcard)) + function("from") { + returnType(type("Req")) + arg("request", type("RawRequest")) + } + function("to") { + returnType(type("RawResponse")) + arg("response", type("Res")) + } + }, + `interface`("ClientEdge") { + typeParam(type("Req"), type("Request", LanguageType.Wildcard)) + typeParam(type("Res"), type("Response", LanguageType.Wildcard)) + function("to") { + returnType(type("RawRequest")) + arg("request", type("Req")) + } + function("from") { + returnType(type("Res")) + arg("response", type("RawResponse")) + } + }, + `interface`("Client") { + typeParam(type("Req"), type("Request", LanguageType.Wildcard)) + typeParam(type("Res"), type("Response", LanguageType.Wildcard)) + field("pathTemplate", LanguageType.String) + field("method", LanguageType.String) + function("client") { + returnType(type("ClientEdge", type("Req"), type("Res"))) + arg("serialization", type("Serialization")) + } + }, + `interface`("Server") { + typeParam(type("Req"), type("Request", LanguageType.Wildcard)) + typeParam(type("Res"), type("Response", LanguageType.Wildcard)) + field("pathTemplate", LanguageType.String) + field("method", LanguageType.String) + function("server") { + returnType(type("ServerEdge", type("Req"), type("Res"))) + arg("serialization", type("Serialization")) + } + }, + ) + + override val source = AstShared(packageString) + .convert() + .transform { + matchingElements { file: LanguageFile -> + val (packageElements, rest) = file.elements.partition { it is LanguagePackage } + file.copy(elements = packageElements + LanguageImport("scala.reflect", LanguageType.Custom("ClassTag")) + rest) + } + matchingElements { ns: Namespace -> + if (ns.name == Name.of("Wirespec")) { + val newElements = ns.elements.flatMap { element -> + if (element is Interface && element.name.pascalCase() in setOf("Request", "Response")) { + val nestedHeaders = element.elements.filterIsInstance() + .firstOrNull { it.name.pascalCase() == "Headers" } + if (nestedHeaders != null) { + listOf( + Namespace(element.name, listOf(nestedHeaders)), + element.copy( + elements = element.elements.filter { + !(it is Interface && it.name.pascalCase() == "Headers") + }, + fields = element.fields.map { f -> + if (f.name.value() == "headers") { + f.copy(type = LanguageType.Custom("${element.name.pascalCase()}.Headers")) + } else f + }, + ), + ) + } else { + listOf(element) + } + } else { + listOf(element) + } + } + ns.copy(elements = newElements) + } else ns + } + injectAfter { namespace: Namespace -> + if (namespace.name == Name.of("Wirespec")) clientServer + else emptyList() + } + } + .generateScala() + } + + override fun emit(module: Module, logger: Logger): NonEmptyList { + val files = super.emit(module, logger) + return if (emitShared.value) { + files + File( + Name.of(PackageName("${DEFAULT_SHARED_PACKAGE_STRING}.scala").toDir() + "Wirespec"), + listOf(RawElement(shared.source)) + ) + } else { + files + } + } + + override fun emit(definition: Definition, module: Module, logger: Logger): File { + val file = super.emit(definition, module, logger) + val subPackageName = packageName + definition + return File( + name = Name.of(subPackageName.toDir() + file.name.pascalCase()), + elements = buildList { + add(LanguagePackage(subPackageName.value)) + if (module.needImports()) add(RawElement(wirespecImport)) + addAll(file.elements) + } + ) + } + + override fun emit(type: Type, module: Module): File = + type.convertWithValidation(module) + .sanitizeNames() + .transform { + matchingElements { struct: Struct -> + if (struct.fields.isEmpty()) struct.copy(constructors = listOf(Constructor(emptyList(), emptyList()))) + else struct + } + } + + override fun emit(enum: Enum, module: Module): File = enum + .convert() + .sanitizeNames() + .transform { + matchingElements { languageEnum: LanguageEnum -> + languageEnum.withLabelField( + sanitizeEntry = { it.sanitizeEnum() }, + labelFieldOverride = true, + labelExpression = RawExpression("label"), + ) + } + } + + override fun emit(union: Union): File = union + .convert() + .sanitizeNames() + + override fun emit(refined: Refined): File { + val file = refined.convert().sanitizeNames() + val struct = file.findElement()!! + val toStringExpr = when (refined.reference.type) { + is Reference.Primitive.Type.String -> "value" + else -> "value.toString" + } + val updatedStruct = struct.copy( + fields = struct.fields.map { f -> f.copy(isOverride = true) }, + elements = listOf( + function("toString", isOverride = true) { + returnType(LanguageType.String) + returns(RawExpression(toStringExpr)) + }, + function("validate", isOverride = true) { + returnType(LanguageType.Boolean) + returns(refined.reference.convertConstraint(VariableReference(Name.of("value")))) + }, + ), + ) + return LanguageFile(Name.of(refined.identifier.sanitize()), listOf(updatedStruct)) + } + + override fun emit(endpoint: Endpoint): File { + val imports = endpoint.buildImports() + val file = endpoint.convert().sanitizeNames() + val endpointNamespace = file.findElement()!! + val flattened = endpointNamespace.flattenNestedStructs() + val requestIsObject = isRequestObject(flattened) + val body = flattened + .injectHandleFunction() + .withClientServerObjects(endpoint, requestIsObject) + + return if (imports.isNotEmpty()) LanguageFile(Name.of(endpoint.identifier.sanitize()), listOf(RawElement(imports), body)) + else LanguageFile(Name.of(endpoint.identifier.sanitize()), listOf(body)) + } + + override fun emit(channel: Channel): File { + val imports = channel.buildImports() + val file = channel.convert().sanitizeNames() + return if (imports.isNotEmpty()) file.copy(elements = listOf(RawElement(imports)) + file.elements) + else file + } + + override fun emitEndpointClient(endpoint: Endpoint): File { + val imports = endpoint.buildImports() + val endpointImport = "import ${packageName.value}.endpoint.${endpoint.identifier.value}" + val allImports = listOf(imports, endpointImport).filter { it.isNotEmpty() }.joinToString("\n") + val file = super.emitEndpointClient(endpoint).sanitizeNames().addIdentityTypeToCall() + val subPackageName = packageName + "client" + return File( + name = Name.of(subPackageName.toDir() + file.name.pascalCase()), + elements = buildList { + add(LanguagePackage(subPackageName.value)) + add(RawElement(wirespecImport)) + if (allImports.isNotEmpty()) add(RawElement(allImports)) + addAll(file.elements) + } + ) + } + + override fun emitClient(endpoints: List, logger: Logger): File { + val imports = endpoints.flatMap { it.importReferences() }.distinctBy { it.value } + .joinToString("\n") { "import ${packageName.value}.model.${it.value}" } + val endpointImports = endpoints + .joinToString("\n") { "import ${packageName.value}.endpoint.${it.identifier.value}" } + val clientImports = endpoints + .joinToString("\n") { "import ${packageName.value}.client.${it.identifier.value}Client" } + val allImports = listOf(imports, endpointImports, clientImports).filter { it.isNotEmpty() }.joinToString("\n") + val file = super.emitClient(endpoints, logger).sanitizeNames().addIdentityTypeToCall() + return File( + name = Name.of(packageName.toDir() + file.name.pascalCase()), + elements = buildList { + add(LanguagePackage(packageName.value)) + add(RawElement(wirespecImport)) + if (allImports.isNotEmpty()) add(RawElement(allImports)) + addAll(file.elements) + } + ) + } + + private fun T.sanitizeNames(): T = transform { + fields { field -> + field.copy(name = field.name.sanitizeName()) + } + parameters { param -> + param.copy(name = Name.of(param.name.camelCase().sanitizeSymbol().sanitizeKeywords())) + } + statementAndExpression { stmt, tr -> + when (stmt) { + is FieldCall -> FieldCall( + receiver = stmt.receiver?.let { tr.transformExpression(it) }, + field = stmt.field.sanitizeName(), + ) + is FunctionCall -> if (stmt.name.value() == "validate") { + stmt.copy(typeArguments = emptyList()).transformChildren(tr) + } else stmt.transformChildren(tr) + is ConstructorStatement -> ConstructorStatement( + type = tr.transformType(stmt.type), + namedArguments = stmt.namedArguments.map { (name, expr) -> + name.sanitizeName() to tr.transformExpression(expr) + }.toMap(), + ) + else -> stmt.transformChildren(tr) + } + } + } + + private fun Name.sanitizeName(): Name { + val sanitized = if (parts.size > 1) camelCase() else value().sanitizeSymbol() + return Name(listOf(sanitized.sanitizeKeywords())) + } + + private fun Identifier.sanitize(): String = value + .split(".", " ") + .mapIndexed { index, s -> if (index > 0) s.firstToUpper() else s } + .joinToString("") + .filter { it.isLetterOrDigit() || it == '_' } + .sanitizeFirstIsDigit() + .let { if (this is FieldIdentifier) it.sanitizeKeywords() else it } + + private fun String.sanitizeFirstIsDigit() = if (firstOrNull()?.isDigit() == true) "_${this}" else this + + private fun String.sanitizeKeywords() = if (this in reservedKeywords) addBackticks() else this + + private fun String.sanitizeSymbol(): String = this + .split(".", " ", "-") + .mapIndexed { index, s -> if (index > 0) s.firstToUpper() else s } + .joinToString("") + .filter { it.isLetterOrDigit() || it == '_' } + .sanitizeFirstIsDigit() + + private fun String.sanitizeEnum() = split("-", ", ", ".", " ", "//") + .joinToString("_") + .sanitizeFirstIsDigit() + .sanitizeKeywords() + + private fun Definition.buildImports() = importReferences() + .distinctBy { it.value } + .joinToString("\n") { "import ${packageName.value}.model.${it.value}" } + + private fun T.addIdentityTypeToCall(): T = transform { + matchingElements { struct: Struct -> + struct.copy( + interfaces = struct.interfaces.map { type -> + if (type is LanguageType.Custom && type.name.endsWith(".Call")) { + type.copy(generics = listOf(LanguageType.Custom("[A] =>> A"))) + } else type + } + ) + } + } + + private fun isRequestObject(namespace: Namespace): Boolean { + val requestStruct = namespace.elements.filterIsInstance() + .firstOrNull { it.name.pascalCase() == "Request" } ?: return false + return (requestStruct.constructors.size == 1 && requestStruct.constructors.single().parameters.isEmpty()) || + (requestStruct.fields.isEmpty() && requestStruct.constructors.isEmpty()) + } + + private fun Namespace.injectHandleFunction(): Namespace = transform { + matchingElements { iface: Interface -> + if (iface.name == Name.of("Handler") || iface.name == Name.of("Call")) { + iface.copy( + typeParameters = listOf(TypeParameter(LanguageType.Custom("F[_]"))), + elements = iface.elements.map { element -> + if (element is LanguageFunction) { + element.copy( + isAsync = false, + returnType = element.returnType?.let { LanguageType.Custom("F", generics = listOf(it)) }, + ) + } else element + }, + ) + } else iface + } + } + + private fun Namespace.withClientServerObjects(endpoint: Endpoint, requestIsObject: Boolean): Namespace { + val reqType = if (requestIsObject) "Request.type" else "Request" + val pathTemplate = "/" + endpoint.path.joinToString("/") { + when (it) { + is Endpoint.Segment.Literal -> it.value + is Endpoint.Segment.Param -> "{${it.identifier.value}}" + } + } + val clientObject = raw( + """ + |object Client extends Wirespec.Client[$reqType, Response[?]] { + | override val pathTemplate: String = "$pathTemplate" + | override val method: String = "${endpoint.method}" + | override def client(serialization: Wirespec.Serialization): Wirespec.ClientEdge[$reqType, Response[?]] = new Wirespec.ClientEdge[$reqType, Response[?]] { + | override def to(request: $reqType): Wirespec.RawRequest = toRawRequest(serialization, request) + | override def from(response: Wirespec.RawResponse): Response[?] = fromRawResponse(serialization, response) + | } + |} + """.trimMargin() + ) + val serverObject = raw( + """ + |object Server extends Wirespec.Server[$reqType, Response[?]] { + | override val pathTemplate: String = "$pathTemplate" + | override val method: String = "${endpoint.method}" + | override def server(serialization: Wirespec.Serialization): Wirespec.ServerEdge[$reqType, Response[?]] = new Wirespec.ServerEdge[$reqType, Response[?]] { + | override def from(request: Wirespec.RawRequest): $reqType = fromRawRequest(serialization, request) + | override def to(response: Response[?]): Wirespec.RawResponse = toRawResponse(serialization, response) + | } + |} + """.trimMargin() + ) + return copy(elements = elements + clientObject + serverObject) + } + + companion object : Keywords { + override val reservedKeywords = setOf( + "abstract", "case", "class", "def", "do", + "else", "extends", "false", "final", "for", + "forSome", "if", "implicit", "import", "lazy", + "match", "new", "null", "object", "override", + "package", "private", "protected", "return", "sealed", + "super", "this", "throw", "trait", "true", + "try", "type", "val", "var", "while", + "with", "yield", "given", "using", "enum", + "export", "then", + ) + } + +} diff --git a/src/compiler/emitters/scala/src/commonTest/kotlin/community/flock/wirespec/emitters/scala/ScalaIrEmitterTest.kt b/src/compiler/emitters/scala/src/commonTest/kotlin/community/flock/wirespec/emitters/scala/ScalaIrEmitterTest.kt new file mode 100644 index 000000000..d84205731 --- /dev/null +++ b/src/compiler/emitters/scala/src/commonTest/kotlin/community/flock/wirespec/emitters/scala/ScalaIrEmitterTest.kt @@ -0,0 +1,1006 @@ +package community.flock.wirespec.emitters.scala + +import arrow.core.nonEmptyListOf +import arrow.core.nonEmptySetOf +import community.flock.wirespec.compiler.core.EmitContext +import community.flock.wirespec.compiler.core.FileUri +import community.flock.wirespec.compiler.core.parse.ast.AST +import community.flock.wirespec.compiler.core.parse.ast.Definition +import community.flock.wirespec.compiler.core.parse.ast.Module +import community.flock.wirespec.compiler.test.CompileChannelTest +import community.flock.wirespec.compiler.test.CompileComplexModelTest +import community.flock.wirespec.compiler.test.CompileEnumTest +import community.flock.wirespec.compiler.test.CompileFullEndpointTest +import community.flock.wirespec.compiler.test.CompileMinimalEndpointTest +import community.flock.wirespec.compiler.test.CompileNestedTypeTest +import community.flock.wirespec.compiler.test.CompileRefinedTest +import community.flock.wirespec.compiler.test.CompileTypeTest +import community.flock.wirespec.compiler.test.CompileUnionTest +import community.flock.wirespec.compiler.test.NodeFixtures +import community.flock.wirespec.compiler.utils.NoLogger +import io.kotest.assertions.arrow.core.shouldBeRight +import io.kotest.matchers.shouldBe +import kotlin.test.Test + +class ScalaIrEmitterTest { + + private val emitContext = object : EmitContext, NoLogger { + override val emitters = nonEmptySetOf(ScalaIrEmitter()) + } + + @Test + fun testEmitterType() { + val expected = listOf( + """ + |package community.flock.wirespec.generated.model + |case class Todo( + | val name: String, + | val description: Option[String], + | val notes: List[String], + | val done: Boolean + |) extends Wirespec.Model { + | override def validate(): List[String] = + | List.empty[String] + |} + | + """.trimMargin(), + ) + + val res = emitContext.emitFirst(NodeFixtures.type) + res shouldBe expected + } + + @Test + fun testEmitterEmptyType() { + val expected = listOf( + """ + |package community.flock.wirespec.generated.model + |case class TodoWithoutProperties() extends Wirespec.Model { + | override def validate(): List[String] = + | List.empty[String] + |} + | + """.trimMargin(), + ) + + val res = emitContext.emitFirst(NodeFixtures.emptyType) + res shouldBe expected + } + + @Test + fun testEmitterRefined() { + val expected = listOf( + """ + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.scala.Wirespec + |import scala.reflect.ClassTag + |case class UUID( + | override val value: String + |) extends Wirespec.Refined[String] { + | override def toString(): String = + | value + | override def validate(): Boolean = + | ${"\"\"\""}^[0-9a-fA-F]{8}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{12}${'$'}${"\"\"\""}.r.findFirstIn(value).isDefined + |} + | + """.trimMargin(), + ) + + val res = emitContext.emitFirst(NodeFixtures.refined) + res shouldBe expected + } + + @Test + fun testEmitterEnum() { + val expected = listOf( + """ + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.scala.Wirespec + |import scala.reflect.ClassTag + |enum TodoStatus(override val label: String) extends Wirespec.Enum { + | case OPEN extends TodoStatus("OPEN"), + | case IN_PROGRESS extends TodoStatus("IN_PROGRESS"), + | case CLOSE extends TodoStatus("CLOSE") + | override def toString(): String = { + | label + | } + |} + | + """.trimMargin(), + ) + + val res = emitContext.emitFirst(NodeFixtures.enum) + res shouldBe expected + } + + @Test + fun compileTypeTest() { + val scala = """ + |package community.flock.wirespec.generated.model + |case class Request( + | val `type`: String, + | val url: String, + | val BODY_TYPE: Option[String], + | val params: List[String], + | val headers: Map[String, String], + | val body: Option[Map[String, Option[List[Option[String]]]]] + |) extends Wirespec.Model { + | override def validate(): List[String] = + | List.empty[String] + |} + | + """.trimMargin() + + CompileTypeTest.compiler { ScalaIrEmitter() } shouldBeRight scala + } + + @Test + fun compileEnumTest() { + val scala = """ + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.scala.Wirespec + |import scala.reflect.ClassTag + |enum MyAwesomeEnum(override val label: String) extends Wirespec.Enum { + | case ONE extends MyAwesomeEnum("ONE"), + | case Two extends MyAwesomeEnum("Two"), + | case THREE_MORE extends MyAwesomeEnum("THREE_MORE"), + | case UnitedKingdom extends MyAwesomeEnum("UnitedKingdom"), + | case _1 extends MyAwesomeEnum("-1"), + | case _0 extends MyAwesomeEnum("0"), + | case _10 extends MyAwesomeEnum("10"), + | case _999 extends MyAwesomeEnum("-999"), + | case _88 extends MyAwesomeEnum("88") + | override def toString(): String = { + | label + | } + |} + | + """.trimMargin() + + CompileEnumTest.compiler { ScalaIrEmitter() } shouldBeRight scala + } + + @Test + fun compileRefinedTest() { + val scala = """ + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.scala.Wirespec + |import scala.reflect.ClassTag + |case class TodoId( + | override val value: String + |) extends Wirespec.Refined[String] { + | override def toString(): String = + | value + | override def validate(): Boolean = + | ${"\"\"\""}^[0-9a-fA-F]{8}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{12}${'$'}${"\"\"\""}.r.findFirstIn(value).isDefined + |} + | + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.scala.Wirespec + |import scala.reflect.ClassTag + |case class TodoNoRegex( + | override val value: String + |) extends Wirespec.Refined[String] { + | override def toString(): String = + | value + | override def validate(): Boolean = + | true + |} + | + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.scala.Wirespec + |import scala.reflect.ClassTag + |case class TestInt( + | override val value: Long + |) extends Wirespec.Refined[Long] { + | override def toString(): String = + | value.toString + | override def validate(): Boolean = + | true + |} + | + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.scala.Wirespec + |import scala.reflect.ClassTag + |case class TestInt0( + | override val value: Long + |) extends Wirespec.Refined[Long] { + | override def toString(): String = + | value.toString + | override def validate(): Boolean = + | true + |} + | + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.scala.Wirespec + |import scala.reflect.ClassTag + |case class TestInt1( + | override val value: Long + |) extends Wirespec.Refined[Long] { + | override def toString(): String = + | value.toString + | override def validate(): Boolean = + | 0 <= value + |} + | + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.scala.Wirespec + |import scala.reflect.ClassTag + |case class TestInt2( + | override val value: Long + |) extends Wirespec.Refined[Long] { + | override def toString(): String = + | value.toString + | override def validate(): Boolean = + | 1 <= value && value <= 3 + |} + | + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.scala.Wirespec + |import scala.reflect.ClassTag + |case class TestNum( + | override val value: Double + |) extends Wirespec.Refined[Double] { + | override def toString(): String = + | value.toString + | override def validate(): Boolean = + | true + |} + | + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.scala.Wirespec + |import scala.reflect.ClassTag + |case class TestNum0( + | override val value: Double + |) extends Wirespec.Refined[Double] { + | override def toString(): String = + | value.toString + | override def validate(): Boolean = + | true + |} + | + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.scala.Wirespec + |import scala.reflect.ClassTag + |case class TestNum1( + | override val value: Double + |) extends Wirespec.Refined[Double] { + | override def toString(): String = + | value.toString + | override def validate(): Boolean = + | value <= 0.5 + |} + | + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.scala.Wirespec + |import scala.reflect.ClassTag + |case class TestNum2( + | override val value: Double + |) extends Wirespec.Refined[Double] { + | override def toString(): String = + | value.toString + | override def validate(): Boolean = + | -0.2 <= value && value <= 0.5 + |} + | + """.trimMargin() + + CompileRefinedTest.compiler { ScalaIrEmitter() } shouldBeRight scala + } + + @Test + fun compileUnionTest() { + val scala = """ + |package community.flock.wirespec.generated.model + |sealed trait UserAccount + | + |package community.flock.wirespec.generated.model + |case class UserAccountPassword( + | val username: String, + | val password: String + |) extends Wirespec.Model with UserAccount { + | override def validate(): List[String] = + | List.empty[String] + |} + | + |package community.flock.wirespec.generated.model + |case class UserAccountToken( + | val token: String + |) extends Wirespec.Model with UserAccount { + | override def validate(): List[String] = + | List.empty[String] + |} + | + |package community.flock.wirespec.generated.model + |case class User( + | val username: String, + | val account: UserAccount + |) extends Wirespec.Model { + | override def validate(): List[String] = + | List.empty[String] + |} + | + """.trimMargin() + + CompileUnionTest.compiler { ScalaIrEmitter() } shouldBeRight scala + } + + @Test + fun compileChannelTest() { + val scala = """ + |package community.flock.wirespec.generated.channel + |trait Queue extends Wirespec.Channel { + | def invoke(message: String): Unit + |} + | + """.trimMargin() + + CompileChannelTest.compiler { ScalaIrEmitter() } shouldBeRight scala + } + + @Test + fun compileMinimalEndpointTest() { + val scala = """ + |package community.flock.wirespec.generated.endpoint + |import community.flock.wirespec.scala.Wirespec + |import scala.reflect.ClassTag + |import community.flock.wirespec.generated.model.TodoDto + |object GetTodos extends Wirespec.Endpoint { + | object Path extends Wirespec.Path + | object Queries extends Wirespec.Queries + | object RequestHeaders extends Wirespec.Request.Headers + | object Request extends Wirespec.Request[Unit] { + | override val path: Path.type = Path + | override val method: Wirespec.Method = Wirespec.Method.GET + | override val queries: Queries.type = Queries + | override val headers: RequestHeaders.type = RequestHeaders + | override val body: Unit = () } + | sealed trait Response[T] extends Wirespec.Response[T] + | sealed trait Response2XX[T] extends Response[T] + | sealed trait ResponseListTodoDto extends Response[List[TodoDto]] + | object Response200Headers extends Wirespec.Response.Headers + | case class Response200( + | override val status: Int, + | override val headers: Response200Headers.type, + | override val body: List[TodoDto] + | ) extends Response2XX[List[TodoDto]] with ResponseListTodoDto { + | def this(body: List[TodoDto]) = this(200, Response200Headers, body) + | } + | def toRawRequest(serialization: Wirespec.Serializer, request: Request.type): Wirespec.RawRequest = + | new Wirespec.RawRequest( + | method = request.method.toString, + | path = List("todos"), + | queries = Map.empty, + | headers = Map.empty, + | body = None + | ) + | def fromRawRequest(serialization: Wirespec.Deserializer, request: Wirespec.RawRequest): Request.type = + | Request + | def toRawResponse(serialization: Wirespec.Serializer, response: Response[?]): Wirespec.RawResponse = { + | response match { + | case r: Response200 => { + | new Wirespec.RawResponse( + | statusCode = r.status, + | headers = Map.empty, + | body = Some(serialization.serializeBody(r.body, scala.reflect.classTag[List[TodoDto]])) + | ) + | } + | case _ => { + | throw new IllegalStateException(("Cannot match response with status: " + response.status)) + | } + | } + | } + | def fromRawResponse(serialization: Wirespec.Deserializer, response: Wirespec.RawResponse): Response[?] = { + | response.statusCode match { + | case 200 => { + | new Response200(body = (response.body.map(it => serialization.deserializeBody[List[TodoDto]](it, scala.reflect.classTag[List[TodoDto]])).getOrElse(throw new IllegalStateException("body is null")))) + | } + | case _ => { + | throw new IllegalStateException(("Cannot match response with status: " + response.statusCode)) + | } + | } + | } + | trait Handler[F[_]] extends Wirespec.Handler { + | def getTodos(request: Request.type): F[Response[?]] + | } + | trait Call[F[_]] extends Wirespec.Call { + | def getTodos(): F[Response[?]] + | } + | object Client extends Wirespec.Client[Request.type, Response[?]] { + | override val pathTemplate: String = "/todos" + | override val method: String = "GET" + | override def client(serialization: Wirespec.Serialization): Wirespec.ClientEdge[Request.type, Response[?]] = new Wirespec.ClientEdge[Request.type, Response[?]] { + | override def to(request: Request.type): Wirespec.RawRequest = toRawRequest(serialization, request) + | override def from(response: Wirespec.RawResponse): Response[?] = fromRawResponse(serialization, response) + | } + | } + | object Server extends Wirespec.Server[Request.type, Response[?]] { + | override val pathTemplate: String = "/todos" + | override val method: String = "GET" + | override def server(serialization: Wirespec.Serialization): Wirespec.ServerEdge[Request.type, Response[?]] = new Wirespec.ServerEdge[Request.type, Response[?]] { + | override def from(request: Wirespec.RawRequest): Request.type = fromRawRequest(serialization, request) + | override def to(response: Response[?]): Wirespec.RawResponse = toRawResponse(serialization, response) + | } + | } + |} + | + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.scala.Wirespec + |import scala.reflect.ClassTag + |case class TodoDto( + | val description: String + |) extends Wirespec.Model { + | override def validate(): List[String] = + | List.empty[String] + |} + | + |package community.flock.wirespec.generated.client + |import community.flock.wirespec.scala.Wirespec + |import scala.reflect.ClassTag + |import community.flock.wirespec.generated.model.TodoDto + |import community.flock.wirespec.generated.endpoint.GetTodos + |case class GetTodosClient( + | val serialization: Wirespec.Serialization, + | val transportation: Wirespec.Transportation + |) extends GetTodos.Call[[A] =>> A] { + | override def getTodos(): GetTodos.Response[?] = { + | val request = GetTodos.Request + | val rawRequest = GetTodos.toRawRequest(serialization, request) + | val rawResponse = transportation.transport(rawRequest) + | GetTodos.fromRawResponse(serialization, rawResponse) + | } + |} + | + |package community.flock.wirespec.generated + |import community.flock.wirespec.scala.Wirespec + |import scala.reflect.ClassTag + |import community.flock.wirespec.generated.model.TodoDto + |import community.flock.wirespec.generated.endpoint.GetTodos + |import community.flock.wirespec.generated.client.GetTodosClient + |case class Client( + | val serialization: Wirespec.Serialization, + | val transportation: Wirespec.Transportation + |) extends GetTodos.Call[[A] =>> A] { + | override def getTodos(): GetTodos.Response[?] = + | new GetTodosClient( + | serialization = serialization, + | transportation = transportation + | ).getTodos() + |} + | + """.trimMargin() + + CompileMinimalEndpointTest.compiler { ScalaIrEmitter() } shouldBeRight scala + } + + @Test + fun compileFullEndpointTest() { + val scala = """ + |package community.flock.wirespec.generated.endpoint + |import community.flock.wirespec.scala.Wirespec + |import scala.reflect.ClassTag + |import community.flock.wirespec.generated.model.Token + |import community.flock.wirespec.generated.model.PotentialTodoDto + |import community.flock.wirespec.generated.model.TodoDto + |import community.flock.wirespec.generated.model.Error + |object PutTodo extends Wirespec.Endpoint { + | case class Path( + | val id: String + | ) extends Wirespec.Path + | case class Queries( + | val done: Boolean, + | val name: Option[String] + | ) extends Wirespec.Queries + | case class RequestHeaders( + | val token: Token, + | val refreshToken: Option[Token] + | ) extends Wirespec.Request.Headers + | case class Request( + | override val path: Path, + | override val method: Wirespec.Method, + | override val queries: Queries, + | override val headers: RequestHeaders, + | override val body: PotentialTodoDto + | ) extends Wirespec.Request[PotentialTodoDto] { + | def this(id: String, done: Boolean, name: Option[String], token: Token, refreshToken: Option[Token], body: PotentialTodoDto) = this(Path(id = id), Wirespec.Method.PUT, Queries( + | done = done, + | name = name + | ), RequestHeaders( + | token = token, + | refreshToken = refreshToken + | ), body) + | } + | sealed trait Response[T] extends Wirespec.Response[T] + | sealed trait Response2XX[T] extends Response[T] + | sealed trait Response5XX[T] extends Response[T] + | sealed trait ResponseTodoDto extends Response[TodoDto] + | sealed trait ResponseError extends Response[Error] + | object Response200Headers extends Wirespec.Response.Headers + | case class Response200( + | override val status: Int, + | override val headers: Response200Headers.type, + | override val body: TodoDto + | ) extends Response2XX[TodoDto] with ResponseTodoDto { + | def this(body: TodoDto) = this(200, Response200Headers, body) + | } + | case class Response201Headers( + | val token: Token, + | val refreshToken: Option[Token] + | ) extends Wirespec.Response.Headers + | case class Response201( + | override val status: Int, + | override val headers: Response201Headers, + | override val body: TodoDto + | ) extends Response2XX[TodoDto] with ResponseTodoDto { + | def this(token: Token, refreshToken: Option[Token], body: TodoDto) = this(201, Response201Headers( + | token = token, + | refreshToken = refreshToken + | ), body) + | } + | object Response500Headers extends Wirespec.Response.Headers + | case class Response500( + | override val status: Int, + | override val headers: Response500Headers.type, + | override val body: Error + | ) extends Response5XX[Error] with ResponseError { + | def this(body: Error) = this(500, Response500Headers, body) + | } + | def toRawRequest(serialization: Wirespec.Serializer, request: Request): Wirespec.RawRequest = + | new Wirespec.RawRequest( + | method = request.method.toString, + | path = List("todos", serialization.serializePath[String](request.path.id, scala.reflect.classTag[String])), + | queries = Map("done" -> serialization.serializeParam[Boolean](request.queries.done, scala.reflect.classTag[Boolean]), "name" -> (request.queries.name.map(it => serialization.serializeParam[String](it, scala.reflect.classTag[String])).getOrElse(List.empty[String]))), + | headers = Map("token" -> serialization.serializeParam[Token](request.headers.token, scala.reflect.classTag[Token]), "Refresh-Token" -> (request.headers.refreshToken.map(it => serialization.serializeParam[Token](it, scala.reflect.classTag[Token])).getOrElse(List.empty[String]))), + | body = Some(serialization.serializeBody[PotentialTodoDto](request.body, scala.reflect.classTag[PotentialTodoDto])) + | ) + | def fromRawRequest(serialization: Wirespec.Deserializer, request: Wirespec.RawRequest): Request = + | new Request( + | id = serialization.deserializePath[String](request.path(1), scala.reflect.classTag[String]), + | done = (request.queries.get("done").map(it => serialization.deserializeParam[Boolean](it, scala.reflect.classTag[Boolean])).getOrElse(throw new IllegalStateException("Param done cannot be null"))), + | name = (request.queries.get("name").map(it => serialization.deserializeParam[String](it, scala.reflect.classTag[String]))), + | token = (request.headers.find(_._1.equalsIgnoreCase("token")).map(_._2).map(it => serialization.deserializeParam[Token](it, scala.reflect.classTag[Token])).getOrElse(throw new IllegalStateException("Param token cannot be null"))), + | refreshToken = (request.headers.find(_._1.equalsIgnoreCase("Refresh-Token")).map(_._2).map(it => serialization.deserializeParam[Token](it, scala.reflect.classTag[Token]))), + | body = (request.body.map(it => serialization.deserializeBody[PotentialTodoDto](it, scala.reflect.classTag[PotentialTodoDto])).getOrElse(throw new IllegalStateException("body is null"))) + | ) + | def toRawResponse(serialization: Wirespec.Serializer, response: Response[?]): Wirespec.RawResponse = { + | response match { + | case r: Response200 => { + | new Wirespec.RawResponse( + | statusCode = r.status, + | headers = Map.empty, + | body = Some(serialization.serializeBody(r.body, scala.reflect.classTag[TodoDto])) + | ) + | } + | case r: Response201 => { + | new Wirespec.RawResponse( + | statusCode = r.status, + | headers = Map("token" -> serialization.serializeParam[Token](r.headers.token, scala.reflect.classTag[Token]), "refreshToken" -> (r.headers.refreshToken.map(it => serialization.serializeParam[Token](it, scala.reflect.classTag[Token])).getOrElse(List.empty[String]))), + | body = Some(serialization.serializeBody(r.body, scala.reflect.classTag[TodoDto])) + | ) + | } + | case r: Response500 => { + | new Wirespec.RawResponse( + | statusCode = r.status, + | headers = Map.empty, + | body = Some(serialization.serializeBody(r.body, scala.reflect.classTag[Error])) + | ) + | } + | case _ => { + | throw new IllegalStateException(("Cannot match response with status: " + response.status)) + | } + | } + | } + | def fromRawResponse(serialization: Wirespec.Deserializer, response: Wirespec.RawResponse): Response[?] = { + | response.statusCode match { + | case 200 => { + | new Response200(body = (response.body.map(it => serialization.deserializeBody[TodoDto](it, scala.reflect.classTag[TodoDto])).getOrElse(throw new IllegalStateException("body is null")))) + | } + | case 201 => { + | new Response201( + | token = (response.headers.find(_._1.equalsIgnoreCase("token")).map(_._2).map(it => serialization.deserializeParam[Token](it, scala.reflect.classTag[Token])).getOrElse(throw new IllegalStateException("Param token cannot be null"))), + | refreshToken = (response.headers.find(_._1.equalsIgnoreCase("refreshToken")).map(_._2).map(it => serialization.deserializeParam[Token](it, scala.reflect.classTag[Token]))), + | body = (response.body.map(it => serialization.deserializeBody[TodoDto](it, scala.reflect.classTag[TodoDto])).getOrElse(throw new IllegalStateException("body is null"))) + | ) + | } + | case 500 => { + | new Response500(body = (response.body.map(it => serialization.deserializeBody[Error](it, scala.reflect.classTag[Error])).getOrElse(throw new IllegalStateException("body is null")))) + | } + | case _ => { + | throw new IllegalStateException(("Cannot match response with status: " + response.statusCode)) + | } + | } + | } + | trait Handler[F[_]] extends Wirespec.Handler { + | def putTodo(request: Request): F[Response[?]] + | } + | trait Call[F[_]] extends Wirespec.Call { + | def putTodo(id: String, done: Boolean, name: Option[String], token: Token, refreshToken: Option[Token], body: PotentialTodoDto): F[Response[?]] + | } + | object Client extends Wirespec.Client[Request, Response[?]] { + | override val pathTemplate: String = "/todos/{id}" + | override val method: String = "PUT" + | override def client(serialization: Wirespec.Serialization): Wirespec.ClientEdge[Request, Response[?]] = new Wirespec.ClientEdge[Request, Response[?]] { + | override def to(request: Request): Wirespec.RawRequest = toRawRequest(serialization, request) + | override def from(response: Wirespec.RawResponse): Response[?] = fromRawResponse(serialization, response) + | } + | } + | object Server extends Wirespec.Server[Request, Response[?]] { + | override val pathTemplate: String = "/todos/{id}" + | override val method: String = "PUT" + | override def server(serialization: Wirespec.Serialization): Wirespec.ServerEdge[Request, Response[?]] = new Wirespec.ServerEdge[Request, Response[?]] { + | override def from(request: Wirespec.RawRequest): Request = fromRawRequest(serialization, request) + | override def to(response: Response[?]): Wirespec.RawResponse = toRawResponse(serialization, response) + | } + | } + |} + | + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.scala.Wirespec + |import scala.reflect.ClassTag + |case class PotentialTodoDto( + | val name: String, + | val done: Boolean + |) extends Wirespec.Model { + | override def validate(): List[String] = + | List.empty[String] + |} + | + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.scala.Wirespec + |import scala.reflect.ClassTag + |case class Token( + | val iss: String + |) extends Wirespec.Model { + | override def validate(): List[String] = + | List.empty[String] + |} + | + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.scala.Wirespec + |import scala.reflect.ClassTag + |case class TodoDto( + | val id: String, + | val name: String, + | val done: Boolean + |) extends Wirespec.Model { + | override def validate(): List[String] = + | List.empty[String] + |} + | + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.scala.Wirespec + |import scala.reflect.ClassTag + |case class Error( + | val code: Long, + | val description: String + |) extends Wirespec.Model { + | override def validate(): List[String] = + | List.empty[String] + |} + | + |package community.flock.wirespec.generated.client + |import community.flock.wirespec.scala.Wirespec + |import scala.reflect.ClassTag + |import community.flock.wirespec.generated.model.Token + |import community.flock.wirespec.generated.model.PotentialTodoDto + |import community.flock.wirespec.generated.model.TodoDto + |import community.flock.wirespec.generated.model.Error + |import community.flock.wirespec.generated.endpoint.PutTodo + |case class PutTodoClient( + | val serialization: Wirespec.Serialization, + | val transportation: Wirespec.Transportation + |) extends PutTodo.Call[[A] =>> A] { + | override def putTodo(id: String, done: Boolean, name: Option[String], token: Token, refreshToken: Option[Token], body: PotentialTodoDto): PutTodo.Response[?] = { + | val request = new PutTodo.Request( + | id = id, + | done = done, + | name = name, + | token = token, + | refreshToken = refreshToken, + | body = body + | ) + | val rawRequest = PutTodo.toRawRequest(serialization, request) + | val rawResponse = transportation.transport(rawRequest) + | PutTodo.fromRawResponse(serialization, rawResponse) + | } + |} + | + |package community.flock.wirespec.generated + |import community.flock.wirespec.scala.Wirespec + |import scala.reflect.ClassTag + |import community.flock.wirespec.generated.model.Token + |import community.flock.wirespec.generated.model.PotentialTodoDto + |import community.flock.wirespec.generated.model.TodoDto + |import community.flock.wirespec.generated.model.Error + |import community.flock.wirespec.generated.endpoint.PutTodo + |import community.flock.wirespec.generated.client.PutTodoClient + |case class Client( + | val serialization: Wirespec.Serialization, + | val transportation: Wirespec.Transportation + |) extends PutTodo.Call[[A] =>> A] { + | override def putTodo(id: String, done: Boolean, name: Option[String], token: Token, refreshToken: Option[Token], body: PotentialTodoDto): PutTodo.Response[?] = + | new PutTodoClient( + | serialization = serialization, + | transportation = transportation + | ).putTodo(id, done, name, token, refreshToken, body) + |} + | + """.trimMargin() + + CompileFullEndpointTest.compiler { ScalaIrEmitter() } shouldBeRight scala + } + + @Test + fun compileNestedTypeTest() { + val scala = """ + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.scala.Wirespec + |import scala.reflect.ClassTag + |case class DutchPostalCode( + | override val value: String + |) extends Wirespec.Refined[String] { + | override def toString(): String = + | value + | override def validate(): Boolean = + | ${"\"\"\""}^([0-9]{4}[A-Z]{2})${'$'}${"\"\"\""}.r.findFirstIn(value).isDefined + |} + | + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.scala.Wirespec + |import scala.reflect.ClassTag + |case class Address( + | val street: String, + | val houseNumber: Long, + | val postalCode: DutchPostalCode + |) extends Wirespec.Model { + | override def validate(): List[String] = + | if (!postalCode.validate()) List("postalCode") else List.empty[String] + |} + | + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.scala.Wirespec + |import scala.reflect.ClassTag + |case class Person( + | val name: String, + | val address: Address, + | val tags: List[String] + |) extends Wirespec.Model { + | override def validate(): List[String] = + | address.validate().map(e => s"address.${'$'}{e}") + |} + | + """.trimMargin() + + CompileNestedTypeTest.compiler { ScalaIrEmitter() } shouldBeRight scala + } + + @Test + fun compileComplexModelTest() { + val scala = """ + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.scala.Wirespec + |import scala.reflect.ClassTag + |case class Email( + | override val value: String + |) extends Wirespec.Refined[String] { + | override def toString(): String = + | value + | override def validate(): Boolean = + | ${"\"\"\""}^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}${'$'}${"\"\"\""}.r.findFirstIn(value).isDefined + |} + | + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.scala.Wirespec + |import scala.reflect.ClassTag + |case class PhoneNumber( + | override val value: String + |) extends Wirespec.Refined[String] { + | override def toString(): String = + | value + | override def validate(): Boolean = + | ${"\"\"\""}^\+[1-9]\d{1,14}${'$'}${"\"\"\""}.r.findFirstIn(value).isDefined + |} + | + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.scala.Wirespec + |import scala.reflect.ClassTag + |case class Tag( + | override val value: String + |) extends Wirespec.Refined[String] { + | override def toString(): String = + | value + | override def validate(): Boolean = + | ${"\"\"\""}^[a-z][a-z0-9-]{0,19}${'$'}${"\"\"\""}.r.findFirstIn(value).isDefined + |} + | + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.scala.Wirespec + |import scala.reflect.ClassTag + |case class EmployeeAge( + | override val value: Long + |) extends Wirespec.Refined[Long] { + | override def toString(): String = + | value.toString + | override def validate(): Boolean = + | 18 <= value && value <= 65 + |} + | + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.scala.Wirespec + |import scala.reflect.ClassTag + |case class ContactInfo( + | val email: Email, + | val phone: Option[PhoneNumber] + |) extends Wirespec.Model { + | override def validate(): List[String] = + | (if (!email.validate()) List("email") else List.empty[String]) ++ (phone.map(it => if (!it.validate()) List("phone") else List.empty[String]).getOrElse(List.empty[String])) + |} + | + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.scala.Wirespec + |import scala.reflect.ClassTag + |case class Employee( + | val name: String, + | val age: EmployeeAge, + | val contactInfo: ContactInfo, + | val tags: List[Tag] + |) extends Wirespec.Model { + | override def validate(): List[String] = + | (if (!age.validate()) List("age") else List.empty[String]) ++ contactInfo.validate().map(e => s"contactInfo.${'$'}{e}") ++ tags.zipWithIndex.flatMap { case (el, i) => if (!el.validate()) List(s"tags[${'$'}{i}]") else List.empty[String] } + |} + | + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.scala.Wirespec + |import scala.reflect.ClassTag + |case class Department( + | val name: String, + | val employees: List[Employee] + |) extends Wirespec.Model { + | override def validate(): List[String] = + | employees.zipWithIndex.flatMap { case (el, i) => el.validate().map(e => s"employees[${'$'}{i}].${'$'}{e}") } + |} + | + |package community.flock.wirespec.generated.model + |import community.flock.wirespec.scala.Wirespec + |import scala.reflect.ClassTag + |case class Company( + | val name: String, + | val departments: List[Department] + |) extends Wirespec.Model { + | override def validate(): List[String] = + | departments.zipWithIndex.flatMap { case (el, i) => el.validate().map(e => s"departments[${'$'}{i}].${'$'}{e}") } + |} + | + """.trimMargin() + + CompileComplexModelTest.compiler { ScalaIrEmitter() } shouldBeRight scala + } + + @Test + fun sharedOutputTest() { + val expected = """ + |package community.flock.wirespec.scala + |import scala.reflect.ClassTag + |object Wirespec { + | trait Model { + | def validate(): List[String] + | } + | trait Enum { + | def label: String + | } + | trait Endpoint + | trait Channel + | trait Refined[T] { + | def value: T + | def validate(): Boolean + | } + | trait Path + | trait Queries + | trait Headers + | trait Handler + | trait Call + | enum Method { + | case GET + | case PUT + | case POST + | case DELETE + | case OPTIONS + | case HEAD + | case PATCH + | case TRACE + | } + | object Request { + | trait Headers + | } + | trait Request[T] { + | def path: Path + | def method: Method + | def queries: Queries + | def headers: Request.Headers + | def body: T + | } + | object Response { + | trait Headers + | } + | trait Response[T] { + | def status: Int + | def headers: Response.Headers + | def body: T + | } + | trait BodySerializer { + | def serializeBody[T](t: T, `type`: scala.reflect.ClassTag[?]): Array[Byte] + | } + | trait BodyDeserializer { + | def deserializeBody[T](raw: Array[Byte], `type`: scala.reflect.ClassTag[?]): T + | } + | trait BodySerialization extends BodySerializer with BodyDeserializer + | trait PathSerializer { + | def serializePath[T](t: T, `type`: scala.reflect.ClassTag[?]): String + | } + | trait PathDeserializer { + | def deserializePath[T](raw: String, `type`: scala.reflect.ClassTag[?]): T + | } + | trait PathSerialization extends PathSerializer with PathDeserializer + | trait ParamSerializer { + | def serializeParam[T](value: T, `type`: scala.reflect.ClassTag[?]): List[String] + | } + | trait ParamDeserializer { + | def deserializeParam[T](values: List[String], `type`: scala.reflect.ClassTag[?]): T + | } + | trait ParamSerialization extends ParamSerializer with ParamDeserializer + | trait Serializer extends BodySerializer with PathSerializer with ParamSerializer + | trait Deserializer extends BodyDeserializer with PathDeserializer with ParamDeserializer + | trait Serialization extends Serializer with Deserializer + | case class RawRequest( + | val method: String, + | val path: List[String], + | val queries: Map[String, List[String]], + | val headers: Map[String, List[String]], + | val body: Option[Array[Byte]] + | ) + | case class RawResponse( + | val statusCode: Int, + | val headers: Map[String, List[String]], + | val body: Option[Array[Byte]] + | ) + | trait Transportation { + | def transport(request: RawRequest): RawResponse + | } + | trait ServerEdge[Req <: Request[?], Res <: Response[?]] { + | def from(request: RawRequest): Req + | def to(response: Res): RawResponse + | } + | trait ClientEdge[Req <: Request[?], Res <: Response[?]] { + | def to(request: Req): RawRequest + | def from(response: RawResponse): Res + | } + | trait Client[Req <: Request[?], Res <: Response[?]] { + | def pathTemplate: String + | def method: String + | def client(serialization: Serialization): ClientEdge[Req, Res] + | } + | trait Server[Req <: Request[?], Res <: Response[?]] { + | def pathTemplate: String + | def method: String + | def server(serialization: Serialization): ServerEdge[Req, Res] + | } + |} + | + """.trimMargin() + + val emitter = ScalaIrEmitter() + emitter.shared.source shouldBe expected + } + + private fun EmitContext.emitFirst(node: Definition) = emitters.map { + val ast = AST( + nonEmptyListOf( + Module( + FileUri(""), + nonEmptyListOf(node), + ), + ), + ) + it.emit(ast, logger).first().result + } +} diff --git a/src/compiler/emitters/typescript/build.gradle.kts b/src/compiler/emitters/typescript/build.gradle.kts index 217e6b788..68b5de636 100644 --- a/src/compiler/emitters/typescript/build.gradle.kts +++ b/src/compiler/emitters/typescript/build.gradle.kts @@ -41,6 +41,7 @@ kotlin { commonMain { dependencies { api(project(":src:compiler:core")) + api(project(":src:compiler:ir")) } } commonTest { diff --git a/src/compiler/emitters/typescript/src/commonMain/kotlin/community/flock/wirespec/emitters/typescript/TypeScriptIrEmitter.kt b/src/compiler/emitters/typescript/src/commonMain/kotlin/community/flock/wirespec/emitters/typescript/TypeScriptIrEmitter.kt new file mode 100644 index 000000000..dfdf62223 --- /dev/null +++ b/src/compiler/emitters/typescript/src/commonMain/kotlin/community/flock/wirespec/emitters/typescript/TypeScriptIrEmitter.kt @@ -0,0 +1,472 @@ +package community.flock.wirespec.emitters.typescript + +import arrow.core.NonEmptyList +import community.flock.wirespec.compiler.core.emit.DEFAULT_SHARED_PACKAGE_STRING +import community.flock.wirespec.compiler.core.emit.Emitted +import community.flock.wirespec.compiler.core.emit.FileExtension +import community.flock.wirespec.ir.emit.IrEmitter +import community.flock.wirespec.compiler.core.emit.PackageName +import community.flock.wirespec.compiler.core.emit.Shared +import community.flock.wirespec.compiler.core.emit.importReferences +import community.flock.wirespec.compiler.core.emit.namespace +import community.flock.wirespec.compiler.core.emit.plus +import community.flock.wirespec.compiler.core.parse.ast.AST +import community.flock.wirespec.compiler.core.parse.ast.Channel +import community.flock.wirespec.compiler.core.parse.ast.Definition +import community.flock.wirespec.compiler.core.parse.ast.Endpoint +import community.flock.wirespec.compiler.core.parse.ast.Identifier +import community.flock.wirespec.compiler.core.parse.ast.Module +import community.flock.wirespec.compiler.core.parse.ast.Reference +import community.flock.wirespec.compiler.core.parse.ast.Refined +import community.flock.wirespec.compiler.core.parse.ast.Union +import community.flock.wirespec.compiler.utils.Logger +import community.flock.wirespec.ir.converter.classifyValidatableFields +import community.flock.wirespec.ir.converter.convert +import community.flock.wirespec.ir.converter.convertConstraint +import community.flock.wirespec.ir.converter.convertWithValidation +import community.flock.wirespec.ir.converter.requestParameters +import community.flock.wirespec.compiler.core.emit.Keywords +import community.flock.wirespec.ir.core.Assignment +import community.flock.wirespec.ir.core.BinaryOp +import community.flock.wirespec.ir.core.ConstructorStatement +import community.flock.wirespec.ir.core.Element +import community.flock.wirespec.ir.core.ErrorStatement +import community.flock.wirespec.ir.core.FieldCall +import community.flock.wirespec.ir.core.FunctionCall +import community.flock.wirespec.ir.core.Literal +import community.flock.wirespec.ir.core.Name +import community.flock.wirespec.ir.core.Parameter +import community.flock.wirespec.ir.core.Switch +import community.flock.wirespec.ir.core.VariableReference +import community.flock.wirespec.ir.core.Type as LanguageType +import community.flock.wirespec.ir.core.Case +import community.flock.wirespec.ir.core.File +import community.flock.wirespec.ir.core.RawElement +import community.flock.wirespec.ir.core.RawExpression +import community.flock.wirespec.ir.core.Namespace +import community.flock.wirespec.ir.core.Transformer +import community.flock.wirespec.ir.core.findElement +import community.flock.wirespec.ir.core.raw +import community.flock.wirespec.ir.core.transform +import community.flock.wirespec.ir.core.transformChildren +import community.flock.wirespec.ir.core.transformer +import community.flock.wirespec.ir.generator.TypeScriptGenerator +import community.flock.wirespec.ir.generator.generateTypeScript +import community.flock.wirespec.compiler.core.parse.ast.Enum as AstEnum +import community.flock.wirespec.compiler.core.parse.ast.Shared as AstShared +import community.flock.wirespec.compiler.core.parse.ast.Type as AstType + +open class TypeScriptIrEmitter : IrEmitter { + + override val generator = TypeScriptGenerator + + override val extension = FileExtension.TypeScript + + override fun transformTestFile(file: File): File = file.transform { + apply(transformPatternSwitchToValueSwitch()) + } + + override val shared = object : Shared { + val api = """ + |export type Client, RES extends Response> = (serialization: Serialization) => { + | to: (request: REQ) => RawRequest; + | from: (response: RawResponse) => RES + |} + |export type Server, RES extends Response> = (serialization: Serialization) => { + | from: (request: RawRequest) => REQ; + | to: (response: RES) => RawResponse + |} + |export type Api, RES extends Response> = { + | name: string; + | method: Method, + | path: string, + | client: Client; + | server: Server + |} + """.trimMargin() + override val packageString = DEFAULT_SHARED_PACKAGE_STRING + override val source = AstShared(packageString) + .convert() + .transform { + injectBefore { namespace: Namespace -> + if (namespace.name == Name.of("Wirespec")) listOf(RawElement("export type Type = string")) + else emptyList() + } + injectAfter { namespace: Namespace -> + if (namespace.name == Name.of("Wirespec")) listOf(RawElement(api)) + else emptyList() + } + } + .generateTypeScript() + } + + override fun emit(ast: AST, logger: Logger): NonEmptyList = super.emit(ast, logger) + .plus( + ast.modules + .flatMap { it.statements } + .groupBy { def -> def.namespace() } + .map { (ns, defs) -> + Emitted( + "${ns}/index.${extension.value}", + defs.joinToString("\n") { "export {${it.identifier.value}} from './${it.identifier.value}'" } + ) + } + ) + + override fun emit(module: Module, logger: Logger): NonEmptyList = + super.emit(module, logger) + File(Name.of("Wirespec"), listOf(RawElement(shared.source))) + + override fun emit(definition: Definition, module: Module, logger: Logger): File { + val file = super.emit(definition, module, logger) + val subPackageName = PackageName("") + definition + return File( + name = Name.of(subPackageName.toDir() + file.name.pascalCase().sanitizeSymbol()), + elements = listOf(RawElement("import {Wirespec} from '../Wirespec'\n")) + file.elements + ) + } + + override fun emit(type: AstType, module: Module): File { + val fieldValidations = type.classifyValidatableFields(module) + val typeImports = type.importReferences().distinctBy { it.value } + .joinToString("\n") { "import {${it.value}} from './${it.value}'" } + val validateImports = fieldValidations.map { it.typeName }.distinct() + .filter { it != type.identifier.value } + .joinToString("\n") { "import {validate$it} from './$it'" } + val allImports = listOf(typeImports, validateImports).filter { it.isNotEmpty() }.joinToString("\n") + val fieldNames = type.shape.value.map { it.identifier.value }.toSet() + val file = type.convertWithValidation(module) + .sanitizeNames() + .transform { + matchingElements { fn: community.flock.wirespec.ir.core.Function -> + if (fn.name == Name.of("validate")) { + fn.copy( + name = Name.of("validate${type.identifier.value}"), + parameters = listOf(Parameter(Name.of("obj"), LanguageType.Custom(type.identifier.value))), + ).transform { + statementAndExpression { s, t -> + when { + s is FunctionCall && s.name == Name.of("validate") && s.receiver != null && s.typeArguments.isNotEmpty() -> { + val typeName = (s.typeArguments.first() as? LanguageType.Custom)?.name ?: "" + FunctionCall(name = Name.of("validate$typeName"), arguments = mapOf(Name.of("obj") to t.transformExpression(s.receiver!!))) + } + s is FieldCall && s.receiver == null && s.field.camelCase() in fieldNames -> + FieldCall(receiver = VariableReference(Name.of("obj")), field = s.field) + else -> s.transformChildren(t) + } + } + } + } else fn + } + } + return if (allImports.isNotEmpty()) file.copy(elements = listOf(RawElement(allImports)) + file.elements) + else file + } + + override fun emit(enum: AstEnum, module: Module): File = + enum.convert() + .sanitizeNames() + + override fun emit(union: Union): File { + val imports = union.importReferences().distinctBy { it.value } + .joinToString("\n") { "import {type ${it.value}} from '../model'" } + val file = union.convert().sanitizeNames() + return if (imports.isNotEmpty()) file.copy(elements = listOf(RawElement(imports)) + file.elements) + else file + } + + override fun emit(refined: Refined): File { + val converted = refined.convert() + val constraintExpr = refined.reference.convertConstraint(VariableReference(Name.of("value"))) + val validatorStr = TypeScriptGenerator.generateExpression(constraintExpr) + return File( + converted.name, listOf( + RawElement("export type ${converted.name.pascalCase()} = ${emitTypeScriptReference(refined.reference)};"), + RawElement("export const validate${refined.identifier.value} = (value: ${emitTypeScriptReference(refined.reference)}) =>\n $validatorStr;"), + ) + ) + } + + override fun emit(endpoint: Endpoint): File { + val imports = endpoint.importReferences().distinctBy { it.value } + .joinToString("\n") { "import {type ${it.value}} from '../model'" } + + val apiName = endpoint.identifier.value.firstToLower() + val method = endpoint.method.name + val pathString = endpoint.path.joinToString("/") { + when (it) { + is Endpoint.Segment.Literal -> it.value + is Endpoint.Segment.Param -> "{${it.identifier.value}}" + } + } + val api = """ + |export const client:Wirespec.Client = (serialization: Wirespec.Serialization) => ({ + | from: (it) => fromRawResponse(serialization, it), + | to: (it) => toRawRequest(serialization, it) + |}) + |export const server:Wirespec.Server = (serialization: Wirespec.Serialization) => ({ + | from: (it) => fromRawRequest(serialization, it), + | to: (it) => toRawResponse(serialization, it) + |}) + |export const api = { + | name: "$apiName", + | method: "$method", + | path: "$pathString", + | server, + | client + |} as const + """.trimMargin() + + val hasRequestParams = endpoint.requestParameters().isNotEmpty() + val endpointNamespace = endpoint.convert().sanitizeNames() + .transform { + statement { stmt, transformer -> + when (stmt) { + is Switch -> stmt.copy( + default = stmt.default?.map { s -> + if (s is ErrorStatement && s.message is BinaryOp) { + val binary = s.message as BinaryOp + val literal = binary.left as? Literal + if (literal != null) ErrorStatement(Literal(literal.value.toString().trimEnd(' '), literal.type)) + else s + } else s + } + ).transformChildren(transformer) + else -> stmt.transformChildren(transformer) + } + } + } + .transform { + apply(transformPatternSwitchToValueSwitch()) + } + .transform { + if (hasRequestParams) { + matchingElements { iface: community.flock.wirespec.ir.core.Interface -> + if (iface.name == Name.of("Call")) { + iface.copy( + elements = iface.elements.map { element -> + if (element is community.flock.wirespec.ir.core.Function) { + element.copy( + parameters = listOf( + Parameter(Name.of("params"), LanguageType.Custom("RequestParams")) + ) + ) + } else element + } + ) + } else iface + } + } + } + .findElement()!! + val body = endpointNamespace + .transform { injectAfter { _: Namespace -> listOf(raw(api)) } } + + return if (imports.isNotEmpty()) File(Name.of(endpoint.identifier.sanitize()), listOf(RawElement(imports), body)) + else File(Name.of(endpoint.identifier.sanitize()), listOf(body)) + } + + override fun emit(channel: Channel): File = + channel.convert() + .sanitizeNames() + + override fun emitEndpointClient(endpoint: Endpoint): File { + val endpointName = endpoint.identifier.value + val methodName = endpointName.firstToLower() + + val imports = endpoint.importReferences().distinctBy { it.value } + .joinToString("\n") { "import {type ${it.value}} from '../model'" } + + val params = buildEndpointParams(endpoint) + val paramList = if (params.isNotEmpty()) "params: $endpointName.RequestParams" else "" + + val requestArgs = if (params.isNotEmpty()) "$endpointName.request(params)" else "$endpointName.request()" + + val code = buildString { + appendLine("export const ${methodName}Client = (serialization: Wirespec.Serialization, transportation: Wirespec.Transportation) => ({") + appendLine(" $methodName: async ($paramList): Promise<$endpointName.Response> => {") + appendLine(" const request: $endpointName.Request = $requestArgs;") + appendLine(" const rawRequest = $endpointName.toRawRequest(serialization, request);") + appendLine(" const rawResponse = await transportation.transport(rawRequest);") + appendLine(" return $endpointName.fromRawResponse(serialization, rawResponse);") + appendLine(" }") + append("})") + } + + return File( + Name.of("client/${endpointName}Client"), + buildList { + add(RawElement("import {Wirespec} from '../Wirespec'")) + add(RawElement("import {$endpointName} from '../endpoint/$endpointName'")) + if (imports.isNotEmpty()) add(RawElement(imports)) + add(RawElement(code)) + } + ) + } + + override fun emitClient(endpoints: List, logger: Logger): File { + logger.info("Emitting main Client for ${endpoints.size} endpoints") + + val clientImports = endpoints.joinToString("\n") { + val methodName = it.identifier.value.firstToLower() + "import {${methodName}Client} from './client/${it.identifier.value}Client'" + } + + val spreadEntries = endpoints.joinToString("\n") { + val methodName = it.identifier.value.firstToLower() + " ...${methodName}Client(serialization, transportation)," + } + + val code = buildString { + appendLine("export const client = (serialization: Wirespec.Serialization, transportation: Wirespec.Transportation) => ({") + appendLine(spreadEntries) + append("})") + } + + return File( + Name.of("Client"), + listOf( + RawElement("import {Wirespec} from './Wirespec'"), + RawElement(clientImports), + RawElement(code), + ) + ) + } + + private fun transformPatternSwitchToValueSwitch(): Transformer = transformer { + statement { stmt, tr -> + if (stmt is Switch && stmt.cases.any { it.type != null }) { + val varName = stmt.variable?.camelCase() ?: "r" + val transformedCases = stmt.cases.map { case -> + val typeName = (case.type as? LanguageType.Custom)?.name + val statusNum = typeName + ?.substringAfterLast(".") + ?.removePrefix("Response") + ?.toIntOrNull() + if (statusNum != null && typeName != null) { + val exprCode = TypeScriptGenerator.generateExpression(tr.transformExpression(stmt.expression)) + val castAssignment = Assignment( + name = Name.of(varName), + value = RawExpression("$exprCode as $typeName"), + isProperty = false, + ) + Case( + value = Literal(statusNum, LanguageType.Integer()), + body = listOf(castAssignment) + case.body.map { tr.transformStatement(it) }, + type = null, + ) + } else { + case.copy(body = case.body.map { tr.transformStatement(it) }) + } + } + Switch( + expression = FieldCall( + receiver = tr.transformExpression(stmt.expression), + field = Name.of("status"), + ), + cases = transformedCases, + default = stmt.default?.map { tr.transformStatement(it) }, + variable = null, + ) + } else stmt.transformChildren(tr) + } + } + + private fun T.sanitizeNames(): T = transform { + fields { field -> + field.copy(name = field.name.sanitizeName()) + } + parameters { param -> + param.copy(name = param.name.sanitizeName()) + } + statementAndExpression { stmt, tr -> + when (stmt) { + is FieldCall -> FieldCall( + receiver = stmt.receiver?.let { tr.transformExpression(it) }, + field = stmt.field.sanitizeName(), + ) + is VariableReference -> VariableReference( + name = stmt.name.sanitizeName(), + ) + is ConstructorStatement -> ConstructorStatement( + type = tr.transformType(stmt.type), + namedArguments = stmt.namedArguments.map { (key, value) -> + key.sanitizeName() to tr.transformExpression(value) + }.toMap(), + ) + is Assignment -> Assignment( + name = stmt.name.sanitizeName(), + value = tr.transformExpression(stmt.value), + isProperty = stmt.isProperty, + ) + else -> stmt.transformChildren(tr) + } + } + } + + private fun Name.sanitizeName(): Name { + val sanitized = if (parts.size > 1) camelCase() else value().sanitizeSymbol() + return Name(listOf(sanitized)) + } + + private fun Identifier.sanitize() = "\"${value}\"" + + private fun String.sanitizeSymbol() = filter { it.isLetterOrDigit() || it == '_' } + + private fun String.sanitizeKeywords() = if (this in reservedKeywords) "_$this" else this + + private fun String.firstToLower() = replaceFirstChar { it.lowercase() } + + private fun sanitizeParamName(identifier: Identifier): String { + val parts = identifier.value.split(Regex("[.\\s-]+")).filter { it.isNotEmpty() } + val name = if (parts.size > 1) Name(parts).camelCase() else identifier.value + return name.sanitizeSymbol().sanitizeKeywords() + } + + private fun buildEndpointParams(endpoint: Endpoint): List = buildList { + endpoint.path.filterIsInstance().forEach { + add(EndpointParam(sanitizeParamName(it.identifier), emitTypeScriptReference(it.reference.copy(isNullable = false)), it.reference.isNullable)) + } + endpoint.queries.forEach { + add(EndpointParam(sanitizeParamName(it.identifier), emitTypeScriptReference(it.reference.copy(isNullable = false)), it.reference.isNullable)) + } + endpoint.headers.forEach { + add(EndpointParam(sanitizeParamName(it.identifier), emitTypeScriptReference(it.reference.copy(isNullable = false)), it.reference.isNullable)) + } + endpoint.requests.first().content?.let { + add(EndpointParam("body", emitTypeScriptReference(it.reference.copy(isNullable = false)), it.reference.isNullable)) + } + } + + private fun emitTypeScriptReference(ref: Reference): String = when (ref) { + is Reference.Dict -> "Record" + is Reference.Iterable -> "${emitTypeScriptReference(ref.reference)}[]" + is Reference.Unit -> "undefined" + is Reference.Any -> "any" + is Reference.Custom -> ref.value.sanitizeSymbol() + is Reference.Primitive -> when (ref.type) { + is Reference.Primitive.Type.String -> "string" + is Reference.Primitive.Type.Integer -> "number" + is Reference.Primitive.Type.Number -> "number" + is Reference.Primitive.Type.Boolean -> "boolean" + is Reference.Primitive.Type.Bytes -> "ArrayBuffer" + } + }.let { "$it${if (ref.isNullable) " | undefined" else ""}" } + + private data class EndpointParam(val name: String, val type: String, val nullable: Boolean) + + companion object : Keywords { + override val reservedKeywords = setOf( + "break", "case", "catch", "continue", "debugger", + "default", "delete", "do", "else", "finally", + "for", "function", "if", "in", "instanceof", + "new", "return", "switch", "this", "throw", + "try", "typeof", "var", "void", "while", + "with", "class", "const", "enum", "export", + "extends", "import", "super", "implements", + "interface", "let", "package", "private", + "protected", "public", "static", "yield", + "type", "async", "await", + ) + } + +} diff --git a/src/compiler/emitters/typescript/src/commonTest/kotlin/community/flock/wirespec/emitters/typescript/TypeScriptIrEmitterTest.kt b/src/compiler/emitters/typescript/src/commonTest/kotlin/community/flock/wirespec/emitters/typescript/TypeScriptIrEmitterTest.kt new file mode 100644 index 000000000..2b5a09649 --- /dev/null +++ b/src/compiler/emitters/typescript/src/commonTest/kotlin/community/flock/wirespec/emitters/typescript/TypeScriptIrEmitterTest.kt @@ -0,0 +1,725 @@ +package community.flock.wirespec.emitters.typescript + +import community.flock.wirespec.compiler.test.CompileChannelTest +import community.flock.wirespec.compiler.test.CompileComplexModelTest +import community.flock.wirespec.compiler.test.CompileEnumTest +import community.flock.wirespec.compiler.test.CompileFullEndpointTest +import community.flock.wirespec.compiler.test.CompileMinimalEndpointTest +import community.flock.wirespec.compiler.test.CompileNestedTypeTest +import community.flock.wirespec.compiler.test.CompileRefinedTest +import community.flock.wirespec.compiler.test.CompileTypeTest +import community.flock.wirespec.compiler.test.CompileUnionTest +import io.kotest.assertions.arrow.core.shouldBeRight +import io.kotest.matchers.shouldBe +import kotlin.test.Test + +class TypeScriptIrEmitterTest { + + @Test + fun compileFullEndpointTest() { + val typescript = """ + |import {Wirespec} from '../Wirespec' + |import {type Token} from '../model' + |import {type PotentialTodoDto} from '../model' + |import {type TodoDto} from '../model' + |import {type Error} from '../model' + |export namespace PutTodo { + | export type Path = { + | "id": string, + | } + | export type Queries = { + | "done": boolean, + | "name": string | undefined, + | } + | export type RequestHeaders = { + | "token": Token, + | "refreshToken": Token | undefined, + | } + | export type Request = { + | "path": Path, + | "method": Wirespec.Method, + | "queries": Queries, + | "headers": RequestHeaders, + | "body": PotentialTodoDto, + | } + | export type RequestParams = {"id": string, "done": boolean, "name"?: string, "token": Token, "refreshToken"?: Token, "body": PotentialTodoDto} + | export const request = (params: RequestParams): Request => ({ + | path: {"id": params["id"]}, + | method: "PUT", + | queries: {"done": params["done"], "name": params["name"]}, + | headers: {"token": params["token"], "refreshToken": params["refreshToken"]}, + | body: params.body, + | }) + | export type Response = Response2XX | Response5XX | ResponseTodoDto | ResponseError + | export type Response2XX = Response200 | Response201 + | export type Response5XX = Response500 + | export type ResponseTodoDto = Response200 | Response201 + | export type ResponseError = Response500 + | export type Response200 = { + | "status": 200, + | "headers": {}, + | "body": TodoDto, + | } + | export type Response200Params = {"body": TodoDto} + | export const response200 = (params: Response200Params): Response200 => ({ + | status: 200, + | headers: {}, + | body: params.body, + | }) + | export type Response201 = { + | "status": 201, + | "headers": {"token": Token, "refreshToken": Token | undefined}, + | "body": TodoDto, + | } + | export type Response201Params = {"token": Token, "refreshToken"?: Token, "body": TodoDto} + | export const response201 = (params: Response201Params): Response201 => ({ + | status: 201, + | headers: {"token": params["token"], "refreshToken": params["refreshToken"]}, + | body: params.body, + | }) + | export type Response500 = { + | "status": 500, + | "headers": {}, + | "body": Error, + | } + | export type Response500Params = {"body": Error} + | export const response500 = (params: Response500Params): Response500 => ({ + | status: 500, + | headers: {}, + | body: params.body, + | }) + | export function toRawRequest(serialization: Wirespec.Serializer, _request: Request): Wirespec.RawRequest { + | return { method: _request.method, path: ['todos', serialization.serializePath(_request.path.id, "string")], queries: { 'done': serialization.serializeParam(_request.queries.done, "boolean"), 'name': _request.queries.name != null ? serialization.serializeParam(_request.queries.name, "string") : [] as string[] }, headers: { 'token': serialization.serializeParam(_request.headers.token, "Token"), 'Refresh-Token': _request.headers.refreshToken != null ? serialization.serializeParam(_request.headers.refreshToken, "Token") : [] as string[] }, body: serialization.serializeBody(_request.body, "PotentialTodoDto") }; + | } + | export function fromRawRequest(serialization: Wirespec.Deserializer, _request: Wirespec.RawRequest): Request { + | return request({"id": serialization.deserializePath(_request.path[1], "string"), "done": _request.queries['done'] != null ? serialization.deserializeParam(_request.queries['done'], "boolean") : (() => { throw new Error('Param done cannot be null') })(), "name": _request.queries['name'] != null ? serialization.deserializeParam(_request.queries['name'], "string") : undefined, "token": Object.entries(_request.headers).find(([k]) => k.toLowerCase() === 'token'.toLowerCase())?.[1] != null ? serialization.deserializeParam(Object.entries(_request.headers).find(([k]) => k.toLowerCase() === 'token'.toLowerCase())?.[1]!, "Token") : (() => { throw new Error('Param token cannot be null') })(), "refreshToken": Object.entries(_request.headers).find(([k]) => k.toLowerCase() === 'Refresh-Token'.toLowerCase())?.[1] != null ? serialization.deserializeParam(Object.entries(_request.headers).find(([k]) => k.toLowerCase() === 'Refresh-Token'.toLowerCase())?.[1]!, "Token") : undefined, "body": _request.body != null ? serialization.deserializeBody(_request.body, "PotentialTodoDto") : (() => { throw new Error('body is null') })()}); + | } + | export function toRawResponse(serialization: Wirespec.Serializer, response: Response): Wirespec.RawResponse { + | switch (response.status) { + | case 200: { + | const r = response as Response200; + | return { statusCode: r.status, headers: {}, body: serialization.serializeBody(r.body, "TodoDto") }; + | } + | case 201: { + | const r = response as Response201; + | return { statusCode: r.status, headers: { 'token': serialization.serializeParam(r.headers.token, "Token"), 'refreshToken': r.headers.refreshToken != null ? serialization.serializeParam(r.headers.refreshToken, "Token") : [] as string[] }, body: serialization.serializeBody(r.body, "TodoDto") }; + | } + | case 500: { + | const r = response as Response500; + | return { statusCode: r.status, headers: {}, body: serialization.serializeBody(r.body, "Error") }; + | } + | default: { + | throw new Error('Cannot match response with status:'); + | } + | } + | } + | export function fromRawResponse(serialization: Wirespec.Deserializer, response: Wirespec.RawResponse): Response { + | switch (response.statusCode) { + | case 200: + | return response200({"body": response.body != null ? serialization.deserializeBody(response.body, "TodoDto") : (() => { throw new Error('body is null') })()}); + | break; + | case 201: + | return response201({"token": Object.entries(response.headers).find(([k]) => k.toLowerCase() === 'token'.toLowerCase())?.[1] != null ? serialization.deserializeParam(Object.entries(response.headers).find(([k]) => k.toLowerCase() === 'token'.toLowerCase())?.[1]!, "Token") : (() => { throw new Error('Param token cannot be null') })(), "refreshToken": Object.entries(response.headers).find(([k]) => k.toLowerCase() === 'refreshToken'.toLowerCase())?.[1] != null ? serialization.deserializeParam(Object.entries(response.headers).find(([k]) => k.toLowerCase() === 'refreshToken'.toLowerCase())?.[1]!, "Token") : undefined, "body": response.body != null ? serialization.deserializeBody(response.body, "TodoDto") : (() => { throw new Error('body is null') })()}); + | break; + | case 500: + | return response500({"body": response.body != null ? serialization.deserializeBody(response.body, "Error") : (() => { throw new Error('body is null') })()}); + | break; + | default: + | throw new Error('Cannot match response with status:'); + | } + | } + | export interface Handler extends Wirespec.Handler { + | putTodo(_request: Request): Promise>; + | } + | export interface Call extends Wirespec.Call { + | putTodo(params: RequestParams): Promise>; + | } + | export const client:Wirespec.Client = (serialization: Wirespec.Serialization) => ({ + | from: (it) => fromRawResponse(serialization, it), + | to: (it) => toRawRequest(serialization, it) + | }) + | export const server:Wirespec.Server = (serialization: Wirespec.Serialization) => ({ + | from: (it) => fromRawRequest(serialization, it), + | to: (it) => toRawResponse(serialization, it) + | }) + | export const api = { + | name: "putTodo", + | method: "PUT", + | path: "todos/{id}", + | server, + | client + | } as const + |} + | + |import {Wirespec} from '../Wirespec' + |export type PotentialTodoDto = { + | "name": string, + | "done": boolean, + |} + |export function validatePotentialTodoDto(obj: PotentialTodoDto): string[] { + | return [] as string[]; + |} + | + |import {Wirespec} from '../Wirespec' + |export type Token = { + | "iss": string, + |} + |export function validateToken(obj: Token): string[] { + | return [] as string[]; + |} + | + |import {Wirespec} from '../Wirespec' + |export type TodoDto = { + | "id": string, + | "name": string, + | "done": boolean, + |} + |export function validateTodoDto(obj: TodoDto): string[] { + | return [] as string[]; + |} + | + |import {Wirespec} from '../Wirespec' + |export type Error = { + | "code": number, + | "description": string, + |} + |export function validateError(obj: Error): string[] { + | return [] as string[]; + |} + | + |import {Wirespec} from '../Wirespec' + |import {PutTodo} from '../endpoint/PutTodo' + |import {type Token} from '../model' + |import {type PotentialTodoDto} from '../model' + |import {type TodoDto} from '../model' + |import {type Error} from '../model' + |export const putTodoClient = (serialization: Wirespec.Serialization, transportation: Wirespec.Transportation) => ({ + | putTodo: async (params: PutTodo.RequestParams): Promise> => { + | const request: PutTodo.Request = PutTodo.request(params); + | const rawRequest = PutTodo.toRawRequest(serialization, request); + | const rawResponse = await transportation.transport(rawRequest); + | return PutTodo.fromRawResponse(serialization, rawResponse); + | } + |}) + | + |import {Wirespec} from './Wirespec' + |import {putTodoClient} from './client/PutTodoClient' + |export const client = (serialization: Wirespec.Serialization, transportation: Wirespec.Transportation) => ({ + | ...putTodoClient(serialization, transportation), + |}) + | + |export {PutTodo} from './PutTodo' + |export {PotentialTodoDto} from './PotentialTodoDto' + |export {Token} from './Token' + |export {TodoDto} from './TodoDto' + |export {Error} from './Error' + """.trimMargin() + + CompileFullEndpointTest.compiler { TypeScriptIrEmitter() } shouldBeRight typescript + } + + @Test + fun compileChannelTest() { + val typescript = """ + |import {Wirespec} from '../Wirespec' + |export interface Queue extends Wirespec.Channel { + | invoke(message: string): void; + |} + | + |export {Queue} from './Queue' + """.trimMargin() + + CompileChannelTest.compiler { TypeScriptIrEmitter() } shouldBeRight typescript + } + + @Test + fun compileEnumTest() { + val typescript = """ + |import {Wirespec} from '../Wirespec' + |export type MyAwesomeEnum = "ONE" | "Two" | "THREE_MORE" | "UnitedKingdom" | "-1" | "0" | "10" | "-999" | "88" + | + |export {MyAwesomeEnum} from './MyAwesomeEnum' + """.trimMargin() + + CompileEnumTest.compiler { TypeScriptIrEmitter() } shouldBeRight typescript + } + + @Test + fun compileMinimalEndpointTest() { + val typescript = """ + |import {Wirespec} from '../Wirespec' + |import {type TodoDto} from '../model' + |export namespace GetTodos { + | export type Path = {} + | export type Queries = {} + | export type RequestHeaders = {} + | export type Request = { + | "path": Path, + | "method": Wirespec.Method, + | "queries": Queries, + | "headers": RequestHeaders, + | "body": void, + | } + | export type RequestParams = {} + | export const request = (): Request => ({ + | path: {}, + | method: "GET", + | queries: {}, + | headers: {}, + | body: undefined, + | }) + | export type Response = Response2XX | ResponseListTodoDto + | export type Response2XX = Response200 + | export type ResponseListTodoDto = Response200 + | export type Response200 = { + | "status": 200, + | "headers": {}, + | "body": TodoDto[], + | } + | export type Response200Params = {"body": TodoDto[]} + | export const response200 = (params: Response200Params): Response200 => ({ + | status: 200, + | headers: {}, + | body: params.body, + | }) + | export function toRawRequest(serialization: Wirespec.Serializer, _request: Request): Wirespec.RawRequest { + | return { method: _request.method, path: ['todos'], queries: {}, headers: {}, body: undefined }; + | } + | export function fromRawRequest(serialization: Wirespec.Deserializer, _request: Wirespec.RawRequest): Request { + | return request(); + | } + | export function toRawResponse(serialization: Wirespec.Serializer, response: Response): Wirespec.RawResponse { + | switch (response.status) { + | case 200: { + | const r = response as Response200; + | return { statusCode: r.status, headers: {}, body: serialization.serializeBody(r.body, "TodoDto[]") }; + | } + | default: { + | throw new Error('Cannot match response with status:'); + | } + | } + | } + | export function fromRawResponse(serialization: Wirespec.Deserializer, response: Wirespec.RawResponse): Response { + | switch (response.statusCode) { + | case 200: + | return response200({"body": response.body != null ? serialization.deserializeBody(response.body, "TodoDto[]") : (() => { throw new Error('body is null') })()}); + | break; + | default: + | throw new Error('Cannot match response with status:'); + | } + | } + | export interface Handler extends Wirespec.Handler { + | getTodos(_request: Request): Promise>; + | } + | export interface Call extends Wirespec.Call { + | getTodos(): Promise>; + | } + | export const client:Wirespec.Client = (serialization: Wirespec.Serialization) => ({ + | from: (it) => fromRawResponse(serialization, it), + | to: (it) => toRawRequest(serialization, it) + | }) + | export const server:Wirespec.Server = (serialization: Wirespec.Serialization) => ({ + | from: (it) => fromRawRequest(serialization, it), + | to: (it) => toRawResponse(serialization, it) + | }) + | export const api = { + | name: "getTodos", + | method: "GET", + | path: "todos", + | server, + | client + | } as const + |} + | + |import {Wirespec} from '../Wirespec' + |export type TodoDto = { + | "description": string, + |} + |export function validateTodoDto(obj: TodoDto): string[] { + | return [] as string[]; + |} + | + |import {Wirespec} from '../Wirespec' + |import {GetTodos} from '../endpoint/GetTodos' + |import {type TodoDto} from '../model' + |export const getTodosClient = (serialization: Wirespec.Serialization, transportation: Wirespec.Transportation) => ({ + | getTodos: async (): Promise> => { + | const request: GetTodos.Request = GetTodos.request(); + | const rawRequest = GetTodos.toRawRequest(serialization, request); + | const rawResponse = await transportation.transport(rawRequest); + | return GetTodos.fromRawResponse(serialization, rawResponse); + | } + |}) + | + |import {Wirespec} from './Wirespec' + |import {getTodosClient} from './client/GetTodosClient' + |export const client = (serialization: Wirespec.Serialization, transportation: Wirespec.Transportation) => ({ + | ...getTodosClient(serialization, transportation), + |}) + | + |export {GetTodos} from './GetTodos' + |export {TodoDto} from './TodoDto' + """.trimMargin() + + CompileMinimalEndpointTest.compiler { TypeScriptIrEmitter() } shouldBeRight typescript + } + + @Test + fun compileRefinedTest() { + val typescript = """ + |import {Wirespec} from '../Wirespec' + |export type TodoId = string; + |export const validateTodoId = (value: string) => + | /^[0-9a-fA-F]{8}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{4}\b-[0-9a-fA-F]{12}${'$'}/g.test(value); + | + |import {Wirespec} from '../Wirespec' + |export type TodoNoRegex = string; + |export const validateTodoNoRegex = (value: string) => + | true; + | + |import {Wirespec} from '../Wirespec' + |export type TestInt = number; + |export const validateTestInt = (value: number) => + | true; + | + |import {Wirespec} from '../Wirespec' + |export type TestInt0 = number; + |export const validateTestInt0 = (value: number) => + | true; + | + |import {Wirespec} from '../Wirespec' + |export type TestInt1 = number; + |export const validateTestInt1 = (value: number) => + | 0 <= value; + | + |import {Wirespec} from '../Wirespec' + |export type TestInt2 = number; + |export const validateTestInt2 = (value: number) => + | 1 <= value && value <= 3; + | + |import {Wirespec} from '../Wirespec' + |export type TestNum = number; + |export const validateTestNum = (value: number) => + | true; + | + |import {Wirespec} from '../Wirespec' + |export type TestNum0 = number; + |export const validateTestNum0 = (value: number) => + | true; + | + |import {Wirespec} from '../Wirespec' + |export type TestNum1 = number; + |export const validateTestNum1 = (value: number) => + | value <= 0.5; + | + |import {Wirespec} from '../Wirespec' + |export type TestNum2 = number; + |export const validateTestNum2 = (value: number) => + | -0.2 <= value && value <= 0.5; + | + |export {TodoId} from './TodoId' + |export {TodoNoRegex} from './TodoNoRegex' + |export {TestInt} from './TestInt' + |export {TestInt0} from './TestInt0' + |export {TestInt1} from './TestInt1' + |export {TestInt2} from './TestInt2' + |export {TestNum} from './TestNum' + |export {TestNum0} from './TestNum0' + |export {TestNum1} from './TestNum1' + |export {TestNum2} from './TestNum2' + """.trimMargin() + + CompileRefinedTest.compiler { TypeScriptIrEmitter() } shouldBeRight typescript + } + + @Test + fun compileUnionTest() { + val typescript = """ + |import {Wirespec} from '../Wirespec' + |import {type UserAccountPassword} from '../model' + |import {type UserAccountToken} from '../model' + |export type UserAccount = UserAccountPassword | UserAccountToken + | + |import {Wirespec} from '../Wirespec' + |export type UserAccountPassword = { + | "username": string, + | "password": string, + |} + |export function validateUserAccountPassword(obj: UserAccountPassword): string[] { + | return [] as string[]; + |} + | + |import {Wirespec} from '../Wirespec' + |export type UserAccountToken = { + | "token": string, + |} + |export function validateUserAccountToken(obj: UserAccountToken): string[] { + | return [] as string[]; + |} + | + |import {Wirespec} from '../Wirespec' + |import {UserAccount} from './UserAccount' + |export type User = { + | "username": string, + | "account": UserAccount, + |} + |export function validateUser(obj: User): string[] { + | return [] as string[]; + |} + | + |export {UserAccount} from './UserAccount' + |export {UserAccountPassword} from './UserAccountPassword' + |export {UserAccountToken} from './UserAccountToken' + |export {User} from './User' + """.trimMargin() + + CompileUnionTest.compiler { TypeScriptIrEmitter() } shouldBeRight typescript + } + + @Test + fun compileTypeTest() { + val typescript = """ + |import {Wirespec} from '../Wirespec' + |export type Request = { + | "type": string, + | "url": string, + | "BODY_TYPE": string | undefined, + | "params": string[], + | "headers": Record, + | "body": Record | undefined, + |} + |export function validateRequest(obj: Request): string[] { + | return [] as string[]; + |} + | + |export {Request} from './Request' + """.trimMargin() + + CompileTypeTest.compiler { TypeScriptIrEmitter() } shouldBeRight typescript + } + + @Test + fun compileNestedTypeTest() { + val typescript = """ + |import {Wirespec} from '../Wirespec' + |export type DutchPostalCode = string; + |export const validateDutchPostalCode = (value: string) => + | /^([0-9]{4}[A-Z]{2})${'$'}/g.test(value); + | + |import {Wirespec} from '../Wirespec' + |import {DutchPostalCode} from './DutchPostalCode' + |import {validateDutchPostalCode} from './DutchPostalCode' + |export type Address = { + | "street": string, + | "houseNumber": number, + | "postalCode": DutchPostalCode, + |} + |export function validateAddress(obj: Address): string[] { + | return (!validateDutchPostalCode(obj.postalCode) ? ['postalCode'] : [] as string[]); + |} + | + |import {Wirespec} from '../Wirespec' + |import {Address} from './Address' + |import {validateAddress} from './Address' + |export type Person = { + | "name": string, + | "address": Address, + | "tags": string[], + |} + |export function validatePerson(obj: Person): string[] { + | return validateAddress(obj.address).map(e => `address.${'$'}{e}`); + |} + | + |export {DutchPostalCode} from './DutchPostalCode' + |export {Address} from './Address' + |export {Person} from './Person' + """.trimMargin() + + CompileNestedTypeTest.compiler { TypeScriptIrEmitter() } shouldBeRight typescript + } + + @Test + fun compileComplexModelTest() { + val typescript = """ + |import {Wirespec} from '../Wirespec' + |export type Email = string; + |export const validateEmail = (value: string) => + | /^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}${'$'}/g.test(value); + | + |import {Wirespec} from '../Wirespec' + |export type PhoneNumber = string; + |export const validatePhoneNumber = (value: string) => + | /^\+[1-9]\d{1,14}${'$'}/g.test(value); + | + |import {Wirespec} from '../Wirespec' + |export type Tag = string; + |export const validateTag = (value: string) => + | /^[a-z][a-z0-9-]{0,19}${'$'}/g.test(value); + | + |import {Wirespec} from '../Wirespec' + |export type EmployeeAge = number; + |export const validateEmployeeAge = (value: number) => + | 18 <= value && value <= 65; + | + |import {Wirespec} from '../Wirespec' + |import {Email} from './Email' + |import {PhoneNumber} from './PhoneNumber' + |import {validateEmail} from './Email' + |import {validatePhoneNumber} from './PhoneNumber' + |export type ContactInfo = { + | "email": Email, + | "phone": PhoneNumber | undefined, + |} + |export function validateContactInfo(obj: ContactInfo): string[] { + | return [...(!validateEmail(obj.email) ? ['email'] : [] as string[]), ...obj.phone != null ? (!validatePhoneNumber(obj.phone) ? ['phone'] : [] as string[]) : [] as string[]]; + |} + | + |import {Wirespec} from '../Wirespec' + |import {EmployeeAge} from './EmployeeAge' + |import {ContactInfo} from './ContactInfo' + |import {Tag} from './Tag' + |import {validateEmployeeAge} from './EmployeeAge' + |import {validateContactInfo} from './ContactInfo' + |import {validateTag} from './Tag' + |export type Employee = { + | "name": string, + | "age": EmployeeAge, + | "contactInfo": ContactInfo, + | "tags": Tag[], + |} + |export function validateEmployee(obj: Employee): string[] { + | return [...(!validateEmployeeAge(obj.age) ? ['age'] : [] as string[]), ...validateContactInfo(obj.contactInfo).map(e => `contactInfo.${'$'}{e}`), ...obj.tags.flatMap((el, i) => (!validateTag(el) ? [`tags[${'$'}{i}]`] : [] as string[]))]; + |} + | + |import {Wirespec} from '../Wirespec' + |import {Employee} from './Employee' + |import {validateEmployee} from './Employee' + |export type Department = { + | "name": string, + | "employees": Employee[], + |} + |export function validateDepartment(obj: Department): string[] { + | return obj.employees.flatMap((el, i) => validateEmployee(el).map(e => `employees[${'$'}{i}].${'$'}{e}`)); + |} + | + |import {Wirespec} from '../Wirespec' + |import {Department} from './Department' + |import {validateDepartment} from './Department' + |export type Company = { + | "name": string, + | "departments": Department[], + |} + |export function validateCompany(obj: Company): string[] { + | return obj.departments.flatMap((el, i) => validateDepartment(el).map(e => `departments[${'$'}{i}].${'$'}{e}`)); + |} + | + |export {Email} from './Email' + |export {PhoneNumber} from './PhoneNumber' + |export {Tag} from './Tag' + |export {EmployeeAge} from './EmployeeAge' + |export {ContactInfo} from './ContactInfo' + |export {Employee} from './Employee' + |export {Department} from './Department' + |export {Company} from './Company' + """.trimMargin() + + CompileComplexModelTest.compiler { TypeScriptIrEmitter() } shouldBeRight typescript + } + + @Test + fun sharedOutputTest() { + val expected = """ + |export namespace Wirespec { + | export type Type = string + | export interface Model { + | validate(): string[]; + | } + | export interface Enum { + | label: string; + | } + | export interface Endpoint {} + | export interface Channel {} + | export interface Refined { + | value: T; + | validate(): boolean; + | } + | export interface Path {} + | export interface Queries {} + | export interface Headers {} + | export interface Handler {} + | export interface Call {} + | export type Method = "GET" | "PUT" | "POST" | "DELETE" | "OPTIONS" | "HEAD" | "PATCH" | "TRACE" + | export interface Request { + | path: Path; + | method: Method; + | queries: Queries; + | headers: {}; + | body: T; + | } + | export interface Response { + | status: number; + | headers: {}; + | body: T; + | } + | export interface BodySerializer { + | serializeBody(t: T, type: Type): Uint8Array; + | } + | export interface BodyDeserializer { + | deserializeBody(raw: Uint8Array, type: Type): T; + | } + | export interface BodySerialization extends BodySerializer, BodyDeserializer {} + | export interface PathSerializer { + | serializePath(t: T, type: Type): string; + | } + | export interface PathDeserializer { + | deserializePath(raw: string, type: Type): T; + | } + | export interface PathSerialization extends PathSerializer, PathDeserializer {} + | export interface ParamSerializer { + | serializeParam(value: T, type: Type): string[]; + | } + | export interface ParamDeserializer { + | deserializeParam(values: string[], type: Type): T; + | } + | export interface ParamSerialization extends ParamSerializer, ParamDeserializer {} + | export interface Serializer extends BodySerializer, PathSerializer, ParamSerializer {} + | export interface Deserializer extends BodyDeserializer, PathDeserializer, ParamDeserializer {} + | export interface Serialization extends Serializer, Deserializer {} + | export type RawRequest = { + | "method": string, + | "path": string[], + | "queries": Record, + | "headers": Record, + | "body": Uint8Array | undefined, + | } + | export type RawResponse = { + | "statusCode": number, + | "headers": Record, + | "body": Uint8Array | undefined, + | } + | export interface Transportation { + | transport(request: RawRequest): Promise; + | } + | export type Client, RES extends Response> = (serialization: Serialization) => { + | to: (request: REQ) => RawRequest; + | from: (response: RawResponse) => RES + | } + | export type Server, RES extends Response> = (serialization: Serialization) => { + | from: (request: RawRequest) => REQ; + | to: (response: RES) => RawResponse + | } + | export type Api, RES extends Response> = { + | name: string; + | method: Method, + | path: string, + | client: Client; + | server: Server + | } + |} + | + """.trimMargin() + + val emitter = TypeScriptIrEmitter() + emitter.shared.source shouldBe expected + } +} diff --git a/src/compiler/ir/build.gradle.kts b/src/compiler/ir/build.gradle.kts new file mode 100644 index 000000000..4c6d78434 --- /dev/null +++ b/src/compiler/ir/build.gradle.kts @@ -0,0 +1,53 @@ +plugins { + id("module.publication") + id("module.spotless") + alias(libs.plugins.kotlin.multiplatform) + alias(libs.plugins.ksp) + alias(libs.plugins.kotest) +} + +group = "${libs.versions.group.id.get()}.ir" +version = System.getenv(libs.versions.from.env.get()) ?: libs.versions.default.get() + +kotlin { + macosX64() + macosArm64() + linuxX64() + mingwX64() + js(IR) { + nodejs() + } + jvm { + java { + toolchain { + languageVersion.set(JavaLanguageVersion.of(libs.versions.java.get())) + } + } + } + + sourceSets.all { + languageSettings.apply { + languageVersion = libs.versions.kotlin.compiler.get() + } + } + + sourceSets { + commonMain { + dependencies { + implementation(libs.kotlinx.io.core) + implementation(project(":src:compiler:core")) + } + } + commonTest { + dependencies { + implementation(libs.kotlin.test) + implementation(libs.bundles.kotest) + } + } + jvmTest { + dependencies { + implementation(libs.kotlin.test) + } + } + } +} diff --git a/src/compiler/ir/src/commonMain/kotlin/community/flock/wirespec/ir/converter/IrConverter.kt b/src/compiler/ir/src/commonMain/kotlin/community/flock/wirespec/ir/converter/IrConverter.kt new file mode 100644 index 000000000..6d83e1ded --- /dev/null +++ b/src/compiler/ir/src/commonMain/kotlin/community/flock/wirespec/ir/converter/IrConverter.kt @@ -0,0 +1,1122 @@ +package community.flock.wirespec.ir.converter + +import community.flock.wirespec.compiler.core.parse.ast.DefinitionIdentifier +import community.flock.wirespec.compiler.core.parse.ast.FieldIdentifier +import community.flock.wirespec.compiler.core.parse.ast.Identifier +import community.flock.wirespec.compiler.core.parse.ast.Module +import community.flock.wirespec.ir.core.ArrayIndexCall +import community.flock.wirespec.ir.core.BinaryOp +import community.flock.wirespec.ir.core.ConstructorStatement +import community.flock.wirespec.ir.core.EnumReference +import community.flock.wirespec.ir.core.EnumValueCall +import community.flock.wirespec.ir.core.ErrorStatement +import community.flock.wirespec.ir.core.Expression +import community.flock.wirespec.ir.core.FieldCall +import community.flock.wirespec.ir.core.File +import community.flock.wirespec.ir.core.FlatMapIndexed +import community.flock.wirespec.ir.core.FunctionCall +import community.flock.wirespec.ir.core.IfExpression +import community.flock.wirespec.ir.core.ListConcat +import community.flock.wirespec.ir.core.Literal +import community.flock.wirespec.ir.core.LiteralList +import community.flock.wirespec.ir.core.LiteralMap +import community.flock.wirespec.ir.core.MapExpression +import community.flock.wirespec.ir.core.Name +import community.flock.wirespec.ir.core.NotExpression +import community.flock.wirespec.ir.core.NullCheck +import community.flock.wirespec.ir.core.NullableEmpty +import community.flock.wirespec.ir.core.NullableMap +import community.flock.wirespec.ir.core.NullableOf +import community.flock.wirespec.ir.core.Precision +import community.flock.wirespec.ir.core.ReturnStatement +import community.flock.wirespec.ir.core.StringTemplate +import community.flock.wirespec.ir.core.Type +import community.flock.wirespec.ir.core.TypeDescriptor +import community.flock.wirespec.ir.core.VariableReference +import community.flock.wirespec.ir.core.file +import community.flock.wirespec.ir.core.transformMatchingElements +import community.flock.wirespec.compiler.core.parse.ast.Channel as ChannelWirespec +import community.flock.wirespec.compiler.core.parse.ast.Definition as DefinitionWirespec +import community.flock.wirespec.compiler.core.parse.ast.Endpoint as EndpointWirespec +import community.flock.wirespec.compiler.core.parse.ast.Enum as EnumWirespec +import community.flock.wirespec.compiler.core.parse.ast.Field as FieldWirespec +import community.flock.wirespec.compiler.core.parse.ast.Reference as ReferenceWirespec +import community.flock.wirespec.compiler.core.parse.ast.Refined as RefinedWirespec +import community.flock.wirespec.compiler.core.parse.ast.Shared as SharedWirespec +import community.flock.wirespec.compiler.core.parse.ast.Type as TypeWirespec +import community.flock.wirespec.compiler.core.parse.ast.Union as UnionWirespec +import community.flock.wirespec.ir.core.Constraint as LanguageConstraint + +fun DefinitionWirespec.convert(): File = when (this) { + is TypeWirespec -> convert() + is EnumWirespec -> convert() + is UnionWirespec -> convert() + is RefinedWirespec -> convert() + is ChannelWirespec -> convert() + is EndpointWirespec -> convert() +} + +fun SharedWirespec.convert(): File = file("Wirespec") { + `package`(packageString) + + namespace("Wirespec") { + `interface`("Model") { + function("validate") { + returnType(list(string)) + } + } + `interface`("Enum") { + field("label", string) + } + `interface`("Endpoint") + `interface`("Channel") + `interface`("Refined") { + typeParam(type("T")) + field("value", type("T")) + function("validate") { + returnType(boolean) + } + } + `interface`("Path") + `interface`("Queries") + `interface`("Headers") + `interface`("Handler") + `interface`("Call") + + enum("Method") { + entry("GET") + entry("PUT") + entry("POST") + entry("DELETE") + entry("OPTIONS") + entry("HEAD") + entry("PATCH") + entry("TRACE") + } + `interface`("Request") { + typeParam(type("T")) + field("path", type("Path")) + field("method", type("Method")) + field("queries", type("Queries")) + field("headers", type("Headers")) + field("body", type("T")) + `interface`("Headers") + } + `interface`("Response") { + typeParam(type("T")) + field("status", integer) + field("headers", type("Headers")) + field("body", type("T")) + `interface`("Headers") + } + `interface`("BodySerializer") { + function(Name("serialize", "Body")) { + returnType(bytes) + typeParam(type("T")) + arg("t", type("T")) + arg("type", reflect) + } + } + `interface`("BodyDeserializer") { + function(Name("deserialize", "Body")) { + returnType(type("T")) + typeParam(type("T")) + arg("raw", bytes) + arg("type", reflect) + } + } + `interface`("BodySerialization") { + extends(type("BodySerializer")) + extends(type("BodyDeserializer")) + } + `interface`("PathSerializer") { + function(Name("serialize", "Path")) { + returnType(string) + typeParam(type("T")) + arg("t", type("T")) + arg("type", reflect) + } + } + `interface`("PathDeserializer") { + function(Name("deserialize", "Path")) { + returnType(type("T")) + typeParam(type("T")) + arg("raw", string) + arg("type", reflect) + } + } + `interface`("PathSerialization") { + extends(type("PathSerializer")) + extends(type("PathDeserializer")) + } + `interface`("ParamSerializer") { + function(Name("serialize", "Param")) { + returnType(list(string)) + typeParam(type("T")) + arg("value", type("T")) + arg("type", reflect) + } + } + `interface`("ParamDeserializer") { + function(Name("deserialize", "Param")) { + returnType(type("T")) + typeParam(type("T")) + arg("values", list(string)) + arg("type", reflect) + } + } + `interface`("ParamSerialization") { + extends(type("ParamSerializer")) + extends(type("ParamDeserializer")) + } + `interface`("Serializer") { + extends(type("BodySerializer")) + extends(type("PathSerializer")) + extends(type("ParamSerializer")) + } + `interface`("Deserializer") { + extends(type("BodyDeserializer")) + extends(type("PathDeserializer")) + extends(type("ParamDeserializer")) + } + `interface`("Serialization") { + extends(type("Serializer")) + extends(type("Deserializer")) + } + struct("RawRequest") { + field("method", string) + field("path", list(string)) + field("queries", dict(string, list(string))) + field("headers", dict(string, list(string))) + field("body", bytes.nullable()) + } + struct("RawResponse") { + field(Name("status", "Code"), integer) + field("headers", dict(string, list(string))) + field("body", bytes.nullable()) + } + `interface`("Transportation") { + asyncFunction("transport") { + returnType(type("RawResponse")) + arg("request", type("RawRequest")) + } + } + } +} + +private fun Identifier.toName(): Name = when (this) { + is FieldIdentifier -> { + // Split on invalid identifier characters (dashes, dots, spaces) to produce word parts. + // The emitter's transform phase is responsible for applying language-specific casing. + val parts = value.split(Regex("[.\\s-]+")).filter { it.isNotEmpty() } + Name(parts) + } + is DefinitionIdentifier -> Name( + Name.of(value).parts.filter { part -> part.any { it.isLetterOrDigit() } }, + ) +} + +fun TypeWirespec.convert() = file(identifier.toName()) { + struct(identifier.toName()) { + implements(Type.Custom("Wirespec.Model")) + extends.map { it.convert() }.filterIsInstance().forEach { implements(it) } + shape.value.forEach { + field(it.identifier.toName(), it.reference.convert()) + } + function("validate", isOverride = true) { + returnType(Type.Array(Type.String)) + returns(LiteralList(emptyList(), Type.String)) + } + } +} + +data class FieldValidation( + val fieldName: Name, + val fieldPath: String, + val kind: Kind, + val isNullable: Boolean, + val typeName: String, + val elementIsNullable: Boolean = false, +) + +enum class Kind { MODEL, REFINED, MODEL_ARRAY, REFINED_ARRAY } + +fun TypeWirespec.convertWithValidation(module: Module): File { + val fieldValidations = classifyValidatableFields(module) + val file = convert() + return if (fieldValidations.isNotEmpty()) { + file.transformMatchingElements { fn: community.flock.wirespec.ir.core.Function -> + if (fn.name == Name.of("validate")) { + fn.copy(body = listOf(ReturnStatement(buildValidateBody(fieldValidations)))) + } else { + fn + } + } + } else { + file + } +} + +private fun buildValidateBody(validations: List): Expression { + if (validations.isEmpty()) return LiteralList(emptyList(), Type.String) + val exprs = validations.map { it.toExpression() } + return if (exprs.size == 1) exprs.single() else ListConcat(exprs) +} + +private fun FieldValidation.toExpression(): Expression { + val fieldRef: Expression = FieldCall(field = fieldName) + // When nullable, NullableMap uses "it" as the lambda variable for the unwrapped value + val valueRef: Expression = if (isNullable) VariableReference(Name.of("it")) else fieldRef + // typeArguments carries the validated type name (used by TypeScript emitter to derive standalone function name) + val validateCall = FunctionCall( + receiver = valueRef, + name = Name.of("validate"), + typeArguments = listOf(Type.Custom(typeName)), + ) + + fun stringTemplate(vararg parts: StringTemplate.Part) = StringTemplate(parts.toList()) + fun text(value: String) = StringTemplate.Part.Text(value) + fun expr(expression: Expression) = StringTemplate.Part.Expr(expression) + + val body: Expression = when (kind) { + Kind.MODEL -> MapExpression( + receiver = validateCall, + variable = Name.of("e"), + body = stringTemplate(text("$fieldPath."), expr(VariableReference(Name.of("e")))), + ) + Kind.REFINED -> IfExpression( + condition = NotExpression(validateCall), + thenExpr = LiteralList(listOf(Literal(fieldPath, Type.String)), Type.String), + elseExpr = LiteralList(emptyList(), Type.String), + ) + Kind.MODEL_ARRAY -> FlatMapIndexed( + receiver = valueRef, + indexVar = Name.of("i"), + elementVar = Name.of("el"), + body = MapExpression( + receiver = FunctionCall( + receiver = VariableReference(Name.of("el")), + name = Name.of("validate"), + typeArguments = listOf(Type.Custom(typeName)), + ), + variable = Name.of("e"), + body = stringTemplate(text("$fieldPath["), expr(VariableReference(Name.of("i"))), text("]."), expr(VariableReference(Name.of("e")))), + ), + ) + Kind.REFINED_ARRAY -> FlatMapIndexed( + receiver = valueRef, + indexVar = Name.of("i"), + elementVar = Name.of("el"), + body = IfExpression( + condition = NotExpression( + FunctionCall( + receiver = VariableReference(Name.of("el")), + name = Name.of("validate"), + typeArguments = listOf(Type.Custom(typeName)), + ), + ), + thenExpr = LiteralList( + listOf(stringTemplate(text("$fieldPath["), expr(VariableReference(Name.of("i"))), text("]"))), + Type.String, + ), + elseExpr = LiteralList(emptyList(), Type.String), + ), + ) + } + + return if (isNullable) { + NullableMap( + expression = fieldRef, + body = body, + alternative = LiteralList(emptyList(), Type.String), + ) + } else { + body + } +} + +fun TypeWirespec.classifyValidatableFields(module: Module): List = buildList { + for (field in shape.value) { + val fieldName = field.identifier.toName() + val fieldPath = field.identifier.value + val ref = field.reference + val isNullable = ref.isNullable + when (ref) { + is ReferenceWirespec.Custom -> { + val typeName = ref.value + val def = module.statements.firstOrNull { + it.identifier.value == typeName + } + when (def) { + is TypeWirespec -> add( + FieldValidation( + fieldName = fieldName, + fieldPath = fieldPath, + kind = Kind.MODEL, + isNullable = isNullable, + typeName = typeName, + ), + ) + is RefinedWirespec -> add( + FieldValidation( + fieldName = fieldName, + fieldPath = fieldPath, + kind = Kind.REFINED, + isNullable = isNullable, + typeName = typeName, + ), + ) + else -> {} // enum, union, etc. - skip + } + } + is ReferenceWirespec.Iterable -> { + val inner = ref.reference + if (inner is ReferenceWirespec.Custom) { + val typeName = inner.value + val def = module.statements.firstOrNull { + it.identifier.value == typeName + } + when (def) { + is TypeWirespec -> add( + FieldValidation( + fieldName = fieldName, + fieldPath = fieldPath, + kind = Kind.MODEL_ARRAY, + isNullable = isNullable, + typeName = typeName, + elementIsNullable = inner.isNullable, + ), + ) + is RefinedWirespec -> add( + FieldValidation( + fieldName = fieldName, + fieldPath = fieldPath, + kind = Kind.REFINED_ARRAY, + isNullable = isNullable, + typeName = typeName, + elementIsNullable = inner.isNullable, + ), + ) + else -> {} // skip + } + } + } + else -> {} // Primitive, Dict, Unit, Any - skip + } + } +} + +fun EnumWirespec.convert() = file(identifier.toName()) { + enum(identifier.toName(), Type.Custom("Wirespec.Enum")) { + entries.forEach { entry(it) } + } +} + +fun UnionWirespec.convert() = file(identifier.toName()) { + union(identifier.toName()) { + entries.map { it.convert() }.filterIsInstance().forEach { member(it.name) } + } +} + +fun RefinedWirespec.convert() = file(identifier.toName()) { + struct(identifier.toName()) { + implements(type("Wirespec.Refined", reference.convert())) + field("value", reference.convert()) + function("validate") { + returnType(Type.Boolean) + returns(reference.convertConstraint(VariableReference(Name.of("value")))) + } + } +} + +fun ChannelWirespec.convert() = file(identifier.toName()) { + `interface`(identifier.toName()) { + extends(type("Wirespec.Channel")) + function("invoke") { + arg("message", reference.convert()) + returnType(unit) + } + } +} + +fun EndpointWirespec.convert(): File { + val endpoint = this + val pathParams = path.filterIsInstance() + val requestContent = requests.first().content + val requestBodyType = requestContent?.reference?.convert() ?: Type.Unit + + return file(identifier.toName()) { + namespace(identifier.toName(), type("Wirespec.Endpoint")) { + // Path record + struct("Path") { + implements(type("Wirespec.Path")) + pathParams.forEach { field(it.identifier.toName(), it.reference.convert()) } + } + + // Queries record + struct("Queries") { + implements(type("Wirespec.Queries")) + endpoint.queries.forEach { field(it.identifier.toName(), it.reference.convert()) } + } + + // RequestHeaders record + struct("RequestHeaders") { + implements(type("Wirespec.Request.Headers")) + endpoint.headers.forEach { field(it.identifier.toName(), it.reference.convert()) } + } + + // Request record + struct("Request") { + implements(type("Wirespec.Request", requestBodyType)) + field("path", type("Path"), isOverride = true) + field("method", type("Wirespec.Method"), isOverride = true) + field("queries", type("Queries"), isOverride = true) + field("headers", type("RequestHeaders"), isOverride = true) + field("body", requestBodyType, isOverride = true) + constructo { + endpoint.requestParameters().forEach { (name, type) -> arg(name, type) } + assign( + "path", + construct(type("Path")) { + pathParams.forEach { + arg( + it.identifier.toName(), + VariableReference(it.identifier.toName()), + ) + } + }, + ) + assign("method", EnumReference(Type.Custom("Wirespec.Method"), Name.of(endpoint.method.name))) + assign( + "queries", + construct(type("Queries")) { + endpoint.queries.forEach { + arg( + it.identifier.toName(), + VariableReference(it.identifier.toName()), + ) + } + }, + ) + assign( + "headers", + construct(type("RequestHeaders")) { + endpoint.headers.forEach { + arg( + it.identifier.toName(), + VariableReference(it.identifier.toName()), + ) + } + }, + ) + assign("body", if (requestContent != null) VariableReference(Name.of("body")) else construct(Type.Unit)) + } + } + + // Pre-compute response names grouped by status prefix and content type + val distinctResponses = endpoint.responses.distinctBy { it.status } + val statusPrefixGroups = distinctResponses.groupBy { it.status.first() } + val contentTypeGroups = distinctResponses.groupBy { it.content?.reference } + + val statusPrefixUnionNames = statusPrefixGroups.keys.map { "Response${it}XX" } + val contentTypeUnionNames = contentTypeGroups.map { (ref, _) -> + val contentType = ref?.convert() ?: Type.Unit + "Response${contentType.toTypeName()}" + } + + // Response union — members are the intermediate unions + union("Response", extends = type("Wirespec.Response", type("T"))) { + typeParam(type("T")) + (statusPrefixUnionNames + contentTypeUnionNames).distinct().forEach { member(it) } + } + + // Status prefix unions (Response2XX, Response5XX, etc.) + statusPrefixGroups.forEach { (prefix, responses) -> + union("Response${prefix}XX", extends = type("Response", type("T"))) { + typeParam(type("T")) + responses.forEach { member("Response${it.status.replaceFirstChar { c -> c.uppercaseChar() }}") } + } + } + + // Content type unions (ResponseUnit, ResponseTodoDto, etc.) + contentTypeGroups.forEach { (ref, responses) -> + val contentType = ref?.convert() ?: Type.Unit + val typeName = contentType.toTypeName() + union("Response$typeName", extends = type("Response", contentType)) { + responses.forEach { member("Response${it.status.replaceFirstChar { c -> c.uppercaseChar() }}") } + } + } + + // Individual response records (Response200, Response201, etc.) + endpoint.responses.distinctBy { it.status }.forEach { response -> + val bodyType = response.content?.reference?.convert() ?: Type.Unit + val statusCode = response.status.toIntOrNull() ?: 0 + val statusClassName = response.status.replaceFirstChar { it.uppercaseChar() } + val statusPrefix = response.status.first() + val contentTypeName = bodyType.toTypeName() + struct("Response$statusClassName") { + implements(type("Response${statusPrefix}XX", bodyType)) + implements(type("Response$contentTypeName")) + field("status", Type.IntegerLiteral(statusCode), isOverride = true) + field("headers", type("Headers"), isOverride = true) + field("body", bodyType, isOverride = true) + struct("Headers") { + implements(type("Wirespec.Response.Headers")) + response.headers.forEach { field(it.identifier.toName(), it.reference.convert()) } + } + constructo { + response.responseParameters().forEach { (name, type) -> arg(name, type) } + assign("status", Literal(statusCode, Type.Integer(Precision.P32))) + assign( + "headers", + construct(type("Headers")) { + response.headers.forEach { + arg( + it.identifier.toName(), + VariableReference(it.identifier.toName()), + ) + } + }, + ) + assign("body", if (response.content != null) VariableReference(Name.of("body")) else construct(Type.Unit)) + } + } + } + + // Conversion functions at Endpoint interface level + function(Name("to", "Raw", "Request"), isStatic = true) { + returnType(type("Wirespec.RawRequest")) + arg("serialization", type("Wirespec.Serializer")) + arg("request", type("Request")) + returns( + construct(type("Wirespec.RawRequest")) { + arg("method", EnumValueCall(FieldCall(VariableReference(Name.of("request")), Name.of("method")))) + arg( + "path", + LiteralList( + values = endpoint.path.map { + when (it) { + is EndpointWirespec.Segment.Literal -> Literal(it.value, Type.String) + is EndpointWirespec.Segment.Param -> FunctionCall( + receiver = VariableReference(Name.of("serialization")), + name = Name("serialize", "Path"), + typeArguments = listOf(it.reference.convert()), + arguments = mapOf( + Name.of("value") to FieldCall( + FieldCall(VariableReference(Name.of("request")), Name.of("path")), + it.identifier.toName(), + ), + Name.of("type") to it.reference.toTypeDescriptor(), + ), + ) + } + }, + type = Type.String, + ), + ) + arg( + "queries", + LiteralMap( + values = endpoint.queries.associate { + it.identifier.value to serializeParamExpression( + fieldAccess = FieldCall( + FieldCall(VariableReference(Name.of("request")), Name.of("queries")), + it.identifier.toName(), + ), + field = it, + ) + }, + keyType = Type.String, + valueType = Type.Custom("List"), + ), + ) + arg( + "headers", + LiteralMap( + values = endpoint.headers.associate { + it.identifier.value to serializeParamExpression( + fieldAccess = FieldCall( + FieldCall(VariableReference(Name.of("request")), Name.of("headers")), + it.identifier.toName(), + ), + field = it, + ) + }, + keyType = Type.String, + valueType = Type.Custom("List"), + ), + ) + arg( + "body", + endpoint.requests.first().content?.let { + NullableOf( + FunctionCall( + receiver = VariableReference(Name.of("serialization")), + name = Name("serialize", "Body"), + typeArguments = listOf(it.reference.convert()), + arguments = mapOf( + Name.of("value") to FieldCall(VariableReference(Name.of("request")), Name.of("body")), + Name.of("type") to it.reference.toTypeDescriptor(), + ), + ), + ) + } ?: NullableEmpty, + ) + }, + ) + } + + function(Name("from", "Raw", "Request"), isStatic = true) { + returnType(type("Request")) + arg("serialization", type("Wirespec.Deserializer")) + arg("request", type("Wirespec.RawRequest")) + returns( + construct(type("Request")) { + endpoint.path.forEachIndexed { index, segment -> + if (segment is EndpointWirespec.Segment.Param) { + arg( + segment.identifier.toName(), + FunctionCall( + receiver = VariableReference(Name.of("serialization")), + name = Name("deserialize", "Path"), + typeArguments = listOf(segment.reference.convert()), + arguments = mapOf( + Name.of("value") to ArrayIndexCall( + receiver = FieldCall(VariableReference(Name.of("request")), Name.of("path")), + index = Literal(index, Type.Integer(Precision.P32)), + ), + Name.of("type") to segment.reference.toTypeDescriptor(), + ), + ), + ) + } + } + endpoint.queries.forEach { field -> + arg( + field.identifier.toName(), + deserializeParamExpression( + map = FieldCall(VariableReference(Name.of("request")), Name.of("queries")), + fieldName = field.identifier.value, + field = field, + ), + ) + } + endpoint.headers.forEach { field -> + arg( + field.identifier.toName(), + deserializeParamExpression( + map = FieldCall(VariableReference(Name.of("request")), Name.of("headers")), + fieldName = field.identifier.value, + field = field, + caseSensitive = false, + ), + ) + } + endpoint.requests.first().content?.let { + arg( + "body", + NullableMap( + expression = FieldCall(VariableReference(Name.of("request")), Name.of("body")), + body = FunctionCall( + receiver = VariableReference(Name.of("serialization")), + name = Name("deserialize", "Body"), + typeArguments = listOf(it.reference.convert()), + arguments = mapOf( + Name.of("value") to VariableReference(Name.of("it")), + Name.of("type") to it.reference.toTypeDescriptor(), + ), + ), + alternative = ErrorStatement(Literal("body is null", Type.String)), + ), + ) + } + }, + ) + } + + function(Name("to", "Raw", "Response"), isStatic = true) { + returnType(type("Wirespec.RawResponse")) + arg("serialization", type("Wirespec.Serializer")) + arg("response", type("Response", wildcard)) + switch(VariableReference(Name.of("response")), "r") { + endpoint.responses.distinctBy { it.status }.forEach { response -> + val statusClassName = response.status.replaceFirstChar { it.uppercaseChar() } + case(type("Response$statusClassName")) { + returns( + construct(type("Wirespec.RawResponse")) { + arg(Name("status", "Code"), FieldCall(VariableReference(Name.of("r")), Name.of("status"))) + arg( + "headers", + LiteralMap( + values = response.headers.associate { header -> + header.identifier.value to serializeParamExpression( + fieldAccess = FieldCall( + FieldCall(VariableReference(Name.of("r")), Name.of("headers")), + header.identifier.toName(), + ), + field = header, + ) + }, + keyType = Type.String, + valueType = Type.Custom("List"), + ), + ) + arg( + "body", + response.content?.let { content -> + NullableOf( + FunctionCall( + receiver = VariableReference(Name.of("serialization")), + name = Name("serialize", "Body"), + arguments = mapOf( + Name.of("value") to FieldCall(VariableReference(Name.of("r")), Name.of("body")), + Name.of("type") to content.reference.toTypeDescriptor(), + ), + ), + ) + } ?: NullableEmpty, + ) + }, + ) + } + } + default { + error( + BinaryOp( + Literal("Cannot match response with status: ", Type.String), + BinaryOp.Operator.PLUS, + FieldCall(VariableReference(Name.of("response")), Name.of("status")), + ), + ) + } + } + } + + function(Name("from", "Raw", "Response"), isStatic = true) { + returnType(type("Response", wildcard)) + arg("serialization", type("Wirespec.Deserializer")) + arg("response", type("Wirespec.RawResponse")) + switch(FieldCall(receiver = VariableReference(Name.of("response")), field = Name("status", "Code"))) { + endpoint.responses.distinctBy { it.status }.filter { it.status.toIntOrNull() != null } + .forEach { response -> + val statusClassName = response.status.replaceFirstChar { it.uppercaseChar() } + case(literal(response.status.toInt())) { + returns( + construct(type("Response$statusClassName")) { + response.headers.forEach { header -> + arg( + header.identifier.toName(), + deserializeParamExpression( + map = FieldCall(VariableReference(Name.of("response")), Name.of("headers")), + fieldName = header.identifier.value, + field = header, + caseSensitive = false, + ), + ) + } + response.content?.let { content -> + arg( + "body", + NullableMap( + expression = FieldCall(VariableReference(Name.of("response")), Name.of("body")), + body = FunctionCall( + receiver = VariableReference(Name.of("serialization")), + name = Name("deserialize", "Body"), + typeArguments = listOf(content.reference.convert()), + arguments = mapOf( + Name.of("value") to VariableReference(Name.of("it")), + Name.of("type") to content.reference.toTypeDescriptor(), + ), + ), + alternative = ErrorStatement(Literal("body is null", Type.String)), + ), + ) + } + }, + ) + } + } + default { + error( + BinaryOp( + Literal("Cannot match response with status: ", Type.String), + BinaryOp.Operator.PLUS, + FieldCall(VariableReference(Name.of("response")), Name("status", "Code")), + ), + ) + } + } + } + + // Handler interface + `interface`("Handler") { + extends(type("Wirespec.Handler")) + asyncFunction(endpoint.identifier.toName()) { + arg("request", type("Request")) + returnType(type("Response", wildcard)) + } + } + + // Call interface + `interface`("Call") { + extends(type("Wirespec.Call")) + asyncFunction(endpoint.identifier.toName()) { + endpoint.requestParameters().forEach { (name, type) -> arg(name, type) } + returnType(type("Response", wildcard)) + } + } + } + } +} + +private fun Type.toTypeName(): String = when (this) { + Type.Any -> "Any" + is Type.Unit -> "Unit" + is Type.Wildcard -> "Wildcard" + is Type.Reflect -> "Type" + is Type.Custom -> name + is Type.Array -> "List${elementType.toTypeName()}" + is Type.Nullable -> "Optional${type.toTypeName()}" + is Type.String -> "String" + is Type.Integer -> "Integer" + is Type.Number -> "Number" + is Type.Boolean -> "Boolean" + is Type.Bytes -> "Bytes" + is Type.Dict -> "Map" + is Type.IntegerLiteral -> "Integer" + is Type.StringLiteral -> "String" +} + +fun ReferenceWirespec.convert(): Type = when (this) { + is ReferenceWirespec.Any -> Type.Custom("Any") + is ReferenceWirespec.Custom -> Type.Custom(value) + is ReferenceWirespec.Dict -> Type.Dict(Type.String, reference.convert()) + is ReferenceWirespec.Iterable -> Type.Array(reference.convert()) + is ReferenceWirespec.Primitive -> when (val t = type) { + ReferenceWirespec.Primitive.Type.Boolean -> Type.Boolean + ReferenceWirespec.Primitive.Type.Bytes -> Type.Bytes + is ReferenceWirespec.Primitive.Type.Integer -> when (t.precision) { + ReferenceWirespec.Primitive.Type.Precision.P32 -> Type.Integer(Precision.P32) + ReferenceWirespec.Primitive.Type.Precision.P64 -> Type.Integer(Precision.P64) + } + + is ReferenceWirespec.Primitive.Type.Number -> when (t.precision) { + ReferenceWirespec.Primitive.Type.Precision.P32 -> Type.Number(Precision.P32) + ReferenceWirespec.Primitive.Type.Precision.P64 -> Type.Number(Precision.P64) + } + + is ReferenceWirespec.Primitive.Type.String -> Type.String + } + + is ReferenceWirespec.Unit -> Type.Unit +} + .let { if (isNullable) Type.Nullable(it) else it } + +fun ReferenceWirespec.Primitive.Type.Constraint.convert(value: Expression): LanguageConstraint = when (this) { + is ReferenceWirespec.Primitive.Type.Constraint.RegExp -> + LanguageConstraint.RegexMatch( + pattern = this.value.split("/").drop(1).dropLast(1).joinToString("/"), + rawValue = this.value, + value = value, + ) + + is ReferenceWirespec.Primitive.Type.Constraint.Bound -> + LanguageConstraint.BoundCheck(min = min, max = max, value = value) +} + +fun ReferenceWirespec.Primitive.convertConstraint(value: Expression): Expression = when (val t = type) { + is ReferenceWirespec.Primitive.Type.String -> t.constraint?.convert(value) + is ReferenceWirespec.Primitive.Type.Integer -> t.constraint?.convert(value) + is ReferenceWirespec.Primitive.Type.Number -> t.constraint?.convert(value) + ReferenceWirespec.Primitive.Type.Boolean -> null + ReferenceWirespec.Primitive.Type.Bytes -> null +} ?: Literal(true, Type.Boolean) + +fun ReferenceWirespec.convertConstraint(value: Expression): Expression = when (this) { + is ReferenceWirespec.Primitive -> convertConstraint(value) + else -> Literal(true, Type.Boolean) +} + +private fun ReferenceWirespec.toTypeDescriptor(): TypeDescriptor = TypeDescriptor(convert()) + +private fun deserializeParamExpression( + map: Expression, + fieldName: String, + field: FieldWirespec, + caseSensitive: Boolean = true, +): Expression { + val type = field.reference.copy(isNullable = false) + val getCall = ArrayIndexCall( + receiver = map, + index = Literal(fieldName, Type.String), + caseSensitive = caseSensitive, + ) + return NullCheck( + expression = getCall, + body = FunctionCall( + receiver = VariableReference(Name.of("serialization")), + name = Name("deserialize", "Param"), + typeArguments = listOf(type.convert()), + arguments = mapOf( + Name.of("value") to VariableReference(Name.of("it")), + Name.of("type") to type.toTypeDescriptor(), + ), + ), + alternative = if (field.reference.isNullable) { + null + } else { + ErrorStatement( + Literal( + "Param $fieldName cannot be null", + Type.String, + ), + ) + }, + ) +} + +private fun serializeParamExpression( + fieldAccess: Expression, + field: FieldWirespec, +): Expression { + val type = field.reference.copy(isNullable = false) + val serializeCall = FunctionCall( + receiver = VariableReference(Name.of("serialization")), + name = Name("serialize", "Param"), + typeArguments = listOf(type.convert()), + arguments = mapOf( + Name.of("value") to VariableReference(Name.of("it")), + Name.of("type") to type.toTypeDescriptor(), + ), + ) + return if (field.reference.isNullable) { + NullableMap( + expression = fieldAccess, + body = serializeCall, + alternative = LiteralList(emptyList(), Type.String), + ) + } else { + FunctionCall( + receiver = VariableReference(Name.of("serialization")), + name = Name("serialize", "Param"), + typeArguments = listOf(type.convert()), + arguments = mapOf( + Name.of("value") to fieldAccess, + Name.of("type") to field.reference.toTypeDescriptor(), + ), + ) + } +} + +fun EndpointWirespec.convertEndpointClient(): File { + val endpointName = identifier.toName() + val endpointNameStr = endpointName.value() + + return file(Name.of("${endpointNameStr}Client")) { + struct(Name.of("${endpointNameStr}Client")) { + field("serialization", Type.Custom("Wirespec.Serialization")) + field("transportation", Type.Custom("Wirespec.Transportation")) + implements(Type.Custom("$endpointNameStr.Call")) + + asyncFunction(endpointName, isOverride = true) { + requestParameters().forEach { (name, type) -> arg(name, type) } + returnType(Type.Custom("$endpointNameStr.Response", listOf(Type.Wildcard))) + + assign( + "request", + ConstructorStatement( + type = Type.Custom("$endpointNameStr.Request"), + namedArguments = requestParameters().associate { (name, _) -> + name to VariableReference(name) + }, + ), + ) + + assign( + "rawRequest", + FunctionCall( + name = Name(listOf("$endpointNameStr.toRawRequest")), + arguments = mapOf( + Name.of("serialization") to FieldCall(field = Name.of("serialization")), + Name.of("request") to VariableReference(Name.of("request")), + ), + ), + ) + + assign( + "rawResponse", + FunctionCall( + name = Name.of("transport"), + receiver = FieldCall(field = Name.of("transportation")), + arguments = mapOf( + Name.of("request") to VariableReference(Name.of("rawRequest")), + ), + ), + ) + + returns( + FunctionCall( + name = Name(listOf("$endpointNameStr.fromRawResponse")), + arguments = mapOf( + Name.of("serialization") to FieldCall(field = Name.of("serialization")), + Name.of("response") to VariableReference(Name.of("rawResponse")), + ), + ), + ) + } + } + } +} + +fun List.convertClient(): File { + val endpoints = this + return file(Name.of("Client")) { + struct(Name.of("Client")) { + field("serialization", Type.Custom("Wirespec.Serialization")) + field("transportation", Type.Custom("Wirespec.Transportation")) + + endpoints.forEach { endpoint -> + implements(Type.Custom("${endpoint.identifier.toName().value()}.Call")) + } + + endpoints.forEach { endpoint -> + val endpointName = endpoint.identifier.toName() + val endpointNameStr = endpointName.value() + + asyncFunction(endpointName, isOverride = true) { + endpoint.requestParameters().forEach { (name, type) -> arg(name, type) } + returnType(Type.Custom("$endpointNameStr.Response", listOf(Type.Wildcard))) + + returns( + FunctionCall( + name = Name(listOf(endpointName.camelCase())), + receiver = ConstructorStatement( + type = Type.Custom("${endpointNameStr}Client"), + namedArguments = mapOf( + Name.of("serialization") to FieldCall(field = Name.of("serialization")), + Name.of("transportation") to FieldCall(field = Name.of("transportation")), + ), + ), + arguments = endpoint.requestParameters().associate { (name, _) -> + name to VariableReference(name) + }, + ), + ) + } + } + } + } +} + +fun EndpointWirespec.requestParameters(): List> = buildList { + path.filterIsInstance() + .forEach { add(it.identifier.toName() to it.reference.convert()) } + queries.forEach { add(it.identifier.toName() to it.reference.convert()) } + headers.forEach { add(it.identifier.toName() to it.reference.convert()) } + requests.first().content?.let { add(Name.of("body") to it.reference.convert()) } +} + +fun EndpointWirespec.Response.responseParameters(): List> = buildList { + headers.forEach { add(it.identifier.toName() to it.reference.convert()) } + content?.let { add(Name.of("body") to it.reference.convert()) } +} diff --git a/src/compiler/ir/src/commonMain/kotlin/community/flock/wirespec/ir/core/Ast.kt b/src/compiler/ir/src/commonMain/kotlin/community/flock/wirespec/ir/core/Ast.kt new file mode 100644 index 000000000..f38d2a06e --- /dev/null +++ b/src/compiler/ir/src/commonMain/kotlin/community/flock/wirespec/ir/core/Ast.kt @@ -0,0 +1,365 @@ +package community.flock.wirespec.ir.core + +data class Name(val parts: List) { + constructor(vararg parts: String) : this(parts.toList()) + + fun value(): String = parts.joinToString("") + + private fun wordParts(): List = parts.filter { it.isNotEmpty() && it.any { ch -> ch.isLetterOrDigit() } } + + fun camelCase(): String { + val words = wordParts() + return if (words.size <= 1) { + words.firstOrNull()?.replaceFirstChar { it.lowercase() } ?: "" + } else { + words.mapIndexed { index, part -> + if (index == 0) part.replaceFirstChar { it.lowercase() } else part.replaceFirstChar { it.uppercase() } + }.joinToString("") + } + } + + fun pascalCase(): String { + val words = wordParts() + return if (words.size <= 1) { + words.firstOrNull()?.replaceFirstChar { it.uppercase() } ?: "" + } else { + words.joinToString("") { it.replaceFirstChar { c -> c.uppercase() } } + } + } + + fun snakeCase(): String { + val words = wordParts() + return if (words.size <= 1) { + words.firstOrNull() + ?.replace(Regex("([a-z0-9])([A-Z])"), "$1_$2") + ?.lowercase() + ?: "" + } else { + words.joinToString("_") { it.lowercase() } + } + } + + override fun toString(): String = camelCase() + + companion object { + private val SPLIT_PATTERN = Regex("[A-Z]{2,}(?=[A-Z][a-z])|[A-Z]?[a-z0-9]+|[A-Z]+|[^a-zA-Z0-9]+") + fun of(value: String): Name = Name(SPLIT_PATTERN.findAll(value).map { it.value }.toList()) + } +} + +fun name(vararg parts: String): Name = Name(parts.toList()) + +enum class Precision { + P32, + P64, +} + +sealed interface Type { + data class Integer(val precision: Precision = Precision.P32) : Type + data class Number(val precision: Precision = Precision.P64) : Type + object Any : Type + object String : Type + object Boolean : Type + object Bytes : Type + object Unit : Type + object Wildcard : Type + object Reflect : Type + data class Array(val elementType: Type) : Type + data class Dict(val keyType: Type, val valueType: Type) : Type + data class Custom(val name: kotlin.String, val generics: List = emptyList()) : Type + data class Nullable(val type: Type) : Type + data class IntegerLiteral(val value: Int) : Type + data class StringLiteral(val value: kotlin.String) : Type +} + +sealed interface Element + +sealed interface HasName : Element { + val name: Name +} + +interface HasElements { + val elements: List +} + +data class File( + override val name: Name, + override val elements: List, +) : HasName, + HasElements + +data class Package( + val path: String, +) : Element + +data class Import( + val path: String, + val type: Type.Custom, +) : Element + +data class Struct( + override val name: Name, + val fields: List, + val constructors: List = emptyList(), + val interfaces: List = emptyList(), + override val elements: List = emptyList(), +) : HasName, + HasElements + +data class Constructor( + val parameters: List, + val body: List, +) + +data class Field( + val name: Name, + val type: Type, + val isOverride: Boolean = false, +) + +data class Function( + override val name: Name, + val typeParameters: List = emptyList(), + val parameters: List, + val returnType: Type?, + val body: List, + val isAsync: Boolean = false, + val isStatic: Boolean = false, + val isOverride: Boolean = false, +) : HasName + +data class Namespace( + override val name: Name, + override val elements: List, + val extends: Type.Custom? = null, +) : HasName, + HasElements + +data class Interface( + override val name: Name, + override val elements: List, + val extends: List = emptyList(), + val isSealed: Boolean = false, + val typeParameters: List = emptyList(), + val fields: List = emptyList(), +) : HasName, + HasElements + +data class Union( + override val name: Name, + val extends: Type.Custom? = null, + val members: List = emptyList(), + val typeParameters: List = emptyList(), +) : HasName + +data class Enum( + override val name: Name, + val extends: Type.Custom? = null, + val entries: List, + val fields: List = emptyList(), + val constructors: List = emptyList(), + override val elements: List = emptyList(), +) : HasName, + HasElements { + data class Entry(val name: Name, val values: List) +} + +data class Parameter( + val name: Name, + val type: Type, +) + +data class TypeParameter( + val type: Type, + val extends: List = emptyList(), +) + +sealed interface Statement : Expression +sealed interface Expression + +data class RawExpression(val code: String) : Statement + +// Main entry point - represents a language-specific main/entry point + +data class Main(val statics: List = emptyList(), val body: List, val isAsync: Boolean = false) : Element + +// Raw element - allows injecting raw code as an Element +data class RawElement(val code: String) : Element + +// Null literal - represents the null value +data object NullLiteral : Statement, Expression + +// Nullable empty literal - represents the empty optional value (e.g., Optional.empty() in Java, null in Kotlin) +data object NullableEmpty : Statement, Expression + +// Variable/identifier reference - represents a reference to a variable +data class VariableReference(val name: Name) : + Statement, + Expression { + constructor(name: String) : this(Name.of(name)) +} + +// Field access - represents accessing a field, optionally on a receiver (e.g., request.body or just body) +data class FieldCall( + val receiver: Expression? = null, + val field: Name, +) : Statement, + Expression + +// Function/method call - represents calling a function or method, optionally on a receiver +// If receiver is null, it's a standalone or static function call (e.g., fromRequest(...), java.util.Collections.emptyList()) +// If receiver is present, it's a method call on an object (e.g., list.get(index)) +data class FunctionCall( + val receiver: Expression? = null, + val typeArguments: List = emptyList(), + val name: Name, + val arguments: Map = emptyMap(), + val isAwait: Boolean = false, +) : Statement, + Expression + +// Array/map index access - represents bracket syntax (e.g., receiver[index]) +data class ArrayIndexCall( + val receiver: Expression, + val index: Expression, + val caseSensitive: Boolean = true, +) : Statement, + Expression + +// Enum constant reference - represents an enum constant (e.g., Wirespec.Method.GET) +data class EnumReference( + val enumType: Type.Custom, + val entry: Name, +) : Statement, + Expression + +// Enum value name access - gets the string name of an enum value +// In Java: .name(), in Kotlin: .name, in TypeScript: no-op (enums are already strings) +data class EnumValueCall( + val expression: Expression, +) : Statement, + Expression + +// Binary operations - represents binary operators (e.g., "message" + status) +data class BinaryOp( + val left: Expression, + val operator: Operator, + val right: Expression, +) : Statement, + Expression { + enum class Operator { PLUS, EQUALS, NOT_EQUALS } +} + +// Type descriptor - represents a runtime type descriptor for serialization +// In Java this emits Wirespec.getType(Type.class, Container.class) +// In other languages it may emit different type descriptor patterns +data class TypeDescriptor(val type: Type) : + Statement, + Expression + +data class PrintStatement(val expression: Expression) : Statement +data class ReturnStatement(val expression: Expression) : Statement +data class ConstructorStatement(val type: Type, val namedArguments: Map = emptyMap()) : + Statement, + Expression +data class Literal(val value: Any, val type: Type) : + Statement, + Expression +data class LiteralList(val values: List, val type: Type) : + Statement, + Expression +data class LiteralMap(val values: Map, val keyType: Type, val valueType: Type) : + Statement, + Expression +data class Assignment(val name: Name, val value: Expression, val isProperty: Boolean = false) : Statement +data class ErrorStatement(val message: Expression) : Statement +data class AssertStatement(val expression: Expression, val message: String) : Statement + +data class NullCheck( + val expression: Expression, + val body: Expression, + val alternative: Expression?, +) : Statement, + Expression + +data class NullableMap( + val expression: Expression, + val body: Expression, + val alternative: Expression, +) : Statement, + Expression + +data class NullableOf( + val expression: Expression, +) : Statement, + Expression + +data class NullableGet( + val expression: Expression, +) : Statement, + Expression + +sealed interface Constraint : + Statement, + Expression { + data class RegexMatch(val pattern: String, val rawValue: String, val value: Expression) : Constraint + data class BoundCheck(val min: String?, val max: String?, val value: Expression) : Constraint +} + +// Boolean negation +data class NotExpression(val expression: Expression) : + Statement, + Expression + +// Conditional expression (ternary) +data class IfExpression( + val condition: Expression, + val thenExpr: Expression, + val elseExpr: Expression, +) : Statement, + Expression + +// Map over a list +data class MapExpression( + val receiver: Expression, + val variable: Name, + val body: Expression, +) : Statement, + Expression + +// Indexed flatMap over a list +data class FlatMapIndexed( + val receiver: Expression, + val indexVar: Name, + val elementVar: Name, + val body: Expression, +) : Statement, + Expression + +// Concatenate multiple lists +data class ListConcat(val lists: List) : + Statement, + Expression + +// String interpolation +data class StringTemplate(val parts: List) : + Statement, + Expression { + sealed interface Part { + data class Text(val value: String) : Part + data class Expr(val expression: Expression) : Part + } +} + +data class Switch( + val expression: Expression, + val cases: List, + val default: List? = null, + val variable: Name? = null, +) : Statement + +data class Case( + val value: Expression, + val body: List, + val type: Type? = null, +) diff --git a/src/compiler/ir/src/commonMain/kotlin/community/flock/wirespec/ir/core/Dsl.kt b/src/compiler/ir/src/commonMain/kotlin/community/flock/wirespec/ir/core/Dsl.kt new file mode 100644 index 000000000..e821f242a --- /dev/null +++ b/src/compiler/ir/src/commonMain/kotlin/community/flock/wirespec/ir/core/Dsl.kt @@ -0,0 +1,790 @@ +package community.flock.wirespec.ir.core + +@DslMarker +annotation class Dsl + +@Dsl +interface BaseBuilder { + val integer get() = Type.Integer() + val integer32 get() = Type.Integer(Precision.P32) + val integer64 get() = Type.Integer(Precision.P64) + val number get() = Type.Number() + val number32 get() = Type.Number(Precision.P32) + val number64 get() = Type.Number(Precision.P64) + val string get() = Type.String + val boolean get() = Type.Boolean + val bytes get() = Type.Bytes + val unit get() = Type.Unit + val wildcard get() = Type.Wildcard + val reflect get() = Type.Reflect + fun list(type: Type) = Type.Array(type) + fun dict(keyType: Type, valueType: Type) = Type.Dict(keyType, valueType) + fun type(name: String, vararg generics: Type): Type.Custom = Type.Custom(name, generics.toList()) + + fun Type.nullable() = Type.Nullable(this) + + fun literal(value: String) = Literal(value, Type.String) + fun literal(value: Int) = Literal(value, Type.Integer()) + fun literal(value: Long) = Literal(value, Type.Integer(Precision.P64)) + fun literal(value: Boolean) = Literal(value, Type.Boolean) + fun literal(value: Float) = Literal(value, Type.Number(Precision.P32)) + fun literal(value: Double) = Literal(value, Type.Number(Precision.P64)) +} + +@Dsl +interface ContainerBuilder : BaseBuilder { + val elements: MutableList + + fun import(path: String, type: Type.Custom) { + elements.add(Import(path, type)) + } + + fun import(path: String, type: String) { + elements.add(Import(path, Type.Custom(type))) + } + + fun raw(code: String) { + elements.add(RawElement(code)) + } + + fun struct(name: String, block: (StructBuilder.() -> Unit)? = null) { + val builder = StructBuilder(name) + block?.let { builder.it() } + elements.add(builder.build()) + } + + fun struct(name: Name, block: (StructBuilder.() -> Unit)? = null) { + val builder = StructBuilder(name) + block?.let { builder.it() } + elements.add(builder.build()) + } + + fun function(name: String, isStatic: Boolean = false, isOverride: Boolean = false, block: (FunctionBuilder.() -> Unit)? = null) { + val builder = FunctionBuilder(name, isAsync = false, isStatic = isStatic, isOverride = isOverride) + block?.let { builder.it() } + elements.add(builder.build()) + } + + fun function(name: Name, isStatic: Boolean = false, isOverride: Boolean = false, block: (FunctionBuilder.() -> Unit)? = null) { + val builder = FunctionBuilder(name, isAsync = false, isStatic = isStatic, isOverride = isOverride) + block?.let { builder.it() } + elements.add(builder.build()) + } + + fun asyncFunction(name: String, isStatic: Boolean = false, isOverride: Boolean = false, block: (FunctionBuilder.() -> Unit)? = null) { + val builder = FunctionBuilder(name, isAsync = true, isStatic = isStatic, isOverride = isOverride) + block?.let { builder.it() } + elements.add(builder.build()) + } + + fun asyncFunction(name: Name, isStatic: Boolean = false, isOverride: Boolean = false, block: (FunctionBuilder.() -> Unit)? = null) { + val builder = FunctionBuilder(name, isAsync = true, isStatic = isStatic, isOverride = isOverride) + block?.let { builder.it() } + elements.add(builder.build()) + } + + fun namespace(name: String, extends: Type.Custom? = null, block: (NamespaceBuilder.() -> Unit)? = null) { + val builder = NamespaceBuilder(name, extends) + block?.let { builder.it() } + elements.add(builder.build()) + } + + fun namespace(name: Name, extends: Type.Custom? = null, block: (NamespaceBuilder.() -> Unit)? = null) { + val builder = NamespaceBuilder(name, extends) + block?.let { builder.it() } + elements.add(builder.build()) + } + + fun `interface`(name: String, isSealed: Boolean = false, block: (InterfaceBuilder.() -> Unit)? = null) { + val builder = InterfaceBuilder(name, isSealed) + block?.let { builder.it() } + elements.add(builder.build()) + } + + fun `interface`(name: Name, isSealed: Boolean = false, block: (InterfaceBuilder.() -> Unit)? = null) { + val builder = InterfaceBuilder(name, isSealed) + block?.let { builder.it() } + elements.add(builder.build()) + } + + fun union(name: String, extends: Type.Custom? = null, block: (UnionBuilder.() -> Unit)? = null) { + val builder = UnionBuilder(name, extends) + block?.let { builder.it() } + elements.add(builder.build()) + } + + fun union(name: Name, extends: Type.Custom? = null, block: (UnionBuilder.() -> Unit)? = null) { + val builder = UnionBuilder(name, extends) + block?.let { builder.it() } + elements.add(builder.build()) + } + + fun enum(name: String, extends: Type.Custom? = null, block: (EnumBuilder.() -> Unit)? = null) { + val builder = EnumBuilder(name, extends) + block?.let { builder.it() } + elements.add(builder.build()) + } + + fun enum(name: Name, extends: Type.Custom? = null, block: (EnumBuilder.() -> Unit)? = null) { + val builder = EnumBuilder(name, extends) + block?.let { builder.it() } + elements.add(builder.build()) + } +} + +@Dsl +class FileBuilder(private val name: Name) : ContainerBuilder { + constructor(nameStr: String) : this(Name.of(nameStr)) + + override val elements = mutableListOf() + fun `package`(path: String) { + elements.add(Package(path)) + } + + override fun function(name: String, isStatic: Boolean, isOverride: Boolean, block: (FunctionBuilder.() -> Unit)?) { + val builder = FunctionBuilder(name) + block?.let { builder.it() } + elements.add(builder.build()) + } + + override fun function(name: Name, isStatic: Boolean, isOverride: Boolean, block: (FunctionBuilder.() -> Unit)?) { + val builder = FunctionBuilder(name) + block?.let { builder.it() } + elements.add(builder.build()) + } + + override fun asyncFunction(name: String, isStatic: Boolean, isOverride: Boolean, block: (FunctionBuilder.() -> Unit)?) { + val builder = FunctionBuilder(name, isAsync = true) + block?.let { builder.it() } + elements.add(builder.build()) + } + + override fun asyncFunction(name: Name, isStatic: Boolean, isOverride: Boolean, block: (FunctionBuilder.() -> Unit)?) { + val builder = FunctionBuilder(name, isAsync = true) + block?.let { builder.it() } + elements.add(builder.build()) + } + + override fun struct(name: String, block: (StructBuilder.() -> Unit)?) { + val builder = StructBuilder(name) + block?.let { builder.it() } + elements.add(builder.build()) + } + + fun main(isAsync: Boolean = false, block: FunctionBuilder.() -> Unit) { + val builder = FunctionBuilder("main") + builder.block() + val fn = builder.build() + elements.add(Main(body = fn.body, isAsync = isAsync)) + } + + fun main(isAsync: Boolean = false, statics: ContainerBuilder.() -> Unit, block: FunctionBuilder.() -> Unit) { + val staticsBuilder = object : ContainerBuilder { + override val elements = mutableListOf() + } + staticsBuilder.statics() + val bodyBuilder = FunctionBuilder("main") + bodyBuilder.block() + val fn = bodyBuilder.build() + elements.add(Main(statics = staticsBuilder.elements, body = fn.body, isAsync = isAsync)) + } + + fun build(): File = File(name, elements) +} + +@Dsl +class NamespaceBuilder(private val name: Name, private val extends: Type.Custom? = null) : ContainerBuilder { + constructor(nameStr: String, extends: Type.Custom? = null) : this(Name.of(nameStr), extends) + + override val elements = mutableListOf() + + fun build(): Namespace = Namespace(name, elements, extends) +} + +@Dsl +class InterfaceBuilder( + private val name: Name, + private var isSealed: Boolean = false, +) : ContainerBuilder { + constructor(nameStr: String, isSealed: Boolean = false) : this(Name.of(nameStr), isSealed) + + override val elements = mutableListOf() + private val typeParameters = mutableListOf() + private val extendsList = mutableListOf() + private val fields = mutableListOf() + + fun typeParam(type: Type, vararg extends: Type) { + typeParameters.add(TypeParameter(type, extends.toList())) + } + + fun extends(type: Type.Custom) { + extendsList.add(type) + } + + fun sealed() { + isSealed = true + } + + fun field(name: String, type: Type, isOverride: Boolean = false) { + fields.add(Field(Name.of(name), type, isOverride)) + } + + fun field(name: Name, type: Type, isOverride: Boolean = false) { + fields.add(Field(name, type, isOverride)) + } + + fun build(): Interface = Interface(name, elements, extendsList, isSealed, typeParameters, fields) +} + +@Dsl +class UnionBuilder(private val name: Name, private val extends: Type.Custom? = null) : BaseBuilder { + constructor(nameStr: String, extends: Type.Custom? = null) : this(Name.of(nameStr), extends) + + private val members = mutableListOf() + private val typeParameters = mutableListOf() + + fun member(name: String) { + members.add(Type.Custom(name)) + } + + fun typeParam(type: Type, vararg extends: Type) { + typeParameters.add(TypeParameter(type, extends.toList())) + } + + fun build(): Union = Union(name, extends, members, typeParameters) +} + +@Dsl +class EnumBuilder(private val name: Name, private val extends: Type.Custom? = null) : ContainerBuilder { + constructor(nameStr: String, extends: Type.Custom? = null) : this(Name.of(nameStr), extends) + + private val entries = mutableListOf() + private val fields = mutableListOf() + private val constructors = mutableListOf() + override val elements = mutableListOf() + + fun entry(name: String, vararg values: String) { + entries.add(Enum.Entry(Name.of(name), values.toList())) + } + + fun field(name: String, type: Type, isOverride: Boolean = false) { + fields.add(Field(Name.of(name), type, isOverride)) + } + + fun field(name: Name, type: Type, isOverride: Boolean = false) { + fields.add(Field(name, type, isOverride)) + } + + fun constructo(block: StructConstructorBuilder.() -> Unit) { + val builder = StructConstructorBuilder() + builder.block() + constructors.add(builder.build()) + } + + fun build(): Enum = Enum(name, extends, entries, fields, constructors, elements) +} + +@Dsl +class StructBuilder(private val name: Name) : ContainerBuilder { + constructor(nameStr: String) : this(Name.of(nameStr)) + + private val fields = mutableListOf() + private val constructors = mutableListOf() + private val interfaces = mutableListOf() + override val elements = mutableListOf() + + fun implements(type: Type.Custom) { + interfaces.add(type) + } + + fun field(name: String, type: Type, isOverride: Boolean = false) { + fields.add(Field(Name.of(name), type, isOverride)) + } + + fun field(name: Name, type: Type, isOverride: Boolean = false) { + fields.add(Field(name, type, isOverride)) + } + + fun construct(type: Type, block: ConstructorBuilder.() -> Unit = {}): ConstructorStatement { + val builder = ConstructorBuilder(type) + builder.block() + return builder.build() + } + + fun constructo(block: StructConstructorBuilder.() -> Unit) { + val builder = StructConstructorBuilder() + builder.block() + constructors.add(builder.build()) + } + + fun build(): Struct = Struct(name, fields, constructors, interfaces, elements) +} + +@Dsl +class StructConstructorBuilder : BaseBuilder { + private val parameters = mutableListOf() + private val body = mutableListOf() + + fun arg(name: String, type: Type) { + parameters.add(Parameter(Name.of(name), type)) + } + + fun arg(name: Name, type: Type) { + parameters.add(Parameter(name, type)) + } + + fun assign(name: String, value: Expression) { + if (value is Statement && body.lastOrNull() === value) { + body.removeAt(body.size - 1) + } + body.add(Assignment(Name.of(name), value, isProperty = name.startsWith("this."))) + } + + fun construct(type: Type, block: ConstructorBuilder.() -> Unit = {}): ConstructorStatement { + val builder = ConstructorBuilder(type) + builder.block() + val node = builder.build() + body.add(node) + return node + } + + fun functionCall(name: String, receiver: Expression? = null, typeArguments: List = emptyList(), isAwait: Boolean = false, block: FunctionCallBuilder.() -> Unit = {}): FunctionCall { + val builder = FunctionCallBuilder(name, receiver, typeArguments, isAwait) + builder.block() + val node = builder.build() + body.add(node) + return node + } + + fun fieldCall(field: String, receiver: Expression? = null): FieldCall { + val node = FieldCall(receiver, Name.of(field)) + body.add(node) + return node + } + + fun nullCheck(expression: Expression, alternative: Expression, bodyExpr: Expression): NullCheck { + val node = NullCheck(expression, bodyExpr, alternative) + body.add(node) + return node + } + + fun build(): Constructor = Constructor(parameters, body) +} + +@Dsl +class FunctionBuilder( + private val name: Name, + private val isAsync: Boolean = false, + private val isStatic: Boolean = false, + private val isOverride: Boolean = false, +) : BaseBuilder { + constructor(name: String, isAsync: Boolean = false, isStatic: Boolean = false, isOverride: Boolean = false) : + this(Name.of(name), isAsync, isStatic, isOverride) + private val typeParameters = mutableListOf() + private val parameters = mutableListOf() + private val body = mutableListOf() + private var returnType: Type? = null + + fun typeParam(type: Type, vararg extends: Type) { + typeParameters.add(TypeParameter(type, extends.toList())) + } + + fun returnType(type: Type) { + returnType = type + } + + fun arg(name: String, type: Type) { + parameters.add(Parameter(Name.of(name), type)) + } + + fun arg(name: Name, type: Type) { + parameters.add(Parameter(name, type)) + } + + fun print(expression: Expression) { + body.add(PrintStatement(expression)) + } + + fun returns(expression: Expression) { + if (expression is Statement && body.lastOrNull() === expression) { + body.removeAt(body.size - 1) + } + body.add(ReturnStatement(expression)) + } + + fun literal(value: Any, type: Type): Literal { + val node = Literal(value, type) + body.add(node) + return node + } + + fun literalList(values: List, type: Type): LiteralList { + val node = LiteralList(values, type) + body.add(node) + return node + } + + fun literalList(type: Type): LiteralList = literalList(emptyList(), type) + + fun literalMap(values: Map, keyType: Type, valueType: Type): LiteralMap { + val node = LiteralMap(values, keyType, valueType) + body.add(node) + return node + } + + fun literalMap(keyType: Type, valueType: Type): LiteralMap = literalMap(emptyMap(), keyType, valueType) + + fun assign(name: String, value: Expression) { + if (value is Statement && body.lastOrNull() === value) { + body.removeAt(body.size - 1) + } + body.add(Assignment(Name.of(name), value)) + } + + fun construct(type: Type, block: ConstructorBuilder.() -> Unit = {}): ConstructorStatement { + val builder = ConstructorBuilder(type) + builder.block() + val node = builder.build() + body.add(node) + return node + } + + fun functionCall(name: String, receiver: Expression? = null, typeArguments: List = emptyList(), isAwait: Boolean = false, block: FunctionCallBuilder.() -> Unit = {}): FunctionCall { + val builder = FunctionCallBuilder(name, receiver, typeArguments, isAwait) + builder.block() + val node = builder.build() + body.add(node) + return node + } + + fun fieldCall(field: String, receiver: Expression? = null): FieldCall { + val node = FieldCall(receiver, Name.of(field)) + body.add(node) + return node + } + + fun switch(expression: Expression, variable: String? = null, block: SwitchBuilder.() -> Unit) { + val builder = SwitchBuilder(expression, variable?.let { Name.of(it) }) + builder.block() + body.add(builder.build()) + } + + fun error(message: Expression) { + body.add(ErrorStatement(message)) + } + + fun assertThat(expression: Expression, message: String) { + body.add(AssertStatement(expression, message)) + } + + fun raw(code: String) { + body.add(RawExpression(code)) + } + + fun nullCheck(expression: Expression, alternative: Expression, bodyExpr: Expression): NullCheck { + val node = NullCheck(expression, bodyExpr, alternative) + body.add(node) + return node + } + + fun build(): Function = Function(name, typeParameters, parameters, returnType, body, isAsync, isStatic, isOverride) +} + +@Dsl +class SwitchBuilder(private val expression: Expression, private val variable: Name? = null) : BaseBuilder { + private val cases = mutableListOf() + private var default: List? = null + + fun case(value: Literal, block: CaseBuilder.() -> Unit) { + val builder = CaseBuilder(value) + builder.block() + cases.add(builder.build()) + } + + fun case(type: Type, block: CaseBuilder.() -> Unit) { + val builder = CaseBuilder(RawExpression("type")) // value not used when type is present + builder.block() + cases.add(builder.build().copy(type = type)) + } + + inline fun case(noinline block: CaseBuilder.() -> Unit) { + val typeName = T::class.simpleName ?: throw IllegalArgumentException("Cannot get simple name for ${T::class}") + case(Type.Custom(typeName), block) + } + + fun default(block: CaseBuilder.() -> Unit) { + val builder = CaseBuilder(RawExpression("default")) // value not used for default + builder.block() + default = builder.build().body + } + + fun build(): Switch = Switch(expression, cases, default, variable) +} + +@Dsl +class CaseBuilder(private val value: Expression) : BaseBuilder { + private val body = mutableListOf() + + fun print(expression: Expression) { + body.add(PrintStatement(expression)) + } + + fun returns(expression: Expression) { + if (expression is Statement && body.lastOrNull() === expression) { + body.removeAt(body.size - 1) + } + body.add(ReturnStatement(expression)) + } + + fun assign(name: String, value: Expression) { + if (value is Statement && body.lastOrNull() === value) { + body.removeAt(body.size - 1) + } + body.add(Assignment(Name.of(name), value)) + } + + fun functionCall(name: String, receiver: Expression? = null, typeArguments: List = emptyList(), isAwait: Boolean = false, block: FunctionCallBuilder.() -> Unit = {}): FunctionCall { + val builder = FunctionCallBuilder(name, receiver, typeArguments, isAwait) + builder.block() + val node = builder.build() + body.add(node) + return node + } + + fun fieldCall(field: String, receiver: Expression? = null): FieldCall { + val node = FieldCall(receiver, Name.of(field)) + body.add(node) + return node + } + + fun construct(type: Type, block: ConstructorBuilder.() -> Unit = {}): ConstructorStatement { + val builder = ConstructorBuilder(type) + builder.block() + val node = builder.build() + body.add(node) + return node + } + + fun error(message: Expression) { + body.add(ErrorStatement(message)) + } + + fun assertThat(expression: Expression, message: String) { + body.add(AssertStatement(expression, message)) + } + + fun nullCheck(expression: Expression, alternative: Expression, bodyExpr: Expression): NullCheck { + val node = NullCheck(expression, bodyExpr, alternative) + body.add(node) + return node + } + + fun build(): Case = Case(value, body) +} + +@Dsl +class FunctionCallBuilder(private val name: String, private val receiver: Expression? = null, private val typeArguments: List = emptyList(), private var isAwait: Boolean = false) : BaseBuilder { + private val arguments = mutableMapOf() + + fun await() { + isAwait = true + } + + fun arg(argName: String, value: Expression) { + arguments[Name.of(argName)] = value + } + + fun arg(argName: Name, value: Expression) { + arguments[argName] = value + } + + fun functionCall(name: String, receiver: Expression? = null, typeArguments: List = emptyList(), isAwait: Boolean = false, block: FunctionCallBuilder.() -> Unit = {}): FunctionCall { + val builder = FunctionCallBuilder(name, receiver, typeArguments, isAwait) + builder.block() + return builder.build() + } + + fun fieldCall(field: String, receiver: Expression? = null): FieldCall = FieldCall(receiver, Name.of(field)) + + fun literal(value: Any, type: Type): Literal = Literal(value, type) + + fun listOf(values: List, type: Type): LiteralList = LiteralList(values, type) + + fun emptyList(type: Type): LiteralList = listOf(emptyList(), type) + + fun mapOf(values: Map, keyType: Type, valueType: Type): LiteralMap = LiteralMap(values, keyType, valueType) + + fun emptyMap(keyType: Type, valueType: Type): LiteralMap = mapOf(emptyMap(), keyType, valueType) + + fun nullCheck(expression: Expression, alternative: Expression, bodyExpr: Expression): NullCheck = NullCheck(expression, bodyExpr, alternative) + + fun build(): FunctionCall = FunctionCall(receiver, typeArguments, Name.of(name), arguments, isAwait) +} + +@Dsl +class ConstructorBuilder(private val type: Type) : BaseBuilder { + private val arguments = mutableMapOf() + + fun arg(name: String, value: Expression) { + arguments[Name.of(name)] = value + } + + fun arg(name: Name, value: Expression) { + arguments[name] = value + } + + fun functionCall(name: String, receiver: Expression? = null, typeArguments: List = emptyList(), isAwait: Boolean = false, block: FunctionCallBuilder.() -> Unit = {}): FunctionCall { + val builder = FunctionCallBuilder(name, receiver, typeArguments, isAwait) + builder.block() + return builder.build() + } + + fun fieldCall(field: String, receiver: Expression? = null): FieldCall = FieldCall(receiver, Name.of(field)) + + fun literal(value: Any, type: Type): Literal = Literal(value, type) + + fun listOf(values: List, type: Type): LiteralList = LiteralList(values, type) + + fun emptyList(type: Type): LiteralList = listOf(emptyList(), type) + + fun mapOf(values: Map, keyType: Type, valueType: Type): LiteralMap = LiteralMap(values, keyType, valueType) + + fun emptyMap(keyType: Type, valueType: Type): LiteralMap = mapOf(emptyMap(), keyType, valueType) + + fun nullCheck(expression: Expression, alternative: Expression, bodyExpr: Expression): NullCheck = NullCheck(expression, bodyExpr, alternative) + + fun build(): ConstructorStatement = ConstructorStatement(type, arguments) +} + +fun file(name: String, block: FileBuilder.() -> Unit): File { + val builder = FileBuilder(name) + builder.block() + return builder.build() +} + +fun file(name: Name, block: FileBuilder.() -> Unit): File { + val builder = FileBuilder(name) + builder.block() + return builder.build() +} + +fun struct(name: String, block: (StructBuilder.() -> Unit)? = null): Struct { + val builder = StructBuilder(name) + block?.let { builder.it() } + return builder.build() +} + +fun struct(name: Name, block: (StructBuilder.() -> Unit)? = null): Struct { + val builder = StructBuilder(name) + block?.let { builder.it() } + return builder.build() +} + +fun enum(name: String, extends: Type.Custom? = null, block: (EnumBuilder.() -> Unit)? = null): Enum { + val builder = EnumBuilder(name, extends) + block?.let { builder.it() } + return builder.build() +} + +fun enum(name: Name, extends: Type.Custom? = null, block: (EnumBuilder.() -> Unit)? = null): Enum { + val builder = EnumBuilder(name, extends) + block?.let { builder.it() } + return builder.build() +} + +fun union(name: String, extends: Type.Custom? = null, block: (UnionBuilder.() -> Unit)? = null): Union { + val builder = UnionBuilder(name, extends) + block?.let { builder.it() } + return builder.build() +} + +fun union(name: Name, extends: Type.Custom? = null, block: (UnionBuilder.() -> Unit)? = null): Union { + val builder = UnionBuilder(name, extends) + block?.let { builder.it() } + return builder.build() +} + +fun `interface`(name: String, isSealed: Boolean = false, block: (InterfaceBuilder.() -> Unit)? = null): Interface { + val builder = InterfaceBuilder(name, isSealed) + block?.let { builder.it() } + return builder.build() +} + +fun `interface`(name: Name, isSealed: Boolean = false, block: (InterfaceBuilder.() -> Unit)? = null): Interface { + val builder = InterfaceBuilder(name, isSealed) + block?.let { builder.it() } + return builder.build() +} + +fun namespace(name: String, extends: Type.Custom? = null, block: (NamespaceBuilder.() -> Unit)? = null): Namespace { + val builder = NamespaceBuilder(name, extends) + block?.let { builder.it() } + return builder.build() +} + +fun namespace(name: Name, extends: Type.Custom? = null, block: (NamespaceBuilder.() -> Unit)? = null): Namespace { + val builder = NamespaceBuilder(name, extends) + block?.let { builder.it() } + return builder.build() +} + +fun function(name: String, isStatic: Boolean = false, isOverride: Boolean = false, block: (FunctionBuilder.() -> Unit)? = null): Function { + val builder = FunctionBuilder(name, isAsync = false, isStatic = isStatic, isOverride = isOverride) + block?.let { builder.it() } + return builder.build() +} + +fun function(name: Name, isStatic: Boolean = false, isOverride: Boolean = false, block: (FunctionBuilder.() -> Unit)? = null): Function { + val builder = FunctionBuilder(name, isAsync = false, isStatic = isStatic, isOverride = isOverride) + block?.let { builder.it() } + return builder.build() +} + +fun asyncFunction(name: String, isStatic: Boolean = false, isOverride: Boolean = false, block: (FunctionBuilder.() -> Unit)? = null): Function { + val builder = FunctionBuilder(name, isAsync = true, isStatic = isStatic, isOverride = isOverride) + block?.let { builder.it() } + return builder.build() +} + +fun asyncFunction(name: Name, isStatic: Boolean = false, isOverride: Boolean = false, block: (FunctionBuilder.() -> Unit)? = null): Function { + val builder = FunctionBuilder(name, isAsync = true, isStatic = isStatic, isOverride = isOverride) + block?.let { builder.it() } + return builder.build() +} + +fun import(path: String, type: Type.Custom): Import = Import(path, type) + +fun import(path: String, type: String): Import = Import(path, Type.Custom(type)) + +fun main(isAsync: Boolean = false, block: FunctionBuilder.() -> Unit): Main { + val builder = FunctionBuilder("main") + builder.block() + val fn = builder.build() + return Main(body = fn.body, isAsync = isAsync) +} + +fun raw(code: String): RawElement = RawElement(code) + +fun Enum.withLabelField( + sanitizeEntry: (String) -> String, + labelFieldOverride: Boolean = false, + labelExpression: Expression = VariableReference(Name.of("label")), + extraElements: List = emptyList(), +): Enum = copy( + entries = entries.map { + Enum.Entry(Name.of(sanitizeEntry(it.name.value())), listOf("\"${it.name.value()}\"")) + }, + fields = listOf(Field(Name.of("label"), Type.String, isOverride = labelFieldOverride)), + constructors = listOf( + Constructor( + parameters = listOf(Parameter(Name.of("label"), Type.String)), + body = listOf(Assignment(Name.of("this.label"), labelExpression, true)), + ), + ), + elements = listOf( + function("toString", isOverride = true) { + returnType(Type.String) + returns(labelExpression) + }, + ) + extraElements, +) diff --git a/src/compiler/ir/src/commonMain/kotlin/community/flock/wirespec/ir/core/Extensions.kt b/src/compiler/ir/src/commonMain/kotlin/community/flock/wirespec/ir/core/Extensions.kt new file mode 100644 index 000000000..051a5768c --- /dev/null +++ b/src/compiler/ir/src/commonMain/kotlin/community/flock/wirespec/ir/core/Extensions.kt @@ -0,0 +1,3 @@ +package community.flock.wirespec.ir.core + +fun Expression.fieldCall(field: String): FieldCall = FieldCall(receiver = this, field = Name.of(field)) diff --git a/src/compiler/ir/src/commonMain/kotlin/community/flock/wirespec/ir/core/Restructure.kt b/src/compiler/ir/src/commonMain/kotlin/community/flock/wirespec/ir/core/Restructure.kt new file mode 100644 index 000000000..a96829f9e --- /dev/null +++ b/src/compiler/ir/src/commonMain/kotlin/community/flock/wirespec/ir/core/Restructure.kt @@ -0,0 +1,60 @@ +package community.flock.wirespec.ir.core + +fun Struct.qualifyNestedRefs(nestedNames: Set): Struct { + val qualifiedFields = fields.map { field -> + val typeName = (field.type as? Type.Custom)?.name + if (typeName != null && typeName in nestedNames) { + field.copy(type = Type.Custom("${name.pascalCase()}$typeName")) + } else { + field + } + } + val qualifiedConstructors = constructors.map { c -> + c.copy( + body = c.body.map { stmt -> + if (stmt is Assignment) { + val value = stmt.value + if (value is ConstructorStatement) { + val typeName = (value.type as? Type.Custom)?.name + if (typeName != null && typeName in nestedNames) { + Assignment(stmt.name, value.copy(type = Type.Custom("${name.pascalCase()}$typeName"))) + } else { + stmt + } + } else { + stmt + } + } else { + stmt + } + }, + ) + } + return copy( + fields = qualifiedFields, + constructors = qualifiedConstructors, + elements = elements.filter { it !is Struct }, + ) +} + +fun Namespace.flattenNestedStructs(): Namespace { + val newElements = mutableListOf() + for (element in elements) { + when (element) { + is Struct -> { + val nested = element.elements.filterIsInstance() + if (nested.isNotEmpty()) { + val nestedNames = nested.map { it.name.pascalCase() }.toSet() + for (nestedStruct in nested) { + newElements.add(nestedStruct.copy(name = Name.of("${element.name.pascalCase()}${nestedStruct.name.pascalCase()}"))) + } + newElements.add(element.qualifyNestedRefs(nestedNames)) + } else { + newElements.add(element) + } + } + else -> newElements.add(element) + } + } + return copy(elements = newElements) +} diff --git a/src/compiler/ir/src/commonMain/kotlin/community/flock/wirespec/ir/core/Transform.kt b/src/compiler/ir/src/commonMain/kotlin/community/flock/wirespec/ir/core/Transform.kt new file mode 100644 index 000000000..209550704 --- /dev/null +++ b/src/compiler/ir/src/commonMain/kotlin/community/flock/wirespec/ir/core/Transform.kt @@ -0,0 +1,496 @@ +package community.flock.wirespec.ir.core + +import kotlin.js.JsName + +interface Transformer { + fun transformType(type: Type): Type = type.transformChildren(this) + fun transformElement(element: Element): Element = element.transformChildren(this) + fun transformStatement(statement: Statement): Statement = statement.transformChildren(this) + fun transformExpression(expression: Expression): Expression = expression.transformChildren(this) + fun transformField(field: Field): Field = field.transformChildren(this) + fun transformParameter(parameter: Parameter): Parameter = parameter.transformChildren(this) + fun transformConstructor(constructor: Constructor): Constructor = constructor.transformChildren(this) + fun transformCase(case: Case): Case = case.transformChildren(this) +} + +@Dsl +class TransformerBuilder @PublishedApi internal constructor() { + private var transformType: ((Type, Transformer) -> Type)? = null + private var transformElement: ((Element, Transformer) -> Element)? = null + private var transformStatement: ((Statement, Transformer) -> Statement)? = null + private var transformExpression: ((Expression, Transformer) -> Expression)? = null + private var transformField: ((Field, Transformer) -> Field)? = null + private var transformParameter: ((Parameter, Transformer) -> Parameter)? = null + private var transformConstructor: ((Constructor, Transformer) -> Constructor)? = null + private var transformCase: ((Case, Transformer) -> Case)? = null + + fun type(transform: (Type, Transformer) -> Type) { + transformType = transform + } + fun element(transform: (Element, Transformer) -> Element) { + transformElement = transform + } + fun statement(transform: (Statement, Transformer) -> Statement) { + transformStatement = transform + } + fun expression(transform: (Expression, Transformer) -> Expression) { + transformExpression = transform + } + fun field(transform: (Field, Transformer) -> Field) { + transformField = transform + } + fun parameter(transform: (Parameter, Transformer) -> Parameter) { + transformParameter = transform + } + + @JsName("constructorNode") + fun constructor(transform: (Constructor, Transformer) -> Constructor) { + transformConstructor = transform + } + fun case(transform: (Case, Transformer) -> Case) { + transformCase = transform + } + + fun statementAndExpression(block: (Statement, Transformer) -> Statement) { + statement(block) + expression { e, t -> + (e as? Statement)?.let { block(it, t) } ?: e.transformChildren(t) + } + } + + @PublishedApi + internal fun build(): Transformer = object : Transformer { + override fun transformType(type: Type): Type = transformType?.invoke(type, this) ?: type.transformChildren(this) + override fun transformElement(element: Element): Element = transformElement?.invoke(element, this) ?: element.transformChildren(this) + override fun transformStatement(statement: Statement): Statement = transformStatement?.invoke(statement, this) ?: statement.transformChildren(this) + override fun transformExpression(expression: Expression): Expression = transformExpression?.invoke(expression, this) ?: expression.transformChildren(this) + override fun transformField(field: Field): Field = transformField?.invoke(field, this) ?: field.transformChildren(this) + override fun transformParameter(parameter: Parameter): Parameter = transformParameter?.invoke(parameter, this) ?: parameter.transformChildren(this) + override fun transformConstructor(constructor: Constructor): Constructor = transformConstructor?.invoke(constructor, this) ?: constructor.transformChildren(this) + override fun transformCase(case: Case): Case = transformCase?.invoke(case, this) ?: case.transformChildren(this) + } +} + +inline fun transformer(block: TransformerBuilder.() -> Unit): Transformer = TransformerBuilder().apply(block).build() + +fun Type.transformChildren(transformer: Transformer): Type = when (this) { + is Type.Array -> copy(elementType = transformer.transformType(elementType)) + is Type.Dict -> copy( + keyType = transformer.transformType(keyType), + valueType = transformer.transformType(valueType), + ) + is Type.Custom -> copy(generics = generics.map { transformer.transformType(it) }) + is Type.Nullable -> copy(type = transformer.transformType(type)) + is Type.Integer, is Type.Number, Type.Any, Type.String, Type.Boolean, Type.Bytes, Type.Unit, Type.Wildcard, Type.Reflect, is Type.IntegerLiteral, is Type.StringLiteral -> this +} + +fun Element.transformChildren(transformer: Transformer): Element = when (this) { + is File -> copy(elements = elements.map { transformer.transformElement(it) }) + is Package -> this + is Import -> copy(type = transformer.transformType(type) as Type.Custom) + is Struct -> copy( + fields = fields.map { transformer.transformField(it) }, + constructors = constructors.map { transformer.transformConstructor(it) }, + interfaces = interfaces.map { transformer.transformType(it) as Type.Custom }, + elements = elements.map { transformer.transformElement(it) }, + ) + is Function -> copy( + parameters = parameters.map { transformer.transformParameter(it) }, + returnType = returnType?.let { transformer.transformType(it) }, + body = body.map { transformer.transformStatement(it) }, + ) + is Namespace -> copy( + elements = elements.map { transformer.transformElement(it) }, + extends = extends?.let { transformer.transformType(it) as Type.Custom }, + ) + is Interface -> copy( + elements = elements.map { transformer.transformElement(it) }, + extends = extends.map { transformer.transformType(it) as Type.Custom }, + fields = fields.map { transformer.transformField(it) }, + ) + is Union -> copy( + extends = extends?.let { transformer.transformType(it) as Type.Custom }, + members = members.map { transformer.transformType(it) as Type.Custom }, + typeParameters = typeParameters.map { + TypeParameter( + transformer.transformType(it.type), + it.extends.map { e -> transformer.transformType(e) }, + ) + }, + ) + is Enum -> copy( + extends = extends?.let { transformer.transformType(it) as Type.Custom }, + fields = fields.map { transformer.transformField(it) }, + constructors = constructors.map { transformer.transformConstructor(it) }, + elements = elements.map { transformer.transformElement(it) }, + ) + is Main -> copy( + statics = statics.map { transformer.transformElement(it) }, + body = body.map { transformer.transformStatement(it) }, + ) + is RawElement -> this +} + +fun Field.transformChildren(transformer: Transformer): Field = copy(type = transformer.transformType(type)) + +fun Parameter.transformChildren(transformer: Transformer): Parameter = copy(type = transformer.transformType(type)) + +fun Constructor.transformChildren(transformer: Transformer): Constructor = copy( + parameters = parameters.map { transformer.transformParameter(it) }, + body = body.map { transformer.transformStatement(it) }, +) + +fun Statement.transformChildren(transformer: Transformer): Statement = when (this) { + is PrintStatement -> copy(expression = transformer.transformExpression(expression)) + is ReturnStatement -> copy(expression = transformer.transformExpression(expression)) + is ConstructorStatement -> copy( + type = transformer.transformType(type), + namedArguments = namedArguments.mapValues { transformer.transformExpression(it.value) }, + ) + is Literal -> copy(type = transformer.transformType(type)) + is LiteralList -> copy( + values = values.map { transformer.transformExpression(it) }, + type = transformer.transformType(type), + ) + is LiteralMap -> copy( + values = values.mapValues { transformer.transformExpression(it.value) }, + keyType = transformer.transformType(keyType), + valueType = transformer.transformType(valueType), + ) + is Assignment -> copy(value = transformer.transformExpression(value)) + is ErrorStatement -> copy(message = transformer.transformExpression(message)) + is AssertStatement -> copy( + expression = transformer.transformExpression(expression), + ) + is Switch -> copy( + expression = transformer.transformExpression(expression), + cases = cases.map { transformer.transformCase(it) }, + default = default?.map { transformer.transformStatement(it) }, + ) + is RawExpression -> this + is NullLiteral -> this + is NullableEmpty -> this + is VariableReference -> this + is FieldCall -> copy(receiver = receiver?.let { transformer.transformExpression(it) }) + is FunctionCall -> copy( + receiver = receiver?.let { transformer.transformExpression(it) }, + arguments = arguments.mapValues { transformer.transformExpression(it.value) }, + ) + is ArrayIndexCall -> copy( + receiver = transformer.transformExpression(receiver), + index = transformer.transformExpression(index), + ) + is EnumReference -> copy(enumType = transformer.transformType(enumType) as Type.Custom) + is EnumValueCall -> copy(expression = transformer.transformExpression(expression)) + is BinaryOp -> copy( + left = transformer.transformExpression(left), + right = transformer.transformExpression(right), + ) + is TypeDescriptor -> copy(type = transformer.transformType(type)) + is NullCheck -> copy( + expression = transformer.transformExpression(expression), + body = transformer.transformExpression(body), + alternative = alternative?.let { transformer.transformExpression(it) }, + ) + is NullableMap -> copy( + expression = transformer.transformExpression(expression), + body = transformer.transformExpression(body), + alternative = transformer.transformExpression(alternative), + ) + is NullableOf -> copy(expression = transformer.transformExpression(expression)) + is NullableGet -> copy(expression = transformer.transformExpression(expression)) + is Constraint.RegexMatch -> copy(value = transformer.transformExpression(value)) + is Constraint.BoundCheck -> copy(value = transformer.transformExpression(value)) + is NotExpression -> copy(expression = transformer.transformExpression(expression)) + is IfExpression -> copy( + condition = transformer.transformExpression(condition), + thenExpr = transformer.transformExpression(thenExpr), + elseExpr = transformer.transformExpression(elseExpr), + ) + is MapExpression -> copy( + receiver = transformer.transformExpression(receiver), + body = transformer.transformExpression(body), + ) + is FlatMapIndexed -> copy( + receiver = transformer.transformExpression(receiver), + body = transformer.transformExpression(body), + ) + is ListConcat -> copy(lists = lists.map { transformer.transformExpression(it) }) + is StringTemplate -> copy( + parts = parts.map { + when (it) { + is StringTemplate.Part.Text -> it + is StringTemplate.Part.Expr -> StringTemplate.Part.Expr(transformer.transformExpression(it.expression)) + } + }, + ) +} + +fun Expression.transformChildren(transformer: Transformer): Expression = when (this) { + is RawExpression -> this + is Statement -> transformChildren(transformer) as Expression +} + +fun Case.transformChildren(transformer: Transformer): Case = copy( + value = transformer.transformExpression(value), + body = body.map { transformer.transformStatement(it) }, + type = type?.let { transformer.transformType(it) }, +) + +@Suppress("UNCHECKED_CAST") +@PublishedApi +internal fun T.transform(transformer: Transformer): T = transformer.transformElement(this) as T + +@Dsl +class TransformScope @PublishedApi internal constructor( + @PublishedApi internal var element: Element, +) { + inline fun matching(crossinline transform: (M) -> Type) { + element = element.transformMatching(transform) + } + + inline fun matchingElements(crossinline transform: (M) -> Element) { + element = element.transformMatchingElements(transform) + } + + fun fieldsWhere(predicate: (Field) -> Boolean, transform: (Field) -> Field) { + element = element.transformFieldsWhere(predicate, transform) + } + + fun fields(transform: (Field) -> Field) { + fieldsWhere({ true }, transform) + } + + fun parametersWhere(predicate: (Parameter) -> Boolean, transform: (Parameter) -> Parameter) { + element = element.transformParametersWhere(predicate, transform) + } + + fun parameters(transform: (Parameter) -> Parameter) { + parametersWhere({ true }, transform) + } + + fun renameType(oldName: String, newName: String) { + element = element.renameType(oldName, newName) + } + + fun renameField(oldName: Name, newName: Name) { + element = element.renameField(oldName, newName) + } + + fun renameField(oldName: String, newName: String) { + element = element.renameField(oldName, newName) + } + + fun typeByName(name: String, transform: (Type.Custom) -> Type) { + element = element.transformTypeByName(name, transform) + } + + inline fun injectBefore( + crossinline produce: (T) -> List, + ) + where T : Element, T : HasElements { + element = element.injectBefore(produce) + } + + inline fun injectAfter( + crossinline produce: (T) -> List, + ) + where T : Element, T : HasElements { + element = element.injectAfter(produce) + } + + fun apply(transformer: Transformer) { + element = element.transform(transformer) + } + + fun type(transform: (Type, Transformer) -> Type) { + element = element.transform(transformer { type(transform) }) + } + + fun statement(transform: (Statement, Transformer) -> Statement) { + element = element.transform(transformer { statement(transform) }) + } + + fun expression(transform: (Expression, Transformer) -> Expression) { + element = element.transform(transformer { expression(transform) }) + } + + fun field(transform: (Field, Transformer) -> Field) { + element = element.transform(transformer { field(transform) }) + } + + fun parameter(transform: (Parameter, Transformer) -> Parameter) { + element = element.transform(transformer { parameter(transform) }) + } + + @JsName("constructorNode") + fun constructor(transform: (Constructor, Transformer) -> Constructor) { + element = element.transform(transformer { constructor(transform) }) + } + + fun case(transform: (Case, Transformer) -> Case) { + element = element.transform(transformer { case(transform) }) + } + + fun statementAndExpression(block: (Statement, Transformer) -> Statement) { + apply(transformer { statementAndExpression(block) }) + } +} + +@Suppress("UNCHECKED_CAST") +inline fun E.transform(block: TransformScope.() -> Unit): E { + val scope = TransformScope(this) + scope.block() + return scope.element as E +} + +@PublishedApi +internal inline fun E.transformMatching( + crossinline transform: (M) -> Type, +): E = transform( + transformer { + type { type, transformer -> + val transformed = if (type is M) transform(type) else type + transformed.transformChildren(transformer) + } + }, +) + +@PublishedApi +internal inline fun E.transformMatchingElements( + crossinline transform: (M) -> Element, +): E = transform( + transformer { + element { element, transformer -> + val transformed = if (element is M) transform(element) else element + transformed.transformChildren(transformer) + } + }, +) + +internal fun T.transformFieldsWhere( + predicate: (Field) -> Boolean, + transform: (Field) -> Field, +): T = transform( + transformer { + field { field, transformer -> + val transformed = if (predicate(field)) transform(field) else field + transformed.transformChildren(transformer) + } + }, +) + +internal fun T.transformTypeByName( + name: String, + transform: (Type.Custom) -> Type, +): T = transformMatching { type: Type.Custom -> + if (type.name == name) transform(type) else type +} + +internal fun T.renameType(oldName: String, newName: String): T = transformTypeByName(oldName) { it.copy(name = newName) } + +internal fun T.renameField(oldName: Name, newName: Name): T = transformFieldsWhere({ it.name == oldName }) { it.copy(name = newName) } + +internal fun T.renameField(oldName: String, newName: String): T = renameField(Name.of(oldName), Name.of(newName)) + +internal fun T.transformParametersWhere( + predicate: (Parameter) -> Boolean, + transform: (Parameter) -> Parameter, +): T = transform( + transformer { + parameter { parameter, transformer -> + val transformed = if (predicate(parameter)) transform(parameter) else parameter + transformed.transformChildren(transformer) + } + }, +) + +@Suppress("UNCHECKED_CAST") +@PublishedApi +internal fun T.withElements(elements: List): T = ( + when (this) { + is File -> copy(elements = elements) + is Struct -> copy(elements = elements) + is Namespace -> copy(elements = elements) + is Interface -> copy(elements = elements) + is Enum -> copy(elements = elements) + is Main -> this + else -> this + } + ) as T + +@PublishedApi +internal inline fun E.injectBefore( + crossinline produce: (T) -> List, +): E where T : Element, T : HasElements = transformMatchingElements { element -> + val injected = produce(element) + if (injected.isNotEmpty()) element.withElements(injected + element.elements) else element +} + +@PublishedApi +internal inline fun E.injectAfter( + crossinline produce: (T) -> List, +): E where T : Element, T : HasElements = transformMatchingElements { element -> + val injected = produce(element) + if (injected.isNotEmpty()) element.withElements(element.elements + injected) else element +} + +internal fun Element.forEachType(action: (Type) -> Unit) { + transform( + transformer { + type { type, tr -> + action(type) + type.transformChildren(tr) + } + }, + ) +} + +@PublishedApi +internal fun Element.forEachElement(action: (Element) -> Unit) { + transform( + transformer { + element { element, tr -> + action(element) + element.transformChildren(tr) + } + }, + ) +} + +internal fun Element.forEachField(action: (Field) -> Unit) { + transform( + transformer { + field { field, tr -> + action(field) + field.transformChildren(tr) + } + }, + ) +} + +internal fun Element.collectTypes(): List = buildList { + forEachType { add(it) } +} + +internal fun Element.collectCustomTypeNames(): Set = buildSet { + forEachType { type -> + if (type is Type.Custom) add(type.name) + } +} + +inline fun HasElements.findElement(): T? = elements.filterIsInstance().firstOrNull() + +inline fun HasElements.findElement(predicate: (T) -> Boolean): T? = elements.filterIsInstance().firstOrNull(predicate) + +inline fun Element.findAll(): List = buildList { + forEachElement { element -> + if (element is T) add(element) + } +} + +internal inline fun Element.findAllTypes(): List = buildList { + forEachType { type -> + if (type is T) add(type) + } +} diff --git a/src/compiler/ir/src/commonMain/kotlin/community/flock/wirespec/ir/emit/IrEmitter.kt b/src/compiler/ir/src/commonMain/kotlin/community/flock/wirespec/ir/emit/IrEmitter.kt new file mode 100644 index 000000000..824fa7e8a --- /dev/null +++ b/src/compiler/ir/src/commonMain/kotlin/community/flock/wirespec/ir/emit/IrEmitter.kt @@ -0,0 +1,80 @@ +package community.flock.wirespec.ir.emit + +import arrow.core.NonEmptyList +import community.flock.wirespec.compiler.core.emit.Emitted +import community.flock.wirespec.compiler.core.emit.Emitter +import community.flock.wirespec.compiler.core.emit.Shared +import community.flock.wirespec.compiler.core.parse.ast.AST +import community.flock.wirespec.compiler.core.parse.ast.Channel +import community.flock.wirespec.compiler.core.parse.ast.Definition +import community.flock.wirespec.compiler.core.parse.ast.Endpoint +import community.flock.wirespec.compiler.core.parse.ast.Enum +import community.flock.wirespec.compiler.core.parse.ast.Module +import community.flock.wirespec.compiler.core.parse.ast.Refined +import community.flock.wirespec.compiler.core.parse.ast.Type +import community.flock.wirespec.compiler.core.parse.ast.Union +import community.flock.wirespec.compiler.utils.Logger +import community.flock.wirespec.ir.converter.convertClient +import community.flock.wirespec.ir.converter.convertEndpointClient +import community.flock.wirespec.ir.core.File +import community.flock.wirespec.ir.generator.Generator + +interface IrEmitter : Emitter { + + val shared: Shared? + + val generator: Generator + + override fun emit(ast: AST, logger: Logger): NonEmptyList { + val moduleEmitted = ast.modules.flatMap { m -> + logger.info("Emitting Nodes from ${m.fileUri.value} ") + emit(m, logger) + }.map { file -> Emitted(file.name.value() + "." + extension.value, generator.generate(file)) } + + val allEndpoints = ast.modules.toList().flatMap { it.statements.filterIsInstance() } + return if (allEndpoints.isNotEmpty()) { + val mainClient = emitClient(allEndpoints, logger) + moduleEmitted + Emitted(mainClient.name.value() + "." + extension.value, generator.generate(mainClient)) + } else { + moduleEmitted + } + } + + fun emit(module: Module, logger: Logger): NonEmptyList { + val definitionFiles = module.statements.map { emit(it, module, logger) } + val endpoints = module.statements.toList().filterIsInstance() + val clientFiles = endpoints.map { endpoint -> + logger.info("Emitting Client for endpoint ${endpoint.identifier.value}") + emitEndpointClient(endpoint) + } + return definitionFiles + clientFiles + } + + fun emit(definition: Definition, module: Module, logger: Logger): File = run { + logger.info("Emitting ${definition::class.simpleName} ${definition.identifier.value}") + return when (definition) { + is Type -> emit(definition, module) + is Endpoint -> emit(definition) + is Enum -> emit(definition, module) + is Refined -> emit(definition) + is Union -> emit(definition) + is Channel -> emit(definition) + } + } + + fun emitEndpointClient(endpoint: Endpoint): File = endpoint.convertEndpointClient() + + fun emitClient(endpoints: List, logger: Logger): File { + logger.info("Emitting main Client for ${endpoints.size} endpoints") + return endpoints.convertClient() + } + + fun emit(type: Type, module: Module): File + fun emit(enum: Enum, module: Module): File + fun emit(refined: Refined): File + fun emit(endpoint: Endpoint): File + fun emit(union: Union): File + fun emit(channel: Channel): File + + fun transformTestFile(file: File): File = file +} diff --git a/src/compiler/ir/src/commonMain/kotlin/community/flock/wirespec/ir/generator/Generator.kt b/src/compiler/ir/src/commonMain/kotlin/community/flock/wirespec/ir/generator/Generator.kt new file mode 100644 index 000000000..96f8e3ddb --- /dev/null +++ b/src/compiler/ir/src/commonMain/kotlin/community/flock/wirespec/ir/generator/Generator.kt @@ -0,0 +1,19 @@ +package community.flock.wirespec.ir.generator + +import community.flock.wirespec.ir.core.Element + +interface Generator { + fun generate(element: Element): String +} + +fun Element.generateJava() = JavaGenerator.generate(this) + +fun Element.generatePython() = PythonGenerator.generate(this) + +fun Element.generateTypeScript() = TypeScriptGenerator.generate(this) + +fun Element.generateKotlin() = KotlinGenerator.generate(this) + +fun Element.generateRust() = RustGenerator.generate(this) + +fun Element.generateScala() = ScalaGenerator.generate(this) diff --git a/src/compiler/ir/src/commonMain/kotlin/community/flock/wirespec/ir/generator/JavaGenerator.kt b/src/compiler/ir/src/commonMain/kotlin/community/flock/wirespec/ir/generator/JavaGenerator.kt new file mode 100644 index 000000000..43399ee7f --- /dev/null +++ b/src/compiler/ir/src/commonMain/kotlin/community/flock/wirespec/ir/generator/JavaGenerator.kt @@ -0,0 +1,646 @@ +package community.flock.wirespec.ir.generator + +import community.flock.wirespec.ir.core.ArrayIndexCall +import community.flock.wirespec.ir.core.AssertStatement +import community.flock.wirespec.ir.core.Assignment +import community.flock.wirespec.ir.core.BinaryOp +import community.flock.wirespec.ir.core.Constraint +import community.flock.wirespec.ir.core.Constructor +import community.flock.wirespec.ir.core.ConstructorStatement +import community.flock.wirespec.ir.core.Element +import community.flock.wirespec.ir.core.Enum +import community.flock.wirespec.ir.core.EnumReference +import community.flock.wirespec.ir.core.EnumValueCall +import community.flock.wirespec.ir.core.ErrorStatement +import community.flock.wirespec.ir.core.Expression +import community.flock.wirespec.ir.core.Field +import community.flock.wirespec.ir.core.FieldCall +import community.flock.wirespec.ir.core.File +import community.flock.wirespec.ir.core.FlatMapIndexed +import community.flock.wirespec.ir.core.FunctionCall +import community.flock.wirespec.ir.core.IfExpression +import community.flock.wirespec.ir.core.Import +import community.flock.wirespec.ir.core.Interface +import community.flock.wirespec.ir.core.ListConcat +import community.flock.wirespec.ir.core.Literal +import community.flock.wirespec.ir.core.LiteralList +import community.flock.wirespec.ir.core.LiteralMap +import community.flock.wirespec.ir.core.Main +import community.flock.wirespec.ir.core.MapExpression +import community.flock.wirespec.ir.core.Name +import community.flock.wirespec.ir.core.Namespace +import community.flock.wirespec.ir.core.NotExpression +import community.flock.wirespec.ir.core.NullCheck +import community.flock.wirespec.ir.core.NullLiteral +import community.flock.wirespec.ir.core.NullableEmpty +import community.flock.wirespec.ir.core.NullableGet +import community.flock.wirespec.ir.core.NullableMap +import community.flock.wirespec.ir.core.NullableOf +import community.flock.wirespec.ir.core.Package +import community.flock.wirespec.ir.core.Parameter +import community.flock.wirespec.ir.core.Precision +import community.flock.wirespec.ir.core.PrintStatement +import community.flock.wirespec.ir.core.RawElement +import community.flock.wirespec.ir.core.RawExpression +import community.flock.wirespec.ir.core.ReturnStatement +import community.flock.wirespec.ir.core.Statement +import community.flock.wirespec.ir.core.StringTemplate +import community.flock.wirespec.ir.core.Struct +import community.flock.wirespec.ir.core.Switch +import community.flock.wirespec.ir.core.Type +import community.flock.wirespec.ir.core.TypeDescriptor +import community.flock.wirespec.ir.core.TypeParameter +import community.flock.wirespec.ir.core.Union +import community.flock.wirespec.ir.core.VariableReference +import community.flock.wirespec.ir.core.Function as AstFunction + +object JavaGenerator : Generator { + override fun generate(element: Element): String = when (element) { + is File -> { + val emitter = JavaEmitter(element) + emitter.emitFile() + } + else -> { + val emitter = JavaEmitter(File(Name.of(""), listOf(element))) + emitter.emitFile() + } + } +} + +private class JavaEmitter(val file: File) { + + fun emitFile(): String { + val packages = file.elements.filterIsInstance() + val imports = file.elements.filterIsInstance() + val otherElements = file.elements.filter { it !is Package && it !is Import } + + val packagesStr = packages.joinToString("") { it.emit(0) } + val importsStr = imports.joinToString("") { it.emit(0) } + val elementsStr = otherElements.joinToString("") { it.emit(0, parents = emptyList()) } + + return "$packagesStr$importsStr$elementsStr".removeEmptyLines() + } + + private fun String.removeEmptyLines(): String = lines().filter { it.isNotEmpty() }.joinToString("\n").plus("\n") + + private fun String.indentCode(level: Int): String { + if (level <= 0) return this + val prefix = " ".repeat(level * 2) + return this.lines().joinToString("\n") { line -> + if (line.isEmpty()) line else prefix + line + } + } + + private fun Element.emit(indent: Int, isStatic: Boolean = true, parents: List): String = when (this) { + is Package -> emit(indent) + is Import -> emit(indent) + is Struct -> { + emit(indent, parents) + } + is AstFunction -> { + val lastParent = parents.lastOrNull() + val isInterface = lastParent is Interface + val isStaticContainer = lastParent is Namespace + val isInterfaceBody = isInterface && body.isNotEmpty() + val isInsideStruct = lastParent is Struct + val shouldBeStatic = (isStatic || isStaticContainer || this.isStatic) && !isInterface && (!isInsideStruct || this.isStatic) + val overridePrefix = if (isOverride) "@Override\n" else "" + + if (indent == 0) { + emit(indent, isStatic = true, modifier = "public") + } else if (isInterfaceBody) { + if (this.isStatic) { + emit(indent, isStatic = true, modifier = "public") + } else { + emit(indent, isStatic = false, modifier = "${overridePrefix}default") + } + } else { + val visibility = if (indent == 1) "public" else "" + val staticStr = if (shouldBeStatic) "static" else "" + val modParts = listOf(visibility, staticStr).filter { it.isNotEmpty() } + val modSuffix = modParts.joinToString(" ") + val fullModifier = if (isOverride) { + if (modSuffix.isNotEmpty()) "$overridePrefix$modSuffix" else "@Override" + } else { + modSuffix + } + emit(indent, isStatic = shouldBeStatic, modifier = fullModifier) + } + } + is Namespace -> emit(indent, parents) + is Interface -> emit(indent, parents) + is Union -> emit(indent, parents) + is Enum -> emit(indent) + is Main -> { + val staticContent = statics.joinToString("") { it.emit(1, true, parents) } + val content = body.joinToString("") { it.emit(1) } + "public class ${file.name.pascalCase()} {\n$staticContent" + + " public static void main(String[] args) {\n$content }\n}\n" + } + is File -> elements.joinToString("") { it.emit(indent, isStatic, parents) } + is RawElement -> code.indentCode(indent) + } + + private fun Package.emit(indent: Int): String = "package $path;\n\n".indentCode(indent) + + private fun Import.emit(indent: Int): String = "import $path.${type.name};\n".indentCode(indent) + + private fun Namespace.emit(indent: Int, parents: List): String { + val extStr = extends?.let { " extends ${it.emitGenerics()}" } ?: "" + val content = elements.joinToString("") { it.emit(1, isStatic = true, parents = parents + this) } + return "public interface ${name.pascalCase()}$extStr {\n$content${"}".indentCode(0)}\n\n".indentCode(indent) + } + + private fun Interface.emit(indent: Int, parents: List): String { + val isInsideStaticOrInterface = parents.any { it is Namespace || it is Interface } + val publicStr = if (indent == 0 || isInsideStaticOrInterface) "public " else "" + val sealedStr = if (isSealed) "sealed " else "" + val typeParamsStr = if (typeParameters.isNotEmpty()) "<${typeParameters.joinToString(", ") { it.emit() }}>" else "" + val extStr = if (extends.isNotEmpty()) " extends ${extends.joinToString(", ") { it.emitGenerics() }}" else "" + val fieldsContent = fields.joinToString("") { field -> + "${field.type.emitGenerics()} ${field.name.value()}();\n".indentCode(1) + } + val elementsContent = elements.joinToString("") { it.emit(1, isStatic = false, parents = parents + this) } + val content = fieldsContent + elementsContent + return if (content.isEmpty()) { + "$publicStr${sealedStr}interface ${name.pascalCase()}$typeParamsStr$extStr {\n}\n\n".indentCode(indent) + } else { + "$publicStr${sealedStr}interface ${name.pascalCase()}$typeParamsStr$extStr {\n$content${"}".indentCode(0)}\n\n".indentCode(indent) + } + } + + private fun Union.emit(indent: Int, parents: List): String { + val typeParamsStr = if (typeParameters.isNotEmpty()) "<${typeParameters.joinToString(", ") { it.emit() }}>" else "" + val extendsName = extends?.name + val ext = listOfNotNull(extends?.emitGenerics()) + + parents.filterIsInstance().filter { it.name.pascalCase() != extendsName }.map { it.name.pascalCase() } + + val extStr = if (ext.isEmpty()) "" else " extends ${ext.distinct().joinToString(", ")}" + val permitsStr = if (members.isEmpty()) "" else " permits ${members.joinToString(", ") { it.name }}" + return "public sealed interface ${name.pascalCase()}$typeParamsStr$extStr$permitsStr {}\n\n".indentCode(indent) + } + + private fun Enum.emit(indent: Int): String { + val entriesStr = entries.joinToString(",\n") { entry -> + val e = if (entry.values.isEmpty()) { + entry.name.value() + } else { + "${entry.name.value()}(${entry.values.joinToString(", ")})" + } + e.indentCode(indent + 1) + } + val implStr = extends?.let { " implements ${it.emitGenerics()}" } ?: "" + + val hasContent = fields.isNotEmpty() || constructors.isNotEmpty() || elements.isNotEmpty() + val terminator = if (hasContent) ";\n" else "" + + val fieldsStr = fields.joinToString("\n") { "public final ${it.type.emitGenerics()} ${it.name.value()};".indentCode(indent + 1) } + val constructorsStr = constructors.joinToString("\n") { it.emit(name.pascalCase(), fields, indent + 1, false, "") } + val functionsStr = elements.filterIsInstance().joinToString("\n") { + val isOverride = it.isOverride || it.name.camelCase() == "toString" || it.name.camelCase() == "getLabel" + val overridePrefix = if (isOverride) "@Override\n${"".indentCode(indent + 1)}" else "" + val visibility = "public" + val staticStr = if (it.isStatic) "static" else "" + val modParts = listOf(visibility, staticStr).filter { it.isNotEmpty() } + val fullModifier = "$overridePrefix${modParts.joinToString(" ")}" + + it.emit(indent + 1, it.isStatic, fullModifier).trimEnd() + } + + val content = listOf(fieldsStr, constructorsStr, functionsStr).filter { it.isNotEmpty() }.joinToString("\n") + val sep = if (content.isNotEmpty()) "\n" else "" + + return ("public enum ${name.pascalCase()}$implStr {\n$entriesStr$terminator$sep$content\n${"}".indentCode(indent)}\n".indentCode(indent)).trimEnd() + } + + private fun Struct.emit(indent: Int, parents: List): String { + val implStr = if (interfaces.isEmpty()) "" else " implements ${interfaces.map { it.emitGenerics() }.distinct().joinToString(", ")}" + + val isInsideStaticOrInterface = parents.any { it is Namespace || it is Interface } + val typeModifier = when { + indent == 0 -> "public record" + isInsideStaticOrInterface -> "public static record" + else -> "record" + } + + val customConstructors = constructors.joinToString("") { it.emit(name.pascalCase(), fields, 1, isRecord = true) } + val nestedContent = elements.joinToString("") { it.emit(1, isStatic = true, parents = parents + this) } + + val params = fields.joinToString(",\n") { "${it.type.emitGenerics()} ${it.name.value().sanitize()}".indentCode(1) } + val paramsStr = if (fields.isEmpty()) " ()" else " (\n$params\n)" + + return "$typeModifier ${name.pascalCase()}$paramsStr$implStr {\n$customConstructors$nestedContent};\n\n".indentCode(indent) + } + + private fun Constructor.emit(structName: String, structFields: List, indent: Int, isRecord: Boolean, modifier: String = "public"): String { + val params = parameters.joinToString(", ") { it.emit(0) } + val isDelegating = body.any { it is ConstructorStatement } + val prefix = if (modifier.isEmpty()) "" else "$modifier " + + if (isRecord && !isDelegating) { + val assignments = body.filterIsInstance().associate { + it.name.value().removePrefix("this.") to it.value.emit() + } + val constructorArgs = structFields.map { field -> + assignments[field.name.value()] ?: "null" + } + val otherStatements = body.filter { it !is Assignment || it.name.value().removePrefix("this.") !in structFields.map { f -> f.name.value() } } + val bodyContent = ( + listOf("this(${constructorArgs.joinToString(", ")});\n") + + otherStatements.map { it.emit(0) } + ) + .joinToString("") { it.indentCode(1) } + + return "${prefix}$structName($params) {\n$bodyContent}\n".indentCode(indent) + } + + val bodyContent = body.joinToString("") { it.emit(1, isInsideConstructor = true) } + + return if (isRecord && !isDelegating) { + "${prefix}$structName {\n$bodyContent}\n".indentCode(indent) + } else { + "${prefix}$structName($params) {\n$bodyContent}\n".indentCode(indent) + } + } + + private fun AstFunction.emit(indent: Int, isStatic: Boolean, modifier: String): String { + val rType = if (isAsync) { + "java.util.concurrent.CompletableFuture<${returnType?.emitGenerics() ?: "Void"}>" + } else { + returnType?.takeIf { it != Type.Unit }?.emitGenerics() ?: "void" + } + val params = parameters.joinToString(", ") { it.emit(0) } + val typeParamsStr = if (typeParameters.isNotEmpty()) { + "<${typeParameters.joinToString(", ") { it.emit() }}> " + } else { + "" + } + val prefix = listOfNotNull( + "public".takeIf { indent == 1 && !modifier.contains("public") }, + "static".takeIf { isStatic && !modifier.contains("static") }, + modifier.takeIf { it.isNotEmpty() }, + ).joinToString(" ") + + val fullPrefix = if (prefix.isEmpty()) "" else "$prefix " + + return if (body.isEmpty()) { + "$fullPrefix$typeParamsStr$rType ${name.camelCase()}($params);\n".indentCode(indent) + } else { + val content = body.joinToString("") { it.emit(1) } + "$fullPrefix$typeParamsStr$rType ${name.camelCase()}($params) {\n$content${"}".indentCode(0)}\n\n".indentCode(indent) + } + } + + private fun Parameter.emit(indent: Int): String = "${type.emitGenerics()} ${name.camelCase().sanitize()}".indentCode(indent) + + private fun TypeParameter.emit(): String { + val typeStr = type.emitGenerics() + return if (extends.isEmpty()) { + typeStr + } else { + "$typeStr extends ${extends.joinToString(" & ") { it.emitGenerics() }}" + } + } + + private fun Type.emit(): String = when (this) { + is Type.Integer -> when (precision) { + Precision.P32 -> "Integer" + Precision.P64 -> "Long" + } + is Type.Number -> when (precision) { + Precision.P32 -> "Float" + Precision.P64 -> "Double" + } + Type.Any -> "Object" + Type.String -> "String" + Type.Bytes -> "byte[]" + Type.Boolean -> "Boolean" + Type.Unit -> "Void" + Type.Wildcard -> "?" + Type.Reflect -> "Type" + is Type.Array -> "java.util.List" + is Type.Dict -> "java.util.Map" + is Type.Custom -> name + is Type.Nullable -> "java.util.Optional<${type.emitGenerics()}>" + is Type.IntegerLiteral -> "Integer" + is Type.StringLiteral -> "String" + } + + private fun Type.emitGenerics(): String = when (this) { + is Type.Array -> "${emit()}<${elementType.emitGenerics()}>" + is Type.Dict -> "${emit()}<${keyType.emitGenerics()}, ${valueType.emitGenerics()}>" + is Type.Custom -> { + if (generics.isEmpty()) { + emit() + } else { + "${emit()}<${generics.joinToString(", ") { it.emitGenerics() }}>" + } + } + is Type.Nullable -> "java.util.Optional<${type.emitGenerics()}>" + else -> emit() + } + + private fun Statement.emit(indent: Int, isInsideConstructor: Boolean = false): String = when (this) { + is PrintStatement -> "System.out.println(${expression.emit()});\n".indentCode(indent) + is ReturnStatement -> "return ${expression.emit()};\n".indentCode(indent) + is ConstructorStatement -> { + if (type == Type.Unit) { + "null;\n".indentCode(indent) + } else { + val allArgs = namedArguments.map { it.value.emit() } + val argsStr = when { + allArgs.isEmpty() -> "()" + allArgs.size == 1 -> "(${allArgs.first()})" + else -> "(\n${allArgs.joinToString(",\n") { it.indentCode(1) }}\n)" + } + if (isInsideConstructor) { + "this$argsStr;\n".indentCode(indent) + } else { + "new ${type.emitGenerics()}$argsStr;\n".indentCode(indent) + } + } + } + is Literal -> "${emit()};\n".indentCode(indent) + is LiteralList -> "${emit()};\n".indentCode(indent) + is LiteralMap -> "${emit()};\n".indentCode(indent) + is Assignment -> { + val expr = (value as? ConstructorStatement)?.let { constructorStmt -> + if (constructorStmt.type == Type.Unit) { + "null" + } else { + val allArgs = constructorStmt.namedArguments.map { it.value.emit() } + val argsStr = when { + allArgs.isEmpty() -> "()" + allArgs.size == 1 -> "(${allArgs.first()})" + else -> "(\n${allArgs.joinToString(",\n") { it.indentCode(1) }}\n)" + } + "new ${constructorStmt.type.emitGenerics()}$argsStr" + } + } ?: value.emit() + if (isProperty) { + "${name.value().sanitize()} = $expr;\n".indentCode(indent) + } else { + "final var ${name.camelCase().sanitize()} = $expr;\n".indentCode(indent) + } + } + is ErrorStatement -> "throw new IllegalStateException(${message.emit()});\n".indentCode(indent) + is AssertStatement -> "assert ${expression.emit()} : \"$message\";\n".indentCode(indent) + is Switch -> { + val isPatternSwitch = cases.any { it.type != null } + if (isPatternSwitch) { + // Use if-else chain with instanceof for pattern matching (Java 16+) + val casesStr = cases.mapIndexed { index, case -> + val bodyStr = case.body.joinToString("") { it.emit(1) } + val typeStr = case.type?.emitGenerics() ?: "Object" + val varName = variable?.camelCase() ?: "_" + val prefix = if (index == 0) "if" else " else if" + "$prefix (${expression.emit()} instanceof $typeStr $varName) {\n$bodyStr}" + }.joinToString("") + val defaultStr = default?.let { + val bodyStr = it.joinToString("") { stmt -> stmt.emit(1) } + " else {\n$bodyStr}" + } ?: "" + "$casesStr$defaultStr\n".indentCode(indent) + } else { + // Regular switch with arrow syntax + val casesStr = cases.joinToString("") { case -> + val bodyStr = case.body.joinToString("") { it.emit(1) } + "case ${case.value.emit()} -> {\n$bodyStr}\n".indentCode(indent + 1) + } + val defaultStr = default?.let { + val bodyStr = it.joinToString("") { stmt -> stmt.emit(1) } + "default -> {\n$bodyStr}\n".indentCode(indent + 1) + } ?: "" + "switch (${expression.emit()}) {\n$casesStr$defaultStr}\n".indentCode(indent) + } + } + is RawExpression -> "$code;\n".indentCode(indent) + is NullLiteral -> "null;\n".indentCode(indent) + is NullableEmpty -> "java.util.Optional.empty();\n".indentCode(indent) + is VariableReference -> "${name.camelCase().sanitize()};\n".indentCode(indent) + is FieldCall -> { + val receiverStr = receiver?.let { "${it.emit()}." } ?: "" + "$receiverStr${field.value().sanitize()}();\n".indentCode(indent) + } + is FunctionCall -> { + val typeArgsStr = if (typeArguments.isNotEmpty()) "<${typeArguments.joinToString(", ") { it.emitGenerics() }}>" else "" + val receiverStr = receiver?.let { "${it.emit()}." } ?: "" + val awaitSuffix = if (isAwait) ".join()" else "" + "$receiverStr$typeArgsStr${name.value().sanitize()}(${arguments.values.joinToString(", ") { it.emit() }})$awaitSuffix;\n".indentCode(indent) + } + is ArrayIndexCall -> if (caseSensitive) { + "${receiver.emit()}.get(${index.emit()});\n".indentCode(indent) + } else { + "${receiver.emit()}.entrySet().stream().filter(e -> e.getKey().equalsIgnoreCase(${index.emit()})).findFirst().map(java.util.Map.Entry::getValue).orElse(null);\n".indentCode(indent) + } + is EnumReference -> "${enumType.emitGenerics()}.${entry.value()};\n".indentCode(indent) + is EnumValueCall -> "${expression.emit()}.name();\n".indentCode(indent) + is BinaryOp -> when { + operator == BinaryOp.Operator.EQUALS && right is NullLiteral -> "(${left.emit()} == null);\n".indentCode(indent) + operator == BinaryOp.Operator.NOT_EQUALS && right is NullLiteral -> "(${left.emit()} != null);\n".indentCode(indent) + operator == BinaryOp.Operator.EQUALS && left is NullLiteral -> "(null == ${right.emit()});\n".indentCode(indent) + operator == BinaryOp.Operator.NOT_EQUALS && left is NullLiteral -> "(null != ${right.emit()});\n".indentCode(indent) + operator == BinaryOp.Operator.EQUALS && isPrimitiveLiteral() -> "(${left.emit()} == ${right.emit()});\n".indentCode(indent) + operator == BinaryOp.Operator.NOT_EQUALS && isPrimitiveLiteral() -> "(${left.emit()} != ${right.emit()});\n".indentCode(indent) + operator == BinaryOp.Operator.EQUALS -> "(${left.emit()}.equals(${right.emit()}));\n".indentCode(indent) + operator == BinaryOp.Operator.NOT_EQUALS -> "(!${left.emit()}.equals(${right.emit()}));\n".indentCode(indent) + else -> "(${left.emit()} ${operator.toJava()} ${right.emit()});\n".indentCode(indent) + } + is TypeDescriptor -> error("TypeDescriptor should be transformed before reaching the generator") + is NullCheck -> "${emit()};\n".indentCode(indent) + is NullableMap -> "${emit()};\n".indentCode(indent) + is NullableOf -> "${emit()};\n".indentCode(indent) + is NullableGet -> "${emit()};\n".indentCode(indent) + is Constraint.RegexMatch -> "${emit()};\n".indentCode(indent) + is Constraint.BoundCheck -> "${emit()};\n".indentCode(indent) + is NotExpression -> "!${expression.emit()};\n".indentCode(indent) + is IfExpression -> "${emit()};\n".indentCode(indent) + is MapExpression -> "${emit()};\n".indentCode(indent) + is FlatMapIndexed -> "${emit()};\n".indentCode(indent) + is ListConcat -> "${emit()};\n".indentCode(indent) + is StringTemplate -> "${emit()};\n".indentCode(indent) + } + + private fun BinaryOp.Operator.toJava(): String = when (this) { + BinaryOp.Operator.PLUS -> "+" + BinaryOp.Operator.EQUALS -> "==" + BinaryOp.Operator.NOT_EQUALS -> "!=" + } + + private fun BinaryOp.isPrimitiveLiteral(): Boolean = left is Literal && + ((left as Literal).type is Type.Integer || (left as Literal).type is Type.Number || (left as Literal).type is Type.Boolean) || + right is Literal && + ((right as Literal).type is Type.Integer || (right as Literal).type is Type.Number || (right as Literal).type is Type.Boolean) + + private fun Expression.emit(): String = when (this) { + is ConstructorStatement -> { + if (type == Type.Unit) { + "null" + } else { + val allArgs = namedArguments.map { it.value.emit() } + val argsStr = when { + allArgs.isEmpty() -> "()" + allArgs.size == 1 -> "(${allArgs.first()})" + else -> "(\n${allArgs.joinToString(",\n") { it.indentCode(1) }}\n)" + } + "new ${type.emitGenerics()}$argsStr" + } + } + is Literal -> emit() + is LiteralList -> emit() + is LiteralMap -> emit() + is RawExpression -> code + is NullLiteral -> "null" + is NullableEmpty -> "java.util.Optional.empty()" + is VariableReference -> name.camelCase().sanitize() + is FieldCall -> { + val receiverStr = receiver?.let { "${it.emit()}." } ?: "" + "$receiverStr${field.value().sanitize()}()" + } + is FunctionCall -> { + val typeArgsStr = if (typeArguments.isNotEmpty()) "<${typeArguments.joinToString(", ") { it.emitGenerics() }}>" else "" + val receiverStr = receiver?.let { "${it.emit()}." } ?: "" + val awaitSuffix = if (isAwait) ".join()" else "" + "$receiverStr$typeArgsStr${name.value().sanitize()}(${arguments.values.joinToString(", ") { it.emit() }})$awaitSuffix" + } + is ArrayIndexCall -> if (caseSensitive) { + "${receiver.emit()}.get(${index.emit()})" + } else { + "${receiver.emit()}.entrySet().stream().filter(e -> e.getKey().equalsIgnoreCase(${index.emit()})).findFirst().map(java.util.Map.Entry::getValue).orElse(null)" + } + is EnumReference -> "${enumType.emitGenerics()}.${entry.value()}" + is EnumValueCall -> "${expression.emit()}.name()" + is BinaryOp -> when { + operator == BinaryOp.Operator.EQUALS && right is NullLiteral -> "(${left.emit()} == null)" + operator == BinaryOp.Operator.NOT_EQUALS && right is NullLiteral -> "(${left.emit()} != null)" + operator == BinaryOp.Operator.EQUALS && left is NullLiteral -> "(null == ${right.emit()})" + operator == BinaryOp.Operator.NOT_EQUALS && left is NullLiteral -> "(null != ${right.emit()})" + operator == BinaryOp.Operator.EQUALS && isPrimitiveLiteral() -> "(${left.emit()} == ${right.emit()})" + operator == BinaryOp.Operator.NOT_EQUALS && isPrimitiveLiteral() -> "(${left.emit()} != ${right.emit()})" + operator == BinaryOp.Operator.EQUALS -> "(${left.emit()}.equals(${right.emit()}))" + operator == BinaryOp.Operator.NOT_EQUALS -> "(!${left.emit()}.equals(${right.emit()}))" + else -> "(${left.emit()} ${operator.toJava()} ${right.emit()})" + } + is TypeDescriptor -> error("TypeDescriptor should be transformed before reaching the generator") + is NullCheck -> { + val orElse = when (val alt = alternative) { + is ErrorStatement -> ".orElseThrow(() -> new IllegalStateException(${alt.message.emit()}))" + null -> "" + else -> ".orElse(${alt.emit()})" + } + "java.util.Optional.ofNullable(${expression.emit()}).map(it -> ${body.emit()})$orElse" + } + is NullableMap -> { + val orElse = when (val alt = alternative) { + is ErrorStatement -> "orElseThrow(() -> new IllegalStateException(${alt.message.emit()}))" + else -> "orElse(${alternative.emit()})" + } + "${expression.emit()}.map(it -> ${body.emit()}).$orElse" + } + is NullableOf -> "java.util.Optional.of(${expression.emit()})" + is NullableGet -> "${expression.emit()}.get()" + is Constraint.RegexMatch -> "java.util.regex.Pattern.compile(\"${pattern.replace("\\", "\\\\")}\").matcher(${value.emit()}).find()" + is Constraint.BoundCheck -> { + val checks = listOfNotNull( + min?.let { "$it <= ${value.emit()}" }, + max?.let { "${value.emit()} <= $it" }, + ).joinToString(" && ").ifEmpty { "true" } + checks + } + is ErrorStatement -> "throw new IllegalStateException(${message.emit()});" + is AssertStatement -> throw IllegalArgumentException("AssertStatement cannot be an expression in Java") + is Switch -> throw IllegalArgumentException("Switch cannot be an expression in Java") + is Assignment -> throw IllegalArgumentException("Assignment cannot be an expression in Java") + is PrintStatement -> throw IllegalArgumentException("PrintStatement cannot be an expression in Java") + is ReturnStatement -> throw IllegalArgumentException("ReturnStatement cannot be an expression in Java") + is NotExpression -> "!${expression.emit()}" + is IfExpression -> "(${condition.emit()} ? ${thenExpr.emit()} : ${elseExpr.emit()})" + is MapExpression -> "${receiver.emit()}.stream().map(${variable.camelCase()} -> ${body.emit()}).toList()" + is FlatMapIndexed -> { + val recv = receiver.emit() + val bodyWithSubstitution = body.emitWithSubstitution(elementVar, "$recv.get(${indexVar.camelCase()})") + "java.util.stream.IntStream.range(0, $recv.size()).mapToObj(${indexVar.camelCase()} -> $bodyWithSubstitution).flatMap(java.util.Collection::stream).toList()" + } + is ListConcat -> when { + lists.isEmpty() -> "java.util.List.of()" + lists.size == 1 -> lists.single().emit() + else -> "java.util.stream.Stream.of(${lists.joinToString(", ") { it.emit() }}).flatMap(java.util.Collection::stream).toList()" + } + is StringTemplate -> parts.joinToString(" + ") { + when (it) { + is StringTemplate.Part.Text -> "\"${it.value}\"" + is StringTemplate.Part.Expr -> it.expression.emit() + } + } + } + + private fun Expression.emitWithSubstitution(varName: Name, replacement: String): String = when (this) { + is VariableReference -> if (name == varName) replacement else emit() + is FunctionCall -> { + val recv = receiver?.emitWithSubstitution(varName, replacement) + val args = arguments.values.map { it.emitWithSubstitution(varName, replacement) } + val typeArgsStr = if (typeArguments.isNotEmpty()) "<${typeArguments.joinToString(", ") { it.emitGenerics() }}>" else "" + val receiverStr = recv?.let { "$it." } ?: "" + "$receiverStr$typeArgsStr${name.value().sanitize()}(${args.joinToString(", ")})" + } + is FieldCall -> { + val recv = receiver?.emitWithSubstitution(varName, replacement) ?: "" + val dot = if (recv.isNotEmpty()) "." else "" + "$recv$dot${field.value().sanitize()}()" + } + is NotExpression -> "!${expression.emitWithSubstitution(varName, replacement)}" + is IfExpression -> "(${condition.emitWithSubstitution(varName, replacement)} ? ${thenExpr.emitWithSubstitution(varName, replacement)} : ${elseExpr.emitWithSubstitution(varName, replacement)})" + is MapExpression -> "${receiver.emitWithSubstitution(varName, replacement)}.stream().map(${variable.camelCase()} -> ${body.emitWithSubstitution(varName, replacement)}).toList()" + is LiteralList -> { + if (values.isEmpty()) { + "java.util.List.<${type.emit()}>of()" + } else { + val list = values.map { it.emitWithSubstitution(varName, replacement) }.joinToString(", ") + "java.util.List.of($list)" + } + } + is StringTemplate -> parts.joinToString(" + ") { + when (it) { + is StringTemplate.Part.Text -> "\"${it.value}\"" + is StringTemplate.Part.Expr -> it.expression.emitWithSubstitution(varName, replacement) + } + } + else -> emit() + } + + private fun LiteralList.emit(): String { + if (values.isEmpty()) return "java.util.List.<${type.emit()}>of()" + val list = values.joinToString(", ") { it.emit() } + return "java.util.List.of($list)" + } + + private fun LiteralMap.emit(): String { + if (values.isEmpty()) return "java.util.Collections.emptyMap()" + val map = values.entries.joinToString(", ") { + "java.util.Map.entry(${Literal(it.key, keyType).emit()}, ${it.value.emit()})" + } + return "java.util.Map.ofEntries($map)" + } + + private fun Literal.emit(): String = when { + type is Type.String -> "\"$value\"" + value is Long -> "${value}L" + else -> value.toString() + } +} + +private fun String.sanitize(): String = if (reservedKeywords.contains(this)) "_$this" else this + +private val reservedKeywords = setOf( + "abstract", "continue", "for", "new", "switch", + "assert", "default", "if", "package", "synchronized", + "boolean", "do", "goto", "private", "this", + "break", "double", "implements", "protected", "throw", + "byte", "else", "import", "public", "throws", + "case", "enum", "instanceof", "return", "transient", + "catch", "extends", "int", "short", "try", + "char", "final", "interface", "static", "void", + "class", "finally", "long", "strictfp", "volatile", + "const", "float", "native", "super", "while", + "true", "false", "null", +) diff --git a/src/compiler/ir/src/commonMain/kotlin/community/flock/wirespec/ir/generator/KotlinGenerator.kt b/src/compiler/ir/src/commonMain/kotlin/community/flock/wirespec/ir/generator/KotlinGenerator.kt new file mode 100644 index 000000000..cfab76403 --- /dev/null +++ b/src/compiler/ir/src/commonMain/kotlin/community/flock/wirespec/ir/generator/KotlinGenerator.kt @@ -0,0 +1,625 @@ +package community.flock.wirespec.ir.generator + +import community.flock.wirespec.ir.core.ArrayIndexCall +import community.flock.wirespec.ir.core.AssertStatement +import community.flock.wirespec.ir.core.Assignment +import community.flock.wirespec.ir.core.BinaryOp +import community.flock.wirespec.ir.core.Constraint +import community.flock.wirespec.ir.core.Constructor +import community.flock.wirespec.ir.core.ConstructorStatement +import community.flock.wirespec.ir.core.Element +import community.flock.wirespec.ir.core.Enum +import community.flock.wirespec.ir.core.EnumReference +import community.flock.wirespec.ir.core.EnumValueCall +import community.flock.wirespec.ir.core.ErrorStatement +import community.flock.wirespec.ir.core.Expression +import community.flock.wirespec.ir.core.Field +import community.flock.wirespec.ir.core.FieldCall +import community.flock.wirespec.ir.core.File +import community.flock.wirespec.ir.core.FlatMapIndexed +import community.flock.wirespec.ir.core.FunctionCall +import community.flock.wirespec.ir.core.IfExpression +import community.flock.wirespec.ir.core.Import +import community.flock.wirespec.ir.core.Interface +import community.flock.wirespec.ir.core.ListConcat +import community.flock.wirespec.ir.core.Literal +import community.flock.wirespec.ir.core.LiteralList +import community.flock.wirespec.ir.core.LiteralMap +import community.flock.wirespec.ir.core.Main +import community.flock.wirespec.ir.core.MapExpression +import community.flock.wirespec.ir.core.Name +import community.flock.wirespec.ir.core.Namespace +import community.flock.wirespec.ir.core.NotExpression +import community.flock.wirespec.ir.core.NullCheck +import community.flock.wirespec.ir.core.NullLiteral +import community.flock.wirespec.ir.core.NullableEmpty +import community.flock.wirespec.ir.core.NullableGet +import community.flock.wirespec.ir.core.NullableMap +import community.flock.wirespec.ir.core.NullableOf +import community.flock.wirespec.ir.core.Package +import community.flock.wirespec.ir.core.Parameter +import community.flock.wirespec.ir.core.Precision +import community.flock.wirespec.ir.core.PrintStatement +import community.flock.wirespec.ir.core.RawElement +import community.flock.wirespec.ir.core.RawExpression +import community.flock.wirespec.ir.core.ReturnStatement +import community.flock.wirespec.ir.core.Statement +import community.flock.wirespec.ir.core.StringTemplate +import community.flock.wirespec.ir.core.Struct +import community.flock.wirespec.ir.core.Switch +import community.flock.wirespec.ir.core.Type +import community.flock.wirespec.ir.core.TypeDescriptor +import community.flock.wirespec.ir.core.TypeParameter +import community.flock.wirespec.ir.core.Union +import community.flock.wirespec.ir.core.VariableReference +import community.flock.wirespec.ir.core.Function as AstFunction + +object KotlinGenerator : Generator { + override fun generate(element: Element): String = when (element) { + is File -> { + val emitter = KotlinEmitter(element) + emitter.emitFile() + } + + else -> { + val emitter = KotlinEmitter(File(Name.of(""), listOf(element))) + emitter.emitFile() + } + } +} + +private class KotlinEmitter(val file: File) { + + fun emitFile(): String { + val packages = file.elements.filterIsInstance() + val imports = file.elements.filterIsInstance() + val otherElements = file.elements.filter { it !is Package && it !is Import } + + val packagesStr = packages.joinToString("") { it.emit(0) } + val importsStr = imports.joinToString("") { it.emit(0) } + val elementsStr = otherElements.joinToString("") { it.emit(0, parents = emptyList()) } + + return "$packagesStr\n$importsStr\n$elementsStr".removeEmptyLines() + } + + private fun String.removeEmptyLines(): String = lines().filter { it.isNotEmpty() }.joinToString("\n").plus("\n") + + private fun String.indentCode(level: Int): String { + if (level <= 0) return this + val prefix = " ".repeat(level * 2) + return this.lines().joinToString("\n") { line -> + if (line.isEmpty()) line else prefix + line + } + } + + private fun Element.emit(indent: Int, isStatic: Boolean = false, parents: List): String = when (this) { + is Package -> emit(indent) + is Import -> emit(indent) + is Struct -> emit(indent, parents) + is AstFunction -> emit(indent, parents) + is Namespace -> emit(indent, parents) + is Interface -> emit(indent, parents) + is Union -> emit(indent, parents) + is Enum -> emit(indent) + is Main -> { + val staticContent = statics.joinToString("") { it.emit(indent, false, parents) } + val content = body.joinToString("") { it.emit(1) } + val modifier = if (isAsync) "suspend " else "" + "$staticContent${"${modifier}fun main() {\n$content}\n\n".indentCode(indent)}" + } + is File -> elements.joinToString("") { it.emit(indent, isStatic, parents) } + is RawElement -> "$code\n".indentCode(indent) + } + + private fun Package.emit(indent: Int): String = "package $path\n\n".indentCode(indent) + + private fun Import.emit(indent: Int): String = "import $path.${type.name}\n".indentCode(indent) + + private fun Namespace.emit(indent: Int, parents: List): String { + val extStr = extends?.let { " : ${it.emitGenerics()}" } ?: "" + val content = elements.joinToString("") { it.emit(indent + 1, isStatic = true, parents = parents + this) } + return "object ${name.pascalCase()}$extStr {\n$content${"}".indentCode(0)}\n\n".indentCode(indent) + } + + private fun Interface.emit(indent: Int, parents: List): String { + val sealedStr = if (isSealed) "sealed " else "" + val typeParamsStr = + if (typeParameters.isNotEmpty()) "<${typeParameters.joinToString(", ") { it.emit() }}>" else "" + val extStr = if (extends.isNotEmpty()) " : ${extends.joinToString(", ") { it.emitGenerics() }}" else "" + val fieldsContent = fields.joinToString("") { field -> + val overridePrefix = if (field.isOverride) "override " else "" + "${overridePrefix}val ${field.name.value()}: ${field.type.emitGenerics()}\n".indentCode(indent + 1) + } + val elementsContent = elements.joinToString("") { it.emit(indent + 1, isStatic = false, parents = parents + this) } + val content = fieldsContent + elementsContent + return if (content.isEmpty()) { + "${sealedStr}interface ${name.pascalCase()}$typeParamsStr$extStr\n\n".indentCode(indent) + } else { + "${sealedStr}interface ${name.pascalCase()}$typeParamsStr$extStr {\n$content${"}".indentCode(0)}\n\n".indentCode(indent) + } + } + + private fun Union.emit(indent: Int, parents: List): String { + val typeParamsStr = if (typeParameters.isNotEmpty()) "<${typeParameters.joinToString(", ") { it.emit() }}>" else "" + val extStr = extends?.let { " : ${it.emitGenerics()}" } ?: "" + return "sealed interface ${name.pascalCase()}$typeParamsStr$extStr\n\n".indentCode(indent) + } + + private fun Enum.emit(indent: Int): String { + val entriesStr = entries.joinToString(",\n") { entry -> + val e = if (entry.values.isEmpty()) { + entry.name.value() + } else { + "${entry.name.value()}(${entry.values.joinToString(", ")})" + } + e.indentCode(indent + 1) + } + val hasContent = fields.isNotEmpty() || constructors.isNotEmpty() || elements.isNotEmpty() + val terminator = if (hasContent) ";\n" else "" + + val constructorParamsStr = if (fields.isNotEmpty()) { + " (${fields.joinToString(", ") { "${if (it.isOverride) "override " else ""}val ${it.name.value()}: ${it.type.emitGenerics()}" }})" + } else { + "" + } + + val implStr = + extends?.let { "${if (constructorParamsStr.isNotEmpty()) "" else " "}: ${it.emitGenerics()}" } ?: "" + + val functionsStr = elements.filterIsInstance().joinToString("\n") { + val overridePrefix = if (it.isOverride || it.name.camelCase() == "toString") "override " else "" + it.emitAsMethod(indent + 1, overridePrefix) + } + + val content = listOf(functionsStr).filter { it.isNotEmpty() }.joinToString("\n") + val sep = if (content.isNotEmpty()) "\n" else "" + + return ( + "enum class ${name.pascalCase()}$constructorParamsStr$implStr {\n$entriesStr$terminator$sep$content\n${ + "}".indentCode( + indent, + ) + }".indentCode(indent) + ).trimEnd() + } + + private fun AstFunction.emitAsMethod(indent: Int, prefix: String): String { + val rType = returnType?.takeIf { it != Type.Unit }?.emitGenerics() ?: "Unit" + val params = parameters.joinToString(", ") { it.emit(0) } + return if (body.isEmpty()) { + "${prefix}fun ${name.camelCase()}($params): $rType".indentCode(indent) + } else { + val content = body.joinToString("") { it.emit(1) } + "${prefix}fun ${name.camelCase()}($params): $rType {\n$content${"}".indentCode(0)}\n".indentCode(indent) + } + } + + private fun Struct.emit(indent: Int, parents: List): String { + val implStr = if (interfaces.isEmpty()) "" else " : ${interfaces.map { it.emitGenerics() }.distinct().joinToString(", ")}" + + val nestedContent = elements.joinToString("") { it.emit(indent + 1, isStatic = true, parents = parents + this) } + val customConstructors = constructors.joinToString("") { it.emitKotlin(fields, indent + 1) } + + if (constructors.size == 1 && constructors.single().parameters.isEmpty()) { + val constructor = constructors.single() + val assignments = constructor.body.filterIsInstance() + val fieldProperties = fields.joinToString("\n") { field -> + val assignment = assignments.find { it.name.camelCase() == field.name.value() } + val valueStr = assignment?.let { " = ${it.value.emit()}" } ?: "" + "${if (field.isOverride) "override " else ""}val ${field.name.value().sanitize()}: ${field.type.emitGenerics()}$valueStr".indentCode(indent + 1) + } + val bodyContent = listOf(fieldProperties, nestedContent).filter { it.isNotEmpty() }.joinToString("\n") + return if (bodyContent.isNotEmpty()) { + "data object ${name.pascalCase()}$implStr {\n$bodyContent${"}".indentCode(indent)}\n\n".indentCode(indent) + } else { + "data object ${name.pascalCase()}$implStr\n\n".indentCode(indent) + } + } + + if (fields.isEmpty() && constructors.isEmpty()) { + return if (nestedContent.isNotEmpty()) { + "object ${name.pascalCase()}$implStr {\n$nestedContent${"}".indentCode(indent)}\n\n".indentCode(indent) + } else { + "object ${name.pascalCase()}$implStr\n\n".indentCode(indent) + } + } + + val params = fields.joinToString(",\n") { + "${if (it.isOverride) "override " else ""}val ${it.name.value().sanitize()}: ${it.type.emitGenerics()}".indentCode( + indent + 1, + ) + } + val paramsStr = if (fields.isEmpty()) "" else "(\n$params\n${")".indentCode(indent)}" + + val hasBody = customConstructors.isNotEmpty() || nestedContent.isNotEmpty() + + return if (hasBody) { + "data class ${name.pascalCase()}$paramsStr$implStr {\n$customConstructors$nestedContent${"}".indentCode(indent)}\n\n".indentCode( + indent, + ) + } else { + "data class ${name.pascalCase()}$paramsStr$implStr\n\n".indentCode(indent) + } + } + + private fun Constructor.emitKotlin(structFields: List, indent: Int): String { + val params = parameters.joinToString(", ") { it.emit(0) } + val isDelegating = body.any { it is ConstructorStatement } + + if (isDelegating) { + val delegationStmt = body.filterIsInstance().first() + val delegationArgs = delegationStmt.namedArguments.map { "${it.key.value()} = ${it.value.emit()}" } + val delegationStr = "this(${delegationArgs.joinToString(", ")})" + val otherStatements = body.filter { it !is ConstructorStatement } + return if (otherStatements.isEmpty()) { + "constructor($params) : $delegationStr\n".indentCode(indent) + } else { + val bodyContent = otherStatements.joinToString("") { it.emit(1) } + "constructor($params) : $delegationStr {\n$bodyContent${"}".indentCode(0)}\n".indentCode(indent) + } + } + + val assignments = body.filterIsInstance().associate { + it.name.camelCase() to it.value.emit() + } + val constructorArgs = structFields.map { field -> + assignments[field.name.value()] ?: "null" + } + val otherStatements = body.filter { + it !is Assignment || it.name.camelCase() !in structFields.map { f -> f.name.value() } + } + + return if (otherStatements.isEmpty()) { + "constructor($params) : this(${constructorArgs.joinToString(", ")})\n".indentCode(indent) + } else { + val bodyContent = otherStatements.joinToString("") { it.emit(1) } + "constructor($params) : this(${constructorArgs.joinToString(", ")}) {\n$bodyContent${"}".indentCode(0)}\n".indentCode( + indent, + ) + } + } + + private fun AstFunction.emit(indent: Int, parents: List): String { + val lastParent = parents.lastOrNull() + val isInterface = lastParent is Interface + + val overridePrefix = if (isOverride) "override " else "" + val suspendPrefix = if (isAsync) "suspend " else "" + val typeParamsStr = if (typeParameters.isNotEmpty()) { + "<${typeParameters.joinToString(", ") { it.emit() }}> " + } else { + "" + } + val rType = if (isAsync) { + returnType?.emitGenerics() ?: "Unit" + } else { + returnType?.takeIf { it != Type.Unit }?.emitGenerics() + } + val returnTypeStr = if (rType != null) ": $rType" else "" + val params = parameters.joinToString(", ") { it.emit(0) } + + return if (body.isEmpty()) { + "$overridePrefix${suspendPrefix}fun $typeParamsStr${name.camelCase()}($params)$returnTypeStr\n".indentCode(indent) + } else if (body.size == 1 && body.first() is ReturnStatement) { + val expr = (body.first() as ReturnStatement).expression.emit() + "$overridePrefix${suspendPrefix}fun $typeParamsStr${name.camelCase()}($params)$returnTypeStr =\n${expr.indentCode(1)}\n\n".indentCode( + indent, + ) + } else { + val content = body.joinToString("") { it.emit(1) } + "$overridePrefix${suspendPrefix}fun $typeParamsStr${name.camelCase()}($params)$returnTypeStr {\n$content${"}".indentCode(0)}\n\n".indentCode( + indent, + ) + } + } + + private fun Parameter.emit(indent: Int): String = "${name.camelCase().sanitize()}: ${type.emitGenerics()}".indentCode(indent) + + private fun TypeParameter.emit(): String { + val typeStr = type.emitGenerics() + return if (extends.isEmpty()) { + "$typeStr: Any" + } else { + "$typeStr: ${extends.joinToString(" & ") { it.emitGenerics() }}" + } + } + + private fun Type.emit(): String = when (this) { + is Type.Integer -> when (precision) { + Precision.P32 -> "Int" + Precision.P64 -> "Long" + } + + is Type.Number -> when (precision) { + Precision.P32 -> "Float" + Precision.P64 -> "Double" + } + + Type.Any -> "Any" + Type.String -> "String" + Type.Bytes -> "ByteArray" + Type.Boolean -> "Boolean" + Type.Unit -> "Unit" + Type.Wildcard -> "*" + Type.Reflect -> "KType" + is Type.Array -> "List" + is Type.Dict -> "Map" + is Type.Custom -> name + is Type.Nullable -> "${type.emitGenerics()}?" + is Type.IntegerLiteral -> "Int" + is Type.StringLiteral -> "String" + } + + private fun Type.emitGenerics(): String = when (this) { + is Type.Array -> "${emit()}<${elementType.emitGenerics()}>" + is Type.Dict -> "${emit()}<${keyType.emitGenerics()}, ${valueType.emitGenerics()}>" + is Type.Custom -> { + if (generics.isEmpty()) { + emit() + } else { + "${emit()}<${generics.joinToString(", ") { it.emitGenerics() }}>" + } + } + + is Type.Nullable -> "${type.emitGenerics()}?" + else -> emit() + } + + private fun Statement.emit(indent: Int): String = when (this) { + is PrintStatement -> "println(${expression.emit()})\n".indentCode(indent) + is ReturnStatement -> "return ${expression.emit()}\n".indentCode(indent) + is ConstructorStatement -> { + val allArgs = namedArguments.map { "${it.key.value()} = ${it.value.emit()}" } + val argsStr = when { + allArgs.isEmpty() -> "" + allArgs.size == 1 -> "(${allArgs.first()})" + else -> "(\n${allArgs.joinToString(",\n") { it.indentCode(1) }}\n)" + } + "${type.emitGenerics()}$argsStr\n".indentCode(indent) + } + + is Literal -> "${emit()}\n".indentCode(indent) + is LiteralList -> "${emit()}\n".indentCode(indent) + is LiteralMap -> "${emit()}\n".indentCode(indent) + is Assignment -> { + val expr = (value as? ConstructorStatement)?.let { constructorStmt -> + val allArgs = constructorStmt.namedArguments.map { "${it.key.value()} = ${it.value.emit()}" } + val argsStr = when { + allArgs.isEmpty() -> "" + allArgs.size == 1 -> "(${allArgs.first()})" + else -> "(\n${allArgs.joinToString(",\n") { it.indentCode(1) }}\n)" + } + "${constructorStmt.type.emitGenerics()}$argsStr" + } ?: value.emit() + if (isProperty) { + "${name.value().sanitize()} = $expr\n".indentCode(indent) + } else { + "val ${name.camelCase().sanitize()} = $expr\n".indentCode(indent) + } + } + + is ErrorStatement -> "error(${message.emit()})\n".indentCode(indent) + is AssertStatement -> "assert(${expression.emit()}) { \"$message\" }\n".indentCode(indent) + is Switch -> { + val isPatternSwitch = cases.any { it.type != null } + if (isPatternSwitch) { + val casesStr = cases.joinToString("") { case -> + val bodyStr = case.body.joinToString("") { it.emit(1) } + val typeStr = case.type?.emitGenerics() ?: "Any" + "is $typeStr -> {\n$bodyStr}\n".indentCode(indent + 1) + } + val defaultStr = default?.let { + val bodyStr = it.joinToString("") { stmt -> stmt.emit(1) } + "else -> {\n$bodyStr}\n".indentCode(indent + 1) + } ?: "" + val exprStr = variable?.let { "val ${it.camelCase()} = ${expression.emit()}" } ?: expression.emit() + "when($exprStr) {\n$casesStr$defaultStr}\n".indentCode(indent) + } else { + val casesStr = cases.joinToString("") { case -> + val bodyStr = case.body.joinToString("") { it.emit(1) } + "${case.value.emit()} -> {\n$bodyStr}\n".indentCode(indent + 1) + } + val defaultStr = default?.let { + val bodyStr = it.joinToString("") { stmt -> stmt.emit(1) } + "else -> {\n$bodyStr}\n".indentCode(indent + 1) + } ?: "" + "when (${expression.emit()}) {\n$casesStr$defaultStr}\n".indentCode(indent) + } + } + + is RawExpression -> "$code\n".indentCode(indent) + is NullLiteral -> "null\n".indentCode(indent) + is NullableEmpty -> "null\n".indentCode(indent) + is VariableReference -> "${name.camelCase().sanitize()}\n".indentCode(indent) + is FieldCall -> { + val receiverStr = receiver?.let { "${it.emit()}." } ?: "" + "$receiverStr${field.value().sanitize()}\n".indentCode(indent) + } + + is FunctionCall -> { + val typeArgsStr = + if (typeArguments.isNotEmpty()) "<${typeArguments.joinToString(", ") { it.emitGenerics() }}>" else "" + val receiverStr = receiver?.let { "${it.emit()}." } ?: "" + "$receiverStr${ + name.value().toKotlinStaticCall().sanitize() + }$typeArgsStr(${arguments.values.joinToString(", ") { it.emit() }})\n".indentCode(indent) + } + + is ArrayIndexCall -> if (caseSensitive) { + "${receiver.emit()}[${index.emit()}]\n".indentCode(indent) + } else { + "${receiver.emit()}.entries.find { it.key.equals(${index.emit()}, ignoreCase = true) }?.value\n".indentCode(indent) + } + + is EnumReference -> "${enumType.emitGenerics()}.${entry.value()}\n".indentCode(indent) + is EnumValueCall -> "${expression.emit()}.name\n".indentCode(indent) + is BinaryOp -> "(${left.emit()} ${operator.toKotlin()} ${right.emit()})\n".indentCode(indent) + is TypeDescriptor -> "${emitTypeDescriptor()}\n".indentCode(indent) + is NullCheck -> "${emit()}\n".indentCode(indent) + is NullableMap -> "${emit()}\n".indentCode(indent) + is NullableOf -> "${emit()}\n".indentCode(indent) + is NullableGet -> "${emit()}\n".indentCode(indent) + is Constraint.RegexMatch -> "${emit()}\n".indentCode(indent) + is Constraint.BoundCheck -> "${emit()}\n".indentCode(indent) + is NotExpression -> "!${expression.emit()}\n".indentCode(indent) + is IfExpression -> "${emit()}\n".indentCode(indent) + is MapExpression -> "${emit()}\n".indentCode(indent) + is FlatMapIndexed -> "${emit()}\n".indentCode(indent) + is ListConcat -> "${emit()}\n".indentCode(indent) + is StringTemplate -> "${emit()}\n".indentCode(indent) + } + + private fun BinaryOp.Operator.toKotlin(): String = when (this) { + BinaryOp.Operator.PLUS -> "+" + BinaryOp.Operator.EQUALS -> "==" + BinaryOp.Operator.NOT_EQUALS -> "!=" + } + + private fun String.toKotlinStaticCall(): String = when (this) { + "java.util.Collections.emptyList" -> "emptyList" + "java.util.Collections.emptyMap" -> "emptyMap" + else -> this + } + + private fun Expression.emit(): String = when (this) { + is ConstructorStatement -> { + if (type == Type.Unit) { + type.emitGenerics() + } else { + val allArgs = namedArguments.map { "${it.key.value()} = ${it.value.emit()}" } + val argsStr = when { + allArgs.isEmpty() -> "" + allArgs.size == 1 -> "(${allArgs.first()})" + else -> "(\n${allArgs.joinToString(",\n") { it.indentCode(1) }}\n)" + } + "${type.emitGenerics()}$argsStr" + } + } + + is Literal -> emit() + is LiteralList -> emit() + is LiteralMap -> emit() + is RawExpression -> code + is NullLiteral -> "null" + is NullableEmpty -> "null" + is VariableReference -> name.camelCase().sanitize() + is FieldCall -> { + val receiverStr = receiver?.let { "${it.emit()}." } ?: "" + "$receiverStr${field.value().sanitize()}" + } + + is FunctionCall -> { + val typeArgsStr = + if (typeArguments.isNotEmpty()) "<${typeArguments.joinToString(", ") { it.emitGenerics() }}>" else "" + val receiverStr = receiver?.let { "${it.emit()}." } ?: "" + "$receiverStr${ + name.value().toKotlinStaticCall().sanitize() + }$typeArgsStr(${arguments.values.joinToString(", ") { it.emit() }})" + } + + is ArrayIndexCall -> if (caseSensitive) { + "${receiver.emit()}[${index.emit()}]" + } else { + "${receiver.emit()}.entries.find { it.key.equals(${index.emit()}, ignoreCase = true) }?.value" + } + + is EnumReference -> "${enumType.emitGenerics()}.${entry.value()}" + is EnumValueCall -> "${expression.emit()}.name" + is BinaryOp -> "(${left.emit()} ${operator.toKotlin()} ${right.emit()})" + is TypeDescriptor -> emitTypeDescriptor() + is NullCheck -> "(${expression.emit()}?.let { ${body.emit()} }${alternative?.emit()?.let { " ?: $it" } ?: ""})" + is NullableMap -> "(${expression.emit()}?.let { ${body.emit()} } ?: ${alternative.emit()})" + is NullableOf -> expression.emit() + is NullableGet -> "${expression.emit()}!!" + is Constraint.RegexMatch -> "Regex(\"\"\"${pattern}\"\"\").matches(${value.emit()})" + is Constraint.BoundCheck -> { + val checks = listOfNotNull( + min?.let { "$it <= ${value.emit()}" }, + max?.let { "${value.emit()} <= $it" }, + ).joinToString(" && ").ifEmpty { "true" } + checks + } + is ErrorStatement -> "error(${message.emit()})" + is AssertStatement -> throw IllegalArgumentException("AssertStatement cannot be an expression in Kotlin") + is Switch -> { + val isPatternSwitch = cases.any { it.type != null } + if (isPatternSwitch) { + val casesStr = cases.joinToString("") { case -> + val bodyStr = case.body.joinToString("") { it.emit(1) } + val typeStr = case.type?.emitGenerics() ?: "Any" + "is $typeStr -> {\n$bodyStr}\n".indentCode(1) + } + val defaultStr = default?.let { + val bodyStr = it.joinToString("") { stmt -> stmt.emit(1) } + "else -> {\n$bodyStr}\n".indentCode(1) + } ?: "" + val exprStr = variable?.let { "val ${it.camelCase()} = ${expression.emit()}" } ?: expression.emit() + "when($exprStr) {\n$casesStr$defaultStr}" + } else { + val casesStr = cases.joinToString("") { case -> + val bodyStr = case.body.joinToString("") { it.emit(1) } + "${case.value.emit()} -> {\n$bodyStr}\n".indentCode(1) + } + val defaultStr = default?.let { + val bodyStr = it.joinToString("") { stmt -> stmt.emit(1) } + "else -> {\n$bodyStr}\n".indentCode(1) + } ?: "" + "when (${expression.emit()}) {\n$casesStr$defaultStr}" + } + } + + is Assignment -> throw IllegalArgumentException("Assignment cannot be an expression in Kotlin") + is PrintStatement -> throw IllegalArgumentException("PrintStatement cannot be an expression in Kotlin") + is ReturnStatement -> throw IllegalArgumentException("ReturnStatement cannot be an expression in Kotlin") + is NotExpression -> "!${expression.emit()}" + is IfExpression -> "if (${condition.emit()}) ${thenExpr.emit()} else ${elseExpr.emit()}" + is MapExpression -> "${receiver.emit()}.map { ${variable.camelCase()} -> ${body.emit()} }" + is FlatMapIndexed -> "${receiver.emit()}.flatMapIndexed { ${indexVar.camelCase()}, ${elementVar.camelCase()} -> ${body.emit()} }" + is ListConcat -> when { + lists.isEmpty() -> "emptyList()" + lists.size == 1 -> lists.single().emit() + else -> lists.joinToString(" + ") { expr -> + val emitted = expr.emit() + if (expr is IfExpression) "($emitted)" else emitted + } + } + is StringTemplate -> "\"${parts.joinToString("") { + when (it) { + is StringTemplate.Part.Text -> it.value + is StringTemplate.Part.Expr -> "\${${it.expression.emit()}}" + } + }}\"" + } + + private fun LiteralList.emit(): String { + if (values.isEmpty()) return "emptyList<${type.emitGenerics()}>()" + val list = values.joinToString(", ") { it.emit() } + return "listOf($list)" + } + + private fun LiteralMap.emit(): String { + if (values.isEmpty()) return "emptyMap()" + val map = values.entries.joinToString(", ") { + "${Literal(it.key, keyType).emit()} to ${it.value.emit()}" + } + return "mapOf($map)" + } + + private fun Literal.emit(): String = when (type) { + Type.String -> "\"$value\"" + else -> value.toString() + } + + private fun TypeDescriptor.emitTypeDescriptor(): String = "typeOf<${type.emitGenerics()}>()" +} + +private fun String.sanitize(): String = if (reservedKeywords.contains(this)) "`$this`" else this + +private val reservedKeywords = setOf( + "as", "break", "class", "continue", "do", + "else", "false", "for", "fun", "if", + "in", "interface", "internal", "is", "null", + "object", "open", "package", "return", "super", + "this", "throw", "true", "try", "typealias", + "typeof", "val", "var", "when", "while", +) diff --git a/src/compiler/ir/src/commonMain/kotlin/community/flock/wirespec/ir/generator/PythonGenerator.kt b/src/compiler/ir/src/commonMain/kotlin/community/flock/wirespec/ir/generator/PythonGenerator.kt new file mode 100644 index 000000000..082d3916f --- /dev/null +++ b/src/compiler/ir/src/commonMain/kotlin/community/flock/wirespec/ir/generator/PythonGenerator.kt @@ -0,0 +1,517 @@ +package community.flock.wirespec.ir.generator + +import community.flock.wirespec.ir.core.ArrayIndexCall +import community.flock.wirespec.ir.core.AssertStatement +import community.flock.wirespec.ir.core.Assignment +import community.flock.wirespec.ir.core.BinaryOp +import community.flock.wirespec.ir.core.Constraint +import community.flock.wirespec.ir.core.Constructor +import community.flock.wirespec.ir.core.ConstructorStatement +import community.flock.wirespec.ir.core.Element +import community.flock.wirespec.ir.core.Enum +import community.flock.wirespec.ir.core.EnumReference +import community.flock.wirespec.ir.core.EnumValueCall +import community.flock.wirespec.ir.core.ErrorStatement +import community.flock.wirespec.ir.core.Expression +import community.flock.wirespec.ir.core.Field +import community.flock.wirespec.ir.core.FieldCall +import community.flock.wirespec.ir.core.File +import community.flock.wirespec.ir.core.FlatMapIndexed +import community.flock.wirespec.ir.core.FunctionCall +import community.flock.wirespec.ir.core.IfExpression +import community.flock.wirespec.ir.core.Import +import community.flock.wirespec.ir.core.Interface +import community.flock.wirespec.ir.core.ListConcat +import community.flock.wirespec.ir.core.Literal +import community.flock.wirespec.ir.core.LiteralList +import community.flock.wirespec.ir.core.LiteralMap +import community.flock.wirespec.ir.core.Main +import community.flock.wirespec.ir.core.MapExpression +import community.flock.wirespec.ir.core.Name +import community.flock.wirespec.ir.core.Namespace +import community.flock.wirespec.ir.core.NotExpression +import community.flock.wirespec.ir.core.NullCheck +import community.flock.wirespec.ir.core.NullLiteral +import community.flock.wirespec.ir.core.NullableEmpty +import community.flock.wirespec.ir.core.NullableGet +import community.flock.wirespec.ir.core.NullableMap +import community.flock.wirespec.ir.core.NullableOf +import community.flock.wirespec.ir.core.Package +import community.flock.wirespec.ir.core.PrintStatement +import community.flock.wirespec.ir.core.RawElement +import community.flock.wirespec.ir.core.RawExpression +import community.flock.wirespec.ir.core.ReturnStatement +import community.flock.wirespec.ir.core.Statement +import community.flock.wirespec.ir.core.StringTemplate +import community.flock.wirespec.ir.core.Struct +import community.flock.wirespec.ir.core.Switch +import community.flock.wirespec.ir.core.Type +import community.flock.wirespec.ir.core.TypeDescriptor +import community.flock.wirespec.ir.core.Union +import community.flock.wirespec.ir.core.VariableReference +import community.flock.wirespec.ir.core.Function as AstFunction + +object PythonGenerator : Generator { + override fun generate(element: Element): String = when (element) { + is File -> element.emit(0) + else -> File(Name.of(""), listOf(element)).emit(0) + } + + private fun String.indentCode(level: Int): String = " ".repeat(level * 4) + this + + private fun File.emit(indent: Int): String = groupImports(elements).joinToString("") { it.emit(indent) }.removeEmptyLines() + + private fun groupImports(elements: List): List { + val result = mutableListOf() + var i = 0 + while (i < elements.size) { + val element = elements[i] + if (element is Import && element.path != ".") { + val path = element.path + val types = mutableListOf(element.type.name) + while (i + 1 < elements.size) { + val next = elements[i + 1] + if (next is Import && next.path == path) { + types.add(next.type.name) + i++ + } else { + break + } + } + result.add(RawElement("from $path import ${types.joinToString(", ")}")) + } else { + result.add(element) + } + i++ + } + return result + } + + private fun String.removeEmptyLines(): String = lines().filter { it.isNotEmpty() }.joinToString("\n").plus("\n") + + private fun Element.emit(indent: Int, parents: List = emptyList(), isStaticScope: Boolean = false, qualifier: ((String) -> String)? = null): String = when (this) { + is Package -> emit(indent) + is Import -> emit(indent) + is Struct -> emit(indent, parents, qualifier = qualifier) + is AstFunction -> { + val isInClass = parents.any { it is Struct || it is Interface || it is Namespace } + val isInInterface = parents.any { it is Interface } + emit(indent, isInClass = isInClass, isStaticScope = isStaticScope, isInInterface = isInInterface, qualifier = qualifier) + } + is Namespace -> emit(indent, parents) + is Interface -> emit(indent, parents, qualifier = qualifier) + is Union -> emit(indent, parents) + is Enum -> emit(indent) + is Main -> { + val staticContent = statics.joinToString("") { it.emit(indent, parents, isStaticScope, qualifier) } + val content = body.joinToString("") { it.emit(indent + 1) } + val asyncPrefix = if (isAsync) "async " else "" + val runner = if (isAsync) "asyncio.run(main())" else "main()" + val defBlock = "${asyncPrefix}def main():\n$content\n".indentCode(indent) + val guard = "if __name__ == \"__main__\":\n${runner.indentCode(1)}\n".indentCode(indent) + "$staticContent$defBlock$guard" + } + is File -> elements.joinToString("") { it.emit(indent, parents, isStaticScope, qualifier) } + is RawElement -> "$code\n".indentCode(indent) + } + + private fun Package.emit(indent: Int): String = "# package $path\n\n".indentCode(indent) + + private fun Import.emit(indent: Int): String = "from $path import ${type.name}\n".indentCode(indent) + + private fun Namespace.emit(indent: Int, parents: List = emptyList()): String { + val p = mutableListOf() + extends?.let { p.add(it.emit()) } + + val ext = if (p.isEmpty()) "" else "(${p.joinToString(", ")})" + val siblingNames = elements.mapNotNull { elementName(it) }.toSet() + val nameStr = name.pascalCase() + val nsQualifier: ((String) -> String)? = if (siblingNames.isNotEmpty()) { + { typeName -> if (typeName in siblingNames) "$nameStr.$typeName" else typeName } + } else { + null + } + val elementsContent = elements.joinToString("") { it.emit(indent + 1, parents = parents + this, isStaticScope = true, qualifier = nsQualifier) } + val content = elementsContent.ifEmpty { "pass\n".indentCode(indent + 1) } + return "class $nameStr$ext:\n$content\n".indentCode(indent) + } + + private fun Interface.emit(indent: Int, parents: List = emptyList(), qualifier: ((String) -> String)? = null): String { + val p = extends.map { it.emit() }.toMutableList() + p.add("ABC") + if (typeParameters.isNotEmpty()) { + p.add("Generic[${typeParameters.joinToString(", ") { it.type.emit() }}]") + } + val ext = if (p.isEmpty()) "" else "(${p.joinToString(", ")})" + val nestedNames = elements.mapNotNull { elementName(it) }.toSet() + val adjustedQualifier = if (qualifier != null && nestedNames.isNotEmpty()) { + { name: String -> if (name in nestedNames) name else qualifier(name) } + } else { + qualifier + } + val fieldsContent = fields.joinToString("") { field -> + "${field.name.value()}: ${field.type.emit(adjustedQualifier)}\n".indentCode(indent + 1) + } + val elementsContent = elements.joinToString("") { it.emit(indent + 1, parents = parents + this, isStaticScope = false, qualifier = adjustedQualifier) } + val content = (fieldsContent + elementsContent).ifEmpty { "pass\n".indentCode(indent + 1) } + return "class ${name.pascalCase()}$ext:\n$content\n".indentCode(indent) + } + + private fun Union.emit(indent: Int, parents: List = emptyList()): String { + val p = mutableListOf() + extends?.let { p.add(it.emit()) } + parents.filterIsInstance().forEach { p.add(it.name.pascalCase()) } + if (typeParameters.isNotEmpty()) { + p.add("Generic[${typeParameters.joinToString(", ") { it.type.emit() }}]") + } + + val ext = if (p.isEmpty()) "" else "(${p.distinct().joinToString(", ")})" + return "class ${name.pascalCase()}$ext:\n${"pass".indentCode(indent + 1)}\n\n".indentCode(indent) + } + + private fun Enum.emit(indent: Int): String { + val ext = if (extends != null) "(${extends!!.emit()}, enum.Enum)" else "(enum.Enum)" + val entriesStr = if (entries.isEmpty()) { + "pass".indentCode(indent + 1) + } else { + entries.joinToString("\n") { entry -> + val value = entry.values.firstOrNull() ?: "\"${entry.name.value()}\"" + "${entry.name.value()} = $value".indentCode(indent + 1) + } + } + return "class ${name.pascalCase()}$ext:\n$entriesStr\n\n".indentCode(indent) + } + + private fun Struct.emit(indent: Int, parents: List = emptyList(), qualifier: ((String) -> String)? = null): String { + val p = mutableListOf() + interfaces.forEach { p.add(it.emit()) } + + val ext = if (p.isEmpty()) "" else "(${p.distinct().joinToString(", ")})" + val nestedContent = elements.joinToString("") { it.emit(indent + 1, parents = parents + this, isStaticScope = false, qualifier = qualifier) } + val content = if (fields.isEmpty() && constructors.isEmpty()) { + "pass\n".indentCode(indent + 1) + } else { + val fieldDecls = fields.joinToString("") { it.emit(indent + 1, qualifier) } + val customConstructors = constructors.joinToString("") { it.emit(indent + 1, qualifier) } + "$fieldDecls$customConstructors" + } + + val decorator = "@dataclass\n".indentCode(indent) + return decorator + "class ${name.pascalCase()}$ext:\n$content$nestedContent\n".indentCode(indent) + } + + private fun Constructor.emit(indent: Int, qualifier: ((String) -> String)? = null): String { + val content = if (body.isEmpty()) { + "pass\n".indentCode(indent + 1) + } else { + body.joinToString("") { stmt -> + when (stmt) { + is Assignment -> "self.${stmt.name.value()} = ${stmt.value.emit()}\n".indentCode(indent + 1) + else -> stmt.emit(indent + 1).replace("this.", "self.") + } + } + } + if (parameters.isEmpty()) { + return "def __init__(self):\n$content\n".indentCode(indent) + } + val selfParam = "self,\n".indentCode(indent + 1) + val paramLines = parameters.joinToString(",\n") { + "${it.name.camelCase()}: ${it.type.emit(qualifier)}".indentCode(indent + 1) + } + val closeParen = "):\n".indentCode(indent) + return "def __init__(\n$selfParam$paramLines,\n$closeParen$content\n".indentCode(indent) + } + + private fun Field.emit(indent: Int, qualifier: ((String) -> String)? = null): String = "${name.value()}: ${type.emit(qualifier)}\n".indentCode(indent) + + private fun AstFunction.emit(indent: Int, isInClass: Boolean = false, isStaticScope: Boolean = false, isInInterface: Boolean = false, qualifier: ((String) -> String)? = null): String { + val params = parameters.joinToString(", ") { + if (it.name.camelCase() == "self") it.name.camelCase() else "${it.name.camelCase()}: ${it.type.emit(qualifier)}" + } + val effectivelyStatic = isStatic || isStaticScope + val selfPrefix = if (isInClass && !effectivelyStatic && parameters.none { it.name.camelCase() == "self" }) { + if (params.isEmpty()) "self" else "self, " + } else { + "" + } + val staticDecorator = if (isInClass && effectivelyStatic) "@staticmethod\n".indentCode(indent) else "" + val abstractDecorator = if (isInInterface && body.isEmpty()) "@abstractmethod\n".indentCode(indent) else "" + val content = if (body.isEmpty()) { + "...\n".indentCode(indent + 1) + } else { + body.joinToString("") { it.emit(indent + 1) } + } + val prefix = if (isAsync) "async " else "" + val returnAnnotation = returnType?.let { " -> ${it.emit(qualifier)}" } ?: "" + return staticDecorator + abstractDecorator + "${prefix}def ${name.value()}($selfPrefix$params)$returnAnnotation:\n$content\n".indentCode(indent) + } + + private fun Type.emit(qualifier: ((String) -> String)? = null): String = when (this) { + is Type.Integer -> "int" + is Type.Number -> "float" + Type.Any -> "Any" + Type.String -> "str" + Type.Boolean -> "bool" + Type.Bytes -> "bytes" + Type.Unit -> "None" + Type.Wildcard -> "Any" + Type.Reflect -> "type[T]" + is Type.Array -> "list[${elementType.emit(qualifier)}]" + is Type.Dict -> "dict[${keyType.emit(qualifier)}, ${valueType.emit(qualifier)}]" + is Type.Custom -> { + val qualifiedName = qualifier?.invoke(name) ?: name + if (generics.isEmpty()) { + qualifiedName + } else { + "$qualifiedName[${generics.joinToString(", ") { it.emit(qualifier) }}]" + } + } + is Type.Nullable -> "Optional[${type.emit(qualifier)}]" + is Type.IntegerLiteral -> "int" + is Type.StringLiteral -> "str" + } + + private fun Statement.emit(indent: Int): String = when (this) { + is PrintStatement -> "print(${expression.emit()})\n".indentCode(indent) + is ReturnStatement -> "return ${expression.emit()}\n".indentCode(indent) + is ConstructorStatement -> { + if (type == Type.Unit) { + "None\n".indentCode(indent) + } else { + val allArgs = namedArguments.map { "${it.key.value()}=${it.value.emit()}" } + "${type.emit()}(${allArgs.joinToString(", ")})\n".indentCode(indent) + } + } + is Literal -> "${emit()}\n".indentCode(indent) + is LiteralList -> "${emit()}\n".indentCode(indent) + is LiteralMap -> "${emit()}\n".indentCode(indent) + is Assignment -> "${name.camelCase()} = ${value.emit()}\n".indentCode(indent) + is ErrorStatement -> "raise Exception(${message.emit()})\n".indentCode(indent) + is AssertStatement -> "assert ${expression.emit()}, '${message.replace("'", "\\'")}'\n".indentCode(indent) + is Switch -> { + val isPatternSwitch = cases.any { it.type != null } + if (isPatternSwitch) { + val casesStr = cases.joinToString("") { case -> + val bodyStr = case.body.joinToString("") { it.emit(indent + 2) } + val typeStr = case.type?.emit() ?: "object" + val varBinding = variable?.let { " as ${it.camelCase()}" } ?: "" + "case $typeStr()$varBinding:\n$bodyStr".indentCode(indent + 1) + } + val defaultStr = default?.let { + val bodyStr = it.joinToString("") { stmt -> stmt.emit(indent + 2) } + "case _:\n$bodyStr".indentCode(indent + 1) + } ?: "" + "match ${expression.emit()}:\n$casesStr$defaultStr".indentCode(indent) + } else { + val casesStr = cases.joinToString("") { case -> + val bodyStr = case.body.joinToString("") { it.emit(indent + 2) } + "case ${case.value.emit()}:\n$bodyStr".indentCode(indent + 1) + } + val defaultStr = default?.let { + val bodyStr = it.joinToString("") { stmt -> stmt.emit(indent + 2) } + "case _:\n$bodyStr".indentCode(indent + 1) + } ?: "" + "match ${expression.emit()}:\n$casesStr$defaultStr".indentCode(indent) + } + } + is RawExpression -> "$code\n".indentCode(indent) + is NullLiteral -> "None\n".indentCode(indent) + is NullableEmpty -> "None\n".indentCode(indent) + is VariableReference -> "${name.camelCase()}\n".indentCode(indent) + is FieldCall -> { + val receiverStr = receiver?.let { "${it.emit()}." } ?: "" + "$receiverStr${field.value()}\n".indentCode(indent) + } + is FunctionCall -> { + val awaitPrefix = if (isAwait) "await " else "" + val recv = receiver + if (recv != null) { + "$awaitPrefix${recv.emit()}.${name.value()}(${arguments.values.joinToString(", ") { it.emit() }})\n".indentCode(indent) + } else { + "$awaitPrefix${name.value()}(${arguments.map { "${it.key.value()}=${it.value.emit()}" }.joinToString(", ")})\n".indentCode(indent) + } + } + is ArrayIndexCall -> if (caseSensitive) { + "${receiver.emit()}[${index.emit()}]\n".indentCode(indent) + } else { + "next((v for k, v in ${receiver.emit()}.items() if k.lower() == ${index.emit()}.lower()), None)\n".indentCode(indent) + } + is EnumReference -> "${enumType.emit()}.${entry.value()}\n".indentCode(indent) + is EnumValueCall -> "${expression.emit()}.value\n".indentCode(indent) + is BinaryOp -> { + if (operator == BinaryOp.Operator.PLUS && (left is Literal && (left as Literal).type == Type.String || right is Literal && (right as Literal).type == Type.String)) { + val leftStr = if (left is Literal && (left as Literal).type == Type.String) left.emit() else "str(${left.emit()})" + val rightStr = if (right is Literal && (right as Literal).type == Type.String) right.emit() else "str(${right.emit()})" + "($leftStr + $rightStr)\n".indentCode(indent) + } else { + "(${left.emit()} ${operator.toPython()} ${right.emit()})\n".indentCode(indent) + } + } + is TypeDescriptor -> "${type.emit()}\n".indentCode(indent) + is NullCheck -> "${emit()}\n".indentCode(indent) + is NullableMap -> "${emit()}\n".indentCode(indent) + is NullableOf -> "${emit()}\n".indentCode(indent) + is NullableGet -> "${emit()}\n".indentCode(indent) + is Constraint.RegexMatch -> "${emit()}\n".indentCode(indent) + is Constraint.BoundCheck -> "${emit()}\n".indentCode(indent) + is NotExpression -> "not ${expression.emit()}\n".indentCode(indent) + is IfExpression -> "${emit()}\n".indentCode(indent) + is MapExpression -> "${emit()}\n".indentCode(indent) + is FlatMapIndexed -> "${emit()}\n".indentCode(indent) + is ListConcat -> "${emit()}\n".indentCode(indent) + is StringTemplate -> "${emit()}\n".indentCode(indent) + } + + private fun BinaryOp.Operator.toPython(): String = when (this) { + BinaryOp.Operator.PLUS -> "+" + BinaryOp.Operator.EQUALS -> "==" + BinaryOp.Operator.NOT_EQUALS -> "!=" + } + + private fun Expression.emit(): String = when (this) { + is ConstructorStatement -> if (type == Type.Unit) "None" else "${type.emit()}(${namedArguments.map { "${it.key.value()}=${it.value.emit()}" }.joinToString(", ")})" + is Literal -> emit() + is LiteralList -> emit() + is LiteralMap -> emit() + is RawExpression -> code + is NullLiteral -> "None" + is NullableEmpty -> "None" + is VariableReference -> name.camelCase() + is FieldCall -> { + val receiverStr = receiver?.let { "${it.emit()}." } ?: "" + "$receiverStr${field.value()}" + } + is FunctionCall -> { + val awaitPrefix = if (isAwait) "await " else "" + val recv = receiver + if (recv != null) { + "$awaitPrefix${recv.emit()}.${name.value()}(${arguments.values.joinToString(", ") { it.emit() }})" + } else { + "$awaitPrefix${name.value()}(${arguments.map { "${it.key.value()}=${it.value.emit()}" }.joinToString(", ")})" + } + } + is ArrayIndexCall -> if (caseSensitive) { + "${receiver.emit()}[${index.emit()}]" + } else { + "next((v for k, v in ${receiver.emit()}.items() if k.lower() == ${index.emit()}.lower()), None)" + } + is EnumReference -> "${enumType.emit()}.${entry.value()}" + is EnumValueCall -> "${expression.emit()}.value" + is BinaryOp -> { + if (operator == BinaryOp.Operator.PLUS && (left is Literal && (left as Literal).type == Type.String || right is Literal && (right as Literal).type == Type.String)) { + val leftStr = if (left is Literal && (left as Literal).type == Type.String) left.emit() else "str(${left.emit()})" + val rightStr = if (right is Literal && (right as Literal).type == Type.String) right.emit() else "str(${right.emit()})" + "($leftStr + $rightStr)" + } else { + "(${left.emit()} ${operator.toPython()} ${right.emit()})" + } + } + is TypeDescriptor -> type.emit() + is NullCheck -> { + val exprStr = expression.emit() + val bodyStr = body.emitWithInlinedIt(exprStr) + val altStr = alternative?.emit() ?: "None" + "$bodyStr if $exprStr is not None else $altStr" + } + is NullableMap -> { + val exprStr = expression.emit() + val bodyStr = body.emitWithInlinedIt(exprStr) + val altStr = alternative.emit() + "$bodyStr if $exprStr is not None else $altStr" + } + is NullableOf -> expression.emit() + is NullableGet -> expression.emit() + is Constraint.RegexMatch -> "bool(re.match(r\"${rawValue}\", ${value.emit()}))" + is Constraint.BoundCheck -> { + val checks = listOfNotNull( + min?.let { "$it <= ${value.emit()}" }, + max?.let { "${value.emit()} <= $it" }, + ).joinToString(" and ").ifEmpty { "True" } + checks + } + is ErrorStatement -> "_raise(${message.emit()})" + is AssertStatement -> throw IllegalArgumentException("AssertStatement cannot be an expression in Python") + is Switch -> throw IllegalArgumentException("Switch cannot be an expression in Python") + is Assignment -> throw IllegalArgumentException("Assignment cannot be an expression in Python") + is PrintStatement -> throw IllegalArgumentException("PrintStatement cannot be an expression in Python") + is ReturnStatement -> throw IllegalArgumentException("ReturnStatement cannot be an expression in Python") + is NotExpression -> "not ${expression.emit()}" + is IfExpression -> "(${thenExpr.emit()} if ${condition.emit()} else ${elseExpr.emit()})" + is MapExpression -> "[${body.emit()} for ${variable.camelCase()} in ${receiver.emit()}]" + is FlatMapIndexed -> "[item for ${indexVar.camelCase()}, ${elementVar.camelCase()} in enumerate(${receiver.emit()}) for item in ${body.emit()}]" + is ListConcat -> when { + lists.isEmpty() -> "[]" + lists.size == 1 -> lists.single().emit() + else -> lists.joinToString(" + ") { it.emit() } + } + is StringTemplate -> "f\"${parts.joinToString("") { + when (it) { + is StringTemplate.Part.Text -> it.value + is StringTemplate.Part.Expr -> "{${it.expression.emit()}}" + } + }}\"" + } + + private fun Expression.emitWithInlinedIt(replacement: String): String = when (this) { + is VariableReference -> if (name.value() == "it") replacement else emit() + is FunctionCall -> { + val recv = receiver + val inlinedArgs = arguments.mapValues { it.value.emitWithInlinedIt(replacement) } + if (recv != null) { + "${recv.emitWithInlinedIt(replacement)}.${name.value()}(${inlinedArgs.values.joinToString(", ")})" + } else { + "${name.value()}(${inlinedArgs.map { "${it.key.value()}=${it.value}" }.joinToString(", ")})" + } + } + is FieldCall -> { + val receiverStr = receiver?.let { "${it.emitWithInlinedIt(replacement)}." } ?: "" + "$receiverStr${field.value()}" + } + is ArrayIndexCall -> if (caseSensitive) { + "${receiver.emitWithInlinedIt(replacement)}[${index.emitWithInlinedIt(replacement)}]" + } else { + "next((v for k, v in ${receiver.emitWithInlinedIt(replacement)}.items() if k.lower() == ${index.emitWithInlinedIt(replacement)}.lower()), None)" + } + is EnumValueCall -> "${expression.emitWithInlinedIt(replacement)}.value" + is NotExpression -> "not ${expression.emitWithInlinedIt(replacement)}" + is IfExpression -> "(${thenExpr.emitWithInlinedIt(replacement)} if ${condition.emitWithInlinedIt(replacement)} else ${elseExpr.emitWithInlinedIt(replacement)})" + is MapExpression -> "[${body.emitWithInlinedIt(replacement)} for ${variable.camelCase()} in ${receiver.emitWithInlinedIt(replacement)}]" + is FlatMapIndexed -> "[item for ${indexVar.camelCase()}, ${elementVar.camelCase()} in enumerate(${receiver.emitWithInlinedIt(replacement)}) for item in ${body.emitWithInlinedIt(replacement)}]" + is ListConcat -> lists.joinToString(" + ") { it.emitWithInlinedIt(replacement) } + is StringTemplate -> "f\"${parts.joinToString("") { + when (it) { + is StringTemplate.Part.Text -> it.value + is StringTemplate.Part.Expr -> "{${it.expression.emitWithInlinedIt(replacement)}}" + } + }}\"" + is LiteralList -> emit() + else -> emit() + } + + private fun LiteralList.emit(): String { + val list = values.joinToString(", ") { it.emit() } + return "[$list]" + } + + private fun LiteralMap.emit(): String { + val map = values.entries.joinToString(", ") { + "${Literal(it.key, keyType).emit()}: ${it.value.emit()}" + } + return "{$map}" + } + + private fun Literal.emit(): String = when (type) { + Type.String -> "'$value'" + Type.Boolean -> value.toString().replaceFirstChar { it.uppercase() } + else -> value.toString() + } + + private fun elementName(element: Element): String? = when (element) { + is Interface -> element.name.pascalCase() + is Struct -> element.name.pascalCase() + is Enum -> element.name.pascalCase() + is Union -> element.name.pascalCase() + is Namespace -> element.name.pascalCase() + else -> null + } +} diff --git a/src/compiler/ir/src/commonMain/kotlin/community/flock/wirespec/ir/generator/RustGenerator.kt b/src/compiler/ir/src/commonMain/kotlin/community/flock/wirespec/ir/generator/RustGenerator.kt new file mode 100644 index 000000000..181f7116c --- /dev/null +++ b/src/compiler/ir/src/commonMain/kotlin/community/flock/wirespec/ir/generator/RustGenerator.kt @@ -0,0 +1,597 @@ +package community.flock.wirespec.ir.generator + +import community.flock.wirespec.ir.core.ArrayIndexCall +import community.flock.wirespec.ir.core.AssertStatement +import community.flock.wirespec.ir.core.Assignment +import community.flock.wirespec.ir.core.BinaryOp +import community.flock.wirespec.ir.core.Constraint +import community.flock.wirespec.ir.core.Constructor +import community.flock.wirespec.ir.core.ConstructorStatement +import community.flock.wirespec.ir.core.Element +import community.flock.wirespec.ir.core.Enum +import community.flock.wirespec.ir.core.EnumReference +import community.flock.wirespec.ir.core.EnumValueCall +import community.flock.wirespec.ir.core.ErrorStatement +import community.flock.wirespec.ir.core.Expression +import community.flock.wirespec.ir.core.Field +import community.flock.wirespec.ir.core.FieldCall +import community.flock.wirespec.ir.core.File +import community.flock.wirespec.ir.core.FlatMapIndexed +import community.flock.wirespec.ir.core.FunctionCall +import community.flock.wirespec.ir.core.IfExpression +import community.flock.wirespec.ir.core.Import +import community.flock.wirespec.ir.core.Interface +import community.flock.wirespec.ir.core.ListConcat +import community.flock.wirespec.ir.core.Literal +import community.flock.wirespec.ir.core.LiteralList +import community.flock.wirespec.ir.core.LiteralMap +import community.flock.wirespec.ir.core.Main +import community.flock.wirespec.ir.core.MapExpression +import community.flock.wirespec.ir.core.Name +import community.flock.wirespec.ir.core.Namespace +import community.flock.wirespec.ir.core.NotExpression +import community.flock.wirespec.ir.core.NullCheck +import community.flock.wirespec.ir.core.NullLiteral +import community.flock.wirespec.ir.core.NullableEmpty +import community.flock.wirespec.ir.core.NullableGet +import community.flock.wirespec.ir.core.NullableMap +import community.flock.wirespec.ir.core.NullableOf +import community.flock.wirespec.ir.core.Package +import community.flock.wirespec.ir.core.Precision +import community.flock.wirespec.ir.core.PrintStatement +import community.flock.wirespec.ir.core.RawElement +import community.flock.wirespec.ir.core.RawExpression +import community.flock.wirespec.ir.core.ReturnStatement +import community.flock.wirespec.ir.core.Statement +import community.flock.wirespec.ir.core.StringTemplate +import community.flock.wirespec.ir.core.Struct +import community.flock.wirespec.ir.core.Switch +import community.flock.wirespec.ir.core.Type +import community.flock.wirespec.ir.core.TypeDescriptor +import community.flock.wirespec.ir.core.TypeParameter +import community.flock.wirespec.ir.core.Union +import community.flock.wirespec.ir.core.VariableReference +import community.flock.wirespec.ir.core.Function as AstFunction + +object RustGenerator : Generator { + override fun generate(element: Element): String = when (element) { + is File -> element.emit(0) + else -> File(Name.of(""), listOf(element)).emit(0) + } + + private fun String.indentCode(level: Int): String { + if (level <= 0) return this + val prefix = " ".repeat(level * 4) + return this.lines().joinToString("\n") { line -> + if (line.isEmpty()) line else prefix + line + } + } + + private fun File.emit(indent: Int): String = elements.joinToString("") { it.emit(indent) }.removeEmptyLines() + + private fun String.removeEmptyLines(): String = lines().filterNot(String::isEmpty).joinToString("\n", postfix = "\n") + + private fun Element.emit(indent: Int, parents: List = emptyList(), isStaticScope: Boolean = false): String = when (this) { + is Package -> emit(indent) + is Import -> emit(indent) + is Struct -> emit(indent, parents) + is AstFunction -> { + val isInClass = parents.any { it is Struct || it is Interface || it is Namespace } + val isInInterface = parents.any { it is Interface } + emit(indent, isInClass = isInClass, isStaticScope = isStaticScope, isInInterface = isInInterface) + } + is Namespace -> emit(indent, parents) + is Interface -> emit(indent, parents) + is Union -> emit(indent, parents) + is Enum -> emit(indent) + is Main -> { + val staticContent = statics.joinToString("") { it.emit(indent, parents, isStaticScope) } + val content = body.joinToString("") { it.emit(1) } + if (isAsync) { + "${staticContent}fn main() {\n${"pollster::block_on(async {\n".indentCode(1)}$content${"});\n".indentCode(1)}}\n".indentCode(indent) + } else { + "${staticContent}fn main() {\n$content}\n".indentCode(indent) + } + } + is File -> elements.joinToString("") { it.emit(indent, parents, isStaticScope) } + is RawElement -> "$code\n".indentCode(indent) + } + + private fun Package.emit(indent: Int): String = "// package $path\n\n".indentCode(indent) + + private fun Import.emit(indent: Int): String = "use super::${Name.of(type.name).snakeCase()}::${type.name};\n".indentCode(indent) + + private fun Namespace.emit(indent: Int, parents: List = emptyList()): String { + val hasComplexElements = elements.any { it is Interface || it is Union || it is Enum || it is Struct } + val content = elements.joinToString("") { it.emit(1, parents = parents + this, isStaticScope = !hasComplexElements) } + val rustName = name.pascalCase() + return when { + content.isBlank() -> "pub struct $rustName;\n\n".indentCode(indent) + hasComplexElements -> { + val useSuper = "use super::*;\n".indentCode(1) + "pub mod $rustName {\n$useSuper$content}\n\n".indentCode(indent) + } + else -> "pub struct $rustName;\n\nimpl $rustName {\n$content}\n\n".indentCode(indent) + } + } + + private fun Interface.emit(indent: Int, parents: List = emptyList()): String { + val rustName = name.pascalCase() + val typeParamsStr = if (typeParameters.isNotEmpty()) "<${typeParameters.joinToString(", ") { it.emit() }}>" else "" + val extStr = if (extends.isNotEmpty()) ": ${extends.joinToString(" + ") { it.emit() }}" else "" + val fieldsContent = fields.joinToString("") { field -> + "fn ${field.name.snakeCase().sanitize()}(&self) -> ${field.type.emit()};\n".indentCode(1) + } + val elementsContent = elements.joinToString("") { it.emit(1, parents = parents + this, isStaticScope = false) } + val content = fieldsContent + elementsContent + return if (content.isEmpty()) { + "pub trait $rustName$typeParamsStr$extStr {}\n\n".indentCode(indent) + } else { + "pub trait $rustName$typeParamsStr$extStr {\n$content}\n\n".indentCode(indent) + } + } + + private fun Union.emit(indent: Int, parents: List = emptyList()): String { + val rustName = name.pascalCase() + val typeParamsStr = if (typeParameters.isNotEmpty()) "<${typeParameters.joinToString(", ") { it.emit() }}>" else "" + val enumDef = if (members.isNotEmpty()) { + val variants = members.joinToString("\n") { "${it.name}(${it.name}),".indentCode(1) } + "#[derive(Debug, Clone, PartialEq)]\npub enum $rustName$typeParamsStr {\n$variants\n}\n\n".indentCode(indent) + } else { + "#[derive(Debug, Clone, PartialEq)]\npub enum $rustName$typeParamsStr {}\n\n".indentCode(indent) + } + + return enumDef + } + + private fun Enum.emit(indent: Int): String { + val rustName = name.pascalCase() + val entriesStr = entries.joinToString("\n") { entry -> + "${entry.name.value()},".indentCode(1) + } + val enumDef = "#[derive(Debug, Clone, PartialEq)]\npub enum $rustName {\n$entriesStr\n}\n\n".indentCode(indent) + + val enumImpl = if (entries.isNotEmpty()) { + fun Enum.Entry.wireValue(): String = if (values.isNotEmpty()) values.first().removeSurrounding("\"") else name.value() + + val labelArms = entries.joinToString("\n") { entry -> + "$rustName::${entry.name.value()} => \"${entry.wireValue()}\",".indentCode(3) + } + val fromLabelArms = entries.joinToString("\n") { entry -> + "\"${entry.wireValue()}\" => Some($rustName::${entry.name.value()}),".indentCode(3) + } + """ + |impl Enum for $rustName { + | fn label(&self) -> &str { + | match self { + |$labelArms + | } + | } + | fn from_label(s: &str) -> Option { + | match s { + |$fromLabelArms + | _ => None, + | } + | } + |} + |impl std::fmt::Display for $rustName { + | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + | write!(f, "{}", self.label()) + | } + |} + | + """.trimMargin().indentCode(indent) + } else { + "" + } + + return enumDef + enumImpl + } + + private fun Struct.emit(indent: Int, parents: List = emptyList()): String { + val rustName = name.pascalCase() + val functions = elements.filterIsInstance() + val nonFunctions = elements.filterNot { it is AstFunction } + val nestedContent = nonFunctions.joinToString("") { it.emit(indent, parents = parents + this, isStaticScope = false) } + + val implBlock = if (functions.isNotEmpty()) { + val fnsContent = functions.joinToString("") { it.emit(1, isInClass = true, isStaticScope = false, isInInterface = false) } + "impl $rustName {\n$fnsContent}\n\n".indentCode(indent) + } else { + "" + } + + if (fields.isEmpty() && constructors.isEmpty()) { + val structDef = "pub struct $rustName;\n\n".indentCode(indent) + return "$structDef$implBlock$nestedContent" + } + + val fieldsStr = fields.joinToString("\n") { + val fieldName = it.name.snakeCase().sanitize() + "pub $fieldName: ${it.type.emit()},".indentCode(1) + } + val structDef = "pub struct $rustName {\n$fieldsStr\n}\n\n".indentCode(indent) + + val customConstructors = constructors.joinToString("") { it.emit(rustName, fields, indent) } + + return "$structDef$customConstructors$implBlock$nestedContent" + } + + private fun Constructor.emit(structName: String, structFields: List, indent: Int): String { + val params = parameters.joinToString(", ") { "${it.name.snakeCase().sanitize()}: ${it.type.emit()}" } + val assignments = body.filterIsInstance().associate { + it.name.value() to it.value.emit() + } + val fieldInits = structFields.joinToString(",\n") { field -> + val value = assignments[field.name.value()] ?: "Default::default()" + "${field.name.snakeCase().sanitize()}: $value".indentCode(1) + } + val constructorBody = "$structName {\n$fieldInits\n}".indentCode(1) + val fnBody = "pub fn new($params) -> Self {\n$constructorBody\n}".indentCode(1) + return "impl $structName {\n$fnBody\n}\n\n".indentCode(indent) + } + + private fun AstFunction.emit(indent: Int, isInClass: Boolean = false, isStaticScope: Boolean = false, isInInterface: Boolean = false): String { + val params = parameters.joinToString(", ") { + val paramName = it.name.value() + if (paramName == "self" || paramName == "&self") paramName else "${it.name.snakeCase().sanitize()}: ${it.type.emit()}" + } + val rType = returnType?.takeIf { it != Type.Unit }?.emit() + val returnTypeStr = if (rType != null) " -> $rType" else "" + val prefix = if (isInInterface && body.isEmpty()) "" else "pub " + val asyncPrefix = if (isAsync) "async " else "" + val content = if (body.isEmpty()) { + return "${prefix}${asyncPrefix}fn ${name.snakeCase().sanitize()}($params)$returnTypeStr;\n".indentCode(indent) + } else { + body.joinToString("") { it.emit(1) } + } + return "${prefix}${asyncPrefix}fn ${name.snakeCase().sanitize()}($params)$returnTypeStr {\n$content}\n\n".indentCode(indent) + } + + private fun TypeParameter.emit(): String { + val typeStr = type.emit() + return if (extends.isEmpty()) { + typeStr + } else { + "$typeStr: ${extends.joinToString(" + ") { it.emit() }}" + } + } + + private fun Type.emit(): String = when (this) { + is Type.Integer -> when (precision) { + Precision.P32 -> "i32" + Precision.P64 -> "i64" + } + is Type.Number -> when (precision) { + Precision.P32 -> "f32" + Precision.P64 -> "f64" + } + Type.Any -> "Box" + Type.String -> "String" + Type.Boolean -> "bool" + Type.Bytes -> "Vec" + Type.Unit -> "()" + Type.Wildcard -> "_" + Type.Reflect -> "std::any::TypeId" + is Type.Array -> "Vec<${elementType.emit()}>" + is Type.Dict -> "std::collections::HashMap<${keyType.emit()}, ${valueType.emit()}>" + is Type.Custom -> { + if (generics.isEmpty()) { + name + } else { + "$name<${generics.joinToString(", ") { it.emit() }}>" + } + } + is Type.Nullable -> "Option<${type.emit()}>" + is Type.IntegerLiteral -> "i32" + is Type.StringLiteral -> "String" + } + + private fun emitArrayIndex(receiver: Expression, index: Expression, caseSensitive: Boolean = true): String = when { + !caseSensitive && index is Literal && index.type is Type.String -> + "${receiver.emit()}.iter().find(|(k, _)| k.eq_ignore_ascii_case(\"${index.value}\")).map(|(_, v)| v.clone())" + index is Literal && index.type is Type.String -> "${receiver.emit()}.get(\"${index.value}\")" + index is Literal && (index.type is Type.Integer || index.type is Type.Number) -> "${receiver.emit()}[${index.value}]" + else -> "${receiver.emit()}[${index.emit()}]" + } + + private fun emitErrorMessage(message: Expression): String = when (message) { + is BinaryOp -> { + fun flattenPlus(expr: Expression): List = when { + expr is BinaryOp && expr.operator == BinaryOp.Operator.PLUS -> flattenPlus(expr.left) + flattenPlus(expr.right) + else -> listOf(expr) + } + val parts = flattenPlus(message) + val formatStr = parts.joinToString("") { + if (it is Literal && it.type is Type.String) it.value.toString() else "{}" + } + val args = parts.filterNot { it is Literal && it.type is Type.String }.map { it.emit() } + if (args.isEmpty()) "\"$formatStr\"" else "\"$formatStr\", ${args.joinToString(", ")}" + } + is Literal -> when { + message.type is Type.String -> "\"${message.value}\"" + else -> "\"{}\", ${message.emit()}" + } + else -> "\"{}\", ${message.emit()}" + } + + private fun emitUnwrap(alternative: Expression?): String = when (alternative) { + is ErrorStatement -> { + val msg = alternative.message + if (msg is Literal && msg.type is Type.String) { + ".expect(\"${msg.value}\")" + } else { + ".unwrap_or_else(|| ${alternative.emit()})" + } + } + null -> "" + else -> ".unwrap_or(${alternative.emit()})" + } + + private fun Statement.emit(indent: Int): String = when (this) { + is PrintStatement -> "println!(\"{}\", ${expression.emit()});\n".indentCode(indent) + is ReturnStatement -> "return ${expression.emit()};\n".indentCode(indent) + is ConstructorStatement -> { + if (type == Type.Unit) { + "()\n".indentCode(indent) + } else { + val allArgs = namedArguments.map { "${it.key.snakeCase().sanitize()}: ${it.value.emit()}" } + val argsStr = when { + allArgs.isEmpty() -> " {}" + else -> " {\n${allArgs.joinToString(",\n") { it.indentCode(1) }},\n}" + } + "${type.emit()}$argsStr\n".indentCode(indent) + } + } + is Literal, is LiteralList, is LiteralMap -> "${emit()};\n".indentCode(indent) + is Assignment -> { + val expr = value.emit() + "let ${name.snakeCase().sanitize()} = $expr;\n".indentCode(indent) + } + is ErrorStatement -> "panic!(${emitErrorMessage(message)});\n".indentCode(indent) + is AssertStatement -> "assert!(${expression.emit()}, \"$message\");\n".indentCode(indent) + is Switch -> { + val casesStr = if (cases.any { it.type != null }) { + cases.joinToString("") { case -> + val bodyStr = case.body.joinToString("") { it.emit(1) } + val typeStr = case.type?.emit() ?: "_" + val varBinding = variable?.snakeCase() ?: "_" + "$typeStr($varBinding) => {\n$bodyStr}\n".indentCode(1) + } + } else { + cases.joinToString("") { case -> + val bodyStr = case.body.joinToString("") { it.emit(1) } + "${case.value.emit()} => {\n$bodyStr}\n".indentCode(1) + } + } + val defaultStr = default?.let { + val bodyStr = it.joinToString("") { stmt -> stmt.emit(1) } + "_ => {\n$bodyStr}\n".indentCode(1) + } ?: "" + "match ${expression.emit()} {\n$casesStr$defaultStr}\n".indentCode(indent) + } + is RawExpression -> "$code;\n".indentCode(indent) + is NullLiteral, is NullableEmpty -> "None;\n".indentCode(indent) + is VariableReference -> "${name.snakeCase().sanitize()};\n".indentCode(indent) + is FieldCall -> { + val receiverStr = receiver?.let { "${it.emit()}." } ?: "" + "$receiverStr${field.snakeCase().sanitize()};\n".indentCode(indent) + } + is FunctionCall -> { + val recv = receiver + val funcName = if (name.value().contains("::")) name.value() else name.snakeCase().sanitize() + val awaitSuffix = if (isAwait) ".await" else "" + if (recv != null) { + "${recv.emit()}.$funcName(${arguments.values.joinToString(", ") { it.emit() }})$awaitSuffix;\n".indentCode(indent) + } else { + "$funcName(${arguments.values.joinToString(", ") { it.emit() }})$awaitSuffix;\n".indentCode(indent) + } + } + is ArrayIndexCall -> "${emitArrayIndex(receiver, index, caseSensitive)};\n".indentCode(indent) + is EnumReference -> "${enumType.emit()}::${entry.value()};\n".indentCode(indent) + is EnumValueCall -> "format!(\"{:?}\", ${expression.emit()});\n".indentCode(indent) + is BinaryOp -> "(${left.emit()} ${operator.toRust()} ${right.emit()});\n".indentCode(indent) + is TypeDescriptor -> "std::any::TypeId::of::<${type.emit()}>();\n".indentCode(indent) + is NullCheck, is NullableMap, is NullableOf, is NullableGet, + is Constraint.RegexMatch, is Constraint.BoundCheck, + is IfExpression, is MapExpression, is FlatMapIndexed, + is ListConcat, is StringTemplate, + -> "${emit()};\n".indentCode(indent) + is NotExpression -> "!${expression.emit()};\n".indentCode(indent) + } + + private fun BinaryOp.Operator.toRust(): String = when (this) { + BinaryOp.Operator.PLUS -> "+" + BinaryOp.Operator.EQUALS -> "==" + BinaryOp.Operator.NOT_EQUALS -> "!=" + } + + private fun Expression.emit(): String = when (this) { + is ConstructorStatement -> { + if (type == Type.Unit) { + "()" + } else { + val allArgs = namedArguments.map { "${it.key.snakeCase().sanitize()}: ${it.value.emit()}" } + val argsStr = when { + allArgs.isEmpty() -> " {}" + else -> " { ${allArgs.joinToString(", ")} }" + } + "${type.emit()}$argsStr" + } + } + is Literal -> emit() + is LiteralList -> emit() + is LiteralMap -> emit() + is RawExpression -> code + is NullLiteral, is NullableEmpty -> "None" + is VariableReference -> name.snakeCase().sanitize() + is FieldCall -> { + val receiverStr = receiver?.let { "${it.emit()}." } ?: "" + "$receiverStr${field.snakeCase().sanitize()}" + } + is FunctionCall -> { + val recv = receiver + val funcName = if (name.value().contains("::")) name.value() else name.snakeCase().sanitize() + val awaitSuffix = if (isAwait) ".await" else "" + if (recv != null) { + "${recv.emit()}.$funcName(${arguments.values.joinToString(", ") { it.emit() }})$awaitSuffix" + } else { + "$funcName(${arguments.values.joinToString(", ") { it.emit() }})$awaitSuffix" + } + } + is ArrayIndexCall -> emitArrayIndex(receiver, index, caseSensitive) + is EnumReference -> "${enumType.emit()}::${entry.value()}" + is EnumValueCall -> "format!(\"{:?}\", ${expression.emit()})" + is BinaryOp -> "(${left.emit()} ${operator.toRust()} ${right.emit()})" + is TypeDescriptor -> "std::any::TypeId::of::<${type.emit()}>()" + is NullCheck -> { + val exprStr = expression.emit() + val bodyStr = body.emit() + "$exprStr.as_ref().map(|it| $bodyStr)${emitUnwrap(alternative)}" + } + is NullableMap -> { + val exprStr = expression.emit() + val bodyStr = body.emit() + "$exprStr.as_ref().map(|it| $bodyStr)${emitUnwrap(alternative)}" + } + is NullableOf -> "Some(${expression.emit()})" + is NullableGet -> "${expression.emit()}.unwrap()" + is Constraint.RegexMatch -> "regex::Regex::new(r\"$pattern\").unwrap().is_match(&${value.emit()})" + is Constraint.BoundCheck -> listOfNotNull( + min?.let { "$it <= ${value.emit()}" }, + max?.let { "${value.emit()} <= $it" }, + ).joinToString(" && ").ifEmpty { "true" } + is ErrorStatement -> "panic!(${emitErrorMessage(message)})" + is AssertStatement -> throw IllegalArgumentException("AssertStatement cannot be an expression in Rust") + is Switch -> throw IllegalArgumentException("Switch cannot be an expression in Rust") + is Assignment -> throw IllegalArgumentException("Assignment cannot be an expression in Rust") + is PrintStatement -> throw IllegalArgumentException("PrintStatement cannot be an expression in Rust") + is ReturnStatement -> throw IllegalArgumentException("ReturnStatement cannot be an expression in Rust") + is NotExpression -> "!${expression.emit()}" + is IfExpression -> "if ${condition.emit()} { ${thenExpr.emit()} } else { ${elseExpr.emit()} }" + is MapExpression -> "${receiver.emit()}.iter().map(|${variable.snakeCase()}| ${body.emit()}).collect::>()" + is FlatMapIndexed -> "${receiver.emit()}.iter().enumerate().flat_map(|(${indexVar.snakeCase()}, ${elementVar.snakeCase()})| ${body.emit()}).collect::>()" + is ListConcat -> when { + lists.isEmpty() -> "vec![]" + lists.size == 1 -> lists.single().emit() + else -> "vec![${lists.joinToString(", ") { "${it.emit()}.as_slice()" }}].concat()" + } + is StringTemplate -> { + val mapped = parts.map { part -> + when (part) { + is StringTemplate.Part.Text -> part.value to null + is StringTemplate.Part.Expr -> "{}" to part.expression.emit() + } + } + val formatStr = mapped.joinToString("") { it.first } + val args = mapped.mapNotNull { it.second } + if (args.isEmpty()) { + "String::from(\"$formatStr\")" + } else { + "format!(\"$formatStr\", ${args.joinToString(", ")})" + } + } + } + + private fun Expression.emitWithInlinedIt(replacement: String): String = when (this) { + is VariableReference -> if (name.value() == "it") replacement else emit() + is FunctionCall -> { + val recv = receiver + val funcName = if (name.value().contains("::")) name.value() else name.snakeCase().sanitize() + val awaitSuffix = if (isAwait) ".await" else "" + if (recv != null) { + "${recv.emitWithInlinedIt(replacement)}.$funcName(${arguments.values.joinToString(", ") { it.emitWithInlinedIt(replacement) }})$awaitSuffix" + } else { + "$funcName(${arguments.values.joinToString(", ") { it.emitWithInlinedIt(replacement) }})$awaitSuffix" + } + } + is FieldCall -> { + val receiverStr = receiver?.let { "${it.emitWithInlinedIt(replacement)}." } ?: "" + "$receiverStr${field.snakeCase().sanitize()}" + } + is ArrayIndexCall -> { + if (!caseSensitive && index is Literal && index.type is Type.String) { + "${receiver.emitWithInlinedIt(replacement)}.iter().find(|(k, _)| k.eq_ignore_ascii_case(\"${(index as Literal).value}\")).map(|(_, v)| v.clone())" + } else if (index is Literal && index.type is Type.String) { + "${receiver.emitWithInlinedIt(replacement)}.get(\"${(index as Literal).value}\")" + } else { + val idxStr = if (index is Literal && (index.type is Type.Integer || index.type is Type.Number)) { + "${(index as Literal).value}" + } else { + index.emitWithInlinedIt(replacement) + } + "${receiver.emitWithInlinedIt(replacement)}[$idxStr]" + } + } + is EnumValueCall -> "format!(\"{:?}\", ${expression.emitWithInlinedIt(replacement)})" + is NotExpression -> "!${expression.emitWithInlinedIt(replacement)}" + is IfExpression -> "if ${condition.emitWithInlinedIt(replacement)} { ${thenExpr.emitWithInlinedIt(replacement)} } else { ${elseExpr.emitWithInlinedIt(replacement)} }" + is MapExpression -> "${receiver.emitWithInlinedIt(replacement)}.iter().map(|${variable.snakeCase()}| ${body.emitWithInlinedIt(replacement)}).collect::>()" + is FlatMapIndexed -> "${receiver.emitWithInlinedIt(replacement)}.iter().enumerate().flat_map(|(${indexVar.snakeCase()}, ${elementVar.snakeCase()})| ${body.emitWithInlinedIt(replacement)}).collect::>()" + is ListConcat -> when { + lists.isEmpty() -> "vec![]" + lists.size == 1 -> lists.single().emitWithInlinedIt(replacement) + else -> "vec![${lists.joinToString(", ") { "${it.emitWithInlinedIt(replacement)}.as_slice()" }}].concat()" + } + is StringTemplate -> { + val mapped = parts.map { part -> + when (part) { + is StringTemplate.Part.Text -> part.value to null + is StringTemplate.Part.Expr -> "{}" to part.expression.emitWithInlinedIt(replacement) + } + } + val formatStr = mapped.joinToString("") { it.first } + val args = mapped.mapNotNull { it.second } + if (args.isEmpty()) { + "String::from(\"$formatStr\")" + } else { + "format!(\"$formatStr\", ${args.joinToString(", ")})" + } + } + else -> emit() + } + + private fun LiteralList.emit(): String { + if (values.isEmpty()) return "Vec::<${type.emit()}>::new()" + val list = values.joinToString(", ") { it.emit() } + return "vec![$list]" + } + + private fun LiteralMap.emit(): String { + if (values.isEmpty()) return "std::collections::HashMap::new()" + val map = values.entries.joinToString(", ") { + "(${Literal(it.key, keyType).emit()}, ${it.value.emit()})" + } + return "std::collections::HashMap::from([$map])" + } + + private fun Literal.emit(): String = when (type) { + is Type.String -> "String::from(\"$value\")" + is Type.Number -> "${value}_${when (type.precision) { + Precision.P32 -> "f32" + Precision.P64 -> "f64" + }}" + is Type.Integer -> "${value}_${when (type.precision) { + Precision.P32 -> "i32" + Precision.P64 -> "i64" + }}" + Type.Boolean -> value.toString() + else -> value.toString() + } +} + +private fun String.sanitize(): String = if (this in reservedKeywords) "r#$this" else this + +private val reservedKeywords = setOf( + "as", "break", "const", "continue", "crate", + "else", "enum", "extern", "false", "fn", + "for", "if", "impl", "in", "let", + "loop", "match", "mod", "move", "mut", + "pub", "ref", "return", + "static", "struct", "super", "trait", "true", + "type", "unsafe", "use", "where", "while", + "async", "await", "dyn", "abstract", "become", + "box", "do", "final", "macro", "override", + "priv", "typeof", "unsized", "virtual", "yield", + "try", +) diff --git a/src/compiler/ir/src/commonMain/kotlin/community/flock/wirespec/ir/generator/ScalaGenerator.kt b/src/compiler/ir/src/commonMain/kotlin/community/flock/wirespec/ir/generator/ScalaGenerator.kt new file mode 100644 index 000000000..319c3e35e --- /dev/null +++ b/src/compiler/ir/src/commonMain/kotlin/community/flock/wirespec/ir/generator/ScalaGenerator.kt @@ -0,0 +1,660 @@ +package community.flock.wirespec.ir.generator + +import community.flock.wirespec.ir.core.ArrayIndexCall +import community.flock.wirespec.ir.core.AssertStatement +import community.flock.wirespec.ir.core.Assignment +import community.flock.wirespec.ir.core.BinaryOp +import community.flock.wirespec.ir.core.Constraint +import community.flock.wirespec.ir.core.Constructor +import community.flock.wirespec.ir.core.ConstructorStatement +import community.flock.wirespec.ir.core.Element +import community.flock.wirespec.ir.core.Enum +import community.flock.wirespec.ir.core.EnumReference +import community.flock.wirespec.ir.core.EnumValueCall +import community.flock.wirespec.ir.core.ErrorStatement +import community.flock.wirespec.ir.core.Expression +import community.flock.wirespec.ir.core.Field +import community.flock.wirespec.ir.core.FieldCall +import community.flock.wirespec.ir.core.File +import community.flock.wirespec.ir.core.FlatMapIndexed +import community.flock.wirespec.ir.core.FunctionCall +import community.flock.wirespec.ir.core.IfExpression +import community.flock.wirespec.ir.core.Import +import community.flock.wirespec.ir.core.Interface +import community.flock.wirespec.ir.core.ListConcat +import community.flock.wirespec.ir.core.Literal +import community.flock.wirespec.ir.core.LiteralList +import community.flock.wirespec.ir.core.LiteralMap +import community.flock.wirespec.ir.core.Main +import community.flock.wirespec.ir.core.MapExpression +import community.flock.wirespec.ir.core.Name +import community.flock.wirespec.ir.core.Namespace +import community.flock.wirespec.ir.core.NotExpression +import community.flock.wirespec.ir.core.NullCheck +import community.flock.wirespec.ir.core.NullLiteral +import community.flock.wirespec.ir.core.NullableEmpty +import community.flock.wirespec.ir.core.NullableGet +import community.flock.wirespec.ir.core.NullableMap +import community.flock.wirespec.ir.core.NullableOf +import community.flock.wirespec.ir.core.Package +import community.flock.wirespec.ir.core.Parameter +import community.flock.wirespec.ir.core.Precision +import community.flock.wirespec.ir.core.PrintStatement +import community.flock.wirespec.ir.core.RawElement +import community.flock.wirespec.ir.core.RawExpression +import community.flock.wirespec.ir.core.ReturnStatement +import community.flock.wirespec.ir.core.Statement +import community.flock.wirespec.ir.core.StringTemplate +import community.flock.wirespec.ir.core.Struct +import community.flock.wirespec.ir.core.Switch +import community.flock.wirespec.ir.core.Type +import community.flock.wirespec.ir.core.TypeDescriptor +import community.flock.wirespec.ir.core.TypeParameter +import community.flock.wirespec.ir.core.Union +import community.flock.wirespec.ir.core.VariableReference +import community.flock.wirespec.ir.core.Function as AstFunction + +object ScalaGenerator : Generator { + override fun generate(element: Element): String = when (element) { + is File -> { + val emitter = ScalaEmitter(element) + emitter.emitFile() + } + + else -> { + val emitter = ScalaEmitter(File(Name.of(""), listOf(element))) + emitter.emitFile() + } + } +} + +private class +ScalaEmitter( + val file: File, +) { + private val objectNames = collectObjectNames(file.elements) + private val primaryFieldNames = collectPrimaryFieldNames(file.elements) + + private fun collectObjectNames(elements: List): Set { + val names = mutableSetOf() + for (element in elements) { + when (element) { + is Struct -> { + val isObject = (element.constructors.size == 1 && element.constructors.single().parameters.isEmpty()) || + (element.fields.isEmpty() && element.constructors.isEmpty()) + if (isObject && !element.isModelStruct()) names.add(element.name.pascalCase()) + names.addAll(collectObjectNames(element.elements)) + } + is Namespace -> names.addAll(collectObjectNames(element.elements)) + is Interface -> names.addAll(collectObjectNames(element.elements)) + else -> {} + } + } + return names + } + + private fun Struct.isModelStruct(): Boolean = interfaces.any { it.name == "Wirespec.Model" } + + private fun collectPrimaryFieldNames(elements: List): Map> { + val result = mutableMapOf>() + for (element in elements) { + when (element) { + is Struct -> { + result[element.name.pascalCase()] = element.fields.map { it.name.value() }.toSet() + result.putAll(collectPrimaryFieldNames(element.elements)) + } + is Namespace -> result.putAll(collectPrimaryFieldNames(element.elements)) + is Interface -> result.putAll(collectPrimaryFieldNames(element.elements)) + else -> {} + } + } + return result + } + + private fun ConstructorStatement.needsNew(): Boolean { + val typeName = (type as? Type.Custom)?.name ?: return false + if (namedArguments.isEmpty()) return false + // Default to true for unknown types - `new CaseClass(args)` is always valid in Scala + val primaryFields = primaryFieldNames[typeName] ?: return true + val argNames = namedArguments.keys.map { it.value() }.toSet() + return argNames != primaryFields + } + + fun emitFile(): String { + val packages = file.elements.filterIsInstance() + val imports = file.elements.filterIsInstance() + val otherElements = file.elements.filter { it !is Package && it !is Import } + + val packagesStr = packages.joinToString("") { it.emit(0) } + val importsStr = imports.joinToString("") { it.emit(0) } + val elementsStr = otherElements.joinToString("") { it.emit(0, parents = emptyList()) } + + return "$packagesStr\n$importsStr\n$elementsStr".removeEmptyLines() + } + + private fun String.removeEmptyLines(): String = lines().filter { it.isNotEmpty() }.joinToString("\n").plus("\n") + + private fun String.indentCode(level: Int): String { + if (level <= 0) return this + val prefix = " ".repeat(level * 2) + return this.lines().joinToString("\n") { line -> + if (line.isEmpty()) line else prefix + line + } + } + + private fun Element.emit(indent: Int, isStatic: Boolean = false, parents: List): String = when (this) { + is Package -> emit(indent) + is Import -> emit(indent) + is Struct -> emit(indent, parents) + is AstFunction -> emit(indent, parents) + is Namespace -> emit(indent, parents) + is Interface -> emit(indent, parents) + is Union -> emit(indent, parents) + is Enum -> emit(indent) + is Main -> { + val staticContent = statics.joinToString("") { it.emit(1, false, parents) } + val content = body.joinToString("") { it.emit(1) } + "object ${file.name.pascalCase()} {\n$staticContent def main(args: Array[String]): Unit = {\n$content }\n}\n\n".indentCode(indent) + } + is File -> elements.joinToString("") { it.emit(indent, isStatic, parents) } + is RawElement -> "$code\n".indentCode(indent) + } + + private fun Package.emit(indent: Int): String = "package $path\n\n".indentCode(indent) + + private fun Import.emit(indent: Int): String = "import $path.${type.name}\n".indentCode(indent) + + private fun Namespace.emit(indent: Int, parents: List): String { + val extStr = extends?.let { " extends ${it.emitTypeAnnotation()}" } ?: "" + val content = elements.joinToString("") { it.emit(indent + 1, isStatic = true, parents = parents + this) } + return "object ${name.pascalCase()}$extStr {\n$content${"}".indentCode(0)}\n\n".indentCode(indent) + } + + private fun Interface.emit(indent: Int, parents: List): String { + val sealedStr = if (isSealed) "sealed " else "" + val typeParamsStr = + if (typeParameters.isNotEmpty()) "[${typeParameters.joinToString(", ") { it.emit() }}]" else "" + val extStr = if (extends.isNotEmpty()) " extends ${extends.joinToString(" with ") { it.emitTypeAnnotation() }}" else "" + val fieldsContent = fields.joinToString("") { field -> + val overridePrefix = if (field.isOverride) "override " else "" + "${overridePrefix}def ${field.name.value()}: ${field.type.emitTypeAnnotation()}\n".indentCode(indent + 1) + } + val elementsContent = elements.joinToString("") { it.emit(indent + 1, isStatic = false, parents = parents + this) } + val content = fieldsContent + elementsContent + return if (content.isEmpty()) { + "${sealedStr}trait ${name.pascalCase()}$typeParamsStr$extStr\n\n".indentCode(indent) + } else { + "${sealedStr}trait ${name.pascalCase()}$typeParamsStr$extStr {\n$content${"}".indentCode(0)}\n\n".indentCode(indent) + } + } + + private fun Union.emit(indent: Int, parents: List): String { + val typeParamsStr = if (typeParameters.isNotEmpty()) "[${typeParameters.joinToString(", ") { it.emit() }}]" else "" + val extStr = extends?.let { " extends ${it.emitTypeAnnotation()}" } ?: "" + return "sealed trait ${name.pascalCase()}$typeParamsStr$extStr\n\n".indentCode(indent) + } + + private fun Enum.emit(indent: Int): String { + val implStr = extends?.let { " extends ${it.emitGenerics()}" } ?: "" + + val hasFields = fields.isNotEmpty() + + if (hasFields) { + val fieldsStr = fields.joinToString(", ") { "${if (it.isOverride) "override " else ""}val ${it.name.value()}: ${it.type.emitGenerics()}" } + val entriesStr = entries.joinToString(",\n") { entry -> + val e = if (entry.values.isEmpty()) { + "case ${entry.name.value()}" + } else { + "case ${entry.name.value()} extends ${name.pascalCase()}(${entry.values.joinToString(", ")})" + } + e.indentCode(indent + 1) + } + val functionsStr = elements.filterIsInstance().joinToString("\n") { + val overridePrefix = if (it.isOverride || it.name.camelCase() == "toString") "override " else "" + it.emitAsMethod(indent + 1, overridePrefix) + } + val content = listOf(entriesStr, functionsStr).filter { it.isNotEmpty() }.joinToString("\n") + return "enum ${name.pascalCase()}($fieldsStr)$implStr {\n$content\n${"}".indentCode(indent)}\n\n".indentCode(indent) + } + + val entriesStr = entries.joinToString("\n") { entry -> + "case ${entry.name.value()}".indentCode(indent + 1) + } + val functionsStr = elements.filterIsInstance().joinToString("\n") { + val overridePrefix = if (it.isOverride || it.name.camelCase() == "toString") "override " else "" + it.emitAsMethod(indent + 1, overridePrefix) + } + val content = listOf(entriesStr, functionsStr).filter { it.isNotEmpty() }.joinToString("\n") + return "enum ${name.pascalCase()}$implStr {\n$content\n${"}".indentCode(indent)}\n\n".indentCode(indent) + } + + private fun AstFunction.emitAsMethod(indent: Int, prefix: String): String { + val rType = returnType?.takeIf { it != Type.Unit }?.emitTypeAnnotation() ?: "Unit" + val params = parameters.joinToString(", ") { it.emit(0) } + return if (body.isEmpty()) { + "${prefix}def ${name.camelCase()}($params): $rType\n".indentCode(indent) + } else { + val content = body.joinToString("") { it.emit(1) } + "${prefix}def ${name.camelCase()}($params): $rType = {\n$content${"}".indentCode(0)}\n".indentCode(indent) + } + } + + private fun Struct.emit(indent: Int, parents: List): String { + val implStr = if (interfaces.isEmpty()) "" else " extends ${interfaces.map { it.emitTypeAnnotation() }.distinct().joinToString(" with ")}" + + val nestedContent = elements.joinToString("") { it.emit(indent + 1, isStatic = true, parents = parents + this) } + val customConstructors = constructors.joinToString("") { it.emitScala(fields, indent + 1) } + + if (constructors.size == 1 && constructors.single().parameters.isEmpty()) { + if (isModelStruct()) { + val bodyContent = listOf(nestedContent).filter { it.isNotEmpty() }.joinToString("\n") + return if (bodyContent.isNotEmpty()) { + "case class ${name.pascalCase()}()$implStr {\n$bodyContent${"}".indentCode(indent)}\n\n".indentCode(indent) + } else { + "case class ${name.pascalCase()}()$implStr\n\n".indentCode(indent) + } + } + val constructor = constructors.single() + val assignments = constructor.body.filterIsInstance() + val fieldProperties = fields.joinToString("\n") { field -> + val assignment = assignments.find { it.name.camelCase() == field.name.value() } + val valueStr = assignment?.let { " = ${it.value.emit()}" } ?: "" + "${if (field.isOverride) "override " else ""}val ${field.name.value().sanitize()}: ${field.type.emitTypeAnnotation()}$valueStr".indentCode(indent + 1) + } + val bodyContent = listOf(fieldProperties, nestedContent).filter { it.isNotEmpty() }.joinToString("\n") + return if (bodyContent.isNotEmpty()) { + "object ${name.pascalCase()}$implStr {\n$bodyContent${"}".indentCode(indent)}\n\n".indentCode(indent) + } else { + "object ${name.pascalCase()}$implStr\n\n".indentCode(indent) + } + } + + if (fields.isEmpty() && constructors.isEmpty()) { + return if (nestedContent.isNotEmpty()) { + "object ${name.pascalCase()}$implStr {\n$nestedContent${"}".indentCode(indent)}\n\n".indentCode(indent) + } else { + "object ${name.pascalCase()}$implStr\n\n".indentCode(indent) + } + } + + val params = fields.joinToString(",\n") { + "${if (it.isOverride) "override " else ""}val ${it.name.value().sanitize()}: ${it.type.emitTypeAnnotation()}".indentCode(indent + 1) + } + val paramsStr = if (fields.isEmpty()) "" else "(\n$params\n${")".indentCode(indent)}" + + val hasBody = customConstructors.isNotEmpty() || nestedContent.isNotEmpty() + + return if (hasBody) { + "case class ${name.pascalCase()}$paramsStr$implStr {\n$customConstructors$nestedContent${"}".indentCode(indent)}\n\n".indentCode(indent) + } else { + "case class ${name.pascalCase()}$paramsStr$implStr\n\n".indentCode(indent) + } + } + + private fun Constructor.emitScala(structFields: List, indent: Int): String { + val params = parameters.joinToString(", ") { it.emit(0) } + val isDelegating = body.any { it is ConstructorStatement } + + if (isDelegating) { + val delegationStmt = body.filterIsInstance().first() + val delegationArgs = delegationStmt.namedArguments.map { "${it.key.value()} = ${it.value.emit()}" } + val delegationStr = "this(${delegationArgs.joinToString(", ")})" + return "def this($params) = $delegationStr\n".indentCode(indent) + } + + val assignments = body.filterIsInstance().associate { + it.name.value() to it.value.emit() + } + val constructorArgs = structFields.map { field -> + assignments[field.name.value()] ?: "null" + } + + return "def this($params) = this(${constructorArgs.joinToString(", ")})\n".indentCode(indent) + } + + private fun AstFunction.emit(indent: Int, parents: List): String { + val overridePrefix = if (isOverride) "override " else "" + val typeParamsStr = if (typeParameters.isNotEmpty()) { + "[${typeParameters.joinToString(", ") { it.emit() }}]" + } else { + "" + } + val rType = returnType?.takeIf { it != Type.Unit }?.emitTypeAnnotation() ?: "Unit" + val returnTypeStr = ": $rType" + val params = parameters.joinToString(", ") { it.emit(0) } + + return if (body.isEmpty()) { + "${overridePrefix}def ${name.camelCase()}$typeParamsStr($params)$returnTypeStr\n".indentCode(indent) + } else if (body.size == 1 && body.first() is ReturnStatement) { + val expr = (body.first() as ReturnStatement).expression.emit() + "${overridePrefix}def ${name.camelCase()}$typeParamsStr($params)$returnTypeStr =\n${expr.indentCode(1)}\n\n".indentCode(indent) + } else { + val content = body.joinToString("") { it.emit(1) } + "${overridePrefix}def ${name.camelCase()}$typeParamsStr($params)$returnTypeStr = {\n$content${"}".indentCode(0)}\n\n".indentCode(indent) + } + } + + private fun Parameter.emit(indent: Int): String = "${name.camelCase().sanitize()}: ${type.emitTypeAnnotation()}".indentCode(indent) + + private fun TypeParameter.emit(): String { + val typeStr = type.emitGenerics() + return if (extends.isEmpty()) { + typeStr + } else { + "$typeStr <: ${extends.joinToString(" with ") { it.emitGenerics() }}" + } + } + + private fun Type.emit(): String = when (this) { + is Type.Integer -> when (precision) { + Precision.P32 -> "Int" + Precision.P64 -> "Long" + } + + is Type.Number -> when (precision) { + Precision.P32 -> "Float" + Precision.P64 -> "Double" + } + + Type.Any -> "Any" + Type.String -> "String" + Type.Bytes -> "Array[Byte]" + Type.Boolean -> "Boolean" + Type.Unit -> "Unit" + Type.Wildcard -> "?" + Type.Reflect -> "scala.reflect.ClassTag[?]" + is Type.Array -> "List" + is Type.Dict -> "Map" + is Type.Custom -> name + is Type.Nullable -> "Option[${type.emitGenerics()}]" + is Type.IntegerLiteral -> "Int" + is Type.StringLiteral -> "String" + } + + private fun Type.emitGenerics(): String = when (this) { + is Type.Array -> "${emit()}[${elementType.emitGenerics()}]" + is Type.Dict -> "${emit()}[${keyType.emitGenerics()}, ${valueType.emitGenerics()}]" + is Type.Custom -> { + if (generics.isEmpty()) { + emit() + } else { + "${emit()}[${generics.joinToString(", ") { it.emitGenerics() }}]" + } + } + + is Type.Nullable -> "Option[${type.emitGenerics()}]" + else -> emit() + } + + // Emit type for use in type annotation positions (field types, parameter types, return types). + // Adds .type suffix for Scala singleton object types and recurses into generics. + private fun Type.emitTypeAnnotation(): String = when (this) { + is Type.Array -> "List[${elementType.emitTypeAnnotation()}]" + is Type.Dict -> "Map[${keyType.emitTypeAnnotation()}, ${valueType.emitTypeAnnotation()}]" + is Type.Custom -> { + if (generics.isEmpty()) { + if (name in objectNames) "$name.type" else name + } else { + "$name[${generics.joinToString(", ") { it.emitTypeAnnotation() }}]" + } + } + is Type.Nullable -> "Option[${type.emitTypeAnnotation()}]" + else -> emit() + } + + private fun Statement.emit(indent: Int): String = when (this) { + is PrintStatement -> "println(${expression.emit()})\n".indentCode(indent) + is ReturnStatement -> "${expression.emit()}\n".indentCode(indent) + is ConstructorStatement -> { + val allArgs = namedArguments.map { "${it.key.value()} = ${it.value.emit()}" } + val argsStr = when { + allArgs.isEmpty() -> "" + allArgs.size == 1 -> "(${allArgs.first()})" + else -> "(\n${allArgs.joinToString(",\n") { it.indentCode(1) }}\n)" + } + val prefix = if (needsNew()) "new " else "" + "$prefix${type.emitGenerics()}$argsStr\n".indentCode(indent) + } + + is Literal -> "${emit()}\n".indentCode(indent) + is LiteralList -> "${emit()}\n".indentCode(indent) + is LiteralMap -> "${emit()}\n".indentCode(indent) + is Assignment -> { + val expr = (value as? ConstructorStatement)?.let { constructorStmt -> + val allArgs = constructorStmt.namedArguments.map { "${it.key.value()} = ${it.value.emit()}" } + val argsStr = when { + allArgs.isEmpty() -> "" + allArgs.size == 1 -> "(${allArgs.first()})" + else -> "(\n${allArgs.joinToString(",\n") { it.indentCode(1) }}\n)" + } + val prefix = if (constructorStmt.needsNew()) "new " else "" + "$prefix${constructorStmt.type.emitGenerics()}$argsStr" + } ?: value.emit() + if (isProperty) { + "${name.value().sanitize()} = $expr\n".indentCode(indent) + } else { + "val ${name.camelCase().sanitize()} = $expr\n".indentCode(indent) + } + } + + is ErrorStatement -> "throw new IllegalStateException(${message.emit()})\n".indentCode(indent) + is AssertStatement -> "assert(${expression.emit()}, \"$message\")\n".indentCode(indent) + is Switch -> { + val isPatternSwitch = cases.any { it.type != null } + if (isPatternSwitch) { + val casesStr = cases.joinToString("") { case -> + val bodyStr = case.body.joinToString("") { it.emit(1) } + val typeStr = case.type?.emitGenerics() ?: "Any" + val varName = variable?.camelCase() ?: "_" + "case $varName: $typeStr => {\n$bodyStr}\n".indentCode(indent + 1) + } + val defaultStr = default?.let { + val bodyStr = it.joinToString("") { stmt -> stmt.emit(1) } + "case _ => {\n$bodyStr}\n".indentCode(indent + 1) + } ?: "" + "${expression.emit()} match {\n$casesStr$defaultStr}\n".indentCode(indent) + } else { + val casesStr = cases.joinToString("") { case -> + val bodyStr = case.body.joinToString("") { it.emit(1) } + "case ${case.value.emit()} => {\n$bodyStr}\n".indentCode(indent + 1) + } + val defaultStr = default?.let { + val bodyStr = it.joinToString("") { stmt -> stmt.emit(1) } + "case _ => {\n$bodyStr}\n".indentCode(indent + 1) + } ?: "" + "${expression.emit()} match {\n$casesStr$defaultStr}\n".indentCode(indent) + } + } + + is RawExpression -> "$code\n".indentCode(indent) + is NullLiteral -> "null\n".indentCode(indent) + is NullableEmpty -> "None\n".indentCode(indent) + is VariableReference -> "${name.camelCase().sanitize()}\n".indentCode(indent) + is FieldCall -> { + val receiverStr = receiver?.let { "${it.emit()}." } ?: "" + "$receiverStr${field.value().sanitize()}\n".indentCode(indent) + } + + is FunctionCall -> { + val typeArgsStr = + if (typeArguments.isNotEmpty()) "[${typeArguments.joinToString(", ") { it.emitGenerics() }}]" else "" + val receiverStr = receiver?.let { "${it.emit()}." } ?: "" + "$receiverStr${name.value().sanitize()}$typeArgsStr(${arguments.values.joinToString(", ") { it.emit() }})\n".indentCode(indent) + } + + is ArrayIndexCall -> "${emitArrayIndex()}\n".indentCode(indent) + + is EnumReference -> "${enumType.emitGenerics()}.${entry.value()}\n".indentCode(indent) + is EnumValueCall -> "${expression.emit()}.toString\n".indentCode(indent) + is BinaryOp -> "(${left.emit()} ${operator.toScala()} ${right.emit()})\n".indentCode(indent) + is TypeDescriptor -> "${emitTypeDescriptor()}\n".indentCode(indent) + is NullCheck -> "${emit()}\n".indentCode(indent) + is NullableMap -> "${emit()}\n".indentCode(indent) + is NullableOf -> "${emit()}\n".indentCode(indent) + is NullableGet -> "${emit()}\n".indentCode(indent) + is Constraint.RegexMatch -> "${emit()}\n".indentCode(indent) + is Constraint.BoundCheck -> "${emit()}\n".indentCode(indent) + is NotExpression -> "!${expression.emit()}\n".indentCode(indent) + is IfExpression -> "${emit()}\n".indentCode(indent) + is MapExpression -> "${emit()}\n".indentCode(indent) + is FlatMapIndexed -> "${emit()}\n".indentCode(indent) + is ListConcat -> "${emit()}\n".indentCode(indent) + is StringTemplate -> "${emit()}\n".indentCode(indent) + } + + private fun BinaryOp.Operator.toScala(): String = when (this) { + BinaryOp.Operator.PLUS -> "+" + BinaryOp.Operator.EQUALS -> "==" + BinaryOp.Operator.NOT_EQUALS -> "!=" + } + + private fun Expression.emit(): String = when (this) { + is ConstructorStatement -> { + if (type == Type.Unit) { + "()" + } else { + val allArgs = namedArguments.map { "${it.key.value()} = ${it.value.emit()}" } + val argsStr = when { + allArgs.isEmpty() -> "" + allArgs.size == 1 -> "(${allArgs.first()})" + else -> "(\n${allArgs.joinToString(",\n") { it.indentCode(1) }}\n)" + } + val prefix = if (needsNew()) "new " else "" + "$prefix${type.emitGenerics()}$argsStr" + } + } + + is Literal -> emit() + is LiteralList -> emit() + is LiteralMap -> emit() + is RawExpression -> code + is NullLiteral -> "null" + is NullableEmpty -> "None" + is VariableReference -> name.camelCase().sanitize() + is FieldCall -> { + val receiverStr = receiver?.let { "${it.emit()}." } ?: "" + "$receiverStr${field.value().sanitize()}" + } + + is FunctionCall -> { + val typeArgsStr = + if (typeArguments.isNotEmpty()) "[${typeArguments.joinToString(", ") { it.emitGenerics() }}]" else "" + val receiverStr = receiver?.let { "${it.emit()}." } ?: "" + "$receiverStr${name.value().sanitize()}$typeArgsStr(${arguments.values.joinToString(", ") { it.emit() }})" + } + + is ArrayIndexCall -> emitArrayIndex() + + is EnumReference -> "${enumType.emitGenerics()}.${entry.value()}" + is EnumValueCall -> "${expression.emit()}.toString" + is BinaryOp -> "(${left.emit()} ${operator.toScala()} ${right.emit()})" + is TypeDescriptor -> emitTypeDescriptor() + is NullCheck -> "(${expression.emit()}.map(it => ${body.emit()})${alternative?.emit()?.let { ".getOrElse($it)" } ?: ""})" + is NullableMap -> "(${expression.emit()}.map(it => ${body.emit()}).getOrElse(${alternative.emit()}))" + is NullableOf -> "Some(${expression.emit()})" + is NullableGet -> "${expression.emit()}.get" + is Constraint.RegexMatch -> "\"\"\"${pattern}\"\"\".r.findFirstIn(${value.emit()}).isDefined" + is Constraint.BoundCheck -> { + val checks = listOfNotNull( + min?.let { "$it <= ${value.emit()}" }, + max?.let { "${value.emit()} <= $it" }, + ).joinToString(" && ").ifEmpty { "true" } + checks + } + is ErrorStatement -> "throw new IllegalStateException(${message.emit()})" + is AssertStatement -> throw IllegalArgumentException("AssertStatement cannot be an expression in Scala") + is Switch -> { + val isPatternSwitch = cases.any { it.type != null } + if (isPatternSwitch) { + val casesStr = cases.joinToString("") { case -> + val bodyStr = case.body.joinToString("") { it.emit(1) } + val typeStr = case.type?.emitGenerics() ?: "Any" + val varName = variable?.camelCase() ?: "_" + "case $varName: $typeStr => {\n$bodyStr}\n".indentCode(1) + } + val defaultStr = default?.let { + val bodyStr = it.joinToString("") { stmt -> stmt.emit(1) } + "case _ => {\n$bodyStr}\n".indentCode(1) + } ?: "" + "${expression.emit()} match {\n$casesStr$defaultStr}" + } else { + val casesStr = cases.joinToString("") { case -> + val bodyStr = case.body.joinToString("") { it.emit(1) } + "case ${case.value.emit()} => {\n$bodyStr}\n".indentCode(1) + } + val defaultStr = default?.let { + val bodyStr = it.joinToString("") { stmt -> stmt.emit(1) } + "case _ => {\n$bodyStr}\n".indentCode(1) + } ?: "" + "${expression.emit()} match {\n$casesStr$defaultStr}" + } + } + + is Assignment -> throw IllegalArgumentException("Assignment cannot be an expression in Scala") + is PrintStatement -> throw IllegalArgumentException("PrintStatement cannot be an expression in Scala") + is ReturnStatement -> throw IllegalArgumentException("ReturnStatement cannot be an expression in Scala") + is NotExpression -> "!${expression.emit()}" + is IfExpression -> "if (${condition.emit()}) ${thenExpr.emit()} else ${elseExpr.emit()}" + is MapExpression -> "${receiver.emit()}.map(${variable.camelCase()} => ${body.emit()})" + is FlatMapIndexed -> "${receiver.emit()}.zipWithIndex.flatMap { case (${elementVar.camelCase()}, ${indexVar.camelCase()}) => ${body.emit()} }" + is ListConcat -> when { + lists.isEmpty() -> "List.empty[String]" + lists.size == 1 -> lists.single().emit() + else -> lists.joinToString(" ++ ") { expr -> + val emitted = expr.emit() + if (expr is IfExpression) "($emitted)" else emitted + } + } + is StringTemplate -> "s\"${parts.joinToString("") { + when (it) { + is StringTemplate.Part.Text -> it.value + is StringTemplate.Part.Expr -> "\${${it.expression.emit()}}" + } + }}\"" + } + + private fun LiteralList.emit(): String { + if (values.isEmpty()) return "List.empty[${type.emitGenerics()}]" + val list = values.joinToString(", ") { it.emit() } + return "List($list)" + } + + private fun LiteralMap.emit(): String { + if (values.isEmpty()) return "Map.empty" + val map = values.entries.joinToString(", ") { + "${Literal(it.key, keyType).emit()} -> ${it.value.emit()}" + } + return "Map($map)" + } + + private fun Literal.emit(): String = when (type) { + Type.String -> "\"$value\"" + is Type.Integer -> if (type.precision == Precision.P64) "${value}L" else value.toString() + else -> value.toString() + } + + private fun ArrayIndexCall.emitArrayIndex(): String { + val isMapAccess = index is Literal && (index as Literal).type == Type.String + return when { + !caseSensitive && isMapAccess -> "${receiver.emit()}.find(_._1.equalsIgnoreCase(${index.emit()})).map(_._2)" + isMapAccess -> "${receiver.emit()}.get(${index.emit()})" + else -> "${receiver.emit()}(${index.emit()})" + } + } + + private fun TypeDescriptor.emitTypeDescriptor(): String = "scala.reflect.classTag[${type.emitGenerics()}]" +} + +private fun String.sanitize(): String = if (reservedKeywords.contains(this)) "`$this`" else this + +private val reservedKeywords = setOf( + "abstract", "case", "class", "def", "do", + "else", "extends", "false", "final", "for", + "forSome", "if", "implicit", "import", "lazy", + "match", "new", "null", "object", "override", + "package", "private", "protected", "return", "sealed", + "super", "this", "throw", "trait", "true", + "try", "type", "val", "var", "while", + "with", "yield", "given", "using", "enum", + "export", "then", +) diff --git a/src/compiler/ir/src/commonMain/kotlin/community/flock/wirespec/ir/generator/TypeScriptGenerator.kt b/src/compiler/ir/src/commonMain/kotlin/community/flock/wirespec/ir/generator/TypeScriptGenerator.kt new file mode 100644 index 000000000..17a1921dc --- /dev/null +++ b/src/compiler/ir/src/commonMain/kotlin/community/flock/wirespec/ir/generator/TypeScriptGenerator.kt @@ -0,0 +1,704 @@ +package community.flock.wirespec.ir.generator + +import community.flock.wirespec.ir.core.ArrayIndexCall +import community.flock.wirespec.ir.core.AssertStatement +import community.flock.wirespec.ir.core.Assignment +import community.flock.wirespec.ir.core.BinaryOp +import community.flock.wirespec.ir.core.Constraint +import community.flock.wirespec.ir.core.Constructor +import community.flock.wirespec.ir.core.ConstructorStatement +import community.flock.wirespec.ir.core.Element +import community.flock.wirespec.ir.core.Enum +import community.flock.wirespec.ir.core.EnumReference +import community.flock.wirespec.ir.core.EnumValueCall +import community.flock.wirespec.ir.core.ErrorStatement +import community.flock.wirespec.ir.core.Expression +import community.flock.wirespec.ir.core.Field +import community.flock.wirespec.ir.core.FieldCall +import community.flock.wirespec.ir.core.File +import community.flock.wirespec.ir.core.FlatMapIndexed +import community.flock.wirespec.ir.core.FunctionCall +import community.flock.wirespec.ir.core.IfExpression +import community.flock.wirespec.ir.core.Import +import community.flock.wirespec.ir.core.Interface +import community.flock.wirespec.ir.core.ListConcat +import community.flock.wirespec.ir.core.Literal +import community.flock.wirespec.ir.core.LiteralList +import community.flock.wirespec.ir.core.LiteralMap +import community.flock.wirespec.ir.core.Main +import community.flock.wirespec.ir.core.MapExpression +import community.flock.wirespec.ir.core.Name +import community.flock.wirespec.ir.core.Namespace +import community.flock.wirespec.ir.core.NotExpression +import community.flock.wirespec.ir.core.NullCheck +import community.flock.wirespec.ir.core.NullLiteral +import community.flock.wirespec.ir.core.NullableEmpty +import community.flock.wirespec.ir.core.NullableGet +import community.flock.wirespec.ir.core.NullableMap +import community.flock.wirespec.ir.core.NullableOf +import community.flock.wirespec.ir.core.Package +import community.flock.wirespec.ir.core.Parameter +import community.flock.wirespec.ir.core.PrintStatement +import community.flock.wirespec.ir.core.RawElement +import community.flock.wirespec.ir.core.RawExpression +import community.flock.wirespec.ir.core.ReturnStatement +import community.flock.wirespec.ir.core.Statement +import community.flock.wirespec.ir.core.StringTemplate +import community.flock.wirespec.ir.core.Struct +import community.flock.wirespec.ir.core.Switch +import community.flock.wirespec.ir.core.Type +import community.flock.wirespec.ir.core.TypeDescriptor +import community.flock.wirespec.ir.core.Union +import community.flock.wirespec.ir.core.VariableReference +import community.flock.wirespec.ir.core.forEachElement +import community.flock.wirespec.ir.core.Function as AstFunction + +object TypeScriptGenerator : Generator { + override fun generate(element: Element): String = when (element) { + is File -> TypeScriptFileEmitter(element).emitFile() + else -> TypeScriptFileEmitter(File(Name.of(""), listOf(element))).emitFile() + } + + fun generateExpression(expression: Expression): String = TypeScriptFileEmitter(File(Name.of(""), emptyList())).renderExpression(expression) +} + +private class TypeScriptFileEmitter(val file: File) { + + private val structsWithConstructors: Set = buildSet { + file.forEachElement { element -> + if (element is Struct && element.constructors.isNotEmpty()) { + add(element.name.pascalCase()) + } + } + } + + private val constructorFuncNames: Set = structsWithConstructors + .map { it.replaceFirstChar { c -> c.lowercaseChar() } } + .toSet() + + fun emitFile(): String = file.elements.joinToString("") { it.emit(0) }.removeEmptyLines() + + fun renderExpression(expression: Expression): String = expression.emit() + + private fun String.indentCode(level: Int): String = " ".repeat(level * 2) + this + + private fun String.removeEmptyLines(): String = lines().filter { it.isNotEmpty() }.joinToString("\n").plus("\n") + + private fun Element.emit(indent: Int): String = when (this) { + is Package -> emit(indent) + is Import -> emit(indent) + is Struct -> emit(indent) + is AstFunction -> emit(indent) + is Namespace -> emit(indent) + is Interface -> emit(indent) + is Union -> emit(indent) + is Enum -> emit(indent) + is Main -> { + val staticContent = statics.joinToString("") { it.emit(indent) } + val content = body.joinToString("") { it.emit(indent + 1) } + "$staticContent${";(async () => {\n$content${"})();".indentCode(indent)}\n".indentCode(indent)}" + } + is File -> elements.joinToString("") { it.emit(indent) } + is RawElement -> code.lines().joinToString("\n") { if (it.isEmpty()) it else it.indentCode(indent) } + "\n" + } + + private fun Package.emit(indent: Int): String = "" + + private fun Import.emit(indent: Int): String = "import { ${type.name} } from '$path';\n".indentCode(indent) + + private fun Namespace.emit(indent: Int): String { + val content = elements.joinToString("") { it.emit(indent + 1) } + val closingBrace = "}\n".indentCode(indent) + return "export namespace ${name.pascalCase()} {\n$content$closingBrace".indentCode(indent) + } + + private fun Interface.emit(indent: Int): String { + val typeParamsStr = if (typeParameters.isNotEmpty()) { + "<${typeParameters.joinToString(", ") { tp -> + val extendsStr = if (tp.extends.isNotEmpty()) " extends ${tp.extends.joinToString(" & ") { it.emit() }}" else "" + "${tp.type.emit()}$extendsStr" + }}>" + } else { + "" + } + val ext = extends.map { it.emit() } + val extStr = if (ext.isEmpty()) "" else " extends ${ext.joinToString(", ")}" + val nestedInterfaces = elements.filterIsInstance().associateBy { it.name.pascalCase() } + val nonInterfaceElements = elements.filter { it !is Interface } + val fieldsContent = fields.joinToString("") { field -> + val typeStr = field.type.emitWithInlineInterfaces(nestedInterfaces) + "${field.name.value()}: $typeStr;\n".indentCode(indent + 1) + } + val elementsContent = nonInterfaceElements.joinToString("") { + when (it) { + is AstFunction -> it.emit(indent + 1, nestedInterfaces) + else -> it.emit(indent + 1) + } + } + val content = fieldsContent + elementsContent + return if (content.isEmpty()) { + "export interface ${name.pascalCase()}$typeParamsStr$extStr {}\n".indentCode(indent) + } else { + val closingBrace = "}\n".indentCode(indent) + "export interface ${name.pascalCase()}$typeParamsStr$extStr {\n$content$closingBrace".indentCode(indent) + } + } + + private fun Union.emit(indent: Int): String { + val typeParamsStr = if (typeParameters.isNotEmpty()) { + "<${typeParameters.joinToString(", ") { "${it.type.emit()} = unknown" }}>" + } else { + "" + } + return if (members.isNotEmpty()) { + "export type ${name.pascalCase()}$typeParamsStr = ${members.joinToString(" | ") { it.name }}\n".indentCode(indent) + } else { + val extStr = extends?.let { " extends ${it.emit()}" } ?: "" + "export interface ${name.pascalCase()}$typeParamsStr$extStr {}\n".indentCode(indent) + } + } + + private fun Enum.emit(indent: Int): String = "export type ${name.pascalCase()} = ${entries.joinToString(" | ") { "\"${it.name.value()}\"" }}\n".indentCode(indent) + + private fun Struct.emit(indent: Int): String { + val nestedStructs = elements.filterIsInstance().associateBy { it.name.pascalCase() } + val nonStructElements = elements.filter { it !is Struct } + val fieldsContent = if (fields.isEmpty()) { + "" + } else { + fields.joinToString("") { it.emit(indent + 1, nestedStructs) } + } + val nestedContent = nonStructElements.joinToString("") { it.emit(indent) } + val closingBrace = "}".indentCode(indent) + val pascalName = name.pascalCase() + val typeStr = if (fields.isEmpty() && nonStructElements.isEmpty()) { + "export type $pascalName = {}\n".indentCode(indent) + } else if (fields.isEmpty()) { + "export type $pascalName = {}\n".indentCode(indent) + nestedContent + } else { + "export type $pascalName = {\n$fieldsContent$closingBrace\n".indentCode(indent) + nestedContent + } + val constructorFunctions = constructors.joinToString("") { constructor -> + emitStructConstructor(pascalName, constructor, indent) + } + return typeStr + constructorFunctions + } + + private fun Field.emit(indent: Int, inlineStructs: Map = emptyMap()): String { + val typeStr = type.emitWithInlineStructs(inlineStructs) + return "\"${name.value()}\": $typeStr,\n".indentCode(indent) + } + + private fun Type.emitWithInlineStructs(inlineStructs: Map): String = when (this) { + is Type.Custom -> inlineStructs[name]?.emitInline() ?: emit() + is Type.Nullable -> "${type.emitWithInlineStructs(inlineStructs)} | undefined" + is Type.Array -> { + val element = elementType.emitWithInlineStructs(inlineStructs) + if (elementType is Type.Nullable) "($element)[]" else "$element[]" + } + else -> emit() + } + + private fun Struct.emitInline(): String { + if (fields.isEmpty()) return "{}" + val nestedStructs = elements.filterIsInstance().associateBy { it.name.pascalCase() } + return "{${fields.joinToString(", ") { field -> + val typeStr = field.type.emitWithInlineStructs(nestedStructs) + "\"${field.name.value()}\": $typeStr" + }}}" + } + + private fun emitStructConstructor(structName: String, constructor: Constructor, indent: Int): String { + val funcName = structName.replaceFirstChar { it.lowercaseChar() } + val paramsTypeName = "${structName}Params" + val paramNames = constructor.parameters.map { it.name.value() }.toSet() + + // Emit params type + val paramsTypeContent = if (constructor.parameters.isEmpty()) { + "{}" + } else { + constructor.parameters.joinToString(", ") { param -> + when (val t = param.type) { + is Type.Nullable -> "\"${param.name.value()}\"?: ${t.type.emit()}" + else -> "\"${param.name.value()}\": ${param.type.emit()}" + } + }.let { "{$it}" } + } + val paramsTypeLine = "export type $paramsTypeName = $paramsTypeContent\n".indentCode(indent) + + // Emit constructor arrow function + val paramsArg = if (constructor.parameters.isEmpty()) "" else "params: $paramsTypeName" + val bodyAssignments = constructor.body.filterIsInstance() + val bodyContent = bodyAssignments.joinToString("") { assignment -> + val value = emitConstructorValue(assignment.value, paramNames) + "${assignment.name.value()}: $value,\n".indentCode(indent + 1) + } + val closingParen = "})".indentCode(indent) + val funcLine = "export const $funcName = ($paramsArg): $structName => ({\n$bodyContent$closingParen\n".indentCode(indent) + + return paramsTypeLine + funcLine + } + + private fun emitConstructorValue(expr: Expression, paramNames: Set): String = when (expr) { + is RawExpression -> when { + expr.code in paramNames -> "params.${expr.code}" + else -> expr.code + } + is VariableReference -> when { + expr.name.value() in paramNames -> "params.${expr.name.value()}" + else -> expr.name.value() + } + is EnumReference -> "\"${expr.entry.value()}\"" + is ConstructorStatement -> when { + expr.type == Type.Unit -> "undefined" + expr.namedArguments.isEmpty() -> "{}" + else -> { + val args = expr.namedArguments.entries.joinToString(", ") { (key, value) -> + "\"${key.value()}\": ${emitConstructorArgValue(value, paramNames)}" + } + "{$args}" + } + } + is NullLiteral -> "undefined" + is NullableEmpty -> "undefined" + else -> expr.emit() + } + + private fun emitConstructorArgValue(expr: Expression, paramNames: Set): String = when (expr) { + is RawExpression -> when { + expr.code in paramNames -> "params[\"${expr.code}\"]" + else -> expr.code + } + is VariableReference -> when { + expr.name.value() in paramNames -> "params[\"${expr.name.value()}\"]" + else -> expr.name.value() + } + else -> emitConstructorValue(expr, paramNames) + } + + private fun AstFunction.emit(indent: Int, inlineInterfaces: Map = emptyMap()): String { + val retType = returnType + val rType = retType?.let { ": ${it.emitWithInlineInterfaces(inlineInterfaces)}" } ?: "" + + val typeParamsStr = if (typeParameters.isNotEmpty()) { + "<${typeParameters.joinToString(", ") { tp -> + val extendsStr = if (tp.extends.isNotEmpty()) " extends ${tp.extends.joinToString(" & ") { it.emit() }}" else "" + "${tp.type.emit()}$extendsStr" + }}>" + } else { + "" + } + + // Detect parameter names that collide with constructor function names + val renames = parameters + .filter { it.name.camelCase() in constructorFuncNames } + .associate { it.name.camelCase() to "_${it.name.camelCase()}" } + + val effectiveParams = parameters.map { p -> + renames[p.name.camelCase()]?.let { Parameter(Name(listOf(it)), p.type) } ?: p + } + val effectiveBody = if (renames.isNotEmpty()) { + body.map { stmt -> renameVariables(stmt, renames) } + } else { + body + } + + val params = effectiveParams.joinToString(", ") { it.emitWithInlineInterfaces(inlineInterfaces) } + val prefix = if (isAsync) "async " else "" + return if (effectiveBody.isEmpty()) { + val tsRType = if (isAsync) { + if (retType == null || retType == Type.Unit) { + ": Promise" + } else { + ": Promise<${retType.emitWithInlineInterfaces(inlineInterfaces)}>" + } + } else { + rType + } + "${name.camelCase()}$typeParamsStr($params)$tsRType;\n".indentCode(indent) + } else { + val content = effectiveBody.joinToString("") { it.emit(indent + 1) } + val closingBrace = "}\n".indentCode(indent) + "export ${prefix}function ${name.camelCase()}$typeParamsStr($params)$rType {\n$content$closingBrace".indentCode(indent) + } + } + + @Suppress("UNCHECKED_CAST") + private fun renameVariables(expr: T, renames: Map): T { + if (renames.isEmpty()) return expr + return when (expr) { + is VariableReference -> { + val newName = renames[expr.name.camelCase()] ?: return expr + VariableReference(Name(listOf(newName))) as T + } + is RawExpression -> { + var code = expr.code + for ((old, new) in renames) { + code = code.replace(old, new) + } + RawExpression(code) as T + } + is FieldCall -> FieldCall( + receiver = expr.receiver?.let { renameVariables(it, renames) }, + field = expr.field, + ) as T + is FunctionCall -> FunctionCall( + receiver = expr.receiver?.let { renameVariables(it, renames) }, + typeArguments = expr.typeArguments, + name = expr.name, + arguments = expr.arguments.mapValues { renameVariables(it.value, renames) }, + ) as T + is ArrayIndexCall -> ArrayIndexCall( + receiver = renameVariables(expr.receiver, renames), + index = renameVariables(expr.index, renames), + caseSensitive = expr.caseSensitive, + ) as T + is ConstructorStatement -> ConstructorStatement( + type = expr.type, + namedArguments = expr.namedArguments.mapValues { renameVariables(it.value, renames) }, + ) as T + is ReturnStatement -> ReturnStatement( + expression = renameVariables(expr.expression, renames), + ) as T + is Assignment -> Assignment( + name = expr.name, + value = renameVariables(expr.value, renames), + isProperty = expr.isProperty, + ) as T + is NullCheck -> NullCheck( + expression = renameVariables(expr.expression, renames), + body = renameVariables(expr.body, renames), + alternative = expr.alternative?.let { renameVariables(it, renames) }, + ) as T + is NullableMap -> NullableMap( + expression = renameVariables(expr.expression, renames), + body = renameVariables(expr.body, renames), + alternative = renameVariables(expr.alternative, renames), + ) as T + is BinaryOp -> BinaryOp( + left = renameVariables(expr.left, renames), + operator = expr.operator, + right = renameVariables(expr.right, renames), + ) as T + is Switch -> Switch( + expression = renameVariables(expr.expression, renames), + cases = expr.cases.map { case -> + case.copy(body = case.body.map { renameVariables(it, renames) }) + }, + default = expr.default?.map { renameVariables(it, renames) }, + variable = expr.variable, + ) as T + is PrintStatement -> PrintStatement( + expression = renameVariables(expr.expression, renames), + ) as T + is ErrorStatement -> ErrorStatement( + message = renameVariables(expr.message, renames), + ) as T + is AssertStatement -> AssertStatement( + expression = renameVariables(expr.expression, renames), + message = expr.message, + ) as T + is LiteralMap -> LiteralMap( + values = expr.values.mapValues { renameVariables(it.value, renames) }, + keyType = expr.keyType, + valueType = expr.valueType, + ) as T + is LiteralList -> LiteralList( + values = expr.values.map { renameVariables(it, renames) }, + type = expr.type, + ) as T + is EnumValueCall -> EnumValueCall( + expression = renameVariables(expr.expression, renames), + ) as T + is NullableOf -> NullableOf( + expression = renameVariables(expr.expression, renames), + ) as T + is NullableGet -> NullableGet( + expression = renameVariables(expr.expression, renames), + ) as T + else -> expr + } + } + + private fun Parameter.emit(): String = "${name.camelCase()}: ${type.emit()}" + + private fun Parameter.emitWithInlineInterfaces(inlineInterfaces: Map): String = "${name.camelCase()}: ${type.emitWithInlineInterfaces(inlineInterfaces)}" + + private fun Type.emitWithInlineInterfaces(inlineInterfaces: Map): String = when { + inlineInterfaces.isEmpty() -> emit() + this is Type.Custom && inlineInterfaces.containsKey(name) -> { + val nested = inlineInterfaces[name]!! + if (nested.elements.isEmpty() && nested.extends.isNotEmpty()) { + nested.extends.joinToString(" & ") { it.emit() } + } else if (nested.elements.isEmpty()) { + "{}" + } else { + emit() + } + } + this is Type.Nullable -> "${type.emitWithInlineInterfaces(inlineInterfaces)} | undefined" + this is Type.Array -> { + val element = elementType.emitWithInlineInterfaces(inlineInterfaces) + if (elementType is Type.Nullable) "($element)[]" else "$element[]" + } + else -> emit() + } + + private fun Type.emit(): String = when (this) { + is Type.Integer -> "number" + is Type.Number -> "number" + Type.Any -> "any" + Type.String -> "string" + Type.Boolean -> "boolean" + Type.Bytes -> "Uint8Array" + Type.Unit -> "void" + Type.Wildcard -> "unknown" + Type.Reflect -> "Type" + is Type.Array -> { + val element = elementType.emit() + if (elementType is Type.Nullable) "($element)[]" else "$element[]" + } + is Type.Dict -> "Record<${keyType.emit()}, ${valueType.emit()}>" + is Type.Custom -> { + if (generics.isEmpty()) { + name + } else { + "$name<${generics.joinToString(", ") { it.emit() }}>" + } + } + is Type.Nullable -> "${type.emit()} | undefined" + is Type.IntegerLiteral -> value.toString() + is Type.StringLiteral -> "\"$value\"" + } + + private fun emitConstructorCall(type: Type, namedArguments: Map): String { + val typeName = (type as? Type.Custom)?.name + if (typeName != null && typeName in structsWithConstructors) { + val funcName = typeName.replaceFirstChar { it.lowercaseChar() } + if (namedArguments.isEmpty()) return "$funcName()" + val args = namedArguments.map { "\"${it.key.value()}\": ${it.value.emit()}" }.joinToString(", ") + return "$funcName({$args})" + } + if (type == Type.Unit) return "undefined" + val named = namedArguments.map { "${it.key.value()}: ${it.value.emit()}" }.joinToString(", ") + return if (named.isEmpty()) "{}" else "{ $named }" + } + + private fun Statement.emit(indent: Int): String = when (this) { + is PrintStatement -> "console.log(${expression.emit()});\n".indentCode(indent) + is ReturnStatement -> "return ${expression.emit()};\n".indentCode(indent) + is ConstructorStatement -> "${emitConstructorCall(type, namedArguments)};\n".indentCode(indent) + is Literal -> "${emit()};\n".indentCode(indent) + is LiteralList -> "${emit()};\n".indentCode(indent) + is LiteralMap -> "${emit()};\n".indentCode(indent) + is Assignment -> { + if (isProperty) { + "${name.value()} = ${value.emit()};\n".indentCode(indent) + } else { + "const ${name.camelCase()} = ${value.emit()};\n".indentCode(indent) + } + } + is ErrorStatement -> "throw new Error(${message.emit()});\n".indentCode(indent) + is AssertStatement -> "if (!(${expression.emit()})) throw new Error('${message.replace("'", "\\'")}');\n".indentCode(indent) + is Switch -> { + val isBlockStyle = cases.any { case -> case.body.any { it is Assignment } } + val casesStr = cases.joinToString("") { case -> + val bodyStr = case.body.joinToString("") { it.emit(indent + 2) } + if (isBlockStyle) { + "case ${case.value.emit()}: {\n".indentCode(indent + 1) + bodyStr + "}\n".indentCode(indent + 1) + } else { + "case ${case.value.emit()}:\n$bodyStr${"break;\n".indentCode(indent + 2)}".indentCode(indent + 1) + } + } + val defaultStr = default?.let { + val bodyStr = it.joinToString("") { stmt -> stmt.emit(indent + 2) } + if (isBlockStyle) { + "default: {\n".indentCode(indent + 1) + + bodyStr + + "}\n".indentCode(indent + 1) + } else { + "default:\n$bodyStr".indentCode(indent + 1) + } + } ?: "" + "${"switch (${expression.emit()}) {\n".indentCode(indent)}$casesStr$defaultStr${"}\n".indentCode(indent)}" + } + is RawExpression -> "$code;\n".indentCode(indent) + is NullLiteral -> "undefined;\n".indentCode(indent) + is NullableEmpty -> "undefined;\n".indentCode(indent) + is VariableReference -> "${name.camelCase()};\n".indentCode(indent) + is FieldCall -> { + val receiverStr = receiver?.let { "${it.emit()}." } ?: "" + "$receiverStr${field.value()};\n".indentCode(indent) + } + is FunctionCall -> { + val awaitPrefix = if (isAwait) "await " else "" + val recv = receiver + if (recv != null) { + "$awaitPrefix${recv.emit()}.${name.value()}(${arguments.values.joinToString(", ") { it.emit() }});\n".indentCode(indent) + } else { + "$awaitPrefix${name.value()}(${arguments.values.joinToString(", ") { it.emit() }});\n".indentCode(indent) + } + } + is ArrayIndexCall -> if (caseSensitive) { + "${receiver.emit()}[${index.emit()}];\n".indentCode(indent) + } else { + "Object.entries(${receiver.emit()}).find(([k]) => k.toLowerCase() === ${index.emit()}.toLowerCase())?.[1] ?? null;\n".indentCode(indent) + } + is EnumReference -> "${enumType.emit()}.${entry.pascalCase()};\n".indentCode(indent) + is EnumValueCall -> "${expression.emit()};\n".indentCode(indent) + is BinaryOp -> "(${left.emit()} ${operator.toTypeScript()} ${right.emit()});\n".indentCode(indent) + is TypeDescriptor -> "\"${type.emit()}\";\n".indentCode(indent) + is NullCheck -> "${emit()};\n".indentCode(indent) + is NullableMap -> "${emit()};\n".indentCode(indent) + is NullableOf -> "${emit()};\n".indentCode(indent) + is NullableGet -> "${emit()};\n".indentCode(indent) + is Constraint.RegexMatch -> "${emit()};\n".indentCode(indent) + is Constraint.BoundCheck -> "${emit()};\n".indentCode(indent) + is NotExpression -> "!${expression.emit()};\n".indentCode(indent) + is IfExpression -> "${emit()};\n".indentCode(indent) + is MapExpression -> "${emit()};\n".indentCode(indent) + is FlatMapIndexed -> "${emit()};\n".indentCode(indent) + is ListConcat -> "${emit()};\n".indentCode(indent) + is StringTemplate -> "${emit()};\n".indentCode(indent) + } + + private fun BinaryOp.Operator.toTypeScript(): String = when (this) { + BinaryOp.Operator.PLUS -> "+" + BinaryOp.Operator.EQUALS -> "===" + BinaryOp.Operator.NOT_EQUALS -> "!==" + } + + private fun Expression.emit(): String = when (this) { + is ConstructorStatement -> emitConstructorCall(type, namedArguments) + is Literal -> emit() + is LiteralList -> emit() + is LiteralMap -> emit() + is RawExpression -> code + is NullLiteral -> "undefined" + is NullableEmpty -> "undefined" + is VariableReference -> name.camelCase() + is FieldCall -> { + val receiverStr = receiver?.let { "${it.emit()}." } ?: "" + "$receiverStr${field.value()}" + } + is FunctionCall -> { + val awaitPrefix = if (isAwait) "await " else "" + val recv = receiver + if (recv != null) { + "$awaitPrefix${recv.emit()}.${name.value()}(${arguments.values.joinToString(", ") { it.emit() }})" + } else { + "$awaitPrefix${name.value()}(${arguments.values.joinToString(", ") { it.emit() }})" + } + } + is ArrayIndexCall -> if (caseSensitive) { + "${receiver.emit()}[${index.emit()}]" + } else { + "Object.entries(${receiver.emit()}).find(([k]) => k.toLowerCase() === ${index.emit()}.toLowerCase())?.[1]" + } + is EnumReference -> "${enumType.emit()}.${entry.pascalCase()}" + is EnumValueCall -> expression.emit() + is BinaryOp -> "(${left.emit()} ${operator.toTypeScript()} ${right.emit()})" + is TypeDescriptor -> "\"${type.emit()}\"" + is NullCheck -> { + val exprStr = expression.emit() + // When expression might be undefined (e.g. case-insensitive header lookup), + // add non-null assertion for the inlined replacement in the body + val bodyReplacement = if (expression is ArrayIndexCall && !(expression as ArrayIndexCall).caseSensitive) "$exprStr!" else exprStr + val bodyStr = body.emitWithInlinedIt(bodyReplacement) + val altStr = alternative?.emit() ?: "undefined" + "$exprStr != null ? $bodyStr : $altStr" + } + is NullableMap -> { + val exprStr = expression.emit() + val bodyStr = body.emitWithInlinedIt(exprStr) + val altStr = alternative.emit() + "$exprStr != null ? $bodyStr : $altStr" + } + is NullableOf -> expression.emit() + is NullableGet -> "${expression.emit()}!" + is Constraint.RegexMatch -> "$rawValue.test(${value.emit()})" + is Constraint.BoundCheck -> { + val checks = listOfNotNull( + min?.let { "$it <= ${value.emit()}" }, + max?.let { "${value.emit()} <= $it" }, + ).joinToString(" && ").ifEmpty { "true" } + checks + } + is ErrorStatement -> "(() => { throw new Error(${message.emit()}) })()" + is AssertStatement -> throw IllegalArgumentException("AssertStatement cannot be an expression in TypeScript") + is Switch -> throw IllegalArgumentException("Switch cannot be an expression in TypeScript") + is Assignment -> throw IllegalArgumentException("Assignment cannot be an expression in TypeScript") + is PrintStatement -> throw IllegalArgumentException("PrintStatement cannot be an expression in TypeScript") + is ReturnStatement -> throw IllegalArgumentException("ReturnStatement cannot be an expression in TypeScript") + is NotExpression -> "!${expression.emit()}" + is IfExpression -> "(${condition.emit()} ? ${thenExpr.emit()} : ${elseExpr.emit()})" + is MapExpression -> "${receiver.emit()}.map(${variable.camelCase()} => ${body.emit()})" + is FlatMapIndexed -> "${receiver.emit()}.flatMap((${elementVar.camelCase()}, ${indexVar.camelCase()}) => ${body.emit()})" + is ListConcat -> when { + lists.isEmpty() -> "[] as string[]" + lists.size == 1 -> lists.single().emit() + else -> "[${lists.joinToString(", ") { "...${it.emit()}" }}]" + } + is StringTemplate -> "`${parts.joinToString("") { + when (it) { + is StringTemplate.Part.Text -> it.value + is StringTemplate.Part.Expr -> "\${${it.expression.emit()}}" + } + }}`" + } + + private fun Expression.emitWithInlinedIt(replacement: String): String = when (this) { + is VariableReference -> if (name.value() == "it") replacement else emit() + is FunctionCall -> { + val recv = receiver + val inlinedArgs = arguments.mapValues { it.value.emitWithInlinedIt(replacement) } + if (recv != null) { + "${recv.emitWithInlinedIt(replacement)}.${name.value()}(${inlinedArgs.values.joinToString(", ")})" + } else { + "${name.value()}(${inlinedArgs.values.joinToString(", ")})" + } + } + is FieldCall -> { + val receiverStr = receiver?.let { "${it.emitWithInlinedIt(replacement)}." } ?: "" + "$receiverStr${field.value()}" + } + is ArrayIndexCall -> if (caseSensitive) { + "${receiver.emitWithInlinedIt(replacement)}[${index.emitWithInlinedIt(replacement)}]" + } else { + "Object.entries(${receiver.emitWithInlinedIt(replacement)}).find(([k]) => k.toLowerCase() === ${index.emitWithInlinedIt(replacement)}.toLowerCase())?.[1]" + } + is EnumValueCall -> expression.emitWithInlinedIt(replacement) + is NotExpression -> "!${expression.emitWithInlinedIt(replacement)}" + is IfExpression -> "(${condition.emitWithInlinedIt(replacement)} ? ${thenExpr.emitWithInlinedIt(replacement)} : ${elseExpr.emitWithInlinedIt(replacement)})" + is MapExpression -> "${receiver.emitWithInlinedIt(replacement)}.map(${variable.camelCase()} => ${body.emitWithInlinedIt(replacement)})" + is FlatMapIndexed -> "${receiver.emitWithInlinedIt(replacement)}.flatMap((${elementVar.camelCase()}, ${indexVar.camelCase()}) => ${body.emitWithInlinedIt(replacement)})" + is ListConcat -> "[${lists.joinToString(", ") { "...${it.emitWithInlinedIt(replacement)}" }}]" + is StringTemplate -> "`${parts.joinToString("") { + when (it) { + is StringTemplate.Part.Text -> it.value + is StringTemplate.Part.Expr -> "\${${it.expression.emitWithInlinedIt(replacement)}}" + } + }}`" + is LiteralList -> emit() + else -> emit() + } + + private fun LiteralList.emit(): String { + if (values.isEmpty()) return "[] as ${type.emit()}[]" + val list = values.joinToString(", ") { it.emit() } + return "[$list]" + } + + private fun LiteralMap.emit(): String { + if (values.isEmpty()) return "{}" + val map = values.entries.joinToString(", ") { + "${Literal(it.key, keyType).emit()}: ${it.value.emit()}" + } + return "{ $map }" + } + + private fun Literal.emit(): String = when (type) { + Type.String -> "'$value'" + else -> value.toString() + } +} diff --git a/src/compiler/ir/src/commonTest/kotlin/community/flock/wirespec/ir/converter/IrConverterTest.kt b/src/compiler/ir/src/commonTest/kotlin/community/flock/wirespec/ir/converter/IrConverterTest.kt new file mode 100644 index 000000000..1a5ad4d51 --- /dev/null +++ b/src/compiler/ir/src/commonTest/kotlin/community/flock/wirespec/ir/converter/IrConverterTest.kt @@ -0,0 +1,160 @@ +package community.flock.wirespec.ir.converter + +import arrow.core.getOrElse +import arrow.core.nonEmptyListOf +import community.flock.wirespec.compiler.core.FileUri +import community.flock.wirespec.compiler.core.ModuleContent +import community.flock.wirespec.compiler.core.ParseContext +import community.flock.wirespec.compiler.core.WirespecSpec +import community.flock.wirespec.compiler.core.parse +import community.flock.wirespec.compiler.core.parse.ast.Definition +import community.flock.wirespec.compiler.core.parse.ast.Module +import community.flock.wirespec.compiler.utils.NoLogger +import community.flock.wirespec.ir.core.Constraint +import community.flock.wirespec.ir.core.LiteralList +import community.flock.wirespec.ir.core.Name +import community.flock.wirespec.ir.core.Type +import community.flock.wirespec.ir.core.VariableReference +import community.flock.wirespec.ir.core.enum +import community.flock.wirespec.ir.core.file +import community.flock.wirespec.ir.core.struct +import community.flock.wirespec.ir.core.union +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.fail +import community.flock.wirespec.compiler.core.parse.ast.Enum as AstEnum +import community.flock.wirespec.compiler.core.parse.ast.Refined as AstRefined +import community.flock.wirespec.compiler.core.parse.ast.Type as AstType + +class IrConverterTest { + + private inline fun parse(source: String): T = object : ParseContext, NoLogger { + override val spec = WirespecSpec + }.parse(nonEmptyListOf(ModuleContent(FileUri("test.ws"), source))) + .map { it.modules.flatMap(Module::statements) } + .getOrElse { fail("Parse failed: $it") } + .first() + .let { it as? T ?: fail("Expected ${T::class.simpleName} but got ${it::class.simpleName}") } + + private fun parseNodes(source: String): List = object : ParseContext, NoLogger { + override val spec = WirespecSpec + }.parse(nonEmptyListOf(ModuleContent(FileUri("test.ws"), source))) + .map { it.modules.flatMap(Module::statements) } + .getOrElse { fail("Parse failed: $it") } + + @Test + fun testLanguageConverter() { + val source = """ + type Foo { + bar: String + } + """.trimIndent() + + val result = parse(source).convert() + + val expected = file("Foo") { + struct("Foo") { + implements(Type.Custom("Wirespec.Model")) + field(Name(listOf("bar")), string) + function("validate", isOverride = true) { + returnType(Type.Array(Type.String)) + returns(LiteralList(emptyList(), Type.String)) + } + } + } + + assertEquals(expected, result) + } + + @Test + fun testEnumConversion() { + val source = """ + enum MyEnum { + FOO, BAR + } + """.trimIndent() + + val result = parse(source).convert() + + val expected = file("MyEnum") { + enum("MyEnum", Type.Custom("Wirespec.Enum")) { + entry("FOO") + entry("BAR") + } + } + + assertEquals(expected, result) + } + + @Test + fun testUnionConversion() { + val source = """ + type MyUnion = Foo | Bar + type Foo { a: String } + type Bar { b: String } + """.trimIndent() + + val result = parseNodes(source).map { it.convert() } + + val expected = listOf( + file("MyUnion") { + union("MyUnion") { + member("Foo") + member("Bar") + } + }, + file("Foo") { + struct("Foo") { + implements(Type.Custom("Wirespec.Model")) + implements(Type.Custom("MyUnion")) + field(Name(listOf("a")), string) + function("validate", isOverride = true) { + returnType(Type.Array(Type.String)) + returns(LiteralList(emptyList(), Type.String)) + } + } + }, + file("Bar") { + struct("Bar") { + implements(Type.Custom("Wirespec.Model")) + implements(Type.Custom("MyUnion")) + field(Name(listOf("b")), string) + function("validate", isOverride = true) { + returnType(Type.Array(Type.String)) + returns(LiteralList(emptyList(), Type.String)) + } + } + }, + ) + + assertEquals(expected, result) + } + + @Test + fun testRefinedConversion() { + val source = """ + type DutchPostalCode = String(/^([0-9]{4}[A-Z]{2})$/g) + """.trimIndent() + + val result = parse(source).convert() + + val expected = file("DutchPostalCode") { + struct("DutchPostalCode") { + implements(type("Wirespec.Refined", string)) + field("value", Type.String) + function("validate") { + returnType(Type.Boolean) + returns( + Constraint.RegexMatch( + pattern = "^([0-9]{4}[A-Z]{2})\$", + rawValue = "/^([0-9]{4}[A-Z]{2})\$/g", + value = VariableReference(Name.of("value")), + ), + ) + } + } + } + + assertEquals(expected, result) + } +} diff --git a/src/compiler/ir/src/commonTest/kotlin/community/flock/wirespec/ir/core/NameTest.kt b/src/compiler/ir/src/commonTest/kotlin/community/flock/wirespec/ir/core/NameTest.kt new file mode 100644 index 000000000..139528973 --- /dev/null +++ b/src/compiler/ir/src/commonTest/kotlin/community/flock/wirespec/ir/core/NameTest.kt @@ -0,0 +1,195 @@ +package community.flock.wirespec.ir.core + +import kotlin.test.Test +import kotlin.test.assertEquals + +class NameTest { + + @Test + fun testOfSplitsHelloWorld() { + assertEquals(listOf("Hello", "World"), Name.of("HelloWorld").parts) + } + + @Test + fun testOfSplitsUSA() { + assertEquals(listOf("USA"), Name.of("USA").parts) + } + + @Test + fun testOfSplitsCodeUUID() { + assertEquals(listOf("code", "UUID"), Name.of("codeUUID").parts) + } + + @Test + fun testOfSplitsFirstName() { + assertEquals(listOf("first", "_", "name"), Name.of("first_name").parts) + } + + @Test + fun testOfSplitsSomeField() { + assertEquals(listOf("some", "-", "field"), Name.of("some-field").parts) + } + + @Test + fun testOfSplitsGetHTTPResponse() { + assertEquals(listOf("get", "HTTP", "Response"), Name.of("getHTTPResponse").parts) + } + + @Test + fun testOfSplitsRequestColonNew() { + assertEquals(listOf("Request", "::", "new"), Name.of("Request::new").parts) + } + + @Test + fun testOfSplitsSimple() { + assertEquals(listOf("simple"), Name.of("simple").parts) + } + + @Test + fun testOfSplitsCamelCase() { + assertEquals(listOf("first", "Name"), Name.of("firstName").parts) + } + + @Test + fun testOfSplitsHTMLParser() { + assertEquals(listOf("HTML", "Parser"), Name.of("HTMLParser").parts) + } + + @Test + fun testOfSplitsUnderscoreClass() { + assertEquals(listOf("_", "class"), Name.of("_class").parts) + } + + // value() tests + + @Test + fun testValueHelloWorld() { + assertEquals("HelloWorld", Name.of("HelloWorld").value()) + } + + @Test + fun testValueFirstName() { + assertEquals("first_name", Name.of("first_name").value()) + } + + @Test + fun testValueSomeField() { + assertEquals("some-field", Name.of("some-field").value()) + } + + @Test + fun testValueUSA() { + assertEquals("USA", Name.of("USA").value()) + } + + @Test + fun testValueCodeUUID() { + assertEquals("codeUUID", Name.of("codeUUID").value()) + } + + @Test + fun testValueRequestColonNew() { + assertEquals("Request::new", Name.of("Request::new").value()) + } + + @Test + fun testValueSinglePart() { + assertEquals("bar", Name(listOf("bar")).value()) + } + + @Test + fun testValueMultipleParts() { + assertEquals("HelloWorld", Name("Hello", "World").value()) + } + + // camelCase() tests + + @Test + fun testCamelCaseHelloWorld() { + assertEquals("helloWorld", Name.of("HelloWorld").camelCase()) + } + + @Test + fun testCamelCaseFirstName() { + assertEquals("firstName", Name.of("first_name").camelCase()) + } + + @Test + fun testCamelCaseUSA() { + assertEquals("uSA", Name.of("USA").camelCase()) + } + + @Test + fun testCamelCaseCodeUUID() { + assertEquals("codeUUID", Name.of("codeUUID").camelCase()) + } + + @Test + fun testCamelCaseGetHTTPResponse() { + assertEquals("getHTTPResponse", Name.of("getHTTPResponse").camelCase()) + } + + @Test + fun testCamelCaseSimple() { + assertEquals("simple", Name.of("simple").camelCase()) + } + + @Test + fun testCamelCaseVararg() { + assertEquals("helloWorld", Name("Hello", "World").camelCase()) + } + + // pascalCase() tests + + @Test + fun testPascalCaseHelloWorld() { + assertEquals("HelloWorld", Name.of("HelloWorld").pascalCase()) + } + + @Test + fun testPascalCaseFirstName() { + assertEquals("FirstName", Name.of("first_name").pascalCase()) + } + + @Test + fun testPascalCaseSimple() { + assertEquals("Simple", Name.of("simple").pascalCase()) + } + + @Test + fun testPascalCaseVararg() { + assertEquals("HelloWorld", Name("hello", "world").pascalCase()) + } + + // snakeCase() tests + + @Test + fun testSnakeCaseHelloWorld() { + assertEquals("hello_world", Name.of("HelloWorld").snakeCase()) + } + + @Test + fun testSnakeCaseFirstName() { + assertEquals("first_name", Name.of("first_name").snakeCase()) + } + + @Test + fun testSnakeCaseCodeUUID() { + assertEquals("code_uuid", Name.of("codeUUID").snakeCase()) + } + + @Test + fun testSnakeCaseGetHTTPResponse() { + assertEquals("get_http_response", Name.of("getHTTPResponse").snakeCase()) + } + + @Test + fun testSnakeCaseSimple() { + assertEquals("simple", Name.of("simple").snakeCase()) + } + + @Test + fun testSnakeCaseVararg() { + assertEquals("hello_world", Name("Hello", "World").snakeCase()) + } +} diff --git a/src/compiler/ir/src/commonTest/kotlin/community/flock/wirespec/ir/core/TransformTest.kt b/src/compiler/ir/src/commonTest/kotlin/community/flock/wirespec/ir/core/TransformTest.kt new file mode 100644 index 000000000..ce379f5dd --- /dev/null +++ b/src/compiler/ir/src/commonTest/kotlin/community/flock/wirespec/ir/core/TransformTest.kt @@ -0,0 +1,607 @@ +package community.flock.wirespec.ir.core + +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +class TransformTest { + + @Test + fun transformShouldRenameCustomType() { + val struct = Struct( + name = Name.of("Person"), + fields = listOf( + Field(Name.of("name"), Type.String), + Field(Name.of("address"), Type.Custom("Address")), + ), + ) + + val result = struct.renameType("Address", "Location") + + assertEquals(Name.of("Person"), result.name) + assertEquals(Type.String, result.fields[0].type) + assertEquals(Type.Custom("Location"), result.fields[1].type) + } + + @Test + fun transformShouldRenameNestedCustomType() { + val struct = Struct( + name = Name.of("Person"), + fields = listOf( + Field(Name.of("addresses"), Type.Array(Type.Custom("Address"))), + ), + ) + + val result = struct.renameType("Address", "Location") + + assertEquals(Type.Array(Type.Custom("Location")), result.fields[0].type) + } + + @Test + fun transformShouldRenameTypeInNullable() { + val struct = Struct( + name = Name.of("Person"), + fields = listOf( + Field(Name.of("address"), Type.Nullable(Type.Custom("Address"))), + ), + ) + + val result = struct.renameType("Address", "Location") + + assertEquals(Type.Nullable(Type.Custom("Location")), result.fields[0].type) + } + + @Test + fun transformShouldRenameTypeInDict() { + val struct = Struct( + name = Name.of("Registry"), + fields = listOf( + Field(Name.of("items"), Type.Dict(Type.String, Type.Custom("Item"))), + ), + ) + + val result = struct.renameType("Item", "Product") + + assertEquals(Type.Dict(Type.String, Type.Custom("Product")), result.fields[0].type) + } + + @Test + fun transformShouldRenameField() { + val struct = Struct( + name = Name.of("Person"), + fields = listOf( + Field(Name.of("firstName"), Type.String), + Field(Name.of("lastName"), Type.String), + ), + ) + + val result = struct.renameField("firstName", "givenName") + + assertEquals(Name.of("givenName"), result.fields[0].name) + assertEquals(Name.of("lastName"), result.fields[1].name) + } + + @Test + fun transformMatchingShouldTransformSpecificType() { + val struct = Struct( + name = Name.of("Container"), + fields = listOf( + Field(Name.of("items"), Type.Array(Type.Custom("Item"))), + Field(Name.of("count"), Type.Integer()), + ), + ) + + val result = struct.transformMatching { array: Type.Array -> + Type.Custom("List", listOf(array.elementType)) + } + + assertEquals(Type.Custom("List", listOf(Type.Custom("Item"))), result.fields[0].type) + assertEquals(Type.Integer(), result.fields[1].type) + } + + @Test + fun transformMatchingElementsShouldTransformSpecificElementType() { + val file = File( + name = Name.of("test.ws"), + elements = listOf( + Struct(Name.of("Person"), listOf(Field(Name.of("name"), Type.String))), + Struct(Name.of("Address"), listOf(Field(Name.of("street"), Type.String))), + ), + ) + + val result = file.transformMatchingElements { struct: Struct -> + struct.copy(name = Name.of("Prefixed${struct.name.pascalCase()}")) + } + + val structs = result.elements.filterIsInstance() + assertEquals(Name.of("PrefixedPerson"), structs[0].name) + assertEquals(Name.of("PrefixedAddress"), structs[1].name) + } + + @Test + fun transformFieldsWhereShouldTransformMatchingFields() { + val struct = Struct( + name = Name.of("Person"), + fields = listOf( + Field(Name.of("name"), Type.String), + Field(Name.of("age"), Type.Integer()), + Field(Name.of("description"), Type.String), + ), + ) + + val result = struct.transformFieldsWhere( + predicate = { it.type == Type.String }, + transform = { it.copy(type = Type.Nullable(it.type)) }, + ) + + assertEquals(Type.Nullable(Type.String), result.fields[0].type) + assertEquals(Type.Integer(), result.fields[1].type) + assertEquals(Type.Nullable(Type.String), result.fields[2].type) + } + + @Test + fun forEachTypeShouldVisitAllTypes() { + val struct = Struct( + name = Name.of("Person"), + fields = listOf( + Field(Name.of("name"), Type.String), + Field(Name.of("addresses"), Type.Array(Type.Custom("Address"))), + ), + ) + + val types = mutableListOf() + struct.forEachType { types.add(it) } + + assertTrue(types.contains(Type.String)) + assertTrue(types.contains(Type.Array(Type.Custom("Address")))) + assertTrue(types.contains(Type.Custom("Address"))) + } + + @Test + fun forEachElementShouldVisitAllElements() { + val file = File( + name = Name.of("test.ws"), + elements = listOf( + Struct(Name.of("Person"), listOf(Field(Name.of("name"), Type.String))), + Struct(Name.of("Address"), listOf(Field(Name.of("street"), Type.String))), + ), + ) + + val elements = mutableListOf() + file.forEachElement { elements.add(it) } + + assertEquals(3, elements.size) + assertTrue(elements[0] is File) + assertTrue(elements[1] is Struct) + assertTrue(elements[2] is Struct) + } + + @Test + fun forEachFieldShouldVisitAllFields() { + val struct = Struct( + name = Name.of("Person"), + fields = listOf( + Field(Name.of("name"), Type.String), + Field(Name.of("age"), Type.Integer()), + ), + ) + + val fields = mutableListOf() + struct.forEachField { fields.add(it) } + + assertEquals(2, fields.size) + assertEquals(Name.of("name"), fields[0].name) + assertEquals(Name.of("age"), fields[1].name) + } + + @Test + fun collectTypesShouldReturnAllTypes() { + val struct = Struct( + name = Name.of("Person"), + fields = listOf( + Field(Name.of("name"), Type.String), + Field(Name.of("age"), Type.Integer()), + ), + ) + + val types = struct.collectTypes() + + assertEquals(2, types.size) + assertTrue(types.contains(Type.String)) + assertTrue(types.contains(Type.Integer())) + } + + @Test + fun collectCustomTypeNamesShouldReturnAllCustomTypeNames() { + val struct = Struct( + name = Name.of("Person"), + fields = listOf( + Field(Name.of("address"), Type.Custom("Address")), + Field(Name.of("company"), Type.Custom("Company")), + Field(Name.of("name"), Type.String), + ), + ) + + val names = struct.collectCustomTypeNames() + + assertEquals(setOf("Address", "Company"), names) + } + + @Test + fun findAllShouldReturnAllMatchingElements() { + val file = File( + name = Name.of("test.ws"), + elements = listOf( + Struct(Name.of("Person"), listOf(Field(Name.of("name"), Type.String))), + Enum(Name.of("Status"), entries = listOf(Enum.Entry(Name.of("Active"), emptyList()))), + Struct(Name.of("Address"), listOf(Field(Name.of("street"), Type.String))), + ), + ) + + val structs = file.findAll() + + assertEquals(2, structs.size) + assertEquals(Name.of("Person"), structs[0].name) + assertEquals(Name.of("Address"), structs[1].name) + } + + @Test + fun findAllTypesShouldReturnAllMatchingTypes() { + val struct = Struct( + name = Name.of("Container"), + fields = listOf( + Field(Name.of("items"), Type.Array(Type.Custom("Item"))), + Field(Name.of("tags"), Type.Array(Type.String)), + Field(Name.of("name"), Type.String), + ), + ) + + val arrays = struct.findAllTypes() + + assertEquals(2, arrays.size) + } + + @Test + fun transformShouldHandleDeeplyNestedStructures() { + val file = File( + name = Name.of("test.ws"), + elements = listOf( + Struct( + name = Name.of("Outer"), + fields = listOf( + Field(Name.of("inner"), Type.Custom("Inner")), + ), + elements = listOf( + Struct( + name = Name.of("Inner"), + fields = listOf( + Field(Name.of("value"), Type.Custom("Value")), + ), + ), + ), + ), + ), + ) + + val result = file.renameType("Value", "Data") + val outer = result.elements[0] as Struct + val inner = outer.elements[0] as Struct + + assertEquals(Type.Custom("Data"), inner.fields[0].type) + } + + @Test + fun transformShouldHandleFunctionParametersAndReturnTypes() { + val function = Function( + name = Name.of("process"), + parameters = listOf( + Parameter(Name.of("input"), Type.Custom("Input")), + ), + returnType = Type.Custom("Output"), + body = emptyList(), + ) + + val result = function.renameType("Input", "Request") + .renameType("Output", "Response") + + assertEquals(Type.Custom("Request"), result.parameters[0].type) + assertEquals(Type.Custom("Response"), result.returnType) + } + + @Test + fun transformShouldHandleGenericTypes() { + val struct = Struct( + name = Name.of("Container"), + fields = listOf( + Field(Name.of("items"), Type.Custom("List", listOf(Type.Custom("Item")))), + ), + ) + + val result = struct.renameType("Item", "Product") + + assertEquals( + Type.Custom("List", listOf(Type.Custom("Product"))), + result.fields[0].type, + ) + } + + @Test + fun customTransformerShouldAllowComplexTransformations() { + val struct = Struct( + name = Name.of("Person"), + fields = listOf( + Field(Name.of("name"), Type.String), + Field(Name.of("age"), Type.Integer()), + ), + ) + + val transformer = object : Transformer { + override fun transformField(field: Field): Field { + val newType = when (field.type) { + Type.String -> Type.Nullable(Type.String) + else -> field.type + } + return field.copy(type = newType).transformChildren(this) + } + } + + val result = struct.transform(transformer) + + assertEquals(Type.Nullable(Type.String), result.fields[0].type) + assertEquals(Type.Integer(), result.fields[1].type) + } + + @Test + fun transformerFunctionShouldCreateTransformerFromLambdas() { + val struct = Struct( + name = Name.of("Person"), + fields = listOf( + Field(Name.of("id"), Type.Integer()), + Field(Name.of("score"), Type.Number()), + ), + ) + + val transformer = transformer { + type { type, t -> + when (type) { + is Type.Integer -> Type.Integer(Precision.P64) + is Type.Number -> Type.Number(Precision.P32) + else -> type.transformChildren(t) + } + } + } + + val result = struct.transform(transformer) + + assertEquals(Type.Integer(Precision.P64), result.fields[0].type) + assertEquals(Type.Number(Precision.P32), result.fields[1].type) + } + + @Test + fun transformShouldHandleStatementsWithExpressions() { + val function = Function( + name = Name.of("create"), + parameters = emptyList(), + returnType = Type.Custom("Result"), + body = listOf( + ReturnStatement( + ConstructorStatement( + type = Type.Custom("Result"), + namedArguments = mapOf( + Name.of("value") to Literal("test", Type.String), + ), + ), + ), + ), + ) + + val result = function.renameType("Result", "Response") + + assertEquals(Type.Custom("Response"), result.returnType) + val returnStmt = result.body[0] as ReturnStatement + val constructor = returnStmt.expression as ConstructorStatement + assertEquals(Type.Custom("Response"), constructor.type) + } + + @Test + fun transformShouldHandleSwitchStatements() { + val function = Function( + name = Name.of("process"), + parameters = listOf(Parameter(Name.of("input"), Type.Custom("Input"))), + returnType = Type.Custom("Output"), + body = listOf( + Switch( + expression = RawExpression("input"), + cases = listOf( + Case( + value = RawExpression("case1"), + body = listOf( + ReturnStatement(ConstructorStatement(Type.Custom("Output"))), + ), + type = Type.Custom("Type1"), + ), + ), + default = listOf( + ErrorStatement(Literal("Unknown", Type.String)), + ), + ), + ), + ) + + val result = function.renameType("Output", "Result") + + val switch = result.body[0] as Switch + val case = switch.cases[0] + val returnStmt = case.body[0] as ReturnStatement + val constructor = returnStmt.expression as ConstructorStatement + assertEquals(Type.Custom("Result"), constructor.type) + } + + // TransformScope DSL tests + + @Test + fun transformScopeSingleRenameType() { + val struct = Struct( + name = Name.of("Person"), + fields = listOf( + Field(Name.of("address"), Type.Custom("Address")), + ), + ) + + val result = struct.transform { + renameType("Address", "Location") + } + + assertEquals(Type.Custom("Location"), result.fields[0].type) + } + + @Test + fun transformScopeMultipleOperationsInSequence() { + val struct = Struct( + name = Name.of("Person"), + fields = listOf( + Field(Name.of("firstName"), Type.Custom("Address")), + ), + ) + + val result = struct.transform { + renameType("Address", "Location") + renameField("firstName", "givenName") + } + + assertEquals(Type.Custom("Location"), result.fields[0].type) + assertEquals(Name.of("givenName"), result.fields[0].name) + } + + @Test + fun transformScopeMatchingElements() { + val file = File( + name = Name.of("test.ws"), + elements = listOf( + Struct(Name.of("Person"), listOf(Field(Name.of("name"), Type.String))), + Struct(Name.of("Address"), listOf(Field(Name.of("street"), Type.String))), + ), + ) + + val result = file.transform { + matchingElements { struct: Struct -> + struct.copy(name = Name.of("Prefixed${struct.name.pascalCase()}")) + } + } + + val structs = result.elements.filterIsInstance() + assertEquals(Name.of("PrefixedPerson"), structs[0].name) + assertEquals(Name.of("PrefixedAddress"), structs[1].name) + } + + @Test + fun transformScopeApplyTransformer() { + val struct = Struct( + name = Name.of("Person"), + fields = listOf( + Field(Name.of("id"), Type.Integer()), + ), + ) + + val result = struct.transform { + type { type, t -> + when (type) { + is Type.Integer -> Type.Integer(Precision.P64) + else -> type.transformChildren(t) + } + } + } + + assertEquals(Type.Integer(Precision.P64), result.fields[0].type) + } + + @Test + fun transformScopeFieldsWhereAndParametersWhere() { + val function = Function( + name = Name.of("process"), + parameters = listOf( + Parameter(Name.of("input"), Type.String), + ), + returnType = Type.String, + body = emptyList(), + ) + + val file = File( + name = Name.of("test.ws"), + elements = listOf( + Struct( + name = Name.of("Person"), + fields = listOf(Field(Name.of("name"), Type.String)), + ), + function, + ), + ) + + val result = file.transform { + fieldsWhere({ it.type == Type.String }) { it.copy(type = Type.Nullable(Type.String)) } + parametersWhere({ it.type == Type.String }) { it.copy(type = Type.Nullable(Type.String)) } + } + + val struct = result.elements[0] as Struct + assertEquals(Type.Nullable(Type.String), struct.fields[0].type) + val fn = result.elements[1] as Function + assertEquals(Type.Nullable(Type.String), fn.parameters[0].type) + } + + @Test + fun transformScopeInjectBeforeAndAfter() { + val file = File( + name = Name.of("test.ws"), + elements = listOf( + Struct( + Name.of("Person"), + listOf(Field(Name.of("name"), Type.String)), + elements = listOf(RawElement("existing")), + ), + ), + ) + + val result = file.transform { + injectBefore { _: Struct -> + listOf(RawElement("before")) + } + injectAfter { _: Struct -> + listOf(RawElement("after")) + } + } + + val struct = result.elements[0] as Struct + assertEquals(3, struct.elements.size) + assertEquals("before", (struct.elements[0] as RawElement).code) + assertEquals("existing", (struct.elements[1] as RawElement).code) + assertEquals("after", (struct.elements[2] as RawElement).code) + } + + @Test + fun transformScopeEmptyIsNoOp() { + val struct = Struct( + name = Name.of("Person"), + fields = listOf(Field(Name.of("name"), Type.String)), + ) + + val result = struct.transform { } + + assertEquals(struct, result) + } + + @Test + fun transformScopePreservesReturnType() { + val file = File( + name = Name.of("test.ws"), + elements = listOf( + Struct(Name.of("Person"), listOf(Field(Name.of("name"), Type.String))), + ), + ) + + val result: File = file.transform { + renameType("String", "Text") + } + + assertEquals(Name.of("test.ws"), result.name) + } +} diff --git a/src/compiler/ir/src/commonTest/kotlin/community/flock/wirespec/ir/generator/DslTest.kt b/src/compiler/ir/src/commonTest/kotlin/community/flock/wirespec/ir/generator/DslTest.kt new file mode 100644 index 000000000..9f23a9ef5 --- /dev/null +++ b/src/compiler/ir/src/commonTest/kotlin/community/flock/wirespec/ir/generator/DslTest.kt @@ -0,0 +1,1119 @@ +package community.flock.wirespec.ir.generator + +import community.flock.wirespec.ir.core.FieldCall +import community.flock.wirespec.ir.core.FunctionCall +import community.flock.wirespec.ir.core.LiteralList +import community.flock.wirespec.ir.core.Name +import community.flock.wirespec.ir.core.NullCheck +import community.flock.wirespec.ir.core.RawExpression +import community.flock.wirespec.ir.core.Type +import community.flock.wirespec.ir.core.VariableReference +import community.flock.wirespec.ir.core.file +import kotlin.test.Test +import kotlin.test.assertTrue + +class DslTest { + + @Test + fun testNestedNullable() { + val file = file("NestedNullableModule") { + struct("Data") { + field("tags", list(string.nullable())) + } + function("process") { + returnType(string.nullable()) + arg("input", list(string.nullable()).nullable()) + returns(RawExpression("null")) + } + } + + val javaCode = JavaGenerator.generate(file) + val tsCode = TypeScriptGenerator.generate(file) + + println("--- Java (Nested Nullable) ---") + println(javaCode) + println("--- TypeScript (Nested Nullable) ---") + println(tsCode) + + assertTrue(javaCode.contains("public record Data")) + assertTrue(javaCode.contains("java.util.List> tags")) + assertTrue(javaCode.contains(") {")) + assertTrue(javaCode.contains("java.util.Optional process(java.util.Optional>> input)")) + + assertTrue(tsCode.contains("\"tags\": (string | undefined)[],")) + assertTrue(tsCode.contains("export function process(input: (string | undefined)[] | undefined): string | undefined")) + } + + @Test + fun testDslAndGeneration() { + val file = file("MyModule") { + struct("User") { + field("id", integer) + field("name", string.nullable()) + } + + function("greet") { + returnType(string) + arg("user", type("User")) + print(RawExpression("\"Greeting \" + user.name")) + returns(RawExpression("user.name")) + } + } + + val javaCode = JavaGenerator.generate(file) + val pythonCode = PythonGenerator.generate(file) + val tsCode = TypeScriptGenerator.generate(file) + + println("--- Java ---") + println(javaCode) + println("--- Python ---") + println(pythonCode) + println("--- TypeScript ---") + println(tsCode) + + assertTrue(javaCode.contains("public record User")) + assertTrue(javaCode.contains("Integer id,")) + assertTrue(javaCode.contains("java.util.Optional name")) + assertTrue(javaCode.contains(") {")) + assertTrue(javaCode.contains("String greet(User user) {")) + + assertTrue(pythonCode.contains("@dataclass")) + assertTrue(pythonCode.contains("class User:")) + assertTrue(pythonCode.contains("id: int")) + assertTrue(pythonCode.contains("name: Optional[str]")) + assertTrue(pythonCode.contains("def greet(user: User) -> str:")) + + assertTrue(tsCode.contains("export type User = {")) + assertTrue(tsCode.contains("\"id\": number,")) + assertTrue(tsCode.contains("\"name\": string | undefined,")) + assertTrue(tsCode.contains("export function greet(user: User): string {")) + } + + @Test + fun testStaticElement() { + val file = file("StaticModule") { + namespace("MyService", extends = type("BaseService")) { + struct("Config") { + field("url", string) + } + function("start") { + print(literal("Starting service")) + } + } + } + + val javaCode = JavaGenerator.generate(file) + val tsCode = TypeScriptGenerator.generate(file) + val pythonCode = PythonGenerator.generate(file) + + println("--- Java (Static) ---") + println(javaCode) + println("--- TypeScript (Static) ---") + println(tsCode) + println("--- Python (Static) ---") + println(pythonCode) + + assertTrue(javaCode.contains("public interface MyService extends BaseService {")) + assertTrue(javaCode.contains("public static record Config")) + assertTrue(javaCode.contains("String url")) + assertTrue(javaCode.contains(") {")) + assertTrue(javaCode.contains("static void start() {")) + assertTrue(javaCode.contains("System.out.println(\"Starting service\");")) + + assertTrue(tsCode.contains("export namespace MyService {")) + assertTrue(tsCode.contains("export type Config = {")) + assertTrue(tsCode.contains("\"url\": string,")) + assertTrue(tsCode.contains("export function start() {")) + assertTrue(tsCode.contains("console.log('Starting service');")) + + assertTrue(pythonCode.contains("class MyService(BaseService):")) + assertTrue(pythonCode.contains("@dataclass")) + assertTrue(pythonCode.contains("class Config:")) + assertTrue(pythonCode.contains("url: str")) + } + + @Test + fun testNestedStatic() { + val file = file("NestedModule") { + namespace("Outer", type("Base")) { + namespace("Inner") { + function("test") { + returns(RawExpression("1")) + } + } + } + } + + val javaCode = JavaGenerator.generate(file) + val tsCode = TypeScriptGenerator.generate(file) + val pythonCode = PythonGenerator.generate(file) + + println("--- Java (Nested Static) ---") + println(javaCode) + println("--- TypeScript (Nested Static) ---") + println(tsCode) + println("--- Python (Nested Static) ---") + println(pythonCode) + + // Java + assertTrue(javaCode.contains("public interface Outer extends Base {")) + assertTrue(javaCode.contains("public interface Inner {")) + assertTrue(javaCode.contains("public static void test() {")) + + // TypeScript + assertTrue(tsCode.contains("export namespace Outer {")) + assertTrue(tsCode.contains("export namespace Inner {")) + assertTrue(tsCode.contains("export function test() {")) + + // Python + assertTrue(pythonCode.contains("class Outer(Base):")) + assertTrue(pythonCode.contains("class Inner:")) + assertTrue(pythonCode.contains("@staticmethod")) + assertTrue(pythonCode.contains("def test():")) + } + + @Test + fun testAsyncFunction() { + val file = file("AsyncModule") { + asyncFunction("fetchData") { + returnType(string) + arg("id", integer) + returns(literal("data")) + } + } + + val javaCode = JavaGenerator.generate(file) + val pythonCode = PythonGenerator.generate(file) + val tsCode = TypeScriptGenerator.generate(file) + + println("--- Java (Async) ---") + println(javaCode) + println("--- Python (Async) ---") + println(pythonCode) + println("--- TypeScript (Async) ---") + println(tsCode) + + assertTrue(javaCode.contains("java.util.concurrent.CompletableFuture fetchData(Integer id) {")) + assertTrue(pythonCode.contains("async def fetchData(id: int) -> str:")) + assertTrue(tsCode.contains("export async function fetchData(id: number): string {")) + } + + @Test + fun testFunctionInterface() { + val file = file("InterfaceModule") { + namespace("Api") { + function("getData") { + returnType(string) + } + asyncFunction("postData") { + arg("data", string) + } + } + } + + val javaCode = JavaGenerator.generate(file) + val pythonCode = PythonGenerator.generate(file) + val tsCode = TypeScriptGenerator.generate(file) + + println("--- Java (Interface) ---") + println(javaCode) + println("--- Python (Interface) ---") + println(pythonCode) + println("--- TypeScript (Interface) ---") + println(tsCode) + + // Java + assertTrue(javaCode.contains("public interface Api {")) + assertTrue(javaCode.contains("public static String getData();")) + assertTrue(javaCode.contains("public static java.util.concurrent.CompletableFuture postData(String data);")) + + // Python + assertTrue(pythonCode.contains("class Api:")) + assertTrue(pythonCode.contains("@staticmethod")) + assertTrue(pythonCode.contains("def getData() -> str:")) + assertTrue(pythonCode.contains("...")) + assertTrue(pythonCode.contains("async def postData(data: str):")) + + // TypeScript + assertTrue(tsCode.contains("export namespace Api {")) + assertTrue(tsCode.contains("getData(): string;")) + assertTrue(tsCode.contains("postData(data: string): Promise;")) + } + + @Test + fun testReturnStatement() { + val file = file("ReturnModule") { + function("test") { + returns( + functionCall("myFunc") { + arg("a", RawExpression("1")) + }, + ) + } + function("test2") { + returns( + construct(type("User")) { + arg("id", RawExpression("1")) + }, + ) + } + } + + val javaCode = JavaGenerator.generate(file) + val pythonCode = PythonGenerator.generate(file) + val tsCode = TypeScriptGenerator.generate(file) + + println("--- Java (Return Statement) ---") + println(javaCode) + println("--- Python (Return Statement) ---") + println(pythonCode) + println("--- TypeScript (Return Statement) ---") + println(tsCode) + + // Java + assertTrue(javaCode.contains("return myFunc(1);")) + assertTrue(javaCode.contains("return new User(1);")) + + // Python + assertTrue(pythonCode.contains("return myFunc(a=1)")) + assertTrue(pythonCode.contains("return User(id=1)")) + + // TypeScript + assertTrue(tsCode.contains("return myFunc(1);")) + assertTrue(tsCode.contains("return { id: 1 };")) + } + + @Test + fun testPrimitiveTypes() { + val file = file("PrimitiveModule") { + struct("Data") { + field("count", integer) + field("name", string) + field("active", boolean) + field("ratio", number) + } + + function("process") { + returnType(boolean) + arg("count", integer) + arg("name", string) + print(literal("count: count")) + returns(literal(true)) + } + } + + val javaCode = JavaGenerator.generate(file) + val pythonCode = PythonGenerator.generate(file) + val tsCode = TypeScriptGenerator.generate(file) + + println("--- Java (Primitives) ---") + println(javaCode) + println("--- Python (Primitives) ---") + println(pythonCode) + println("--- TypeScript (Primitives) ---") + println(tsCode) + + // Java + assertTrue(javaCode.contains("public record Data")) + assertTrue(javaCode.contains("Integer count,")) + assertTrue(javaCode.contains("String name,")) + assertTrue(javaCode.contains("Boolean active,")) + assertTrue(javaCode.contains("Double ratio")) + assertTrue(javaCode.contains(") {")) + assertTrue(javaCode.contains("Boolean process(Integer count, String name) {")) + assertTrue(javaCode.contains("return true;")) + assertTrue(javaCode.contains("System.out.println(\"count: count\");")) + + // Python + assertTrue(pythonCode.contains("class Data:")) + assertTrue(pythonCode.contains("print('count: count')")) + assertTrue(pythonCode.contains("return True")) + + // TypeScript + assertTrue(tsCode.contains("console.log('count: count');")) + assertTrue(tsCode.contains("return true;")) + } + + @Test + fun testLiteralStatement() { + val file = file("LiteralModule") { + function("test") { + literal(1, integer) + literal("hello", string) + literal(true, boolean) + literal(1.2, number) + + returns( + construct(type("User")) { + arg("id", literal(1, integer)) + arg("name", literal("John", string)) + }, + ) + } + } + + val javaCode = JavaGenerator.generate(file) + val pythonCode = PythonGenerator.generate(file) + val tsCode = TypeScriptGenerator.generate(file) + + println("--- Java (Literal) ---") + println(javaCode) + println("--- Python (Literal) ---") + println(pythonCode) + println("--- TypeScript (Literal) ---") + println(tsCode) + + // Java + assertTrue(javaCode.contains("1;")) + assertTrue(javaCode.contains("\"hello\";")) + assertTrue(javaCode.contains("true;")) + assertTrue(javaCode.contains("1.2;")) + assertTrue(javaCode.contains("return new User(")) + + // Python + assertTrue(pythonCode.contains("1")) + assertTrue(pythonCode.contains("'hello'")) + assertTrue(pythonCode.contains("True")) + assertTrue(pythonCode.contains("1.2")) + assertTrue(pythonCode.contains("return User(id=1, name='John')")) + + // TypeScript + assertTrue(tsCode.contains("1;")) + assertTrue(tsCode.contains("'hello';")) + assertTrue(tsCode.contains("true;")) + assertTrue(tsCode.contains("1.2;")) + assertTrue(tsCode.contains("return { id: 1, name: 'John' };")) + } + + @Test + fun testArrayType() { + val file = file("ArrayModule") { + struct("User") { + field("tags", list(string)) + field("scores", list(integer)) + } + + function("process") { + returnType(list(boolean)) + arg("tags", list(string)) + returns(RawExpression("tags")) + } + } + + val javaCode = JavaGenerator.generate(file) + val pythonCode = PythonGenerator.generate(file) + val tsCode = TypeScriptGenerator.generate(file) + + println("--- Java (Array) ---") + println(javaCode) + println("--- Python (Array) ---") + println(pythonCode) + println("--- TypeScript (Array) ---") + println(tsCode) + + // Java + assertTrue(javaCode.contains("public record User")) + assertTrue(javaCode.contains("java.util.List tags,")) + assertTrue(javaCode.contains("java.util.List scores")) + assertTrue(javaCode.contains(") {")) + assertTrue(javaCode.contains("java.util.List process(java.util.List tags) {")) + assertTrue(javaCode.contains("return tags;")) + + // Python + assertTrue(pythonCode.contains("def process(tags: list[str]) -> list[bool]:")) + + // TypeScript + assertTrue(tsCode.contains("\"tags\": string[],")) + assertTrue(tsCode.contains("\"scores\": number[],")) + assertTrue(tsCode.contains("export function process(tags: string[]): boolean[] {")) + } + + @Test + fun testAssignmentStatement() { + val file = file("AssignmentModule") { + function("test") { + assign("x", literal(1, integer)) + assign( + "y", + functionCall("myFunc") { + arg("a", RawExpression("1")) + }, + ) + returns(RawExpression("x")) + } + } + + val javaCode = JavaGenerator.generate(file) + val pythonCode = PythonGenerator.generate(file) + val tsCode = TypeScriptGenerator.generate(file) + + println("--- Java (Assignment) ---") + println(javaCode) + println("--- Python (Assignment) ---") + println(pythonCode) + println("--- TypeScript (Assignment) ---") + println(tsCode) + + // Java + assertTrue(javaCode.contains("final var x = 1;")) + assertTrue(javaCode.contains("final var y = myFunc(1);")) + + // Python + assertTrue(pythonCode.contains("x = 1")) + assertTrue(pythonCode.contains("y = myFunc(a=1)")) + + // TypeScript + assertTrue(tsCode.contains("const x = 1;")) + assertTrue(tsCode.contains("const y = myFunc(1);")) + } + + @Test + fun testLiteralList() { + val file = file("LiteralListModule") { + function("test") { + assign("tags", literalList(listOf(literal("tag1"), literal("tag2")), string)) + returns(RawExpression("tags")) + } + } + + val javaCode = JavaGenerator.generate(file) + val pythonCode = PythonGenerator.generate(file) + val tsCode = TypeScriptGenerator.generate(file) + + println("--- Java (Literal List) ---") + println(javaCode) + println("--- Python (Literal List) ---") + println(pythonCode) + println("--- TypeScript (Literal List) ---") + println(tsCode) + + // Java + assertTrue(javaCode.contains("final var tags = java.util.List.of(\"tag1\", \"tag2\");")) + + // Python + assertTrue(pythonCode.contains("tags = ['tag1', 'tag2']")) + + // TypeScript + assertTrue(tsCode.contains("const tags = ['tag1', 'tag2'];")) + } + + @Test + fun testLiteralMap() { + val file = file("LiteralMapModule") { + function("test") { + assign("scores", literalMap(mapOf("Alice" to literal(10), "Bob" to literal(20)), string, integer)) + returns(RawExpression("scores")) + } + } + + val javaCode = JavaGenerator.generate(file) + val pythonCode = PythonGenerator.generate(file) + val tsCode = TypeScriptGenerator.generate(file) + + println("--- Java (Literal Map) ---") + println(javaCode) + println("--- Python (Literal Map) ---") + println(pythonCode) + println("--- TypeScript (Literal Map) ---") + println(tsCode) + + // Java + assertTrue(javaCode.contains("final var scores = java.util.Map.ofEntries(java.util.Map.entry(\"Alice\", 10), java.util.Map.entry(\"Bob\", 20));")) + + // Python + assertTrue(pythonCode.contains("scores = {'Alice': 10, 'Bob': 20}")) + + // TypeScript + assertTrue(tsCode.contains("const scores = { 'Alice': 10, 'Bob': 20 };")) + } + + @Test + fun testEmptyLiterals() { + val file = file("EmptyLiteralsModule") { + function("test") { + assign("emptyList", literalList(string)) + assign("emptyMap", literalMap(string, integer)) + + functionCall("myFunc") { + arg("list", emptyList(integer)) + arg("map", emptyMap(string, string)) + } + + construct(type("User")) { + arg("tags", emptyList(string)) + } + } + } + + val javaCode = JavaGenerator.generate(file) + val tsCode = TypeScriptGenerator.generate(file) + + println("--- Java (Empty Literals) ---") + println(javaCode) + println("--- TypeScript (Empty Literals) ---") + println(tsCode) + + assertTrue(javaCode.contains("final var emptyList = java.util.List.of();")) + assertTrue(javaCode.contains("final var emptyMap = java.util.Collections.emptyMap();")) + assertTrue(javaCode.contains("myFunc(java.util.List.of(), java.util.Collections.emptyMap());")) + assertTrue(javaCode.contains("new User(java.util.List.of());")) + + assertTrue(tsCode.contains("const emptyList = [] as string[];")) + assertTrue(tsCode.contains("const emptyMap = {};")) + assertTrue(tsCode.contains("myFunc([] as number[], {});")) + assertTrue(tsCode.contains("{ tags: [] as string[] };")) + } + + @Test + fun testStructConstructor() { + val file = file("StructConstructorModule") { + struct("User") { + field("id", integer) + field("name", string) + constructo { + arg("id", integer) + assign("this.id", RawExpression("id")) + assign("this.name", literal("Anonymous")) + } + } + } + + val javaCode = JavaGenerator.generate(file) + val pythonCode = PythonGenerator.generate(file) + val tsCode = TypeScriptGenerator.generate(file) + val kotlinCode = KotlinGenerator.generate(file) + + println("--- Java (Struct Constructor) ---") + println(javaCode) + println("--- Python (Struct Constructor) ---") + println(pythonCode) + println("--- TypeScript (Struct Constructor) ---") + println(tsCode) + println("--- Kotlin (Struct Constructor) ---") + println(kotlinCode) + + // Java + assertTrue(javaCode.contains("public record User")) + assertTrue(javaCode.contains("Integer id,")) + assertTrue(javaCode.contains("String name")) + assertTrue(javaCode.contains(") {")) + assertTrue(javaCode.contains("User(Integer id) {")) + assertTrue(javaCode.contains("this(id, \"Anonymous\");")) + + // Kotlin + assertTrue(kotlinCode.contains("data class User(")) + assertTrue(kotlinCode.contains("val id: Int,")) + assertTrue(kotlinCode.contains("val name: String")) + assertTrue(kotlinCode.contains("constructor(id: Int) : this(null, null) {")) + assertTrue(kotlinCode.contains("this.id = id")) + assertTrue(kotlinCode.contains("this.name = \"Anonymous\"")) + } + + @Test + fun testSwitchStatement() { + val file = file("SwitchModule") { + function("test") { + arg("x", integer) + switch(RawExpression("x")) { + case(literal(1)) { + print(literal("one")) + returns(literal("ONE")) + } + case(literal(2)) { + assign("y", literal(22)) + returns(RawExpression("y")) + } + default { + returns(literal("MANY")) + } + } + } + } + + val javaCode = JavaGenerator.generate(file) + val pythonCode = PythonGenerator.generate(file) + val tsCode = TypeScriptGenerator.generate(file) + + println("--- Java (Switch) ---") + println(javaCode) + println("--- Python (Switch) ---") + println(pythonCode) + println("--- TypeScript (Switch) ---") + println(tsCode) + + // Java + assertTrue(javaCode.contains("switch (x) {")) + assertTrue(javaCode.contains("case 1 -> {")) + assertTrue(javaCode.contains("System.out.println(\"one\");")) + assertTrue(javaCode.contains("return \"ONE\";")) + assertTrue(javaCode.contains("default -> {")) + assertTrue(javaCode.contains("return \"MANY\";")) + // Python + assertTrue(pythonCode.contains("match x:")) + assertTrue(pythonCode.contains("case 1:")) + assertTrue(pythonCode.contains("print('one')")) + assertTrue(pythonCode.contains("return 'ONE'")) + assertTrue(pythonCode.contains("case 2:")) + assertTrue(pythonCode.contains("y = 22")) + assertTrue(pythonCode.contains("return y")) + assertTrue(pythonCode.contains("case _:")) + assertTrue(pythonCode.contains("return 'MANY'")) + // TypeScript + assertTrue(tsCode.contains("switch (x) {")) + assertTrue(tsCode.contains("case 1:")) + assertTrue(tsCode.contains("console.log('one');")) + assertTrue(tsCode.contains("return 'ONE';")) + assertTrue(tsCode.contains("case 2:")) + assertTrue(tsCode.contains("const y = 22;")) + assertTrue(tsCode.contains("return y;")) + assertTrue(tsCode.contains("default:")) + } + + @Test + fun testErrorStatement() { + val file = file("ErrorModule") { + function("test") { + error(literal("Something went wrong")) + } + } + + val javaCode = JavaGenerator.generate(file) + val pythonCode = PythonGenerator.generate(file) + val tsCode = TypeScriptGenerator.generate(file) + + println("--- Java (Error) ---") + println(javaCode) + println("--- Python (Error) ---") + println(pythonCode) + println("--- TypeScript (Error) ---") + println(tsCode) + + // Java + assertTrue(javaCode.contains("throw new IllegalStateException(\"Something went wrong\");")) + + // Python + assertTrue(pythonCode.contains("raise Exception('Something went wrong')")) + + // TypeScript + assertTrue(tsCode.contains("throw new Error('Something went wrong');")) + } + + @Test + fun testStructExtends() { + val file = file("ExtendsModule") { + struct("User") { + implements(type("BaseUser")) + field("id", integer) + } + struct("Data") { + implements(type("BaseData")) + field("value", string) + } + } + + val javaCode = JavaGenerator.generate(file) + val tsCode = TypeScriptGenerator.generate(file) + val pythonCode = PythonGenerator.generate(file) + + println("--- Java (Struct Extends) ---") + println(javaCode) + println("--- TypeScript (Struct Extends) ---") + println(tsCode) + println("--- Python (Struct Extends) ---") + println(pythonCode) + + assertTrue(javaCode.contains("public record User")) + assertTrue(javaCode.contains(") implements BaseUser {")) + assertTrue(javaCode.contains("public record Data")) + assertTrue(javaCode.contains(") implements BaseData {")) + + assertTrue(tsCode.contains("export type User = {")) + assertTrue(tsCode.contains("export type Data = {")) + + assertTrue(pythonCode.contains("class User(BaseUser):")) + assertTrue(pythonCode.contains("class Data(BaseData):")) + } + + @Test + fun testImport() { + val file = file("ImportModule") { + import("java.util", "List") + import("com.example", "Other") + struct("MyStruct") { + field("other", type("Other")) + } + } + + val javaCode = JavaGenerator.generate(file) + val pythonCode = PythonGenerator.generate(file) + val tsCode = TypeScriptGenerator.generate(file) + + println("--- Java (Import) ---") + println(javaCode) + println("--- Python (Import) ---") + println(pythonCode) + println("--- TypeScript (Import) ---") + println(tsCode) + + assertTrue(javaCode.contains("import java.util.List;")) + assertTrue(javaCode.contains("import com.example.Other;")) + + assertTrue(pythonCode.contains("from java.util import List")) + assertTrue(pythonCode.contains("from com.example import Other")) + + assertTrue(tsCode.contains("import { List } from 'java.util';")) + assertTrue(tsCode.contains("import { Other } from 'com.example';")) + } + + @Test + fun testPackage() { + val file = file("PackageModule") { + `package`("com.example.test") + struct("MyStruct") { + field("id", integer) + } + } + + val javaCode = JavaGenerator.generate(file) + val pythonCode = PythonGenerator.generate(file) + val tsCode = TypeScriptGenerator.generate(file) + + println("--- Java (Package) ---") + println(javaCode) + println("--- Python (Package) ---") + println(pythonCode) + println("--- TypeScript (Package) ---") + println(tsCode) + + assertTrue(javaCode.contains("package com.example.test;")) + assertTrue(pythonCode.contains("# package com.example.test")) + // TypeScript: packages are not emitted (they have no meaning in TS) + assertTrue(tsCode.contains("export type MyStruct = {")) + } + + @Test + fun testUnion() { + val file = file("UnionModule") { + union("Response") { + member("Success") + member("Error") + } + struct("Success") { + implements(type("Response")) + field("data", string) + } + struct("Error") { + implements(type("Response")) + field("message", string) + } + } + + val javaCode = JavaGenerator.generate(file) + val pythonCode = PythonGenerator.generate(file) + val tsCode = TypeScriptGenerator.generate(file) + + println("--- Java (Union) ---") + println(javaCode) + println("--- Python (Union) ---") + println(pythonCode) + println("--- TypeScript (Union) ---") + println(tsCode) + + // Java + assertTrue(javaCode.contains("public sealed interface Response permits Success, Error {}")) + assertTrue(javaCode.contains("public record Success")) + assertTrue(javaCode.contains(") implements Response {")) + assertTrue(javaCode.contains("public record Error")) + assertTrue(javaCode.contains(") implements Response {")) + + // Python + assertTrue(pythonCode.contains("class Response:")) + assertTrue(pythonCode.contains("class Success(Response):")) + assertTrue(pythonCode.contains("class Error(Response):")) + + // TypeScript + assertTrue(tsCode.contains("export type Response = Success | Error")) + assertTrue(tsCode.contains("export type Success = {")) + assertTrue(tsCode.contains("export type Error = {")) + } + + @Test + fun testMultipleUnions() { + val file = file("MultipleUnionsModule") { + union("Response") { + member("Response2XX") + } + union("Response2XX", extends = type("Response")) { + member("Response200") + } + struct("Response200") { + implements(type("Response2XX")) + field("body", string) + } + } + + val javaCode = JavaGenerator.generate(file) + println("--- Java (Multiple Unions) ---") + println(javaCode) + + assertTrue(javaCode.contains("public sealed interface Response permits Response2XX {}")) + assertTrue(javaCode.contains("public sealed interface Response2XX extends Response permits Response200 {}")) + assertTrue(javaCode.contains("public record Response200")) + assertTrue(javaCode.contains("String body")) + assertTrue(javaCode.contains(") implements Response2XX {")) + } + + @Test + fun testMultiDimensionalUnions() { + val file = file("MultiDimensionalModule") { + union("UnionA") { + member("SharedStruct") + } + union("UnionB") { + member("SharedStruct") + } + struct("SharedStruct") { + implements(type("UnionA")) + implements(type("UnionB")) + field("id", string) + } + } + + val javaCode = JavaGenerator.generate(file) + val tsCode = TypeScriptGenerator.generate(file) + val pythonCode = PythonGenerator.generate(file) + + println("--- Java (Multi-Dimensional Unions) ---") + println(javaCode) + println("--- TypeScript (Multi-Dimensional Unions) ---") + println(tsCode) + println("--- Python (Multi-Dimensional Unions) ---") + println(pythonCode) + + // Java + assertTrue(javaCode.contains("public record SharedStruct")) + assertTrue(javaCode.contains(") implements UnionA, UnionB {")) + + // TypeScript + assertTrue(tsCode.contains("export type SharedStruct = {")) + + // Python + assertTrue(pythonCode.contains("class SharedStruct(UnionA, UnionB):")) + } + + @Test + fun testGenericsInExtends() { + val file = file("GenericsExtendsModule") { + struct("Box") { + implements(type("BaseBox", string)) + field("value", string) + } + namespace("Api", extends = type("BaseApi", integer)) { + function("getData") { + returnType(string) + } + } + } + + val javaCode = JavaGenerator.generate(file) + val pythonCode = PythonGenerator.generate(file) + val tsCode = TypeScriptGenerator.generate(file) + + println("--- Java (Generics Extends) ---") + println(javaCode) + println("--- Python (Generics Extends) ---") + println(pythonCode) + println("--- TypeScript (Generics Extends) ---") + println(tsCode) + + assertTrue(javaCode.contains("public record Box")) + assertTrue(javaCode.contains(") implements BaseBox {")) + assertTrue(javaCode.contains("public interface Api extends BaseApi {")) + + assertTrue(pythonCode.contains("class Box(BaseBox[str]):")) + assertTrue(pythonCode.contains("class Api(BaseApi[int]):")) + + assertTrue(tsCode.contains("export type Box = {")) + assertTrue(tsCode.contains("export namespace Api {")) + } + + @Test + fun testEmptyStruct() { + val file = file("EmptyStructModule") { + struct("Empty") + } + + val javaCode = JavaGenerator.generate(file) + println("--- Java (Empty Struct) ---") + println(javaCode) + + assertTrue(javaCode.contains("public record Empty"), "Should be a record") + } + + @Test + fun testNestedStructInline() { + val file = file("NestedStructModule") { + struct("Response201") { + field("status", integer) + struct("Headers") { + field("token", type("Token")) + field("refreshToken", type("Token").nullable()) + } + field("headers", type("Headers")) + field("body", type("TodoDto")) + } + } + + val tsCode = TypeScriptGenerator.generate(file) + + println("--- TypeScript (Nested Struct Inline) ---") + println(tsCode) + + // The nested Headers struct should be inlined as an anonymous object type + assertTrue(tsCode.contains("\"headers\": {\"token\": Token, \"refreshToken\": Token | undefined},")) + // There should be no separate "export type Headers" declaration + assertTrue(!tsCode.contains("export type Headers")) + // The parent struct should still be emitted normally + assertTrue(tsCode.contains("export type Response201 = {")) + assertTrue(tsCode.contains("\"status\": number,")) + assertTrue(tsCode.contains("\"body\": TodoDto,")) + } + + @Test + fun testNullCheck() { + // Build AST: NullCheck on request.queries.page with a function call body and empty list fallback + val nullCheckExpr = NullCheck( + expression = FieldCall(FieldCall(VariableReference(Name.of("request")), Name.of("queries")), Name.of("page")), + body = FunctionCall( + receiver = VariableReference(Name.of("serialization")), + name = Name.of("serializeParam"), + arguments = mapOf( + Name.of("value") to VariableReference(Name.of("it")), + ), + ), + alternative = LiteralList(emptyList(), Type.String), + ) + + val file = file("NullCheckModule") { + function("test") { + returns(nullCheckExpr) + } + } + + val javaCode = JavaGenerator.generate(file) + val kotlinCode = KotlinGenerator.generate(file) + val tsCode = TypeScriptGenerator.generate(file) + val pythonCode = PythonGenerator.generate(file) + + println("--- Java (NullCheck) ---") + println(javaCode) + println("--- Kotlin (NullCheck) ---") + println(kotlinCode) + println("--- TypeScript (NullCheck) ---") + println(tsCode) + println("--- Python (NullCheck) ---") + println(pythonCode) + + // Java: Optional.ofNullable(...).map(it -> ...).orElse(...) + assertTrue(javaCode.contains("java.util.Optional.ofNullable(request.queries().page())")) + assertTrue(javaCode.contains(".map(it -> serialization.serializeParam(it))")) + assertTrue(javaCode.contains(".orElse(java.util.List.of())")) + + // Kotlin: expression?.let { body } ?: alternative + assertTrue(kotlinCode.contains("request.queries.page?.let { serialization.serializeParam(it) } ?: emptyList()")) + + // TypeScript: ternary with inlined expression + assertTrue(tsCode.contains("request.queries.page != null")) + assertTrue(tsCode.contains("serialization.serializeParam(request.queries.page)")) + assertTrue(tsCode.contains(": [] as string[]")) + + // Python: conditional with inlined expression + assertTrue(pythonCode.contains("serialization.serializeParam(request.queries.page)")) + assertTrue(pythonCode.contains("if request.queries.page is not None else")) + assertTrue(pythonCode.contains("[]")) + } + + @Test + fun testAssertStatement() { + val file = file("AssertModule") { + function("validate") { + arg("value", integer) + assertThat( + RawExpression("value > 0"), + "Value must be positive", + ) + } + } + + val javaCode = JavaGenerator.generate(file) + val kotlinCode = KotlinGenerator.generate(file) + val tsCode = TypeScriptGenerator.generate(file) + val pythonCode = PythonGenerator.generate(file) + + println("--- Java (Assert) ---") + println(javaCode) + println("--- Kotlin (Assert) ---") + println(kotlinCode) + println("--- TypeScript (Assert) ---") + println(tsCode) + println("--- Python (Assert) ---") + println(pythonCode) + + // Java + assertTrue(javaCode.contains("assert value > 0 : \"Value must be positive\";")) + + // Kotlin + assertTrue(kotlinCode.contains("assert(value > 0) { \"Value must be positive\" }")) + + // TypeScript + assertTrue( + tsCode.contains( + "if (!(value > 0)) throw new Error('Value must be positive');" + + "", + ), + ) + + // Python + assertTrue(pythonCode.contains("assert value > 0, 'Value must be positive'")) + } + + @Test + fun testDataObjectWithFields() { + val file = file("DataObjectModule") { + struct("Response400") { + implements(type("Response4XX", Type.Unit)) + implements(type("ResponseUnit")) + field("status", integer, isOverride = true) + field("headers", type("Headers"), isOverride = true) + field("body", Type.Unit, isOverride = true) + constructo { + assign("this.status", RawExpression("400")) + assign("this.headers", RawExpression("Headers")) + assign("this.body", RawExpression("Unit")) + } + struct("Headers") { + implements(type("Wirespec.Response.Headers")) + } + } + } + + val kotlinCode = KotlinGenerator.generate(file) + + println("--- Kotlin (Data Object With Fields) ---") + println(kotlinCode) + + // Should be data object, not data class + assertTrue(kotlinCode.contains("data object Response400")) + // Should implement interfaces + assertTrue(kotlinCode.contains("Response4XX")) + assertTrue(kotlinCode.contains("ResponseUnit")) + // Override fields with default values from constructor + assertTrue(kotlinCode.contains("override val status: Int")) + assertTrue(kotlinCode.contains("override val headers: Headers")) + assertTrue(kotlinCode.contains("override val body: Unit")) + // Nested struct should be present + assertTrue(kotlinCode.contains("object Headers : Wirespec.Response.Headers")) + // Should NOT contain data class + assertTrue(!kotlinCode.contains("data class Response400")) + } +} diff --git a/src/compiler/ir/src/commonTest/kotlin/community/flock/wirespec/ir/generator/ExtensionsTest.kt b/src/compiler/ir/src/commonTest/kotlin/community/flock/wirespec/ir/generator/ExtensionsTest.kt new file mode 100644 index 000000000..fde1ee607 --- /dev/null +++ b/src/compiler/ir/src/commonTest/kotlin/community/flock/wirespec/ir/generator/ExtensionsTest.kt @@ -0,0 +1,16 @@ +package community.flock.wirespec.ir.generator + +import community.flock.wirespec.ir.core.File +import community.flock.wirespec.ir.core.Name +import kotlin.test.Test +import kotlin.test.assertNotNull + +class ExtensionsTest { + @Test + fun testExtensions() { + val file = File(Name.of(""), emptyList()) + assertNotNull(file.generateJava()) + assertNotNull(file.generatePython()) + assertNotNull(file.generateTypeScript()) + } +} diff --git a/src/compiler/ir/src/commonTest/kotlin/community/flock/wirespec/ir/generator/JavaStructTest.kt b/src/compiler/ir/src/commonTest/kotlin/community/flock/wirespec/ir/generator/JavaStructTest.kt new file mode 100644 index 000000000..65a5d5ed0 --- /dev/null +++ b/src/compiler/ir/src/commonTest/kotlin/community/flock/wirespec/ir/generator/JavaStructTest.kt @@ -0,0 +1,172 @@ +package community.flock.wirespec.ir.generator + +import community.flock.wirespec.ir.core.file +import kotlin.test.Test +import kotlin.test.assertContains + +class JavaStructTest { + + @Test + fun testEmptyStructEmitAsRecord() { + val file = file("EmptyStruct") { + `package`("com.example") + struct("Empty") {} + } + + val output = JavaGenerator.generate(file) + + // Should be a record + assertContains(output, "public record Empty") + } + + @Test + fun testEmptyStructWithInterfaceEmitAsRecord() { + val file = file("EmptyRecordWithInterface") { + `package`("com.example") + union("MyUnion") { + member("Empty") + } + struct("Empty") { + implements(type("MyUnion")) + } + } + + val output = JavaGenerator.generate(file) + + // Should be a record + assertContains(output, "public record Empty") + assertContains(output, "implements MyUnion") + } + + @Test + fun testEmptyStructExtendingInterfaceEmitAsRecord() { + val file = file("EmptyStructExtendingInterface") { + `package`("com.example") + `interface`("MyInterface") {} + struct("MyStruct") { + implements(type("MyInterface")) + } + } + + val output = JavaGenerator.generate(file) + + // Should be a record + assertContains(output, "public record MyStruct") + assertContains(output, "implements MyInterface") + } + + @Test + fun testEmptyStructExtendingClassEmitAsClass() { + val file = file("EmptyStructExtendingClass") { + `package`("com.example") + struct("MyClass") {} + struct("MyStruct") { + implements(type("MyClass")) + } + } + + val output = JavaGenerator.generate(file) + + // Should be a record implementing the struct (treated as interface) + assertContains(output, "public record MyStruct") + assertContains(output, "implements MyClass") + } + + @Test + fun testStructWithFieldsEmitAsRecord() { + val file = file("FieldStruct") { + `package`("com.example") + struct("WithFields") { + field("id", string) + } + } + + val output = JavaGenerator.generate(file) + + // Should be a record + assertContains(output, "public record WithFields") + assertContains(output, "String id") + } + + @Test + fun testStructWithExtendsEmitAsClass() { + val file = file("ExtendsStruct") { + `package`("com.example") + struct("Extending") { + implements(type("Base")) + field("id", string) + } + } + + val output = JavaGenerator.generate(file) + + // Should be a record implementing the interface (unknown type treated as interface) + assertContains(output, "public record Extending") + assertContains(output, "implements Base") + } + + @Test + fun testRecordWithInterface() { + val file = file("RecordWithInterface") { + `package`("com.example") + union("MyUnion") { + member("MyRecord") + } + struct("MyRecord") { + implements(type("MyUnion")) + field("id", string) + } + } + + val output = JavaGenerator.generate(file) + + // Should be a record implementing the union + assertContains(output, "public record MyRecord") + assertContains(output, "implements MyUnion") + // Check syntax: record MyRecord (...) implements MyUnion { + assertContains(output, "public record MyRecord (") + assertContains(output, "String id") + assertContains(output, ") implements MyUnion {") + } + + @Test + fun testStructExtendingInterfaceEmitAsRecord() { + val file = file("StructExtendingInterface") { + `package`("com.example") + `interface`("MyInterface") {} + struct("MyStruct") { + implements(type("MyInterface")) + field("id", string) + } + } + + val output = JavaGenerator.generate(file) + + // Should be a record because MyInterface is an interface + assertContains(output, "public record MyStruct") + assertContains(output, "implements MyInterface") + assertContains(output, "public record MyStruct (") + assertContains(output, "String id") + assertContains(output, ") implements MyInterface {") + } + + @Test + fun testStructExtendingStructEmitAsClass() { + val file = file("StructExtendingStruct") { + `package`("com.example") + struct("Base") { + field("name", string) + } + struct("Derived") { + implements(type("Base")) + field("id", string) + } + } + + val output = JavaGenerator.generate(file) + + // Should be a record implementing the struct (treated as interface) + assertContains(output, "public record Derived") + assertContains(output, "implements Base") + } +} diff --git a/src/compiler/ir/src/commonTest/kotlin/community/flock/wirespec/ir/generator/PetTest.kt b/src/compiler/ir/src/commonTest/kotlin/community/flock/wirespec/ir/generator/PetTest.kt new file mode 100644 index 000000000..1ae0cf2cc --- /dev/null +++ b/src/compiler/ir/src/commonTest/kotlin/community/flock/wirespec/ir/generator/PetTest.kt @@ -0,0 +1,324 @@ +package community.flock.wirespec.ir.generator + +import community.flock.wirespec.ir.core.EnumReference +import community.flock.wirespec.ir.core.Name +import community.flock.wirespec.ir.core.NullCheck +import community.flock.wirespec.ir.core.NullLiteral +import community.flock.wirespec.ir.core.RawExpression +import community.flock.wirespec.ir.core.Type +import community.flock.wirespec.ir.core.file +import kotlin.test.Test +import kotlin.test.assertTrue + +class PetTest { + + @Test + fun getTodos() { + val getTodos = file("GetTodos") { + `package`("community.flock.wirespec.generated.examples.spring.endpoint") + + import("community.flock.wirespec.generated.examples.spring.model", "Todo") + import("community.flock.wirespec.generated.examples.spring.model", "Error") + + import("community.flock.wirespec.java", "Wirespec") + + namespace("GetTodos", type("Wirespec.Endpoint")) { + struct("Path") { + implements(type("Wirespec.Path")) + } + + struct("Queries") { + implements(type("Wirespec.Queries")) + field("done", boolean.nullable()) + } + + struct("RequestHeaders") { + implements(type("Wirespec.Request.Headers")) + } + + struct("Request") { + implements(type("Wirespec.Request", Type.Unit)) + field("path", type("Path")) + field("method", type("Wirespec.Method")) + field("queries", type("Queries")) + field("headers", type("RequestHeaders")) + field("body", type("Void")) + constructo { + arg("done", boolean.nullable()) + assign("path", construct(type("Path"))) + assign("method", EnumReference(Type.Custom("Wirespec.Method"), Name.of("GET"))) + assign( + "queries", + construct(type("Queries")) { + arg("done", RawExpression("done")) + }, + ) + assign("headers", construct(type("RequestHeaders"))) + } + } + + union("Response", extends = type("Wirespec.Response")) { + member("Response2XX") + member("Response4XX") + } + + union("Response2XX", extends = type("Response")) { + member("Response200") + } + + union("Response4XX", extends = type("Response")) { + member("Response404") + } + + struct("Response200") { + implements(type("Response2XX")) + field("status", integer) + field("headers", type("Headers")) + field("body", list(type("Todo"))) + struct("Headers") { + implements(type("Wirespec.Response.Headers")) + } + constructo { + arg("body", list(type("Todo"))) + assign("status", literal(200)) + assign("headers", construct(type("Headers"))) + assign("body", RawExpression("body")) + } + } + + struct("Response404") { + implements(type("Response4XX")) + field("status", integer) + field("headers", type("Headers")) + field("body", type("Error")) + struct("Headers") { + implements(type("Wirespec.Response.Headers")) + } + constructo { + arg("body", type("Error")) + assign("status", literal(404)) + assign("headers", construct(type("Headers"))) + assign("body", RawExpression("body")) + } + } + + `interface`("Handler") { + extends(type("Wirespec.Handler")) + function("toRequest") { + returnType(type("Wirespec.RawRequest")) + arg("serialization", type("Wirespec.Serializer")) + arg("request", type("Request")) + returns( + construct(type("Wirespec.RawRequest")) { + arg("method", functionCall("request.method.name")) + arg("path", listOf(kotlin.collections.listOf(literal("todos")), string)) + arg( + "queries", + mapOf( + mapOf( + "done" to NullCheck( + expression = RawExpression("request.queries.done"), + body = functionCall("serialization.serializeParam") { + arg("value", RawExpression("it")) + arg( + "type", + functionCall("getType", receiver = RawExpression("Wirespec")) { + arg("type", RawExpression("Boolean.class")) + arg("container", RawExpression("java.util.Optional.class")) + }, + ) + }, + alternative = emptyList(string), + ), + ), + string, + string, + ), + ) + arg("headers", emptyMap(string, string)) + arg("body", NullLiteral) + }, + ) + } + + function("fromRequest") { + returnType(type("Request")) + arg("serialization", type("Wirespec.Deserializer")) + arg("request", type("Wirespec.RawRequest")) + returns( + construct(type("Request")) { + arg( + "done", + NullCheck( + expression = functionCall("request.queries().get") { + arg("key", literal("done")) + }, + body = functionCall("serialization.deserializeParam") { + arg("value", RawExpression("it")) + arg( + "type", + functionCall("getType", receiver = RawExpression("Wirespec")) { + arg("type", RawExpression("Boolean.class")) + arg("container", RawExpression("java.util.Optional.class")) + }, + ) + }, + alternative = NullLiteral, + ), + ) + }, + ) + } + + function("fromResponse") { + returnType(type("Response")) + arg("serialization", type("Wirespec.Deserializer")) + arg("response", type("Wirespec.RawResponse")) + switch(functionCall("response.statusCode")) { + case(literal(200)) { + returns( + construct(type("Response200")) { + arg( + "body", + functionCall("serialization.deserializeBody") { + arg("body", RawExpression("response.body()")) + arg( + "type", + functionCall("getType", receiver = RawExpression("Wirespec")) { + arg("type", RawExpression("Todo.class")) + arg("container", RawExpression("java.util.List.class")) + }, + ) + }, + ) + }, + ) + } + case(literal(404)) { + returns( + construct(type("Response404")) { + arg( + "body", + functionCall("serialization.deserializeBody") { + arg("body", functionCall("response.body")) + arg( + "type", + functionCall("getType", receiver = RawExpression("Wirespec")) { + arg("type", RawExpression("Error.class")) + arg("container", NullLiteral) + }, + ) + }, + ) + }, + ) + } + default { + error(RawExpression("\"Cannot match response with status: \" + response.statusCode()")) + } + } + } + + function("toResponse") { + returnType(type("Wirespec.RawResponse")) + arg("serialization", type("Wirespec.Serializer")) + arg("response", type("Response")) + switch(RawExpression("response"), "r") { + case(type("Response200")) { + returns( + construct(type("Wirespec.RawResponse")) { + arg("statusCode", functionCall("r.status")) + arg("headers", RawExpression("java.util.Collections.emptyMap()")) + arg( + "body", + functionCall("serialization.serializeBody") { + arg("body", RawExpression("r.body")) + arg( + "type", + functionCall("getType", receiver = RawExpression("Wirespec")) { + arg("type", RawExpression("Todo.class")) + arg("container", RawExpression("java.util.List.class")) + }, + ) + }, + ) + }, + ) + } + case(type("Response404")) { + returns( + construct(type("Wirespec.RawResponse")) { + arg("statusCode", functionCall("r.status")) + arg("headers", RawExpression("java.util.Collections.emptyMap()")) + arg( + "body", + functionCall("serialization.serializeBody") { + arg("body", RawExpression("r.body")) + arg( + "type", + functionCall("getType", receiver = RawExpression("Wirespec")) { + arg("type", RawExpression("Error.class")) + arg("container", NullLiteral) + }, + ) + }, + ) + }, + ) + } + default { + error(RawExpression("\"Cannot match response with status: \" + response.status()")) + } + } + } + + asyncFunction("getTodos") { + returnType(type("Response")) + } + } + } + } + + val output = JavaGenerator.generate(getTodos) + + println(output) + + assertTrue(output.contains("package community.flock.wirespec.generated.examples.spring.endpoint;")) + assertTrue(output.contains("import community.flock.wirespec.java.Wirespec;")) + assertTrue(output.contains("import community.flock.wirespec.generated.examples.spring.model.Todo;")) + assertTrue(output.contains("import community.flock.wirespec.generated.examples.spring.model.Error;")) + // Verify the imports are correctly composed from path + type + + assertTrue(output.contains("public sealed interface Response extends Wirespec.Response permits Response2XX, Response4XX {}")) + assertTrue(output.contains("public sealed interface Response2XX extends Response permits Response200 {}")) + assertTrue(output.contains("public sealed interface Response4XX extends Response permits Response404 {}")) + + assertTrue(output.contains("public static record Response200 (")) + assertTrue(output.contains("java.util.List body")) + assertTrue(output.contains(") implements Response2XX {")) + assertTrue(output.contains("public static record Headers () implements Wirespec.Response.Headers {")) + + assertTrue(output.contains("public static record Response404 (")) + assertTrue(output.contains("Error body")) + assertTrue(output.contains(") implements Response4XX {")) + + assertTrue(output.contains("public static record Queries (")) + assertTrue(output.contains("java.util.Optional done")) + assertTrue(output.contains(") implements Wirespec.Queries {")) + assertTrue(output.contains("public static record Request (")) + assertTrue(output.contains(") implements Wirespec.Request {")) + assertTrue(output.contains("public Request(java.util.Optional done) {")) + assertTrue(output.contains("request.method.name()")) + assertTrue(output.contains("java.util.List.of(\"todos\")")) + assertTrue(output.contains("java.util.Map.ofEntries(java.util.Map.entry(\"done\", java.util.Optional.ofNullable(request.queries.done).map(it -> serialization.serializeParam(it, Wirespec.getType(Boolean.class, java.util.Optional.class))).orElse(java.util.List.of())))")) + assertTrue(output.contains("java.util.Collections.emptyMap()")) + assertTrue(output.contains("null")) + assertTrue(output.contains("public interface Handler extends Wirespec.Handler {")) + assertTrue(output.contains("public java.util.concurrent.CompletableFuture getTodos();")) + assertTrue(output.contains("public default Wirespec.RawResponse toResponse(Wirespec.Serializer serialization, Response response) {")) + assertTrue(output.contains("if (response instanceof Response200 r) {")) + assertTrue(output.contains("else if (response instanceof Response404 r) {")) + assertTrue(output.contains("return new Wirespec.RawResponse(")) + assertTrue(output.contains("else {")) + } +} diff --git a/src/plugin/arguments/build.gradle.kts b/src/plugin/arguments/build.gradle.kts index bdec03699..c64b028f9 100644 --- a/src/plugin/arguments/build.gradle.kts +++ b/src/plugin/arguments/build.gradle.kts @@ -39,6 +39,8 @@ kotlin { api(project(":src:compiler:emitters:java")) api(project(":src:compiler:emitters:typescript")) api(project(":src:compiler:emitters:python")) + api(project(":src:compiler:emitters:rust")) + api(project(":src:compiler:emitters:scala")) api(project(":src:compiler:emitters:wirespec")) implementation(project(":src:converter:avro")) implementation(project(":src:converter:openapi")) diff --git a/src/plugin/arguments/src/commonMain/kotlin/community/flock/wirespec/plugin/Language.kt b/src/plugin/arguments/src/commonMain/kotlin/community/flock/wirespec/plugin/Language.kt index 04b53f3f9..e30c637da 100644 --- a/src/plugin/arguments/src/commonMain/kotlin/community/flock/wirespec/plugin/Language.kt +++ b/src/plugin/arguments/src/commonMain/kotlin/community/flock/wirespec/plugin/Language.kt @@ -4,9 +4,15 @@ import community.flock.wirespec.compiler.core.emit.EmitShared import community.flock.wirespec.compiler.core.emit.PackageName import community.flock.wirespec.converter.avro.AvroEmitter import community.flock.wirespec.emitters.java.JavaEmitter +import community.flock.wirespec.emitters.java.JavaIrEmitter import community.flock.wirespec.emitters.kotlin.KotlinEmitter +import community.flock.wirespec.emitters.kotlin.KotlinIrEmitter import community.flock.wirespec.emitters.python.PythonEmitter +import community.flock.wirespec.emitters.python.PythonIrEmitter +import community.flock.wirespec.emitters.rust.RustIrEmitter +import community.flock.wirespec.emitters.scala.ScalaIrEmitter import community.flock.wirespec.emitters.typescript.TypeScriptEmitter +import community.flock.wirespec.emitters.typescript.TypeScriptIrEmitter import community.flock.wirespec.emitters.wirespec.WirespecEmitter import community.flock.wirespec.openapi.v2.OpenAPIV2Emitter import community.flock.wirespec.openapi.v3.OpenAPIV3Emitter @@ -16,6 +22,8 @@ enum class Language { Kotlin, TypeScript, Python, + Rust, + Scala, Wirespec, OpenAPIV2, OpenAPIV3, @@ -32,9 +40,24 @@ fun Language.toEmitter(packageName: PackageName, emitShared: EmitShared) = when Language.Java -> JavaEmitter(packageName, emitShared) Language.Kotlin -> KotlinEmitter(packageName, emitShared) Language.Python -> PythonEmitter(packageName, emitShared) + Language.Rust -> RustIrEmitter(packageName, emitShared) + Language.Scala -> ScalaIrEmitter(packageName, emitShared) Language.TypeScript -> TypeScriptEmitter() Language.Wirespec -> WirespecEmitter() Language.OpenAPIV2 -> OpenAPIV2Emitter Language.OpenAPIV3 -> OpenAPIV3Emitter Language.Avro -> AvroEmitter } + +fun Language.toIrEmitter(packageName: PackageName, emitShared: EmitShared) = when (this) { + Language.Java -> JavaIrEmitter(packageName, emitShared) + Language.Kotlin -> KotlinIrEmitter(packageName, emitShared) + Language.Python -> PythonIrEmitter(packageName, emitShared) + Language.Rust -> RustIrEmitter(packageName, emitShared) + Language.Scala -> ScalaIrEmitter(packageName, emitShared) + Language.TypeScript -> TypeScriptIrEmitter() + Language.Wirespec -> WirespecEmitter() + Language.OpenAPIV2 -> OpenAPIV2Emitter + Language.OpenAPIV3 -> OpenAPIV3Emitter + Language.Avro -> AvroEmitter +} diff --git a/src/plugin/arguments/src/commonMain/kotlin/community/flock/wirespec/plugin/WirespecArguments.kt b/src/plugin/arguments/src/commonMain/kotlin/community/flock/wirespec/plugin/WirespecArguments.kt index a36aa9072..7e5d244ae 100644 --- a/src/plugin/arguments/src/commonMain/kotlin/community/flock/wirespec/plugin/WirespecArguments.kt +++ b/src/plugin/arguments/src/commonMain/kotlin/community/flock/wirespec/plugin/WirespecArguments.kt @@ -20,6 +20,7 @@ sealed interface WirespecArguments { val logger: Logger val shared: Boolean val strict: Boolean + val ir: Boolean } data class CompilerArguments( @@ -31,6 +32,7 @@ data class CompilerArguments( override val logger: Logger, override val shared: Boolean, override val strict: Boolean, + override val ir: Boolean, ) : WirespecArguments data class ConverterArguments( @@ -43,6 +45,7 @@ data class ConverterArguments( override val logger: Logger, override val shared: Boolean, override val strict: Boolean, + override val ir: Boolean, ) : WirespecArguments fun PackageName?.toDirectory() = this?.value diff --git a/src/plugin/arguments/src/commonTest/kotlin/community/flock/wirespec/plugin/LanguageTest.kt b/src/plugin/arguments/src/commonTest/kotlin/community/flock/wirespec/plugin/LanguageTest.kt index 3f7445818..4805818a9 100644 --- a/src/plugin/arguments/src/commonTest/kotlin/community/flock/wirespec/plugin/LanguageTest.kt +++ b/src/plugin/arguments/src/commonTest/kotlin/community/flock/wirespec/plugin/LanguageTest.kt @@ -6,6 +6,6 @@ import kotlin.test.Test class LanguageTest { @Test fun testLanguages() { - Language.toString() shouldBe "Java, Kotlin, TypeScript, Python, Wirespec, OpenAPIV2, OpenAPIV3, Avro" + Language.toString() shouldBe "Java, Kotlin, TypeScript, Python, Rust, Scala, Wirespec, OpenAPIV2, OpenAPIV3, Avro" } } diff --git a/src/plugin/cli/src/commonMain/kotlin/community/flock/wirespec/plugin/cli/CommandLineArgumentsParser.kt b/src/plugin/cli/src/commonMain/kotlin/community/flock/wirespec/plugin/cli/CommandLineArgumentsParser.kt index 997e62611..1d32d5bf7 100644 --- a/src/plugin/cli/src/commonMain/kotlin/community/flock/wirespec/plugin/cli/CommandLineArgumentsParser.kt +++ b/src/plugin/cli/src/commonMain/kotlin/community/flock/wirespec/plugin/cli/CommandLineArgumentsParser.kt @@ -46,6 +46,7 @@ import community.flock.wirespec.plugin.io.read import community.flock.wirespec.plugin.io.wirespecSources import community.flock.wirespec.plugin.io.write import community.flock.wirespec.plugin.toEmitter +import community.flock.wirespec.plugin.toIrEmitter enum class Options(vararg val flags: String) { Input("-i", "--input"), @@ -55,6 +56,7 @@ enum class Options(vararg val flags: String) { LogLevel("--log-level"), Shared("--shared"), Strict("--strict"), + Ir("--ir"), } class WirespecCli : NoOpCliktCommand(name = "wirespec") { @@ -74,6 +76,7 @@ private abstract class CommonOptions : CliktCommand() { val logLevel by option(*Options.LogLevel.flags, help = "Log level: $Level").default("$INFO") val shared by option(*Options.Shared.flags, help = "Generate shared wirespec code").flag(default = false) val strict by option(*Options.Strict.flags, help = "Strict mode").flag() + val ir by option(*Options.Ir.flags, help = "Output intermediate representation").flag(default = false) fun String.toLogLevel() = when (trim().uppercase()) { "DEBUG" -> DEBUG @@ -119,7 +122,7 @@ private class Compile( } } - val emitters = languages.toEmitters(PackageName(packageName), EmitShared(shared)) + val emitters = languages.toEmitters(PackageName(packageName), EmitShared(shared), ir) val outputDir = inputPath?.let { Directory(getOutPutPath(it, output).or(::handleError)) } CompilerArguments( @@ -131,6 +134,7 @@ private class Compile( logger = logger, shared = shared, strict = strict, + ir = ir, ).let(compiler) } } @@ -159,7 +163,7 @@ private class Convert( .also { logger.info("Found 1 file to process: $inputPath") } } - val emitters = languages.toEmitters(PackageName(packageName), EmitShared(shared)) + val emitters = languages.toEmitters(PackageName(packageName), EmitShared(shared), ir) val directory = inputPath?.let { Directory(getOutPutPath(it, output).or(::handleError)) } ConverterArguments( format = format, @@ -171,12 +175,13 @@ private class Convert( logger = logger, shared = shared, strict = strict, + ir = ir, ).let(converter) } } private fun handleError(string: String): Nothing = throw CliktError(string) -private fun List.toEmitters(packageName: PackageName, emitShared: EmitShared) = this - .map { it.toEmitter(packageName, emitShared) } +private fun List.toEmitters(packageName: PackageName, emitShared: EmitShared, ir: Boolean = false) = this + .map { if (ir) it.toIrEmitter(packageName, emitShared) else it.toEmitter(packageName, emitShared) } .toNonEmptySetOrNull() ?: nonEmptySetOf(WirespecEmitter()) diff --git a/src/plugin/cli/src/commonTest/kotlin/community/flock/wirespec/plugin/cli/CommandLineEntitiesTest.kt b/src/plugin/cli/src/commonTest/kotlin/community/flock/wirespec/plugin/cli/CommandLineEntitiesTest.kt index bbc2b84b6..ed5d424ad 100644 --- a/src/plugin/cli/src/commonTest/kotlin/community/flock/wirespec/plugin/cli/CommandLineEntitiesTest.kt +++ b/src/plugin/cli/src/commonTest/kotlin/community/flock/wirespec/plugin/cli/CommandLineEntitiesTest.kt @@ -29,6 +29,7 @@ class CommandLineEntitiesTest { // To enable flags they only need the flag name. Therefore, the 'argument' part is null. Options.Shared -> null Options.Strict -> null + Options.Ir -> null }, ) }.toTypedArray() @@ -45,6 +46,7 @@ class CommandLineEntitiesTest { } it.shared.also(::println) shouldBe true it.strict shouldBe true + it.ir shouldBe true }, noopConverter {}, ).main(arrayOf("compile") + opts) @@ -67,6 +69,7 @@ class CommandLineEntitiesTest { } it.shared shouldBe false it.strict shouldBe false + it.ir shouldBe false }, ).main(arrayOf("convert", "-i", "src/commonTest/resources/openapi/keto.json", "openapiv2")) } diff --git a/src/plugin/gradle/src/main/kotlin/BaseWirespecTask.kt b/src/plugin/gradle/src/main/kotlin/BaseWirespecTask.kt index 62a8b65cc..ff04669f2 100644 --- a/src/plugin/gradle/src/main/kotlin/BaseWirespecTask.kt +++ b/src/plugin/gradle/src/main/kotlin/BaseWirespecTask.kt @@ -16,6 +16,7 @@ import community.flock.wirespec.plugin.io.Name import community.flock.wirespec.plugin.io.Source import community.flock.wirespec.plugin.io.write import community.flock.wirespec.plugin.toEmitter +import community.flock.wirespec.plugin.toIrEmitter import org.gradle.api.DefaultTask import org.gradle.api.file.DirectoryProperty import org.gradle.api.provider.ListProperty @@ -58,6 +59,11 @@ abstract class BaseWirespecTask : DefaultTask() { @get:Option(option = "strict", description = "strict parsing mode") abstract val strict: Property + @get:Input + @get:Optional + @get:Option(option = "ir", description = "output intermediate representation") + abstract val ir: Property + @Internal val wirespecLogger = object : Logger(Level.INFO) { override fun debug(string: String) = logger.debug(string) @@ -88,7 +94,7 @@ abstract class BaseWirespecTask : DefaultTask() { } protected fun emitters() = languages.get() - .map { it.toEmitter(packageNameValue(), sharedValue()) } + .map { if (ir.getOrElse(false)) it.toIrEmitter(packageNameValue(), sharedValue()) else it.toEmitter(packageNameValue(), sharedValue()) } .plus(emitter()) .mapNotNull { it } .toNonEmptySetOrNull() diff --git a/src/plugin/gradle/src/main/kotlin/CompileWirespecTask.kt b/src/plugin/gradle/src/main/kotlin/CompileWirespecTask.kt index c39cfb9b0..23e1427e7 100644 --- a/src/plugin/gradle/src/main/kotlin/CompileWirespecTask.kt +++ b/src/plugin/gradle/src/main/kotlin/CompileWirespecTask.kt @@ -50,6 +50,7 @@ abstract class CompileWirespecTask : BaseWirespecTask() { logger = wirespecLogger, shared = shared.getOrElse(true), strict = strict.getOrElse(false), + ir = ir.getOrElse(false), ).let(::compile) } } diff --git a/src/plugin/gradle/src/main/kotlin/ConvertWirespecTask.kt b/src/plugin/gradle/src/main/kotlin/ConvertWirespecTask.kt index bba6a49bb..b5cabe6e2 100644 --- a/src/plugin/gradle/src/main/kotlin/ConvertWirespecTask.kt +++ b/src/plugin/gradle/src/main/kotlin/ConvertWirespecTask.kt @@ -64,6 +64,7 @@ abstract class ConvertWirespecTask : BaseWirespecTask() { logger = wirespecLogger, shared = shared.getOrElse(true), strict = strict.getOrElse(false), + ir = ir.getOrElse(false), ).let(::convert) } } diff --git a/src/plugin/maven/src/main/kotlin/community/flock/wirespec/plugin/maven/mojo/BaseMojo.kt b/src/plugin/maven/src/main/kotlin/community/flock/wirespec/plugin/maven/mojo/BaseMojo.kt index 40a4dab4a..d8075c7c6 100644 --- a/src/plugin/maven/src/main/kotlin/community/flock/wirespec/plugin/maven/mojo/BaseMojo.kt +++ b/src/plugin/maven/src/main/kotlin/community/flock/wirespec/plugin/maven/mojo/BaseMojo.kt @@ -19,6 +19,7 @@ import community.flock.wirespec.plugin.io.write import community.flock.wirespec.plugin.maven.compiler.JavaCompiler import community.flock.wirespec.plugin.maven.compiler.KotlinCompiler import community.flock.wirespec.plugin.toEmitter +import community.flock.wirespec.plugin.toIrEmitter import org.apache.maven.plugin.AbstractMojo import org.apache.maven.plugins.annotations.Parameter import org.apache.maven.project.MavenProject @@ -71,6 +72,12 @@ abstract class BaseMojo : AbstractMojo() { @Parameter protected var strict: Boolean = true + /** + * Specifies whether to output intermediate representation. Default 'false'. + */ + @Parameter + protected var ir: Boolean = false + /** * Source directory. Default 'null'. */ @@ -87,26 +94,28 @@ abstract class BaseMojo : AbstractMojo() { override fun error(string: String) = log.error(string) } - private fun emitter() = if (emitterClass != null) { - val clazz = getClassLoader(project).loadClass(emitterClass) ?: error("No class found: $emitterClass") - val constructor = clazz.constructors.first() ?: error("No constructor found: $emitterClass") - val args: List = constructor.parameters - .map { - when (it.type) { - PackageName::class.java -> PackageName(packageName) - EmitShared::class.java -> EmitShared(shared) - else -> error("Cannot map constructor parameter: $emitterClass - ${it.type.simpleName}") + private val emitter + get() = try { + val clazz = getClassLoader(project).loadClass(emitterClass) + val constructor = clazz.constructors.first() + val args: List = constructor.parameters + .map { + when (it.type) { + PackageName::class.java -> PackageName(packageName) + EmitShared::class.java -> EmitShared(shared) + else -> error("Cannot map constructor parameter") + } } - } - constructor.newInstance(*args.toTypedArray()) as Emitter - } else { - null - } + constructor.newInstance(*args.toTypedArray()) as Emitter + } catch (e: Exception) { + logger.debug(e.toString()) + null + } val emitters get() = languages - .map { it.toEmitter(PackageName(packageName), EmitShared(shared)) } - .plus(emitter()) + .map { if (ir) it.toIrEmitter(PackageName(packageName), EmitShared(shared)) else it.toEmitter(PackageName(packageName), EmitShared(shared)) } + .plus(emitter) .mapNotNull { it } .toNonEmptySetOrNull() ?: throw PickAtLeastOneLanguageOrEmitter() diff --git a/src/plugin/maven/src/main/kotlin/community/flock/wirespec/plugin/maven/mojo/CompileMojo.kt b/src/plugin/maven/src/main/kotlin/community/flock/wirespec/plugin/maven/mojo/CompileMojo.kt index 5ca92a219..5d5b74542 100644 --- a/src/plugin/maven/src/main/kotlin/community/flock/wirespec/plugin/maven/mojo/CompileMojo.kt +++ b/src/plugin/maven/src/main/kotlin/community/flock/wirespec/plugin/maven/mojo/CompileMojo.kt @@ -53,6 +53,7 @@ class CompileMojo : BaseMojo() { logger = logger, shared = shared, strict = strict, + ir = ir, ).let(::compile) } } diff --git a/src/plugin/maven/src/main/kotlin/community/flock/wirespec/plugin/maven/mojo/ConvertMojo.kt b/src/plugin/maven/src/main/kotlin/community/flock/wirespec/plugin/maven/mojo/ConvertMojo.kt index 5db0ad917..0dbbb2375 100644 --- a/src/plugin/maven/src/main/kotlin/community/flock/wirespec/plugin/maven/mojo/ConvertMojo.kt +++ b/src/plugin/maven/src/main/kotlin/community/flock/wirespec/plugin/maven/mojo/ConvertMojo.kt @@ -136,6 +136,7 @@ class ConvertMojo : BaseMojo() { logger = logger, shared = shared, strict = strict, + ir = ir, ).let(::convert) } } diff --git a/src/plugin/npm/src/jsMain/kotlin/community/flock/wirespec/plugin/npm/Main.kt b/src/plugin/npm/src/jsMain/kotlin/community/flock/wirespec/plugin/npm/Main.kt index 7dc19fd93..601d35eea 100644 --- a/src/plugin/npm/src/jsMain/kotlin/community/flock/wirespec/plugin/npm/Main.kt +++ b/src/plugin/npm/src/jsMain/kotlin/community/flock/wirespec/plugin/npm/Main.kt @@ -23,14 +23,10 @@ import community.flock.wirespec.compiler.utils.NoLogger import community.flock.wirespec.compiler.utils.noLogger import community.flock.wirespec.converter.avro.AvroEmitter import community.flock.wirespec.converter.avro.AvroParser -import community.flock.wirespec.emitters.java.JavaEmitter -import community.flock.wirespec.emitters.java.JavaShared -import community.flock.wirespec.emitters.kotlin.KotlinEmitter -import community.flock.wirespec.emitters.kotlin.KotlinShared -import community.flock.wirespec.emitters.python.PythonEmitter -import community.flock.wirespec.emitters.python.PythonShared -import community.flock.wirespec.emitters.typescript.TypeScriptEmitter -import community.flock.wirespec.emitters.typescript.TypeScriptShared +import community.flock.wirespec.emitters.java.JavaIrEmitter +import community.flock.wirespec.emitters.kotlin.KotlinIrEmitter +import community.flock.wirespec.emitters.python.PythonIrEmitter +import community.flock.wirespec.emitters.typescript.TypeScriptIrEmitter import community.flock.wirespec.emitters.wirespec.WirespecEmitter import community.flock.wirespec.generator.generate import community.flock.wirespec.openapi.v2.OpenAPIV2Emitter @@ -44,10 +40,10 @@ import kotlinx.serialization.json.Json @JsExport enum class Shared(val source: String) { - KOTLIN(KotlinShared.source), - JAVA(JavaShared.source), - TYPESCRIPT(TypeScriptShared.source), - PYTHON(PythonShared.source), + KOTLIN(KotlinIrEmitter().shared.source), + JAVA(JavaIrEmitter().shared.source), + TYPESCRIPT(TypeScriptIrEmitter().shared.source), + PYTHON(PythonIrEmitter().shared.source), } @JsExport @@ -98,43 +94,23 @@ fun generate(source: String, type: String): WsStringResult = object : ParseConte fun emit(wsAst: WsAST, emitter: Emitters, packageName: String, emitShared: Boolean): Array { val ast = wsAst.consume() return when (emitter) { - Emitters.WIRESPEC -> ast.modules.flatMap { WirespecEmitter().emit(it, noLogger) } - Emitters.TYPESCRIPT -> ast.modules.flatMap { TypeScriptEmitter().emit(it, noLogger) } - Emitters.JAVA -> ast.modules.flatMap { - JavaEmitter(PackageName(packageName), EmitShared(emitShared)).emit( - it, - noLogger, - ) - } - - Emitters.KOTLIN -> ast.modules.flatMap { - KotlinEmitter(PackageName(packageName), EmitShared(emitShared)).emit( - it, - noLogger, - ) - } - - Emitters.PYTHON -> ast.modules.flatMap { - PythonEmitter(PackageName(packageName), EmitShared(emitShared)).emit( - it, - noLogger, - ) - } - + Emitters.WIRESPEC -> WirespecEmitter().emit(ast, noLogger) + Emitters.TYPESCRIPT -> TypeScriptIrEmitter().emit(ast, noLogger) + Emitters.JAVA -> JavaIrEmitter(PackageName(packageName), EmitShared(emitShared)).emit(ast, noLogger) + Emitters.KOTLIN -> KotlinIrEmitter(PackageName(packageName), EmitShared(emitShared)).emit(ast, noLogger) + Emitters.PYTHON -> PythonIrEmitter(PackageName(packageName), EmitShared(emitShared)).emit(ast, noLogger) Emitters.OPENAPI_V2 -> OpenAPIV2Emitter .emitSwaggerObject(ast.modules.flatMap { it.statements }, noLogger) .let(encode(OpenAPIV2Model.serializer())) .let { Emitted("openapi", it) } .let { nonEmptyListOf(it) } - Emitters.OPENAPI_V3 -> OpenAPIV3Emitter .emitOpenAPIObject(ast.modules.flatMap { it.statements }, null, noLogger) .let(encode(OpenAPIV3Model.serializer())) .let { Emitted("openapi", it) } .let { nonEmptyListOf(it) } - Emitters.AVRO -> ast.modules .map { ast -> AvroEmitter.emit(ast) } diff --git a/src/plugin/npm/src/jsMain/resources/wirespec-fetch.d.ts b/src/plugin/npm/src/jsMain/resources/wirespec-fetch.d.ts index 51259a8f7..8aa542c87 100644 --- a/src/plugin/npm/src/jsMain/resources/wirespec-fetch.d.ts +++ b/src/plugin/npm/src/jsMain/resources/wirespec-fetch.d.ts @@ -1,6 +1,9 @@ export type Method = "GET" | "PUT" | "POST" | "DELETE" | "OPTIONS" | "HEAD" | "PATCH" | "TRACE" export type RawRequest = { method: Method, path: string[], queries: Record, headers: Record, body?: string } +export type RawRequestIr = { method: string, path: string[], queries: Record, headers: Record, body: Uint8Array | undefined } export type RawResponse = { status: number, headers: Record, body?: string } +export type RawResponseIr = { statusCode:number, headers: Record, body: Uint8Array | undefined } export type HandleFetch = ( path:string, init?:RequestInit) => Promise -export declare function wirespecFetch (rawRequest:RawRequest, handle?: HandleFetch): Promise \ No newline at end of file +export declare function wirespecFetch (rawRequest:(RawRequest), handle?: HandleFetch): Promise +export declare function wirespecFetchIr (rawRequest:(RawRequestIr), handle?: HandleFetch): Promise \ No newline at end of file diff --git a/src/plugin/npm/src/jsMain/resources/wirespec-fetch.mjs b/src/plugin/npm/src/jsMain/resources/wirespec-fetch.mjs index b30a17723..6e3ad3a2c 100644 --- a/src/plugin/npm/src/jsMain/resources/wirespec-fetch.mjs +++ b/src/plugin/npm/src/jsMain/resources/wirespec-fetch.mjs @@ -31,3 +31,37 @@ export async function wirespecFetch(req, handler) { }; } + +export async function wirespecFetchIr(req, handler) { + const contentHeader = req.body ? { 'Content-Type': 'application/json' } : {}; + const body = req.body !== undefined ? req.body : undefined; + const query = Object.entries(req.queries) + .filter(([_, value]) => value !== undefined) + .flatMap(([key, value]) => { + if (value && typeof value === 'string' && value.startsWith('[') && value.endsWith(']')) { + const parsedValue = JSON.parse(value); + if (Array.isArray(parsedValue)) { + return parsedValue.map((item) => `${key}=${item}`); + } + } + return `${key}=${value}`; + }) + .join('&'); + const path = req.path + .map(segment => encodeURIComponent(segment)) + .join('/') + const url = `/${path}${query ? `?${query}` : ''}`; + const init = {method: req.method, body, headers: {...req.headers, ...contentHeader}} + const res = handler ? await handler(url, init) : await fetch(url, init) + const contentType = res.headers.get('Content-Type'); + const contentLength = res.headers.get('Content-Length'); + return { + statusCode: res.status, + headers: { + ...[...res.headers.entries()].reduce((acc, [key, value]) => ({...acc, [key]: [value]}), {}), + 'Content-Type': [contentType], + }, + body: contentLength !== '0' && contentType ? await res.text() : undefined, + }; + +} \ No newline at end of file diff --git a/src/plugin/npm/src/jsMain/resources/wirespec-serialization.d.ts b/src/plugin/npm/src/jsMain/resources/wirespec-serialization.d.ts index c653d2093..3ee4d50d1 100644 --- a/src/plugin/npm/src/jsMain/resources/wirespec-serialization.d.ts +++ b/src/plugin/npm/src/jsMain/resources/wirespec-serialization.d.ts @@ -1,3 +1,14 @@ -export type Serialization = { serialize: (typed: T) => string; deserialize: (raw: string | undefined) => T } +export type Type = string -export const wirespecSerialization: Serialization \ No newline at end of file +export interface Serialization { + serialize(typed: T): string; + deserialize(raw: string | undefined): T; + serializeBody(t: T, type: Type): Uint8Array; + deserializeBody(raw: Uint8Array, type: Type): T; + serializePath(t: T, type: Type): string; + deserializePath(raw: string, type: Type): T; + serializeParam(value: T, type: Type): string[]; + deserializeParam(values: string[], type: Type): T; +} + +export declare const wirespecSerialization: Serialization diff --git a/src/plugin/npm/src/jsMain/resources/wirespec-serialization.mjs b/src/plugin/npm/src/jsMain/resources/wirespec-serialization.mjs index 40d0f24b8..d9ae0f95a 100644 --- a/src/plugin/npm/src/jsMain/resources/wirespec-serialization.mjs +++ b/src/plugin/npm/src/jsMain/resources/wirespec-serialization.mjs @@ -1,3 +1,6 @@ +const encoder = new TextEncoder(); +const decoder = new TextDecoder(); + export const wirespecSerialization = { deserialize(raw) { if (raw === undefined) { @@ -14,4 +17,22 @@ export const wirespecSerialization = { return JSON.stringify(type); }, -}; \ No newline at end of file + serializeBody(t, _type) { + return encoder.encode(JSON.stringify(t)); + }, + deserializeBody(raw, _type) { + return JSON.parse(decoder.decode(raw)); + }, + serializePath(t, _type) { + return String(t); + }, + deserializePath(raw, _type) { + return raw; + }, + serializeParam(value, _type) { + return Array.isArray(value) ? value.map(String) : [String(value)]; + }, + deserializeParam(values, _type) { + return values[0]; + }, +}; diff --git a/src/site/docs/docs/intro/intro-ir.md b/src/site/docs/docs/intro/intro-ir.md new file mode 100644 index 000000000..c77f5146f --- /dev/null +++ b/src/site/docs/docs/intro/intro-ir.md @@ -0,0 +1,241 @@ +--- +title: IR Model +sidebar_position: 4 +--- + +# IR Model + +Wirespec compiles `.ws` source files into code for multiple target languages. After parsing Wirespec source into an AST, the compiler uses an **Intermediate Representation (IR)** — a language-neutral tree that sits between the parser and code generation. This page explains the IR pipeline and the model each Wirespec definition is converted into. + +## Pipeline Overview + +The IR pipeline transforms parsed Wirespec definitions into target-language source code through three stages: + +``` +Wirespec AST (definitions) + │ + ▼ + ┌─────────┐ + │ Convert │ AST → IR (language-neutral element tree) + └────┬─────┘ + │ + ▼ + ┌───────────┐ + │ Transform │ IR → IR (reshape for a specific target language) + └────┬──────┘ + │ + ▼ + ┌──────────┐ + │ Generate │ IR → Source code string + └──────────┘ +``` + +**Convert** maps each Wirespec definition to a language-neutral IR tree. It also produces a shared model containing the base interfaces all generated code depends on. + +**Transform** reshapes the IR to match the conventions of a specific target language (naming, types, patterns) without changing its structure. + +**Generate** walks the final IR tree and emits the target-language source code as a string. + +## Shared Model + +The converter produces a **shared model** — a single file called `Wirespec` that contains the base interfaces and types all generated code depends on. This file is emitted once per target language and provides the common vocabulary that generated types, endpoints, and channels build upon. + +The shared model is wrapped in a `Wirespec` namespace and defines: + +- **Core interfaces** — `Model`, `Enum`, `Refined`, `Endpoint`, and `Channel`. Every generated definition implements one of these. For example, all generated types implement `Wirespec.Model` which provides a `validate` function. +- **HTTP primitives** — A `Method` enum (`GET`, `POST`, etc.) and typed `Request` / `Response` interfaces that describe HTTP messages with path, queries, headers, and body. +- **Serialization contracts** — Serializer and deserializer interfaces for three layers: body (binary), path (string), and params (string lists). These combine into a unified `Serialization` interface. +- **Raw transport** — `RawRequest` and `RawResponse` structs for untyped HTTP messages, and a `Transportation` interface that users implement to plug in their HTTP client. + +Each language emitter transforms the shared model through the same Transform pipeline and may inject additional language-specific interfaces (for example, Java adds `Client` and `Server` interfaces for framework integration). + +## Type + +A Wirespec `type` is converted into a struct that implements `Wirespec.Model`. Each field in the Wirespec shape becomes a field in the struct, and a `validate` function is generated that returns a list of validation errors. + +```wirespec +type Address { + street: String, + number: Integer, + tags: String[] +} +``` + +``` +File("Address") { + Struct("Address") implements Wirespec.Model { + Field("street", String) + Field("number", Integer) + Field("tags", Array(String)) + Function("validate") → Array(String) + } +} +``` + +When a type has fields that reference other types or refined types, the `validate` function is extended with nested validation — calling `validate` on those fields and collecting the results with field path prefixes. + +## Enum + +A Wirespec `enum` is converted into an enum that extends `Wirespec.Enum`. Each entry becomes an enum member. + +```wirespec +enum Role { ADMIN, USER, GUEST } +``` + +``` +File("Role") { + Enum("Role") extends Wirespec.Enum { + Entry("ADMIN") + Entry("USER") + Entry("GUEST") + } +} +``` + +## Union + +A Wirespec `union` is converted into a union with members referencing each entry type. + +```wirespec +union Animal { Cat | Dog | Bird } +``` + +``` +File("Animal") { + Union("Animal") { + Member("Cat") + Member("Dog") + Member("Bird") + } +} +``` + +## Refined + +A Wirespec `refined` type is converted into a struct that implements `Wirespec.Refined`. It holds the underlying value and a `validate` function that checks the constraint (regex or bounds). + +```wirespec +refined Email /^[^@]+@[^@]+$/g +``` + +``` +File("Email") { + Struct("Email") implements Wirespec.Refined { + Field("value", String) + Function("validate") → Boolean // checks regex + } +} +``` + +## Channel + +A Wirespec `channel` is converted into an interface that extends `Wirespec.Channel` with an `invoke` function accepting the message type. + +```wirespec +channel OrderEvents -> OrderEvent +``` + +``` +File("OrderEvents") { + Interface("OrderEvents") extends Wirespec.Channel { + Function("invoke")(message: OrderEvent) → Unit + } +} +``` + +## Endpoint + +A Wirespec `endpoint` produces the most complex IR structure. It is wrapped in a namespace and contains everything needed for typed, serializable HTTP communication. + +```wirespec +endpoint GetTodos GET /todos?{done: Boolean} -> { + 200 -> Todo[] +} +``` + +``` +File("GetTodos") { + Namespace("GetTodos") extends Wirespec.Endpoint { + Struct("Path") implements Wirespec.Path {} + Struct("Queries") implements Wirespec.Queries { + Field("done", Boolean) + } + Struct("RequestHeaders") implements Wirespec.Request.Headers {} + Struct("Request") implements Wirespec.Request { + Field("path", Path) + Field("method", Wirespec.Method) + Field("queries", Queries) + Field("headers", RequestHeaders) + Field("body", Unit) + } + Union("Response") extends Wirespec.Response { + Member("Response2XX") + Member("ResponseTodo") + } + Union("Response2XX") extends Response { + Member("Response200") + } + Struct("Response200") { + Field("status", Integer) + Field("headers", Headers) + Field("body", Array(Todo)) + } + Function("toRawRequest")(serialization, request) → Wirespec.RawRequest + Function("fromRawRequest")(serialization, request) → Request + Function("toRawResponse")(serialization, response) → Wirespec.RawResponse + Function("fromRawResponse")(serialization, response) → Response + Interface("Handler") { + Function("getTodos")(request: Request) → Response<*> + } + } +} +``` + +The key parts of an endpoint IR model are: + +- **Path, Queries, RequestHeaders** — structs for the typed request components, implementing the corresponding shared interfaces +- **Request** — a struct implementing `Wirespec.Request` with a constructor that assembles path, method, queries, headers, and body +- **Response union hierarchy** — a `Response` union with intermediate unions grouped by status prefix (`Response2XX`) and content type (`ResponseTodo`), with concrete response structs for each status code +- **Conversion functions** — `toRawRequest`, `fromRawRequest`, `toRawResponse`, and `fromRawResponse` that bridge between the typed request/response structs and `Wirespec.RawRequest` / `Wirespec.RawResponse` using the serialization interfaces +- **Handler** — an interface with a function matching the endpoint name, representing the server-side handler contract + +## Transform + +The transform stage is where language-specific adaptation happens. Each language emitter reshapes the generic IR into a form that generates idiomatic code for its target. + +Transforms work by walking the IR tree and replacing nodes. A **Transformer** provides override points for each node kind (types, elements, statements, expressions, fields, parameters). Each override defaults to a recursive traversal — you only override what you want to change. + +Language emitters use a high-level `transform { }` block that chains multiple transformations. Common operations include renaming types, replacing type patterns (e.g., mapping `Array` to a language-specific list type), transforming fields conditionally, and injecting elements before or after containers. + +Because transforms recurse automatically, a single type rename propagates through every struct field, function parameter, return type, and nested expression in the tree. + +## Example: Transform and Generate + +Given the IR for the `Address` type shown above, a Java emitter transforms it to match Java conventions: + +``` +File("Address") { + Struct("Address") implements Wirespec.Model { + Field("street", Custom("String")) + Field("number", Custom("int")) + Field("tags", Custom("java.util.List", [Custom("String")])) + Function("validate") → Custom("java.util.List", [Custom("String")]) + } +} +``` + +The generator then walks this transformed IR and emits: + +```java +public record Address( + String street, + int number, + java.util.List tags +) implements Wirespec.Model { + public java.util.List validate() { + return java.util.Collections.emptyList(); + } +} +``` + +The same IR — after different transforms — produces equivalent code in Kotlin, TypeScript, Python, or Rust. diff --git a/src/verify/build.gradle.kts b/src/verify/build.gradle.kts new file mode 100644 index 000000000..e08119ed6 --- /dev/null +++ b/src/verify/build.gradle.kts @@ -0,0 +1,36 @@ +plugins { + alias(libs.plugins.kotlin.jvm) +} + +group = "${libs.versions.group.id.get()}.verify" +version = System.getenv(libs.versions.from.env.get()) ?: libs.versions.default.get() + +repositories { + mavenCentral() + mavenLocal() +} + +kotlin { + jvmToolchain(libs.versions.java.get().toInt()) +} + +tasks.test { + useJUnitPlatform() + systemProperty("buildDir", layout.buildDirectory.get().asFile.absolutePath) + onlyIf { project.hasProperty("verify") } +} + +dependencies { + implementation(project(":src:compiler:core")) + implementation(project(":src:compiler:test")) + implementation(project(":src:compiler:emitters:java")) + implementation(project(":src:compiler:emitters:kotlin")) + implementation(project(":src:compiler:emitters:python")) + implementation(project(":src:compiler:emitters:typescript")) + implementation(project(":src:compiler:emitters:rust")) + implementation(project(":src:compiler:emitters:scala")) + implementation(libs.bundles.kotest) + implementation(libs.testcontainers) + testImplementation(libs.kotlin.test) + testImplementation(libs.kotest.runner.junit5) +} diff --git a/src/verify/src/main/kotlin/community/flock/wirespec/verify/VerifyImage.kt b/src/verify/src/main/kotlin/community/flock/wirespec/verify/VerifyImage.kt new file mode 100644 index 000000000..96738bbdb --- /dev/null +++ b/src/verify/src/main/kotlin/community/flock/wirespec/verify/VerifyImage.kt @@ -0,0 +1,85 @@ +package community.flock.wirespec.verify + +import org.testcontainers.images.builder.ImageFromDockerfile + +enum class VerifyImage { + KOTLIN_1 { + override val image by lazy { + val version = "1.9.24" + ImageFromDockerfile("wirespec-kotlin-verify", false) + .withDockerfileFromBuilder { builder -> + builder + .from("eclipse-temurin:17-jdk") + .run("apt-get update -qq && apt-get install -y -qq wget unzip > /dev/null 2>&1") + .run( + "wget -q https://github.com/JetBrains/kotlin/releases/download/v$version/kotlin-compiler-$version.zip -O /tmp/kotlin.zip && " + + "unzip -q /tmp/kotlin.zip -d /opt && " + + "rm /tmp/kotlin.zip" + ) + .build() + } + .get() + } + }, + KOTLIN_2 { + override val image by lazy { + val version = "2.0.21" + ImageFromDockerfile("wirespec-kotlin-verify", false) + .withDockerfileFromBuilder { builder -> + builder + .from("eclipse-temurin:17-jdk") + .run("apt-get update -qq && apt-get install -y -qq wget unzip > /dev/null 2>&1") + .run( + "wget -q https://github.com/JetBrains/kotlin/releases/download/v$version/kotlin-compiler-$version.zip -O /tmp/kotlin.zip && " + + "unzip -q /tmp/kotlin.zip -d /opt && " + + "rm /tmp/kotlin.zip" + ) + .build() + } + .get() + } + }, + PYTHON { + override val image by lazy { + ImageFromDockerfile("wirespec-python-verify", false) + .withDockerfileFromBuilder { builder -> + builder + .from("python:3.12-slim") + .run("pip install mypy") + .build() + } + .get() + } + }, + RUST { + override val image by lazy { + ImageFromDockerfile("wirespec-rust-verify", false) + .withDockerfileFromBuilder { builder -> + builder + .from("rust:1.83-slim") + .run("cargo init /app && cd /app && cargo add regex serde --features serde/derive && cargo add serde_json && cargo add pollster") + .build() + } + .get() + } + }, + SCALA { + override val image by lazy { + ImageFromDockerfile("wirespec-scala-verify", false) + .withDockerfileFromBuilder { builder -> + builder + .from("eclipse-temurin:17-jdk") + .run("apt-get update -qq && apt-get install -y -qq curl > /dev/null 2>&1") + .run( + "curl -sSLf https://scala-cli.virtuslab.org/get | sh && " + + "ln -s /root/.cache/scalacli/local-repo/bin/scala-cli/scala-cli /usr/local/bin/scala-cli && " + + "scala-cli version" + ) + .build() + } + .get() + } + }; + + abstract val image: String +} diff --git a/src/verify/src/main/kotlin/community/flock/wirespec/verify/VerifySerialization.kt b/src/verify/src/main/kotlin/community/flock/wirespec/verify/VerifySerialization.kt new file mode 100644 index 000000000..762c25a09 --- /dev/null +++ b/src/verify/src/main/kotlin/community/flock/wirespec/verify/VerifySerialization.kt @@ -0,0 +1,239 @@ +package community.flock.wirespec.verify + +import community.flock.wirespec.compiler.core.parse.ast.Type as AstType +import community.flock.wirespec.compiler.test.Fixture +import community.flock.wirespec.emitters.java.JavaIrEmitter +import community.flock.wirespec.emitters.kotlin.KotlinIrEmitter +import community.flock.wirespec.emitters.python.PythonIrEmitter +import community.flock.wirespec.emitters.rust.RustIrEmitter +import community.flock.wirespec.emitters.scala.ScalaIrEmitter +import community.flock.wirespec.emitters.typescript.TypeScriptIrEmitter + +internal fun serializationCode(lang: Language, fixture: Fixture? = null): String = when (lang.emitter) { + is JavaIrEmitter -> """ + |static Wirespec.Serialization serialization = new Wirespec.Serialization() { + | private final java.util.Map store = new java.util.HashMap<>(); + | private String randomKey() { return java.util.UUID.randomUUID().toString(); } + | @Override public byte[] serializeBody(T t, java.lang.reflect.Type type) { String key = randomKey(); store.put(key, t); return key.getBytes(); } + | @Override public T deserializeBody(byte[] raw, java.lang.reflect.Type type) { return (T) store.get(new String(raw)); } + | @Override public String serializePath(T t, java.lang.reflect.Type type) { return t.toString(); } + | @Override public T deserializePath(String raw, java.lang.reflect.Type type) { return (T) raw; } + | @Override public java.util.List serializeParam(T value, java.lang.reflect.Type type) { return java.util.List.of(value.toString()); } + | @Override public T deserializeParam(java.util.List values, java.lang.reflect.Type type) { + | String v = values.get(0); + | Class cls = type instanceof Class ? (Class) type : (Class) ((java.lang.reflect.ParameterizedType) type).getRawType(); + | if (cls == Boolean.class || cls == boolean.class) return (T) Boolean.valueOf(v); + | if (cls == Integer.class || cls == int.class) return (T) Integer.valueOf(v); + | if (cls == String.class) return (T) v; + | try { return (T) cls.getConstructors()[0].newInstance(v); } catch (Exception e) { return (T) v; } + | } + |}; + """.trimMargin() + + is KotlinIrEmitter -> """ + |val serialization = object : Wirespec.Serialization { + | private val store = mutableMapOf() + | private fun randomKey() = java.util.UUID.randomUUID().toString() + | override fun serializeBody(t: T, kType: kotlin.reflect.KType): ByteArray { val key = randomKey(); store[key] = t; return key.toByteArray() } + | override fun deserializeBody(raw: ByteArray, kType: kotlin.reflect.KType): T = store[String(raw)] as T + | override fun serializePath(t: T, kType: kotlin.reflect.KType): String = t.toString() + | override fun deserializePath(raw: String, kType: kotlin.reflect.KType): T = raw as T + | override fun serializeParam(value: T, kType: kotlin.reflect.KType): List = listOf(value.toString()) + | override fun deserializeParam(values: List, kType: kotlin.reflect.KType): T { + | val v = values[0] + | val cls = kType.classifier as? kotlin.reflect.KClass<*> + | if (cls == Boolean::class) return v.toBoolean() as T + | if (cls == Int::class) return v.toInt() as T + | if (cls == String::class) return v as T + | return cls!!.constructors.first().call(v) as T + | } + |} + """.trimMargin() + + is TypeScriptIrEmitter -> """ + |const store: Record = {}; + |let counter = 0; + |const serialization: Wirespec.Serialization = { + | serializeBody: (t: T, type: Wirespec.Type): Uint8Array => { const key = String(counter++); store[key] = t; return new TextEncoder().encode(key); }, + | deserializeBody: (raw: Uint8Array, type: Wirespec.Type): T => store[new TextDecoder().decode(raw)] as T, + | serializePath: (t: T, type: Wirespec.Type): string => String(t), + | deserializePath: (raw: string, type: Wirespec.Type): T => raw as unknown as T, + | serializeParam: (value: T, type: Wirespec.Type): string[] => [String(value)], + | deserializeParam: (values: string[], type: Wirespec.Type): T => { + | const v = values[0]; + | if (type === "boolean") return (v === "true") as unknown as T; + | if (type === "number") return Number(v) as unknown as T; + | if (type === "string") return v as unknown as T; + | return ({ iss: v }) as unknown as T; + | }, + |} + """.trimMargin() + + is PythonIrEmitter -> """ + |class TestSerialization(Wirespec.Serialization): + | def __init__(self): + | self.store = {} + | self.counter = 0 + | def serializeBody(self, t, type): + | key = str(self.counter) + | self.counter += 1 + | self.store[key] = t + | return key.encode() + | def deserializeBody(self, raw, type): + | return self.store[raw.decode()] + | def serializePath(self, t, type): + | return str(t) + | def deserializePath(self, raw, type): + | return raw + | def serializeParam(self, value, type): + | return [str(value)] + | def deserializeParam(self, values, type): + | v = values[0] + | if type == bool: return v.lower() == 'true' + | if type == int: return int(v) + | if type == str: return v + | return type(v) + |serialization = TestSerialization() + """.trimMargin() + + is RustIrEmitter -> rustSerializationCode(fixture) + + is ScalaIrEmitter -> """ + |val serialization = new Wirespec.Serialization { + | private val store = scala.collection.mutable.Map[String, Any]() + | private def randomKey(): String = java.util.UUID.randomUUID().toString + | override def serializeBody[T](t: T, classTag: scala.reflect.ClassTag[?]): Array[Byte] = { val key = randomKey(); store(key) = t; key.getBytes } + | override def deserializeBody[T](raw: Array[Byte], classTag: scala.reflect.ClassTag[?]): T = store(new String(raw)).asInstanceOf[T] + | override def serializePath[T](t: T, classTag: scala.reflect.ClassTag[?]): String = t.toString + | override def deserializePath[T](raw: String, classTag: scala.reflect.ClassTag[?]): T = raw.asInstanceOf[T] + | override def serializeParam[T](value: T, classTag: scala.reflect.ClassTag[?]): List[String] = List(value.toString) + | override def deserializeParam[T](values: List[String], classTag: scala.reflect.ClassTag[?]): T = { + | val v = values.head + | val cls = classTag.runtimeClass + | if (cls == classOf[Boolean]) java.lang.Boolean.parseBoolean(v).asInstanceOf[T] + | else if (cls == classOf[Int]) v.toInt.asInstanceOf[T] + | else if (cls == classOf[String]) v.asInstanceOf[T] + | else cls.getConstructors.head.newInstance(v).asInstanceOf[T] + | } + |} + """.trimMargin() + + else -> error("Unknown emitter: ${lang.emitter::class.simpleName}") +} + +private fun rustSerializationCode(fixture: Fixture?): String { + val types = fixture?.definitions()?.filterIsInstance()?.associate { + it.identifier.value to it.shape.value.map { f -> f.identifier.value to f.reference } + } ?: emptyMap() + + fun serializeBranch(typeName: String, fields: List>): String { + val jsonFields = fields.joinToString(", ") { (name, _) -> "\"$name\": v.$name" } + return "if let Some(v) = any.downcast_ref::<$typeName>() {\n | serde_json::to_vec(&serde_json::json!({$jsonFields})).unwrap()" + } + + fun serializeVecBranch(typeName: String, fields: List>): String { + val jsonFields = fields.joinToString(", ") { (name, _) -> "\"$name\": td.$name" } + return "if let Some(v) = any.downcast_ref::>() {\n | serde_json::to_vec(&v.iter().map(|td| serde_json::json!({$jsonFields})).collect::>()).unwrap()" + } + + fun deserializeBranch(typeName: String, fields: List>): String { + val fieldInits = fields.joinToString(", ") { (name, ref) -> + val accessor = when { + ref.toString().contains("Boolean") -> "v[\"$name\"].as_bool().unwrap_or_default()" + ref.toString().contains("Integer") -> "v[\"$name\"].as_i64().unwrap_or_default() as i64" + else -> "v[\"$name\"].as_str().unwrap_or_default().to_string()" + } + "$name: $accessor" + } + return "if _type == std::any::TypeId::of::<$typeName>() {\n | let v: serde_json::Value = serde_json::from_slice(raw).unwrap();\n | Box::new($typeName { $fieldInits })" + } + + fun deserializeVecBranch(typeName: String, fields: List>): String { + val fieldInits = fields.joinToString(", ") { (name, ref) -> + val accessor = when { + ref.toString().contains("Boolean") -> "v[\"$name\"].as_bool().unwrap_or_default()" + ref.toString().contains("Integer") -> "v[\"$name\"].as_i64().unwrap_or_default() as i64" + else -> "v[\"$name\"].as_str().unwrap_or_default().to_string()" + } + "$name: $accessor" + } + return "if _type == std::any::TypeId::of::>() {\n | let values: Vec = serde_json::from_slice(raw).unwrap();\n | let items: Vec<$typeName> = values.iter().map(|v| $typeName { $fieldInits }).collect();\n | Box::new(items)" + } + + // Build body serializer branches + val serBranches = types.flatMap { (name, fields) -> + listOf(serializeVecBranch(name, fields), serializeBranch(name, fields)) + } + val serBody = if (serBranches.isEmpty()) { + "panic!(\"Unsupported body type for serialization: {:?}\", _type)" + } else { + serBranches.joinToString("\n | } else ") { it } + "\n | } else {\n | panic!(\"Unsupported body type for serialization: {:?}\", _type)\n | }" + } + + // Build body deserializer branches + val deserBranches = types.flatMap { (name, fields) -> + listOf(deserializeVecBranch(name, fields), deserializeBranch(name, fields)) + } + val deserBody = if (deserBranches.isEmpty()) { + "panic!(\"Unsupported body type for deserialization: {:?}\", _type)" + } else { + "let boxed: Box = " + deserBranches.joinToString("\n | } else ") { it } + "\n | } else {\n | panic!(\"Unsupported body type for deserialization: {:?}\", _type)\n | };\n | *boxed.downcast::().unwrap()" + } + + // Build param deserializer branches for custom types + val paramBranches = types.filter { (_, fields) -> fields.size == 1 }.map { (name, fields) -> + val fieldName = fields.first().first + "} else if _type == std::any::TypeId::of::<$name>() {\n | Box::new($name { $fieldName: values.first().cloned().unwrap_or_default() })" + } + val paramExtra = paramBranches.joinToString("\n | ") + + return """ + |struct MockSer; + |impl BodySerializer for MockSer { + | fn serialize_body(&self, t: &T, _type: std::any::TypeId) -> Vec { + | let any: &dyn std::any::Any = t; + | $serBody + | } + |} + |impl BodyDeserializer for MockSer { + | fn deserialize_body(&self, raw: &[u8], _type: std::any::TypeId) -> T { + | $deserBody + | } + |} + |impl PathSerializer for MockSer { + | fn serialize_path(&self, t: &T, _type: std::any::TypeId) -> String { t.to_string() } + |} + |impl PathDeserializer for MockSer { + | fn deserialize_path(&self, raw: &str, _type: std::any::TypeId) -> T where T::Err: std::fmt::Debug { raw.parse().unwrap() } + |} + |impl ParamSerializer for MockSer { + | fn serialize_param(&self, value: &T, _type: std::any::TypeId) -> Vec { + | let any: &dyn std::any::Any = value; + | if let Some(s) = any.downcast_ref::() { vec![s.clone()] } + | else if let Some(b) = any.downcast_ref::() { vec![b.to_string()] } + | else { panic!("Unsupported param type for serialization: {:?}", _type) } + | } + |} + |impl ParamDeserializer for MockSer { + | fn deserialize_param(&self, values: &[String], _type: std::any::TypeId) -> T { + | let boxed: Box = if _type == std::any::TypeId::of::() { + | Box::new(values.first().cloned().unwrap_or_default()) + | } else if _type == std::any::TypeId::of::() { + | Box::new(values.first().map(|v| v == "true").unwrap_or(false)) + | $paramExtra + | } else { + | panic!("Unsupported param type for deserialization: {:?}", _type) + | }; + | *boxed.downcast::().unwrap() + | } + |} + |impl BodySerialization for MockSer {} + |impl PathSerialization for MockSer {} + |impl ParamSerialization for MockSer {} + |impl Serializer for MockSer {} + |impl Deserializer for MockSer {} + |impl Serialization for MockSer {} + |#[allow(non_upper_case_globals)] + |static serialization: MockSer = MockSer; + """.trimMargin() +} diff --git a/src/verify/src/main/kotlin/community/flock/wirespec/verify/VerifyTransportation.kt b/src/verify/src/main/kotlin/community/flock/wirespec/verify/VerifyTransportation.kt new file mode 100644 index 000000000..654bfb774 --- /dev/null +++ b/src/verify/src/main/kotlin/community/flock/wirespec/verify/VerifyTransportation.kt @@ -0,0 +1,89 @@ +package community.flock.wirespec.verify + +import community.flock.wirespec.emitters.java.JavaIrEmitter +import community.flock.wirespec.emitters.kotlin.KotlinIrEmitter +import community.flock.wirespec.emitters.python.PythonIrEmitter +import community.flock.wirespec.emitters.rust.RustIrEmitter +import community.flock.wirespec.emitters.scala.ScalaIrEmitter +import community.flock.wirespec.emitters.typescript.TypeScriptIrEmitter + +internal fun transportationCode(lang: Language): String = when (lang.emitter) { + is JavaIrEmitter -> """ + |static Wirespec.Transportation transportation = (Wirespec.RawRequest rawRequest) -> { + | assert rawRequest.method().equals("GET") : "Method should be GET"; + | assert rawRequest.path().get(0).equals("todos") : "Path should start with todos"; + | TodoDto todo = new TodoDto("test"); + | byte[] body = serialization.serializeBody(java.util.List.of(todo), Wirespec.getType(TodoDto.class, java.util.List.class)); + | return java.util.concurrent.CompletableFuture.completedFuture(new Wirespec.RawResponse(200, java.util.Collections.emptyMap(), java.util.Optional.of(body))); + |}; + """.trimMargin() + + is KotlinIrEmitter -> """ + |val transportation = object : Wirespec.Transportation { + | override suspend fun transport(request: Wirespec.RawRequest): Wirespec.RawResponse { + | assert(request.method == "GET") { "Method should be GET" } + | assert(request.path[0] == "todos") { "Path should start with todos" } + | val todo = TodoDto(description = "test") + | val body = serialization.serializeBody(listOf(todo), kotlin.reflect.typeOf>()) + | return Wirespec.RawResponse(statusCode = 200, headers = emptyMap(), body = body) + | } + |} + """.trimMargin() + + is TypeScriptIrEmitter -> """ + |const transportation: Wirespec.Transportation = { + | transport: async (request: Wirespec.RawRequest): Promise => { + | if (request.method !== "GET") throw new Error("Method should be GET"); + | if (request.path[0] !== "todos") throw new Error("Path should start with todos"); + | const todo: TodoDto = { description: "test" }; + | const body = serialization.serializeBody([todo], "TodoDto"); + | return { statusCode: 200, headers: {}, body }; + | } + |} + """.trimMargin() + + is PythonIrEmitter -> """ + |class TestTransportation(Wirespec.Transportation): + | def __init__(self, serialization): + | self.serialization = serialization + | async def transport(self, request): + | assert request.method == "GET", "Method should be GET" + | assert request.path[0] == "todos", "Path should start with todos" + | todo = TodoDto(description="test") + | body = self.serialization.serializeBody([todo], "List[TodoDto]") + | return Wirespec.RawResponse(statusCode=200, headers={}, body=body) + |transportation = TestTransportation(serialization) + """.trimMargin() + + is RustIrEmitter -> """ + |use generated::wirespec::Transportation; + |struct MockTransport<'a, S: Serialization> { + | serialization: &'a S, + |} + |impl<'a, S: Serialization> Transportation for MockTransport<'a, S> { + | async fn transport(&self, request: &RawRequest) -> RawResponse { + | assert_eq!(request.method, "GET", "Method should be GET"); + | assert_eq!(request.path[0], "todos", "Path should start with todos"); + | let todo = TodoDto { description: "test".to_string() }; + | let body = self.serialization.serialize_body(&vec![todo], std::any::TypeId::of::>()); + | RawResponse { status_code: 200, headers: std::collections::HashMap::new(), body: Some(body) } + | } + |} + |#[allow(non_upper_case_globals)] + |static transportation: MockTransport<'static, MockSer> = MockTransport { serialization: &serialization }; + """.trimMargin() + + is ScalaIrEmitter -> """ + |val transportation = new Wirespec.Transportation { + | override def transport(request: Wirespec.RawRequest): Wirespec.RawResponse = { + | assert(request.method == "GET", "Method should be GET") + | assert(request.path.head == "todos", "Path should start with todos") + | val todo = TodoDto(description = "test") + | val body = serialization.serializeBody(List(todo), scala.reflect.classTag[List[TodoDto]]) + | Wirespec.RawResponse(statusCode = 200, headers = Map.empty, body = Some(body)) + | } + |} + """.trimMargin() + + else -> error("Unknown emitter: ${lang.emitter::class.simpleName}") +} diff --git a/src/verify/src/main/kotlin/community/flock/wirespec/verify/VerifyUtil.kt b/src/verify/src/main/kotlin/community/flock/wirespec/verify/VerifyUtil.kt new file mode 100644 index 000000000..43d77c327 --- /dev/null +++ b/src/verify/src/main/kotlin/community/flock/wirespec/verify/VerifyUtil.kt @@ -0,0 +1,462 @@ +package community.flock.wirespec.verify + +import arrow.core.nonEmptyListOf +import arrow.core.nonEmptySetOf +import community.flock.wirespec.compiler.core.CompilationContext +import community.flock.wirespec.compiler.core.FileUri +import community.flock.wirespec.compiler.core.ModuleContent +import community.flock.wirespec.compiler.core.ParseContext +import community.flock.wirespec.compiler.core.WirespecSpec +import community.flock.wirespec.compiler.core.compile +import community.flock.wirespec.compiler.core.emit.EmitShared +import community.flock.wirespec.compiler.core.emit.Emitter +import community.flock.wirespec.compiler.core.parse +import community.flock.wirespec.compiler.core.emit.importReferences +import community.flock.wirespec.compiler.core.parse.ast.Definition +import community.flock.wirespec.compiler.core.parse.ast.Endpoint +import community.flock.wirespec.compiler.core.parse.ast.Refined +import community.flock.wirespec.ir.core.ContainerBuilder +import community.flock.wirespec.compiler.test.Fixture +import community.flock.wirespec.compiler.utils.NoLogger +import community.flock.wirespec.emitters.java.JavaIrEmitter +import community.flock.wirespec.emitters.kotlin.KotlinIrEmitter +import community.flock.wirespec.emitters.python.PythonIrEmitter +import community.flock.wirespec.emitters.rust.RustIrEmitter +import community.flock.wirespec.emitters.scala.ScalaIrEmitter +import community.flock.wirespec.emitters.typescript.TypeScriptIrEmitter +import community.flock.wirespec.ir.core.AssertStatement +import community.flock.wirespec.ir.core.Assignment +import community.flock.wirespec.ir.core.BinaryOp +import community.flock.wirespec.ir.core.ConstructorStatement +import community.flock.wirespec.ir.core.Expression +import community.flock.wirespec.ir.core.FieldCall +import community.flock.wirespec.ir.core.FunctionCall +import community.flock.wirespec.ir.core.RawExpression +import community.flock.wirespec.ir.core.Import +import community.flock.wirespec.ir.core.Main +import community.flock.wirespec.ir.core.Name +import community.flock.wirespec.ir.core.Type +import community.flock.wirespec.ir.core.VariableReference +import community.flock.wirespec.ir.core.transformChildren +import community.flock.wirespec.ir.core.transformer +import community.flock.wirespec.ir.emit.IrEmitter +import community.flock.wirespec.ir.generator.generateJava +import community.flock.wirespec.ir.generator.generateKotlin +import community.flock.wirespec.ir.generator.generatePython +import community.flock.wirespec.ir.generator.generateRust +import community.flock.wirespec.ir.generator.generateScala +import community.flock.wirespec.ir.generator.generateTypeScript +import io.kotest.matchers.shouldBe +import org.testcontainers.containers.BindMode +import org.testcontainers.containers.GenericContainer +import java.io.File +import community.flock.wirespec.ir.core.File as AstFile + +val languages = mapOf( + "java-17" to Language(JavaIrEmitter(emitShared = EmitShared(true)), { "eclipse-temurin:17-jdk" }), + "java-21" to Language(JavaIrEmitter(emitShared = EmitShared(true)), { "eclipse-temurin:21-jdk" }), + "kotlin-1" to Language(KotlinIrEmitter(emitShared = EmitShared(true)), { VerifyImage.KOTLIN_1.image }), + "kotlin-2" to Language(KotlinIrEmitter(emitShared = EmitShared(true)), { VerifyImage.KOTLIN_2.image }), + "python" to Language(PythonIrEmitter(emitShared = EmitShared(true)), { VerifyImage.PYTHON.image }), + "typescript" to Language(TypeScriptIrEmitter(), { "node:20-slim" }), + "rust" to Language(RustIrEmitter(emitShared = EmitShared(true)), { VerifyImage.RUST.image }), + "scala" to Language(ScalaIrEmitter(emitShared = EmitShared(true)), { VerifyImage.SCALA.image }), +).onEach { (name, lang) -> lang.name = name } + +class Language( + val emitter: IrEmitter, + val image: () -> String, +) { + lateinit var name: String + override fun toString() = name + lateinit var container: GenericContainer<*> + private lateinit var outputDir: File + private lateinit var fixture: Fixture + + private val tsExtraFiles: (File) -> Unit = { outputDir -> + File(outputDir, "tsconfig.json").writeText( + """ + { + "compilerOptions": { + "strict": true, + "noEmit": true, + "skipLibCheck": true, + "target": "ES2019", + "module": "CommonJS", + "moduleResolution": "node10", + "ignoreDeprecations": "6.0" + }, + "include": ["./**/*.ts"] + } + """.trimIndent() + ) + } + + fun generate(file: AstFile, outputDir: File) { + val name = file.name.pascalCase() + val transformed = emitter.transformTestFile(file) + val (fileName, content) = when (emitter) { + is JavaIrEmitter -> "${name}.java" to transformed.generateJava() + is KotlinIrEmitter -> "${name}.kt" to transformed.generateKotlin() + is PythonIrEmitter -> "${name}.py" to transformed.generatePython() + is RustIrEmitter -> "${name}.rs" to transformed.generateRust() + is ScalaIrEmitter -> "${name}.scala" to transformed.generateScala() + is TypeScriptIrEmitter -> "${name}.ts" to transformed.generateTypeScript() + else -> error("Unknown language: $name") + } + outputDir.resolve(fileName).writeText(content) + } + + fun start(name: String, fixture: Fixture, extraFiles: (File) -> Unit = {}) { + this.fixture = fixture + val (cont, dir) = compileAndVerify( + name = name, + emitter = emitter, + fixture = fixture, + language = this.name.lowercase(), + image = image(), + extraFiles = { dir -> + tsExtraFiles(dir) + extraFiles(dir) + }, + ) + container = cont + outputDir = dir + } + + fun compile() { + val verifyCommand = when (emitter) { + is JavaIrEmitter -> "find /app/gen -name '*.java' | xargs javac -d /tmp/out" + is KotlinIrEmitter -> "/opt/kotlinc/bin/kotlinc -nowarn -include-runtime /app/gen/ -d /tmp/run.jar" + is PythonIrEmitter -> "python -m mypy --disable-error-code=empty-body --disable-error-code=arg-type /app/gen/" + is RustIrEmitter -> "rm -rf /app/src/generated && cp -r /app/gen/community/flock/wirespec/generated /app/src/generated && printf 'mod generated;\\nfn main() {}\\n' > /app/src/main.rs && cd /app && cargo build" + is ScalaIrEmitter -> "find /app/gen -name '*.scala' | xargs scala-cli compile --server=false" + is TypeScriptIrEmitter -> "npm install -g typescript && cd /app/gen && tsc --noEmit" + else -> error("Unknown language: ${emitter::class.simpleName}") + } + exec(verifyCommand) + } + + fun run(testFile: AstFile) { + val resolved = if (emitter is TypeScriptIrEmitter) testFile.adaptForTypeScript(fixture) else testFile + generate(resolved, outputDir) + compile() + val fileName = testFile.name.pascalCase() + val runCommand: String = when (emitter) { + is JavaIrEmitter -> "java -ea -cp /tmp/out $fileName" + is KotlinIrEmitter -> "java -ea -cp /tmp/run.jar ${fileName}Kt" + is PythonIrEmitter -> "cd /app/gen && python -O ${fileName}.py" + is RustIrEmitter -> { + // Build use statements from the test file's imports + val imports = resolved.elements.filterIsInstance() + val hasEndpointImports = imports.any { it.path.contains("endpoint") } + val useStatements = imports.flatMap { imp -> + val typeName = imp.type.name + val snakeName = Name.of(typeName).snakeCase() + when { + imp.path.contains("endpoint") -> listOf( + "use generated::endpoint::${snakeName}::*;", + "use generated::endpoint::${snakeName}::${typeName}::*;", + ) + imp.path.contains("model") -> listOf("use generated::model::${snakeName}::${typeName};") + imp.path.contains("client") -> listOf("use generated::client::${snakeName}::${typeName};") + else -> listOf("use generated::${snakeName}::${typeName};") + } + }.joinToString("\n") + // Import specific wirespec traits to avoid name clashes with endpoint types (Request, Response, etc.) + val wirespecUse = if (hasEndpointImports) { + "use generated::wirespec::{BodySerializer, BodyDeserializer, PathSerializer, PathDeserializer, ParamSerializer, ParamDeserializer, BodySerialization, PathSerialization, ParamSerialization, Serializer, Deserializer, Serialization, RawRequest, RawResponse, Method};" + } else "" + // Generate the test file content (which contains fn main()) + val transformedRust = emitter.transformTestFile(resolved) + val rustContent = transformedRust.generateRust() + // Filter out the import lines that the generator produced (use super::...) + val filteredContent = rustContent.lines() + .filter { !it.startsWith("use super::") } + .joinToString("\n") + val mainRs = "mod generated;\n$useStatements\n$wirespecUse\n\n$filteredContent" + container.execInContainer("sh", "-c", "cat > /app/src/main.rs << 'RUSTEOF'\n$mainRs\nRUSTEOF") + "cd /app && cargo build && cargo run" + } + + is ScalaIrEmitter -> "find /app/gen -name '*.scala' | xargs scala-cli run --server=false --main-class ${fileName}" + is TypeScriptIrEmitter -> "npm install -g tsx && cd /app/gen && tsx ${fileName}.ts" + else -> error("Unknown language: ${name}") + } + exec(runCommand) + } + + fun exec(command: String) { + val result = container.execInContainer("sh", "-c", command) + if (result.stdout.isNotBlank()) { + println("=== stdout ===") + println(result.stdout) + } + if (result.stderr.isNotBlank()) { + println("=== stderr ===") + println(result.stderr) + } + if (result.exitCode != 0) { + println("=== exit code ===") + println(result.exitCode) + } + result.exitCode shouldBe 0 + } +} + +fun compileAndVerify( + name: String, + emitter: Emitter, + fixture: Fixture, + language: String, + image: String, + extraFiles: (File) -> Unit = {}, +): Pair, File> { + val emitted = object : CompilationContext, NoLogger { + override val spec = WirespecSpec + override val emitters = nonEmptySetOf(emitter) + }.compile(nonEmptyListOf(ModuleContent(FileUri("N/A"), fixture.source))) + + val files = emitted.fold( + { error -> throw AssertionError("Compilation failed: $error") }, + { it } + ) + + val outputDir = File(System.getProperty("buildDir"), "generated/$name/$language") + outputDir.deleteRecursively() + outputDir.mkdirs() + print(outputDir.absolutePath) + files.forEach { file -> + val target = File(outputDir, file.file) + target.parentFile.mkdirs() + target.writeText(file.result) + } + + extraFiles(outputDir) + + val container = GenericContainer(image) + .withFileSystemBind(outputDir.absolutePath, "/app/gen", BindMode.READ_ONLY) + .withCommand("tail", "-f", "/dev/null") + + container.start() + + return container to outputDir +} + +fun Fixture.refinedTypeNames(): Set { + val ctx = object : ParseContext, NoLogger { + override val spec = WirespecSpec + } + val ast = ctx.parse(nonEmptyListOf(ModuleContent(FileUri("N/A"), source))) + .getOrNull() ?: return emptySet() + return ast.modules.toList() + .flatMap { it.statements.toList() } + .filterIsInstance() + .map { it.identifier.value } + .toSet() +} + +/** + * Adapts a canonical (non-TS) test file for TypeScript: + * inlines refined wrappers, rewrites validate calls, and rebuilds imports. + */ +fun AstFile.adaptForTypeScript(fixture: Fixture): AstFile { + val refinedTypes = fixture.refinedTypeNames() + val main = elements.filterIsInstance
().firstOrNull() ?: return this + val body = main.body + + // Analyze: variable->type mapping and validate targets + val variableTypes = mutableMapOf() + val validateTargets = mutableSetOf() + for (stmt in body) { + if (stmt !is Assignment) continue + when (val value = stmt.value) { + is ConstructorStatement -> (value.type as? Type.Custom)?.let { variableTypes[stmt.name] = it.name } + is FunctionCall -> if (value.name == Name("validate") && value.receiver is VariableReference) + validateTargets.add((value.receiver as VariableReference).name) + + else -> {} + } + } + + // Refined wrappers to inline: refined type assignments NOT used as validate targets + val inlineMap = body.filterIsInstance() + .filter { variableTypes[it.name] in refinedTypes && it.name !in validateTargets } + .mapNotNull { a -> + (a.value as? ConstructorStatement)?.namedArguments?.get(Name("value"))?.let { a.name to it } + } + .toMap() + + if (validateTargets.isEmpty() && inlineMap.isEmpty()) return this + + // Transform: inline refs + rewrite validate calls + val t = transformer { + expression { expr, self -> + when { + expr is VariableReference && expr.name in inlineMap -> inlineMap.getValue(expr.name) + expr is FunctionCall && expr.name == Name("validate") && expr.receiver is VariableReference -> { + val varName = (expr.receiver as VariableReference).name + val typeName = variableTypes[varName] ?: return@expression expr.transformChildren(self) + val arg = if (typeName in refinedTypes) Name("value") to FieldCall( + VariableReference(varName), + Name("value") + ) + else Name("obj") to VariableReference(varName) + FunctionCall(name = Name("validate$typeName"), arguments = mapOf(arg)) + } + + else -> expr.transformChildren(self) + } + } + } + + // Second pass: wrap variable-to-variable equality in JSON.stringify for TypeScript + // TypeScript uses === which compares references, not structural equality for arrays + val jsonStringifyTransformer = transformer { + statement { stmt, self -> + if (stmt is AssertStatement && stmt.expression is BinaryOp) { + val op = stmt.expression as BinaryOp + if ((op.operator == BinaryOp.Operator.EQUALS || op.operator == BinaryOp.Operator.NOT_EQUALS) && + op.left is VariableReference && op.right is VariableReference + ) { + fun jsonStringify(e: Expression): FunctionCall = FunctionCall( + receiver = RawExpression("JSON"), + name = Name.of("stringify"), + arguments = mapOf(Name.of("_") to e), + ) + AssertStatement( + BinaryOp(jsonStringify(op.left), op.operator, jsonStringify(op.right)), + stmt.message, + ) + } else { + stmt.transformChildren(self) + } + } else { + stmt.transformChildren(self) + } + } + } + + val transformedBody = body + .filter { it !is Assignment || it.name !in inlineMap } + .map { t.transformStatement(it) } + .map { jsonStringifyTransformer.transformStatement(it) } + + // Rebuild imports: only validated types + their validate functions + val newImports = validateTargets.flatMap { varName -> + val typeName = variableTypes[varName] ?: return@flatMap emptyList() + listOf( + Import("./model/$typeName", Type.Custom(typeName)), + Import("./model/$typeName", Type.Custom("validate$typeName")), + ) + }.distinct() + + val existingMain = elements.filterIsInstance
().firstOrNull() + return copy(elements = newImports + elements.filter { it !is Import && it !is Main } + Main(statics = existingMain?.statics.orEmpty(), body = transformedBody)) +} + +fun Fixture.definitions(): List { + val ctx = object : ParseContext, NoLogger { + override val spec = WirespecSpec + } + val ast = ctx.parse(nonEmptyListOf(ModuleContent(FileUri("N/A"), source))) + .getOrNull() ?: return emptyList() + return ast.modules.toList().flatMap { it.statements.toList() } +} + + +fun Fixture.endpointNames(): List = + definitions().filterIsInstance().map { it.identifier.value } + +fun Fixture.modelNames(): List = + definitions().filterIsInstance() + .flatMap { it.importReferences() } + .distinctBy { it.value } + .map { it.value } + +fun ContainerBuilder.endpointClientImports(lang: Language, fixture: Fixture) { + val endpoints = fixture.endpointNames() + val models = fixture.modelNames() + clientImportsShared(lang, endpoints, models) + when (lang.emitter) { + is JavaIrEmitter -> endpoints.forEach { import("community.flock.wirespec.generated.client", "${it}Client") } + is KotlinIrEmitter -> endpoints.forEach { import("community.flock.wirespec.generated.client", "${it}Client") } + is TypeScriptIrEmitter -> endpoints.forEach { + val camel = Name.of(it).camelCase() + import("./client/${it}Client", "${camel}Client") + } + is PythonIrEmitter -> endpoints.forEach { + raw("from community.flock.wirespec.generated.client.${it}Client import ${it}Client") + } + is ScalaIrEmitter -> endpoints.forEach { import("community.flock.wirespec.generated.client", "${it}Client") } + is RustIrEmitter -> endpoints.forEach { import("community.flock.wirespec.generated.client", "${it}Client") } + } +} + +fun ContainerBuilder.mainClientImports(lang: Language, fixture: Fixture) { + val endpoints = fixture.endpointNames() + val models = fixture.modelNames() + clientImportsShared(lang, endpoints, models) + when (lang.emitter) { + is JavaIrEmitter -> import("community.flock.wirespec.generated", "Client") + is KotlinIrEmitter -> import("community.flock.wirespec.generated", "Client") + is TypeScriptIrEmitter -> import("./Client", "client") + is PythonIrEmitter -> raw("from community.flock.wirespec.generated.Client import Client") + is ScalaIrEmitter -> import("community.flock.wirespec.generated", "Client") + is RustIrEmitter -> import("community.flock.wirespec.generated", "Client") + } +} + +fun ContainerBuilder.endpointImports(lang: Language, fixture: Fixture) { + endpointImports(lang, fixture.endpointNames(), fixture.modelNames()) +} + +private fun ContainerBuilder.clientImportsShared(lang: Language, endpoints: List, models: List) { + endpointImports(lang, endpoints, models) + when (lang.emitter) { + is KotlinIrEmitter -> { + import("kotlin.coroutines", "createCoroutine") + import("kotlin.coroutines", "resume") + } + is PythonIrEmitter -> { + endpoints.forEach { raw("from community.flock.wirespec.generated.endpoint.$it import Response200") } + raw("import asyncio") + } + else -> {} + } +} + +private fun ContainerBuilder.endpointImports(lang: Language, endpoints: List, models: List) { + when (lang.emitter) { + is JavaIrEmitter -> { + import("community.flock.wirespec.java", "Wirespec") + endpoints.forEach { import("community.flock.wirespec.generated.endpoint", it) } + models.forEach { import("community.flock.wirespec.generated.model", it) } + } + is KotlinIrEmitter -> { + import("community.flock.wirespec.kotlin", "Wirespec") + import("kotlin.reflect", "typeOf") + endpoints.forEach { import("community.flock.wirespec.generated.endpoint", it) } + models.forEach { import("community.flock.wirespec.generated.model", it) } + } + is TypeScriptIrEmitter -> { + import("./Wirespec", "Wirespec") + endpoints.forEach { import("./endpoint/$it", it) } + models.forEach { import("./model/$it", it) } + } + is PythonIrEmitter -> { + import("community.flock.wirespec.generated.wirespec", "Wirespec") + endpoints.forEach { import("community.flock.wirespec.generated.endpoint.$it", it) } + models.forEach { import("community.flock.wirespec.generated.model.$it", it) } + } + is ScalaIrEmitter -> { + import("community.flock.wirespec.scala", "Wirespec") + endpoints.forEach { import("community.flock.wirespec.generated.endpoint", it) } + models.forEach { import("community.flock.wirespec.generated.model", it) } + } + is RustIrEmitter -> { + endpoints.forEach { import("community.flock.wirespec.generated.endpoint", it) } + models.forEach { import("community.flock.wirespec.generated.model", it) } + } + } +} \ No newline at end of file diff --git a/src/verify/src/test/kotlin/community/flock/wirespec/verify/VerifyCaseInsensitivityTest.kt b/src/verify/src/test/kotlin/community/flock/wirespec/verify/VerifyCaseInsensitivityTest.kt new file mode 100644 index 000000000..bb7fc249e --- /dev/null +++ b/src/verify/src/test/kotlin/community/flock/wirespec/verify/VerifyCaseInsensitivityTest.kt @@ -0,0 +1,136 @@ +package community.flock.wirespec.verify + +import community.flock.wirespec.compiler.test.CompileFullEndpointTest +import community.flock.wirespec.emitters.rust.RustIrEmitter +import community.flock.wirespec.ir.core.BinaryOp +import community.flock.wirespec.ir.core.ConstructorStatement +import community.flock.wirespec.ir.core.Expression +import community.flock.wirespec.ir.core.fieldCall +import community.flock.wirespec.ir.core.FunctionCall +import community.flock.wirespec.ir.core.Literal +import community.flock.wirespec.ir.core.LiteralList +import community.flock.wirespec.ir.core.Name +import community.flock.wirespec.ir.core.NullableEmpty +import community.flock.wirespec.ir.core.NullableGet +import community.flock.wirespec.ir.core.NullableOf +import community.flock.wirespec.ir.core.RawExpression +import community.flock.wirespec.ir.core.Type +import community.flock.wirespec.ir.core.TypeDescriptor +import community.flock.wirespec.ir.core.VariableReference +import community.flock.wirespec.ir.core.file +import io.kotest.core.spec.style.FunSpec + +/** + * Tests that HTTP headers are deserialized case-insensitively (RFC 7230) + * while query parameters remain case-sensitive (RFC 3986). + * + * Uses CompileFullEndpointTest fixture which has: + * endpoint PutTodo PUT PotentialTodoDto /todos/{id: String} + * ?{done: Boolean, name: String?} + * #{token: Token, `Refresh-Token`: Token?} -> { + * 200 -> TodoDto + * 201 -> TodoDto #{token: Token, refreshToken: Token?} + * 500 -> Error + * } + * + * Constructs a RawRequest with differently-cased header keys ("TOKEN", "refresh-token") + * and verifies fromRawRequest deserializes them correctly despite case mismatch. + */ +class VerifyCaseInsensitivityTest : FunSpec({ + + languages.values.forEach { lang -> + test("header case insensitivity - $lang") { + val isRust = lang.emitter is RustIrEmitter + val endpointRef: Expression = RawExpression("PutTodo") + + val testFile = file("CaseInsensitivityTest") { + endpointImports(lang, CompileFullEndpointTest) + + main(statics = { + raw(serializationCode(lang, CompileFullEndpointTest)) + }) { + + assign("rawRequest", construct(type(if (isRust) "RawRequest" else "Wirespec.RawRequest")) { + arg("method", literal("PUT")) + arg("path", listOf( + listOf(Literal("todos", Type.String), Literal("123", Type.String)), + Type.String + )) + arg("queries", mapOf( + mapOf( + "done" to LiteralList(listOf(Literal("true", Type.String)), Type.String), + "name" to LiteralList(listOf(Literal("test", Type.String)), Type.String) + ), + Type.String, Type.Array(Type.String) + )) + arg("headers", mapOf( + mapOf( + "TOKEN" to LiteralList(listOf(Literal("issValue", Type.String)), Type.String), + "refresh-token" to LiteralList(listOf(Literal("refreshIssValue", Type.String)), Type.String) + ), + Type.String, Type.Array(Type.String) + )) + arg("body", NullableOf( + FunctionCall( + receiver = VariableReference(Name.of("serialization")), + name = Name("serialize", "Body"), + typeArguments = listOf(Type.Custom("PotentialTodoDto")), + arguments = mapOf( + Name.of("value") to ConstructorStatement( + Type.Custom("PotentialTodoDto"), + mapOf( + Name.of("name") to Literal("test", Type.String), + Name.of("done") to Literal(true, Type.Boolean), + ), + ), + Name.of("type") to TypeDescriptor(Type.Custom("PotentialTodoDto")), + ), + ), + )) + }) + + // fromRawRequest call + assign("fromRaw", functionCall("fromRawRequest", receiver = if (isRust) null else endpointRef) { + arg("serialization", if (isRust) with(RustIrEmitter) { VariableReference("serialization").borrow() } else VariableReference("serialization")) + arg("rawRequest", VariableReference("rawRequest")) + }) + + // Assert token.iss matches despite case mismatch + assertThat( + BinaryOp( + VariableReference("fromRaw").fieldCall("headers").fieldCall("token").fieldCall("iss"), + BinaryOp.Operator.EQUALS, + Literal("issValue", Type.String) + ), + "Header 'token' should match 'TOKEN' case-insensitively" + ) + + // Assert refreshToken is present + assertThat( + BinaryOp( + VariableReference("fromRaw").fieldCall("headers").fieldCall("refreshToken"), + BinaryOp.Operator.NOT_EQUALS, + NullableEmpty + ), + "Header 'Refresh-Token' should match 'refresh-token' case-insensitively" + ) + + // Assert refreshToken.iss == "refreshIssValue" — unwrap then access + assertThat( + BinaryOp( + NullableGet(VariableReference("fromRaw").fieldCall("headers").fieldCall("refreshToken")).fieldCall("iss"), + BinaryOp.Operator.EQUALS, + Literal("refreshIssValue", Type.String) + ), + "Header 'Refresh-Token' value should be correct" + ) + + } + } + + lang.start(name = "case-insensitivity-test", fixture = CompileFullEndpointTest) + lang.run(testFile) + } + } +}) + diff --git a/src/verify/src/test/kotlin/community/flock/wirespec/verify/VerifyClientTest.kt b/src/verify/src/test/kotlin/community/flock/wirespec/verify/VerifyClientTest.kt new file mode 100644 index 000000000..33d289a2f --- /dev/null +++ b/src/verify/src/test/kotlin/community/flock/wirespec/verify/VerifyClientTest.kt @@ -0,0 +1,141 @@ +package community.flock.wirespec.verify + +import community.flock.wirespec.compiler.test.CompileMinimalEndpointTest +import community.flock.wirespec.emitters.python.PythonIrEmitter +import community.flock.wirespec.emitters.rust.RustIrEmitter +import community.flock.wirespec.emitters.typescript.TypeScriptIrEmitter +import community.flock.wirespec.ir.core.ArrayIndexCall +import community.flock.wirespec.ir.core.BinaryOp +import community.flock.wirespec.emitters.rust.RustIrEmitter.Companion.borrow +import community.flock.wirespec.ir.core.ConstructorStatement +import community.flock.wirespec.ir.core.FieldCall +import community.flock.wirespec.ir.core.FunctionBuilder +import community.flock.wirespec.ir.core.Literal +import community.flock.wirespec.ir.core.Name +import community.flock.wirespec.ir.core.Type +import community.flock.wirespec.ir.core.VariableReference +import community.flock.wirespec.ir.core.file +import io.kotest.core.spec.style.FunSpec + +class VerifyClientTest : FunSpec({ + + languages.values.forEach { lang -> + test("endpoint client - $lang") { + lang.start(name = "client-test", fixture = CompileMinimalEndpointTest) + + val isRust = lang.emitter is RustIrEmitter + val isPython = lang.emitter is PythonIrEmitter + val isTypeScript = lang.emitter is TypeScriptIrEmitter + val response200Type = response200Type(isRust, isPython) + + val testFile = file("EndpointClientTest") { + endpointClientImports(lang, CompileMinimalEndpointTest) + + main(isAsync = true, statics = { + raw(serializationCode(lang, CompileMinimalEndpointTest)) + raw(transportationCode(lang)) + }) { + when { + isTypeScript -> assign("endpointClient", functionCall("getTodosClient") { + arg("serialization", VariableReference("serialization")) + arg("transportation", VariableReference("transportation")) + }) + isRust -> assign("endpointClient", construct(Type.Custom("GetTodosClient")) { + arg("serialization", VariableReference("serialization").borrow()) + arg("transportation", VariableReference("transportation").borrow()) + }) + else -> assign("endpointClient", construct(Type.Custom("GetTodosClient")) { + arg("serialization", VariableReference("serialization")) + arg("transportation", VariableReference("transportation")) + }) + } + + val getTodosMethod = if (isPython) "get_todos" else "getTodos" + assign("response", functionCall(getTodosMethod, + receiver = VariableReference("endpointClient"), + isAwait = true, + )) + + assertDescriptionSwitch(response200Type) + } + } + + lang.run(testFile) + } + } + + languages.values.forEach { lang -> + test("main client - $lang") { + lang.start(name = "client-test", fixture = CompileMinimalEndpointTest) + + val isRust = lang.emitter is RustIrEmitter + val isPython = lang.emitter is PythonIrEmitter + val isTypeScript = lang.emitter is TypeScriptIrEmitter + val response200Type = response200Type(isRust, isPython) + + val testFile = file("MainClientTest") { + mainClientImports(lang, CompileMinimalEndpointTest) + + main(isAsync = true, statics = { + raw(serializationCode(lang, CompileMinimalEndpointTest)) + raw(transportationCode(lang)) + }) { + when { + isTypeScript -> assign("mainClient", functionCall("client") { + arg("serialization", VariableReference("serialization")) + arg("transportation", VariableReference("transportation")) + }) + isRust -> assign("mainClient", construct(Type.Custom("Client")) { + arg("serialization", ConstructorStatement(Type.Custom("MockSer"))) + arg("transportation", ConstructorStatement( + Type.Custom("MockTransport"), + mapOf(Name.of("serialization") to VariableReference("serialization").borrow()), + )) + }) + else -> assign("mainClient", construct(Type.Custom("Client")) { + arg("serialization", VariableReference("serialization")) + arg("transportation", VariableReference("transportation")) + }) + } + + val getTodosMethod = if (isPython) "get_todos" else "getTodos" + assign("response", functionCall(getTodosMethod, + receiver = VariableReference("mainClient"), + isAwait = true, + )) + + assertDescriptionSwitch(response200Type) + } + } + + lang.run(testFile) + } + } +}) + +private fun response200Type(isRust: Boolean, isPython: Boolean): Type.Custom = when { + isPython -> Type.Custom("Response200") + isRust -> Type.Custom("Response::Response200") + else -> Type.Custom("GetTodos.Response200") +} + +private fun FunctionBuilder.assertDescriptionSwitch(response200Type: Type.Custom) { + switch(VariableReference("response"), variable = "r") { + case(response200Type) { + assertThat( + BinaryOp( + FieldCall( + ArrayIndexCall( + FieldCall(VariableReference("r"), Name.of("body")), + Literal(0, Type.Integer()), + ), + Name.of("description"), + ), + BinaryOp.Operator.EQUALS, + Literal("test", Type.String), + ), + "Description should be test", + ) + } + } +} diff --git a/src/verify/src/test/kotlin/community/flock/wirespec/verify/VerifyComplexModelTest.kt b/src/verify/src/test/kotlin/community/flock/wirespec/verify/VerifyComplexModelTest.kt new file mode 100644 index 000000000..a0fdaecdb --- /dev/null +++ b/src/verify/src/test/kotlin/community/flock/wirespec/verify/VerifyComplexModelTest.kt @@ -0,0 +1,119 @@ +package community.flock.wirespec.verify + +import community.flock.wirespec.compiler.test.CompileComplexModelTest +import community.flock.wirespec.ir.core.BinaryOp +import community.flock.wirespec.ir.core.NullableEmpty +import community.flock.wirespec.ir.core.NullableOf +import community.flock.wirespec.ir.core.VariableReference +import community.flock.wirespec.ir.core.file +import io.kotest.core.spec.style.FunSpec + +class VerifyComplexModelTest : FunSpec({ + + languages.values.forEach { lang -> + test("complex model validation valid - $lang") { + val testFile = file("ComplexModelValidationValid") { + import("community.flock.wirespec.generated.model", "Email") + import("community.flock.wirespec.generated.model", "PhoneNumber") + import("community.flock.wirespec.generated.model", "Tag") + import("community.flock.wirespec.generated.model", "EmployeeAge") + import("community.flock.wirespec.generated.model", "ContactInfo") + import("community.flock.wirespec.generated.model", "Employee") + import("community.flock.wirespec.generated.model", "Department") + import("community.flock.wirespec.generated.model", "Company") + main { + assign("email", construct(type("Email")) { + arg("value", literal("test@example.com")) + }) + assign("phone", construct(type("PhoneNumber")) { + arg("value", literal("+1234567890")) + }) + assign("tag1", construct(type("Tag")) { + arg("value", literal("developer")) + }) + assign("tag2", construct(type("Tag")) { + arg("value", literal("senior")) + }) + assign("age", construct(type("EmployeeAge")) { + arg("value", literal(30L)) + }) + assign("contactInfo", construct(type("ContactInfo")) { + arg("email", VariableReference("email")) + arg("phone", NullableOf(VariableReference("phone"))) + }) + assign("employee", construct(type("Employee")) { + arg("name", literal("John")) + arg("age", VariableReference("age")) + arg("contactInfo", VariableReference("contactInfo")) + arg("tags", listOf(listOf(VariableReference("tag1"), VariableReference("tag2")), type("Tag"))) + }) + assign("department", construct(type("Department")) { + arg("name", literal("Engineering")) + arg("employees", listOf(listOf(VariableReference("employee")), type("Employee"))) + }) + assign("company", construct(type("Company")) { + arg("name", literal("Acme")) + arg("departments", listOf(listOf(VariableReference("department")), type("Department"))) + }) + assign("errors", functionCall("validate", receiver = VariableReference("company"))) + assign("expected", literalList(string)) + assertThat(BinaryOp(VariableReference("errors"), BinaryOp.Operator.EQUALS, VariableReference("expected")), "Valid company should have no validation errors") + } + } + + lang.start(name = "complex-model-valid", fixture = CompileComplexModelTest) + lang.run(testFile) + } + + test("complex model validation invalid - $lang") { + val testFile = file("ComplexModelValidationInvalid") { + import("community.flock.wirespec.generated.model", "Email") + import("community.flock.wirespec.generated.model", "PhoneNumber") + import("community.flock.wirespec.generated.model", "Tag") + import("community.flock.wirespec.generated.model", "EmployeeAge") + import("community.flock.wirespec.generated.model", "ContactInfo") + import("community.flock.wirespec.generated.model", "Employee") + import("community.flock.wirespec.generated.model", "Department") + import("community.flock.wirespec.generated.model", "Company") + main { + assign("email", construct(type("Email")) { + arg("value", literal("not-an-email")) + }) + assign("tag1", construct(type("Tag")) { + arg("value", literal("valid")) + }) + assign("tag2", construct(type("Tag")) { + arg("value", literal("INVALID TAG!")) + }) + assign("age", construct(type("EmployeeAge")) { + arg("value", literal(10L)) + }) + assign("contactInfo", construct(type("ContactInfo")) { + arg("email", VariableReference("email")) + arg("phone", NullableEmpty) + }) + assign("employee", construct(type("Employee")) { + arg("name", literal("John")) + arg("age", VariableReference("age")) + arg("contactInfo", VariableReference("contactInfo")) + arg("tags", listOf(listOf(VariableReference("tag1"), VariableReference("tag2")), type("Tag"))) + }) + assign("department", construct(type("Department")) { + arg("name", literal("Engineering")) + arg("employees", listOf(listOf(VariableReference("employee")), type("Employee"))) + }) + assign("company", construct(type("Company")) { + arg("name", literal("Acme")) + arg("departments", listOf(listOf(VariableReference("department")), type("Department"))) + }) + assign("errors", functionCall("validate", receiver = VariableReference("company"))) + assign("expected", literalList(listOf(literal("departments[0].employees[0].age"), literal("departments[0].employees[0].contactInfo.email"), literal("departments[0].employees[0].tags[1]")), string)) + assertThat(BinaryOp(VariableReference("errors"), BinaryOp.Operator.EQUALS, VariableReference("expected")), "Invalid company should have validation errors for age, email, and tags[1]") + } + } + + lang.start(name = "complex-model-invalid", fixture = CompileComplexModelTest) + lang.run(testFile) + } + } +}) diff --git a/src/verify/src/test/kotlin/community/flock/wirespec/verify/VerifyConversionTest.kt b/src/verify/src/test/kotlin/community/flock/wirespec/verify/VerifyConversionTest.kt new file mode 100644 index 000000000..a9818e7b9 --- /dev/null +++ b/src/verify/src/test/kotlin/community/flock/wirespec/verify/VerifyConversionTest.kt @@ -0,0 +1,168 @@ +package community.flock.wirespec.verify + +import community.flock.wirespec.compiler.test.CompileMinimalEndpointTest +import community.flock.wirespec.emitters.java.JavaIrEmitter +import community.flock.wirespec.emitters.kotlin.KotlinIrEmitter +import community.flock.wirespec.emitters.python.PythonIrEmitter +import community.flock.wirespec.emitters.rust.RustIrEmitter +import community.flock.wirespec.emitters.scala.ScalaIrEmitter +import community.flock.wirespec.emitters.typescript.TypeScriptIrEmitter +import community.flock.wirespec.ir.core.BinaryOp +import community.flock.wirespec.ir.core.ConstructorStatement +import community.flock.wirespec.ir.core.Expression +import community.flock.wirespec.ir.core.FieldCall +import community.flock.wirespec.ir.core.Literal +import community.flock.wirespec.ir.core.Name +import community.flock.wirespec.ir.core.RawExpression +import community.flock.wirespec.ir.core.Type +import community.flock.wirespec.ir.core.VariableReference +import community.flock.wirespec.ir.core.file +import io.kotest.core.spec.style.FunSpec + +class VerifyConversionTest : FunSpec({ + + languages.values.forEach { lang -> + test("conversion functions - $lang") { + val isRust = lang.emitter is RustIrEmitter + val isTypeScript = lang.emitter is TypeScriptIrEmitter + val isPython = lang.emitter is PythonIrEmitter + val endpointRef: Expression? = if (isRust) null else RawExpression("GetTodos") + val requestType = if (isRust || isPython) Type.Custom("Request") else Type.Custom("GetTodos.Request") + val response200Type = if (isRust || isPython) Type.Custom("Response200") else Type.Custom("GetTodos.Response200") + val todoDtoType = Type.Custom("TodoDto") + + val testFile = file("ConversionTest") { + when (lang.emitter) { + is JavaIrEmitter -> { + import("community.flock.wirespec.java", "Wirespec") + import("community.flock.wirespec.generated.endpoint", "GetTodos") + import("community.flock.wirespec.generated.model", "TodoDto") + } + + is KotlinIrEmitter -> { + import("community.flock.wirespec.kotlin", "Wirespec") + import("community.flock.wirespec.generated.endpoint", "GetTodos") + import("community.flock.wirespec.generated.model", "TodoDto") + } + + is TypeScriptIrEmitter -> { + import("./Wirespec", "Wirespec") + import("./endpoint/GetTodos", "GetTodos") + import("./model/TodoDto", "TodoDto") + } + + is PythonIrEmitter -> { + import("community.flock.wirespec.generated.wirespec", "Wirespec") + import("community.flock.wirespec.generated.endpoint.GetTodos", "GetTodos") + import("community.flock.wirespec.generated.endpoint.GetTodos", "Request") + import("community.flock.wirespec.generated.endpoint.GetTodos", "Response200") + import("community.flock.wirespec.generated.model.TodoDto", "TodoDto") + } + + is ScalaIrEmitter -> { + import("community.flock.wirespec.scala", "Wirespec") + import("community.flock.wirespec.generated.endpoint", "GetTodos") + import("community.flock.wirespec.generated.model", "TodoDto") + } + + is RustIrEmitter -> { + // Rust imports are handled by run() use statements + import("community.flock.wirespec.generated.endpoint", "GetTodos") + import("community.flock.wirespec.generated.model", "TodoDto") + } + } + + main(statics = { raw(serializationCode(lang, CompileMinimalEndpointTest)) }) { + // toRawRequest + when { + isRust -> { + raw("let request = Request::new()") + raw("let raw_request = to_raw_request(&serialization, request)") + } + isTypeScript -> { + raw("const request = GetTodos.request()") + raw("const rawRequest = GetTodos.toRawRequest(serialization, request)") + } + else -> { + assign("request", construct(requestType)) + assign("rawRequest", functionCall("toRawRequest", receiver = endpointRef) { + arg("serialization", VariableReference("serialization")) + arg("request", VariableReference("request")) + }) + } + } + assertThat( + BinaryOp( + FieldCall(VariableReference("rawRequest"), Name.of("method")), + BinaryOp.Operator.EQUALS, + Literal("GET", Type.String) + ), + "Method should be GET" + ) + + // fromRawRequest + when { + isRust -> raw("let from_raw = from_raw_request(&serialization, raw_request)") + isTypeScript -> raw("const fromRaw = GetTodos.fromRawRequest(serialization, rawRequest)") + else -> assign("fromRaw", functionCall("fromRawRequest", receiver = endpointRef) { + arg("serialization", VariableReference("serialization")) + arg("request", VariableReference("rawRequest")) + }) + } + + // toRawResponse + when { + isRust -> { + raw("""let response200 = Response200::new(vec![TodoDto { description: "test".to_string() }])""") + raw("let raw_response = to_raw_response(&serialization, response200.into())") + } + isTypeScript -> { + raw("const response200 = GetTodos.response200({ body: [{ description: 'test' }] })") + raw("const rawResponse = GetTodos.toRawResponse(serialization, response200)") + } + else -> { + assign("response200", construct(response200Type) { + arg( + "body", listOf( + listOf( + ConstructorStatement( + todoDtoType, + mapOf(Name.of("description") to Literal("test", Type.String)) + ) + ), + todoDtoType + ) + ) + }) + assign("rawResponse", functionCall("toRawResponse", receiver = endpointRef) { + arg("serialization", VariableReference("serialization")) + arg("response", VariableReference("response200")) + }) + } + } + assertThat( + BinaryOp( + FieldCall(VariableReference("rawResponse"), Name.of("statusCode")), + BinaryOp.Operator.EQUALS, + Literal(200, Type.Integer()) + ), + "Status should be 200" + ) + + // fromRawResponse + when { + isRust -> raw("let from_raw_resp = from_raw_response(&serialization, raw_response)") + isTypeScript -> raw("const fromRawResp = GetTodos.fromRawResponse(serialization, rawResponse)") + else -> assign("fromRawResp", functionCall("fromRawResponse", receiver = endpointRef) { + arg("serialization", VariableReference("serialization")) + arg("response", VariableReference("rawResponse")) + }) + } + } + } + + lang.start(name = "conversion-test", fixture = CompileMinimalEndpointTest) + lang.run(testFile) + } + } +}) diff --git a/src/verify/src/test/kotlin/community/flock/wirespec/verify/VerifyFullEndpointTest.kt b/src/verify/src/test/kotlin/community/flock/wirespec/verify/VerifyFullEndpointTest.kt new file mode 100644 index 000000000..cf1287750 --- /dev/null +++ b/src/verify/src/test/kotlin/community/flock/wirespec/verify/VerifyFullEndpointTest.kt @@ -0,0 +1,17 @@ +package community.flock.wirespec.verify + +import community.flock.wirespec.compiler.test.CompileFullEndpointTest +import io.kotest.core.spec.style.FunSpec + +class VerifyFullEndpointTest : FunSpec({ + + languages.values.forEach { lang -> + test("full endpoint - $lang") { + lang.start( + name = "full-endpoint", + fixture = CompileFullEndpointTest, + ) + lang.compile() + } + } +}) diff --git a/src/verify/src/test/kotlin/community/flock/wirespec/verify/VerifyModelValidationTest.kt b/src/verify/src/test/kotlin/community/flock/wirespec/verify/VerifyModelValidationTest.kt new file mode 100644 index 000000000..307c97aae --- /dev/null +++ b/src/verify/src/test/kotlin/community/flock/wirespec/verify/VerifyModelValidationTest.kt @@ -0,0 +1,70 @@ +package community.flock.wirespec.verify + +import community.flock.wirespec.compiler.test.CompileNestedTypeTest +import community.flock.wirespec.ir.core.BinaryOp +import community.flock.wirespec.ir.core.VariableReference +import community.flock.wirespec.ir.core.file +import io.kotest.core.spec.style.FunSpec + +class VerifyModelValidationTest : FunSpec({ + + languages.values.forEach { lang -> + test("model validation valid - $lang") { + val testFile = file("ModelValidationValid") { + import("community.flock.wirespec.generated.model", "DutchPostalCode") + import("community.flock.wirespec.generated.model", "Address") + import("community.flock.wirespec.generated.model", "Person") + main { + assign("postalCode", construct(type("DutchPostalCode")) { + arg("value", literal("1234AB")) + }) + assign("address", construct(type("Address")) { + arg("street", literal("Main St")) + arg("houseNumber", literal(42L)) + arg("postalCode", VariableReference("postalCode")) + }) + assign("person", construct(type("Person")) { + arg("name", literal("John")) + arg("address", VariableReference("address")) + arg("tags", emptyList(string)) + }) + assign("errors", functionCall("validate", receiver = VariableReference("person"))) + assign("expected", literalList(string)) + assertThat(BinaryOp(VariableReference("errors"), BinaryOp.Operator.EQUALS, VariableReference("expected")), "Valid person should have no validation errors") + } + } + + lang.start(name = "model-validation-valid", fixture = CompileNestedTypeTest) + lang.run(testFile) + } + + test("model validation - $lang") { + val testFile = file("ModelValidation") { + import("community.flock.wirespec.generated.model", "DutchPostalCode") + import("community.flock.wirespec.generated.model", "Address") + import("community.flock.wirespec.generated.model", "Person") + main { + assign("postalCode", construct(type("DutchPostalCode")) { + arg("value", literal("invalid")) + }) + assign("address", construct(type("Address")) { + arg("street", literal("Main St")) + arg("houseNumber", literal(42L)) + arg("postalCode", VariableReference("postalCode")) + }) + assign("person", construct(type("Person")) { + arg("name", literal("John")) + arg("address", VariableReference("address")) + arg("tags", emptyList(string)) + }) + assign("errors", functionCall("validate", receiver = VariableReference("person"))) + assign("expected", literalList(listOf(literal("address.postalCode")), string)) + assertThat(BinaryOp(VariableReference("errors"), BinaryOp.Operator.EQUALS, VariableReference("expected")), "Refined type is not valid") + } + } + + lang.start(name = "model-validation", fixture = CompileNestedTypeTest) + lang.run(testFile) + } + } +}) diff --git a/src/verify/src/test/kotlin/community/flock/wirespec/verify/VerifyRefinedTest.kt b/src/verify/src/test/kotlin/community/flock/wirespec/verify/VerifyRefinedTest.kt new file mode 100644 index 000000000..04ee8233d --- /dev/null +++ b/src/verify/src/test/kotlin/community/flock/wirespec/verify/VerifyRefinedTest.kt @@ -0,0 +1,77 @@ +package community.flock.wirespec.verify + +import community.flock.wirespec.compiler.test.CompileRefinedTest +import community.flock.wirespec.ir.core.VariableReference +import community.flock.wirespec.ir.core.file +import io.kotest.core.spec.style.FunSpec + +class VerifyRefinedTest : FunSpec({ + + languages.values.forEach { lang -> + test("refined types - $lang") { + val testFile = file("RefinedValidation") { + import("community.flock.wirespec.generated.model", "TestInt2") + main { + assign("refined", construct(type("TestInt2")) { + arg("value", literal(2L)) + }) + assign("result", functionCall("validate", receiver = VariableReference("refined"))) + assertThat(VariableReference("result"), "Refined type is not valid") + } + } + + lang.start(name = "refined-types", fixture = CompileRefinedTest) + lang.run(testFile) + } + + test("refined type boundary validation - $lang") { + val testFile = file("RefinedBoundaryValidation") { + import("community.flock.wirespec.generated.model", "TestInt2") + import("community.flock.wirespec.generated.model", "TodoId") + import("community.flock.wirespec.generated.model", "TestNum2") + import("community.flock.wirespec.generated.model", "TestInt1") + import("community.flock.wirespec.generated.model", "TestNum1") + main { + assign("int2Min", construct(type("TestInt2")) { + arg("value", literal(1L)) + }) + assign("int2MinResult", functionCall("validate", receiver = VariableReference("int2Min"))) + assertThat(VariableReference("int2MinResult"), "TestInt2(1) should be valid") + + assign("int2Max", construct(type("TestInt2")) { + arg("value", literal(3L)) + }) + assign("int2MaxResult", functionCall("validate", receiver = VariableReference("int2Max"))) + assertThat(VariableReference("int2MaxResult"), "TestInt2(3) should be valid") + + assign("todoId", construct(type("TodoId")) { + arg("value", literal("550e8400-e29b-41d4-a716-446655440000")) + }) + assign("todoIdResult", functionCall("validate", receiver = VariableReference("todoId"))) + assertThat(VariableReference("todoIdResult"), "TodoId with valid UUID should be valid") + + assign("num2", construct(type("TestNum2")) { + arg("value", literal(0.3)) + }) + assign("num2Result", functionCall("validate", receiver = VariableReference("num2"))) + assertThat(VariableReference("num2Result"), "TestNum2(0.3) should be valid") + + assign("int1", construct(type("TestInt1")) { + arg("value", literal(0L)) + }) + assign("int1Result", functionCall("validate", receiver = VariableReference("int1"))) + assertThat(VariableReference("int1Result"), "TestInt1(0) should be valid") + + assign("num1", construct(type("TestNum1")) { + arg("value", literal(0.5)) + }) + assign("num1Result", functionCall("validate", receiver = VariableReference("num1"))) + assertThat(VariableReference("num1Result"), "TestNum1(0.5) should be valid") + } + } + + lang.start(name = "refined-boundary", fixture = CompileRefinedTest) + lang.run(testFile) + } + } +})