lib/encoding/zstd: properly apply size limits

Previously, zstd Decoder didn't take in account Request Size limits
applied by VictoriaMetrics components.  And in case of incorrectly formed zstd block, VictoriaMetrics
component may allocate extra memory. Which may lead to the OOM errors.

This commit makes ingest endpoints check frame content size and window size headers based on MaxRequest Limits.
This commit is contained in:
Nikolay
2025-11-13 18:11:18 +01:00
committed by f41gh7
parent fa85726a82
commit 10f7cd2ffc
8 changed files with 193 additions and 7 deletions

View File

@@ -22,6 +22,8 @@ func CompressZSTDLevel(dst, src []byte, compressLevel int) []byte {
// DecompressZSTD decompresses src, appends the result to dst and returns
// the appended dst.
//
// This function must be called only for the trusted src.
func DecompressZSTD(dst, src []byte) ([]byte, error) {
decompressCalls.Inc()
b, err := zstd.Decompress(dst, src)

View File

@@ -7,10 +7,17 @@ import (
)
// Decompress appends decompressed src to dst and returns the result.
//
// This function must be called only for the trusted src.
func Decompress(dst, src []byte) ([]byte, error) {
return gozstd.Decompress(dst, src)
}
// Decompress appends decompressed src to dst and returns the result.
func DecompressLimited(dst, src []byte, maxDataSizeBytes int) ([]byte, error) {
return gozstd.DecompressLimited(dst, src, maxDataSizeBytes)
}
// CompressLevel appends compressed src to dst and returns the result.
//
// The given compressionLevel is used for the compression.

View File

@@ -11,7 +11,8 @@ import (
)
var (
decoder *zstd.Decoder
decodersMu sync.Mutex
decoders atomic.Value
mu sync.Mutex
@@ -24,15 +25,27 @@ func init() {
av.Store(r)
var err error
decoder, err = zstd.NewReader(nil)
decoder, err := zstd.NewReader(nil)
if err != nil {
logger.Panicf("BUG: failed to create ZSTD reader: %s", err)
}
d := make(map[int]*zstd.Decoder)
d[0] = decoder
decoders.Store(d)
}
// Decompress appends decompressed src to dst and returns the result.
//
// This function must be called only for the trusted src.
func Decompress(dst, src []byte) ([]byte, error) {
return decoder.DecodeAll(src, dst)
d := getDecoder(0)
return d.DecodeAll(src, dst)
}
// Decompress appends decompressed src to dst and returns the result.
func DecompressLimited(dst, src []byte, maxDataSizeBytes int) ([]byte, error) {
d := getDecoder(maxDataSizeBytes)
return d.DecodeAll(src, dst)
}
// CompressLevel appends compressed src to dst and returns the result.
@@ -50,6 +63,34 @@ func CompressLevel(dst, src []byte, compressionLevel int) []byte {
return e.EncodeAll(src, dst)
}
func getDecoder(maxMemory int) *zstd.Decoder {
r := decoders.Load().(map[int]*zstd.Decoder)
d := r[maxMemory]
if d != nil {
return d
}
decodersMu.Lock()
// Create the decoder under lock in order to prevent from wasted work
// when concurrent goroutines create decoder for the same compressionLevel.
r1 := decoders.Load().(map[int]*zstd.Decoder)
if d = r1[maxMemory]; d == nil {
var err error
d, err = zstd.NewReader(nil, zstd.WithDecoderMaxMemory(uint64(maxMemory)))
if err != nil {
logger.Panicf("BUG: failed to create ZSTD reader: %s", err)
}
r2 := make(map[int]*zstd.Decoder)
for k, v := range r1 {
r2[k] = v
}
r2[maxMemory] = d
decoders.Store(r2)
}
decodersMu.Unlock()
return d
}
func getEncoder(compressionLevel zstd.EncoderLevel) *zstd.Encoder {
r := av.Load().(map[zstd.EncoderLevel]*zstd.Encoder)
e := r[compressionLevel]

View File

@@ -0,0 +1,69 @@
//go:build !cgo
package zstd
import (
"bytes"
"encoding/hex"
"fmt"
"testing"
)
func TestDecomrpessLimitedOk(t *testing.T) {
f := func(compressedData []byte, limit int) {
t.Helper()
_, err := DecompressLimited(nil, compressedData, limit)
if err != nil {
t.Fatalf("cannot decompress data with limit=%d: %s", limit, err)
}
}
var bb bytes.Buffer
for bb.Len() < 12*128*1024 {
fmt.Fprintf(&bb, "compress/decompress big data %d, ", bb.Len())
}
originData := bb.Bytes()
// block decompression
cd := CompressLevel(nil, originData, 0)
// decompressed size matches block limit
f(cd, len(originData))
// unlimited
f(cd, 0)
}
func TestDecompressLimitedFail(t *testing.T) {
f := func(input []byte, limit int) {
t.Helper()
_, err := DecompressLimited(nil, input, limit)
if err == nil {
t.Errorf("unexpected nil-error for decompress with limit: %d", limit)
}
}
var bb bytes.Buffer
for bb.Len() < 12*128*1024 {
fmt.Fprintf(&bb, "compress/decompress big data %d, ", bb.Len())
}
// valid input bigger than limit
f(bb.Bytes(), 1024)
input, err := hex.DecodeString("28b52ffd8400005ed0b209000030ecaf4412")
if err != nil {
t.Fatalf("BUG: unexpected hex input: %s", err)
}
// input with framecontent bigger than actual payload
f(input, 512)
// input with stream windowSize bigger than limit
input, err = hex.DecodeString("28b52ffd04981900003030304e8da22b")
if err != nil {
t.Fatalf("BUG: unexpected hex input: %s", err)
}
f(input, 8*1e6*10)
}

View File

@@ -3,6 +3,9 @@
package zstd
import (
"bytes"
"encoding/hex"
"fmt"
"math/rand"
"testing"
@@ -10,6 +13,66 @@ import (
cgo "github.com/valyala/gozstd"
)
func TestDecomrpessLimitedOK(t *testing.T) {
f := func(compressedData []byte, limit int) {
t.Helper()
_, err := DecompressLimited(nil, compressedData, limit)
if err != nil {
t.Fatalf("cannot decompress data with limit=%d: %s", limit, err)
}
}
var bb bytes.Buffer
for bb.Len() < 12*128*1024 {
fmt.Fprintf(&bb, "compress/decompress big data %d, ", bb.Len())
}
originData := bb.Bytes()
// block decompression
cd := CompressLevel(nil, originData, 0)
// decompressed size matches block limit
f(cd, len(originData))
// unlimited
f(cd, 0)
}
func TestDecompressLimitedFail(t *testing.T) {
f := func(input []byte, limit int) {
t.Helper()
_, err := DecompressLimited(nil, input, limit)
if err == nil {
t.Errorf("unexpected nil-error for decompress with limit: %d", limit)
}
}
var bb bytes.Buffer
for bb.Len() < 12*128*1024 {
fmt.Fprintf(&bb, "compress/decompress big data %d, ", bb.Len())
}
// valid input bigger than limit
f(bb.Bytes(), 1024)
input, err := hex.DecodeString("28b52ffd8400005ed0b209000030ecaf4412")
if err != nil {
t.Fatalf("BUG: unexpected hex input: %s", err)
}
// input with framecontent bigger than actual payload
f(input, 512)
// input with stream windowSize bigger than limit
input, err = hex.DecodeString("28b52ffd04981900003030304e8da22b")
if err != nil {
t.Fatalf("BUG: unexpected hex input: %s", err)
}
f(input, 8*1e6*10)
}
func TestCompressDecompress(t *testing.T) {
testCrossCompressDecompress(t, []byte("a"))
testCrossCompressDecompress(t, []byte("foobarbaz"))