From d890932d5c972ddf48cc826347534f8ad5194f53 Mon Sep 17 00:00:00 2001 From: Caleb Gardner Date: Thu, 27 Feb 2025 07:19:04 -0600 Subject: [PATCH] Use WriterAt if it's available for FullReader --- low/data/fullreader.go | 99 +++++++++++++++++++++++++++++++++++++----- low/reader.go | 5 +-- squashfs_test.go | 30 +++++++------ 3 files changed, 105 insertions(+), 29 deletions(-) diff --git a/low/data/fullreader.go b/low/data/fullreader.go index c9c694f..d075844 100644 --- a/low/data/fullreader.go +++ b/low/data/fullreader.go @@ -9,6 +9,7 @@ import ( "sync" "github.com/CalebQ42/squashfs/internal/decompress" + "github.com/CalebQ42/squashfs/internal/routinemanager" "github.com/CalebQ42/squashfs/internal/toreader" ) @@ -18,7 +19,6 @@ type FullReader struct { r io.ReaderAt d decompress.Decompressor frag FragReaderConstructor - retPool *sync.Pool sizes []uint32 initialOffset int64 finalBlockSize uint64 @@ -35,11 +35,6 @@ func NewFullReader(r io.ReaderAt, initialOffset int64, d decompress.Decompressor goroutineLimit: uint16(runtime.NumCPU()), finalBlockSize: finalBlockSize, blockSize: blockSize, - retPool: &sync.Pool{ - New: func() any { - return &retValue{} - }, - }, } } @@ -57,8 +52,8 @@ type retValue struct { index uint64 } -func (r *FullReader) process(index uint64, fileOffset uint64, retChan chan *retValue) { - ret := r.retPool.Get().(*retValue) +func (r FullReader) process(index uint64, fileOffset uint64, pool *sync.Pool, retChan chan *retValue) { + ret := pool.Get().(*retValue) ret.index = index realSize := r.sizes[index] &^ (1 << 24) if realSize == 0 { @@ -79,7 +74,10 @@ func (r *FullReader) process(index uint64, fileOffset uint64, retChan chan *retV retChan <- ret } -func (r *FullReader) WriteTo(w io.Writer) (int64, error) { +func (r FullReader) WriteTo(w io.Writer) (int64, error) { + if wa, is := w.(io.WriterAt); is { + return r.writeToWriteAt(wa) + } var curIndex uint64 var curOffset uint64 var toProcess uint16 @@ -87,11 +85,16 @@ func (r *FullReader) WriteTo(w io.Writer) (int64, error) { cache := make(map[uint64]*retValue) var errCache []error retChan := make(chan *retValue, r.goroutineLimit) + pool := &sync.Pool{ + New: func() any { + return &retValue{} + }, + } for i := uint64(0); i < uint64(math.Ceil(float64(len(r.sizes))/float64(r.goroutineLimit))); i++ { toProcess = min(uint16(len(r.sizes))-(uint16(i)*r.goroutineLimit), r.goroutineLimit) // Start all the goroutines for j := uint16(0); j < toProcess; j++ { - go r.process((i*uint64(r.goroutineLimit))+uint64(j), curOffset, retChan) + go r.process((i*uint64(r.goroutineLimit))+uint64(j), curOffset, pool, retChan) curOffset += uint64(r.sizes[(i*uint64(r.goroutineLimit))+uint64(j)]) &^ (1 << 24) } // Then consume the results on retChan @@ -125,7 +128,7 @@ func (r *FullReader) WriteTo(w io.Writer) (int64, error) { } continue } - r.retPool.Put(res) + pool.Put(res) curIndex++ // Now we recursively try to clear the cache for len(cache) > 0 { @@ -143,7 +146,7 @@ func (r *FullReader) WriteTo(w io.Writer) (int64, error) { break } delete(cache, curIndex) - r.retPool.Put(res) + pool.Put(res) curIndex++ } } @@ -169,3 +172,75 @@ func (r *FullReader) WriteTo(w io.Writer) (int64, error) { } return wrote, nil } + +func (r FullReader) writeToWriteAt(w io.WriterAt) (out int64, outErr error) { + wait := sync.WaitGroup{} + wait.Add(len(r.sizes)) + mgr := routinemanager.NewManager(r.goroutineLimit) + curOffset := r.initialOffset + for i := uint64(0); i < uint64(len(r.sizes)); i++ { + go func(index uint64, fileOffset int64) { + lckNum := mgr.Lock() + defer mgr.Unlock(lckNum) + defer wait.Done() + realSize := r.sizes[index] &^ (1 << 24) + if realSize == 0 { + if index == uint64(len(r.sizes))-1 && r.frag == nil { + _, err := w.WriteAt([]byte{0}, int64((uint64(r.blockSize)*index)+r.finalBlockSize)-1) + if err != nil { + outErr = errors.Join(outErr, err) + return + } + out = max(out, int64((uint64(r.blockSize)*index)+r.finalBlockSize)) + } + return + } + data := make([]byte, realSize) + err := binary.Read(toreader.NewReader(r.r, int64(r.initialOffset)+int64(fileOffset)), binary.LittleEndian, &data) + if err != nil { + outErr = errors.Join(outErr, err) + return + } + if r.sizes[index] == realSize { + data, err = r.d.Decompress(data) + } + if err != nil { + outErr = errors.Join(outErr, err) + return + } + _, err = w.WriteAt(data, int64(uint64(r.blockSize)*index)) + if err != nil { + outErr = errors.Join(outErr, err) + return + } + out = max(out, int64(uint64(r.blockSize)*(index+1))) + }(i, curOffset) + curOffset += int64(r.sizes[i]) &^ (1 << 24) + } + if r.frag != nil { + wait.Add(1) + go func() { + lckNum := mgr.Lock() + defer mgr.Unlock(lckNum) + defer wait.Done() + rdr, err := r.frag() + if err != nil { + outErr = errors.Join(outErr, err) + return + } + dat, err := io.ReadAll(rdr) + if err != nil { + outErr = errors.Join(outErr, err) + return + } + _, err = w.WriteAt(dat, int64(int(r.blockSize)*len(r.sizes))) + if err != nil { + outErr = errors.Join(outErr, err) + return + } + out = int64(int(r.blockSize)*len(r.sizes)) + int64(r.finalBlockSize) + }() + } + wait.Wait() + return +} diff --git a/low/reader.go b/low/reader.go index 614a628..1535dde 100644 --- a/low/reader.go +++ b/low/reader.go @@ -188,10 +188,7 @@ func (r *Reader) inodeRef(i uint32) (uint64, error) { if err != nil { return 0, err } - refsToRead = r.Superblock.InodeCount - uint32(len(r.exportTable)) - if refsToRead > 1024 { - refsToRead = 1024 - } + refsToRead = min(r.Superblock.InodeCount-uint32(len(r.exportTable)), 1024) refsTmp = make([]uint64, refsToRead) rdr = metadata.NewReader(toreader.NewReader(r.r, int64(offset)), r.d) err = binary.Read(rdr, binary.LittleEndian, &refsTmp) diff --git a/squashfs_test.go b/squashfs_test.go index 1479774..ed74d3f 100644 --- a/squashfs_test.go +++ b/squashfs_test.go @@ -1,4 +1,4 @@ -package squashfs_test +package squashfs //Actually proper tests go here. @@ -13,13 +13,11 @@ import ( "strconv" "testing" "time" - - "github.com/CalebQ42/squashfs" ) const ( squashfsURL = "https://darkstorm.tech/files/LinuxPATest.sfs" - squashfsName = "airootfs.sfs" + squashfsName = "LinuxPATest.sfs" ) func preTest(dir string) (fil *os.File, err error) { @@ -61,7 +59,7 @@ func TestMisc(t *testing.T) { if err != nil { t.Fatal(err) } - rdr, err := squashfs.NewReader(fil) + rdr, err := NewReader(fil) if err != nil { t.Fatal(err) } @@ -81,10 +79,10 @@ func BenchmarkRace(b *testing.B) { os.RemoveAll(libPath) os.RemoveAll(unsquashPath) var libTime, unsquashTime time.Duration - op := squashfs.FastOptions() + op := FastOptions() op.IgnorePerm = true start := time.Now() - rdr, err := squashfs.NewReader(fil) + rdr, err := NewReader(fil) if err != nil { b.Fatal(err) } @@ -104,13 +102,13 @@ func BenchmarkRace(b *testing.B) { unsquashTime = time.Since(start) // b.Log("Library took:", libTime.Round(time.Millisecond)) // b.Log("unsquashfs took:", unsquashTime.Round(time.Millisecond)) - b.Fatal("unsquashfs is", strconv.FormatFloat(float64(libTime.Milliseconds())/float64(unsquashTime.Milliseconds()), 'f', 2, 64), "times faster") + b.Log("unsquashfs is", strconv.FormatFloat(float64(libTime.Milliseconds())/float64(unsquashTime.Milliseconds()), 'f', 2, 64), "times faster") } func TestExtractQuick(t *testing.T) { //First, setup everything and extract the archive using the library and unsquashfs - // tmpDir := b.TempDir() + // tmpDir := bTempDir() tmpDir := "testing" fil, err := preTest(tmpDir) if err != nil { @@ -120,13 +118,13 @@ func TestExtractQuick(t *testing.T) { unsquashPath := filepath.Join(tmpDir, "ExtractSquashfs") os.RemoveAll(libPath) os.RemoveAll(unsquashPath) - rdr, err := squashfs.NewReader(fil) + rdr, err := NewReader(fil) if err != nil { t.Fatal(err) } os.RemoveAll(filepath.Join(tmpDir, "testLog.txt")) logFil, _ := os.Create(filepath.Join(tmpDir, "testLog.txt")) - op := squashfs.DefaultOptions() + op := DefaultOptions() op.Verbose = true op.IgnorePerm = true op.LogOutput = logFil @@ -178,7 +176,7 @@ func TestSingleFile(t *testing.T) { t.Fatal(err) } os.Remove(filepath.Join(tmpDir, filePath)) - rdr, err := squashfs.NewReader(fil) + rdr, err := NewReader(fil) if err != nil { t.Fatal(err) } @@ -186,9 +184,15 @@ func TestSingleFile(t *testing.T) { if err != nil { t.Fatal(err) } - err = f.(*squashfs.File).ExtractWithOptions("testing", &squashfs.ExtractionOptions{Verbose: true}) + err = f.(*File).ExtractWithOptions("testing", &ExtractionOptions{Verbose: true}) if err != nil { t.Fatal(err) } t.Fatal("HI") } + +func TestStuff(t *testing.T) { + fil, _ := os.Create("testing/stuff.txt") + _, err := fil.WriteAt([]byte("Yo"), 1024) + t.Fatal(err) +}