diff --git a/backend/sftp/sftp.go b/backend/sftp/sftp.go index 0dee508c6..5e6898949 100644 --- a/backend/sftp/sftp.go +++ b/backend/sftp/sftp.go @@ -16,6 +16,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" "github.com/pkg/errors" @@ -286,6 +287,7 @@ type Fs struct { drain *time.Timer // used to drain the pool when we stop using the connections pacer *fs.Pacer // pacer for operations savedpswd string + transfers int32 // count in use references } // Object is a remote SFTP file that has been stat'd (so it exists, but is not necessarily open for reading) @@ -348,6 +350,23 @@ func (c *conn) closed() error { return nil } +// Show that we are doing an upload or download +// +// Call removeTransfer() when done +func (f *Fs) addTransfer() { + atomic.AddInt32(&f.transfers, 1) +} + +// Show the upload or download done +func (f *Fs) removeTransfer() { + atomic.AddInt32(&f.transfers, -1) +} + +// getTransfers shows whether there are any transfers in progress +func (f *Fs) getTransfers() int32 { + return atomic.LoadInt32(&f.transfers) +} + // Open a new connection to the SFTP server. func (f *Fs) sftpConnection(ctx context.Context) (c *conn, err error) { // Rate limit rate of new connections @@ -478,6 +497,13 @@ func (f *Fs) putSftpConnection(pc **conn, err error) { func (f *Fs) drainPool(ctx context.Context) (err error) { f.poolMu.Lock() defer f.poolMu.Unlock() + if transfers := f.getTransfers(); transfers != 0 { + fs.Debugf(f, "Not closing %d unused connections as %d transfers in progress", len(f.pool), transfers) + if f.opt.IdleTimeout > 0 { + f.drain.Reset(time.Duration(f.opt.IdleTimeout)) // nudge on the pool emptying timer + } + return nil + } if f.opt.IdleTimeout > 0 { f.drain.Stop() } @@ -1384,18 +1410,22 @@ func (o *Object) Storable() bool { // objectReader represents a file open for reading on the SFTP server type objectReader struct { + f *Fs sftpFile *sftp.File pipeReader *io.PipeReader done chan struct{} } -func newObjectReader(sftpFile *sftp.File) *objectReader { +func (f *Fs) newObjectReader(sftpFile *sftp.File) *objectReader { pipeReader, pipeWriter := io.Pipe() file := &objectReader{ + f: f, sftpFile: sftpFile, pipeReader: pipeReader, done: make(chan struct{}), } + // Show connection in use + f.addTransfer() go func() { // Use sftpFile.WriteTo to pump data so that it gets a @@ -1425,6 +1455,8 @@ func (file *objectReader) Close() (err error) { _ = file.pipeReader.Close() // Wait for the background process to finish <-file.done + // Show connection no longer in use + file.f.removeTransfer() return err } @@ -1458,12 +1490,14 @@ func (o *Object) Open(ctx context.Context, options ...fs.OpenOption) (in io.Read return nil, errors.Wrap(err, "Open Seek failed") } } - in = readers.NewLimitedReadCloser(newObjectReader(sftpFile), limit) + in = readers.NewLimitedReadCloser(o.fs.newObjectReader(sftpFile), limit) return in, nil } // Update a remote sftp file using the data and ModTime from func (o *Object) Update(ctx context.Context, in io.Reader, src fs.ObjectInfo, options ...fs.OpenOption) error { + o.fs.addTransfer() // Show transfer in progress + defer o.fs.removeTransfer() // Clear the hash cache since we are about to update the object o.md5sum = nil o.sha1sum = nil