From 0905141013b3e5a597a3c2db0dbac82de1c761fc Mon Sep 17 00:00:00 2001 From: Will Murphy Date: Tue, 26 Nov 2024 07:07:00 -0500 Subject: [PATCH] fix: prevent index out of range on long frag tables Previously, reading fragment 512 would panic with index out of range. Fix that panic by introducing an abstraction over reading blocks of items, caching the intermediate result, and returning an item at a particular index. The primary goal of this abstraction is to make edge cases like requesting items on page boundaries easy to unit test for. Additionally, fix unit tests by making t.Fatal calls protected by nil checks on the error values. --- low/caching_paged_reader.go | 72 +++++++++++++++++ low/caching_paged_reader_test.go | 130 +++++++++++++++++++++++++++++++ low/reader.go | 54 +++++-------- low/reader_test.go | 12 ++- 4 files changed, 228 insertions(+), 40 deletions(-) create mode 100644 low/caching_paged_reader.go create mode 100644 low/caching_paged_reader_test.go diff --git a/low/caching_paged_reader.go b/low/caching_paged_reader.go new file mode 100644 index 0000000..9c94aea --- /dev/null +++ b/low/caching_paged_reader.go @@ -0,0 +1,72 @@ +package squashfslow + +import ( + "errors" + "math" +) + +var errOutOfBounds = errors.New("out of bounds") +var errUnexpectedOutOfBounds = errors.New("unexpected out of bounds") +var errNilCollection = errors.New("nil collection") + +// readPagedItems calls readBLockOrPartial the correct number of times to cache +// requestedItemIndex in currentItems, and then returns currentItems[requestedItemIndex]. +// Parameters: +// - requestedItemIndex: The index of the item to be retrieved. +// - blockSize: The number of items per block. +// - currentItems: A slice of already-read items to manage in-memory storage. Must not be nil. +// - readBlockOrPartial: A callback function that reads the next block. It takes the index of the block +// to be read, and the number of items to read. It is normally passed block size, but if the last +// block is incomplete, it will be passed the number of items in the last block. +// Returns: +// - the T at requestedItemIndex +// - a non-nil error and the zero value of T if an error was encountered. +func readPagedItems[T any]( + requestedItemIndex int, + blockSize int, + currentItems *[]T, + totalItems int, + readBlockOrPartial func(idxBlock, numItems int) ([]T, error), +) (T, error) { + var zero T // Zero value for the item type, used for default return in error cases. + if currentItems == nil { + return zero, errNilCollection + } + + if requestedItemIndex < 0 || requestedItemIndex >= totalItems { + return zero, errOutOfBounds + } + + if len(*currentItems) > requestedItemIndex { + return (*currentItems)[requestedItemIndex], nil + } + + // Calculate which block contains the requested item + blockNum := int(math.Ceil(float64(requestedItemIndex+1)/float64(blockSize))) - 1 + + // Calculate blocks to read + blocksRead := len(*currentItems) / blockSize + blocksToRead := blockNum - blocksRead + 1 + + // Read and append new blocks + for i := 0; i < blocksToRead; i++ { + startBlock := blocksRead + i + itemsLeft := totalItems - len(*currentItems) + itemsToRead := blockSize + if itemsToRead > itemsLeft { + itemsToRead = itemsLeft + } + items, err := readBlockOrPartial(startBlock, itemsToRead) + if err != nil { + return zero, err + } + *currentItems = append(*currentItems, items...) + } + + // Ensure the slice contains the requested index after reading + if len(*currentItems) <= requestedItemIndex { + return zero, errUnexpectedOutOfBounds + } + + return (*currentItems)[requestedItemIndex], nil +} diff --git a/low/caching_paged_reader_test.go b/low/caching_paged_reader_test.go new file mode 100644 index 0000000..51e8ba9 --- /dev/null +++ b/low/caching_paged_reader_test.go @@ -0,0 +1,130 @@ +package squashfslow + +import ( + "errors" + "testing" +) + +func requireNoError(t *testing.T, err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } +} + +func assertEqual(t *testing.T, want int, got int) { + t.Helper() + if want != got { + t.Errorf("want %d, got %d", want, got) + } +} + +func assertLength(t *testing.T, want int, slice []int) { + t.Helper() + if len(slice) != want { + t.Errorf("want len %d, got %d", want, len(slice)) + } +} + +func assertErrorIs(t *testing.T, err error, wantErr error) { + t.Helper() + if err == nil { + t.Errorf("want %s, got nil", wantErr) + return + } + if !errors.Is(err, wantErr) { + t.Errorf("want %s, got %v", wantErr, err) + } +} + +func TestCachingPagedReader(t *testing.T) { + // Mock readBlocks function + mockReadNMore := func(startBlock, numItems int) ([]int, error) { + if startBlock < 0 { + return nil, errors.New("invalid block start") + } + var result []int + for i := 0; i < numItems; i++ { + result = append(result, startBlock*512+i) + } + return result, nil + } + + t.Run("ValidRequestWithinFirstBlock", func(t *testing.T) { + currentItems := make([]int, 0) + item, err := readPagedItems(300, 512, ¤tItems, 2048, mockReadNMore) + requireNoError(t, err) + assertEqual(t, 300, item) + assertLength(t, 512, currentItems) // Ensure one block is read + }) + + t.Run("ValidRequestAcrossMultipleBlocks", func(t *testing.T) { + currentItems := make([]int, 0) + item, err := readPagedItems(600, 512, ¤tItems, 2048, mockReadNMore) + requireNoError(t, err) + assertEqual(t, 600, item) + assertLength(t, 1024, currentItems) + }) + + t.Run("SequentialRequestsWithinBlocks", func(t *testing.T) { + currentItems := make([]int, 0) + // First request + item, err := readPagedItems(300, 512, ¤tItems, 2048, mockReadNMore) + requireNoError(t, err) + assertEqual(t, 300, item) + + // Second request in the same block + item, err = readPagedItems(400, 512, ¤tItems, 2048, mockReadNMore) + requireNoError(t, err) + assertEqual(t, 400, item) + assertLength(t, 512, currentItems) + }) + + t.Run("RequestExactBlockBoundary", func(t *testing.T) { + currentItems := make([]int, 0) + item, err := readPagedItems(511, 512, ¤tItems, 2048, mockReadNMore) + requireNoError(t, err) + assertEqual(t, 511, item) + assertLength(t, 512, currentItems) + + // Request the next block's first item + item, err = readPagedItems(512, 512, ¤tItems, 2048, mockReadNMore) + requireNoError(t, err) + assertEqual(t, 512, item) + assertLength(t, 1024, currentItems) + }) + + t.Run("OutOfBoundsRequest", func(t *testing.T) { + currentItems := make([]int, 0) + _, err := readPagedItems(2048, 512, ¤tItems, 2048, mockReadNMore) + assertErrorIs(t, err, errOutOfBounds) + }) + + t.Run("RequestBeyondReadBlocks", func(t *testing.T) { + readFail := errors.New("failed to read block") + failingReadBlocks := func(startBlock, numBlocks int) ([]int, error) { + if startBlock > 1 { + return nil, readFail + } + var result []int + for i := 0; i < numBlocks*512; i++ { + result = append(result, startBlock*512+i) + } + return result, nil + } + + currentItems := make([]int, 0) + _, err := readPagedItems(1024, 512, ¤tItems, 2048, failingReadBlocks) + assertErrorIs(t, err, readFail) + }) + + t.Run("partial last page", func(t *testing.T) { + currentItems := make([]int, 0) + + // Request the next block's first item + item, err := readPagedItems(512, 512, ¤tItems, 612, mockReadNMore) + requireNoError(t, err) + assertEqual(t, 512, item) + assertLength(t, 612, currentItems) + }) +} diff --git a/low/reader.go b/low/reader.go index 4f8efb9..397a499 100644 --- a/low/reader.go +++ b/low/reader.go @@ -126,43 +126,25 @@ func (r *Reader) Id(i uint16) (uint32, error) { // Get a fragment entry at the given index. Lazily populates the reader's fragment table as necessary. func (r *Reader) fragEntry(i uint32) (fragEntry, error) { - if len(r.fragTable) > int(i) { - return r.fragTable[i], nil - } else if i >= r.Superblock.FragCount { - return fragEntry{}, errors.New("fragment out of bounds") - } - // Populate the fragment table as needed - var blockNum uint32 - if i != 0 { // If i == 0, we go negatives causing issues with uint32s - blockNum = uint32(math.Ceil(float64(i+1)/512)) - 1 - } else { - blockNum = 0 - } - blocksRead := len(r.fragTable) / 512 - blocksToRead := int(blockNum) - blocksRead + 1 + return readPagedItems(int(i), 512, &r.fragTable, int(r.Superblock.FragCount), + func(startBlock, fragsToRead int) ([]fragEntry, error) { + // get the offset of the next block of fragments + var offset uint64 + err := binary.Read(toreader.NewReader(r.r, int64(r.Superblock.FragTableStart)+int64(8*startBlock)), binary.LittleEndian, &offset) + if err != nil { + return nil, err + } - var offset uint64 - var fragsToRead uint32 - var fragsTmp []fragEntry - var err error - var rdr metadata.Reader - // We can *maybe* have a slight speed increase by manually decoding instead of using reflection via binary.Read - for i := blocksRead; i < int(blocksRead)+blocksToRead; i++ { - err = binary.Read(toreader.NewReader(r.r, int64(r.Superblock.FragTableStart)+int64(8*i)), binary.LittleEndian, &offset) - if err != nil { - return fragEntry{}, err - } - fragsToRead = min(r.Superblock.FragCount-uint32(len(r.fragTable)), 512) - fragsTmp = make([]fragEntry, fragsToRead) - rdr = metadata.NewReader(toreader.NewReader(r.r, int64(offset)), r.d) - err = binary.Read(&rdr, binary.LittleEndian, &fragsTmp) - rdr.Close() - if err != nil { - return fragEntry{}, err - } - r.fragTable = append(r.fragTable, fragsTmp...) - } - return r.fragTable[i], nil + fragsTmp := make([]fragEntry, fragsToRead) + rdr := metadata.NewReader(toreader.NewReader(r.r, int64(offset)), r.d) + defer rdr.Close() + err = binary.Read(rdr, binary.LittleEndian, &fragsTmp) + if err != nil { + return nil, err + } + + return fragsTmp, nil + }) } // Get an inode reference at the given index. Lazily populates the reader's export table as necessary. diff --git a/low/reader_test.go b/low/reader_test.go index 403a7ba..9e895a6 100644 --- a/low/reader_test.go +++ b/low/reader_test.go @@ -77,8 +77,10 @@ func TestReader(t *testing.T) { path := filepath.Join(tmpDir, "extractTest") os.RemoveAll(path) os.MkdirAll(path, 0777) - err = extractToDir(rdr, rdr.Root.FileBase, path) - t.Fatal(err) + err = extractToDir(rdr, &rdr.Root.FileBase, path) + if err != nil { + t.Fatal(err) + } } var singleFile = "PortableApps/CPU-X/CPU-X-v4.2.0-x86_64.AppImage" @@ -101,8 +103,10 @@ func TestSingleFile(t *testing.T) { if err != nil { t.Fatal(err) } - err = extractToDir(rdr, b, path) - t.Fatal(err) + err = extractToDir(rdr, &b, path) + if err != nil { + t.Fatal(err) + } } func extractToDir(rdr Reader, b FileBase, folder string) error {