diff --git a/pkg/mtr/mtr.go b/pkg/mtr/mtr.go index 8b37f9f..8d49082 100644 --- a/pkg/mtr/mtr.go +++ b/pkg/mtr/mtr.go @@ -87,8 +87,10 @@ func (m *MTR) registerStatistic(ttl int, r icmp.ICMPReturn) *hop.HopStatistic { s.Targets = addTarget(s.Targets, r.Addr) - s.Packets = s.Packets.Prev() - s.Packets.Value = r + if m.ringBufferSize > 0 { + s.Packets = s.Packets.Prev() + s.Packets.Value = r + } if !r.Success { s.Lost++ @@ -150,7 +152,7 @@ func (m *MTR) Render(offset int) { } func (m *MTR) RunWithContext(ctx context.Context, count int) error { - err := m.discover(ctx, count) + err := m.loopDiscover(ctx, count) close(m.channel) return err } @@ -161,8 +163,8 @@ func (m *MTR) Run(count int) error { return err } -// discover discovers all hops on the route -func (m *MTR) discover(ctx context.Context, count int) error { +// loopDiscover: discovers all hops on the route +func (m *MTR) loopDiscover(ctx context.Context, count int) error { // Sequences are incrementing as we don't won't to get old replys which might be from a previous run (where we timed out and continued). // We can't use the process id as unique identifier as there might be multiple runs within a single binary, thus we use a fixed pseudo random number. rand.Seed(time.Now().UnixNano()) @@ -170,48 +172,63 @@ func (m *MTR) discover(ctx context.Context, count int) error { id := rand.Intn(math.MaxUint16) & 0xffff ipAddr := net.IPAddr{IP: net.ParseIP(m.Address)} + wg := new(sync.WaitGroup) for i := 1; i <= count; i++ { select { case <-ctx.Done(): return ErrTimeout case <-time.After(m.interval): - unknownHopsCount := 0 - - ttlLoop: - for ttl := 1; ttl < m.maxHops; ttl++ { - seq++ - select { - case <-ctx.Done(): - return ErrTimeout - case <-time.After(m.hopsleep): - var hopReturn icmp.ICMPReturn - var err error - if ipAddr.IP.To4() != nil { - hopReturn, err = icmp.SendDiscoverICMP(m.SrcAddress, &ipAddr, ttl, id, m.timeout, seq) - } else { - hopReturn, err = icmp.SendDiscoverICMPv6(m.SrcAddress, &ipAddr, ttl, id, m.timeout, seq) - } - - m.mutex.Lock() - s := m.registerStatistic(ttl, hopReturn) - s.Dest = &ipAddr - s.PID = id - m.mutex.Unlock() - m.channel <- struct{}{} - if hopReturn.Addr == m.Address { - break ttlLoop - } - if err != nil || !hopReturn.Success { - unknownHopsCount++ - if unknownHopsCount >= m.maxUnknownHops { - return ErrMaxUnknownHops - } - continue ttlLoop - } - unknownHopsCount = 0 + wg.Add(1) + go func() { + defer wg.Done() + + _ = m.discover(seq, id, ipAddr, ctx) + }() + } + } + wg.Wait() + return nil +} + +func (m *MTR) discover(seq, id int, ipAddr net.IPAddr, ctx context.Context) error { + unknownHopsCount := 0 +ttlLoop: + for ttl := 1; ttl < m.maxHops; ttl++ { + seq++ + select { + case <-ctx.Done(): + return ErrTimeout + case <-time.After(m.hopsleep): + var hopReturn icmp.ICMPReturn + var err error + if ipAddr.IP.To4() != nil { + hopReturn, err = icmp.SendDiscoverICMP(m.SrcAddress, &ipAddr, ttl, id, m.timeout, seq) + } else { + hopReturn, err = icmp.SendDiscoverICMPv6(m.SrcAddress, &ipAddr, ttl, id, m.timeout, seq) + } + + go func(ttl int, hopReturn icmp.ICMPReturn) { + m.mutex.Lock() + defer m.mutex.Unlock() + + s := m.registerStatistic(ttl, hopReturn) + s.Dest = &ipAddr + s.PID = id + }(ttl, hopReturn) + + m.channel <- struct{}{} + if hopReturn.Addr == m.Address { + break ttlLoop + } + if err != nil || !hopReturn.Success { + unknownHopsCount++ + if unknownHopsCount >= m.maxUnknownHops { + return ErrMaxUnknownHops } + continue ttlLoop } + unknownHopsCount = 0 } } return nil