mirror of
https://github.com/mpolden/echoip.git
synced 2024-11-10 07:27:22 +01:00
Add support for multiple trusted headers
This commit is contained in:
parent
e282ac2729
commit
91f0c17c94
@ -14,13 +14,13 @@ import (
|
||||
|
||||
func main() {
|
||||
var opts struct {
|
||||
CountryDBPath string `short:"f" long:"country-db" description:"Path to GeoIP country database" value-name:"FILE" default:""`
|
||||
CityDBPath string `short:"c" long:"city-db" description:"Path to GeoIP city database" value-name:"FILE" default:""`
|
||||
Listen string `short:"l" long:"listen" description:"Listening address" value-name:"ADDR" default:":8080"`
|
||||
ReverseLookup bool `short:"r" long:"reverse-lookup" description:"Perform reverse hostname lookups"`
|
||||
PortLookup bool `short:"p" long:"port-lookup" description:"Enable port lookup"`
|
||||
Template string `short:"t" long:"template" description:"Path to template" default:"index.html" value-name:"FILE"`
|
||||
IPHeader string `short:"H" long:"trusted-header" description:"Header to trust for remote IP, if present (e.g. X-Real-IP)" value-name:"NAME"`
|
||||
CountryDBPath string `short:"f" long:"country-db" description:"Path to GeoIP country database" value-name:"FILE" default:""`
|
||||
CityDBPath string `short:"c" long:"city-db" description:"Path to GeoIP city database" value-name:"FILE" default:""`
|
||||
Listen string `short:"l" long:"listen" description:"Listening address" value-name:"ADDR" default:":8080"`
|
||||
ReverseLookup bool `short:"r" long:"reverse-lookup" description:"Perform reverse hostname lookups"`
|
||||
PortLookup bool `short:"p" long:"port-lookup" description:"Enable port lookup"`
|
||||
Template string `short:"t" long:"template" description:"Path to template" default:"index.html" value-name:"FILE"`
|
||||
IPHeaders []string `short:"H" long:"trusted-header" description:"Header to trust for remote IP, if present (e.g. X-Real-IP)" value-name:"NAME"`
|
||||
}
|
||||
_, err := flags.ParseArgs(&opts, os.Args)
|
||||
if err != nil {
|
||||
@ -35,7 +35,7 @@ func main() {
|
||||
|
||||
server := http.New(db)
|
||||
server.Template = opts.Template
|
||||
server.IPHeader = opts.IPHeader
|
||||
server.IPHeaders = opts.IPHeaders
|
||||
if opts.ReverseLookup {
|
||||
log.Println("Enabling reverse lookup")
|
||||
server.LookupAddr = iputil.LookupAddr
|
||||
@ -44,8 +44,8 @@ func main() {
|
||||
log.Println("Enabling port lookup")
|
||||
server.LookupPort = iputil.LookupPort
|
||||
}
|
||||
if opts.IPHeader != "" {
|
||||
log.Printf("Trusting header %s to contain correct remote IP", opts.IPHeader)
|
||||
if len(opts.IPHeaders) > 0 {
|
||||
log.Printf("Trusting header(s) %+v to contain correct remote IP", opts.IPHeaders)
|
||||
}
|
||||
|
||||
log.Printf("Listening on http://%s", opts.Listen)
|
||||
|
18
http/http.go
18
http/http.go
@ -22,7 +22,7 @@ const (
|
||||
|
||||
type Server struct {
|
||||
Template string
|
||||
IPHeader string
|
||||
IPHeaders []string
|
||||
LookupAddr func(net.IP) (string, error)
|
||||
LookupPort func(net.IP, uint64) error
|
||||
db database.Client
|
||||
@ -47,8 +47,14 @@ func New(db database.Client) *Server {
|
||||
return &Server{db: db}
|
||||
}
|
||||
|
||||
func ipFromRequest(header string, r *http.Request) (net.IP, error) {
|
||||
remoteIP := r.Header.Get(header)
|
||||
func ipFromRequest(headers []string, r *http.Request) (net.IP, error) {
|
||||
remoteIP := ""
|
||||
for _, header := range headers {
|
||||
remoteIP = r.Header.Get(header)
|
||||
if remoteIP != "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
if remoteIP == "" {
|
||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
@ -64,7 +70,7 @@ func ipFromRequest(header string, r *http.Request) (net.IP, error) {
|
||||
}
|
||||
|
||||
func (s *Server) newResponse(r *http.Request) (Response, error) {
|
||||
ip, err := ipFromRequest(s.IPHeader, r)
|
||||
ip, err := ipFromRequest(s.IPHeaders, r)
|
||||
if err != nil {
|
||||
return Response{}, err
|
||||
}
|
||||
@ -91,7 +97,7 @@ func (s *Server) newPortResponse(r *http.Request) (PortResponse, error) {
|
||||
if err != nil || port < 1 || port > 65355 {
|
||||
return PortResponse{Port: port}, fmt.Errorf("invalid port: %d", port)
|
||||
}
|
||||
ip, err := ipFromRequest(s.IPHeader, r)
|
||||
ip, err := ipFromRequest(s.IPHeaders, r)
|
||||
if err != nil {
|
||||
return PortResponse{Port: port}, err
|
||||
}
|
||||
@ -104,7 +110,7 @@ func (s *Server) newPortResponse(r *http.Request) (PortResponse, error) {
|
||||
}
|
||||
|
||||
func (s *Server) CLIHandler(w http.ResponseWriter, r *http.Request) *appError {
|
||||
ip, err := ipFromRequest(s.IPHeader, r)
|
||||
ip, err := ipFromRequest(s.IPHeaders, r)
|
||||
if err != nil {
|
||||
return internalServerError(err)
|
||||
}
|
||||
|
@ -149,16 +149,17 @@ func TestJSONHandlers(t *testing.T) {
|
||||
|
||||
func TestIPFromRequest(t *testing.T) {
|
||||
var tests = []struct {
|
||||
remoteAddr string
|
||||
headerKey string
|
||||
headerValue string
|
||||
trustedHeader string
|
||||
out string
|
||||
remoteAddr string
|
||||
headerKey string
|
||||
headerValue string
|
||||
trustedHeaders []string
|
||||
out string
|
||||
}{
|
||||
{"127.0.0.1:9999", "", "", "", "127.0.0.1"}, // No header given
|
||||
{"127.0.0.1:9999", "X-Real-IP", "1.3.3.7", "", "127.0.0.1"}, // Trusted header is empty
|
||||
{"127.0.0.1:9999", "X-Real-IP", "1.3.3.7", "X-Foo-Bar", "127.0.0.1"}, // Trusted header does not match
|
||||
{"127.0.0.1:9999", "X-Real-IP", "1.3.3.7", "X-Real-IP", "1.3.3.7"}, // Trusted header matches
|
||||
{"127.0.0.1:9999", "", "", nil, "127.0.0.1"}, // No header given
|
||||
{"127.0.0.1:9999", "X-Real-IP", "1.3.3.7", nil, "127.0.0.1"}, // Trusted header is empty
|
||||
{"127.0.0.1:9999", "X-Real-IP", "1.3.3.7", []string{"X-Foo-Bar"}, "127.0.0.1"}, // Trusted header does not match
|
||||
{"127.0.0.1:9999", "X-Real-IP", "1.3.3.7", []string{"X-Real-IP", "X-Forwarded-For"}, "1.3.3.7"}, // Trusted header matches
|
||||
{"127.0.0.1:9999", "X-Forwarded-For", "1.3.3.7", []string{"X-Real-IP", "X-Forwarded-For"}, "1.3.3.7"}, // Second trusted header matches
|
||||
}
|
||||
for _, tt := range tests {
|
||||
r := &http.Request{
|
||||
@ -166,7 +167,7 @@ func TestIPFromRequest(t *testing.T) {
|
||||
Header: http.Header{},
|
||||
}
|
||||
r.Header.Add(tt.headerKey, tt.headerValue)
|
||||
ip, err := ipFromRequest(tt.trustedHeader, r)
|
||||
ip, err := ipFromRequest(tt.trustedHeaders, r)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user