Skip to content

Allow ACLs to be reloaded with SIGHUP #601

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 3, 2022
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
- Add option to enable/disable logtail (Tailscale's logging infrastructure) [#596](https://github.com/juanfont/headscale/pull/596)
- This change disables the logs by default
- Use [Prometheus]'s duration parser, supporting days (`d`), weeks (`w`) and years (`y`) [#598](https://github.com/juanfont/headscale/pull/598)
- Add support for reloading ACLs with SIGHUP [#601](https://github.com/juanfont/headscale/pull/601)

## 0.15.0 (2022-03-20)

Expand Down
74 changes: 61 additions & 13 deletions app.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ type Config struct {
LogTail LogTailConfig

CLI CLIConfig

ACL ACLConfig
}

type OIDCConfig struct {
Expand Down Expand Up @@ -152,6 +154,10 @@ type CLIConfig struct {
Insecure bool
}

type ACLConfig struct {
PolicyPath string
}

// Headscale represents the base app of the service.
type Headscale struct {
cfg Config
Expand Down Expand Up @@ -568,19 +574,6 @@ func (h *Headscale) Serve() error {
return fmt.Errorf("failed change permission of gRPC socket: %w", err)
}

// Handle common process-killing signals so we can gracefully shut down:
sigc := make(chan os.Signal, 1)
signal.Notify(sigc, os.Interrupt, syscall.SIGTERM)
go func(c chan os.Signal) {
// Wait for a SIGINT or SIGKILL:
sig := <-c
log.Printf("Caught signal %s: shutting down.", sig)
// Stop listening (and unlink the socket if unix type):
socketListener.Close()
// And we're done:
os.Exit(0)
}(sigc)

grpcGatewayMux := runtime.NewServeMux()

// Make the grpc-gateway connect to grpc over socket
Expand Down Expand Up @@ -725,6 +718,61 @@ func (h *Headscale) Serve() error {
log.Info().
Msgf("listening and serving metrics on: %s", h.cfg.MetricsAddr)

// Handle common process-killing signals so we can gracefully shut down:
sigc := make(chan os.Signal, 1)
signal.Notify(sigc,
syscall.SIGHUP,
syscall.SIGINT,
syscall.SIGTERM,
syscall.SIGQUIT,
syscall.SIGHUP)
go func(c chan os.Signal) {
// Wait for a SIGINT or SIGKILL:
for {
sig := <-c
switch sig {
case syscall.SIGHUP:
log.Info().
Str("signal", sig.String()).
Msg("Received SIGHUP, reloading ACL and Config")

// TODO(kradalby): Reload config on SIGHUP

if h.cfg.ACL.PolicyPath != "" {
aclPath := AbsolutePathFromConfigPath(h.cfg.ACL.PolicyPath)
err := h.LoadACLPolicy(aclPath)
if err != nil {
log.Error().Err(err).Msg("Failed to reload ACL policy")
}
log.Info().
Str("path", aclPath).
Msg("ACL policy successfully reloaded")
}

default:
log.Info().
Str("signal", sig.String()).
Msg("Received signal to stop, shutting down gracefully")

// Gracefully shut down servers
promHTTPServer.Shutdown(ctx)
httpServer.Shutdown(ctx)
grpcSocket.GracefulStop()

// Close network listeners
promHTTPListener.Close()
httpListener.Close()
grpcGatewayConn.Close()

// Stop listening (and unlink the socket if unix type):
socketListener.Close()

// And we're done:
os.Exit(0)
}
}
}(sigc)

return errorGroup.Wait()
}

Expand Down
48 changes: 22 additions & 26 deletions cmd/headscale/cli/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"io/fs"
"net/url"
"os"
"path/filepath"
"reflect"
"strconv"
"strings"
Expand Down Expand Up @@ -169,14 +168,22 @@ func GetDERPConfig() headscale.DERPConfig {
}
}

func GetLogConfig() headscale.LogTailConfig {
func GetLogTailConfig() headscale.LogTailConfig {
enabled := viper.GetBool("logtail.enabled")

return headscale.LogTailConfig{
Enabled: enabled,
}
}

func GetACLConfig() headscale.ACLConfig {
policyPath := viper.GetString("acl_policy_path")

return headscale.ACLConfig{
PolicyPath: policyPath,
}
}

func GetDNSConfig() (*tailcfg.DNSConfig, string) {
if viper.IsSet("dns_config") {
dnsConfig := &tailcfg.DNSConfig{}
Expand Down Expand Up @@ -264,23 +271,10 @@ func GetDNSConfig() (*tailcfg.DNSConfig, string) {
return nil, ""
}

func absPath(path string) string {
// If a relative path is provided, prefix it with the the directory where
// the config file was found.
if (path != "") && !strings.HasPrefix(path, string(os.PathSeparator)) {
dir, _ := filepath.Split(viper.ConfigFileUsed())
if dir != "" {
path = filepath.Join(dir, path)
}
}

return path
}

func getHeadscaleConfig() headscale.Config {
func GetHeadscaleConfig() headscale.Config {
dnsConfig, baseDomain := GetDNSConfig()
derpConfig := GetDERPConfig()
logConfig := GetLogConfig()
logConfig := GetLogTailConfig()

configuredPrefixes := viper.GetStringSlice("ip_prefixes")
parsedPrefixes := make([]netaddr.IPPrefix, 0, len(configuredPrefixes)+1)
Expand Down Expand Up @@ -342,7 +336,7 @@ func getHeadscaleConfig() headscale.Config {
GRPCAllowInsecure: viper.GetBool("grpc_allow_insecure"),

IPPrefixes: prefixes,
PrivateKeyPath: absPath(viper.GetString("private_key_path")),
PrivateKeyPath: headscale.AbsolutePathFromConfigPath(viper.GetString("private_key_path")),
BaseDomain: baseDomain,

DERP: derpConfig,
Expand All @@ -352,7 +346,7 @@ func getHeadscaleConfig() headscale.Config {
),

DBtype: viper.GetString("db_type"),
DBpath: absPath(viper.GetString("db_path")),
DBpath: headscale.AbsolutePathFromConfigPath(viper.GetString("db_path")),
DBhost: viper.GetString("db_host"),
DBport: viper.GetInt("db_port"),
DBname: viper.GetString("db_name"),
Expand All @@ -361,13 +355,13 @@ func getHeadscaleConfig() headscale.Config {

TLSLetsEncryptHostname: viper.GetString("tls_letsencrypt_hostname"),
TLSLetsEncryptListen: viper.GetString("tls_letsencrypt_listen"),
TLSLetsEncryptCacheDir: absPath(
TLSLetsEncryptCacheDir: headscale.AbsolutePathFromConfigPath(
viper.GetString("tls_letsencrypt_cache_dir"),
),
TLSLetsEncryptChallengeType: viper.GetString("tls_letsencrypt_challenge_type"),

TLSCertPath: absPath(viper.GetString("tls_cert_path")),
TLSKeyPath: absPath(viper.GetString("tls_key_path")),
TLSCertPath: headscale.AbsolutePathFromConfigPath(viper.GetString("tls_cert_path")),
TLSKeyPath: headscale.AbsolutePathFromConfigPath(viper.GetString("tls_key_path")),
TLSClientAuthMode: tlsClientAuthMode,

DNSConfig: dnsConfig,
Expand Down Expand Up @@ -397,6 +391,8 @@ func getHeadscaleConfig() headscale.Config {
Timeout: viper.GetDuration("cli.timeout"),
Insecure: viper.GetBool("cli.insecure"),
},

ACL: GetACLConfig(),
}
}

Expand All @@ -416,7 +412,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
return nil, err
}

cfg := getHeadscaleConfig()
cfg := GetHeadscaleConfig()

app, err := headscale.NewHeadscale(cfg)
if err != nil {
Expand All @@ -425,8 +421,8 @@ func getHeadscaleApp() (*headscale.Headscale, error) {

// We are doing this here, as in the future could be cool to have it also hot-reload

if viper.GetString("acl_policy_path") != "" {
aclPath := absPath(viper.GetString("acl_policy_path"))
if cfg.ACL.PolicyPath != "" {
aclPath := headscale.AbsolutePathFromConfigPath(cfg.ACL.PolicyPath)
err = app.LoadACLPolicy(aclPath)
if err != nil {
log.Fatal().
Expand All @@ -440,7 +436,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
}

func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc.ClientConn, context.CancelFunc) {
cfg := getHeadscaleConfig()
cfg := GetHeadscaleConfig()

log.Debug().
Dur("timeout", cfg.CLI.Timeout).
Expand Down
16 changes: 16 additions & 0 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@ import (
"encoding/json"
"fmt"
"net"
"os"
"path/filepath"
"reflect"
"strings"

"github.com/rs/zerolog/log"
"github.com/spf13/viper"
"inet.af/netaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
Expand Down Expand Up @@ -334,3 +337,16 @@ func IsStringInSlice(slice []string, str string) bool {

return false
}

func AbsolutePathFromConfigPath(path string) string {
// If a relative path is provided, prefix it with the the directory where
// the config file was found.
if (path != "") && !strings.HasPrefix(path, string(os.PathSeparator)) {
dir, _ := filepath.Split(viper.ConfigFileUsed())
if dir != "" {
path = filepath.Join(dir, path)
}
}

return path
}