Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backend/pkg/transport/constructor.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (

func NewTransport(baseLogger zerolog.Logger) *Transport {
transport := &Transport{
connectionsMx: &sync.Mutex{},
connectionsMx: &sync.RWMutex{},
connections: make(map[abstraction.TransportTarget]net.Conn),
idToTarget: make(map[abstraction.PacketId]abstraction.TransportTarget),
ipToTarget: make(map[string]abstraction.TransportTarget),
Expand Down
2 changes: 1 addition & 1 deletion backend/pkg/transport/packet/data/decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func (decoder *Decoder) Decode(id abstraction.PacketId, reader io.Reader) (abstr
return nil, ErrUnexpectedId{Id: id}
}

packet := NewPacket(id)
packet := GetPacket(id)
for _, value := range descriptor {
val, err := value.Decode(decoder.endianness, reader)
if err != nil {
Expand Down
43 changes: 43 additions & 0 deletions backend/pkg/transport/packet/data/packet.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package data

import (
"sync"
"time"

"github.com/HyperloopUPV-H8/h9-backend/pkg/abstraction"
Expand All @@ -27,6 +28,16 @@ func NewPacket(id abstraction.PacketId) *Packet {
}
}

var packetPool = sync.Pool{
New: func() any {
return &Packet{
values: make(map[ValueName]Value),
enabled: make(map[ValueName]bool),
}
},
}


// NewPacketWithValues creates a new data packet with the given values
func NewPacketWithValues(id abstraction.PacketId, values map[ValueName]Value, enabled map[ValueName]bool) *Packet {
return &Packet{
Expand Down Expand Up @@ -62,3 +73,35 @@ func (packet *Packet) SetTimestamp(timestamp time.Time) *Packet {
packet.timestamp = timestamp
return packet
}

func (packet *Packet) Reset() {
clear(packet.values)
clear(packet.enabled)
packet.id = 0
packet.timestamp = time.Time{}
}

func GetPacket(id abstraction.PacketId) *Packet {
p := packetPool.Get().(*Packet)
if p.values == nil {
p.values = make(map[ValueName]Value)
} else {
clear(p.values)
}
if p.enabled == nil {
p.enabled = make(map[ValueName]bool)
} else {
clear(p.enabled)
}
p.id = id
p.timestamp = time.Now()
return p
}

func ReleasePacket(p *Packet) {
if p == nil {
return
}
p.Reset()
packetPool.Put(p)
}
35 changes: 29 additions & 6 deletions backend/pkg/transport/presentation/encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"encoding/binary"
"io"
"sync"

"github.com/HyperloopUPV-H8/h9-backend/pkg/abstraction"
"github.com/rs/zerolog"
Expand All @@ -17,7 +18,8 @@ type Encoder struct {
idToEncoder map[abstraction.PacketId]PacketEncoder
endianness binary.ByteOrder

logger zerolog.Logger
logger zerolog.Logger
bufPool sync.Pool
}

// TODO: improve constructor
Expand All @@ -28,6 +30,9 @@ func NewEncoder(endianness binary.ByteOrder, baseLogger zerolog.Logger) *Encoder
endianness: endianness,

logger: baseLogger,
bufPool: sync.Pool{
New: func() any { return new(bytes.Buffer) },
},
}
}

Expand All @@ -37,23 +42,41 @@ func (encoder *Encoder) SetPacketEncoder(id abstraction.PacketId, enc PacketEnco
encoder.logger.Trace().Uint16("id", uint16(id)).Type("encoder", enc).Msg("set encoder")
}

// Encode encodes the provided packet into a byte slice, returning any errors
func (encoder *Encoder) Encode(packet abstraction.Packet) ([]byte, error) {
// Encode encodes the provided packet into a pooled buffer. Callers must release
// the buffer via ReleaseBuffer once they are done using the returned data.
func (encoder *Encoder) Encode(packet abstraction.Packet) (*bytes.Buffer, error) {
enc, ok := encoder.idToEncoder[packet.Id()]
if !ok {
encoder.logger.Warn().Uint16("id", uint16(packet.Id())).Msg("no encoder set")
return nil, ErrUnexpectedId{Id: packet.Id()}
}

buffer := new(bytes.Buffer)
bufferAny := encoder.bufPool.Get()
buffer := bufferAny.(*bytes.Buffer)
buffer.Reset()

err := binary.Write(buffer, encoder.endianness, packet.Id())
if err != nil {
encoder.logger.Error().Stack().Err(err).Uint16("id", uint16(packet.Id())).Msg("buffering id")
return buffer.Bytes(), err
encoder.ReleaseBuffer(buffer)
return nil, err
}

encoder.logger.Debug().Uint16("id", uint16(packet.Id())).Type("encoder", enc).Msg("encoding")
err = enc.Encode(packet, buffer)
return buffer.Bytes(), err
if err != nil {
encoder.ReleaseBuffer(buffer)
return nil, err
}

return buffer, nil
}

// ReleaseBuffer returns a buffer obtained from Encode back to the pool.
func (encoder *Encoder) ReleaseBuffer(buffer *bytes.Buffer) {
if buffer == nil {
return
}
buffer.Reset()
encoder.bufPool.Put(buffer)
}
5 changes: 3 additions & 2 deletions backend/pkg/transport/presentation/encoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,12 +379,13 @@ func TestEncoder(t *testing.T) {

output := make([]byte, 0, len(test.output))
for i := 0; i < len(test.input); i++ {
encoded, err := encoder.Encode(test.input[i])
buf, err := encoder.Encode(test.input[i])
if err != nil {
t.Fatalf("\nError encoding (%d) packet: %s\n", i+1, err)
}

output = append(output, encoded...)
output = append(output, buf.Bytes()...)
encoder.ReleaseBuffer(buf)

}

Expand Down
91 changes: 60 additions & 31 deletions backend/pkg/transport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ type Transport struct {
decoder *presentation.Decoder
encoder *presentation.Encoder

connectionsMx *sync.Mutex
connectionsMx *sync.RWMutex
connections map[abstraction.TransportTarget]net.Conn

ipToTarget map[string]abstraction.TransportTarget
Expand All @@ -45,22 +45,27 @@ type Transport struct {

logger zerolog.Logger

byteReaderPool sync.Pool

errChan chan error
}

// For tests
var zeroTime time.Time

// HandleClient connects to the specified client and handles its messages. This method blocks.
// This method will continuously try to reconnect to the client if it disconnects,
// applying exponential backoff between attempts.
func (transport *Transport) HandleClient(config tcp.ClientConfig, remote string) error {
client := tcp.NewClient(remote, config, transport.logger)
defer transport.logger.Warn().Str("remoteAddress", remote).Msg("abort connection")
clientLogger := transport.logger.With().Str("remoteAddress", remote).Logger()
defer clientLogger.Warn().Msg("abort connection")
var hasConnected = false

for {
conn, err := client.Dial()
if err != nil {
transport.logger.Debug().Stack().Err(err).Str("remoteAddress", remote).Msg("dial failed")

clientLogger.Debug().Stack().Err(err).Msg("dial failed")
// Only return if reconnection is disabled
if !config.TryReconnect {
if hasConnected {
Expand All @@ -73,7 +78,7 @@ func (transport *Transport) HandleClient(config tcp.ClientConfig, remote string)
// For ErrTooManyRetries, we still want to continue retrying
// The client will reset its retry counter on the next Dial() call
if _, ok := err.(tcp.ErrTooManyRetries); ok {
transport.logger.Warn().Str("remoteAddress", remote).Msg("reached max retries, will continue attempting to reconnect")
clientLogger.Warn().Msg("reached max retries, will continue attempting to reconnect")
// Add a longer delay before restarting the retry cycle
time.Sleep(config.ConnectionBackoffFunction(config.MaxConnectionRetries))
}
Expand All @@ -85,12 +90,12 @@ func (transport *Transport) HandleClient(config tcp.ClientConfig, remote string)

err = transport.handleTCPConn(conn)
if errors.Is(err, error(ErrTargetAlreadyConnected{})) {
transport.logger.Warn().Stack().Err(err).Str("remoteAddress", remote).Msg("multiple connections for same target")
clientLogger.Warn().Stack().Err(err).Msg("multiple connections for same target")
transport.errChan <- err
return err
}
if err != nil {
transport.logger.Debug().Stack().Err(err).Str("remoteAddress", remote).Msg("connection lost")
clientLogger.Debug().Stack().Err(err).Msg("connection lost")
if !config.TryReconnect {
transport.SendFault()
transport.errChan <- err
Expand Down Expand Up @@ -254,6 +259,10 @@ func (transport *Transport) readLoopTCPConn(conn net.Conn, logger zerolog.Logger

logger.Trace().Type("type", packet).Msg("packet")
transport.api.Notification(NewPacketNotification(packet, from, to, time.Now()))

if dataPacket, ok := packet.(*data.Packet); ok {
data.ReleasePacket(dataPacket)
}
}
}()
}
Expand Down Expand Up @@ -289,30 +298,31 @@ func (transport *Transport) handlePacketEvent(message PacketMessage) error {

if message.Id() == 0 {
eventLogger.Info().Msg("broadcasting packet id 0")
data, err := transport.encoder.Encode(message.Packet)
buf, err := transport.encoder.Encode(message.Packet)
if err != nil {
eventLogger.Error().Stack().Err(err).Msg("encode")
transport.errChan <- err
return err
}
defer transport.encoder.ReleaseBuffer(buf)
data := buf.Bytes()

transport.connectionsMx.Lock()
defer transport.connectionsMx.Unlock()
transport.connectionsMx.RLock()
defer transport.connectionsMx.RUnlock()
for target, conn := range transport.connections {
eventLogger := eventLogger.With().Str("target", string(target)).Logger()

targetName := string(target)
totalWritten := 0
for totalWritten < len(data) {
n, err := conn.Write(data[totalWritten:])
eventLogger.Trace().Int("amount", n).Msg("written chunk")
eventLogger.Trace().Str("target", targetName).Int("amount", n).Msg("written chunk")
totalWritten += n
if err != nil {
eventLogger.Error().Stack().Err(err).Msg("write")
eventLogger.Error().Str("target", targetName).Stack().Err(err).Msg("write")
transport.errChan <- err
return err
}
}
eventLogger.Info().Msg("sent")
eventLogger.Info().Str("target", targetName).Msg("sent")
}
return nil
}
Expand All @@ -328,11 +338,11 @@ func (transport *Transport) handlePacketEvent(message PacketMessage) error {
eventLogger.Info().Msg("sending")

conn, err := func() (net.Conn, error) {
transport.connectionsMx.Lock()
defer transport.connectionsMx.Unlock()
transport.connectionsMx.RLock()
defer transport.connectionsMx.RUnlock()
conn, ok := transport.connections[target]
if !ok {
eventLogger.Warn().Msg("target not connected")
eventLogger.Warn().Msg("target not connected")

err := ErrConnClosed{Target: target}
return nil, err
Expand All @@ -344,12 +354,14 @@ func (transport *Transport) handlePacketEvent(message PacketMessage) error {
return err
}

data, err := transport.encoder.Encode(message.Packet)
buf, err := transport.encoder.Encode(message.Packet)
if err != nil {
eventLogger.Error().Stack().Err(err).Msg("encode")
transport.errChan <- err
return err
}
defer transport.encoder.ReleaseBuffer(buf)
data := buf.Bytes()

totalWritten := 0
for totalWritten < len(data) {
Expand Down Expand Up @@ -413,14 +425,30 @@ func (transport *Transport) HandleUDPServer(server *udp.Server) {
}
}

func (transport *Transport) replicateFault(packet abstraction.Packet, logger zerolog.Logger) {
logger.Info().Msg("replicating packet with id 0 to all boards")
err := transport.handlePacketEvent(NewPacketMessage(packet))
if err != nil {
logger.Error().Err(err).Msg("failed to replicate packet")
}
}

// handleUDPPacket handles a single UDP packet received by the UDP server
func (transport *Transport) handleUDPPacket(udpPacket udp.Packet) {
srcAddr := fmt.Sprintf("%s:%d", udpPacket.SourceIP, udpPacket.SourcePort)
dstAddr := fmt.Sprintf("%s:%d", udpPacket.DestIP, udpPacket.DestPort)

// Create a reader from the payload
reader := bytes.NewReader(udpPacket.Payload)

readerAny := transport.byteReaderPool.Get()
var reader *bytes.Reader
if readerAny != nil {
reader = readerAny.(*bytes.Reader)
reader.Reset(udpPacket.Payload)
} else {
reader = bytes.NewReader(udpPacket.Payload)
}
defer transport.byteReaderPool.Put(reader)

// Decode the packet
packet, err := transport.decoder.DecodeNext(reader)
if err != nil {
Expand All @@ -435,15 +463,15 @@ func (transport *Transport) handleUDPPacket(udpPacket udp.Packet) {

// Intercept packets with id == 0 and replicate
if transport.propagateFault && packet.Id() == 0 {
transport.logger.Info().Msg("replicating packet with id 0 to all boards")
err := transport.handlePacketEvent(NewPacketMessage(packet))
if err != nil {
transport.logger.Error().Err(err).Msg("failed to replicate packet")
}
transport.replicateFault(packet, transport.logger)
}

// Send notification
transport.api.Notification(NewPacketNotification(packet, srcAddr, dstAddr, udpPacket.Timestamp))

if dataPacket, ok := packet.(*data.Packet); ok {
data.ReleasePacket(dataPacket)
}
}

// handleConversation is called when the sniffer detects a new conversation and handles its specific packets
Expand All @@ -463,14 +491,15 @@ func (transport *Transport) handleConversation(socket network.Socket, reader io.

// Intercept packets with id == 0 and replicate
if transport.propagateFault && packet.Id() == 0 {
conversationLogger.Info().Msg("replicating packet with id 0 to all boards")
err := transport.handlePacketEvent(NewPacketMessage(packet))
if err != nil {
conversationLogger.Error().Err(err).Msg("failed to replicate packet")
}
transport.replicateFault(packet, transport.logger)
}

// Send notification
transport.api.Notification(NewPacketNotification(packet, srcAddr, dstAddr, time.Now()))

if dataPacket, ok := packet.(*data.Packet); ok {
data.ReleasePacket(dataPacket)
}
}
}()
}
Expand Down
Loading
Loading