From 57105460b9296a067426f3c36dc0394dbef85ae0 Mon Sep 17 00:00:00 2001 From: Night Owl Nerd <256460992+nightowlnerd@users.noreply.github.com> Date: Mon, 2 Feb 2026 07:17:58 +0100 Subject: [PATCH] feat: add parallel burst testing for >5 servers --- e2e_test.go | 84 +++++++++++++++++++++++++++++++-- main.go | 131 +++++++++++++++++++++++++++++++++++++--------------- scanner.go | 86 +++++++++++++++++++++++++++++++++- 3 files changed, 259 insertions(+), 42 deletions(-) diff --git a/e2e_test.go b/e2e_test.go index 04d531e..914075d 100644 --- a/e2e_test.go +++ b/e2e_test.go @@ -136,7 +136,7 @@ func TestE2EBurstTest(t *testing.T) { } defer mock.Close() - result := BurstTest(mock.ip, "test.example.com", mock.port, 2*time.Second) + result := BurstTest(context.Background(), mock.ip, "test.example.com", mock.port, 2*time.Second) if result.Queries != BurstQueries { t.Errorf("Expected %d queries, got %d", BurstQueries, result.Queries) @@ -243,7 +243,7 @@ func TestE2EBurstQPS(t *testing.T) { } defer mock.Close() - result := BurstTest(mock.ip, "test.example.com", mock.port, 2*time.Second) + result := BurstTest(context.Background(), mock.ip, "test.example.com", mock.port, 2*time.Second) if result.QPS() <= 0 { t.Errorf("QPS should be positive, got %.2f", result.QPS()) @@ -253,6 +253,84 @@ func TestE2EBurstQPS(t *testing.T) { } } +func TestE2EParallelBurstTest(t *testing.T) { + // Start 3 mock servers + var mocks []*mockDNSServer + var ips []string + for i := 0; i < 3; i++ { + mock, err := newMockDNSServer("93.184.216.34") + if err != nil { + t.Fatalf("Failed to start mock %d: %v", i, err) + } + mocks = append(mocks, mock) + ips = append(ips, mock.ip) + } + defer func() { + for _, m := range mocks { + m.Close() + } + }() + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // All mocks use same port pattern, so we use first mock's port + resultChan := ParallelBurstTest(ctx, ips, "test.example.com", mocks[0].port, 2*time.Second, 3) + + var count, passed int + for r := range resultChan { + count++ + if r.Passed() { + passed++ + } + } + + if count != 3 { + t.Errorf("Expected 3 results, got %d", count) + } + if passed != 3 { + t.Errorf("Expected 3 passed, got %d", passed) + } +} + +func TestE2EBurstTestContextCancellation(t *testing.T) { + mock, err := newMockDNSServer("93.184.216.34") + if err != nil { + t.Fatalf("Failed to start mock DNS: %v", err) + } + defer mock.Close() + + // Cancel immediately + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + result := BurstTest(ctx, mock.ip, "test.example.com", mock.port, 2*time.Second) + + // Should return early with partial/no results + if result.Successful == BurstQueries { + t.Error("Expected early termination with cancelled context") + } +} + +func TestBurstProgress(t *testing.T) { + prog := NewBurstProgress(10, true) + + prog.Tested() + prog.Tested() + prog.Passed() + + tested, passed, total := prog.Stats() + if tested != 2 { + t.Errorf("Expected 2 tested, got %d", tested) + } + if passed != 1 { + t.Errorf("Expected 1 passed, got %d", passed) + } + if total != 10 { + t.Errorf("Expected total 10, got %d", total) + } +} + func TestE2EJSONServerFromBurstResult(t *testing.T) { mock, err := newMockDNSServer("93.184.216.34") if err != nil { @@ -260,7 +338,7 @@ func TestE2EJSONServerFromBurstResult(t *testing.T) { } defer mock.Close() - result := BurstTest(mock.ip, "test.example.com", mock.port, 2*time.Second) + result := BurstTest(context.Background(), mock.ip, "test.example.com", mock.port, 2*time.Second) server := JSONServer{ IP: result.IP, diff --git a/main.go b/main.go index eb96684..3ad0350 100644 --- a/main.go +++ b/main.go @@ -127,6 +127,11 @@ func main() { // Set data directory DataDir = *dataDir + // JSON mode disables progress - machine output only + if *jsonOutput { + *progress = false + } + if *showVersion { fmt.Printf("dnscan %s\n", version) os.Exit(0) @@ -355,50 +360,94 @@ resultLoop: // Phase 3: Burst test to verify servers handle concurrent load var burstResults []*BurstResult if *domain != "" && len(workingDNS) > 0 { - if *progress { - fmt.Fprintf(os.Stderr, "\nBurst testing %d candidates (%d queries, %d%% required)...\n", - len(workingDNS), BurstQueries, BurstMinSuccess) - } - total := len(workingDNS) - width := len(fmt.Sprintf("%d", total)) - for i, ip := range workingDNS { - // Check for interrupt - select { - case <-ctx.Done(): - if *progress { - fmt.Fprintf(os.Stderr, "\nInterrupted during burst test\n") - } - goto burstDone - default: - } + if total <= 5 { + // Sequential for small lists - nicer per-IP output if *progress { - fmt.Fprintf(os.Stderr, "[%*d/%d] %-15s ", width, i+1, total, ip) + fmt.Fprintf(os.Stderr, "\nBurst testing %d candidates (%d queries, %d%% required)...\n", + total, BurstQueries, BurstMinSuccess) } - result := BurstTest(ip, *domain, 53, *timeout) + width := len(fmt.Sprintf("%d", total)) + for i, ip := range workingDNS { + select { + case <-ctx.Done(): + if *progress { + fmt.Fprintf(os.Stderr, "\nInterrupted during burst test\n") + } + goto burstDone + default: + } - if result.Passed() { - burstResults = append(burstResults, result) if *progress { - // Green for >=85%, yellow for 70-84% - color := "\033[33m" // yellow - if result.SuccessRate() >= 85 { - color = "\033[32m" // green + fmt.Fprintf(os.Stderr, "[%*d/%d] %-15s ", width, i+1, total, ip) + } + + result := BurstTest(ctx, ip, *domain, 53, *timeout) + + if result.Passed() { + burstResults = append(burstResults, result) + if *progress { + color := "\033[33m" + if result.SuccessRate() >= 85 { + color = "\033[32m" + } + fmt.Fprintf(os.Stderr, "%sOK %.0f%% (%.1f qps, p50=%v)\033[0m\n", + color, result.SuccessRate(), result.QPS(), result.P50().Round(time.Millisecond)) + } + } else { + if *progress { + fmt.Fprintf(os.Stderr, "FAIL %.0f%%\n", result.SuccessRate()) } - fmt.Fprintf(os.Stderr, "%sOK %.0f%% (%.1f qps, p50=%v)\033[0m\n", - color, result.SuccessRate(), result.QPS(), result.P50().Round(time.Millisecond)) } - } else { - if *progress { - fmt.Fprintf(os.Stderr, "FAIL %.0f%%\n", result.SuccessRate()) + } + } else { + // Parallel for larger lists + burstWorkers := min(total, 10) + if *progress { + fmt.Fprintf(os.Stderr, "\nBurst testing %d candidates in parallel (%d workers)...\n", + total, burstWorkers) + } + + burstProg := NewBurstProgress(total, *progress) + var progressDone chan struct{} + + if *progress { + progressDone = make(chan struct{}) + go func() { + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + for { + select { + case <-ticker.C: + tested, passed, tot := burstProg.Stats() + fmt.Fprintf(os.Stderr, "\rBurst testing: %d/%d tested, %d passed ", tested, tot, passed) + case <-ctx.Done(): + return + case <-progressDone: + return + } + } + }() + } + + resultChan := ParallelBurstTest(ctx, workingDNS, *domain, 53, *timeout, burstWorkers) + for result := range resultChan { + burstProg.Tested() + if result.Passed() { + burstProg.Passed() + burstResults = append(burstResults, result) } } + + if progressDone != nil { + close(progressDone) + fmt.Fprintf(os.Stderr, "\r \r") + } } burstDone: - // Sort by QPS descending (highest throughput first) sort.Slice(burstResults, func(i, j int) bool { return burstResults[i].QPS() > burstResults[j].QPS() }) @@ -406,9 +455,16 @@ resultLoop: if *progress { fmt.Fprintf(os.Stderr, "---\n") fmt.Fprintf(os.Stderr, "Burst test: %d/%d passed (sorted by throughput)\n", len(burstResults), len(workingDNS)) + for _, r := range burstResults { + color := "\033[33m" + if r.SuccessRate() >= 85 { + color = "\033[32m" + } + fmt.Fprintf(os.Stderr, "%s%-15s OK %.0f%% (%.1f qps, p50=%v)\033[0m\n", + color, r.IP, r.SuccessRate(), r.QPS(), r.P50().Round(time.Millisecond)) + } } - // Extract sorted IPs workingDNS = nil for _, r := range burstResults { workingDNS = append(workingDNS, r.IP) @@ -457,15 +513,14 @@ resultLoop: } enc.Encode(output) } else { - // Plain text output (default) + // Plain text output - skip stdout when progress shows colored stats if len(workingDNS) > 0 { - if *progress { - fmt.Fprintf(os.Stderr, "---\n") - } - for _, ip := range workingDNS { - if outFile != nil { + if outFile != nil { + for _, ip := range workingDNS { fmt.Fprintln(outFile, ip) - } else { + } + } else if !*progress { + for _, ip := range workingDNS { fmt.Println(ip) } } diff --git a/scanner.go b/scanner.go index 7df53d4..8f40f49 100644 --- a/scanner.go +++ b/scanner.go @@ -285,7 +285,7 @@ func (r *BurstResult) Passed() bool { } // BurstTest runs concurrent DNS queries to test server reliability under load -func BurstTest(ip, domain string, port int, timeout time.Duration) *BurstResult { +func BurstTest(ctx context.Context, ip, domain string, port int, timeout time.Duration) *BurstResult { if port == 0 { port = 53 } @@ -308,6 +308,13 @@ func BurstTest(ip, domain string, port int, timeout time.Duration) *BurstResult start := time.Now() for i := 0; i < BurstQueries; i++ { + select { + case <-ctx.Done(): + result.Duration = time.Since(start) + return result + default: + } + wg.Add(1) sem <- struct{}{} @@ -315,6 +322,15 @@ func BurstTest(ip, domain string, port int, timeout time.Duration) *BurstResult defer wg.Done() defer func() { <-sem }() + select { + case <-ctx.Done(): + mu.Lock() + result.Failed++ + mu.Unlock() + return + default: + } + subdomain := randomSlipstreamSubdomain() m := new(dns.Msg) m.SetQuestion(dns.Fqdn(subdomain+"."+domain), dns.TypeTXT) @@ -338,3 +354,71 @@ func BurstTest(ip, domain string, port int, timeout time.Duration) *BurstResult result.Duration = time.Since(start) return result } + +// BurstProgress tracks parallel burst test progress with atomic counters +type BurstProgress struct { + total int64 + tested int64 + passed int64 + enabled bool +} + +func NewBurstProgress(total int, enabled bool) *BurstProgress { + return &BurstProgress{total: int64(total), enabled: enabled} +} + +func (p *BurstProgress) Tested() { atomic.AddInt64(&p.tested, 1) } +func (p *BurstProgress) Passed() { atomic.AddInt64(&p.passed, 1) } +func (p *BurstProgress) Stats() (tested, passed, total int64) { + return atomic.LoadInt64(&p.tested), atomic.LoadInt64(&p.passed), p.total +} + +// ParallelBurstTest runs burst tests on multiple IPs concurrently +func ParallelBurstTest(ctx context.Context, ips []string, domain string, port int, + timeout time.Duration, workers int) <-chan *BurstResult { + + results := make(chan *BurstResult, workers) + ipChan := make(chan string, len(ips)) + + go func() { + defer close(ipChan) + for _, ip := range ips { + select { + case ipChan <- ip: + case <-ctx.Done(): + return + } + } + }() + + var wg sync.WaitGroup + for i := 0; i < workers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-ctx.Done(): + return + case ip, ok := <-ipChan: + if !ok { + return + } + result := BurstTest(ctx, ip, domain, port, timeout) + select { + case results <- result: + case <-ctx.Done(): + return + } + } + } + }() + } + + go func() { + wg.Wait() + close(results) + }() + + return results +}