Skip to content

Commit 18181da

Browse files
lucasmrodgetvictor
authored andcommitted
Happy path implementation of the calendar cron job (#17713)
Happy path for #17441.
1 parent a38f414 commit 18181da

15 files changed

+1050
-3
lines changed

cmd/fleet/calendar_cron.go

+454
Large diffs are not rendered by default.

cmd/fleet/calendar_cron_test.go

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package main
2+
3+
import (
4+
"testing"
5+
"time"
6+
7+
"github.com/stretchr/testify/require"
8+
)
9+
10+
func TestGetPreferredCalendarEventDate(t *testing.T) {
11+
date := func(year int, month time.Month, day int) time.Time {
12+
return time.Date(year, month, day, 0, 0, 0, 0, time.UTC)
13+
}
14+
for _, tc := range []struct {
15+
name string
16+
year int
17+
month time.Month
18+
days int
19+
20+
expected time.Time
21+
}{
22+
{
23+
year: 2024,
24+
month: 3,
25+
days: 31,
26+
name: "March 2024",
27+
expected: date(2024, 3, 19),
28+
},
29+
{
30+
year: 2024,
31+
month: 4,
32+
days: 30,
33+
name: "April 2024",
34+
expected: date(2024, 4, 16),
35+
},
36+
} {
37+
t.Run(tc.name, func(t *testing.T) {
38+
for day := 1; day <= tc.days; day++ {
39+
actual := getPreferredCalendarEventDate(tc.year, tc.month, day)
40+
require.NotEqual(t, actual.Weekday(), time.Saturday)
41+
require.NotEqual(t, actual.Weekday(), time.Sunday)
42+
if day <= tc.expected.Day() {
43+
require.Equal(t, tc.expected, actual)
44+
} else {
45+
today := date(tc.year, tc.month, day)
46+
if weekday := today.Weekday(); weekday == time.Friday {
47+
require.Equal(t, today.AddDate(0, 0, +3), actual)
48+
} else if weekday == time.Saturday {
49+
require.Equal(t, today.AddDate(0, 0, +2), actual)
50+
} else {
51+
require.Equal(t, today.AddDate(0, 0, +1), actual)
52+
}
53+
}
54+
}
55+
})
56+
}
57+
}

cmd/fleet/serve.go

+12
Original file line numberDiff line numberDiff line change
@@ -768,6 +768,18 @@ the way that the Fleet server works.
768768
}
769769
}
770770

771+
if license.IsPremium() {
772+
if err := cronSchedules.StartCronSchedule(
773+
func() (fleet.CronSchedule, error) {
774+
return newCalendarSchedule(
775+
ctx, instanceID, ds, logger,
776+
)
777+
},
778+
); err != nil {
779+
initFatal(err, "failed to register calendar schedule")
780+
}
781+
}
782+
771783
level.Info(logger).Log("msg", fmt.Sprintf("started cron schedules: %s", strings.Join(cronSchedules.ScheduleNames(), ", ")))
772784

773785
// StartCollectors starts a goroutine per collector, using ctx to cancel.
+150
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
package mysql
2+
3+
import (
4+
"context"
5+
"database/sql"
6+
"fmt"
7+
"time"
8+
9+
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
10+
"github.com/fleetdm/fleet/v4/server/fleet"
11+
"github.com/jmoiron/sqlx"
12+
)
13+
14+
func (ds *Datastore) NewCalendarEvent(
15+
ctx context.Context,
16+
email string,
17+
startTime time.Time,
18+
endTime time.Time,
19+
data []byte,
20+
hostID uint,
21+
) (*fleet.CalendarEvent, error) {
22+
var calendarEvent *fleet.CalendarEvent
23+
if err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
24+
const calendarEventsQuery = `
25+
INSERT INTO calendar_events (
26+
email,
27+
start_time,
28+
end_time,
29+
event
30+
) VALUES (?, ?, ?, ?);
31+
`
32+
result, err := tx.ExecContext(
33+
ctx,
34+
calendarEventsQuery,
35+
email,
36+
startTime,
37+
endTime,
38+
data,
39+
)
40+
if err != nil {
41+
return ctxerr.Wrap(ctx, err, "insert calendar event")
42+
}
43+
44+
id, _ := result.LastInsertId()
45+
calendarEvent = &fleet.CalendarEvent{
46+
ID: uint(id),
47+
Email: email,
48+
StartTime: startTime,
49+
EndTime: endTime,
50+
Data: data,
51+
}
52+
53+
const hostCalendarEventsQuery = `
54+
INSERT INTO host_calendar_events (
55+
host_id,
56+
calendar_event_id,
57+
webhook_status
58+
) VALUES (?, ?, ?);
59+
`
60+
result, err = tx.ExecContext(
61+
ctx,
62+
hostCalendarEventsQuery,
63+
hostID,
64+
calendarEvent.ID,
65+
fleet.CalendarWebhookStatusPending,
66+
)
67+
if err != nil {
68+
return ctxerr.Wrap(ctx, err, "insert host calendar event")
69+
}
70+
return nil
71+
}); err != nil {
72+
return nil, ctxerr.Wrap(ctx, err)
73+
}
74+
return calendarEvent, nil
75+
}
76+
77+
func (ds *Datastore) GetCalendarEvent(ctx context.Context, email string) (*fleet.CalendarEvent, error) {
78+
const calendarEventsQuery = `
79+
SELECT * FROM calendar_events WHERE email = ?;
80+
`
81+
var calendarEvent fleet.CalendarEvent
82+
err := sqlx.GetContext(ctx, ds.reader(ctx), &calendarEvent, calendarEventsQuery, email)
83+
if err != nil {
84+
if err == sql.ErrNoRows {
85+
return nil, ctxerr.Wrap(ctx, notFound("CalendarEvent").WithMessage(fmt.Sprintf("email: %s", email)))
86+
}
87+
return nil, ctxerr.Wrap(ctx, err, "get calendar event")
88+
}
89+
return &calendarEvent, nil
90+
}
91+
92+
func (ds *Datastore) UpdateCalendarEvent(ctx context.Context, calendarEventID uint, startTime time.Time, endTime time.Time, data []byte) error {
93+
const calendarEventsQuery = `
94+
UPDATE calendar_events SET
95+
start_time = ?,
96+
end_time = ?,
97+
event = ?
98+
WHERE id = ?;
99+
`
100+
if _, err := ds.writer(ctx).ExecContext(ctx, calendarEventsQuery, startTime, endTime, data, calendarEventID); err != nil {
101+
return ctxerr.Wrap(ctx, err, "update calendar event")
102+
}
103+
return nil
104+
}
105+
106+
func (ds *Datastore) DeleteCalendarEvent(ctx context.Context, calendarEventID uint) error {
107+
const calendarEventsQuery = `
108+
DELETE FROM calendar_events WHERE id = ?;
109+
`
110+
if _, err := ds.writer(ctx).ExecContext(ctx, calendarEventsQuery, calendarEventID); err != nil {
111+
return ctxerr.Wrap(ctx, err, "delete calendar event")
112+
}
113+
return nil
114+
}
115+
116+
func (ds *Datastore) GetHostCalendarEvent(ctx context.Context, hostID uint) (*fleet.HostCalendarEvent, *fleet.CalendarEvent, error) {
117+
const hostCalendarEventsQuery = `
118+
SELECT * FROM host_calendar_events WHERE host_id = ?
119+
`
120+
var hostCalendarEvent fleet.HostCalendarEvent
121+
if err := sqlx.GetContext(ctx, ds.reader(ctx), &hostCalendarEvent, hostCalendarEventsQuery, hostID); err != nil {
122+
if err == sql.ErrNoRows {
123+
return nil, nil, ctxerr.Wrap(ctx, notFound("HostCalendarEvent").WithMessage(fmt.Sprintf("host_id: %d", hostID)))
124+
}
125+
return nil, nil, ctxerr.Wrap(ctx, err, "get host calendar event")
126+
}
127+
const calendarEventsQuery = `
128+
SELECT * FROM calendar_events WHERE id = ?
129+
`
130+
var calendarEvent fleet.CalendarEvent
131+
if err := sqlx.GetContext(ctx, ds.reader(ctx), &calendarEvent, calendarEventsQuery, hostCalendarEvent.CalendarEventID); err != nil {
132+
if err == sql.ErrNoRows {
133+
return nil, nil, ctxerr.Wrap(ctx, notFound("CalendarEvent").WithID(hostCalendarEvent.CalendarEventID))
134+
}
135+
return nil, nil, ctxerr.Wrap(ctx, err, "get calendar event")
136+
}
137+
return &hostCalendarEvent, &calendarEvent, nil
138+
}
139+
140+
func (ds *Datastore) UpdateHostCalendarWebhookStatus(ctx context.Context, hostID uint, status fleet.CalendarWebhookStatus) error {
141+
const calendarEventsQuery = `
142+
UPDATE host_calendar_events SET
143+
webhook_status = ?
144+
WHERE host_id = ?;
145+
`
146+
if _, err := ds.writer(ctx).ExecContext(ctx, calendarEventsQuery, status, hostID); err != nil {
147+
return ctxerr.Wrap(ctx, err, "update host calendar event webhook status")
148+
}
149+
return nil
150+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
package mysql
2+
3+
import "testing"
4+
5+
func TestCalendarEvents(t *testing.T) {
6+
}

server/datastore/mysql/policies.go

+51-1
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@ import (
55
"database/sql"
66
"encoding/json"
77
"fmt"
8-
"golang.org/x/text/unicode/norm"
98
"sort"
109
"strings"
1110
"time"
1211

12+
"golang.org/x/text/unicode/norm"
13+
1314
"github.com/doug-martin/goqu/v9"
1415
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
1516
"github.com/fleetdm/fleet/v4/server/fleet"
@@ -1159,3 +1160,52 @@ func (ds *Datastore) UpdateHostPolicyCounts(ctx context.Context) error {
11591160

11601161
return nil
11611162
}
1163+
1164+
func (ds *Datastore) GetCalendarPolicies(ctx context.Context, teamID uint) ([]fleet.PolicyCalendarData, error) {
1165+
query := `SELECT id, name FROM policies WHERE team_id = ? AND calendar_events_enabled;`
1166+
var policies []fleet.PolicyCalendarData
1167+
err := sqlx.SelectContext(ctx, ds.reader(ctx), &policies, query, teamID)
1168+
if err != nil {
1169+
return nil, ctxerr.Wrap(ctx, err, "get calendar policies")
1170+
}
1171+
return policies, nil
1172+
}
1173+
1174+
// TODO(lucas): Must be tested at scale.
1175+
func (ds *Datastore) GetHostsPolicyMemberships(ctx context.Context, domain string, policyIDs []uint) ([]fleet.HostPolicyMembershipData, error) {
1176+
query := `
1177+
SELECT
1178+
COALESCE(sh.email, '') AS email,
1179+
pm.passing AS passing,
1180+
h.id AS host_id,
1181+
hdn.display_name AS host_display_name,
1182+
h.hardware_serial AS host_hardware_serial
1183+
FROM (
1184+
SELECT host_id, BIT_AND(COALESCE(passes, 0)) AS passing
1185+
FROM policy_membership
1186+
WHERE policy_id IN (?)
1187+
GROUP BY host_id
1188+
) pm
1189+
LEFT JOIN (
1190+
SELECT MIN(h.host_id) as host_id, h.email as email
1191+
FROM (
1192+
SELECT host_id, MIN(email) AS email
1193+
FROM host_emails WHERE email LIKE CONCAT('%@', ?)
1194+
GROUP BY host_id
1195+
) h GROUP BY h.email
1196+
) sh ON sh.host_id = pm.host_id
1197+
JOIN hosts h ON h.id = pm.host_id
1198+
LEFT JOIN host_display_names hdn ON hdn.host_id = pm.host_id;
1199+
`
1200+
1201+
query, args, err := sqlx.In(query, policyIDs, domain)
1202+
if err != nil {
1203+
return nil, ctxerr.Wrapf(ctx, err, "build select get team hosts policy memberships query")
1204+
}
1205+
var hosts []fleet.HostPolicyMembershipData
1206+
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &hosts, query, args...); err != nil {
1207+
return nil, ctxerr.Wrap(ctx, err, "listing policies")
1208+
}
1209+
1210+
return hosts, nil
1211+
}

server/datastore/mysql/policies_test.go

+55-1
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ func TestPolicies(t *testing.T) {
5959
{"TestPoliciesNameUnicode", testPoliciesNameUnicode},
6060
{"TestPoliciesNameEmoji", testPoliciesNameEmoji},
6161
{"TestPoliciesNameSort", testPoliciesNameSort},
62+
{"TestGetCalendarPolicies", testGetCalendarPolicies},
6263
}
6364
for _, c := range cases {
6465
t.Run(c.name, func(t *testing.T) {
@@ -2784,7 +2785,6 @@ func testPoliciesNameEmoji(t *testing.T, ds *Datastore) {
27842785
assert.NoError(t, err)
27852786
require.Len(t, policies, 1)
27862787
assert.Equal(t, emoji1, policies[0].Name)
2787-
27882788
}
27892789

27902790
// Ensure case-insensitive sort order for policy names
@@ -2806,3 +2806,57 @@ func testPoliciesNameSort(t *testing.T, ds *Datastore) {
28062806
assert.Equal(t, policy.Name, policiesResult[i].Name)
28072807
}
28082808
}
2809+
2810+
func testGetCalendarPolicies(t *testing.T, ds *Datastore) {
2811+
ctx := context.Background()
2812+
2813+
// Test with non-existent team.
2814+
_, err := ds.GetCalendarPolicies(ctx, 999)
2815+
require.NoError(t, err)
2816+
2817+
team, err := ds.NewTeam(ctx, &fleet.Team{
2818+
Name: "Foobar",
2819+
})
2820+
require.NoError(t, err)
2821+
2822+
// Test when the team has no policies.
2823+
_, err = ds.GetCalendarPolicies(ctx, team.ID)
2824+
require.NoError(t, err)
2825+
2826+
// Create a global query to test that only team policies are returned.
2827+
_, err = ds.NewGlobalPolicy(ctx, nil, fleet.PolicyPayload{
2828+
Name: "Global Policy",
2829+
Query: "SELECT * FROM time;",
2830+
})
2831+
require.NoError(t, err)
2832+
2833+
_, err = ds.NewTeamPolicy(ctx, team.ID, nil, fleet.PolicyPayload{
2834+
Name: "Team Policy 1",
2835+
Query: "SELECT * FROM system_info;",
2836+
CalendarEventsEnabled: false,
2837+
})
2838+
require.NoError(t, err)
2839+
2840+
// Test when the team has policies, but none is configured for calendar.
2841+
_, err = ds.GetCalendarPolicies(ctx, team.ID)
2842+
require.NoError(t, err)
2843+
2844+
teamPolicy2, err := ds.NewTeamPolicy(ctx, team.ID, nil, fleet.PolicyPayload{
2845+
Name: "Team Policy 2",
2846+
Query: "SELECT * FROM osquery_info;",
2847+
CalendarEventsEnabled: true,
2848+
})
2849+
require.NoError(t, err)
2850+
teamPolicy3, err := ds.NewTeamPolicy(ctx, team.ID, nil, fleet.PolicyPayload{
2851+
Name: "Team Policy 3",
2852+
Query: "SELECT * FROM os_version;",
2853+
CalendarEventsEnabled: true,
2854+
})
2855+
require.NoError(t, err)
2856+
2857+
calendarPolicies, err := ds.GetCalendarPolicies(ctx, team.ID)
2858+
require.NoError(t, err)
2859+
require.Len(t, calendarPolicies, 2)
2860+
require.Equal(t, calendarPolicies[0].ID, teamPolicy2.ID)
2861+
require.Equal(t, calendarPolicies[1].ID, teamPolicy3.ID)
2862+
}

server/fleet/app.go

+7
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,13 @@ func (c *AppConfig) Copy() *AppConfig {
571571
clone.Integrations.Zendesk[i] = &zd
572572
}
573573
}
574+
if len(c.Integrations.GoogleCalendar) > 0 {
575+
clone.Integrations.GoogleCalendar = make([]*GoogleCalendarIntegration, len(c.Integrations.GoogleCalendar))
576+
for i, g := range c.Integrations.GoogleCalendar {
577+
gc := *g
578+
clone.Integrations.GoogleCalendar[i] = &gc
579+
}
580+
}
574581

575582
if c.MDM.MacOSSettings.CustomSettings != nil {
576583
clone.MDM.MacOSSettings.CustomSettings = make([]MDMProfileSpec, len(c.MDM.MacOSSettings.CustomSettings))

0 commit comments

Comments
 (0)