diff --git a/examples/proxyservice/server.go b/examples/proxyservice/server.go index 92f4d87d..2e41f68c 100644 --- a/examples/proxyservice/server.go +++ b/examples/proxyservice/server.go @@ -3,6 +3,7 @@ package main import ( "flag" "log" + "net" "os" "os/signal" "syscall" @@ -25,7 +26,20 @@ var ( func main() { flag.Parse() - server, err := zeroconf.RegisterProxy(*name, *service, *domain, *port, *host, []string{*ip}, []string{"txtv=0", "lo=1", "la=2"}, nil) + ifaces, _ := net.Interfaces() + + config := zeroconf.ProxyRegistrationConfig{ + Instance: "GoZeroconfGo", + Service: "_workstation._tcp", + Domain: "local.", + Port: 42424, + Host: "pc1", + IPs: []string{"txtv=0", "lo=1", "la=2"}, + Text: []string{"key=value", "env=dev"}, + Ifaces: []net.Interface{ifaces[0]}, // use the first available interface + } + + server, err := zeroconf.RegisterProxy(config) if err != nil { panic(err) } diff --git a/go.mod b/go.mod index 76ce30c4..abdf6fad 100644 --- a/go.mod +++ b/go.mod @@ -1,11 +1,18 @@ module github.com/grandcat/zeroconf -go 1.13 +go 1.22 require ( github.com/cenkalti/backoff v2.2.1+incompatible github.com/miekg/dns v1.1.41 github.com/pkg/errors v0.9.1 + github.com/stretchr/testify v1.10.0 golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect golang.org/x/sys v0.0.0-20210426080607-c94f62235c83 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 73cf68dc..1528e2d3 100644 --- a/go.sum +++ b/go.sum @@ -1,9 +1,15 @@ github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4= github.com/cenkalti/backoff v2.2.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QHaoyV4aDUVVkXQJJJ3NXXM= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/miekg/dns v1.1.41 h1:WMszZWJG0XmzbK9FEmzH2TVcqYzFesusSIB41b8KHxY= github.com/miekg/dns v1.1.41/go.mod h1:p6aan82bvRIyn+zDIv9xYNUpwa73JcSh9BKwknJysuI= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6 h1:0PC75Fz/kyMGhL0e1QnypqK2kQMqKt9csD1GnMJR+Zk= golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= @@ -18,3 +24,7 @@ golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9sn golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/server.go b/server.go old mode 100644 new mode 100755 index 4d907f93..e09ba48a --- a/server.go +++ b/server.go @@ -19,7 +19,8 @@ import ( const ( // Number of Multicast responses sent for a query message (default: 1 < x < 9) - multicastRepetitions = 2 + multicastRepetitions = 2 + MulticastInterfaceError = "[WARN] mdns: Failed to set multicast interface: %v" ) // Register a service by given arguments. This call will take the system's hostname @@ -80,13 +81,24 @@ func Register(instance, service, domain string, port int, text []string, ifaces return s, nil } +type ProxyRegistrationConfig struct { + Instance string + Service string + Domain string + Port int + Host string + IPs []string + Text []string + Ifaces []net.Interface +} + // RegisterProxy registers a service proxy. This call will skip the hostname/IP lookup and // will use the provided values. -func RegisterProxy(instance, service, domain string, port int, host string, ips []string, text []string, ifaces []net.Interface) (*Server, error) { - entry := NewServiceEntry(instance, service, domain) - entry.Port = port - entry.Text = text - entry.HostName = host +func RegisterProxy(cfg ProxyRegistrationConfig) (*Server, error) { + entry := NewServiceEntry(cfg.Instance, cfg.Service, cfg.Domain) + entry.Port = cfg.Port + entry.Text = cfg.Text + entry.HostName = cfg.Host if entry.Instance == "" { return nil, fmt.Errorf("missing service instance name") @@ -108,7 +120,7 @@ func RegisterProxy(instance, service, domain string, port int, host string, ips entry.HostName = fmt.Sprintf("%s.%s.", trimDot(entry.HostName), trimDot(entry.Domain)) } - for _, ip := range ips { + for _, ip := range cfg.IPs { ipAddr := net.ParseIP(ip) if ipAddr == nil { return nil, fmt.Errorf("failed to parse given IP: %v", ip) @@ -121,11 +133,11 @@ func RegisterProxy(instance, service, domain string, port int, host string, ips } } - if len(ifaces) == 0 { - ifaces = listMulticastInterfaces() + if len(cfg.Ifaces) == 0 { + cfg.Ifaces = listMulticastInterfaces() } - s, err := newServer(ifaces) + s, err := newServer(cfg.Ifaces) if err != nil { return nil, err } @@ -159,11 +171,11 @@ type Server struct { func newServer(ifaces []net.Interface) (*Server, error) { ipv4conn, err4 := joinUdp4Multicast(ifaces) if err4 != nil { - log.Printf("[zeroconf] no suitable IPv4 interface: %s", err4.Error()) + log.Printf("[zeroconf] not suitable IPv4 interface: %s", err4.Error()) } ipv6conn, err6 := joinUdp6Multicast(ifaces) if err6 != nil { - log.Printf("[zeroconf] no suitable IPv6 interface: %s", err6.Error()) + log.Printf("[zeroconf] not suitable IPv6 interface: %s", err6.Error()) } if err4 != nil && err6 != nil { // No supported interface left. @@ -289,7 +301,6 @@ func (s *Server) recv6(c *ipv6.PacketConn) { func (s *Server) parsePacket(packet []byte, ifIndex int, from net.Addr) error { var msg dns.Msg if err := msg.Unpack(packet); err != nil { - // log.Printf("[ERR] zeroconf: Failed to unpack packet: %v", err) return err } return s.handleQuery(&msg, ifIndex, from) @@ -314,7 +325,6 @@ func (s *Server) handleQuery(query *dns.Msg, ifIndex int, from net.Addr) error { resp.Answer = []dns.RR{} resp.Extra = []dns.RR{} if err = s.handleQuestion(q, &resp, query, ifIndex); err != nil { - // log.Printf("[ERR] zeroconf: failed to handle question %v: %v", q, err) continue } // Check if there is an answer @@ -344,24 +354,29 @@ func isKnownAnswer(resp *dns.Msg, query *dns.Msg) bool { return false } - if resp.Answer[0].Header().Rrtype != dns.TypePTR { - return false - } - answer := resp.Answer[0].(*dns.PTR) - - for _, known := range query.Answer { - hdr := known.Header() - if hdr.Rrtype != answer.Hdr.Rrtype { + for _, answerRR := range resp.Answer { + if answerRR.Header().Rrtype != dns.TypePTR { continue } - ptr := known.(*dns.PTR) - if ptr.Ptr == answer.Ptr && hdr.Ttl >= answer.Hdr.Ttl/2 { - // log.Printf("skipping known answer: %v", ptr) - return true + answer := answerRR.(*dns.PTR) + + matched := false + for _, known := range query.Answer { + if known.Header().Rrtype != dns.TypePTR { + continue + } + ptr := known.(*dns.PTR) + if ptr.Ptr == answer.Ptr && known.Header().Ttl >= answer.Header().Ttl/2 { + matched = true + break + } + } + if !matched { + return false // at least one PTR answer was not known } } - return false + return true } // handleQuestion is used to handle an incoming question @@ -369,7 +384,6 @@ func (s *Server) handleQuestion(q dns.Question, resp *dns.Msg, query *dns.Msg, i if s.service == nil { return nil } - switch q.Name { case s.service.ServiceTypeName(): s.serviceTypeName(resp, s.ttl) @@ -388,7 +402,6 @@ func (s *Server) handleQuestion(q dns.Question, resp *dns.Msg, query *dns.Msg, i default: // handle matching subtype query for _, subtype := range s.service.Subtypes { - subtype = fmt.Sprintf("%s._sub.%s", subtype, s.service.ServiceName()) if q.Name == subtype { s.composeBrowsingAnswers(resp, ifIndex) if isKnownAnswer(resp, query) { @@ -414,6 +427,21 @@ func (s *Server) composeBrowsingAnswers(resp *dns.Msg, ifIndex int) { } resp.Answer = append(resp.Answer, ptr) + // PTRs for subtypes + for _, subtype := range s.service.Subtypes { + subtypePTR := &dns.PTR{ + Hdr: dns.RR_Header{ + Name: subtype, + Rrtype: dns.TypePTR, + Class: dns.ClassINET, + Ttl: s.ttl, + }, + Ptr: s.service.ServiceInstanceName(), + } + resp.Answer = append(resp.Answer, subtypePTR) + } + + // SRV + TXT txt := &dns.TXT{ Hdr: dns.RR_Header{ Name: s.service.ServiceInstanceName(), @@ -525,7 +553,7 @@ func (s *Server) serviceTypeName(resp *dns.Msg, ttl uint32) { } // Perform probing & announcement -//TODO: implement a proper probing & conflict resolution +// TODO: implement a proper probing & conflict resolution func (s *Server) probe() { q := new(dns.Msg) q.SetQuestion(s.service.ServiceInstanceName(), dns.TypePTR) @@ -733,7 +761,7 @@ func (s *Server) multicastResponse(msg *dns.Msg, ifIndex int) error { default: iface, _ := net.InterfaceByIndex(ifIndex) if err := s.ipv4conn.SetMulticastInterface(iface); err != nil { - log.Printf("[WARN] mdns: Failed to set multicast interface: %v", err) + log.Printf(MulticastInterfaceError, err) } } s.ipv4conn.WriteTo(buf, &wcm, ipv4Addr) @@ -744,7 +772,7 @@ func (s *Server) multicastResponse(msg *dns.Msg, ifIndex int) error { wcm.IfIndex = intf.Index default: if err := s.ipv4conn.SetMulticastInterface(&intf); err != nil { - log.Printf("[WARN] mdns: Failed to set multicast interface: %v", err) + log.Printf(MulticastInterfaceError, err) } } s.ipv4conn.WriteTo(buf, &wcm, ipv4Addr) @@ -764,7 +792,7 @@ func (s *Server) multicastResponse(msg *dns.Msg, ifIndex int) error { default: iface, _ := net.InterfaceByIndex(ifIndex) if err := s.ipv6conn.SetMulticastInterface(iface); err != nil { - log.Printf("[WARN] mdns: Failed to set multicast interface: %v", err) + log.Printf(MulticastInterfaceError, err) } } s.ipv6conn.WriteTo(buf, &wcm, ipv6Addr) @@ -775,7 +803,7 @@ func (s *Server) multicastResponse(msg *dns.Msg, ifIndex int) error { wcm.IfIndex = intf.Index default: if err := s.ipv6conn.SetMulticastInterface(&intf); err != nil { - log.Printf("[WARN] mdns: Failed to set multicast interface: %v", err) + log.Printf(MulticastInterfaceError, err) } } s.ipv6conn.WriteTo(buf, &wcm, ipv6Addr) diff --git a/service_record_test.go b/service_record_test.go new file mode 100644 index 00000000..9769c95e --- /dev/null +++ b/service_record_test.go @@ -0,0 +1,46 @@ +package zeroconf + +import ( + "reflect" + "testing" +) + +func TestNewServiceRecordSubtypes(t *testing.T) { + tests := []struct { + name string + serviceInput string + domain string + expected []string + }{ + { + name: "single subtype", + serviceInput: "_http._tcp,_printer", + domain: "local", + expected: []string{"_printer._sub._http._tcp.local."}, + }, + { + name: "multiple subtypes", + serviceInput: "_ftp._tcp,_secure,_fast", + domain: "local", + expected: []string{ + "_secure._sub._ftp._tcp.local.", + "_fast._sub._ftp._tcp.local.", + }, + }, + { + name: "no subtypes", + serviceInput: "_ssh._tcp", + domain: "local", + expected: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sr := NewServiceRecord("TestInstance", tt.serviceInput, tt.domain) + if !reflect.DeepEqual(sr.Subtypes, tt.expected) { + t.Errorf("Subtypes = %v; want %v", sr.Subtypes, tt.expected) + } + }) + } +} diff --git a/service_test.go b/service_test.go index 2c5a23ed..ff4315f1 100644 --- a/service_test.go +++ b/service_test.go @@ -3,10 +3,14 @@ package zeroconf import ( "context" "log" + "net" + "os" + "strings" "testing" "time" "github.com/pkg/errors" + "github.com/stretchr/testify/assert" ) var ( @@ -17,6 +21,16 @@ var ( mdnsPort = 8888 ) +const ( + expectedResolverSuccessMsg = "Expected create resolver success, but got %v" + expectedBrowseSuccessMsg = "Expected browse success, but got %v" + expectedNumEntriesMsg = "Expected number of service entries is 1, but got %d" + expectedDomainMsg = "Expected domain is %s, but got %s" + expectedServiceMsg = "Expected service is %s, but got %s" + expectedInstanceMsg = "Expected instance is %s, but got %s" + expectedPortMsg = "Expected port is %d, but got %d" +) + func startMDNS(ctx context.Context, port int, name, service, domain string) { // 5353 is default mdns port server, err := Register(name, service, domain, port, []string{"txtv=0", "lo=1", "la=2"}, nil) @@ -42,36 +56,36 @@ func TestBasic(t *testing.T) { resolver, err := NewResolver(nil) if err != nil { - t.Fatalf("Expected create resolver success, but got %v", err) + t.Fatalf(expectedResolverSuccessMsg, err) } entries := make(chan *ServiceEntry, 100) if err := resolver.Browse(ctx, mdnsService, mdnsDomain, entries); err != nil { - t.Fatalf("Expected browse success, but got %v", err) + t.Fatalf(expectedBrowseSuccessMsg, err) } <-ctx.Done() if len(entries) != 1 { - t.Fatalf("Expected number of service entries is 1, but got %d", len(entries)) + t.Fatalf(expectedNumEntriesMsg, len(entries)) } result := <-entries if result.Domain != mdnsDomain { - t.Fatalf("Expected domain is %s, but got %s", mdnsDomain, result.Domain) + t.Fatalf(expectedDomainMsg, mdnsDomain, result.Domain) } if result.Service != mdnsService { - t.Fatalf("Expected service is %s, but got %s", mdnsService, result.Service) + t.Fatalf(expectedServiceMsg, mdnsService, result.Service) } if result.Instance != mdnsName { - t.Fatalf("Expected instance is %s, but got %s", mdnsName, result.Instance) + t.Fatalf(expectedInstanceMsg, mdnsName, result.Instance) } if result.Port != mdnsPort { - t.Fatalf("Expected port is %d, but got %d", mdnsPort, result.Port) + t.Fatalf(expectedPortMsg, mdnsPort, result.Port) } } func TestNoRegister(t *testing.T) { resolver, err := NewResolver(nil) if err != nil { - t.Fatalf("Expected create resolver success, but got %v", err) + t.Fatalf(expectedResolverSuccessMsg, err) } // before register, mdns resolve shuold not have any entry @@ -85,7 +99,7 @@ func TestNoRegister(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) if err := resolver.Browse(ctx, mdnsService, mdnsDomain, entries); err != nil { - t.Fatalf("Expected browse success, but got %v", err) + t.Fatalf(expectedBrowseSuccessMsg, err) } <-ctx.Done() cancel() @@ -102,29 +116,29 @@ func TestSubtype(t *testing.T) { resolver, err := NewResolver(nil) if err != nil { - t.Fatalf("Expected create resolver success, but got %v", err) + t.Fatalf(expectedResolverSuccessMsg, err) } entries := make(chan *ServiceEntry, 100) if err := resolver.Browse(ctx, mdnsSubtype, mdnsDomain, entries); err != nil { - t.Fatalf("Expected browse success, but got %v", err) + t.Fatalf(expectedBrowseSuccessMsg, err) } <-ctx.Done() if len(entries) != 1 { - t.Fatalf("Expected number of service entries is 1, but got %d", len(entries)) + t.Fatalf(expectedNumEntriesMsg, len(entries)) } result := <-entries if result.Domain != mdnsDomain { - t.Fatalf("Expected domain is %s, but got %s", mdnsDomain, result.Domain) + t.Fatalf(expectedDomainMsg, mdnsDomain, result.Domain) } if result.Service != mdnsService { - t.Fatalf("Expected service is %s, but got %s", mdnsService, result.Service) + t.Fatalf(expectedServiceMsg, mdnsService, result.Service) } if result.Instance != mdnsName { - t.Fatalf("Expected instance is %s, but got %s", mdnsName, result.Instance) + t.Fatalf(expectedInstanceMsg, mdnsName, result.Instance) } if result.Port != mdnsPort { - t.Fatalf("Expected port is %d, but got %d", mdnsPort, result.Port) + t.Fatalf(expectedPortMsg, mdnsPort, result.Port) } }) @@ -138,7 +152,7 @@ func TestSubtype(t *testing.T) { resolver, err := NewResolver(nil) if err != nil { - t.Fatalf("Expected create resolver success, but got %v", err) + t.Fatalf(expectedResolverSuccessMsg, err) } entries := make(chan *ServiceEntry, 100) if err := resolver.Browse(ctx, mdnsService, mdnsDomain, entries); err != nil { @@ -147,7 +161,7 @@ func TestSubtype(t *testing.T) { <-ctx.Done() if len(entries) != 1 { - t.Fatalf("Expected number of service entries is 1, but got %d", len(entries)) + t.Fatalf(expectedNumEntriesMsg, len(entries)) } result := <-entries if result.Domain != mdnsDomain { @@ -157,10 +171,94 @@ func TestSubtype(t *testing.T) { t.Fatalf("Expected service is %s, but got %s", mdnsService, result.Service) } if result.Instance != mdnsName { - t.Fatalf("Expected instance is %s, but got %s", mdnsName, result.Instance) + t.Fatalf(expectedInstanceMsg, mdnsName, result.Instance) } if result.Port != mdnsPort { - t.Fatalf("Expected port is %d, but got %d", mdnsPort, result.Port) + t.Fatalf(expectedPortMsg, mdnsPort, result.Port) } }) } + +func TestRegister(t *testing.T) { + instance := "my-service" + service := "_http._tcp" + domain := "local" + port := 8080 + text := []string{"foo=bar", "version=1"} + + ifaces, err := net.Interfaces() + if err != nil { + t.Fatalf("failed to get interfaces: %v", err) + } + + server, err := Register(instance, service, domain, port, text, ifaces) + if err != nil { + t.Fatalf("Register() failed: %v", err) + } + defer server.Shutdown() + + entry := server.service + + // Verify instance, service, domain + if entry.Instance != instance { + t.Errorf("expected instance %q, got %q", instance, entry.Instance) + } + if entry.Service != service { + t.Errorf("expected service %q, got %q", service, entry.Service) + } + if strings.TrimSuffix(entry.Domain, ".") != domain { + t.Errorf("expected domain %q, got %q", domain, entry.Domain) + } + + // Verify HostName includes system hostname and domain + hostname, _ := os.Hostname() + expectedSuffix := "." + domain + "." + if !strings.HasPrefix(entry.HostName, hostname+".") || !strings.HasSuffix(entry.HostName, expectedSuffix) { + t.Errorf("unexpected HostName format: got %q", entry.HostName) + } + + // Verify IPs are populated + if len(entry.AddrIPv4) == 0 && len(entry.AddrIPv6) == 0 { + t.Error("expected at least one IP address") + } +} + +func TestProxyRegistrationConfigSetup(t *testing.T) { + ifaces, err := net.Interfaces() + if err != nil || len(ifaces) == 0 { + t.Fatalf("Failed to get network interfaces: %v", err) + } + + config := ProxyRegistrationConfig{ + Instance: "test-instance", + Service: "_http._tcp", + Domain: "local.", + Port: 8080, + Host: "test-host", + IPs: []string{"192.168.1.10", "10.0.0.1"}, + Text: []string{"key=value", "env=dev"}, + Ifaces: []net.Interface{ifaces[0]}, // use the first available interface + } + + // Register the service proxy (Updated to new RegisterProxy function) + server, err := RegisterProxy(config) + assert.NoError(t, err) + assert.NotNil(t, server) + + // Ensure service proxy is registered correctly + assert.Equal(t, config.Instance, server.service.Instance) + assert.Equal(t, config.Service, server.service.Service) + assert.Equal(t, config.Port, server.service.Port) + assert.Equal(t, config.Host+"."+config.Domain, server.service.HostName) + assert.Equal(t, config.Text, server.service.Text) + + // Verify HostName includes system hostname and domain + expectedHostname := config.Host + "." + config.Domain + assert.Equal(t, expectedHostname, server.service.HostName) + + // Act: Test the server's proxy behavior, including probing and announcement + server.probe() + + // Assert: Verify server operations or shutdown + server.Shutdown() +}