From b7dd3ce6085d8959686e236ea3b616ffcdaaf384 Mon Sep 17 00:00:00 2001
From: Nick Craig-Wood <nick@craig-wood.com>
Date: Thu, 30 Jul 2020 10:52:32 +0100
Subject: [PATCH] s3: preserve metadata when doing multipart copy

Before this change the s3 multipart server side copy was not
preserving the metadata of the object. This was most noticeable
because the modtime was not preserved.

This change fetches the metadata from the object before starting the
copy and overwrites it if requires.

It will also mean any other metadata is preserved.

See: https://forum.rclone.org/t/copying-files-within-a-b2-bucket/16680/70
---
 backend/s3/s3.go | 100 +++++++++++++++++++++++++++--------------------
 1 file changed, 58 insertions(+), 42 deletions(-)

diff --git a/backend/s3/s3.go b/backend/s3/s3.go
index 856927fab..430f06ea6 100644
--- a/backend/s3/s3.go
+++ b/backend/s3/s3.go
@@ -1967,7 +1967,7 @@ func pathEscape(s string) string {
 //
 // It adds the boiler plate to the req passed in and calls the s3
 // method
-func (f *Fs) copy(ctx context.Context, req *s3.CopyObjectInput, dstBucket, dstPath, srcBucket, srcPath string, srcSize int64) error {
+func (f *Fs) copy(ctx context.Context, req *s3.CopyObjectInput, dstBucket, dstPath, srcBucket, srcPath string, src *Object) error {
 	req.Bucket = &dstBucket
 	req.ACL = &f.opt.ACL
 	req.Key = &dstPath
@@ -1983,8 +1983,8 @@ func (f *Fs) copy(ctx context.Context, req *s3.CopyObjectInput, dstBucket, dstPa
 		req.StorageClass = &f.opt.StorageClass
 	}
 
-	if srcSize >= int64(f.opt.CopyCutoff) {
-		return f.copyMultipart(ctx, req, dstBucket, dstPath, srcBucket, srcPath, srcSize)
+	if src.bytes >= int64(f.opt.CopyCutoff) {
+		return f.copyMultipart(ctx, req, dstBucket, dstPath, srcBucket, srcPath, src)
 	}
 	return f.pacer.Call(func() (bool, error) {
 		_, err := f.c.CopyObjectWithContext(ctx, req)
@@ -2005,14 +2005,33 @@ func calculateRange(partSize, partIndex, numParts, totalSize int64) string {
 	return fmt.Sprintf("bytes=%v-%v", start, ends)
 }
 
-func (f *Fs) copyMultipart(ctx context.Context, req *s3.CopyObjectInput, dstBucket, dstPath, srcBucket, srcPath string, srcSize int64) (err error) {
+func (f *Fs) copyMultipart(ctx context.Context, copyReq *s3.CopyObjectInput, dstBucket, dstPath, srcBucket, srcPath string, src *Object) (err error) {
+	info, err := src.headObject(ctx)
+	if err != nil {
+		return err
+	}
+
+	req := &s3.CreateMultipartUploadInput{}
+
+	// Fill in the request from the head info
+	structs.SetFrom(req, info)
+
+	// If copy metadata was set then set the Metadata to that read
+	// from the head request
+	if aws.StringValue(copyReq.MetadataDirective) == s3.MetadataDirectiveCopy {
+		copyReq.Metadata = info.Metadata
+	}
+
+	// Overwrite any from the copyReq
+	structs.SetFrom(req, copyReq)
+
+	req.Bucket = &dstBucket
+	req.Key = &dstPath
+
 	var cout *s3.CreateMultipartUploadOutput
 	if err := f.pacer.Call(func() (bool, error) {
 		var err error
-		cout, err = f.c.CreateMultipartUploadWithContext(ctx, &s3.CreateMultipartUploadInput{
-			Bucket: &dstBucket,
-			Key:    &dstPath,
-		})
+		cout, err = f.c.CreateMultipartUploadWithContext(ctx, req)
 		return f.shouldRetry(err)
 	}); err != nil {
 		return err
@@ -2021,7 +2040,7 @@ func (f *Fs) copyMultipart(ctx context.Context, req *s3.CopyObjectInput, dstBuck
 
 	defer atexit.OnError(&err, func() {
 		// Try to abort the upload, but ignore the error.
-		fs.Debugf(nil, "Cancelling multipart copy")
+		fs.Debugf(src, "Cancelling multipart copy")
 		_ = f.pacer.Call(func() (bool, error) {
 			_, err := f.c.AbortMultipartUploadWithContext(context.Background(), &s3.AbortMultipartUploadInput{
 				Bucket:       &dstBucket,
@@ -2033,33 +2052,23 @@ func (f *Fs) copyMultipart(ctx context.Context, req *s3.CopyObjectInput, dstBuck
 		})
 	})()
 
+	srcSize := src.bytes
 	partSize := int64(f.opt.CopyCutoff)
 	numParts := (srcSize-1)/partSize + 1
 
+	fs.Debugf(src, "Starting  multipart copy with %d parts", numParts)
+
 	var parts []*s3.CompletedPart
 	for partNum := int64(1); partNum <= numParts; partNum++ {
 		if err := f.pacer.Call(func() (bool, error) {
 			partNum := partNum
-			uploadPartReq := &s3.UploadPartCopyInput{
-				Bucket:          &dstBucket,
-				Key:             &dstPath,
-				PartNumber:      &partNum,
-				UploadId:        uid,
-				CopySourceRange: aws.String(calculateRange(partSize, partNum-1, numParts, srcSize)),
-				// Args copy from req
-				CopySource:                     req.CopySource,
-				CopySourceIfMatch:              req.CopySourceIfMatch,
-				CopySourceIfModifiedSince:      req.CopySourceIfModifiedSince,
-				CopySourceIfNoneMatch:          req.CopySourceIfNoneMatch,
-				CopySourceIfUnmodifiedSince:    req.CopySourceIfUnmodifiedSince,
-				CopySourceSSECustomerAlgorithm: req.CopySourceSSECustomerAlgorithm,
-				CopySourceSSECustomerKey:       req.CopySourceSSECustomerKey,
-				CopySourceSSECustomerKeyMD5:    req.CopySourceSSECustomerKeyMD5,
-				RequestPayer:                   req.RequestPayer,
-				SSECustomerAlgorithm:           req.SSECustomerAlgorithm,
-				SSECustomerKey:                 req.SSECustomerKey,
-				SSECustomerKeyMD5:              req.SSECustomerKeyMD5,
-			}
+			uploadPartReq := &s3.UploadPartCopyInput{}
+			structs.SetFrom(uploadPartReq, copyReq)
+			uploadPartReq.Bucket = &dstBucket
+			uploadPartReq.Key = &dstPath
+			uploadPartReq.PartNumber = &partNum
+			uploadPartReq.UploadId = uid
+			uploadPartReq.CopySourceRange = aws.String(calculateRange(partSize, partNum-1, numParts, srcSize))
 			uout, err := f.c.UploadPartCopyWithContext(ctx, uploadPartReq)
 			if err != nil {
 				return f.shouldRetry(err)
@@ -2112,7 +2121,7 @@ func (f *Fs) Copy(ctx context.Context, src fs.Object, remote string) (fs.Object,
 	req := s3.CopyObjectInput{
 		MetadataDirective: aws.String(s3.MetadataDirectiveCopy),
 	}
-	err = f.copy(ctx, &req, dstBucket, dstPath, srcBucket, srcPath, srcObj.Size())
+	err = f.copy(ctx, &req, dstBucket, dstPath, srcBucket, srcPath, srcObj)
 	if err != nil {
 		return nil, err
 	}
@@ -2509,19 +2518,12 @@ func (o *Object) Size() int64 {
 	return o.bytes
 }
 
-// readMetaData gets the metadata if it hasn't already been fetched
-//
-// it also sets the info
-func (o *Object) readMetaData(ctx context.Context) (err error) {
-	if o.meta != nil {
-		return nil
-	}
+func (o *Object) headObject(ctx context.Context) (resp *s3.HeadObjectOutput, err error) {
 	bucket, bucketPath := o.split()
 	req := s3.HeadObjectInput{
 		Bucket: &bucket,
 		Key:    &bucketPath,
 	}
-	var resp *s3.HeadObjectOutput
 	err = o.fs.pacer.Call(func() (bool, error) {
 		var err error
 		resp, err = o.fs.c.HeadObjectWithContext(ctx, &req)
@@ -2530,12 +2532,26 @@ func (o *Object) readMetaData(ctx context.Context) (err error) {
 	if err != nil {
 		if awsErr, ok := err.(awserr.RequestFailure); ok {
 			if awsErr.StatusCode() == http.StatusNotFound {
-				return fs.ErrorObjectNotFound
+				return nil, fs.ErrorObjectNotFound
 			}
 		}
-		return err
+		return nil, err
 	}
 	o.fs.cache.MarkOK(bucket)
+	return resp, nil
+}
+
+// readMetaData gets the metadata if it hasn't already been fetched
+//
+// it also sets the info
+func (o *Object) readMetaData(ctx context.Context) (err error) {
+	if o.meta != nil {
+		return nil
+	}
+	resp, err := o.headObject(ctx)
+	if err != nil {
+		return err
+	}
 	var size int64
 	// Ignore missing Content-Length assuming it is 0
 	// Some versions of ceph do this due their apache proxies
@@ -2606,7 +2622,7 @@ func (o *Object) SetModTime(ctx context.Context, modTime time.Time) error {
 		Metadata:          o.meta,
 		MetadataDirective: aws.String(s3.MetadataDirectiveReplace), // replace metadata with that passed in
 	}
-	return o.fs.copy(ctx, &req, bucket, bucketPath, bucket, bucketPath, o.bytes)
+	return o.fs.copy(ctx, &req, bucket, bucketPath, bucket, bucketPath, o)
 }
 
 // Storable raturns a boolean indicating if this object is storable
@@ -3046,7 +3062,7 @@ func (o *Object) SetTier(tier string) (err error) {
 		MetadataDirective: aws.String(s3.MetadataDirectiveCopy),
 		StorageClass:      aws.String(tier),
 	}
-	err = o.fs.copy(ctx, &req, bucket, bucketPath, bucket, bucketPath, o.bytes)
+	err = o.fs.copy(ctx, &req, bucket, bucketPath, bucket, bucketPath, o)
 	if err != nil {
 		return err
 	}