Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions server/control/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ type ClientAuthenticator interface {
type ClientAuthentication []byte

type ClientRelays interface {
Directs(ctx context.Context, endpoint model.Endpoint, role model.Role, cert *x509.Certificate, auth ClientAuthentication,
Relays(ctx context.Context, endpoint model.Endpoint, role model.Role, cert *x509.Certificate, auth ClientAuthentication,
notify func(map[RelayID]*pbclient.Relay) error) error
}

Expand Down Expand Up @@ -744,9 +744,9 @@ func (s *clientStream) relay(ctx context.Context, req *pbclient.Request_Relay) e
g.Go(quicc.CancelStream(s.stream))

g.Go(func(ctx context.Context) error {
defer s.conn.logger.Debug("completed direct relay notify")
return s.conn.server.relays.Directs(ctx, endpoint, role, clientCert, s.conn.auth, func(relays map[RelayID]*pbclient.Relay) error {
s.conn.logger.Debug("updated direct relay list", "relays", len(relays))
defer s.conn.logger.Debug("completed relay notify")
return s.conn.server.relays.Relays(ctx, endpoint, role, clientCert, s.conn.auth, func(relays map[RelayID]*pbclient.Relay) error {
s.conn.logger.Debug("updated relay list", "relays", len(relays))
if err := proto.Write(s.stream, &pbclient.Response{
Relay: &pbclient.Response_Relays{
Relays: slices.Collect(maps.Values(relays)),
Expand Down
96 changes: 49 additions & 47 deletions server/control/relays.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,18 @@ func newRelayServer(
stores Stores,
logger *slog.Logger,
) (*relayServer, error) {
directs, err := stores.RelayDirects()
conns, err := stores.RelayConns()
if err != nil {
return nil, fmt.Errorf("relay directs store open: %w", err)
return nil, fmt.Errorf("relay conns store open: %w", err)
}

directsMsgs, directsOffset, err := directs.Snapshot()
connsMsgs, connsOffset, err := conns.Snapshot()
if err != nil {
return nil, fmt.Errorf("relay conns snapshot: %w", err)
}
directsCache := map[RelayID]directRelay{}
for _, msg := range directsMsgs {
directsCache[msg.Key.ID] = directRelay{
connsCache := map[RelayID]cachedRelay{}
for _, msg := range connsMsgs {
connsCache[msg.Key.ID] = cachedRelay{
auth: msg.Value.Authentication,
authSealKey: msg.Value.AuthenticationSealKey,
template: &pbclient.Relay{
Expand Down Expand Up @@ -112,9 +112,9 @@ func newRelayServer(

reconnect: &reconnectToken{[32]byte(serverSecret.Bytes)},

directs: directs,
directsCache: directsCache,
directsOffset: directsOffset,
conns: conns,
connsCache: connsCache,
connsOffset: connsOffset,
}, nil
}

Expand All @@ -128,26 +128,26 @@ type relayServer struct {

reconnect *reconnectToken

directs logc.KV[RelayConnKey, RelayDirectValue]
directsCache map[RelayID]directRelay
directsOffset int64
directsMu sync.RWMutex
conns logc.KV[RelayConnKey, RelayConnValue]
connsCache map[RelayID]cachedRelay
connsOffset int64
connsMu sync.RWMutex
}

type directRelay struct {
type cachedRelay struct {
auth RelayAuthentication
authSealKey *[32]byte
template *pbclient.Relay
}

func (s *relayServer) cachedDirects() (map[RelayID]directRelay, int64) {
s.directsMu.RLock()
defer s.directsMu.RUnlock()
func (s *relayServer) cachedRelays() (map[RelayID]cachedRelay, int64) {
s.connsMu.RLock()
defer s.connsMu.RUnlock()

return maps.Clone(s.directsCache), s.directsOffset
return maps.Clone(s.connsCache), s.connsOffset
}

func (s *relayServer) Directs(ctx context.Context, endpoint model.Endpoint, role model.Role, cert *x509.Certificate, auth ClientAuthentication,
func (s *relayServer) Relays(ctx context.Context, endpoint model.Endpoint, role model.Role, cert *x509.Certificate, auth ClientAuthentication,
notify func(map[RelayID]*pbclient.Relay) error) error {

authenticationData, err := protobuf.Marshal(&pbrelay.ClientAuthentication{
Expand All @@ -164,13 +164,15 @@ func (s *relayServer) Directs(ctx context.Context, endpoint model.Endpoint, role
return box.SealAfterPrecomputation(nonce[:], authenticationData, &nonce, key)
}

directRelays, offset := s.cachedDirects()
localDirectRelays := map[RelayID]*pbclient.Relay{}
for id, relay := range directRelays {
localRelays := map[RelayID]*pbclient.Relay{}

// load initial state
globalRelays, offset := s.cachedRelays()
for id, relay := range globalRelays {
if ok, err := s.auth.Allow(relay.auth, auth, endpoint); err != nil {
return fmt.Errorf("auth allow error: %w", err)
} else if ok {
localDirectRelays[id] = &pbclient.Relay{
localRelays[id] = &pbclient.Relay{
Id: relay.template.Id,
Addresses: relay.template.Addresses,
ServerCertificate: relay.template.ServerCertificate,
Expand All @@ -179,25 +181,25 @@ func (s *relayServer) Directs(ctx context.Context, endpoint model.Endpoint, role
}
}
}
if err := notify(localDirectRelays); err != nil {
if err := notify(localRelays); err != nil {
return err
}

for {
msgs, nextOffset, err := s.directs.Consume(ctx, offset)
msgs, nextOffset, err := s.conns.Consume(ctx, offset)
if err != nil {
return err
}

var changed bool
for _, msg := range msgs {
if msg.Delete {
delete(localDirectRelays, msg.Key.ID)
delete(localRelays, msg.Key.ID)
changed = true
} else if ok, err := s.auth.Allow(msg.Value.Authentication, auth, endpoint); err != nil {
return fmt.Errorf("auth allow error: %w", err)
} else if ok {
localDirectRelays[msg.Key.ID] = &pbclient.Relay{
localRelays[msg.Key.ID] = &pbclient.Relay{
Id: msg.Key.ID.string,
Addresses: model.PBsFromHostPorts(msg.Value.Hostports),
ServerCertificate: msg.Value.Certificate.Raw,
Expand All @@ -211,7 +213,7 @@ func (s *relayServer) Directs(ctx context.Context, endpoint model.Endpoint, role
offset = nextOffset

if changed {
if err := notify(localDirectRelays); err != nil {
if err := notify(localRelays); err != nil {
return err
}
}
Expand All @@ -224,9 +226,9 @@ func (s *relayServer) run(ctx context.Context) error {
for _, ingress := range s.ingresses {
g.Go(reliable.Bind(ingress, s.runListener))
}
g.Go(s.runDirectsCache)
g.Go(s.runConnsCache)

g.Go(logc.ScheduleCompact(s.directs))
g.Go(logc.ScheduleCompact(s.conns))

return g.Wait()
}
Expand Down Expand Up @@ -294,15 +296,15 @@ func (s *relayServer) runListener(ctx context.Context, ingress Ingress) error {
}
}

func (s *relayServer) runDirectsCache(ctx context.Context) error {
update := func(msg logc.Message[RelayConnKey, RelayDirectValue]) {
s.directsMu.Lock()
defer s.directsMu.Unlock()
func (s *relayServer) runConnsCache(ctx context.Context) error {
update := func(msg logc.Message[RelayConnKey, RelayConnValue]) {
s.connsMu.Lock()
defer s.connsMu.Unlock()

if msg.Delete {
delete(s.directsCache, msg.Key.ID)
delete(s.connsCache, msg.Key.ID)
} else {
s.directsCache[msg.Key.ID] = directRelay{
s.connsCache[msg.Key.ID] = cachedRelay{
auth: msg.Value.Authentication,
authSealKey: msg.Value.AuthenticationSealKey,
template: &pbclient.Relay{
Expand All @@ -314,15 +316,15 @@ func (s *relayServer) runDirectsCache(ctx context.Context) error {
}
}

s.directsOffset = msg.Offset + 1
s.connsOffset = msg.Offset + 1
}

for {
s.directsMu.RLock()
offset := s.directsOffset
s.directsMu.RUnlock()
s.connsMu.RLock()
offset := s.connsOffset
s.connsMu.RUnlock()

msgs, nextOffset, err := s.directs.Consume(ctx, offset)
msgs, nextOffset, err := s.conns.Consume(ctx, offset)
if err != nil {
return fmt.Errorf("relay conns consume: %w", err)
}
Expand All @@ -331,9 +333,9 @@ func (s *relayServer) runDirectsCache(ctx context.Context) error {
update(msg)
}

s.directsMu.Lock()
s.directsOffset = nextOffset
s.directsMu.Unlock()
s.connsMu.Lock()
s.connsOffset = nextOffset
s.connsMu.Unlock()
}
}

Expand Down Expand Up @@ -387,12 +389,12 @@ func (c *relayConn) runErr(ctx context.Context) error {
defer c.logger.Info("relay disconnected", "addr", c.conn.RemoteAddr(), "metadata", c.metadata)

key := RelayConnKey{ID: c.id}
value := RelayDirectValue{c.auth, c.hostports, c.metadata, c.certificate, c.authSignKey}
if err := c.server.directs.Put(key, value); err != nil {
value := RelayConnValue{c.auth, c.hostports, c.metadata, c.certificate, c.authSignKey}
if err := c.server.conns.Put(key, value); err != nil {
return err
}
defer func() {
if err := c.server.directs.Del(key); err != nil {
if err := c.server.conns.Del(key); err != nil {
c.logger.Warn("failed to delete conn", "key", key, "err", err)
}
}()
Expand Down
2 changes: 1 addition & 1 deletion server/control/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ func (s *Server) getEndpoints() (map[string]StatusEndpoint, error) {
}

func (s *Server) getRelays() (map[string]StatusRelay, error) {
msgs, _, err := s.relays.directs.Snapshot()
msgs, _, err := s.relays.conns.Snapshot()
if err != nil {
return nil, err
}
Expand Down
20 changes: 10 additions & 10 deletions server/control/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ type Stores interface {
ClientConns() (logc.KV[ClientConnKey, ClientConnValue], error)
ClientPeers() (logc.KV[ClientPeerKey, ClientPeerValue], error)

RelayDirects() (logc.KV[RelayConnKey, RelayDirectValue], error)
RelayConns() (logc.KV[RelayConnKey, RelayConnValue], error)

RemoveDeprecated() error
}
Expand All @@ -43,8 +43,8 @@ func (f *fileStores) ClientPeers() (logc.KV[ClientPeerKey, ClientPeerValue], err
return logc.NewKV[ClientPeerKey, ClientPeerValue](filepath.Join(f.dir, "client-peers"))
}

func (f *fileStores) RelayDirects() (logc.KV[RelayConnKey, RelayDirectValue], error) {
return logc.NewKV[RelayConnKey, RelayDirectValue](filepath.Join(f.dir, "relay-directs"))
func (f *fileStores) RelayConns() (logc.KV[RelayConnKey, RelayConnValue], error) {
return logc.NewKV[RelayConnKey, RelayConnValue](filepath.Join(f.dir, "relay-directs")) // TODO rename in v0.15
}

func (f *fileStores) RemoveDeprecated() error {
Expand Down Expand Up @@ -100,24 +100,24 @@ type RelayConnKey struct {
ID RelayID `json:"id"`
}

type RelayDirectValue struct {
type RelayConnValue struct {
Authentication RelayAuthentication `json:"authentication"`
Hostports []model.HostPort `json:"hostports"`
Metadata string `json:"metadata"`
Certificate *x509.Certificate `json:"certificate"`
AuthenticationSealKey *[32]byte `json:"authentication-seal-key"`
}

type jsonRelayDirectValue struct {
type jsonRelayConnValue struct {
Authentication RelayAuthentication `json:"authentication"`
Hostports []model.HostPort `json:"hostports"`
Metadata string `json:"metadata"`
Certificate []byte `json:"certificate"`
AuthenticationSealKey []byte `json:"authentication-seal-key"`
}

func (v RelayDirectValue) MarshalJSON() ([]byte, error) {
return json.Marshal(jsonRelayDirectValue{
func (v RelayConnValue) MarshalJSON() ([]byte, error) {
return json.Marshal(jsonRelayConnValue{
Authentication: v.Authentication,
Hostports: v.Hostports,
Metadata: v.Metadata,
Expand All @@ -126,8 +126,8 @@ func (v RelayDirectValue) MarshalJSON() ([]byte, error) {
})
}

func (v *RelayDirectValue) UnmarshalJSON(b []byte) error {
s := jsonRelayDirectValue{}
func (v *RelayConnValue) UnmarshalJSON(b []byte) error {
s := jsonRelayConnValue{}
if err := json.Unmarshal(b, &s); err != nil {
return err
}
Expand All @@ -139,6 +139,6 @@ func (v *RelayDirectValue) UnmarshalJSON(b []byte) error {

var authKey [32]byte
copy(authKey[:], s.AuthenticationSealKey)
*v = RelayDirectValue{s.Authentication, s.Hostports, s.Metadata, cert, &authKey}
*v = RelayConnValue{s.Authentication, s.Hostports, s.Metadata, cert, &authKey}
return nil
}
17 changes: 8 additions & 9 deletions server/relay/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,19 @@ type clientsServer struct {
}

func newClientsServer(cfg Config, cert *certc.Cert, auth ClientAuthenticator) (*clientsServer, error) {
directTLS, err := cert.TLSCert()
tlsCert, err := cert.TLSCert()
if err != nil {
return nil, fmt.Errorf("direct TLS cert: %w", err)
}
directTLSConf := &tls.Config{
ServerName: directTLS.Leaf.DNSNames[0],
Certificates: []tls.Certificate{directTLS},
ClientAuth: tls.RequireAnyClientCert,
NextProtos: iterc.MapVarStrings(model.ConnectRelayV02),
}

return &clientsServer{
tlsConf: directTLSConf,
auth: auth,
tlsConf: &tls.Config{
ServerName: tlsCert.Leaf.DNSNames[0],
Certificates: []tls.Certificate{tlsCert},
ClientAuth: tls.RequireAnyClientCert,
NextProtos: iterc.MapVarStrings(model.ConnectRelayV02),
},
auth: auth,

endpoints: map[model.Endpoint]*endpointClients{},

Expand Down
16 changes: 8 additions & 8 deletions server/relay/control.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ import (
)

type controlClient struct {
hostports []model.HostPort
direct *certc.Cert
metadata string
hostports []model.HostPort
clientsCert *certc.Cert
metadata string

controlAddr *net.UDPAddr
controlToken string
Expand All @@ -47,15 +47,15 @@ type controlClient struct {
logger *slog.Logger
}

func newControlClient(cfg Config, direct *certc.Cert, configStore logc.KV[ConfigKey, ConfigValue]) (*controlClient, error) {
func newControlClient(cfg Config, clientsCert *certc.Cert, configStore logc.KV[ConfigKey, ConfigValue]) (*controlClient, error) {
hostports := iterc.FlattenSlice(iterc.MapSlice(cfg.Ingress, func(in Ingress) []model.HostPort {
return in.Hostports
}))

c := &controlClient{
hostports: hostports,
direct: direct,
metadata: cfg.Metadata,
hostports: hostports,
clientsCert: clientsCert,
metadata: cfg.Metadata,

controlAddr: cfg.ControlAddr,
controlToken: cfg.ControlToken,
Expand Down Expand Up @@ -197,7 +197,7 @@ func (s *controlClient) authenticate(authStream *quic.Stream, reconnConfig Confi
ReconnectToken: reconnConfig.Bytes,
BuildVersion: model.BuildVersion(),
Metadata: s.metadata,
ServerCertificate: s.direct.Raw(),
ServerCertificate: s.clientsCert.Raw(),
RelayAuthenticationKey: relayPk[:],
}); err != nil {
return fmt.Errorf("auth write error: %w", err)
Expand Down
Loading