diff --git a/contract/hooks.go b/contract/hooks.go index 86431ba..62c9897 100644 --- a/contract/hooks.go +++ b/contract/hooks.go @@ -1,6 +1,10 @@ package contract -import "net/http" +import ( + "net/http" + "slices" + "sync" +) // hooksKey is the unexported type used as the context key // for storing and retrieving [Hooks] from a request context. @@ -26,28 +30,91 @@ type BeforeWriteHeaderHook = func(w http.ResponseWriter, status int) // slice that is about to be sent. type BeforeWriteHook = func(w http.ResponseWriter, content []byte) -// Hooks defines the contract for registering and retrieving -// lifecycle callbacks during HTTP request processing. Middleware -// and handlers use these hooks to observe response events. -type Hooks interface { - // AfterResponse registers one or more callbacks to be invoked - // after the HTTP response has been fully written. - AfterResponse(callbacks ...AfterResponseHook) +// Hooks provides lifecycle hook registration for the HTTP +// request/response cycle. Middleware and handlers can attach +// callbacks that fire before headers are written, before the +// body is written, and after the response completes. +// +// All methods are safe for concurrent use. +type Hooks struct { + mutex sync.Mutex + afterResponseHooks []AfterResponseHook + beforeWriteHeaderHooks []BeforeWriteHeaderHook + beforeWriteHooks []BeforeWriteHook +} + +// NewHooks creates a [Hooks] instance with empty callback slices +// ready to accept registrations via the Before* and After* methods. +func NewHooks() *Hooks { + return &Hooks{ + beforeWriteHeaderHooks: []BeforeWriteHeaderHook{}, + beforeWriteHooks: []BeforeWriteHook{}, + afterResponseHooks: []AfterResponseHook{}, + } +} + +// AfterResponse registers one or more callbacks to be invoked +// after the HTTP response has been fully written. +func (hooks *Hooks) AfterResponse(callbacks ...AfterResponseHook) { + hooks.mutex.Lock() + defer hooks.mutex.Unlock() + + hooks.afterResponseHooks = append(hooks.afterResponseHooks, callbacks...) +} + +// AfterResponseFuncs returns a reversed clone of the registered +// AfterResponse callbacks. The reversal ensures that the most +// recently registered callback executes first (LIFO order). +func (hooks *Hooks) AfterResponseFuncs() []AfterResponseHook { + hooks.mutex.Lock() + defer hooks.mutex.Unlock() + + clone := slices.Clone(hooks.afterResponseHooks) + slices.Reverse(clone) + + return clone +} + +// BeforeWrite registers one or more callbacks to be invoked +// just before response body bytes are written. +func (hooks *Hooks) BeforeWrite(callbacks ...BeforeWriteHook) { + hooks.mutex.Lock() + defer hooks.mutex.Unlock() + + hooks.beforeWriteHooks = append(hooks.beforeWriteHooks, callbacks...) +} + +// BeforeWriteFuncs returns a reversed clone of the registered +// BeforeWrite callbacks. The reversal ensures that the most +// recently registered callback executes first (LIFO order). +func (hooks *Hooks) BeforeWriteFuncs() []BeforeWriteHook { + hooks.mutex.Lock() + defer hooks.mutex.Unlock() - // AfterResponseFuncs returns all registered after-response callbacks. - AfterResponseFuncs() []AfterResponseHook + clone := slices.Clone(hooks.beforeWriteHooks) + slices.Reverse(clone) - // BeforeWrite registers one or more callbacks to be invoked - // just before response body bytes are written. - BeforeWrite(callbacks ...BeforeWriteHook) + return clone +} + +// BeforeWriteHeader registers one or more callbacks to be invoked +// just before the response status code is written. +func (hooks *Hooks) BeforeWriteHeader(callbacks ...BeforeWriteHeaderHook) { + hooks.mutex.Lock() + defer hooks.mutex.Unlock() + + hooks.beforeWriteHeaderHooks = append(hooks.beforeWriteHeaderHooks, callbacks...) +} - // BeforeWriteFuncs returns all registered before-write callbacks. - BeforeWriteFuncs() []BeforeWriteHook +// BeforeWriteHeaderFuncs returns a reversed clone of the registered +// BeforeWriteHeader callbacks. The reversal ensures that the most +// recently registered callback executes first (LIFO order). +func (hooks *Hooks) BeforeWriteHeaderFuncs() []BeforeWriteHeaderHook { + hooks.mutex.Lock() + defer hooks.mutex.Unlock() - // BeforeWriteHeader registers one or more callbacks to be invoked - // just before the response status code is written. - BeforeWriteHeader(callbacks ...BeforeWriteHeaderHook) + clone := slices.Clone(hooks.beforeWriteHeaderHooks) + slices.Reverse(clone) - // BeforeWriteHeaderFuncs returns all registered before-write-header callbacks. - BeforeWriteHeaderFuncs() []BeforeWriteHeaderHook + return clone } diff --git a/contract/hooks_test.go b/contract/hooks_test.go index 3408d95..d3ffde6 100644 --- a/contract/hooks_test.go +++ b/contract/hooks_test.go @@ -1,6 +1,9 @@ package contract_test import ( + "net/http" + "net/http/httptest" + "sync" "testing" "github.com/stretchr/testify/require" @@ -20,3 +23,143 @@ func TestHooksKeyIsDistinctType(t *testing.T) { require.NotEqual(t, other, contract.HooksKey) } + +func TestNewHooksReturnsNonNil(t *testing.T) { + t.Parallel() + + hooks := contract.NewHooks() + + require.NotNil(t, hooks) +} + +func TestHooksAfterResponseRegisters(t *testing.T) { + t.Parallel() + + hooks := contract.NewHooks() + + var called bool + hooks.AfterResponse(func(err error) { called = true }) + + fns := hooks.AfterResponseFuncs() + + require.Len(t, fns, 1) + + fns[0](nil) + + require.True(t, called) +} + +func TestHooksAfterResponseFuncsReturnsLIFO(t *testing.T) { + t.Parallel() + + hooks := contract.NewHooks() + + var order []int + hooks.AfterResponse(func(err error) { order = append(order, 1) }) + hooks.AfterResponse(func(err error) { order = append(order, 2) }) + + for _, fn := range hooks.AfterResponseFuncs() { + fn(nil) + } + + require.Equal(t, []int{2, 1}, order) +} + +func TestHooksBeforeWriteRegisters(t *testing.T) { + t.Parallel() + + hooks := contract.NewHooks() + + var called bool + hooks.BeforeWrite(func(w http.ResponseWriter, content []byte) { called = true }) + + fns := hooks.BeforeWriteFuncs() + + require.Len(t, fns, 1) + + fns[0](httptest.NewRecorder(), nil) + + require.True(t, called) +} + +func TestHooksBeforeWriteFuncsReturnsLIFO(t *testing.T) { + t.Parallel() + + hooks := contract.NewHooks() + + var order []int + hooks.BeforeWrite(func(w http.ResponseWriter, content []byte) { order = append(order, 1) }) + hooks.BeforeWrite(func(w http.ResponseWriter, content []byte) { order = append(order, 2) }) + + for _, fn := range hooks.BeforeWriteFuncs() { + fn(httptest.NewRecorder(), nil) + } + + require.Equal(t, []int{2, 1}, order) +} + +func TestHooksBeforeWriteHeaderRegisters(t *testing.T) { + t.Parallel() + + hooks := contract.NewHooks() + + var called bool + hooks.BeforeWriteHeader(func(w http.ResponseWriter, status int) { called = true }) + + fns := hooks.BeforeWriteHeaderFuncs() + + require.Len(t, fns, 1) + + fns[0](httptest.NewRecorder(), 200) + + require.True(t, called) +} + +func TestHooksBeforeWriteHeaderFuncsReturnsLIFO(t *testing.T) { + t.Parallel() + + hooks := contract.NewHooks() + + var order []int + hooks.BeforeWriteHeader(func(w http.ResponseWriter, status int) { order = append(order, 1) }) + hooks.BeforeWriteHeader(func(w http.ResponseWriter, status int) { order = append(order, 2) }) + + for _, fn := range hooks.BeforeWriteHeaderFuncs() { + fn(httptest.NewRecorder(), 200) + } + + require.Equal(t, []int{2, 1}, order) +} + +func TestHooksEmptyFuncsReturnsEmptySlice(t *testing.T) { + t.Parallel() + + hooks := contract.NewHooks() + + require.Empty(t, hooks.AfterResponseFuncs()) + require.Empty(t, hooks.BeforeWriteFuncs()) + require.Empty(t, hooks.BeforeWriteHeaderFuncs()) +} + +func TestHooksConcurrentAccess(t *testing.T) { + t.Parallel() + + hooks := contract.NewHooks() + + var wg sync.WaitGroup + + for range 100 { + wg.Add(1) + + go func() { + defer wg.Done() + + hooks.AfterResponse(func(err error) {}) + hooks.AfterResponseFuncs() + }() + } + + wg.Wait() + + require.Len(t, hooks.AfterResponseFuncs(), 100) +} diff --git a/contract/mock/hooks.go b/contract/mock/hooks.go deleted file mode 100644 index ed46244..0000000 --- a/contract/mock/hooks.go +++ /dev/null @@ -1,319 +0,0 @@ -// Code generated by mockery; DO NOT EDIT. -// github.com/vektra/mockery -// template: testify - -package mock - -import ( - mock "github.com/stretchr/testify/mock" - "github.com/studiolambda/cosmos/contract" -) - -// NewHooksMock creates a new instance of HooksMock. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewHooksMock(t interface { - mock.TestingT - Cleanup(func()) -}) *HooksMock { - mock := &HooksMock{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} - -// HooksMock is an autogenerated mock type for the Hooks type -type HooksMock struct { - mock.Mock -} - -type HooksMock_Expecter struct { - mock *mock.Mock -} - -func (_m *HooksMock) EXPECT() *HooksMock_Expecter { - return &HooksMock_Expecter{mock: &_m.Mock} -} - -// AfterResponse provides a mock function for the type HooksMock -func (_mock *HooksMock) AfterResponse(callbacks ...contract.AfterResponseHook) { - if len(callbacks) > 0 { - _mock.Called(callbacks) - } else { - _mock.Called() - } - - return -} - -// HooksMock_AfterResponse_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AfterResponse' -type HooksMock_AfterResponse_Call struct { - *mock.Call -} - -// AfterResponse is a helper method to define mock.On call -// - callbacks ...contract.AfterResponseHook -func (_e *HooksMock_Expecter) AfterResponse(callbacks ...interface{}) *HooksMock_AfterResponse_Call { - return &HooksMock_AfterResponse_Call{Call: _e.mock.On("AfterResponse", - append([]interface{}{}, callbacks...)...)} -} - -func (_c *HooksMock_AfterResponse_Call) Run(run func(callbacks ...contract.AfterResponseHook)) *HooksMock_AfterResponse_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 []contract.AfterResponseHook - var variadicArgs []contract.AfterResponseHook - if len(args) > 0 { - variadicArgs = args[0].([]contract.AfterResponseHook) - } - arg0 = variadicArgs - run( - arg0..., - ) - }) - return _c -} - -func (_c *HooksMock_AfterResponse_Call) Return() *HooksMock_AfterResponse_Call { - _c.Call.Return() - return _c -} - -func (_c *HooksMock_AfterResponse_Call) RunAndReturn(run func(callbacks ...contract.AfterResponseHook)) *HooksMock_AfterResponse_Call { - _c.Run(run) - return _c -} - -// AfterResponseFuncs provides a mock function for the type HooksMock -func (_mock *HooksMock) AfterResponseFuncs() []contract.AfterResponseHook { - ret := _mock.Called() - - if len(ret) == 0 { - panic("no return value specified for AfterResponseFuncs") - } - - var r0 []contract.AfterResponseHook - if returnFunc, ok := ret.Get(0).(func() []contract.AfterResponseHook); ok { - r0 = returnFunc() - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]contract.AfterResponseHook) - } - } - return r0 -} - -// HooksMock_AfterResponseFuncs_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AfterResponseFuncs' -type HooksMock_AfterResponseFuncs_Call struct { - *mock.Call -} - -// AfterResponseFuncs is a helper method to define mock.On call -func (_e *HooksMock_Expecter) AfterResponseFuncs() *HooksMock_AfterResponseFuncs_Call { - return &HooksMock_AfterResponseFuncs_Call{Call: _e.mock.On("AfterResponseFuncs")} -} - -func (_c *HooksMock_AfterResponseFuncs_Call) Run(run func()) *HooksMock_AfterResponseFuncs_Call { - _c.Call.Run(func(args mock.Arguments) { - run() - }) - return _c -} - -func (_c *HooksMock_AfterResponseFuncs_Call) Return(vs []contract.AfterResponseHook) *HooksMock_AfterResponseFuncs_Call { - _c.Call.Return(vs) - return _c -} - -func (_c *HooksMock_AfterResponseFuncs_Call) RunAndReturn(run func() []contract.AfterResponseHook) *HooksMock_AfterResponseFuncs_Call { - _c.Call.Return(run) - return _c -} - -// BeforeWrite provides a mock function for the type HooksMock -func (_mock *HooksMock) BeforeWrite(callbacks ...contract.BeforeWriteHook) { - if len(callbacks) > 0 { - _mock.Called(callbacks) - } else { - _mock.Called() - } - - return -} - -// HooksMock_BeforeWrite_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'BeforeWrite' -type HooksMock_BeforeWrite_Call struct { - *mock.Call -} - -// BeforeWrite is a helper method to define mock.On call -// - callbacks ...contract.BeforeWriteHook -func (_e *HooksMock_Expecter) BeforeWrite(callbacks ...interface{}) *HooksMock_BeforeWrite_Call { - return &HooksMock_BeforeWrite_Call{Call: _e.mock.On("BeforeWrite", - append([]interface{}{}, callbacks...)...)} -} - -func (_c *HooksMock_BeforeWrite_Call) Run(run func(callbacks ...contract.BeforeWriteHook)) *HooksMock_BeforeWrite_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 []contract.BeforeWriteHook - var variadicArgs []contract.BeforeWriteHook - if len(args) > 0 { - variadicArgs = args[0].([]contract.BeforeWriteHook) - } - arg0 = variadicArgs - run( - arg0..., - ) - }) - return _c -} - -func (_c *HooksMock_BeforeWrite_Call) Return() *HooksMock_BeforeWrite_Call { - _c.Call.Return() - return _c -} - -func (_c *HooksMock_BeforeWrite_Call) RunAndReturn(run func(callbacks ...contract.BeforeWriteHook)) *HooksMock_BeforeWrite_Call { - _c.Run(run) - return _c -} - -// BeforeWriteFuncs provides a mock function for the type HooksMock -func (_mock *HooksMock) BeforeWriteFuncs() []contract.BeforeWriteHook { - ret := _mock.Called() - - if len(ret) == 0 { - panic("no return value specified for BeforeWriteFuncs") - } - - var r0 []contract.BeforeWriteHook - if returnFunc, ok := ret.Get(0).(func() []contract.BeforeWriteHook); ok { - r0 = returnFunc() - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]contract.BeforeWriteHook) - } - } - return r0 -} - -// HooksMock_BeforeWriteFuncs_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'BeforeWriteFuncs' -type HooksMock_BeforeWriteFuncs_Call struct { - *mock.Call -} - -// BeforeWriteFuncs is a helper method to define mock.On call -func (_e *HooksMock_Expecter) BeforeWriteFuncs() *HooksMock_BeforeWriteFuncs_Call { - return &HooksMock_BeforeWriteFuncs_Call{Call: _e.mock.On("BeforeWriteFuncs")} -} - -func (_c *HooksMock_BeforeWriteFuncs_Call) Run(run func()) *HooksMock_BeforeWriteFuncs_Call { - _c.Call.Run(func(args mock.Arguments) { - run() - }) - return _c -} - -func (_c *HooksMock_BeforeWriteFuncs_Call) Return(vs []contract.BeforeWriteHook) *HooksMock_BeforeWriteFuncs_Call { - _c.Call.Return(vs) - return _c -} - -func (_c *HooksMock_BeforeWriteFuncs_Call) RunAndReturn(run func() []contract.BeforeWriteHook) *HooksMock_BeforeWriteFuncs_Call { - _c.Call.Return(run) - return _c -} - -// BeforeWriteHeader provides a mock function for the type HooksMock -func (_mock *HooksMock) BeforeWriteHeader(callbacks ...contract.BeforeWriteHeaderHook) { - if len(callbacks) > 0 { - _mock.Called(callbacks) - } else { - _mock.Called() - } - - return -} - -// HooksMock_BeforeWriteHeader_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'BeforeWriteHeader' -type HooksMock_BeforeWriteHeader_Call struct { - *mock.Call -} - -// BeforeWriteHeader is a helper method to define mock.On call -// - callbacks ...contract.BeforeWriteHeaderHook -func (_e *HooksMock_Expecter) BeforeWriteHeader(callbacks ...interface{}) *HooksMock_BeforeWriteHeader_Call { - return &HooksMock_BeforeWriteHeader_Call{Call: _e.mock.On("BeforeWriteHeader", - append([]interface{}{}, callbacks...)...)} -} - -func (_c *HooksMock_BeforeWriteHeader_Call) Run(run func(callbacks ...contract.BeforeWriteHeaderHook)) *HooksMock_BeforeWriteHeader_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 []contract.BeforeWriteHeaderHook - var variadicArgs []contract.BeforeWriteHeaderHook - if len(args) > 0 { - variadicArgs = args[0].([]contract.BeforeWriteHeaderHook) - } - arg0 = variadicArgs - run( - arg0..., - ) - }) - return _c -} - -func (_c *HooksMock_BeforeWriteHeader_Call) Return() *HooksMock_BeforeWriteHeader_Call { - _c.Call.Return() - return _c -} - -func (_c *HooksMock_BeforeWriteHeader_Call) RunAndReturn(run func(callbacks ...contract.BeforeWriteHeaderHook)) *HooksMock_BeforeWriteHeader_Call { - _c.Run(run) - return _c -} - -// BeforeWriteHeaderFuncs provides a mock function for the type HooksMock -func (_mock *HooksMock) BeforeWriteHeaderFuncs() []contract.BeforeWriteHeaderHook { - ret := _mock.Called() - - if len(ret) == 0 { - panic("no return value specified for BeforeWriteHeaderFuncs") - } - - var r0 []contract.BeforeWriteHeaderHook - if returnFunc, ok := ret.Get(0).(func() []contract.BeforeWriteHeaderHook); ok { - r0 = returnFunc() - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]contract.BeforeWriteHeaderHook) - } - } - return r0 -} - -// HooksMock_BeforeWriteHeaderFuncs_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'BeforeWriteHeaderFuncs' -type HooksMock_BeforeWriteHeaderFuncs_Call struct { - *mock.Call -} - -// BeforeWriteHeaderFuncs is a helper method to define mock.On call -func (_e *HooksMock_Expecter) BeforeWriteHeaderFuncs() *HooksMock_BeforeWriteHeaderFuncs_Call { - return &HooksMock_BeforeWriteHeaderFuncs_Call{Call: _e.mock.On("BeforeWriteHeaderFuncs")} -} - -func (_c *HooksMock_BeforeWriteHeaderFuncs_Call) Run(run func()) *HooksMock_BeforeWriteHeaderFuncs_Call { - _c.Call.Run(func(args mock.Arguments) { - run() - }) - return _c -} - -func (_c *HooksMock_BeforeWriteHeaderFuncs_Call) Return(vs []contract.BeforeWriteHeaderHook) *HooksMock_BeforeWriteHeaderFuncs_Call { - _c.Call.Return(vs) - return _c -} - -func (_c *HooksMock_BeforeWriteHeaderFuncs_Call) RunAndReturn(run func() []contract.BeforeWriteHeaderHook) *HooksMock_BeforeWriteHeaderFuncs_Call { - _c.Call.Return(run) - return _c -} diff --git a/contract/pagination.go b/contract/pagination.go new file mode 100644 index 0000000..931bd05 --- /dev/null +++ b/contract/pagination.go @@ -0,0 +1,159 @@ +package contract + +import ( + "encoding/base64" + "encoding/json" + "errors" +) + +// ErrCursorEncode is returned when a cursor value fails to encode. +var ErrCursorEncode = errors.New("failed to encode cursor") + +// ErrCursorDecode is returned when a cursor string fails to decode. +var ErrCursorDecode = errors.New("failed to decode cursor") + +// Page represents an offset-based paginated result set. +type Page[T any] struct { + Items []T `json:"items"` + Total int64 `json:"total"` + PerPage int `json:"per_page"` + CurrentPage int `json:"current_page"` + LastPage int `json:"last_page"` +} + +// Cursor represents a cursor-based paginated result set. +type Cursor[T any] struct { + Items []T `json:"items"` + PerPage int `json:"per_page"` + NextCursor string `json:"next_cursor,omitempty"` + PrevCursor string `json:"prev_cursor,omitempty"` +} + +// Paginate creates a new [Page] from the given items, total count, +// current page number, and items per page. It computes the last +// page automatically. The current page is clamped to [1, LastPage]. +// +// page, perPage := request.Pagination(r) +// +// var users []User +// db.Select(ctx, "SELECT * FROM users LIMIT $1 OFFSET $2", &users, perPage, (page-1)*perPage) +// +// var total int64 +// db.Find(ctx, "SELECT COUNT(*) FROM users", &total) +// +// result := contract.Paginate(users, total, page, perPage) +func Paginate[T any](items []T, total int64, page, perPage int) Page[T] { + perPage = max(perPage, 1) + lastPage := max(int((total+int64(perPage)-1)/int64(perPage)), 1) + page = min(max(page, 1), lastPage) + + if items == nil { + items = []T{} + } + + return Page[T]{ + Items: items, + Total: total, + PerPage: perPage, + CurrentPage: page, + LastPage: lastPage, + } +} + +// CursorPaginate creates a new [Cursor] from the given items. The encode +// function determines how each item is transformed into an opaque +// cursor string. When hasNext is true, the last item is encoded to +// produce the next cursor. When hasPrev is true, the first item is +// encoded to produce the previous cursor. +// +// cursor, perPage := request.CursorPagination(r) +// +// var startID int64 +// if cursor != "" { +// startID, _ = contract.CursorDecode[int64](cursor) +// } +// +// var items []FeedItem +// db.Select(ctx, "SELECT * FROM feed WHERE id > $1 ORDER BY id LIMIT $2", &items, startID, perPage+1) +// +// hasNext := len(items) > perPage +// if hasNext { +// items = items[:perPage] +// } +// +// result, err := contract.CursorPaginate(items, perPage, hasNext, cursor != "", func(item FeedItem) (string, error) { +// return contract.CursorEncode(item.ID) +// }) +func CursorPaginate[T any](items []T, perPage int, hasNext, hasPrev bool, encode func(T) (string, error)) (Cursor[T], error) { + if items == nil { + items = []T{} + } + + result := Cursor[T]{ + Items: items, + PerPage: perPage, + } + + if len(items) == 0 { + return result, nil + } + + if hasNext { + encoded, err := encode(items[len(items)-1]) + + if err != nil { + return result, errors.Join(ErrCursorEncode, err) + } + + result.NextCursor = encoded + } + + if hasPrev { + encoded, err := encode(items[0]) + + if err != nil { + return result, errors.Join(ErrCursorEncode, err) + } + + result.PrevCursor = encoded + } + + return result, nil +} + +// CursorEncode encodes a value into an opaque cursor string using +// JSON serialization and base64url encoding. Use this as the encoding +// helper inside the encode function passed to [CursorPaginate]. +// +// encoded, err := contract.CursorEncode(user.ID) +func CursorEncode[V any](value V) (string, error) { + data, err := json.Marshal(value) + + if err != nil { + return "", err + } + + return base64.RawURLEncoding.EncodeToString(data), nil +} + +// CursorDecode decodes an opaque cursor string back into a typed +// value. It reverses the encoding performed by [CursorEncode]. +// +// id, err := contract.CursorDecode[int64](cursorString) +func CursorDecode[V any](cursor string) (V, error) { + var value V + + data, err := base64.RawURLEncoding.DecodeString(cursor) + + if err != nil { + var zero V + return zero, errors.Join(ErrCursorDecode, err) + } + + if err := json.Unmarshal(data, &value); err != nil { + var zero V + return zero, errors.Join(ErrCursorDecode, err) + } + + return value, nil +} diff --git a/contract/pagination_example_test.go b/contract/pagination_example_test.go new file mode 100644 index 0000000..c3afe16 --- /dev/null +++ b/contract/pagination_example_test.go @@ -0,0 +1,81 @@ +package contract_test + +import ( + "fmt" + + "github.com/studiolambda/cosmos/contract" +) + +func ExamplePaginate() { + items := []string{"a", "b", "c"} + + page := contract.Paginate(items, 10, 2, 3) + + fmt.Println(page.CurrentPage) + fmt.Println(page.LastPage) + fmt.Println(page.PerPage) + fmt.Println(page.Total) + fmt.Println(page.Items) + // Output: + // 2 + // 4 + // 3 + // 10 + // [a b c] +} + +func ExampleCursorPaginate() { + type User struct { + ID int + Name string + } + + users := []User{{ID: 10, Name: "Alice"}, {ID: 20, Name: "Bob"}} + + cursor, err := contract.CursorPaginate(users, 2, true, false, func(u User) (string, error) { + return contract.CursorEncode(u.ID) + }) + + if err != nil { + panic(err) + } + + fmt.Println(cursor.PerPage) + fmt.Println(cursor.NextCursor != "") + fmt.Println(cursor.PrevCursor) + + id, err := contract.CursorDecode[int](cursor.NextCursor) + + if err != nil { + panic(err) + } + + fmt.Println(id) + // Output: + // 2 + // true + // + // 20 +} + +func ExampleCursorEncode() { + encoded, err := contract.CursorEncode(42) + + if err != nil { + panic(err) + } + + fmt.Println(encoded) + // Output: NDI +} + +func ExampleCursorDecode() { + value, err := contract.CursorDecode[int]("NDI") + + if err != nil { + panic(err) + } + + fmt.Println(value) + // Output: 42 +} diff --git a/contract/pagination_test.go b/contract/pagination_test.go new file mode 100644 index 0000000..1df13f7 --- /dev/null +++ b/contract/pagination_test.go @@ -0,0 +1,277 @@ +package contract_test + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/require" + "github.com/studiolambda/cosmos/contract" +) + +func TestPaginateComputesLastPage(t *testing.T) { + t.Parallel() + + page := contract.Paginate([]string{"a", "b"}, 10, 1, 5) + + require.Equal(t, 2, page.LastPage) +} + +func TestPaginateComputesLastPageWithRemainder(t *testing.T) { + t.Parallel() + + page := contract.Paginate([]string{"a", "b"}, 11, 1, 5) + + require.Equal(t, 3, page.LastPage) +} + +func TestPaginateClampsPageBelowOne(t *testing.T) { + t.Parallel() + + page := contract.Paginate([]string{"a"}, 10, 0, 5) + + require.Equal(t, 1, page.CurrentPage) +} + +func TestPaginateClampsPageAboveLastPage(t *testing.T) { + t.Parallel() + + page := contract.Paginate([]string{}, 10, 99, 5) + + require.Equal(t, 2, page.CurrentPage) +} + +func TestPaginateClampsPerPageBelowOne(t *testing.T) { + t.Parallel() + + page := contract.Paginate([]string{"a"}, 5, 1, 0) + + require.Equal(t, 1, page.PerPage) +} + +func TestPaginateZeroTotalSetsLastPageOne(t *testing.T) { + t.Parallel() + + page := contract.Paginate([]string{}, 0, 1, 10) + + require.Equal(t, 1, page.LastPage) + require.Equal(t, 1, page.CurrentPage) +} + +func TestPaginateNilItemsBecomesEmptySlice(t *testing.T) { + t.Parallel() + + page := contract.Paginate[string](nil, 0, 1, 10) + + require.NotNil(t, page.Items) + require.Empty(t, page.Items) +} + +func TestPaginatePreservesItems(t *testing.T) { + t.Parallel() + + items := []int{1, 2, 3} + page := contract.Paginate(items, 100, 3, 10) + + require.Equal(t, items, page.Items) + require.Equal(t, int64(100), page.Total) + require.Equal(t, 3, page.CurrentPage) + require.Equal(t, 10, page.PerPage) +} + +func TestCursorPaginateEncodesNextCursor(t *testing.T) { + t.Parallel() + + items := []int{1, 2, 3} + cursor, err := contract.CursorPaginate(items, 3, true, false, func(item int) (string, error) { + return contract.CursorEncode(item) + }) + + require.NoError(t, err) + require.NotEmpty(t, cursor.NextCursor) + require.Empty(t, cursor.PrevCursor) +} + +func TestCursorPaginateEncodesPrevCursor(t *testing.T) { + t.Parallel() + + items := []int{1, 2, 3} + cursor, err := contract.CursorPaginate(items, 3, false, true, func(item int) (string, error) { + return contract.CursorEncode(item) + }) + + require.NoError(t, err) + require.Empty(t, cursor.NextCursor) + require.NotEmpty(t, cursor.PrevCursor) +} + +func TestCursorPaginateEncodesBothCursors(t *testing.T) { + t.Parallel() + + items := []int{1, 2, 3} + cursor, err := contract.CursorPaginate(items, 3, true, true, func(item int) (string, error) { + return contract.CursorEncode(item) + }) + + require.NoError(t, err) + require.NotEmpty(t, cursor.NextCursor) + require.NotEmpty(t, cursor.PrevCursor) +} + +func TestCursorPaginateEmptyItemsNoCursors(t *testing.T) { + t.Parallel() + + cursor, err := contract.CursorPaginate([]int{}, 10, true, true, func(item int) (string, error) { + return contract.CursorEncode(item) + }) + + require.NoError(t, err) + require.Empty(t, cursor.NextCursor) + require.Empty(t, cursor.PrevCursor) +} + +func TestCursorPaginateNilItemsBecomesEmptySlice(t *testing.T) { + t.Parallel() + + cursor, err := contract.CursorPaginate[int](nil, 10, false, false, func(item int) (string, error) { + return contract.CursorEncode(item) + }) + + require.NoError(t, err) + require.NotNil(t, cursor.Items) + require.Empty(t, cursor.Items) +} + +func TestCursorPaginatePreservesPerPage(t *testing.T) { + t.Parallel() + + cursor, err := contract.CursorPaginate([]int{1}, 25, false, false, func(item int) (string, error) { + return contract.CursorEncode(item) + }) + + require.NoError(t, err) + require.Equal(t, 25, cursor.PerPage) +} + +func TestCursorPaginateNextEncodeErrorReturnsErrCursorEncode(t *testing.T) { + t.Parallel() + + _, err := contract.CursorPaginate([]int{1}, 10, true, false, func(item int) (string, error) { + return "", errors.New("encode failed") + }) + + require.ErrorIs(t, err, contract.ErrCursorEncode) +} + +func TestCursorPaginatePrevEncodeErrorReturnsErrCursorEncode(t *testing.T) { + t.Parallel() + + _, err := contract.CursorPaginate([]int{1}, 10, false, true, func(item int) (string, error) { + return "", errors.New("encode failed") + }) + + require.ErrorIs(t, err, contract.ErrCursorEncode) +} + +func TestCursorPaginateCustomEncoder(t *testing.T) { + t.Parallel() + + items := []int{1, 2, 3} + cursor, err := contract.CursorPaginate(items, 3, true, false, func(item int) (string, error) { + return "custom-cursor", nil + }) + + require.NoError(t, err) + require.Equal(t, "custom-cursor", cursor.NextCursor) +} + +func TestCursorEncodeDecoderRoundTrip(t *testing.T) { + t.Parallel() + + encoded, err := contract.CursorEncode(42) + + require.NoError(t, err) + + value, err := contract.CursorDecode[int](encoded) + + require.NoError(t, err) + require.Equal(t, 42, value) +} + +func TestCursorEncodeStringRoundTrip(t *testing.T) { + t.Parallel() + + encoded, err := contract.CursorEncode("hello") + + require.NoError(t, err) + + value, err := contract.CursorDecode[string](encoded) + + require.NoError(t, err) + require.Equal(t, "hello", value) +} + +func TestCursorEncodeStructRoundTrip(t *testing.T) { + t.Parallel() + + type key struct { + ID int `json:"id"` + Name string `json:"name"` + } + + original := key{ID: 7, Name: "test"} + encoded, err := contract.CursorEncode(original) + + require.NoError(t, err) + + value, err := contract.CursorDecode[key](encoded) + + require.NoError(t, err) + require.Equal(t, original, value) +} + +func TestCursorEncodeUnencodableReturnsError(t *testing.T) { + t.Parallel() + + _, err := contract.CursorEncode(make(chan int)) + + require.Error(t, err) +} + +func TestCursorDecodeInvalidBase64ReturnsErrCursorDecode(t *testing.T) { + t.Parallel() + + _, err := contract.CursorDecode[int]("not-valid-base64!!!") + + require.ErrorIs(t, err, contract.ErrCursorDecode) +} + +func TestCursorDecodeInvalidJSONReturnsErrCursorDecode(t *testing.T) { + t.Parallel() + + // Valid base64 but invalid JSON. + _, err := contract.CursorDecode[int]("bm90LWpzb24") + + require.ErrorIs(t, err, contract.ErrCursorDecode) +} + +func TestCursorPaginateWithMarshalCursorEndToEnd(t *testing.T) { + t.Parallel() + + type User struct { + ID int + Name string + } + + users := []User{{ID: 10, Name: "Alice"}, {ID: 20, Name: "Bob"}} + + cursor, err := contract.CursorPaginate(users, 2, true, false, func(u User) (string, error) { + return contract.CursorEncode(u.ID) + }) + + require.NoError(t, err) + + id, err := contract.CursorDecode[int](cursor.NextCursor) + + require.NoError(t, err) + require.Equal(t, 20, id) +} diff --git a/contract/request/hooks.go b/contract/request/hooks.go index 7737acc..00cf465 100644 --- a/contract/request/hooks.go +++ b/contract/request/hooks.go @@ -22,8 +22,8 @@ var ErrNoHooksMiddleware = problem.Problem{ // WARNING: This function panics when hooks are missing. Use // [TryHooks] for a non-panicking alternative, or ensure the // [framework.Recover] middleware is in place. -func Hooks(r *http.Request) contract.Hooks { - if hooks, ok := r.Context().Value(contract.HooksKey).(contract.Hooks); ok { +func Hooks(r *http.Request) *contract.Hooks { + if hooks, ok := r.Context().Value(contract.HooksKey).(*contract.Hooks); ok { return hooks } @@ -34,8 +34,8 @@ func Hooks(r *http.Request) contract.Hooks { // context without panicking. The boolean return value indicates // whether hooks were found. This is the safe alternative to // [Hooks] for use outside the framework handler chain. -func TryHooks(r *http.Request) (contract.Hooks, bool) { - hooks, ok := r.Context().Value(contract.HooksKey).(contract.Hooks) +func TryHooks(r *http.Request) (*contract.Hooks, bool) { + hooks, ok := r.Context().Value(contract.HooksKey).(*contract.Hooks) return hooks, ok } diff --git a/contract/request/hooks_test.go b/contract/request/hooks_test.go index 1b6a57e..83d7022 100644 --- a/contract/request/hooks_test.go +++ b/contract/request/hooks_test.go @@ -11,24 +11,12 @@ import ( "github.com/studiolambda/cosmos/contract/request" ) -// stubHooks is a minimal implementation of contract.Hooks for testing. -type stubHooks struct{} - -func (stubHooks) AfterResponse(...contract.AfterResponseHook) {} -func (stubHooks) AfterResponseFuncs() []contract.AfterResponseHook { return nil } -func (stubHooks) BeforeWrite(...contract.BeforeWriteHook) {} -func (stubHooks) BeforeWriteFuncs() []contract.BeforeWriteHook { return nil } -func (stubHooks) BeforeWriteHeader(...contract.BeforeWriteHeaderHook) {} -func (stubHooks) BeforeWriteHeaderFuncs() []contract.BeforeWriteHeaderHook { - return nil -} - func TestHooksReturnsHooksFromContext(t *testing.T) { t.Parallel() - hooks := stubHooks{} + hooks := contract.NewHooks() ctx := context.WithValue( - context.Background(), contract.HooksKey, contract.Hooks(hooks), + context.Background(), contract.HooksKey, hooks, ) r := httptest.NewRequest( http.MethodGet, "/", nil, @@ -65,9 +53,9 @@ func TestHooksPanicsWithErrNoHooksMiddleware(t *testing.T) { func TestTryHooksReturnsTrueWhenPresent(t *testing.T) { t.Parallel() - hooks := stubHooks{} + hooks := contract.NewHooks() ctx := context.WithValue( - context.Background(), contract.HooksKey, contract.Hooks(hooks), + context.Background(), contract.HooksKey, hooks, ) r := httptest.NewRequest( http.MethodGet, "/", nil, diff --git a/contract/request/pagination.go b/contract/request/pagination.go new file mode 100644 index 0000000..083e8cd --- /dev/null +++ b/contract/request/pagination.go @@ -0,0 +1,41 @@ +package request + +import "net/http" + +// Pagination extracts the page number and per-page count from the +// request query parameters "page" and "per_page". It applies sensible +// defaults: page 1, 25 items per page, and a maximum of 100 items +// per page. Use [PaginationWith] for custom defaults and limits. +func Pagination(r *http.Request) (page, perPage int) { + return PaginationWith(r, 1, 25, 100) +} + +// PaginationWith extracts the page number and per-page count from the +// request query parameters "page" and "per_page" using the provided +// defaults and maximum per-page limit. The page is floored at 1 and +// the per-page is clamped between 1 and maxPerPage. +func PaginationWith(r *http.Request, defaultPage, defaultPerPage, maxPerPage int) (page, perPage int) { + page = max(QueryIntOr(r, "page", defaultPage), 1) + perPage = min(max(QueryIntOr(r, "per_page", defaultPerPage), 1), maxPerPage) + + return page, perPage +} + +// CursorPagination extracts the cursor string and per-page count from +// the request query parameters "cursor" and "per_page". It applies +// sensible defaults: 25 items per page and a maximum of 100 items per +// page. Use [CursorPaginationWith] for custom defaults and limits. +func CursorPagination(r *http.Request) (cursor string, perPage int) { + return CursorPaginationWith(r, 25, 100) +} + +// CursorPaginationWith extracts the cursor string and per-page count +// from the request query parameters "cursor" and "per_page" using the +// provided defaults and maximum per-page limit. The per-page is clamped +// between 1 and maxPerPage. +func CursorPaginationWith(r *http.Request, defaultPerPage, maxPerPage int) (cursor string, perPage int) { + cursor = Query(r, "cursor") + perPage = min(max(QueryIntOr(r, "per_page", defaultPerPage), 1), maxPerPage) + + return cursor, perPage +} diff --git a/contract/request/pagination_test.go b/contract/request/pagination_test.go new file mode 100644 index 0000000..6cb6e38 --- /dev/null +++ b/contract/request/pagination_test.go @@ -0,0 +1,195 @@ +package request_test + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + "github.com/studiolambda/cosmos/contract/request" +) + +func TestPaginationReturnsDefaults(t *testing.T) { + t.Parallel() + + r := httptest.NewRequest(http.MethodGet, "/", nil) + + page, perPage := request.Pagination(r) + + require.Equal(t, 1, page) + require.Equal(t, 25, perPage) +} + +func TestPaginationParsesQueryParams(t *testing.T) { + t.Parallel() + + r := httptest.NewRequest(http.MethodGet, "/?page=3&per_page=50", nil) + + page, perPage := request.Pagination(r) + + require.Equal(t, 3, page) + require.Equal(t, 50, perPage) +} + +func TestPaginationClampsPerPageToMax(t *testing.T) { + t.Parallel() + + r := httptest.NewRequest(http.MethodGet, "/?per_page=999", nil) + + _, perPage := request.Pagination(r) + + require.Equal(t, 100, perPage) +} + +func TestPaginationClampsPageBelowOne(t *testing.T) { + t.Parallel() + + r := httptest.NewRequest(http.MethodGet, "/?page=0", nil) + + page, _ := request.Pagination(r) + + require.Equal(t, 1, page) +} + +func TestPaginationClampsNegativePage(t *testing.T) { + t.Parallel() + + r := httptest.NewRequest(http.MethodGet, "/?page=-5", nil) + + page, _ := request.Pagination(r) + + require.Equal(t, 1, page) +} + +func TestPaginationClampsPerPageBelowOne(t *testing.T) { + t.Parallel() + + r := httptest.NewRequest(http.MethodGet, "/?per_page=0", nil) + + _, perPage := request.Pagination(r) + + require.Equal(t, 1, perPage) +} + +func TestPaginationWithCustomDefaults(t *testing.T) { + t.Parallel() + + r := httptest.NewRequest(http.MethodGet, "/", nil) + + page, perPage := request.PaginationWith(r, 2, 10, 50) + + require.Equal(t, 2, page) + require.Equal(t, 10, perPage) +} + +func TestPaginationWithCustomMax(t *testing.T) { + t.Parallel() + + r := httptest.NewRequest(http.MethodGet, "/?per_page=100", nil) + + _, perPage := request.PaginationWith(r, 1, 10, 50) + + require.Equal(t, 50, perPage) +} + +func TestPaginationIgnoresInvalidPage(t *testing.T) { + t.Parallel() + + r := httptest.NewRequest(http.MethodGet, "/?page=abc", nil) + + page, _ := request.Pagination(r) + + require.Equal(t, 1, page) +} + +func TestPaginationIgnoresInvalidPerPage(t *testing.T) { + t.Parallel() + + r := httptest.NewRequest(http.MethodGet, "/?per_page=abc", nil) + + _, perPage := request.Pagination(r) + + require.Equal(t, 25, perPage) +} + +func TestCursorPaginationReturnsDefaults(t *testing.T) { + t.Parallel() + + r := httptest.NewRequest(http.MethodGet, "/", nil) + + cursor, perPage := request.CursorPagination(r) + + require.Empty(t, cursor) + require.Equal(t, 25, perPage) +} + +func TestCursorPaginationParsesCursor(t *testing.T) { + t.Parallel() + + r := httptest.NewRequest(http.MethodGet, "/?cursor=abc123&per_page=50", nil) + + cursor, perPage := request.CursorPagination(r) + + require.Equal(t, "abc123", cursor) + require.Equal(t, 50, perPage) +} + +func TestCursorPaginationClampsPerPageToMax(t *testing.T) { + t.Parallel() + + r := httptest.NewRequest(http.MethodGet, "/?per_page=999", nil) + + _, perPage := request.CursorPagination(r) + + require.Equal(t, 100, perPage) +} + +func TestCursorPaginationClampsPerPageBelowOne(t *testing.T) { + t.Parallel() + + r := httptest.NewRequest(http.MethodGet, "/?per_page=0", nil) + + _, perPage := request.CursorPagination(r) + + require.Equal(t, 1, perPage) +} + +func TestCursorPaginationWithCustomDefaults(t *testing.T) { + t.Parallel() + + r := httptest.NewRequest(http.MethodGet, "/", nil) + + _, perPage := request.CursorPaginationWith(r, 10, 50) + + require.Equal(t, 10, perPage) +} + +func TestCursorPaginationWithCustomMax(t *testing.T) { + t.Parallel() + + r := httptest.NewRequest(http.MethodGet, "/?per_page=100", nil) + + _, perPage := request.CursorPaginationWith(r, 10, 50) + + require.Equal(t, 50, perPage) +} + +func TestCursorPaginationIgnoresInvalidPerPage(t *testing.T) { + t.Parallel() + + r := httptest.NewRequest(http.MethodGet, "/?per_page=abc", nil) + + _, perPage := request.CursorPagination(r) + + require.Equal(t, 25, perPage) +} + +func TestCursorPaginationNegativePerPage(t *testing.T) { + t.Parallel() + + r := httptest.NewRequest(http.MethodGet, "/?per_page=-5", nil) + + _, perPage := request.CursorPagination(r) + + require.Equal(t, 1, perPage) +} diff --git a/contract/response/static_test.go b/contract/response/static_test.go index 9ecd2cb..7df5ecb 100644 --- a/contract/response/static_test.go +++ b/contract/response/static_test.go @@ -438,16 +438,6 @@ func TestSafeRedirectRejectsUnparseableURL(t *testing.T) { require.ErrorIs(t, err, response.ErrUnsafeRedirect) } -func TestErrUnsafeRedirectMessage(t *testing.T) { - t.Parallel() - - require.Equal( - t, - "unsafe redirect URL: must be a relative path", - response.ErrUnsafeRedirect.Error(), - ) -} - func TestStringTemplateBuffersBeforeWritingStatus(t *testing.T) { t.Parallel() diff --git a/contract/session.go b/contract/session.go index 337a9c4..2181241 100644 --- a/contract/session.go +++ b/contract/session.go @@ -4,6 +4,7 @@ import ( "context" "crypto/rand" "encoding/base64" + "maps" "sync" "time" ) @@ -106,9 +107,7 @@ func (session *Session) All() map[string]any { result := make(map[string]any, len(session.storage)) - for k, v := range session.storage { - result[k] = v - } + maps.Copy(result, session.storage) return result } diff --git a/framework/handler.go b/framework/handler.go index db2007f..824b499 100644 --- a/framework/handler.go +++ b/framework/handler.go @@ -101,7 +101,7 @@ func (handler Handler) ServeHTTP( w http.ResponseWriter, r *http.Request, ) { - hooks := NewHooks() + hooks := contract.NewHooks() wrapped := NewResponseWriter(w, hooks) ctx := context.WithValue(r.Context(), contract.HooksKey, hooks) err := handler(wrapped, r.WithContext(ctx)) diff --git a/framework/handler_test.go b/framework/handler_test.go index a8a79fe..d1a2232 100644 --- a/framework/handler_test.go +++ b/framework/handler_test.go @@ -159,7 +159,7 @@ func TestServeHTTPAfterResponseHooksRun(t *testing.T) { var receivedErr atomic.Value h := framework.Handler(func(w http.ResponseWriter, r *http.Request) error { - hooks := r.Context().Value(contract.HooksKey).(contract.Hooks) + hooks := r.Context().Value(contract.HooksKey).(*contract.Hooks) hooks.AfterResponse(func(err error) { hookCalled.Store(true) receivedErr.Store(err) @@ -180,7 +180,7 @@ func TestServeHTTPAfterResponseHookPanicRecovered(t *testing.T) { t.Parallel() h := framework.Handler(func(w http.ResponseWriter, r *http.Request) error { - hooks := r.Context().Value(contract.HooksKey).(contract.Hooks) + hooks := r.Context().Value(contract.HooksKey).(*contract.Hooks) hooks.AfterResponse(func(err error) { panic("hook panic") }) @@ -203,7 +203,7 @@ func TestServeHTTPHooksInContext(t *testing.T) { var foundHooks atomic.Bool h := framework.Handler(func(w http.ResponseWriter, r *http.Request) error { - hooks, ok := r.Context().Value(contract.HooksKey).(contract.Hooks) + hooks, ok := r.Context().Value(contract.HooksKey).(*contract.Hooks) foundHooks.Store(ok && hooks != nil) return nil @@ -223,7 +223,7 @@ func TestServeHTTPAfterResponseHookReceivesNilOnSuccess(t *testing.T) { errChan := make(chan error, 1) h := framework.Handler(func(w http.ResponseWriter, r *http.Request) error { - hooks := r.Context().Value(contract.HooksKey).(contract.Hooks) + hooks := r.Context().Value(contract.HooksKey).(*contract.Hooks) hooks.AfterResponse(func(err error) { hookCalled.Store(true) errChan <- err diff --git a/framework/hooks.go b/framework/hooks.go deleted file mode 100644 index 532524e..0000000 --- a/framework/hooks.go +++ /dev/null @@ -1,102 +0,0 @@ -package framework - -import ( - "slices" - "sync" - - "github.com/studiolambda/cosmos/contract" -) - -// Hooks provides lifecycle hook registration for the HTTP -// request/response cycle. Middleware and handlers can attach -// callbacks that fire before headers are written, before the -// body is written, and after the response completes. -// -// All methods are safe for concurrent use. -type Hooks struct { - // mutex guards all hook slices. - mutex sync.Mutex - afterResponseHooks []contract.AfterResponseHook - beforeWriteHeaderHooks []contract.BeforeWriteHeaderHook - beforeWriteHooks []contract.BeforeWriteHook -} - -// NewHooks creates a Hooks instance with empty callback slices -// ready to accept registrations via the Before* and After* methods. -func NewHooks() *Hooks { - return &Hooks{ - beforeWriteHeaderHooks: []contract.BeforeWriteHeaderHook{}, - beforeWriteHooks: []contract.BeforeWriteHook{}, - afterResponseHooks: []contract.AfterResponseHook{}, - } -} - -// BeforeWriteHeader registers one or more callbacks that will be -// invoked just before the response status code is written. This -// is the last opportunity to inspect or modify headers. -func (hooks *Hooks) BeforeWriteHeader(callbacks ...contract.BeforeWriteHeaderHook) { - hooks.mutex.Lock() - defer hooks.mutex.Unlock() - - hooks.beforeWriteHeaderHooks = append(hooks.beforeWriteHeaderHooks, callbacks...) -} - -// BeforeWriteHeaderFuncs returns a reversed clone of the registered -// BeforeWriteHeader callbacks. The reversal ensures that the most -// recently registered callback executes first (LIFO order). -func (hooks *Hooks) BeforeWriteHeaderFuncs() []contract.BeforeWriteHeaderHook { - hooks.mutex.Lock() - defer hooks.mutex.Unlock() - - clone := slices.Clone(hooks.beforeWriteHeaderHooks) - slices.Reverse(clone) - - return clone -} - -// BeforeWrite registers one or more callbacks that will be -// invoked just before the response body bytes are written. -// This is useful for logging, metrics, or content transformation. -func (hooks *Hooks) BeforeWrite(callbacks ...contract.BeforeWriteHook) { - hooks.mutex.Lock() - defer hooks.mutex.Unlock() - - hooks.beforeWriteHooks = append(hooks.beforeWriteHooks, callbacks...) -} - -// BeforeWriteFuncs returns a reversed clone of the registered -// BeforeWrite callbacks. The reversal ensures that the most -// recently registered callback executes first (LIFO order). -func (hooks *Hooks) BeforeWriteFuncs() []contract.BeforeWriteHook { - hooks.mutex.Lock() - defer hooks.mutex.Unlock() - - clone := slices.Clone(hooks.beforeWriteHooks) - slices.Reverse(clone) - - return clone -} - -// AfterResponse registers one or more callbacks that will be -// invoked after the handler has completed and all response data -// has been written. The callback receives the handler's error -// (or nil if the handler succeeded). -func (hooks *Hooks) AfterResponse(callbacks ...contract.AfterResponseHook) { - hooks.mutex.Lock() - defer hooks.mutex.Unlock() - - hooks.afterResponseHooks = append(hooks.afterResponseHooks, callbacks...) -} - -// AfterResponseFuncs returns a reversed clone of the registered -// AfterResponse callbacks. The reversal ensures that the most -// recently registered callback executes first (LIFO order). -func (hooks *Hooks) AfterResponseFuncs() []contract.AfterResponseHook { - hooks.mutex.Lock() - defer hooks.mutex.Unlock() - - clone := slices.Clone(hooks.afterResponseHooks) - slices.Reverse(clone) - - return clone -} diff --git a/framework/hooks_test.go b/framework/hooks_test.go deleted file mode 100644 index 108badf..0000000 --- a/framework/hooks_test.go +++ /dev/null @@ -1,139 +0,0 @@ -package framework_test - -import ( - "net/http" - "testing" - - "github.com/studiolambda/cosmos/contract" - "github.com/studiolambda/cosmos/framework" - - "github.com/stretchr/testify/require" -) - -func TestNewHooksEmpty(t *testing.T) { - t.Parallel() - - hooks := framework.NewHooks() - - require.Empty(t, hooks.BeforeWriteHeaderFuncs()) - require.Empty(t, hooks.BeforeWriteFuncs()) - require.Empty(t, hooks.AfterResponseFuncs()) -} - -func TestBeforeWriteHeaderRegistersAndReturnsReversed(t *testing.T) { - t.Parallel() - - hooks := framework.NewHooks() - var order []int - - first := contract.BeforeWriteHeaderHook( - func(w http.ResponseWriter, status int) { order = append(order, 1) }, - ) - - second := contract.BeforeWriteHeaderHook( - func(w http.ResponseWriter, status int) { order = append(order, 2) }, - ) - - hooks.BeforeWriteHeader(first, second) - - funcs := hooks.BeforeWriteHeaderFuncs() - - require.Len(t, funcs, 2) - - for _, fn := range funcs { - fn(nil, 0) - } - - // Reversed: second fires first. - require.Equal(t, []int{2, 1}, order) -} - -func TestBeforeWriteRegistersAndReturnsReversed(t *testing.T) { - t.Parallel() - - hooks := framework.NewHooks() - var order []int - - first := contract.BeforeWriteHook( - func(w http.ResponseWriter, content []byte) { order = append(order, 1) }, - ) - - second := contract.BeforeWriteHook( - func(w http.ResponseWriter, content []byte) { order = append(order, 2) }, - ) - - hooks.BeforeWrite(first, second) - - funcs := hooks.BeforeWriteFuncs() - - require.Len(t, funcs, 2) - - for _, fn := range funcs { - fn(nil, nil) - } - - require.Equal(t, []int{2, 1}, order) -} - -func TestAfterResponseRegistersAndReturnsReversed(t *testing.T) { - t.Parallel() - - hooks := framework.NewHooks() - var order []int - - first := contract.AfterResponseHook( - func(err error) { order = append(order, 1) }, - ) - - second := contract.AfterResponseHook( - func(err error) { order = append(order, 2) }, - ) - - hooks.AfterResponse(first, second) - - funcs := hooks.AfterResponseFuncs() - - require.Len(t, funcs, 2) - - for _, fn := range funcs { - fn(nil) - } - - require.Equal(t, []int{2, 1}, order) -} - -func TestHooksMultipleRegistrations(t *testing.T) { - t.Parallel() - - hooks := framework.NewHooks() - var order []int - - hooks.AfterResponse(func(err error) { order = append(order, 1) }) - hooks.AfterResponse(func(err error) { order = append(order, 2) }) - hooks.AfterResponse(func(err error) { order = append(order, 3) }) - - funcs := hooks.AfterResponseFuncs() - - require.Len(t, funcs, 3) - - for _, fn := range funcs { - fn(nil) - } - - // LIFO: 3, 2, 1. - require.Equal(t, []int{3, 2, 1}, order) -} - -func TestHooksFuncsReturnClone(t *testing.T) { - t.Parallel() - - hooks := framework.NewHooks() - - hooks.AfterResponse(func(err error) {}) - - funcs := hooks.AfterResponseFuncs() - funcs[0] = nil - - // Original should be unaffected by mutation of the returned slice. - require.NotNil(t, hooks.AfterResponseFuncs()[0]) -} diff --git a/framework/hooks_writer.go b/framework/hooks_writer.go index 71cbde0..1594d4c 100644 --- a/framework/hooks_writer.go +++ b/framework/hooks_writer.go @@ -4,6 +4,8 @@ import ( "log/slog" "net/http" "sync/atomic" + + "github.com/studiolambda/cosmos/contract" ) // ResponseWriter wraps an http.ResponseWriter to intercept @@ -14,7 +16,7 @@ import ( // sync/atomic for safe concurrent access. type ResponseWriter struct { http.ResponseWriter - *Hooks + *contract.Hooks writeHeaderCalled atomic.Bool } @@ -40,7 +42,7 @@ type WrappedResponseWriter interface { // the given hooks on write operations. If the underlying writer // implements http.Flusher, the returned value also satisfies // http.Flusher via ResponseWriterFlusher. -func NewResponseWriter(writer http.ResponseWriter, hooks *Hooks) WrappedResponseWriter { +func NewResponseWriter(writer http.ResponseWriter, hooks *contract.Hooks) WrappedResponseWriter { wrapped := &ResponseWriter{ ResponseWriter: writer, Hooks: hooks, diff --git a/framework/hooks_writer_test.go b/framework/hooks_writer_test.go index e3ba629..e68fcda 100644 --- a/framework/hooks_writer_test.go +++ b/framework/hooks_writer_test.go @@ -6,6 +6,7 @@ import ( "sync/atomic" "testing" + "github.com/studiolambda/cosmos/contract" "github.com/studiolambda/cosmos/framework" "github.com/stretchr/testify/require" @@ -27,7 +28,7 @@ func (writer *flusherWriter) Flush() { func TestNewResponseWriterNonFlusher(t *testing.T) { t.Parallel() - hooks := framework.NewHooks() + hooks := contract.NewHooks() rec := httptest.NewRecorder() wrapped := framework.NewResponseWriter(&plainWriter{rec}, hooks) @@ -39,7 +40,7 @@ func TestNewResponseWriterNonFlusher(t *testing.T) { func TestNewResponseWriterFlusher(t *testing.T) { t.Parallel() - hooks := framework.NewHooks() + hooks := contract.NewHooks() rec := httptest.NewRecorder() wrapped := framework.NewResponseWriter( &flusherWriter{ResponseWriter: rec}, @@ -56,7 +57,7 @@ func TestNewResponseWriterFlusher(t *testing.T) { func TestWriteHeaderCalledInitiallyFalse(t *testing.T) { t.Parallel() - hooks := framework.NewHooks() + hooks := contract.NewHooks() rec := httptest.NewRecorder() wrapped := framework.NewResponseWriter(rec, hooks) @@ -66,7 +67,7 @@ func TestWriteHeaderCalledInitiallyFalse(t *testing.T) { func TestWriteHeaderSetsCalledFlag(t *testing.T) { t.Parallel() - hooks := framework.NewHooks() + hooks := contract.NewHooks() rec := httptest.NewRecorder() wrapped := framework.NewResponseWriter(rec, hooks) @@ -78,7 +79,7 @@ func TestWriteHeaderSetsCalledFlag(t *testing.T) { func TestWriteHeaderFiresBeforeWriteHeaderHooks(t *testing.T) { t.Parallel() - hooks := framework.NewHooks() + hooks := contract.NewHooks() rec := httptest.NewRecorder() wrapped := framework.NewResponseWriter(rec, hooks) @@ -99,7 +100,7 @@ func TestWriteHeaderFiresBeforeWriteHeaderHooks(t *testing.T) { func TestWriteHeaderSecondCallIsNoop(t *testing.T) { t.Parallel() - hooks := framework.NewHooks() + hooks := contract.NewHooks() rec := httptest.NewRecorder() wrapped := framework.NewResponseWriter(rec, hooks) @@ -121,7 +122,7 @@ func TestWriteHeaderSecondCallIsNoop(t *testing.T) { func TestWriteFiresBeforeWriteHooks(t *testing.T) { t.Parallel() - hooks := framework.NewHooks() + hooks := contract.NewHooks() rec := httptest.NewRecorder() wrapped := framework.NewResponseWriter(rec, hooks) @@ -145,7 +146,7 @@ func TestWriteFiresBeforeWriteHooks(t *testing.T) { func TestWriteAutoCallsWriteHeaderWith200(t *testing.T) { t.Parallel() - hooks := framework.NewHooks() + hooks := contract.NewHooks() rec := httptest.NewRecorder() wrapped := framework.NewResponseWriter(rec, hooks) @@ -159,7 +160,7 @@ func TestWriteAutoCallsWriteHeaderWith200(t *testing.T) { func TestWriteAfterWriteHeaderDoesNotCallAgain(t *testing.T) { t.Parallel() - hooks := framework.NewHooks() + hooks := contract.NewHooks() rec := httptest.NewRecorder() wrapped := framework.NewResponseWriter(rec, hooks) @@ -183,7 +184,7 @@ func TestWriteAfterWriteHeaderDoesNotCallAgain(t *testing.T) { func TestBeforeWriteHeaderHookPanicIsRecovered(t *testing.T) { t.Parallel() - hooks := framework.NewHooks() + hooks := contract.NewHooks() rec := httptest.NewRecorder() wrapped := framework.NewResponseWriter(rec, hooks) @@ -202,7 +203,7 @@ func TestBeforeWriteHeaderHookPanicIsRecovered(t *testing.T) { func TestBeforeWriteHookPanicIsRecovered(t *testing.T) { t.Parallel() - hooks := framework.NewHooks() + hooks := contract.NewHooks() rec := httptest.NewRecorder() wrapped := framework.NewResponseWriter(rec, hooks) @@ -226,7 +227,7 @@ func TestBeforeWriteHookPanicIsRecovered(t *testing.T) { func TestResponseWriterUnwrapReturnsUnderlying(t *testing.T) { t.Parallel() - hooks := framework.NewHooks() + hooks := contract.NewHooks() rec := httptest.NewRecorder() plain := &plainWriter{rec} wrapped := framework.NewResponseWriter(plain, hooks) @@ -244,7 +245,7 @@ func TestResponseWriterUnwrapReturnsUnderlying(t *testing.T) { func TestResponseControllerFlushThroughWrappedWriter(t *testing.T) { t.Parallel() - hooks := framework.NewHooks() + hooks := contract.NewHooks() rec := httptest.NewRecorder() fw := &flusherWriter{ResponseWriter: rec} wrapped := framework.NewResponseWriter(fw, hooks) @@ -259,7 +260,7 @@ func TestResponseControllerFlushThroughWrappedWriter(t *testing.T) { func TestResponseWriterFlusherUnwrapReturnsUnderlying(t *testing.T) { t.Parallel() - hooks := framework.NewHooks() + hooks := contract.NewHooks() rec := httptest.NewRecorder() fw := &flusherWriter{ResponseWriter: rec} wrapped := framework.NewResponseWriter(fw, hooks) @@ -277,7 +278,7 @@ func TestResponseWriterFlusherUnwrapReturnsUnderlying(t *testing.T) { func TestWriteHeaderHookReceivesUnderlyingWriter(t *testing.T) { t.Parallel() - hooks := framework.NewHooks() + hooks := contract.NewHooks() rec := httptest.NewRecorder() wrapped := framework.NewResponseWriter(rec, hooks)