From eb1519b0b4d97665d4a30a91bbc686edcac13176 Mon Sep 17 00:00:00 2001 From: Josh Humphries Date: Tue, 27 Jun 2017 17:57:33 -0400 Subject: [PATCH] unmarshal unknown extensions into XXX_unrecognized instead of into extension map --- proto/decode.go | 23 ++++++++++--- proto/extensions.go | 8 +++-- proto/extensions_test.go | 71 +++++++++++++++++++++++++++++++++++----- 3 files changed, 85 insertions(+), 17 deletions(-) diff --git a/proto/decode.go b/proto/decode.go index aa207298f9..ee144df013 100644 --- a/proto/decode.go +++ b/proto/decode.go @@ -462,6 +462,8 @@ func (p *Buffer) Unmarshal(pb Message) error { func (o *Buffer) unmarshalType(st reflect.Type, prop *StructProperties, is_group bool, base structPointer) error { var state errorState required, reqFields := prop.reqCount, uint64(0) + var regExt map[int32]*ExtensionDesc + regExtInit := false var err error for err == nil && o.index < len(o.buf) { @@ -492,11 +494,22 @@ func (o *Buffer) unmarshalType(st reflect.Type, prop *StructProperties, is_group // Maybe it's an extension? if prop.extendable { if e, _ := extendable(structPointer_Interface(base, st)); isExtensionField(e, int32(tag)) { - if err = o.skip(st, tag, wire); err == nil { - extmap := e.extensionsWrite() - ext := extmap[int32(tag)] // may be missing - ext.enc = append(ext.enc, o.buf[oi:o.index]...) - extmap[int32(tag)] = ext + if !regExtInit { + msgType := reflect.Zero(reflect.PtrTo(st)).Interface().(Message) + regExt = RegisteredExtensions(msgType) + } + extdesc := regExt[int32(tag)] + if extdesc == nil { + // unknown extension + err = o.skipAndSave(st, tag, wire, base, prop.unrecField) + } else { + if err = o.skip(st, tag, wire); err == nil { + extmap := e.extensionsWrite() + ext := extmap[int32(tag)] // may be missing + ext.enc = append(ext.enc, o.buf[oi:o.index]...) + ext.desc = extdesc + extmap[int32(tag)] = ext + } } continue } diff --git a/proto/extensions.go b/proto/extensions.go index eaad218312..e5665ee5ae 100644 --- a/proto/extensions.go +++ b/proto/extensions.go @@ -491,8 +491,10 @@ func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, e } // ExtensionDescs returns a new slice containing pb's extension descriptors, in undefined order. -// For non-registered extensions, ExtensionDescs returns an incomplete descriptor containing -// just the Field field, which defines the extension's field number. +// If the message was de-serialized from a stream that referenced unknown extensions (e.g. fields +// with a tag number in an extension range, but not registered), they will not be returned by +// this function. Instead, they can only be found by examining the message's XXX_unrecognized +// data. func ExtensionDescs(pb Message) ([]*ExtensionDesc, error) { epb, ok := extendable(pb) if !ok { @@ -512,7 +514,7 @@ func ExtensionDescs(pb Message) ([]*ExtensionDesc, error) { if desc == nil { desc = registeredExtensions[extid] if desc == nil { - desc = &ExtensionDesc{Field: extid} + continue } } diff --git a/proto/extensions_test.go b/proto/extensions_test.go index b6d9114c56..9668ae625c 100644 --- a/proto/extensions_test.go +++ b/proto/extensions_test.go @@ -34,6 +34,7 @@ package proto_test import ( "bytes" "fmt" + "math" "reflect" "sort" "testing" @@ -41,6 +42,7 @@ import ( "github.com/golang/protobuf/proto" pb "github.com/golang/protobuf/proto/testdata" "golang.org/x/sync/errgroup" + "io" ) func TestGetExtensionsWithMissingExtensions(t *testing.T) { @@ -64,27 +66,40 @@ func TestGetExtensionsWithMissingExtensions(t *testing.T) { } } -func TestExtensionDescsWithMissingExtensions(t *testing.T) { +func TestExtensionDescsWithUnrecognizedExtensions(t *testing.T) { msg := &pb.MyMessage{Count: proto.Int32(0)} - extdesc1 := pb.E_Ext_More if descs, err := proto.ExtensionDescs(msg); len(descs) != 0 || err != nil { t.Errorf("proto.ExtensionDescs: got %d descs, error %v; want 0, nil", len(descs), err) } - ext1 := &pb.Ext{} - if err := proto.SetExtension(msg, extdesc1, ext1); err != nil { + extdesc1 := pb.E_Ext_More + if err := proto.SetExtension(msg, extdesc1, &pb.Ext{}); err != nil { t.Fatalf("Could not set ext1: %s", err) } - extdesc2 := &proto.ExtensionDesc{ + extdesc2 := pb.E_Ext_Number + if err := proto.SetExtension(msg, extdesc2, proto.Int32(int32(101))); err != nil { + t.Fatalf("Could not set ext2: %s", err) + } + + unknownExtdesc1 := &proto.ExtensionDesc{ ExtendedType: (*pb.MyMessage)(nil), ExtensionType: (*bool)(nil), Field: 123456789, Name: "a.b", Tag: "varint,123456789,opt", } - ext2 := proto.Bool(false) - if err := proto.SetExtension(msg, extdesc2, ext2); err != nil { - t.Fatalf("Could not set ext2: %s", err) + if err := proto.SetExtension(msg, unknownExtdesc1, proto.Bool(true)); err != nil { + t.Fatalf("Could not set unknownExtdesc1: %s", err) + } + unknownExtdesc2 := &proto.ExtensionDesc{ + ExtendedType: (*pb.MyMessage)(nil), + ExtensionType: (*float64)(nil), + Field: 123456790, + Name: "a.c", + Tag: "fixed64,123456790,opt", + } + if err := proto.SetExtension(msg, unknownExtdesc2, proto.Float64(12.34)); err != nil { + t.Fatalf("Could not set unknownExtdesc2: %s", err) } b, err := proto.Marshal(msg) @@ -100,10 +115,48 @@ func TestExtensionDescsWithMissingExtensions(t *testing.T) { t.Fatalf("proto.ExtensionDescs: got error %v", err) } sortExtDescs(descs) - wantDescs := []*proto.ExtensionDesc{extdesc1, &proto.ExtensionDesc{Field: extdesc2.Field}} + wantDescs := []*proto.ExtensionDesc{extdesc1, extdesc2} if !reflect.DeepEqual(descs, wantDescs) { t.Errorf("proto.ExtensionDescs(msg) sorted extension ids: got %+v, want %+v", descs, wantDescs) } + + // make sure the unrecognized fields are serialized correctly + bb := proto.NewBuffer(msg.XXX_unrecognized) + + // unrecognized extension #1 + expectedTagAndWire := uint64((unknownExtdesc1.Field << 3) | proto.WireVarint) + if tagAndWire, err := bb.DecodeVarint(); err != nil { + t.Fatalf("Could not read unrecognized field tag and wire type: %v", err) + } else if tagAndWire != expectedTagAndWire { + t.Fatalf("Wrong tag and wire type: %d != %d", tagAndWire, expectedTagAndWire) + } + if val, err := bb.DecodeVarint(); err != nil { + t.Fatalf("Could not read unrecognized field value: %v", err) + } else if val != 1 /* varint value of bool "true" */ { + t.Fatalf("Wrong value for unrecognized extension 1: %d != 1", val) + } + + // unrecognized extension #2 + expectedTagAndWire = uint64((unknownExtdesc2.Field << 3) | proto.WireFixed64) + if tagAndWire, err := bb.DecodeVarint(); err != nil { + t.Fatalf("Could not read unrecognized field tag and wire type: %v", err) + } else if tagAndWire != expectedTagAndWire { + t.Fatalf("Wrong tag and wire type: %d != %d", tagAndWire, expectedTagAndWire) + } + if val, err := bb.DecodeFixed64(); err != nil { + t.Fatalf("Could not read unrecognized field value: %v", err) + } else if math.Float64frombits(val) != 12.34 { + t.Fatalf("Wrong value for unrecognized extension 1: %f != 12.34", math.Float64frombits(val)) + } + + // we should have reached EOF of the unknown fields + if _, err := bb.DecodeFixed32(); err != io.ErrUnexpectedEOF { + if err == nil { + t.Fatalf("Unexpected unrecognized data after expected extensions") + } else { + t.Fatalf("Unexpected error checking for buffer EOF: %v", err) + } + } } type ExtensionDescSlice []*proto.ExtensionDesc