diff --git a/sei-tendermint/internal/protoutils/alloc_scan.go b/sei-tendermint/internal/protoutils/alloc_scan.go new file mode 100644 index 0000000000..7a24cdcd67 --- /dev/null +++ b/sei-tendermint/internal/protoutils/alloc_scan.go @@ -0,0 +1,302 @@ +package protoutils + +import ( + "fmt" + "reflect" + + gogoproto "github.com/gogo/protobuf/proto" + "google.golang.org/protobuf/encoding/protowire" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/reflect/protoregistry" +) + +var ( + pointerSize = int(reflect.TypeFor[*byte]().Size()) // 8 on 64-bit + sliceHeaderSize = int(reflect.TypeFor[[]byte]().Size()) // 24 on 64-bit + stringHeaderSize = int(reflect.TypeFor[string]().Size()) // 16 on 64-bit + + // mapEntryOverhead is added per map entry to account for Go runtime map + // internals (hmap struct, bucket array slots, tophash bytes) that are not + // captured by the map-entry message struct size. Adding it per entry + // over-counts the fixed hmap header for multi-entry maps, which is + // conservative. 8×pointerSize matches the number of fields in runtime.hmap. + mapEntryOverhead = 8 * int(reflect.TypeFor[*byte]().Size()) +) + +// allocEstimate walks raw protobuf wire bytes and returns a conservative +// upper-bound on the heap bytes that proto.Unmarshal would allocate. +// +// The estimate accounts for: +// - the Go struct for each message occurrence (looked up from protoregistry) +// - backing arrays for bytes and string fields +// - pointer elements in repeated-message slices +// +// Unknown fields (field numbers not in the descriptor) are stored verbatim by +// proto.Unmarshal in a single raw []byte blob on the struct without decoding. +// Their allocation cost equals their wire size, so we add the wire bytes for +// each unknown field occurrence. This is exact for bytes-type unknown fields +// and a slight over-count for scalar unknown fields (which store tag+value +// together), but both are correct in the conservative direction. +// +// The function returns an error on corrupt or truncated wire bytes. Truncation +// (n < 0 from a Consume* call) surfaces as protowire.ParseError("unexpected +// end of data") and is treated identically to corruption — the caller receives +// an error and Unmarshal is never called. +func allocEstimate(data []byte, desc protoreflect.MessageDescriptor) (int, error) { + total := msgStructSize(desc) + + for len(data) > 0 { + num, typ, tagLen := protowire.ConsumeTag(data) + if tagLen <= 0 { + return 0, fmt.Errorf("tag: %w", protowire.ParseError(tagLen)) + } + if num == 0 { + // Field number 0 is reserved and illegal in the protobuf spec. + return 0, fmt.Errorf("invalid field number 0") + } + data = data[tagLen:] + + fd := desc.Fields().ByNumber(num) + + switch typ { + case protowire.BytesType: + val, n := protowire.ConsumeBytes(data) + if n <= 0 { + return 0, fmt.Errorf("field %d bytes: %w", num, protowire.ParseError(n)) + } + data = data[n:] + add, err := bytesFieldSize(tagLen, n, val, fd) + if err != nil { + return 0, err + } + total += add + + case protowire.VarintType: + _, n := protowire.ConsumeVarint(data) + if n <= 0 { + return 0, fmt.Errorf("field %d varint: %w", num, protowire.ParseError(n)) + } + data = data[n:] + if fd == nil || !isVarintKind(fd.Kind()) { + // Unknown field or known field with wrong wire type: proto.Unmarshal + // stores it verbatim in the unknown fields blob. + total += tagLen + n + } else if fd.IsList() { + total += sliceHeaderSize + scalarElementSize(fd.Kind()) + } + + case protowire.Fixed32Type: + _, n := protowire.ConsumeFixed32(data) + if n <= 0 { + return 0, fmt.Errorf("field %d fixed32: %w", num, protowire.ParseError(n)) + } + data = data[n:] + if fd == nil || !isFixed32Kind(fd.Kind()) { + // Unknown field or known field with wrong wire type: proto.Unmarshal + // stores it verbatim in the unknown fields blob. + total += tagLen + n + } else if fd.IsList() { + total += sliceHeaderSize + scalarElementSize(fd.Kind()) + } + + case protowire.Fixed64Type: + _, n := protowire.ConsumeFixed64(data) + if n <= 0 { + return 0, fmt.Errorf("field %d fixed64: %w", num, protowire.ParseError(n)) + } + data = data[n:] + if fd == nil || !isFixed64Kind(fd.Kind()) { + // Unknown field or known field with wrong wire type: proto.Unmarshal + // stores it verbatim in the unknown fields blob. + total += tagLen + n + } else if fd.IsList() { + total += sliceHeaderSize + scalarElementSize(fd.Kind()) + } + + case protowire.StartGroupType: + val, n := protowire.ConsumeGroup(num, data) + if n <= 0 { + return 0, fmt.Errorf("field %d group: %w", num, protowire.ParseError(n)) + } + data = data[n:] + if fd != nil && (fd.Kind() == protoreflect.GroupKind) { + sub, err := allocEstimate(val, fd.Message()) + if err != nil { + return 0, err + } + total += sub + } else { + total += tagLen + n + } + + default: + return 0, fmt.Errorf("unknown wire type %d at field %d", typ, num) + } + } + return total, nil +} + +// bytesFieldSize returns the allocation estimate for one BytesType wire record. +// tagLen+n is the full wire record size (tag + length varint + payload). +// val is the payload slice (without tag or length prefix). +// fd is nil for unknown fields. +func bytesFieldSize(tagLen, n int, val []byte, fd protoreflect.FieldDescriptor) (int, error) { + if fd == nil { + // Unknown field: proto.Unmarshal appends the full wire record to the + // unknown-fields blob. tagLen+n covers tag + length varint + payload. + return tagLen + n, nil + } + + total := 0 + if fd.IsMap() { + // Map fields: Go allocates a runtime map (hmap), not a slice. Add + // per-entry overhead for hmap fields, bucket slots, and tophash. + // Over-counts the fixed hmap header across N entries, which is + // conservative. + total += mapEntryOverhead + } else if fd.IsList() { + // Repeated fields: one slice header per field. Adding it per + // occurrence over-counts by (N-1)*sliceHeaderSize across N elements + // — noise near a 1MB limit. + total += sliceHeaderSize + } + + switch fd.Kind() { + case protoreflect.MessageKind, protoreflect.GroupKind: + if fd.IsList() { + total += pointerSize // pointer element in the backing array + } + sub, err := allocEstimate(val, fd.Message()) + if err != nil { + return 0, err + } + total += sub + case protoreflect.BytesKind: + total += sliceHeaderSize + len(val) + case protoreflect.StringKind: + total += stringHeaderSize + len(val) + default: + // Packed repeated scalar. Fixed-width kinds have equal wire and Go + // sizes so len(val) is exact. Varint kinds (bool, int32, int64, + // uint32, uint64, sint32, sint64, enum) encode small values in 1 byte + // on the wire while occupying 4 or 8 bytes in the Go slice — up to 8× + // amplification. We count elements and multiply by Go element size. + packed, err := packedAllocSize(val, fd.Kind()) + if err != nil { + return 0, err + } + total += packed + } + return total, nil +} + +// packedAllocSize returns the estimated Go heap bytes for a packed repeated +// scalar field whose raw wire bytes are bz. +// +// Fixed-width wire types (float, double, fixed32, fixed64, sfixed32, sfixed64) +// have the same size on the wire and in Go, so len(bz) is exact. +// +// Varint-encoded types (bool, int32, int64, uint32, uint64, sint32, sint64, +// enum) can encode small values in as little as 1 byte while each element +// occupies 4 or 8 bytes in the Go slice backing array. We walk the payload +// counting elements and multiply by the Go element size. +func packedAllocSize(bz []byte, kind protoreflect.Kind) (int, error) { + switch kind { + case protoreflect.Fixed32Kind, protoreflect.Sfixed32Kind, protoreflect.FloatKind: + return len(bz), nil // 4 bytes wire == 4 bytes Go, exact + case protoreflect.Fixed64Kind, protoreflect.Sfixed64Kind, protoreflect.DoubleKind: + return len(bz), nil // 8 bytes wire == 8 bytes Go, exact + case protoreflect.BoolKind: + n, err := countVarintsInPacked(bz) + return n * 1, err // bool = 1 byte in Go + case protoreflect.Int32Kind, protoreflect.Uint32Kind, + protoreflect.Sint32Kind, protoreflect.EnumKind: + n, err := countVarintsInPacked(bz) + return n * 4, err + case protoreflect.Int64Kind, protoreflect.Uint64Kind, protoreflect.Sint64Kind: + n, err := countVarintsInPacked(bz) + return n * 8, err + default: + panic(fmt.Sprintf("packedAllocSize: unexpected kind %v", kind)) + } +} + +// scalarElementSize returns the size in bytes of one element in the Go slice +// backing array for a repeated scalar field. +func scalarElementSize(kind protoreflect.Kind) int { + switch kind { + case protoreflect.BoolKind: + return 1 + case protoreflect.Int32Kind, protoreflect.Uint32Kind, protoreflect.Sint32Kind, + protoreflect.EnumKind, protoreflect.Fixed32Kind, protoreflect.Sfixed32Kind, + protoreflect.FloatKind: + return 4 + case protoreflect.Int64Kind, protoreflect.Uint64Kind, protoreflect.Sint64Kind, + protoreflect.Fixed64Kind, protoreflect.Sfixed64Kind, protoreflect.DoubleKind: + return 8 + default: + panic(fmt.Sprintf("scalarElementSize: unexpected kind %v", kind)) + } +} + +func isVarintKind(k protoreflect.Kind) bool { + switch k { + case protoreflect.BoolKind, protoreflect.EnumKind, + protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Uint32Kind, + protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Uint64Kind: + return true + } + return false +} + +func isFixed32Kind(k protoreflect.Kind) bool { + switch k { + case protoreflect.Fixed32Kind, protoreflect.Sfixed32Kind, protoreflect.FloatKind: + return true + } + return false +} + +func isFixed64Kind(k protoreflect.Kind) bool { + switch k { + case protoreflect.Fixed64Kind, protoreflect.Sfixed64Kind, protoreflect.DoubleKind: + return true + } + return false +} + +// countVarintsInPacked counts the number of varint-encoded elements in a +// packed repeated field payload. +func countVarintsInPacked(bz []byte) (int, error) { + count := 0 + for len(bz) > 0 { + _, n := protowire.ConsumeVarint(bz) + if n <= 0 { + return 0, fmt.Errorf("packed varint: %w", protowire.ParseError(n)) + } + bz = bz[n:] + count++ + } + return count, nil +} + +// msgStructSize returns the size of the Go struct backing desc. +// It tries the google protobuf v2 registry first (for protoc-gen-go types), +// then falls back to the gogoproto registry (for protoc-gen-gogofaster types +// used by Tendermint P2P). Panics if the type is not registered in either +// registry, since that indicates a programming error. +func msgStructSize(desc protoreflect.MessageDescriptor) int { + if desc.IsMapEntry() { + // Synthetic map-entry types have no standalone Go struct; the runtime map + // stores keys and values in bucket arrays. Return 0 here; mapEntryOverhead + // is added per entry at the call site to account for runtime overhead. + return 0 + } + if mt, err := protoregistry.GlobalTypes.FindMessageByName(desc.FullName()); err == nil { + return int(reflect.TypeOf(mt.Zero().Interface()).Elem().Size()) + } + if t := gogoproto.MessageType(string(desc.FullName())); t != nil { + return int(t.Elem().Size()) + } + panic(fmt.Sprintf("protoutils: message type not registered: %s", desc.FullName())) +} diff --git a/sei-tendermint/internal/protoutils/alloc_scan_load_test.go b/sei-tendermint/internal/protoutils/alloc_scan_load_test.go new file mode 100644 index 0000000000..f7715ec008 --- /dev/null +++ b/sei-tendermint/internal/protoutils/alloc_scan_load_test.go @@ -0,0 +1,111 @@ +package protoutils_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/encoding/protowire" + "google.golang.org/protobuf/proto" + + autopb "github.com/sei-protocol/sei-chain/sei-tendermint/internal/autobahn/pb" + "github.com/sei-protocol/sei-chain/sei-tendermint/internal/protoutils" + "github.com/sei-protocol/sei-chain/sei-tendermint/internal/protoutils/test/a/pb" +) + +const ( + maxTxsPerBlock = 2000 // Payload.txs max_count + maxTotalTxBytes = 2048000 // Payload.txs max_total_size + maxBytesPerTx = maxTotalTxBytes / maxTxsPerBlock // 1024 +) + +// maxBlock builds a Block at the wireguard limits: maxTxsPerBlock transactions +// each of maxBytesPerTx bytes, with a minimal valid BlockHeader. +func maxBlock() *autopb.Block { + txs := make([][]byte, maxTxsPerBlock) + for i := range txs { + txs[i] = make([]byte, maxBytesPerTx) + } + return &autopb.Block{ + Header: &autopb.BlockHeader{ + Lane: &autopb.PublicKey{Ed25519: make([]byte, 32)}, + BlockNumber: proto.Uint64(1), + ParentHash: make([]byte, 32), + PayloadHash: make([]byte, 32), + }, + Payload: &autopb.Payload{ + CreatedAt: &autopb.Timestamp{Seconds: proto.Int64(1), Nanos: proto.Int32(0)}, + TotalGasWanted: proto.Uint64(100_000_000), + TotalGasEstimated: proto.Uint64(100_000_000), + Txs: txs, + }, + } +} + +// TestUnmarshalWithLimit_MaxBlockAccepted verifies that a legitimately +// max-sized Block (max txs × max bytes/tx + valid header) is accepted by +// UnmarshalWithLimit with a generous limit. This guards against the estimate +// being so conservative that valid messages are incorrectly rejected. +func TestUnmarshalWithLimit_MaxBlockAccepted(t *testing.T) { + block := maxBlock() + bz, err := proto.Marshal(block) + require.NoError(t, err) + t.Logf("max block wire size: %d bytes", len(bz)) + + // Limit is 4MB: 2× the 2MB tx payload, leaving room for header overhead + // and the ~8× varint amplification in the worst case. + const limitBytes = 4 << 20 + _, err = protoutils.UnmarshalWithLimit[*autopb.Block](bz, limitBytes) + require.NoError(t, err, "legitimately max-sized block must be accepted") +} + +// amplifiedPayload builds a wire payload for OuterNotSized: 10k empty +// SizedOk message entries. Each encodes as 2 bytes on the wire (tag + length 0) +// but would allocate a full SizedOk struct during proto.Unmarshal. +// Total wire size: ~20KB. Total allocation without the limit guard: many MB. +func amplifiedPayload() []byte { + var bz []byte + for range 10_000 { + bz = protowire.AppendTag(bz, 2, protowire.BytesType) // OuterNotSized.b (repeated SizedOk) + bz = protowire.AppendBytes(bz, nil) // empty SizedOk + } + return bz +} + +// BenchmarkUnmarshalWithLimit_MaxBlock measures the overhead of allocEstimate +// on a max-sized Block wire payload. The pre-scan should be cheap relative to +// proto.Unmarshal itself. +func BenchmarkUnmarshalWithLimit_MaxBlock(b *testing.B) { + block := maxBlock() + bz, err := proto.Marshal(block) + require.NoError(b, err) + b.Logf("wire size: %d bytes", len(bz)) + + const limitBytes = 4 << 20 + b.ResetTimer() + b.SetBytes(int64(len(bz))) + for range b.N { + _, err := protoutils.UnmarshalWithLimit[*autopb.Block](bz, limitBytes) + if err != nil { + b.Fatal(err) + } + } +} + +// BenchmarkUnmarshalWithLimit_AmplifiedPayload measures the extreme bad case: +// 10k empty repeated-message entries that are ~20KB on the wire but would +// allocate many MB of Go structs during proto.Unmarshal. allocEstimate must +// catch and reject this quickly. Unmarshal is never called for rejected messages. +func BenchmarkUnmarshalWithLimit_AmplifiedPayload(b *testing.B) { + bz := amplifiedPayload() + b.Logf("wire size: %d bytes", len(bz)) + + const limitBytes = 1 << 20 // 1MB + b.ResetTimer() + b.SetBytes(int64(len(bz))) + for range b.N { + _, err := protoutils.UnmarshalWithLimit[*pb.OuterNotSized](bz, limitBytes) + if err == nil { + b.Fatal("amplified payload must be rejected") + } + } +} diff --git a/sei-tendermint/internal/protoutils/alloc_scan_test.go b/sei-tendermint/internal/protoutils/alloc_scan_test.go new file mode 100644 index 0000000000..b218900edf --- /dev/null +++ b/sei-tendermint/internal/protoutils/alloc_scan_test.go @@ -0,0 +1,303 @@ +package protoutils_test + +import ( + "fmt" + "testing" + + gogoproto "github.com/gogo/protobuf/proto" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/encoding/protowire" + + "github.com/sei-protocol/sei-chain/sei-tendermint/internal/protoutils" + "github.com/sei-protocol/sei-chain/sei-tendermint/internal/protoutils/test/a/pb" + tmproto "github.com/sei-protocol/sei-chain/sei-tendermint/proto/tendermint/types" +) + +// TestUnmarshalWithLimit_SmallMessageAccepted verifies that a legitimate small +// message is accepted when the limit is generous. +func TestUnmarshalWithLimit_SmallMessageAccepted(t *testing.T) { + msg := &pb.OuterNotSized{ + B: make([]*pb.SizedOk, 3), + } + for i := range msg.B { + msg.B[i] = &pb.SizedOk{} + } + bz := protoutils.Marshal(msg) + _, err := protoutils.UnmarshalWithLimit[*pb.OuterNotSized](bz, 1<<20 /* 1MB */) + require.NoError(t, err) +} + +// TestUnmarshalWithLimit_ManyEmptyEntriesRejected verifies the core amplification +// scenario: many empty repeated-message entries are tiny on the wire but each +// cause a Go heap allocation. The limit catches this before proto.Unmarshal runs. +func TestUnmarshalWithLimit_ManyEmptyEntriesRejected(t *testing.T) { + // 10_000 empty SizedOk entries — small on the wire but each allocates a struct. + msg := &pb.OuterNotSized{B: make([]*pb.SizedOk, 10_000)} + for i := range msg.B { + msg.B[i] = &pb.SizedOk{} + } + bz := protoutils.Marshal(msg) + require.Less(t, len(bz), 1<<20, "wire bytes should be well under 1MB") + + _, err := protoutils.UnmarshalWithLimit[*pb.OuterNotSized](bz, 1<<20 /* 1MB */) + require.Error(t, err, "10k empty entries should exceed the 1MB allocation estimate") +} + +// TestUnmarshalWithLimit_ZeroLimitPanics verifies that limitBytes=0 panics. +// A zero (or negative) limit is a programming error, not a runtime condition. +func TestUnmarshalWithLimit_ZeroLimitPanics(t *testing.T) { + msg := &pb.OuterNotSized{B: []*pb.SizedOk{{}}} + bz := protoutils.Marshal(msg) + require.Panics(t, func() { + _, _ = protoutils.UnmarshalWithLimit[*pb.OuterNotSized](bz, 0) + }) + require.Panics(t, func() { + _, _ = protoutils.UnmarshalWithLimit[*pb.OuterNotSized](bz, -1) + }) +} + +// TestUnmarshalWithLimit_ResultIsCorrect verifies that a message passing the +// limit check is correctly unmarshalled. +func TestUnmarshalWithLimit_ResultIsCorrect(t *testing.T) { + msg := &pb.Msg{StringValue: "hello", RepeatedValue: []string{"a", "b"}} + bz := protoutils.Marshal(msg) + got, err := protoutils.UnmarshalWithLimit[*pb.Msg](bz, 1<<20) + require.NoError(t, err) + require.Equal(t, "hello", got.StringValue) + require.Equal(t, []string{"a", "b"}, got.RepeatedValue) +} + +// TestUnmarshalWithLimit_LargePayloadRejected verifies that a single large +// bytes field is rejected when it exceeds the limit. +func TestUnmarshalWithLimit_LargePayloadRejected(t *testing.T) { + msg := &pb.NotSized{LargeField: make([]byte, 2<<20 /* 2MB */)} + bz := protoutils.Marshal(msg) + _, err := protoutils.UnmarshalWithLimit[*pb.NotSized](bz, 1<<20 /* 1MB */) + require.Error(t, err) +} + +// TestUnmarshalWithLimit_PackedVarintAmplificationCounted verifies that packed +// repeated uint64 fields with small values are not undercounted. Each value +// encodes as 1 byte on the wire but occupies 8 bytes in the Go slice, giving +// up to 8× amplification that must be accounted for. +func TestUnmarshalWithLimit_PackedVarintAmplificationCounted(t *testing.T) { + // 200k uint64 values of 0: each encodes as 1 varint byte = ~200KB wire, + // but the Go slice is 200k×8 = ~1.6MB. + vals := make([]uint64, 200_000) + msg := &pb.SizedOk{U64Count: vals[:4]} // SizedOk.U64Count is repeated uint64 + // Build a message with a large packed uint64 field using raw wire bytes + // since SizedOk caps U64Count at 4. Use field 11 (u64_count) directly. + var bz []byte + bz = protowire.AppendTag(bz, 11, protowire.BytesType) + var packed []byte + for range 200_000 { + packed = protowire.AppendVarint(packed, 0) + } + bz = protowire.AppendBytes(bz, packed) + _ = msg + + require.Less(t, len(bz), 1<<20, "wire bytes should be under 1MB") + _, err := protoutils.UnmarshalWithLimit[*pb.SizedOk](bz, 1<<20 /* 1MB */) + require.Error(t, err, "200k packed uint64 values should exceed 1MB allocation estimate due to 8x wire-to-Go amplification") +} + +// TestUnmarshalWithLimit_UnknownBytesFieldCounted verifies that a large unknown +// bytes field (field number not in the schema) is counted toward the limit. +// proto.Unmarshal stores unknown fields verbatim, allocating exactly len(val) +// bytes, so our estimate must catch this even without knowing the field type. +func TestUnmarshalWithLimit_UnknownBytesFieldCounted(t *testing.T) { + // Field 999 is unknown to pb.NotSized (which only has field 1). + var bz []byte + bz = protowire.AppendTag(bz, 999, protowire.BytesType) + bz = protowire.AppendBytes(bz, make([]byte, 2<<20 /* 2MB */)) + + _, err := protoutils.UnmarshalWithLimit[*pb.NotSized](bz, 1<<20 /* 1MB */) + require.Error(t, err, "large unknown bytes field must be counted toward the limit") +} + +// TestUnmarshalGogoWithLimit_ManyEmptySignaturesRejected verifies that +// UnmarshalGogoWithLimit protects gogoproto-generated Tendermint P2P types. +// A Commit with 100k empty CommitSig entries is tiny on the wire but would +// allocate many structs during Unmarshal. +func TestUnmarshalGogoWithLimit_ManyEmptySignaturesRejected(t *testing.T) { + // CommitSig has a non-nullable time.Time timestamp that encodes non-zero + // even when empty, so wire bytes grow faster than a purely empty message. + // Use 10k entries: small enough to stay comfortably under 1MB on the wire + // while still causing enough struct allocations to exceed the 1MB limit. + msg := &tmproto.Commit{Signatures: make([]tmproto.CommitSig, 10_000)} + bz, err := gogoproto.Marshal(msg) + require.NoError(t, err) + require.Less(t, len(bz), 1<<20, "wire bytes should be well under 1MB") + + out := &tmproto.Commit{} + err = protoutils.UnmarshalGogoWithLimit(bz, out, 1<<20 /* 1MB */) + require.Error(t, err, "10k CommitSig entries should exceed the 1MB allocation estimate") +} + +// TestUnmarshalGogoWithLimit_SingularFieldMerge documents and verifies that +// allocEstimate accumulates all wire occurrences of a singular field. +// +// Protobuf allows a singular field to appear multiple times on the wire; +// gogoproto merges them by appending repeated sub-fields. wireguard Scan +// checks each occurrence independently and passes each one. allocEstimate +// recurses into every occurrence and accumulates the totals, so the budget +// is consumed proportionally to the true decoded size. +func TestUnmarshalGogoWithLimit_SingularFieldMerge(t *testing.T) { + // Build 500 Commit blobs, each with 50 CommitSig entries. Each occurrence + // is individually small (~100 bytes wire), but gogoproto would merge them + // into a single Commit with 25,000 signatures. Total wire size: ~50KB. + commitPerOccurrence := &tmproto.Commit{Signatures: make([]tmproto.CommitSig, 50)} + commitBz, err := gogoproto.Marshal(commitPerOccurrence) + require.NoError(t, err) + + // Build a raw Proposal wire payload with last_commit (field 10) repeated + // 500 times. Each occurrence is individually small. + var proposalBz []byte + for range 500 { + proposalBz = protowire.AppendTag(proposalBz, 10, protowire.BytesType) + proposalBz = protowire.AppendBytes(proposalBz, commitBz) + } + require.Less(t, len(proposalBz), 1<<20, "wire bytes should be well under 1MB") + + // allocEstimate recurses into all 500 wire occurrences of last_commit, + // counting 500 × 50 CommitSig structs — exceeding the 1MB limit. + out := &tmproto.Proposal{} + err = protoutils.UnmarshalGogoWithLimit(proposalBz, out, 1<<20 /* 1MB */) + require.Error(t, err, "500 occurrences × 50 signatures should exceed the 1MB allocation estimate") +} + +// TestUnmarshalGogoWithLimit_SmallCommitAccepted verifies that a legitimate +// Commit with a handful of signatures passes the limit check. +func TestUnmarshalGogoWithLimit_SmallCommitAccepted(t *testing.T) { + msg := &tmproto.Commit{Height: 100, Signatures: make([]tmproto.CommitSig, 10)} + bz, err := gogoproto.Marshal(msg) + require.NoError(t, err) + + out := &tmproto.Commit{} + err = protoutils.UnmarshalGogoWithLimit(bz, out, 1<<20 /* 1MB */) + require.NoError(t, err) + require.Equal(t, int64(100), out.Height) +} + +// TestUnmarshalWithLimit_WireTypeMismatchNoPanic verifies that wire bytes +// presenting a repeated message field with a mismatched scalar wire type do +// not panic. scalarElementSize panics on MessageKind; the mismatch is instead +// counted as unknown-field bytes so the process stays alive. +func TestUnmarshalWithLimit_WireTypeMismatchNoPanic(t *testing.T) { + // Field 2 of OuterNotSized is repeated SizedOk (MessageKind), but we encode + // it as a varint — a wire type mismatch. Each occurrence is stored in the + // unknown-fields blob (~2 bytes). + var one []byte + one = protowire.AppendTag(one, 2, protowire.VarintType) + one = protowire.AppendVarint(one, 42) + + // A single mismatch: well within limit, must not panic. + _, err := protoutils.UnmarshalWithLimit[*pb.OuterNotSized](one, 1<<20) + require.NoError(t, err) + + // Many mismatches: each contributes ~2 bytes to the estimate, so enough + // occurrences must exceed the 1 byte limit, proving they are counted. + var many []byte + for range 1_000_000 { + many = protowire.AppendTag(many, 2, protowire.VarintType) + many = protowire.AppendVarint(many, 42) + } + _, err = protoutils.UnmarshalWithLimit[*pb.OuterNotSized](many, 1) + require.Error(t, err, "1M mismatched-type occurrences must exceed a 1-byte limit") +} + +// TestUnmarshalWithLimit_LargeMapRejected verifies that a message with many map +// entries is rejected when the entries' total allocation would exceed the limit. +// Map fields encode as repeated message entries; Go allocates runtime map +// overhead per entry beyond the key+value content. +func TestUnmarshalWithLimit_LargeMapRejected(t *testing.T) { + // 50k entries in a map: small values but each entry adds + // string headers + key + value + map runtime overhead. + msg := &pb.Msg{MapValue: make(map[string]string, 50_000)} + for i := range 50_000 { + k := fmt.Sprintf("key%d", i) + msg.MapValue[k] = "v" + } + bz := protoutils.Marshal(msg) + _, err := protoutils.UnmarshalWithLimit[*pb.Msg](bz, 1<<20 /* 1MB */) + require.Error(t, err, "50k map entries should exceed the 1MB allocation estimate") +} + +// TestUnmarshalWithLimit_SmallMapAccepted verifies that a small map passes. +func TestUnmarshalWithLimit_SmallMapAccepted(t *testing.T) { + msg := &pb.Msg{MapValue: map[string]string{"a": "b", "c": "d"}} + bz := protoutils.Marshal(msg) + _, err := protoutils.UnmarshalWithLimit[*pb.Msg](bz, 1<<20 /* 1MB */) + require.NoError(t, err) +} + +// TestUnmarshalWithLimit_TruncatedInputReturnsError verifies that wire bytes +// cut off mid-field return an error rather than panicking or silently +// accepting partial data. Truncation surfaces as protowire.ParseError +// ("unexpected end of data"), the same path as corrupt wire bytes. +// +// Note: a prefix that ends exactly on a field boundary is a valid shorter +// proto message (proto has no end-of-message marker), so we construct inputs +// that are definitely cut mid-field. +func TestUnmarshalWithLimit_TruncatedInputReturnsError(t *testing.T) { + // Case 1: tag present but bytes-field length prefix missing. + // A lone tag byte for a BytesType field with no following length varint. + var tagOnly []byte + tagOnly = protowire.AppendTag(tagOnly, 1, protowire.BytesType) + _, err := protoutils.UnmarshalWithLimit[*pb.NotSized](tagOnly, 1<<20) + require.Error(t, err, "tag with no value should return an error") + + // Case 2: bytes-field length prefix present but payload truncated. + // Claim 100 bytes follow, but provide only 10. + var truncBytes []byte + truncBytes = protowire.AppendTag(truncBytes, 1, protowire.BytesType) + truncBytes = protowire.AppendVarint(truncBytes, 100) // length = 100 + truncBytes = append(truncBytes, make([]byte, 10)...) // only 10 bytes + _, err = protoutils.UnmarshalWithLimit[*pb.NotSized](truncBytes, 1<<20) + require.Error(t, err, "truncated bytes payload should return an error") + + // Case 3: varint field tag present but varint value truncated mid-byte. + // A varint with the MSB set signals continuation; cut before the last byte. + var truncVarint []byte + truncVarint = protowire.AppendTag(truncVarint, 1, protowire.VarintType) + truncVarint = append(truncVarint, 0x80) // first varint byte with continuation bit set + _, err = protoutils.UnmarshalWithLimit[*pb.Msg](truncVarint, 1<<20) + require.Error(t, err, "truncated mid-varint should return an error") +} + +// TestUnmarshalWithLimit_UnpackedRepeatedScalarSliceHeaderCounted verifies that +// non-packed repeated scalar fields (Fixed64Type wire encoding) include the +// slice header in the allocation estimate. Each occurrence contributes +// sliceHeaderSize + elementSize, not just elementSize. +func TestUnmarshalWithLimit_UnpackedRepeatedScalarSliceHeaderCounted(t *testing.T) { + // SizedOk.f64_count is repeated fixed64 (field 14). fixed64 always uses + // Fixed64Type wire encoding — never packed. Each occurrence costs 8 bytes + // element + 24 bytes slice header = 32 bytes in the estimate. + // 100k occurrences × 32 = ~3.2MB, well over the 1MB limit. + // Wire size: 100k × (1 tag + 8 value) = ~900KB, under 1MB. + var bz []byte + for range 100_000 { + bz = protowire.AppendTag(bz, 14, protowire.Fixed64Type) + bz = protowire.AppendFixed64(bz, 0) + } + require.Less(t, len(bz), 1<<20, "wire bytes should be under 1MB") + _, err := protoutils.UnmarshalWithLimit[*pb.SizedOk](bz, 1<<20) + require.Error(t, err, "100k unpacked fixed64 elements should exceed 1MB due to per-occurrence slice header cost") +} + +// TestUnmarshalWithLimit_SmallUnknownFieldsAccepted verifies that small unknown +// scalar fields (varint, fixed32, fixed64) are accepted within a generous limit. +func TestUnmarshalWithLimit_SmallUnknownFieldsAccepted(t *testing.T) { + var bz []byte + for i := protowire.Number(100); i < 200; i++ { + bz = protowire.AppendTag(bz, i, protowire.VarintType) + bz = protowire.AppendVarint(bz, 42) + bz = protowire.AppendTag(bz, i+1000, protowire.Fixed32Type) + bz = protowire.AppendFixed32(bz, 0xdeadbeef) + bz = protowire.AppendTag(bz, i+2000, protowire.Fixed64Type) + bz = protowire.AppendFixed64(bz, 0xdeadbeefcafe) + } + + _, err := protoutils.UnmarshalWithLimit[*pb.NotSized](bz, 1<<20 /* 1MB */) + require.NoError(t, err, "small unknown scalar fields should be well within a 1MB limit") +} diff --git a/sei-tendermint/internal/protoutils/msg.go b/sei-tendermint/internal/protoutils/msg.go index 87d12d7720..f3d1630a54 100644 --- a/sei-tendermint/internal/protoutils/msg.go +++ b/sei-tendermint/internal/protoutils/msg.go @@ -1,9 +1,11 @@ package protoutils import ( + "fmt" "reflect" gogoproto "github.com/gogo/protobuf/proto" + golangproto "github.com/golang/protobuf/proto" //nolint:staticcheck // MessageReflect is the only bridge from gogoproto to protoreflect.Message "google.golang.org/protobuf/proto" "github.com/sei-protocol/sei-chain/sei-tendermint/internal/protoutils/runtime" @@ -40,6 +42,60 @@ func Unmarshal[T Message](bytes []byte) (T, error) { return t, err } +// UnmarshalWithLimit estimates the heap allocation that proto.Unmarshal would +// make for bytes and returns an error if the estimate exceeds limitBytes. +// This bounds allocation amplification where a small wire payload encodes many +// empty repeated-field entries, each causing a Go heap allocation. The estimate +// is conservative (may over-count) so legitimate messages must stay well within +// the limit. +func UnmarshalWithLimit[T Message](bytes []byte, limitBytes int) (T, error) { + if limitBytes <= 0 { + panic(fmt.Sprintf("protoutils: limitBytes must be positive, got %d", limitBytes)) + } + if err := Scan[T](bytes); err != nil { + return utils.Zero[T](), err + } + desc := New[T]().ProtoReflect().Descriptor() + est, err := allocEstimate(bytes, desc) + if err != nil { + return utils.Zero[T](), fmt.Errorf("protoutils: alloc scan: %w", err) + } + if est > limitBytes { + return utils.Zero[T](), fmt.Errorf("protoutils: message would allocate ~%d bytes, limit is %d", est, limitBytes) + } + t := New[T]() + if err := proto.Unmarshal(bytes, t); err != nil { + return utils.Zero[T](), err + } + return t, nil +} + +// UnmarshalGogoWithLimit is the gogoproto variant of UnmarshalWithLimit. +// It uses github.com/golang/protobuf's reflection bridge to obtain the +// protoreflect.MessageDescriptor from a gogoproto-generated type (which does +// not implement google.golang.org/protobuf/proto.Message directly), allowing +// the same allocEstimate walk to protect Tendermint P2P messages. +func UnmarshalGogoWithLimit(bz []byte, msg gogoproto.Message, limitBytes int) error { + if limitBytes <= 0 { + panic(fmt.Sprintf("protoutils: limitBytes must be positive, got %d", limitBytes)) + } + if msg == nil { + return fmt.Errorf("protoutils: nil message") + } + if err := ScanAny(bz, msg); err != nil { + return err + } + desc := golangproto.MessageReflect(msg).Descriptor() //nolint:staticcheck + est, err := allocEstimate(bz, desc) + if err != nil { + return fmt.Errorf("protoutils: alloc scan: %w", err) + } + if est > limitBytes { + return fmt.Errorf("protoutils: message would allocate ~%d bytes, limit is %d", est, limitBytes) + } + return gogoproto.Unmarshal(bz, msg) +} + // Scan walks bz once, applying the schema registered for T. Returns nil on // success, an error on malformed wire bytes or a rule violation. If T has no // registered schema, Scan is a no-op.