From 63bcc48e31aac2a355a669fc19747807d9434ee7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B6rg=20Thalheim?= Date: Sat, 11 Jun 2022 09:45:43 +0200 Subject: [PATCH] mvp + tests --- .gitignore | 9 ++ Procfile | 2 + default.nix | 2 + epoll.go | 44 ++++++++++ go.mod | 2 +- main.go | 146 ++++++++++++++++++++++++++++++- setup.cfg | 23 +++++ tests/command.py | 81 +++++++++++++++++ tests/conftest.py | 8 ++ tests/root.py | 24 +++++ tests/systemd_vault.py | 23 +++++ tests/tempdir.py | 12 +++ tests/test_service.py | 84 ++++++++++++++++++ watcher.go | 193 +++++++++++++++++++++++++++++++++++++++++ 14 files changed, 651 insertions(+), 2 deletions(-) create mode 100644 .gitignore create mode 100644 Procfile create mode 100644 epoll.go create mode 100644 setup.cfg create mode 100644 tests/command.py create mode 100644 tests/conftest.py create mode 100644 tests/root.py create mode 100644 tests/systemd_vault.py create mode 100644 tests/tempdir.py create mode 100644 tests/test_service.py create mode 100644 watcher.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3c2977f --- /dev/null +++ b/.gitignore @@ -0,0 +1,9 @@ +.DS_Store +.idea +*.log +systemd-vault + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class diff --git a/Procfile b/Procfile new file mode 100644 index 0000000..be8bcb1 --- /dev/null +++ b/Procfile @@ -0,0 +1,2 @@ +vault: vault server -dev +systemd-vault: go run . diff --git a/default.nix b/default.nix index 04d82ba..5883284 100644 --- a/default.nix +++ b/default.nix @@ -4,5 +4,7 @@ mkShell { go vault python3.pkgs.pytest + golangci-lint + hivemind ]; } diff --git a/epoll.go b/epoll.go new file mode 100644 index 0000000..8f4dc06 --- /dev/null +++ b/epoll.go @@ -0,0 +1,44 @@ +package main + +import ( + "log" + "syscall" +) + +const ( + EPOLLET = 1 << 31 +) + +func (s *server) epollWatch(fd int) error { + event := syscall.EpollEvent{ + Fd: int32(fd), + Events: syscall.EPOLLHUP | EPOLLET, + } + return syscall.EpollCtl(s.epfd, syscall.EPOLL_CTL_ADD, fd, &event) +} + +func (s *server) epollDelete(fd int) error { + return syscall.EpollCtl(s.epfd, syscall.EPOLL_CTL_DEL, fd, &syscall.EpollEvent{}) +} + +func (s *server) handleEpoll() { + events := make([]syscall.EpollEvent, 1024) + for { + n, errno := syscall.EpollWait(s.epfd, events, -1) + if n == -1 { + if errno == syscall.EINTR { + continue + } + log.Fatalf("connection cleaner: epoll wait failed with %v", errno) + } + ready := events[:n] + for _, event := range ready { + if event.Events&(syscall.EPOLLHUP|syscall.EPOLLERR) != 0 { + s.epollDelete(int(event.Fd)) + s.connectionClosed <- int(event.Fd) + } else { + log.Printf("Unhandled epoll event: %d for fd %d", event.Events, event.Fd) + } + } + } +} diff --git a/go.mod b/go.mod index 8e81507..27db5fc 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ -module github.com/numtide/vault-nixos +module github.com/numtide/systemd-vault go 1.17 diff --git a/main.go b/main.go index 24dcac8..033fe71 100644 --- a/main.go +++ b/main.go @@ -1,9 +1,153 @@ package main import ( + "flag" "fmt" + "io" + "log" + "net" + "os" + "path/filepath" + "strings" + "syscall" ) +type server struct { + Socket string + SecretDir string + epfd int + inotifyRequests chan inotifyRequest + connectionClosed chan int +} + +func listenSocket(path string) (*net.UnixListener, error) { + if err := syscall.Unlink(path); err != nil && !os.IsNotExist(err) { + return nil, fmt.Errorf("Cannot remove old socket: %v", err) + } + abs, err := filepath.Abs(path) + if err != nil { + return nil, fmt.Errorf("'%s' is not a valid socket path: %v", path, err) + } + addr, err := net.ResolveUnixAddr("unix", abs) + if err != nil { + return nil, fmt.Errorf("Failed to resolv '%s' as a unix address: %v", abs, err) + } + listener, err := net.ListenUnix("unix", addr) + if err != nil { + return nil, fmt.Errorf("Failed to open socket at %s: %v", addr.Name, err) + } + return listener, nil +} + +func parseCredentialsAddr(addr string) (*string, *string, error) { + // Systemd stores metadata in its local unix address + fields := strings.Split(addr, "/") + if len(fields) != 4 || fields[1] != "unit" { + return nil, nil, fmt.Errorf("Address needs to match this format: @/unit//, got '%s'", addr) + } + return &fields[2], &fields[3], nil +} + +func (s *server) serveConnection(conn *net.UnixConn) { + shouldClose := true + defer func() { + if shouldClose { + conn.Close() + } + }() + + addr := conn.RemoteAddr().String() + unit, secret, err := parseCredentialsAddr(addr) + if err != nil { + log.Printf("Received connection but remote unix address seems to be not from systemd: %v", err) + return + } + log.Printf("Systemd requested secret for %s/%s", *unit, *secret) + secretName := *unit + "-" + *secret + secretPath := filepath.Join(s.SecretDir, secretName) + f, err := os.Open(secretPath) + if os.IsNotExist(err) { + log.Printf("Block start until %s appears", secretPath) + shouldClose = false + fd, err := connFd(conn) + if err != nil { + // connection was closed while we trying to wait + return + } + if err := s.epollWatch(fd); err != nil { + log.Printf("Cannot get setup epoll for unix socket: %s", err) + return + } + s.inotifyRequests <- inotifyRequest{filename: secretName, conn: conn} + return + } else if err != nil { + log.Printf("Cannot open secret %s/%s: %v", *unit, *secret, err) + return + } + defer f.Close() + if _, err = io.Copy(conn, f); err != nil { + log.Printf("Failed to send secret: %v", err) + } +} + +func serveSecrets(s *server) error { + l, err := listenSocket(s.Socket) + if err != nil { + return fmt.Errorf("Failed to setup listening socket: %v", err) + } + defer l.Close() + log.Printf("Listening on %s", s.Socket) + go s.handleEpoll() + for { + conn, err := l.AcceptUnix() + if err != nil { + return fmt.Errorf("Error accepting unix connection: %v", err) + } + go s.serveConnection(conn) + } +} + +var secretDir, socketDir string + +func init() { + defaultDir := os.Getenv("SYSTEMD_VAULT_SECRETS") + if defaultDir == "" { + defaultDir = "/run/systemd-vault" + } + flag.StringVar(&secretDir, "secrets", defaultDir, "directory where secrets are looked up") + + defaultSock := os.Getenv("SYSTEMD_VAULT_SOCK") + if defaultSock == "" { + defaultSock = "/run/systemd-vault.sock" + } + flag.StringVar(&socketDir, "sock", defaultSock, "unix socket to listen to for systemd requests") + flag.Parse() +} + +func createServer(secretDir string, socketDir string) (*server, error) { + epfd, err := syscall.EpollCreate1(syscall.EPOLL_CLOEXEC) + if epfd == -1 { + return nil, fmt.Errorf("failed to create epoll fd: %v", err) + } + s := &server{ + Socket: socketDir, + SecretDir: secretDir, + epfd: epfd, + inotifyRequests: make(chan inotifyRequest), + connectionClosed: make(chan int), + } + if err := s.setupWatcher(secretDir); err != nil { + return nil, fmt.Errorf("Failed to setup file system watcher: %v", err) + } + return s, nil +} + func main() { - fmt.Println("Hello world") + s, err := createServer(secretDir, socketDir) + if err != nil { + log.Fatalf("Failed to create server: %v", err) + } + if err := serveSecrets(s); err != nil { + log.Fatalf("Failed serve secrets: %v", err) + } } diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..1294fa5 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,23 @@ +[wheel] +universal = 1 + +[pycodestyle] +max-line-length = 88 +ignore = E501,E741,W503 + +[flake8] +max-line-length = 88 +ignore = E501,E741,W503 +exclude = .git,__pycache__,docs/source/conf.py,old,build,dist + +[mypy] +warn_redundant_casts = true +disallow_untyped_calls = true +disallow_untyped_defs = true +no_implicit_optional = true + +[mypy-pytest.*] +ignore_missing_imports = True + +[isort] +profile = black diff --git a/tests/command.py b/tests/command.py new file mode 100644 index 0000000..533df6e --- /dev/null +++ b/tests/command.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python3 + +import os +import signal +import subprocess +from typing import IO, Any, Dict, Iterator, List, Union +from pathlib import Path + +import pytest + +_DIR = Union[None, Path, str] +_FILE = Union[None, int, IO[Any]] + + +def run( + cmd: List[str], + text: bool = True, + check: bool = True, + cwd: _DIR = None, + stderr: _FILE = None, + stdout: _FILE = None, +) -> subprocess.CompletedProcess: + if cwd is not None: + print(f"cd {cwd}") + print("$ " + " ".join(cmd)) + return subprocess.run( + cmd, text=text, check=check, cwd=cwd, stderr=stderr, stdout=stdout + ) + + +class Command: + def __init__(self) -> None: + self.processes: List[subprocess.Popen] = [] + + def run( + self, + command: List[str], + extra_env: Dict[str, str] = {}, + stdin: _FILE = None, + stdout: _FILE = None, + stderr: _FILE = None, + text: bool = True, + ) -> subprocess.Popen: + env = os.environ.copy() + env.update(extra_env) + # We start a new session here so that we can than more reliably kill all childs as well + p = subprocess.Popen( + command, + env=env, + start_new_session=True, + stdout=stdout, + stderr=stderr, + stdin=stdin, + text=text, + ) + self.processes.append(p) + return p + + def terminate(self) -> None: + # Stop in reverse order in case there are dependencies. + # We just kill all processes as quickly as possible because we don't + # care about corrupted state and want to make tests fasts. + for p in reversed(self.processes): + try: + os.killpg(os.getpgid(p.pid), signal.SIGKILL) + except OSError: + pass + + +@pytest.fixture +def command() -> Iterator[Command]: + """ + Starts a background command. The process is automatically terminated in the end. + >>> p = command.run(["some", "daemon"]) + >>> print(p.pid) + """ + c = Command() + try: + yield c + finally: + c.terminate() diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..2eecc6b --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,8 @@ +#!/usr/bin/env python3 + +pytest_plugins = [ + "command", + "root", + "systemd_vault", + "tempdir", +] diff --git a/tests/root.py b/tests/root.py new file mode 100644 index 0000000..087c307 --- /dev/null +++ b/tests/root.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python3 + +from pathlib import Path + +import pytest + +TEST_ROOT = Path(__file__).parent.resolve() +PROJECT_ROOT = TEST_ROOT.parent + + +@pytest.fixture +def test_root() -> Path: + """ + Root directory of the tests + """ + return TEST_ROOT + + +@pytest.fixture +def project_root() -> Path: + """ + Root directory of the tests + """ + return PROJECT_ROOT diff --git a/tests/systemd_vault.py b/tests/systemd_vault.py new file mode 100644 index 0000000..b350879 --- /dev/null +++ b/tests/systemd_vault.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python3 + +import os +import pytest +from pathlib import Path +from typing import Optional +from command import run + +BIN: Optional[Path] = None + + +@pytest.fixture +def systemd_vault(project_root: Path) -> Path: + global BIN + if BIN: + return BIN + bin = os.environ.get("SYSTEMD_VAULT_BIN") + if bin: + BIN = Path(bin) + return BIN + run(["go", "build", str(project_root)]) + BIN = project_root / "systemd-vault" + return BIN diff --git a/tests/tempdir.py b/tests/tempdir.py new file mode 100644 index 0000000..2a4adc3 --- /dev/null +++ b/tests/tempdir.py @@ -0,0 +1,12 @@ +#!/usr/bin/env python3 + +import pytest +from tempfile import TemporaryDirectory +from pathlib import Path +from typing import Iterator + + +@pytest.fixture +def tempdir() -> Iterator[Path]: + with TemporaryDirectory() as dir: + yield Path(dir) diff --git a/tests/test_service.py b/tests/test_service.py new file mode 100644 index 0000000..fdd7157 --- /dev/null +++ b/tests/test_service.py @@ -0,0 +1,84 @@ +import subprocess + +from dataclasses import dataclass +from command import Command, run +from pathlib import Path + +import string +import random + + +def rand_word(n: int) -> str: + return "".join(random.choices(string.ascii_uppercase + string.digits, k=n)) + + +@dataclass +class Service: + name: str + secret_name: str + secret_path: Path + + +def random_service(secrets_dir: Path) -> Service: + service = f"test-service-{rand_word(8)}.service" + secret_name = "foo" + secret = f"{service}-{secret_name}" + secret_path = secrets_dir / secret + return Service(service, secret_name, secret_path) + + +def test_service(systemd_vault: Path, command: Command, tempdir: Path) -> None: + secrets_dir = tempdir / "secrets" + sock = tempdir / "sock" + command.run([str(systemd_vault), "-secrets", str(secrets_dir), "-sock", str(sock)]) + import time + + while not sock.exists(): + time.sleep(1) + + service = random_service(secrets_dir) + service.secret_path.write_text("foo") + + # should not block + out = run( + [ + "systemd-run", + "-u", + service.name, + "--collect", + "--user", + "-p", + f"LoadCredential={service.secret_name}:{sock}", + "--wait", + "--pipe", + "cat", + "${CREDENTIALS_DIRECTORY}/" + service.secret_name, + ], + stdout=subprocess.PIPE, + ) + assert out.stdout == "foo" + assert out.returncode == 0 + + service = random_service(secrets_dir) + + proc = command.run( + [ + "systemd-run", + "-u", + service.name, + "--collect", + "--user", + "-p", + f"LoadCredential={service.secret_name}:{sock}", + "--wait", + "--pipe", + "cat", + "${CREDENTIALS_DIRECTORY}/" + service.secret_name, + ], + stdout=subprocess.PIPE, + ) + time.sleep(0.1) + assert proc.poll() is None, "service should block for secret" + service.secret_path.write_text("foo") + assert proc.stdout is not None and proc.stdout.read() == "foo" + assert proc.wait() == 0 diff --git a/watcher.go b/watcher.go new file mode 100644 index 0000000..865335b --- /dev/null +++ b/watcher.go @@ -0,0 +1,193 @@ +package main + +import ( + "errors" + "fmt" + "io" + "log" + "net" + "os" + "path/filepath" + "strings" + "syscall" + "unsafe" +) + +type inotifyRequest struct { + filename string + conn *net.UnixConn +} + +type watcher struct { + requests chan inotifyRequest + connectionEvents chan int + dir string +} + +type connection struct { + fd int + connection *net.UnixConn +} + +type watch struct { + connections []connection +} + +func readEvents(inotifyFd int, events chan string) { + var buf [syscall.SizeofInotifyEvent * 4096]byte // Buffer for a maximum of 4096 raw events + for { + n, err := syscall.Read(inotifyFd, buf[:]) + // If a signal interrupted execution, see if we've been asked to close, and try again. + // http://man7.org/linux/man-pages/man7/signal.7.html : + // "Before Linux 3.8, reads from an inotify(7) file descriptor were not restartable" + if errors.Is(err, syscall.EINTR) { + continue + } + + if n < syscall.SizeofInotifyEvent { + if n == 0 { + log.Fatalf("notify: read EOF from inotify (cause: %v)", err) + } else if n < 0 { + log.Fatalf("notify: Received error while reading from inotify: %v", err) + } else { + log.Fatal("notify: short read in readEvents()") + } + continue + } + var offset uint32 + for offset <= uint32(n-syscall.SizeofInotifyEvent) { + // Point "raw" to the event in the buffer + raw := (*syscall.InotifyEvent)(unsafe.Pointer(&buf[offset])) + + mask := uint32(raw.Mask) + nameLen := uint32(raw.Len) + + if mask&syscall.IN_Q_OVERFLOW != 0 { + // TODO Re-scan all files in this case + log.Fatal("Overflow in inotify") + } + if nameLen > 0 { + // Point "bytes" at the first byte of the filename + bytes := (*[syscall.PathMax]byte)(unsafe.Pointer(&buf[offset+syscall.SizeofInotifyEvent])) + // The filename is padded with NULL bytes. TrimRight() gets rid of those. + fname := strings.TrimRight(string(bytes[0:nameLen]), "\000") + log.Printf("Detected added file: %s", fname) + events <- fname + } + + // Move to the next event in the buffer + offset += syscall.SizeofInotifyEvent + nameLen + } + } +} + +func connFd(conn *net.UnixConn) (int, error) { + file, err := conn.File() + if err != nil { + return -1, err + } + return int(file.Fd()), nil +} + +func (s *server) watch(inotifyFd int) { + connsForPath := make(map[string][]connection) + fdToPath := make(map[int]string) + + defer syscall.Close(inotifyFd) + + fsEvents := make(chan string) + go readEvents(inotifyFd, fsEvents) + for { + select { + case req, ok := <-s.inotifyRequests: + if !ok { + return + } + fd, err := connFd(req.conn) + if err != nil { + log.Println("Received inotify request for closed connection") + continue + } + fdToPath[fd] = req.filename + conns, ok := connsForPath[req.filename] + if ok { + conns = append(conns, connection{fd, req.conn}) + continue + } + + connsForPath[req.filename] = []connection{{fd, req.conn}} + case fname, ok := <-fsEvents: + if !ok { + return + } + conns := connsForPath[fname] + if conns == nil { + log.Printf("Ignore unknown file: %s", fname) + continue + } + delete(connsForPath, fname) + + for _, conn := range conns { + f, err := os.Open(filepath.Join(s.SecretDir, fname)) + defer f.Close() + defer delete(fdToPath, conn.fd) + + if err == nil { + _, err := io.Copy(conn.connection, f) + if err == nil { + log.Printf("Served %s to %s", fname, conn.connection.RemoteAddr().String()) + } else { + log.Printf("Failed to send secret: %v", err) + } + s.epollDelete(conn.fd) + if err := syscall.Shutdown(conn.fd, syscall.SHUT_RDWR); err != nil { + log.Printf("Failed to shutdown socket: %v", err) + } + } else { + log.Printf("Failed to open secret: %v", err) + } + } + case fd, ok := <-s.connectionClosed: + if !ok { + return + } + path := fdToPath[fd] + delete(fdToPath, fd) + conns := connsForPath[path] + if conns == nil { + // watcher has been already deregistered + return + } + for idx, c := range conns { + if c.fd == fd { + last := len(conns) - 1 + conns[idx] = conns[last] + conns = conns[:last] + + c.connection.Close() + break + } + } + if len(conns) == 0 { + delete(connsForPath, path) + } + } + } +} + +func (s *server) setupWatcher(dir string) error { + fd, err := syscall.InotifyInit1(syscall.IN_CLOEXEC) + if err != nil { + return fmt.Errorf("Failed to initialize inotify: %v", err) + } + flags := uint32(syscall.IN_CREATE | syscall.IN_MOVED_TO | syscall.IN_ONLYDIR) + res := os.MkdirAll(dir, 0700) + if err != nil && !os.IsNotExist(res) { + return fmt.Errorf("Failed to create secret directory: %v", err) + } + if _, err = syscall.InotifyAddWatch(fd, dir, flags); err != nil { + return fmt.Errorf("Failed to initialize inotify on secret directory %s: %v", dir, err) + } + go s.watch(fd) + return nil +}