diff --git a/agents.go b/agents.go deleted file mode 100644 index df5693e..0000000 --- a/agents.go +++ /dev/null @@ -1,63 +0,0 @@ -package groq - -import ( - "context" - "log/slog" - - "github.com/conneroisu/groq-go/pkg/tools" -) - -type ( - // Agenter is an interface for an agent. - Agenter interface { - ToolManager - } - // ToolManager is an interface for a tool manager. - ToolManager interface { - ToolGetter - ToolRunner - } - // ToolGetter is an interface for a tool getter. - ToolGetter interface { - Get( - ctx context.Context, - params ToolGetParams, - ) ([]tools.Tool, error) - } - // ToolRunner is an interface for a tool runner. - ToolRunner interface { - Run( - ctx context.Context, - response ChatCompletionResponse, - ) ([]ChatCompletionMessage, error) - } - // ToolGetParams are the parameters for getting tools. - ToolGetParams struct { - } - // Router is an agent router. - // - // It is used to route messages to the appropriate model. - Router struct { - // Agents is the agents of the router. - Agents []Agent - // Logger is the logger of the router. - Logger *slog.Logger - } -) - -// Agent is an agent. -type Agent struct { - client *Client - logger *slog.Logger -} - -// NewAgent creates a new agent. -func NewAgent( - client *Client, - logger *slog.Logger, -) *Agent { - return &Agent{ - client: client, - logger: logger, - } -} diff --git a/extensions/e2b/model.go b/extensions/e2b/model.go deleted file mode 100644 index 80e26c7..0000000 --- a/extensions/e2b/model.go +++ /dev/null @@ -1,68 +0,0 @@ -package e2b - -import ( - "context" - "io" - "time" -) - -type ( - // Receiver is an interface for a constantly receiving instance that - // can closed. - // - // Implementations should be conccurent safe. - Receiver interface { - Read(ctx context.Context) error - io.Closer - } - // Identifier is an interface for a constantly running process to - // identify new request ids. - Identifier interface { - Identify(ctx context.Context) - } - // Sandboxer is an interface for a sandbox. - Sandboxer interface { - // KeepAlive keeps the underlying interface alive. - // - // If the context is cancelled before requesting the timeout, - // the error will be ctx.Err(). - KeepAlive( - ctx context.Context, - timeout time.Duration, - ) error - // NewProcess creates a new process. - NewProcess( - cmd string, - ) (*Processor, error) - - // Write writes a file to the filesystem. - Write( - ctx context.Context, - method Method, - params []any, - respCh chan<- []byte, - ) - // Read reads a file from the filesystem. - Read( - ctx context.Context, - path string, - ) (string, error) - } - // Processor is an interface for a process. - Processor interface { - Start( - ctx context.Context, - cmd string, - timeout time.Duration, - ) - SubscribeStdout() (events chan Event, err error) - SubscribeStderr() (events chan Event, err error) - } - // Watcher is an interface for a instance that can watch a filesystem. - Watcher interface { - Watch( - ctx context.Context, - path string, - ) (<-chan Event, error) - } -) diff --git a/extensions/e2b/options.go b/extensions/e2b/options.go index 03f7690..ebf9db6 100644 --- a/extensions/e2b/options.go +++ b/extensions/e2b/options.go @@ -38,8 +38,20 @@ func WithCwd(cwd string) Option { } // WithWsURL sets the websocket url resolving function for the e2b sandbox. +// +// This is useful for testing. func WithWsURL(wsURL func(s *Sandbox) string) Option { return func(s *Sandbox) { s.wsURL = wsURL } } // Process Options + +// ProcessWithEnv sets the environment variables for the process. +func ProcessWithEnv(env map[string]string) ProcessOption { + return func(p *Process) { p.Env = env } +} + +// ProcessWithCwd sets the current working directory for the process. +func ProcessWithCwd(cwd string) ProcessOption { + return func(p *Process) { p.Cwd = cwd } +} diff --git a/extensions/e2b/sandbox.go b/extensions/e2b/sandbox.go index 4e6d235..13a6d40 100644 --- a/extensions/e2b/sandbox.go +++ b/extensions/e2b/sandbox.go @@ -26,42 +26,42 @@ type ( // // The sandbox is like an isolated, but interactive system. Sandbox struct { - ID string `json:"sandboxID"` // ID of the sandbox. - Metadata map[string]string `json:"metadata"` // Metadata of the sandbox. - Template SandboxTemplate `json:"templateID"` // Template of the sandbox. - ClientID string `json:"clientID"` // ClientID of the sandbox. - Cwd string `json:"cwd"` // Cwd is the sandbox's current working directory. - - logger *slog.Logger `json:"-"` // logger is the sandbox's logger. - apiKey string `json:"-"` // apiKey is the sandbox's api key. - baseURL string `json:"-"` // baseAPIURL is the base api url of the sandbox. - client *http.Client `json:"-"` // client is the sandbox's http client. - header builders.Header `json:"-"` // header is the sandbox's request header builder. - ws *websocket.Conn `json:"-"` // ws is the sandbox's websocket connection. - wsURL func(s *Sandbox) string `json:"-"` // wsURL is the sandbox's websocket url. - Map sync.Map `json:"-"` // Map is the map of the sandbox. - idCh chan int `json:"-"` // idCh is the channel to generate ids for requests. - toolW ToolingWrapper `json:"-"` // toolW is the tooling wrapper for the sandbox. + ID string `json:"sandboxID"` // ID of the sandbox. + ClientID string `json:"clientID"` // ClientID of the sandbox. + Cwd string `json:"cwd"` // Cwd is the sandbox's current working directory. + apiKey string `json:"-"` // apiKey is the sandbox's api key. + Template SandboxTemplate `json:"templateID"` // Template of the sandbox. + baseURL string `json:"-"` // baseAPIURL is the base api url of the sandbox. + Metadata map[string]string `json:"metadata"` // Metadata of the sandbox. + logger *slog.Logger `json:"-"` // logger is the sandbox's logger. + client *http.Client `json:"-"` // client is the sandbox's http client. + header builders.Header `json:"-"` // header is the sandbox's request header builder. + ws *websocket.Conn `json:"-"` // ws is the sandbox's websocket connection. + wsURL func(s *Sandbox) string `json:"-"` // wsURL is the sandbox's websocket url. + Map *sync.Map `json:"-"` // Map is the map of the sandbox. + idCh chan int `json:"-"` // idCh is the channel to generate ids for requests. + toolW ToolingWrapper `json:"-"` // toolW is the tooling wrapper for the sandbox. } + // Option is an option for the sandbox. + Option func(*Sandbox) // Process is a process in the sandbox. Process struct { - sb *Sandbox // sb is the sandbox the process belongs to. - ctx context.Context // ctx is the context for the process. id string // ID is process id. cmd string // cmd is process's command. Cwd string // cwd is process's current working directory. + ctx context.Context // ctx is the context for the process. + sb *Sandbox // sb is the sandbox the process belongs to. Env map[string]string // env is process's environment variables. } - // Option is an option for the sandbox. - Option func(*Sandbox) + // ProcessOption is an option for the process. + ProcessOption func(*Process) // Event is a file system event. Event struct { - Path string `json:"path"` // Path is the path of the event. - Name string `json:"name"` // Name is the name of file or directory. - Timestamp int64 `json:"timestamp"` // Timestamp is the timestamp of the event. - Error string `json:"error"` // Error is the possible error of the event. - Params EventParams `json:"params"` // Params is the parameters of the event. - Operation OperationType `json:"operation"` // Operation is the operation type of the event. + Path string `json:"path"` // Path is the path of the event. + Name string `json:"name"` // Name is the name of file or directory. + Timestamp int64 `json:"timestamp"` // Timestamp is the timestamp of the event. + Error string `json:"error"` // Error is the possible error of the event. + Params EventParams `json:"params"` // Params is the parameters of the event. } // EventParams is the params for subscribing to a process event. EventParams struct { @@ -76,8 +76,6 @@ type ( IsDirectory bool `json:"isDirectory"` Error string `json:"error"` } - // OperationType is an operation type. - OperationType int // Request is a JSON-RPC request. Request struct { JSONRPC string `json:"jsonrpc"` // JSONRPC is the JSON-RPC version of the request. @@ -110,10 +108,6 @@ const ( OnStderr ProcessEvents = "onStderr" // OnStderr is the event for the stderr. OnExit ProcessEvents = "onExit" // OnExit is the event for the exit. - EventTypeCreate OperationType = iota // EventTypeCreate is an event for the creation of a file/dir. - EventTypeWrite // EventTypeWrite is an event for the write to a file. - EventTypeRemove // EventTypeRemove is an event for the removal of a file/dir. - rpc = "2.0" charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" defaultBaseURL = "https://api.e2b.dev" @@ -151,20 +145,23 @@ func NewSandbox( }, client: http.DefaultClient, logger: slog.New(slog.NewJSONHandler(io.Discard, nil)), - idCh: make(chan int), toolW: defaultToolWrapper, + idCh: make(chan int), + Map: new(sync.Map), wsURL: func(s *Sandbox) string { return fmt.Sprintf("wss://49982-%s-%s.e2b.dev/ws", s.ID, s.ClientID) }, + header: builders.Header{ + SetCommonHeaders: func(req *http.Request) { + req.Header.Set("X-API-Key", apiKey) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + }, + }, } for _, opt := range opts { opt(&sb) } - sb.header.SetCommonHeaders = func(req *http.Request) { - req.Header.Set("X-API-Key", sb.apiKey) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "application/json") - } req, err := builders.NewRequest( ctx, sb.header, http.MethodPost, fmt.Sprintf("%s%s", sb.baseURL, sandboxesRoute), @@ -185,7 +182,7 @@ func NewSandbox( go func() { err := sb.read(ctx) if err != nil { - fmt.Println(err) + sb.logger.Error("failed to read sandbox", "error", err) } }() return &sb, nil @@ -416,19 +413,24 @@ func (s *Sandbox) Watch( // NewProcess creates a new process startable in the sandbox. func (s *Sandbox) NewProcess( cmd string, - proc Process, + opts ...ProcessOption, ) (*Process, error) { - proc.cmd = cmd b := make([]byte, 12) for i := range b { b[i] = charset[rand.Intn(len(charset))] } - proc.id = string(b) - proc.sb = s + proc := &Process{ + id: string(b), + sb: s, + cmd: cmd, + } + for _, opt := range opts { + opt(proc) + } if proc.Cwd == "" { proc.Cwd = s.Cwd } - return &proc, nil + return proc, nil } // Start starts a process in the sandbox. @@ -478,18 +480,18 @@ func (p *Process) Done() <-chan struct{} { } // SubscribeStdout subscribes to the process's stdout. -func (p *Process) SubscribeStdout() (chan Event, chan error) { - return p.subscribe(p.ctx, OnStdout) +func (p *Process) SubscribeStdout(ctx context.Context) (chan Event, chan error) { + return p.subscribe(ctx, OnStdout) } // SubscribeStderr subscribes to the process's stderr. -func (p *Process) SubscribeStderr() (chan Event, chan error) { - return p.subscribe(p.ctx, OnStderr) +func (p *Process) SubscribeStderr(ctx context.Context) (chan Event, chan error) { + return p.subscribe(ctx, OnStderr) } // SubscribeExit subscribes to the process's exit. -func (p *Process) SubscribeExit() (chan Event, chan error) { - return p.subscribe(p.ctx, OnExit) +func (p *Process) SubscribeExit(ctx context.Context) (chan Event, chan error) { + return p.subscribe(ctx, OnExit) } // Subscribe subscribes to a process event. @@ -513,45 +515,37 @@ func (p *Process) subscribe( errCh <- err } p.sb.Map.Store(res.Result, respCh) + loop: for { select { case eventBd := <-respCh: - p.sb.logger.Debug("eventByCh", "event", string(eventBd)) var event Event _ = json.Unmarshal(eventBd, &event) if event.Error != "" { - p.sb.logger.Debug("failed to read event", "error", event.Error) - continue - } - if event.Params.Subscription != res.Result { - p.sb.logger.Debug("subscription id mismatch", "expected", res.Result, "got", event.Params.Subscription) + p.sb.logger.Error("failed to read event", "error", event.Error) continue } events <- event case <-ctx.Done(): - p.sb.Map.Delete(res.Result) - finishCtx, cancel := context.WithCancel(context.Background()) - defer cancel() - p.sb.logger.Debug("unsubscribing from process", "event", event, "id", res.Result) - _ = p.sb.writeRequest(finishCtx, processUnsubscribe, []any{res.Result}, respCh) - unsubRes, _ := decodeResponse[bool, string](<-respCh) - if unsubRes.Error != "" || !unsubRes.Result { - p.sb.logger.Debug("failed to unsubscribe from process", "error", unsubRes.Error) - } - return + break loop case <-p.Done(): - return + break loop } } + + p.sb.Map.Delete(res.Result) + finishCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + p.sb.logger.Debug("unsubscribing from process", "event", event, "id", res.Result) + _ = p.sb.writeRequest(finishCtx, processUnsubscribe, []any{res.Result}, respCh) + unsubRes, _ := decodeResponse[bool, string](<-respCh) + if unsubRes.Error != "" || !unsubRes.Result { + p.sb.logger.Debug("failed to unsubscribe from process", "error", unsubRes.Error) + } }(errs) return events, errs } func (s *Sandbox) sendRequest(req *http.Request, v interface{}) error { - req.Header.Set("Accept", "application/json") - contentType := req.Header.Get("Content-Type") - if contentType == "" { - req.Header.Set("Content-Type", "application/json") - } res, err := s.client.Do(req) if err != nil { return err @@ -584,19 +578,9 @@ func decodeResponse[T any, Q any](body []byte) (*Response[T, Q], error) { } return decResp, nil } -func (s *Sandbox) identify(ctx context.Context) { - id := 1 - for { - select { - case <-ctx.Done(): - return - default: - s.idCh <- id - id++ - } - } -} -func (s *Sandbox) read(ctx context.Context) (err error) { +func (s *Sandbox) read(ctx context.Context) error { + var body []byte + var err error type decResp struct { Method string `json:"method"` ID int `json:"id"` @@ -604,9 +588,11 @@ func (s *Sandbox) read(ctx context.Context) (err error) { Subscription string `json:"subscription"` } } - var body []byte defer func() { - err = s.ws.Close() + err := s.ws.Close() + if err != nil { + s.logger.Error("failed to close sandbox", "error", err) + } }() msgCh := make(chan []byte, 10) for { @@ -617,50 +603,35 @@ func (s *Sandbox) read(ctx context.Context) (err error) { if err != nil { return err } - if decResp.Params.Subscription != "" { - toR, ok := s.Map.Load(decResp.Params.Subscription) - if !ok { - msgCh <- body - continue - } - toRCh, ok := toR.(chan []byte) - if !ok { - msgCh <- body - continue - } - s.logger.Debug("read", - "subscription", decResp.Params.Subscription, - "body", body, - "sandbox", s.ID, - ) - toRCh <- body - } + var key any + key = decResp.Params.Subscription if decResp.ID != 0 { - toR, ok := s.Map.Load(decResp.ID) - if !ok { - msgCh <- body - continue - } - toRCh, ok := toR.(chan []byte) - if !ok { - msgCh <- body - continue - } - s.logger.Debug("read", - "id", decResp.ID, - "body", body, - "sandbox", s.ID, - ) - toRCh <- body + key = decResp.ID + } + toR, ok := s.Map.Load(key) + if !ok { + msgCh <- body + continue } + toRCh, ok := toR.(chan []byte) + if !ok { + msgCh <- body + continue + } + s.logger.Debug("read", + "subscription", decResp.Params.Subscription, + "body", body, + "sandbox", s.ID, + ) + toRCh <- body case <-ctx.Done(): return ctx.Err() default: - _, body, err := s.ws.ReadMessage() + _, msg, err := s.ws.ReadMessage() if err != nil { return err } - msgCh <- body + msgCh <- msg } } } @@ -681,10 +652,10 @@ func (s *Sandbox) writeRequest( ID: id, } s.logger.Debug("request", - "method", req.Method, - "id", req.ID, - "params", req.Params, - "sandbox", s.ID, + "sandbox", id, + "method", method, + "id", id, + "params", params, ) s.Map.Store(req.ID, respCh) jsVal, err := json.Marshal(req) @@ -703,3 +674,15 @@ func (s *Sandbox) writeRequest( return nil } } +func (s *Sandbox) identify(ctx context.Context) { + id := 1 + for { + select { + case <-ctx.Done(): + return + default: + s.idCh <- id + id++ + } + } +} diff --git a/extensions/e2b/sandbox_test.go b/extensions/e2b/sandbox_test.go index 02806ab..401e18b 100644 --- a/extensions/e2b/sandbox_test.go +++ b/extensions/e2b/sandbox_test.go @@ -116,6 +116,9 @@ func decode(bod []byte) Request { } func TestNewSandbox(t *testing.T) { + if test.IsIntegrationTest() { + t.Skip() + } a := assert.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -159,12 +162,12 @@ func TestNewSandbox(t *testing.T) { a.NoError(err) a.Equal("hello", readRes) - proc, err := sb.NewProcess("sleep 5 && echo 'hello world!'", Process{}) + proc, err := sb.NewProcess("sleep 5 && echo 'hello world!'") a.NoError(err) err = proc.Start(ctx) a.NoError(err) - e, errCh := proc.SubscribeStdout() + e, errCh := proc.SubscribeStdout(ctx) select { case <-errCh: t.Fatal("got error from SubscribeStdout") diff --git a/extensions/e2b/tools.go b/extensions/e2b/tools.go index 3148e2e..976a530 100644 --- a/extensions/e2b/tools.go +++ b/extensions/e2b/tools.go @@ -127,15 +127,15 @@ var ( s *Sandbox, params *Params, ) (groq.ChatCompletionMessage, error) { - proc, err := s.NewProcess(params.Cmd, Process{}) + proc, err := s.NewProcess(params.Cmd) if err != nil { return groq.ChatCompletionMessage{}, err } - e, errCh := proc.SubscribeStdout() + e, errCh := proc.SubscribeStdout(ctx) if err != nil { return groq.ChatCompletionMessage{}, err } - e2, errCh := proc.SubscribeStderr() + e2, errCh := proc.SubscribeStderr(ctx) if err != nil { return groq.ChatCompletionMessage{}, err } diff --git a/extensions/e2b/unit_test.go b/extensions/e2b/unit_test.go index 0346a24..cfb48e3 100644 --- a/extensions/e2b/unit_test.go +++ b/extensions/e2b/unit_test.go @@ -90,20 +90,17 @@ func TestCreateProcess(t *testing.T) { e2b.WithLogger(test.DefaultLogger), ) a.NoError(err, "NewSandbox error") - proc, err := sb.NewProcess("echo 'Hello World!'", - e2b.Process{ - Env: map[string]string{ - "FOO": "bar", - }, - }) + proc, err := sb.NewProcess("echo 'Hello World!'", e2b.ProcessWithEnv(map[string]string{ + "FOO": "bar", + })) a.NoError(err, "could not create process") err = proc.Start(ctx) a.NoError(err) - proc, err = sb.NewProcess("sleep 2 && echo 'Hello World!'", e2b.Process{}) + proc, err = sb.NewProcess("sleep 2 && echo 'Hello World!'") a.NoError(err, "could not create process") err = proc.Start(ctx) a.NoError(err) - stdOutEvents, errCh := proc.SubscribeStdout() + stdOutEvents, errCh := proc.SubscribeStdout(ctx) a.NoError(err) select { case <-errCh: diff --git a/extensions/jigsawstack/tts.mp3 b/extensions/jigsawstack/tts.mp3 index 4e38023..e69de29 100644 Binary files a/extensions/jigsawstack/tts.mp3 and b/extensions/jigsawstack/tts.mp3 differ diff --git a/extensions/toolhouse/toolhouse.go b/extensions/toolhouse/toolhouse.go index d38070e..318ad1b 100644 --- a/extensions/toolhouse/toolhouse.go +++ b/extensions/toolhouse/toolhouse.go @@ -23,10 +23,10 @@ type ( Toolhouse struct { apiKey string baseURL string - client *http.Client provider string - metadata map[string]any bundle string + client *http.Client + metadata map[string]any logger *slog.Logger header builders.Header } diff --git a/go.mod b/go.mod index 58356d9..6270f73 100644 --- a/go.mod +++ b/go.mod @@ -3,17 +3,14 @@ module github.com/conneroisu/groq-go go 1.23.2 require ( + github.com/buger/jsonparser v1.1.1 github.com/gorilla/websocket v1.5.3 github.com/stretchr/testify v1.9.0 - github.com/wk8/go-ordered-map/v2 v2.1.8 ) require ( - github.com/bahlo/generic-list-go v0.2.0 // indirect - github.com/buger/jsonparser v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/kr/pretty v0.3.0 // indirect - github.com/mailru/easyjson v0.7.7 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.8.1 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect diff --git a/go.sum b/go.sum index f7c9cc1..c6c3137 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,3 @@ -github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= -github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= @@ -7,7 +5,6 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= @@ -16,8 +13,6 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= -github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -26,8 +21,6 @@ github.com/rogpeppe/go-internal v1.8.1 h1:geMPLpDpQOgVyCg5z5GoRwLHepNdb71NXb67XF github.com/rogpeppe/go-internal v1.8.1/go.mod h1:JeRgkft04UBgHMgCIwADu4Pn6Mtm5d4nPKWu0nJ5d+o= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= -github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= diff --git a/pkg/list/doc.go b/pkg/list/doc.go new file mode 100644 index 0000000..15e3852 --- /dev/null +++ b/pkg/list/doc.go @@ -0,0 +1,2 @@ +// Package list containes the implementation a doubly linked list. +package list diff --git a/pkg/list/element.go b/pkg/list/element.go new file mode 100644 index 0000000..cdd6738 --- /dev/null +++ b/pkg/list/element.go @@ -0,0 +1,35 @@ +package list + +// Element is an element of a linked list. +type Element[T any] struct { + // Next and previous pointers in the doubly-linked list of elements. + // To simplify the implementation, internally a list l is implemented + // as a ring, such that &l.root is both the next element of the last + // list element (l.Back()) and the previous element of the first list + // element (l.Front()). + next, prev *Element[T] + + // The list to which this element belongs. + list *List[T] + + // The value stored with this element. + Value T +} + +// Next returns the next list element or nil. +func (e *Element[T]) Next() *Element[T] { + p := e.next + if e.list != nil && p != &e.list.root { + return p + } + return nil +} + +// Prev returns the previous list element or nil. +func (e *Element[T]) Prev() *Element[T] { + p := e.prev + if e.list != nil && p != &e.list.root { + return p + } + return nil +} diff --git a/pkg/list/list.go b/pkg/list/list.go new file mode 100644 index 0000000..b78a5ca --- /dev/null +++ b/pkg/list/list.go @@ -0,0 +1,192 @@ +package list + +// List represents a doubly linked list. +// The zero value for List is an empty list ready to use. +type List[T any] struct { + root Element[T] // sentinel list element, only &root, root.prev, and root.next are used + len int // current list length excluding (this) sentinel element +} + +// Init initializes or clears list l. +func (l *List[T]) Init() *List[T] { + l.root.next = &l.root + l.root.prev = &l.root + l.len = 0 + return l +} + +// New returns an initialized list. +func New[T any]() *List[T] { return new(List[T]).Init() } + +// Len returns the number of elements of list l. +// The complexity is O(1). +func (l *List[T]) Len() int { return l.len } + +// Front returns the first element of list l or nil if the list is empty. +func (l *List[T]) Front() *Element[T] { + if l.len == 0 { + return nil + } + return l.root.next +} + +// Back returns the last element of list l or nil if the list is empty. +func (l *List[T]) Back() *Element[T] { + if l.len == 0 { + return nil + } + return l.root.prev +} + +// lazyInit lazily initializes a zero List value. +func (l *List[T]) lazyInit() { + if l.root.next == nil { + l.Init() + } +} + +// insert inserts e after at, increments l.len, and returns e. +func (l *List[T]) insert(e, at *Element[T]) *Element[T] { + e.prev = at + e.next = at.next + e.prev.next = e + e.next.prev = e + e.list = l + l.len++ + return e +} + +// insertValue is a convenience wrapper for insert(&Element{Value: v}, at). +func (l *List[T]) insertValue(v T, at *Element[T]) *Element[T] { + return l.insert(&Element[T]{Value: v}, at) +} + +// remove removes e from its list, decrements l.len +func (l *List[T]) remove(e *Element[T]) { + e.prev.next = e.next + e.next.prev = e.prev + e.next = nil // avoid memory leaks + e.prev = nil // avoid memory leaks + e.list = nil + l.len-- +} + +// move moves e to next to at. +func (l *List[T]) move(e, at *Element[T]) { + if e == at { + return + } + e.prev.next = e.next + e.next.prev = e.prev + + e.prev = at + e.next = at.next + e.prev.next = e + e.next.prev = e +} + +// Remove removes e from l if e is an element of list l. +// It returns the element value e.Value. +// The element must not be nil. +func (l *List[T]) Remove(e *Element[T]) T { + if e.list == l { + // if e.list == l, l must have been initialized when e was inserted + // in l or l == nil (e is a zero Element) and l.remove will crash + l.remove(e) + } + return e.Value +} + +// PushFront inserts a new element e with value v at the front of list l and returns e. +func (l *List[T]) PushFront(v T) *Element[T] { + l.lazyInit() + return l.insertValue(v, &l.root) +} + +// PushBack inserts a new element e with value v at the back of list l and returns e. +func (l *List[T]) PushBack(v T) *Element[T] { + l.lazyInit() + return l.insertValue(v, l.root.prev) +} + +// InsertBefore inserts a new element e with value v immediately before mark and returns e. +// If mark is not an element of l, the list is not modified. +// The mark must not be nil. +func (l *List[T]) InsertBefore(v T, mark *Element[T]) *Element[T] { + if mark.list != l { + return nil + } + // see comment in List.Remove about initialization of l + return l.insertValue(v, mark.prev) +} + +// InsertAfter inserts a new element e with value v immediately after mark and returns e. +// If mark is not an element of l, the list is not modified. +// The mark must not be nil. +func (l *List[T]) InsertAfter(v T, mark *Element[T]) *Element[T] { + if mark.list != l { + return nil + } + // see comment in List.Remove about initialization of l + return l.insertValue(v, mark) +} + +// MoveToFront moves element e to the front of list l. +// If e is not an element of l, the list is not modified. +// The element must not be nil. +func (l *List[T]) MoveToFront(e *Element[T]) { + if e.list != l || l.root.next == e { + return + } + // see comment in List.Remove about initialization of l + l.move(e, &l.root) +} + +// MoveToBack moves element e to the back of list l. +// If e is not an element of l, the list is not modified. +// The element must not be nil. +func (l *List[T]) MoveToBack(e *Element[T]) { + if e.list != l || l.root.prev == e { + return + } + // see comment in List.Remove about initialization of l + l.move(e, l.root.prev) +} + +// MoveBefore moves element e to its new position before mark. +// If e or mark is not an element of l, or e == mark, the list is not modified. +// The element and mark must not be nil. +func (l *List[T]) MoveBefore(e, mark *Element[T]) { + if e.list != l || e == mark || mark.list != l { + return + } + l.move(e, mark.prev) +} + +// MoveAfter moves element e to its new position after mark. +// If e or mark is not an element of l, or e == mark, the list is not modified. +// The element and mark must not be nil. +func (l *List[T]) MoveAfter(e, mark *Element[T]) { + if e.list != l || e == mark || mark.list != l { + return + } + l.move(e, mark) +} + +// PushBackList inserts a copy of another list at the back of list l. +// The lists l and other may be the same. They must not be nil. +func (l *List[T]) PushBackList(other *List[T]) { + l.lazyInit() + for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() { + l.insertValue(e.Value, l.root.prev) + } +} + +// PushFrontList inserts a copy of another list at the front of list l. +// The lists l and other may be the same. They must not be nil. +func (l *List[T]) PushFrontList(other *List[T]) { + l.lazyInit() + for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() { + l.insertValue(e.Value, &l.root) + } +} diff --git a/pkg/list/list_test.go b/pkg/list/list_test.go new file mode 100644 index 0000000..cd071c4 --- /dev/null +++ b/pkg/list/list_test.go @@ -0,0 +1,292 @@ +package list + +import "testing" + +func checkListLen[T any](t *testing.T, l *List[T], length int) bool { + if n := l.Len(); n != length { + t.Errorf("l.Len() = %d, want %d", n, length) + return false + } + return true +} +func checkListPointers[T any](t *testing.T, l *List[T], es []*Element[T]) { + root := &l.root + if !checkListLen(t, l, len(es)) { + return + } + // zero length lists must be the zero value or properly initialized (sentinel circle) + if len(es) == 0 { + if l.root.next != nil && l.root.next != root || l.root.prev != nil && l.root.prev != root { + t.Errorf("l.root.next = %p, l.root.prev = %p; both should both be nil or %p", l.root.next, l.root.prev, root) + } + return + } + // len(es) > 0 + // check internal and external prev/next connections + for i, e := range es { + prev := root + Prev := (*Element[T])(nil) + if i > 0 { + prev = es[i-1] + Prev = prev + } + if p := e.prev; p != prev { + t.Errorf("elt[%d](%p).prev = %p, want %p", i, e, p, prev) + } + if p := e.Prev(); p != Prev { + t.Errorf("elt[%d](%p).Prev() = %p, want %p", i, e, p, Prev) + } + next := root + Next := (*Element[T])(nil) + if i < len(es)-1 { + next = es[i+1] + Next = next + } + if n := e.next; n != next { + t.Errorf("elt[%d](%p).next = %p, want %p", i, e, n, next) + } + if n := e.Next(); n != Next { + t.Errorf("elt[%d](%p).Next() = %p, want %p", i, e, n, Next) + } + } +} +func TestList(t *testing.T) { + // Single element list + { + l := New[string]() + checkListPointers(t, l, []*Element[string]{}) + e := l.PushFront("a") + checkListPointers(t, l, []*Element[string]{e}) + l.MoveToFront(e) + checkListPointers(t, l, []*Element[string]{e}) + l.MoveToBack(e) + checkListPointers(t, l, []*Element[string]{e}) + l.Remove(e) + checkListPointers(t, l, []*Element[string]{}) + } + // Bigger list + l := New[int]() + checkListPointers(t, l, []*Element[int]{}) + e2 := l.PushFront(2) + e1 := l.PushFront(1) + e3 := l.PushBack(3) + e4 := l.PushBack(4) + checkListPointers(t, l, []*Element[int]{e1, e2, e3, e4}) + l.Remove(e2) + checkListPointers(t, l, []*Element[int]{e1, e3, e4}) + l.MoveToFront(e3) // move from middle + checkListPointers(t, l, []*Element[int]{e3, e1, e4}) + l.MoveToFront(e1) + l.MoveToBack(e3) // move from middle + checkListPointers(t, l, []*Element[int]{e1, e4, e3}) + l.MoveToFront(e3) // move from back + checkListPointers(t, l, []*Element[int]{e3, e1, e4}) + l.MoveToFront(e3) // should be no-op + checkListPointers(t, l, []*Element[int]{e3, e1, e4}) + l.MoveToBack(e3) // move from front + checkListPointers(t, l, []*Element[int]{e1, e4, e3}) + l.MoveToBack(e3) // should be no-op + checkListPointers(t, l, []*Element[int]{e1, e4, e3}) + e2 = l.InsertBefore(2, e1) // insert before front + checkListPointers(t, l, []*Element[int]{e2, e1, e4, e3}) + l.Remove(e2) + e2 = l.InsertBefore(2, e4) // insert before middle + checkListPointers(t, l, []*Element[int]{e1, e2, e4, e3}) + l.Remove(e2) + e2 = l.InsertBefore(2, e3) // insert before back + checkListPointers(t, l, []*Element[int]{e1, e4, e2, e3}) + l.Remove(e2) + e2 = l.InsertAfter(2, e1) // insert after front + checkListPointers(t, l, []*Element[int]{e1, e2, e4, e3}) + l.Remove(e2) + e2 = l.InsertAfter(2, e4) // insert after middle + checkListPointers(t, l, []*Element[int]{e1, e4, e2, e3}) + l.Remove(e2) + e2 = l.InsertAfter(2, e3) // insert after back + checkListPointers(t, l, []*Element[int]{e1, e4, e3, e2}) + l.Remove(e2) + // Check standard iteration. + sum := 0 + for e := l.Front(); e != nil; e = e.Next() { + sum += e.Value + } + if sum != 8 { + t.Errorf("sum over l = %d, want 8", sum) + } + // Clear all elements by iterating + var next *Element[int] + for e := l.Front(); e != nil; e = next { + next = e.Next() + l.Remove(e) + } + checkListPointers(t, l, []*Element[int]{}) +} +func checkList[T int](t *testing.T, l *List[T], es []T) { + if !checkListLen(t, l, len(es)) { + return + } + i := 0 + for e := l.Front(); e != nil; e = e.Next() { + le := e.Value + if le != es[i] { + t.Errorf("elt[%d].Value = %v, want %v", i, le, es[i]) + } + i++ + } +} +func TestExtending(t *testing.T) { + l1 := New[int]() + l2 := New[int]() + l1.PushBack(1) + l1.PushBack(2) + l1.PushBack(3) + l2.PushBack(4) + l2.PushBack(5) + l3 := New[int]() + l3.PushBackList(l1) + checkList(t, l3, []int{1, 2, 3}) + l3.PushBackList(l2) + checkList(t, l3, []int{1, 2, 3, 4, 5}) + l3 = New[int]() + l3.PushFrontList(l2) + checkList(t, l3, []int{4, 5}) + l3.PushFrontList(l1) + checkList(t, l3, []int{1, 2, 3, 4, 5}) + checkList(t, l1, []int{1, 2, 3}) + checkList(t, l2, []int{4, 5}) + l3 = New[int]() + l3.PushBackList(l1) + checkList(t, l3, []int{1, 2, 3}) + l3.PushBackList(l3) + checkList(t, l3, []int{1, 2, 3, 1, 2, 3}) + l3 = New[int]() + l3.PushFrontList(l1) + checkList(t, l3, []int{1, 2, 3}) + l3.PushFrontList(l3) + checkList(t, l3, []int{1, 2, 3, 1, 2, 3}) + l3 = New[int]() + l1.PushBackList(l3) + checkList(t, l1, []int{1, 2, 3}) + l1.PushFrontList(l3) + checkList(t, l1, []int{1, 2, 3}) +} +func TestRemove(t *testing.T) { + l := New[int]() + e1 := l.PushBack(1) + e2 := l.PushBack(2) + checkListPointers(t, l, []*Element[int]{e1, e2}) + e := l.Front() + l.Remove(e) + checkListPointers(t, l, []*Element[int]{e2}) + l.Remove(e) + checkListPointers(t, l, []*Element[int]{e2}) +} +func TestIssue4103(t *testing.T) { + l1 := New[int]() + l1.PushBack(1) + l1.PushBack(2) + l2 := New[int]() + l2.PushBack(3) + l2.PushBack(4) + e := l1.Front() + l2.Remove(e) // l2 should not change because e is not an element of l2 + if n := l2.Len(); n != 2 { + t.Errorf("l2.Len() = %d, want 2", n) + } + l1.InsertBefore(8, e) + if n := l1.Len(); n != 3 { + t.Errorf("l1.Len() = %d, want 3", n) + } +} +func TestIssue6349(t *testing.T) { + l := New[int]() + l.PushBack(1) + l.PushBack(2) + e := l.Front() + l.Remove(e) + if e.Value != 1 { + t.Errorf("e.value = %d, want 1", e.Value) + } + if e.Next() != nil { + t.Errorf("e.Next() != nil") + } + if e.Prev() != nil { + t.Errorf("e.Prev() != nil") + } +} +func TestMove(t *testing.T) { + l := New[int]() + e1 := l.PushBack(1) + e2 := l.PushBack(2) + e3 := l.PushBack(3) + e4 := l.PushBack(4) + l.MoveAfter(e3, e3) + checkListPointers(t, l, []*Element[int]{e1, e2, e3, e4}) + l.MoveBefore(e2, e2) + checkListPointers(t, l, []*Element[int]{e1, e2, e3, e4}) + l.MoveAfter(e3, e2) + checkListPointers(t, l, []*Element[int]{e1, e2, e3, e4}) + l.MoveBefore(e2, e3) + checkListPointers(t, l, []*Element[int]{e1, e2, e3, e4}) + l.MoveBefore(e2, e4) + checkListPointers(t, l, []*Element[int]{e1, e3, e2, e4}) + e2, e3 = e3, e2 + l.MoveBefore(e4, e1) + checkListPointers(t, l, []*Element[int]{e4, e1, e2, e3}) + e1, e2, e3, e4 = e4, e1, e2, e3 + l.MoveAfter(e4, e1) + checkListPointers(t, l, []*Element[int]{e1, e4, e2, e3}) + e2, e3, e4 = e4, e2, e3 + l.MoveAfter(e2, e3) + checkListPointers(t, l, []*Element[int]{e1, e3, e2, e4}) +} + +// Test PushFront, PushBack, PushFrontList, PushBackList with uninitialized List +func TestZeroList(t *testing.T) { + var l1 = new(List[int]) + l1.PushFront(1) + checkList(t, l1, []int{1}) + var l2 = new(List[int]) + l2.PushBack(1) + checkList(t, l2, []int{1}) + var l3 = new(List[int]) + l3.PushFrontList(l1) + checkList(t, l3, []int{1}) + var l4 = new(List[int]) + l4.PushBackList(l2) + checkList(t, l4, []int{1}) +} + +// Test that a list l is not modified when calling InsertBefore with a mark that is not an element of l. +func TestInsertBeforeUnknownMark(t *testing.T) { + var l List[int] + l.PushBack(1) + l.PushBack(2) + l.PushBack(3) + l.InsertBefore(1, new(Element[int])) + checkList(t, &l, []int{1, 2, 3}) +} + +// Test that a list l is not modified when calling InsertAfter with a mark that is not an element of l. +func TestInsertAfterUnknownMark(t *testing.T) { + var l List[int] + l.PushBack(1) + l.PushBack(2) + l.PushBack(3) + l.InsertAfter(1, new(Element[int])) + checkList(t, &l, []int{1, 2, 3}) +} + +// Test that a list l is not modified when calling MoveAfter or MoveBefore with a mark that is not an element of l. +func TestMoveUnknownMark(t *testing.T) { + var l1 List[int] + e1 := l1.PushBack(1) + var l2 List[int] + e2 := l2.PushBack(2) + l1.MoveAfter(e1, e2) + checkList(t, &l1, []int{1}) + checkList(t, &l2, []int{2}) + l1.MoveBefore(e1, e2) + checkList(t, &l1, []int{1}) + checkList(t, &l2, []int{2}) +} diff --git a/pkg/omap/doc.go b/pkg/omap/doc.go new file mode 100644 index 0000000..2928ce6 --- /dev/null +++ b/pkg/omap/doc.go @@ -0,0 +1,2 @@ +// Package omap provides an ordered map implementation. +package omap diff --git a/pkg/omap/json.go b/pkg/omap/json.go new file mode 100644 index 0000000..3c094a0 --- /dev/null +++ b/pkg/omap/json.go @@ -0,0 +1,181 @@ +package omap + +import ( + "bytes" + "encoding" + "encoding/json" + "fmt" + "reflect" + "unicode/utf8" + + "github.com/buger/jsonparser" +) + +var ( + _ json.Marshaler = &OrderedMap[int, any]{} + _ json.Unmarshaler = &OrderedMap[int, any]{} +) + +// MarshalJSON implements the json.Marshaler interface. +func (om *OrderedMap[K, V]) MarshalJSON() ([]byte, error) { //nolint:funlen + if om == nil || om.list == nil { + return []byte("null"), nil + } + + writer := Writer{} + writer.RawByte('{') + + for pair, firstIteration := om.Oldest(), true; pair != nil; pair = pair.Next() { + if firstIteration { + firstIteration = false + } else { + writer.RawByte(',') + } + + switch key := any(pair.Key).(type) { + case string: + writer.String(key) + case encoding.TextMarshaler: + writer.RawByte('"') + writer.Raw(key.MarshalText()) + writer.RawByte('"') + case int: + writer.IntStr(key) + case int8: + writer.Int8Str(key) + case int16: + writer.Int16Str(key) + case int32: + writer.Int32Str(key) + case int64: + writer.Int64Str(key) + case uint: + writer.UintStr(key) + case uint8: + writer.Uint8Str(key) + case uint16: + writer.Uint16Str(key) + case uint32: + writer.Uint32Str(key) + case uint64: + writer.Uint64Str(key) + default: + + // this switch takes care of wrapper types around primitive types, such as + // type myType string + switch keyValue := reflect.ValueOf(key); keyValue.Type().Kind() { + case reflect.String: + writer.String(keyValue.String()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + writer.Int64Str(keyValue.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + writer.Uint64Str(keyValue.Uint()) + default: + return nil, fmt.Errorf("unsupported key type: %T", key) + } + } + + writer.RawByte(':') + // the error is checked at the end of the function + writer.Raw(json.Marshal(pair.Value)) //nolint:errchkjson + } + + writer.RawByte('}') + + return dumpWriter(&writer) +} + +func dumpWriter(writer *Writer) ([]byte, error) { + if writer.Error != nil { + return nil, writer.Error + } + + var buf bytes.Buffer + buf.Grow(writer.Size()) + if _, err := writer.DumpTo(&buf); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +// UnmarshalJSON implements the json.Unmarshaler interface. +func (om *OrderedMap[K, V]) UnmarshalJSON(data []byte) error { + if om.list == nil { + om.initialize(0) + } + + return jsonparser.ObjectEach( + data, + func(keyData []byte, valueData []byte, dataType jsonparser.ValueType, offset int) error { + if dataType == jsonparser.String { + // jsonparser removes the enclosing quotes; we need to restore them to make a valid JSON + valueData = data[offset-len(valueData)-2 : offset] + } + + var key K + var value V + + switch typedKey := any(&key).(type) { + case *string: + s, err := decodeUTF8(keyData) + if err != nil { + return err + } + *typedKey = s + case encoding.TextUnmarshaler: + if err := typedKey.UnmarshalText(keyData); err != nil { + return err + } + case *int, *int8, *int16, *int32, *int64, *uint, *uint8, *uint16, *uint32, *uint64: + if err := json.Unmarshal(keyData, typedKey); err != nil { + return err + } + default: + // this switch takes care of wrapper types around primitive types, such as + // type myType string + switch reflect.TypeOf(key).Kind() { + case reflect.String: + s, err := decodeUTF8(keyData) + if err != nil { + return err + } + + convertedKeyData := reflect.ValueOf(s).Convert(reflect.TypeOf(key)) + reflect.ValueOf(&key).Elem().Set(convertedKeyData) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if err := json.Unmarshal(keyData, &key); err != nil { + return err + } + default: + return fmt.Errorf("unsupported key type: %T", key) + } + } + + if err := json.Unmarshal(valueData, &value); err != nil { + return err + } + + om.Set(key, value) + return nil + }) +} + +func decodeUTF8(input []byte) (string, error) { + remaining, offset := input, 0 + runes := make([]rune, 0, len(remaining)) + + for len(remaining) > 0 { + r, size := utf8.DecodeRune(remaining) + if r == utf8.RuneError && size <= 1 { + return "", fmt.Errorf("not a valid UTF-8 string (at position %d): %s", offset, string(input)) + } + + runes = append(runes, r) + remaining = remaining[size:] + offset += size + } + + return string(runes), nil +} diff --git a/pkg/omap/json_fuzz_test.go b/pkg/omap/json_fuzz_test.go new file mode 100644 index 0000000..cb81d64 --- /dev/null +++ b/pkg/omap/json_fuzz_test.go @@ -0,0 +1,111 @@ +package omap + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func FuzzRoundTripJSON(f *testing.F) { + f.Fuzz(func(t *testing.T, data []byte) { + for _, testCase := range []struct { + name string + constructor func() any + equalityAssertion func(*testing.T, any, any) bool + }{ + { + name: "with a string -> string map", + constructor: func() any { return &OrderedMap[string, string]{} }, + equalityAssertion: assertOrderedMapsEqual[string, string], + }, + { + name: "with a string -> int map", + constructor: func() any { return &OrderedMap[string, int]{} }, + equalityAssertion: assertOrderedMapsEqual[string, int], + }, + { + name: "with a string -> any map", + constructor: func() any { return &OrderedMap[string, any]{} }, + equalityAssertion: assertOrderedMapsEqual[string, any], + }, + { + name: "with a struct with map fields", + constructor: func() any { return new(testFuzzStruct) }, + equalityAssertion: assertTestFuzzStructEqual, + }, + } { + t.Run(testCase.name, func(t *testing.T) { + v1 := testCase.constructor() + if json.Unmarshal(data, v1) != nil { + return + } + + jsonData, err := json.Marshal(v1) + require.NoError(t, err) + + v2 := testCase.constructor() + require.NoError(t, json.Unmarshal(jsonData, v2)) + + if !assert.True(t, testCase.equalityAssertion(t, v1, v2), "failed with input data %q", string(data)) { + var m1 map[string]any + require.NoError(t, json.Unmarshal(data, &m1)) + + mapJSONData, err := json.Marshal(m1) + require.NoError(t, err) + + var m2 map[string]any + require.NoError(t, json.Unmarshal(mapJSONData, &m2)) + + t.Logf("initial data = %s", string(data)) + t.Logf("unmarshalled map = %v", m1) + t.Logf("re-marshalled from map = %s", string(mapJSONData)) + t.Logf("re-marshalled from test obj = %s", string(jsonData)) + t.Logf("re-unmarshalled map = %s", m2) + } + }) + } + }) +} + +func assertOrderedMapsEqual[K comparable, V any](t *testing.T, v1, v2 any) bool { + om1, ok1 := v1.(*OrderedMap[K, V]) + om2, ok2 := v2.(*OrderedMap[K, V]) + + if !assert.True(t, ok1, "v1 not an orderedmap") || + !assert.True(t, ok2, "v2 not an orderedmap") { + return false + } + + success := assert.Equal(t, om1.Len(), om2.Len(), "om1 and om2 have different lengths: %d vs %d", om1.Len(), om2.Len()) + + for i, pair1, pair2 := 0, om1.Oldest(), om2.Oldest(); pair1 != nil && pair2 != nil; i, pair1, pair2 = i+1, pair1.Next(), pair2.Next() { + success = assert.Equal(t, pair1.Key, pair2.Key, "different keys at position %d: %v vs %v", i, pair1.Key, pair2.Key) && success + success = assert.Equal(t, pair1.Value, pair2.Value, "different values at position %d: %v vs %v", i, pair1.Value, pair2.Value) && success + } + + return success +} + +type testFuzzStruct struct { + M1 *OrderedMap[int, any] + M2 *OrderedMap[int, string] + M3 *OrderedMap[string, string] +} + +func assertTestFuzzStructEqual(t *testing.T, v1, v2 any) bool { + s1, ok1 := v1.(*testFuzzStruct) + s2, ok2 := v2.(*testFuzzStruct) + + if !assert.True(t, ok1, "v1 not an testFuzzStruct") || + !assert.True(t, ok2, "v2 not an testFuzzStruct") { + return false + } + + success := assertOrderedMapsEqual[int, any](t, s1.M1, s2.M1) + success = assertOrderedMapsEqual[int, string](t, s1.M2, s2.M2) && success + success = assertOrderedMapsEqual[string, string](t, s1.M3, s2.M3) && success + + return success +} diff --git a/pkg/omap/json_test.go b/pkg/omap/json_test.go new file mode 100644 index 0000000..43b1ec6 --- /dev/null +++ b/pkg/omap/json_test.go @@ -0,0 +1,338 @@ +package omap + +import ( + "encoding/json" + "errors" + "fmt" + "strconv" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// to test marshalling TextMarshalers and unmarshalling TextUnmarshalers +type marshallable int + +func (m marshallable) MarshalText() ([]byte, error) { + return []byte(fmt.Sprintf("#%d#", m)), nil +} + +func (m *marshallable) UnmarshalText(text []byte) error { + if len(text) < 3 { + return errors.New("too short") + } + if text[0] != '#' || text[len(text)-1] != '#' { + return errors.New("missing prefix or suffix") + } + + value, err := strconv.Atoi(string(text[1 : len(text)-1])) + if err != nil { + return err + } + + *m = marshallable(value) + return nil +} + +func TestMarshalJSON(t *testing.T) { + t.Run("int key", func(t *testing.T) { + om := New[int, any]() + om.Set(1, "bar") + om.Set(7, "baz") + om.Set(2, 28) + om.Set(3, 100) + om.Set(4, "baz") + om.Set(5, "28") + om.Set(6, "100") + om.Set(8, "baz") + om.Set(8, "baz") + om.Set(9, "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Quisque auctor augue accumsan mi maximus, quis viverra massa pretium. Phasellus imperdiet sapien a interdum sollicitudin. Duis at commodo lectus, a lacinia sem.") + + b, err := json.Marshal(om) + assert.NoError(t, err) + assert.Equal(t, `{"1":"bar","7":"baz","2":28,"3":100,"4":"baz","5":"28","6":"100","8":"baz","9":"Lorem ipsum dolor sit amet, consectetur adipiscing elit. Quisque auctor augue accumsan mi maximus, quis viverra massa pretium. Phasellus imperdiet sapien a interdum sollicitudin. Duis at commodo lectus, a lacinia sem."}`, string(b)) + }) + + t.Run("string key", func(t *testing.T) { + om := New[string, any]() + om.Set("test", "bar") + om.Set("abc", true) + + b, err := json.Marshal(om) + assert.NoError(t, err) + assert.Equal(t, `{"test":"bar","abc":true}`, string(b)) + }) + + t.Run("typed string key", func(t *testing.T) { + type myString string + om := New[myString, any]() + om.Set("test", "bar") + om.Set("abc", true) + + b, err := json.Marshal(om) + assert.NoError(t, err) + assert.Equal(t, `{"test":"bar","abc":true}`, string(b)) + }) + + t.Run("typed int key", func(t *testing.T) { + type myInt uint32 + om := New[myInt, any]() + om.Set(1, "bar") + om.Set(7, "baz") + om.Set(2, 28) + om.Set(3, 100) + om.Set(4, "baz") + + b, err := json.Marshal(om) + assert.NoError(t, err) + assert.Equal(t, `{"1":"bar","7":"baz","2":28,"3":100,"4":"baz"}`, string(b)) + }) + + t.Run("TextMarshaller key", func(t *testing.T) { + om := New[marshallable, any]() + om.Set(marshallable(1), "bar") + om.Set(marshallable(28), true) + + b, err := json.Marshal(om) + assert.NoError(t, err) + assert.Equal(t, `{"#1#":"bar","#28#":true}`, string(b)) + }) + + t.Run("empty map", func(t *testing.T) { + om := New[string, any]() + + b, err := json.Marshal(om) + assert.NoError(t, err) + assert.Equal(t, `{}`, string(b)) + }) +} + +func TestUnmarshallJSON(t *testing.T) { + t.Run("int key", func(t *testing.T) { + data := `{"1":"bar","7":"baz","2":28,"3":100,"4":"baz","5":"28","6":"100","8":"baz"}` + + om := New[int, any]() + require.NoError(t, json.Unmarshal([]byte(data), &om)) + + assertOrderedPairsEqual(t, om, + []int{1, 7, 2, 3, 4, 5, 6, 8}, + []any{"bar", "baz", float64(28), float64(100), "baz", "28", "100", "baz"}) + }) + + t.Run("string key", func(t *testing.T) { + data := `{"test":"bar","abc":true}` + + om := New[string, any]() + require.NoError(t, json.Unmarshal([]byte(data), &om)) + + assertOrderedPairsEqual(t, om, + []string{"test", "abc"}, + []any{"bar", true}) + }) + + t.Run("typed string key", func(t *testing.T) { + data := `{"test":"bar","abc":true}` + + type myString string + om := New[myString, any]() + require.NoError(t, json.Unmarshal([]byte(data), &om)) + + assertOrderedPairsEqual(t, om, + []myString{"test", "abc"}, + []any{"bar", true}) + }) + + t.Run("typed int key", func(t *testing.T) { + data := `{"1":"bar","7":"baz","2":28,"3":100,"4":"baz","5":"28","6":"100","8":"baz"}` + + type myInt uint32 + om := New[myInt, any]() + require.NoError(t, json.Unmarshal([]byte(data), &om)) + + assertOrderedPairsEqual(t, om, + []myInt{1, 7, 2, 3, 4, 5, 6, 8}, + []any{"bar", "baz", float64(28), float64(100), "baz", "28", "100", "baz"}) + }) + + t.Run("TextUnmarshaler key", func(t *testing.T) { + data := `{"#1#":"bar","#28#":true}` + + om := New[marshallable, any]() + require.NoError(t, json.Unmarshal([]byte(data), &om)) + + assertOrderedPairsEqual(t, om, + []marshallable{1, 28}, + []any{"bar", true}) + }) + + t.Run("when fed with an input that's not an object", func(t *testing.T) { + for _, data := range []string{"true", `["foo"]`, "42", `"foo"`} { + om := New[int, any]() + require.Error(t, json.Unmarshal([]byte(data), &om)) + } + }) + + t.Run("empty map", func(t *testing.T) { + data := `{}` + + om := New[int, any]() + require.NoError(t, json.Unmarshal([]byte(data), &om)) + + assertLenEqual(t, om, 0) + }) +} + +// const specialCharacters = "\\\\/\"\b\f\n\r\t\x00\uffff\ufffd世界\u007f\u00ff\U0010FFFF" +const specialCharacters = "\uffff\ufffd世界\u007f\u00ff\U0010FFFF" + +func TestJSONSpecialCharacters(t *testing.T) { + baselineMap := map[string]any{specialCharacters: specialCharacters} + baselineData, err := json.Marshal(baselineMap) + require.NoError(t, err) // baseline proves this key is supported by official json library + t.Logf("specialCharacters: %#v as []rune:%v", specialCharacters, []rune(specialCharacters)) + t.Logf("baseline json data: %s", baselineData) + + t.Run("marshal special characters", func(t *testing.T) { + om := New[string, any]() + om.Set(specialCharacters, specialCharacters) + b, err := json.Marshal(om) + require.NoError(t, err) + require.Equal(t, baselineData, b) + + type myString string + om2 := New[myString, myString]() + om2.Set(specialCharacters, specialCharacters) + b, err = json.Marshal(om2) + require.NoError(t, err) + require.Equal(t, baselineData, b) + }) + + t.Run("unmarshall special characters", func(t *testing.T) { + om := New[string, any]() + require.NoError(t, json.Unmarshal(baselineData, &om)) + assertOrderedPairsEqual(t, om, + []string{specialCharacters}, + []any{specialCharacters}) + + type myString string + om2 := New[myString, myString]() + require.NoError(t, json.Unmarshal(baselineData, &om2)) + assertOrderedPairsEqual(t, om2, + []myString{specialCharacters}, + []myString{specialCharacters}) + }) +} + +// to test structs that have nested map fields +type nestedMaps struct { + X int `json:"x" yaml:"x"` + M *OrderedMap[string, []*OrderedMap[int, *OrderedMap[string, any]]] `json:"m" yaml:"m"` +} + +func TestJSONRoundTrip(t *testing.T) { + for _, testCase := range []struct { + name string + input string + targetFactory func() any + isPrettyPrinted bool + }{ + { + name: "", + input: `{ + "x": 28, + "m": { + "foo": [ + { + "12": { + "i": 12, + "b": true, + "n": null, + "m": { + "a": "b", + "c": 28 + } + }, + "28": { + "a": false, + "b": [ + 1, + 2, + 3 + ] + } + }, + { + "3": { + "c": null, + "d": 87 + }, + "4": { + "e": true + }, + "5": { + "f": 4, + "g": 5, + "h": 6 + } + } + ], + "bar": [ + { + "5": { + "foo": "bar" + } + } + ] + } +}`, + targetFactory: func() any { return &nestedMaps{} }, + isPrettyPrinted: true, + }, + { + name: "with UTF-8 special chars in key", + input: `{"�":0}`, + targetFactory: func() any { return &OrderedMap[string, int]{} }, + }, + } { + t.Run(testCase.name, func(t *testing.T) { + target := testCase.targetFactory() + + require.NoError(t, json.Unmarshal([]byte(testCase.input), target)) + + var ( + out []byte + err error + ) + if testCase.isPrettyPrinted { + out, err = json.MarshalIndent(target, "", " ") + } else { + out, err = json.Marshal(target) + } + + if assert.NoError(t, err) { + assert.Equal(t, strings.TrimSpace(testCase.input), string(out)) + } + }) + } +} + +func BenchmarkMarshalJSON(b *testing.B) { + om := New[int, any]() + om.Set(1, "bar") + om.Set(7, "baz") + om.Set(2, 28) + om.Set(3, 100) + om.Set(4, "baz") + om.Set(5, "28") + om.Set(6, "100") + om.Set(8, "baz") + om.Set(8, "baz") + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _, _ = json.Marshal(om) + } +} diff --git a/pkg/omap/omap.go b/pkg/omap/omap.go new file mode 100644 index 0000000..307e193 --- /dev/null +++ b/pkg/omap/omap.go @@ -0,0 +1,292 @@ +package omap + +import ( + "fmt" + + "github.com/conneroisu/groq-go/pkg/list" +) + +// Pair is a generic pair. +type Pair[K comparable, V any] struct { + Key K + Value V + + element *list.Element[*Pair[K, V]] +} + +// OrderedMap is a generic ordered map. +type OrderedMap[K comparable, V any] struct { + pairs map[K]*Pair[K, V] + list *list.List[*Pair[K, V]] +} + +type initConfig[K comparable, V any] struct { + capacity int + initialData []Pair[K, V] +} + +// InitOption is an option for initializing an OrderedMap. +type InitOption[K comparable, V any] func(config *initConfig[K, V]) + +// WithCapacity allows giving a capacity hint for the map, akin to the standard make(map[K]V, capacity). +func WithCapacity[K comparable, V any](capacity int) InitOption[K, V] { + return func(c *initConfig[K, V]) { + c.capacity = capacity + } +} + +// WithInitialData allows passing in initial data for the map. +func WithInitialData[K comparable, V any](initialData ...Pair[K, V]) InitOption[K, V] { + return func(c *initConfig[K, V]) { + c.initialData = initialData + if c.capacity < len(initialData) { + c.capacity = len(initialData) + } + } +} + +// New creates a new OrderedMap. +// options can either be one or several InitOption[K, V], or a single integer, +// which is then interpreted as a capacity hint, à la make(map[K]V, capacity). +func New[K comparable, V any](options ...any) *OrderedMap[K, V] { //nolint:varnamelen + orderedMap := &OrderedMap[K, V]{} + + var config initConfig[K, V] + for _, untypedOption := range options { + switch option := untypedOption.(type) { + case int: + if len(options) != 1 { + invalidOption() + } + config.capacity = option + + case InitOption[K, V]: + option(&config) + + default: + invalidOption() + } + } + + orderedMap.initialize(config.capacity) + orderedMap.AddPairs(config.initialData...) + + return orderedMap +} + +const invalidOptionMessage = `when using orderedmap.New[K,V]() with options, either provide one or several InitOption[K, V]; or a single integer which is then interpreted as a capacity hint, à la make(map[K]V, capacity).` //nolint:lll + +func invalidOption() { panic(invalidOptionMessage) } + +func (om *OrderedMap[K, V]) initialize(capacity int) { + om.pairs = make(map[K]*Pair[K, V], capacity) + om.list = list.New[*Pair[K, V]]() +} + +// Get looks for the given key, and returns the value associated with it, +// or V's nil value if not found. The boolean it returns says whether the key is present in the map. +func (om *OrderedMap[K, V]) Get(key K) (val V, present bool) { + if pair, present := om.pairs[key]; present { + return pair.Value, true + } + + return +} + +// Load is an alias for Get, mostly to present an API similar to `sync.Map`'s. +func (om *OrderedMap[K, V]) Load(key K) (V, bool) { + return om.Get(key) +} + +// Value returns the value associated with the given key or the zero value. +func (om *OrderedMap[K, V]) Value(key K) (val V) { + if pair, present := om.pairs[key]; present { + val = pair.Value + } + return +} + +// GetPair looks for the given key, and returns the pair associated with it, +// or nil if not found. The Pair struct can then be used to iterate over the ordered map +// from that point, either forward or backward. +func (om *OrderedMap[K, V]) GetPair(key K) *Pair[K, V] { + return om.pairs[key] +} + +// Set sets the key-value pair, and returns what `Get` would have returned +// on that key prior to the call to `Set`. +func (om *OrderedMap[K, V]) Set(key K, value V) (val V, present bool) { + if pair, present := om.pairs[key]; present { + oldValue := pair.Value + pair.Value = value + return oldValue, true + } + + pair := &Pair[K, V]{ + Key: key, + Value: value, + } + pair.element = om.list.PushBack(pair) + om.pairs[key] = pair + + return +} + +// AddPairs allows setting multiple pairs at a time. It's equivalent to calling +// Set on each pair sequentially. +func (om *OrderedMap[K, V]) AddPairs(pairs ...Pair[K, V]) { + for _, pair := range pairs { + om.Set(pair.Key, pair.Value) + } +} + +// Store is an alias for Set, mostly to present an API similar to `sync.Map`'s. +func (om *OrderedMap[K, V]) Store(key K, value V) (V, bool) { + return om.Set(key, value) +} + +// Delete removes the key-value pair, and returns what `Get` would have returned +// on that key prior to the call to `Delete`. +func (om *OrderedMap[K, V]) Delete(key K) (val V, present bool) { + if pair, present := om.pairs[key]; present { + om.list.Remove(pair.element) + delete(om.pairs, key) + return pair.Value, true + } + return +} + +// Len returns the length of the ordered map. +func (om *OrderedMap[K, V]) Len() int { + if om == nil || om.pairs == nil { + return 0 + } + return len(om.pairs) +} + +// Oldest returns a pointer to the oldest pair. It's meant to be used to iterate on the ordered map's +// pairs from the oldest to the newest, e.g.: +// for pair := orderedMap.Oldest(); pair != nil; pair = pair.Next() { fmt.Printf("%v => %v\n", pair.Key, pair.Value) } +func (om *OrderedMap[K, V]) Oldest() *Pair[K, V] { + if om == nil || om.list == nil { + return nil + } + return listElementToPair(om.list.Front()) +} + +// Newest returns a pointer to the newest pair. It's meant to be used to iterate on the ordered map's +// pairs from the newest to the oldest, e.g.: +// for pair := orderedMap.Oldest(); pair != nil; pair = pair.Next() { fmt.Printf("%v => %v\n", pair.Key, pair.Value) } +func (om *OrderedMap[K, V]) Newest() *Pair[K, V] { + if om == nil || om.list == nil { + return nil + } + return listElementToPair(om.list.Back()) +} + +// Next returns a pointer to the next pair. +func (p *Pair[K, V]) Next() *Pair[K, V] { + return listElementToPair(p.element.Next()) +} + +// Prev returns a pointer to the previous pair. +func (p *Pair[K, V]) Prev() *Pair[K, V] { + return listElementToPair(p.element.Prev()) +} + +func listElementToPair[K comparable, V any](element *list.Element[*Pair[K, V]]) *Pair[K, V] { + if element == nil { + return nil + } + return element.Value +} + +// KeyNotFoundError may be returned by functions in this package when they're called with keys that are not present +// in the map. +type KeyNotFoundError[K comparable] struct { + MissingKey K +} + +func (e *KeyNotFoundError[K]) Error() string { + return fmt.Sprintf("missing key: %v", e.MissingKey) +} + +// MoveAfter moves the value associated with key to its new position after the one associated with markKey. +// Returns an error iff key or markKey are not present in the map. If an error is returned, +// it will be a KeyNotFoundError. +func (om *OrderedMap[K, V]) MoveAfter(key, markKey K) error { + elements, err := om.getElements(key, markKey) + if err != nil { + return err + } + om.list.MoveAfter(elements[0], elements[1]) + return nil +} + +// MoveBefore moves the value associated with key to its new position before the one associated with markKey. +// Returns an error iff key or markKey are not present in the map. If an error is returned, +// it will be a KeyNotFoundError. +func (om *OrderedMap[K, V]) MoveBefore(key, markKey K) error { + elements, err := om.getElements(key, markKey) + if err != nil { + return err + } + om.list.MoveBefore(elements[0], elements[1]) + return nil +} + +func (om *OrderedMap[K, V]) getElements(keys ...K) ([]*list.Element[*Pair[K, V]], error) { + elements := make([]*list.Element[*Pair[K, V]], len(keys)) + for i, k := range keys { + pair, present := om.pairs[k] + if !present { + return nil, &KeyNotFoundError[K]{k} + } + elements[i] = pair.element + } + return elements, nil +} + +// MoveToBack moves the value associated with key to the back of the ordered map, +// i.e. makes it the newest pair in the map. +// Returns an error iff key is not present in the map. If an error is returned, +// it will be a KeyNotFoundError. +func (om *OrderedMap[K, V]) MoveToBack(key K) error { + _, err := om.GetAndMoveToBack(key) + return err +} + +// MoveToFront moves the value associated with key to the front of the ordered map, +// i.e. makes it the oldest pair in the map. +// Returns an error iff key is not present in the map. If an error is returned, +// it will be a KeyNotFoundError. +func (om *OrderedMap[K, V]) MoveToFront(key K) error { + _, err := om.GetAndMoveToFront(key) + return err +} + +// GetAndMoveToBack combines Get and MoveToBack in the same call. If an error is returned, +// it will be a KeyNotFoundError. +func (om *OrderedMap[K, V]) GetAndMoveToBack(key K) (val V, err error) { + if pair, present := om.pairs[key]; present { + val = pair.Value + om.list.MoveToBack(pair.element) + } else { + err = &KeyNotFoundError[K]{key} + } + + return +} + +// GetAndMoveToFront combines Get and MoveToFront in the same call. If an error is returned, +// it will be a KeyNotFoundError. +func (om *OrderedMap[K, V]) GetAndMoveToFront(key K) (val V, err error) { + if pair, present := om.pairs[key]; present { + val = pair.Value + om.list.MoveToFront(pair.element) + } else { + err = &KeyNotFoundError[K]{key} + } + + return +} diff --git a/pkg/omap/omap_test.go b/pkg/omap/omap_test.go new file mode 100644 index 0000000..4bebc02 --- /dev/null +++ b/pkg/omap/omap_test.go @@ -0,0 +1,385 @@ +package omap + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBasicFeatures(t *testing.T) { + n := 100 + om := New[int, int]() + + // set(i, 2 * i) + for i := 0; i < n; i++ { + assertLenEqual(t, om, i) + oldValue, present := om.Set(i, 2*i) + assertLenEqual(t, om, i+1) + + assert.Equal(t, 0, oldValue) + assert.False(t, present) + } + + // get what we just set + for i := 0; i < n; i++ { + value, present := om.Get(i) + + assert.Equal(t, 2*i, value) + assert.Equal(t, value, om.Value(i)) + assert.True(t, present) + } + + // get pairs of what we just set + for i := 0; i < n; i++ { + pair := om.GetPair(i) + + assert.NotNil(t, pair) + assert.Equal(t, 2*i, pair.Value) + } + + // forward iteration + i := 0 + for pair := om.Oldest(); pair != nil; pair = pair.Next() { + assert.Equal(t, i, pair.Key) + assert.Equal(t, 2*i, pair.Value) + i++ + } + // backward iteration + i = n - 1 + for pair := om.Newest(); pair != nil; pair = pair.Prev() { + assert.Equal(t, i, pair.Key) + assert.Equal(t, 2*i, pair.Value) + i-- + } + + // forward iteration starting from known key + i = 42 + for pair := om.GetPair(i); pair != nil; pair = pair.Next() { + assert.Equal(t, i, pair.Key) + assert.Equal(t, 2*i, pair.Value) + i++ + } + + // double values for pairs with even keys + for j := 0; j < n/2; j++ { + i = 2 * j + oldValue, present := om.Set(i, 4*i) + + assert.Equal(t, 2*i, oldValue) + assert.True(t, present) + } + // and delete pairs with odd keys + for j := 0; j < n/2; j++ { + i = 2*j + 1 + assertLenEqual(t, om, n-j) + value, present := om.Delete(i) + assertLenEqual(t, om, n-j-1) + + assert.Equal(t, 2*i, value) + assert.True(t, present) + + // deleting again shouldn't change anything + value, present = om.Delete(i) + assertLenEqual(t, om, n-j-1) + assert.Equal(t, 0, value) + assert.False(t, present) + } + + // get the whole range + for j := 0; j < n/2; j++ { + i = 2 * j + value, present := om.Get(i) + assert.Equal(t, 4*i, value) + assert.Equal(t, value, om.Value(i)) + assert.True(t, present) + + i = 2*j + 1 + value, present = om.Get(i) + assert.Equal(t, 0, value) + assert.Equal(t, value, om.Value(i)) + assert.False(t, present) + } + + // check iterations again + i = 0 + for pair := om.Oldest(); pair != nil; pair = pair.Next() { + assert.Equal(t, i, pair.Key) + assert.Equal(t, 4*i, pair.Value) + i += 2 + } + i = 2 * ((n - 1) / 2) + for pair := om.Newest(); pair != nil; pair = pair.Prev() { + assert.Equal(t, i, pair.Key) + assert.Equal(t, 4*i, pair.Value) + i -= 2 + } +} + +func TestUpdatingDoesntChangePairsOrder(t *testing.T) { + om := New[string, any]() + om.Set("foo", "bar") + om.Set("wk", 28) + om.Set("po", 100) + om.Set("bar", "baz") + + oldValue, present := om.Set("po", 102) + assert.Equal(t, 100, oldValue) + assert.True(t, present) + + assertOrderedPairsEqual(t, om, + []string{"foo", "wk", "po", "bar"}, + []any{"bar", 28, 102, "baz"}) +} + +func TestDeletingAndReinsertingChangesPairsOrder(t *testing.T) { + om := New[string, any]() + om.Set("foo", "bar") + om.Set("wk", 28) + om.Set("po", 100) + om.Set("bar", "baz") + + // delete a pair + oldValue, present := om.Delete("po") + assert.Equal(t, 100, oldValue) + assert.True(t, present) + + // re-insert the same pair + oldValue, present = om.Set("po", 100) + assert.Nil(t, oldValue) + assert.False(t, present) + + assertOrderedPairsEqual(t, om, + []string{"foo", "wk", "bar", "po"}, + []any{"bar", 28, "baz", 100}) +} + +func TestEmptyMapOperations(t *testing.T) { + om := New[string, any]() + + oldValue, present := om.Get("foo") + assert.Nil(t, oldValue) + assert.Nil(t, om.Value("foo")) + assert.False(t, present) + + oldValue, present = om.Delete("bar") + assert.Nil(t, oldValue) + assert.False(t, present) + + assertLenEqual(t, om, 0) + + assert.Nil(t, om.Oldest()) + assert.Nil(t, om.Newest()) +} + +type dummyTestStruct struct { + value string +} + +func TestPackUnpackStructs(t *testing.T) { + om := New[string, dummyTestStruct]() + om.Set("foo", dummyTestStruct{"foo!"}) + om.Set("bar", dummyTestStruct{"bar!"}) + + value, present := om.Get("foo") + assert.True(t, present) + assert.Equal(t, value, om.Value("foo")) + if assert.NotNil(t, value) { + assert.Equal(t, "foo!", value.value) + } + + value, present = om.Set("bar", dummyTestStruct{"baz!"}) + assert.True(t, present) + if assert.NotNil(t, value) { + assert.Equal(t, "bar!", value.value) + } + + value, present = om.Get("bar") + assert.Equal(t, value, om.Value("bar")) + assert.True(t, present) + if assert.NotNil(t, value) { + assert.Equal(t, "baz!", value.value) + } +} + +// shamelessly stolen from https://github.com/python/cpython/blob/e19a91e45fd54a56e39c2d12e6aaf4757030507f/Lib/test/test_ordered_dict.py#L55-L61 +func TestShuffle(t *testing.T) { + ranLen := 100 + + for _, n := range []int{0, 10, 20, 100, 1000, 10000} { + t.Run(fmt.Sprintf("shuffle test with %d items", n), func(t *testing.T) { + om := New[string, string]() + + keys := make([]string, n) + values := make([]string, n) + + for i := 0; i < n; i++ { + // we prefix with the number to ensure that we don't get any duplicates + keys[i] = fmt.Sprintf("%d_%s", i, randomHexString(t, ranLen)) + values[i] = randomHexString(t, ranLen) + + value, present := om.Set(keys[i], values[i]) + assert.Equal(t, "", value) + assert.False(t, present) + } + + assertOrderedPairsEqual(t, om, keys, values) + }) + } +} + +func TestMove(t *testing.T) { + om := New[int, any]() + om.Set(1, "bar") + om.Set(2, 28) + om.Set(3, 100) + om.Set(4, "baz") + om.Set(5, "28") + om.Set(6, "100") + om.Set(7, "baz") + om.Set(8, "baz") + + err := om.MoveAfter(2, 3) + assert.Nil(t, err) + assertOrderedPairsEqual(t, om, + []int{1, 3, 2, 4, 5, 6, 7, 8}, + []any{"bar", 100, 28, "baz", "28", "100", "baz", "baz"}) + + err = om.MoveBefore(6, 4) + assert.Nil(t, err) + assertOrderedPairsEqual(t, om, + []int{1, 3, 2, 6, 4, 5, 7, 8}, + []any{"bar", 100, 28, "100", "baz", "28", "baz", "baz"}) + + err = om.MoveToBack(3) + assert.Nil(t, err) + assertOrderedPairsEqual(t, om, + []int{1, 2, 6, 4, 5, 7, 8, 3}, + []any{"bar", 28, "100", "baz", "28", "baz", "baz", 100}) + + err = om.MoveToFront(5) + assert.Nil(t, err) + assertOrderedPairsEqual(t, om, + []int{5, 1, 2, 6, 4, 7, 8, 3}, + []any{"28", "bar", 28, "100", "baz", "baz", "baz", 100}) + + err = om.MoveToFront(100) + assert.Equal(t, &KeyNotFoundError[int]{100}, err) +} + +func TestGetAndMove(t *testing.T) { + om := New[int, any]() + om.Set(1, "bar") + om.Set(2, 28) + om.Set(3, 100) + om.Set(4, "baz") + om.Set(5, "28") + om.Set(6, "100") + om.Set(7, "baz") + om.Set(8, "baz") + + value, err := om.GetAndMoveToBack(3) + assert.Nil(t, err) + assert.Equal(t, 100, value) + assertOrderedPairsEqual(t, om, + []int{1, 2, 4, 5, 6, 7, 8, 3}, + []any{"bar", 28, "baz", "28", "100", "baz", "baz", 100}) + + value, err = om.GetAndMoveToFront(5) + assert.Nil(t, err) + assert.Equal(t, "28", value) + assertOrderedPairsEqual(t, om, + []int{5, 1, 2, 4, 6, 7, 8, 3}, + []any{"28", "bar", 28, "baz", "100", "baz", "baz", 100}) + + value, err = om.GetAndMoveToBack(100) + assert.Equal(t, &KeyNotFoundError[int]{100}, err) + assert.Nil(t, value) +} + +func TestAddPairs(t *testing.T) { + om := New[int, any]() + om.AddPairs( + Pair[int, any]{ + Key: 28, + Value: "foo", + }, + Pair[int, any]{ + Key: 12, + Value: "bar", + }, + Pair[int, any]{ + Key: 28, + Value: "baz", + }, + ) + + assertOrderedPairsEqual(t, om, + []int{28, 12}, + []any{"baz", "bar"}) +} + +// sadly, we can't test the "actual" capacity here, see https://github.com/golang/go/issues/52157 +func TestNewWithCapacity(t *testing.T) { + zero := New[int, string](0) + assert.Empty(t, zero.Len()) + + assert.PanicsWithValue(t, invalidOptionMessage, func() { + _ = New[int, string](1, 2) + }) + assert.PanicsWithValue(t, invalidOptionMessage, func() { + _ = New[int, string](1, 2, 3) + }) + + om := New[int, string](-1) + om.Set(1337, "quarante-deux") + assert.Equal(t, 1, om.Len()) +} + +func TestNewWithOptions(t *testing.T) { + t.Run("wih capacity", func(t *testing.T) { + om := New[string, any](WithCapacity[string, any](98)) + assert.Equal(t, 0, om.Len()) + }) + + t.Run("with initial data", func(t *testing.T) { + om := New[string, int](WithInitialData( + Pair[string, int]{ + Key: "a", + Value: 1, + }, + Pair[string, int]{ + Key: "b", + Value: 2, + }, + Pair[string, int]{ + Key: "c", + Value: 3, + }, + )) + + assertOrderedPairsEqual(t, om, + []string{"a", "b", "c"}, + []int{1, 2, 3}) + }) + + t.Run("with an invalid option type", func(t *testing.T) { + assert.PanicsWithValue(t, invalidOptionMessage, func() { + _ = New[int, string]("foo") + }) + }) +} + +func TestNilMap(t *testing.T) { + // we want certain behaviors of a nil ordered map to be the same as they are for standard nil maps + var om *OrderedMap[int, any] + + t.Run("len", func(t *testing.T) { + assert.Equal(t, 0, om.Len()) + }) + + t.Run("iterating - akin to range", func(t *testing.T) { + assert.Nil(t, om.Oldest()) + assert.Nil(t, om.Newest()) + }) +} diff --git a/pkg/omap/utils_test.go b/pkg/omap/utils_test.go new file mode 100644 index 0000000..bf15175 --- /dev/null +++ b/pkg/omap/utils_test.go @@ -0,0 +1,76 @@ +package omap + +import ( + "crypto/rand" + "encoding/hex" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +// assertOrderedPairsEqual asserts that the map contains the given keys and values +// from oldest to newest. +func assertOrderedPairsEqual[K comparable, V any]( + t *testing.T, orderedMap *OrderedMap[K, V], expectedKeys []K, expectedValues []V, +) { + t.Helper() + + assertOrderedPairsEqualFromNewest(t, orderedMap, expectedKeys, expectedValues) + assertOrderedPairsEqualFromOldest(t, orderedMap, expectedKeys, expectedValues) +} + +func assertOrderedPairsEqualFromNewest[K comparable, V any]( + t *testing.T, orderedMap *OrderedMap[K, V], expectedKeys []K, expectedValues []V, +) { + t.Helper() + + if assert.Equal(t, len(expectedKeys), len(expectedValues)) && assert.Equal(t, len(expectedKeys), orderedMap.Len()) { + i := orderedMap.Len() - 1 + for pair := orderedMap.Newest(); pair != nil; pair = pair.Prev() { + assert.Equal(t, expectedKeys[i], pair.Key, "from newest index=%d on key", i) + assert.Equal(t, expectedValues[i], pair.Value, "from newest index=%d on value", i) + i-- + } + } +} + +func assertOrderedPairsEqualFromOldest[K comparable, V any]( + t *testing.T, orderedMap *OrderedMap[K, V], expectedKeys []K, expectedValues []V, +) { + t.Helper() + + if assert.Equal(t, len(expectedKeys), len(expectedValues)) && assert.Equal(t, len(expectedKeys), orderedMap.Len()) { + i := 0 + for pair := orderedMap.Oldest(); pair != nil; pair = pair.Next() { + assert.Equal(t, expectedKeys[i], pair.Key, "from oldest index=%d on key", i) + assert.Equal(t, expectedValues[i], pair.Value, "from oldest index=%d on value", i) + i++ + } + } +} + +func assertLenEqual[K comparable, V any](t *testing.T, orderedMap *OrderedMap[K, V], expectedLen int) { + t.Helper() + + assert.Equal(t, expectedLen, orderedMap.Len()) + + // also check the list length, for good measure + assert.Equal(t, expectedLen, orderedMap.list.Len()) +} + +func randomHexString(t *testing.T, length int) string { + t.Helper() + + b := length / 2 //nolint:gomnd + randBytes := make([]byte, b) + + if n, err := rand.Read(randBytes); err != nil || n != b { + if err == nil { + err = fmt.Errorf("only got %v random bytes, expected %v", n, b) + } + t.Fatal(err) + } + + return hex.EncodeToString(randBytes) +} diff --git a/pkg/omap/wbuf.go b/pkg/omap/wbuf.go new file mode 100644 index 0000000..d962405 --- /dev/null +++ b/pkg/omap/wbuf.go @@ -0,0 +1,273 @@ +package omap + +import ( + "io" + "net" + "sync" +) + +// PoolConfig contains configuration for the allocation and reuse strategy. +type PoolConfig struct { + StartSize int // Minimum chunk size that is allocated. + PooledSize int // Minimum chunk size that is reused, reusing chunks too small will result in overhead. + MaxSize int // Maximum chunk size that will be allocated. +} + +var config = PoolConfig{ + StartSize: 128, + PooledSize: 512, + MaxSize: 32768, +} + +// Reuse pool: chunk size -> pool. +var buffers = map[int]*sync.Pool{} + +func initBuffers() { + for l := config.PooledSize; l <= config.MaxSize; l *= 2 { + buffers[l] = new(sync.Pool) + } +} + +func init() { + initBuffers() +} + +// Init sets up a non-default pooling and allocation strategy. Should be run before serialization is done. +func Init(cfg PoolConfig) { + config = cfg + initBuffers() +} + +// putBuf puts a chunk to reuse pool if it can be reused. +func putBuf(buf []byte) { + size := cap(buf) + if size < config.PooledSize { + return + } +} + +// getBuf gets a chunk from reuse pool or creates a new one if reuse failed. +func getBuf(size int) []byte { + if size >= config.PooledSize { + if c := buffers[size]; c != nil { + v := c.Get() + if v != nil { + return v.([]byte) + } + } + } + return make([]byte, 0, size) +} + +// Buffer is a buffer optimized for serialization without extra copying. +type Buffer struct { + + // Buf is the current chunk that can be used for serialization. + Buf []byte + + toPool []byte + bufs [][]byte +} + +// EnsureSpace makes sure that the current chunk contains at least s free bytes, +// possibly creating a new chunk. +func (b *Buffer) EnsureSpace(s int) { + if cap(b.Buf)-len(b.Buf) < s { + b.ensureSpaceSlow() + } +} + +func (b *Buffer) ensureSpaceSlow() { + l := len(b.Buf) + if l > 0 { + if cap(b.toPool) != cap(b.Buf) { + // Chunk was reallocated, toPool can be pooled. + putBuf(b.toPool) + } + if cap(b.bufs) == 0 { + b.bufs = make([][]byte, 0, 8) + } + b.bufs = append(b.bufs, b.Buf) + l = cap(b.toPool) * 2 + } else { + l = config.StartSize + } + + if l > config.MaxSize { + l = config.MaxSize + } + b.Buf = getBuf(l) + b.toPool = b.Buf +} + +// AppendByte appends a single byte to buffer. +func (b *Buffer) AppendByte(data byte) { + b.EnsureSpace(1) + b.Buf = append(b.Buf, data) +} + +// AppendBytes appends a byte slice to buffer. +func (b *Buffer) AppendBytes(data []byte) { + if len(data) <= cap(b.Buf)-len(b.Buf) { + b.Buf = append(b.Buf, data...) // fast path + } else { + b.appendBytesSlow(data) + } +} + +func (b *Buffer) appendBytesSlow(data []byte) { + for len(data) > 0 { + b.EnsureSpace(1) + + sz := cap(b.Buf) - len(b.Buf) + if sz > len(data) { + sz = len(data) + } + + b.Buf = append(b.Buf, data[:sz]...) + data = data[sz:] + } +} + +// AppendString appends a string to buffer. +func (b *Buffer) AppendString(data string) { + if len(data) <= cap(b.Buf)-len(b.Buf) { + b.Buf = append(b.Buf, data...) // fast path + } else { + b.appendStringSlow(data) + } +} + +func (b *Buffer) appendStringSlow(data string) { + for len(data) > 0 { + b.EnsureSpace(1) + + sz := cap(b.Buf) - len(b.Buf) + if sz > len(data) { + sz = len(data) + } + + b.Buf = append(b.Buf, data[:sz]...) + data = data[sz:] + } +} + +// Size computes the size of a buffer by adding sizes of every chunk. +func (b *Buffer) Size() int { + size := len(b.Buf) + for _, buf := range b.bufs { + size += len(buf) + } + return size +} + +// DumpTo outputs the contents of a buffer to a writer and resets the buffer. +func (b *Buffer) DumpTo(w io.Writer) (written int, err error) { + bufs := net.Buffers(b.bufs) + if len(b.Buf) > 0 { + bufs = append(bufs, b.Buf) + } + n, err := bufs.WriteTo(w) + + for _, buf := range b.bufs { + putBuf(buf) + } + putBuf(b.toPool) + + b.bufs = nil + b.Buf = nil + b.toPool = nil + + return int(n), err +} + +// BuildBytes creates a single byte slice with all the contents of the buffer. Data is +// copied if it does not fit in a single chunk. You can optionally provide one byte +// slice as argument that it will try to reuse. +func (b *Buffer) BuildBytes(reuse ...[]byte) []byte { + if len(b.bufs) == 0 { + ret := b.Buf + b.toPool = nil + b.Buf = nil + return ret + } + + var ret []byte + size := b.Size() + + // If we got a buffer as argument and it is big enough, reuse it. + if len(reuse) == 1 && cap(reuse[0]) >= size { + ret = reuse[0][:0] + } else { + ret = make([]byte, 0, size) + } + for _, buf := range b.bufs { + ret = append(ret, buf...) + putBuf(buf) + } + + ret = append(ret, b.Buf...) + putBuf(b.toPool) + + b.bufs = nil + b.toPool = nil + b.Buf = nil + + return ret +} + +type readCloser struct { + offset int + bufs [][]byte +} + +func (r *readCloser) Read(p []byte) (n int, err error) { + for _, buf := range r.bufs { + // Copy as much as we can. + x := copy(p[n:], buf[r.offset:]) + n += x // Increment how much we filled. + + // Did we empty the whole buffer? + if r.offset+x == len(buf) { + // On to the next buffer. + r.offset = 0 + r.bufs = r.bufs[1:] + + // We can release this buffer. + putBuf(buf) + } else { + r.offset += x + } + + if n == len(p) { + break + } + } + // No buffers left or nothing read? + if len(r.bufs) == 0 { + err = io.EOF + } + return +} + +func (r *readCloser) Close() error { + // Release all remaining buffers. + for _, buf := range r.bufs { + putBuf(buf) + } + // In case Close gets called multiple times. + r.bufs = nil + + return nil +} + +// ReadCloser creates an io.ReadCloser with all the contents of the buffer. +func (b *Buffer) ReadCloser() io.ReadCloser { + ret := &readCloser{0, append(b.bufs, b.Buf)} + + b.bufs = nil + b.toPool = nil + b.Buf = nil + + return ret +} diff --git a/pkg/omap/wbuf_test.go b/pkg/omap/wbuf_test.go new file mode 100644 index 0000000..3fc1ce9 --- /dev/null +++ b/pkg/omap/wbuf_test.go @@ -0,0 +1,107 @@ +package omap + +import ( + "bytes" + "testing" +) + +func TestAppendByte(t *testing.T) { + var b Buffer + var want []byte + + for i := 0; i < 1000; i++ { + b.AppendByte(1) + b.AppendByte(2) + want = append(want, 1, 2) + } + + got := b.BuildBytes() + if !bytes.Equal(got, want) { + t.Errorf("BuildBytes() = %v; want %v", got, want) + } +} + +func TestAppendBytes(t *testing.T) { + var b Buffer + var want []byte + + for i := 0; i < 1000; i++ { + b.AppendBytes([]byte{1, 2}) + want = append(want, 1, 2) + } + + got := b.BuildBytes() + if !bytes.Equal(got, want) { + t.Errorf("BuildBytes() = %v; want %v", got, want) + } +} + +func TestAppendString(t *testing.T) { + var b Buffer + var want []byte + + s := "test" + for i := 0; i < 1000; i++ { + b.AppendString(s) + want = append(want, s...) + } + + got := b.BuildBytes() + if !bytes.Equal(got, want) { + t.Errorf("BuildBytes() = %v; want %v", got, want) + } +} + +func TestDumpTo(t *testing.T) { + var b Buffer + var want []byte + + s := "test" + for i := 0; i < 1000; i++ { + b.AppendBytes([]byte(s)) + want = append(want, s...) + } + + out := &bytes.Buffer{} + n, err := b.DumpTo(out) + if err != nil { + t.Errorf("DumpTo() error: %v", err) + } + + got := out.Bytes() + if !bytes.Equal(got, want) { + t.Errorf("DumpTo(): got %v; want %v", got, want) + } + + if n != len(want) { + t.Errorf("DumpTo() = %v; want %v", n, len(want)) + } +} + +func TestReadCloser(t *testing.T) { + var b Buffer + var want []byte + + s := "test" + for i := 0; i < 1000; i++ { + b.AppendBytes([]byte(s)) + want = append(want, s...) + } + + out := &bytes.Buffer{} + rc := b.ReadCloser() + n, err := out.ReadFrom(rc) + if err != nil { + t.Errorf("ReadCloser() error: %v", err) + } + rc.Close() // Will always return nil + + got := out.Bytes() + if !bytes.Equal(got, want) { + t.Errorf("DumpTo(): got %v; want %v", got, want) + } + + if n != int64(len(want)) { + t.Errorf("DumpTo() = %v; want %v", n, len(want)) + } +} diff --git a/pkg/omap/writer.go b/pkg/omap/writer.go new file mode 100644 index 0000000..9240c4a --- /dev/null +++ b/pkg/omap/writer.go @@ -0,0 +1,430 @@ +package omap + +import ( + "io" + "strconv" + "unicode/utf8" +) + +// Flags describe various encoding options. The behavior may be actually implemented in the encoder, but +// Flags field in Writer is used to set and pass them around. +type Flags int + +const ( + // NilMapAsEmpty encodes nil map as '{}' rather than 'null'. + NilMapAsEmpty Flags = 1 << iota + // NilSliceAsEmpty encodes nil slice as '[]' rather than 'null'. + NilSliceAsEmpty +) + +// Writer is a JSON writer. +type Writer struct { + Flags Flags + + Error error + Buffer Buffer + NoEscapeHTML bool +} + +// Size returns the size of the data that was written out. +func (w *Writer) Size() int { + return w.Buffer.Size() +} + +// DumpTo outputs the data to given io.Writer, resetting the buffer. +func (w *Writer) DumpTo(out io.Writer) (written int, err error) { + return w.Buffer.DumpTo(out) +} + +// BuildBytes returns writer data as a single byte slice. You can optionally provide one byte slice +// as argument that it will try to reuse. +func (w *Writer) BuildBytes(reuse ...[]byte) ([]byte, error) { + if w.Error != nil { + return nil, w.Error + } + + return w.Buffer.BuildBytes(reuse...), nil +} + +// ReadCloser returns an io.ReadCloser that can be used to read the data. +// ReadCloser also resets the buffer. +func (w *Writer) ReadCloser() (io.ReadCloser, error) { + if w.Error != nil { + return nil, w.Error + } + + return w.Buffer.ReadCloser(), nil +} + +// RawByte appends raw binary data to the buffer. +func (w *Writer) RawByte(c byte) { + w.Buffer.AppendByte(c) +} + +// RawString appends raw binary data to the buffer. +func (w *Writer) RawString(s string) { + w.Buffer.AppendString(s) +} + +// Raw appends raw binary data to the buffer or sets the error if it is given. Useful for +// calling with results of MarshalJSON-like functions. +func (w *Writer) Raw(data []byte, err error) { + switch { + case w.Error != nil: + return + case err != nil: + w.Error = err + case len(data) > 0: + w.Buffer.AppendBytes(data) + default: + w.RawString("null") + } +} + +// RawText encloses raw binary data in quotes and appends in to the buffer. +// Useful for calling with results of MarshalText-like functions. +func (w *Writer) RawText(data []byte, err error) { + switch { + case w.Error != nil: + return + case err != nil: + w.Error = err + case len(data) > 0: + w.String(string(data)) + default: + w.RawString("null") + } +} + +// Base64Bytes appends data to the buffer after base64 encoding it +func (w *Writer) Base64Bytes(data []byte) { + if data == nil { + w.Buffer.AppendString("null") + return + } + w.Buffer.AppendByte('"') + w.base64(data) + w.Buffer.AppendByte('"') +} + +// Uint8 appends an uint8 to the buffer. +func (w *Writer) Uint8(n uint8) { + w.Buffer.EnsureSpace(3) + w.Buffer.Buf = strconv.AppendUint(w.Buffer.Buf, uint64(n), 10) +} + +// Uint16 appends an uint16 to the buffer. +func (w *Writer) Uint16(n uint16) { + w.Buffer.EnsureSpace(5) + w.Buffer.Buf = strconv.AppendUint(w.Buffer.Buf, uint64(n), 10) +} + +// Uint32 appends an uint32 to the buffer. +func (w *Writer) Uint32(n uint32) { + w.Buffer.EnsureSpace(10) + w.Buffer.Buf = strconv.AppendUint(w.Buffer.Buf, uint64(n), 10) +} + +// Uint appends an uint to the buffer. +func (w *Writer) Uint(n uint) { + w.Buffer.EnsureSpace(20) + w.Buffer.Buf = strconv.AppendUint(w.Buffer.Buf, uint64(n), 10) +} + +// Uint64 appends an uint64 to the buffer. +func (w *Writer) Uint64(n uint64) { + w.Buffer.EnsureSpace(20) + w.Buffer.Buf = strconv.AppendUint(w.Buffer.Buf, n, 10) +} + +// Int8 appends an int8 to the buffer. +func (w *Writer) Int8(n int8) { + w.Buffer.EnsureSpace(4) + w.Buffer.Buf = strconv.AppendInt(w.Buffer.Buf, int64(n), 10) +} + +// Int16 appends an int16 to the buffer. +func (w *Writer) Int16(n int16) { + w.Buffer.EnsureSpace(6) + w.Buffer.Buf = strconv.AppendInt(w.Buffer.Buf, int64(n), 10) +} + +// Int32 appends an int32 to the buffer. +func (w *Writer) Int32(n int32) { + w.Buffer.EnsureSpace(11) + w.Buffer.Buf = strconv.AppendInt(w.Buffer.Buf, int64(n), 10) +} + +// Int appends an int to the buffer. +func (w *Writer) Int(n int) { + w.Buffer.EnsureSpace(21) + w.Buffer.Buf = strconv.AppendInt(w.Buffer.Buf, int64(n), 10) +} + +// Int64 appends an int64 to the buffer. +func (w *Writer) Int64(n int64) { + w.Buffer.EnsureSpace(21) + w.Buffer.Buf = strconv.AppendInt(w.Buffer.Buf, n, 10) +} + +// Uint8Str appends an uint8 to the buffer as a quoted string. +func (w *Writer) Uint8Str(n uint8) { + w.Buffer.EnsureSpace(3) + w.Buffer.Buf = append(w.Buffer.Buf, '"') + w.Buffer.Buf = strconv.AppendUint(w.Buffer.Buf, uint64(n), 10) + w.Buffer.Buf = append(w.Buffer.Buf, '"') +} + +// Uint16Str appends an uint16 to the buffer as a quoted string. +func (w *Writer) Uint16Str(n uint16) { + w.Buffer.EnsureSpace(5) + w.Buffer.Buf = append(w.Buffer.Buf, '"') + w.Buffer.Buf = strconv.AppendUint(w.Buffer.Buf, uint64(n), 10) + w.Buffer.Buf = append(w.Buffer.Buf, '"') +} + +// Uint32Str appends an uint32 to the buffer as a quoted string. +func (w *Writer) Uint32Str(n uint32) { + w.Buffer.EnsureSpace(10) + w.Buffer.Buf = append(w.Buffer.Buf, '"') + w.Buffer.Buf = strconv.AppendUint(w.Buffer.Buf, uint64(n), 10) + w.Buffer.Buf = append(w.Buffer.Buf, '"') +} + +// UintStr appends an uint to the buffer as a quoted string. +func (w *Writer) UintStr(n uint) { + w.Buffer.EnsureSpace(20) + w.Buffer.Buf = append(w.Buffer.Buf, '"') + w.Buffer.Buf = strconv.AppendUint(w.Buffer.Buf, uint64(n), 10) + w.Buffer.Buf = append(w.Buffer.Buf, '"') +} + +// Uint64Str appends an uint64 to the buffer as a quoted string. +func (w *Writer) Uint64Str(n uint64) { + w.Buffer.EnsureSpace(20) + w.Buffer.Buf = append(w.Buffer.Buf, '"') + w.Buffer.Buf = strconv.AppendUint(w.Buffer.Buf, n, 10) + w.Buffer.Buf = append(w.Buffer.Buf, '"') +} + +// UintptrStr appends an uintptr to the buffer as a quoted string. +func (w *Writer) UintptrStr(n uintptr) { + w.Buffer.EnsureSpace(20) + w.Buffer.Buf = append(w.Buffer.Buf, '"') + w.Buffer.Buf = strconv.AppendUint(w.Buffer.Buf, uint64(n), 10) + w.Buffer.Buf = append(w.Buffer.Buf, '"') +} + +// Int8Str appends an int8 to the buffer as a quoted string. +func (w *Writer) Int8Str(n int8) { + w.Buffer.EnsureSpace(4) + w.Buffer.Buf = append(w.Buffer.Buf, '"') + w.Buffer.Buf = strconv.AppendInt(w.Buffer.Buf, int64(n), 10) + w.Buffer.Buf = append(w.Buffer.Buf, '"') +} + +// Int16Str appends an int16 to the buffer as a quoted string. +func (w *Writer) Int16Str(n int16) { + w.Buffer.EnsureSpace(6) + w.Buffer.Buf = append(w.Buffer.Buf, '"') + w.Buffer.Buf = strconv.AppendInt(w.Buffer.Buf, int64(n), 10) + w.Buffer.Buf = append(w.Buffer.Buf, '"') +} + +// Int32Str appends an int32 to the buffer as a quoted string. +func (w *Writer) Int32Str(n int32) { + w.Buffer.EnsureSpace(11) + w.Buffer.Buf = append(w.Buffer.Buf, '"') + w.Buffer.Buf = strconv.AppendInt(w.Buffer.Buf, int64(n), 10) + w.Buffer.Buf = append(w.Buffer.Buf, '"') +} + +// IntStr appends an int to the buffer as a quoted string. +func (w *Writer) IntStr(n int) { + w.Buffer.EnsureSpace(21) + w.Buffer.Buf = append(w.Buffer.Buf, '"') + w.Buffer.Buf = strconv.AppendInt(w.Buffer.Buf, int64(n), 10) + w.Buffer.Buf = append(w.Buffer.Buf, '"') +} + +// Int64Str appends an int64 to the buffer as a quoted string. +func (w *Writer) Int64Str(n int64) { + w.Buffer.EnsureSpace(21) + w.Buffer.Buf = append(w.Buffer.Buf, '"') + w.Buffer.Buf = strconv.AppendInt(w.Buffer.Buf, n, 10) + w.Buffer.Buf = append(w.Buffer.Buf, '"') +} + +// Float32 appends a float32 to the buffer. +func (w *Writer) Float32(n float32) { + w.Buffer.EnsureSpace(20) + w.Buffer.Buf = strconv.AppendFloat(w.Buffer.Buf, float64(n), 'g', -1, 32) +} + +// Float32Str appends a float32 to the buffer as a quoted string. +func (w *Writer) Float32Str(n float32) { + w.Buffer.EnsureSpace(20) + w.Buffer.Buf = append(w.Buffer.Buf, '"') + w.Buffer.Buf = strconv.AppendFloat(w.Buffer.Buf, float64(n), 'g', -1, 32) + w.Buffer.Buf = append(w.Buffer.Buf, '"') +} + +// Float64 appends a float64 to the buffer. +func (w *Writer) Float64(n float64) { + w.Buffer.EnsureSpace(20) + w.Buffer.Buf = strconv.AppendFloat(w.Buffer.Buf, n, 'g', -1, 64) +} + +// Float64Str appends a float64 to the buffer as a quoted string. +func (w *Writer) Float64Str(n float64) { + w.Buffer.EnsureSpace(20) + w.Buffer.Buf = append(w.Buffer.Buf, '"') + w.Buffer.Buf = strconv.AppendFloat(w.Buffer.Buf, float64(n), 'g', -1, 64) + w.Buffer.Buf = append(w.Buffer.Buf, '"') +} + +// Bool appends a bool to the buffer. +func (w *Writer) Bool(v bool) { + w.Buffer.EnsureSpace(5) + if v { + w.Buffer.Buf = append(w.Buffer.Buf, "true"...) + } else { + w.Buffer.Buf = append(w.Buffer.Buf, "false"...) + } +} + +const chars = "0123456789abcdef" + +func getTable(falseValues ...int) [128]bool { + table := [128]bool{} + + for i := 0; i < 128; i++ { + table[i] = true + } + + for _, v := range falseValues { + table[v] = false + } + + return table +} + +var ( + htmlEscapeTable = getTable(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, '"', '&', '<', '>', '\\') + htmlNoEscapeTable = getTable(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, '"', '\\') +) + +func (w *Writer) String(s string) { + w.Buffer.AppendByte('"') + + // Portions of the string that contain no escapes are appended as + // byte slices. + + p := 0 // last non-escape symbol + + escapeTable := &htmlEscapeTable + if w.NoEscapeHTML { + escapeTable = &htmlNoEscapeTable + } + + for i := 0; i < len(s); { + c := s[i] + + if c < utf8.RuneSelf { + if escapeTable[c] { + // single-width character, no escaping is required + i++ + continue + } + + w.Buffer.AppendString(s[p:i]) + switch c { + case '\t': + w.Buffer.AppendString(`\t`) + case '\r': + w.Buffer.AppendString(`\r`) + case '\n': + w.Buffer.AppendString(`\n`) + case '\\': + w.Buffer.AppendString(`\\`) + case '"': + w.Buffer.AppendString(`\"`) + default: + w.Buffer.AppendString(`\u00`) + w.Buffer.AppendByte(chars[c>>4]) + w.Buffer.AppendByte(chars[c&0xf]) + } + + i++ + p = i + continue + } + + // broken utf + runeValue, runeWidth := utf8.DecodeRuneInString(s[i:]) + if runeValue == utf8.RuneError && runeWidth == 1 { + w.Buffer.AppendString(s[p:i]) + w.Buffer.AppendString(`\ufffd`) + i++ + p = i + continue + } + + // jsonp stuff - tab separator and line separator + if runeValue == '\u2028' || runeValue == '\u2029' { + w.Buffer.AppendString(s[p:i]) + w.Buffer.AppendString(`\u202`) + w.Buffer.AppendByte(chars[runeValue&0xf]) + i += runeWidth + p = i + continue + } + i += runeWidth + } + w.Buffer.AppendString(s[p:]) + w.Buffer.AppendByte('"') +} + +const encode = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/" +const padChar = '=' + +func (w *Writer) base64(in []byte) { + + if len(in) == 0 { + return + } + + w.Buffer.EnsureSpace(((len(in)-1)/3 + 1) * 4) + + si := 0 + n := (len(in) / 3) * 3 + + for si < n { + // Convert 3x 8bit source bytes into 4 bytes + val := uint(in[si+0])<<16 | uint(in[si+1])<<8 | uint(in[si+2]) + + w.Buffer.Buf = append(w.Buffer.Buf, encode[val>>18&0x3F], encode[val>>12&0x3F], encode[val>>6&0x3F], encode[val&0x3F]) + + si += 3 + } + + remain := len(in) - si + if remain == 0 { + return + } + + // Add the remaining small block + val := uint(in[si+0]) << 16 + if remain == 2 { + val |= uint(in[si+1]) << 8 + } + + w.Buffer.Buf = append(w.Buffer.Buf, encode[val>>18&0x3F], encode[val>>12&0x3F]) + + switch remain { + case 2: + w.Buffer.Buf = append(w.Buffer.Buf, encode[val>>6&0x3F], byte(padChar)) + case 1: + w.Buffer.Buf = append(w.Buffer.Buf, byte(padChar), byte(padChar)) + } +} diff --git a/pkg/schema/schema.go b/pkg/schema/schema.go index 94ed0f7..545f6e6 100644 --- a/pkg/schema/schema.go +++ b/pkg/schema/schema.go @@ -12,7 +12,7 @@ import ( "strings" "time" - orderedmap "github.com/wk8/go-ordered-map/v2" + "github.com/conneroisu/groq-go/pkg/omap" ) const ( @@ -328,7 +328,7 @@ type ( // // Omitting this field has the same assertion behavior as an empty // object. - Properties *orderedmap.OrderedMap[string, *Schema] `json:"properties,omitempty"` + Properties *omap.OrderedMap[string, *Schema] `json:"properties,omitempty"` // PatternProperties are the pattern properties of the schema as specified in section 10.3.2.2 of RFC // draft-bhutton-json-schema-00. // @@ -1425,8 +1425,8 @@ func ToSnakeCase(str string) string { // newProperties is a helper method to instantiate a new properties ordered // map. -func newProperties() *orderedmap.OrderedMap[string, *Schema] { - return orderedmap.New[string, *Schema]() +func newProperties() *omap.OrderedMap[string, *Schema] { + return omap.New[string, *Schema]() } // Validate is used to check if the ID looks like a proper schema. diff --git a/scripts/generate-jigsaw-accents/go.mod b/scripts/generate-jigsaw-accents/go.mod index d9a6860..ce42a30 100644 --- a/scripts/generate-jigsaw-accents/go.mod +++ b/scripts/generate-jigsaw-accents/go.mod @@ -1,3 +1,5 @@ module github.com/conneroisu/groq-go/generate-jigsaw-accents go 1.23.2 + +require golang.org/x/text v0.18.0 diff --git a/scripts/generate-jigsaw-accents/go.sum b/scripts/generate-jigsaw-accents/go.sum new file mode 100644 index 0000000..94d17ef --- /dev/null +++ b/scripts/generate-jigsaw-accents/go.sum @@ -0,0 +1 @@ +golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224=