Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 66 additions & 22 deletions pkg/funcspec/function_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -522,10 +522,16 @@ func (h *Handler) SqlQuery(sqlquery string, options SqlQueryOptions) HTTPFuncTyp
if strings.HasPrefix(kLower, "x-fieldfilter-") {
colname := strings.ReplaceAll(kLower, "x-fieldfilter-", "")
if strings.Contains(strings.ToLower(sqlquery), colname) {
if val == "" || val == "0" {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("COALESCE(%s, 0) = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
if val == "0" {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("COALESCE(%s, 0) = 0", ValidSQL(colname, "colname")))
} else if val == "" {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("(%[1]s = '' OR %[1]s IS NULL)", ValidSQL(colname, "colname")))
} else {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
if IsNumeric(val) {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
} else {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = '%s'", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
}
}
}
}
Expand Down Expand Up @@ -662,7 +668,10 @@ func (h *Handler) mergePathParams(r *http.Request, sqlquery string, variables ma
for k, v := range pathVars {
kword := fmt.Sprintf("[%s]", k)
if strings.Contains(sqlquery, kword) {
sqlquery = strings.ReplaceAll(sqlquery, kword, fmt.Sprintf("%v", v))
// Sanitize the value before replacing
vStr := fmt.Sprintf("%v", v)
sanitized := ValidSQL(vStr, "colvalue")
sqlquery = strings.ReplaceAll(sqlquery, kword, sanitized)
}
variables[k] = v

Expand Down Expand Up @@ -690,7 +699,9 @@ func (h *Handler) mergeQueryParams(r *http.Request, sqlquery string, variables m
// Replace in SQL if placeholder exists
if strings.Contains(sqlquery, kword) && len(val) > 0 {
if strings.HasPrefix(parmk, "p-") {
sqlquery = strings.ReplaceAll(sqlquery, kword, val)
// Sanitize the parameter value before replacing
sanitized := ValidSQL(val, "colvalue")
sqlquery = strings.ReplaceAll(sqlquery, kword, sanitized)
}
}

Expand All @@ -702,15 +713,36 @@ func (h *Handler) mergeQueryParams(r *http.Request, sqlquery string, variables m
// Apply filters if allowed
if allowFilter && len(parmk) > 1 && strings.Contains(strings.ToLower(sqlquery), strings.ToLower(parmk)) {
if len(parmv) > 1 {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s IN (%s)", ValidSQL(parmk, "colname"), strings.Join(parmv, ",")))
// Sanitize each value in the IN clause with appropriate quoting
sanitizedValues := make([]string, len(parmv))
for i, v := range parmv {
if IsNumeric(v) {
// Numeric values don't need quotes
sanitizedValues[i] = ValidSQL(v, "colvalue")
} else {
// String values need quotes
sanitized := ValidSQL(v, "colvalue")
sanitizedValues[i] = fmt.Sprintf("'%s'", sanitized)
}
}
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s IN (%s)", ValidSQL(parmk, "colname"), strings.Join(sanitizedValues, ",")))
} else {
if strings.Contains(val, "match=") {
colval := strings.ReplaceAll(val, "match=", "")
// Escape single quotes and backslashes for LIKE patterns
// But don't escape wildcards % and _ which are intentional
colval = strings.ReplaceAll(colval, "\\", "\\\\")
colval = strings.ReplaceAll(colval, "'", "''")
if colval != "*" {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s ILIKE '%%%s%%'", ValidSQL(parmk, "colname"), ValidSQL(colval, "colvalue")))
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s ILIKE '%%%s%%'", ValidSQL(parmk, "colname"), colval))
}
} else if val == "" || val == "0" {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("(%[1]s = %[2]s OR %[1]s IS NULL)", ValidSQL(parmk, "colname"), ValidSQL(val, "colvalue")))
// For empty/zero values, treat as literal 0 or empty string with quotes
if val == "0" {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("(%[1]s = 0 OR %[1]s IS NULL)", ValidSQL(parmk, "colname")))
} else {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("(%[1]s = '' OR %[1]s IS NULL)", ValidSQL(parmk, "colname")))
}
} else {
if IsNumeric(val) {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(parmk, "colname"), ValidSQL(val, "colvalue")))
Expand Down Expand Up @@ -743,16 +775,24 @@ func (h *Handler) mergeHeaderParams(r *http.Request, sqlquery string, variables

kword := fmt.Sprintf("[%s]", k)
if strings.Contains(sqlquery, kword) {
sqlquery = strings.ReplaceAll(sqlquery, kword, val)
// Sanitize the header value before replacing
sanitized := ValidSQL(val, "colvalue")
sqlquery = strings.ReplaceAll(sqlquery, kword, sanitized)
}

// Handle special headers
if strings.Contains(k, "x-fieldfilter-") {
colname := strings.ReplaceAll(k, "x-fieldfilter-", "")
if val == "" || val == "0" {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("COALESCE(%s, 0) = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
if val == "0" {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("COALESCE(%s, 0) = 0", ValidSQL(colname, "colname")))
} else if val == "" {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("(%[1]s = '' OR %[1]s IS NULL)", ValidSQL(colname, "colname")))
} else {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
if IsNumeric(val) {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = %s", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
} else {
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s = '%s'", ValidSQL(colname, "colname"), ValidSQL(val, "colvalue")))
}
}
}

Expand Down Expand Up @@ -862,19 +902,23 @@ func ValidSQL(input, mode string) string {
reg := regexp.MustCompile(`[^a-zA-Z0-9_\.]`)
return reg.ReplaceAllString(input, "")
case "colvalue":
// For column values, escape single quotes
return strings.ReplaceAll(input, "'", "''")
// For column values, escape single quotes and backslashes
// Note: Backslashes must be escaped first, then single quotes
result := strings.ReplaceAll(input, "\\", "\\\\")
result = strings.ReplaceAll(result, "'", "''")
return result
case "select":
// For SELECT clauses, be more permissive but still safe
// Remove semicolons and common SQL injection patterns
dangerous := []string{";", "--", "/*", "*/", "xp_", "sp_", "DROP ", "DELETE ", "TRUNCATE ", "UPDATE ", "INSERT "}
result := input
for _, d := range dangerous {
result = strings.ReplaceAll(result, d, "")
result = strings.ReplaceAll(result, strings.ToLower(d), "")
result = strings.ReplaceAll(result, strings.ToUpper(d), "")
// Remove semicolons and common SQL injection patterns (case-insensitive)
dangerous := []string{
";", "--", "/\\*", "\\*/", "xp_", "sp_",
"drop ", "delete ", "truncate ", "update ", "insert ",
"exec ", "execute ", "union ", "declare ", "alter ", "create ",
}
return result
// Build a single regex pattern with all dangerous keywords
pattern := "(?i)(" + strings.Join(dangerous, "|") + ")"
re := regexp.MustCompile(pattern)
return re.ReplaceAllString(input, "")
default:
return input
}
Expand Down