fix: prevent index out of range on long frag tables

Previously, reading fragment 512 would panic with index out of range.
Fix that panic by introducing an abstraction over reading blocks of
items, caching the intermediate result, and returning an item at a
particular index. The primary goal of this abstraction is to make edge
cases like requesting items on page boundaries easy to unit test for.

Additionally, fix unit tests by making t.Fatal calls protected by
nil checks on the error values.
This commit is contained in:
Will Murphy
2024-11-26 07:07:00 -05:00
committed by Caleb Gardner
parent 81b663b48a
commit 0905141013
4 changed files with 228 additions and 40 deletions
+72
View File
@@ -0,0 +1,72 @@
package squashfslow
import (
"errors"
"math"
)
var errOutOfBounds = errors.New("out of bounds")
var errUnexpectedOutOfBounds = errors.New("unexpected out of bounds")
var errNilCollection = errors.New("nil collection")
// readPagedItems calls readBLockOrPartial the correct number of times to cache
// requestedItemIndex in currentItems, and then returns currentItems[requestedItemIndex].
// Parameters:
// - requestedItemIndex: The index of the item to be retrieved.
// - blockSize: The number of items per block.
// - currentItems: A slice of already-read items to manage in-memory storage. Must not be nil.
// - readBlockOrPartial: A callback function that reads the next block. It takes the index of the block
// to be read, and the number of items to read. It is normally passed block size, but if the last
// block is incomplete, it will be passed the number of items in the last block.
// Returns:
// - the T at requestedItemIndex
// - a non-nil error and the zero value of T if an error was encountered.
func readPagedItems[T any](
requestedItemIndex int,
blockSize int,
currentItems *[]T,
totalItems int,
readBlockOrPartial func(idxBlock, numItems int) ([]T, error),
) (T, error) {
var zero T // Zero value for the item type, used for default return in error cases.
if currentItems == nil {
return zero, errNilCollection
}
if requestedItemIndex < 0 || requestedItemIndex >= totalItems {
return zero, errOutOfBounds
}
if len(*currentItems) > requestedItemIndex {
return (*currentItems)[requestedItemIndex], nil
}
// Calculate which block contains the requested item
blockNum := int(math.Ceil(float64(requestedItemIndex+1)/float64(blockSize))) - 1
// Calculate blocks to read
blocksRead := len(*currentItems) / blockSize
blocksToRead := blockNum - blocksRead + 1
// Read and append new blocks
for i := 0; i < blocksToRead; i++ {
startBlock := blocksRead + i
itemsLeft := totalItems - len(*currentItems)
itemsToRead := blockSize
if itemsToRead > itemsLeft {
itemsToRead = itemsLeft
}
items, err := readBlockOrPartial(startBlock, itemsToRead)
if err != nil {
return zero, err
}
*currentItems = append(*currentItems, items...)
}
// Ensure the slice contains the requested index after reading
if len(*currentItems) <= requestedItemIndex {
return zero, errUnexpectedOutOfBounds
}
return (*currentItems)[requestedItemIndex], nil
}
+130
View File
@@ -0,0 +1,130 @@
package squashfslow
import (
"errors"
"testing"
)
func requireNoError(t *testing.T, err error) {
t.Helper()
if err != nil {
t.Fatal(err)
}
}
func assertEqual(t *testing.T, want int, got int) {
t.Helper()
if want != got {
t.Errorf("want %d, got %d", want, got)
}
}
func assertLength(t *testing.T, want int, slice []int) {
t.Helper()
if len(slice) != want {
t.Errorf("want len %d, got %d", want, len(slice))
}
}
func assertErrorIs(t *testing.T, err error, wantErr error) {
t.Helper()
if err == nil {
t.Errorf("want %s, got nil", wantErr)
return
}
if !errors.Is(err, wantErr) {
t.Errorf("want %s, got %v", wantErr, err)
}
}
func TestCachingPagedReader(t *testing.T) {
// Mock readBlocks function
mockReadNMore := func(startBlock, numItems int) ([]int, error) {
if startBlock < 0 {
return nil, errors.New("invalid block start")
}
var result []int
for i := 0; i < numItems; i++ {
result = append(result, startBlock*512+i)
}
return result, nil
}
t.Run("ValidRequestWithinFirstBlock", func(t *testing.T) {
currentItems := make([]int, 0)
item, err := readPagedItems(300, 512, &currentItems, 2048, mockReadNMore)
requireNoError(t, err)
assertEqual(t, 300, item)
assertLength(t, 512, currentItems) // Ensure one block is read
})
t.Run("ValidRequestAcrossMultipleBlocks", func(t *testing.T) {
currentItems := make([]int, 0)
item, err := readPagedItems(600, 512, &currentItems, 2048, mockReadNMore)
requireNoError(t, err)
assertEqual(t, 600, item)
assertLength(t, 1024, currentItems)
})
t.Run("SequentialRequestsWithinBlocks", func(t *testing.T) {
currentItems := make([]int, 0)
// First request
item, err := readPagedItems(300, 512, &currentItems, 2048, mockReadNMore)
requireNoError(t, err)
assertEqual(t, 300, item)
// Second request in the same block
item, err = readPagedItems(400, 512, &currentItems, 2048, mockReadNMore)
requireNoError(t, err)
assertEqual(t, 400, item)
assertLength(t, 512, currentItems)
})
t.Run("RequestExactBlockBoundary", func(t *testing.T) {
currentItems := make([]int, 0)
item, err := readPagedItems(511, 512, &currentItems, 2048, mockReadNMore)
requireNoError(t, err)
assertEqual(t, 511, item)
assertLength(t, 512, currentItems)
// Request the next block's first item
item, err = readPagedItems(512, 512, &currentItems, 2048, mockReadNMore)
requireNoError(t, err)
assertEqual(t, 512, item)
assertLength(t, 1024, currentItems)
})
t.Run("OutOfBoundsRequest", func(t *testing.T) {
currentItems := make([]int, 0)
_, err := readPagedItems(2048, 512, &currentItems, 2048, mockReadNMore)
assertErrorIs(t, err, errOutOfBounds)
})
t.Run("RequestBeyondReadBlocks", func(t *testing.T) {
readFail := errors.New("failed to read block")
failingReadBlocks := func(startBlock, numBlocks int) ([]int, error) {
if startBlock > 1 {
return nil, readFail
}
var result []int
for i := 0; i < numBlocks*512; i++ {
result = append(result, startBlock*512+i)
}
return result, nil
}
currentItems := make([]int, 0)
_, err := readPagedItems(1024, 512, &currentItems, 2048, failingReadBlocks)
assertErrorIs(t, err, readFail)
})
t.Run("partial last page", func(t *testing.T) {
currentItems := make([]int, 0)
// Request the next block's first item
item, err := readPagedItems(512, 512, &currentItems, 612, mockReadNMore)
requireNoError(t, err)
assertEqual(t, 512, item)
assertLength(t, 612, currentItems)
})
}
+18 -36
View File
@@ -126,43 +126,25 @@ func (r *Reader) Id(i uint16) (uint32, error) {
// Get a fragment entry at the given index. Lazily populates the reader's fragment table as necessary. // Get a fragment entry at the given index. Lazily populates the reader's fragment table as necessary.
func (r *Reader) fragEntry(i uint32) (fragEntry, error) { func (r *Reader) fragEntry(i uint32) (fragEntry, error) {
if len(r.fragTable) > int(i) { return readPagedItems(int(i), 512, &r.fragTable, int(r.Superblock.FragCount),
return r.fragTable[i], nil func(startBlock, fragsToRead int) ([]fragEntry, error) {
} else if i >= r.Superblock.FragCount { // get the offset of the next block of fragments
return fragEntry{}, errors.New("fragment out of bounds") var offset uint64
} err := binary.Read(toreader.NewReader(r.r, int64(r.Superblock.FragTableStart)+int64(8*startBlock)), binary.LittleEndian, &offset)
// Populate the fragment table as needed if err != nil {
var blockNum uint32 return nil, err
if i != 0 { // If i == 0, we go negatives causing issues with uint32s }
blockNum = uint32(math.Ceil(float64(i+1)/512)) - 1
} else {
blockNum = 0
}
blocksRead := len(r.fragTable) / 512
blocksToRead := int(blockNum) - blocksRead + 1
var offset uint64 fragsTmp := make([]fragEntry, fragsToRead)
var fragsToRead uint32 rdr := metadata.NewReader(toreader.NewReader(r.r, int64(offset)), r.d)
var fragsTmp []fragEntry defer rdr.Close()
var err error err = binary.Read(rdr, binary.LittleEndian, &fragsTmp)
var rdr metadata.Reader if err != nil {
// We can *maybe* have a slight speed increase by manually decoding instead of using reflection via binary.Read return nil, err
for i := blocksRead; i < int(blocksRead)+blocksToRead; i++ { }
err = binary.Read(toreader.NewReader(r.r, int64(r.Superblock.FragTableStart)+int64(8*i)), binary.LittleEndian, &offset)
if err != nil { return fragsTmp, nil
return fragEntry{}, err })
}
fragsToRead = min(r.Superblock.FragCount-uint32(len(r.fragTable)), 512)
fragsTmp = make([]fragEntry, fragsToRead)
rdr = metadata.NewReader(toreader.NewReader(r.r, int64(offset)), r.d)
err = binary.Read(&rdr, binary.LittleEndian, &fragsTmp)
rdr.Close()
if err != nil {
return fragEntry{}, err
}
r.fragTable = append(r.fragTable, fragsTmp...)
}
return r.fragTable[i], nil
} }
// Get an inode reference at the given index. Lazily populates the reader's export table as necessary. // Get an inode reference at the given index. Lazily populates the reader's export table as necessary.
+8 -4
View File
@@ -77,8 +77,10 @@ func TestReader(t *testing.T) {
path := filepath.Join(tmpDir, "extractTest") path := filepath.Join(tmpDir, "extractTest")
os.RemoveAll(path) os.RemoveAll(path)
os.MkdirAll(path, 0777) os.MkdirAll(path, 0777)
err = extractToDir(rdr, rdr.Root.FileBase, path) err = extractToDir(rdr, &rdr.Root.FileBase, path)
t.Fatal(err) if err != nil {
t.Fatal(err)
}
} }
var singleFile = "PortableApps/CPU-X/CPU-X-v4.2.0-x86_64.AppImage" var singleFile = "PortableApps/CPU-X/CPU-X-v4.2.0-x86_64.AppImage"
@@ -101,8 +103,10 @@ func TestSingleFile(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = extractToDir(rdr, b, path) err = extractToDir(rdr, &b, path)
t.Fatal(err) if err != nil {
t.Fatal(err)
}
} }
func extractToDir(rdr Reader, b FileBase, folder string) error { func extractToDir(rdr Reader, b FileBase, folder string) error {