mvp + tests
parent
9e4fd1f36b
commit
63bcc48e31
@ -0,0 +1,9 @@
|
|||||||
|
.DS_Store
|
||||||
|
.idea
|
||||||
|
*.log
|
||||||
|
systemd-vault
|
||||||
|
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
@ -0,0 +1,2 @@
|
|||||||
|
vault: vault server -dev
|
||||||
|
systemd-vault: go run .
|
@ -0,0 +1,44 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"syscall"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
EPOLLET = 1 << 31
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *server) epollWatch(fd int) error {
|
||||||
|
event := syscall.EpollEvent{
|
||||||
|
Fd: int32(fd),
|
||||||
|
Events: syscall.EPOLLHUP | EPOLLET,
|
||||||
|
}
|
||||||
|
return syscall.EpollCtl(s.epfd, syscall.EPOLL_CTL_ADD, fd, &event)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *server) epollDelete(fd int) error {
|
||||||
|
return syscall.EpollCtl(s.epfd, syscall.EPOLL_CTL_DEL, fd, &syscall.EpollEvent{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *server) handleEpoll() {
|
||||||
|
events := make([]syscall.EpollEvent, 1024)
|
||||||
|
for {
|
||||||
|
n, errno := syscall.EpollWait(s.epfd, events, -1)
|
||||||
|
if n == -1 {
|
||||||
|
if errno == syscall.EINTR {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
log.Fatalf("connection cleaner: epoll wait failed with %v", errno)
|
||||||
|
}
|
||||||
|
ready := events[:n]
|
||||||
|
for _, event := range ready {
|
||||||
|
if event.Events&(syscall.EPOLLHUP|syscall.EPOLLERR) != 0 {
|
||||||
|
s.epollDelete(int(event.Fd))
|
||||||
|
s.connectionClosed <- int(event.Fd)
|
||||||
|
} else {
|
||||||
|
log.Printf("Unhandled epoll event: %d for fd %d", event.Events, event.Fd)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -1,3 +1,3 @@
|
|||||||
module github.com/numtide/vault-nixos
|
module github.com/numtide/systemd-vault
|
||||||
|
|
||||||
go 1.17
|
go 1.17
|
||||||
|
@ -1,9 +1,153 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type server struct {
|
||||||
|
Socket string
|
||||||
|
SecretDir string
|
||||||
|
epfd int
|
||||||
|
inotifyRequests chan inotifyRequest
|
||||||
|
connectionClosed chan int
|
||||||
|
}
|
||||||
|
|
||||||
|
func listenSocket(path string) (*net.UnixListener, error) {
|
||||||
|
if err := syscall.Unlink(path); err != nil && !os.IsNotExist(err) {
|
||||||
|
return nil, fmt.Errorf("Cannot remove old socket: %v", err)
|
||||||
|
}
|
||||||
|
abs, err := filepath.Abs(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("'%s' is not a valid socket path: %v", path, err)
|
||||||
|
}
|
||||||
|
addr, err := net.ResolveUnixAddr("unix", abs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Failed to resolv '%s' as a unix address: %v", abs, err)
|
||||||
|
}
|
||||||
|
listener, err := net.ListenUnix("unix", addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Failed to open socket at %s: %v", addr.Name, err)
|
||||||
|
}
|
||||||
|
return listener, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseCredentialsAddr(addr string) (*string, *string, error) {
|
||||||
|
// Systemd stores metadata in its local unix address
|
||||||
|
fields := strings.Split(addr, "/")
|
||||||
|
if len(fields) != 4 || fields[1] != "unit" {
|
||||||
|
return nil, nil, fmt.Errorf("Address needs to match this format: @<random_hex>/unit/<service_name>/<secret_id>, got '%s'", addr)
|
||||||
|
}
|
||||||
|
return &fields[2], &fields[3], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *server) serveConnection(conn *net.UnixConn) {
|
||||||
|
shouldClose := true
|
||||||
|
defer func() {
|
||||||
|
if shouldClose {
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
addr := conn.RemoteAddr().String()
|
||||||
|
unit, secret, err := parseCredentialsAddr(addr)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Received connection but remote unix address seems to be not from systemd: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Printf("Systemd requested secret for %s/%s", *unit, *secret)
|
||||||
|
secretName := *unit + "-" + *secret
|
||||||
|
secretPath := filepath.Join(s.SecretDir, secretName)
|
||||||
|
f, err := os.Open(secretPath)
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
log.Printf("Block start until %s appears", secretPath)
|
||||||
|
shouldClose = false
|
||||||
|
fd, err := connFd(conn)
|
||||||
|
if err != nil {
|
||||||
|
// connection was closed while we trying to wait
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := s.epollWatch(fd); err != nil {
|
||||||
|
log.Printf("Cannot get setup epoll for unix socket: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.inotifyRequests <- inotifyRequest{filename: secretName, conn: conn}
|
||||||
|
return
|
||||||
|
} else if err != nil {
|
||||||
|
log.Printf("Cannot open secret %s/%s: %v", *unit, *secret, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
if _, err = io.Copy(conn, f); err != nil {
|
||||||
|
log.Printf("Failed to send secret: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func serveSecrets(s *server) error {
|
||||||
|
l, err := listenSocket(s.Socket)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Failed to setup listening socket: %v", err)
|
||||||
|
}
|
||||||
|
defer l.Close()
|
||||||
|
log.Printf("Listening on %s", s.Socket)
|
||||||
|
go s.handleEpoll()
|
||||||
|
for {
|
||||||
|
conn, err := l.AcceptUnix()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Error accepting unix connection: %v", err)
|
||||||
|
}
|
||||||
|
go s.serveConnection(conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var secretDir, socketDir string
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
defaultDir := os.Getenv("SYSTEMD_VAULT_SECRETS")
|
||||||
|
if defaultDir == "" {
|
||||||
|
defaultDir = "/run/systemd-vault"
|
||||||
|
}
|
||||||
|
flag.StringVar(&secretDir, "secrets", defaultDir, "directory where secrets are looked up")
|
||||||
|
|
||||||
|
defaultSock := os.Getenv("SYSTEMD_VAULT_SOCK")
|
||||||
|
if defaultSock == "" {
|
||||||
|
defaultSock = "/run/systemd-vault.sock"
|
||||||
|
}
|
||||||
|
flag.StringVar(&socketDir, "sock", defaultSock, "unix socket to listen to for systemd requests")
|
||||||
|
flag.Parse()
|
||||||
|
}
|
||||||
|
|
||||||
|
func createServer(secretDir string, socketDir string) (*server, error) {
|
||||||
|
epfd, err := syscall.EpollCreate1(syscall.EPOLL_CLOEXEC)
|
||||||
|
if epfd == -1 {
|
||||||
|
return nil, fmt.Errorf("failed to create epoll fd: %v", err)
|
||||||
|
}
|
||||||
|
s := &server{
|
||||||
|
Socket: socketDir,
|
||||||
|
SecretDir: secretDir,
|
||||||
|
epfd: epfd,
|
||||||
|
inotifyRequests: make(chan inotifyRequest),
|
||||||
|
connectionClosed: make(chan int),
|
||||||
|
}
|
||||||
|
if err := s.setupWatcher(secretDir); err != nil {
|
||||||
|
return nil, fmt.Errorf("Failed to setup file system watcher: %v", err)
|
||||||
|
}
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
fmt.Println("Hello world")
|
s, err := createServer(secretDir, socketDir)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to create server: %v", err)
|
||||||
|
}
|
||||||
|
if err := serveSecrets(s); err != nil {
|
||||||
|
log.Fatalf("Failed serve secrets: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,23 @@
|
|||||||
|
[wheel]
|
||||||
|
universal = 1
|
||||||
|
|
||||||
|
[pycodestyle]
|
||||||
|
max-line-length = 88
|
||||||
|
ignore = E501,E741,W503
|
||||||
|
|
||||||
|
[flake8]
|
||||||
|
max-line-length = 88
|
||||||
|
ignore = E501,E741,W503
|
||||||
|
exclude = .git,__pycache__,docs/source/conf.py,old,build,dist
|
||||||
|
|
||||||
|
[mypy]
|
||||||
|
warn_redundant_casts = true
|
||||||
|
disallow_untyped_calls = true
|
||||||
|
disallow_untyped_defs = true
|
||||||
|
no_implicit_optional = true
|
||||||
|
|
||||||
|
[mypy-pytest.*]
|
||||||
|
ignore_missing_imports = True
|
||||||
|
|
||||||
|
[isort]
|
||||||
|
profile = black
|
@ -0,0 +1,81 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import os
|
||||||
|
import signal
|
||||||
|
import subprocess
|
||||||
|
from typing import IO, Any, Dict, Iterator, List, Union
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
_DIR = Union[None, Path, str]
|
||||||
|
_FILE = Union[None, int, IO[Any]]
|
||||||
|
|
||||||
|
|
||||||
|
def run(
|
||||||
|
cmd: List[str],
|
||||||
|
text: bool = True,
|
||||||
|
check: bool = True,
|
||||||
|
cwd: _DIR = None,
|
||||||
|
stderr: _FILE = None,
|
||||||
|
stdout: _FILE = None,
|
||||||
|
) -> subprocess.CompletedProcess:
|
||||||
|
if cwd is not None:
|
||||||
|
print(f"cd {cwd}")
|
||||||
|
print("$ " + " ".join(cmd))
|
||||||
|
return subprocess.run(
|
||||||
|
cmd, text=text, check=check, cwd=cwd, stderr=stderr, stdout=stdout
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Command:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.processes: List[subprocess.Popen] = []
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
command: List[str],
|
||||||
|
extra_env: Dict[str, str] = {},
|
||||||
|
stdin: _FILE = None,
|
||||||
|
stdout: _FILE = None,
|
||||||
|
stderr: _FILE = None,
|
||||||
|
text: bool = True,
|
||||||
|
) -> subprocess.Popen:
|
||||||
|
env = os.environ.copy()
|
||||||
|
env.update(extra_env)
|
||||||
|
# We start a new session here so that we can than more reliably kill all childs as well
|
||||||
|
p = subprocess.Popen(
|
||||||
|
command,
|
||||||
|
env=env,
|
||||||
|
start_new_session=True,
|
||||||
|
stdout=stdout,
|
||||||
|
stderr=stderr,
|
||||||
|
stdin=stdin,
|
||||||
|
text=text,
|
||||||
|
)
|
||||||
|
self.processes.append(p)
|
||||||
|
return p
|
||||||
|
|
||||||
|
def terminate(self) -> None:
|
||||||
|
# Stop in reverse order in case there are dependencies.
|
||||||
|
# We just kill all processes as quickly as possible because we don't
|
||||||
|
# care about corrupted state and want to make tests fasts.
|
||||||
|
for p in reversed(self.processes):
|
||||||
|
try:
|
||||||
|
os.killpg(os.getpgid(p.pid), signal.SIGKILL)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def command() -> Iterator[Command]:
|
||||||
|
"""
|
||||||
|
Starts a background command. The process is automatically terminated in the end.
|
||||||
|
>>> p = command.run(["some", "daemon"])
|
||||||
|
>>> print(p.pid)
|
||||||
|
"""
|
||||||
|
c = Command()
|
||||||
|
try:
|
||||||
|
yield c
|
||||||
|
finally:
|
||||||
|
c.terminate()
|
@ -0,0 +1,8 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
pytest_plugins = [
|
||||||
|
"command",
|
||||||
|
"root",
|
||||||
|
"systemd_vault",
|
||||||
|
"tempdir",
|
||||||
|
]
|
@ -0,0 +1,24 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
TEST_ROOT = Path(__file__).parent.resolve()
|
||||||
|
PROJECT_ROOT = TEST_ROOT.parent
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_root() -> Path:
|
||||||
|
"""
|
||||||
|
Root directory of the tests
|
||||||
|
"""
|
||||||
|
return TEST_ROOT
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def project_root() -> Path:
|
||||||
|
"""
|
||||||
|
Root directory of the tests
|
||||||
|
"""
|
||||||
|
return PROJECT_ROOT
|
@ -0,0 +1,23 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
from command import run
|
||||||
|
|
||||||
|
BIN: Optional[Path] = None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def systemd_vault(project_root: Path) -> Path:
|
||||||
|
global BIN
|
||||||
|
if BIN:
|
||||||
|
return BIN
|
||||||
|
bin = os.environ.get("SYSTEMD_VAULT_BIN")
|
||||||
|
if bin:
|
||||||
|
BIN = Path(bin)
|
||||||
|
return BIN
|
||||||
|
run(["go", "build", str(project_root)])
|
||||||
|
BIN = project_root / "systemd-vault"
|
||||||
|
return BIN
|
@ -0,0 +1,12 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from tempfile import TemporaryDirectory
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Iterator
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tempdir() -> Iterator[Path]:
|
||||||
|
with TemporaryDirectory() as dir:
|
||||||
|
yield Path(dir)
|
@ -0,0 +1,84 @@
|
|||||||
|
import subprocess
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from command import Command, run
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import string
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
|
def rand_word(n: int) -> str:
|
||||||
|
return "".join(random.choices(string.ascii_uppercase + string.digits, k=n))
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Service:
|
||||||
|
name: str
|
||||||
|
secret_name: str
|
||||||
|
secret_path: Path
|
||||||
|
|
||||||
|
|
||||||
|
def random_service(secrets_dir: Path) -> Service:
|
||||||
|
service = f"test-service-{rand_word(8)}.service"
|
||||||
|
secret_name = "foo"
|
||||||
|
secret = f"{service}-{secret_name}"
|
||||||
|
secret_path = secrets_dir / secret
|
||||||
|
return Service(service, secret_name, secret_path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_service(systemd_vault: Path, command: Command, tempdir: Path) -> None:
|
||||||
|
secrets_dir = tempdir / "secrets"
|
||||||
|
sock = tempdir / "sock"
|
||||||
|
command.run([str(systemd_vault), "-secrets", str(secrets_dir), "-sock", str(sock)])
|
||||||
|
import time
|
||||||
|
|
||||||
|
while not sock.exists():
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
service = random_service(secrets_dir)
|
||||||
|
service.secret_path.write_text("foo")
|
||||||
|
|
||||||
|
# should not block
|
||||||
|
out = run(
|
||||||
|
[
|
||||||
|
"systemd-run",
|
||||||
|
"-u",
|
||||||
|
service.name,
|
||||||
|
"--collect",
|
||||||
|
"--user",
|
||||||
|
"-p",
|
||||||
|
f"LoadCredential={service.secret_name}:{sock}",
|
||||||
|
"--wait",
|
||||||
|
"--pipe",
|
||||||
|
"cat",
|
||||||
|
"${CREDENTIALS_DIRECTORY}/" + service.secret_name,
|
||||||
|
],
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
)
|
||||||
|
assert out.stdout == "foo"
|
||||||
|
assert out.returncode == 0
|
||||||
|
|
||||||
|
service = random_service(secrets_dir)
|
||||||
|
|
||||||
|
proc = command.run(
|
||||||
|
[
|
||||||
|
"systemd-run",
|
||||||
|
"-u",
|
||||||
|
service.name,
|
||||||
|
"--collect",
|
||||||
|
"--user",
|
||||||
|
"-p",
|
||||||
|
f"LoadCredential={service.secret_name}:{sock}",
|
||||||
|
"--wait",
|
||||||
|
"--pipe",
|
||||||
|
"cat",
|
||||||
|
"${CREDENTIALS_DIRECTORY}/" + service.secret_name,
|
||||||
|
],
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
)
|
||||||
|
time.sleep(0.1)
|
||||||
|
assert proc.poll() is None, "service should block for secret"
|
||||||
|
service.secret_path.write_text("foo")
|
||||||
|
assert proc.stdout is not None and proc.stdout.read() == "foo"
|
||||||
|
assert proc.wait() == 0
|
@ -0,0 +1,193 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
type inotifyRequest struct {
|
||||||
|
filename string
|
||||||
|
conn *net.UnixConn
|
||||||
|
}
|
||||||
|
|
||||||
|
type watcher struct {
|
||||||
|
requests chan inotifyRequest
|
||||||
|
connectionEvents chan int
|
||||||
|
dir string
|
||||||
|
}
|
||||||
|
|
||||||
|
type connection struct {
|
||||||
|
fd int
|
||||||
|
connection *net.UnixConn
|
||||||
|
}
|
||||||
|
|
||||||
|
type watch struct {
|
||||||
|
connections []connection
|
||||||
|
}
|
||||||
|
|
||||||
|
func readEvents(inotifyFd int, events chan string) {
|
||||||
|
var buf [syscall.SizeofInotifyEvent * 4096]byte // Buffer for a maximum of 4096 raw events
|
||||||
|
for {
|
||||||
|
n, err := syscall.Read(inotifyFd, buf[:])
|
||||||
|
// If a signal interrupted execution, see if we've been asked to close, and try again.
|
||||||
|
// http://man7.org/linux/man-pages/man7/signal.7.html :
|
||||||
|
// "Before Linux 3.8, reads from an inotify(7) file descriptor were not restartable"
|
||||||
|
if errors.Is(err, syscall.EINTR) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if n < syscall.SizeofInotifyEvent {
|
||||||
|
if n == 0 {
|
||||||
|
log.Fatalf("notify: read EOF from inotify (cause: %v)", err)
|
||||||
|
} else if n < 0 {
|
||||||
|
log.Fatalf("notify: Received error while reading from inotify: %v", err)
|
||||||
|
} else {
|
||||||
|
log.Fatal("notify: short read in readEvents()")
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var offset uint32
|
||||||
|
for offset <= uint32(n-syscall.SizeofInotifyEvent) {
|
||||||
|
// Point "raw" to the event in the buffer
|
||||||
|
raw := (*syscall.InotifyEvent)(unsafe.Pointer(&buf[offset]))
|
||||||
|
|
||||||
|
mask := uint32(raw.Mask)
|
||||||
|
nameLen := uint32(raw.Len)
|
||||||
|
|
||||||
|
if mask&syscall.IN_Q_OVERFLOW != 0 {
|
||||||
|
// TODO Re-scan all files in this case
|
||||||
|
log.Fatal("Overflow in inotify")
|
||||||
|
}
|
||||||
|
if nameLen > 0 {
|
||||||
|
// Point "bytes" at the first byte of the filename
|
||||||
|
bytes := (*[syscall.PathMax]byte)(unsafe.Pointer(&buf[offset+syscall.SizeofInotifyEvent]))
|
||||||
|
// The filename is padded with NULL bytes. TrimRight() gets rid of those.
|
||||||
|
fname := strings.TrimRight(string(bytes[0:nameLen]), "\000")
|
||||||
|
log.Printf("Detected added file: %s", fname)
|
||||||
|
events <- fname
|
||||||
|
}
|
||||||
|
|
||||||
|
// Move to the next event in the buffer
|
||||||
|
offset += syscall.SizeofInotifyEvent + nameLen
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func connFd(conn *net.UnixConn) (int, error) {
|
||||||
|
file, err := conn.File()
|
||||||
|
if err != nil {
|
||||||
|
return -1, err
|
||||||
|
}
|
||||||
|
return int(file.Fd()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *server) watch(inotifyFd int) {
|
||||||
|
connsForPath := make(map[string][]connection)
|
||||||
|
fdToPath := make(map[int]string)
|
||||||
|
|
||||||
|
defer syscall.Close(inotifyFd)
|
||||||
|
|
||||||
|
fsEvents := make(chan string)
|
||||||
|
go readEvents(inotifyFd, fsEvents)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case req, ok := <-s.inotifyRequests:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fd, err := connFd(req.conn)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Received inotify request for closed connection")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fdToPath[fd] = req.filename
|
||||||
|
conns, ok := connsForPath[req.filename]
|
||||||
|
if ok {
|
||||||
|
conns = append(conns, connection{fd, req.conn})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
connsForPath[req.filename] = []connection{{fd, req.conn}}
|
||||||
|
case fname, ok := <-fsEvents:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
conns := connsForPath[fname]
|
||||||
|
if conns == nil {
|
||||||
|
log.Printf("Ignore unknown file: %s", fname)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
delete(connsForPath, fname)
|
||||||
|
|
||||||
|
for _, conn := range conns {
|
||||||
|
f, err := os.Open(filepath.Join(s.SecretDir, fname))
|
||||||
|
defer f.Close()
|
||||||
|
defer delete(fdToPath, conn.fd)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
_, err := io.Copy(conn.connection, f)
|
||||||
|
if err == nil {
|
||||||
|
log.Printf("Served %s to %s", fname, conn.connection.RemoteAddr().String())
|
||||||
|
} else {
|
||||||
|
log.Printf("Failed to send secret: %v", err)
|
||||||
|
}
|
||||||
|
s.epollDelete(conn.fd)
|
||||||
|
if err := syscall.Shutdown(conn.fd, syscall.SHUT_RDWR); err != nil {
|
||||||
|
log.Printf("Failed to shutdown socket: %v", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.Printf("Failed to open secret: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case fd, ok := <-s.connectionClosed:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
path := fdToPath[fd]
|
||||||
|
delete(fdToPath, fd)
|
||||||
|
conns := connsForPath[path]
|
||||||
|
if conns == nil {
|
||||||
|
// watcher has been already deregistered
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for idx, c := range conns {
|
||||||
|
if c.fd == fd {
|
||||||
|
last := len(conns) - 1
|
||||||
|
conns[idx] = conns[last]
|
||||||
|
conns = conns[:last]
|
||||||
|
|
||||||
|
c.connection.Close()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(conns) == 0 {
|
||||||
|
delete(connsForPath, path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *server) setupWatcher(dir string) error {
|
||||||
|
fd, err := syscall.InotifyInit1(syscall.IN_CLOEXEC)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Failed to initialize inotify: %v", err)
|
||||||
|
}
|
||||||
|
flags := uint32(syscall.IN_CREATE | syscall.IN_MOVED_TO | syscall.IN_ONLYDIR)
|
||||||
|
res := os.MkdirAll(dir, 0700)
|
||||||
|
if err != nil && !os.IsNotExist(res) {
|
||||||
|
return fmt.Errorf("Failed to create secret directory: %v", err)
|
||||||
|
}
|
||||||
|
if _, err = syscall.InotifyAddWatch(fd, dir, flags); err != nil {
|
||||||
|
return fmt.Errorf("Failed to initialize inotify on secret directory %s: %v", dir, err)
|
||||||
|
}
|
||||||
|
go s.watch(fd)
|
||||||
|
return nil
|
||||||
|
}
|
Loading…
Reference in New Issue