diff --git a/README.md b/README.md index d9550fd..5b97d80 100644 --- a/README.md +++ b/README.md @@ -31,12 +31,11 @@ Doing a lookup for service providers is also very simple: // Make a channel for results and start listening entriesCh := make(chan *mdns.ServiceEntry, 4) go func() { - for entry := range entriesCh { - fmt.Printf("Got new entry: %v\n", entry) - } + for entry := range entriesCh { + fmt.Printf("Got new entry: %v\n", entry) + } }() // Start the lookup - mdns.Lookup("_foobar._tcp", entriesCh) + mdns.ResolveService("_foobar._tcp.local.", entriesCh) close(entriesCh) - diff --git a/client.go b/client.go index 5b3cbda..f887291 100644 --- a/client.go +++ b/client.go @@ -7,57 +7,71 @@ import ( "github.com/miekg/dns" "log" "net" - "strings" "sync" "time" + //"strings" + "bytes" ) -// ServiceEntry is returned after we query for a service -type ServiceEntry struct { - Name string - Host string - AddrV4 net.IP - AddrV6 net.IP - Port int - Info string - - Addr net.IP // @Deprecated - - hasTXT bool - sent bool -} - -// complete is used to check if we have all the info we need -func (s *ServiceEntry) complete() bool { - return (s.AddrV4 != nil || s.AddrV6 != nil || s.Addr != nil) && s.Port != 0 && s.hasTXT -} - // QueryParam is used to customize how a Lookup is performed type QueryParam struct { - Service string // Service to lookup - Domain string // Lookup domain, default "local" + RecordName string // RecordName to lookup Timeout time.Duration // Lookup timeout, default 1 second Interface *net.Interface // Multicast interface to use - Entries chan<- *ServiceEntry // Entries Channel + QueryType uint16 // dns Type Constant to use + Entries chan<- dns.RR // Entries Channel +} + +type operationType string +const ( + SUBSCRIBE operationType = "SUBSCRIBE" + UNSUBSCRIBE operationType = "UNSUBSCRIBE" + CLOSE operationType = "CLOSE" +) + +type subscriptionMessage struct { + Operation operationType + Channel chan<- dns.RR } // DefaultParams is used to return a default set of QueryParam's -func DefaultParams(service string) *QueryParam { +func DefaultParams(recordName string) *QueryParam { return &QueryParam{ - Service: service, - Domain: "local", + RecordName: recordName, + QueryType: dns.TypeANY, Timeout: time.Second, - Entries: make(chan *ServiceEntry), + Entries: make(chan dns.RR), + } +} + +func EscapeName(name string) string { + var outputBuffer bytes.Buffer + + previousIsSlash := false + for _, c := range name { + if c == ' ' && !previousIsSlash { + outputBuffer.WriteRune('\\') + } + + outputBuffer.WriteRune(c) + + if c == '\\' { + previousIsSlash = true + } else { + previousIsSlash = false + } } + + return outputBuffer.String() } -// Query looks up a given service, in a domain, waiting at most +// Query looks up a given recordName, in a domain, waiting at most // for a timeout before finishing the query. The results are streamed // to a channel. Sends will not block, so clients should make sure to // either read or buffer. func Query(params *QueryParam) error { // Create a new client - client, err := newClient() + client, err := NewClient() if err != nil { return err } @@ -65,32 +79,28 @@ func Query(params *QueryParam) error { // Set the multicast interface if params.Interface != nil { - if err := client.setInterface(params.Interface); err != nil { + if err := client.SetInterface(params.Interface); err != nil { return err } } - // Ensure defaults are set - if params.Domain == "" { - params.Domain = "local" - } if params.Timeout == 0 { params.Timeout = time.Second } // Run the query - return client.query(params) + return client.Query(params) } // Lookup is the same as Query, however it uses all the default parameters -func Lookup(service string, entries chan<- *ServiceEntry) error { - params := DefaultParams(service) +func Lookup(recordName string, entries chan<- dns.RR) error { + params := DefaultParams(recordName) params.Entries = entries return Query(params) } // Client provides a query interface that can be used to -// search for service providers using mDNS +// search for recordName providers using mDNS type client struct { ipv4List *net.UDPConn ipv6List *net.UDPConn @@ -98,17 +108,21 @@ type client struct { closed bool closedCh chan struct{} closeLock sync.Mutex + + msgChan chan *dns.Msg + subscriptionChannel chan subscriptionMessage + //subscriberChans []chan dns.RR } // NewClient creates a new mdns Client that can be used to query // for records -func newClient() (*client, error) { +func NewClient() (*client, error) { // Create a IPv4 listener - ipv4, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + ipv4, err := net.ListenMulticastUDP("udp4", nil, ipv4Addr) if err != nil { log.Printf("[ERR] mdns: Failed to bind to udp4 port: %v", err) } - ipv6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0}) + ipv6, err := net.ListenMulticastUDP("udp6", nil, ipv6Addr) if err != nil { log.Printf("[ERR] mdns: Failed to bind to udp6 port: %v", err) } @@ -121,10 +135,62 @@ func newClient() (*client, error) { ipv4List: ipv4, ipv6List: ipv6, closedCh: make(chan struct{}), + msgChan: make(chan *dns.Msg, 32), + subscriptionChannel: make(chan subscriptionMessage, 32), } + go c.broadcastAll() return c, nil } +func (c *client) broadcastAll() { + go c.recv(c.ipv4List, c.msgChan) + go c.recv(c.ipv6List, c.msgChan) + + entryCache := make([]dns.RR, 0) + subscriberChans := make([]chan<- dns.RR, 0) + + for { + select { + case msg := <- c.msgChan: + for _, answer := range msg.Answer { + entryCache = append(entryCache, answer) + for _, channel := range subscriberChans { + channel <- answer + } + } + for _, answer := range msg.Extra { + entryCache = append(entryCache, answer) + for _, channel := range subscriberChans { + channel <- answer + } + } + + case msg := <- c.subscriptionChannel: + switch msg.Operation { + case SUBSCRIBE: + //fmt.Println("Subscribe") + subscriberChans = append(subscriberChans, msg.Channel) + for _, entry := range entryCache { + msg.Channel <- entry + } + break + case UNSUBSCRIBE: + //fmt.Println("Unsubscribe") + for idx, channel := range subscriberChans { + if channel == msg.Channel { + subscriberChans = append(subscriberChans[:idx],subscriberChans[idx + 1:]...) + close(channel) + break + } + } + break + case CLOSE: + return + } + } + } +} + // Close is used to cleanup the client func (c *client) Close() error { c.closeLock.Lock() @@ -145,9 +211,9 @@ func (c *client) Close() error { return nil } -// setInterface is used to set the query interface, uses sytem +// setInterface is used to set the query interface, uses system // default if not provided -func (c *client) setInterface(iface *net.Interface) error { +func (c *client) SetInterface(iface *net.Interface) error { p := ipv4.NewPacketConn(c.ipv4List) if err := p.SetMulticastInterface(iface); err != nil { return err @@ -159,94 +225,48 @@ func (c *client) setInterface(iface *net.Interface) error { return nil } -// query is used to perform a lookup and stream results -func (c *client) query(params *QueryParam) error { - // Create the service name - serviceAddr := fmt.Sprintf("%s.%s.", trimDot(params.Service), trimDot(params.Domain)) +func (c *client) Subscribe() chan dns.RR { + channel := make(chan dns.RR) + c.subscriptionChannel <- subscriptionMessage{ + Operation: SUBSCRIBE, + Channel: channel, + } + return channel +} - // Start listening for response packets - msgCh := make(chan *dns.Msg, 32) - go c.recv(c.ipv4List, msgCh) - go c.recv(c.ipv6List, msgCh) +func (c *client) Unsubscribe(channel chan dns.RR) { + c.subscriptionChannel <- subscriptionMessage{ + Operation: UNSUBSCRIBE, + Channel: channel, + } +} + +// query is used to perform a lookup and stream results +func (c *client) Query(params *QueryParam) error { + // Create the recordName name + recordName := EscapeName(params.RecordName) + answerChan := c.Subscribe() + + go func() { + for answer := range answerChan { + if (answer.Header().Name == recordName) && (params.QueryType == dns.TypeANY || answer.Header().Rrtype == params.QueryType) { + params.Entries <- answer + } + } + }() // Send the query m := new(dns.Msg) - m.SetQuestion(serviceAddr, dns.TypePTR) - m.RecursionDesired = false + m.SetQuestion(recordName, params.QueryType) if err := c.sendQuery(m); err != nil { return nil } - // Map the in-progress responses - inprogress := make(map[string]*ServiceEntry) - - // Listen until we reach the timeout - finish := time.After(params.Timeout) - for { - select { - case resp := <-msgCh: - var inp *ServiceEntry - for _, answer := range resp.Answer { - switch rr := answer.(type) { - case *dns.PTR: - // Create new entry for this - inp = ensureName(inprogress, rr.Ptr) - - case *dns.SRV: - // Check for a target mismatch - if rr.Target != rr.Hdr.Name { - alias(inprogress, rr.Hdr.Name, rr.Target) - } - - // Get the port - inp = ensureName(inprogress, rr.Hdr.Name) - inp.Host = rr.Target - inp.Port = int(rr.Port) - - case *dns.TXT: - // Pull out the txt - inp = ensureName(inprogress, rr.Hdr.Name) - inp.Info = strings.Join(rr.Txt, "|") - inp.hasTXT = true - - case *dns.A: - // Pull out the IP - inp = ensureName(inprogress, rr.Hdr.Name) - inp.Addr = rr.A // @Deprecated - inp.AddrV4 = rr.A - - case *dns.AAAA: - // Pull out the IP - inp = ensureName(inprogress, rr.Hdr.Name) - inp.Addr = rr.AAAA // @Deprecated - inp.AddrV6 = rr.AAAA - } - } - - // Check if this entry is complete - if inp.complete() { - if inp.sent { - continue - } - inp.sent = true - select { - case params.Entries <- inp: - default: - } - } else { - // Fire off a node specific query - m := new(dns.Msg) - m.SetQuestion(inp.Name, dns.TypePTR) - m.RecursionDesired = false - if err := c.sendQuery(m); err != nil { - log.Printf("[ERR] mdns: Failed to query instance %s: %v", inp.Name, err) - } - } - case <-finish: - return nil - } + select { + case <- time.After(params.Timeout): + c.Unsubscribe(answerChan) + return nil } - return nil } // sendQuery is used to multicast a query out @@ -287,21 +307,3 @@ func (c *client) recv(l *net.UDPConn, msgCh chan *dns.Msg) { } } } - -// ensureName is used to ensure the named node is in progress -func ensureName(inprogress map[string]*ServiceEntry, name string) *ServiceEntry { - if inp, ok := inprogress[name]; ok { - return inp - } - inp := &ServiceEntry{ - Name: name, - } - inprogress[name] = inp - return inp -} - -// alias is used to setup an alias between two entries -func alias(inprogress map[string]*ServiceEntry, src, dst string) { - srcEntry := ensureName(inprogress, src) - inprogress[dst] = srcEntry -} diff --git a/dnssd_client.go b/dnssd_client.go new file mode 100644 index 0000000..faf369e --- /dev/null +++ b/dnssd_client.go @@ -0,0 +1,227 @@ +package mdns + +import ( + "net" + "github.com/miekg/dns" + "fmt" + "strings" +) + +// ServiceEntry is returned after we query for a service +type ServiceEntry struct { + ServiceName string + ServiceInstanceName string + ServiceHost string + Priority uint16 + Weight uint16 + Addr net.IP + Port uint16 + PropertyList map[string]interface{} +} + +func findIP(client *client, entries chan<- ServiceEntry, service ServiceEntry) { + seenList := map[string]bool{} + resultChannel := make(chan dns.RR) + + go func() { + for result := range resultChannel { + if result, ok := result.(*dns.A); ok { + if !seenList[string(result.A)] { + seenList[string(result.A)] = true + newService := ServiceEntry{ + ServiceName: service.ServiceName, + ServiceInstanceName: service.ServiceInstanceName, + ServiceHost: service.ServiceHost, + Port: service.Port, + Priority: service.Priority, + Weight: service.Weight, + PropertyList: service.PropertyList, + Addr: result.A, + } + + entries <- newService + } + } + + if result, ok := result.(*dns.AAAA); ok { + if !seenList[string(result.AAAA)] { + seenList[string(result.AAAA)] = true + newService := ServiceEntry{ + ServiceName: service.ServiceName, + ServiceInstanceName: service.ServiceInstanceName, + ServiceHost: service.ServiceHost, + Port: service.Port, + Priority: service.Priority, + Weight: service.Weight, + PropertyList: service.PropertyList, + Addr: result.AAAA, + } + + entries <- newService + } + } + } + }() + + params := DefaultParams(service.ServiceHost) + params.QueryType = dns.TypeA + params.Entries = resultChannel + go client.Query(params) + + params = DefaultParams(service.ServiceHost) + params.QueryType = dns.TypeAAAA + params.Entries = resultChannel + go client.Query(params) + +} + + +func parseTXT(items []string) map[string]interface{} { + propertyList := map[string]interface{}{} + + for _, item := range items { + if len(item) == 0 { + continue + } + + if item[0] == '=' { + // key cannot start with a '=' + continue + } + + var key string + var value interface{} + for idx, c := range item { + // Find first instance of '=', everything after is value + if c == '=' { + // Keys are case insensitive + key = strings.ToUpper(string(item[:idx])) + value = item[idx + 1:] + break + } else if idx + 1 == len(item) { + // If item does not have a '=', interpret as a bool + key = strings.ToUpper(item) + value = true + fmt.Println(key) + break + } + } + + if propertyList[key] == nil { + // Only the first instance of a key is respected + propertyList[key] = value + } + } + + return propertyList +} + +func findTXT(client *client, entries chan<- ServiceEntry, service ServiceEntry) { + // DNS-SD is required to have a TXT record + + seenList := map[string]bool{} + resultChannel := make(chan dns.RR) + + params := DefaultParams(service.ServiceInstanceName) + params.QueryType = dns.TypeTXT + params.Entries = resultChannel + + go func() { + for result := range resultChannel { + if result, ok := result.(*dns.TXT); ok { + if !seenList[result.Hdr.Name] { + seenList[result.Hdr.Name] = true + newService := ServiceEntry{ + ServiceName: service.ServiceName, + ServiceInstanceName: service.ServiceInstanceName, + ServiceHost: service.ServiceHost, + Port: service.Port, + Priority: service.Priority, + Weight: service.Weight, + PropertyList: parseTXT(result.Txt), + } + + + //entries <- newService + findIP(client, entries, newService) + } + } + } + }() + + go client.Query(params) + +} + +func findSRV(client *client, entries chan<- ServiceEntry, service ServiceEntry) { + seenList := map[string]bool{} + resultChannel := make(chan dns.RR) + + params := DefaultParams(service.ServiceInstanceName) + params.QueryType = dns.TypeSRV + params.Entries = resultChannel + + go func() { + for result := range resultChannel { + if result, ok := result.(*dns.SRV); ok { + if !seenList[result.Target] { + seenList[result.Target] = true + newService := ServiceEntry{ + ServiceName: service.ServiceName, + ServiceInstanceName: service.ServiceInstanceName, + ServiceHost: result.Target, + Port: result.Port, + Priority: result.Priority, + Weight: result.Weight, + } + findTXT(client, entries, newService) + } + } + } + }() + + go client.Query(params) +} + + + + +func findPTR(client *client, params *QueryParam, entries chan<- ServiceEntry) { + seenList := map[string]bool{} + resultChannel := make(chan dns.RR) + params.Entries = resultChannel + + go func() { + for result := range resultChannel { + if result, ok := result.(*dns.PTR); ok { + if !seenList[result.Ptr] { + seenList[result.Ptr] = true + service := ServiceEntry{ + ServiceName: params.RecordName, + ServiceInstanceName: result.Ptr, + } + //fmt.Println("PTR: " + result.Ptr) + findSRV(client, entries, service) + } + } + } + }() +} + +func ResolveService(service string, entries chan<- ServiceEntry) error { + params := DefaultParams(service) + params.QueryType = dns.TypePTR + + client, err := NewClient() + if err != nil { + fmt.Println(err) + return err + } + + findPTR(client, params, entries) + //processResultsSRV(client, resultCh, entries) + + + client.Query(params) + return nil +}