You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
185 lines
3.9 KiB
185 lines
3.9 KiB
3 years ago
|
package proto
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"sync"
|
||
|
"sync/atomic"
|
||
|
|
||
|
"github.com/go-faster/errors"
|
||
|
"github.com/klauspost/compress/gzip"
|
||
|
"go.uber.org/multierr"
|
||
|
|
||
|
"github.com/gotd/td/bin"
|
||
|
)
|
||
|
|
||
|
type gzipPool struct {
|
||
|
writers sync.Pool
|
||
|
readers sync.Pool
|
||
|
}
|
||
|
|
||
|
func newGzipPool() *gzipPool {
|
||
|
return &gzipPool{
|
||
|
writers: sync.Pool{
|
||
|
New: func() interface{} {
|
||
|
return gzip.NewWriter(nil)
|
||
|
},
|
||
|
},
|
||
|
readers: sync.Pool{},
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (g *gzipPool) GetWriter(w io.Writer) *gzip.Writer {
|
||
|
writer := g.writers.Get().(*gzip.Writer)
|
||
|
writer.Reset(w)
|
||
|
return writer
|
||
|
}
|
||
|
|
||
|
func (g *gzipPool) PutWriter(w *gzip.Writer) {
|
||
|
g.writers.Put(w)
|
||
|
}
|
||
|
|
||
|
func (g *gzipPool) GetReader(r io.Reader) (*gzip.Reader, error) {
|
||
|
reader, ok := g.readers.Get().(*gzip.Reader)
|
||
|
if !ok {
|
||
|
r, err := gzip.NewReader(r)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
return r, nil
|
||
|
}
|
||
|
|
||
|
if err := reader.Reset(r); err != nil {
|
||
|
g.readers.Put(reader)
|
||
|
return nil, err
|
||
|
}
|
||
|
return reader, nil
|
||
|
}
|
||
|
|
||
|
func (g *gzipPool) PutReader(w *gzip.Reader) {
|
||
|
g.readers.Put(w)
|
||
|
}
|
||
|
|
||
|
// GZIP represents a Packed Object.
|
||
|
//
|
||
|
// Used to replace any other object (or rather, a serialization thereof)
|
||
|
// with its archived (gzipped) representation.
|
||
|
type GZIP struct {
|
||
|
Data []byte
|
||
|
}
|
||
|
|
||
|
// GZIPTypeID is TL type id of GZIP.
|
||
|
const GZIPTypeID = 0x3072cfa1
|
||
|
|
||
|
var (
|
||
|
gzipRWPool = newGzipPool()
|
||
|
gzipBufPool = sync.Pool{New: func() interface{} {
|
||
|
return bytes.NewBuffer(nil)
|
||
|
}}
|
||
|
)
|
||
|
|
||
|
// Encode implements bin.Encoder.
|
||
|
func (g GZIP) Encode(b *bin.Buffer) (rErr error) {
|
||
|
b.PutID(GZIPTypeID)
|
||
|
|
||
|
// Writing compressed data to buf.
|
||
|
buf := gzipBufPool.Get().(*bytes.Buffer)
|
||
|
buf.Reset()
|
||
|
defer gzipBufPool.Put(buf)
|
||
|
|
||
|
w := gzipRWPool.GetWriter(buf)
|
||
|
defer func() {
|
||
|
if closeErr := w.Close(); closeErr != nil {
|
||
|
closeErr = errors.Wrap(closeErr, "close")
|
||
|
multierr.AppendInto(&rErr, closeErr)
|
||
|
}
|
||
|
gzipRWPool.PutWriter(w)
|
||
|
}()
|
||
|
if _, err := w.Write(g.Data); err != nil {
|
||
|
return errors.Wrap(err, "compress")
|
||
|
}
|
||
|
if err := w.Close(); err != nil {
|
||
|
return errors.Wrap(err, "close")
|
||
|
}
|
||
|
|
||
|
// Writing compressed data as bytes.
|
||
|
b.PutBytes(buf.Bytes())
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
type countReader struct {
|
||
|
reader io.Reader
|
||
|
read int64
|
||
|
}
|
||
|
|
||
|
func (c *countReader) Total() int64 {
|
||
|
return atomic.LoadInt64(&c.read)
|
||
|
}
|
||
|
|
||
|
func (c *countReader) Read(p []byte) (n int, err error) {
|
||
|
n, err = c.reader.Read(p)
|
||
|
atomic.AddInt64(&c.read, int64(n))
|
||
|
return n, err
|
||
|
}
|
||
|
|
||
|
// DecompressionBombErr means that GZIP decode detected decompression bomb
|
||
|
// which decompressed payload is significantly higher than initial compressed
|
||
|
// size and stopped decompression to prevent OOM.
|
||
|
type DecompressionBombErr struct {
|
||
|
Compressed int
|
||
|
Decompressed int
|
||
|
}
|
||
|
|
||
|
func (d *DecompressionBombErr) Error() string {
|
||
|
return fmt.Sprintf("payload too big (expanded %d bytes to greater than %d)",
|
||
|
d.Compressed, d.Decompressed,
|
||
|
)
|
||
|
}
|
||
|
|
||
|
// Decode implements bin.Decoder.
|
||
|
func (g *GZIP) Decode(b *bin.Buffer) (rErr error) {
|
||
|
if err := b.ConsumeID(GZIPTypeID); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
buf, err := b.Bytes()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
r, err := gzipRWPool.GetReader(bytes.NewReader(buf))
|
||
|
if err != nil {
|
||
|
return errors.Wrap(err, "gzip error")
|
||
|
}
|
||
|
defer func() {
|
||
|
if closeErr := r.Close(); closeErr != nil {
|
||
|
closeErr = errors.Wrap(closeErr, "close")
|
||
|
multierr.AppendInto(&rErr, closeErr)
|
||
|
}
|
||
|
gzipRWPool.PutReader(r)
|
||
|
}()
|
||
|
|
||
|
// Apply mitigation for reading too much data which can result in OOM.
|
||
|
const maxUncompressedSize = 1024 * 1024 * 10 // 10 mb
|
||
|
reader := &countReader{
|
||
|
reader: io.LimitReader(r, maxUncompressedSize),
|
||
|
}
|
||
|
if g.Data, err = io.ReadAll(reader); err != nil {
|
||
|
return errors.Wrap(err, "decompress")
|
||
|
}
|
||
|
if reader.Total() >= maxUncompressedSize {
|
||
|
// Read limit reached, possible decompression bomb detected.
|
||
|
return errors.Wrap(&DecompressionBombErr{
|
||
|
Compressed: maxUncompressedSize,
|
||
|
Decompressed: int(reader.Total()),
|
||
|
}, "decompress")
|
||
|
}
|
||
|
|
||
|
if err := r.Close(); err != nil {
|
||
|
return errors.Wrap(err, "checksum")
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|