Skip to content

Commit d8b26ea

Browse files
committed
feat: initial array type support
1 parent b40b326 commit d8b26ea

File tree

9 files changed

+75
-43
lines changed

9 files changed

+75
-43
lines changed

internal/codegen/queries.go

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,16 +131,21 @@ func BuildQueriesFile(config core.Config, queryFilename string, queries []core.Q
131131
body.WriteString("\n")
132132
body.WriteIndentedString(1, "public record "+returnType+"(\n")
133133
for i, ret := range q.Returns {
134-
jt := ret.JavaType
134+
jt := ret.JavaType.Type
135135
if strings.Contains(jt, ".") {
136136
parts := strings.Split(jt, ".")
137137

138138
imports = append(imports, jt)
139139
jt = parts[len(parts)-1]
140140
}
141141

142+
if ret.JavaType.IsList {
143+
imports = append(imports, "java.util.List", "java.util.Arrays")
144+
jt = "List<" + jt + ">"
145+
}
146+
142147
annotation := nonNullAnnotation
143-
if ret.Nullable {
148+
if ret.JavaType.Nullable {
144149
annotation = nullableAnnotation
145150
}
146151

@@ -181,16 +186,21 @@ func BuildQueriesFile(config core.Config, queryFilename string, queries []core.Q
181186
body.WriteString("\n")
182187

183188
for i, arg := range q.Args {
184-
jt := arg.JavaType
189+
jt := arg.JavaType.Type
185190
if strings.Contains(jt, ".") {
186191
parts := strings.Split(jt, ".")
187192

188193
imports = append(imports, jt)
189194
jt = parts[len(parts)-1]
190195
}
191196

197+
if arg.JavaType.IsList {
198+
imports = append(imports, "java.util.List")
199+
jt = "List<" + jt + ">"
200+
}
201+
192202
annotation := nonNullAnnotation
193-
if arg.Nullable {
203+
if arg.JavaType.Nullable {
194204
annotation = nullableAnnotation
195205
}
196206

internal/core/models.go

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,18 +36,31 @@ func QueryCommandFor(rawCommand string) (QueryCommand, error) {
3636
}
3737
}
3838

39+
type JavaType struct {
40+
SqlType string
41+
Type string
42+
IsList bool
43+
Nullable bool
44+
}
45+
3946
type QueryArg struct {
4047
Number int
4148
Name string
42-
Nullable bool
43-
JavaType string
49+
JavaType JavaType
4450
}
4551

46-
// TODO - array types, enum types
52+
// TODO - enum types
4753
var literalBindTypes = []string{"Long", "Short", "String", "Boolean", "Float", "Double", "BigDecimal"}
4854

4955
func (q QueryArg) BindStmt() string {
50-
typeOnly := q.JavaType[strings.LastIndex(q.JavaType, ".")+1:]
56+
typeOnly := q.JavaType.Type[strings.LastIndex(q.JavaType.Type, ".")+1:]
57+
58+
if q.JavaType.IsList {
59+
if q.JavaType.Nullable {
60+
return fmt.Sprintf("stmt.setArray(%d, %s == null ? null : conn.createArrayOf(\"%s\", %s.toArray()));", q.Number, q.Name, q.JavaType.SqlType, q.Name)
61+
}
62+
return fmt.Sprintf("stmt.setArray(%d, conn.createArrayOf(\"%s\", %s.toArray()));", q.Number, q.JavaType.SqlType, q.Name)
63+
}
5164

5265
if slices.Contains(literalBindTypes, typeOnly) {
5366
return fmt.Sprintf("stmt.set%s(%d, %s);", typeOnly, q.Number, q.Name)
@@ -58,18 +71,20 @@ func (q QueryArg) BindStmt() string {
5871

5972
type QueryReturn struct {
6073
Name string
61-
Nullable bool
62-
JavaType string
74+
JavaType JavaType
6375
}
6476

6577
func (q QueryReturn) ResultStmt(number int) string {
66-
typeOnly := q.JavaType[strings.LastIndex(q.JavaType, ".")+1:]
78+
typeOnly := q.JavaType.Type[strings.LastIndex(q.JavaType.Type, ".")+1:]
79+
80+
if q.JavaType.IsList {
81+
return fmt.Sprintf("Arrays.asList((%s[]) results.getArray(%d).getArray())", typeOnly, number)
82+
}
6783

6884
if slices.Contains(literalBindTypes, typeOnly) {
6985
return fmt.Sprintf("results.get%s(%d)", typeOnly, number)
7086
}
7187

72-
// TODO - check correct!
7388
return fmt.Sprintf("results.getObject(%d, %s.class)", number, typeOnly)
7489
}
7590

internal/gen.go

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"fmt"
77
"github.com/iancoleman/strcase"
88
"github.com/sqlc-dev/plugin-sdk-go/plugin"
9+
"github.com/sqlc-dev/plugin-sdk-go/sdk"
910
"github.com/tandemdude/sqlc-gen-java/internal/codegen"
1011
"github.com/tandemdude/sqlc-gen-java/internal/core"
1112
"github.com/tandemdude/sqlc-gen-java/internal/sql_types"
@@ -84,11 +85,19 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat
8485
return nil, err
8586
}
8687

88+
if arg.Column.ArrayDims > 1 {
89+
return nil, fmt.Errorf("multidimensional arrays are not supported, store JSON instead")
90+
}
91+
8792
args = append(args, core.QueryArg{
88-
Number: int(arg.Number),
89-
Name: arg.Column.Name,
90-
Nullable: arg.Column.NotNull == false,
91-
JavaType: javaType,
93+
Number: int(arg.Number),
94+
Name: arg.Column.Name,
95+
JavaType: core.JavaType{
96+
SqlType: sdk.DataType(arg.Column.Type),
97+
Type: javaType,
98+
IsList: arg.Column.ArrayDims > 0, // TODO check this will always be present
99+
Nullable: !arg.Column.NotNull,
100+
},
92101
})
93102
}
94103

@@ -100,10 +109,18 @@ func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.Generat
100109
return nil, err
101110
}
102111

112+
if ret.ArrayDims > 1 {
113+
return nil, fmt.Errorf("multidimensional arrays are not supported, store JSON instead")
114+
}
115+
103116
returns = append(returns, core.QueryReturn{
104-
Name: ret.Name,
105-
Nullable: ret.NotNull == false,
106-
JavaType: javaType,
117+
Name: ret.Name,
118+
JavaType: core.JavaType{
119+
SqlType: sdk.DataType(ret.Type),
120+
Type: javaType,
121+
IsList: ret.ArrayDims > 0, // TODO check this will always be present
122+
Nullable: !ret.NotNull,
123+
},
107124
})
108125
}
109126

internal/sql_types/postgresql.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ func PostgresTypeToJavaType(identifier *plugin.Identifier) (string, error) {
3131
return "java.time.LocalDate", nil
3232
case "pg_catalog.time", "pg_catalog.timetz":
3333
return "java.time.LocalTime", nil
34-
case "pg_catalog.timestamp":
34+
case "pg_catalog.timestamp", "timestamp":
3535
return "java.time.LocalDateTime", nil
3636
case "pg_catalog.timestamptz", "timestamptz":
3737
return "java.time.OffsetDateTime", nil

tests/.gitignore

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ build/
3838
.DS_Store
3939

4040
### Generated Files ###
41-
src/main/java/io/github/tandemdude/sgj/mysql/**.java
41+
src/main/java/io/github/tandemdude/sgj/mysql/*.java
4242
!src/main/java/io/github/tandemdude/sgj/mysql/package-info.java
43-
src/main/java/io/github/tandemdude/sgj/postgresql/**.java
44-
!src/main/java/io/github/tandemdude/sgj/postgresql/package-info.java
45-
src/main/java/io/github/tandemdude/sgj/sqlite/**.java
43+
src/main/java/io/github/tandemdude/sgj/postgres/*.java
44+
!src/main/java/io/github/tandemdude/sgj/postgres/package-info.java
45+
src/main/java/io/github/tandemdude/sgj/sqlite/*.java
4646
!src/main/java/io/github/tandemdude/sgj/sqlite/package-info.java

tests/sqlc.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ plugins:
33
- name: java
44
wasm:
55
url: file://sqlc-gen-java.wasm
6-
sha256: 8cbaf60688121152cac88287a368251dde08712382b62cc75bd5089e202dbb03
6+
sha256: a8a2028aa1fb7c6fed6ae5dc19e10d876448fd2766a916b4a208d7d11da91f2f
77
sql:
88
- schema: src/main/resources/postgres/schema.sql
99
queries: src/main/resources/postgres/queries.sql

tests/src/main/java/io/github/tandemdude/sgj/Main.java

Lines changed: 0 additions & 17 deletions
This file was deleted.

tests/src/main/resources/postgres/queries.sql

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,10 @@ SELECT * FROM users WHERE user_id = $1;
66

77
-- name: ListUsers :many
88
SELECT * FROM users;
9+
10+
-- name: GetMessage :one
11+
SELECT * FROM messages WHERE message_id = $1;
12+
13+
-- name: CreateMessage :exec
14+
INSERT INTO messages(chat_id, user_id, content, attachments)
15+
VALUES ($1, $2, $3, $4);

tests/src/main/resources/postgres/schema.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ CREATE TABLE messages (
2727
chat_id INTEGER NOT NULL REFERENCES chats(chat_id),
2828
user_id UUID NOT NULL REFERENCES users(user_id),
2929
content TEXT NOT NULL,
30-
-- attachments TEXT[], -- Array column for storing attachment URLs
30+
attachments TEXT[], -- Array column for storing attachment URLs
3131
sent_at_time TIME DEFAULT NOW()::time without time zone,
3232
sent_at_date DATE DEFAULT NOW()::date
3333
);

0 commit comments

Comments
 (0)