diff --git a/lib/http/middleware.go b/lib/http/middleware.go index 3ee5c0cdc..06b82278b 100644 --- a/lib/http/middleware.go +++ b/lib/http/middleware.go @@ -195,6 +195,14 @@ func MiddlewareCORS(allowOrigin string) Middleware { // MiddlewareStripPrefix instantiates middleware that removes the BaseURL from the path func MiddlewareStripPrefix(prefix string) Middleware { return func(next http.Handler) http.Handler { - return http.StripPrefix(prefix, next) + stripPrefixHandler := http.StripPrefix(prefix, next) + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Allow OPTIONS on the root only + if r.URL.Path == "/" && r.Method == "OPTIONS" { + next.ServeHTTP(w, r) + return + } + stripPrefixHandler.ServeHTTP(w, r) + }) } } diff --git a/lib/http/middleware_test.go b/lib/http/middleware_test.go index 15d0b0525..845e28013 100644 --- a/lib/http/middleware_test.go +++ b/lib/http/middleware_test.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "net/http" + "strings" "testing" "github.com/stretchr/testify/require" @@ -329,8 +330,11 @@ var _testCORSHeaderKeys = []string{ func TestMiddlewareCORS(t *testing.T) { servers := []struct { - name string - http Config + name string + http Config + tryRoot bool + method string + status int }{ { name: "CustomOrigin", @@ -338,6 +342,40 @@ func TestMiddlewareCORS(t *testing.T) { ListenAddr: []string{"127.0.0.1:0"}, AllowOrigin: "http://test.rclone.org", }, + method: "GET", + status: http.StatusOK, + }, + { + name: "WithBaseURL", + http: Config{ + ListenAddr: []string{"127.0.0.1:0"}, + AllowOrigin: "http://test.rclone.org", + BaseURL: "/baseurl/", + }, + method: "GET", + status: http.StatusOK, + }, + { + name: "WithBaseURLTryRootGET", + http: Config{ + ListenAddr: []string{"127.0.0.1:0"}, + AllowOrigin: "http://test.rclone.org", + BaseURL: "/baseurl/", + }, + method: "GET", + status: http.StatusNotFound, + tryRoot: true, + }, + { + name: "WithBaseURLTryRootOPTIONS", + http: Config{ + ListenAddr: []string{"127.0.0.1:0"}, + AllowOrigin: "http://test.rclone.org", + BaseURL: "/baseurl/", + }, + method: "OPTIONS", + status: http.StatusOK, + tryRoot: true, }, } @@ -354,9 +392,14 @@ func TestMiddlewareCORS(t *testing.T) { s.Serve() url := testGetServerURL(t, s) + // Try the query on the root, ignoring the baseURL + if ss.tryRoot { + slash := strings.LastIndex(url[:len(url)-1], "/") + url = url[:slash+1] + } client := &http.Client{} - req, err := http.NewRequest("GET", url, nil) + req, err := http.NewRequest(ss.method, url, nil) require.NoError(t, err) resp, err := client.Do(req) @@ -365,8 +408,11 @@ func TestMiddlewareCORS(t *testing.T) { _ = resp.Body.Close() }() - require.Equal(t, http.StatusOK, resp.StatusCode, "should return ok") + require.Equal(t, ss.status, resp.StatusCode, "should return expected error code") + if ss.status == http.StatusNotFound { + return + } testExpectRespBody(t, resp, expected) for _, key := range _testCORSHeaderKeys {