diff --git a/conn.go b/conn.go index 3be0dfd..92299b0 100644 --- a/conn.go +++ b/conn.go @@ -18,10 +18,10 @@ type conn struct { ConnHook } -func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (dd driver.Result, err error) { +func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (result driver.Result, err error) { ctx, err = c.BeforeExecContext(ctx, query, args) defer func() { - _, err = c.AfterExecContext(ctx, query, args, err) + _, result, err = c.AfterExecContext(ctx, query, args, result, err) }() if err != nil { return nil, err @@ -47,10 +47,10 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name } } -func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (dd driver.Rows, err error) { +func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) { ctx, err = c.BeforeQueryContext(ctx, query, args) defer func() { - _, err = c.AfterQueryContext(ctx, query, args, err) + _, rows, err = c.AfterQueryContext(ctx, query, args, rows, err) }() if err != nil { return @@ -77,10 +77,10 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam } } -func (c *conn) PrepareContext(ctx context.Context, query string) (dd driver.Stmt, err error) { +func (c *conn) PrepareContext(ctx context.Context, query string) (s driver.Stmt, err error) { ctx, err = c.BeforePrepareContext(ctx, query) defer func() { - _, err = c.AfterPrepareContext(ctx, query, err) + _, s, err = c.AfterPrepareContext(ctx, query, s, err) }() if err != nil { return nil, err @@ -103,7 +103,7 @@ func (c *conn) PrepareContext(ctx context.Context, query string) (dd driver.Stmt func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (dd driver.Tx, err error) { ctx, err = c.BeforeBeginTx(ctx, opts) defer func() { - _, err = c.AfterBeginTx(ctx, opts, err) + _, dd, err = c.AfterBeginTx(ctx, opts, dd, err) }() if err != nil { return nil, err diff --git a/connector.go b/connector.go index 37787aa..e74ee13 100644 --- a/connector.go +++ b/connector.go @@ -12,10 +12,10 @@ type connector struct { ConnectorHook } -func (c *connector) Connect(ctx context.Context) (dd driver.Conn, err error) { +func (c *connector) Connect(ctx context.Context) (dc driver.Conn, err error) { ctx, err = c.BeforeConnect(ctx) defer func() { - _, err = c.AfterConnect(ctx, err) + _, dc, err = c.AfterConnect(ctx, dc, err) }() if err != nil { return nil, err diff --git a/hook.go b/hook.go index 8fef54c..b0122a4 100644 --- a/hook.go +++ b/hook.go @@ -14,18 +14,18 @@ type Hook interface { type ConnHook interface { BeforeExecContext(ctx context.Context, query string, args []driver.NamedValue) (context.Context, error) - AfterExecContext(ctx context.Context, query string, args []driver.NamedValue, err error) (context.Context, error) + AfterExecContext(ctx context.Context, query string, args []driver.NamedValue, r driver.Result, err error) (context.Context, driver.Result, error) BeforeBeginTx(ctx context.Context, opts driver.TxOptions) (context.Context, error) - AfterBeginTx(ctx context.Context, opts driver.TxOptions, err error) (context.Context, error) + AfterBeginTx(ctx context.Context, opts driver.TxOptions, dd driver.Tx, err error) (context.Context, driver.Tx, error) BeforeQueryContext(ctx context.Context, query string, args []driver.NamedValue) (context.Context, error) - AfterQueryContext(ctx context.Context, query string, args []driver.NamedValue, err error) (context.Context, error) + AfterQueryContext(ctx context.Context, query string, args []driver.NamedValue, rows driver.Rows, err error) (context.Context, driver.Rows, error) BeforePrepareContext(ctx context.Context, query string) (context.Context, error) - AfterPrepareContext(ctx context.Context, query string, err error) (context.Context, error) + AfterPrepareContext(ctx context.Context, query string, s driver.Stmt, err error) (context.Context, driver.Stmt, error) } type ConnectorHook interface { BeforeConnect(ctx context.Context) (context.Context, error) - AfterConnect(ctx context.Context, err error) (context.Context, error) + AfterConnect(ctx context.Context, dc driver.Conn, err error) (context.Context, driver.Conn, error) } type TxHook interface { @@ -37,7 +37,7 @@ type TxHook interface { type StmtHook interface { BeforeStmtQueryContext(ctx context.Context, query string, args []driver.NamedValue) (context.Context, error) - AfterStmtQueryContext(ctx context.Context, query string, args []driver.NamedValue, err error) (context.Context, error) + AfterStmtQueryContext(ctx context.Context, query string, args []driver.NamedValue, rows driver.Rows, err error) (context.Context, driver.Rows, error) BeforeStmtExecContext(ctx context.Context, query string, args []driver.NamedValue) (context.Context, error) - AfterStmtExecContext(ctx context.Context, query string, args []driver.NamedValue, err error) (context.Context, error) + AfterStmtExecContext(ctx context.Context, query string, args []driver.NamedValue, r driver.Result, err error) (context.Context, driver.Result, error) } diff --git a/stmt.go b/stmt.go index 3b36b25..8c34192 100644 --- a/stmt.go +++ b/stmt.go @@ -20,7 +20,7 @@ type stmt struct { func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (rows driver.Rows, err error) { ctx, err = s.BeforeStmtQueryContext(ctx, s.query, args) defer func() { - _, err = s.AfterStmtQueryContext(ctx, s.query, args, err) + _, rows, err = s.AfterStmtQueryContext(ctx, s.query, args, rows, err) }() if err != nil { return nil, err @@ -46,7 +46,7 @@ func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (rows func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (r driver.Result, err error) { ctx, err = s.BeforeStmtExecContext(ctx, s.query, args) defer func() { - _, err = s.AfterStmtExecContext(ctx, s.query, args, err) + _, r, err = s.AfterStmtExecContext(ctx, s.query, args, r, err) }() if err != nil { return nil, err