Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 84 additions & 2 deletions hosts.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,23 @@ import (
"bufio"
"bytes"
"fmt"
"io"
"net"
"os"
"path/filepath"
"sort"
"strings"
"time"

"github.com/asaskevich/govalidator"
"github.com/dimchansky/utfbom"
)

// Hosts represents hosts file with the path and parsed contents of each line
type Hosts struct {
Path string // Path to the location of the hosts file that will be loaded/flushed
Lines []HostsLine // Slice containing all the lines parsed from the hosts file
Path string // Path to the location of the hosts file that will be loaded/flushed
Lines []HostsLine // Slice containing all the lines parsed from the hosts file
modTime time.Time // Track file modification time

ips lookup
hosts lookup
Expand Down Expand Up @@ -88,6 +91,13 @@ func (h *Hosts) Load() error {
}
defer file.Close()

// Capture modification time for concurrent modification detection
info, err := file.Stat()
if err != nil {
return err
}
h.modTime = info.ModTime()

h.Clear() // reset the lines and lookups in case anything was previously set

scanner := bufio.NewScanner(utfbom.SkipOnly(file))
Expand All @@ -98,6 +108,15 @@ func (h *Hosts) Load() error {
return scanner.Err()
}

// HasBeenModified checks if the hosts file was modified since it was loaded
func (h *Hosts) HasBeenModified() (bool, error) {
info, err := os.Stat(h.Path)
if err != nil {
return false, err
}
return !info.ModTime().Equal(h.modTime), nil
}

// Flush writes to the file located at Path the contents of Lines in a hostsfile format
func (h *Hosts) Flush() error {
if err := h.preFlush(); err != nil {
Expand Down Expand Up @@ -128,6 +147,34 @@ func (h *Hosts) Flush() error {
return h.Load()
}

// BackupPath returns the default backup path for the hosts file
func (h *Hosts) BackupPath() string {
return h.Path + ".bak"
}

// Backup creates a backup of the current hosts file
func (h *Hosts) Backup() error {
return h.BackupTo(h.BackupPath())
}

// BackupTo creates a backup of the hosts file to a specified path
func (h *Hosts) BackupTo(path string) error {
source, err := os.Open(h.Path)
if err != nil {
return err
}
defer source.Close()

dest, err := os.Create(path)
if err != nil {
return err
}
defer dest.Close()

_, err = io.Copy(dest, source)
return err
}

// AddRaw takes a line from a hosts file and parses/adds the HostsLine
func (h *Hosts) AddRaw(raw ...string) error {
for _, r := range raw {
Expand Down Expand Up @@ -242,6 +289,41 @@ func (h *Hosts) HasIP(ip string) bool {
return len(h.ips.get(ip)) > 0
}

// HasAll returns true if the IP has ALL specified hostnames mapped to it
func (h *Hosts) HasAll(ip string, hosts ...string) bool {
if len(hosts) == 0 {
return h.HasIP(ip)
}
for _, host := range hosts {
if !h.Has(ip, host) {
return false
}
}
return true
}

// HasAny returns true if the IP has ANY of the specified hostnames mapped to it
func (h *Hosts) HasAny(ip string, hosts ...string) bool {
if len(hosts) == 0 {
return h.HasIP(ip)
}
for _, host := range hosts {
if h.Has(ip, host) {
return true
}
}
return false
}

// CheckAll returns a map indicating which hostnames exist for the IP
func (h *Hosts) CheckAll(ip string, hosts ...string) map[string]bool {
result := make(map[string]bool)
for _, host := range hosts {
result[host] = h.Has(ip, host)
}
return result
}

// Remove takes an ip and an optional host(s), if only an ip is passed the whole line is removed
// when the optional hosts param is passed it will remove only those specific hosts from that ip
func (h *Hosts) Remove(ip string, hosts ...string) error {
Expand Down
94 changes: 94 additions & 0 deletions hosts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -804,3 +804,97 @@ func TestHosts_SortIPs(t *testing.T) {
"",
}, eol), hosts.String())
}

func TestHosts_HasAll(t *testing.T) {
hosts := newHosts()
assert.Nil(t, hosts.AddRaw("127.0.0.1 host1 host2 host3"))

// All hosts exist
assert.True(t, hosts.HasAll("127.0.0.1", "host1", "host2"))
assert.True(t, hosts.HasAll("127.0.0.1", "host1", "host2", "host3"))

// One host missing
assert.False(t, hosts.HasAll("127.0.0.1", "host1", "host4"))
assert.False(t, hosts.HasAll("127.0.0.1", "host4"))

// IP doesn't exist
assert.False(t, hosts.HasAll("10.0.0.1", "host1"))

// Empty hosts - should check IP exists
assert.True(t, hosts.HasAll("127.0.0.1"))
assert.False(t, hosts.HasAll("10.0.0.1"))
}

func TestHosts_HasAny(t *testing.T) {
hosts := newHosts()
assert.Nil(t, hosts.AddRaw("127.0.0.1 host1 host2"))

// At least one exists
assert.True(t, hosts.HasAny("127.0.0.1", "host1", "host4"))
assert.True(t, hosts.HasAny("127.0.0.1", "host4", "host2"))

// None exist
assert.False(t, hosts.HasAny("127.0.0.1", "host3", "host4"))
assert.False(t, hosts.HasAny("10.0.0.1", "host1"))

// Empty hosts - should check IP exists
assert.True(t, hosts.HasAny("127.0.0.1"))
assert.False(t, hosts.HasAny("10.0.0.1"))
}

func TestHosts_CheckAll(t *testing.T) {
hosts := newHosts()
assert.Nil(t, hosts.AddRaw("127.0.0.1 host1 host2"))

result := hosts.CheckAll("127.0.0.1", "host1", "host2", "host3")
assert.True(t, result["host1"])
assert.True(t, result["host2"])
assert.False(t, result["host3"])

// Empty result for IP that doesn't exist
result = hosts.CheckAll("10.0.0.1", "host1")
assert.False(t, result["host1"])
}

func TestHosts_Backup(t *testing.T) {
// Create a temporary hosts file
fp := filepath.Join(os.TempDir(), fmt.Sprintf("hostsfile-test-%s", randomString(8)))
f, err := os.Create(fp)
assert.Nil(t, err)
defer os.Remove(fp)

// Write test content
testContent := "127.0.0.1 localhost\n192.168.1.1 testhost\n"
_, err = f.WriteString(testContent)
assert.Nil(t, err)
assert.Nil(t, f.Close())

// Load the hosts file
hosts, err := NewCustomHosts(fp)
assert.Nil(t, err)

// Create backup
err = hosts.Backup()
assert.Nil(t, err)
defer os.Remove(hosts.BackupPath())

// Verify backup exists and has correct content
backupData, err := os.ReadFile(hosts.BackupPath())
assert.Nil(t, err)
assert.Equal(t, testContent, string(backupData))

// Test BackupTo with custom path
customBackupPath := fp + ".custom.bak"
err = hosts.BackupTo(customBackupPath)
assert.Nil(t, err)
defer os.Remove(customBackupPath)

customBackupData, err := os.ReadFile(customBackupPath)
assert.Nil(t, err)
assert.Equal(t, testContent, string(customBackupData))
}

func TestHosts_BackupPath(t *testing.T) {
hosts := &Hosts{Path: "/etc/hosts"}
assert.Equal(t, "/etc/hosts.bak", hosts.BackupPath())
}
6 changes: 3 additions & 3 deletions hostsline.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ func NewHostsLine(raw string) HostsLine {
output := HostsLine{Raw: raw}

if output.HasComment() { //trailing comment
commentSplit := strings.Split(output.Raw, commentChar)
raw = commentSplit[0]
output.Comment = commentSplit[1]
idx := strings.Index(output.Raw, commentChar)
raw = output.Raw[:idx]
output.Comment = output.Raw[idx+1:]
}

if output.IsComment() { //whole line is comment
Expand Down
20 changes: 20 additions & 0 deletions hostsline_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,23 @@ func TestHosts_combine(t *testing.T) {
hl2.Combine(hl1) // should have dupes removed
assert.Equal(t, "127.0.0.1 test2 test1 test2", hl2.String())
}

func TestHostsline_CommentWithMultipleHashes(t *testing.T) {
// Test that comments with multiple # characters are preserved correctly
raw := "127.0.0.1 localhost # comment with # symbol in it"
hl := NewHostsLine(raw)

assert.Equal(t, "127.0.0.1", hl.IP)
assert.Equal(t, []string{"localhost"}, hl.Hosts)
assert.Equal(t, " comment with # symbol in it", hl.Comment)
assert.Equal(t, raw, hl.ToRaw())

// Test another case
raw2 := "192.168.1.1 host1 host2 # first # second # third"
hl2 := NewHostsLine(raw2)

assert.Equal(t, "192.168.1.1", hl2.IP)
assert.Equal(t, []string{"host1", "host2"}, hl2.Hosts)
assert.Equal(t, " first # second # third", hl2.Comment)
assert.Equal(t, raw2, hl2.ToRaw())
}
Loading