diff --git a/build.go b/build.go index 8ea30c3..8df7c25 100644 --- a/build.go +++ b/build.go @@ -75,18 +75,6 @@ const ( Password Format = "password" ) */ -// common media types -const ( - Json MIMEType = "application/json" - Xml MIMEType = "application/xml" - Text MIMEType = "text/plain" - General MIMEType = "application/octet-stream" - Html MIMEType = "text/html" - XForm MIMEType = "application/x-www-form-urlencoded" - Jscript MIMEType = "application/javascript" - Form MIMEType = "multipart/form-data" -) - func (o *OpenAPI) AddTags(t ...Tag) { o.Tags = append(o.Tags, t...) } diff --git a/openapi.go b/openapi.go index 252b128..6272458 100644 --- a/openapi.go +++ b/openapi.go @@ -85,6 +85,19 @@ func (c *Code) UnmarshalText(b []byte) error { } type MIMEType string + +// common media types +const ( + Json MIMEType = "application/json" + Xml MIMEType = "application/xml" + Text MIMEType = "text/plain" + General MIMEType = "application/octet-stream" + Html MIMEType = "text/html" + XForm MIMEType = "application/x-www-form-urlencoded" + Jscript MIMEType = "application/javascript" + Form MIMEType = "multipart/form-data" +) + type Content map[MIMEType]Media type Media struct { diff --git a/paths.go b/paths.go index 6b2f6a1..2814ede 100644 --- a/paths.go +++ b/paths.go @@ -201,9 +201,10 @@ func (r Response) WithNamedExample(name string, i any) Response { if r.Content == nil { r.Content = make(Content) } - m := r.Content[Json] + c := getContentType(i) + m := r.Content[c] m.AddExample(name, i) - r.Content[Json] = m + r.Content[c] = m return r } @@ -261,6 +262,53 @@ func (r RequestBody) WithNamedJsonString(name string, s string) RequestBody { return r.WithExample(m) } +// getContentType attempts to determine the MIME type based on the data provided. +func getContentType(i any) MIMEType { + switch v := i.(type) { + case string: + // Try to detect content type from string content + s := strings.TrimSpace(v) + if len(s) == 0 { + return Text + } + + // Check for XML/HTML content first (both start with <) + if strings.HasPrefix(s, "<") { + // Check for XML declaration + if strings.HasPrefix(s, "", "", "
    ", "
  1. ", "", "console.log", "document.", "window."} + for _, pattern := range jsPatterns { + if strings.Contains(s, pattern) { + return Jscript + } + } + + // Default to plain text for other strings + return Text + case []byte: + return General + case map[string]any, []any, JSONString: + } + // default to JSON for any other type + return Json +} + // Deprecated: use WithExample(JSONString(s)) instead func (r RequestBody) WithJSONString(s string) RequestBody { return r.WithNamedJsonString("", s) @@ -274,9 +322,10 @@ func (r RequestBody) WithNamedExample(name string, i any) RequestBody { if r.Content == nil { r.Content = make(Content) } - m := r.Content[Json] + c := getContentType(i) + m := r.Content[c] m.AddExample(name, i) - r.Content[Json] = m + r.Content[c] = m return r } diff --git a/paths_test.go b/paths_test.go index b9f9d63..01e4630 100644 --- a/paths_test.go +++ b/paths_test.go @@ -333,7 +333,102 @@ func TestMarshalRoute(t *testing.T) { }, } trial.New(fn, cases).SubTest(t) +} +func TestGetContentType(t *testing.T) { + fn := func(i any) (MIMEType, error) { + return getContentType(i), nil + } + cases := trial.Cases[any, MIMEType]{ + "jsonString": { + Input: JSONString(`{"key":"value"}`), + Expected: Json, + }, + "csv": { + Input: "data1,data2,data3", + Expected: Text, + }, + "xml": { + Input: "value", + Expected: Xml, + }, + "xmlWithDeclaration": { + Input: `value`, + Expected: Xml, + }, + "html": { + Input: "Hello World", + Expected: Html, + }, + "htmlWithHead": { + Input: "TestContent", + Expected: Html, + }, + "htmlWithDiv": { + Input: "
    Hello World
    ", + Expected: Html, + }, + "htmlWithParagraph": { + Input: "

    This is a paragraph

    ", + Expected: Html, + }, + "htmlWithHeading": { + Input: "

    Main Heading

    ", + Expected: Html, + }, + "htmlWithForm": { + Input: "
    ", + Expected: Html, + }, + "javascript": { + Input: "function hello() { console.log('Hello World'); }", + Expected: Jscript, + }, + "javascriptWithVar": { + Input: "var name = 'John'; let age = 30; const city = 'NYC';", + Expected: Jscript, + }, + "javascriptWithArrow": { + Input: "const greet = () => 'Hello';", + Expected: Jscript, + }, + "javascriptWithConsole": { + Input: "console.log('Debug info');", + Expected: Jscript, + }, + "javascriptWithDocument": { + Input: "document.getElementById('myElement');", + Expected: Jscript, + }, + "javascriptWithWindow": { + Input: "window.location.href = '/new-page';", + Expected: Jscript, + }, + "plainText": { + Input: "This is just plain text without any special formatting", + Expected: Text, + }, + "emptyString": { + Input: "", + Expected: Text, + }, + "whitespaceString": { + Input: " \n\t ", + Expected: Text, + }, + "struct": { + Input: struct{ Name string }{Name: "example"}, + Expected: Json, + }, + "bytes": {Input: []byte(`902kn219jsk`), + Expected: General, + }, + "map": { + Input: map[string]int{"name": 1}, + Expected: Json, + }, + } + trial.New(fn, cases).SubTest(t) } func TestRouteAddSecurityRequirement(t *testing.T) {