You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
184 lines
3.9 KiB
184 lines
3.9 KiB
package proto |
|
|
|
import ( |
|
"bytes" |
|
"fmt" |
|
"io" |
|
"sync" |
|
"sync/atomic" |
|
|
|
"github.com/go-faster/errors" |
|
"github.com/klauspost/compress/gzip" |
|
"go.uber.org/multierr" |
|
|
|
"github.com/gotd/td/bin" |
|
) |
|
|
|
type gzipPool struct { |
|
writers sync.Pool |
|
readers sync.Pool |
|
} |
|
|
|
func newGzipPool() *gzipPool { |
|
return &gzipPool{ |
|
writers: sync.Pool{ |
|
New: func() interface{} { |
|
return gzip.NewWriter(nil) |
|
}, |
|
}, |
|
readers: sync.Pool{}, |
|
} |
|
} |
|
|
|
func (g *gzipPool) GetWriter(w io.Writer) *gzip.Writer { |
|
writer := g.writers.Get().(*gzip.Writer) |
|
writer.Reset(w) |
|
return writer |
|
} |
|
|
|
func (g *gzipPool) PutWriter(w *gzip.Writer) { |
|
g.writers.Put(w) |
|
} |
|
|
|
func (g *gzipPool) GetReader(r io.Reader) (*gzip.Reader, error) { |
|
reader, ok := g.readers.Get().(*gzip.Reader) |
|
if !ok { |
|
r, err := gzip.NewReader(r) |
|
if err != nil { |
|
return nil, err |
|
} |
|
return r, nil |
|
} |
|
|
|
if err := reader.Reset(r); err != nil { |
|
g.readers.Put(reader) |
|
return nil, err |
|
} |
|
return reader, nil |
|
} |
|
|
|
func (g *gzipPool) PutReader(w *gzip.Reader) { |
|
g.readers.Put(w) |
|
} |
|
|
|
// GZIP represents a Packed Object. |
|
// |
|
// Used to replace any other object (or rather, a serialization thereof) |
|
// with its archived (gzipped) representation. |
|
type GZIP struct { |
|
Data []byte |
|
} |
|
|
|
// GZIPTypeID is TL type id of GZIP. |
|
const GZIPTypeID = 0x3072cfa1 |
|
|
|
var ( |
|
gzipRWPool = newGzipPool() |
|
gzipBufPool = sync.Pool{New: func() interface{} { |
|
return bytes.NewBuffer(nil) |
|
}} |
|
) |
|
|
|
// Encode implements bin.Encoder. |
|
func (g GZIP) Encode(b *bin.Buffer) (rErr error) { |
|
b.PutID(GZIPTypeID) |
|
|
|
// Writing compressed data to buf. |
|
buf := gzipBufPool.Get().(*bytes.Buffer) |
|
buf.Reset() |
|
defer gzipBufPool.Put(buf) |
|
|
|
w := gzipRWPool.GetWriter(buf) |
|
defer func() { |
|
if closeErr := w.Close(); closeErr != nil { |
|
closeErr = errors.Wrap(closeErr, "close") |
|
multierr.AppendInto(&rErr, closeErr) |
|
} |
|
gzipRWPool.PutWriter(w) |
|
}() |
|
if _, err := w.Write(g.Data); err != nil { |
|
return errors.Wrap(err, "compress") |
|
} |
|
if err := w.Close(); err != nil { |
|
return errors.Wrap(err, "close") |
|
} |
|
|
|
// Writing compressed data as bytes. |
|
b.PutBytes(buf.Bytes()) |
|
|
|
return nil |
|
} |
|
|
|
type countReader struct { |
|
reader io.Reader |
|
read int64 |
|
} |
|
|
|
func (c *countReader) Total() int64 { |
|
return atomic.LoadInt64(&c.read) |
|
} |
|
|
|
func (c *countReader) Read(p []byte) (n int, err error) { |
|
n, err = c.reader.Read(p) |
|
atomic.AddInt64(&c.read, int64(n)) |
|
return n, err |
|
} |
|
|
|
// DecompressionBombErr means that GZIP decode detected decompression bomb |
|
// which decompressed payload is significantly higher than initial compressed |
|
// size and stopped decompression to prevent OOM. |
|
type DecompressionBombErr struct { |
|
Compressed int |
|
Decompressed int |
|
} |
|
|
|
func (d *DecompressionBombErr) Error() string { |
|
return fmt.Sprintf("payload too big (expanded %d bytes to greater than %d)", |
|
d.Compressed, d.Decompressed, |
|
) |
|
} |
|
|
|
// Decode implements bin.Decoder. |
|
func (g *GZIP) Decode(b *bin.Buffer) (rErr error) { |
|
if err := b.ConsumeID(GZIPTypeID); err != nil { |
|
return err |
|
} |
|
buf, err := b.Bytes() |
|
if err != nil { |
|
return err |
|
} |
|
|
|
r, err := gzipRWPool.GetReader(bytes.NewReader(buf)) |
|
if err != nil { |
|
return errors.Wrap(err, "gzip error") |
|
} |
|
defer func() { |
|
if closeErr := r.Close(); closeErr != nil { |
|
closeErr = errors.Wrap(closeErr, "close") |
|
multierr.AppendInto(&rErr, closeErr) |
|
} |
|
gzipRWPool.PutReader(r) |
|
}() |
|
|
|
// Apply mitigation for reading too much data which can result in OOM. |
|
const maxUncompressedSize = 1024 * 1024 * 10 // 10 mb |
|
reader := &countReader{ |
|
reader: io.LimitReader(r, maxUncompressedSize), |
|
} |
|
if g.Data, err = io.ReadAll(reader); err != nil { |
|
return errors.Wrap(err, "decompress") |
|
} |
|
if reader.Total() >= maxUncompressedSize { |
|
// Read limit reached, possible decompression bomb detected. |
|
return errors.Wrap(&DecompressionBombErr{ |
|
Compressed: maxUncompressedSize, |
|
Decompressed: int(reader.Total()), |
|
}, "decompress") |
|
} |
|
|
|
if err := r.Close(); err != nil { |
|
return errors.Wrap(err, "checksum") |
|
} |
|
|
|
return nil |
|
}
|
|
|