diff --git a/cmd/systemd-vaultd-update-secrets/main.go b/cmd/systemd-vaultd-update-secrets/main.go index 8ee6ee9..a5ae241 100644 --- a/cmd/systemd-vaultd-update-secrets/main.go +++ b/cmd/systemd-vaultd-update-secrets/main.go @@ -6,6 +6,7 @@ import ( "log" "os" "path" + "strings" "syscall" "time" ) @@ -85,14 +86,31 @@ func updateSecrets(serviceName, target string) error { return nil } +func getSystemdServiceName() (string, error) { + mainPid := os.Getenv("MAINPID") + if mainPid == "" { + return "", fmt.Errorf("MAINPID not set") + } + p := fmt.Sprintf("/proc/%s/cgroup", mainPid) + content, err := os.ReadFile(p) + if err != nil { + return "", fmt.Errorf("failed to read cgroup file %s: %w", p, err) + } + line := strings.SplitN(string(content), "\n", 2)[0] + if !strings.HasSuffix(line, ".service") { + return "", fmt.Errorf("cgroup file %s does not end with .service", p) + } + return path.Base(line), nil +} + func main() { if len(os.Args) != 2 { fmt.Println("Usage: systemd-vaultd-update-secrets ") os.Exit(1) } - serviceName := os.Getenv("SYSTEMD_ACTIVATION_UNIT") - if serviceName == "" { - fmt.Println("SYSTEMD_ACTIVATION_UNIT not set") + serviceName, err := getSystemdServiceName() + if err != nil { + fmt.Println(err) os.Exit(1) }