diff --git a/pkg/collector/process.go b/pkg/collector/process.go index b247eaf..1f1406e 100644 --- a/pkg/collector/process.go +++ b/pkg/collector/process.go @@ -261,12 +261,7 @@ func (cp *CollectingProcess) decodeTemplateSet(templateBuffer *bytes.Buffer, obs return nil, err } - templateSet := entities.NewSet(true) - if err := templateSet.PrepareSet(entities.Template, templateID); err != nil { - return nil, err - } - elementsWithValue := make([]entities.InfoElementWithValue, int(fieldCount)) - for i := 0; i < int(fieldCount); i++ { + decodeField := func() (entities.InfoElementWithValue, error) { var element *entities.InfoElement var enterpriseID uint32 var elementID uint16 @@ -320,12 +315,37 @@ func (cp *CollectingProcess) decodeTemplateSet(templateBuffer *bytes.Buffer, obs element = entities.NewInfoElement("", elementID, entities.OctetArray, enterpriseID, elementLength) } } - if elementsWithValue[i], err = entities.DecodeAndCreateInfoElementWithValue(element, nil); err != nil { - return nil, err - } + + return entities.DecodeAndCreateInfoElementWithValue(element, nil) } - err := templateSet.AddRecordV2(elementsWithValue, templateID) + + elementsWithValue, err := func() ([]entities.InfoElementWithValue, error) { + elementsWithValue := make([]entities.InfoElementWithValue, int(fieldCount)) + for i := range fieldCount { + elementWithValue, err := decodeField() + if err != nil { + return nil, err + } + elementsWithValue[i] = elementWithValue + } + return elementsWithValue, nil + }() if err != nil { + // Delete existing template (if one exists) from template map if the new one is invalid. + // This is particularly useful for UDP collection, as there is no feedback mechanism + // to let the sender know that the new template is invalid (while with TCP, we can close + // the connection). If we keep the old template and the sender sends data records + // which use the new template, we would try to decode them according to the old + // template, which would cause issues. + cp.deleteTemplate(obsDomainID, templateID) + return nil, err + } + + templateSet := entities.NewSet(true) + if err := templateSet.PrepareSet(entities.Template, templateID); err != nil { + return nil, err + } + if err := templateSet.AddRecordV2(elementsWithValue, templateID); err != nil { return nil, err } cp.addTemplate(obsDomainID, templateID, elementsWithValue) diff --git a/pkg/collector/process_test.go b/pkg/collector/process_test.go index 360cf31..5f07efc 100644 --- a/pkg/collector/process_test.go +++ b/pkg/collector/process_test.go @@ -349,43 +349,97 @@ func TestUDPCollectingProcess_DecodePacketError(t *testing.T) { } func TestCollectingProcess_DecodeTemplateRecord(t *testing.T) { - cp := CollectingProcess{} - cp.templatesMap = make(map[uint32]map[uint16]*template) - cp.mutex = sync.RWMutex{} - address, err := net.ResolveTCPAddr(tcpTransport, hostPortIPv4) - if err != nil { - t.Error(err) + // This is the observation domain ID used by test template records + const obsDomainID = uint32(1) + // This is the template ID used by test template records + const templateID = uint16(256) + + testCases := []struct { + name string + existingTemplates map[uint32]map[uint16]*template + templateRecord []byte + expectedErr string + // whether an entry is expected in the templates map after decoding the packet + isTemplateExpected bool + }{ + { + name: "valid template", + existingTemplates: map[uint32]map[uint16]*template{}, + templateRecord: validTemplatePacket, + isTemplateExpected: true, + }, + { + name: "invalid version", + existingTemplates: map[uint32]map[uint16]*template{ + obsDomainID: { + templateID: &template{}, + }, + }, + templateRecord: []byte{0, 9, 0, 40, 95, 40, 211, 236, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 24, 1, 0, 0, 3, 0, 8, 0, 4, 0, 12, 0, 4, 128, 105, 255, 255, 0, 0, 218, 21}, + expectedErr: "collector only supports IPFIX (v10)", + // Invalid version means we stop decoding the packet right away, so we will not modify the existing template map + isTemplateExpected: true, + }, + { + name: "malformed record fields", + existingTemplates: map[uint32]map[uint16]*template{ + obsDomainID: { + templateID: &template{}, + }, + }, + templateRecord: []byte{0, 10, 0, 40, 95, 40, 211, 236, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 24, 1, 0, 0, 3, 0, 8, 0, 4, 0, 12, 0, 4, 128, 105, 255, 255, 0, 0}, + expectedErr: "error in decoding data", + isTemplateExpected: false, + }, + { + name: "malformed record header", + existingTemplates: map[uint32]map[uint16]*template{ + obsDomainID: { + templateID: &template{}, + }, + }, + // We truncate the record header (3 bytes instead of 4) + templateRecord: []byte{0, 10, 0, 40, 95, 154, 107, 127, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 24, 1, 0, 0}, + expectedErr: "error in decoding data", + // If we cannot decode the message to get a template ID, then the existing template entry will not be removed + isTemplateExpected: true, + }, } - cp.netAddress = address - cp.messageChan = make(chan *entities.Message) - go func() { // remove the message from the message channel - for range cp.GetMsgChan() { - } - }() - message, err := cp.decodePacket(bytes.NewBuffer(validTemplatePacket), address.String()) - if err != nil { - t.Fatalf("Got error in decoding template record: %v", err) - } - assert.Equal(t, uint16(10), message.GetVersion(), "Flow record version should be 10.") - assert.Equal(t, uint32(1), message.GetObsDomainID(), "Flow record obsDomainID should be 1.") - assert.NotNil(t, cp.templatesMap[message.GetObsDomainID()], "Template should be stored in template map") - templateSet := message.GetSet() - assert.NotNil(t, templateSet, "Template record should be stored in message flowset") - sourceIPv4Address, _, exist := templateSet.GetRecords()[0].GetInfoElementWithValue("sourceIPv4Address") - assert.Equal(t, true, exist) - assert.Equal(t, uint32(0), sourceIPv4Address.GetInfoElement().EnterpriseId, "Template record is not stored correctly.") - // Invalid version - templateRecord := []byte{0, 9, 0, 40, 95, 40, 211, 236, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 24, 1, 0, 0, 3, 0, 8, 0, 4, 0, 12, 0, 4, 128, 105, 255, 255, 0, 0, 218, 21} - _, err = cp.decodePacket(bytes.NewBuffer(templateRecord), address.String()) - assert.NotNil(t, err, "Error should be logged for invalid version") - // Malformed record - templateRecord = []byte{0, 10, 0, 40, 95, 40, 211, 236, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 24, 1, 0, 0, 3, 0, 8, 0, 4, 0, 12, 0, 4, 128, 105, 255, 255, 0, 0} - cp.templatesMap = make(map[uint32]map[uint16]*template) - _, err = cp.decodePacket(bytes.NewBuffer(templateRecord), address.String()) - assert.NotNil(t, err, "Error should be logged for malformed template record") - if _, exist := cp.templatesMap[uint32(1)]; exist { - t.Fatal("Template should not be stored for malformed template record") + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + cp := CollectingProcess{} + cp.templatesMap = tc.existingTemplates + cp.mutex = sync.RWMutex{} + address, err := net.ResolveTCPAddr(tcpTransport, hostPortIPv4) + require.NoError(t, err) + cp.netAddress = address + cp.messageChan = make(chan *entities.Message) + go func() { // remove the message from the message channel + for range cp.GetMsgChan() { + } + }() + message, err := cp.decodePacket(bytes.NewBuffer(tc.templateRecord), address.String()) + if tc.expectedErr != "" { + require.ErrorContains(t, err, tc.expectedErr) + } else { + require.NoError(t, err, "failed to decode template record") + + assert.Equal(t, uint16(10), message.GetVersion(), "Unexpected IPFIX version in message") + assert.Equal(t, obsDomainID, message.GetObsDomainID(), "Unexpected obsDomainID in message") + + templateSet := message.GetSet() + assert.NotNil(t, templateSet, "Template record should be stored in message flowset") + sourceIPv4Address, _, exist := templateSet.GetRecords()[0].GetInfoElementWithValue("sourceIPv4Address") + assert.Equal(t, true, exist) + assert.Equal(t, uint32(0), sourceIPv4Address.GetInfoElement().EnterpriseId, "Template record is not stored correctly.") + } + if tc.isTemplateExpected { + assert.NotNil(t, cp.templatesMap[obsDomainID][templateID], "Template should be stored in template map") + } else { + assert.Nil(t, cp.templatesMap[obsDomainID][templateID], "Template should not be stored in template map") + } + }) } }