diff options
Diffstat (limited to 'vendor/modernc.org/sqlite/sqlite.go')
-rw-r--r-- | vendor/modernc.org/sqlite/sqlite.go | 225 |
1 files changed, 218 insertions, 7 deletions
diff --git a/vendor/modernc.org/sqlite/sqlite.go b/vendor/modernc.org/sqlite/sqlite.go index f8c696f6..7fce68ee 100644 --- a/vendor/modernc.org/sqlite/sqlite.go +++ b/vendor/modernc.org/sqlite/sqlite.go @@ -1090,14 +1090,18 @@ func (c *conn) bindText(pstmt uintptr, idx1 int, value string) (uintptr, error) // int sqlite3_bind_blob(sqlite3_stmt*, int, const void*, int n, void(*)(void*)); func (c *conn) bindBlob(pstmt uintptr, idx1 int, value []byte) (uintptr, error) { + if len(value) == 0 { + if rc := sqlite3.Xsqlite3_bind_zeroblob(c.tls, pstmt, int32(idx1), 0); rc != sqlite3.SQLITE_OK { + return 0, c.errstr(rc) + } + return 0, nil + } + p, err := c.malloc(len(value)) if err != nil { return 0, err } - - if len(value) != 0 { - copy((*libc.RawMem)(unsafe.Pointer(p))[:len(value):len(value)], value) - } + copy((*libc.RawMem)(unsafe.Pointer(p))[:len(value):len(value)], value) if rc := sqlite3.Xsqlite3_bind_blob(c.tls, pstmt, int32(idx1), p, int32(len(value)), 0); rc != sqlite3.SQLITE_OK { c.free(p) return 0, c.errstr(rc) @@ -1307,6 +1311,7 @@ func (c *conn) Close() error { c.db = 0 } + if c.tls != nil { c.tls.Close() c.tls = nil @@ -1323,6 +1328,32 @@ func (c *conn) closeV2(db uintptr) error { return nil } +type userDefinedFunction struct { + zFuncName uintptr + nArg int32 + eTextRep int32 + xFunc func(*libc.TLS, uintptr, int32, uintptr) + + freeOnce sync.Once +} + +func (c *conn) createFunctionInternal(fun *userDefinedFunction) error { + if rc := sqlite3.Xsqlite3_create_function( + c.tls, + c.db, + fun.zFuncName, + fun.nArg, + fun.eTextRep, + 0, + *(*uintptr)(unsafe.Pointer(&fun.xFunc)), + 0, + 0, + ); rc != sqlite3.SQLITE_OK { + return c.errstr(rc) + } + return nil +} + // Execer is an optional interface that may be implemented by a Conn. // // If a Conn does not implement Execer, the sql package's DB.Exec will first @@ -1388,9 +1419,14 @@ func (c *conn) query(ctx context.Context, query string, args []driver.NamedValue } // Driver implements database/sql/driver.Driver. -type Driver struct{} +type Driver struct { + // user defined functions that are added to every new connection on Open + udfs map[string]*userDefinedFunction +} + +var d = &Driver{udfs: make(map[string]*userDefinedFunction)} -func newDriver() *Driver { return &Driver{} } +func newDriver() *Driver { return d } // Open returns a new connection to the database. The name is a string in a // driver-specific format. @@ -1422,5 +1458,180 @@ func newDriver() *Driver { return &Driver{} } // available at // https://www.sqlite.org/lang_transaction.html#deferred_immediate_and_exclusive_transactions func (d *Driver) Open(name string) (driver.Conn, error) { - return newConn(name) + c, err := newConn(name) + if err != nil { + return nil, err + } + + for _, udf := range d.udfs { + if err = c.createFunctionInternal(udf); err != nil { + c.Close() + return nil, err + } + } + return c, nil +} + +// FunctionContext represents the context user defined functions execute in. +// Fields and/or methods of this type may get addedd in the future. +type FunctionContext struct{} + +const sqliteValPtrSize = unsafe.Sizeof(&sqlite3.Sqlite3_value{}) + +// RegisterScalarFunction registers a scalar function named zFuncName with nArg +// arguments. Passing -1 for nArg indicates the function is variadic. +// +// The new function will be available to all new connections opened after +// executing RegisterScalarFunction. +func RegisterScalarFunction( + zFuncName string, + nArg int32, + xFunc func(ctx *FunctionContext, args []driver.Value) (driver.Value, error), +) error { + return registerScalarFunction(zFuncName, nArg, sqlite3.SQLITE_UTF8, xFunc) +} + +// MustRegisterScalarFunction is like RegisterScalarFunction but panics on +// error. +func MustRegisterScalarFunction( + zFuncName string, + nArg int32, + xFunc func(ctx *FunctionContext, args []driver.Value) (driver.Value, error), +) { + if err := RegisterScalarFunction(zFuncName, nArg, xFunc); err != nil { + panic(err) + } +} + +// MustRegisterDeterministicScalarFunction is like +// RegisterDeterministicScalarFunction but panics on error. +func MustRegisterDeterministicScalarFunction( + zFuncName string, + nArg int32, + xFunc func(ctx *FunctionContext, args []driver.Value) (driver.Value, error), +) { + if err := RegisterDeterministicScalarFunction(zFuncName, nArg, xFunc); err != nil { + panic(err) + } +} + +// RegisterDeterministicScalarFunction registers a deterministic scalar +// function named zFuncName with nArg arguments. Passing -1 for nArg indicates +// the function is variadic. A deterministic function means that the function +// always gives the same output when the input parameters are the same. +// +// The new function will be available to all new connections opened after +// executing RegisterDeterministicScalarFunction. +func RegisterDeterministicScalarFunction( + zFuncName string, + nArg int32, + xFunc func(ctx *FunctionContext, args []driver.Value) (driver.Value, error), +) error { + return registerScalarFunction(zFuncName, nArg, sqlite3.SQLITE_UTF8|sqlite3.SQLITE_DETERMINISTIC, xFunc) +} + +func registerScalarFunction( + zFuncName string, + nArg int32, + eTextRep int32, + xFunc func(ctx *FunctionContext, args []driver.Value) (driver.Value, error), +) error { + + if _, ok := d.udfs[zFuncName]; ok { + return fmt.Errorf("a function named %q is already registered", zFuncName) + } + + // dont free, functions registered on the driver live as long as the program + name, err := libc.CString(zFuncName) + if err != nil { + return err + } + + udf := &userDefinedFunction{ + zFuncName: name, + nArg: nArg, + eTextRep: eTextRep, + xFunc: func(tls *libc.TLS, ctx uintptr, argc int32, argv uintptr) { + setErrorResult := func(res error) { + errmsg, cerr := libc.CString(res.Error()) + if cerr != nil { + panic(cerr) + } + defer libc.Xfree(tls, errmsg) + sqlite3.Xsqlite3_result_error(tls, ctx, errmsg, -1) + sqlite3.Xsqlite3_result_error_code(tls, ctx, sqlite3.SQLITE_ERROR) + } + + args := make([]driver.Value, argc) + for i := int32(0); i < argc; i++ { + valPtr := *(*uintptr)(unsafe.Pointer(argv + uintptr(i)*sqliteValPtrSize)) + + switch valType := sqlite3.Xsqlite3_value_type(tls, valPtr); valType { + case sqlite3.SQLITE_TEXT: + args[i] = libc.GoString(sqlite3.Xsqlite3_value_text(tls, valPtr)) + case sqlite3.SQLITE_INTEGER: + args[i] = sqlite3.Xsqlite3_value_int64(tls, valPtr) + case sqlite3.SQLITE_FLOAT: + args[i] = sqlite3.Xsqlite3_value_double(tls, valPtr) + case sqlite3.SQLITE_NULL: + args[i] = nil + case sqlite3.SQLITE_BLOB: + size := sqlite3.Xsqlite3_value_bytes(tls, valPtr) + blobPtr := sqlite3.Xsqlite3_value_blob(tls, valPtr) + v := make([]byte, size) + copy(v, (*libc.RawMem)(unsafe.Pointer(blobPtr))[:size:size]) + args[i] = v + default: + panic(fmt.Sprintf("unexpected argument type %q passed by sqlite", valType)) + } + } + + res, err := xFunc(&FunctionContext{}, args) + if err != nil { + setErrorResult(err) + return + } + + switch resTyped := res.(type) { + case nil: + sqlite3.Xsqlite3_result_null(tls, ctx) + case int64: + sqlite3.Xsqlite3_result_int64(tls, ctx, resTyped) + case float64: + sqlite3.Xsqlite3_result_double(tls, ctx, resTyped) + case bool: + sqlite3.Xsqlite3_result_int(tls, ctx, libc.Bool32(resTyped)) + case time.Time: + sqlite3.Xsqlite3_result_int64(tls, ctx, resTyped.Unix()) + case string: + size := int32(len(resTyped)) + cstr, err := libc.CString(resTyped) + if err != nil { + panic(err) + } + defer libc.Xfree(tls, cstr) + sqlite3.Xsqlite3_result_text(tls, ctx, cstr, size, sqlite3.SQLITE_TRANSIENT) + case []byte: + size := int32(len(resTyped)) + if size == 0 { + sqlite3.Xsqlite3_result_zeroblob(tls, ctx, 0) + return + } + p := libc.Xmalloc(tls, types.Size_t(size)) + if p == 0 { + panic(fmt.Sprintf("unable to allocate space for blob: %d", size)) + } + defer libc.Xfree(tls, p) + copy((*libc.RawMem)(unsafe.Pointer(p))[:size:size], resTyped) + + sqlite3.Xsqlite3_result_blob(tls, ctx, p, size, sqlite3.SQLITE_TRANSIENT) + default: + setErrorResult(fmt.Errorf("function did not return a valid driver.Value: %T", resTyped)) + return + } + }, + } + d.udfs[zFuncName] = udf + + return nil } |