diff --git a/internal/decompress/zlib.go b/internal/decompress/zlib.go index 3da3415..f658c26 100644 --- a/internal/decompress/zlib.go +++ b/internal/decompress/zlib.go @@ -2,17 +2,31 @@ package decompress import ( "bytes" - "compress/zlib" "io" + "sync" + + "github.com/klauspost/compress/zlib" ) -type Zlib struct{} +type Zlib struct { + pool sync.Pool +} -func (z Zlib) Decompress(data []byte) ([]byte, error) { - rdr, err := zlib.NewReader(bytes.NewReader(data)) +func NewZlib() *Zlib { + return &Zlib{} +} + +func (z *Zlib) Decompress(data []byte) ([]byte, error) { + rdr := z.pool.Get() + defer z.pool.Put(rdr) + var err error + if rdr == nil { + rdr, err = zlib.NewReader(bytes.NewReader(data)) + } else { + err = rdr.(zlib.Resetter).Reset(bytes.NewReader(data), nil) + } if err != nil { return nil, err } - defer rdr.Close() - return io.ReadAll(rdr) + return io.ReadAll(rdr.(io.ReadCloser)) } diff --git a/internal/decompress/zstd.go b/internal/decompress/zstd.go index d65afac..acf977f 100644 --- a/internal/decompress/zstd.go +++ b/internal/decompress/zstd.go @@ -1,16 +1,31 @@ package decompress import ( + "sync" + "github.com/klauspost/compress/zstd" ) -type Zstd struct{} +type Zstd struct { + pool sync.Pool +} -func (z Zstd) Decompress(data []byte) ([]byte, error) { - rdr, err := zstd.NewReader(nil, zstd.WithDecoderLowmem(true), zstd.WithDecoderConcurrency(1)) - if err != nil { - return nil, err +func NewZstd() *Zstd { + return &Zstd{ + pool: sync.Pool{ + New: func() any { + rdr, _ := zstd.NewReader(nil, zstd.WithDecoderLowmem(true), zstd.WithDecoderConcurrency(1)) + return rdr + }, + }, } - defer rdr.Close() +} + +func (z *Zstd) Decompress(data []byte) ([]byte, error) { + rdr := z.pool.Get().(*zstd.Decoder) + defer func() { + rdr.Reset(nil) + z.pool.Put(rdr) + }() return rdr.DecodeAll(data, nil) } diff --git a/low/reader.go b/low/reader.go index 3cf24b2..4f8efb9 100644 --- a/low/reader.go +++ b/low/reader.go @@ -56,7 +56,7 @@ func NewReader(r io.ReaderAt) (rdr Reader, err error) { } switch rdr.Superblock.CompType { case ZlibCompression: - rdr.d = decompress.Zlib{} + rdr.d = decompress.NewZlib() case LZMACompression: rdr.d, err = decompress.NewLzma() if err != nil { @@ -72,7 +72,7 @@ func NewReader(r io.ReaderAt) (rdr Reader, err error) { case LZ4Compression: rdr.d = decompress.NewLz4() case ZSTDCompression: - rdr.d = decompress.Zstd{} + rdr.d = decompress.NewZstd() default: return rdr, errors.New("invalid compression type. possible corrupted archive") }