You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
304 lines
6.1 KiB
304 lines
6.1 KiB
package sanitize |
|
|
|
import ( |
|
"bytes" |
|
"encoding/hex" |
|
"fmt" |
|
"strconv" |
|
"strings" |
|
"time" |
|
"unicode/utf8" |
|
) |
|
|
|
// Part is either a string or an int. A string is raw SQL. An int is a |
|
// argument placeholder. |
|
type Part interface{} |
|
|
|
type Query struct { |
|
Parts []Part |
|
} |
|
|
|
func (q *Query) Sanitize(args ...interface{}) (string, error) { |
|
argUse := make([]bool, len(args)) |
|
buf := &bytes.Buffer{} |
|
|
|
for _, part := range q.Parts { |
|
var str string |
|
switch part := part.(type) { |
|
case string: |
|
str = part |
|
case int: |
|
argIdx := part - 1 |
|
if argIdx >= len(args) { |
|
return "", fmt.Errorf("insufficient arguments") |
|
} |
|
arg := args[argIdx] |
|
switch arg := arg.(type) { |
|
case nil: |
|
str = "null" |
|
case int64: |
|
str = strconv.FormatInt(arg, 10) |
|
case float64: |
|
str = strconv.FormatFloat(arg, 'f', -1, 64) |
|
case bool: |
|
str = strconv.FormatBool(arg) |
|
case []byte: |
|
str = QuoteBytes(arg) |
|
case string: |
|
str = QuoteString(arg) |
|
case time.Time: |
|
str = arg.Truncate(time.Microsecond).Format("'2006-01-02 15:04:05.999999999Z07:00:00'") |
|
default: |
|
return "", fmt.Errorf("invalid arg type: %T", arg) |
|
} |
|
argUse[argIdx] = true |
|
default: |
|
return "", fmt.Errorf("invalid Part type: %T", part) |
|
} |
|
buf.WriteString(str) |
|
} |
|
|
|
for i, used := range argUse { |
|
if !used { |
|
return "", fmt.Errorf("unused argument: %d", i) |
|
} |
|
} |
|
return buf.String(), nil |
|
} |
|
|
|
func NewQuery(sql string) (*Query, error) { |
|
l := &sqlLexer{ |
|
src: sql, |
|
stateFn: rawState, |
|
} |
|
|
|
for l.stateFn != nil { |
|
l.stateFn = l.stateFn(l) |
|
} |
|
|
|
query := &Query{Parts: l.parts} |
|
|
|
return query, nil |
|
} |
|
|
|
func QuoteString(str string) string { |
|
return "'" + strings.ReplaceAll(str, "'", "''") + "'" |
|
} |
|
|
|
func QuoteBytes(buf []byte) string { |
|
return `'\x` + hex.EncodeToString(buf) + "'" |
|
} |
|
|
|
type sqlLexer struct { |
|
src string |
|
start int |
|
pos int |
|
nested int // multiline comment nesting level. |
|
stateFn stateFn |
|
parts []Part |
|
} |
|
|
|
type stateFn func(*sqlLexer) stateFn |
|
|
|
func rawState(l *sqlLexer) stateFn { |
|
for { |
|
r, width := utf8.DecodeRuneInString(l.src[l.pos:]) |
|
l.pos += width |
|
|
|
switch r { |
|
case 'e', 'E': |
|
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) |
|
if nextRune == '\'' { |
|
l.pos += width |
|
return escapeStringState |
|
} |
|
case '\'': |
|
return singleQuoteState |
|
case '"': |
|
return doubleQuoteState |
|
case '$': |
|
nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:]) |
|
if '0' <= nextRune && nextRune <= '9' { |
|
if l.pos-l.start > 0 { |
|
l.parts = append(l.parts, l.src[l.start:l.pos-width]) |
|
} |
|
l.start = l.pos |
|
return placeholderState |
|
} |
|
case '-': |
|
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) |
|
if nextRune == '-' { |
|
l.pos += width |
|
return oneLineCommentState |
|
} |
|
case '/': |
|
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) |
|
if nextRune == '*' { |
|
l.pos += width |
|
return multilineCommentState |
|
} |
|
case utf8.RuneError: |
|
if l.pos-l.start > 0 { |
|
l.parts = append(l.parts, l.src[l.start:l.pos]) |
|
l.start = l.pos |
|
} |
|
return nil |
|
} |
|
} |
|
} |
|
|
|
func singleQuoteState(l *sqlLexer) stateFn { |
|
for { |
|
r, width := utf8.DecodeRuneInString(l.src[l.pos:]) |
|
l.pos += width |
|
|
|
switch r { |
|
case '\'': |
|
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) |
|
if nextRune != '\'' { |
|
return rawState |
|
} |
|
l.pos += width |
|
case utf8.RuneError: |
|
if l.pos-l.start > 0 { |
|
l.parts = append(l.parts, l.src[l.start:l.pos]) |
|
l.start = l.pos |
|
} |
|
return nil |
|
} |
|
} |
|
} |
|
|
|
func doubleQuoteState(l *sqlLexer) stateFn { |
|
for { |
|
r, width := utf8.DecodeRuneInString(l.src[l.pos:]) |
|
l.pos += width |
|
|
|
switch r { |
|
case '"': |
|
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) |
|
if nextRune != '"' { |
|
return rawState |
|
} |
|
l.pos += width |
|
case utf8.RuneError: |
|
if l.pos-l.start > 0 { |
|
l.parts = append(l.parts, l.src[l.start:l.pos]) |
|
l.start = l.pos |
|
} |
|
return nil |
|
} |
|
} |
|
} |
|
|
|
// placeholderState consumes a placeholder value. The $ must have already has |
|
// already been consumed. The first rune must be a digit. |
|
func placeholderState(l *sqlLexer) stateFn { |
|
num := 0 |
|
|
|
for { |
|
r, width := utf8.DecodeRuneInString(l.src[l.pos:]) |
|
l.pos += width |
|
|
|
if '0' <= r && r <= '9' { |
|
num *= 10 |
|
num += int(r - '0') |
|
} else { |
|
l.parts = append(l.parts, num) |
|
l.pos -= width |
|
l.start = l.pos |
|
return rawState |
|
} |
|
} |
|
} |
|
|
|
func escapeStringState(l *sqlLexer) stateFn { |
|
for { |
|
r, width := utf8.DecodeRuneInString(l.src[l.pos:]) |
|
l.pos += width |
|
|
|
switch r { |
|
case '\\': |
|
_, width = utf8.DecodeRuneInString(l.src[l.pos:]) |
|
l.pos += width |
|
case '\'': |
|
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) |
|
if nextRune != '\'' { |
|
return rawState |
|
} |
|
l.pos += width |
|
case utf8.RuneError: |
|
if l.pos-l.start > 0 { |
|
l.parts = append(l.parts, l.src[l.start:l.pos]) |
|
l.start = l.pos |
|
} |
|
return nil |
|
} |
|
} |
|
} |
|
|
|
func oneLineCommentState(l *sqlLexer) stateFn { |
|
for { |
|
r, width := utf8.DecodeRuneInString(l.src[l.pos:]) |
|
l.pos += width |
|
|
|
switch r { |
|
case '\\': |
|
_, width = utf8.DecodeRuneInString(l.src[l.pos:]) |
|
l.pos += width |
|
case '\n': |
|
return rawState |
|
case utf8.RuneError: |
|
if l.pos-l.start > 0 { |
|
l.parts = append(l.parts, l.src[l.start:l.pos]) |
|
l.start = l.pos |
|
} |
|
return nil |
|
} |
|
} |
|
} |
|
|
|
func multilineCommentState(l *sqlLexer) stateFn { |
|
for { |
|
r, width := utf8.DecodeRuneInString(l.src[l.pos:]) |
|
l.pos += width |
|
|
|
switch r { |
|
case '/': |
|
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) |
|
if nextRune == '*' { |
|
l.pos += width |
|
l.nested++ |
|
} |
|
case '*': |
|
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) |
|
if nextRune != '/' { |
|
continue |
|
} |
|
|
|
l.pos += width |
|
if l.nested == 0 { |
|
return rawState |
|
} |
|
l.nested-- |
|
|
|
case utf8.RuneError: |
|
if l.pos-l.start > 0 { |
|
l.parts = append(l.parts, l.src[l.start:l.pos]) |
|
l.start = l.pos |
|
} |
|
return nil |
|
} |
|
} |
|
} |
|
|
|
// SanitizeSQL replaces placeholder values with args. It quotes and escapes args |
|
// as necessary. This function is only safe when standard_conforming_strings is |
|
// on. |
|
func SanitizeSQL(sql string, args ...interface{}) (string, error) { |
|
query, err := NewQuery(sql) |
|
if err != nil { |
|
return "", err |
|
} |
|
return query.Sanitize(args...) |
|
}
|
|
|