Skip to content

Commit 820366e

Browse files
authored
Generate new ObjectID only when required (#1479)
1 parent 4dbe540 commit 820366e

File tree

5 files changed

+40
-4
lines changed

5 files changed

+40
-4
lines changed

bson/primitive/objectid_test.go

+7
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,13 @@ func BenchmarkObjectIDFromHex(b *testing.B) {
3737
}
3838
}
3939

40+
func BenchmarkNewObjectIDFromTimestamp(b *testing.B) {
41+
for i := 0; i < b.N; i++ {
42+
timestamp := time.Now().Add(time.Duration(i) * time.Millisecond)
43+
_ = NewObjectIDFromTimestamp(timestamp)
44+
}
45+
}
46+
4047
func TestFromHex_RoundTrip(t *testing.T) {
4148
before := NewObjectID()
4249
after, err := ObjectIDFromHex(before.Hex())

mongo/bulk_write.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ func (bw *bulkWrite) runInsert(ctx context.Context, batch bulkWriteBatch) (opera
171171
if err != nil {
172172
return operation.InsertResult{}, err
173173
}
174-
doc, _, err = ensureID(doc, primitive.NewObjectID(), bw.collection.bsonOpts, bw.collection.registry)
174+
doc, _, err = ensureID(doc, primitive.NilObjectID, bw.collection.bsonOpts, bw.collection.registry)
175175
if err != nil {
176176
return operation.InsertResult{}, err
177177
}

mongo/collection.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ func (coll *Collection) insert(ctx context.Context, documents []interface{},
256256
if err != nil {
257257
return nil, err
258258
}
259-
bsoncoreDoc, id, err := ensureID(bsoncoreDoc, primitive.NewObjectID(), coll.bsonOpts, coll.registry)
259+
bsoncoreDoc, id, err := ensureID(bsoncoreDoc, primitive.NilObjectID, coll.bsonOpts, coll.registry)
260260
if err != nil {
261261
return nil, err
262262
}

mongo/mongo.go

+8-2
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,11 @@ func marshal(
177177
}
178178

179179
// ensureID inserts the given ObjectID as an element named "_id" at the
180-
// beginning of the given BSON document if there is not an "_id" already. If
181-
// there is already an element named "_id", the document is not modified. It
180+
// beginning of the given BSON document if there is not an "_id" already.
181+
// If the given ObjectID is primitive.NilObjectID, a new object ID will be
182+
// generated with time.Now().
183+
//
184+
// If there is already an element named "_id", the document is not modified. It
182185
// returns the resulting document and the decoded Go value of the "_id" element.
183186
func ensureID(
184187
doc bsoncore.Document,
@@ -219,6 +222,9 @@ func ensureID(
219222
const extraSpace = 17
220223
doc = make(bsoncore.Document, 0, len(olddoc)+extraSpace)
221224
_, doc = bsoncore.ReserveLength(doc)
225+
if oid.IsZero() {
226+
oid = primitive.NewObjectID()
227+
}
222228
doc = bsoncore.AppendObjectIDElement(doc, "_id", oid)
223229

224230
// Remove and re-write the BSON document length header.

mongo/mongo_test.go

+23
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,29 @@ func TestEnsureID(t *testing.T) {
134134
}
135135
}
136136

137+
func TestEnsureID_NilObjectID(t *testing.T) {
138+
t.Parallel()
139+
140+
doc := bsoncore.NewDocumentBuilder().
141+
AppendString("foo", "bar").
142+
Build()
143+
144+
got, gotIDI, err := ensureID(doc, primitive.NilObjectID, nil, nil)
145+
assert.NoError(t, err)
146+
147+
gotID, ok := gotIDI.(primitive.ObjectID)
148+
149+
assert.True(t, ok)
150+
assert.NotEqual(t, primitive.NilObjectID, gotID)
151+
152+
want := bsoncore.NewDocumentBuilder().
153+
AppendObjectID("_id", gotID).
154+
AppendString("foo", "bar").
155+
Build()
156+
157+
assert.Equal(t, want, got)
158+
}
159+
137160
func TestMarshalAggregatePipeline(t *testing.T) {
138161
// []byte of [{{"$limit", 12345}}]
139162
index, arr := bsoncore.AppendArrayStart(nil)

0 commit comments

Comments
 (0)