@@ -335,24 +335,27 @@ func (c *CSAPI) MustSync(t *testing.T, syncReq SyncReq) (gjson.Result, string) {
335
335
// check functions return no error. Returns the final/latest since token.
336
336
//
337
337
// Initial /sync example: (no since token)
338
- // bob.InviteRoom(t, roomID, alice.UserID)
339
- // alice.JoinRoom(t, roomID, nil)
340
- // alice.MustSyncUntil(t, client.SyncReq{}, client.SyncJoinedTo(alice.UserID, roomID))
338
+ //
339
+ // bob.InviteRoom(t, roomID, alice.UserID)
340
+ // alice.JoinRoom(t, roomID, nil)
341
+ // alice.MustSyncUntil(t, client.SyncReq{}, client.SyncJoinedTo(alice.UserID, roomID))
341
342
//
342
343
// Incremental /sync example: (test controls since token)
343
- // since := alice.MustSyncUntil(t, client.SyncReq{TimeoutMillis: "0"}) // get a since token
344
- // bob.InviteRoom(t, roomID, alice.UserID)
345
- // since = alice.MustSyncUntil(t, client.SyncReq{Since: since}, client.SyncInvitedTo(alice.UserID, roomID))
346
- // alice.JoinRoom(t, roomID, nil)
347
- // alice.MustSyncUntil(t, client.SyncReq{Since: since}, client.SyncJoinedTo(alice.UserID, roomID))
344
+ //
345
+ // since := alice.MustSyncUntil(t, client.SyncReq{TimeoutMillis: "0"}) // get a since token
346
+ // bob.InviteRoom(t, roomID, alice.UserID)
347
+ // since = alice.MustSyncUntil(t, client.SyncReq{Since: since}, client.SyncInvitedTo(alice.UserID, roomID))
348
+ // alice.JoinRoom(t, roomID, nil)
349
+ // alice.MustSyncUntil(t, client.SyncReq{Since: since}, client.SyncJoinedTo(alice.UserID, roomID))
348
350
//
349
351
// Checking multiple parts of /sync:
350
- // alice.MustSyncUntil(
351
- // t, client.SyncReq{},
352
- // client.SyncJoinedTo(alice.UserID, roomID),
353
- // client.SyncJoinedTo(alice.UserID, roomID2),
354
- // client.SyncJoinedTo(alice.UserID, roomID3),
355
- // )
352
+ //
353
+ // alice.MustSyncUntil(
354
+ // t, client.SyncReq{},
355
+ // client.SyncJoinedTo(alice.UserID, roomID),
356
+ // client.SyncJoinedTo(alice.UserID, roomID2),
357
+ // client.SyncJoinedTo(alice.UserID, roomID3),
358
+ // )
356
359
//
357
360
// Check functions are unordered and independent. Once a check function returns true it is removed
358
361
// from the list of checks and won't be called again.
@@ -438,7 +441,81 @@ func (c *CSAPI) LoginUser(t *testing.T, localpart, password string) (userID, acc
438
441
return userID , accessToken , deviceID
439
442
}
440
443
441
- //RegisterUser will register the user with given parameters and
444
+ // LoginUserWithDeviceID will log in to a homeserver on an existing device
445
+ func (c * CSAPI ) LoginUserWithDeviceID (t * testing.T , localpart , password , deviceID string ) (userID , accessToken string ) {
446
+ t .Helper ()
447
+ reqBody := map [string ]interface {}{
448
+ "identifier" : map [string ]interface {}{
449
+ "type" : "m.id.user" ,
450
+ "user" : localpart ,
451
+ },
452
+ "device_id" : deviceID ,
453
+ "password" : password ,
454
+ "type" : "m.login.password" ,
455
+ }
456
+ res := c .MustDoFunc (t , "POST" , []string {"_matrix" , "client" , "v3" , "login" }, WithJSONBody (t , reqBody ))
457
+
458
+ body , err := ioutil .ReadAll (res .Body )
459
+ if err != nil {
460
+ t .Fatalf ("unable to read response body: %v" , err )
461
+ }
462
+
463
+ userID = gjson .GetBytes (body , "user_id" ).Str
464
+ accessToken = gjson .GetBytes (body , "access_token" ).Str
465
+ if gjson .GetBytes (body , "device_id" ).Str != deviceID {
466
+ t .Fatalf ("device_id returned by login does not match the one requested" )
467
+ }
468
+ return userID , accessToken
469
+ }
470
+
471
+ // LoginUserWithRefreshToken will log in to a homeserver, with refresh token enabled,
472
+ // and create a new device on an existing user.
473
+ func (c * CSAPI ) LoginUserWithRefreshToken (t * testing.T , localpart , password string ) (userID , accessToken , refreshToken , deviceID string , expiresInMs int64 ) {
474
+ t .Helper ()
475
+ reqBody := map [string ]interface {}{
476
+ "identifier" : map [string ]interface {}{
477
+ "type" : "m.id.user" ,
478
+ "user" : localpart ,
479
+ },
480
+ "password" : password ,
481
+ "type" : "m.login.password" ,
482
+ "refresh_token" : true ,
483
+ }
484
+ res := c .MustDoFunc (t , "POST" , []string {"_matrix" , "client" , "v3" , "login" }, WithJSONBody (t , reqBody ))
485
+
486
+ body , err := ioutil .ReadAll (res .Body )
487
+ if err != nil {
488
+ t .Fatalf ("unable to read response body: %v" , err )
489
+ }
490
+
491
+ userID = gjson .GetBytes (body , "user_id" ).Str
492
+ accessToken = gjson .GetBytes (body , "access_token" ).Str
493
+ deviceID = gjson .GetBytes (body , "device_id" ).Str
494
+ refreshToken = gjson .GetBytes (body , "refresh_token" ).Str
495
+ expiresInMs = gjson .GetBytes (body , "expires_in_ms" ).Int ()
496
+ return userID , accessToken , refreshToken , deviceID , expiresInMs
497
+ }
498
+
499
+ // RefreshToken will consume a refresh token and return a new access token and refresh token.
500
+ func (c * CSAPI ) ConsumeRefreshToken (t * testing.T , refreshToken string ) (newAccessToken , newRefreshToken string , expiresInMs int64 ) {
501
+ t .Helper ()
502
+ reqBody := map [string ]interface {}{
503
+ "refresh_token" : refreshToken ,
504
+ }
505
+ res := c .MustDoFunc (t , "POST" , []string {"_matrix" , "client" , "v3" , "refresh" }, WithJSONBody (t , reqBody ))
506
+
507
+ body , err := ioutil .ReadAll (res .Body )
508
+ if err != nil {
509
+ t .Fatalf ("unable to read response body: %v" , err )
510
+ }
511
+
512
+ newAccessToken = gjson .GetBytes (body , "access_token" ).Str
513
+ newRefreshToken = gjson .GetBytes (body , "refresh_token" ).Str
514
+ expiresInMs = gjson .GetBytes (body , "expires_in_ms" ).Int ()
515
+ return newAccessToken , newRefreshToken , expiresInMs
516
+ }
517
+
518
+ // RegisterUser will register the user with given parameters and
442
519
// return user ID & access token, and fail the test on network error
443
520
func (c * CSAPI ) RegisterUser (t * testing.T , localpart , password string ) (userID , accessToken , deviceID string ) {
444
521
t .Helper ()
@@ -598,12 +675,13 @@ func (c *CSAPI) MustDoFunc(t *testing.T, method string, paths []string, opts ...
598
675
//
599
676
// Fails the test if an HTTP request could not be made or if there was a network error talking to the
600
677
// server. To do assertions on the HTTP response, see the `must` package. For example:
601
- // must.MatchResponse(t, res, match.HTTPResponse{
602
- // StatusCode: 400,
603
- // JSON: []match.JSON{
604
- // match.JSONKeyEqual("errcode", "M_INVALID_USERNAME"),
605
- // },
606
- // })
678
+ //
679
+ // must.MatchResponse(t, res, match.HTTPResponse{
680
+ // StatusCode: 400,
681
+ // JSON: []match.JSON{
682
+ // match.JSONKeyEqual("errcode", "M_INVALID_USERNAME"),
683
+ // },
684
+ // })
607
685
func (c * CSAPI ) DoFunc (t * testing.T , method string , paths []string , opts ... RequestOpt ) * http.Response {
608
686
t .Helper ()
609
687
for i := range paths {
0 commit comments