diff options
Diffstat (limited to 'vendor/modernc.org/sqlite/sqlite.go')
-rw-r--r-- | vendor/modernc.org/sqlite/sqlite.go | 53 |
1 files changed, 43 insertions, 10 deletions
diff --git a/vendor/modernc.org/sqlite/sqlite.go b/vendor/modernc.org/sqlite/sqlite.go index 1bef9f47..ade56472 100644 --- a/vendor/modernc.org/sqlite/sqlite.go +++ b/vendor/modernc.org/sqlite/sqlite.go @@ -657,14 +657,14 @@ type tx struct { c *conn } -func newTx(c *conn) (*tx, error) { +func newTx(c *conn, opts driver.TxOptions) (*tx, error) { r := &tx{c: c} - var sql string - if c.beginMode != "" { + + sql := "begin" + if !opts.ReadOnly && c.beginMode != "" { sql = "begin " + c.beginMode - } else { - sql = "begin" } + if err := r.exec(context.Background(), sql); err != nil { return nil, err } @@ -752,7 +752,7 @@ type conn struct { } func newConn(dsn string) (*conn, error) { - var query string + var query, vfsName string // Parse the query parameters from the dsn and them from the dsn if not prefixed by file: // https://github.com/mattn/go-sqlite3/blob/3392062c729d77820afc1f5cae3427f0de39e954/sqlite3.go#L1046 @@ -760,6 +760,12 @@ func newConn(dsn string) (*conn, error) { pos := strings.IndexRune(dsn, '?') if pos >= 1 { query = dsn[pos+1:] + var err error + vfsName, err = getVFSName(query) + if err != nil { + return nil, err + } + if !strings.HasPrefix(dsn, "file:") { dsn = dsn[:pos] } @@ -768,6 +774,7 @@ func newConn(dsn string) (*conn, error) { c := &conn{tls: libc.NewTLS()} db, err := c.openV2( dsn, + vfsName, sqlite3.SQLITE_OPEN_READWRITE|sqlite3.SQLITE_OPEN_CREATE| sqlite3.SQLITE_OPEN_FULLMUTEX| sqlite3.SQLITE_OPEN_URI, @@ -790,6 +797,23 @@ func newConn(dsn string) (*conn, error) { return c, nil } +func getVFSName(query string) (r string, err error) { + q, err := url.ParseQuery(query) + if err != nil { + return "", err + } + + for _, v := range q["vfs"] { + if r != "" && r != v { + return "", fmt.Errorf("conflicting vfs query parameters: %v", q["vfs"]) + } + + r = v + } + + return r, nil +} + func applyQueryParams(c *conn, query string) error { q, err := url.ParseQuery(query) if err != nil { @@ -1225,8 +1249,8 @@ func (c *conn) extendedResultCodes(on bool) error { // int flags, /* Flags */ // const char *zVfs /* Name of VFS module to use */ // ); -func (c *conn) openV2(name string, flags int32) (uintptr, error) { - var p, s uintptr +func (c *conn) openV2(name, vfsName string, flags int32) (uintptr, error) { + var p, s, vfs uintptr defer func() { if p != 0 { @@ -1235,6 +1259,9 @@ func (c *conn) openV2(name string, flags int32) (uintptr, error) { if s != 0 { c.free(s) } + if vfs != 0 { + c.free(vfs) + } }() p, err := c.malloc(int(ptrSize)) @@ -1246,7 +1273,13 @@ func (c *conn) openV2(name string, flags int32) (uintptr, error) { return 0, err } - if rc := sqlite3.Xsqlite3_open_v2(c.tls, s, p, flags, 0); rc != sqlite3.SQLITE_OK { + if vfsName != "" { + if vfs, err = libc.CString(vfsName); err != nil { + return 0, err + } + } + + if rc := sqlite3.Xsqlite3_open_v2(c.tls, s, p, flags, vfs); rc != sqlite3.SQLITE_OK { return 0, c.errstr(rc) } @@ -1292,7 +1325,7 @@ func (c *conn) Begin() (driver.Tx, error) { } func (c *conn) begin(ctx context.Context, opts driver.TxOptions) (t driver.Tx, err error) { - return newTx(c) + return newTx(c, opts) } // Close invalidates and potentially stops any current prepared statements and |