Skip to content

Commit dd1740c

Browse files
Refactor the websocket client and add fixes
The websocket client and hub interaction has been simplified a bit. The hub now acts only as a tee writer to the various clients that register. Clients must register and unregister explicitly. The hub is no longer passed in to the client. Websocket clients now watch for password changes or jwt token expiration times. Clients are disconnected if auth token expires or if the password is changed. Various aditional safety checks have been added. Signed-off-by: Gabriel Adrian Samfira <[email protected]>
1 parent ca7f20b commit dd1740c

File tree

17 files changed

+423
-140
lines changed

17 files changed

+423
-140
lines changed

apiserver/controllers/controllers.go

+10-8
Original file line numberDiff line numberDiff line change
@@ -183,14 +183,9 @@ func (a *APIController) WSHandler(writer http.ResponseWriter, req *http.Request)
183183
slog.With(slog.Any("error", err)).ErrorContext(ctx, "error upgrading to websockets")
184184
return
185185
}
186+
defer conn.Close()
186187

187-
// nolint:golangci-lint,godox
188-
// TODO (gsamfira): Handle ExpiresAt. Right now, if a client uses
189-
// a valid token to authenticate, and keeps the websocket connection
190-
// open, it will allow that client to stream logs via websockets
191-
// until the connection is broken. We need to forcefully disconnect
192-
// the client once the token expires.
193-
client, err := wsWriter.NewClient(conn, a.hub)
188+
client, err := wsWriter.NewClient(ctx, conn)
194189
if err != nil {
195190
slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to create new client")
196191
return
@@ -199,7 +194,14 @@ func (a *APIController) WSHandler(writer http.ResponseWriter, req *http.Request)
199194
slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to register new client")
200195
return
201196
}
202-
client.Go()
197+
defer a.hub.Unregister(client)
198+
199+
if err := client.Start(); err != nil {
200+
slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to start client")
201+
return
202+
}
203+
<-client.Done()
204+
slog.Info("client disconnected", "client_id", client.ID())
203205
}
204206

205207
// NotFoundHandler is returned when an invalid URL is acccessed

auth/auth.go

+7-5
Original file line numberDiff line numberDiff line change
@@ -55,17 +55,19 @@ func (a *Authenticator) GetJWTToken(ctx context.Context) (string, error) {
5555
expires := &jwt.NumericDate{
5656
Time: expireToken,
5757
}
58+
generation := PasswordGeneration(ctx)
5859
claims := JWTClaims{
5960
RegisteredClaims: jwt.RegisteredClaims{
6061
ExpiresAt: expires,
6162
// nolint:golangci-lint,godox
6263
// TODO: make this configurable
6364
Issuer: "garm",
6465
},
65-
UserID: UserID(ctx),
66-
TokenID: tokenID,
67-
IsAdmin: IsAdmin(ctx),
68-
FullName: FullName(ctx),
66+
UserID: UserID(ctx),
67+
TokenID: tokenID,
68+
IsAdmin: IsAdmin(ctx),
69+
FullName: FullName(ctx),
70+
Generation: generation,
6971
}
7072
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
7173
tokenString, err := token.SignedString([]byte(a.cfg.Secret))
@@ -182,5 +184,5 @@ func (a *Authenticator) AuthenticateUser(ctx context.Context, info params.Passwo
182184
return ctx, runnerErrors.ErrUnauthorized
183185
}
184186

185-
return PopulateContext(ctx, user), nil
187+
return PopulateContext(ctx, user, nil), nil
186188
}

auth/context.go

+36-4
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ package auth
1616

1717
import (
1818
"context"
19+
"time"
1920

2021
runnerErrors "github.com/cloudbase/garm-provider-common/errors"
2122
"github.com/cloudbase/garm/params"
@@ -28,9 +29,11 @@ const (
2829
fullNameKey contextFlags = "full_name"
2930
readMetricsKey contextFlags = "read_metrics"
3031
// UserIDFlag is the User ID flag we set in the context
31-
UserIDFlag contextFlags = "user_id"
32-
isEnabledFlag contextFlags = "is_enabled"
33-
jwtTokenFlag contextFlags = "jwt_token"
32+
UserIDFlag contextFlags = "user_id"
33+
isEnabledFlag contextFlags = "is_enabled"
34+
jwtTokenFlag contextFlags = "jwt_token"
35+
authExpiresFlag contextFlags = "auth_expires"
36+
passwordGenerationFlag contextFlags = "password_generation"
3437

3538
instanceIDKey contextFlags = "id"
3639
instanceNameKey contextFlags = "name"
@@ -169,14 +172,43 @@ func PopulateInstanceContext(ctx context.Context, instance params.Instance) cont
169172

170173
// PopulateContext sets the appropriate fields in the context, based on
171174
// the user object
172-
func PopulateContext(ctx context.Context, user params.User) context.Context {
175+
func PopulateContext(ctx context.Context, user params.User, authExpires *time.Time) context.Context {
173176
ctx = SetUserID(ctx, user.ID)
174177
ctx = SetAdmin(ctx, user.IsAdmin)
175178
ctx = SetIsEnabled(ctx, user.Enabled)
176179
ctx = SetFullName(ctx, user.FullName)
180+
ctx = SetExpires(ctx, authExpires)
181+
ctx = SetPasswordGeneration(ctx, user.Generation)
177182
return ctx
178183
}
179184

185+
func SetExpires(ctx context.Context, expires *time.Time) context.Context {
186+
if expires == nil {
187+
return ctx
188+
}
189+
return context.WithValue(ctx, authExpiresFlag, expires)
190+
}
191+
192+
func Expires(ctx context.Context) *time.Time {
193+
elem := ctx.Value(authExpiresFlag)
194+
if elem == nil {
195+
return nil
196+
}
197+
return elem.(*time.Time)
198+
}
199+
200+
func SetPasswordGeneration(ctx context.Context, val uint) context.Context {
201+
return context.WithValue(ctx, passwordGenerationFlag, val)
202+
}
203+
204+
func PasswordGeneration(ctx context.Context) uint {
205+
elem := ctx.Value(passwordGenerationFlag)
206+
if elem == nil {
207+
return 0
208+
}
209+
return elem.(uint)
210+
}
211+
180212
// SetFullName sets the user full name in the context
181213
func SetFullName(ctx context.Context, fullName string) context.Context {
182214
return context.WithValue(ctx, fullNameKey, fullName)

auth/jwt.go

+14-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"log/slog"
2222
"net/http"
2323
"strings"
24+
"time"
2425

2526
jwt "github.com/golang-jwt/jwt/v5"
2627

@@ -37,6 +38,7 @@ type JWTClaims struct {
3738
FullName string `json:"full_name"`
3839
IsAdmin bool `json:"is_admin"`
3940
ReadMetrics bool `json:"read_metrics"`
41+
Generation uint `json:"generation"`
4042
jwt.RegisteredClaims
4143
}
4244

@@ -69,7 +71,18 @@ func (amw *jwtMiddleware) claimsToContext(ctx context.Context, claims *JWTClaims
6971
return ctx, runnerErrors.ErrUnauthorized
7072
}
7173

72-
ctx = PopulateContext(ctx, userInfo)
74+
var expiresAt *time.Time
75+
if claims.ExpiresAt != nil {
76+
expires := claims.ExpiresAt.Time.UTC()
77+
expiresAt = &expires
78+
}
79+
80+
if userInfo.Generation != claims.Generation {
81+
// Password was reset since token was issued. Invalidate.
82+
return ctx, runnerErrors.ErrUnauthorized
83+
}
84+
85+
ctx = PopulateContext(ctx, userInfo, expiresAt)
7386
return ctx, nil
7487
}
7588

cmd/garm-cli/cmd/log.go

+4-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616

1717
"github.com/cloudbase/garm-provider-common/util"
1818
apiParams "github.com/cloudbase/garm/apiserver/params"
19+
garmWs "github.com/cloudbase/garm/websocket"
1920
)
2021

2122
var logCmd = &cobra.Command{
@@ -66,7 +67,9 @@ var logCmd = &cobra.Command{
6667
for {
6768
_, message, err := c.ReadMessage()
6869
if err != nil {
69-
slog.With(slog.Any("error", err)).Error("reading log message")
70+
if garmWs.IsErrorOfInterest(err) {
71+
slog.With(slog.Any("error", err)).Error("reading log message")
72+
}
7073
return
7174
}
7275
fmt.Println(util.SanitizeLogEntry(string(message)))

cmd/garm/main.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ func main() {
320320
slog.With(slog.Any("error", err)).ErrorContext(ctx, "graceful api server shutdown failed")
321321
}
322322

323-
slog.With(slog.Any("error", err)).ErrorContext(ctx, "waiting for runner to stop")
323+
slog.With(slog.Any("error", err)).InfoContext(ctx, "waiting for runner to stop")
324324
if err := runner.Wait(); err != nil {
325325
slog.With(slog.Any("error", err)).ErrorContext(ctx, "failed to shutdown workers")
326326
os.Exit(1)

database/sql/github_test.go

+7-7
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ func (s *GithubTestSuite) TestCreateCredentials() {
284284
func (s *GithubTestSuite) TestCreateCredentialsFailsOnDuplicateCredentials() {
285285
ctx := garmTesting.ImpersonateAdminContext(context.Background(), s.db, s.T())
286286
testUser := garmTesting.CreateGARMTestUser(ctx, "testuser", s.db, s.T())
287-
testUserCtx := auth.PopulateContext(context.Background(), testUser)
287+
testUserCtx := auth.PopulateContext(context.Background(), testUser, nil)
288288

289289
credParams := params.CreateGithubCredentialsParams{
290290
Name: testCredsName,
@@ -313,8 +313,8 @@ func (s *GithubTestSuite) TestNormalUsersCanOnlySeeTheirOwnCredentialsAdminCanSe
313313
ctx := garmTesting.ImpersonateAdminContext(context.Background(), s.db, s.T())
314314
testUser := garmTesting.CreateGARMTestUser(ctx, "testuser1", s.db, s.T())
315315
testUser2 := garmTesting.CreateGARMTestUser(ctx, "testuser2", s.db, s.T())
316-
testUserCtx := auth.PopulateContext(context.Background(), testUser)
317-
testUser2Ctx := auth.PopulateContext(context.Background(), testUser2)
316+
testUserCtx := auth.PopulateContext(context.Background(), testUser, nil)
317+
testUser2Ctx := auth.PopulateContext(context.Background(), testUser2, nil)
318318

319319
credParams := params.CreateGithubCredentialsParams{
320320
Name: testCredsName,
@@ -370,7 +370,7 @@ func (s *GithubTestSuite) TestGetGithubCredentialsFailsWhenCredentialsDontExist(
370370
func (s *GithubTestSuite) TestGetGithubCredentialsByNameReturnsOnlyCurrentUserCredentials() {
371371
ctx := garmTesting.ImpersonateAdminContext(context.Background(), s.db, s.T())
372372
testUser := garmTesting.CreateGARMTestUser(ctx, "test-user1", s.db, s.T())
373-
testUserCtx := auth.PopulateContext(context.Background(), testUser)
373+
testUserCtx := auth.PopulateContext(context.Background(), testUser, nil)
374374

375375
credParams := params.CreateGithubCredentialsParams{
376376
Name: testCredsName,
@@ -472,7 +472,7 @@ func (s *GithubTestSuite) TestDeleteGithubCredentials() {
472472
func (s *GithubTestSuite) TestDeleteGithubCredentialsByNonAdminUser() {
473473
ctx := garmTesting.ImpersonateAdminContext(context.Background(), s.db, s.T())
474474
testUser := garmTesting.CreateGARMTestUser(ctx, "test-user4", s.db, s.T())
475-
testUserCtx := auth.PopulateContext(context.Background(), testUser)
475+
testUserCtx := auth.PopulateContext(context.Background(), testUser, nil)
476476

477477
credParams := params.CreateGithubCredentialsParams{
478478
Name: testCredsName,
@@ -682,7 +682,7 @@ func (s *GithubTestSuite) TestUpdateCredentialsFailsForNonExistingCredentials()
682682
func (s *GithubTestSuite) TestUpdateCredentialsFailsIfCredentialsAreOwnedByNonAdminUser() {
683683
ctx := garmTesting.ImpersonateAdminContext(context.Background(), s.db, s.T())
684684
testUser := garmTesting.CreateGARMTestUser(ctx, "test-user5", s.db, s.T())
685-
testUserCtx := auth.PopulateContext(context.Background(), testUser)
685+
testUserCtx := auth.PopulateContext(context.Background(), testUser, nil)
686686

687687
credParams := params.CreateGithubCredentialsParams{
688688
Name: testCredsName,
@@ -711,7 +711,7 @@ func (s *GithubTestSuite) TestUpdateCredentialsFailsIfCredentialsAreOwnedByNonAd
711711
func (s *GithubTestSuite) TestAdminUserCanUpdateAnyGithubCredentials() {
712712
ctx := garmTesting.ImpersonateAdminContext(context.Background(), s.db, s.T())
713713
testUser := garmTesting.CreateGARMTestUser(ctx, "test-user5", s.db, s.T())
714-
testUserCtx := auth.PopulateContext(context.Background(), testUser)
714+
testUserCtx := auth.PopulateContext(context.Background(), testUser, nil)
715715

716716
credParams := params.CreateGithubCredentialsParams{
717717
Name: testCredsName,

database/sql/models.go

+7-6
Original file line numberDiff line numberDiff line change
@@ -195,12 +195,13 @@ type Instance struct {
195195
type User struct {
196196
Base
197197

198-
Username string `gorm:"uniqueIndex;varchar(64)"`
199-
FullName string `gorm:"type:varchar(254)"`
200-
Email string `gorm:"type:varchar(254);unique;index:idx_email"`
201-
Password string `gorm:"type:varchar(60)"`
202-
IsAdmin bool
203-
Enabled bool
198+
Username string `gorm:"uniqueIndex;varchar(64)"`
199+
FullName string `gorm:"type:varchar(254)"`
200+
Email string `gorm:"type:varchar(254);unique;index:idx_email"`
201+
Password string `gorm:"type:varchar(60)"`
202+
Generation uint
203+
IsAdmin bool
204+
Enabled bool
204205
}
205206

206207
type ControllerInfo struct {

database/sql/sql.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ func (s *sqlDatabase) migrateCredentialsToDB() (err error) {
239239
// user. GARM is not yet multi-user, so it's safe to assume we only have this
240240
// one user.
241241
adminCtx := context.Background()
242-
adminCtx = auth.PopulateContext(adminCtx, adminUser)
242+
adminCtx = auth.PopulateContext(adminCtx, adminUser, nil)
243243

244244
slog.Info("migrating credentials to DB")
245245
slog.Info("creating github endpoints table")

0 commit comments

Comments
 (0)