From 84646441d8f9af2fb696f89eb56eaf84fedb2016 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B6rg=20Thalheim?= Date: Wed, 5 Apr 2023 12:33:03 +0200 Subject: [PATCH 1/2] systemd-vaultd-update-secrets: do not depend on CREDENTIALS_DIRECTORY to be accessible --- cmd/systemd-vaultd-update-secrets/main.go | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/cmd/systemd-vaultd-update-secrets/main.go b/cmd/systemd-vaultd-update-secrets/main.go index 22ff597..8de83d8 100644 --- a/cmd/systemd-vaultd-update-secrets/main.go +++ b/cmd/systemd-vaultd-update-secrets/main.go @@ -13,18 +13,17 @@ const ( systemdVaultdir = "/run/systemd-vaultd/secrets" ) -func updateSecrets(credentialsDirectory, target string) error { +func updateSecrets(serviceName, target string) error { // get systemd service name from credentials directory - serviceName := path.Base(credentialsDirectory) - stat, err := os.Stat(credentialsDirectory) + stat, err := os.Stat(target) if err != nil { - return fmt.Errorf("failed to stat %s: %w", credentialsDirectory, err) + return fmt.Errorf("failed to stat target %s: %w", target, err) } // inherit the owner and group of the credentials directory uid := stat.Sys().(*syscall.Stat_t).Uid gid := stat.Sys().(*syscall.Stat_t).Gid - jsonPath := path.Join(credentialsDirectory, fmt.Sprintf("%s.json", serviceName)) + jsonPath := path.Join(systemdVaultdir, fmt.Sprintf("%s.json", serviceName)) var content []byte for i := 0; i < 10; i++ { content, err = os.ReadFile(jsonPath) @@ -45,7 +44,6 @@ func updateSecrets(credentialsDirectory, target string) error { } for key, value := range data { targetPath := path.Join(target, key) - err := os.MkdirAll(path.Dir(targetPath), 0o700) os.Chown(path.Dir(targetPath), int(uid), int(gid)) if err != nil { @@ -63,14 +61,14 @@ func main() { fmt.Println("Usage: systemd-vaultd-update-secrets ") os.Exit(1) } - credentialsDirectory := os.Getenv("CREDENTIALS_DIRECTORY") - if credentialsDirectory == "" { - fmt.Println("CREDENTIALS_DIRECTORY environment variable must be set") + serviceName := os.Getenv("SYSTEMD_ACTIVATION_UNIT") + if serviceName == "" { + fmt.Println("SYSTEMD_ACTIVATION_UNIT not set") os.Exit(1) } target := os.Args[1] - if err := updateSecrets(credentialsDirectory, target); err != nil { + if err := updateSecrets(serviceName, target); err != nil { fmt.Println(err) os.Exit(1) } From 2fd4e8a5c931bd42c2777658085f69fe23d81213 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B6rg=20Thalheim?= Date: Wed, 5 Apr 2023 12:51:34 +0200 Subject: [PATCH 2/2] systemd-vaultd-update-secrets: fix on race condition if json is not up-to-date --- cmd/systemd-vaultd-update-secrets/main.go | 39 ++++++++++++++++++++--- 1 file changed, 34 insertions(+), 5 deletions(-) diff --git a/cmd/systemd-vaultd-update-secrets/main.go b/cmd/systemd-vaultd-update-secrets/main.go index 8de83d8..8ee6ee9 100644 --- a/cmd/systemd-vaultd-update-secrets/main.go +++ b/cmd/systemd-vaultd-update-secrets/main.go @@ -3,6 +3,7 @@ package main import ( "encoding/json" "fmt" + "log" "os" "path" "syscall" @@ -26,6 +27,24 @@ func updateSecrets(serviceName, target string) error { jsonPath := path.Join(systemdVaultdir, fmt.Sprintf("%s.json", serviceName)) var content []byte for i := 0; i < 10; i++ { + jsonStat, err := os.Stat(jsonPath) + if err != nil { + if os.IsNotExist(err) { + // wait for the file to be created + fmt.Printf("waiting for %s to be created", jsonPath) + time.Sleep(1 * time.Second) + continue + } + return fmt.Errorf("failed to stat vault json file %s: %w", serviceName, err) + } + + if jsonStat.ModTime().Before(stat.ModTime()) { + // wait for the file to be updated + fmt.Printf("waiting for %s to be updated", jsonPath) + time.Sleep(1 * time.Second) + continue + } + content, err = os.ReadFile(jsonPath) if err != nil { if os.IsNotExist(err) { @@ -44,13 +63,23 @@ func updateSecrets(serviceName, target string) error { } for key, value := range data { targetPath := path.Join(target, key) - os.Chown(path.Dir(targetPath), int(uid), int(gid)) - + tempPath := targetPath + ".tmp" + err = os.WriteFile(tempPath, []byte(value.(string)), 0o400) + if err != nil { + return fmt.Errorf("failed to write file %s: %w", targetPath, err) + } + err = os.Chown(tempPath, int(uid), int(gid)) if err != nil { - return fmt.Errorf("failed to create directory %s: %w", path.Dir(targetPath), err) + return fmt.Errorf("failed to chown file %s: %w", targetPath, err) } - os.WriteFile(targetPath, []byte(value.(string)), 0o400) - os.Chown(targetPath, int(uid), int(gid)) + err = os.Rename(tempPath, targetPath) + if err != nil { + return fmt.Errorf("failed to rename file %s: %w", targetPath, err) + } + } + err = os.Chtimes(target, time.Now(), time.Now()) + if err != nil { + log.Printf("failed to update modification time of %s: %v", target, err) } return nil