From 0402b0a2ee7f4466367eb109af52385232f1f92a Mon Sep 17 00:00:00 2001 From: Caleb Gardner Date: Sun, 26 Sep 2021 18:30:08 -0500 Subject: [PATCH] Bringing rawreader from expiremental branch. Now allows creation of a squashfs.Reader from an io.Reader --- internal/rawreader/rawreader.go | 133 ++++++++++++++++++++++++++++ reader.go | 149 +++++++++++++++++--------------- 2 files changed, 213 insertions(+), 69 deletions(-) create mode 100644 internal/rawreader/rawreader.go diff --git a/internal/rawreader/rawreader.go b/internal/rawreader/rawreader.go new file mode 100644 index 0000000..4acdaa9 --- /dev/null +++ b/internal/rawreader/rawreader.go @@ -0,0 +1,133 @@ +package rawreader + +import ( + "errors" + "io" +) + +func ConvertReader(r io.Reader) RawReader { + if rr, ok := r.(RawReader); ok { + return rr + } + if rs, is := r.(io.ReadSeeker); is { + return &fromReadSeeker{ + ReadSeeker: rs, + } + } + return &fromReader{ + rdr: r, + cache: make([]byte, 0), + } +} + +func ConvertReaderAt(r io.ReaderAt) RawReader { + if rr, ok := r.(RawReader); ok { + return rr + } + return &fromReaderAt{ + ReaderAt: r, + } +} + +//TODO: Add way to discard data from fromReader +//RawReader implements the needed interfaces for reading a squashfs archive. +type RawReader interface { + io.ReadSeeker + io.ReaderAt +} + +type fromReader struct { + rdr io.Reader + cache []byte + off int +} + +func (r *fromReader) increaseCache(len int) error { + newCache := make([]byte, len) + _, err := r.rdr.Read(newCache) + if err != nil { + return err + } + r.cache = append(r.cache, newCache...) + return nil +} + +func (r *fromReader) ReadAt(p []byte, off int64) (n int, err error) { + if int(off)+len(p) > len(r.cache) { + r.increaseCache((int(off) + len(p)) - len(r.cache)) + } + for i := int64(0); i < int64(len(p)); i++ { + p[i] = r.cache[off+i] + } + return +} + +func (r *fromReader) Seek(off int64, whence int) (n int64, err error) { + switch whence { + case io.SeekEnd: + return 0, errors.New("cannot SeekEnd RawReader") + case io.SeekCurrent: + r.off += int(off) + case io.SeekStart: + r.off = int(off) + } + if r.off > len(r.cache) { + err = r.increaseCache(len(r.cache) - r.off) + if err != nil { + r.off = len(r.cache) + } + } + return int64(r.off), err +} + +func (r *fromReader) Read(p []byte) (n int, err error) { + if len(p)+r.off > len(r.cache) { + err = r.increaseCache((len(p) + r.off) - len(r.cache)) + if err != nil { + return + } + } + for i := 0; i < len(p); i++ { + p[i] = r.cache[r.off+i] + } + r.off += len(p) + return +} + +type fromReadSeeker struct { + io.ReadSeeker +} + +func (r *fromReadSeeker) ReadAt(p []byte, off int64) (n int, err error) { + tmp, _ := r.Seek(0, io.SeekCurrent) + defer r.Seek(tmp, io.SeekStart) + _, err = r.Seek(off, io.SeekStart) + if err != nil { + return + } + return r.Read(p) +} + +type fromReaderAt struct { + io.ReaderAt + + off int +} + +func (r *fromReaderAt) Read(p []byte) (n int, err error) { + n, err = r.ReadAt(p, int64(r.off)) + r.off += n + return +} + +func (r *fromReaderAt) Seek(off int64, whence int) (n int64, err error) { + switch whence { + case io.SeekEnd: + return 0, errors.New("cannot SeekEnd RawReader") + case io.SeekCurrent: + r.off += int(off) + case io.SeekStart: + r.off = int(off) + } + return int64(r.off), nil +} diff --git a/reader.go b/reader.go index 0ed5f53..a647f73 100644 --- a/reader.go +++ b/reader.go @@ -9,6 +9,7 @@ import ( "github.com/CalebQ42/squashfs/internal/compression" "github.com/CalebQ42/squashfs/internal/inode" + "github.com/CalebQ42/squashfs/internal/rawreader" ) const ( @@ -20,16 +21,12 @@ var ( errNoMagic = errors.New("magic number doesn't match. Either isn't a squashfs or corrupted") //ErrIncompatibleCompression is returned if the compression type in the superblock doesn't work. errIncompatibleCompression = errors.New("compression type unsupported") - //ErrOptions is returned when compression options that I haven't tested is set. When this is returned, the Reader is also returned. - ErrOptions = errors.New("possibly incompatible compressor options") ) -//TODO: implement fs.FS, possibly more FS types for compatibility. Most of this work will mostly be handed off to root anyway so this shouldn't be too difficult. - //Reader processes and reads a squashfs archive. type Reader struct { FS - r *io.SectionReader + r rawreader.RawReader decompressor compression.Decompressor fragOffsets []uint64 idTable []uint32 @@ -40,140 +37,154 @@ type Reader struct { //NewSquashfsReader returns a new squashfs.Reader from an io.ReaderAt func NewSquashfsReader(r io.ReaderAt) (*Reader, error) { var rdr Reader - err := binary.Read(io.NewSectionReader(r, 0, int64(binary.Size(rdr.super))), binary.LittleEndian, &rdr.super) + rdr.r = rawreader.ConvertReaderAt(r) + err := rdr.Init() if err != nil { return nil, err } - rdr.r = io.NewSectionReader(r, 0, int64(rdr.super.BytesUsed)) - if rdr.super.Magic != magic { - return nil, errNoMagic + return &rdr, nil +} + +//NewSquashfsReaderFromReader returns a new squashfs.Reader from an io.Reader. +//If possible, try to use a io.ReaderAt. +//With an io.Reader, much of the data has to be cached to memory. +func NewSquashfsReaderFromReader(r io.Reader) (*Reader, error) { + var rdr Reader + rdr.r = rawreader.ConvertReader(r) + err := rdr.Init() + if err != nil { + return nil, err } - if rdr.super.BlockLog != uint16(math.Log2(float64(rdr.super.BlockSize))) { - return nil, errors.New("BlockSize and BlockLog doesn't match. The archive is probably corrupt") + return &rdr, nil +} + +func (r *Reader) Init() error { + err := binary.Read(r.r, binary.LittleEndian, &r.super) + if err != nil { + return err } - rdr.r.Seek(96, io.SeekStart) - hasUnsupportedOptions := false - rdr.flags = rdr.super.GetFlags() - if rdr.flags.compressorOptions { - switch rdr.super.CompressionType { + if r.super.Magic != magic { + return errNoMagic + } + if r.super.BlockLog != uint16(math.Log2(float64(r.super.BlockSize))) { + return errors.New("BlockSize and BlockLog doesn't match. The archive is probably corrupt") + } + r.r.Seek(96, io.SeekStart) + r.flags = r.super.GetFlags() + if r.flags.compressorOptions { + switch r.super.CompressionType { case GzipCompression: var gzip *compression.Gzip - gzip, err = compression.NewGzipCompressorWithOptions(rdr.r) + gzip, err = compression.NewGzipCompressorWithOptions(r.r) if err != nil { - return nil, err + return err } - if gzip.HasCustomWindow || gzip.HasStrategies { - hasUnsupportedOptions = true - } - rdr.decompressor = gzip + r.decompressor = gzip case XzCompression: var xz *compression.Xz - xz, err = compression.NewXzCompressorWithOptions(rdr.r) + xz, err = compression.NewXzCompressorWithOptions(r.r) if err != nil { - return nil, err + return err } - rdr.decompressor = xz + r.decompressor = xz case LzoCompression: var lz *compression.Lzo - lz, err = compression.NewLzoCompressorWithOptions(rdr.r) + lz, err = compression.NewLzoCompressorWithOptions(r.r) if err != nil { - return nil, err + return err } - rdr.decompressor = lz + r.decompressor = lz case Lz4Compression: var lz4 *compression.Lz4 - lz4, err = compression.NewLz4CompressorWithOptions(rdr.r) + lz4, err = compression.NewLz4CompressorWithOptions(r.r) if err != nil { - return nil, err + return err } - rdr.decompressor = lz4 + r.decompressor = lz4 case ZstdCompression: var zstd *compression.Zstd - zstd, err = compression.NewZstdCompressorWithOptions(rdr.r) + zstd, err = compression.NewZstdCompressorWithOptions(r.r) if err != nil { - return nil, err + return err } - rdr.decompressor = zstd + r.decompressor = zstd default: - return nil, errIncompatibleCompression + return errIncompatibleCompression } } else { - switch rdr.super.CompressionType { + switch r.super.CompressionType { case GzipCompression: - rdr.decompressor = &compression.Gzip{} + r.decompressor = &compression.Gzip{} case LzmaCompression: - rdr.decompressor = &compression.Lzma{} + r.decompressor = &compression.Lzma{} case LzoCompression: - rdr.decompressor = &compression.Lzo{} + r.decompressor = &compression.Lzo{} case XzCompression: - rdr.decompressor = &compression.Xz{} + r.decompressor = &compression.Xz{} case Lz4Compression: - rdr.decompressor = &compression.Lz4{} + r.decompressor = &compression.Lz4{} case ZstdCompression: - rdr.decompressor = &compression.Zstd{} + r.decompressor = &compression.Zstd{} default: //TODO: all compression types. - return nil, errIncompatibleCompression + return errIncompatibleCompression } } - fragBlocks := int(math.Ceil(float64(rdr.super.FragCount) / 512)) + fragBlocks := int(math.Ceil(float64(r.super.FragCount) / 512)) if fragBlocks > 0 { - offset := int64(rdr.super.FragTableStart) + offset := int64(r.super.FragTableStart) for i := 0; i < fragBlocks; i++ { tmp := make([]byte, 8) - _, err = r.ReadAt(tmp, offset) + _, err = r.r.ReadAt(tmp, offset) if err != nil { - return nil, err + return err } - rdr.fragOffsets = append(rdr.fragOffsets, binary.LittleEndian.Uint64(tmp)) + r.fragOffsets = append(r.fragOffsets, binary.LittleEndian.Uint64(tmp)) offset += 8 } } - unread := rdr.super.IDCount - blockOffsets := make([]uint64, int(math.Ceil(float64(rdr.super.IDCount)/2048))) - rdr.r.Seek(int64(rdr.super.IDTableStart), io.SeekStart) + unread := r.super.IDCount + blockOffsets := make([]uint64, int(math.Ceil(float64(r.super.IDCount)/2048))) + r.r.Seek(int64(r.super.IDTableStart), io.SeekStart) for i := range blockOffsets { - err = binary.Read(rdr.r, binary.LittleEndian, &blockOffsets[i]) + err = binary.Read(r.r, binary.LittleEndian, &blockOffsets[i]) if err != nil { - return nil, err + return err } var idRdr *metadataReader - idRdr, err = rdr.newMetadataReader(int64(blockOffsets[i])) + idRdr, err = r.newMetadataReader(int64(blockOffsets[i])) if err != nil { - return nil, err + return err } read := uint16(math.Min(float64(unread), 2048)) for i := uint16(0); i < read; i++ { var tmp uint32 err = binary.Read(idRdr, binary.LittleEndian, &tmp) if err != nil { - return nil, err + return err } - rdr.idTable = append(rdr.idTable, tmp) + r.idTable = append(r.idTable, tmp) } unread -= read } - metaRdr, err := rdr.newMetadataReaderFromInodeRef(rdr.super.RootInodeRef) + metaRdr, err := r.newMetadataReaderFromInodeRef(r.super.RootInodeRef) if err != nil { - return nil, err + return err } - i, err := inode.ProcessInode(metaRdr, rdr.super.BlockSize) + i, err := inode.ProcessInode(metaRdr, r.super.BlockSize) if err != nil { - return nil, err + return err } - entries, err := rdr.readDirFromInode(i) + entries, err := r.readDirFromInode(i) if err != nil { - return nil, err + return err } - rdr.FS = FS{ - r: &rdr, + r.FS = FS{ + r: r, name: "/", entries: entries, } - if hasUnsupportedOptions { - return &rdr, ErrOptions - } - return &rdr, nil + return nil } //ModTime is the last time the file was modified/created.