@@ -133,14 +133,16 @@ type T struct {
133133 succeeded []* event.CommandSucceededEvent
134134 failed []* event.CommandFailedEvent
135135
136- Client * mongo.Client
137- DB * mongo.Database
138- Coll * mongo.Collection
136+ Client * mongo.Client
137+ fpClients map [* mongo.Client ]bool
138+ DB * mongo.Database
139+ Coll * mongo.Collection
139140}
140141
141142func newT (wrapped * testing.T , opts ... * Options ) * T {
142143 t := & T {
143- T : wrapped ,
144+ T : wrapped ,
145+ fpClients : make (map [* mongo.Client ]bool ),
144146 }
145147 for _ , opt := range opts {
146148 for _ , optFn := range opt .optFuncs {
@@ -207,6 +209,12 @@ func (t *T) cleanup() {
207209 // always disconnect the client regardless of clientType because Client.Disconnect will work against
208210 // all deployments
209211 _ = t .Client .Disconnect (context .Background ())
212+ for client , v := range t .fpClients {
213+ if v {
214+ client .Disconnect (context .Background ())
215+ }
216+ }
217+ t .fpClients = make (map [* mongo.Client ]bool )
210218}
211219
212220// Run creates a new T instance for a sub-test and runs the given callback. It also creates a new collection using the
@@ -261,7 +269,9 @@ func (t *T) RunOpts(name string, opts *Options, callback func(mt *T)) {
261269 }
262270 // only disconnect client if it's not being shared
263271 if sub .shareClient == nil || ! * sub .shareClient {
264- _ = sub .Client .Disconnect (context .Background ())
272+ if v , ok := sub .fpClients [sub .Client ]; ! ok || ! v {
273+ _ = sub .Client .Disconnect (context .Background ())
274+ }
265275 }
266276 assert .Equal (sub , 0 , sessions , "%v sessions checked out" , sessions )
267277 assert .Equal (sub , 0 , conns , "%v connections checked out" , conns )
@@ -410,7 +420,9 @@ func (t *T) ResetClient(opts *options.ClientOptions) {
410420 t .clientOpts = opts
411421 }
412422
413- _ = t .Client .Disconnect (context .Background ())
423+ if v , ok := t .fpClients [t .Client ]; ! ok || ! v {
424+ _ = t .Client .Disconnect (context .Background ())
425+ }
414426 t .createTestClient ()
415427 t .DB = t .Client .Database (t .dbName )
416428 t .Coll = t .DB .Collection (t .collName , t .collOpts )
@@ -564,42 +576,31 @@ func (t *T) SetFailPoint(fp FailPoint) {
564576 }
565577 }
566578
567- client , err := mongo .NewClient (t .clientOpts )
568- if err != nil {
569- t .Fatalf ("error creating client: %v" , err )
570- }
571- if err = client .Connect (context .Background ()); err != nil {
572- t .Fatalf ("error connecting client: %v" , err )
573- }
574- if err = SetFailPoint (fp , client ); err != nil {
579+ if err := SetFailPoint (fp , t .Client ); err != nil {
575580 t .Fatal (err )
576581 }
577- t .failPoints = append (t .failPoints , failPoint {fp .ConfigureFailPoint , client })
582+ t .fpClients [t .Client ] = true
583+ t .failPoints = append (t .failPoints , failPoint {fp .ConfigureFailPoint , t .Client })
578584}
579585
580586// SetFailPointFromDocument sets the fail point represented by the given document for the client associated with T. This
581587// method assumes that the given document is in the form {configureFailPoint: <failPointName>, ...}. Commands to create
582588// the failpoint will appear in command monitoring channels. The fail point will be automatically disabled after this
583589// test has run.
584590func (t * T ) SetFailPointFromDocument (fp bson.Raw ) {
585- client , err := mongo .NewClient (t .clientOpts )
586- if err != nil {
587- t .Fatalf ("error creating client: %v" , err )
588- }
589- if err = client .Connect (context .Background ()); err != nil {
590- t .Fatalf ("error connecting client: %v" , err )
591- }
592- if err = SetRawFailPoint (fp , client ); err != nil {
591+ if err := SetRawFailPoint (fp , t .Client ); err != nil {
593592 t .Fatal (err )
594593 }
595594
595+ t .fpClients [t .Client ] = true
596596 name := fp .Index (0 ).Value ().StringValue ()
597- t .failPoints = append (t .failPoints , failPoint {name , client })
597+ t .failPoints = append (t .failPoints , failPoint {name , t . Client })
598598}
599599
600600// TrackFailPoint adds the given fail point to the list of fail points to be disabled when the current test finishes.
601601// This function does not create a fail point on the server.
602602func (t * T ) TrackFailPoint (fpName string , client * mongo.Client ) {
603+ t .fpClients [client ] = true
603604 t .failPoints = append (t .failPoints , failPoint {fpName , client })
604605}
605606
@@ -614,7 +615,10 @@ func (t *T) ClearFailPoints() {
614615 if err != nil {
615616 t .Fatalf ("error clearing fail point %s: %v" , fp .name , err )
616617 }
617- _ = fp .client .Disconnect (context .Background ())
618+ if fp .client != t .Client {
619+ _ = fp .client .Disconnect (context .Background ())
620+ t .fpClients [fp .client ] = false
621+ }
618622 }
619623 t .failPoints = t .failPoints [:0 ]
620624}
0 commit comments