mvp + tests

main
Jörg Thalheim 2 years ago
parent 9e4fd1f36b
commit 63bcc48e31
No known key found for this signature in database

9
.gitignore vendored

@ -0,0 +1,9 @@
.DS_Store
.idea
*.log
systemd-vault
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

@ -0,0 +1,2 @@
vault: vault server -dev
systemd-vault: go run .

@ -4,5 +4,7 @@ mkShell {
go go
vault vault
python3.pkgs.pytest python3.pkgs.pytest
golangci-lint
hivemind
]; ];
} }

@ -0,0 +1,44 @@
package main
import (
"log"
"syscall"
)
const (
EPOLLET = 1 << 31
)
func (s *server) epollWatch(fd int) error {
event := syscall.EpollEvent{
Fd: int32(fd),
Events: syscall.EPOLLHUP | EPOLLET,
}
return syscall.EpollCtl(s.epfd, syscall.EPOLL_CTL_ADD, fd, &event)
}
func (s *server) epollDelete(fd int) error {
return syscall.EpollCtl(s.epfd, syscall.EPOLL_CTL_DEL, fd, &syscall.EpollEvent{})
}
func (s *server) handleEpoll() {
events := make([]syscall.EpollEvent, 1024)
for {
n, errno := syscall.EpollWait(s.epfd, events, -1)
if n == -1 {
if errno == syscall.EINTR {
continue
}
log.Fatalf("connection cleaner: epoll wait failed with %v", errno)
}
ready := events[:n]
for _, event := range ready {
if event.Events&(syscall.EPOLLHUP|syscall.EPOLLERR) != 0 {
s.epollDelete(int(event.Fd))
s.connectionClosed <- int(event.Fd)
} else {
log.Printf("Unhandled epoll event: %d for fd %d", event.Events, event.Fd)
}
}
}
}

@ -1,3 +1,3 @@
module github.com/numtide/vault-nixos module github.com/numtide/systemd-vault
go 1.17 go 1.17

@ -1,9 +1,153 @@
package main package main
import ( import (
"flag"
"fmt" "fmt"
"io"
"log"
"net"
"os"
"path/filepath"
"strings"
"syscall"
) )
type server struct {
Socket string
SecretDir string
epfd int
inotifyRequests chan inotifyRequest
connectionClosed chan int
}
func listenSocket(path string) (*net.UnixListener, error) {
if err := syscall.Unlink(path); err != nil && !os.IsNotExist(err) {
return nil, fmt.Errorf("Cannot remove old socket: %v", err)
}
abs, err := filepath.Abs(path)
if err != nil {
return nil, fmt.Errorf("'%s' is not a valid socket path: %v", path, err)
}
addr, err := net.ResolveUnixAddr("unix", abs)
if err != nil {
return nil, fmt.Errorf("Failed to resolv '%s' as a unix address: %v", abs, err)
}
listener, err := net.ListenUnix("unix", addr)
if err != nil {
return nil, fmt.Errorf("Failed to open socket at %s: %v", addr.Name, err)
}
return listener, nil
}
func parseCredentialsAddr(addr string) (*string, *string, error) {
// Systemd stores metadata in its local unix address
fields := strings.Split(addr, "/")
if len(fields) != 4 || fields[1] != "unit" {
return nil, nil, fmt.Errorf("Address needs to match this format: @<random_hex>/unit/<service_name>/<secret_id>, got '%s'", addr)
}
return &fields[2], &fields[3], nil
}
func (s *server) serveConnection(conn *net.UnixConn) {
shouldClose := true
defer func() {
if shouldClose {
conn.Close()
}
}()
addr := conn.RemoteAddr().String()
unit, secret, err := parseCredentialsAddr(addr)
if err != nil {
log.Printf("Received connection but remote unix address seems to be not from systemd: %v", err)
return
}
log.Printf("Systemd requested secret for %s/%s", *unit, *secret)
secretName := *unit + "-" + *secret
secretPath := filepath.Join(s.SecretDir, secretName)
f, err := os.Open(secretPath)
if os.IsNotExist(err) {
log.Printf("Block start until %s appears", secretPath)
shouldClose = false
fd, err := connFd(conn)
if err != nil {
// connection was closed while we trying to wait
return
}
if err := s.epollWatch(fd); err != nil {
log.Printf("Cannot get setup epoll for unix socket: %s", err)
return
}
s.inotifyRequests <- inotifyRequest{filename: secretName, conn: conn}
return
} else if err != nil {
log.Printf("Cannot open secret %s/%s: %v", *unit, *secret, err)
return
}
defer f.Close()
if _, err = io.Copy(conn, f); err != nil {
log.Printf("Failed to send secret: %v", err)
}
}
func serveSecrets(s *server) error {
l, err := listenSocket(s.Socket)
if err != nil {
return fmt.Errorf("Failed to setup listening socket: %v", err)
}
defer l.Close()
log.Printf("Listening on %s", s.Socket)
go s.handleEpoll()
for {
conn, err := l.AcceptUnix()
if err != nil {
return fmt.Errorf("Error accepting unix connection: %v", err)
}
go s.serveConnection(conn)
}
}
var secretDir, socketDir string
func init() {
defaultDir := os.Getenv("SYSTEMD_VAULT_SECRETS")
if defaultDir == "" {
defaultDir = "/run/systemd-vault"
}
flag.StringVar(&secretDir, "secrets", defaultDir, "directory where secrets are looked up")
defaultSock := os.Getenv("SYSTEMD_VAULT_SOCK")
if defaultSock == "" {
defaultSock = "/run/systemd-vault.sock"
}
flag.StringVar(&socketDir, "sock", defaultSock, "unix socket to listen to for systemd requests")
flag.Parse()
}
func createServer(secretDir string, socketDir string) (*server, error) {
epfd, err := syscall.EpollCreate1(syscall.EPOLL_CLOEXEC)
if epfd == -1 {
return nil, fmt.Errorf("failed to create epoll fd: %v", err)
}
s := &server{
Socket: socketDir,
SecretDir: secretDir,
epfd: epfd,
inotifyRequests: make(chan inotifyRequest),
connectionClosed: make(chan int),
}
if err := s.setupWatcher(secretDir); err != nil {
return nil, fmt.Errorf("Failed to setup file system watcher: %v", err)
}
return s, nil
}
func main() { func main() {
fmt.Println("Hello world") s, err := createServer(secretDir, socketDir)
if err != nil {
log.Fatalf("Failed to create server: %v", err)
}
if err := serveSecrets(s); err != nil {
log.Fatalf("Failed serve secrets: %v", err)
}
} }

@ -0,0 +1,23 @@
[wheel]
universal = 1
[pycodestyle]
max-line-length = 88
ignore = E501,E741,W503
[flake8]
max-line-length = 88
ignore = E501,E741,W503
exclude = .git,__pycache__,docs/source/conf.py,old,build,dist
[mypy]
warn_redundant_casts = true
disallow_untyped_calls = true
disallow_untyped_defs = true
no_implicit_optional = true
[mypy-pytest.*]
ignore_missing_imports = True
[isort]
profile = black

@ -0,0 +1,81 @@
#!/usr/bin/env python3
import os
import signal
import subprocess
from typing import IO, Any, Dict, Iterator, List, Union
from pathlib import Path
import pytest
_DIR = Union[None, Path, str]
_FILE = Union[None, int, IO[Any]]
def run(
cmd: List[str],
text: bool = True,
check: bool = True,
cwd: _DIR = None,
stderr: _FILE = None,
stdout: _FILE = None,
) -> subprocess.CompletedProcess:
if cwd is not None:
print(f"cd {cwd}")
print("$ " + " ".join(cmd))
return subprocess.run(
cmd, text=text, check=check, cwd=cwd, stderr=stderr, stdout=stdout
)
class Command:
def __init__(self) -> None:
self.processes: List[subprocess.Popen] = []
def run(
self,
command: List[str],
extra_env: Dict[str, str] = {},
stdin: _FILE = None,
stdout: _FILE = None,
stderr: _FILE = None,
text: bool = True,
) -> subprocess.Popen:
env = os.environ.copy()
env.update(extra_env)
# We start a new session here so that we can than more reliably kill all childs as well
p = subprocess.Popen(
command,
env=env,
start_new_session=True,
stdout=stdout,
stderr=stderr,
stdin=stdin,
text=text,
)
self.processes.append(p)
return p
def terminate(self) -> None:
# Stop in reverse order in case there are dependencies.
# We just kill all processes as quickly as possible because we don't
# care about corrupted state and want to make tests fasts.
for p in reversed(self.processes):
try:
os.killpg(os.getpgid(p.pid), signal.SIGKILL)
except OSError:
pass
@pytest.fixture
def command() -> Iterator[Command]:
"""
Starts a background command. The process is automatically terminated in the end.
>>> p = command.run(["some", "daemon"])
>>> print(p.pid)
"""
c = Command()
try:
yield c
finally:
c.terminate()

@ -0,0 +1,8 @@
#!/usr/bin/env python3
pytest_plugins = [
"command",
"root",
"systemd_vault",
"tempdir",
]

@ -0,0 +1,24 @@
#!/usr/bin/env python3
from pathlib import Path
import pytest
TEST_ROOT = Path(__file__).parent.resolve()
PROJECT_ROOT = TEST_ROOT.parent
@pytest.fixture
def test_root() -> Path:
"""
Root directory of the tests
"""
return TEST_ROOT
@pytest.fixture
def project_root() -> Path:
"""
Root directory of the tests
"""
return PROJECT_ROOT

@ -0,0 +1,23 @@
#!/usr/bin/env python3
import os
import pytest
from pathlib import Path
from typing import Optional
from command import run
BIN: Optional[Path] = None
@pytest.fixture
def systemd_vault(project_root: Path) -> Path:
global BIN
if BIN:
return BIN
bin = os.environ.get("SYSTEMD_VAULT_BIN")
if bin:
BIN = Path(bin)
return BIN
run(["go", "build", str(project_root)])
BIN = project_root / "systemd-vault"
return BIN

@ -0,0 +1,12 @@
#!/usr/bin/env python3
import pytest
from tempfile import TemporaryDirectory
from pathlib import Path
from typing import Iterator
@pytest.fixture
def tempdir() -> Iterator[Path]:
with TemporaryDirectory() as dir:
yield Path(dir)

@ -0,0 +1,84 @@
import subprocess
from dataclasses import dataclass
from command import Command, run
from pathlib import Path
import string
import random
def rand_word(n: int) -> str:
return "".join(random.choices(string.ascii_uppercase + string.digits, k=n))
@dataclass
class Service:
name: str
secret_name: str
secret_path: Path
def random_service(secrets_dir: Path) -> Service:
service = f"test-service-{rand_word(8)}.service"
secret_name = "foo"
secret = f"{service}-{secret_name}"
secret_path = secrets_dir / secret
return Service(service, secret_name, secret_path)
def test_service(systemd_vault: Path, command: Command, tempdir: Path) -> None:
secrets_dir = tempdir / "secrets"
sock = tempdir / "sock"
command.run([str(systemd_vault), "-secrets", str(secrets_dir), "-sock", str(sock)])
import time
while not sock.exists():
time.sleep(1)
service = random_service(secrets_dir)
service.secret_path.write_text("foo")
# should not block
out = run(
[
"systemd-run",
"-u",
service.name,
"--collect",
"--user",
"-p",
f"LoadCredential={service.secret_name}:{sock}",
"--wait",
"--pipe",
"cat",
"${CREDENTIALS_DIRECTORY}/" + service.secret_name,
],
stdout=subprocess.PIPE,
)
assert out.stdout == "foo"
assert out.returncode == 0
service = random_service(secrets_dir)
proc = command.run(
[
"systemd-run",
"-u",
service.name,
"--collect",
"--user",
"-p",
f"LoadCredential={service.secret_name}:{sock}",
"--wait",
"--pipe",
"cat",
"${CREDENTIALS_DIRECTORY}/" + service.secret_name,
],
stdout=subprocess.PIPE,
)
time.sleep(0.1)
assert proc.poll() is None, "service should block for secret"
service.secret_path.write_text("foo")
assert proc.stdout is not None and proc.stdout.read() == "foo"
assert proc.wait() == 0

@ -0,0 +1,193 @@
package main
import (
"errors"
"fmt"
"io"
"log"
"net"
"os"
"path/filepath"
"strings"
"syscall"
"unsafe"
)
type inotifyRequest struct {
filename string
conn *net.UnixConn
}
type watcher struct {
requests chan inotifyRequest
connectionEvents chan int
dir string
}
type connection struct {
fd int
connection *net.UnixConn
}
type watch struct {
connections []connection
}
func readEvents(inotifyFd int, events chan string) {
var buf [syscall.SizeofInotifyEvent * 4096]byte // Buffer for a maximum of 4096 raw events
for {
n, err := syscall.Read(inotifyFd, buf[:])
// If a signal interrupted execution, see if we've been asked to close, and try again.
// http://man7.org/linux/man-pages/man7/signal.7.html :
// "Before Linux 3.8, reads from an inotify(7) file descriptor were not restartable"
if errors.Is(err, syscall.EINTR) {
continue
}
if n < syscall.SizeofInotifyEvent {
if n == 0 {
log.Fatalf("notify: read EOF from inotify (cause: %v)", err)
} else if n < 0 {
log.Fatalf("notify: Received error while reading from inotify: %v", err)
} else {
log.Fatal("notify: short read in readEvents()")
}
continue
}
var offset uint32
for offset <= uint32(n-syscall.SizeofInotifyEvent) {
// Point "raw" to the event in the buffer
raw := (*syscall.InotifyEvent)(unsafe.Pointer(&buf[offset]))
mask := uint32(raw.Mask)
nameLen := uint32(raw.Len)
if mask&syscall.IN_Q_OVERFLOW != 0 {
// TODO Re-scan all files in this case
log.Fatal("Overflow in inotify")
}
if nameLen > 0 {
// Point "bytes" at the first byte of the filename
bytes := (*[syscall.PathMax]byte)(unsafe.Pointer(&buf[offset+syscall.SizeofInotifyEvent]))
// The filename is padded with NULL bytes. TrimRight() gets rid of those.
fname := strings.TrimRight(string(bytes[0:nameLen]), "\000")
log.Printf("Detected added file: %s", fname)
events <- fname
}
// Move to the next event in the buffer
offset += syscall.SizeofInotifyEvent + nameLen
}
}
}
func connFd(conn *net.UnixConn) (int, error) {
file, err := conn.File()
if err != nil {
return -1, err
}
return int(file.Fd()), nil
}
func (s *server) watch(inotifyFd int) {
connsForPath := make(map[string][]connection)
fdToPath := make(map[int]string)
defer syscall.Close(inotifyFd)
fsEvents := make(chan string)
go readEvents(inotifyFd, fsEvents)
for {
select {
case req, ok := <-s.inotifyRequests:
if !ok {
return
}
fd, err := connFd(req.conn)
if err != nil {
log.Println("Received inotify request for closed connection")
continue
}
fdToPath[fd] = req.filename
conns, ok := connsForPath[req.filename]
if ok {
conns = append(conns, connection{fd, req.conn})
continue
}
connsForPath[req.filename] = []connection{{fd, req.conn}}
case fname, ok := <-fsEvents:
if !ok {
return
}
conns := connsForPath[fname]
if conns == nil {
log.Printf("Ignore unknown file: %s", fname)
continue
}
delete(connsForPath, fname)
for _, conn := range conns {
f, err := os.Open(filepath.Join(s.SecretDir, fname))
defer f.Close()
defer delete(fdToPath, conn.fd)
if err == nil {
_, err := io.Copy(conn.connection, f)
if err == nil {
log.Printf("Served %s to %s", fname, conn.connection.RemoteAddr().String())
} else {
log.Printf("Failed to send secret: %v", err)
}
s.epollDelete(conn.fd)
if err := syscall.Shutdown(conn.fd, syscall.SHUT_RDWR); err != nil {
log.Printf("Failed to shutdown socket: %v", err)
}
} else {
log.Printf("Failed to open secret: %v", err)
}
}
case fd, ok := <-s.connectionClosed:
if !ok {
return
}
path := fdToPath[fd]
delete(fdToPath, fd)
conns := connsForPath[path]
if conns == nil {
// watcher has been already deregistered
return
}
for idx, c := range conns {
if c.fd == fd {
last := len(conns) - 1
conns[idx] = conns[last]
conns = conns[:last]
c.connection.Close()
break
}
}
if len(conns) == 0 {
delete(connsForPath, path)
}
}
}
}
func (s *server) setupWatcher(dir string) error {
fd, err := syscall.InotifyInit1(syscall.IN_CLOEXEC)
if err != nil {
return fmt.Errorf("Failed to initialize inotify: %v", err)
}
flags := uint32(syscall.IN_CREATE | syscall.IN_MOVED_TO | syscall.IN_ONLYDIR)
res := os.MkdirAll(dir, 0700)
if err != nil && !os.IsNotExist(res) {
return fmt.Errorf("Failed to create secret directory: %v", err)
}
if _, err = syscall.InotifyAddWatch(fd, dir, flags); err != nil {
return fmt.Errorf("Failed to initialize inotify on secret directory %s: %v", dir, err)
}
go s.watch(fd)
return nil
}
Loading…
Cancel
Save