Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 61 additions & 25 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/riverqueue/river/internal/middlewarelookup"
"github.com/riverqueue/river/internal/notifier"
"github.com/riverqueue/river/internal/notifylimiter"
"github.com/riverqueue/river/internal/pluginconfig"
"github.com/riverqueue/river/internal/rivercommon"
"github.com/riverqueue/river/internal/rivermiddleware"
"github.com/riverqueue/river/internal/workunit"
Expand Down Expand Up @@ -219,6 +220,9 @@ type Config struct {
// work hook runs and the insertion hooks on either side of it are skipped.
//
// Jobs may have their own specific hooks by implementing JobArgsWithHooks.
//
// If a type in Hooks also implements rivertype.Middleware, it will be
// installed as middleware too.
Hooks []rivertype.Hook

// Logger is the structured logger to use for logging purposes. If none is
Expand Down Expand Up @@ -252,8 +256,25 @@ type Config struct {
// middlewares will run one after another, and the work middleware between
// them will not run. When a job is worked, the work middleware runs and the
// insertion middlewares on either side of it are skipped.
//
// If a type in Middleware also implements rivertype.Hook, it will be
// installed as a hook too.
Middleware []rivertype.Middleware

// Plugins contains extensions installed globally as both hooks and
// middleware.
//
// A plugin must implement both rivertype.Hook and rivertype.Middleware. It
// may embed PluginDefaults, or embed both HookDefaults and
// MiddlewareDefaults directly, then implement any operation-specific hook or
// middleware interfaces it needs.
//
// Hooks and Middleware are still supported. If a type in Hooks also
// implements middleware, or a type in Middleware also implements hooks, River
// will install it on both sides automatically. The Plugins list exists as a
// generic place for new extensions to be registered.
Plugins []rivertype.Plugin

// PeriodicJobs are a set of periodic jobs to run at the specified intervals
// in the client.
PeriodicJobs []*PeriodicJob
Expand Down Expand Up @@ -476,6 +497,7 @@ func (c *Config) WithDefaults() *Config {
MaxAttempts: cmp.Or(c.MaxAttempts, MaxAttemptsDefault),
Middleware: c.Middleware,
PeriodicJobs: c.PeriodicJobs,
Plugins: c.Plugins,
PollOnly: c.PollOnly,
Queues: c.Queues,
ReindexerIndexNames: reindexerIndexNames,
Expand Down Expand Up @@ -774,7 +796,11 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client
}
}

for _, hook := range config.Hooks {
configuredMiddleware := middlewareFromConfig(config)
effectiveHooks := pluginconfig.Hooks(config.Hooks, configuredMiddleware, config.Plugins)
effectiveMiddleware := pluginconfig.Middleware(config.Hooks, configuredMiddleware, config.Plugins)

for _, hook := range effectiveHooks {
if withBaseService, ok := hook.(baseservice.WithBaseService); ok {
baseservice.Init(archetype, withBaseService)
}
Expand All @@ -788,7 +814,7 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client
config: config,
driver: driver,
hookLookupByJob: hooklookup.NewJobHookLookup(),
hookLookupGlobal: hooklookup.NewHookLookup(config.Hooks),
hookLookupGlobal: hooklookup.NewHookLookup(effectiveHooks),
producersByQueueName: make(map[string]*producer),
testSignals: clientTestSignals{},
workCancel: func(cause error) {}, // replaced on start, but here in case StopAndCancel is called before start up
Expand All @@ -806,31 +832,12 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client
client.baseService.Name = "Client" // Have to correct the name because base service isn't embedded like it usually is
client.insertNotifyLimiter = notifylimiter.NewLimiter(archetype, config.FetchCooldown)

// Validation ensures that config.JobInsertMiddleware/WorkerMiddleware or
// the more abstract config.Middleware for middleware are set, but not both,
// so in practice we never append all three of these to each other.
// effectiveMiddleware contains configured middleware, hook/middleware
// hybrids, and plugins. Default middleware stays first so user middleware
// wraps inside River's internal defaults like before.
{
middleware := rivermiddleware.DefaultMiddleware()
middleware = append(middleware, config.Middleware...)
for _, jobInsertMiddleware := range config.JobInsertMiddleware {
middleware = append(middleware, jobInsertMiddleware)
}
outerLoop:
for _, workerMiddleware := range config.WorkerMiddleware {
// Don't add the middleware if it also implements JobInsertMiddleware
// and the instance has been added to config.JobInsertMiddleware. This
// is a hedge to make sure we don't accidentally double add middleware
// as we've converted over to the unified config.Middleware setting.
if workerMiddlewareAsJobInsertMiddleware, ok := workerMiddleware.(rivertype.JobInsertMiddleware); ok {
for _, jobInsertMiddleware := range config.JobInsertMiddleware {
if workerMiddlewareAsJobInsertMiddleware == jobInsertMiddleware {
continue outerLoop
}
}
}

middleware = append(middleware, workerMiddleware)
}
middleware = append(middleware, effectiveMiddleware...)

for _, middleware := range middleware {
if withBaseService, ok := middleware.(baseservice.WithBaseService); ok {
Expand Down Expand Up @@ -1040,6 +1047,35 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client
return client, nil
}

func middlewareFromConfig(config *Config) []rivertype.Middleware {
middleware := make([]rivertype.Middleware, 0,
len(config.Middleware)+len(config.JobInsertMiddleware)+len(config.WorkerMiddleware))
middleware = append(middleware, config.Middleware...)

for _, jobInsertMiddleware := range config.JobInsertMiddleware {
middleware = append(middleware, jobInsertMiddleware)
}

outerLoop:
for _, workerMiddleware := range config.WorkerMiddleware {
// Don't add the middleware if it also implements JobInsertMiddleware
// and the instance has been added to config.JobInsertMiddleware. This
// is a hedge to make sure we don't accidentally double add middleware
// as we've converted over to the unified config.Middleware setting.
if workerMiddlewareAsJobInsertMiddleware, ok := workerMiddleware.(rivertype.JobInsertMiddleware); ok {
for _, jobInsertMiddleware := range config.JobInsertMiddleware {
if workerMiddlewareAsJobInsertMiddleware == jobInsertMiddleware {
continue outerLoop
}
}
}

middleware = append(middleware, workerMiddleware)
}

return middleware
}

// Start starts the client's job fetching and working loops. Once this is called,
// the client will run in a background goroutine until stopped. All jobs are
// run with a context inheriting from the provided context, but with a timeout
Expand Down
186 changes: 186 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,77 @@ type noOpWorker struct {

func (w *noOpWorker) Work(ctx context.Context, job *Job[noOpArgs]) error { return nil }

var (
_ rivertype.HookInsertBegin = &hookMiddlewareEmbeddedDefaultsPlugin{}
_ rivertype.JobInsertMiddleware = &hookMiddlewareEmbeddedDefaultsPlugin{}
_ rivertype.Plugin = &hookMiddlewareEmbeddedDefaultsPlugin{}

_ rivertype.HookInsertBegin = &hookMiddlewarePlugin{}
_ rivertype.JobInsertMiddleware = &hookMiddlewarePlugin{}
_ rivertype.Plugin = &hookMiddlewarePlugin{}

_ rivertype.HookInsertBegin = hookMiddlewareValuePlugin{}
_ rivertype.JobInsertMiddleware = hookMiddlewareValuePlugin{}
)

type hookMiddlewareEmbeddedDefaultsPlugin struct {
HookDefaults
MiddlewareDefaults

insertBeginCount int
insertManyCount int
}

func (p *hookMiddlewareEmbeddedDefaultsPlugin) InsertBegin(ctx context.Context, params *rivertype.JobInsertParams) error {
p.insertBeginCount++
return nil
}

func (p *hookMiddlewareEmbeddedDefaultsPlugin) InsertMany(ctx context.Context, manyParams []*rivertype.JobInsertParams, doInner func(context.Context) ([]*rivertype.JobInsertResult, error)) ([]*rivertype.JobInsertResult, error) {
p.insertManyCount++
return doInner(ctx)
}

type hookMiddlewarePlugin struct {
PluginDefaults

insertBeginCount int
insertManyCount int
}

func (p *hookMiddlewarePlugin) InsertBegin(ctx context.Context, params *rivertype.JobInsertParams) error {
p.insertBeginCount++
return nil
}

func (p *hookMiddlewarePlugin) InsertMany(ctx context.Context, manyParams []*rivertype.JobInsertParams, doInner func(context.Context) ([]*rivertype.JobInsertResult, error)) ([]*rivertype.JobInsertResult, error) {
p.insertManyCount++
return doInner(ctx)
}

type hookMiddlewareValuePlugin struct {
counts *hookMiddlewareValuePluginCounts
}

func (p hookMiddlewareValuePlugin) InsertBegin(ctx context.Context, params *rivertype.JobInsertParams) error {
p.counts.insertBeginCount++
return nil
}

func (p hookMiddlewareValuePlugin) InsertMany(ctx context.Context, manyParams []*rivertype.JobInsertParams, doInner func(context.Context) ([]*rivertype.JobInsertResult, error)) ([]*rivertype.JobInsertResult, error) {
p.counts.insertManyCount++
return doInner(ctx)
}

func (p hookMiddlewareValuePlugin) IsHook() bool { return true }

func (p hookMiddlewareValuePlugin) IsMiddleware() bool { return true }

type hookMiddlewareValuePluginCounts struct {
insertBeginCount int
insertManyCount int
}

type periodicJobArgs struct{}

func (periodicJobArgs) Kind() string { return "periodic_job" }
Expand Down Expand Up @@ -8195,6 +8266,121 @@ func Test_NewClient_Overrides(t *testing.T) {
require.Len(t, client.config.WorkerMiddleware, 1)
}

func Test_NewClient_PluginsAndHybrids(t *testing.T) {
t.Parallel()

ctx := context.Background()

type testBundle struct {
config *Config
dbPool *pgxpool.Pool
}

setup := func(t *testing.T) *testBundle {
t.Helper()

dbPool := riversharedtest.DBPool(ctx, t)
driver := riverpgxv5.New(dbPool)
schema := riverdbtest.TestSchema(ctx, t, driver, nil)

return &testBundle{
config: newTestConfig(t, schema),
dbPool: dbPool,
}
}

insertAndRequireCounts := func(t *testing.T, bundle *testBundle, plugin *hookMiddlewarePlugin, expectedCount int) {
t.Helper()

client := newTestClient(t, bundle.dbPool, bundle.config)

_, err := client.Insert(ctx, noOpArgs{}, nil)
require.NoError(t, err)

require.Equal(t, expectedCount, plugin.insertBeginCount)
require.Equal(t, expectedCount, plugin.insertManyCount)
}

t.Run("DuplicatesAcrossHooksMiddlewareAndPluginsRunMultipleTimes", func(t *testing.T) {
t.Parallel()

bundle := setup(t)
plugin := &hookMiddlewarePlugin{}
bundle.config.Hooks = []rivertype.Hook{plugin}
bundle.config.Middleware = []rivertype.Middleware{plugin}
bundle.config.Plugins = []rivertype.Plugin{plugin}

insertAndRequireCounts(t, bundle, plugin, 3)
})

t.Run("HookAlsoRegisteredAsMiddleware", func(t *testing.T) {
t.Parallel()

bundle := setup(t)
plugin := &hookMiddlewarePlugin{}
bundle.config.Hooks = []rivertype.Hook{plugin}

insertAndRequireCounts(t, bundle, plugin, 1)
})

t.Run("MiddlewareAlsoRegisteredAsHook", func(t *testing.T) {
t.Parallel()

bundle := setup(t)
plugin := &hookMiddlewarePlugin{}
bundle.config.Middleware = []rivertype.Middleware{plugin}

insertAndRequireCounts(t, bundle, plugin, 1)
})

t.Run("PluginRegisteredAsHookAndMiddleware", func(t *testing.T) {
t.Parallel()

bundle := setup(t)
plugin := &hookMiddlewarePlugin{}
bundle.config.Plugins = []rivertype.Plugin{plugin}

insertAndRequireCounts(t, bundle, plugin, 1)
})

t.Run("PluginRegisteredWithEmbeddedDefaults", func(t *testing.T) {
t.Parallel()

bundle := setup(t)
plugin := &hookMiddlewareEmbeddedDefaultsPlugin{}
bundle.config.Plugins = []rivertype.Plugin{plugin}

client := newTestClient(t, bundle.dbPool, bundle.config)

_, err := client.Insert(ctx, noOpArgs{}, nil)
require.NoError(t, err)

require.Equal(t, 1, plugin.insertBeginCount)
require.Equal(t, 1, plugin.insertManyCount)
})

t.Run("SeparateEqualValueInstancesRunSeparately", func(t *testing.T) {
t.Parallel()

bundle := setup(t)
counts := &hookMiddlewareValuePluginCounts{}
bundle.config.Hooks = []rivertype.Hook{
hookMiddlewareValuePlugin{counts: counts},
}
bundle.config.Middleware = []rivertype.Middleware{
hookMiddlewareValuePlugin{counts: counts},
}

client := newTestClient(t, bundle.dbPool, bundle.config)

_, err := client.Insert(ctx, noOpArgs{}, nil)
require.NoError(t, err)

require.Equal(t, 2, counts.insertBeginCount)
require.Equal(t, 2, counts.insertManyCount)
})
}

func Test_NewClient_ReindexerIndexNamesExplicitEmptyOverride(t *testing.T) {
t.Parallel()

Expand Down
51 changes: 51 additions & 0 deletions internal/pluginconfig/plugin_config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package pluginconfig

import "github.com/riverqueue/river/rivertype"

// Hooks returns the effective hook list from configured hooks, middleware, and
// plugins. Explicit hooks are preserved first, followed by middleware that also
// implement hooks, then plugins.
func Hooks(hooks []rivertype.Hook, middleware []rivertype.Middleware, plugins []rivertype.Plugin) []rivertype.Hook {
effectiveHooks := make([]rivertype.Hook, 0, len(hooks)+len(middleware)+len(plugins))

effectiveHooks = append(effectiveHooks, hooks...)

for _, middlewareItem := range middleware {
hook, ok := middlewareItem.(rivertype.Hook)
if !ok {
continue
}

effectiveHooks = append(effectiveHooks, hook)
}

for _, plugin := range plugins {
effectiveHooks = append(effectiveHooks, plugin)
}

return effectiveHooks
}

// Middleware returns the effective middleware list from configured hooks,
// middleware, and plugins. Explicit middleware are preserved first, followed by
// hooks that also implement middleware, then plugins.
func Middleware(hooks []rivertype.Hook, middleware []rivertype.Middleware, plugins []rivertype.Plugin) []rivertype.Middleware {
effectiveMiddleware := make([]rivertype.Middleware, 0, len(hooks)+len(middleware)+len(plugins))

effectiveMiddleware = append(effectiveMiddleware, middleware...)

for _, hook := range hooks {
middlewareItem, ok := hook.(rivertype.Middleware)
if !ok {
continue
}

effectiveMiddleware = append(effectiveMiddleware, middlewareItem)
}

for _, plugin := range plugins {
effectiveMiddleware = append(effectiveMiddleware, plugin)
}

return effectiveMiddleware
}
Loading
Loading