diff --git a/pkg/collector/process.go b/pkg/collector/process.go index a686447f..48fe2087 100644 --- a/pkg/collector/process.go +++ b/pkg/collector/process.go @@ -31,13 +31,32 @@ import ( "github.com/vmware/go-ipfix/pkg/util" ) +// DecodingMode specifies how unknown information elements (in templates) are handled when decoding. +// Unknown information elements are elements which are not part of the static registry included with +// the library. +// Note that regardless of the DecodingMode, data sets must always match the corresponding template. +type DecodingMode string + +const ( + // DecodingModeStrict will cause decoding to fail when an unknown IE is encountered in a template. + DecodingModeStrict DecodingMode = "Strict" + // DecodingModeLenientKeepUnknown will accept unknown IEs in templates. When decoding the + // corresponding field in a data record, the value will be preserved (as an octet array). + DecodingModeLenientKeepUnknown DecodingMode = "LenientKeepUnknown" + // DecodingModeLenientDropUnknown will accept unknown IEs in templates. When decoding the + // corresponding field in a data record, the value will be dropped (information element will + // not be present in the resulting Record). Be careful when using this mode as the IEs + // included in the resulting Record will no longer match the received template. + DecodingModeLenientDropUnknown DecodingMode = "LenientDropUnknown" +) + type CollectingProcess struct { // for each obsDomainID, there is a map of templates templatesMap map[uint32]map[uint16][]*entities.InfoElement // mutex allows multiple readers or one writer at the same time mutex sync.RWMutex // template lifetime - templateTTL uint32 + templateTTL time.Duration // server information address string // server protocol @@ -57,6 +76,9 @@ type CollectingProcess struct { // numExtraElements specifies number of elements that could be added after // decoding the IPFIX data packet. numExtraElements int + // decodingMode specifies how unknown information elements (in templates) are handled when + // decoding. + decodingMode DecodingMode // caCert, serverCert and serverKey are for storing encryption info when using TLS/DTLS caCert []byte serverCert []byte @@ -80,6 +102,10 @@ type CollectorInput struct { ServerCert []byte ServerKey []byte NumExtraElements int + // DecodingMode specifies how unknown information elements (in templates) are handled when + // decoding. The default value is DecodingModeStrict for historical reasons. For most uses, + // DecodingModeLenientKeepUnknown is the most appropriate mode. + DecodingMode DecodingMode } type clientHandler struct { @@ -88,10 +114,24 @@ type clientHandler struct { } func InitCollectingProcess(input CollectorInput) (*CollectingProcess, error) { + templateTTLSeconds := input.TemplateTTL + if input.Protocol == "udp" && templateTTLSeconds == 0 { + templateTTLSeconds = entities.TemplateTTL + } + templateTTL := time.Duration(templateTTLSeconds) * time.Second + decodingMode := input.DecodingMode + if decodingMode == "" { + decodingMode = DecodingModeStrict + } + klog.InfoS( + "Initializing the collecting process", + "encrypted", input.IsEncrypted, "address", input.Address, "protocol", input.Protocol, "maxBufferSize", input.MaxBufferSize, + "templateTTL", templateTTL, "numExtraElements", input.NumExtraElements, "decodingMode", decodingMode, + ) collectProc := &CollectingProcess{ templatesMap: make(map[uint32]map[uint16][]*entities.InfoElement), mutex: sync.RWMutex{}, - templateTTL: input.TemplateTTL, + templateTTL: templateTTL, address: input.Address, protocol: input.Protocol, maxBufferSize: input.MaxBufferSize, @@ -103,11 +143,13 @@ func InitCollectingProcess(input CollectorInput) (*CollectingProcess, error) { serverCert: input.ServerCert, serverKey: input.ServerKey, numExtraElements: input.NumExtraElements, + decodingMode: decodingMode, } return collectProc, nil } func (cp *CollectingProcess) Start() { + klog.Info("Starting the collecting process") if cp.protocol == "tcp" { cp.startTCPServer() } else if cp.protocol == "udp" { @@ -119,7 +161,7 @@ func (cp *CollectingProcess) Stop() { close(cp.stopChan) // wait for all connections to be safely deleted and returned cp.wg.Wait() - klog.Info("stopping the collecting process") + klog.Info("Stopping the collecting process") } func (cp *CollectingProcess) GetAddress() net.Addr { @@ -228,7 +270,11 @@ func (cp *CollectingProcess) decodeTemplateSet(templateBuffer *bytes.Buffer, obs enterpriseID = registry.IANAEnterpriseID element, err = registry.GetInfoElementFromID(elementID, enterpriseID) if err != nil { - return nil, err + if cp.decodingMode == DecodingModeStrict { + return nil, err + } + klog.InfoS("Template includes an information element that is not present in registry", "obsDomainID", obsDomainID, "templateID", templateID, "enterpriseID", enterpriseID, "elementID", elementID) + element = entities.NewInfoElement("", elementID, entities.OctetArray, enterpriseID, elementLength) } } else { /* @@ -254,7 +300,11 @@ func (cp *CollectingProcess) decodeTemplateSet(templateBuffer *bytes.Buffer, obs elementID = binary.BigEndian.Uint16(elementid) element, err = registry.GetInfoElementFromID(elementID, enterpriseID) if err != nil { - return nil, err + if cp.decodingMode == DecodingModeStrict { + return nil, err + } + klog.InfoS("Template includes an information element that is not present in registry", "obsDomainID", obsDomainID, "templateID", templateID, "enterpriseID", enterpriseID, "elementID", elementID) + element = entities.NewInfoElement("", elementID, entities.OctetArray, enterpriseID, elementLength) } } if elementsWithValue[i], err = entities.DecodeAndCreateInfoElementWithValue(element, nil); err != nil { @@ -281,17 +331,24 @@ func (cp *CollectingProcess) decodeDataSet(dataBuffer *bytes.Buffer, obsDomainID } for dataBuffer.Len() > 0 { - elements := make([]entities.InfoElementWithValue, len(template), len(template)+cp.numExtraElements) - for i, element := range template { + elements := make([]entities.InfoElementWithValue, 0, len(template)+cp.numExtraElements) + for _, ie := range template { var length int - if element.Len == entities.VariableLength { // string + if ie.Len == entities.VariableLength { // string / octet array length = getFieldLength(dataBuffer) } else { - length = int(element.Len) + length = int(ie.Len) } - if elements[i], err = entities.DecodeAndCreateInfoElementWithValue(element, dataBuffer.Next(length)); err != nil { + element, err := entities.DecodeAndCreateInfoElementWithValue(ie, dataBuffer.Next(length)) + if err != nil { return nil, err } + // A missing name means an unknown element was received + if cp.decodingMode == DecodingModeLenientDropUnknown && ie.Name == "" { + klog.V(5).InfoS("Dropping field for unknown information element", "obsDomainID", obsDomainID, "ie", ie) + continue + } + elements = append(elements, element) } err = dataSet.AddRecordV2(elements, templateID) if err != nil { @@ -313,16 +370,12 @@ func (cp *CollectingProcess) addTemplate(obsDomainID uint32, templateID uint16, } cp.templatesMap[obsDomainID][templateID] = elements // template lifetime management - if cp.protocol == "tcp" { + if cp.protocol != "udp" { return } - // Handle udp template expiration - if cp.templateTTL == 0 { - cp.templateTTL = entities.TemplateTTL // Default value - } go func() { - ticker := time.NewTicker(time.Duration(cp.templateTTL) * time.Second) + ticker := time.NewTicker(cp.templateTTL) defer ticker.Stop() select { case <-ticker.C: diff --git a/pkg/collector/process_test.go b/pkg/collector/process_test.go index 25e100a1..c7ba19e8 100644 --- a/pkg/collector/process_test.go +++ b/pkg/collector/process_test.go @@ -19,6 +19,8 @@ import ( "context" "crypto/tls" "crypto/x509" + "encoding/binary" + "fmt" "net" "runtime" "sync" @@ -31,6 +33,7 @@ import ( "k8s.io/apimachinery/pkg/util/wait" "github.com/vmware/go-ipfix/pkg/entities" + "github.com/vmware/go-ipfix/pkg/exporter" "github.com/vmware/go-ipfix/pkg/registry" testcerts "github.com/vmware/go-ipfix/pkg/test/certs" ) @@ -89,6 +92,7 @@ func TestUDPCollectingProcess_ReceiveTemplateRecord(t *testing.T) { t.Fatalf("UDP Collecting Process does not start correctly: %v", err) } go cp.Start() + // wait until collector is ready waitForCollectorReady(t, cp) collectorAddr := cp.GetAddress() @@ -605,6 +609,79 @@ func TestUDPCollectingProcessIPv6(t *testing.T) { assert.Equal(t, net.ParseIP("2001:0:3238:DFE1:63::FEFB"), ie.GetIPAddressValue()) } +// TestUnknownInformationElement validates that message decoding when dealing with unknown IEs (not +// part of the static registry included in this project). All 3 supported decoding modes are tested. +func TestUnknownInformationElement(t *testing.T) { + const ( + templateID = 100 + obsDomainID = 0xabcd + unknownID = 999 + unknownValue = uint32(0x1234) + ) + + for _, enterpriseID := range []uint32{registry.IANAEnterpriseID, registry.AntreaEnterpriseID} { + for _, mode := range []DecodingMode{DecodingModeStrict, DecodingModeLenientKeepUnknown, DecodingModeLenientDropUnknown} { + t.Run(fmt.Sprintf("enterpriseID-%d_%s", enterpriseID, mode), func(t *testing.T) { + input := getCollectorInput(tcpTransport, false, false) + input.DecodingMode = mode + cp, err := InitCollectingProcess(input) + require.NoError(t, err) + defer cp.Stop() + + go func() { // remove the message from the message channel + for range cp.GetMsgChan() { + } + }() + + // First, send template set. + + unknownIE := entities.NewInfoElement("foo", unknownID, entities.Unsigned32, enterpriseID, 4) + knownIE1, _ := registry.GetInfoElement("octetDeltaCount", registry.IANAEnterpriseID) + knownIE2, _ := registry.GetInfoElement("sourceNodeName", registry.AntreaEnterpriseID) + templateSet, err := entities.MakeTemplateSet(templateID, []*entities.InfoElement{knownIE1, unknownIE, knownIE2}) + require.NoError(t, err) + templateBytes, err := exporter.CreateIPFIXMsg(templateSet, obsDomainID, 0 /* seqNumber */, time.Now()) + require.NoError(t, err) + _, err = cp.decodePacket(bytes.NewBuffer(templateBytes), "1.2.3.4:12345") + // If decoding is strict, there will be an error and we need to stop the test. + if mode == DecodingModeStrict { + require.Error(t, err) + return + } + require.NoError(t, err) + + // Second, send data set. + + unknownIEWithValue := entities.NewUnsigned32InfoElement(unknownIE, unknownValue) + knownIE1WithValue := entities.NewUnsigned64InfoElement(knownIE1, 0x100) + knownIE2WithValue := entities.NewStringInfoElement(knownIE2, "node-1") + dataSet, err := entities.MakeDataSet(templateID, []entities.InfoElementWithValue{knownIE1WithValue, unknownIEWithValue, knownIE2WithValue}) + require.NoError(t, err) + dataBytes, err := exporter.CreateIPFIXMsg(dataSet, obsDomainID, 1 /* seqNumber */, time.Now()) + require.NoError(t, err) + msg, err := cp.decodePacket(bytes.NewBuffer(dataBytes), "1.2.3.4:12345") + require.NoError(t, err) + records := msg.GetSet().GetRecords() + require.Len(t, records, 1) + record := records[0] + ies := record.GetOrderedElementList() + + if mode == DecodingModeLenientKeepUnknown { + require.Len(t, ies, 3) + // the unknown IE after decoding + ieWithValue := ies[1] + // the decoded IE has no name and the type always defaults to OctetArray + require.Equal(t, entities.NewInfoElement("", unknownID, entities.OctetArray, enterpriseID, 4), ieWithValue.GetInfoElement()) + value := ieWithValue.GetOctetArrayValue() + assert.Equal(t, unknownValue, binary.BigEndian.Uint32(value)) + } else if mode == DecodingModeLenientDropUnknown { + require.Len(t, ies, 2) + } + }) + } + } +} + func getCollectorInput(network string, isEncrypted bool, isIPv6 bool) CollectorInput { if network == tcpTransport { var address string diff --git a/pkg/entities/ie.go b/pkg/entities/ie.go index e18428fb..3b3190bd 100644 --- a/pkg/entities/ie.go +++ b/pkg/entities/ie.go @@ -162,6 +162,8 @@ func decodeToIEDataType(dataType IEDataType, val interface{}) (interface{}, erro return nil, fmt.Errorf("error when converting value to []bytes for decoding") } switch dataType { + case OctetArray: + return value, nil case Unsigned8: return value[0], nil case Unsigned16: @@ -211,6 +213,12 @@ func decodeToIEDataType(dataType IEDataType, val interface{}) (interface{}, erro // returns appropriate InfoElementWithValue. func DecodeAndCreateInfoElementWithValue(element *InfoElement, value []byte) (InfoElementWithValue, error) { switch element.DataType { + case OctetArray: + var val []byte + if value != nil { + val = append(val, value...) + } + return NewOctetArrayInfoElement(element, val), nil case Unsigned8: var val uint8 if value == nil { @@ -354,6 +362,10 @@ func DecodeAndCreateInfoElementWithValue(element *InfoElement, value []byte) (In // used for testing. func EncodeToIEDataType(dataType IEDataType, val interface{}) ([]byte, error) { switch dataType { + case OctetArray: + // Supporting the type properly would require knowing whether we are dealing with a + // fixed-length or variable-length element. + return nil, fmt.Errorf("octet array data type not supported by this method yet") case Unsigned8: v, ok := val.(uint8) if !ok { @@ -521,6 +533,25 @@ func encodeInfoElementValueToBuff(element InfoElementWithValue, buffer []byte, i return fmt.Errorf("buffer size is not enough for encoding") } switch element.GetDataType() { + case OctetArray: + v := element.GetOctetArrayValue() + ieLen := element.GetInfoElement().Len + if ieLen < VariableLength { + // fixed length case + if len(v) != int(ieLen) { + return fmt.Errorf("invalid value for fixed-length octet array: length mismatch") + } + copy(buffer[index:], v) + } else if len(v) < 255 { + buffer[index] = uint8(len(v)) + copy(buffer[index+1:], v) + } else if len(v) <= math.MaxUint16 { + buffer[index] = byte(255) // marker byte for long array + binary.BigEndian.PutUint16(buffer[index+1:index+3], uint16(len(v))) + copy(buffer[index+3:], v) + } else { + return fmt.Errorf("provided OctetArray value is too long and cannot be encoded: len=%d, maxlen=%d", len(v), math.MaxUint16) + } case Unsigned8: copy(buffer[index:index+1], []byte{element.GetUnsigned8Value()}) case Unsigned16: diff --git a/pkg/entities/ie_test.go b/pkg/entities/ie_test.go index 43ae6a02..36e3c978 100644 --- a/pkg/entities/ie_test.go +++ b/pkg/entities/ie_test.go @@ -78,7 +78,7 @@ func TestNewInfoElementWithValue(t *testing.T) { func BenchmarkEncodeInfoElementValueToBuffShortString(b *testing.B) { // a short string has a max length of 254 str := strings.Repeat("x", 128) - element := NewStringInfoElement(NewInfoElement("interfaceDescription", 83, 13, 0, 65535), str) + element := NewStringInfoElement(NewInfoElement("interfaceDescription", 83, 13, 0, VariableLength), str) const numCopies = 1000 length := element.GetLength() buffer := make([]byte, numCopies*length) @@ -95,7 +95,7 @@ func BenchmarkEncodeInfoElementValueToBuffShortString(b *testing.B) { func BenchmarkEncodeInfoElementValueToBuffLongString(b *testing.B) { // a long string has a max length of 65535 str := strings.Repeat("x", 10000) - element := NewStringInfoElement(NewInfoElement("interfaceDescription", 83, 13, 0, 65535), str) + element := NewStringInfoElement(NewInfoElement("interfaceDescription", 83, 13, 0, VariableLength), str) const numCopies = 1000 length := element.GetLength() buffer := make([]byte, numCopies*length) @@ -108,3 +108,43 @@ func BenchmarkEncodeInfoElementValueToBuffLongString(b *testing.B) { } } } + +func TestEncodeInfoElementValueToBuffOctetArray(t *testing.T) { + shortArray := make([]byte, 128) + longArray := make([]byte, 10000) + testCases := []struct { + name string + ieLen uint16 + array []byte + expectedBuffer []byte + }{ + { + name: "fixed length", + ieLen: uint16(len(shortArray)), + array: shortArray, + expectedBuffer: shortArray, + }, + { + name: "variable length - short", + ieLen: VariableLength, + array: shortArray, + expectedBuffer: append([]byte{128}, shortArray...), + }, + { + name: "variable length - long", + ieLen: VariableLength, + array: longArray, + // 10000 is 0x2710 + expectedBuffer: append([]byte{255, 0x27, 0x10}, longArray...), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + element := NewOctetArrayInfoElement(NewInfoElement("", 999, OctetArray, 56506, tc.ieLen), tc.array) + buffer := make([]byte, len(tc.expectedBuffer)) + require.NoError(t, encodeInfoElementValueToBuff(element, buffer, 0)) + assert.Equal(t, tc.expectedBuffer, buffer) + }) + } +} diff --git a/pkg/entities/ie_value.go b/pkg/entities/ie_value.go index fef66b03..9c24dc9c 100644 --- a/pkg/entities/ie_value.go +++ b/pkg/entities/ie_value.go @@ -11,6 +11,7 @@ type InfoElementWithValue interface { // TODO: Handle error to make it more robust if it is called prior to AddInfoElement. GetInfoElement() *InfoElement AddInfoElement(infoElement *InfoElement) + GetOctetArrayValue() []byte GetUnsigned8Value() uint8 GetUnsigned16Value() uint16 GetUnsigned32Value() uint32 @@ -25,6 +26,7 @@ type InfoElementWithValue interface { GetMacAddressValue() net.HardwareAddr GetStringValue() string GetIPAddressValue() net.IP + SetOctetArrayValue(val []byte) SetUnsigned8Value(val uint8) SetUnsigned16Value(val uint16) SetUnsigned32Value(val uint32) @@ -64,6 +66,10 @@ func (b *baseInfoElement) AddInfoElement(infoElement *InfoElement) { b.element = infoElement } +func (b *baseInfoElement) GetOctetArrayValue() []byte { + panic("accessing value of wrong data type") +} + func (b *baseInfoElement) GetUnsigned8Value() uint8 { panic("accessing value of wrong data type") } @@ -120,6 +126,10 @@ func (b *baseInfoElement) GetIPAddressValue() net.IP { panic("accessing value of wrong data type") } +func (b *baseInfoElement) SetOctetArrayValue(val []byte) { + panic("setting value with wrong data type") +} + func (b *baseInfoElement) SetUnsigned8Value(val uint8) { panic("setting value with wrong data type") } @@ -180,6 +190,46 @@ func (b *baseInfoElement) GetLength() int { return int(b.element.Len) } +type OctetArrayInfoElement struct { + value []byte + baseInfoElement +} + +func NewOctetArrayInfoElement(element *InfoElement, val []byte) *OctetArrayInfoElement { + infoElem := &OctetArrayInfoElement{ + value: val, + } + infoElem.element = element + return infoElem +} + +func (a *OctetArrayInfoElement) GetOctetArrayValue() []byte { + return a.value +} + +func (a *OctetArrayInfoElement) GetLength() int { + if a.element.Len < VariableLength { + return int(a.element.Len) + } + if len(a.value) < 255 { + return len(a.value) + 1 + } else { + return len(a.value) + 3 + } +} + +func (a *OctetArrayInfoElement) SetOctetArrayValue(val []byte) { + a.value = val +} + +func (a *OctetArrayInfoElement) IsValueEmpty() bool { + return a.value == nil +} + +func (a *OctetArrayInfoElement) ResetValue() { + a.value = nil +} + type Unsigned8InfoElement struct { value uint8 baseInfoElement diff --git a/pkg/entities/set.go b/pkg/entities/set.go index 24961c1b..c8efb54b 100644 --- a/pkg/entities/set.go +++ b/pkg/entities/set.go @@ -68,7 +68,7 @@ type set struct { length int } -func NewSet(isDecoding bool) Set { +func NewSet(isDecoding bool) *set { if isDecoding { return &set{ records: make([]Record, 0), @@ -189,3 +189,34 @@ func (s *set) createHeader(setType ContentType, templateID uint16) { binary.BigEndian.PutUint16(s.headerBuffer[0:2], templateID) } } + +// MakeTemplateSet is a convenience function which creates a template Set with a single Record. +func MakeTemplateSet(templateID uint16, ies []*InfoElement) (*set, error) { + tempSet := NewSet(false) + if err := tempSet.PrepareSet(Template, templateID); err != nil { + return nil, err + } + elements := make([]InfoElementWithValue, len(ies)) + for idx, ie := range ies { + var err error + if elements[idx], err = DecodeAndCreateInfoElementWithValue(ie, nil); err != nil { + return nil, err + } + } + if err := tempSet.AddRecord(elements, templateID); err != nil { + return nil, err + } + return tempSet, nil +} + +// MakeDataSet is a convenience function which creates a data Set with a single Record. +func MakeDataSet(templateID uint16, ies []InfoElementWithValue) (*set, error) { + dataSet := NewSet(false) + if err := dataSet.PrepareSet(Data, templateID); err != nil { + return nil, err + } + if err := dataSet.AddRecord(ies, templateID); err != nil { + return nil, err + } + return dataSet, nil +} diff --git a/pkg/exporter/process.go b/pkg/exporter/process.go index 0f2256ad..1f1c0459 100644 --- a/pkg/exporter/process.go +++ b/pkg/exporter/process.go @@ -428,18 +428,7 @@ func (ep *ExportingProcess) sendRefreshedTemplates() error { ep.templateMutex.Lock() for templateID, tempValue := range ep.templatesMap { - tempSet := entities.NewSet(false) - if err := tempSet.PrepareSet(entities.Template, templateID); err != nil { - return err - } - elements := make([]entities.InfoElementWithValue, len(tempValue.elements)) - var err error - for i, element := range tempValue.elements { - if elements[i], err = entities.DecodeAndCreateInfoElementWithValue(element, nil); err != nil { - return err - } - } - err = tempSet.AddRecord(elements, templateID) + tempSet, err := entities.MakeTemplateSet(templateID, tempValue.elements) if err != nil { return err } diff --git a/pkg/test/util.go b/pkg/test/util.go index ae2ac878..d188eee8 100644 --- a/pkg/test/util.go +++ b/pkg/test/util.go @@ -159,18 +159,15 @@ func getTestRecord(isSrcNode, isIPv6 bool, options ...testRecordOptions) *testRe } func createTemplateSet(templateID uint16, isIPv6 bool) entities.Set { - templateSet := entities.NewSet(false) - templateSet.PrepareSet(entities.Template, templateID) - elements := make([]entities.InfoElementWithValue, 0, numFields) + ies := make([]*entities.InfoElement, 0, numFields) ianaFields := ianaIPv4Fields if isIPv6 { ianaFields = ianaIPv6Fields } ianaFields = append(ianaFields, commonFields...) for _, name := range ianaFields { - element, _ := registry.GetInfoElement(name, registry.IANAEnterpriseID) - ie, _ := entities.DecodeAndCreateInfoElementWithValue(element, nil) - elements = append(elements, ie) + ie, _ := registry.GetInfoElement(name, registry.IANAEnterpriseID) + ies = append(ies, ie) } antreaFields := antreaCommonFields if !isIPv6 { @@ -179,16 +176,14 @@ func createTemplateSet(templateID uint16, isIPv6 bool) entities.Set { antreaFields = append(antreaFields, antreaIPv6...) } for _, name := range antreaFields { - element, _ := registry.GetInfoElement(name, registry.AntreaEnterpriseID) - ie, _ := entities.DecodeAndCreateInfoElementWithValue(element, nil) - elements = append(elements, ie) + ie, _ := registry.GetInfoElement(name, registry.AntreaEnterpriseID) + ies = append(ies, ie) } for _, name := range reverseFields { - element, _ := registry.GetInfoElement(name, registry.IANAReversedEnterpriseID) - ie, _ := entities.DecodeAndCreateInfoElementWithValue(element, nil) - elements = append(elements, ie) + ie, _ := registry.GetInfoElement(name, registry.IANAReversedEnterpriseID) + ies = append(ies, ie) } - templateSet.AddRecord(elements, templateID) + templateSet, _ := entities.MakeTemplateSet(templateID, ies) return templateSet }