about summary refs log tree commit diff
path: root/tvix/nar-bridge/pkg/importer
diff options
context:
space:
mode:
Diffstat (limited to 'tvix/nar-bridge/pkg/importer')
-rw-r--r--tvix/nar-bridge/pkg/importer/importer.go43
-rw-r--r--tvix/nar-bridge/pkg/importer/importer_test.go82
2 files changed, 63 insertions, 62 deletions
diff --git a/tvix/nar-bridge/pkg/importer/importer.go b/tvix/nar-bridge/pkg/importer/importer.go
index de5456be0b3e..465b3bc84fdb 100644
--- a/tvix/nar-bridge/pkg/importer/importer.go
+++ b/tvix/nar-bridge/pkg/importer/importer.go
@@ -30,9 +30,9 @@ func Import(
 	// The reader the data is read from
 	r io.Reader,
 	// callback function called with each regular file content
-	blobCb func(fileReader io.Reader) error,
+	blobCb func(fileReader io.Reader) ([]byte, error),
 	// callback function called with each finalized directory node
-	directoryCb func(directory *castorev1pb.Directory) error,
+	directoryCb func(directory *castorev1pb.Directory) ([]byte, error),
 ) (*storev1pb.PathInfo, error) {
 	// wrap the passed reader in a reader that records the number of bytes read and
 	// their sha256 sum.
@@ -65,24 +65,21 @@ func Import(
 		toPop := stack[len(stack)-1]
 		stack = stack[:len(stack)-1]
 
+		// call the directoryCb
+		directoryDigest, err := directoryCb(toPop.directory)
+		if err != nil {
+			return fmt.Errorf("failed calling directoryCb: %w", err)
+		}
+
 		// if there's still a parent left on the stack, refer to it from there.
 		if len(stack) > 0 {
-			dgst, err := toPop.directory.Digest()
-			if err != nil {
-				return fmt.Errorf("unable to calculate directory digest: %w", err)
-			}
-
 			topOfStack := stack[len(stack)-1].directory
 			topOfStack.Directories = append(topOfStack.Directories, &castorev1pb.DirectoryNode{
 				Name:   []byte(path.Base(toPop.path)),
-				Digest: dgst,
+				Digest: directoryDigest,
 				Size:   toPop.directory.Size(),
 			})
 		}
-		// call the directoryCb
-		if err := directoryCb(toPop.directory); err != nil {
-			return fmt.Errorf("failed calling directoryCb: %w", err)
-		}
 		// Keep track that we have encounter at least one directory
 		stackDirectory = toPop.directory
 		return nil
@@ -106,7 +103,7 @@ func Import(
 			hdr, err := narReader.Next()
 
 			// If this returns an error, it's either EOF (when we're done reading from the NAR),
-			// or another error
+			// or another error.
 			if err != nil {
 				// if this returns no EOF, bail out
 				if !errors.Is(err, io.EOF) {
@@ -206,28 +203,22 @@ func Import(
 			}
 			if hdr.Type == nar.TypeRegular {
 				// wrap reader with a reader calculating the blake3 hash
-				fileReader := hashers.NewHasher(narReader, blake3.New(32, nil))
+				blobReader := hashers.NewHasher(narReader, blake3.New(32, nil))
 
-				err := blobCb(fileReader)
+				blobDigest, err := blobCb(blobReader)
 				if err != nil {
 					return nil, fmt.Errorf("failure from blobCb: %w", err)
 				}
 
-				// drive the file reader to the end, in case the CB function doesn't read
-				// all the way to the end on its own
-				if fileReader.BytesWritten() != uint32(hdr.Size) {
-					_, err := io.ReadAll(fileReader)
-					if err != nil {
-						return nil, fmt.Errorf("unable to read until the end of the file content: %w", err)
-					}
+				// ensure blobCb did read all the way to the end.
+				// If it didn't, the blobCb function is wrong and we should bail out.
+				if blobReader.BytesWritten() != uint32(hdr.Size) {
+					panic("not read to end")
 				}
 
-				// read the blake3 hash
-				dgst := fileReader.Sum(nil)
-
 				fileNode := &castorev1pb.FileNode{
 					Name:       []byte(getBasename(hdr.Path)),
-					Digest:     dgst,
+					Digest:     blobDigest,
 					Size:       uint32(hdr.Size),
 					Executable: hdr.Executable,
 				}
diff --git a/tvix/nar-bridge/pkg/importer/importer_test.go b/tvix/nar-bridge/pkg/importer/importer_test.go
index 0557c1d6dd1f..de0548da9398 100644
--- a/tvix/nar-bridge/pkg/importer/importer_test.go
+++ b/tvix/nar-bridge/pkg/importer/importer_test.go
@@ -1,6 +1,7 @@
 package importer_test
 
 import (
+	"bytes"
 	"context"
 	"errors"
 	"io"
@@ -13,6 +14,7 @@ import (
 	"github.com/google/go-cmp/cmp"
 	"github.com/stretchr/testify/require"
 	"google.golang.org/protobuf/testing/protocmp"
+	"lukechampine.com/blake3"
 )
 
 func requireProtoEq(t *testing.T, expected interface{}, actual interface{}) {
@@ -21,7 +23,7 @@ func requireProtoEq(t *testing.T, expected interface{}, actual interface{}) {
 	}
 }
 
-func mustDigest(d *castorev1pb.Directory) []byte {
+func mustDirectoryDigest(d *castorev1pb.Directory) []byte {
 	dgst, err := d.Digest()
 	if err != nil {
 		panic(err)
@@ -29,6 +31,15 @@ func mustDigest(d *castorev1pb.Directory) []byte {
 	return dgst
 }
 
+func mustBlobDigest(r io.Reader) []byte {
+	hasher := blake3.New(32, nil)
+	_, err := io.Copy(hasher, r)
+	if err != nil {
+		panic(err)
+	}
+	return hasher.Sum([]byte{})
+}
+
 func TestSymlink(t *testing.T) {
 	f, err := os.Open("../../testdata/symlink.nar")
 	require.NoError(t, err)
@@ -36,9 +47,9 @@ func TestSymlink(t *testing.T) {
 	actualPathInfo, err := importer.Import(
 		context.Background(),
 		f,
-		func(fileReader io.Reader) error {
+		func(blobReader io.Reader) ([]byte, error) {
 			panic("no file contents expected!")
-		}, func(directory *castorev1pb.Directory) error {
+		}, func(directory *castorev1pb.Directory) ([]byte, error) {
 			panic("no directories expected!")
 		},
 	)
@@ -74,12 +85,12 @@ func TestRegular(t *testing.T) {
 	actualPathInfo, err := importer.Import(
 		context.Background(),
 		f,
-		func(fileReader io.Reader) error {
-			contents, err := io.ReadAll(fileReader)
-			require.NoError(t, err, "reading fileReader should not error")
-			require.Equal(t, []byte{0x01}, contents, "contents read from fileReader should match expectations")
-			return nil
-		}, func(directory *castorev1pb.Directory) error {
+		func(blobReader io.Reader) ([]byte, error) {
+			contents, err := io.ReadAll(blobReader)
+			require.NoError(t, err, "reading blobReader should not error")
+			require.Equal(t, []byte{0x01}, contents, "contents read from blobReader should match expectations")
+			return mustBlobDigest(bytes.NewBuffer(contents)), nil
+		}, func(directory *castorev1pb.Directory) ([]byte, error) {
 			panic("no directories expected!")
 		},
 	)
@@ -129,11 +140,11 @@ func TestEmptyDirectory(t *testing.T) {
 	actualPathInfo, err := importer.Import(
 		context.Background(),
 		f,
-		func(fileReader io.Reader) error {
+		func(blobReader io.Reader) ([]byte, error) {
 			panic("no file contents expected!")
-		}, func(directory *castorev1pb.Directory) error {
+		}, func(directory *castorev1pb.Directory) ([]byte, error) {
 			requireProtoEq(t, expectedDirectory, directory)
-			return nil
+			return mustDirectoryDigest(directory), nil
 		},
 	)
 	require.NoError(t, err)
@@ -143,7 +154,7 @@ func TestEmptyDirectory(t *testing.T) {
 			Node: &castorev1pb.Node_Directory{
 				Directory: &castorev1pb.DirectoryNode{
 					Name:   []byte(""),
-					Digest: mustDigest(expectedDirectory),
+					Digest: mustDirectoryDigest(expectedDirectory),
 					Size:   expectedDirectory.Size(),
 				},
 			},
@@ -415,17 +426,17 @@ func TestFull(t *testing.T) {
 		Directories: []*castorev1pb.DirectoryNode{
 			{
 				Name:   []byte("man1"),
-				Digest: mustDigest(expectedDirectories["/share/man/man1"]),
+				Digest: mustDirectoryDigest(expectedDirectories["/share/man/man1"]),
 				Size:   expectedDirectories["/share/man/man1"].Size(),
 			},
 			{
 				Name:   []byte("man5"),
-				Digest: mustDigest(expectedDirectories["/share/man/man5"]),
+				Digest: mustDirectoryDigest(expectedDirectories["/share/man/man5"]),
 				Size:   expectedDirectories["/share/man/man5"].Size(),
 			},
 			{
 				Name:   []byte("man8"),
-				Digest: mustDigest(expectedDirectories["/share/man/man8"]),
+				Digest: mustDirectoryDigest(expectedDirectories["/share/man/man8"]),
 				Size:   expectedDirectories["/share/man/man8"].Size(),
 			},
 		},
@@ -438,7 +449,7 @@ func TestFull(t *testing.T) {
 		Directories: []*castorev1pb.DirectoryNode{
 			{
 				Name:   []byte("man"),
-				Digest: mustDigest(expectedDirectories["/share/man"]),
+				Digest: mustDirectoryDigest(expectedDirectories["/share/man"]),
 				Size:   expectedDirectories["/share/man"].Size(),
 			},
 		},
@@ -451,12 +462,12 @@ func TestFull(t *testing.T) {
 		Directories: []*castorev1pb.DirectoryNode{
 			{
 				Name:   []byte("bin"),
-				Digest: mustDigest(expectedDirectories["/bin"]),
+				Digest: mustDirectoryDigest(expectedDirectories["/bin"]),
 				Size:   expectedDirectories["/bin"].Size(),
 			},
 			{
 				Name:   []byte("share"),
-				Digest: mustDigest(expectedDirectories["/share"]),
+				Digest: mustDirectoryDigest(expectedDirectories["/share"]),
 				Size:   expectedDirectories["/share"].Size(),
 			},
 		},
@@ -476,14 +487,12 @@ func TestFull(t *testing.T) {
 	actualPathInfo, err := importer.Import(
 		context.Background(),
 		f,
-		func(fileReader io.Reader) error {
+		func(blobReader io.Reader) ([]byte, error) {
 			// Don't really bother reading and comparing the contents here,
 			// We already verify the right digests are produced by comparing the
 			// directoryCb calls, and TestRegular ensures the reader works.
-			// This also covers the case when the client doesn't read from the reader, and that the
-			// importer will take care of reading all the way to the end no matter what.
-			return nil
-		}, func(directory *castorev1pb.Directory) error {
+			return mustBlobDigest(blobReader), nil
+		}, func(directory *castorev1pb.Directory) ([]byte, error) {
 			// use actualDirectoryOrder to look up the Directory object we expect at this specific invocation.
 			currentDirectoryPath := expectedDirectoryPaths[numDirectoriesReceived]
 
@@ -493,7 +502,7 @@ func TestFull(t *testing.T) {
 			requireProtoEq(t, expectedDirectory, directory)
 
 			numDirectoriesReceived += 1
-			return nil
+			return mustDirectoryDigest(directory), nil
 		},
 	)
 	require.NoError(t, err)
@@ -503,7 +512,7 @@ func TestFull(t *testing.T) {
 			Node: &castorev1pb.Node_Directory{
 				Directory: &castorev1pb.DirectoryNode{
 					Name:   []byte(""),
-					Digest: mustDigest(expectedDirectories["/"]),
+					Digest: mustDirectoryDigest(expectedDirectories["/"]),
 					Size:   expectedDirectories["/"].Size(),
 				},
 			},
@@ -524,7 +533,7 @@ func TestFull(t *testing.T) {
 // TestCallbackErrors ensures that errors returned from the callback function
 // bubble up to the importer process, and are not ignored.
 func TestCallbackErrors(t *testing.T) {
-	t.Run("callback file", func(t *testing.T) {
+	t.Run("callback blob", func(t *testing.T) {
 		// Pick an example NAR with a regular file.
 		f, err := os.Open("../../testdata/onebyteregular.nar")
 		require.NoError(t, err)
@@ -534,9 +543,9 @@ func TestCallbackErrors(t *testing.T) {
 		_, err = importer.Import(
 			context.Background(),
 			f,
-			func(fileReader io.Reader) error {
-				return targetErr
-			}, func(directory *castorev1pb.Directory) error {
+			func(blobReader io.Reader) ([]byte, error) {
+				return nil, targetErr
+			}, func(directory *castorev1pb.Directory) ([]byte, error) {
 				panic("no directories expected!")
 			},
 		)
@@ -552,10 +561,10 @@ func TestCallbackErrors(t *testing.T) {
 		_, err = importer.Import(
 			context.Background(),
 			f,
-			func(fileReader io.Reader) error {
+			func(blobReader io.Reader) ([]byte, error) {
 				panic("no file contents expected!")
-			}, func(directory *castorev1pb.Directory) error {
-				return targetErr
+			}, func(directory *castorev1pb.Directory) ([]byte, error) {
+				return nil, targetErr
 			},
 		)
 		require.ErrorIs(t, err, targetErr)
@@ -582,9 +591,10 @@ func TestPopDirectories(t *testing.T) {
 	_, err = importer.Import(
 		context.Background(),
 		f,
-		func(fileReader io.Reader) error { return nil },
-		func(directory *castorev1pb.Directory) error {
-			return directory.Validate()
+		func(blobReader io.Reader) ([]byte, error) { return mustBlobDigest(blobReader), nil },
+		func(directory *castorev1pb.Directory) ([]byte, error) {
+			require.NoError(t, directory.Validate(), "directory validation shouldn't error")
+			return mustDirectoryDigest(directory), nil
 		},
 	)
 	require.NoError(t, err)