From 7c36193a144e3337e532e374de3509bfcde37299 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B6rg=20Thalheim?= Date: Sat, 11 Jun 2022 10:35:01 +0200 Subject: [PATCH] implement systemd socket activation --- default.nix | 25 ++++++++++++++++----- main.go | 30 +++++++++++++++++++++++++ systemd_sockets.go | 52 +++++++++++++++++++++++++++++++++++++++++++ tests/test_service.py | 21 +++++++++++++---- 4 files changed, 119 insertions(+), 9 deletions(-) create mode 100644 systemd_sockets.go diff --git a/default.nix b/default.nix index 5883284..f57f2fd 100644 --- a/default.nix +++ b/default.nix @@ -1,10 +1,25 @@ with import {}; -mkShell { - nativeBuildInputs = [ - go - vault + +buildGoModule { + name = "systemd-vault"; + src = ./.; + vendorSha256 = null; + checkInputs = [ python3.pkgs.pytest golangci-lint - hivemind + vault ]; + meta = with lib; { + description = "A proxy for secrets between systemd services and vault"; + homepage = "https://github.com/numtide/systemd-vault"; + license = licenses.mit; + maintainers = with maintainers; [ mic92 ]; + platforms = platforms.unix; + }; } +#mkShell { +# nativeBuildInputs = [ +# go +# hivemind +# ]; +#} diff --git a/main.go b/main.go index 033fe71..7c005c9 100644 --- a/main.go +++ b/main.go @@ -20,7 +20,37 @@ type server struct { connectionClosed chan int } +func inheritSocket() *net.UnixListener { + socks := systemdSockets(true) + stat := &syscall.Stat_t {} + for _, s := range socks { + fd := s.Fd() + err := syscall.Fstat(int(fd), stat); + if err != nil { + log.Printf("Received invalid file descriptor from systemd for fd%d: %v", fd, err) + continue + } + listener, err := net.FileListener(s) + if err != nil { + log.Printf("Received file descriptor %d from systemd that is not a valid socket: %v", fd, err) + continue + } + unixListener, ok := listener.(*net.UnixListener); + if !ok { + log.Printf("Ignore file descriptor %d from systemd, which is not a unix socket", fd) + continue + } + log.Printf("Use unix socket received from systemd") + return unixListener + } + return nil +} + func listenSocket(path string) (*net.UnixListener, error) { + s := inheritSocket() + if s != nil { + return s, nil + } if err := syscall.Unlink(path); err != nil && !os.IsNotExist(err) { return nil, fmt.Errorf("Cannot remove old socket: %v", err) } diff --git a/systemd_sockets.go b/systemd_sockets.go new file mode 100644 index 0000000..e428e64 --- /dev/null +++ b/systemd_sockets.go @@ -0,0 +1,52 @@ +package main + +import ( + "os" + "strconv" + "strings" + "syscall" +) + +const ( + // listenFdsStart corresponds to `SD_LISTEN_FDS_START`. + listenFdsStart = 3 +) + +// Files returns a slice containing a `os.File` object for each +// file descriptor passed to this process via systemd fd-passing protocol. +// +// The order of the file descriptors is preserved in the returned slice. +// `unsetEnv` is typically set to `true` in order to avoid clashes in +// fd usage and to avoid leaking environment flags to child processes. +func systemdSockets(unsetEnv bool) []*os.File { + if unsetEnv { + defer os.Unsetenv("LISTEN_PID") + defer os.Unsetenv("LISTEN_FDS") + defer os.Unsetenv("LISTEN_FDNAMES") + } + + pid, err := strconv.Atoi(os.Getenv("LISTEN_PID")) + if err != nil || pid != os.Getpid() { + return nil + } + + nfds, err := strconv.Atoi(os.Getenv("LISTEN_FDS")) + if err != nil || nfds == 0 { + return nil + } + + names := strings.Split(os.Getenv("LISTEN_FDNAMES"), ":") + + files := make([]*os.File, 0, nfds) + for fd := listenFdsStart; fd < listenFdsStart+nfds; fd++ { + syscall.CloseOnExec(fd) + name := "LISTEN_FD_" + strconv.Itoa(fd) + offset := fd - listenFdsStart + if offset < len(names) && len(names[offset]) > 0 { + name = names[offset] + } + files = append(files, os.NewFile(uintptr(fd), name)) + } + + return files +} diff --git a/tests/test_service.py b/tests/test_service.py index fdd7157..ea21edb 100644 --- a/tests/test_service.py +++ b/tests/test_service.py @@ -3,6 +3,7 @@ import subprocess from dataclasses import dataclass from command import Command, run from pathlib import Path +import time import string import random @@ -27,14 +28,17 @@ def random_service(secrets_dir: Path) -> Service: return Service(service, secret_name, secret_path) -def test_service(systemd_vault: Path, command: Command, tempdir: Path) -> None: +def test_socket_activation( + systemd_vault: Path, command: Command, tempdir: Path +) -> None: secrets_dir = tempdir / "secrets" + secrets_dir.mkdir() sock = tempdir / "sock" - command.run([str(systemd_vault), "-secrets", str(secrets_dir), "-sock", str(sock)]) - import time + + command.run(["systemd-socket-activate", "--listen", str(sock), str(systemd_vault), "-secrets", str(secrets_dir), "-sock", str(sock)]) while not sock.exists(): - time.sleep(1) + time.sleep(0.1) service = random_service(secrets_dir) service.secret_path.write_text("foo") @@ -59,6 +63,15 @@ def test_service(systemd_vault: Path, command: Command, tempdir: Path) -> None: assert out.stdout == "foo" assert out.returncode == 0 + +def test_blocking_secret(systemd_vault: Path, command: Command, tempdir: Path) -> None: + secrets_dir = tempdir / "secrets" + sock = tempdir / "sock" + command.run([str(systemd_vault), "-secrets", str(secrets_dir), "-sock", str(sock)]) + + while not sock.exists(): + time.sleep(0.1) + service = random_service(secrets_dir) proc = command.run(