Skip to content

Commit 2a8964b

Browse files
committed
v3: protect from partial DML execution
V3 sometimes breaks up a single DML into multiple statements. Sometimes, this may lead to failures in the middle. The code change detects such failures and rolls back the transaction to prevent commits that are accidentally partial.
1 parent a0ca6d8 commit 2a8964b

File tree

8 files changed

+83
-25
lines changed

8 files changed

+83
-25
lines changed

go/vt/vtgate/engine/primitive.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ const ListVarName = "__vals"
2222
// VCursor defines the interface the engine will use
2323
// to execute routes.
2424
type VCursor interface {
25-
Execute(query string, bindvars map[string]interface{}) (*sqltypes.Result, error)
26-
ExecuteMultiShard(keyspace string, shardQueries map[string]querytypes.BoundQuery) (*sqltypes.Result, error)
25+
Execute(query string, bindvars map[string]interface{}, isDML bool) (*sqltypes.Result, error)
26+
ExecuteMultiShard(keyspace string, shardQueries map[string]querytypes.BoundQuery, isDML bool) (*sqltypes.Result, error)
2727
ExecuteStandalone(query string, bindvars map[string]interface{}, keyspace, shard string) (*sqltypes.Result, error)
2828
StreamExecuteMulti(query string, keyspace string, shardVars map[string]map[string]interface{}, callback func(reply *sqltypes.Result) error) error
2929
GetKeyspaceShards(vkeyspace *vindexes.Keyspace) (string, []*topodatapb.ShardReference, error)

go/vt/vtgate/engine/route.go

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -257,8 +257,12 @@ func (route *Route) Execute(vcursor VCursor, bindVars, joinVars map[string]inter
257257

258258
var err error
259259
var params *scatterParams
260+
isDML := false
260261
switch route.Opcode {
261-
case SelectUnsharded, UpdateUnsharded, DeleteUnsharded, SelectScatter:
262+
case SelectUnsharded, SelectScatter:
263+
params, err = route.paramsAllShards(vcursor, bindVars)
264+
case UpdateUnsharded, DeleteUnsharded:
265+
isDML = true
262266
params, err = route.paramsAllShards(vcursor, bindVars)
263267
case SelectEqual, SelectEqualUnique:
264268
params, err = route.paramsSelectEqual(vcursor, bindVars)
@@ -273,7 +277,7 @@ func (route *Route) Execute(vcursor VCursor, bindVars, joinVars map[string]inter
273277
}
274278

275279
shardQueries := route.getShardQueries(route.Query, params)
276-
return vcursor.ExecuteMultiShard(params.ks, shardQueries)
280+
return vcursor.ExecuteMultiShard(params.ks, shardQueries, isDML)
277281
}
278282

279283
// StreamExecute performs a streaming exec.
@@ -312,7 +316,7 @@ func (route *Route) GetFields(vcursor VCursor, bindVars, joinVars map[string]int
312316
return nil, err
313317
}
314318

315-
return route.execShard(vcursor, route.FieldQuery, bindVars, ks, shard)
319+
return route.execShard(vcursor, route.FieldQuery, bindVars, ks, shard, false /* isDML */)
316320
}
317321

318322
func combineVars(bv1, bv2 map[string]interface{}) map[string]interface{} {
@@ -390,7 +394,7 @@ func (route *Route) execUpdateEqual(vcursor VCursor, bindVars map[string]interfa
390394
return &sqltypes.Result{}, nil
391395
}
392396
rewritten := sqlannotation.AddKeyspaceIDs(route.Query, [][]byte{ksid}, "")
393-
return route.execShard(vcursor, rewritten, bindVars, ks, shard)
397+
return route.execShard(vcursor, rewritten, bindVars, ks, shard, true /* isDML */)
394398
}
395399

396400
func (route *Route) execDeleteEqual(vcursor VCursor, bindVars map[string]interface{}) (*sqltypes.Result, error) {
@@ -412,7 +416,7 @@ func (route *Route) execDeleteEqual(vcursor VCursor, bindVars map[string]interfa
412416
}
413417
}
414418
rewritten := sqlannotation.AddKeyspaceIDs(route.Query, [][]byte{ksid}, "")
415-
return route.execShard(vcursor, rewritten, bindVars, ks, shard)
419+
return route.execShard(vcursor, rewritten, bindVars, ks, shard, true /* isDML */)
416420
}
417421

418422
func (route *Route) execInsertUnsharded(vcursor VCursor, bindVars map[string]interface{}) (*sqltypes.Result, error) {
@@ -426,7 +430,7 @@ func (route *Route) execInsertUnsharded(vcursor VCursor, bindVars map[string]int
426430
}
427431

428432
shardQueries := route.getShardQueries(route.Query, params)
429-
result, err := vcursor.ExecuteMultiShard(params.ks, shardQueries)
433+
result, err := vcursor.ExecuteMultiShard(params.ks, shardQueries, true /* isDML */)
430434
if err != nil {
431435
return nil, fmt.Errorf("execInsertUnsharded: %v", err)
432436
}
@@ -451,7 +455,7 @@ func (route *Route) execInsertSharded(vcursor VCursor, bindVars map[string]inter
451455
return nil, fmt.Errorf("execInsertSharded: %v", err)
452456
}
453457

454-
result, err := vcursor.ExecuteMultiShard(keyspace, shardQueries)
458+
result, err := vcursor.ExecuteMultiShard(keyspace, shardQueries, true /* isDML */)
455459

456460
if err != nil {
457461
return nil, fmt.Errorf("execInsertSharded: %v", err)
@@ -640,7 +644,7 @@ func (route *Route) resolveSingleShard(vcursor VCursor, bindVars map[string]inte
640644
}
641645

642646
func (route *Route) deleteVindexEntries(vcursor VCursor, bindVars map[string]interface{}, ks, shard string, ksid []byte) error {
643-
result, err := route.execShard(vcursor, route.Subquery, bindVars, ks, shard)
647+
result, err := route.execShard(vcursor, route.Subquery, bindVars, ks, shard, false /* isDML */)
644648
if err != nil {
645649
return err
646650
}
@@ -807,13 +811,13 @@ func (route *Route) handleNonPrimary(vcursor VCursor, vindexKeys []interface{},
807811
return nil
808812
}
809813

810-
func (route *Route) execShard(vcursor VCursor, query string, bindVars map[string]interface{}, keyspace, shard string) (*sqltypes.Result, error) {
814+
func (route *Route) execShard(vcursor VCursor, query string, bindVars map[string]interface{}, keyspace, shard string, isDML bool) (*sqltypes.Result, error) {
811815
return vcursor.ExecuteMultiShard(keyspace, map[string]querytypes.BoundQuery{
812816
shard: {
813817
Sql: query,
814818
BindVariables: bindVars,
815819
},
816-
})
820+
}, isDML)
817821
}
818822

819823
func (route *Route) anyShard(vcursor VCursor, keyspace *vindexes.Keyspace) (string, string, error) {

go/vt/vtgate/executor.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,13 @@ func (e *Executor) handleExec(ctx context.Context, session *vtgatepb.Session, sq
145145
if err != nil {
146146
return nil, err
147147
}
148-
return plan.Instructions.Execute(vcursor, bindVars, make(map[string]interface{}), true)
148+
qr, err := plan.Instructions.Execute(vcursor, bindVars, make(map[string]interface{}), true)
149+
// Check if there was partial DML execution. If so, rollback the transaction.
150+
if err != nil && session.InTransaction && vcursor.hasPartialDML {
151+
_ = e.txConn.Rollback(ctx, NewSafeSession(session))
152+
err = vterrors.Errorf(vtrpcpb.Code_ABORTED, "transaction rolled back due to partial DML execution: %v", err)
153+
}
154+
return qr, err
149155
}
150156

151157
func (e *Executor) shardExec(ctx context.Context, session *vtgatepb.Session, sql string, bindVars map[string]interface{}, target querypb.Target) (*sqltypes.Result, error) {

go/vt/vtgate/executor_dml_test.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package vtgate
66

77
import (
8+
"context"
89
"reflect"
910
"strings"
1011
"testing"
@@ -15,6 +16,7 @@ import (
1516
"github.com/youtube/vitess/go/vt/vttablet/tabletserver/querytypes"
1617

1718
querypb "github.com/youtube/vitess/go/vt/proto/query"
19+
vtgatepb "github.com/youtube/vitess/go/vt/proto/vtgate"
1820
vtrpcpb "github.com/youtube/vitess/go/vt/proto/vtrpc"
1921
)
2022

@@ -785,6 +787,36 @@ func TestInsertFail(t *testing.T) {
785787
}
786788
}
787789

790+
func TestInsertPartialFail(t *testing.T) {
791+
executor, sbc1, _, sbclookup := createExecutorEnv()
792+
793+
// If the first DML fails, there should be no rollback.
794+
sbclookup.MustFailCodes[vtrpcpb.Code_INVALID_ARGUMENT] = 1
795+
_, err := executor.Execute(
796+
context.Background(),
797+
&vtgatepb.Session{InTransaction: true},
798+
"insert into user(id, v, name) values (1, 2, 'myname')",
799+
nil,
800+
)
801+
want := "execInsertSharded:"
802+
if err == nil || !strings.HasPrefix(err.Error(), want) {
803+
t.Errorf("insert first DML fail: %v, must start with %s", err, want)
804+
}
805+
806+
// If the second DML fails, we should rollback.
807+
sbc1.MustFailCodes[vtrpcpb.Code_INVALID_ARGUMENT] = 1
808+
_, err = executor.Execute(
809+
context.Background(),
810+
&vtgatepb.Session{InTransaction: true},
811+
"insert into user(id, v, name) values (1, 2, 'myname')",
812+
nil,
813+
)
814+
want = "transaction rolled back"
815+
if err == nil || !strings.HasPrefix(err.Error(), want) {
816+
t.Errorf("insert first DML fail: %v, must start with %s", err, want)
817+
}
818+
}
819+
788820
func TestMultiInsertSharded(t *testing.T) {
789821
executor, sbc1, sbc2, sbclookup := createExecutorEnv()
790822

go/vt/vtgate/vcursor_impl.go

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ type vcursorImpl struct {
2727
target querypb.Target
2828
trailingComments string
2929
executor *Executor
30+
// hasPartialDML is set to true if any DML was successfully
31+
// executed. If there was a subsequent failure, the transaction
32+
// must be forced to rollback.
33+
hasPartialDML bool
3034
}
3135

3236
// newVcursorImpl creates a vcursorImpl. Before creating this object, you have to separate out any trailingComments that came with
@@ -54,13 +58,21 @@ func (vc *vcursorImpl) Find(keyspace, tablename sqlparser.TableIdent) (table *vi
5458
}
5559

5660
// Execute performs a V3 level execution of the query. It does not take any routing directives.
57-
func (vc *vcursorImpl) Execute(query string, BindVars map[string]interface{}) (*sqltypes.Result, error) {
58-
return vc.executor.Execute(vc.ctx, vc.session, query+vc.trailingComments, BindVars)
61+
func (vc *vcursorImpl) Execute(query string, BindVars map[string]interface{}, isDML bool) (*sqltypes.Result, error) {
62+
qr, err := vc.executor.Execute(vc.ctx, vc.session, query+vc.trailingComments, BindVars)
63+
if err == nil {
64+
vc.hasPartialDML = true
65+
}
66+
return qr, err
5967
}
6068

6169
// ExecuteMultiShard executes different queries on different shards and returns the combined result.
62-
func (vc *vcursorImpl) ExecuteMultiShard(keyspace string, shardQueries map[string]querytypes.BoundQuery) (*sqltypes.Result, error) {
63-
return vc.executor.scatterConn.ExecuteMultiShard(vc.ctx, keyspace, commentedShardQueries(shardQueries, vc.trailingComments), vc.target.TabletType, NewSafeSession(vc.session), false, vc.session.Options)
70+
func (vc *vcursorImpl) ExecuteMultiShard(keyspace string, shardQueries map[string]querytypes.BoundQuery, isDML bool) (*sqltypes.Result, error) {
71+
qr, err := vc.executor.scatterConn.ExecuteMultiShard(vc.ctx, keyspace, commentedShardQueries(shardQueries, vc.trailingComments), vc.target.TabletType, NewSafeSession(vc.session), false, vc.session.Options)
72+
if err == nil {
73+
vc.hasPartialDML = true
74+
}
75+
return qr, err
6476
}
6577

6678
// ExecuteStandalone executes the specified query on keyspace:shard, but outside of the current transaction, as an independent statement.
@@ -71,7 +83,11 @@ func (vc *vcursorImpl) ExecuteStandalone(query string, BindVars map[string]inter
7183
BindVariables: BindVars,
7284
},
7385
}
74-
return vc.executor.scatterConn.ExecuteMultiShard(vc.ctx, keyspace, bq, vc.target.TabletType, NewSafeSession(nil), false, vc.session.Options)
86+
qr, err := vc.executor.scatterConn.ExecuteMultiShard(vc.ctx, keyspace, bq, vc.target.TabletType, NewSafeSession(nil), false, vc.session.Options)
87+
if err == nil {
88+
vc.hasPartialDML = true
89+
}
90+
return qr, err
7591
}
7692

7793
// StreamExeculteMulti is the streaming version of ExecuteMultiShard.

go/vt/vtgate/vindexes/lookup_hash_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ type vcursor struct {
2323
bq *querytypes.BoundQuery
2424
}
2525

26-
func (vc *vcursor) Execute(query string, bindvars map[string]interface{}) (*sqltypes.Result, error) {
26+
func (vc *vcursor) Execute(query string, bindvars map[string]interface{}, isDML bool) (*sqltypes.Result, error) {
2727
vc.bq = &querytypes.BoundQuery{
2828
Sql: query,
2929
BindVariables: bindvars,

go/vt/vtgate/vindexes/lookup_internal.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ func (lkp *lookup) MapUniqueLookup(vcursor VCursor, ids []interface{}) ([][]byte
3737
for _, id := range ids {
3838
result, err := vcursor.Execute(lkp.sel, map[string]interface{}{
3939
lkp.From: id,
40-
})
40+
}, false /* isDML */)
4141
if err != nil {
4242
return nil, fmt.Errorf("lookup.Map: %v", err)
4343
}
@@ -67,7 +67,7 @@ func (lkp *lookup) MapNonUniqueLookup(vcursor VCursor, ids []interface{}) ([][][
6767
for _, id := range ids {
6868
result, err := vcursor.Execute(lkp.sel, map[string]interface{}{
6969
lkp.From: id,
70-
})
70+
}, false /* isDML */)
7171
if err != nil {
7272
return nil, fmt.Errorf("lookup.Map: %v", err)
7373
}
@@ -124,7 +124,7 @@ func (lkp *lookup) Verify(vcursor VCursor, ids []interface{}, ksids [][]byte) (b
124124
bindVars[toStr] = val[rowNum]
125125
}
126126
lkp.ver = fmt.Sprintf("select %s from %s where %s", lkp.From, lkp.Table, strings.Trim(colBuff.String(), "or")+")")
127-
result, err := vcursor.Execute(lkp.ver, bindVars)
127+
result, err := vcursor.Execute(lkp.ver, bindVars, false /* isDML */)
128128
if err != nil {
129129
return false, fmt.Errorf("lookup.Verify: %v", err)
130130
}
@@ -168,7 +168,7 @@ func (lkp *lookup) Create(vcursor VCursor, ids []interface{}, ksids [][]byte) er
168168
bindVars[toStr] = val[rowNum]
169169
}
170170
lkp.ins = strings.Trim(insBuffer.String(), ",")
171-
if _, err := vcursor.Execute(lkp.ins, bindVars); err != nil {
171+
if _, err := vcursor.Execute(lkp.ins, bindVars, true /* isDML */); err != nil {
172172
return fmt.Errorf("lookup.Create: %v", err)
173173
}
174174
return nil
@@ -191,7 +191,7 @@ func (lkp *lookup) Delete(vcursor VCursor, ids []interface{}, ksid []byte) error
191191
}
192192
for _, id := range ids {
193193
bindvars[lkp.From] = id
194-
if _, err := vcursor.Execute(lkp.del, bindvars); err != nil {
194+
if _, err := vcursor.Execute(lkp.del, bindvars, true /* isDML */); err != nil {
195195
return fmt.Errorf("lookup.Delete: %v", err)
196196
}
197197
}

go/vt/vtgate/vindexes/vindex.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import (
1616
// in the current context and session of a VTGate request. Vindexes
1717
// can use this interface to execute lookup queries.
1818
type VCursor interface {
19-
Execute(query string, bindvars map[string]interface{}) (*sqltypes.Result, error)
19+
Execute(query string, bindvars map[string]interface{}, isDML bool) (*sqltypes.Result, error)
2020
}
2121

2222
// Vindex defines the interface required to register a vindex.

0 commit comments

Comments
 (0)