-
Notifications
You must be signed in to change notification settings - Fork 881
feat(protoutils): UnmarshalWithLimit — pre-decode allocation estimate #3615
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
wen-coding
wants to merge
3
commits into
main
Choose a base branch
from
wen/unmarshal_with_limit
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+752
−0
Open
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,302 @@ | ||
| package protoutils | ||
|
|
||
| import ( | ||
| "fmt" | ||
| "reflect" | ||
|
|
||
| gogoproto "github.com/gogo/protobuf/proto" | ||
| "google.golang.org/protobuf/encoding/protowire" | ||
| "google.golang.org/protobuf/reflect/protoreflect" | ||
| "google.golang.org/protobuf/reflect/protoregistry" | ||
| ) | ||
|
|
||
| var ( | ||
| pointerSize = int(reflect.TypeFor[*byte]().Size()) // 8 on 64-bit | ||
| sliceHeaderSize = int(reflect.TypeFor[[]byte]().Size()) // 24 on 64-bit | ||
| stringHeaderSize = int(reflect.TypeFor[string]().Size()) // 16 on 64-bit | ||
|
|
||
| // mapEntryOverhead is added per map entry to account for Go runtime map | ||
| // internals (hmap struct, bucket array slots, tophash bytes) that are not | ||
| // captured by the map-entry message struct size. Adding it per entry | ||
| // over-counts the fixed hmap header for multi-entry maps, which is | ||
| // conservative. 8×pointerSize matches the number of fields in runtime.hmap. | ||
| mapEntryOverhead = 8 * int(reflect.TypeFor[*byte]().Size()) | ||
| ) | ||
|
|
||
| // allocEstimate walks raw protobuf wire bytes and returns a conservative | ||
| // upper-bound on the heap bytes that proto.Unmarshal would allocate. | ||
| // | ||
| // The estimate accounts for: | ||
| // - the Go struct for each message occurrence (looked up from protoregistry) | ||
| // - backing arrays for bytes and string fields | ||
| // - pointer elements in repeated-message slices | ||
| // | ||
| // Unknown fields (field numbers not in the descriptor) are stored verbatim by | ||
| // proto.Unmarshal in a single raw []byte blob on the struct without decoding. | ||
| // Their allocation cost equals their wire size, so we add the wire bytes for | ||
| // each unknown field occurrence. This is exact for bytes-type unknown fields | ||
| // and a slight over-count for scalar unknown fields (which store tag+value | ||
| // together), but both are correct in the conservative direction. | ||
| // | ||
| // The function returns an error on corrupt or truncated wire bytes. Truncation | ||
| // (n < 0 from a Consume* call) surfaces as protowire.ParseError("unexpected | ||
| // end of data") and is treated identically to corruption — the caller receives | ||
| // an error and Unmarshal is never called. | ||
| func allocEstimate(data []byte, desc protoreflect.MessageDescriptor) (int, error) { | ||
| total := msgStructSize(desc) | ||
|
|
||
| for len(data) > 0 { | ||
| num, typ, tagLen := protowire.ConsumeTag(data) | ||
| if tagLen <= 0 { | ||
| return 0, fmt.Errorf("tag: %w", protowire.ParseError(tagLen)) | ||
| } | ||
| if num == 0 { | ||
| // Field number 0 is reserved and illegal in the protobuf spec. | ||
| return 0, fmt.Errorf("invalid field number 0") | ||
| } | ||
| data = data[tagLen:] | ||
|
|
||
| fd := desc.Fields().ByNumber(num) | ||
|
|
||
| switch typ { | ||
| case protowire.BytesType: | ||
| val, n := protowire.ConsumeBytes(data) | ||
| if n <= 0 { | ||
| return 0, fmt.Errorf("field %d bytes: %w", num, protowire.ParseError(n)) | ||
| } | ||
| data = data[n:] | ||
| add, err := bytesFieldSize(tagLen, n, val, fd) | ||
| if err != nil { | ||
| return 0, err | ||
| } | ||
| total += add | ||
|
|
||
| case protowire.VarintType: | ||
| _, n := protowire.ConsumeVarint(data) | ||
| if n <= 0 { | ||
| return 0, fmt.Errorf("field %d varint: %w", num, protowire.ParseError(n)) | ||
| } | ||
| data = data[n:] | ||
| if fd == nil || !isVarintKind(fd.Kind()) { | ||
| // Unknown field or known field with wrong wire type: proto.Unmarshal | ||
| // stores it verbatim in the unknown fields blob. | ||
| total += tagLen + n | ||
| } else if fd.IsList() { | ||
| total += scalarElementSize(fd.Kind()) | ||
| } | ||
|
|
||
| case protowire.Fixed32Type: | ||
| _, n := protowire.ConsumeFixed32(data) | ||
| if n <= 0 { | ||
| return 0, fmt.Errorf("field %d fixed32: %w", num, protowire.ParseError(n)) | ||
| } | ||
| data = data[n:] | ||
| if fd == nil || !isFixed32Kind(fd.Kind()) { | ||
| // Unknown field or known field with wrong wire type: proto.Unmarshal | ||
| // stores it verbatim in the unknown fields blob. | ||
| total += tagLen + n | ||
| } else if fd.IsList() { | ||
| total += scalarElementSize(fd.Kind()) | ||
| } | ||
|
|
||
| case protowire.Fixed64Type: | ||
| _, n := protowire.ConsumeFixed64(data) | ||
| if n <= 0 { | ||
| return 0, fmt.Errorf("field %d fixed64: %w", num, protowire.ParseError(n)) | ||
| } | ||
| data = data[n:] | ||
| if fd == nil || !isFixed64Kind(fd.Kind()) { | ||
| // Unknown field or known field with wrong wire type: proto.Unmarshal | ||
| // stores it verbatim in the unknown fields blob. | ||
| total += tagLen + n | ||
| } else if fd.IsList() { | ||
| total += scalarElementSize(fd.Kind()) | ||
| } | ||
|
cursor[bot] marked this conversation as resolved.
|
||
|
|
||
| case protowire.StartGroupType: | ||
| val, n := protowire.ConsumeGroup(num, data) | ||
| if n <= 0 { | ||
| return 0, fmt.Errorf("field %d group: %w", num, protowire.ParseError(n)) | ||
| } | ||
| data = data[n:] | ||
| if fd != nil && (fd.Kind() == protoreflect.GroupKind) { | ||
| sub, err := allocEstimate(val, fd.Message()) | ||
| if err != nil { | ||
| return 0, err | ||
| } | ||
| total += sub | ||
| } else { | ||
| total += tagLen + n | ||
| } | ||
|
|
||
| default: | ||
| return 0, fmt.Errorf("unknown wire type %d at field %d", typ, num) | ||
| } | ||
| } | ||
| return total, nil | ||
| } | ||
|
|
||
| // bytesFieldSize returns the allocation estimate for one BytesType wire record. | ||
| // tagLen+n is the full wire record size (tag + length varint + payload). | ||
| // val is the payload slice (without tag or length prefix). | ||
| // fd is nil for unknown fields. | ||
| func bytesFieldSize(tagLen, n int, val []byte, fd protoreflect.FieldDescriptor) (int, error) { | ||
| if fd == nil { | ||
| // Unknown field: proto.Unmarshal appends the full wire record to the | ||
| // unknown-fields blob. tagLen+n covers tag + length varint + payload. | ||
| return tagLen + n, nil | ||
| } | ||
|
|
||
| total := 0 | ||
| if fd.IsMap() { | ||
| // Map fields: Go allocates a runtime map (hmap), not a slice. Add | ||
| // per-entry overhead for hmap fields, bucket slots, and tophash. | ||
| // Over-counts the fixed hmap header across N entries, which is | ||
| // conservative. | ||
| total += mapEntryOverhead | ||
| } else if fd.IsList() { | ||
| // Repeated fields: one slice header per field. Adding it per | ||
| // occurrence over-counts by (N-1)*sliceHeaderSize across N elements | ||
| // — noise near a 1MB limit. | ||
| total += sliceHeaderSize | ||
| } | ||
|
|
||
| switch fd.Kind() { | ||
| case protoreflect.MessageKind, protoreflect.GroupKind: | ||
| if fd.IsList() { | ||
| total += pointerSize // pointer element in the backing array | ||
| } | ||
| sub, err := allocEstimate(val, fd.Message()) | ||
| if err != nil { | ||
| return 0, err | ||
| } | ||
| total += sub | ||
| case protoreflect.BytesKind: | ||
| total += sliceHeaderSize + len(val) | ||
| case protoreflect.StringKind: | ||
| total += stringHeaderSize + len(val) | ||
| default: | ||
| // Packed repeated scalar. Fixed-width kinds have equal wire and Go | ||
| // sizes so len(val) is exact. Varint kinds (bool, int32, int64, | ||
| // uint32, uint64, sint32, sint64, enum) encode small values in 1 byte | ||
| // on the wire while occupying 4 or 8 bytes in the Go slice — up to 8× | ||
| // amplification. We count elements and multiply by Go element size. | ||
| packed, err := packedAllocSize(val, fd.Kind()) | ||
| if err != nil { | ||
| return 0, err | ||
| } | ||
| total += packed | ||
| } | ||
| return total, nil | ||
| } | ||
|
|
||
| // packedAllocSize returns the estimated Go heap bytes for a packed repeated | ||
| // scalar field whose raw wire bytes are bz. | ||
| // | ||
| // Fixed-width wire types (float, double, fixed32, fixed64, sfixed32, sfixed64) | ||
| // have the same size on the wire and in Go, so len(bz) is exact. | ||
| // | ||
| // Varint-encoded types (bool, int32, int64, uint32, uint64, sint32, sint64, | ||
| // enum) can encode small values in as little as 1 byte while each element | ||
| // occupies 4 or 8 bytes in the Go slice backing array. We walk the payload | ||
| // counting elements and multiply by the Go element size. | ||
| func packedAllocSize(bz []byte, kind protoreflect.Kind) (int, error) { | ||
| switch kind { | ||
| case protoreflect.Fixed32Kind, protoreflect.Sfixed32Kind, protoreflect.FloatKind: | ||
| return len(bz), nil // 4 bytes wire == 4 bytes Go, exact | ||
| case protoreflect.Fixed64Kind, protoreflect.Sfixed64Kind, protoreflect.DoubleKind: | ||
| return len(bz), nil // 8 bytes wire == 8 bytes Go, exact | ||
| case protoreflect.BoolKind: | ||
| n, err := countVarintsInPacked(bz) | ||
| return n * 1, err // bool = 1 byte in Go | ||
| case protoreflect.Int32Kind, protoreflect.Uint32Kind, | ||
| protoreflect.Sint32Kind, protoreflect.EnumKind: | ||
| n, err := countVarintsInPacked(bz) | ||
| return n * 4, err | ||
| case protoreflect.Int64Kind, protoreflect.Uint64Kind, protoreflect.Sint64Kind: | ||
| n, err := countVarintsInPacked(bz) | ||
| return n * 8, err | ||
| default: | ||
| panic(fmt.Sprintf("packedAllocSize: unexpected kind %v", kind)) | ||
| } | ||
| } | ||
|
|
||
| // scalarElementSize returns the size in bytes of one element in the Go slice | ||
| // backing array for a repeated scalar field. | ||
| func scalarElementSize(kind protoreflect.Kind) int { | ||
| switch kind { | ||
| case protoreflect.BoolKind: | ||
| return 1 | ||
| case protoreflect.Int32Kind, protoreflect.Uint32Kind, protoreflect.Sint32Kind, | ||
| protoreflect.EnumKind, protoreflect.Fixed32Kind, protoreflect.Sfixed32Kind, | ||
| protoreflect.FloatKind: | ||
| return 4 | ||
| case protoreflect.Int64Kind, protoreflect.Uint64Kind, protoreflect.Sint64Kind, | ||
| protoreflect.Fixed64Kind, protoreflect.Sfixed64Kind, protoreflect.DoubleKind: | ||
| return 8 | ||
| default: | ||
| panic(fmt.Sprintf("scalarElementSize: unexpected kind %v", kind)) | ||
| } | ||
| } | ||
|
|
||
| func isVarintKind(k protoreflect.Kind) bool { | ||
| switch k { | ||
| case protoreflect.BoolKind, protoreflect.EnumKind, | ||
| protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Uint32Kind, | ||
| protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Uint64Kind: | ||
| return true | ||
| } | ||
| return false | ||
| } | ||
|
|
||
| func isFixed32Kind(k protoreflect.Kind) bool { | ||
| switch k { | ||
| case protoreflect.Fixed32Kind, protoreflect.Sfixed32Kind, protoreflect.FloatKind: | ||
| return true | ||
| } | ||
| return false | ||
| } | ||
|
|
||
| func isFixed64Kind(k protoreflect.Kind) bool { | ||
| switch k { | ||
| case protoreflect.Fixed64Kind, protoreflect.Sfixed64Kind, protoreflect.DoubleKind: | ||
| return true | ||
| } | ||
| return false | ||
| } | ||
|
|
||
| // countVarintsInPacked counts the number of varint-encoded elements in a | ||
| // packed repeated field payload. | ||
| func countVarintsInPacked(bz []byte) (int, error) { | ||
| count := 0 | ||
| for len(bz) > 0 { | ||
| _, n := protowire.ConsumeVarint(bz) | ||
| if n <= 0 { | ||
| return 0, fmt.Errorf("packed varint: %w", protowire.ParseError(n)) | ||
| } | ||
| bz = bz[n:] | ||
| count++ | ||
| } | ||
| return count, nil | ||
| } | ||
|
|
||
| // msgStructSize returns the size of the Go struct backing desc. | ||
| // It tries the google protobuf v2 registry first (for protoc-gen-go types), | ||
| // then falls back to the gogoproto registry (for protoc-gen-gogofaster types | ||
| // used by Tendermint P2P). Panics if the type is not registered in either | ||
| // registry, since that indicates a programming error. | ||
| func msgStructSize(desc protoreflect.MessageDescriptor) int { | ||
| if desc.IsMapEntry() { | ||
| // Synthetic map-entry types have no standalone Go struct; the runtime map | ||
| // stores keys and values in bucket arrays. Return 0 here; mapEntryOverhead | ||
| // is added per entry at the call site to account for runtime overhead. | ||
| return 0 | ||
| } | ||
| if mt, err := protoregistry.GlobalTypes.FindMessageByName(desc.FullName()); err == nil { | ||
| return int(reflect.TypeOf(mt.Zero().Interface()).Elem().Size()) | ||
| } | ||
| if t := gogoproto.MessageType(string(desc.FullName())); t != nil { | ||
| return int(t.Elem().Size()) | ||
| } | ||
| panic(fmt.Sprintf("protoutils: message type not registered: %s", desc.FullName())) | ||
| } | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.