diff --git a/fs/accounting/accounting.go b/fs/accounting/accounting.go index 1e1e1f29c..0e6b766ec 100644 --- a/fs/accounting/accounting.go +++ b/fs/accounting/accounting.go @@ -154,21 +154,23 @@ func (acc *Account) averageLoop() { } } -// read bytes from the io.Reader passed in and account them -func (acc *Account) read(in io.Reader, p []byte) (n int, err error) { +// Check the read is valid +func (acc *Account) checkRead() (err error) { acc.statmu.Lock() if acc.max >= 0 && Stats.GetBytes() >= acc.max { acc.statmu.Unlock() - return 0, ErrorMaxTransferLimitReached + return ErrorMaxTransferLimitReached } // Set start time. if acc.start.IsZero() { acc.start = time.Now() } acc.statmu.Unlock() + return nil +} - n, err = in.Read(p) - +// Account the read and limit bandwidth +func (acc *Account) accountRead(n int) { // Update Stats acc.statmu.Lock() acc.lpBytes += n @@ -178,7 +180,16 @@ func (acc *Account) read(in io.Reader, p []byte) (n int, err error) { Stats.Bytes(int64(n)) limitBandwidth(n) - return +} + +// read bytes from the io.Reader passed in and account them +func (acc *Account) read(in io.Reader, p []byte) (n int, err error) { + err = acc.checkRead() + if err == nil { + n, err = in.Read(p) + acc.accountRead(n) + } + return n, err } // Read bytes from the object - see io.Reader @@ -188,6 +199,17 @@ func (acc *Account) Read(p []byte) (n int, err error) { return acc.read(acc.in, p) } +// AccountRead account having read n bytes +func (acc *Account) AccountRead(n int) (err error) { + acc.mu.Lock() + defer acc.mu.Unlock() + err = acc.checkRead() + if err == nil { + acc.accountRead(n) + } + return err +} + // Close the object func (acc *Account) Close() error { acc.mu.Lock() @@ -198,6 +220,9 @@ func (acc *Account) Close() error { acc.closed = true close(acc.exit) Stats.inProgress.clear(acc.name) + if acc.close == nil { + return nil + } return acc.close.Close() }