diff --git a/low/caching_paged_reader.go b/low/caching_paged_reader.go new file mode 100644 index 0000000..9c94aea --- /dev/null +++ b/low/caching_paged_reader.go @@ -0,0 +1,72 @@ +package squashfslow + +import ( + "errors" + "math" +) + +var errOutOfBounds = errors.New("out of bounds") +var errUnexpectedOutOfBounds = errors.New("unexpected out of bounds") +var errNilCollection = errors.New("nil collection") + +// readPagedItems calls readBLockOrPartial the correct number of times to cache +// requestedItemIndex in currentItems, and then returns currentItems[requestedItemIndex]. +// Parameters: +// - requestedItemIndex: The index of the item to be retrieved. +// - blockSize: The number of items per block. +// - currentItems: A slice of already-read items to manage in-memory storage. Must not be nil. +// - readBlockOrPartial: A callback function that reads the next block. It takes the index of the block +// to be read, and the number of items to read. It is normally passed block size, but if the last +// block is incomplete, it will be passed the number of items in the last block. +// Returns: +// - the T at requestedItemIndex +// - a non-nil error and the zero value of T if an error was encountered. +func readPagedItems[T any]( + requestedItemIndex int, + blockSize int, + currentItems *[]T, + totalItems int, + readBlockOrPartial func(idxBlock, numItems int) ([]T, error), +) (T, error) { + var zero T // Zero value for the item type, used for default return in error cases. + if currentItems == nil { + return zero, errNilCollection + } + + if requestedItemIndex < 0 || requestedItemIndex >= totalItems { + return zero, errOutOfBounds + } + + if len(*currentItems) > requestedItemIndex { + return (*currentItems)[requestedItemIndex], nil + } + + // Calculate which block contains the requested item + blockNum := int(math.Ceil(float64(requestedItemIndex+1)/float64(blockSize))) - 1 + + // Calculate blocks to read + blocksRead := len(*currentItems) / blockSize + blocksToRead := blockNum - blocksRead + 1 + + // Read and append new blocks + for i := 0; i < blocksToRead; i++ { + startBlock := blocksRead + i + itemsLeft := totalItems - len(*currentItems) + itemsToRead := blockSize + if itemsToRead > itemsLeft { + itemsToRead = itemsLeft + } + items, err := readBlockOrPartial(startBlock, itemsToRead) + if err != nil { + return zero, err + } + *currentItems = append(*currentItems, items...) + } + + // Ensure the slice contains the requested index after reading + if len(*currentItems) <= requestedItemIndex { + return zero, errUnexpectedOutOfBounds + } + + return (*currentItems)[requestedItemIndex], nil +} diff --git a/low/caching_paged_reader_test.go b/low/caching_paged_reader_test.go new file mode 100644 index 0000000..51e8ba9 --- /dev/null +++ b/low/caching_paged_reader_test.go @@ -0,0 +1,130 @@ +package squashfslow + +import ( + "errors" + "testing" +) + +func requireNoError(t *testing.T, err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } +} + +func assertEqual(t *testing.T, want int, got int) { + t.Helper() + if want != got { + t.Errorf("want %d, got %d", want, got) + } +} + +func assertLength(t *testing.T, want int, slice []int) { + t.Helper() + if len(slice) != want { + t.Errorf("want len %d, got %d", want, len(slice)) + } +} + +func assertErrorIs(t *testing.T, err error, wantErr error) { + t.Helper() + if err == nil { + t.Errorf("want %s, got nil", wantErr) + return + } + if !errors.Is(err, wantErr) { + t.Errorf("want %s, got %v", wantErr, err) + } +} + +func TestCachingPagedReader(t *testing.T) { + // Mock readBlocks function + mockReadNMore := func(startBlock, numItems int) ([]int, error) { + if startBlock < 0 { + return nil, errors.New("invalid block start") + } + var result []int + for i := 0; i < numItems; i++ { + result = append(result, startBlock*512+i) + } + return result, nil + } + + t.Run("ValidRequestWithinFirstBlock", func(t *testing.T) { + currentItems := make([]int, 0) + item, err := readPagedItems(300, 512, ¤tItems, 2048, mockReadNMore) + requireNoError(t, err) + assertEqual(t, 300, item) + assertLength(t, 512, currentItems) // Ensure one block is read + }) + + t.Run("ValidRequestAcrossMultipleBlocks", func(t *testing.T) { + currentItems := make([]int, 0) + item, err := readPagedItems(600, 512, ¤tItems, 2048, mockReadNMore) + requireNoError(t, err) + assertEqual(t, 600, item) + assertLength(t, 1024, currentItems) + }) + + t.Run("SequentialRequestsWithinBlocks", func(t *testing.T) { + currentItems := make([]int, 0) + // First request + item, err := readPagedItems(300, 512, ¤tItems, 2048, mockReadNMore) + requireNoError(t, err) + assertEqual(t, 300, item) + + // Second request in the same block + item, err = readPagedItems(400, 512, ¤tItems, 2048, mockReadNMore) + requireNoError(t, err) + assertEqual(t, 400, item) + assertLength(t, 512, currentItems) + }) + + t.Run("RequestExactBlockBoundary", func(t *testing.T) { + currentItems := make([]int, 0) + item, err := readPagedItems(511, 512, ¤tItems, 2048, mockReadNMore) + requireNoError(t, err) + assertEqual(t, 511, item) + assertLength(t, 512, currentItems) + + // Request the next block's first item + item, err = readPagedItems(512, 512, ¤tItems, 2048, mockReadNMore) + requireNoError(t, err) + assertEqual(t, 512, item) + assertLength(t, 1024, currentItems) + }) + + t.Run("OutOfBoundsRequest", func(t *testing.T) { + currentItems := make([]int, 0) + _, err := readPagedItems(2048, 512, ¤tItems, 2048, mockReadNMore) + assertErrorIs(t, err, errOutOfBounds) + }) + + t.Run("RequestBeyondReadBlocks", func(t *testing.T) { + readFail := errors.New("failed to read block") + failingReadBlocks := func(startBlock, numBlocks int) ([]int, error) { + if startBlock > 1 { + return nil, readFail + } + var result []int + for i := 0; i < numBlocks*512; i++ { + result = append(result, startBlock*512+i) + } + return result, nil + } + + currentItems := make([]int, 0) + _, err := readPagedItems(1024, 512, ¤tItems, 2048, failingReadBlocks) + assertErrorIs(t, err, readFail) + }) + + t.Run("partial last page", func(t *testing.T) { + currentItems := make([]int, 0) + + // Request the next block's first item + item, err := readPagedItems(512, 512, ¤tItems, 612, mockReadNMore) + requireNoError(t, err) + assertEqual(t, 512, item) + assertLength(t, 612, currentItems) + }) +} diff --git a/low/reader.go b/low/reader.go index 4f8efb9..397a499 100644 --- a/low/reader.go +++ b/low/reader.go @@ -126,43 +126,25 @@ func (r *Reader) Id(i uint16) (uint32, error) { // Get a fragment entry at the given index. Lazily populates the reader's fragment table as necessary. func (r *Reader) fragEntry(i uint32) (fragEntry, error) { - if len(r.fragTable) > int(i) { - return r.fragTable[i], nil - } else if i >= r.Superblock.FragCount { - return fragEntry{}, errors.New("fragment out of bounds") - } - // Populate the fragment table as needed - var blockNum uint32 - if i != 0 { // If i == 0, we go negatives causing issues with uint32s - blockNum = uint32(math.Ceil(float64(i+1)/512)) - 1 - } else { - blockNum = 0 - } - blocksRead := len(r.fragTable) / 512 - blocksToRead := int(blockNum) - blocksRead + 1 + return readPagedItems(int(i), 512, &r.fragTable, int(r.Superblock.FragCount), + func(startBlock, fragsToRead int) ([]fragEntry, error) { + // get the offset of the next block of fragments + var offset uint64 + err := binary.Read(toreader.NewReader(r.r, int64(r.Superblock.FragTableStart)+int64(8*startBlock)), binary.LittleEndian, &offset) + if err != nil { + return nil, err + } - var offset uint64 - var fragsToRead uint32 - var fragsTmp []fragEntry - var err error - var rdr metadata.Reader - // We can *maybe* have a slight speed increase by manually decoding instead of using reflection via binary.Read - for i := blocksRead; i < int(blocksRead)+blocksToRead; i++ { - err = binary.Read(toreader.NewReader(r.r, int64(r.Superblock.FragTableStart)+int64(8*i)), binary.LittleEndian, &offset) - if err != nil { - return fragEntry{}, err - } - fragsToRead = min(r.Superblock.FragCount-uint32(len(r.fragTable)), 512) - fragsTmp = make([]fragEntry, fragsToRead) - rdr = metadata.NewReader(toreader.NewReader(r.r, int64(offset)), r.d) - err = binary.Read(&rdr, binary.LittleEndian, &fragsTmp) - rdr.Close() - if err != nil { - return fragEntry{}, err - } - r.fragTable = append(r.fragTable, fragsTmp...) - } - return r.fragTable[i], nil + fragsTmp := make([]fragEntry, fragsToRead) + rdr := metadata.NewReader(toreader.NewReader(r.r, int64(offset)), r.d) + defer rdr.Close() + err = binary.Read(rdr, binary.LittleEndian, &fragsTmp) + if err != nil { + return nil, err + } + + return fragsTmp, nil + }) } // Get an inode reference at the given index. Lazily populates the reader's export table as necessary. diff --git a/low/reader_test.go b/low/reader_test.go index 403a7ba..9e895a6 100644 --- a/low/reader_test.go +++ b/low/reader_test.go @@ -77,8 +77,10 @@ func TestReader(t *testing.T) { path := filepath.Join(tmpDir, "extractTest") os.RemoveAll(path) os.MkdirAll(path, 0777) - err = extractToDir(rdr, rdr.Root.FileBase, path) - t.Fatal(err) + err = extractToDir(rdr, &rdr.Root.FileBase, path) + if err != nil { + t.Fatal(err) + } } var singleFile = "PortableApps/CPU-X/CPU-X-v4.2.0-x86_64.AppImage" @@ -101,8 +103,10 @@ func TestSingleFile(t *testing.T) { if err != nil { t.Fatal(err) } - err = extractToDir(rdr, b, path) - t.Fatal(err) + err = extractToDir(rdr, &b, path) + if err != nil { + t.Fatal(err) + } } func extractToDir(rdr Reader, b FileBase, folder string) error {