summaryrefslogtreecommitdiffstats
path: root/vendor/modernc.org/sqlite/sqlite.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/modernc.org/sqlite/sqlite.go')
-rw-r--r--vendor/modernc.org/sqlite/sqlite.go225
1 files changed, 218 insertions, 7 deletions
diff --git a/vendor/modernc.org/sqlite/sqlite.go b/vendor/modernc.org/sqlite/sqlite.go
index f8c696f6..7fce68ee 100644
--- a/vendor/modernc.org/sqlite/sqlite.go
+++ b/vendor/modernc.org/sqlite/sqlite.go
@@ -1090,14 +1090,18 @@ func (c *conn) bindText(pstmt uintptr, idx1 int, value string) (uintptr, error)
// int sqlite3_bind_blob(sqlite3_stmt*, int, const void*, int n, void(*)(void*));
func (c *conn) bindBlob(pstmt uintptr, idx1 int, value []byte) (uintptr, error) {
+ if len(value) == 0 {
+ if rc := sqlite3.Xsqlite3_bind_zeroblob(c.tls, pstmt, int32(idx1), 0); rc != sqlite3.SQLITE_OK {
+ return 0, c.errstr(rc)
+ }
+ return 0, nil
+ }
+
p, err := c.malloc(len(value))
if err != nil {
return 0, err
}
-
- if len(value) != 0 {
- copy((*libc.RawMem)(unsafe.Pointer(p))[:len(value):len(value)], value)
- }
+ copy((*libc.RawMem)(unsafe.Pointer(p))[:len(value):len(value)], value)
if rc := sqlite3.Xsqlite3_bind_blob(c.tls, pstmt, int32(idx1), p, int32(len(value)), 0); rc != sqlite3.SQLITE_OK {
c.free(p)
return 0, c.errstr(rc)
@@ -1307,6 +1311,7 @@ func (c *conn) Close() error {
c.db = 0
}
+
if c.tls != nil {
c.tls.Close()
c.tls = nil
@@ -1323,6 +1328,32 @@ func (c *conn) closeV2(db uintptr) error {
return nil
}
+type userDefinedFunction struct {
+ zFuncName uintptr
+ nArg int32
+ eTextRep int32
+ xFunc func(*libc.TLS, uintptr, int32, uintptr)
+
+ freeOnce sync.Once
+}
+
+func (c *conn) createFunctionInternal(fun *userDefinedFunction) error {
+ if rc := sqlite3.Xsqlite3_create_function(
+ c.tls,
+ c.db,
+ fun.zFuncName,
+ fun.nArg,
+ fun.eTextRep,
+ 0,
+ *(*uintptr)(unsafe.Pointer(&fun.xFunc)),
+ 0,
+ 0,
+ ); rc != sqlite3.SQLITE_OK {
+ return c.errstr(rc)
+ }
+ return nil
+}
+
// Execer is an optional interface that may be implemented by a Conn.
//
// If a Conn does not implement Execer, the sql package's DB.Exec will first
@@ -1388,9 +1419,14 @@ func (c *conn) query(ctx context.Context, query string, args []driver.NamedValue
}
// Driver implements database/sql/driver.Driver.
-type Driver struct{}
+type Driver struct {
+ // user defined functions that are added to every new connection on Open
+ udfs map[string]*userDefinedFunction
+}
+
+var d = &Driver{udfs: make(map[string]*userDefinedFunction)}
-func newDriver() *Driver { return &Driver{} }
+func newDriver() *Driver { return d }
// Open returns a new connection to the database. The name is a string in a
// driver-specific format.
@@ -1422,5 +1458,180 @@ func newDriver() *Driver { return &Driver{} }
// available at
// https://www.sqlite.org/lang_transaction.html#deferred_immediate_and_exclusive_transactions
func (d *Driver) Open(name string) (driver.Conn, error) {
- return newConn(name)
+ c, err := newConn(name)
+ if err != nil {
+ return nil, err
+ }
+
+ for _, udf := range d.udfs {
+ if err = c.createFunctionInternal(udf); err != nil {
+ c.Close()
+ return nil, err
+ }
+ }
+ return c, nil
+}
+
+// FunctionContext represents the context user defined functions execute in.
+// Fields and/or methods of this type may get addedd in the future.
+type FunctionContext struct{}
+
+const sqliteValPtrSize = unsafe.Sizeof(&sqlite3.Sqlite3_value{})
+
+// RegisterScalarFunction registers a scalar function named zFuncName with nArg
+// arguments. Passing -1 for nArg indicates the function is variadic.
+//
+// The new function will be available to all new connections opened after
+// executing RegisterScalarFunction.
+func RegisterScalarFunction(
+ zFuncName string,
+ nArg int32,
+ xFunc func(ctx *FunctionContext, args []driver.Value) (driver.Value, error),
+) error {
+ return registerScalarFunction(zFuncName, nArg, sqlite3.SQLITE_UTF8, xFunc)
+}
+
+// MustRegisterScalarFunction is like RegisterScalarFunction but panics on
+// error.
+func MustRegisterScalarFunction(
+ zFuncName string,
+ nArg int32,
+ xFunc func(ctx *FunctionContext, args []driver.Value) (driver.Value, error),
+) {
+ if err := RegisterScalarFunction(zFuncName, nArg, xFunc); err != nil {
+ panic(err)
+ }
+}
+
+// MustRegisterDeterministicScalarFunction is like
+// RegisterDeterministicScalarFunction but panics on error.
+func MustRegisterDeterministicScalarFunction(
+ zFuncName string,
+ nArg int32,
+ xFunc func(ctx *FunctionContext, args []driver.Value) (driver.Value, error),
+) {
+ if err := RegisterDeterministicScalarFunction(zFuncName, nArg, xFunc); err != nil {
+ panic(err)
+ }
+}
+
+// RegisterDeterministicScalarFunction registers a deterministic scalar
+// function named zFuncName with nArg arguments. Passing -1 for nArg indicates
+// the function is variadic. A deterministic function means that the function
+// always gives the same output when the input parameters are the same.
+//
+// The new function will be available to all new connections opened after
+// executing RegisterDeterministicScalarFunction.
+func RegisterDeterministicScalarFunction(
+ zFuncName string,
+ nArg int32,
+ xFunc func(ctx *FunctionContext, args []driver.Value) (driver.Value, error),
+) error {
+ return registerScalarFunction(zFuncName, nArg, sqlite3.SQLITE_UTF8|sqlite3.SQLITE_DETERMINISTIC, xFunc)
+}
+
+func registerScalarFunction(
+ zFuncName string,
+ nArg int32,
+ eTextRep int32,
+ xFunc func(ctx *FunctionContext, args []driver.Value) (driver.Value, error),
+) error {
+
+ if _, ok := d.udfs[zFuncName]; ok {
+ return fmt.Errorf("a function named %q is already registered", zFuncName)
+ }
+
+ // dont free, functions registered on the driver live as long as the program
+ name, err := libc.CString(zFuncName)
+ if err != nil {
+ return err
+ }
+
+ udf := &userDefinedFunction{
+ zFuncName: name,
+ nArg: nArg,
+ eTextRep: eTextRep,
+ xFunc: func(tls *libc.TLS, ctx uintptr, argc int32, argv uintptr) {
+ setErrorResult := func(res error) {
+ errmsg, cerr := libc.CString(res.Error())
+ if cerr != nil {
+ panic(cerr)
+ }
+ defer libc.Xfree(tls, errmsg)
+ sqlite3.Xsqlite3_result_error(tls, ctx, errmsg, -1)
+ sqlite3.Xsqlite3_result_error_code(tls, ctx, sqlite3.SQLITE_ERROR)
+ }
+
+ args := make([]driver.Value, argc)
+ for i := int32(0); i < argc; i++ {
+ valPtr := *(*uintptr)(unsafe.Pointer(argv + uintptr(i)*sqliteValPtrSize))
+
+ switch valType := sqlite3.Xsqlite3_value_type(tls, valPtr); valType {
+ case sqlite3.SQLITE_TEXT:
+ args[i] = libc.GoString(sqlite3.Xsqlite3_value_text(tls, valPtr))
+ case sqlite3.SQLITE_INTEGER:
+ args[i] = sqlite3.Xsqlite3_value_int64(tls, valPtr)
+ case sqlite3.SQLITE_FLOAT:
+ args[i] = sqlite3.Xsqlite3_value_double(tls, valPtr)
+ case sqlite3.SQLITE_NULL:
+ args[i] = nil
+ case sqlite3.SQLITE_BLOB:
+ size := sqlite3.Xsqlite3_value_bytes(tls, valPtr)
+ blobPtr := sqlite3.Xsqlite3_value_blob(tls, valPtr)
+ v := make([]byte, size)
+ copy(v, (*libc.RawMem)(unsafe.Pointer(blobPtr))[:size:size])
+ args[i] = v
+ default:
+ panic(fmt.Sprintf("unexpected argument type %q passed by sqlite", valType))
+ }
+ }
+
+ res, err := xFunc(&FunctionContext{}, args)
+ if err != nil {
+ setErrorResult(err)
+ return
+ }
+
+ switch resTyped := res.(type) {
+ case nil:
+ sqlite3.Xsqlite3_result_null(tls, ctx)
+ case int64:
+ sqlite3.Xsqlite3_result_int64(tls, ctx, resTyped)
+ case float64:
+ sqlite3.Xsqlite3_result_double(tls, ctx, resTyped)
+ case bool:
+ sqlite3.Xsqlite3_result_int(tls, ctx, libc.Bool32(resTyped))
+ case time.Time:
+ sqlite3.Xsqlite3_result_int64(tls, ctx, resTyped.Unix())
+ case string:
+ size := int32(len(resTyped))
+ cstr, err := libc.CString(resTyped)
+ if err != nil {
+ panic(err)
+ }
+ defer libc.Xfree(tls, cstr)
+ sqlite3.Xsqlite3_result_text(tls, ctx, cstr, size, sqlite3.SQLITE_TRANSIENT)
+ case []byte:
+ size := int32(len(resTyped))
+ if size == 0 {
+ sqlite3.Xsqlite3_result_zeroblob(tls, ctx, 0)
+ return
+ }
+ p := libc.Xmalloc(tls, types.Size_t(size))
+ if p == 0 {
+ panic(fmt.Sprintf("unable to allocate space for blob: %d", size))
+ }
+ defer libc.Xfree(tls, p)
+ copy((*libc.RawMem)(unsafe.Pointer(p))[:size:size], resTyped)
+
+ sqlite3.Xsqlite3_result_blob(tls, ctx, p, size, sqlite3.SQLITE_TRANSIENT)
+ default:
+ setErrorResult(fmt.Errorf("function did not return a valid driver.Value: %T", resTyped))
+ return
+ }
+ },
+ }
+ d.udfs[zFuncName] = udf
+
+ return nil
}