From eb9c359d7145bde1a65a81868810a85b71665248 Mon Sep 17 00:00:00 2001 From: Antonin Bas Date: Fri, 22 Nov 2024 12:12:47 -0800 Subject: [PATCH] Delete existing template if a new invalid one is received (#383) This is particularly useful for UDP collection, as there is no feedback mechanism to let the sender know that the new template is invalid (while with TCP, we can close the connection). If we keep the old template and the sender sends data records which use the new template, we would try to decode them according to the old template, which would cause issues. Instead we will ignore data records for that observation domain and template ID until a new valid template is received. Signed-off-by: Antonin Bas --- pkg/collector/process.go | 40 ++++++++--- pkg/collector/process_test.go | 124 ++++++++++++++++++++++++---------- 2 files changed, 119 insertions(+), 45 deletions(-) diff --git a/pkg/collector/process.go b/pkg/collector/process.go index b247eaf..1f1406e 100644 --- a/pkg/collector/process.go +++ b/pkg/collector/process.go @@ -261,12 +261,7 @@ func (cp *CollectingProcess) decodeTemplateSet(templateBuffer *bytes.Buffer, obs return nil, err } - templateSet := entities.NewSet(true) - if err := templateSet.PrepareSet(entities.Template, templateID); err != nil { - return nil, err - } - elementsWithValue := make([]entities.InfoElementWithValue, int(fieldCount)) - for i := 0; i < int(fieldCount); i++ { + decodeField := func() (entities.InfoElementWithValue, error) { var element *entities.InfoElement var enterpriseID uint32 var elementID uint16 @@ -320,12 +315,37 @@ func (cp *CollectingProcess) decodeTemplateSet(templateBuffer *bytes.Buffer, obs element = entities.NewInfoElement("", elementID, entities.OctetArray, enterpriseID, elementLength) } } - if elementsWithValue[i], err = entities.DecodeAndCreateInfoElementWithValue(element, nil); err != nil { - return nil, err - } + + return entities.DecodeAndCreateInfoElementWithValue(element, nil) } - err := templateSet.AddRecordV2(elementsWithValue, templateID) + + elementsWithValue, err := func() ([]entities.InfoElementWithValue, error) { + elementsWithValue := make([]entities.InfoElementWithValue, int(fieldCount)) + for i := range fieldCount { + elementWithValue, err := decodeField() + if err != nil { + return nil, err + } + elementsWithValue[i] = elementWithValue + } + return elementsWithValue, nil + }() if err != nil { + // Delete existing template (if one exists) from template map if the new one is invalid. + // This is particularly useful for UDP collection, as there is no feedback mechanism + // to let the sender know that the new template is invalid (while with TCP, we can close + // the connection). If we keep the old template and the sender sends data records + // which use the new template, we would try to decode them according to the old + // template, which would cause issues. + cp.deleteTemplate(obsDomainID, templateID) + return nil, err + } + + templateSet := entities.NewSet(true) + if err := templateSet.PrepareSet(entities.Template, templateID); err != nil { + return nil, err + } + if err := templateSet.AddRecordV2(elementsWithValue, templateID); err != nil { return nil, err } cp.addTemplate(obsDomainID, templateID, elementsWithValue) diff --git a/pkg/collector/process_test.go b/pkg/collector/process_test.go index 360cf31..5f07efc 100644 --- a/pkg/collector/process_test.go +++ b/pkg/collector/process_test.go @@ -349,43 +349,97 @@ func TestUDPCollectingProcess_DecodePacketError(t *testing.T) { } func TestCollectingProcess_DecodeTemplateRecord(t *testing.T) { - cp := CollectingProcess{} - cp.templatesMap = make(map[uint32]map[uint16]*template) - cp.mutex = sync.RWMutex{} - address, err := net.ResolveTCPAddr(tcpTransport, hostPortIPv4) - if err != nil { - t.Error(err) + // This is the observation domain ID used by test template records + const obsDomainID = uint32(1) + // This is the template ID used by test template records + const templateID = uint16(256) + + testCases := []struct { + name string + existingTemplates map[uint32]map[uint16]*template + templateRecord []byte + expectedErr string + // whether an entry is expected in the templates map after decoding the packet + isTemplateExpected bool + }{ + { + name: "valid template", + existingTemplates: map[uint32]map[uint16]*template{}, + templateRecord: validTemplatePacket, + isTemplateExpected: true, + }, + { + name: "invalid version", + existingTemplates: map[uint32]map[uint16]*template{ + obsDomainID: { + templateID: &template{}, + }, + }, + templateRecord: []byte{0, 9, 0, 40, 95, 40, 211, 236, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 24, 1, 0, 0, 3, 0, 8, 0, 4, 0, 12, 0, 4, 128, 105, 255, 255, 0, 0, 218, 21}, + expectedErr: "collector only supports IPFIX (v10)", + // Invalid version means we stop decoding the packet right away, so we will not modify the existing template map + isTemplateExpected: true, + }, + { + name: "malformed record fields", + existingTemplates: map[uint32]map[uint16]*template{ + obsDomainID: { + templateID: &template{}, + }, + }, + templateRecord: []byte{0, 10, 0, 40, 95, 40, 211, 236, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 24, 1, 0, 0, 3, 0, 8, 0, 4, 0, 12, 0, 4, 128, 105, 255, 255, 0, 0}, + expectedErr: "error in decoding data", + isTemplateExpected: false, + }, + { + name: "malformed record header", + existingTemplates: map[uint32]map[uint16]*template{ + obsDomainID: { + templateID: &template{}, + }, + }, + // We truncate the record header (3 bytes instead of 4) + templateRecord: []byte{0, 10, 0, 40, 95, 154, 107, 127, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 24, 1, 0, 0}, + expectedErr: "error in decoding data", + // If we cannot decode the message to get a template ID, then the existing template entry will not be removed + isTemplateExpected: true, + }, } - cp.netAddress = address - cp.messageChan = make(chan *entities.Message) - go func() { // remove the message from the message channel - for range cp.GetMsgChan() { - } - }() - message, err := cp.decodePacket(bytes.NewBuffer(validTemplatePacket), address.String()) - if err != nil { - t.Fatalf("Got error in decoding template record: %v", err) - } - assert.Equal(t, uint16(10), message.GetVersion(), "Flow record version should be 10.") - assert.Equal(t, uint32(1), message.GetObsDomainID(), "Flow record obsDomainID should be 1.") - assert.NotNil(t, cp.templatesMap[message.GetObsDomainID()], "Template should be stored in template map") - templateSet := message.GetSet() - assert.NotNil(t, templateSet, "Template record should be stored in message flowset") - sourceIPv4Address, _, exist := templateSet.GetRecords()[0].GetInfoElementWithValue("sourceIPv4Address") - assert.Equal(t, true, exist) - assert.Equal(t, uint32(0), sourceIPv4Address.GetInfoElement().EnterpriseId, "Template record is not stored correctly.") - // Invalid version - templateRecord := []byte{0, 9, 0, 40, 95, 40, 211, 236, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 24, 1, 0, 0, 3, 0, 8, 0, 4, 0, 12, 0, 4, 128, 105, 255, 255, 0, 0, 218, 21} - _, err = cp.decodePacket(bytes.NewBuffer(templateRecord), address.String()) - assert.NotNil(t, err, "Error should be logged for invalid version") - // Malformed record - templateRecord = []byte{0, 10, 0, 40, 95, 40, 211, 236, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 24, 1, 0, 0, 3, 0, 8, 0, 4, 0, 12, 0, 4, 128, 105, 255, 255, 0, 0} - cp.templatesMap = make(map[uint32]map[uint16]*template) - _, err = cp.decodePacket(bytes.NewBuffer(templateRecord), address.String()) - assert.NotNil(t, err, "Error should be logged for malformed template record") - if _, exist := cp.templatesMap[uint32(1)]; exist { - t.Fatal("Template should not be stored for malformed template record") + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + cp := CollectingProcess{} + cp.templatesMap = tc.existingTemplates + cp.mutex = sync.RWMutex{} + address, err := net.ResolveTCPAddr(tcpTransport, hostPortIPv4) + require.NoError(t, err) + cp.netAddress = address + cp.messageChan = make(chan *entities.Message) + go func() { // remove the message from the message channel + for range cp.GetMsgChan() { + } + }() + message, err := cp.decodePacket(bytes.NewBuffer(tc.templateRecord), address.String()) + if tc.expectedErr != "" { + require.ErrorContains(t, err, tc.expectedErr) + } else { + require.NoError(t, err, "failed to decode template record") + + assert.Equal(t, uint16(10), message.GetVersion(), "Unexpected IPFIX version in message") + assert.Equal(t, obsDomainID, message.GetObsDomainID(), "Unexpected obsDomainID in message") + + templateSet := message.GetSet() + assert.NotNil(t, templateSet, "Template record should be stored in message flowset") + sourceIPv4Address, _, exist := templateSet.GetRecords()[0].GetInfoElementWithValue("sourceIPv4Address") + assert.Equal(t, true, exist) + assert.Equal(t, uint32(0), sourceIPv4Address.GetInfoElement().EnterpriseId, "Template record is not stored correctly.") + } + if tc.isTemplateExpected { + assert.NotNil(t, cp.templatesMap[obsDomainID][templateID], "Template should be stored in template map") + } else { + assert.Nil(t, cp.templatesMap[obsDomainID][templateID], "Template should not be stored in template map") + } + }) } }