Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create consistent workunit IDs across all nodes #1194

Merged
merged 9 commits into from
Nov 12, 2024
2 changes: 1 addition & 1 deletion docs/source/user_guide/configuration_options.rst
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ Node
- Default value
- Type
* - ``id``
- Node ID
- Node ID can only contain a-z, A-Z, 0-9 or special characters . - _ @
- local hostname
- string
* - ``datadir``
Expand Down
2 changes: 1 addition & 1 deletion pkg/controlsvc/controlsvc.go
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ func (s *Server) RunControlSession(conn net.Conn) {
}
}
} else {
writeMsg := "ERROR: Unknown command\n"
writeMsg := fmt.Sprintf("ERROR: Unknown command, %v\n", cmd)
if writeToConnWithLog(conn, s.nc, writeMsg, writeControlServiceError) {
return
}
Expand Down
7 changes: 7 additions & 0 deletions pkg/types/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"os"
"regexp"
"strings"

"github.com/ansible/receptor/pkg/controlsvc"
Expand Down Expand Up @@ -34,6 +35,12 @@ func (cfg NodeCfg) Init() error {
return fmt.Errorf("no node ID specified and local host name is localhost")
}
cfg.ID = host
} else {
submitIDRegex := regexp.MustCompile(`^[.\-_@a-zA-Z0-9]*$`)
match := submitIDRegex.FindSubmatch([]byte(cfg.ID))
if match == nil {
return fmt.Errorf("node id can only contain a-z, A-Z, 0-9 or special characters . - _ @ but received: %s", cfg.ID)
}
}
if strings.ToLower(cfg.ID) == "localhost" {
return fmt.Errorf("node ID \"localhost\" is reserved")
Expand Down
8 changes: 6 additions & 2 deletions pkg/workceptor/controlsvc.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,10 @@ func (c *workceptorCommand) ControlFunc(ctx context.Context, nc controlsvc.Netce
if err != nil {
signature = ""
}
workUnitID, err := strFromMap(c.params, "workUnitID")
if err != nil {
workUnitID = ""
}
workParams := make(map[string]string)
nonParams := []string{"command", "subcommand", "node", "worktype", "tlsclient", "ttl", "signwork", "signature"}
inNonParams := func(p string) bool {
Expand Down Expand Up @@ -283,9 +287,9 @@ func (c *workceptorCommand) ControlFunc(ctx context.Context, nc controlsvc.Netce
if ttl != "" {
return nil, fmt.Errorf("ttl option is intended for remote work only")
}
worker, err = c.w.AllocateUnit(workType, workParams)
worker, err = c.w.AllocateUnit(workType, workUnitID, workParams)
} else {
worker, err = c.w.AllocateRemoteUnit(workNode, workType, tlsClient, ttl, signWork, workParams)
worker, err = c.w.AllocateRemoteUnit(workNode, workType, workUnitID, tlsClient, ttl, signWork, workParams)
}
if err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion pkg/workceptor/json_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func TestWorkceptorJson(t *testing.T) {
if err != nil {
t.Fatal(err)
}
cw, err := w.AllocateUnit("command", make(map[string]string))
cw, err := w.AllocateUnit("command", "", make(map[string]string))
if err != nil {
t.Fatal(err)
}
Expand Down
3 changes: 2 additions & 1 deletion pkg/workceptor/remote_work.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ func (rw *remoteUnit) startRemoteUnit(ctx context.Context, conn net.Conn, reader
for k, v := range red.RemoteParams {
workSubmitCmd[k] = v
}
workSubmitCmd["workUnitID"] = rw.ID()
workSubmitCmd["command"] = "work"
workSubmitCmd["subcommand"] = "submit"
workSubmitCmd["node"] = red.RemoteNode
Expand Down Expand Up @@ -180,7 +181,7 @@ func (rw *remoteUnit) startRemoteUnit(ctx context.Context, conn net.Conn, reader
if err != nil {
return fmt.Errorf("read error reading from %s: %s", red.RemoteNode, err)
}
submitIDRegex := regexp.MustCompile(`with ID ([a-zA-Z0-9]+)\.`)
submitIDRegex := regexp.MustCompile(`with ID ([.\-_@a-zA-Z0-9]+)\.`)
match := submitIDRegex.FindSubmatch([]byte(response))
if match == nil || len(match) != 2 {
return fmt.Errorf("could not parse response: %s", strings.TrimRight(response, "\n"))
Expand Down
23 changes: 17 additions & 6 deletions pkg/workceptor/workceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,25 @@ func (w *Workceptor) RegisterWorker(typeName string, newWorkerFunc NewWorkerFunc
return nil
}

func (w *Workceptor) generateUnitID(lock bool) (string, error) {
func (w *Workceptor) generateUnitID(lock bool, workUnitID string) (string, error) {
if lock {
w.activeUnitsLock.RLock()
defer w.activeUnitsLock.RUnlock()
}
var ident string
for {
ident = randstr.RandomString(8)
if workUnitID == "" {
rstr := randstr.RandomString(8)
nid := w.nc.NodeID()
ident = fmt.Sprintf("%s%s", nid, rstr)
} else {
ident = workUnitID
unitdir := path.Join(w.dataDir, ident)
_, err := os.Stat(unitdir)
if err == nil {
return "", fmt.Errorf("workunit ID %s is already in use, cannot use the same workunit ID more than once", ident)
}
}
_, ok := w.activeUnits[ident]
if !ok {
unitdir := path.Join(w.dataDir, ident)
Expand Down Expand Up @@ -243,7 +254,7 @@ func (w *Workceptor) VerifySignature(signature string) error {
}

// AllocateUnit creates a new local work unit and generates an identifier for it.
func (w *Workceptor) AllocateUnit(workTypeName string, params map[string]string) (WorkUnit, error) {
func (w *Workceptor) AllocateUnit(workTypeName string, workUnitID string, params map[string]string) (WorkUnit, error) {
w.workTypesLock.RLock()
wt, ok := w.workTypes[workTypeName]
w.workTypesLock.RUnlock()
Expand All @@ -252,7 +263,7 @@ func (w *Workceptor) AllocateUnit(workTypeName string, params map[string]string)
}
w.activeUnitsLock.Lock()
defer w.activeUnitsLock.Unlock()
ident, err := w.generateUnitID(false)
ident, err := w.generateUnitID(false, workUnitID)
if err != nil {
return nil, err
}
Expand All @@ -270,7 +281,7 @@ func (w *Workceptor) AllocateUnit(workTypeName string, params map[string]string)
}

// AllocateRemoteUnit creates a new remote work unit and generates a local identifier for it.
func (w *Workceptor) AllocateRemoteUnit(remoteNode, remoteWorkType, tlsClient, ttl string, signWork bool, params map[string]string) (WorkUnit, error) {
func (w *Workceptor) AllocateRemoteUnit(remoteNode, remoteWorkType, workUnitID string, tlsClient, ttl string, signWork bool, params map[string]string) (WorkUnit, error) {
if tlsClient != "" {
_, err := w.nc.GetClientTLSConfig(tlsClient, "testhost", netceptor.ExpectedHostnameTypeReceptor)
if err != nil {
Expand All @@ -288,7 +299,7 @@ func (w *Workceptor) AllocateRemoteUnit(remoteNode, remoteWorkType, tlsClient, t
if hasSecrets && tlsClient == "" {
return nil, fmt.Errorf("cannot send secrets over a non-TLS connection")
}
rw, err := w.AllocateUnit("remote", params)
rw, err := w.AllocateUnit("remote", workUnitID, params)
if err != nil {
return nil, err
}
Expand Down
10 changes: 5 additions & 5 deletions pkg/workceptor/workceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func testSetup(t *testing.T) (*gomock.Controller, *mock_workceptor.MockNetceptor

ctx := context.Background()
mockNetceptor := mock_workceptor.NewMockNetceptorForWorkceptor(ctrl)
mockNetceptor.EXPECT().NodeID().Return("test")
mockNetceptor.EXPECT().NodeID().Return("test").AnyTimes()

logger := logger.NewReceptorLogger("")
mockNetceptor.EXPECT().GetLogger().AnyTimes().Return(logger)
Expand Down Expand Up @@ -48,7 +48,7 @@ func TestAllocateUnit(t *testing.T) {
return mockWorkUnit
}

mockNetceptor.EXPECT().NodeID().Return("test")
mockNetceptor.EXPECT().NodeID().Return("test").Times(4)
w, err := workceptor.New(ctx, mockNetceptor, "/tmp")
if err != nil {
t.Errorf("Error while creating Workceptor: %v", err)
Expand Down Expand Up @@ -124,7 +124,7 @@ func TestAllocateUnit(t *testing.T) {
mockWorkUnit.EXPECT().Save().Return(tc.saveError).Times(1)
}

_, err := w.AllocateUnit(tc.workType, map[string]string{"param": "value"})
_, err := w.AllocateUnit(tc.workType, "", map[string]string{"param": "value"})
checkError(err, tc.expectedError, t)
})
}
Expand Down Expand Up @@ -195,7 +195,7 @@ func TestRegisterWorker(t *testing.T) {
hasError: false,
expectedCalls: func() {
mockNetceptor.EXPECT().AddWorkCommand(gomock.Any(), gomock.Any())
w.AllocateUnit("remote", map[string]string{})
w.AllocateUnit("remote", "", map[string]string{})
},
},
}
Expand Down Expand Up @@ -350,7 +350,7 @@ func TestAllocateRemoteUnit(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tc.expectedCalls()
_, err := w.AllocateRemoteUnit("", "", tc.tlsClient, tc.ttl, tc.signWork, tc.params)
_, err := w.AllocateRemoteUnit("", "", "", tc.tlsClient, tc.ttl, tc.signWork, tc.params)

if tc.errorMsg != "" && tc.errorMsg != err.Error() && err != nil {
t.Errorf("expected: %s, received: %s", tc.errorMsg, err)
Expand Down