feat(storage/transfermanager): checksum full object downloads (#10569) · googleapis/google-cloud-go@c366c90
@@ -18,6 +18,8 @@ import (
1818"context"
1919"errors"
2020"fmt"
21+"hash"
22+"hash/crc32"
2123"io"
2224"io/fs"
2325"math"
@@ -31,6 +33,12 @@ import (
3133"google.golang.org/api/iterator"
3234)
333536+// maxChecksumZeroArraySize is the maximum amount of memory to allocate for
37+// updating the checksum. A larger size will occupy more memory but will require
38+// fewer updates when computing the crc32c of a full object.
39+// TODO: test the performance of smaller values for this.
40+const maxChecksumZeroArraySize = 4 * 1024 * 1024
41+3442// Downloader manages a set of parallelized downloads.
3543type Downloader struct {
3644client *storage.Client
@@ -288,7 +296,7 @@ func (d *Downloader) addNewInputs(inputs []DownloadObjectInput) {
288296}
289297290298func (d *Downloader) addResult(input *DownloadObjectInput, result *DownloadOutput) {
291-copiedResult := *result // make a copy so that callbacks do not affect the result
299+copiedResult := *result // make a copy so that callbacks do not affect the result
292300293301if input.directory {
294302f := input.Destination.(*os.File)
@@ -305,7 +313,6 @@ func (d *Downloader) addResult(input *DownloadObjectInput, result *DownloadOutpu
305313input.directoryObjectOutputs <- copiedResult
306314 }
307315 }
308-// TODO: check checksum if full object
309316310317if d.config.asynchronous || input.directory {
311318input.Callback(result)
@@ -337,27 +344,10 @@ func (d *Downloader) downloadWorker() {
337344break // no more work; exit
338345 }
339346340-out := input.downloadShard(d.client, d.config.perOperationTimeout, d.config.partSize)
341-342347if input.shard == 0 {
343-if out.Err != nil {
344-// Don't queue more shards if the first failed.
345-d.addResult(input, out)
346- } else {
347-numShards := numShards(out.Attrs, input.Range, d.config.partSize)
348-349-if numShards <= 1 {
350-// Download completed with a single shard.
351-d.addResult(input, out)
352- } else {
353-// Queue more shards.
354-outs := d.queueShards(input, out.Attrs.Generation, numShards)
355-// Start a goroutine that gathers shards sent to the output
356-// channel and adds the result once it has received all shards.
357-go d.gatherShards(input, outs, numShards)
358- }
359- }
348+d.startDownload(input)
360349 } else {
350+out := input.downloadShard(d.client, d.config.perOperationTimeout, d.config.partSize)
361351// If this isn't the first shard, send to the output channel specific to the object.
362352// This should never block since the channel is buffered to exactly the number of shards.
363353input.shardOutputs <- out
@@ -366,6 +356,47 @@ func (d *Downloader) downloadWorker() {
366356d.workers.Done()
367357}
368358359+// startDownload downloads the first shard and schedules subsequent shards
360+// if necessary.
361+func (d *Downloader) startDownload(input *DownloadObjectInput) {
362+var out *DownloadOutput
363+364+// Full object read. Request the full object and only read partSize bytes
365+// (or the full object, if smaller than partSize), so that we can avoid a
366+// metadata call to grab the CRC32C for JSON downloads.
367+if fullObjectRead(input.Range) {
368+input.checkCRC = true
369+out = input.downloadFirstShard(d.client, d.config.perOperationTimeout, d.config.partSize)
370+ } else {
371+out = input.downloadShard(d.client, d.config.perOperationTimeout, d.config.partSize)
372+ }
373+374+if out.Err != nil {
375+// Don't queue more shards if the first failed.
376+d.addResult(input, out)
377+return
378+ }
379+380+numShards := numShards(out.Attrs, input.Range, d.config.partSize)
381+input.checkCRC = input.checkCRC && !out.Attrs.Decompressed // do not checksum if the object was decompressed
382+383+if numShards > 1 {
384+outs := d.queueShards(input, out.Attrs.Generation, numShards)
385+// Start a goroutine that gathers shards sent to the output
386+// channel and adds the result once it has received all shards.
387+go d.gatherShards(input, out, outs, numShards, out.crc32c)
388+389+ } else {
390+// Download completed with a single shard.
391+if input.checkCRC {
392+if err := checksumObject(out.crc32c, out.Attrs.CRC32C); err != nil {
393+out.Err = err
394+ }
395+ }
396+d.addResult(input, out)
397+ }
398+}
399+369400// queueShards queues all subsequent shards of an object after the first.
370401// The results should be forwarded to the returned channel.
371402func (d *Downloader) queueShards(in *DownloadObjectInput, gen int64, shards int) <-chan *DownloadOutput {
@@ -397,12 +428,12 @@ var errCancelAllShards = errors.New("cancelled because another shard failed")
397428// It will add the result to the Downloader once it has received all shards.
398429// gatherShards cancels remaining shards if any shard errored.
399430// It does not do any checking to verify that shards are for the same object.
400-func (d *Downloader) gatherShards(in *DownloadObjectInput, outs <-chan *DownloadOutput, shards int) {
431+func (d *Downloader) gatherShards(in *DownloadObjectInput, out *DownloadOutput, outs <-chan *DownloadOutput, shards int, firstPieceCRC uint32) {
401432errs := []error{}
402-var shardOut *DownloadOutput
433+orderedChecksums := make([]crc32cPiece, shards-1)
434+403435for i := 1; i < shards; i++ {
404-// Add monitoring here? This could hang if any individual piece does.
405-shardOut = <-outs
436+shardOut := <-outs
406437407438// We can ignore errors that resulted from a previous error.
408439// Note that we may still get some cancel errors if they
@@ -412,20 +443,30 @@ func (d *Downloader) gatherShards(in *DownloadObjectInput, outs <-chan *Download
412443errs = append(errs, shardOut.Err)
413444in.cancelCtx(errCancelAllShards)
414445 }
446+447+orderedChecksums[shardOut.shard-1] = crc32cPiece{sum: shardOut.crc32c, length: shardOut.shardLength}
448+ }
449+450+// All pieces gathered.
451+if len(errs) == 0 && in.checkCRC && out.Attrs != nil {
452+fullCrc := joinCRC32C(firstPieceCRC, orderedChecksums)
453+if err := checksumObject(fullCrc, out.Attrs.CRC32C); err != nil {
454+errs = append(errs, err)
455+ }
415456 }
416457417-// All pieces gathered; return output. Any shard output will do.
418-shardOut.Range = in.Range
458+// Prepare output.
459+out.Range = in.Range
419460if len(errs) != 0 {
420-shardOut.Err = fmt.Errorf("download shard errors:\n%w", errors.Join(errs...))
461+out.Err = fmt.Errorf("download shard errors:\n%w", errors.Join(errs...))
421462 }
422-if shardOut.Attrs != nil {
423-shardOut.Attrs.StartOffset = 0
463+if out.Attrs != nil {
464+out.Attrs.StartOffset = 0
424465if in.Range != nil {
425-shardOut.Attrs.StartOffset = in.Range.Offset
466+out.Attrs.StartOffset = in.Range.Offset
426467 }
427468 }
428-d.addResult(in, shardOut)
469+d.addResult(in, out)
429470}
430471431472// gatherObjectOutputs receives from the given channel exactly numObjects times.
@@ -563,45 +604,18 @@ type DownloadObjectInput struct {
563604shardOutputs chan<- *DownloadOutput
564605directory bool // input was queued by calling DownloadDirectory
565606directoryObjectOutputs chan<- DownloadOutput
607+checkCRC bool
566608}
567609568610// downloadShard will read a specific object piece into in.Destination.
569611// If timeout is less than 0, no timeout is set.
570612func (in *DownloadObjectInput) downloadShard(client *storage.Client, timeout time.Duration, partSize int64) (out *DownloadOutput) {
571613out = &DownloadOutput{Bucket: in.Bucket, Object: in.Object, Range: in.Range}
572614573-// Set timeout.
574-ctx := in.ctx
575-if timeout > 0 {
576-c, cancel := context.WithTimeout(ctx, timeout)
577-defer cancel()
578-ctx = c
579- }
580-581-// The first shard will be sent as download many, since we do not know yet
582-// if it will be sharded.
583-method := downloadMany
584-if in.shard != 0 {
585-method = downloadSharded
586- }
587-ctx = setUsageMetricHeader(ctx, method)
588-589-// Set options on the object.
590-o := client.Bucket(in.Bucket).Object(in.Object)
591-592-if in.Conditions != nil {
593-o = o.If(*in.Conditions)
594- }
595-if in.Generation != nil {
596-o = o.Generation(*in.Generation)
597- }
598-if len(in.EncryptionKey) > 0 {
599-o = o.Key(in.EncryptionKey)
600- }
601-602615objRange := shardRange(in.Range, partSize, in.shard)
616+ctx := in.setOptionsOnContext(timeout)
617+o := in.setOptionsOnObject(client)
603618604-// Read.
605619r, err := o.NewRangeReader(ctx, objRange.Offset, objRange.Length)
606620if err != nil {
607621out.Err = err
@@ -618,9 +632,63 @@ func (in *DownloadObjectInput) downloadShard(client *storage.Client, timeout tim
618632 }
619633 }
620634621-w := io.NewOffsetWriter(in.Destination, offset)
622-_, err = io.Copy(w, r)
635+var w io.Writer
636+w = io.NewOffsetWriter(in.Destination, offset)
637+638+var crcHash hash.Hash32
639+if in.checkCRC {
640+crcHash = crc32.New(crc32.MakeTable(crc32.Castagnoli))
641+w = io.MultiWriter(w, crcHash)
642+ }
643+644+n, err := io.Copy(w, r)
645+if err != nil {
646+out.Err = err
647+r.Close()
648+return
649+ }
650+651+if err = r.Close(); err != nil {
652+out.Err = err
653+return
654+ }
655+656+out.Attrs = &r.Attrs
657+out.shard = in.shard
658+out.shardLength = n
659+if in.checkCRC {
660+out.crc32c = crcHash.Sum32()
661+ }
662+return
663+}
664+665+// downloadFirstShard will read the first object piece into in.Destination.
666+// If timeout is less than 0, no timeout is set.
667+func (in *DownloadObjectInput) downloadFirstShard(client *storage.Client, timeout time.Duration, partSize int64) (out *DownloadOutput) {
668+out = &DownloadOutput{Bucket: in.Bucket, Object: in.Object, Range: in.Range}
669+670+ctx := in.setOptionsOnContext(timeout)
671+o := in.setOptionsOnObject(client)
672+673+r, err := o.NewReader(ctx)
623674if err != nil {
675+out.Err = err
676+return
677+ }
678+679+var w io.Writer
680+w = io.NewOffsetWriter(in.Destination, 0)
681+682+var crcHash hash.Hash32
683+if in.checkCRC {
684+crcHash = crc32.New(crc32.MakeTable(crc32.Castagnoli))
685+w = io.MultiWriter(w, crcHash)
686+ }
687+688+// Copy only the first partSize bytes before closing the reader.
689+// If we encounter an EOF, the file was smaller than partSize.
690+n, err := io.CopyN(w, r, partSize)
691+if err != nil && err != io.EOF {
624692out.Err = err
625693r.Close()
626694return
@@ -632,9 +700,45 @@ func (in *DownloadObjectInput) downloadShard(client *storage.Client, timeout tim
632700 }
633701634702out.Attrs = &r.Attrs
703+out.shard = in.shard
704+out.shardLength = n
705+if in.checkCRC {
706+out.crc32c = crcHash.Sum32()
707+ }
635708return
636709}
637710711+func (in *DownloadObjectInput) setOptionsOnContext(timeout time.Duration) context.Context {
712+ctx := in.ctx
713+if timeout > 0 {
714+c, cancel := context.WithTimeout(ctx, timeout)
715+defer cancel()
716+ctx = c
717+ }
718+719+// The first shard will be sent as download many, since we do not know yet
720+// if it will be sharded.
721+method := downloadMany
722+if in.shard != 0 {
723+method = downloadSharded
724+ }
725+return setUsageMetricHeader(ctx, method)
726+}
727+728+func (in *DownloadObjectInput) setOptionsOnObject(client *storage.Client) *storage.ObjectHandle {
729+o := client.Bucket(in.Bucket).Object(in.Object)
730+if in.Conditions != nil {
731+o = o.If(*in.Conditions)
732+ }
733+if in.Generation != nil {
734+o = o.Generation(*in.Generation)
735+ }
736+if len(in.EncryptionKey) > 0 {
737+o = o.Key(in.EncryptionKey)
738+ }
739+return o
740+}
741+638742// DownloadDirectoryInput is the input for a directory to download.
639743type DownloadDirectoryInput struct {
640744// Bucket is the bucket in GCS to download from. Required.
@@ -686,6 +790,10 @@ type DownloadOutput struct {
686790Range *DownloadRange // requested range, if it was specified
687791Err error // error occurring during download
688792Attrs *storage.ReaderObjectAttrs // attributes of downloaded object, if successful
793+794+shard int
795+shardLength int64
796+crc32c uint32
689797}
690798691799// TODO: use built-in after go < 1.21 is dropped.
@@ -784,3 +892,48 @@ func setUsageMetricHeader(ctx context.Context, method string) context.Context {
784892header := fmt.Sprintf("%s/%s", usageMetricKey, method)
785893return callctx.SetHeaders(ctx, xGoogHeaderKey, header)
786894}
895+896+type crc32cPiece struct {
897+sum uint32 // crc32c checksum of the piece
898+length int64 // number of bytes in this piece
899+}
900+901+// joinCRC32C pieces together the initial checksum with the orderedChecksums
902+// provided to calculate the checksum of the whole.
903+func joinCRC32C(initialChecksum uint32, orderedChecksums []crc32cPiece) uint32 {
904+base := initialChecksum
905+906+zeroes := make([]byte, maxChecksumZeroArraySize)
907+for _, part := range orderedChecksums {
908+// Precondition Base (flip every bit)
909+base ^= 0xFFFFFFFF
910+911+// Zero pad base crc32c. To conserve memory, do so with only maxChecksumZeroArraySize
912+// at a time. Reuse the zeroes array where possible.
913+var padded int64 = 0
914+for padded < part.length {
915+desiredZeroes := min(part.length-padded, maxChecksumZeroArraySize)
916+base = crc32.Update(base, crc32.MakeTable(crc32.Castagnoli), zeroes[:desiredZeroes])
917+padded += desiredZeroes
918+ }
919+920+// Postcondition Base (same as precondition, this switches the bits back)
921+base ^= 0xFFFFFFFF
922+923+// Bitwise OR between Base and Part to produce a new Base
924+base ^= part.sum
925+ }
926+return base
927+}
928+929+func fullObjectRead(r *DownloadRange) bool {
930+return r == nil || (r.Offset == 0 && r.Length < 0)
931+}
932+933+func checksumObject(got, want uint32) error {
934+// Only checksum the object if we have a valid CRC32C.
935+if want != 0 && want != got {
936+return fmt.Errorf("bad CRC on read: got %d, want %d", got, want)
937+ }
938+return nil
939+}