diff --git a/hosts.go b/hosts.go index 81b64be..1f7f1a5 100644 --- a/hosts.go +++ b/hosts.go @@ -4,11 +4,13 @@ import ( "bufio" "bytes" "fmt" + "io" "net" "os" "path/filepath" "sort" "strings" + "time" "github.com/asaskevich/govalidator" "github.com/dimchansky/utfbom" @@ -16,8 +18,9 @@ import ( // Hosts represents hosts file with the path and parsed contents of each line type Hosts struct { - Path string // Path to the location of the hosts file that will be loaded/flushed - Lines []HostsLine // Slice containing all the lines parsed from the hosts file + Path string // Path to the location of the hosts file that will be loaded/flushed + Lines []HostsLine // Slice containing all the lines parsed from the hosts file + modTime time.Time // Track file modification time ips lookup hosts lookup @@ -88,6 +91,13 @@ func (h *Hosts) Load() error { } defer file.Close() + // Capture modification time for concurrent modification detection + info, err := file.Stat() + if err != nil { + return err + } + h.modTime = info.ModTime() + h.Clear() // reset the lines and lookups in case anything was previously set scanner := bufio.NewScanner(utfbom.SkipOnly(file)) @@ -98,6 +108,15 @@ func (h *Hosts) Load() error { return scanner.Err() } +// HasBeenModified checks if the hosts file was modified since it was loaded +func (h *Hosts) HasBeenModified() (bool, error) { + info, err := os.Stat(h.Path) + if err != nil { + return false, err + } + return !info.ModTime().Equal(h.modTime), nil +} + // Flush writes to the file located at Path the contents of Lines in a hostsfile format func (h *Hosts) Flush() error { if err := h.preFlush(); err != nil { @@ -128,6 +147,34 @@ func (h *Hosts) Flush() error { return h.Load() } +// BackupPath returns the default backup path for the hosts file +func (h *Hosts) BackupPath() string { + return h.Path + ".bak" +} + +// Backup creates a backup of the current hosts file +func (h *Hosts) Backup() error { + return h.BackupTo(h.BackupPath()) +} + +// BackupTo creates a backup of the hosts file to a specified path +func (h *Hosts) BackupTo(path string) error { + source, err := os.Open(h.Path) + if err != nil { + return err + } + defer source.Close() + + dest, err := os.Create(path) + if err != nil { + return err + } + defer dest.Close() + + _, err = io.Copy(dest, source) + return err +} + // AddRaw takes a line from a hosts file and parses/adds the HostsLine func (h *Hosts) AddRaw(raw ...string) error { for _, r := range raw { @@ -242,6 +289,41 @@ func (h *Hosts) HasIP(ip string) bool { return len(h.ips.get(ip)) > 0 } +// HasAll returns true if the IP has ALL specified hostnames mapped to it +func (h *Hosts) HasAll(ip string, hosts ...string) bool { + if len(hosts) == 0 { + return h.HasIP(ip) + } + for _, host := range hosts { + if !h.Has(ip, host) { + return false + } + } + return true +} + +// HasAny returns true if the IP has ANY of the specified hostnames mapped to it +func (h *Hosts) HasAny(ip string, hosts ...string) bool { + if len(hosts) == 0 { + return h.HasIP(ip) + } + for _, host := range hosts { + if h.Has(ip, host) { + return true + } + } + return false +} + +// CheckAll returns a map indicating which hostnames exist for the IP +func (h *Hosts) CheckAll(ip string, hosts ...string) map[string]bool { + result := make(map[string]bool) + for _, host := range hosts { + result[host] = h.Has(ip, host) + } + return result +} + // Remove takes an ip and an optional host(s), if only an ip is passed the whole line is removed // when the optional hosts param is passed it will remove only those specific hosts from that ip func (h *Hosts) Remove(ip string, hosts ...string) error { diff --git a/hosts_test.go b/hosts_test.go index 1e1ee51..eb9184e 100644 --- a/hosts_test.go +++ b/hosts_test.go @@ -804,3 +804,97 @@ func TestHosts_SortIPs(t *testing.T) { "", }, eol), hosts.String()) } + +func TestHosts_HasAll(t *testing.T) { + hosts := newHosts() + assert.Nil(t, hosts.AddRaw("127.0.0.1 host1 host2 host3")) + + // All hosts exist + assert.True(t, hosts.HasAll("127.0.0.1", "host1", "host2")) + assert.True(t, hosts.HasAll("127.0.0.1", "host1", "host2", "host3")) + + // One host missing + assert.False(t, hosts.HasAll("127.0.0.1", "host1", "host4")) + assert.False(t, hosts.HasAll("127.0.0.1", "host4")) + + // IP doesn't exist + assert.False(t, hosts.HasAll("10.0.0.1", "host1")) + + // Empty hosts - should check IP exists + assert.True(t, hosts.HasAll("127.0.0.1")) + assert.False(t, hosts.HasAll("10.0.0.1")) +} + +func TestHosts_HasAny(t *testing.T) { + hosts := newHosts() + assert.Nil(t, hosts.AddRaw("127.0.0.1 host1 host2")) + + // At least one exists + assert.True(t, hosts.HasAny("127.0.0.1", "host1", "host4")) + assert.True(t, hosts.HasAny("127.0.0.1", "host4", "host2")) + + // None exist + assert.False(t, hosts.HasAny("127.0.0.1", "host3", "host4")) + assert.False(t, hosts.HasAny("10.0.0.1", "host1")) + + // Empty hosts - should check IP exists + assert.True(t, hosts.HasAny("127.0.0.1")) + assert.False(t, hosts.HasAny("10.0.0.1")) +} + +func TestHosts_CheckAll(t *testing.T) { + hosts := newHosts() + assert.Nil(t, hosts.AddRaw("127.0.0.1 host1 host2")) + + result := hosts.CheckAll("127.0.0.1", "host1", "host2", "host3") + assert.True(t, result["host1"]) + assert.True(t, result["host2"]) + assert.False(t, result["host3"]) + + // Empty result for IP that doesn't exist + result = hosts.CheckAll("10.0.0.1", "host1") + assert.False(t, result["host1"]) +} + +func TestHosts_Backup(t *testing.T) { + // Create a temporary hosts file + fp := filepath.Join(os.TempDir(), fmt.Sprintf("hostsfile-test-%s", randomString(8))) + f, err := os.Create(fp) + assert.Nil(t, err) + defer os.Remove(fp) + + // Write test content + testContent := "127.0.0.1 localhost\n192.168.1.1 testhost\n" + _, err = f.WriteString(testContent) + assert.Nil(t, err) + assert.Nil(t, f.Close()) + + // Load the hosts file + hosts, err := NewCustomHosts(fp) + assert.Nil(t, err) + + // Create backup + err = hosts.Backup() + assert.Nil(t, err) + defer os.Remove(hosts.BackupPath()) + + // Verify backup exists and has correct content + backupData, err := os.ReadFile(hosts.BackupPath()) + assert.Nil(t, err) + assert.Equal(t, testContent, string(backupData)) + + // Test BackupTo with custom path + customBackupPath := fp + ".custom.bak" + err = hosts.BackupTo(customBackupPath) + assert.Nil(t, err) + defer os.Remove(customBackupPath) + + customBackupData, err := os.ReadFile(customBackupPath) + assert.Nil(t, err) + assert.Equal(t, testContent, string(customBackupData)) +} + +func TestHosts_BackupPath(t *testing.T) { + hosts := &Hosts{Path: "/etc/hosts"} + assert.Equal(t, "/etc/hosts.bak", hosts.BackupPath()) +} diff --git a/hostsline.go b/hostsline.go index d0b9f3b..d5b23b4 100644 --- a/hostsline.go +++ b/hostsline.go @@ -24,9 +24,9 @@ func NewHostsLine(raw string) HostsLine { output := HostsLine{Raw: raw} if output.HasComment() { //trailing comment - commentSplit := strings.Split(output.Raw, commentChar) - raw = commentSplit[0] - output.Comment = commentSplit[1] + idx := strings.Index(output.Raw, commentChar) + raw = output.Raw[:idx] + output.Comment = output.Raw[idx+1:] } if output.IsComment() { //whole line is comment diff --git a/hostsline_test.go b/hostsline_test.go index e1dba18..8482dd4 100644 --- a/hostsline_test.go +++ b/hostsline_test.go @@ -24,3 +24,23 @@ func TestHosts_combine(t *testing.T) { hl2.Combine(hl1) // should have dupes removed assert.Equal(t, "127.0.0.1 test2 test1 test2", hl2.String()) } + +func TestHostsline_CommentWithMultipleHashes(t *testing.T) { + // Test that comments with multiple # characters are preserved correctly + raw := "127.0.0.1 localhost # comment with # symbol in it" + hl := NewHostsLine(raw) + + assert.Equal(t, "127.0.0.1", hl.IP) + assert.Equal(t, []string{"localhost"}, hl.Hosts) + assert.Equal(t, " comment with # symbol in it", hl.Comment) + assert.Equal(t, raw, hl.ToRaw()) + + // Test another case + raw2 := "192.168.1.1 host1 host2 # first # second # third" + hl2 := NewHostsLine(raw2) + + assert.Equal(t, "192.168.1.1", hl2.IP) + assert.Equal(t, []string{"host1", "host2"}, hl2.Hosts) + assert.Equal(t, " first # second # third", hl2.Comment) + assert.Equal(t, raw2, hl2.ToRaw()) +}