about summary refs log tree commit diff
path: root/tvix/nar-bridge
diff options
context:
space:
mode:
Diffstat (limited to 'tvix/nar-bridge')
-rw-r--r--tvix/nar-bridge/pkg/hashers/hashers.go66
-rw-r--r--tvix/nar-bridge/pkg/importer/counting_writer.go21
-rw-r--r--tvix/nar-bridge/pkg/importer/importer.go30
3 files changed, 37 insertions, 80 deletions
diff --git a/tvix/nar-bridge/pkg/hashers/hashers.go b/tvix/nar-bridge/pkg/hashers/hashers.go
deleted file mode 100644
index 0c9e611799e3..000000000000
--- a/tvix/nar-bridge/pkg/hashers/hashers.go
+++ /dev/null
@@ -1,66 +0,0 @@
-package hashers
-
-import (
-	"errors"
-	"fmt"
-	"hash"
-	"io"
-)
-
-var _ io.Reader = &Hasher{}
-
-// Hasher wraps io.Reader.
-// You can ask it for the digest of the hash function used internally, and the
-// number of bytes written.
-type Hasher struct {
-	r         io.Reader
-	h         hash.Hash
-	bytesRead uint32
-}
-
-func NewHasher(r io.Reader, h hash.Hash) *Hasher {
-	return &Hasher{
-		r:         r,
-		h:         h,
-		bytesRead: 0,
-	}
-}
-
-func (h *Hasher) Read(p []byte) (int, error) {
-	nRead, rdErr := h.r.Read(p)
-
-	// write the number of bytes read from the reader to the hash.
-	// We need to do this independently on whether there's been error.
-	// n always describes the number of successfully written bytes.
-	nHash, hashErr := h.h.Write(p[0:nRead])
-	if hashErr != nil {
-		return nRead, fmt.Errorf("unable to write to hash: %w", hashErr)
-	}
-
-	// We assume here the hash function accepts the whole p in one Go,
-	// and doesn't early-return on the Write.
-	// We compare it with nRead and bail out if that was not the case.
-	if nHash != nRead {
-		return nRead, fmt.Errorf("hash didn't accept the full write")
-	}
-
-	// update bytesWritten
-	h.bytesRead += uint32(nRead)
-
-	if rdErr != nil {
-		if errors.Is(rdErr, io.EOF) {
-			return nRead, rdErr
-		}
-		return nRead, fmt.Errorf("error from underlying reader: %w", rdErr)
-	}
-
-	return nRead, hashErr
-}
-
-func (h *Hasher) BytesWritten() uint32 {
-	return h.bytesRead
-}
-
-func (h *Hasher) Sum(b []byte) []byte {
-	return h.h.Sum(b)
-}
diff --git a/tvix/nar-bridge/pkg/importer/counting_writer.go b/tvix/nar-bridge/pkg/importer/counting_writer.go
new file mode 100644
index 000000000000..d003a4b11bfd
--- /dev/null
+++ b/tvix/nar-bridge/pkg/importer/counting_writer.go
@@ -0,0 +1,21 @@
+package importer
+
+import (
+	"io"
+)
+
+// CountingWriter implements io.Writer.
+var _ io.Writer = &CountingWriter{}
+
+type CountingWriter struct {
+	bytesWritten uint64
+}
+
+func (cw *CountingWriter) Write(p []byte) (n int, err error) {
+	cw.bytesWritten += uint64(len(p))
+	return len(p), nil
+}
+
+func (cw *CountingWriter) BytesWritten() uint64 {
+	return cw.bytesWritten
+}
diff --git a/tvix/nar-bridge/pkg/importer/importer.go b/tvix/nar-bridge/pkg/importer/importer.go
index 465b3bc84fdb..9d6a7178a2c5 100644
--- a/tvix/nar-bridge/pkg/importer/importer.go
+++ b/tvix/nar-bridge/pkg/importer/importer.go
@@ -10,10 +10,8 @@ import (
 	"strings"
 
 	castorev1pb "code.tvl.fyi/tvix/castore/protos"
-	"code.tvl.fyi/tvix/nar-bridge/pkg/hashers"
 	storev1pb "code.tvl.fyi/tvix/store/protos"
 	"github.com/nix-community/go-nix/pkg/nar"
-	"lukechampine.com/blake3"
 )
 
 // An item on the directories stack
@@ -34,12 +32,15 @@ func Import(
 	// callback function called with each finalized directory node
 	directoryCb func(directory *castorev1pb.Directory) ([]byte, error),
 ) (*storev1pb.PathInfo, error) {
-	// wrap the passed reader in a reader that records the number of bytes read and
-	// their sha256 sum.
-	hr := hashers.NewHasher(r, sha256.New())
-
-	// construct a NAR reader from the underlying data.
-	narReader, err := nar.NewReader(hr)
+	// We need to wrap the underlying reader a bit.
+	// - we want to keep track of the number of bytes read in total
+	// - we calculate the sha256 digest over all data read
+	// Express these two things in a MultiWriter, and give the NAR reader a
+	// TeeReader that writes to it.
+	narCountW := &CountingWriter{}
+	sha256W := sha256.New()
+	multiW := io.MultiWriter(narCountW, sha256W)
+	narReader, err := nar.NewReader(io.TeeReader(r, multiW))
 	if err != nil {
 		return nil, fmt.Errorf("failed to instantiate nar reader: %w", err)
 	}
@@ -132,8 +133,8 @@ func Import(
 					Node:       nil,
 					References: [][]byte{},
 					Narinfo: &storev1pb.NARInfo{
-						NarSize:        uint64(hr.BytesWritten()),
-						NarSha256:      hr.Sum(nil),
+						NarSize:        narCountW.BytesWritten(),
+						NarSha256:      sha256W.Sum(nil),
 						Signatures:     []*storev1pb.NARInfo_Signature{},
 						ReferenceNames: []string{},
 					},
@@ -202,8 +203,9 @@ func Import(
 
 			}
 			if hdr.Type == nar.TypeRegular {
-				// wrap reader with a reader calculating the blake3 hash
-				blobReader := hashers.NewHasher(narReader, blake3.New(32, nil))
+				// wrap reader with a reader counting the number of bytes read
+				blobCountW := &CountingWriter{}
+				blobReader := io.TeeReader(narReader, blobCountW)
 
 				blobDigest, err := blobCb(blobReader)
 				if err != nil {
@@ -212,8 +214,8 @@ func Import(
 
 				// ensure blobCb did read all the way to the end.
 				// If it didn't, the blobCb function is wrong and we should bail out.
-				if blobReader.BytesWritten() != uint32(hdr.Size) {
-					panic("not read to end")
+				if blobCountW.BytesWritten() != uint64(hdr.Size) {
+					panic("blobCB did not read to end")
 				}
 
 				fileNode := &castorev1pb.FileNode{