diff --git a/backend/pkg/transport/constructor.go b/backend/pkg/transport/constructor.go index a582a093f..d555f40ef 100644 --- a/backend/pkg/transport/constructor.go +++ b/backend/pkg/transport/constructor.go @@ -12,7 +12,7 @@ import ( func NewTransport(baseLogger zerolog.Logger) *Transport { transport := &Transport{ - connectionsMx: &sync.Mutex{}, + connectionsMx: &sync.RWMutex{}, connections: make(map[abstraction.TransportTarget]net.Conn), idToTarget: make(map[abstraction.PacketId]abstraction.TransportTarget), ipToTarget: make(map[string]abstraction.TransportTarget), diff --git a/backend/pkg/transport/packet/data/decoder.go b/backend/pkg/transport/packet/data/decoder.go index a4c11699a..70722a647 100644 --- a/backend/pkg/transport/packet/data/decoder.go +++ b/backend/pkg/transport/packet/data/decoder.go @@ -35,7 +35,7 @@ func (decoder *Decoder) Decode(id abstraction.PacketId, reader io.Reader) (abstr return nil, ErrUnexpectedId{Id: id} } - packet := NewPacket(id) + packet := GetPacket(id) for _, value := range descriptor { val, err := value.Decode(decoder.endianness, reader) if err != nil { diff --git a/backend/pkg/transport/packet/data/packet.go b/backend/pkg/transport/packet/data/packet.go index d44f3c7df..75178e0a6 100644 --- a/backend/pkg/transport/packet/data/packet.go +++ b/backend/pkg/transport/packet/data/packet.go @@ -1,6 +1,7 @@ package data import ( + "sync" "time" "github.com/HyperloopUPV-H8/h9-backend/pkg/abstraction" @@ -27,6 +28,16 @@ func NewPacket(id abstraction.PacketId) *Packet { } } +var packetPool = sync.Pool{ + New: func() any { + return &Packet{ + values: make(map[ValueName]Value), + enabled: make(map[ValueName]bool), + } + }, +} + + // NewPacketWithValues creates a new data packet with the given values func NewPacketWithValues(id abstraction.PacketId, values map[ValueName]Value, enabled map[ValueName]bool) *Packet { return &Packet{ @@ -62,3 +73,35 @@ func (packet *Packet) SetTimestamp(timestamp time.Time) *Packet { packet.timestamp = timestamp return packet } + +func (packet *Packet) Reset() { + clear(packet.values) + clear(packet.enabled) + packet.id = 0 + packet.timestamp = time.Time{} +} + +func GetPacket(id abstraction.PacketId) *Packet { + p := packetPool.Get().(*Packet) + if p.values == nil { + p.values = make(map[ValueName]Value) + } else { + clear(p.values) + } + if p.enabled == nil { + p.enabled = make(map[ValueName]bool) + } else { + clear(p.enabled) + } + p.id = id + p.timestamp = time.Now() + return p +} + +func ReleasePacket(p *Packet) { + if p == nil { + return + } + p.Reset() + packetPool.Put(p) +} diff --git a/backend/pkg/transport/presentation/encoder.go b/backend/pkg/transport/presentation/encoder.go index 2328f0442..7618628d8 100644 --- a/backend/pkg/transport/presentation/encoder.go +++ b/backend/pkg/transport/presentation/encoder.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" "io" + "sync" "github.com/HyperloopUPV-H8/h9-backend/pkg/abstraction" "github.com/rs/zerolog" @@ -17,7 +18,8 @@ type Encoder struct { idToEncoder map[abstraction.PacketId]PacketEncoder endianness binary.ByteOrder - logger zerolog.Logger + logger zerolog.Logger + bufPool sync.Pool } // TODO: improve constructor @@ -28,6 +30,9 @@ func NewEncoder(endianness binary.ByteOrder, baseLogger zerolog.Logger) *Encoder endianness: endianness, logger: baseLogger, + bufPool: sync.Pool{ + New: func() any { return new(bytes.Buffer) }, + }, } } @@ -37,23 +42,41 @@ func (encoder *Encoder) SetPacketEncoder(id abstraction.PacketId, enc PacketEnco encoder.logger.Trace().Uint16("id", uint16(id)).Type("encoder", enc).Msg("set encoder") } -// Encode encodes the provided packet into a byte slice, returning any errors -func (encoder *Encoder) Encode(packet abstraction.Packet) ([]byte, error) { +// Encode encodes the provided packet into a pooled buffer. Callers must release +// the buffer via ReleaseBuffer once they are done using the returned data. +func (encoder *Encoder) Encode(packet abstraction.Packet) (*bytes.Buffer, error) { enc, ok := encoder.idToEncoder[packet.Id()] if !ok { encoder.logger.Warn().Uint16("id", uint16(packet.Id())).Msg("no encoder set") return nil, ErrUnexpectedId{Id: packet.Id()} } - buffer := new(bytes.Buffer) + bufferAny := encoder.bufPool.Get() + buffer := bufferAny.(*bytes.Buffer) + buffer.Reset() err := binary.Write(buffer, encoder.endianness, packet.Id()) if err != nil { encoder.logger.Error().Stack().Err(err).Uint16("id", uint16(packet.Id())).Msg("buffering id") - return buffer.Bytes(), err + encoder.ReleaseBuffer(buffer) + return nil, err } encoder.logger.Debug().Uint16("id", uint16(packet.Id())).Type("encoder", enc).Msg("encoding") err = enc.Encode(packet, buffer) - return buffer.Bytes(), err + if err != nil { + encoder.ReleaseBuffer(buffer) + return nil, err + } + + return buffer, nil +} + +// ReleaseBuffer returns a buffer obtained from Encode back to the pool. +func (encoder *Encoder) ReleaseBuffer(buffer *bytes.Buffer) { + if buffer == nil { + return + } + buffer.Reset() + encoder.bufPool.Put(buffer) } diff --git a/backend/pkg/transport/presentation/encoder_test.go b/backend/pkg/transport/presentation/encoder_test.go index ecff59dd3..d0420c913 100644 --- a/backend/pkg/transport/presentation/encoder_test.go +++ b/backend/pkg/transport/presentation/encoder_test.go @@ -379,12 +379,13 @@ func TestEncoder(t *testing.T) { output := make([]byte, 0, len(test.output)) for i := 0; i < len(test.input); i++ { - encoded, err := encoder.Encode(test.input[i]) + buf, err := encoder.Encode(test.input[i]) if err != nil { t.Fatalf("\nError encoding (%d) packet: %s\n", i+1, err) } - output = append(output, encoded...) + output = append(output, buf.Bytes()...) + encoder.ReleaseBuffer(buf) } diff --git a/backend/pkg/transport/transport.go b/backend/pkg/transport/transport.go index bc11b11c8..2fb9ea39f 100644 --- a/backend/pkg/transport/transport.go +++ b/backend/pkg/transport/transport.go @@ -31,7 +31,7 @@ type Transport struct { decoder *presentation.Decoder encoder *presentation.Encoder - connectionsMx *sync.Mutex + connectionsMx *sync.RWMutex connections map[abstraction.TransportTarget]net.Conn ipToTarget map[string]abstraction.TransportTarget @@ -45,22 +45,27 @@ type Transport struct { logger zerolog.Logger + byteReaderPool sync.Pool + errChan chan error } +// For tests +var zeroTime time.Time + // HandleClient connects to the specified client and handles its messages. This method blocks. // This method will continuously try to reconnect to the client if it disconnects, // applying exponential backoff between attempts. func (transport *Transport) HandleClient(config tcp.ClientConfig, remote string) error { client := tcp.NewClient(remote, config, transport.logger) - defer transport.logger.Warn().Str("remoteAddress", remote).Msg("abort connection") + clientLogger := transport.logger.With().Str("remoteAddress", remote).Logger() + defer clientLogger.Warn().Msg("abort connection") var hasConnected = false for { conn, err := client.Dial() if err != nil { - transport.logger.Debug().Stack().Err(err).Str("remoteAddress", remote).Msg("dial failed") - + clientLogger.Debug().Stack().Err(err).Msg("dial failed") // Only return if reconnection is disabled if !config.TryReconnect { if hasConnected { @@ -73,7 +78,7 @@ func (transport *Transport) HandleClient(config tcp.ClientConfig, remote string) // For ErrTooManyRetries, we still want to continue retrying // The client will reset its retry counter on the next Dial() call if _, ok := err.(tcp.ErrTooManyRetries); ok { - transport.logger.Warn().Str("remoteAddress", remote).Msg("reached max retries, will continue attempting to reconnect") + clientLogger.Warn().Msg("reached max retries, will continue attempting to reconnect") // Add a longer delay before restarting the retry cycle time.Sleep(config.ConnectionBackoffFunction(config.MaxConnectionRetries)) } @@ -85,12 +90,12 @@ func (transport *Transport) HandleClient(config tcp.ClientConfig, remote string) err = transport.handleTCPConn(conn) if errors.Is(err, error(ErrTargetAlreadyConnected{})) { - transport.logger.Warn().Stack().Err(err).Str("remoteAddress", remote).Msg("multiple connections for same target") + clientLogger.Warn().Stack().Err(err).Msg("multiple connections for same target") transport.errChan <- err return err } if err != nil { - transport.logger.Debug().Stack().Err(err).Str("remoteAddress", remote).Msg("connection lost") + clientLogger.Debug().Stack().Err(err).Msg("connection lost") if !config.TryReconnect { transport.SendFault() transport.errChan <- err @@ -254,6 +259,10 @@ func (transport *Transport) readLoopTCPConn(conn net.Conn, logger zerolog.Logger logger.Trace().Type("type", packet).Msg("packet") transport.api.Notification(NewPacketNotification(packet, from, to, time.Now())) + + if dataPacket, ok := packet.(*data.Packet); ok { + data.ReleasePacket(dataPacket) + } } }() } @@ -289,30 +298,31 @@ func (transport *Transport) handlePacketEvent(message PacketMessage) error { if message.Id() == 0 { eventLogger.Info().Msg("broadcasting packet id 0") - data, err := transport.encoder.Encode(message.Packet) + buf, err := transport.encoder.Encode(message.Packet) if err != nil { eventLogger.Error().Stack().Err(err).Msg("encode") transport.errChan <- err return err } + defer transport.encoder.ReleaseBuffer(buf) + data := buf.Bytes() - transport.connectionsMx.Lock() - defer transport.connectionsMx.Unlock() + transport.connectionsMx.RLock() + defer transport.connectionsMx.RUnlock() for target, conn := range transport.connections { - eventLogger := eventLogger.With().Str("target", string(target)).Logger() - + targetName := string(target) totalWritten := 0 for totalWritten < len(data) { n, err := conn.Write(data[totalWritten:]) - eventLogger.Trace().Int("amount", n).Msg("written chunk") + eventLogger.Trace().Str("target", targetName).Int("amount", n).Msg("written chunk") totalWritten += n if err != nil { - eventLogger.Error().Stack().Err(err).Msg("write") + eventLogger.Error().Str("target", targetName).Stack().Err(err).Msg("write") transport.errChan <- err return err } } - eventLogger.Info().Msg("sent") + eventLogger.Info().Str("target", targetName).Msg("sent") } return nil } @@ -328,11 +338,11 @@ func (transport *Transport) handlePacketEvent(message PacketMessage) error { eventLogger.Info().Msg("sending") conn, err := func() (net.Conn, error) { - transport.connectionsMx.Lock() - defer transport.connectionsMx.Unlock() + transport.connectionsMx.RLock() + defer transport.connectionsMx.RUnlock() conn, ok := transport.connections[target] if !ok { - eventLogger.Warn().Msg("target not connected") + eventLogger.Warn().Msg("target not connected") err := ErrConnClosed{Target: target} return nil, err @@ -344,12 +354,14 @@ func (transport *Transport) handlePacketEvent(message PacketMessage) error { return err } - data, err := transport.encoder.Encode(message.Packet) + buf, err := transport.encoder.Encode(message.Packet) if err != nil { eventLogger.Error().Stack().Err(err).Msg("encode") transport.errChan <- err return err } + defer transport.encoder.ReleaseBuffer(buf) + data := buf.Bytes() totalWritten := 0 for totalWritten < len(data) { @@ -413,14 +425,30 @@ func (transport *Transport) HandleUDPServer(server *udp.Server) { } } +func (transport *Transport) replicateFault(packet abstraction.Packet, logger zerolog.Logger) { + logger.Info().Msg("replicating packet with id 0 to all boards") + err := transport.handlePacketEvent(NewPacketMessage(packet)) + if err != nil { + logger.Error().Err(err).Msg("failed to replicate packet") + } +} + // handleUDPPacket handles a single UDP packet received by the UDP server func (transport *Transport) handleUDPPacket(udpPacket udp.Packet) { srcAddr := fmt.Sprintf("%s:%d", udpPacket.SourceIP, udpPacket.SourcePort) dstAddr := fmt.Sprintf("%s:%d", udpPacket.DestIP, udpPacket.DestPort) // Create a reader from the payload - reader := bytes.NewReader(udpPacket.Payload) - + readerAny := transport.byteReaderPool.Get() + var reader *bytes.Reader + if readerAny != nil { + reader = readerAny.(*bytes.Reader) + reader.Reset(udpPacket.Payload) + } else { + reader = bytes.NewReader(udpPacket.Payload) + } + defer transport.byteReaderPool.Put(reader) + // Decode the packet packet, err := transport.decoder.DecodeNext(reader) if err != nil { @@ -435,15 +463,15 @@ func (transport *Transport) handleUDPPacket(udpPacket udp.Packet) { // Intercept packets with id == 0 and replicate if transport.propagateFault && packet.Id() == 0 { - transport.logger.Info().Msg("replicating packet with id 0 to all boards") - err := transport.handlePacketEvent(NewPacketMessage(packet)) - if err != nil { - transport.logger.Error().Err(err).Msg("failed to replicate packet") - } + transport.replicateFault(packet, transport.logger) } // Send notification transport.api.Notification(NewPacketNotification(packet, srcAddr, dstAddr, udpPacket.Timestamp)) + + if dataPacket, ok := packet.(*data.Packet); ok { + data.ReleasePacket(dataPacket) + } } // handleConversation is called when the sniffer detects a new conversation and handles its specific packets @@ -463,14 +491,15 @@ func (transport *Transport) handleConversation(socket network.Socket, reader io. // Intercept packets with id == 0 and replicate if transport.propagateFault && packet.Id() == 0 { - conversationLogger.Info().Msg("replicating packet with id 0 to all boards") - err := transport.handlePacketEvent(NewPacketMessage(packet)) - if err != nil { - conversationLogger.Error().Err(err).Msg("failed to replicate packet") - } + transport.replicateFault(packet, transport.logger) } + // Send notification transport.api.Notification(NewPacketNotification(packet, srcAddr, dstAddr, time.Now())) + + if dataPacket, ok := packet.(*data.Packet); ok { + data.ReleasePacket(dataPacket) + } } }() } diff --git a/backend/pkg/transport/transport_test.go b/backend/pkg/transport/transport_test.go index 0a1c9fd6f..428dc4f13 100644 --- a/backend/pkg/transport/transport_test.go +++ b/backend/pkg/transport/transport_test.go @@ -1,18 +1,30 @@ package transport import ( + "bytes" "context" "encoding/binary" "fmt" + "io" "net" + "os" + "strings" "sync" "testing" "time" "github.com/HyperloopUPV-H8/h9-backend/pkg/abstraction" + "github.com/HyperloopUPV-H8/h9-backend/pkg/transport/network" + "github.com/HyperloopUPV-H8/h9-backend/pkg/transport/network/sniffer" "github.com/HyperloopUPV-H8/h9-backend/pkg/transport/network/tcp" + "github.com/HyperloopUPV-H8/h9-backend/pkg/transport/network/tftp" + "github.com/HyperloopUPV-H8/h9-backend/pkg/transport/network/udp" "github.com/HyperloopUPV-H8/h9-backend/pkg/transport/packet/data" "github.com/HyperloopUPV-H8/h9-backend/pkg/transport/presentation" + "github.com/google/gopacket/layers" + "github.com/google/gopacket/pcap" + "github.com/google/gopacket/pcapgo" + tftpv3 "github.com/pin/tftp/v3" "github.com/rs/zerolog" ) @@ -75,6 +87,26 @@ func (api *TestTransportAPI) Reset() { api.notifications = api.notifications[:0] } +// simpleConn is a net.Conn with specified local and remote addresses +type simpleConn struct { + net.Conn + local net.Addr + remote net.Addr +} + +func (c *simpleConn) LocalAddr() net.Addr { return c.local } +func (c *simpleConn) RemoteAddr() net.Addr { return c.remote } + +func defaultLogger() zerolog.Logger { + return zerolog.New(zerolog.Nop()) +} + +// noopTransportAPI is a no-op implementation of abstraction.TransportAPI +type noopTransportAPI struct{} + +func (noopTransportAPI) Notification(abstraction.TransportNotification) {} +func (noopTransportAPI) ConnectionUpdate(abstraction.TransportTarget, bool) {} + // MockBoardServer simulates a vehicle board type MockBoardServer struct { address string @@ -234,6 +266,7 @@ func createTestTransport(t *testing.T) (*Transport, *TestTransportAPI) { enc := presentation.NewEncoder(binary.BigEndian, logger) dec := presentation.NewDecoder(binary.BigEndian, logger) wireTestPacketCodec(enc, dec, abstraction.PacketId(100)) + wireTestPacketCodec(enc, dec, abstraction.PacketId(0)) transport := NewTransport(logger). @@ -255,6 +288,19 @@ func getAvailablePort(t testing.TB) string { return listener.Addr().String() } +func getAvailableUDPPort(t testing.TB) uint16 { + addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Failed to resolve UDP addr: %v", err) + } + conn, err := net.ListenUDP("udp", addr) + if err != nil { + t.Fatalf("Failed to listen UDP: %v", err) + } + defer conn.Close() + return uint16(conn.LocalAddr().(*net.UDPAddr).Port) +} + // waitForCondition waits for a condition to be true within a timeout func waitForCondition(condition func() bool, timeout time.Duration, message string) error { deadline := time.Now().Add(timeout) @@ -333,39 +379,206 @@ func TestTransport_SetTargetIp(t *testing.T) { } } -func TestTransport_InvalidInputs(t *testing.T) { - transport, _ := createTestTransport(t) +func TestWithTFTP(t *testing.T) { + tr := NewTransport(defaultLogger()) + tr.SetAPI(noopTransportAPI{}) + client := &tftp.Client{} + + out := tr.WithTFTP(client) + if out.tftp != client { + t.Fatalf("expected tftp client to be set") + } +} + +func TestTransportErrors(t *testing.T) { + tests := []struct { + err error + want string + }{ + {ErrUnrecognizedEvent{Event: PacketEvent}, "unrecognized event packet"}, + {ErrTargetAlreadyConnected{Target: "X"}, "X is already connected"}, + {ErrUnrecognizedId{Id: 7}, "could not find target for packet with id 7"}, + {ErrConnClosed{Target: "Y"}, "connection with Y is closed"}, + {ErrUnknownTarget{Remote: &net.TCPAddr{IP: net.ParseIP("1.2.3.4"), Port: 1234}}, "unknown target for 1.2.3.4:1234"}, + } + + for _, tt := range tests { + if got := tt.err.Error(); !strings.Contains(got, tt.want) { + t.Fatalf("expected %q to contain %q", got, tt.want) + } + } +} + +func TestMessages(t *testing.T) { + pm := NewPacketMessage(nil) + if pm.Event() != PacketEvent { + t.Fatalf("packet event mismatch") + } + + fr := bytes.NewBuffer(nil) + fwm := NewFileWriteMessage("a.bin", fr) + if fwm.Event() != FileWriteEvent || fwm.Filename() != "a.bin" { + t.Fatalf("file write message mismatch") + } + + fw := bytes.NewBuffer(nil) + frm := NewFileReadMessage("b.bin", fw) + if frm.Event() != FileReadEvent || frm.Filename() != "b.bin" { + t.Fatalf("file read message mismatch") + } +} + +func TestNotifications(t *testing.T) { + pn := NewPacketNotification(nil, "from", "to", zeroTime) + if pn.Event() != PacketEvent || pn.From != "from" || pn.To != "to" { + t.Fatalf("packet notification mismatch") + } + + en := NewErrorNotification(io.EOF) + if en.Event() != ErrorEvent || en.Err != io.EOF { + t.Fatalf("error notification mismatch") + } +} + +func TestSetpropagateFault(t *testing.T) { + tr := NewTransport(defaultLogger()) + tr.SetAPI(noopTransportAPI{}) + if tr.propagateFault { + t.Fatalf("expected propagateFault false by default") + } + tr.SetpropagateFault(true) + if !tr.propagateFault { + t.Fatalf("expected propagateFault true after setter") + } +} + +func TestTargetFromTCPConnKnown(t *testing.T) { + tr := NewTransport(defaultLogger()) + tr.SetAPI(noopTransportAPI{}) + tr.ipToTarget["127.0.0.1"] = "KNOWN" + pr, pw := net.Pipe() + defer pw.Close() + conn := &simpleConn{ + Conn: pr, + local: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1}, + remote: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 2}, + } + + target, err := tr.targetFromTCPConn(conn) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if target != "KNOWN" { + t.Fatalf("expected target KNOWN, got %s", target) + } +} - // Test invalid ID input - err := transport.SetIdTarget(0, "") +func TestTargetFromTCPConnUnknown(t *testing.T) { + tr := NewTransport(defaultLogger()) + tr.SetAPI(noopTransportAPI{}) + pr, pw := net.Pipe() + defer pw.Close() + conn := &simpleConn{ + Conn: pr, + local: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1}, + remote: &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 2}, + } + + _, err := tr.targetFromTCPConn(conn) if err == nil { - t.Errorf("Expected error for invalid ID input, got nil") + t.Fatalf("expected error for unknown target") + } + if _, ok := err.(ErrUnknownTarget); !ok { + t.Fatalf("expected ErrUnknownTarget, got %T", err) + } +} + +func TestRejectIfConnectedTCPConn(t *testing.T) { + tr := NewTransport(defaultLogger()) + tr.SetAPI(noopTransportAPI{}) + tr.connections["X"] = &simpleConn{} + + // new conn to reject + pr, pw := net.Pipe() + defer pw.Close() + conn := &simpleConn{ + Conn: pr, + local: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1}, + remote: &net.TCPAddr{IP: net.ParseIP("127.0.0.2"), Port: 2}, + } + + err := tr.rejectIfConnectedTCPConn("X", conn, defaultLogger()) + if _, ok := err.(ErrTargetAlreadyConnected); !ok { + t.Fatalf("expected ErrTargetAlreadyConnected, got %v", err) + } + // conn should be closed + if _, werr := conn.Write([]byte("test")); werr == nil { + t.Fatalf("expected write to fail on closed conn") } +} - // Test invalid IP input - err = transport.SetTargetIp("", "") +func TestHandlePacketEvent_TargetNotConnected(t *testing.T) { + tr, _ := createTestTransport(t) + tr.SetpropagateFault(false) + tr.idToTarget[42] = "TARGET" + // encoder/decoder wired only for id 100; id 42 will cause ErrUnexpectedId in encoder + pkt := data.NewPacket(42) + err := tr.handlePacketEvent(NewPacketMessage(pkt)) if err == nil { - t.Errorf("Expected error for invalid IP input, got nil") + t.Fatalf("expected error for missing encoder/connection") } } -func TestTransport_RemoveTargets(t *testing.T) { - transport, _ := createTestTransport(t) +func TestReplicateFaultBroadcast(t *testing.T) { + tr, api := createTestTransport(t) + tr.SetpropagateFault(true) + // create a connection to receive broadcast + c1, c2 := net.Pipe() + tr.connectionsMx.Lock() + tr.connections["TARGET"] = c1 + tr.connectionsMx.Unlock() + defer c1.Close() + defer c2.Close() - // Add entries - transport.SetIdTarget(100, "TEST_BOARD") - transport.SetTargetIp("192.168.1.100", "TEST_BOARD") + go tr.replicateFault(data.NewPacket(0), tr.logger) - // Remove entries - delete(transport.idToTarget, 100) - delete(transport.ipToTarget, "192.168.1.100") + buf := make([]byte, 2) + if _, err := io.ReadFull(c2, buf); err != nil { + t.Fatalf("expected broadcast data, got err %v", err) + } + // ensure no error notifications + if len(api.GetNotifications()) != 0 { + t.Fatalf("expected no notifications during replicateFault") + } +} - // Verify removal - if _, exists := transport.idToTarget[100]; exists { - t.Errorf("Expected ID 100 to be removed, but it still exists") +func TestHandleUDPPacket_Success(t *testing.T) { + tr, api := createTestTransport(t) + tr.SetpropagateFault(false) + + pkt := data.NewPacket(100) + pkt.SetTimestamp(time.Unix(0, 0)) + buf, err := tr.encoder.Encode(pkt) + if err != nil { + t.Fatalf("encode failed: %v", err) } - if _, exists := transport.ipToTarget["192.168.1.100"]; exists { - t.Errorf("Expected IP 192.168.1.100 to be removed, but it still exists") + + payload := append([]byte(nil), buf.Bytes()...) + tr.encoder.ReleaseBuffer(buf) + + udpPkt := udp.Packet{ + SourceIP: net.ParseIP("127.0.0.1"), + SourcePort: 9999, + DestIP: net.ParseIP("127.0.0.1"), + DestPort: 9998, + Payload: payload, + Timestamp: time.Unix(0, 0), + } + + tr.handleUDPPacket(udpPkt) + + if len(api.GetNotifications()) == 0 { + t.Fatalf("expected notification after UDP packet") } } @@ -691,6 +904,239 @@ func TestTransport_ReconnectionBehavior(t *testing.T) { } } +func TestHandleServer_AcceptsAndDispatches(t *testing.T) { + tr, api := createTestTransport(t) + target := abstraction.TransportTarget("SERVER_TARGET") + tr.SetTargetIp("127.0.0.1", target) + tr.SetIdTarget(100, target) + + local := getAvailablePort(t) + cfg := tcp.NewServerConfig() + ctx, cancel := context.WithCancel(context.Background()) + cfg.Context = ctx + defer cancel() + + done := make(chan struct{}) + go func() { + _ = tr.HandleServer(cfg, local) + close(done) + }() + + var conn net.Conn + var err error + deadline := time.Now().Add(500 * time.Millisecond) + for time.Now().Before(deadline) { + conn, err = net.Dial("tcp", local) + if err == nil { + break + } + time.Sleep(20 * time.Millisecond) + } + if conn == nil { + t.Fatalf("failed to dial server: %v", err) + } + defer conn.Close() + + packet := data.NewPacket(100) + packet.SetTimestamp(time.Unix(0, 0)) + buf, err := tr.encoder.Encode(packet) + if err != nil { + t.Fatalf("encode failed: %v", err) + } + defer tr.encoder.ReleaseBuffer(buf) + + if _, err := conn.Write(buf.Bytes()); err != nil { + t.Fatalf("failed to write packet: %v", err) + } + + if err := waitForCondition(func() bool { + return len(api.GetNotifications()) > 0 + }, 2*time.Second, "Should receive notification from server connection"); err != nil { + t.Fatal(err) + } + + cancel() + select { + case <-done: + case <-time.After(500 * time.Millisecond): + } +} + +func TestHandleUDPServer_Dispatches(t *testing.T) { + tr, api := createTestTransport(t) + tr.SetpropagateFault(false) + + port := getAvailableUDPPort(t) + logger := zerolog.Nop() + server := udp.NewServer("127.0.0.1", port, &logger) + if err := server.Start(); err != nil { + t.Fatalf("failed to start UDP server: %v", err) + } + defer server.Stop() + + go tr.HandleUDPServer(server) + + packet := data.NewPacket(100) + packet.SetTimestamp(time.Unix(0, 0)) + buf, err := tr.encoder.Encode(packet) + if err != nil { + t.Fatalf("encode failed: %v", err) + } + defer tr.encoder.ReleaseBuffer(buf) + + conn, err := net.DialUDP("udp", nil, &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: int(port)}) + if err != nil { + t.Fatalf("failed to dial UDP server: %v", err) + } + defer conn.Close() + + if _, err := conn.Write(buf.Bytes()); err != nil { + t.Fatalf("failed to send UDP packet: %v", err) + } + + if err := waitForCondition(func() bool { + return len(api.GetNotifications()) > 0 + }, 2*time.Second, "Should receive notification from UDP server"); err != nil { + t.Fatal(err) + } +} + +func TestHandleFileWriteRead_WithRealTFTP(t *testing.T) { + readHandler := func(filename string, rf io.ReaderFrom) error { + _, err := rf.ReadFrom(bytes.NewBufferString("from-server")) + return err + } + writeBuf := &bytes.Buffer{} + writeHandler := func(filename string, wt io.WriterTo) error { + _, err := wt.WriteTo(writeBuf) + return err + } + server := tftpv3.NewServer(readHandler, writeHandler) + addr := fmt.Sprintf("127.0.0.1:%d", getAvailableUDPPort(t)) + go func() { + _ = server.ListenAndServe(addr) + }() + defer server.Shutdown() + time.Sleep(20 * time.Millisecond) + + client, err := tftp.NewClient(addr) + if err != nil { + t.Fatalf("failed to create tftp client: %v", err) + } + + tr := NewTransport(defaultLogger()).WithTFTP(client) + tr.SetAPI(NewTestTransportAPI()) + + if err := tr.handleFileWrite(NewFileWriteMessage("file.bin", bytes.NewBufferString("hello"))); err != nil { + t.Fatalf("handleFileWrite error: %v", err) + } + if writeBuf.String() != "hello" { + t.Fatalf("expected written data 'hello', got %q", writeBuf.String()) + } + + out := &bytes.Buffer{} + if err := tr.handleFileRead(NewFileReadMessage("file.bin", out)); err != nil { + t.Fatalf("handleFileRead error: %v", err) + } + if out.String() != "from-server" { + t.Fatalf("expected read data 'from-server', got %q", out.String()) + } +} + +func TestHandleFileWriteRead_ErrorPath(t *testing.T) { + // Point to an unused UDP port to force WriteFile/ReadFile errors. + addr := fmt.Sprintf("127.0.0.1:%d", getAvailableUDPPort(t)) + client, err := tftp.NewClient(addr, tftp.WithTimeout(50*time.Millisecond), tftp.WithRetries(1)) + if err != nil { + t.Fatalf("failed to create tftp client: %v", err) + } + + tr := NewTransport(defaultLogger()).WithTFTP(client) + api := NewTestTransportAPI() + tr.SetAPI(api) + + if err := tr.handleFileWrite(NewFileWriteMessage("file.bin", bytes.NewBufferString("hello"))); err == nil { + t.Fatalf("expected error writing to unreachable TFTP server") + } + if err := waitForCondition(func() bool { return len(api.GetNotifications()) > 0 }, time.Second, "error notification"); err != nil { + t.Fatalf("expected error notification") + } + + api.Reset() + if err := tr.handleFileRead(NewFileReadMessage("file.bin", &bytes.Buffer{})); err == nil { + t.Fatalf("expected error reading from unreachable TFTP server") + } + if err := waitForCondition(func() bool { return len(api.GetNotifications()) > 0 }, time.Second, "error notification"); err != nil { + t.Fatalf("expected error notification") + } +} + +func TestHandleSniffer_Dispatches(t *testing.T) { + tr, api := createTestTransport(t) + + // empty pcap (header only) to drive HandleSniffer through EOF path + tmp, err := os.CreateTemp("", "sniffer*.pcap") + if err != nil { + t.Fatalf("failed to create temp file: %v", err) + } + writer := pcapgo.NewWriter(tmp) + if err := writer.WriteFileHeader(65535, layers.LinkTypeEthernet); err != nil { + t.Fatalf("write header failed: %v", err) + } + tmp.Close() + + handle, err := pcap.OpenOffline(tmp.Name()) + if err != nil { + t.Fatalf("failed to open pcap: %v", err) + } + sn := sniffer.New(handle, nil, defaultLogger()) + + done := make(chan struct{}) + go func() { + tr.HandleSniffer(sn) + close(done) + }() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatalf("HandleSniffer did not return on EOF") + } + + // No notifications expected; just ensure no panic/block. + _ = api +} + +func TestHandleConversation_DispatchesAndStopsOnError(t *testing.T) { + tr, api := createTestTransport(t) + + pkt := data.NewPacket(100) + pkt.SetTimestamp(time.Unix(0, 0)) + buf, err := tr.encoder.Encode(pkt) + if err != nil { + t.Fatalf("encode failed: %v", err) + } + defer tr.encoder.ReleaseBuffer(buf) + + socket := network.Socket{ + SrcIP: "127.0.0.1", + SrcPort: 8000, + DstIP: "127.0.0.1", + DstPort: 8001, + } + + reader := bytes.NewReader(buf.Bytes()) + tr.handleConversation(socket, reader) + + if err := waitForCondition(func() bool { return len(api.GetNotifications()) >= 1 }, time.Second, "packet notification"); err != nil { + t.Fatal(err) + } + // After the first packet, DecodeNext will hit EOF and SendFault will result in an error notification. + if err := waitForCondition(func() bool { return len(api.GetNotifications()) >= 2 }, 2*time.Second, "error notification"); err != nil { + t.Fatal(err) + } +} + // Helper function to mimic errors.As behavior func ErrorAs(err error, target interface{}) bool { switch target := target.(type) {