You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
374 lines
9.2 KiB
374 lines
9.2 KiB
package fse |
|
|
|
import ( |
|
"errors" |
|
"fmt" |
|
) |
|
|
|
const ( |
|
tablelogAbsoluteMax = 15 |
|
) |
|
|
|
// Decompress a block of data. |
|
// You can provide a scratch buffer to avoid allocations. |
|
// If nil is provided a temporary one will be allocated. |
|
// It is possible, but by no way guaranteed that corrupt data will |
|
// return an error. |
|
// It is up to the caller to verify integrity of the returned data. |
|
// Use a predefined Scrach to set maximum acceptable output size. |
|
func Decompress(b []byte, s *Scratch) ([]byte, error) { |
|
s, err := s.prepare(b) |
|
if err != nil { |
|
return nil, err |
|
} |
|
s.Out = s.Out[:0] |
|
err = s.readNCount() |
|
if err != nil { |
|
return nil, err |
|
} |
|
err = s.buildDtable() |
|
if err != nil { |
|
return nil, err |
|
} |
|
err = s.decompress() |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
return s.Out, nil |
|
} |
|
|
|
// readNCount will read the symbol distribution so decoding tables can be constructed. |
|
func (s *Scratch) readNCount() error { |
|
var ( |
|
charnum uint16 |
|
previous0 bool |
|
b = &s.br |
|
) |
|
iend := b.remain() |
|
if iend < 4 { |
|
return errors.New("input too small") |
|
} |
|
bitStream := b.Uint32() |
|
nbBits := uint((bitStream & 0xF) + minTablelog) // extract tableLog |
|
if nbBits > tablelogAbsoluteMax { |
|
return errors.New("tableLog too large") |
|
} |
|
bitStream >>= 4 |
|
bitCount := uint(4) |
|
|
|
s.actualTableLog = uint8(nbBits) |
|
remaining := int32((1 << nbBits) + 1) |
|
threshold := int32(1 << nbBits) |
|
gotTotal := int32(0) |
|
nbBits++ |
|
|
|
for remaining > 1 { |
|
if previous0 { |
|
n0 := charnum |
|
for (bitStream & 0xFFFF) == 0xFFFF { |
|
n0 += 24 |
|
if b.off < iend-5 { |
|
b.advance(2) |
|
bitStream = b.Uint32() >> bitCount |
|
} else { |
|
bitStream >>= 16 |
|
bitCount += 16 |
|
} |
|
} |
|
for (bitStream & 3) == 3 { |
|
n0 += 3 |
|
bitStream >>= 2 |
|
bitCount += 2 |
|
} |
|
n0 += uint16(bitStream & 3) |
|
bitCount += 2 |
|
if n0 > maxSymbolValue { |
|
return errors.New("maxSymbolValue too small") |
|
} |
|
for charnum < n0 { |
|
s.norm[charnum&0xff] = 0 |
|
charnum++ |
|
} |
|
|
|
if b.off <= iend-7 || b.off+int(bitCount>>3) <= iend-4 { |
|
b.advance(bitCount >> 3) |
|
bitCount &= 7 |
|
bitStream = b.Uint32() >> bitCount |
|
} else { |
|
bitStream >>= 2 |
|
} |
|
} |
|
|
|
max := (2*(threshold) - 1) - (remaining) |
|
var count int32 |
|
|
|
if (int32(bitStream) & (threshold - 1)) < max { |
|
count = int32(bitStream) & (threshold - 1) |
|
bitCount += nbBits - 1 |
|
} else { |
|
count = int32(bitStream) & (2*threshold - 1) |
|
if count >= threshold { |
|
count -= max |
|
} |
|
bitCount += nbBits |
|
} |
|
|
|
count-- // extra accuracy |
|
if count < 0 { |
|
// -1 means +1 |
|
remaining += count |
|
gotTotal -= count |
|
} else { |
|
remaining -= count |
|
gotTotal += count |
|
} |
|
s.norm[charnum&0xff] = int16(count) |
|
charnum++ |
|
previous0 = count == 0 |
|
for remaining < threshold { |
|
nbBits-- |
|
threshold >>= 1 |
|
} |
|
if b.off <= iend-7 || b.off+int(bitCount>>3) <= iend-4 { |
|
b.advance(bitCount >> 3) |
|
bitCount &= 7 |
|
} else { |
|
bitCount -= (uint)(8 * (len(b.b) - 4 - b.off)) |
|
b.off = len(b.b) - 4 |
|
} |
|
bitStream = b.Uint32() >> (bitCount & 31) |
|
} |
|
s.symbolLen = charnum |
|
|
|
if s.symbolLen <= 1 { |
|
return fmt.Errorf("symbolLen (%d) too small", s.symbolLen) |
|
} |
|
if s.symbolLen > maxSymbolValue+1 { |
|
return fmt.Errorf("symbolLen (%d) too big", s.symbolLen) |
|
} |
|
if remaining != 1 { |
|
return fmt.Errorf("corruption detected (remaining %d != 1)", remaining) |
|
} |
|
if bitCount > 32 { |
|
return fmt.Errorf("corruption detected (bitCount %d > 32)", bitCount) |
|
} |
|
if gotTotal != 1<<s.actualTableLog { |
|
return fmt.Errorf("corruption detected (total %d != %d)", gotTotal, 1<<s.actualTableLog) |
|
} |
|
b.advance((bitCount + 7) >> 3) |
|
return nil |
|
} |
|
|
|
// decSymbol contains information about a state entry, |
|
// Including the state offset base, the output symbol and |
|
// the number of bits to read for the low part of the destination state. |
|
type decSymbol struct { |
|
newState uint16 |
|
symbol uint8 |
|
nbBits uint8 |
|
} |
|
|
|
// allocDtable will allocate decoding tables if they are not big enough. |
|
func (s *Scratch) allocDtable() { |
|
tableSize := 1 << s.actualTableLog |
|
if cap(s.decTable) < tableSize { |
|
s.decTable = make([]decSymbol, tableSize) |
|
} |
|
s.decTable = s.decTable[:tableSize] |
|
|
|
if cap(s.ct.tableSymbol) < 256 { |
|
s.ct.tableSymbol = make([]byte, 256) |
|
} |
|
s.ct.tableSymbol = s.ct.tableSymbol[:256] |
|
|
|
if cap(s.ct.stateTable) < 256 { |
|
s.ct.stateTable = make([]uint16, 256) |
|
} |
|
s.ct.stateTable = s.ct.stateTable[:256] |
|
} |
|
|
|
// buildDtable will build the decoding table. |
|
func (s *Scratch) buildDtable() error { |
|
tableSize := uint32(1 << s.actualTableLog) |
|
highThreshold := tableSize - 1 |
|
s.allocDtable() |
|
symbolNext := s.ct.stateTable[:256] |
|
|
|
// Init, lay down lowprob symbols |
|
s.zeroBits = false |
|
{ |
|
largeLimit := int16(1 << (s.actualTableLog - 1)) |
|
for i, v := range s.norm[:s.symbolLen] { |
|
if v == -1 { |
|
s.decTable[highThreshold].symbol = uint8(i) |
|
highThreshold-- |
|
symbolNext[i] = 1 |
|
} else { |
|
if v >= largeLimit { |
|
s.zeroBits = true |
|
} |
|
symbolNext[i] = uint16(v) |
|
} |
|
} |
|
} |
|
// Spread symbols |
|
{ |
|
tableMask := tableSize - 1 |
|
step := tableStep(tableSize) |
|
position := uint32(0) |
|
for ss, v := range s.norm[:s.symbolLen] { |
|
for i := 0; i < int(v); i++ { |
|
s.decTable[position].symbol = uint8(ss) |
|
position = (position + step) & tableMask |
|
for position > highThreshold { |
|
// lowprob area |
|
position = (position + step) & tableMask |
|
} |
|
} |
|
} |
|
if position != 0 { |
|
// position must reach all cells once, otherwise normalizedCounter is incorrect |
|
return errors.New("corrupted input (position != 0)") |
|
} |
|
} |
|
|
|
// Build Decoding table |
|
{ |
|
tableSize := uint16(1 << s.actualTableLog) |
|
for u, v := range s.decTable { |
|
symbol := v.symbol |
|
nextState := symbolNext[symbol] |
|
symbolNext[symbol] = nextState + 1 |
|
nBits := s.actualTableLog - byte(highBits(uint32(nextState))) |
|
s.decTable[u].nbBits = nBits |
|
newState := (nextState << nBits) - tableSize |
|
if newState >= tableSize { |
|
return fmt.Errorf("newState (%d) outside table size (%d)", newState, tableSize) |
|
} |
|
if newState == uint16(u) && nBits == 0 { |
|
// Seems weird that this is possible with nbits > 0. |
|
return fmt.Errorf("newState (%d) == oldState (%d) and no bits", newState, u) |
|
} |
|
s.decTable[u].newState = newState |
|
} |
|
} |
|
return nil |
|
} |
|
|
|
// decompress will decompress the bitstream. |
|
// If the buffer is over-read an error is returned. |
|
func (s *Scratch) decompress() error { |
|
br := &s.bits |
|
br.init(s.br.unread()) |
|
|
|
var s1, s2 decoder |
|
// Initialize and decode first state and symbol. |
|
s1.init(br, s.decTable, s.actualTableLog) |
|
s2.init(br, s.decTable, s.actualTableLog) |
|
|
|
// Use temp table to avoid bound checks/append penalty. |
|
var tmp = s.ct.tableSymbol[:256] |
|
var off uint8 |
|
|
|
// Main part |
|
if !s.zeroBits { |
|
for br.off >= 8 { |
|
br.fillFast() |
|
tmp[off+0] = s1.nextFast() |
|
tmp[off+1] = s2.nextFast() |
|
br.fillFast() |
|
tmp[off+2] = s1.nextFast() |
|
tmp[off+3] = s2.nextFast() |
|
off += 4 |
|
// When off is 0, we have overflowed and should write. |
|
if off == 0 { |
|
s.Out = append(s.Out, tmp...) |
|
if len(s.Out) >= s.DecompressLimit { |
|
return fmt.Errorf("output size (%d) > DecompressLimit (%d)", len(s.Out), s.DecompressLimit) |
|
} |
|
} |
|
} |
|
} else { |
|
for br.off >= 8 { |
|
br.fillFast() |
|
tmp[off+0] = s1.next() |
|
tmp[off+1] = s2.next() |
|
br.fillFast() |
|
tmp[off+2] = s1.next() |
|
tmp[off+3] = s2.next() |
|
off += 4 |
|
if off == 0 { |
|
s.Out = append(s.Out, tmp...) |
|
// When off is 0, we have overflowed and should write. |
|
if len(s.Out) >= s.DecompressLimit { |
|
return fmt.Errorf("output size (%d) > DecompressLimit (%d)", len(s.Out), s.DecompressLimit) |
|
} |
|
} |
|
} |
|
} |
|
s.Out = append(s.Out, tmp[:off]...) |
|
|
|
// Final bits, a bit more expensive check |
|
for { |
|
if s1.finished() { |
|
s.Out = append(s.Out, s1.final(), s2.final()) |
|
break |
|
} |
|
br.fill() |
|
s.Out = append(s.Out, s1.next()) |
|
if s2.finished() { |
|
s.Out = append(s.Out, s2.final(), s1.final()) |
|
break |
|
} |
|
s.Out = append(s.Out, s2.next()) |
|
if len(s.Out) >= s.DecompressLimit { |
|
return fmt.Errorf("output size (%d) > DecompressLimit (%d)", len(s.Out), s.DecompressLimit) |
|
} |
|
} |
|
return br.close() |
|
} |
|
|
|
// decoder keeps track of the current state and updates it from the bitstream. |
|
type decoder struct { |
|
state uint16 |
|
br *bitReader |
|
dt []decSymbol |
|
} |
|
|
|
// init will initialize the decoder and read the first state from the stream. |
|
func (d *decoder) init(in *bitReader, dt []decSymbol, tableLog uint8) { |
|
d.dt = dt |
|
d.br = in |
|
d.state = in.getBits(tableLog) |
|
} |
|
|
|
// next returns the next symbol and sets the next state. |
|
// At least tablelog bits must be available in the bit reader. |
|
func (d *decoder) next() uint8 { |
|
n := &d.dt[d.state] |
|
lowBits := d.br.getBits(n.nbBits) |
|
d.state = n.newState + lowBits |
|
return n.symbol |
|
} |
|
|
|
// finished returns true if all bits have been read from the bitstream |
|
// and the next state would require reading bits from the input. |
|
func (d *decoder) finished() bool { |
|
return d.br.finished() && d.dt[d.state].nbBits > 0 |
|
} |
|
|
|
// final returns the current state symbol without decoding the next. |
|
func (d *decoder) final() uint8 { |
|
return d.dt[d.state].symbol |
|
} |
|
|
|
// nextFast returns the next symbol and sets the next state. |
|
// This can only be used if no symbols are 0 bits. |
|
// At least tablelog bits must be available in the bit reader. |
|
func (d *decoder) nextFast() uint8 { |
|
n := d.dt[d.state] |
|
lowBits := d.br.getBitsFast(n.nbBits) |
|
d.state = n.newState + lowBits |
|
return n.symbol |
|
}
|
|
|