lib/pool: add DelayAccounting() to fix accounting when reading hashes

This commit is contained in:
Nick Craig-Wood 2023-08-24 16:42:09 +01:00
parent f4b1a51af6
commit bc986b44b2
2 changed files with 219 additions and 76 deletions

View File

@ -19,6 +19,8 @@ type RW struct {
out int // offset we are reading from
lastOffset int // size in last page
account RWAccount // account for a read
reads int // count how many times the data has been read
accountOn int // only account on or after this read
}
var (
@ -50,10 +52,40 @@ func (rw *RW) SetAccounting(account RWAccount) *RW {
return rw
}
// DelayAccountinger enables an accounting delay
type DelayAccountinger interface {
// DelayAccounting makes sure the accounting function only
// gets called on the i-th or later read of the data from this
// point (counting from 1).
//
// This is useful so that we don't account initial reads of
// the data e.g. when calculating hashes.
//
// Set this to 0 to account everything.
DelayAccounting(i int)
}
// DelayAccounting makes sure the accounting function only gets called
// on the i-th or later read of the data from this point (counting
// from 1).
//
// This is useful so that we don't account initial reads of the data
// e.g. when calculating hashes.
//
// Set this to 0 to account everything.
func (rw *RW) DelayAccounting(i int) {
rw.accountOn = i
rw.reads = 0
}
// Returns the page and offset of i for reading.
//
// Ensure there are pages before calling this.
func (rw *RW) readPage(i int) (page []byte) {
// Count a read of the data if we read the first page
if i == 0 {
rw.reads++
}
pageNumber := i / rw.pool.bufferSize
offset := i % rw.pool.bufferSize
page = rw.pages[pageNumber]
@ -69,7 +101,14 @@ func (rw *RW) accountRead(n int) error {
if rw.account == nil {
return nil
}
return rw.account(n)
// Don't start accounting until we've reached this many reads
//
// rw.reads will be 1 the first time this is called
// rw.accountOn 2 means start accounting on the 2nd read through
if rw.reads >= rw.accountOn {
return rw.account(n)
}
return nil
}
// Read reads up to len(p) bytes into p. It returns the number of
@ -227,10 +266,11 @@ func (rw *RW) Size() int64 {
// Check interfaces
var (
_ io.Reader = (*RW)(nil)
_ io.ReaderFrom = (*RW)(nil)
_ io.Writer = (*RW)(nil)
_ io.WriterTo = (*RW)(nil)
_ io.Seeker = (*RW)(nil)
_ io.Closer = (*RW)(nil)
_ io.Reader = (*RW)(nil)
_ io.ReaderFrom = (*RW)(nil)
_ io.Writer = (*RW)(nil)
_ io.WriterTo = (*RW)(nil)
_ io.Seeker = (*RW)(nil)
_ io.Closer = (*RW)(nil)
_ DelayAccountinger = (*RW)(nil)
)

View File

@ -9,6 +9,7 @@ import (
"github.com/rclone/rclone/lib/random"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const blockSize = 4096
@ -178,71 +179,164 @@ func TestRW(t *testing.T) {
assert.Equal(t, testData[7:10], dst)
})
errBoom := errors.New("accounting error")
t.Run("Account", func(t *testing.T) {
errBoom := errors.New("accounting error")
t.Run("AccountRead", func(t *testing.T) {
// Test accounting errors
rw := newRW()
defer close(rw)
t.Run("Read", func(t *testing.T) {
rw := newRW()
defer close(rw)
var total int
rw.SetAccounting(func(n int) error {
total += n
return nil
var total int
rw.SetAccounting(func(n int) error {
total += n
return nil
})
dst = make([]byte, 3)
n, err = rw.Read(dst)
assert.Equal(t, 3, n)
assert.NoError(t, err)
assert.Equal(t, 3, total)
})
dst = make([]byte, 3)
n, err = rw.Read(dst)
assert.Equal(t, 3, n)
assert.NoError(t, err)
assert.Equal(t, 3, total)
})
t.Run("WriteTo", func(t *testing.T) {
rw := newRW()
defer close(rw)
var b bytes.Buffer
t.Run("AccountWriteTo", func(t *testing.T) {
rw := newRW()
defer close(rw)
var b bytes.Buffer
var total int
rw.SetAccounting(func(n int) error {
total += n
return nil
})
var total int
rw.SetAccounting(func(n int) error {
total += n
return nil
n, err := rw.WriteTo(&b)
assert.NoError(t, err)
assert.Equal(t, 10, total)
assert.Equal(t, int64(10), n)
assert.Equal(t, testData, b.Bytes())
})
n, err := rw.WriteTo(&b)
assert.NoError(t, err)
assert.Equal(t, 10, total)
assert.Equal(t, int64(10), n)
assert.Equal(t, testData, b.Bytes())
})
t.Run("ReadDelay", func(t *testing.T) {
rw := newRW()
defer close(rw)
t.Run("AccountReadError", func(t *testing.T) {
// Test accounting errors
rw := newRW()
defer close(rw)
var total int
rw.SetAccounting(func(n int) error {
total += n
return nil
})
rw.SetAccounting(func(n int) error {
return errBoom
rewind := func() {
_, err := rw.Seek(0, io.SeekStart)
require.NoError(t, err)
}
rw.DelayAccounting(3)
dst = make([]byte, 16)
n, err = rw.Read(dst)
assert.Equal(t, 10, n)
assert.Equal(t, io.EOF, err)
assert.Equal(t, 0, total)
rewind()
n, err = rw.Read(dst)
assert.Equal(t, 10, n)
assert.Equal(t, io.EOF, err)
assert.Equal(t, 0, total)
rewind()
n, err = rw.Read(dst)
assert.Equal(t, 10, n)
assert.Equal(t, io.EOF, err)
assert.Equal(t, 10, total)
rewind()
n, err = rw.Read(dst)
assert.Equal(t, 10, n)
assert.Equal(t, io.EOF, err)
assert.Equal(t, 20, total)
rewind()
})
dst = make([]byte, 3)
n, err = rw.Read(dst)
assert.Equal(t, 3, n)
assert.Equal(t, errBoom, err)
})
t.Run("WriteToDelay", func(t *testing.T) {
rw := newRW()
defer close(rw)
var b bytes.Buffer
t.Run("AccountWriteToError", func(t *testing.T) {
rw := newRW()
defer close(rw)
rw.SetAccounting(func(n int) error {
return errBoom
var total int
rw.SetAccounting(func(n int) error {
total += n
return nil
})
rw.DelayAccounting(3)
rewind := func() {
_, err := rw.Seek(0, io.SeekStart)
require.NoError(t, err)
b.Reset()
}
n, err := rw.WriteTo(&b)
assert.NoError(t, err)
assert.Equal(t, 0, total)
assert.Equal(t, int64(10), n)
assert.Equal(t, testData, b.Bytes())
rewind()
n, err = rw.WriteTo(&b)
assert.NoError(t, err)
assert.Equal(t, 0, total)
assert.Equal(t, int64(10), n)
assert.Equal(t, testData, b.Bytes())
rewind()
n, err = rw.WriteTo(&b)
assert.NoError(t, err)
assert.Equal(t, 10, total)
assert.Equal(t, int64(10), n)
assert.Equal(t, testData, b.Bytes())
rewind()
n, err = rw.WriteTo(&b)
assert.NoError(t, err)
assert.Equal(t, 20, total)
assert.Equal(t, int64(10), n)
assert.Equal(t, testData, b.Bytes())
rewind()
})
var b bytes.Buffer
n, err := rw.WriteTo(&b)
assert.Equal(t, errBoom, err)
assert.Equal(t, int64(10), n)
assert.Equal(t, testData, b.Bytes())
t.Run("ReadError", func(t *testing.T) {
// Test accounting errors
rw := newRW()
defer close(rw)
rw.SetAccounting(func(n int) error {
return errBoom
})
dst = make([]byte, 3)
n, err = rw.Read(dst)
assert.Equal(t, 3, n)
assert.Equal(t, errBoom, err)
})
t.Run("WriteToError", func(t *testing.T) {
rw := newRW()
defer close(rw)
rw.SetAccounting(func(n int) error {
return errBoom
})
var b bytes.Buffer
n, err := rw.WriteTo(&b)
assert.Equal(t, errBoom, err)
assert.Equal(t, int64(10), n)
assert.Equal(t, testData, b.Bytes())
})
})
}
@ -363,26 +457,35 @@ func TestRWBoundaryConditions(t *testing.T) {
assert.Equal(t, int64(len(data)), nn)
}
type test struct {
name string
fn func(*RW, []byte, int)
}
// Read and Write the data with a range of block sizes and functions
for _, writeFn := range []func(*RW, []byte, int){write, readFrom} {
for _, readFn := range []func(*RW, []byte, int){read, writeTo} {
for _, size := range sizes {
data := buf[:size]
for _, chunkSize := range sizes {
//t.Logf("Testing size=%d chunkSize=%d", useWrite, size, chunkSize)
rw := NewRW(rwPool)
assert.Equal(t, int64(0), rw.Size())
accounted = 0
rw.SetAccounting(account)
assert.Equal(t, 0, accounted)
writeFn(rw, data, chunkSize)
assert.Equal(t, int64(size), rw.Size())
assert.Equal(t, 0, accounted)
readFn(rw, data, chunkSize)
assert.NoError(t, rw.Close())
assert.Equal(t, size, accounted)
}
for _, write := range []test{{"Write", write}, {"ReadFrom", readFrom}} {
t.Run(write.name, func(t *testing.T) {
for _, read := range []test{{"Read", read}, {"WriteTo", writeTo}} {
t.Run(read.name, func(t *testing.T) {
for _, size := range sizes {
data := buf[:size]
for _, chunkSize := range sizes {
//t.Logf("Testing size=%d chunkSize=%d", useWrite, size, chunkSize)
rw := NewRW(rwPool)
assert.Equal(t, int64(0), rw.Size())
accounted = 0
rw.SetAccounting(account)
assert.Equal(t, 0, accounted)
write.fn(rw, data, chunkSize)
assert.Equal(t, int64(size), rw.Size())
assert.Equal(t, 0, accounted)
read.fn(rw, data, chunkSize)
assert.NoError(t, rw.Close())
assert.Equal(t, size, accounted)
}
}
})
}
}
})
}
}