From 1ecd7830babf859c861ca1c7ad851f4283b47889 Mon Sep 17 00:00:00 2001 From: zhoujian2 Date: Tue, 24 Aug 2021 11:48:10 +0800 Subject: [PATCH] fix bug and improve implementation --- .gitignore | 1 + main.go | 95 +++++++++++++++++++++++++++++++++--------------------- 2 files changed, 59 insertions(+), 37 deletions(-) diff --git a/.gitignore b/.gitignore index 37ba729..af570f8 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ _cgo* *.prof bench gonetbench +main diff --git a/main.go b/main.go index 19d6869..62c3411 100644 --- a/main.go +++ b/main.go @@ -4,29 +4,28 @@ import ( "flag" "fmt" "net" - "time" - "sync" "os" "os/signal" + "sync" + "syscall" + "time" ) var ( Addr string - Cmd string Num uint Size uint - wg sync.WaitGroup - quit chan bool - sc chan bool // sent packets - lc chan bool // lost packets - ac chan bool // acctive connections - pc chan bool // pending connections - ec chan bool // error connections + quit chan struct{} + sc chan struct{} // sent packets + lc chan struct{} // lost packets + ac chan struct{} // acctive connections + pc chan struct{} // pending connections + ec chan struct{} // error connections pack []byte ) func init() { - flag.UintVar(&Num, "n", 10, "number of concurrent clients") + flag.UintVar(&Num, "n", 10, "number of concurrent connections") flag.UintVar(&Size, "s", 1, "packet size (in bytes)") flag.Parse() @@ -38,13 +37,13 @@ func init() { } func usage() { - fmt.Fprintf(os.Stderr, "Usage: time %s [options] addr cmd\n", os.Args[0]) + fmt.Fprintf(os.Stderr, "Usage: %s [options] addr cmd\n", os.Args[0]) flag.PrintDefaults() } -func status() { - wg.Add(1) +func status(wg *sync.WaitGroup) { defer wg.Done() + acs, ecs, pcs, scs, lcs := 0, 0, 0, 0, 0 fmt.Printf("Benchmarking %s with %d concurrent connections:\n\n", Addr, Num) for { @@ -71,54 +70,76 @@ func status() { func read(conn net.Conn) { for { - buf := make([]byte, 1024) - conn.Read(buf) + select { + case <-quit: + return + default: + buf := make([]byte, 1024) + _, _ = conn.Read(buf) + } } } func client() { conn, err := net.Dial("tcp", Addr) if err != nil { - ec <- true + ec <- struct{}{} return } - ac <- true defer conn.Close() + + ac <- struct{}{} go read(conn) for { - n, err := conn.Write(pack) - if err != nil || n != len(pack) { - lc <- true + select { + case <-quit: + return + default: + n, err := conn.Write(pack) + if err != nil || n != len(pack) { + lc <- struct{}{} + } + sc <- struct{}{} + <-time.After(1 * time.Second) } - sc <- true - <-time.After(1 * time.Second) } } func startAll() { for i := 0; i < int(Num); i += 1 { - pc <- true + pc <- struct{}{} go client() } } func main() { - quit = make(chan bool) - ac = make(chan bool) - sc = make(chan bool) - pc = make(chan bool) - ec = make(chan bool) - lc = make(chan bool) + quit = make(chan struct{}, 2*Num+1) + ac = make(chan struct{}, 128) + sc = make(chan struct{}, 128) + pc = make(chan struct{}, 128) + ec = make(chan struct{}, 128) + lc = make(chan struct{}, 128) pack = make([]byte, Size) - for i, _ := range pack { + for i := range pack { pack[i] = 'x' } - - go status() - go startAll() - <-signal.Incoming - quit <- true - wg.Wait() + wg := &sync.WaitGroup{} + defer func() { + wg.Wait() + }() + + wg.Add(1) + go status(wg) + startAll() + + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt, syscall.SIGTERM, syscall.SIGINT) + <-sig + + var i uint + for i = 0; i < 2*Num+1; i++ { + quit <- struct{}{} + } }