Skip to content

Commit 794a322

Browse files
authored
More fixes to support users with hosts in same team and hosts in different teams (#17789)
#17441
1 parent 5219427 commit 794a322

File tree

6 files changed

+306
-26
lines changed

6 files changed

+306
-26
lines changed

cmd/fleet/calendar_cron.go

+50-10
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ func cronCalendarEventsForTeam(
125125
for _, policy := range policies {
126126
policyIDs = append(policyIDs, policy.ID)
127127
}
128-
hosts, err := ds.GetHostsPolicyMemberships(ctx, domain, policyIDs)
128+
hosts, err := ds.GetTeamHostsPolicyMemberships(ctx, domain, team.ID, policyIDs)
129129
if err != nil {
130130
return fmt.Errorf("get team hosts failing policies: %w", err)
131131
}
@@ -150,22 +150,28 @@ func cronCalendarEventsForTeam(
150150
}
151151
level.Debug(logger).Log(
152152
"msg", "summary",
153+
"team_id", team.ID,
153154
"passing_hosts", len(passingHosts),
154155
"failing_hosts", len(failingHosts),
155156
"failing_hosts_without_associated_email", len(failingHostsWithoutAssociatedEmail),
156157
)
157158

159+
// Remove calendar events from hosts that are passing the calendar policies.
160+
//
161+
// We execute this first to remove any calendar events for a user that is now passing
162+
// policies on one of its hosts, and possibly create a new calendar event if they have
163+
// another failing host on the same team.
164+
if err := removeCalendarEventsFromPassingHosts(ctx, ds, calendar, passingHosts); err != nil {
165+
level.Info(logger).Log("msg", "removing calendar events from passing hosts", "err", err)
166+
}
167+
168+
// Process hosts that are failing calendar policies.
158169
if err := processCalendarFailingHosts(
159170
ctx, ds, calendar, orgName, failingHosts, logger,
160171
); err != nil {
161172
level.Info(logger).Log("msg", "processing failing hosts", "err", err)
162173
}
163174

164-
// Remove calendar events from hosts that are passing the policies.
165-
if err := removeCalendarEventsFromPassingHosts(ctx, ds, calendar, passingHosts); err != nil {
166-
level.Info(logger).Log("msg", "removing calendar events from passing hosts", "err", err)
167-
}
168-
169175
// At last we want to log the hosts that are failing and don't have an associated email.
170176
logHostsWithoutAssociatedEmail(
171177
domain,
@@ -184,14 +190,26 @@ func processCalendarFailingHosts(
184190
hosts []fleet.HostPolicyMembershipData,
185191
logger kitlog.Logger,
186192
) error {
193+
hosts = filterHostsWithSameEmail(hosts)
194+
187195
for _, host := range hosts {
188196
logger := log.With(logger, "host_id", host.HostID)
189197

190-
hostCalendarEvent, calendarEvent, err := ds.GetHostCalendarEvent(ctx, host.HostID)
198+
hostCalendarEvent, calendarEvent, err := ds.GetHostCalendarEventByEmail(ctx, host.Email)
191199

192200
expiredEvent := false
193201
webhookAlreadyFiredThisMonth := false
194202
if err == nil {
203+
if hostCalendarEvent.HostID != host.HostID {
204+
// This calendar event belongs to another host with this associated email,
205+
// thus we skip this entry.
206+
continue // continue with next host
207+
}
208+
if hostCalendarEvent.WebhookStatus == fleet.CalendarWebhookStatusPending {
209+
// This can happen if the host went offline (and never returned results)
210+
// after setting the webhook as pending.
211+
continue // continue with next host
212+
}
195213
now := time.Now()
196214
webhookAlreadyFired := hostCalendarEvent.WebhookStatus == fleet.CalendarWebhookStatusSent
197215
if webhookAlreadyFired && sameDate(now, calendarEvent.StartTime) {
@@ -200,7 +218,7 @@ func processCalendarFailingHosts(
200218
continue // continue with next host
201219
}
202220
webhookAlreadyFiredThisMonth = webhookAlreadyFired && sameMonth(now, calendarEvent.StartTime)
203-
if calendarEvent.EndTime.Before(time.Now()) {
221+
if calendarEvent.EndTime.Before(now) {
204222
expiredEvent = true
205223
}
206224
}
@@ -232,6 +250,25 @@ func processCalendarFailingHosts(
232250
return nil
233251
}
234252

253+
func filterHostsWithSameEmail(hosts []fleet.HostPolicyMembershipData) []fleet.HostPolicyMembershipData {
254+
minHostPerEmail := make(map[string]fleet.HostPolicyMembershipData)
255+
for _, host := range hosts {
256+
minHost, ok := minHostPerEmail[host.Email]
257+
if !ok {
258+
minHostPerEmail[host.Email] = host
259+
continue
260+
}
261+
if host.HostID < minHost.HostID {
262+
minHostPerEmail[host.Email] = host
263+
}
264+
}
265+
filtered := make([]fleet.HostPolicyMembershipData, 0, len(minHostPerEmail))
266+
for _, host := range minHostPerEmail {
267+
filtered = append(filtered, host)
268+
}
269+
return filtered
270+
}
271+
235272
func processFailingHostExistingCalendarEvent(
236273
ctx context.Context,
237274
ds fleet.Datastore,
@@ -416,10 +453,13 @@ func removeCalendarEventsFromPassingHosts(
416453
hosts []fleet.HostPolicyMembershipData,
417454
) error {
418455
for _, host := range hosts {
419-
calendarEvent, err := ds.GetCalendarEvent(ctx, host.Email)
456+
hostCalendarEvent, calendarEvent, err := ds.GetHostCalendarEventByEmail(ctx, host.Email)
420457
switch {
421458
case err == nil:
422-
// OK
459+
if hostCalendarEvent.HostID != host.HostID {
460+
// This calendar event belongs to another host, thus we skip this entry.
461+
continue
462+
}
423463
case fleet.IsNotFound(err):
424464
continue
425465
default:

server/datastore/mysql/calendar_events.go

+24
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,30 @@ func (ds *Datastore) GetHostCalendarEvent(ctx context.Context, hostID uint) (*fl
167167
return &hostCalendarEvent, &calendarEvent, nil
168168
}
169169

170+
func (ds *Datastore) GetHostCalendarEventByEmail(ctx context.Context, email string) (*fleet.HostCalendarEvent, *fleet.CalendarEvent, error) {
171+
const calendarEventsQuery = `
172+
SELECT * FROM calendar_events WHERE email = ?
173+
`
174+
var calendarEvent fleet.CalendarEvent
175+
if err := sqlx.GetContext(ctx, ds.reader(ctx), &calendarEvent, calendarEventsQuery, email); err != nil {
176+
if err == sql.ErrNoRows {
177+
return nil, nil, ctxerr.Wrap(ctx, notFound("CalendarEvent").WithMessage(fmt.Sprintf("email: %s", email)))
178+
}
179+
return nil, nil, ctxerr.Wrap(ctx, err, "get calendar event")
180+
}
181+
const hostCalendarEventsQuery = `
182+
SELECT * FROM host_calendar_events WHERE calendar_event_id = ?
183+
`
184+
var hostCalendarEvent fleet.HostCalendarEvent
185+
if err := sqlx.GetContext(ctx, ds.reader(ctx), &hostCalendarEvent, hostCalendarEventsQuery, calendarEvent.ID); err != nil {
186+
if err == sql.ErrNoRows {
187+
return nil, nil, ctxerr.Wrap(ctx, notFound("HostCalendarEvent").WithID(calendarEvent.ID))
188+
}
189+
return nil, nil, ctxerr.Wrap(ctx, err, "get host calendar event")
190+
}
191+
return &hostCalendarEvent, &calendarEvent, nil
192+
}
193+
170194
func (ds *Datastore) UpdateHostCalendarWebhookStatus(ctx context.Context, hostID uint, status fleet.CalendarWebhookStatus) error {
171195
const calendarEventsQuery = `
172196
UPDATE host_calendar_events SET

server/datastore/mysql/policies.go

+12-9
Original file line numberDiff line numberDiff line change
@@ -1172,8 +1172,12 @@ func (ds *Datastore) GetCalendarPolicies(ctx context.Context, teamID uint) ([]fl
11721172
}
11731173

11741174
// TODO(lucas): Must be tested at scale.
1175-
// TODO(lucas): Filter out hosts with team_id == NULL
1176-
func (ds *Datastore) GetHostsPolicyMemberships(ctx context.Context, domain string, policyIDs []uint) ([]fleet.HostPolicyMembershipData, error) {
1175+
func (ds *Datastore) GetTeamHostsPolicyMemberships(
1176+
ctx context.Context,
1177+
domain string,
1178+
teamID uint,
1179+
policyIDs []uint,
1180+
) ([]fleet.HostPolicyMembershipData, error) {
11771181
query := `
11781182
SELECT
11791183
COALESCE(sh.email, '') AS email,
@@ -1188,18 +1192,17 @@ func (ds *Datastore) GetHostsPolicyMemberships(ctx context.Context, domain strin
11881192
GROUP BY host_id
11891193
) pm
11901194
LEFT JOIN (
1191-
SELECT MIN(h.host_id) as host_id, h.email as email
1192-
FROM (
1193-
SELECT host_id, MIN(email) AS email
1194-
FROM host_emails WHERE email LIKE CONCAT('%@', ?)
1195-
GROUP BY host_id
1196-
) h GROUP BY h.email
1195+
SELECT host_id, MIN(email) AS email
1196+
FROM host_emails
1197+
JOIN hosts ON host_emails.host_id=hosts.id
1198+
WHERE email LIKE CONCAT('%@', ?) AND team_id = ?
1199+
GROUP BY host_id
11971200
) sh ON sh.host_id = pm.host_id
11981201
JOIN hosts h ON h.id = pm.host_id
11991202
LEFT JOIN host_display_names hdn ON hdn.host_id = pm.host_id;
12001203
`
12011204

1202-
query, args, err := sqlx.In(query, policyIDs, domain)
1205+
query, args, err := sqlx.In(query, policyIDs, domain, teamID)
12031206
if err != nil {
12041207
return nil, ctxerr.Wrapf(ctx, err, "build select get team hosts policy memberships query")
12051208
}

0 commit comments

Comments
 (0)