1
mirror of https://github.com/rclone/rclone synced 2025-01-25 07:47:29 +01:00

http servers: allow CORS to be set with --allow-origin flag - fixes #5078

Some changes about test cases:
Because MiddlewareCORS will return early on OPTIONS request,
this middleware should only be used once at NewServer function.
Test cases should pass AllowOrigin config instead of adding
this middleware again.

A new test case was added to test CORS preflight request with
an authenticator. Preflight request should always return 200 OK
regardless of autentications.

Co-authored-by: yuudi <yuudi@users.noreply.github.com>
This commit is contained in:
yuudi 2023-07-26 05:15:54 -04:00 committed by GitHub
parent 3ed4a2e963
commit 6c8148ef39
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 95 additions and 75 deletions

View File

@ -30,7 +30,6 @@ type Options struct {
WebGUIForceUpdate bool // set to force download new update
WebGUINoOpenBrowser bool // set to disable auto opening browser
WebGUIFetchURL string // set the default url for fetching webgui
AccessControlAllowOrigin string // set the access control for CORS configuration
EnableMetrics bool // set to disable prometheus metrics on /metrics
JobExpireDuration time.Duration
JobExpireInterval time.Duration

View File

@ -27,7 +27,6 @@ func AddFlags(flagSet *pflag.FlagSet) {
flags.BoolVarP(flagSet, &Opt.WebGUIForceUpdate, "rc-web-gui-force-update", "", false, "Force update to latest version of web gui")
flags.BoolVarP(flagSet, &Opt.WebGUINoOpenBrowser, "rc-web-gui-no-open-browser", "", false, "Don't open the browser automatically")
flags.StringVarP(flagSet, &Opt.WebGUIFetchURL, "rc-web-fetch-url", "", "https://api.github.com/repos/rclone/rclone-webui-react/releases/latest", "URL to fetch the releases for webgui")
flags.StringVarP(flagSet, &Opt.AccessControlAllowOrigin, "rc-allow-origin", "", "", "Set the allowed origin for CORS")
flags.BoolVarP(flagSet, &Opt.EnableMetrics, "rc-enable-metrics", "", false, "Enable prometheus metrics on /metrics")
flags.DurationVarP(flagSet, &Opt.JobExpireDuration, "rc-job-expire-duration", "", Opt.JobExpireDuration, "Expire finished async jobs older than this value")
flags.DurationVarP(flagSet, &Opt.JobExpireInterval, "rc-job-expire-interval", "", Opt.JobExpireInterval, "Interval to check for expired async jobs")

View File

@ -15,7 +15,6 @@ import (
"regexp"
"sort"
"strings"
"sync"
"time"
"github.com/go-chi/chi/v5/middleware"
@ -38,7 +37,6 @@ import (
)
var promHandler http.Handler
var onlyOnceWarningAllowOrigin sync.Once
func init() {
rcloneCollector := accounting.NewRcloneCollector(context.Background())
@ -214,23 +212,6 @@ func writeError(path string, in rc.Params, w http.ResponseWriter, err error, sta
func (s *Server) handler(w http.ResponseWriter, r *http.Request) {
path := strings.TrimLeft(r.URL.Path, "/")
allowOrigin := rcflags.Opt.AccessControlAllowOrigin
if allowOrigin != "" {
onlyOnceWarningAllowOrigin.Do(func() {
if allowOrigin == "*" {
fs.Logf(nil, "Warning: Allow origin set to *. This can cause serious security problems.")
}
})
w.Header().Add("Access-Control-Allow-Origin", allowOrigin)
} else {
urls := s.server.URLs()
if len(urls) == 1 {
w.Header().Add("Access-Control-Allow-Origin", urls[0])
} else {
fs.Errorf(nil, "Warning, need exactly 1 URL for Access-Control-Allow-Origin, got %d %q", len(urls), urls)
}
}
// echo back access control headers client needs
//reqAccessHeaders := r.Header.Get("Access-Control-Request-Headers")
w.Header().Add("Access-Control-Request-Method", "POST, OPTIONS, GET, HEAD")

View File

@ -552,32 +552,6 @@ Unknown command
testServer(t, tests, &opt)
}
func TestMethods(t *testing.T) {
tests := []testRun{{
Name: "options",
URL: "",
Method: "OPTIONS",
Status: http.StatusOK,
Expected: "",
Headers: map[string]string{
"Access-Control-Allow-Origin": "testURL",
"Access-Control-Request-Method": "POST, OPTIONS, GET, HEAD",
"Access-Control-Allow-Headers": "authorization, Content-Type",
},
}, {
Name: "bad",
URL: "",
Method: "POTATO",
Status: http.StatusMethodNotAllowed,
Expected: `Method Not Allowed
`,
}}
opt := newTestOpt()
opt.Serve = true
opt.Files = testFs
testServer(t, tests, &opt)
}
func TestMetrics(t *testing.T) {
stats := accounting.GlobalStats()
tests := makeMetricsTestCases(stats)

View File

@ -181,6 +181,13 @@ func MiddlewareCORS(allowOrigin string) Middleware {
w.Header().Add("Access-Control-Request-Method", "POST, OPTIONS, GET, HEAD")
w.Header().Add("Access-Control-Allow-Headers", "authorization, Content-Type")
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK)
return
// Because CORS preflight OPTIONS requests are not authenticated,
// and require a 200 OK response, we will return early here.
}
next.ServeHTTP(w, r)
})
}

View File

@ -331,21 +331,20 @@ func TestMiddlewareCORS(t *testing.T) {
servers := []struct {
name string
http Config
origin string
}{
{
name: "EmptyOrigin",
http: Config{
ListenAddr: []string{"127.0.0.1:0"},
AllowOrigin: "",
},
origin: "",
},
{
name: "CustomOrigin",
http: Config{
ListenAddr: []string{"127.0.0.1:0"},
AllowOrigin: "http://test.rclone.org",
},
origin: "http://test.rclone.org",
},
}
@ -357,8 +356,6 @@ func TestMiddlewareCORS(t *testing.T) {
require.NoError(t, s.Shutdown())
}()
s.Router().Use(MiddlewareCORS(ss.origin))
expected := []byte("data")
s.Router().Mount("/", testEchoHandler(expected))
s.Serve()
@ -384,8 +381,69 @@ func TestMiddlewareCORS(t *testing.T) {
}
expectedOrigin := url
if ss.origin != "" {
expectedOrigin = ss.origin
if ss.http.AllowOrigin != "" {
expectedOrigin = ss.http.AllowOrigin
}
require.Equal(t, expectedOrigin, resp.Header.Get("Access-Control-Allow-Origin"), "allow origin should match")
})
}
}
func TestMiddlewareCORSWithAuth(t *testing.T) {
authServers := []struct {
name string
http Config
auth AuthConfig
}{
{
name: "ServerWithAuth",
http: Config{
ListenAddr: []string{"127.0.0.1:0"},
AllowOrigin: "http://test.rclone.org",
},
auth: AuthConfig{
Realm: "test",
BasicUser: "test_user",
BasicPass: "test_pass",
},
},
}
for _, ss := range authServers {
t.Run(ss.name, func(t *testing.T) {
s, err := NewServer(context.Background(), WithConfig(ss.http))
require.NoError(t, err)
defer func() {
require.NoError(t, s.Shutdown())
}()
expected := []byte("data")
s.Router().Mount("/", testEchoHandler(expected))
s.Serve()
url := testGetServerURL(t, s)
client := &http.Client{}
req, err := http.NewRequest("OPTIONS", url, nil)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
defer func() {
_ = resp.Body.Close()
}()
require.Equal(t, http.StatusOK, resp.StatusCode, "OPTIONS should return ok even if not authenticated")
testExpectRespBody(t, resp, []byte{})
for _, key := range _testCORSHeaderKeys {
require.Contains(t, resp.Header, key, "CORS headers should be sent even if not authenticated")
}
expectedOrigin := url
if ss.http.AllowOrigin != "" {
expectedOrigin = ss.http.AllowOrigin
}
require.Equal(t, expectedOrigin, resp.Header.Get("Access-Control-Allow-Origin"), "allow origin should match")
})

View File

@ -109,6 +109,7 @@ type Config struct {
TLSKeyBody []byte // TLS PEM Private key body, ignores TLSKey
ClientCA string // Client certificate authority to verify clients with
MinTLSVersion string // MinTLSVersion contains the minimum TLS version that is acceptable.
AllowOrigin string // AllowOrigin sets the Access-Control-Allow-Origin header
}
// AddFlagsPrefix adds flags for the httplib
@ -122,6 +123,7 @@ func (cfg *Config) AddFlagsPrefix(flagSet *pflag.FlagSet, prefix string) {
flags.StringVarP(flagSet, &cfg.ClientCA, prefix+"client-ca", "", cfg.ClientCA, "Client certificate authority to verify clients with")
flags.StringVarP(flagSet, &cfg.BaseURL, prefix+"baseurl", "", cfg.BaseURL, "Prefix for URLs - leave blank for root")
flags.StringVarP(flagSet, &cfg.MinTLSVersion, prefix+"min-tls-version", "", cfg.MinTLSVersion, "Minimum TLS version that is acceptable")
flags.StringVarP(flagSet, &cfg.AllowOrigin, prefix+"allow-origin", "", cfg.AllowOrigin, "Origin which cross-domain request (CORS) can be executed from")
}
// AddHTTPFlagsPrefix adds flags for the httplib
@ -236,6 +238,8 @@ func NewServer(ctx context.Context, options ...Option) (*Server, error) {
return nil, err
}
s.mux.Use(MiddlewareCORS(s.cfg.AllowOrigin))
s.initAuth()
for _, addr := range s.cfg.ListenAddr {

View File

@ -82,8 +82,6 @@ func TestNewServerUnix(t *testing.T) {
require.Empty(t, s.URLs(), "unix socket should not appear in URLs")
s.Router().Use(MiddlewareCORS(""))
expected := []byte("hello world")
s.Router().Mount("/", testEchoHandler(expected))
s.Serve()