diff --git a/internal/data/fullreader.go b/internal/data/fullreader.go index 60d4ec4..c868702 100644 --- a/internal/data/fullreader.go +++ b/internal/data/fullreader.go @@ -40,6 +40,7 @@ type outDat struct { func (r FullReader) process(index int, offset int64, out chan outDat) { var err error var dat []byte + var rdr io.ReadCloser size := realSize(r.sizes[index]) if size == 0 { out <- outDat{ @@ -49,12 +50,24 @@ func (r FullReader) process(index int, offset int64, out chan outDat) { } return } - rdr := io.LimitReader(toreader.NewReader(r.r, offset), int64(size)) + // rdr := io.LimitReader(toreader.NewReader(r.r, offset), int64(size)) if size == r.sizes[index] { - rdr, err = r.d.Reader(rdr) - } - if err == nil { - dat, err = io.ReadAll(rdr) + //Special workaround for zstd for increased performancce. + if zstd, ok := r.d.(*decompress.Zstd); ok { + dat = make([]byte, size) + _, err = r.r.ReadAt(dat, offset) + if err == nil { + dat, err = zstd.Decode(dat) + } + } else { + rdr, err = r.d.Reader(io.LimitReader(toreader.NewReader(r.r, offset), int64(size))) + if err == nil { + dat, err = io.ReadAll(rdr) + } + } + } else { + dat = make([]byte, size) + _, err = r.r.ReadAt(dat, offset) } out <- outDat{ i: index, diff --git a/internal/data/reader.go b/internal/data/reader.go index 3914796..d4c51c2 100644 --- a/internal/data/reader.go +++ b/internal/data/reader.go @@ -12,17 +12,20 @@ type Reader struct { cur io.Reader fragRdr io.Reader d decompress.Decompressor + comRdr io.Reader blockSizes []uint32 blockSize uint32 + resetable bool } func NewReader(r io.Reader, d decompress.Decompressor, blockSizes []uint32, blockSize uint32) *Reader { - var out Reader - out.d = d - out.master = r - out.blockSizes = blockSizes - out.blockSize = blockSize - return &out + return &Reader{ + d: d, + master: r, + blockSizes: blockSizes, + blockSize: blockSize, + resetable: true, + } } func (r *Reader) AddFragment(rdr io.Reader) { @@ -50,7 +53,19 @@ func (r *Reader) advance() (err error) { } else { r.cur = io.LimitReader(r.master, int64(size)) if size == r.blockSizes[0] { - r.cur, err = r.d.Reader(r.cur) + if r.d.Resetable() { + if r.comRdr == nil { + r.cur, err = r.d.Reader(r.cur) + if err != nil { + return + } + } else { + err = r.d.Reset(r.comRdr, r.cur) + r.cur = r.comRdr + } + } else { + r.cur, err = r.d.Reader(r.cur) + } } } } diff --git a/internal/decompress/gzip.go b/internal/decompress/gzip.go index c7ced32..879686f 100644 --- a/internal/decompress/gzip.go +++ b/internal/decompress/gzip.go @@ -11,3 +11,9 @@ type GZip struct{} func (g GZip) Reader(src io.Reader) (io.ReadCloser, error) { return zlib.NewReader(src) } + +func (g GZip) Resetable() bool { return true } + +func (g GZip) Reset(old, src io.Reader) error { + return old.(zlib.Resetter).Reset(src, nil) +} diff --git a/internal/decompress/interface.go b/internal/decompress/interface.go index 177e45d..af522f1 100644 --- a/internal/decompress/interface.go +++ b/internal/decompress/interface.go @@ -1,7 +1,19 @@ package decompress -import "io" +import ( + "errors" + "io" +) + +var ErrNotResetable = errors.New("decompressor not resetable") type Decompressor interface { + //Creates a new decompressor reading from src. Reader(src io.Reader) (io.ReadCloser, error) + //Reports whether Reset will work or not. + Resetable() bool + //Reset attempts to re-use an old decompressor with new data. + //Will return ErrNotResetable if not Resetable(). + //Must ALWAYS be provided with a reader created with Reader. + Reset(old, src io.Reader) error } diff --git a/internal/decompress/lz4.go b/internal/decompress/lz4.go index af399ba..c2bab28 100644 --- a/internal/decompress/lz4.go +++ b/internal/decompress/lz4.go @@ -11,3 +11,10 @@ type Lz4 struct{} func (l Lz4) Reader(r io.Reader) (io.ReadCloser, error) { return io.NopCloser(lz4.NewReader(r)), nil } + +func (l Lz4) Resetable() bool { return true } + +func (l Lz4) Reset(old, src io.Reader) error { + old.(*lz4.Reader).Reset(src) + return nil +} diff --git a/internal/decompress/lzma.go b/internal/decompress/lzma.go index 9add4ae..f93b29d 100644 --- a/internal/decompress/lzma.go +++ b/internal/decompress/lzma.go @@ -12,3 +12,7 @@ func (l Lzma) Reader(r io.Reader) (io.ReadCloser, error) { rdr, err := lzma.NewReader(r) return io.NopCloser(rdr), err } + +func (l Lzma) Resetable() bool { return false } + +func (l Lzma) Reset(old, src io.Reader) error { return ErrNotResetable } diff --git a/internal/decompress/lzo.go b/internal/decompress/lzo.go index 76333e7..15d0f2e 100644 --- a/internal/decompress/lzo.go +++ b/internal/decompress/lzo.go @@ -16,3 +16,7 @@ func (l Lzo) Reader(r io.Reader) (io.ReadCloser, error) { } return io.NopCloser(bytes.NewReader(cache)), nil } + +func (l Lzo) Resetable() bool { return false } + +func (l Lzo) Reset(old, src io.Reader) error { return ErrNotResetable } diff --git a/internal/decompress/xz.go b/internal/decompress/xz.go index 9a22256..ad260bd 100644 --- a/internal/decompress/xz.go +++ b/internal/decompress/xz.go @@ -12,3 +12,9 @@ func (x Xz) Reader(r io.Reader) (io.ReadCloser, error) { rdr, err := xz.NewReader(r, 0) return io.NopCloser(rdr), err } + +func (x Xz) Resetable() bool { return true } + +func (x Xz) Reset(old, src io.Reader) error { + return old.(*xz.Reader).Reset(src) +} diff --git a/internal/decompress/zstd.go b/internal/decompress/zstd.go index 4f5df6b..544500d 100644 --- a/internal/decompress/zstd.go +++ b/internal/decompress/zstd.go @@ -1,31 +1,29 @@ package decompress import ( - "bytes" "io" "github.com/klauspost/compress/zstd" ) -type Zstd struct{} +type Zstd struct { + writeToReader *zstd.Decoder +} func (z Zstd) Reader(src io.Reader) (io.ReadCloser, error) { r, err := zstd.NewReader(src) return r.IOReadCloser(), err } -type ZstdDecodeAll struct { - rdr *zstd.Decoder +func (z Zstd) Resetable() bool { return true } + +func (z Zstd) Reset(old, src io.Reader) error { + return old.(*zstd.Decoder).Reset(src) } -func (z *ZstdDecodeAll) Reader(src io.Reader) (io.ReadCloser, error) { - if z.rdr == nil { - z.rdr, _ = zstd.NewReader(nil) +func (z *Zstd) Decode(in []byte) (out []byte, err error) { + if z.writeToReader == nil { + z.writeToReader, _ = zstd.NewReader(nil) } - data, err := io.ReadAll(src) - if err != nil { - return nil, err - } - out, err := z.rdr.DecodeAll(data, nil) - return io.NopCloser(bytes.NewReader(out)), err + return z.writeToReader.DecodeAll(in, nil) } diff --git a/internal/metadata/reader.go b/internal/metadata/reader.go index 4567cbe..0768084 100644 --- a/internal/metadata/reader.go +++ b/internal/metadata/reader.go @@ -11,6 +11,7 @@ type Reader struct { master io.Reader cur io.Reader d decompress.Decompressor + comRdr io.Reader } func NewReader(master io.Reader, d decompress.Decompressor) *Reader { @@ -25,8 +26,10 @@ func realSize(siz uint16) uint16 { } func (r *Reader) advance() (err error) { - if clr, ok := r.cur.(io.Closer); ok { - clr.Close() + if !r.d.Resetable() { + if clr, ok := r.cur.(io.Closer); ok { + clr.Close() + } } var raw uint16 err = binary.Read(r.master, binary.LittleEndian, &raw) @@ -36,7 +39,19 @@ func (r *Reader) advance() (err error) { size := realSize(raw) r.cur = io.LimitReader(r.master, int64(size)) if size == raw { - r.cur, err = r.d.Reader(r.cur) + if r.d.Resetable() { + if r.comRdr == nil { + r.cur, err = r.d.Reader(r.cur) + if err != nil { + return + } + } else { + err = r.d.Reset(r.comRdr, r.cur) + r.cur = r.comRdr + } + } else { + r.cur, err = r.d.Reader(r.cur) + } } return } diff --git a/reader.go b/reader.go index b72383d..a0dc1f0 100644 --- a/reader.go +++ b/reader.go @@ -75,7 +75,7 @@ func NewReader(r io.ReaderAt) (*Reader, error) { case LZ4Compression: squash.d = decompress.Lz4{} case ZSTDCompression: - squash.d = decompress.Zstd{} + squash.d = &decompress.Zstd{} default: return nil, errors.New("uh, I need to do this, OR something if very wrong") } diff --git a/fragment.go b/reader_frag.go similarity index 100% rename from fragment.go rename to reader_frag.go