diff --git a/fs/accounting/accounting.go b/fs/accounting/accounting.go index 2686f2ebf..eeef4715a 100644 --- a/fs/accounting/accounting.go +++ b/fs/accounting/accounting.go @@ -216,6 +216,10 @@ func (acc *Account) averageLoop() { // Check the read before it has happened is valid returning the number // of bytes remaining to read. func (acc *Account) checkReadBefore() (bytesUntilLimit int64, err error) { + // Check to see if context is cancelled + if err = acc.ctx.Err(); err != nil { + return 0, err + } acc.values.mu.Lock() if acc.values.max >= 0 { bytesUntilLimit = acc.values.max - acc.stats.GetBytes() @@ -235,7 +239,7 @@ func (acc *Account) checkReadBefore() (bytesUntilLimit int64, err error) { } // Check the read call after the read has happened -func checkReadAfter(bytesUntilLimit int64, n int, err error) (outN int, outErr error) { +func (acc *Account) checkReadAfter(bytesUntilLimit int64, n int, err error) (outN int, outErr error) { bytesUntilLimit -= int64(n) if bytesUntilLimit < 0 { // chop the overage off @@ -304,7 +308,7 @@ func (acc *Account) read(in io.Reader, p []byte) (n int, err error) { if err == nil { n, err = in.Read(p) acc.accountRead(n) - n, err = checkReadAfter(bytesUntilLimit, n, err) + n, err = acc.checkReadAfter(bytesUntilLimit, n, err) } return n, err } @@ -333,7 +337,7 @@ func (awt *accountWriteTo) Write(p []byte) (n int, err error) { bytesUntilLimit, err := awt.acc.checkReadBefore() if err == nil { n, err = awt.w.Write(p) - n, err = checkReadAfter(bytesUntilLimit, n, err) + n, err = awt.acc.checkReadAfter(bytesUntilLimit, n, err) awt.acc.accountRead(n) } return n, err @@ -361,7 +365,7 @@ func (acc *Account) AccountRead(n int) (err error) { defer acc.mu.Unlock() bytesUntilLimit, err := acc.checkReadBefore() if err == nil { - n, err = checkReadAfter(bytesUntilLimit, n, err) + n, err = acc.checkReadAfter(bytesUntilLimit, n, err) acc.accountRead(n) } return err diff --git a/fs/accounting/accounting_test.go b/fs/accounting/accounting_test.go index 8161522f2..67b7e1d49 100644 --- a/fs/accounting/accounting_test.go +++ b/fs/accounting/accounting_test.go @@ -312,6 +312,25 @@ func TestAccountMaxTransferWriteTo(t *testing.T) { assert.Equal(t, ErrorMaxTransferLimitReachedFatal, err) } +func TestAccountReadCtx(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + in := ioutil.NopCloser(bytes.NewBuffer(make([]byte, 100))) + stats := NewStats() + acc := newAccountSizeName(ctx, stats, in, 1, "test") + + var b = make([]byte, 10) + + n, err := acc.Read(b) + assert.Equal(t, 10, n) + assert.NoError(t, err) + + cancel() + + n, err = acc.Read(b) + assert.Equal(t, 0, n) + assert.Equal(t, context.Canceled, err) +} + func TestShortenName(t *testing.T) { for _, test := range []struct { in string diff --git a/fs/sync/sync_test.go b/fs/sync/sync_test.go index 95cdfcef6..b43f2adc2 100644 --- a/fs/sync/sync_test.go +++ b/fs/sync/sync_test.go @@ -1040,8 +1040,6 @@ func TestSyncWithMaxDuration(t *testing.T) { startTime := time.Now() err := Sync(context.Background(), r.Fremote, r.Flocal, false) require.Equal(t, context.DeadlineExceeded, errors.Cause(err)) - err = accounting.GlobalStats().GetLastError() - require.NoError(t, err) elapsed := time.Since(startTime) maxTransferTime := (time.Duration(len(testFiles)) * 60 * time.Second) / time.Duration(bytesPerSecond)