Rework caching_paged_reader into Table[T]

This *should* fix some issues with extraction due to race conditions
This commit is contained in:
Caleb Gardner
2025-06-06 15:37:48 -05:00
parent b2c8084f41
commit ddb81aade0
5 changed files with 202 additions and 291 deletions
-72
View File
@@ -1,72 +0,0 @@
package squashfslow
import (
"errors"
"math"
)
var errOutOfBounds = errors.New("out of bounds")
var errUnexpectedOutOfBounds = errors.New("unexpected out of bounds")
var errNilCollection = errors.New("nil collection")
// readPagedItems calls readBLockOrPartial the correct number of times to cache
// requestedItemIndex in currentItems, and then returns currentItems[requestedItemIndex].
// Parameters:
// - requestedItemIndex: The index of the item to be retrieved.
// - blockSize: The number of items per block.
// - currentItems: A slice of already-read items to manage in-memory storage. Must not be nil.
// - readBlockOrPartial: A callback function that reads the next block. It takes the index of the block
// to be read, and the number of items to read. It is normally passed block size, but if the last
// block is incomplete, it will be passed the number of items in the last block.
// Returns:
// - the T at requestedItemIndex
// - a non-nil error and the zero value of T if an error was encountered.
func readPagedItems[T any](
requestedItemIndex int,
blockSize int,
currentItems *[]T,
totalItems int,
readBlockOrPartial func(idxBlock, numItems int) ([]T, error),
) (T, error) {
var zero T // Zero value for the item type, used for default return in error cases.
if currentItems == nil {
return zero, errNilCollection
}
if requestedItemIndex < 0 || requestedItemIndex >= totalItems {
return zero, errOutOfBounds
}
if len(*currentItems) > requestedItemIndex {
return (*currentItems)[requestedItemIndex], nil
}
// Calculate which block contains the requested item
blockNum := int(math.Ceil(float64(requestedItemIndex+1)/float64(blockSize))) - 1
// Calculate blocks to read
blocksRead := len(*currentItems) / blockSize
blocksToRead := blockNum - blocksRead + 1
// Read and append new blocks
for i := 0; i < blocksToRead; i++ {
startBlock := blocksRead + i
itemsLeft := totalItems - len(*currentItems)
itemsToRead := blockSize
if itemsToRead > itemsLeft {
itemsToRead = itemsLeft
}
items, err := readBlockOrPartial(startBlock, itemsToRead)
if err != nil {
return zero, err
}
*currentItems = append(*currentItems, items...)
}
// Ensure the slice contains the requested index after reading
if len(*currentItems) <= requestedItemIndex {
return zero, errUnexpectedOutOfBounds
}
return (*currentItems)[requestedItemIndex], nil
}
+110 -113
View File
@@ -1,130 +1,127 @@
package squashfslow package squashfslow
import ( // TODO: Make work
"errors" // func requireNoError(t *testing.T, err error) {
"testing" // t.Helper()
) // if err != nil {
// t.Fatal(err)
// }
// }
func requireNoError(t *testing.T, err error) { // func assertEqual(t *testing.T, want int, got int) {
t.Helper() // t.Helper()
if err != nil { // if want != got {
t.Fatal(err) // t.Errorf("want %d, got %d", want, got)
} // }
} // }
func assertEqual(t *testing.T, want int, got int) { // func assertLength(t *testing.T, want int, slice []int) {
t.Helper() // t.Helper()
if want != got { // if len(slice) != want {
t.Errorf("want %d, got %d", want, got) // t.Errorf("want len %d, got %d", want, len(slice))
} // }
} // }
func assertLength(t *testing.T, want int, slice []int) { // func assertErrorIs(t *testing.T, err error, wantErr error) {
t.Helper() // t.Helper()
if len(slice) != want { // if err == nil {
t.Errorf("want len %d, got %d", want, len(slice)) // t.Errorf("want %s, got nil", wantErr)
} // return
} // }
// if !errors.Is(err, wantErr) {
// t.Errorf("want %s, got %v", wantErr, err)
// }
// }
func assertErrorIs(t *testing.T, err error, wantErr error) { // func TestCachingPagedReader(t *testing.T) {
t.Helper() // // Mock readBlocks function
if err == nil { // mockReadNMore := func(startBlock, numItems int) ([]int, error) {
t.Errorf("want %s, got nil", wantErr) // if startBlock < 0 {
return // return nil, errors.New("invalid block start")
} // }
if !errors.Is(err, wantErr) { // var result []int
t.Errorf("want %s, got %v", wantErr, err) // for i := 0; i < numItems; i++ {
} // result = append(result, startBlock*512+i)
} // }
// return result, nil
// }
func TestCachingPagedReader(t *testing.T) { // t.Run("ValidRequestWithinFirstBlock", func(t *testing.T) {
// Mock readBlocks function // tab := NewTable[int]()
mockReadNMore := func(startBlock, numItems int) ([]int, error) { // currentItems := make([]int, 0)
if startBlock < 0 { // item, err := readPagedItems(300, 512, &currentItems, 2048, mockReadNMore)
return nil, errors.New("invalid block start") // requireNoError(t, err)
} // assertEqual(t, 300, item)
var result []int // assertLength(t, 512, currentItems) // Ensure one block is read
for i := 0; i < numItems; i++ { // })
result = append(result, startBlock*512+i)
}
return result, nil
}
t.Run("ValidRequestWithinFirstBlock", func(t *testing.T) { // t.Run("ValidRequestAcrossMultipleBlocks", func(t *testing.T) {
currentItems := make([]int, 0) // currentItems := make([]int, 0)
item, err := readPagedItems(300, 512, &currentItems, 2048, mockReadNMore) // item, err := readPagedItems(600, 512, &currentItems, 2048, mockReadNMore)
requireNoError(t, err) // requireNoError(t, err)
assertEqual(t, 300, item) // assertEqual(t, 600, item)
assertLength(t, 512, currentItems) // Ensure one block is read // assertLength(t, 1024, currentItems)
}) // })
t.Run("ValidRequestAcrossMultipleBlocks", func(t *testing.T) { // t.Run("SequentialRequestsWithinBlocks", func(t *testing.T) {
currentItems := make([]int, 0) // currentItems := make([]int, 0)
item, err := readPagedItems(600, 512, &currentItems, 2048, mockReadNMore) // // First request
requireNoError(t, err) // item, err := readPagedItems(300, 512, &currentItems, 2048, mockReadNMore)
assertEqual(t, 600, item) // requireNoError(t, err)
assertLength(t, 1024, currentItems) // assertEqual(t, 300, item)
})
t.Run("SequentialRequestsWithinBlocks", func(t *testing.T) { // // Second request in the same block
currentItems := make([]int, 0) // item, err = readPagedItems(400, 512, &currentItems, 2048, mockReadNMore)
// First request // requireNoError(t, err)
item, err := readPagedItems(300, 512, &currentItems, 2048, mockReadNMore) // assertEqual(t, 400, item)
requireNoError(t, err) // assertLength(t, 512, currentItems)
assertEqual(t, 300, item) // })
// Second request in the same block // t.Run("RequestExactBlockBoundary", func(t *testing.T) {
item, err = readPagedItems(400, 512, &currentItems, 2048, mockReadNMore) // currentItems := make([]int, 0)
requireNoError(t, err) // item, err := readPagedItems(511, 512, &currentItems, 2048, mockReadNMore)
assertEqual(t, 400, item) // requireNoError(t, err)
assertLength(t, 512, currentItems) // assertEqual(t, 511, item)
}) // assertLength(t, 512, currentItems)
t.Run("RequestExactBlockBoundary", func(t *testing.T) { // // Request the next block's first item
currentItems := make([]int, 0) // item, err = readPagedItems(512, 512, &currentItems, 2048, mockReadNMore)
item, err := readPagedItems(511, 512, &currentItems, 2048, mockReadNMore) // requireNoError(t, err)
requireNoError(t, err) // assertEqual(t, 512, item)
assertEqual(t, 511, item) // assertLength(t, 1024, currentItems)
assertLength(t, 512, currentItems) // })
// Request the next block's first item // t.Run("OutOfBoundsRequest", func(t *testing.T) {
item, err = readPagedItems(512, 512, &currentItems, 2048, mockReadNMore) // currentItems := make([]int, 0)
requireNoError(t, err) // _, err := readPagedItems(2048, 512, &currentItems, 2048, mockReadNMore)
assertEqual(t, 512, item) // assertErrorIs(t, err, errOutOfBounds)
assertLength(t, 1024, currentItems) // })
})
t.Run("OutOfBoundsRequest", func(t *testing.T) { // t.Run("RequestBeyondReadBlocks", func(t *testing.T) {
currentItems := make([]int, 0) // readFail := errors.New("failed to read block")
_, err := readPagedItems(2048, 512, &currentItems, 2048, mockReadNMore) // failingReadBlocks := func(startBlock, numBlocks int) ([]int, error) {
assertErrorIs(t, err, errOutOfBounds) // if startBlock > 1 {
}) // return nil, readFail
// }
// var result []int
// for i := 0; i < numBlocks*512; i++ {
// result = append(result, startBlock*512+i)
// }
// return result, nil
// }
t.Run("RequestBeyondReadBlocks", func(t *testing.T) { // currentItems := make([]int, 0)
readFail := errors.New("failed to read block") // _, err := readPagedItems(1024, 512, &currentItems, 2048, failingReadBlocks)
failingReadBlocks := func(startBlock, numBlocks int) ([]int, error) { // assertErrorIs(t, err, readFail)
if startBlock > 1 { // })
return nil, readFail
}
var result []int
for i := 0; i < numBlocks*512; i++ {
result = append(result, startBlock*512+i)
}
return result, nil
}
currentItems := make([]int, 0) // t.Run("partial last page", func(t *testing.T) {
_, err := readPagedItems(1024, 512, &currentItems, 2048, failingReadBlocks) // currentItems := make([]int, 0)
assertErrorIs(t, err, readFail)
})
t.Run("partial last page", func(t *testing.T) { // // Request the next block's first item
currentItems := make([]int, 0) // item, err := readPagedItems(512, 512, &currentItems, 612, mockReadNMore)
// requireNoError(t, err)
// Request the next block's first item // assertEqual(t, 512, item)
item, err := readPagedItems(512, 512, &currentItems, 612, mockReadNMore) // assertLength(t, 612, currentItems)
requireNoError(t, err) // })
assertEqual(t, 512, item) // }
assertLength(t, 612, currentItems)
})
}
+12 -104
View File
@@ -4,10 +4,8 @@ import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"io" "io"
"math"
"github.com/CalebQ42/squashfs/internal/decompress" "github.com/CalebQ42/squashfs/internal/decompress"
"github.com/CalebQ42/squashfs/internal/metadata"
"github.com/CalebQ42/squashfs/internal/toreader" "github.com/CalebQ42/squashfs/internal/toreader"
"github.com/CalebQ42/squashfs/low/inode" "github.com/CalebQ42/squashfs/low/inode"
) )
@@ -30,13 +28,13 @@ var (
) )
type Reader struct { type Reader struct {
Root Directory
Superblock superblock
r io.ReaderAt r io.ReaderAt
d decompress.Decompressor d decompress.Decompressor
Root Directory fragTable *Table[fragEntry]
fragTable []fragEntry idTable *Table[uint32]
idTable []uint32 exportTable *Table[uint64]
exportTable []uint64
Superblock superblock
} }
func NewReader(r io.ReaderAt) (rdr Reader, err error) { func NewReader(r io.ReaderAt) (rdr Reader, err error) {
@@ -80,119 +78,29 @@ func NewReader(r io.ReaderAt) (rdr Reader, err error) {
if err != nil { if err != nil {
return rdr, errors.Join(errors.New("failed to read root directory"), err) return rdr, errors.Join(errors.New("failed to read root directory"), err)
} }
rdr.fragTable = NewTable[fragEntry](&rdr, rdr.Superblock.FragTableStart, rdr.Superblock.FragCount)
rdr.idTable = NewTable[uint32](&rdr, rdr.Superblock.IdTableStart, uint32(rdr.Superblock.IdCount))
rdr.exportTable = NewTable[uint64](&rdr, rdr.Superblock.ExportTableStart, rdr.Superblock.InodeCount)
return return
} }
// Get a uid/gid at the given index. Lazily populates the reader's Id table as necessary. // Get a uid/gid at the given index. Lazily populates the reader's Id table as necessary.
func (r *Reader) Id(i uint16) (uint32, error) { func (r *Reader) Id(i uint16) (uint32, error) {
if len(r.idTable) > int(i) { return r.idTable.Get(uint32(i))
return r.idTable[i], nil
} else if i >= r.Superblock.IdCount {
return 0, errors.New("id out of bounds")
}
// Populate the id table as needed
var blockNum uint32
if i != 0 { // If i == 0, we go negatives causing issues with uint32s
blockNum = uint32(math.Ceil(float64(i+1)/2048)) - 1
} else {
blockNum = 0
}
blocksRead := len(r.idTable) / 2048
blocksToRead := int(blockNum) - blocksRead + 1
var offset uint64
var idsToRead uint16
var idsTmp []uint32
var err error
var rdr metadata.Reader
// We can *maybe* have a slight speed increase by manually decoding instead of using reflection via binary.Read
for i := blocksRead; i < int(blocksRead)+blocksToRead; i++ {
err = binary.Read(toreader.NewReader(r.r, int64(r.Superblock.IdTableStart)+int64(8*i)), binary.LittleEndian, &offset)
if err != nil {
return 0, err
}
idsToRead = min(r.Superblock.IdCount-uint16(len(r.idTable)), 2048)
idsTmp = make([]uint32, idsToRead)
rdr = metadata.NewReader(toreader.NewReader(r.r, int64(offset)), r.d)
err = binary.Read(&rdr, binary.LittleEndian, &idsTmp)
rdr.Close()
if err != nil {
return 0, err
}
r.idTable = append(r.idTable, idsTmp...)
}
return r.idTable[i], nil
} }
// Get a fragment entry at the given index. Lazily populates the reader's fragment table as necessary. // Get a fragment entry at the given index. Lazily populates the reader's fragment table as necessary.
func (r *Reader) fragEntry(i uint32) (fragEntry, error) { func (r *Reader) fragEntry(i uint32) (fragEntry, error) {
return readPagedItems(int(i), 512, &r.fragTable, int(r.Superblock.FragCount), return r.fragTable.Get(i)
func(startBlock, fragsToRead int) ([]fragEntry, error) {
// get the offset of the next block of fragments
var offset uint64
err := binary.Read(toreader.NewReader(r.r, int64(r.Superblock.FragTableStart)+int64(8*startBlock)), binary.LittleEndian, &offset)
if err != nil {
return nil, err
}
fragsTmp := make([]fragEntry, fragsToRead)
rdr := metadata.NewReader(toreader.NewReader(r.r, int64(offset)), r.d)
defer rdr.Close()
err = binary.Read(rdr, binary.LittleEndian, &fragsTmp)
if err != nil {
return nil, err
}
return fragsTmp, nil
})
} }
// Get an inode reference at the given index. Lazily populates the reader's export table as necessary. // Get an inode reference at the given index. Lazily populates the reader's export table as necessary.
func (r *Reader) inodeRef(i uint32) (uint64, error) { func (r *Reader) inodeRef(i uint32) (uint64, error) {
if !r.Superblock.Exportable() { return r.exportTable.Get(i)
return 0, ErrorNotExportable
}
if len(r.exportTable) > int(i) {
return r.exportTable[i], nil
} else if i >= r.Superblock.InodeCount {
return 0, errors.New("inode out of bounds")
}
// Populate the export table as needed
var blockNum uint32
if i != 0 { // If i == 0, we go negatives causing issues with uint32s
blockNum = uint32(math.Ceil(float64(i+1)/1024)) - 1
} else {
blockNum = 0
}
blocksRead := len(r.exportTable) / 1024
blocksToRead := int(blockNum) - blocksRead + 1
var offset uint64
var refsToRead uint32
var refsTmp []uint64
var err error
var rdr metadata.Reader
// We can *maybe* have a slight speed increase by manually decoding instead of using reflection via binary.Read
for i := blocksRead; i < int(blocksRead)+blocksToRead; i++ {
err = binary.Read(toreader.NewReader(r.r, int64(r.Superblock.ExportTableStart)+int64(8*i)), binary.LittleEndian, &offset)
if err != nil {
return 0, err
}
refsToRead = min(r.Superblock.InodeCount-uint32(len(r.exportTable)), 1024)
refsTmp = make([]uint64, refsToRead)
rdr = metadata.NewReader(toreader.NewReader(r.r, int64(offset)), r.d)
err = binary.Read(&rdr, binary.LittleEndian, &refsTmp)
rdr.Close()
if err != nil {
return 0, err
}
r.exportTable = append(r.exportTable, refsTmp...)
}
return r.exportTable[i], nil
} }
func (r Reader) Inode(i uint32) (inode.Inode, error) { func (r Reader) Inode(i uint32) (inode.Inode, error) {
ref, err := r.inodeRef(i) ref, err := r.inodeRef(i - 1) // Inode table is 1 indexed
if err != nil { if err != nil {
return inode.Inode{}, err return inode.Inode{}, err
} }
+2 -2
View File
@@ -77,7 +77,7 @@ func TestReader(t *testing.T) {
path := filepath.Join(tmpDir, "extractTest") path := filepath.Join(tmpDir, "extractTest")
os.RemoveAll(path) os.RemoveAll(path)
os.MkdirAll(path, 0777) os.MkdirAll(path, 0777)
err = extractToDir(rdr, &rdr.Root.FileBase, path) err = extractToDir(rdr, rdr.Root.FileBase, path)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -103,7 +103,7 @@ func TestSingleFile(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
err = extractToDir(rdr, &b, path) err = extractToDir(rdr, b, path)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
+78
View File
@@ -0,0 +1,78 @@
package squashfslow
import (
"encoding/binary"
"errors"
"sync"
"github.com/CalebQ42/squashfs/internal/metadata"
"github.com/CalebQ42/squashfs/internal/toreader"
)
var errOutOfBounds = errors.New("out of bounds")
var errUnexpectedOutOfBounds = errors.New("unexpected out of bounds")
var errNilCollection = errors.New("nil collection")
type Table[T any] struct {
totalItems uint32
itemsPerBlock uint32
offset uint64
mut sync.RWMutex
currentItems []T
rdr *Reader
}
func NewTable[T any](rdr *Reader, start uint64, totalItems uint32) *Table[T] {
var zero T
return &Table[T]{
totalItems: totalItems,
itemsPerBlock: 8192 / uint32(binary.Size(zero)),
offset: start,
mut: sync.RWMutex{},
rdr: rdr,
}
}
func (t *Table[T]) Get(requestedItemIndex uint32) (T, error) {
t.mut.RLock()
if requestedItemIndex >= t.totalItems {
t.mut.RUnlock()
var zero T
return zero, errOutOfBounds
}
if uint32(len(t.currentItems)) > requestedItemIndex {
t.mut.RUnlock()
return t.currentItems[requestedItemIndex], nil
}
t.mut.RUnlock()
return t.fillAndGet(requestedItemIndex)
}
func (t *Table[T]) fillAndGet(requestedItemIndex uint32) (T, error) {
t.mut.Lock()
defer t.mut.Unlock()
var offset uint64
var toRead uint32
var rdr *toreader.Reader
var metaRdr metadata.Reader
var err error
for uint32(len(t.currentItems)) <= requestedItemIndex {
rdr = toreader.NewReader(t.rdr.r, int64(t.offset))
err = binary.Read(rdr, binary.LittleEndian, &offset)
if err != nil {
var zero T
return zero, err
}
t.offset += 8
toRead = min(t.itemsPerBlock, t.totalItems-uint32(len(t.currentItems)))
new := make([]T, toRead)
metaRdr = metadata.NewReader(toreader.NewReader(t.rdr.r, int64(offset)), t.rdr.d)
err = binary.Read(&metaRdr, binary.LittleEndian, new)
if err != nil {
var zero T
return zero, err
}
t.currentItems = append(t.currentItems, new...)
}
return t.currentItems[requestedItemIndex], nil
}