diff --git a/example_conf.yaml b/example_conf.yaml new file mode 100644 index 0000000..4c4f669 --- /dev/null +++ b/example_conf.yaml @@ -0,0 +1,4 @@ +# TFTP path regex Map to http/https url +/goog/it: http://192.168.10.72/test0 +/abc/xyz/(.*\.iso): http://192.168.10.72/test1/$1 +/abc/def/(.*)/foobar/(.*\.iso): http://192.168.10.72/test2/$1/X/$2 diff --git a/tftp2http.go b/tftp2http.go index 4bad9dd..e6ab578 100644 --- a/tftp2http.go +++ b/tftp2http.go @@ -2,50 +2,96 @@ package main import ( "flag" + "errors" "fmt" + "gopkg.in/yaml.v2" "io" + "io/ioutil" "log" "net/http" "net/url" "os" + "regexp" + "strings" "time" - "github.com/pin/tftp" ) var ( + map_file = flag.String("map-file", "", "YAML file which maps TFTP path regular-expressions into HTTP urls") listenAddr = flag.String("listen", ":69", "Listen address and port") httpTimeout = flag.Int("http-timeout", 5, "HTTP timeout") httpMaxIdle = flag.Int("http-max-idle", 10, "HTTP max idle connections") tftpTimeout = flag.Float64("tftp-timeout", 1, "TFTP timeout") tftpRetries = flag.Int("tftp-retries", 5, "TFTP retries") - serverUrl *url.URL ) +type RedirectItem struct { + tftp_regexp regexp.Regexp + http_url url.URL +} +var redirect_map = make(map[string]RedirectItem) var client *http.Client func processFlags() error { flag.Parse() - var err error - if len(flag.Args()) != 1 { - return fmt.Errorf("must provide server") + var path_map = make(map[string]string) + if *map_file != "" { + cfg_yaml, err := ioutil.ReadFile(*map_file) + err = yaml.Unmarshal(cfg_yaml, &path_map) + if err != nil { + return err + } } - server := flag.Args()[0] - serverUrl, err = url.Parse(server) - if err != nil { - return fmt.Errorf("url: error parsing: '%s'", server) + if len(flag.Args()) == 1 { + server := flag.Args()[0] + path_map["/"] = server } - if serverUrl.Scheme != "http" && serverUrl.Scheme != "https" { - return fmt.Errorf("url: invalid scheme: '%s'", serverUrl.Scheme) + if len(path_map) == 0 { + return fmt.Errorf("No servers defined") } - if len(serverUrl.Host) == 0 { - return fmt.Errorf("url: host must be provided") + for k, v := range path_map { + log.Printf("\t%s\t%s", k, v) + http_url, err := url.Parse(v) + if err != nil { + return fmt.Errorf("url: error parsing: '%s'", v) + } + if http_url.Scheme != "http" && http_url.Scheme != "https" { + return fmt.Errorf("url: invalid scheme: '%s'", http_url.Scheme) + } + if len(http_url.Host) == 0 { + return fmt.Errorf("url: host must be provided") + } + regexp_path, err := regexp.Compile(k) + if err != nil { + return fmt.Errorf("url: illegal regexp '" + k + "'") + } + redirect_map[k] = RedirectItem{ + tftp_regexp: *regexp_path, + http_url: *http_url, + } } - return nil } +func map_tftp_path(tftp_path string) (url.URL, error) { + for k, redirect_item := range redirect_map { + path_regexp := redirect_item.tftp_regexp + if ! path_regexp.MatchString(tftp_path) { + continue + } + new_url := redirect_item.http_url + result_path := path_regexp.ReplaceAllString(tftp_path, new_url.Path) + log.Printf("map_tftp_path: Matched '%s' (result_path '%s')", k, result_path) + new_url.Path = result_path + log.Printf("map_tftp_path: translate '%s' -> '%s'", tftp_path, new_url.String()) + return new_url, nil + } + log.Printf("map_tftp_path: did not find '%s'", tftp_path) + return *new(url.URL), errors.New("Cannot map") +} + func setForwardedHeader(req *http.Request, from string) { req.Header.Add("X-Forwarded-From", from) req.Header.Add("X-Forwarded-Proto", "tftp") @@ -57,9 +103,15 @@ func readHandler(filename string, rf io.ReaderFrom) error { from := raddr.String() log.Printf("{%s} received RRQ '%s'", from, filename) - - u := *serverUrl - u.Path += filename + if ! strings.HasPrefix(filename, "/") { + filename = "/" + filename + } + http_url, err := map_tftp_path(filename) + if err != nil { + log.Printf("Cannot map '%s'", filename) + return err + } + log.Printf("Translate to '%s'", http_url.String()) start := time.Now() var n int64 @@ -68,7 +120,7 @@ func readHandler(filename string, rf io.ReaderFrom) error { log.Printf("{%s} completed RRQ '%s' bytes:%d,duration:%s", from, filename, n, elapsed) }() - req, err := http.NewRequest("GET", u.String(), nil) + req, err := http.NewRequest("GET", http_url.String(), nil) setForwardedHeader(req, from) res, err := client.Do(req) @@ -106,9 +158,15 @@ func writeHandler(filename string, wt io.WriterTo) error { from := raddr.String() log.Printf("{%s} received WRQ '%s'", filename, from) - - u := *serverUrl - u.Path += filename + if ! strings.HasPrefix(filename, "/") { + filename = "/" + filename + } + http_url, err := map_tftp_path(filename) + if err != nil { + log.Printf("Cannot map '%s'", filename) + return err + } + log.Printf("Translate to '%s'", http_url.String()) start := time.Now() var n int64 @@ -125,7 +183,7 @@ func writeHandler(filename string, wt io.WriterTo) error { done <- n }() - req, err := http.NewRequest("PUT", u.String(), r) + req, err := http.NewRequest("PUT", http_url.String(), r) setForwardedHeader(req, from) req.Header.Add("Content-Type", "application/octet-stream") req.ContentLength = -1 @@ -160,7 +218,7 @@ func main() { s.SetRetries(*tftpRetries) // s.SetBackoff(func (int) time.Duration { return 0 }) - log.Printf("proxying TFTP requests on %s to %s", *listenAddr, serverUrl) + log.Printf("proxying TFTP requests on %s", *listenAddr) err := s.ListenAndServe(*listenAddr) if err != nil { log.Fatalf("server: %v\n", err)