diff --git a/CONTEXT.md b/CONTEXT.md new file mode 100644 index 0000000..dc33534 --- /dev/null +++ b/CONTEXT.md @@ -0,0 +1,72 @@ +# Needle Context + +Needle is a generic dependency-injection container for Go. The vocabulary below is what the codebase uses; pick these terms over their aliases. + +## Language + +**Container**: +The owner of registrations, lifecycle, and resolution. The thing users construct with `needle.New()`. +_Avoid_: registry (that's an internal sub-component), context (overloaded with `context.Context`). + +**Spec[T]**: +The configuration a caller hands to `Register[T]` to register a service of type `T`. Sum-typed: carries either a `Provider` or a `Value`, plus optional name, dependencies, scope, hooks, pool size, and lazy flag. The user-facing description of a service. +_Avoid_: definition, registration, config, builder. + +**ServiceEntry**: +The internal runtime state of a registered service: the spec's contents plus `sync.Once`, init error, pool channel, instantiation flag. Lives inside the registry; users never see it. +_Avoid_: service record, entry (when ambiguous). + +**Provider**: +A function `func(ctx, Resolver) (T, error)` that constructs an instance of `T`. One of the two things a `Spec[T]` can carry. +_Avoid_: factory, constructor (constructor refers to plain Go constructor functions used by the autowire path). + +**Hook**: +A function `func(ctx) error` called at a service's lifecycle transition. A `Spec[T]` carries at most one `OnStart` and one `OnStop`; multiple hooks compose via `needle.Compose(h1, h2)`. +_Avoid_: callback, listener, observer (observer is reserved for container-wide observation hooks like `WithStartObserver`). + +**Scope**: +When and how a service instance is reused: Singleton, Transient, Request, Pooled. A property of the spec; resolved differently per scope. +_Avoid_: lifetime (overloaded with lifecycle), strategy. + +**Resolver**: +The lookup interface a `Provider` uses to fetch dependencies during construction. In the deepened design this is the Container itself, not a separate adapter. +_Avoid_: locator, injector. + +**Module**: +A deferred recorder of typed registrations. `ModuleRegister[T](m, spec)` captures a closure that calls `Register[T]` against the container at `Apply` time. Modules carry no semantics beyond batched registration. +_Avoid_: bundle, package, group. + +**Decorator**: +A function that wraps a resolved instance of `T` to add cross-cutting behaviour. Registered separately from specs (cross-cutting; not per-service config). +_Avoid_: middleware, interceptor. + +**Binding**: +A spec where the provider resolves another key and returns it as the registered type. Built via `SpecFromBinding[I, T]()`. Lets an interface `I` be served by an implementation `T` already registered under a different key. +_Avoid_: alias, link. + +**Observer**: +A container-wide callback fired on `Resolve`/`Provide`/`Start`/`Stop`. Distinct from per-service Hooks: observers see every service, hooks fire on one service. +_Avoid_: listener, hook (hook is reserved for per-service lifecycle). + +## Relationships + +- A **Container** holds many **ServiceEntries**, one per registered key. +- A **Spec[T]** is the input to `Register`; the **Container** turns it into a **ServiceEntry**. +- A **ServiceEntry** carries at most one **Provider** (or a pre-built value), at most one **OnStart Hook**, at most one **OnStop Hook**, and exactly one **Scope**. +- A **Module** records typed `Spec[T]` closures and replays them against a **Container** at apply time. +- A **Binding** is a **Spec** whose **Provider** delegates to another key's resolution. +- **Decorators** attach to a key independently of the **Spec** for that key; one key can have many decorators. +- **Observers** attach to the **Container**, not to a **Spec**. + +## Example dialogue + +> **Dev:** "If I want a service to start lazily and run an `OnStart` hook the first time it resolves, what do I put on the **Spec**?" +> **Maintainer:** "Set `Lazy: true` and `OnStart: myHook`. The **Container** holds the **Spec** as a **ServiceEntry**; on first `Resolve` it constructs the instance via the **Provider** and runs the **OnStart Hook** because the container is already in the running state." + +> **Dev:** "Can I add two `OnStart` hooks to one service?" +> **Maintainer:** "A **Spec** carries one **Hook** per slot. Compose them with `needle.Compose(h1, h2)` and assign the composite. **Modules** wanting to layer extra behaviour register a **Decorator** instead — that's the cross-cutting path." + +## Flagged ambiguities + +- "service" was used loosely for both the runtime instance and the registration. Resolved: the registration is a **Spec[T]** (user-facing) or **ServiceEntry** (internal); the runtime instance is just "the resolved value" or "instance." +- "hook" previously named both per-service lifecycle callbacks and container-wide callbacks. Resolved: per-service is **Hook**; container-wide is **Observer**. diff --git a/README.md b/README.md index 94a07a5..b92446c 100644 --- a/README.md +++ b/README.md @@ -7,11 +7,11 @@ A modern, type-safe dependency injection framework for Go. ## Features -Needle uses Go generics for compile-time type safety (`Provide[T]`, `Invoke[T]`) and has zero external dependencies. +Needle uses Go generics for compile-time type safety (`Register[T]`, `Invoke[T]`) and has zero external dependencies. -It supports constructor auto-wiring, struct tag injection, multiple scopes (singleton, transient, request, pooled), and lifecycle hooks that run in dependency order. Services can start in parallel, be lazily initialized, or be replaced at runtime without restarting the container. +A single `Register[T]` entry point takes a typed `Spec[T]` carrying provider, dependencies, scope, hooks, pool size, and lazy flag. Constructor helpers (`SpecFromConstructor`, `SpecFromStruct`, `SpecFromBinding`, `SpecValue`) cover auto-wiring, struct-tag injection, interface binding, and pre-built values. Lifecycle hooks run in dependency order, services can start in parallel, and any spec can be replaced at runtime. -You can group providers into modules, bind interfaces to implementations, wrap services with decorators, and resolve optional dependencies with a built-in `Optional[T]` type. Health and readiness checks are supported out of the box. +You can group specs into modules, attach decorators for cross-cutting concerns, and resolve optional dependencies with the built-in `Optional[T]` type. Health and readiness checks are supported out of the box. ## Installation @@ -24,9 +24,11 @@ go get github.com/danpasecinic/needle ```go c := needle.New() -needle.ProvideValue(c, &Config{Port: 8080}) -needle.Provide(c, func(ctx context.Context, r needle.Resolver) (*Server, error) { - return &Server{Config: needle.MustInvoke[*Config](c)}, nil +needle.Register(c, needle.SpecValue(&Config{Port: 8080})) +needle.Register(c, needle.Spec[*Server]{ + Provider: func(ctx context.Context, r needle.Resolver) (*Server, error) { + return &Server{Config: needle.MustInvoke[*Config](c)}, nil + }, }) server := needle.MustInvoke[*Server](c) @@ -37,7 +39,7 @@ server := needle.MustInvoke[*Server](c) See the [examples](examples/) directory: - [basic](examples/basic/) - Simple dependency chain -- [autowire](examples/autowire/) - Struct-based injection +- [autowire](examples/autowire/) - Constructor and struct-tag injection - [httpserver](examples/httpserver/) - HTTP server with lifecycle - [modules](examples/modules/) - Modules and interface binding - [scopes](examples/scopes/) - Singleton, Transient, Request, Pooled @@ -47,6 +49,40 @@ See the [examples](examples/) directory: - [optional](examples/optional/) - Optional dependencies with fallbacks - [parallel](examples/parallel/) - Parallel startup/shutdown +## The Spec Type + +Every registration goes through one type. The defaults (zero values) cover the common case: singleton scope, eager initialization, no hooks. + +```go +type Spec[T any] struct { + Name string // optional, for named services + Provider Provider[T] // factory function (mutually exclusive with SpecValue) + Dependencies []string // explicit dependency keys + Scope Scope // Singleton (default), Transient, Request, Pooled + OnStart Hook // lifecycle hook on container Start + OnStop Hook // lifecycle hook on container Stop + PoolSize int // pool size when Scope is Pooled + Lazy bool // defer instantiation until first Resolve +} +``` + +Constructor helpers fill the spec for common patterns: + +```go +needle.SpecValue(&Config{Port: 8080}) // pre-built value +needle.SpecFromConstructor[*Database](NewDatabase) // auto-wire from func params +needle.SpecFromStruct[*UserService]() // auto-wire from `needle:""` tags +needle.SpecFromBinding[UserRepo, *PostgresRepo]() // bind interface to impl +``` + +Each helper returns a `Spec[T]` that you can field-tweak or chain via `WithName`, `WithScope`, `WithLazy`, etc: + +```go +needle.Register(c, needle.SpecFromConstructor[*Server](NewServer). + WithName("primary"). + WithLazy()) +``` + ## Choosing a Scope | Scope | Lifetime | Use When | @@ -57,10 +93,10 @@ See the [examples](examples/) directory: | **Pooled** | Reusable instances from a fixed-size pool | Expensive-to-create, stateless-between-uses resources: gRPC connections, worker objects | ```go -needle.Provide(c, NewService) // Singleton (default) -needle.Provide(c, NewHandler, needle.WithScope(needle.Transient)) -needle.Provide(c, NewRequestLogger, needle.WithScope(needle.Request)) -needle.Provide(c, NewWorker, needle.WithPoolSize(10)) // Pooled with 10 slots +needle.Register(c, needle.SpecFromConstructor[*Service](NewService)) // Singleton (default) +needle.Register(c, needle.SpecFromConstructor[*Handler](NewHandler).WithScope(needle.Transient)) // Transient +needle.Register(c, needle.SpecFromConstructor[*RequestLogger](NewRequestLogger).WithScope(needle.Request)) +needle.Register(c, needle.SpecFromConstructor[*Worker](NewWorker).WithPoolSize(10)) // Pooled ``` Pooled services must be released by the caller via `c.Release(key, instance)`. If the pool is full, the instance is dropped and a warning is logged. @@ -70,28 +106,35 @@ Pooled services must be released by the caller via `c.Release(key, instance)`. I Replace services at runtime without restarting the container. Useful for feature flags, A/B testing, test doubles, or configuration updates. ```go -// Replace with a new value -needle.ReplaceValue(c, &Config{Port: 9090}) +needle.Replace(c, needle.SpecValue(&Config{Port: 9090})) -// Replace with a new provider -needle.Replace(c, func(ctx context.Context, r needle.Resolver) (*Server, error) { - return &Server{Config: needle.MustInvoke[*Config](c)}, nil +needle.Replace(c, needle.Spec[*Server]{ + Provider: func(ctx context.Context, r needle.Resolver) (*Server, error) { + return &Server{Config: needle.MustInvoke[*Config](c)}, nil + }, }) -// Replace with auto-wired constructor -needle.ReplaceFunc[*Service](c, NewService) +needle.Replace(c, needle.SpecFromConstructor[*Service](NewService)) +needle.Replace(c, needle.SpecFromStruct[*Service]()) -// Replace with struct injection -needle.ReplaceStruct[*Service](c) - -// Named variants -needle.ReplaceNamedValue(c, "primary", &Config{Port: 5432}) -needle.ReplaceNamed(c, "primary", provider) +needle.Replace(c, needle.SpecValue(&Config{Port: 5432}).WithName("primary")) ``` -All Replace functions accept the same options as Provide (`WithScope`, `WithOnStart`, `WithOnStop`, `WithLazy`, `WithPoolSize`). If the service does not exist yet, Replace creates it. If it does exist, the old entry is removed from both the registry and the dependency graph before re-registering. +`Register` errors on a duplicate key. `Replace` overwrites if the key exists, or registers if it does not. The same `Spec[T]` type is the input to both -- only the intent differs. `MustRegister` and `MustReplace` panic on error. + +## Multiple Hooks + +`Spec[T]` carries one `OnStart` and one `OnStop` Hook. Combine multiple hooks with `Compose`: + +```go +needle.Register(c, needle.Spec[*Server]{ + Provider: NewServer, + OnStart: needle.Compose(installRoutes, openListener), + OnStop: needle.Compose(stopGracefully, flushLogs), +}) +``` -`Must` variants (`MustReplace`, `MustReplaceValue`, `MustReplaceFunc`, `MustReplaceStruct`) panic on error. +`Compose` runs hooks in argument order and returns the first error. ## Benchmarks diff --git a/autowire.go b/autowire.go index 5a43039..a541d38 100644 --- a/autowire.go +++ b/autowire.go @@ -15,6 +15,10 @@ func InvokeStruct[T any](c *Container) (T, error) { } func InvokeStructCtx[T any](ctx context.Context, c *Container) (T, error) { + return resolveStruct[T](ctx, c.resolver) +} + +func resolveStruct[T any](ctx context.Context, r Resolver) (T, error) { var zero T t := reflectPkg.TypeOf(zero) @@ -42,14 +46,14 @@ func InvokeStructCtx[T any](ctx context.Context, c *Container) (T, error) { key = field.TypeKey } - if !c.internal.Has(key) { + if !r.Has(key) { if field.Optional { continue } return zero, errServiceNotFound(key) } - instance, err := c.internal.Resolve(ctx, key) + instance, err := r.Resolve(ctx, key) if err != nil { if field.Optional { continue @@ -82,7 +86,7 @@ func InvokeStructCtx[T any](ctx context.Context, c *Container) (T, error) { return structVal.Interface().(T), nil } -func buildFuncProvider[T any](c *Container, constructor any) (Provider[T], []ProviderOption, error) { +func buildFuncProvider[T any](constructor any) (Provider[T], []string, error) { params, returnType, err := reflect.FuncParams(constructor) if err != nil { return nil, nil, err @@ -112,7 +116,7 @@ func buildFuncProvider[T any](c *Container, constructor any) (Provider[T], []Pro args := make([]reflectPkg.Value, len(params)) for i, p := range params { - instance, err := c.internal.Resolve(ctx, p.TypeKey) + instance, err := r.Resolve(ctx, p.TypeKey) if err != nil { return zero, fmt.Errorf("failed to resolve parameter %d (%s): %w", i, p.TypeKey, err) } @@ -128,12 +132,12 @@ func buildFuncProvider[T any](c *Container, constructor any) (Provider[T], []Pro return results[0].Interface().(T), nil } - return provider, []ProviderOption{WithDependencies(deps...)}, nil + return provider, deps, nil } -func buildStructProvider[T any](c *Container) (Provider[T], []ProviderOption) { +func buildStructProvider[T any]() (Provider[T], []string) { provider := func(ctx context.Context, r Resolver) (T, error) { - return InvokeStructCtx[T](ctx, c) + return resolveStruct[T](ctx, r) } fields, _ := reflect.StructFields[T](TagKey) @@ -148,33 +152,5 @@ func buildStructProvider[T any](c *Container) (Provider[T], []ProviderOption) { } } - return provider, []ProviderOption{WithDependencies(deps...)} -} - -func ProvideFunc[T any](c *Container, constructor any, opts ...ProviderOption) error { - provider, depOpts, err := buildFuncProvider[T](c, constructor) - if err != nil { - return err - } - - opts = append(depOpts, opts...) - return Provide(c, provider, opts...) -} - -func MustProvideFunc[T any](c *Container, constructor any, opts ...ProviderOption) { - if err := ProvideFunc[T](c, constructor, opts...); err != nil { - panic(err) - } -} - -func ProvideStruct[T any](c *Container, opts ...ProviderOption) error { - provider, depOpts := buildStructProvider[T](c) - opts = append(depOpts, opts...) - return Provide(c, provider, opts...) -} - -func MustProvideStruct[T any](c *Container, opts ...ProviderOption) { - if err := ProvideStruct[T](c, opts...); err != nil { - panic(err) - } + return provider, deps } diff --git a/autowire_test.go b/autowire_test.go index 0ae42d0..a050af7 100644 --- a/autowire_test.go +++ b/autowire_test.go @@ -36,8 +36,8 @@ func TestInvokeStruct(t *testing.T) { "resolves tagged fields", func(t *testing.T) { c := needle.New() - _ = needle.ProvideValue(c, &TestLogger{Name: "app"}) - _ = needle.ProvideValue(c, &TestDatabase{URL: "postgres://localhost"}) + _ = needle.Register(c, needle.SpecValue(&TestLogger{Name: "app"})) + _ = needle.Register(c, needle.SpecValue(&TestDatabase{URL: "postgres://localhost"})) svc, err := needle.InvokeStruct[*TestServiceWithTags](c) if err != nil { @@ -60,8 +60,8 @@ func TestInvokeStruct(t *testing.T) { "resolves named dependencies", func(t *testing.T) { c := needle.New() - _ = needle.ProvideNamedValue(c, "primary", &TestDatabase{URL: "primary-db"}) - _ = needle.ProvideNamedValue(c, "secondary", &TestDatabase{URL: "secondary-db"}) + _ = needle.Register(c, needle.SpecValue(&TestDatabase{URL: "primary-db"}).WithName("primary")) + _ = needle.Register(c, needle.SpecValue(&TestDatabase{URL: "secondary-db"}).WithName("secondary")) svc, err := needle.InvokeStruct[*TestServiceWithNamedDep](c) if err != nil { @@ -81,7 +81,7 @@ func TestInvokeStruct(t *testing.T) { "fails on missing required dependency", func(t *testing.T) { c := needle.New() - _ = needle.ProvideValue(c, &TestLogger{Name: "app"}) + _ = needle.Register(c, needle.SpecValue(&TestLogger{Name: "app"})) _, err := needle.InvokeStruct[*TestServiceWithTags](c) if err == nil { @@ -94,7 +94,7 @@ func TestInvokeStruct(t *testing.T) { "succeeds with missing optional dependency", func(t *testing.T) { c := needle.New() - _ = needle.ProvideNamedValue(c, "primary", &TestDatabase{URL: "primary-db"}) + _ = needle.Register(c, needle.SpecValue(&TestDatabase{URL: "primary-db"}).WithName("primary")) svc, err := needle.InvokeStruct[*TestServiceWithNamedDep](c) if err != nil { @@ -114,8 +114,8 @@ func TestInvokeStruct(t *testing.T) { "returns non-pointer struct", func(t *testing.T) { c := needle.New() - _ = needle.ProvideValue(c, &TestLogger{Name: "app"}) - _ = needle.ProvideValue(c, &TestDatabase{URL: "postgres://localhost"}) + _ = needle.Register(c, needle.SpecValue(&TestLogger{Name: "app"})) + _ = needle.Register(c, needle.SpecValue(&TestDatabase{URL: "postgres://localhost"})) svc, err := needle.InvokeStruct[TestServiceWithTags](c) if err != nil { @@ -153,13 +153,13 @@ func NewTestUserService(db *TestDatabase, logger *TestLogger) *TestUserService { return &TestUserService{DB: db, Logger: logger} } -func TestProvideFunc(t *testing.T) { +func TestSpecFromConstructor(t *testing.T) { t.Run( "auto-wires constructor parameters", func(t *testing.T) { c := needle.New() - _ = needle.ProvideFunc[*TestLogger](c, NewTestLogger) - _ = needle.ProvideFunc[*TestDatabase](c, NewTestDatabase) + _ = needle.Register(c, needle.SpecFromConstructor[*TestLogger](NewTestLogger)) + _ = needle.Register(c, needle.SpecFromConstructor[*TestDatabase](NewTestDatabase)) db, err := needle.Invoke[*TestDatabase](c) if err != nil { @@ -176,8 +176,8 @@ func TestProvideFunc(t *testing.T) { "handles constructor returning error", func(t *testing.T) { c := needle.New() - _ = needle.ProvideValue(c, &TestLogger{Name: "fail"}) - _ = needle.ProvideFunc[*TestDatabase](c, NewTestDatabaseWithError) + _ = needle.Register(c, needle.SpecValue(&TestLogger{Name: "fail"})) + _ = needle.Register(c, needle.SpecFromConstructor[*TestDatabase](NewTestDatabaseWithError)) _, err := needle.Invoke[*TestDatabase](c) if err == nil { @@ -190,9 +190,9 @@ func TestProvideFunc(t *testing.T) { "chains multiple auto-wired services", func(t *testing.T) { c := needle.New() - _ = needle.ProvideFunc[*TestLogger](c, NewTestLogger) - _ = needle.ProvideFunc[*TestDatabase](c, NewTestDatabase) - _ = needle.ProvideFunc[*TestUserService](c, NewTestUserService) + _ = needle.Register(c, needle.SpecFromConstructor[*TestLogger](NewTestLogger)) + _ = needle.Register(c, needle.SpecFromConstructor[*TestDatabase](NewTestDatabase)) + _ = needle.Register(c, needle.SpecFromConstructor[*TestUserService](NewTestUserService)) svc, err := needle.Invoke[*TestUserService](c) if err != nil { @@ -212,7 +212,7 @@ func TestProvideFunc(t *testing.T) { "fails on missing dependency", func(t *testing.T) { c := needle.New() - _ = needle.ProvideFunc[*TestDatabase](c, NewTestDatabase) + _ = needle.Register(c, needle.SpecFromConstructor[*TestDatabase](NewTestDatabase)) _, err := needle.Invoke[*TestDatabase](c) if err == nil { @@ -225,7 +225,7 @@ func TestProvideFunc(t *testing.T) { "works with zero-arg constructor", func(t *testing.T) { c := needle.New() - _ = needle.ProvideFunc[*TestLogger](c, NewTestLogger) + _ = needle.Register(c, needle.SpecFromConstructor[*TestLogger](NewTestLogger)) logger, err := needle.Invoke[*TestLogger](c) if err != nil { @@ -239,14 +239,14 @@ func TestProvideFunc(t *testing.T) { ) } -func TestProvideStruct(t *testing.T) { +func TestSpecFromStruct(t *testing.T) { t.Run( "registers struct with tagged fields", func(t *testing.T) { c := needle.New() - _ = needle.ProvideValue(c, &TestLogger{Name: "app"}) - _ = needle.ProvideValue(c, &TestDatabase{URL: "postgres://localhost"}) - _ = needle.ProvideStruct[*TestServiceWithTags](c) + _ = needle.Register(c, needle.SpecValue(&TestLogger{Name: "app"})) + _ = needle.Register(c, needle.SpecValue(&TestDatabase{URL: "postgres://localhost"})) + _ = needle.Register(c, needle.SpecFromStruct[*TestServiceWithTags]()) svc, err := needle.Invoke[*TestServiceWithTags](c) if err != nil { @@ -263,8 +263,8 @@ func TestProvideStruct(t *testing.T) { "validates dependencies on registration", func(t *testing.T) { c := needle.New() - _ = needle.ProvideValue(c, &TestLogger{Name: "app"}) - _ = needle.ProvideStruct[*TestServiceWithTags](c) + _ = needle.Register(c, needle.SpecValue(&TestLogger{Name: "app"})) + _ = needle.Register(c, needle.SpecFromStruct[*TestServiceWithTags]()) err := c.Validate() if err == nil { @@ -274,11 +274,11 @@ func TestProvideStruct(t *testing.T) { ) } -func TestProvideStructWithContext(t *testing.T) { +func TestSpecFromStructCtx(t *testing.T) { c := needle.New() - _ = needle.ProvideValue(c, &TestLogger{Name: "ctx-test"}) - _ = needle.ProvideValue(c, &TestDatabase{URL: "ctx-db"}) + _ = needle.Register(c, needle.SpecValue(&TestLogger{Name: "ctx-test"})) + _ = needle.Register(c, needle.SpecValue(&TestDatabase{URL: "ctx-db"})) ctx := context.Background() svc, err := needle.InvokeStructCtx[*TestServiceWithTags](ctx, c) @@ -291,35 +291,12 @@ func TestProvideStructWithContext(t *testing.T) { } } -func TestMustProvideFunc(t *testing.T) { - t.Run( - "panics on invalid constructor", func(t *testing.T) { - c := needle.New() - - defer func() { - if r := recover(); r == nil { - t.Error("expected panic") - } - }() - - needle.MustProvideFunc[*TestLogger](c, "not a function") - }, - ) -} +func TestSpecFromConstructor_InvalidConstructor(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("expected panic for invalid constructor") + } + }() -func TestMustProvideStruct(t *testing.T) { - t.Run( - "does not panic on valid struct", func(t *testing.T) { - c := needle.New() - - _ = needle.ProvideValue(c, &TestLogger{Name: "app"}) - _ = needle.ProvideValue(c, &TestDatabase{URL: "db"}) - - needle.MustProvideStruct[*TestServiceWithTags](c) - - if !needle.Has[*TestServiceWithTags](c) { - t.Error("struct should be registered") - } - }, - ) + _ = needle.SpecFromConstructor[*TestLogger]("not a function") } diff --git a/benchmark/invoke_test.go b/benchmark/invoke_test.go index 4186a65..5bed2ca 100644 --- a/benchmark/invoke_test.go +++ b/benchmark/invoke_test.go @@ -13,7 +13,7 @@ import ( func BenchmarkInvoke_Singleton_Needle(b *testing.B) { c := needle.New() - _ = needle.ProvideValue(c, &Config{Host: "localhost", Port: 8080}) + _ = needle.Register(c, needle.SpecValue(&Config{Host: "localhost", Port: 8080})) ctx := context.Background() _ = c.Start(ctx) @@ -69,35 +69,35 @@ func BenchmarkInvoke_Singleton_Fx(b *testing.B) { func BenchmarkInvoke_Chain_Needle(b *testing.B) { c := needle.New() - _ = needle.ProvideValue(c, &Config{Host: "localhost", Port: 8080}) - _ = needle.ProvideValue(c, &Logger{Level: "info"}) - _ = needle.Provide( - c, func(ctx context.Context, r needle.Resolver) (*Database, error) { + _ = needle.Register(c, needle.SpecValue(&Config{Host: "localhost", Port: 8080})) + _ = needle.Register(c, needle.SpecValue(&Logger{Level: "info"})) + _ = needle.Register(c, needle.Spec[*Database]{ + Provider: func(ctx context.Context, r needle.Resolver) (*Database, error) { cfg := needle.MustInvoke[*Config](c) log := needle.MustInvoke[*Logger](c) return &Database{Config: cfg, Logger: log}, nil }, - ) - _ = needle.Provide( - c, func(ctx context.Context, r needle.Resolver) (*Cache, error) { + }) + _ = needle.Register(c, needle.Spec[*Cache]{ + Provider: func(ctx context.Context, r needle.Resolver) (*Cache, error) { log := needle.MustInvoke[*Logger](c) return &Cache{Logger: log}, nil }, - ) - _ = needle.Provide( - c, func(ctx context.Context, r needle.Resolver) (*Repository, error) { + }) + _ = needle.Register(c, needle.Spec[*Repository]{ + Provider: func(ctx context.Context, r needle.Resolver) (*Repository, error) { db := needle.MustInvoke[*Database](c) cache := needle.MustInvoke[*Cache](c) return &Repository{DB: db, Cache: cache}, nil }, - ) - _ = needle.Provide( - c, func(ctx context.Context, r needle.Resolver) (*Service, error) { + }) + _ = needle.Register(c, needle.Spec[*Service]{ + Provider: func(ctx context.Context, r needle.Resolver) (*Service, error) { repo := needle.MustInvoke[*Repository](c) log := needle.MustInvoke[*Logger](c) return &Service{Repo: repo, Logger: log}, nil }, - ) + }) ctx := context.Background() _ = c.Start(ctx) diff --git a/benchmark/lifecycle_test.go b/benchmark/lifecycle_test.go index 82d0a08..c4b26f6 100644 --- a/benchmark/lifecycle_test.go +++ b/benchmark/lifecycle_test.go @@ -72,11 +72,12 @@ func benchmarkLifecycleNeedle(b *testing.B, count int, parallel bool) { for j := 0; j < count; j++ { idx := j key := fmt.Sprintf("svc_%d", j) - _ = needle.ProvideNamed( - c, key, func(ctx context.Context, r needle.Resolver) (*Config, error) { + _ = needle.Register(c, needle.Spec[*Config]{ + Name: key, + Provider: func(ctx context.Context, r needle.Resolver) (*Config, error) { return &Config{Port: idx}, nil }, - ) + }) } ctx := context.Background() @@ -99,23 +100,20 @@ func benchmarkLifecycleNeedleWithWork(b *testing.B, count int, parallel bool) { for j := 0; j < count; j++ { idx := j key := fmt.Sprintf("svc_%d", j) - _ = needle.ProvideNamed( - c, key, func(ctx context.Context, r needle.Resolver) (*Config, error) { + _ = needle.Register(c, needle.Spec[*Config]{ + Name: key, + Provider: func(ctx context.Context, r needle.Resolver) (*Config, error) { return &Config{Port: idx}, nil }, - needle.WithOnStart( - func(ctx context.Context) error { - time.Sleep(time.Millisecond) - return nil - }, - ), - needle.WithOnStop( - func(ctx context.Context) error { - time.Sleep(time.Millisecond) - return nil - }, - ), - ) + OnStart: func(ctx context.Context) error { + time.Sleep(time.Millisecond) + return nil + }, + OnStop: func(ctx context.Context) error { + time.Sleep(time.Millisecond) + return nil + }, + }) } ctx := context.Background() diff --git a/benchmark/named_test.go b/benchmark/named_test.go index 7fff55d..4b19d48 100644 --- a/benchmark/named_test.go +++ b/benchmark/named_test.go @@ -19,11 +19,12 @@ func BenchmarkNamed_10_Needle(b *testing.B) { for j := 0; j < 10; j++ { idx := j key := fmt.Sprintf("svc_%d", j) - _ = needle.ProvideNamed( - c, key, func(ctx context.Context, r needle.Resolver) (*Config, error) { + _ = needle.Register(c, needle.Spec[*Config]{ + Name: key, + Provider: func(ctx context.Context, r needle.Resolver) (*Config, error) { return &Config{Port: idx}, nil }, - ) + }) } } } diff --git a/benchmark/provide_test.go b/benchmark/provide_test.go index 3efdaa0..48db94a 100644 --- a/benchmark/provide_test.go +++ b/benchmark/provide_test.go @@ -15,7 +15,7 @@ func BenchmarkProvide_Simple_Needle(b *testing.B) { b.ReportAllocs() for i := 0; i < b.N; i++ { c := needle.New() - _ = needle.ProvideValue(c, &Config{Host: "localhost", Port: 8080}) + _ = needle.Register(c, needle.SpecValue(&Config{Host: "localhost", Port: 8080})) } } @@ -57,35 +57,35 @@ func BenchmarkProvide_Chain_Needle(b *testing.B) { b.ReportAllocs() for i := 0; i < b.N; i++ { c := needle.New() - _ = needle.ProvideValue(c, &Config{Host: "localhost", Port: 8080}) - _ = needle.ProvideValue(c, &Logger{Level: "info"}) - _ = needle.Provide( - c, func(ctx context.Context, r needle.Resolver) (*Database, error) { + _ = needle.Register(c, needle.SpecValue(&Config{Host: "localhost", Port: 8080})) + _ = needle.Register(c, needle.SpecValue(&Logger{Level: "info"})) + _ = needle.Register(c, needle.Spec[*Database]{ + Provider: func(ctx context.Context, r needle.Resolver) (*Database, error) { cfg := needle.MustInvoke[*Config](c) log := needle.MustInvoke[*Logger](c) return &Database{Config: cfg, Logger: log}, nil }, - ) - _ = needle.Provide( - c, func(ctx context.Context, r needle.Resolver) (*Cache, error) { + }) + _ = needle.Register(c, needle.Spec[*Cache]{ + Provider: func(ctx context.Context, r needle.Resolver) (*Cache, error) { log := needle.MustInvoke[*Logger](c) return &Cache{Logger: log}, nil }, - ) - _ = needle.Provide( - c, func(ctx context.Context, r needle.Resolver) (*Repository, error) { + }) + _ = needle.Register(c, needle.Spec[*Repository]{ + Provider: func(ctx context.Context, r needle.Resolver) (*Repository, error) { db := needle.MustInvoke[*Database](c) cache := needle.MustInvoke[*Cache](c) return &Repository{DB: db, Cache: cache}, nil }, - ) - _ = needle.Provide( - c, func(ctx context.Context, r needle.Resolver) (*Service, error) { + }) + _ = needle.Register(c, needle.Spec[*Service]{ + Provider: func(ctx context.Context, r needle.Resolver) (*Service, error) { repo := needle.MustInvoke[*Repository](c) log := needle.MustInvoke[*Logger](c) return &Service{Repo: repo, Logger: log}, nil }, - ) + }) } } diff --git a/benchmark_test.go b/benchmark_test.go index 87337cd..cef2469 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -138,19 +138,18 @@ func benchmarkStartup(b *testing.B, parallel bool, count int, workDuration time. for j := 0; j < count; j++ { idx := j key := fmt.Sprintf("svc_%d", j) - _ = ProvideNamed( - c, key, func(ctx context.Context, r Resolver) (*benchService, error) { + _ = Register(c, Spec[*benchService]{ + Name: key, + Provider: func(ctx context.Context, r Resolver) (*benchService, error) { return &benchService{id: idx}, nil }, - WithOnStart( - func(ctx context.Context) error { - if workDuration > 0 { - time.Sleep(workDuration) - } - return nil - }, - ), - ) + OnStart: func(ctx context.Context) error { + if workDuration > 0 { + time.Sleep(workDuration) + } + return nil + }, + }) } ctx := context.Background() @@ -176,19 +175,18 @@ func benchmarkShutdown(b *testing.B, parallel bool, count int, workDuration time for j := 0; j < count; j++ { idx := j key := fmt.Sprintf("svc_%d", j) - _ = ProvideNamed( - c, key, func(ctx context.Context, r Resolver) (*benchService, error) { + _ = Register(c, Spec[*benchService]{ + Name: key, + Provider: func(ctx context.Context, r Resolver) (*benchService, error) { return &benchService{id: idx}, nil }, - WithOnStop( - func(ctx context.Context) error { - if workDuration > 0 { - time.Sleep(workDuration) - } - return nil - }, - ), - ) + OnStop: func(ctx context.Context) error { + if workDuration > 0 { + time.Sleep(workDuration) + } + return nil + }, + }) } ctx := context.Background() @@ -223,28 +221,25 @@ func benchmarkDependencyChain(b *testing.B, parallel bool, depth int, workDurati deps = append(deps, prevKey) } - _ = ProvideNamed( - c, key, func(ctx context.Context, r Resolver) (*chainService, error) { + _ = Register(c, Spec[*chainService]{ + Name: key, + Provider: func(ctx context.Context, r Resolver) (*chainService, error) { return &chainService{level: level}, nil }, - WithDependencies(deps...), - WithOnStart( - func(ctx context.Context) error { - if workDuration > 0 { - time.Sleep(workDuration) - } - return nil - }, - ), - WithOnStop( - func(ctx context.Context) error { - if workDuration > 0 { - time.Sleep(workDuration) - } - return nil - }, - ), - ) + Dependencies: deps, + OnStart: func(ctx context.Context) error { + if workDuration > 0 { + time.Sleep(workDuration) + } + return nil + }, + OnStop: func(ctx context.Context) error { + if workDuration > 0 { + time.Sleep(workDuration) + } + return nil + }, + }) prevKey = key } @@ -279,51 +274,45 @@ func benchmarkWideDependencies(b *testing.B, parallel bool, width int, workDurat key := fmt.Sprintf("wide_%d", j) depKeys[j] = key - _ = ProvideNamed( - c, key, func(ctx context.Context, r Resolver) (*wideService, error) { + _ = Register(c, Spec[*wideService]{ + Name: key, + Provider: func(ctx context.Context, r Resolver) (*wideService, error) { return &wideService{id: idx}, nil }, - WithOnStart( - func(ctx context.Context) error { - if workDuration > 0 { - time.Sleep(workDuration) - } - return nil - }, - ), - WithOnStop( - func(ctx context.Context) error { - if workDuration > 0 { - time.Sleep(workDuration) - } - return nil - }, - ), - ) - } - - _ = ProvideNamed( - c, "aggregator", func(ctx context.Context, r Resolver) (*aggregatorService, error) { - return &aggregatorService{}, nil - }, - WithDependencies(depKeys...), - WithOnStart( - func(ctx context.Context) error { + OnStart: func(ctx context.Context) error { if workDuration > 0 { time.Sleep(workDuration) } return nil }, - ), - WithOnStop( - func(ctx context.Context) error { + OnStop: func(ctx context.Context) error { if workDuration > 0 { time.Sleep(workDuration) } return nil }, - ), - ) + }) + } + + _ = Register(c, Spec[*aggregatorService]{ + Name: "aggregator", + Provider: func(ctx context.Context, r Resolver) (*aggregatorService, error) { + return &aggregatorService{}, nil + }, + Dependencies: depKeys, + OnStart: func(ctx context.Context) error { + if workDuration > 0 { + time.Sleep(workDuration) + } + return nil + }, + OnStop: func(ctx context.Context) error { + if workDuration > 0 { + time.Sleep(workDuration) + } + return nil + }, + }) ctx := context.Background() b.StartTimer() diff --git a/bind.go b/bind.go index 4129d31..a0db5d2 100644 --- a/bind.go +++ b/bind.go @@ -9,72 +9,23 @@ import ( type Decorator[T any] func(ctx context.Context, r Resolver, base T) (T, error) -func Bind[I, T any](c *Container, opts ...ProviderOption) error { - cfg := &providerConfig{} - for _, opt := range opts { - opt(cfg) - } - - interfaceKey := reflect.TypeKey[I]() - implKey := reflect.TypeKey[T]() - - if cfg.name != "" { - interfaceKey = reflect.TypeKeyNamed[I](cfg.name) - } - - wrappedProvider := func(ctx context.Context, r container.Resolver) (any, error) { - return r.Resolve(ctx, implKey) - } - - if err := c.internal.Register(interfaceKey, wrappedProvider, []string{implKey}); err != nil { - return err - } - - for _, hook := range cfg.onStart { - c.internal.AddOnStart(interfaceKey, hook) - } - for _, hook := range cfg.onStop { - c.internal.AddOnStop(interfaceKey, hook) - } - - return nil -} - -func BindNamed[I, T any](c *Container, name string, opts ...ProviderOption) error { - opts = append(opts, WithName(name)) - return Bind[I, T](c, opts...) -} - func Decorate[T any](c *Container, decorator Decorator[T]) { - key := reflect.TypeKey[T]() - - c.internal.AddDecorator( - key, func(ctx context.Context, r container.Resolver, instance any) (any, error) { - typed, ok := instance.(T) - if !ok { - var zero T - return zero, errDecoratorTypeMismatch(reflect.TypeName[T]()) - } - - resolver := &resolverAdapter{container: c} - return decorator(ctx, resolver, typed) - }, - ) + decorateKey(c, reflect.TypeKey[T](), decorator) } func DecorateNamed[T any](c *Container, name string, decorator Decorator[T]) { - key := reflect.TypeKeyNamed[T](name) + decorateKey(c, reflect.TypeKeyNamed[T](name), decorator) +} +func decorateKey[T any](c *Container, key string, decorator Decorator[T]) { c.internal.AddDecorator( - key, func(ctx context.Context, r container.Resolver, instance any) (any, error) { + key, func(ctx context.Context, _ container.Resolver, instance any) (any, error) { typed, ok := instance.(T) if !ok { var zero T return zero, errDecoratorTypeMismatch(reflect.TypeName[T]()) } - - resolver := &resolverAdapter{container: c} - return decorator(ctx, resolver, typed) + return decorator(ctx, c.resolver, typed) }, ) } diff --git a/concurrent_test.go b/concurrent_test.go index 9bbd0ba..0798fcf 100644 --- a/concurrent_test.go +++ b/concurrent_test.go @@ -14,7 +14,7 @@ func TestConcurrentSingletonResolve(t *testing.T) { t.Parallel() c := New() - _ = ProvideValue(c, &testCounter{id: 42}) + _ = Register(c, SpecValue(&testCounter{id: 42})) const n = 100 results := make([]*testCounter, n) @@ -41,7 +41,7 @@ func TestConcurrentSingletonResolve(t *testing.T) { } } -func TestConcurrentNamedProvideAndInvoke(t *testing.T) { +func TestConcurrentNamedRegisterAndInvoke(t *testing.T) { t.Parallel() c := New() @@ -52,7 +52,7 @@ func TestConcurrentNamedProvideAndInvoke(t *testing.T) { for i := range n { go func(idx int) { defer wg.Done() - _ = ProvideNamedValue(c, fmt.Sprintf("s%d", idx), &concService{id: idx}) + _ = Register(c, SpecValue(&concService{id: idx}).WithName(fmt.Sprintf("s%d", idx))) }(i) } wg.Wait() @@ -80,13 +80,16 @@ func TestConcurrentPoolAcquireRelease(t *testing.T) { c := New() var created atomic.Int32 - _ = Provide(c, func(_ context.Context, _ Resolver) (*testCounter, error) { - return &testCounter{id: int(created.Add(1))}, nil - }, WithPoolSize(3)) + _ = Register(c, Spec[*testCounter]{ + Provider: func(_ context.Context, _ Resolver) (*testCounter, error) { + return &testCounter{id: int(created.Add(1))}, nil + }, + Scope: Pooled, + PoolSize: 3, + }) key := reflect.TypeKey[*testCounter]() - // Pre-fill: create 3 instances, then release all to pool instances := make([]*testCounter, 3) for i := range 3 { inst, err := Invoke[*testCounter](c) @@ -99,7 +102,6 @@ func TestConcurrentPoolAcquireRelease(t *testing.T) { c.Release(key, inst) } - // Concurrent acquire-release cycles from the pre-filled pool const n = 20 var wg sync.WaitGroup wg.Add(n) @@ -128,9 +130,13 @@ func TestConcurrentTransientDifferentKeys(t *testing.T) { for i := range n { idx := i - _ = ProvideNamed(c, fmt.Sprintf("t%d", idx), func(_ context.Context, _ Resolver) (*concService, error) { - return &concService{id: idx}, nil - }, WithScope(Transient)) + _ = Register(c, Spec[*concService]{ + Name: fmt.Sprintf("t%d", idx), + Provider: func(_ context.Context, _ Resolver) (*concService, error) { + return &concService{id: idx}, nil + }, + Scope: Transient, + }) } var wg sync.WaitGroup @@ -157,9 +163,12 @@ func TestConcurrentRequestScopeIsolation(t *testing.T) { c := New() var created atomic.Int32 - _ = Provide(c, func(_ context.Context, _ Resolver) (*testCounter, error) { - return &testCounter{id: int(created.Add(1))}, nil - }, WithScope(Request)) + _ = Register(c, Spec[*testCounter]{ + Provider: func(_ context.Context, _ Resolver) (*testCounter, error) { + return &testCounter{id: int(created.Add(1))}, nil + }, + Scope: Request, + }) const numContexts = 10 const resolvesPerCtx = 5 @@ -193,7 +202,7 @@ func TestConcurrentReplaceNoRace(t *testing.T) { t.Parallel() c := New() - _ = ProvideValue(c, &testCounter{id: 0}) + _ = Register(c, SpecValue(&testCounter{id: 0})) const n = 50 var wg sync.WaitGroup @@ -202,10 +211,8 @@ func TestConcurrentReplaceNoRace(t *testing.T) { go func(idx int) { defer wg.Done() if idx%2 == 0 { - _ = ReplaceValue(c, &testCounter{id: idx}) + _ = Replace(c, SpecValue(&testCounter{id: idx})) } else { - // Invoke may fail due to concurrent replace, - // we're verifying no panics or data races. _, _ = Invoke[*testCounter](c) } }(i) diff --git a/debug_test.go b/debug_test.go index dd03e74..c79c3e4 100644 --- a/debug_test.go +++ b/debug_test.go @@ -27,13 +27,13 @@ func TestPrintGraph(t *testing.T) { c := needle.New() - _ = needle.ProvideValue(c, &Config{Port: 8080}) - _ = needle.Provide( - c, func(ctx context.Context, r needle.Resolver) (*Database, error) { + _ = needle.Register(c, needle.SpecValue(&Config{Port: 8080})) + _ = needle.Register(c, needle.Spec[*Database]{ + Provider: func(ctx context.Context, r needle.Resolver) (*Database, error) { _ = needle.MustInvoke[*Config](c) return &Database{}, nil }, - ) + }) var buf bytes.Buffer c.FprintGraph(&buf) @@ -52,7 +52,7 @@ func TestPrintGraphWithInstantiated(t *testing.T) { c := needle.New() - _ = needle.ProvideValue(c, &Config{Port: 8080}) + _ = needle.Register(c, needle.SpecValue(&Config{Port: 8080})) _ = needle.MustInvoke[*Config](c) var buf bytes.Buffer @@ -60,7 +60,7 @@ func TestPrintGraphWithInstantiated(t *testing.T) { output := buf.String() if !strings.Contains(output, "●") { - t.Errorf("expected instantiated marker (●), got: %s", output) + t.Errorf("expected instantiated marker, got: %s", output) } } @@ -69,18 +69,18 @@ func TestPrintGraphNotInstantiated(t *testing.T) { c := needle.New() - _ = needle.Provide( - c, func(ctx context.Context, r needle.Resolver) (*Config, error) { + _ = needle.Register(c, needle.Spec[*Config]{ + Provider: func(ctx context.Context, r needle.Resolver) (*Config, error) { return &Config{Port: 8080}, nil }, - ) + }) var buf bytes.Buffer c.FprintGraph(&buf) output := buf.String() if !strings.Contains(output, "○") { - t.Errorf("expected not-instantiated marker (○), got: %s", output) + t.Errorf("expected not-instantiated marker, got: %s", output) } } @@ -88,7 +88,7 @@ func TestSprintGraph(t *testing.T) { t.Parallel() c := needle.New() - _ = needle.ProvideValue(c, &Config{Port: 8080}) + _ = needle.Register(c, needle.SpecValue(&Config{Port: 8080})) output := c.SprintGraph() if output == "" { @@ -101,12 +101,13 @@ func TestPrintGraphDOT(t *testing.T) { c := needle.New() - _ = needle.ProvideValue(c, &Config{Port: 8080}) - _ = needle.Provide( - c, func(ctx context.Context, r needle.Resolver) (*Database, error) { + _ = needle.Register(c, needle.SpecValue(&Config{Port: 8080})) + _ = needle.Register(c, needle.Spec[*Database]{ + Provider: func(ctx context.Context, r needle.Resolver) (*Database, error) { return &Database{}, nil - }, needle.WithDependencies("*needle_test.Config"), - ) + }, + Dependencies: []string{"*needle_test.Config"}, + }) var buf bytes.Buffer c.FprintGraphDOT(&buf) @@ -127,7 +128,7 @@ func TestSprintGraphDOT(t *testing.T) { t.Parallel() c := needle.New() - _ = needle.ProvideValue(c, &Config{Port: 8080}) + _ = needle.Register(c, needle.SpecValue(&Config{Port: 8080})) output := c.SprintGraphDOT() if !strings.Contains(output, "digraph") { @@ -140,12 +141,13 @@ func TestGraphInfo(t *testing.T) { c := needle.New() - _ = needle.ProvideValue(c, &Config{Port: 8080}) - _ = needle.Provide( - c, func(ctx context.Context, r needle.Resolver) (*Database, error) { + _ = needle.Register(c, needle.SpecValue(&Config{Port: 8080})) + _ = needle.Register(c, needle.Spec[*Database]{ + Provider: func(ctx context.Context, r needle.Resolver) (*Database, error) { return &Database{}, nil - }, needle.WithDependencies("*needle_test.Config"), - ) + }, + Dependencies: []string{"*needle_test.Config"}, + }) info := c.Graph() diff --git a/doc.go b/doc.go index b7db138..9c3914f 100644 --- a/doc.go +++ b/doc.go @@ -1,64 +1,77 @@ // Package needle provides a type-safe dependency injection framework for Go 1.25+. // -// Needle is designed to be simple yet powerful, offering compile-time type safety -// through generics, lifecycle management, scoped dependencies, and modular organization. +// Needle has one registration entry point. Every service registers via Register[T] +// passing a Spec[T] that captures provider, dependencies, scope, hooks, pool size, +// and lazy flag. Constructor helpers (SpecValue, SpecFromConstructor, +// SpecFromStruct, SpecFromBinding) cover common patterns. Decorators are attached +// separately via Decorate. // // # Quick Start // -// Create a container and register providers: +// Create a container and register services: // // c := needle.New() // -// needle.Provide(c, func(ctx context.Context, r needle.Resolver) (*Config, error) { -// return &Config{Port: 8080}, nil -// }) +// needle.Register(c, needle.SpecValue(&Config{Port: 8080})) // -// needle.Provide(c, func(ctx context.Context, r needle.Resolver) (*Server, error) { -// cfg := needle.MustInvoke[*Config](c) -// return &Server{config: cfg}, nil +// needle.Register(c, needle.Spec[*Server]{ +// Provider: func(ctx context.Context, r needle.Resolver) (*Server, error) { +// cfg := needle.MustInvoke[*Config](c) +// return &Server{config: cfg}, nil +// }, // }) // // c.Run(ctx) // -// # Providers +// # The Spec // -// Providers are functions that create instances of a type. They receive a context -// and a Resolver for accessing other dependencies: +// Spec[T] is the single configuration object. Zero values mean "singleton, eager, +// no hooks." Set fields directly or chain helpers: // -// needle.Provide[T](c, provider) // Register a provider -// needle.ProvideValue[T](c, value) // Register an existing value -// needle.ProvideNamed[T](c, "name", prov) // Register a named provider +// type Spec[T any] struct { +// Name string +// Provider Provider[T] +// Dependencies []string +// Scope Scope +// OnStart Hook +// OnStop Hook +// PoolSize int +// Lazy bool +// } // -// # Auto-Wiring +// Register errors if the key already exists. Replace overwrites or registers if not +// present. MustRegister and MustReplace panic on error. // -// Reduce boilerplate with constructor auto-wiring and struct tag injection. +// # Constructor Helpers // -// Constructor auto-wiring automatically resolves function parameters: +// SpecValue binds a pre-built value: // -// func NewUserService(db *Database, log *Logger) *UserService { -// return &UserService{db: db, log: log} -// } -// needle.ProvideFunc[*UserService](c, NewUserService) +// needle.Register(c, needle.SpecValue(&Config{Port: 8080})) +// +// SpecFromConstructor auto-wires from a constructor's parameters: +// +// func NewUserService(db *Database, log *Logger) *UserService { ... } +// needle.Register(c, needle.SpecFromConstructor[*UserService](NewUserService)) // -// Struct tag injection uses the `needle` tag to inject fields: +// SpecFromStruct populates a struct via `needle:"..."` tags: // // type UserService struct { -// DB *Database `needle:""` // inject by type -// Log *Logger `needle:"appLogger"` // inject by name -// Cache *Cache `needle:",optional"` // optional dependency +// DB *Database `needle:""` // inject by type +// Log *Logger `needle:"appLogger"` // inject by name +// Cache *Cache `needle:",optional"` // optional dependency // } -// needle.ProvideStruct[*UserService](c) +// needle.Register(c, needle.SpecFromStruct[*UserService]()) // -// Or invoke directly without registering: +// SpecFromBinding wires an interface to a registered implementation: // -// svc, err := needle.InvokeStruct[*UserService](c) +// needle.Register(c, needle.SpecFromBinding[UserRepository, *PostgresUserRepo]()) // // # Resolution // // Resolve dependencies using the Invoke functions: // -// svc, err := needle.Invoke[*Service](c) // Returns value and error -// svc := needle.MustInvoke[*Service](c) // Panics on error +// svc, err := needle.Invoke[*Service](c) // returns value and error +// svc := needle.MustInvoke[*Service](c) // panics on error // // # Optional Dependencies // @@ -72,41 +85,35 @@ // cache := opt.Value() // } // -// // Or use OrElse for default values -// opt, _ := needle.InvokeOptional[*Cache](c) // cache := opt.OrElse(defaultCache) -// -// // OrElseFunc for lazy defaults -// opt, _ := needle.InvokeOptional[*Cache](c) -// cache := opt.OrElseFunc(func() *Cache { -// return NewDefaultCache() -// }) +// cache := opt.OrElseFunc(func() *Cache { return NewDefaultCache() }) // // # Lifecycle // -// Services can participate in the container's lifecycle: +// Specs can carry OnStart and OnStop hooks: // -// needle.Provide(c, NewServer, -// needle.WithOnStart(func(ctx context.Context) error { -// return server.Listen() -// }), -// needle.WithOnStop(func(ctx context.Context) error { -// return server.Shutdown(ctx) -// }), -// ) +// needle.Register(c, needle.Spec[*Server]{ +// Provider: NewServer, +// OnStart: func(ctx context.Context) error { return server.Listen() }, +// OnStop: func(ctx context.Context) error { return server.Shutdown(ctx) }, +// }) // -// c.Start(ctx) // Starts all services in dependency order -// c.Stop(ctx) // Stops all services in reverse order +// c.Start(ctx) // starts all services in dependency order +// c.Stop(ctx) // stops all services in reverse order // c.Run(ctx) // Start + wait for signal + Stop // -// # Lazy Providers +// Multiple hooks compose via Compose, which runs them in order and stops on first error: +// +// OnStart: needle.Compose(installRoutes, openListener) +// +// # Lazy Specs // // Defer instantiation until first use: // -// needle.Provide(c, NewExpensiveService, needle.WithLazy()) +// needle.Register(c, needle.SpecFromConstructor[*Expensive](NewExpensive).WithLazy()) // -// Lazy services are not instantiated during Start(). They are created on first -// Invoke(), and their OnStart hooks run at that time if the container is running. +// Lazy services are not instantiated during Start. They are created on first +// Invoke, and their OnStart hook runs at that time if the container is running. // // # Parallel Startup // @@ -123,7 +130,7 @@ // // c := needle.New(needle.WithShutdownTimeout(30 * time.Second)) // -// The timeout applies to Stop() and is checked between service shutdowns. +// The timeout applies to Stop and is checked between service shutdowns. // Individual OnStop hooks receive the timeout context. // // # Debug Visualization @@ -133,18 +140,18 @@ // c.PrintGraph() // ASCII to stdout // c.PrintGraphDOT() // Graphviz DOT to stdout // output := c.SprintGraph() -// info := c.Graph() // Structured GraphInfo +// info := c.Graph() // structured GraphInfo // // # Modules // -// Group related providers into modules: +// Group related specs into modules: // // var ConfigModule = needle.NewModule("config") -// needle.ModuleProvideValue(ConfigModule, &Config{Port: 8080}) +// needle.ModuleRegister(ConfigModule, needle.SpecValue(&Config{Port: 8080})) // // var HTTPModule = needle.NewModule("http") -// needle.ModuleProvide(HTTPModule, NewServer) -// needle.ModuleProvide(HTTPModule, NewRouter) +// needle.ModuleRegister(HTTPModule, needle.SpecFromConstructor[*Server](NewServer)) +// needle.ModuleRegister(HTTPModule, needle.SpecFromConstructor[*Router](NewRouter)) // // c.Apply(ConfigModule, HTTPModule) // @@ -154,17 +161,6 @@ // Include(ConfigModule). // Include(HTTPModule) // -// # Interface Binding -// -// Bind interfaces to concrete implementations: -// -// needle.Bind[UserRepository, *PostgresUserRepo](c) -// needle.BindNamed[Cache, *RedisCache](c, "session") -// -// Or within modules: -// -// needle.ModuleBind[UserRepository, *PostgresUserRepo](module) -// // # Decorators // // Wrap services with cross-cutting concerns: @@ -180,11 +176,11 @@ // // # Scopes // -// Control instance lifetime with scopes: +// Control instance lifetime with Scope on the spec: // -// needle.Provide(c, NewService, needle.WithScope(needle.Transient)) -// needle.Provide(c, NewService, needle.WithScope(needle.Request)) -// needle.Provide(c, NewService, needle.WithPoolSize(10)) +// needle.Register(c, needle.SpecFromConstructor[*Handler](NewHandler).WithScope(needle.Transient)) +// needle.Register(c, needle.SpecFromConstructor[*ReqLog](NewReqLog).WithScope(needle.Request)) +// needle.Register(c, needle.SpecFromConstructor[*Worker](NewWorker).WithPoolSize(10)) // // Available scopes: Singleton (default), Transient, Request, Pooled. // @@ -198,25 +194,20 @@ // // Check health status: // -// err := c.Live(ctx) // Fails if any HealthChecker returns error -// err := c.Ready(ctx) // Fails if any ReadinessChecker returns error -// reports := c.Health(ctx) // Get detailed health reports with latency +// err := c.Live(ctx) // fails if any HealthChecker returns error +// err := c.Ready(ctx) // fails if any ReadinessChecker returns error +// reports := c.Health(ctx) // detailed health reports with latency // // # Hot Reload / Dynamic Replacement // // Replace services at runtime without restarting the container: // -// needle.ReplaceValue(c, &Config{NewValue: "updated"}) -// needle.Replace(c, newProvider) -// needle.ReplaceFunc[*Service](c, NewServiceConstructor) -// needle.ReplaceStruct[*Service](c) -// -// Named variants are also available: -// -// needle.ReplaceNamedValue(c, "primary", &Config{}) -// needle.ReplaceNamed(c, "primary", provider) +// needle.Replace(c, needle.SpecValue(&Config{NewValue: "updated"})) +// needle.Replace(c, needle.SpecFromConstructor[*Service](NewService)) +// needle.Replace(c, needle.SpecFromStruct[*Service]()) +// needle.Replace(c, needle.SpecValue(&Config{}).WithName("primary")) // -// This is useful for feature flags, A/B testing, or configuration updates. +// Useful for feature flags, A/B testing, or configuration updates. // // # Metrics Observers // diff --git a/examples/autowire/main.go b/examples/autowire/main.go index b058f04..ce0036e 100644 --- a/examples/autowire/main.go +++ b/examples/autowire/main.go @@ -52,19 +52,17 @@ type UserService struct { func main() { c := needle.New() - _ = needle.ProvideValue( - c, &Config{ - DatabaseURL: "postgres://localhost/mydb", - CacheSize: 1000, - }, - ) + _ = needle.Register(c, needle.SpecValue(&Config{ + DatabaseURL: "postgres://localhost/mydb", + CacheSize: 1000, + })) - _ = needle.ProvideFunc[*Logger](c, NewLogger) - _ = needle.ProvideFunc[*Database](c, NewDatabase) - _ = needle.ProvideFunc[*Cache](c, NewCache) + _ = needle.Register(c, needle.SpecFromConstructor[*Logger](NewLogger)) + _ = needle.Register(c, needle.SpecFromConstructor[*Database](NewDatabase)) + _ = needle.Register(c, needle.SpecFromConstructor[*Cache](NewCache)) - _ = needle.ProvideStruct[*UserRepository](c) - _ = needle.ProvideStruct[*UserService](c) + _ = needle.Register(c, needle.SpecFromStruct[*UserRepository]()) + _ = needle.Register(c, needle.SpecFromStruct[*UserService]()) if err := c.Validate(); err != nil { panic(err) @@ -77,19 +75,9 @@ func main() { fmt.Printf(" Repo DB URL: %s\n", svc.Repo.DB.URL) fmt.Printf(" Repo Cache size: %d\n", svc.Repo.Cache.Size) - fmt.Println("\n--- Comparison ---") - fmt.Println("Traditional (verbose):") - fmt.Println( - ` needle.Provide(c, func(ctx context.Context, r needle.Resolver) (*UserService, error) { - repo := needle.MustInvoke[*UserRepository](c) - logger := needle.MustInvoke[*Logger](c) - return &UserService{Repo: repo, Logger: logger}, nil - })`, - ) - - fmt.Println("\nWith ProvideFunc (constructor auto-wiring):") - fmt.Println(` needle.ProvideFunc[*Database](c, NewDatabase)`) - - fmt.Println("\nWith ProvideStruct (struct tag injection):") - fmt.Println(` needle.ProvideStruct[*UserService](c)`) + fmt.Println("\nWith SpecFromConstructor (auto-wired params):") + fmt.Println(` needle.Register(c, needle.SpecFromConstructor[*Database](NewDatabase))`) + + fmt.Println("\nWith SpecFromStruct (struct-tag injection):") + fmt.Println(` needle.Register(c, needle.SpecFromStruct[*UserService]())`) } diff --git a/examples/basic/main.go b/examples/basic/main.go index b4a5339..a4473d1 100644 --- a/examples/basic/main.go +++ b/examples/basic/main.go @@ -47,33 +47,31 @@ func (s *UserService) GetUser(id int) string { func main() { c := needle.New() - _ = needle.ProvideValue( - c, &Config{ - DatabaseURL: "postgres://localhost/mydb", - Port: 8080, - }, - ) + _ = needle.Register(c, needle.SpecValue(&Config{ + DatabaseURL: "postgres://localhost/mydb", + Port: 8080, + })) - _ = needle.Provide( - c, func(ctx context.Context, r needle.Resolver) (*Database, error) { + _ = needle.Register(c, needle.Spec[*Database]{ + Provider: func(ctx context.Context, r needle.Resolver) (*Database, error) { cfg := needle.MustInvoke[*Config](c) return NewDatabase(cfg), nil }, - ) + }) - _ = needle.Provide( - c, func(ctx context.Context, r needle.Resolver) (*UserRepository, error) { + _ = needle.Register(c, needle.Spec[*UserRepository]{ + Provider: func(ctx context.Context, r needle.Resolver) (*UserRepository, error) { db := needle.MustInvoke[*Database](c) return NewUserRepository(db), nil }, - ) + }) - _ = needle.Provide( - c, func(ctx context.Context, r needle.Resolver) (*UserService, error) { + _ = needle.Register(c, needle.Spec[*UserService]{ + Provider: func(ctx context.Context, r needle.Resolver) (*UserService, error) { repo := needle.MustInvoke[*UserRepository](c) return NewUserService(repo), nil }, - ) + }) if err := c.Validate(); err != nil { panic(err) diff --git a/examples/decorators/main.go b/examples/decorators/main.go index 5ecfaeb..a96c1a6 100644 --- a/examples/decorators/main.go +++ b/examples/decorators/main.go @@ -72,15 +72,15 @@ func main() { logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) c := needle.New() - _ = needle.ProvideValue(c, logger) + _ = needle.Register(c, needle.SpecValue(logger)) - _ = needle.Provide( - c, func(_ context.Context, _ needle.Resolver) (*PostgresUserRepository, error) { + _ = needle.Register(c, needle.Spec[*PostgresUserRepository]{ + Provider: func(_ context.Context, _ needle.Resolver) (*PostgresUserRepository, error) { return &PostgresUserRepository{}, nil }, - ) + }) - _ = needle.Bind[UserRepository, *PostgresUserRepository](c) + _ = needle.Register(c, needle.SpecFromBinding[UserRepository, *PostgresUserRepository]()) needle.Decorate( c, func(_ context.Context, _ needle.Resolver, repo UserRepository) (UserRepository, error) { diff --git a/examples/healthchecks/main.go b/examples/healthchecks/main.go index 6cbbb1f..ced5668 100644 --- a/examples/healthchecks/main.go +++ b/examples/healthchecks/main.go @@ -80,23 +80,23 @@ func (m *MessageQueue) SetReady(ready bool) { func main() { c := needle.New() - _ = needle.Provide( - c, func(_ context.Context, _ needle.Resolver) (*Database, error) { + _ = needle.Register(c, needle.Spec[*Database]{ + Provider: func(_ context.Context, _ needle.Resolver) (*Database, error) { return NewDatabase(), nil }, - ) + }) - _ = needle.Provide( - c, func(_ context.Context, _ needle.Resolver) (*Cache, error) { + _ = needle.Register(c, needle.Spec[*Cache]{ + Provider: func(_ context.Context, _ needle.Resolver) (*Cache, error) { return NewCache(), nil }, - ) + }) - _ = needle.Provide( - c, func(_ context.Context, _ needle.Resolver) (*MessageQueue, error) { + _ = needle.Register(c, needle.Spec[*MessageQueue]{ + Provider: func(_ context.Context, _ needle.Resolver) (*MessageQueue, error) { return NewMessageQueue(), nil }, - ) + }) ctx := context.Background() _ = c.Start(ctx) diff --git a/examples/httpserver/main.go b/examples/httpserver/main.go index a3a77e0..ee89abc 100644 --- a/examples/httpserver/main.go +++ b/examples/httpserver/main.go @@ -79,43 +79,37 @@ func main() { c := needle.New(needle.WithLogger(logger)) - _ = needle.ProvideValue( - c, &Config{ - Port: 8080, - ReadTimeout: 10 * time.Second, - WriteTimeout: 10 * time.Second, - }, - ) + _ = needle.Register(c, needle.SpecValue(&Config{ + Port: 8080, + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + })) - _ = needle.ProvideValue(c, logger) + _ = needle.Register(c, needle.SpecValue(logger)) - _ = needle.Provide( - c, func(ctx context.Context, r needle.Resolver) (http.Handler, error) { + _ = needle.Register(c, needle.Spec[http.Handler]{ + Provider: func(ctx context.Context, r needle.Resolver) (http.Handler, error) { log := needle.MustInvoke[*slog.Logger](c) return NewHandler(log), nil }, - ) + }) - _ = needle.Provide( - c, func(ctx context.Context, r needle.Resolver) (*Server, error) { + _ = needle.Register(c, needle.Spec[*Server]{ + Provider: func(ctx context.Context, r needle.Resolver) (*Server, error) { cfg := needle.MustInvoke[*Config](c) handler := needle.MustInvoke[http.Handler](c) log := needle.MustInvoke[*slog.Logger](c) return NewServer(cfg, handler, log), nil }, - needle.WithOnStart( - func(ctx context.Context) error { - srv := needle.MustInvoke[*Server](c) - return srv.Start(ctx) - }, - ), - needle.WithOnStop( - func(ctx context.Context) error { - srv := needle.MustInvoke[*Server](c) - return srv.Stop(ctx) - }, - ), - ) + OnStart: func(ctx context.Context) error { + srv := needle.MustInvoke[*Server](c) + return srv.Start(ctx) + }, + OnStop: func(ctx context.Context) error { + srv := needle.MustInvoke[*Server](c) + return srv.Stop(ctx) + }, + }) logger.Info("starting application") if err := c.Run(context.Background()); err != nil { diff --git a/examples/lazy/main.go b/examples/lazy/main.go index f663562..3b47327 100644 --- a/examples/lazy/main.go +++ b/examples/lazy/main.go @@ -39,42 +39,34 @@ func (s *EagerService) DoWork() { func main() { c := needle.New() - _ = needle.Provide( - c, func(_ context.Context, _ needle.Resolver) (*EagerService, error) { + _ = needle.Register(c, needle.Spec[*EagerService]{ + Provider: func(_ context.Context, _ needle.Resolver) (*EagerService, error) { return NewEagerService("EagerService"), nil }, - needle.WithOnStart( - func(_ context.Context) error { - fmt.Println("[lifecycle] EagerService OnStart hook running") - return nil - }, - ), - needle.WithOnStop( - func(_ context.Context) error { - fmt.Println("[lifecycle] EagerService OnStop hook running") - return nil - }, - ), - ) - - _ = needle.Provide( - c, func(_ context.Context, _ needle.Resolver) (*ExpensiveService, error) { + OnStart: func(_ context.Context) error { + fmt.Println("[lifecycle] EagerService OnStart hook running") + return nil + }, + OnStop: func(_ context.Context) error { + fmt.Println("[lifecycle] EagerService OnStop hook running") + return nil + }, + }) + + _ = needle.Register(c, needle.Spec[*ExpensiveService]{ + Provider: func(_ context.Context, _ needle.Resolver) (*ExpensiveService, error) { return NewExpensiveService("LazyExpensiveService"), nil }, - needle.WithLazy(), - needle.WithOnStart( - func(_ context.Context) error { - fmt.Println("[lifecycle] LazyExpensiveService OnStart hook running") - return nil - }, - ), - needle.WithOnStop( - func(_ context.Context) error { - fmt.Println("[lifecycle] LazyExpensiveService OnStop hook running") - return nil - }, - ), - ) + Lazy: true, + OnStart: func(_ context.Context) error { + fmt.Println("[lifecycle] LazyExpensiveService OnStart hook running") + return nil + }, + OnStop: func(_ context.Context) error { + fmt.Println("[lifecycle] LazyExpensiveService OnStop hook running") + return nil + }, + }) fmt.Println("=== Starting container ===") ctx := context.Background() diff --git a/examples/modules/main.go b/examples/modules/main.go index 2803b12..f4193b9 100644 --- a/examples/modules/main.go +++ b/examples/modules/main.go @@ -70,8 +70,8 @@ var RepositoryModule = needle.NewModule("repository") var ServiceModule = needle.NewModule("service") func init() { - needle.ModuleProvide( - DatabaseModule, func(ctx context.Context, r needle.Resolver) (*Database, error) { + needle.ModuleRegister(DatabaseModule, needle.Spec[*Database]{ + Provider: func(ctx context.Context, r needle.Resolver) (*Database, error) { cfg, _ := r.Resolve(ctx, "*main.Config") logger, _ := r.Resolve(ctx, "*log/slog.Logger") return &Database{ @@ -79,17 +79,17 @@ func init() { logger: logger.(*slog.Logger), }, nil }, - ) + }) - needle.ModuleProvide( - CacheModule, func(ctx context.Context, r needle.Resolver) (*Cache, error) { + needle.ModuleRegister(CacheModule, needle.Spec[*Cache]{ + Provider: func(ctx context.Context, r needle.Resolver) (*Cache, error) { cfg, _ := r.Resolve(ctx, "*main.Config") return &Cache{url: cfg.(*Config).CacheURL}, nil }, - ) + }) - needle.ModuleProvide( - RepositoryModule, func(ctx context.Context, r needle.Resolver) (*PostgresUserRepository, error) { + needle.ModuleRegister(RepositoryModule, needle.Spec[*PostgresUserRepository]{ + Provider: func(ctx context.Context, r needle.Resolver) (*PostgresUserRepository, error) { db, _ := r.Resolve(ctx, "*main.Database") cache, _ := r.Resolve(ctx, "*main.Cache") return &PostgresUserRepository{ @@ -97,12 +97,12 @@ func init() { cache: cache.(*Cache), }, nil }, - ) + }) - needle.ModuleBind[UserRepository, *PostgresUserRepository](RepositoryModule) + needle.ModuleRegister(RepositoryModule, needle.SpecFromBinding[UserRepository, *PostgresUserRepository]()) - needle.ModuleProvide( - ServiceModule, func(ctx context.Context, r needle.Resolver) (*UserService, error) { + needle.ModuleRegister(ServiceModule, needle.Spec[*UserService]{ + Provider: func(ctx context.Context, r needle.Resolver) (*UserService, error) { repo, _ := r.Resolve(ctx, "main.UserRepository") logger, _ := r.Resolve(ctx, "*log/slog.Logger") return &UserService{ @@ -110,7 +110,7 @@ func init() { logger: logger.(*slog.Logger), }, nil }, - ) + }) } var AppModule = needle.NewModule("app"). @@ -124,13 +124,11 @@ func main() { logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) c := needle.New(needle.WithLogger(logger)) - _ = needle.ProvideValue(c, logger) - needle.ModuleProvideValue( - ConfigModule, &Config{ - DatabaseURL: "postgres://localhost/mydb", - CacheURL: "redis://localhost:6379", - }, - ) + _ = needle.Register(c, needle.SpecValue(logger)) + needle.ModuleRegister(ConfigModule, needle.SpecValue(&Config{ + DatabaseURL: "postgres://localhost/mydb", + CacheURL: "redis://localhost:6379", + })) if err := c.Apply(AppModule); err != nil { logger.Error("failed to apply modules", "error", err) diff --git a/examples/optional/main.go b/examples/optional/main.go index 545afa3..2a9c6f1 100644 --- a/examples/optional/main.go +++ b/examples/optional/main.go @@ -103,22 +103,22 @@ func main() { func runWithAllDeps() { c := needle.New() - _ = needle.Provide( - c, func(_ context.Context, _ needle.Resolver) (*RedisCache, error) { + _ = needle.Register(c, needle.Spec[*RedisCache]{ + Provider: func(_ context.Context, _ needle.Resolver) (*RedisCache, error) { return NewRedisCache(), nil }, - ) - _ = needle.Bind[Cache, *RedisCache](c) + }) + _ = needle.Register(c, needle.SpecFromBinding[Cache, *RedisCache]()) - _ = needle.Provide( - c, func(_ context.Context, _ needle.Resolver) (*PrometheusMetrics, error) { + _ = needle.Register(c, needle.Spec[*PrometheusMetrics]{ + Provider: func(_ context.Context, _ needle.Resolver) (*PrometheusMetrics, error) { return &PrometheusMetrics{}, nil }, - ) - _ = needle.Bind[Metrics, *PrometheusMetrics](c) + }) + _ = needle.Register(c, needle.SpecFromBinding[Metrics, *PrometheusMetrics]()) - _ = needle.Provide( - c, func(_ context.Context, _ needle.Resolver) (*UserService, error) { + _ = needle.Register(c, needle.Spec[*UserService]{ + Provider: func(_ context.Context, _ needle.Resolver) (*UserService, error) { cache := mustOptional(needle.InvokeOptional[Cache](c)).OrElseFunc( func() Cache { return NewInMemoryCache() @@ -127,7 +127,7 @@ func runWithAllDeps() { metrics := mustOptional(needle.InvokeOptional[Metrics](c)).OrElse(&NoOpMetrics{}) return &UserService{cache: cache, metrics: metrics}, nil }, - ) + }) svc := needle.MustInvoke[*UserService](c) fmt.Println(svc.GetUser(42)) @@ -137,8 +137,8 @@ func runWithAllDeps() { func runWithoutOptionalDeps() { c := needle.New() - _ = needle.Provide( - c, func(_ context.Context, _ needle.Resolver) (*UserService, error) { + _ = needle.Register(c, needle.Spec[*UserService]{ + Provider: func(_ context.Context, _ needle.Resolver) (*UserService, error) { cache := mustOptional(needle.InvokeOptional[Cache](c)).OrElseFunc( func() Cache { return NewInMemoryCache() @@ -147,7 +147,7 @@ func runWithoutOptionalDeps() { metrics := mustOptional(needle.InvokeOptional[Metrics](c)).OrElse(&NoOpMetrics{}) return &UserService{cache: cache, metrics: metrics}, nil }, - ) + }) svc := needle.MustInvoke[*UserService](c) fmt.Println(svc.GetUser(42)) @@ -157,8 +157,8 @@ func runWithoutOptionalDeps() { func demonstrateOptionalAPI() { c := needle.New() - _ = needle.ProvideValue(c, &RedisCache{data: make(map[string]string)}) - _ = needle.Bind[Cache, *RedisCache](c) + _ = needle.Register(c, needle.SpecValue(&RedisCache{data: make(map[string]string)})) + _ = needle.Register(c, needle.SpecFromBinding[Cache, *RedisCache]()) fmt.Println("--- Present() and Value() ---") opt := mustOptional(needle.InvokeOptional[Cache](c)) diff --git a/examples/parallel/main.go b/examples/parallel/main.go index 3d4d509..f7ae088 100644 --- a/examples/parallel/main.go +++ b/examples/parallel/main.go @@ -79,46 +79,46 @@ func runParallel() { } func registerProviders(c *needle.Container) { - _ = needle.ProvideValue(c, &Config{Value: "config"}) + _ = needle.Register(c, needle.SpecValue(&Config{Value: "config"})) - _ = needle.Provide( - c, func(_ context.Context, _ needle.Resolver) (*DatabaseA, error) { + _ = needle.Register(c, needle.Spec[*DatabaseA]{ + Provider: func(_ context.Context, _ needle.Resolver) (*DatabaseA, error) { fmt.Printf("[%s] Starting DatabaseA...\n", timestamp()) time.Sleep(100 * time.Millisecond) fmt.Printf("[%s] DatabaseA ready\n", timestamp()) return &DatabaseA{name: "db-a"}, nil }, - ) + }) - _ = needle.Provide( - c, func(_ context.Context, _ needle.Resolver) (*DatabaseB, error) { + _ = needle.Register(c, needle.Spec[*DatabaseB]{ + Provider: func(_ context.Context, _ needle.Resolver) (*DatabaseB, error) { fmt.Printf("[%s] Starting DatabaseB...\n", timestamp()) time.Sleep(100 * time.Millisecond) fmt.Printf("[%s] DatabaseB ready\n", timestamp()) return &DatabaseB{name: "db-b"}, nil }, - ) + }) - _ = needle.Provide( - c, func(_ context.Context, _ needle.Resolver) (*CacheA, error) { + _ = needle.Register(c, needle.Spec[*CacheA]{ + Provider: func(_ context.Context, _ needle.Resolver) (*CacheA, error) { fmt.Printf("[%s] Starting CacheA...\n", timestamp()) time.Sleep(100 * time.Millisecond) fmt.Printf("[%s] CacheA ready\n", timestamp()) return &CacheA{name: "cache-a"}, nil }, - ) + }) - _ = needle.Provide( - c, func(_ context.Context, _ needle.Resolver) (*CacheB, error) { + _ = needle.Register(c, needle.Spec[*CacheB]{ + Provider: func(_ context.Context, _ needle.Resolver) (*CacheB, error) { fmt.Printf("[%s] Starting CacheB...\n", timestamp()) time.Sleep(100 * time.Millisecond) fmt.Printf("[%s] CacheB ready\n", timestamp()) return &CacheB{name: "cache-b"}, nil }, - ) + }) - _ = needle.Provide( - c, func(_ context.Context, _ needle.Resolver) (*ServiceA, error) { + _ = needle.Register(c, needle.Spec[*ServiceA]{ + Provider: func(_ context.Context, _ needle.Resolver) (*ServiceA, error) { db := needle.MustInvoke[*DatabaseA](c) cache := needle.MustInvoke[*CacheA](c) fmt.Printf("[%s] Starting ServiceA...\n", timestamp()) @@ -126,11 +126,11 @@ func registerProviders(c *needle.Container) { fmt.Printf("[%s] ServiceA ready\n", timestamp()) return &ServiceA{db: db, cache: cache}, nil }, - needle.WithDependencies("*main.DatabaseA", "*main.CacheA"), - ) + Dependencies: []string{"*main.DatabaseA", "*main.CacheA"}, + }) - _ = needle.Provide( - c, func(_ context.Context, _ needle.Resolver) (*ServiceB, error) { + _ = needle.Register(c, needle.Spec[*ServiceB]{ + Provider: func(_ context.Context, _ needle.Resolver) (*ServiceB, error) { db := needle.MustInvoke[*DatabaseB](c) cache := needle.MustInvoke[*CacheB](c) fmt.Printf("[%s] Starting ServiceB...\n", timestamp()) @@ -138,11 +138,11 @@ func registerProviders(c *needle.Container) { fmt.Printf("[%s] ServiceB ready\n", timestamp()) return &ServiceB{db: db, cache: cache}, nil }, - needle.WithDependencies("*main.DatabaseB", "*main.CacheB"), - ) + Dependencies: []string{"*main.DatabaseB", "*main.CacheB"}, + }) - _ = needle.Provide( - c, func(_ context.Context, _ needle.Resolver) (*Gateway, error) { + _ = needle.Register(c, needle.Spec[*Gateway]{ + Provider: func(_ context.Context, _ needle.Resolver) (*Gateway, error) { svcA := needle.MustInvoke[*ServiceA](c) svcB := needle.MustInvoke[*ServiceB](c) fmt.Printf("[%s] Starting Gateway...\n", timestamp()) @@ -150,8 +150,8 @@ func registerProviders(c *needle.Container) { fmt.Printf("[%s] Gateway ready\n", timestamp()) return &Gateway{svcA: svcA, svcB: svcB}, nil }, - needle.WithDependencies("*main.ServiceA", "*main.ServiceB"), - ) + Dependencies: []string{"*main.ServiceA", "*main.ServiceB"}, + }) } var startTime = time.Now() diff --git a/examples/scopes/main.go b/examples/scopes/main.go index 6ff8a5b..52c47ef 100644 --- a/examples/scopes/main.go +++ b/examples/scopes/main.go @@ -51,32 +51,33 @@ func NewPooledConnection() *PooledConnection { func main() { c := needle.New() - _ = needle.Provide( - c, func(_ context.Context, _ needle.Resolver) (*SingletonCounter, error) { + _ = needle.Register(c, needle.Spec[*SingletonCounter]{ + Provider: func(_ context.Context, _ needle.Resolver) (*SingletonCounter, error) { return NewSingletonCounter(), nil }, - ) + }) - _ = needle.Provide( - c, func(_ context.Context, _ needle.Resolver) (*TransientCounter, error) { + _ = needle.Register(c, needle.Spec[*TransientCounter]{ + Provider: func(_ context.Context, _ needle.Resolver) (*TransientCounter, error) { return NewTransientCounter(), nil }, - needle.WithScope(needle.Transient), - ) + Scope: needle.Transient, + }) - _ = needle.Provide( - c, func(_ context.Context, _ needle.Resolver) (*RequestCounter, error) { + _ = needle.Register(c, needle.Spec[*RequestCounter]{ + Provider: func(_ context.Context, _ needle.Resolver) (*RequestCounter, error) { return NewRequestCounter(), nil }, - needle.WithScope(needle.Request), - ) + Scope: needle.Request, + }) - _ = needle.Provide( - c, func(_ context.Context, _ needle.Resolver) (*PooledConnection, error) { + _ = needle.Register(c, needle.Spec[*PooledConnection]{ + Provider: func(_ context.Context, _ needle.Resolver) (*PooledConnection, error) { return NewPooledConnection(), nil }, - needle.WithPoolSize(3), - ) + Scope: needle.Pooled, + PoolSize: 3, + }) fmt.Println("=== Singleton Scope ===") fmt.Println("Same instance returned every time:") diff --git a/internal/container/container.go b/internal/container/container.go index 9fab32b..50e7979 100644 --- a/internal/container/container.go +++ b/internal/container/container.go @@ -8,7 +8,6 @@ import ( "time" "github.com/danpasecinic/needle/internal/graph" - "github.com/danpasecinic/needle/internal/scope" ) type State int @@ -74,63 +73,38 @@ func New(cfg *Config) *Container { } } -func (c *Container) Register(key string, provider ProviderFunc, dependencies []string) error { - if err := c.registerLocked(key, provider, dependencies); err != nil { +func (c *Container) Register(entry *ServiceEntry) error { + if err := c.registerLocked(entry); err != nil { return err } for _, hook := range c.onProvide { - hook(key) + hook(entry.Key) } return nil } -func (c *Container) registerLocked(key string, provider ProviderFunc, dependencies []string) error { +func (c *Container) registerLocked(entry *ServiceEntry) error { c.mu.Lock() defer c.mu.Unlock() - if c.registry.Has(key) { - return fmt.Errorf("service already registered: %s", key) + if c.registry.Has(entry.Key) { + return fmt.Errorf("service already registered: %s", entry.Key) } - _ = c.registry.Register(key, provider, dependencies) - c.graph.AddNode(key, dependencies) + c.registry.Add(entry) + c.graph.AddNode(entry.Key, entry.Dependencies) - if len(dependencies) > 0 && c.graph.HasCycle() { - c.registry.Remove(key) - c.graph.RemoveNode(key) - return fmt.Errorf("circular dependency detected for: %s", key) + if len(entry.Dependencies) > 0 && c.graph.HasCycle() { + c.registry.Remove(entry.Key) + c.graph.RemoveNode(entry.Key) + return fmt.Errorf("circular dependency detected for: %s", entry.Key) } return nil } -func (c *Container) RegisterValue(key string, value any) error { - if err := c.registerValueLocked(key, value); err != nil { - return err - } - - for _, hook := range c.onProvide { - hook(key) - } - - return nil -} - -func (c *Container) registerValueLocked(key string, value any) error { - c.mu.Lock() - defer c.mu.Unlock() - - if c.registry.Has(key) { - return fmt.Errorf("service already registered: %s", key) - } - - _ = c.registry.RegisterValue(key, value) - c.graph.AddNode(key, nil) - return nil -} - func (c *Container) Has(key string) bool { c.mu.RLock() defer c.mu.RUnlock() @@ -196,23 +170,3 @@ func (c *Container) Release(key string, instance any) bool { } return released } - -func (c *Container) AddOnStart(key string, hook Hook) { - c.registry.AddOnStart(key, hook) -} - -func (c *Container) AddOnStop(key string, hook Hook) { - c.registry.AddOnStop(key, hook) -} - -func (c *Container) SetScope(key string, s scope.Scope) { - c.registry.SetScope(key, s) -} - -func (c *Container) SetPoolSize(key string, size int) { - c.registry.SetPoolSize(key, size) -} - -func (c *Container) SetLazy(key string, lazy bool) { - c.registry.SetLazy(key, lazy) -} diff --git a/internal/container/container_test.go b/internal/container/container_test.go index 06d2de5..65933d6 100644 --- a/internal/container/container_test.go +++ b/internal/container/container_test.go @@ -8,16 +8,30 @@ import ( "testing" ) +func registerProvider(c *Container, key string, fn ProviderFunc, deps []string) error { + return c.Register(NewServiceEntry(EntryConfig{ + Key: key, + Provider: fn, + Dependencies: deps, + })) +} + +func registerValue(c *Container, key string, v any) error { + return c.Register(NewServiceEntry(EntryConfig{ + Key: key, + Value: v, + HasValue: true, + })) +} + func TestContainer_RegisterAndResolve(t *testing.T) { t.Parallel() c := New(&Config{}) - err := c.Register( - "config", func(ctx context.Context, r Resolver) (any, error) { - return map[string]string{"port": "8080"}, nil - }, nil, - ) + err := registerProvider(c, "config", func(ctx context.Context, r Resolver) (any, error) { + return map[string]string{"port": "8080"}, nil + }, nil) if err != nil { t.Fatalf("failed to register: %v", err) } @@ -44,7 +58,7 @@ func TestContainer_RegisterValue(t *testing.T) { c := New(&Config{}) value := "test-value" - err := c.RegisterValue("myvalue", value) + err := registerValue(c, "myvalue", value) if err != nil { t.Fatalf("failed to register value: %v", err) } @@ -65,20 +79,18 @@ func TestContainer_DependencyResolution(t *testing.T) { c := New(&Config{}) - err := c.RegisterValue("config", map[string]string{"db": "postgres"}) + err := registerValue(c, "config", map[string]string{"db": "postgres"}) if err != nil { t.Fatalf("failed to register config: %v", err) } - err = c.Register( - "database", func(ctx context.Context, r Resolver) (any, error) { - cfg, err := r.Resolve(ctx, "config") - if err != nil { - return nil, err - } - return "connected to " + cfg.(map[string]string)["db"], nil - }, []string{"config"}, - ) + err = registerProvider(c, "database", func(ctx context.Context, r Resolver) (any, error) { + cfg, err := r.Resolve(ctx, "config") + if err != nil { + return nil, err + } + return "connected to " + cfg.(map[string]string)["db"], nil + }, []string{"config"}) if err != nil { t.Fatalf("failed to register database: %v", err) } @@ -99,12 +111,12 @@ func TestContainer_DuplicateRegistration(t *testing.T) { c := New(&Config{}) - err := c.RegisterValue("test", "value1") + err := registerValue(c, "test", "value1") if err != nil { t.Fatalf("first registration failed: %v", err) } - err = c.RegisterValue("test", "value2") + err = registerValue(c, "test", "value2") if err == nil { t.Error("expected error for duplicate registration") } @@ -115,20 +127,16 @@ func TestContainer_CircularDependency(t *testing.T) { c := New(&Config{}) - err := c.Register( - "A", func(ctx context.Context, r Resolver) (any, error) { - return "A", nil - }, []string{"B"}, - ) + err := registerProvider(c, "A", func(ctx context.Context, r Resolver) (any, error) { + return "A", nil + }, []string{"B"}) if err != nil { t.Fatalf("failed to register A: %v", err) } - err = c.Register( - "B", func(ctx context.Context, r Resolver) (any, error) { - return "B", nil - }, []string{"A"}, - ) + err = registerProvider(c, "B", func(ctx context.Context, r Resolver) (any, error) { + return "B", nil + }, []string{"A"}) if err == nil { t.Error("expected error for circular dependency") } @@ -139,12 +147,10 @@ func TestContainer_MissingDependency(t *testing.T) { c := New(&Config{}) - err := c.Register( - "service", func(ctx context.Context, r Resolver) (any, error) { - _, err := r.Resolve(ctx, "missing") - return nil, err - }, []string{"missing"}, - ) + err := registerProvider(c, "service", func(ctx context.Context, r Resolver) (any, error) { + _, err := r.Resolve(ctx, "missing") + return nil, err + }, []string{"missing"}) if err != nil { t.Fatalf("registration should succeed: %v", err) } @@ -162,11 +168,9 @@ func TestContainer_ProviderError(t *testing.T) { c := New(&Config{}) expectedErr := errors.New("provider failed") - err := c.Register( - "failing", func(ctx context.Context, r Resolver) (any, error) { - return nil, expectedErr - }, nil, - ) + err := registerProvider(c, "failing", func(ctx context.Context, r Resolver) (any, error) { + return nil, expectedErr + }, nil) if err != nil { t.Fatalf("registration failed: %v", err) } @@ -184,12 +188,10 @@ func TestContainer_Singleton(t *testing.T) { c := New(&Config{}) callCount := 0 - err := c.Register( - "counter", func(ctx context.Context, r Resolver) (any, error) { - callCount++ - return callCount, nil - }, nil, - ) + err := registerProvider(c, "counter", func(ctx context.Context, r Resolver) (any, error) { + callCount++ + return callCount, nil + }, nil) if err != nil { t.Fatalf("registration failed: %v", err) } @@ -217,7 +219,7 @@ func TestContainer_Has(t *testing.T) { t.Error("should not have unregistered service") } - _ = c.RegisterValue("test", "value") + _ = registerValue(c, "test", "value") if !c.Has("test") { t.Error("should have registered service") @@ -229,9 +231,9 @@ func TestContainer_Keys(t *testing.T) { c := New(&Config{}) - _ = c.RegisterValue("a", 1) - _ = c.RegisterValue("b", 2) - _ = c.RegisterValue("c", 3) + _ = registerValue(c, "a", 1) + _ = registerValue(c, "b", 2) + _ = registerValue(c, "c", 3) keys := c.Keys() if len(keys) != 3 { @@ -248,8 +250,8 @@ func TestContainer_Size(t *testing.T) { t.Error("empty container should have size 0") } - _ = c.RegisterValue("a", 1) - _ = c.RegisterValue("b", 2) + _ = registerValue(c, "a", 1) + _ = registerValue(c, "b", 2) if c.Size() != 2 { t.Errorf("expected size 2, got %d", c.Size()) @@ -261,12 +263,10 @@ func TestContainer_Validate(t *testing.T) { c := New(&Config{}) - _ = c.RegisterValue("config", "config") - _ = c.Register( - "service", func(ctx context.Context, r Resolver) (any, error) { - return "service", nil - }, []string{"config"}, - ) + _ = registerValue(c, "config", "config") + _ = registerProvider(c, "service", func(ctx context.Context, r Resolver) (any, error) { + return "service", nil + }, []string{"config"}) err := c.Validate() if err != nil { @@ -279,16 +279,14 @@ func TestContainer_ContextCancellation(t *testing.T) { c := New(&Config{}) - _ = c.Register( - "slow", func(ctx context.Context, r Resolver) (any, error) { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - return "done", nil - } - }, nil, - ) + _ = registerProvider(c, "slow", func(ctx context.Context, r Resolver) (any, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + return "done", nil + } + }, nil) ctx, cancel := context.WithCancel(context.Background()) cancel() @@ -304,8 +302,8 @@ func TestContainer_ConcurrentResolve_NoFalseCycle(t *testing.T) { c := New(&Config{}) - _ = c.RegisterValue("dep", "dependency") - _ = c.Register("svc", func(ctx context.Context, r Resolver) (any, error) { + _ = registerValue(c, "dep", "dependency") + _ = registerProvider(c, "svc", func(ctx context.Context, r Resolver) (any, error) { _, _ = r.Resolve(ctx, "dep") return "service", nil }, []string{"dep"}) @@ -338,7 +336,7 @@ func TestContainer_SingletonCalledOnce(t *testing.T) { c := New(&Config{}) var callCount atomic.Int64 - _ = c.Register("singleton", func(ctx context.Context, r Resolver) (any, error) { + _ = registerProvider(c, "singleton", func(ctx context.Context, r Resolver) (any, error) { callCount.Add(1) return "instance", nil }, nil) @@ -362,13 +360,11 @@ func TestContainer_SingletonCalledOnce(t *testing.T) { func BenchmarkContainer_Resolve(b *testing.B) { c := New(&Config{}) - _ = c.RegisterValue("config", map[string]string{"key": "value"}) - _ = c.Register( - "service", func(ctx context.Context, r Resolver) (any, error) { - _, _ = r.Resolve(ctx, "config") - return "service", nil - }, []string{"config"}, - ) + _ = registerValue(c, "config", map[string]string{"key": "value"}) + _ = registerProvider(c, "service", func(ctx context.Context, r Resolver) (any, error) { + _, _ = r.Resolve(ctx, "config") + return "service", nil + }, []string{"config"}) ctx := context.Background() _, _ = c.Resolve(ctx, "service") @@ -384,6 +380,6 @@ func BenchmarkContainer_Register(b *testing.B) { b.ReportAllocs() for i := 0; i < b.N; i++ { c := New(&Config{}) - _ = c.RegisterValue("test", "value") + _ = registerValue(c, "test", "value") } } diff --git a/internal/container/lifecycle.go b/internal/container/lifecycle.go index 4b8714b..4482294 100644 --- a/internal/container/lifecycle.go +++ b/internal/container/lifecycle.go @@ -2,7 +2,6 @@ package container import ( "context" - "errors" "fmt" "sync" "time" @@ -112,12 +111,10 @@ func (c *Container) startService(ctx context.Context, key string) error { } var startErr error - hooks := c.registry.GetOnStartHooks(key) - for _, hook := range hooks { + if entry, ok := c.registry.GetEntry(key); ok && entry.OnStart != nil { c.logger.Debug("running OnStart hook", "service", key) - if err := hook(ctx); err != nil { + if err := entry.OnStart(ctx); err != nil { startErr = fmt.Errorf("OnStart hook failed for %s: %w", key, err) - break } } @@ -237,17 +234,15 @@ func (c *Container) stopService(ctx context.Context, key string) error { } start := time.Now() - var errs []error + var stopErr error - hooks := c.registry.GetOnStopHooks(key) - for i := len(hooks) - 1; i >= 0; i-- { + if entry.OnStop != nil { c.logger.Debug("running OnStop hook", "service", key) - if err := hooks[i](ctx); err != nil { - errs = append(errs, fmt.Errorf("OnStop hook failed for %s: %w", key, err)) + if err := entry.OnStop(ctx); err != nil { + stopErr = fmt.Errorf("OnStop hook failed for %s: %w", key, err) } } - stopErr := errors.Join(errs...) c.callStopHooks(key, time.Since(start), stopErr) return stopErr } diff --git a/internal/container/lifecycle_test.go b/internal/container/lifecycle_test.go index 4322980..0517301 100644 --- a/internal/container/lifecycle_test.go +++ b/internal/container/lifecycle_test.go @@ -3,59 +3,54 @@ package container import ( "context" "errors" - "fmt" "strings" "testing" ) -func TestStopService_CollectsAllErrors(t *testing.T) { +func TestStopService_OnStopErrorPropagates(t *testing.T) { t.Parallel() c := New(&Config{}) - err1 := errors.New("hook1 failed") - err2 := errors.New("hook2 failed") + stopErr := errors.New("hook failed") - _ = c.Register("svc", func(ctx context.Context, r Resolver) (any, error) { - return "instance", nil - }, nil) - - c.registry.AddOnStop("svc", func(ctx context.Context) error { - return err1 - }) - c.registry.AddOnStop("svc", func(ctx context.Context) error { - return err2 - }) + _ = c.Register(NewServiceEntry(EntryConfig{ + Key: "svc", + Provider: func(ctx context.Context, r Resolver) (any, error) { + return "instance", nil + }, + OnStop: func(ctx context.Context) error { + return stopErr + }, + })) ctx := context.Background() _, _ = c.Resolve(ctx, "svc") - stopErr := c.stopService(ctx, "svc") - if stopErr == nil { + err := c.stopService(ctx, "svc") + if err == nil { t.Fatal("expected error from stopService") } - msg := stopErr.Error() - if !strings.Contains(msg, "hook1 failed") { - t.Errorf("expected error to contain 'hook1 failed', got: %s", msg) - } - if !strings.Contains(msg, "hook2 failed") { - t.Errorf("expected error to contain 'hook2 failed', got: %s", msg) + if !strings.Contains(err.Error(), "hook failed") { + t.Errorf("expected error to contain 'hook failed', got: %s", err.Error()) } } -func TestStopService_NoErrorWhenHooksSucceed(t *testing.T) { +func TestStopService_NoErrorWhenHookSucceeds(t *testing.T) { t.Parallel() c := New(&Config{}) - _ = c.Register("svc", func(ctx context.Context, r Resolver) (any, error) { - return "instance", nil - }, nil) - - c.registry.AddOnStop("svc", func(ctx context.Context) error { - return nil - }) + _ = c.Register(NewServiceEntry(EntryConfig{ + Key: "svc", + Provider: func(ctx context.Context, r Resolver) (any, error) { + return "instance", nil + }, + OnStop: func(ctx context.Context) error { + return nil + }, + })) ctx := context.Background() _, _ = c.Resolve(ctx, "svc") @@ -73,18 +68,20 @@ func TestStartAndStop_Integration(t *testing.T) { var order []string - _ = c.Register("svc", func(ctx context.Context, r Resolver) (any, error) { - return "instance", nil - }, nil) - - c.registry.AddOnStart("svc", func(ctx context.Context) error { - order = append(order, "started") - return nil - }) - c.registry.AddOnStop("svc", func(ctx context.Context) error { - order = append(order, "stopped") - return nil - }) + _ = c.Register(NewServiceEntry(EntryConfig{ + Key: "svc", + Provider: func(ctx context.Context, r Resolver) (any, error) { + return "instance", nil + }, + OnStart: func(ctx context.Context) error { + order = append(order, "started") + return nil + }, + OnStop: func(ctx context.Context) error { + order = append(order, "stopped") + return nil + }, + })) ctx := context.Background() if err := c.Start(ctx); err != nil { @@ -103,35 +100,3 @@ func TestStartAndStop_Integration(t *testing.T) { t.Errorf("expected [started, stopped], got %v", order) } } - -func TestStopService_MultipleFailingHooks_BothPresent(t *testing.T) { - t.Parallel() - - c := New(&Config{}) - - _ = c.Register("svc", func(ctx context.Context, r Resolver) (any, error) { - return "instance", nil - }, nil) - - c.registry.AddOnStop("svc", func(ctx context.Context) error { - return fmt.Errorf("first error") - }) - c.registry.AddOnStop("svc", func(ctx context.Context) error { - return fmt.Errorf("second error") - }) - - ctx := context.Background() - _, _ = c.Resolve(ctx, "svc") - - stopErr := c.stopService(ctx, "svc") - if stopErr == nil { - t.Fatal("expected combined error") - } - - if !strings.Contains(stopErr.Error(), "first error") { - t.Error("missing first error") - } - if !strings.Contains(stopErr.Error(), "second error") { - t.Error("missing second error") - } -} diff --git a/internal/container/registry.go b/internal/container/registry.go index 2c8771f..68524d2 100644 --- a/internal/container/registry.go +++ b/internal/container/registry.go @@ -22,8 +22,8 @@ type ServiceEntry struct { Instance any Instantiated bool Dependencies []string - OnStart []Hook - OnStop []Hook + OnStart Hook + OnStop Hook Scope scope.Scope PoolSize int pool chan any @@ -33,6 +33,38 @@ type ServiceEntry struct { initErr error } +type EntryConfig struct { + Key string + Provider ProviderFunc + Value any + HasValue bool + Dependencies []string + Scope scope.Scope + PoolSize int + Lazy bool + OnStart Hook + OnStop Hook +} + +func NewServiceEntry(cfg EntryConfig) *ServiceEntry { + entry := &ServiceEntry{ + Key: cfg.Key, + Provider: cfg.Provider, + Instance: cfg.Value, + Instantiated: cfg.HasValue, + Dependencies: cfg.Dependencies, + Scope: cfg.Scope, + PoolSize: cfg.PoolSize, + Lazy: cfg.Lazy, + OnStart: cfg.OnStart, + OnStop: cfg.OnStop, + } + if cfg.PoolSize > 0 { + entry.pool = make(chan any, cfg.PoolSize) + } + return entry +} + type Registry struct { mu sync.RWMutex services map[string]*ServiceEntry @@ -44,26 +76,10 @@ func NewRegistry() *Registry { } } -func (r *Registry) Register(key string, provider ProviderFunc, dependencies []string) error { +func (r *Registry) Add(entry *ServiceEntry) { r.mu.Lock() defer r.mu.Unlock() - r.services[key] = &ServiceEntry{ - Key: key, - Provider: provider, - Dependencies: dependencies, - } - return nil -} - -func (r *Registry) RegisterValue(key string, value any) error { - r.mu.Lock() - defer r.mu.Unlock() - r.services[key] = &ServiceEntry{ - Key: key, - Instance: value, - Instantiated: true, - } - return nil + r.services[entry.Key] = entry } func (r *Registry) Has(key string) bool { @@ -92,7 +108,6 @@ func (r *Registry) GetInstance(key string) (any, bool) { return entry.Instance, true } -// GetInstanceFast avoids defer for performance -- this is a hot path called on every Resolve. func (r *Registry) GetInstanceFast(key string) (any, bool) { r.mu.RLock() entry, exists := r.services[key] @@ -177,24 +192,6 @@ func (r *Registry) AllDependencies() map[string][]string { return deps } -func (r *Registry) AddOnStart(key string, hook Hook) { - r.mu.Lock() - defer r.mu.Unlock() - - if entry, exists := r.services[key]; exists { - entry.OnStart = append(entry.OnStart, hook) - } -} - -func (r *Registry) AddOnStop(key string, hook Hook) { - r.mu.Lock() - defer r.mu.Unlock() - - if entry, exists := r.services[key]; exists { - entry.OnStop = append(entry.OnStop, hook) - } -} - func (r *Registry) GetEntry(key string) (*ServiceEntry, bool) { r.mu.RLock() defer r.mu.RUnlock() @@ -214,27 +211,6 @@ func (r *Registry) AllEntries() []*ServiceEntry { return entries } -func (r *Registry) SetScope(key string, s scope.Scope) { - r.mu.Lock() - defer r.mu.Unlock() - - if entry, exists := r.services[key]; exists { - entry.Scope = s - } -} - -func (r *Registry) SetPoolSize(key string, size int) { - r.mu.Lock() - defer r.mu.Unlock() - - if entry, exists := r.services[key]; exists { - entry.PoolSize = size - if size > 0 { - entry.pool = make(chan any, size) - } - } -} - func (r *Registry) AcquireFromPool(key string) (any, bool) { r.mu.RLock() entry, exists := r.services[key] @@ -269,15 +245,6 @@ func (r *Registry) ReleaseToPool(key string, instance any) bool { } } -func (r *Registry) SetLazy(key string, lazy bool) { - r.mu.Lock() - defer r.mu.Unlock() - - if entry, exists := r.services[key]; exists { - entry.Lazy = lazy - } -} - func (r *Registry) IsLazy(key string) bool { r.mu.RLock() defer r.mu.RUnlock() @@ -288,34 +255,6 @@ func (r *Registry) IsLazy(key string) bool { return false } -func (r *Registry) GetOnStartHooks(key string) []Hook { - r.mu.RLock() - defer r.mu.RUnlock() - - entry, exists := r.services[key] - if !exists { - return nil - } - - hooks := make([]Hook, len(entry.OnStart)) - copy(hooks, entry.OnStart) - return hooks -} - -func (r *Registry) GetOnStopHooks(key string) []Hook { - r.mu.RLock() - defer r.mu.RUnlock() - - entry, exists := r.services[key] - if !exists { - return nil - } - - hooks := make([]Hook, len(entry.OnStop)) - copy(hooks, entry.OnStop) - return hooks -} - func (r *Registry) SetStartRan(key string) { r.mu.Lock() defer r.mu.Unlock() diff --git a/internal/container/replace.go b/internal/container/replace.go index 33564e0..bc8ad23 100644 --- a/internal/container/replace.go +++ b/internal/container/replace.go @@ -2,34 +2,22 @@ package container import "fmt" -func (c *Container) Replace(key string, provider ProviderFunc, dependencies []string) error { +func (c *Container) Replace(entry *ServiceEntry) error { c.mu.Lock() defer c.mu.Unlock() - c.registry.Remove(key) - c.graph.RemoveNode(key) + c.registry.Remove(entry.Key) + c.graph.RemoveNode(entry.Key) - _ = c.registry.Register(key, provider, dependencies) - c.graph.AddNode(key, dependencies) + c.registry.Add(entry) + c.graph.AddNode(entry.Key, entry.Dependencies) - if len(dependencies) > 0 && c.graph.HasCycle() { - c.registry.Remove(key) - c.graph.RemoveNode(key) - cyclePath := c.graph.FindCyclePath(key) + if len(entry.Dependencies) > 0 && c.graph.HasCycle() { + c.registry.Remove(entry.Key) + c.graph.RemoveNode(entry.Key) + cyclePath := c.graph.FindCyclePath(entry.Key) return fmt.Errorf("circular dependency detected: %v", cyclePath) } return nil } - -func (c *Container) ReplaceValue(key string, value any) error { - c.mu.Lock() - defer c.mu.Unlock() - - c.registry.Remove(key) - c.graph.RemoveNode(key) - - _ = c.registry.RegisterValue(key, value) - c.graph.AddNode(key, nil) - return nil -} diff --git a/internal/container/resolve.go b/internal/container/resolve.go index a790238..08f7f97 100644 --- a/internal/container/resolve.go +++ b/internal/container/resolve.go @@ -124,16 +124,14 @@ func (c *Container) resolveSingleton(ctx context.Context, key string, entry *Ser return entry.Instance, nil } -func (c *Container) runLazyStart(ctx context.Context, key string, _ *ServiceEntry) error { +func (c *Container) runLazyStart(ctx context.Context, key string, entry *ServiceEntry) error { start := time.Now() var startErr error - hooks := c.registry.GetOnStartHooks(key) - for _, hook := range hooks { + if entry.OnStart != nil { c.logger.Debug("running lazy OnStart hook", "service", key) - if err := hook(ctx); err != nil { + if err := entry.OnStart(ctx); err != nil { startErr = fmt.Errorf("OnStart hook failed for %s: %w", key, err) - break } } diff --git a/lifecycle_test.go b/lifecycle_test.go index dcd7ece..4a0afbc 100644 --- a/lifecycle_test.go +++ b/lifecycle_test.go @@ -18,25 +18,21 @@ func TestContainer_StartStop(t *testing.T) { var startCount, stopCount atomic.Int32 - err := Provide( - c, func(ctx context.Context, r Resolver) (*testService, error) { + err := Register(c, Spec[*testService]{ + Provider: func(ctx context.Context, r Resolver) (*testService, error) { return &testService{name: "test"}, nil }, - WithOnStart( - func(ctx context.Context) error { - startCount.Add(1) - return nil - }, - ), - WithOnStop( - func(ctx context.Context) error { - stopCount.Add(1) - return nil - }, - ), - ) + OnStart: func(ctx context.Context) error { + startCount.Add(1) + return nil + }, + OnStop: func(ctx context.Context) error { + stopCount.Add(1) + return nil + }, + }) if err != nil { - t.Fatalf("failed to provide: %v", err) + t.Fatalf("failed to register: %v", err) } ctx := context.Background() @@ -65,43 +61,36 @@ func TestContainer_StartOrder(t *testing.T) { var order []string - _ = ProvideValue( - c, &testConfig{value: "config"}, - WithOnStart( - func(ctx context.Context) error { - order = append(order, "config") - return nil - }, - ), - ) + _ = Register(c, SpecValue(&testConfig{value: "config"}).WithOnStart( + func(ctx context.Context) error { + order = append(order, "config") + return nil + }, + )) - _ = Provide( - c, func(ctx context.Context, r Resolver) (*testDatabase, error) { + _ = Register(c, Spec[*testDatabase]{ + Provider: func(ctx context.Context, r Resolver) (*testDatabase, error) { _ = MustInvoke[*testConfig](c) return &testDatabase{}, nil }, - WithDependencies(reflect.TypeKey[*testConfig]()), - WithOnStart( - func(ctx context.Context) error { - order = append(order, "database") - return nil - }, - ), - ) + Dependencies: []string{reflect.TypeKey[*testConfig]()}, + OnStart: func(ctx context.Context) error { + order = append(order, "database") + return nil + }, + }) - _ = Provide( - c, func(ctx context.Context, r Resolver) (*testServer, error) { + _ = Register(c, Spec[*testServer]{ + Provider: func(ctx context.Context, r Resolver) (*testServer, error) { _ = MustInvoke[*testDatabase](c) return &testServer{}, nil }, - WithDependencies(reflect.TypeKey[*testDatabase]()), - WithOnStart( - func(ctx context.Context) error { - order = append(order, "server") - return nil - }, - ), - ) + Dependencies: []string{reflect.TypeKey[*testDatabase]()}, + OnStart: func(ctx context.Context) error { + order = append(order, "server") + return nil + }, + }) ctx := context.Background() if err := c.Start(ctx); err != nil { @@ -128,43 +117,36 @@ func TestContainer_StopOrder(t *testing.T) { var order []string - _ = ProvideValue( - c, &testConfig{value: "config"}, - WithOnStop( - func(ctx context.Context) error { - order = append(order, "config") - return nil - }, - ), - ) + _ = Register(c, SpecValue(&testConfig{value: "config"}).WithOnStop( + func(ctx context.Context) error { + order = append(order, "config") + return nil + }, + )) - _ = Provide( - c, func(ctx context.Context, r Resolver) (*testDatabase, error) { + _ = Register(c, Spec[*testDatabase]{ + Provider: func(ctx context.Context, r Resolver) (*testDatabase, error) { _ = MustInvoke[*testConfig](c) return &testDatabase{}, nil }, - WithDependencies(reflect.TypeKey[*testConfig]()), - WithOnStop( - func(ctx context.Context) error { - order = append(order, "database") - return nil - }, - ), - ) + Dependencies: []string{reflect.TypeKey[*testConfig]()}, + OnStop: func(ctx context.Context) error { + order = append(order, "database") + return nil + }, + }) - _ = Provide( - c, func(ctx context.Context, r Resolver) (*testServer, error) { + _ = Register(c, Spec[*testServer]{ + Provider: func(ctx context.Context, r Resolver) (*testServer, error) { _ = MustInvoke[*testDatabase](c) return &testServer{}, nil }, - WithDependencies(reflect.TypeKey[*testDatabase]()), - WithOnStop( - func(ctx context.Context) error { - order = append(order, "server") - return nil - }, - ), - ) + Dependencies: []string{reflect.TypeKey[*testDatabase]()}, + OnStop: func(ctx context.Context) error { + order = append(order, "server") + return nil + }, + }) ctx := context.Background() _ = c.Start(ctx) @@ -191,16 +173,14 @@ func TestContainer_StartError(t *testing.T) { expectedErr := errors.New("start failed") - _ = Provide( - c, func(ctx context.Context, r Resolver) (*testService, error) { + _ = Register(c, Spec[*testService]{ + Provider: func(ctx context.Context, r Resolver) (*testService, error) { return &testService{name: "test"}, nil }, - WithOnStart( - func(ctx context.Context) error { - return expectedErr - }, - ), - ) + OnStart: func(ctx context.Context) error { + return expectedErr + }, + }) ctx := context.Background() err := c.Start(ctx) @@ -220,16 +200,14 @@ func TestContainer_StopError(t *testing.T) { expectedErr := errors.New("stop failed") - _ = Provide( - c, func(ctx context.Context, r Resolver) (*testService, error) { + _ = Register(c, Spec[*testService]{ + Provider: func(ctx context.Context, r Resolver) (*testService, error) { return &testService{name: "test"}, nil }, - WithOnStop( - func(ctx context.Context) error { - return expectedErr - }, - ), - ) + OnStop: func(ctx context.Context) error { + return expectedErr + }, + }) ctx := context.Background() _ = c.Start(ctx) @@ -240,48 +218,44 @@ func TestContainer_StopError(t *testing.T) { } } -func TestContainer_MultipleHooks(t *testing.T) { +func TestContainer_ComposedHooks(t *testing.T) { t.Parallel() c := New() var order []string - _ = Provide( - c, func(ctx context.Context, r Resolver) (*testService, error) { + _ = Register(c, Spec[*testService]{ + Provider: func(ctx context.Context, r Resolver) (*testService, error) { return &testService{name: "test"}, nil }, - WithOnStart( + OnStart: Compose( func(ctx context.Context) error { order = append(order, "start1") return nil }, - ), - WithOnStart( func(ctx context.Context) error { order = append(order, "start2") return nil }, ), - WithOnStop( + OnStop: Compose( func(ctx context.Context) error { order = append(order, "stop1") return nil }, - ), - WithOnStop( func(ctx context.Context) error { order = append(order, "stop2") return nil }, ), - ) + }) ctx := context.Background() _ = c.Start(ctx) _ = c.Stop(ctx) - expected := []string{"start1", "start2", "stop2", "stop1"} + expected := []string{"start1", "start2", "stop1", "stop2"} if len(order) != len(expected) { t.Fatalf("expected %d items, got %d", len(expected), len(order)) } @@ -299,23 +273,19 @@ func TestContainer_Run(t *testing.T) { var started, stopped atomic.Bool - _ = Provide( - c, func(ctx context.Context, r Resolver) (*testService, error) { + _ = Register(c, Spec[*testService]{ + Provider: func(ctx context.Context, r Resolver) (*testService, error) { return &testService{name: "test"}, nil }, - WithOnStart( - func(ctx context.Context) error { - started.Store(true) - return nil - }, - ), - WithOnStop( - func(ctx context.Context) error { - stopped.Store(true) - return nil - }, - ), - ) + OnStart: func(ctx context.Context) error { + started.Store(true) + return nil + }, + OnStop: func(ctx context.Context) error { + stopped.Store(true) + return nil + }, + }) ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() @@ -338,7 +308,7 @@ func TestContainer_DoubleStart(t *testing.T) { c := New() - _ = ProvideValue(c, &testConfig{value: "config"}) + _ = Register(c, SpecValue(&testConfig{value: "config"})) ctx := context.Background() if err := c.Start(ctx); err != nil { @@ -358,7 +328,7 @@ func TestContainer_StopWithoutStart(t *testing.T) { c := New() - _ = ProvideValue(c, &testConfig{value: "config"}) + _ = Register(c, SpecValue(&testConfig{value: "config"})) ctx := context.Background() err := c.Stop(ctx) @@ -386,19 +356,17 @@ func TestContainer_LazyProvider(t *testing.T) { var instantiated, started atomic.Bool - _ = Provide( - c, func(ctx context.Context, r Resolver) (*testService, error) { + _ = Register(c, Spec[*testService]{ + Provider: func(ctx context.Context, r Resolver) (*testService, error) { instantiated.Store(true) return &testService{name: "lazy"}, nil }, - WithLazy(), - WithOnStart( - func(ctx context.Context) error { - started.Store(true) - return nil - }, - ), - ) + Lazy: true, + OnStart: func(ctx context.Context) error { + started.Store(true) + return nil + }, + }) ctx := context.Background() if err := c.Start(ctx); err != nil { @@ -434,18 +402,16 @@ func TestContainer_LazyProviderOnStartRunsOnce(t *testing.T) { var startCount atomic.Int32 - _ = Provide( - c, func(ctx context.Context, r Resolver) (*testService, error) { + _ = Register(c, Spec[*testService]{ + Provider: func(ctx context.Context, r Resolver) (*testService, error) { return &testService{name: "lazy"}, nil }, - WithLazy(), - WithOnStart( - func(ctx context.Context) error { - startCount.Add(1) - return nil - }, - ), - ) + Lazy: true, + OnStart: func(ctx context.Context) error { + startCount.Add(1) + return nil + }, + }) ctx := context.Background() _ = c.Start(ctx) @@ -468,18 +434,16 @@ func TestContainer_LazyProviderStopHook(t *testing.T) { var stopped atomic.Bool - _ = Provide( - c, func(ctx context.Context, r Resolver) (*testService, error) { + _ = Register(c, Spec[*testService]{ + Provider: func(ctx context.Context, r Resolver) (*testService, error) { return &testService{name: "lazy"}, nil }, - WithLazy(), - WithOnStop( - func(ctx context.Context) error { - stopped.Store(true) - return nil - }, - ), - ) + Lazy: true, + OnStop: func(ctx context.Context) error { + stopped.Store(true) + return nil + }, + }) ctx := context.Background() _ = c.Start(ctx) @@ -498,18 +462,16 @@ func TestContainer_LazyProviderNotInstantiatedNoStop(t *testing.T) { var stopped atomic.Bool - _ = Provide( - c, func(ctx context.Context, r Resolver) (*testService, error) { + _ = Register(c, Spec[*testService]{ + Provider: func(ctx context.Context, r Resolver) (*testService, error) { return &testService{name: "lazy"}, nil }, - WithLazy(), - WithOnStop( - func(ctx context.Context) error { - stopped.Store(true) - return nil - }, - ), - ) + Lazy: true, + OnStop: func(ctx context.Context) error { + stopped.Store(true) + return nil + }, + }) ctx := context.Background() _ = c.Start(ctx) @@ -527,18 +489,16 @@ func TestContainer_LazyProviderBeforeStart(t *testing.T) { var started atomic.Bool - _ = Provide( - c, func(ctx context.Context, r Resolver) (*testService, error) { + _ = Register(c, Spec[*testService]{ + Provider: func(ctx context.Context, r Resolver) (*testService, error) { return &testService{name: "lazy"}, nil }, - WithLazy(), - WithOnStart( - func(ctx context.Context) error { - started.Store(true) - return nil - }, - ), - ) + Lazy: true, + OnStart: func(ctx context.Context) error { + started.Store(true) + return nil + }, + }) _, err := Invoke[*testService](c) if err != nil { @@ -564,22 +524,20 @@ func TestContainer_ShutdownTimeout(t *testing.T) { var stopped atomic.Bool - _ = Provide( - c, func(ctx context.Context, r Resolver) (*testService, error) { + _ = Register(c, Spec[*testService]{ + Provider: func(ctx context.Context, r Resolver) (*testService, error) { return &testService{name: "slow"}, nil }, - WithOnStop( - func(ctx context.Context) error { - select { - case <-time.After(500 * time.Millisecond): - stopped.Store(true) - return nil - case <-ctx.Done(): - return ctx.Err() - } - }, - ), - ) + OnStop: func(ctx context.Context) error { + select { + case <-time.After(500 * time.Millisecond): + stopped.Store(true) + return nil + case <-ctx.Done(): + return ctx.Err() + } + }, + }) ctx := context.Background() _ = c.Start(ctx) @@ -601,17 +559,15 @@ func TestContainer_ShutdownTimeoutNotSet(t *testing.T) { var stopped atomic.Bool - _ = Provide( - c, func(ctx context.Context, r Resolver) (*testService, error) { + _ = Register(c, Spec[*testService]{ + Provider: func(ctx context.Context, r Resolver) (*testService, error) { return &testService{name: "test"}, nil }, - WithOnStop( - func(ctx context.Context) error { - stopped.Store(true) - return nil - }, - ), - ) + OnStop: func(ctx context.Context) error { + stopped.Store(true) + return nil + }, + }) ctx := context.Background() _ = c.Start(ctx) @@ -634,50 +590,43 @@ func TestContainer_ParallelStartup(t *testing.T) { var order []string var mu sync.Mutex - _ = ProvideValue( - c, &testConfig{value: "config"}, - WithOnStart( - func(ctx context.Context) error { - time.Sleep(10 * time.Millisecond) - mu.Lock() - order = append(order, "config") - mu.Unlock() - return nil - }, - ), - ) + _ = Register(c, SpecValue(&testConfig{value: "config"}).WithOnStart( + func(ctx context.Context) error { + time.Sleep(10 * time.Millisecond) + mu.Lock() + order = append(order, "config") + mu.Unlock() + return nil + }, + )) - _ = Provide( - c, func(ctx context.Context, r Resolver) (*testDatabase, error) { + _ = Register(c, Spec[*testDatabase]{ + Provider: func(ctx context.Context, r Resolver) (*testDatabase, error) { _ = MustInvoke[*testConfig](c) return &testDatabase{}, nil }, - WithDependencies(reflect.TypeKey[*testConfig]()), - WithOnStart( - func(ctx context.Context) error { - mu.Lock() - order = append(order, "database") - mu.Unlock() - return nil - }, - ), - ) + Dependencies: []string{reflect.TypeKey[*testConfig]()}, + OnStart: func(ctx context.Context) error { + mu.Lock() + order = append(order, "database") + mu.Unlock() + return nil + }, + }) - _ = Provide( - c, func(ctx context.Context, r Resolver) (*testServer, error) { + _ = Register(c, Spec[*testServer]{ + Provider: func(ctx context.Context, r Resolver) (*testServer, error) { _ = MustInvoke[*testDatabase](c) return &testServer{}, nil }, - WithDependencies(reflect.TypeKey[*testDatabase]()), - WithOnStart( - func(ctx context.Context) error { - mu.Lock() - order = append(order, "server") - mu.Unlock() - return nil - }, - ), - ) + Dependencies: []string{reflect.TypeKey[*testDatabase]()}, + OnStart: func(ctx context.Context) error { + mu.Lock() + order = append(order, "server") + mu.Unlock() + return nil + }, + }) ctx := context.Background() if err := c.Start(ctx); err != nil { @@ -704,50 +653,44 @@ func TestContainer_ParallelStartupIndependent(t *testing.T) { var mu sync.Mutex startTime := time.Now() - _ = Provide( - c, func(ctx context.Context, r Resolver) (*testConfig, error) { + _ = Register(c, Spec[*testConfig]{ + Provider: func(ctx context.Context, r Resolver) (*testConfig, error) { return &testConfig{value: "a"}, nil }, - WithOnStart( - func(ctx context.Context) error { - mu.Lock() - startTimes = append(startTimes, time.Now()) - mu.Unlock() - time.Sleep(50 * time.Millisecond) - return nil - }, - ), - ) + OnStart: func(ctx context.Context) error { + mu.Lock() + startTimes = append(startTimes, time.Now()) + mu.Unlock() + time.Sleep(50 * time.Millisecond) + return nil + }, + }) - _ = Provide( - c, func(ctx context.Context, r Resolver) (*testDatabase, error) { + _ = Register(c, Spec[*testDatabase]{ + Provider: func(ctx context.Context, r Resolver) (*testDatabase, error) { return &testDatabase{}, nil }, - WithOnStart( - func(ctx context.Context) error { - mu.Lock() - startTimes = append(startTimes, time.Now()) - mu.Unlock() - time.Sleep(50 * time.Millisecond) - return nil - }, - ), - ) + OnStart: func(ctx context.Context) error { + mu.Lock() + startTimes = append(startTimes, time.Now()) + mu.Unlock() + time.Sleep(50 * time.Millisecond) + return nil + }, + }) - _ = Provide( - c, func(ctx context.Context, r Resolver) (*testServer, error) { + _ = Register(c, Spec[*testServer]{ + Provider: func(ctx context.Context, r Resolver) (*testServer, error) { return &testServer{}, nil }, - WithOnStart( - func(ctx context.Context) error { - mu.Lock() - startTimes = append(startTimes, time.Now()) - mu.Unlock() - time.Sleep(50 * time.Millisecond) - return nil - }, - ), - ) + OnStart: func(ctx context.Context) error { + mu.Lock() + startTimes = append(startTimes, time.Now()) + mu.Unlock() + time.Sleep(50 * time.Millisecond) + return nil + }, + }) ctx := context.Background() if err := c.Start(ctx); err != nil { @@ -776,49 +719,42 @@ func TestContainer_ParallelShutdown(t *testing.T) { var stopOrder []string var mu sync.Mutex - _ = ProvideValue( - c, &testConfig{value: "config"}, - WithOnStop( - func(ctx context.Context) error { - mu.Lock() - stopOrder = append(stopOrder, "config") - mu.Unlock() - return nil - }, - ), - ) + _ = Register(c, SpecValue(&testConfig{value: "config"}).WithOnStop( + func(ctx context.Context) error { + mu.Lock() + stopOrder = append(stopOrder, "config") + mu.Unlock() + return nil + }, + )) - _ = Provide( - c, func(ctx context.Context, r Resolver) (*testDatabase, error) { + _ = Register(c, Spec[*testDatabase]{ + Provider: func(ctx context.Context, r Resolver) (*testDatabase, error) { _ = MustInvoke[*testConfig](c) return &testDatabase{}, nil }, - WithDependencies(reflect.TypeKey[*testConfig]()), - WithOnStop( - func(ctx context.Context) error { - mu.Lock() - stopOrder = append(stopOrder, "database") - mu.Unlock() - return nil - }, - ), - ) + Dependencies: []string{reflect.TypeKey[*testConfig]()}, + OnStop: func(ctx context.Context) error { + mu.Lock() + stopOrder = append(stopOrder, "database") + mu.Unlock() + return nil + }, + }) - _ = Provide( - c, func(ctx context.Context, r Resolver) (*testServer, error) { + _ = Register(c, Spec[*testServer]{ + Provider: func(ctx context.Context, r Resolver) (*testServer, error) { _ = MustInvoke[*testDatabase](c) return &testServer{}, nil }, - WithDependencies(reflect.TypeKey[*testDatabase]()), - WithOnStop( - func(ctx context.Context) error { - mu.Lock() - stopOrder = append(stopOrder, "server") - mu.Unlock() - return nil - }, - ), - ) + Dependencies: []string{reflect.TypeKey[*testDatabase]()}, + OnStop: func(ctx context.Context) error { + mu.Lock() + stopOrder = append(stopOrder, "server") + mu.Unlock() + return nil + }, + }) ctx := context.Background() _ = c.Start(ctx) diff --git a/module.go b/module.go index 6e44976..4048b1d 100644 --- a/module.go +++ b/module.go @@ -9,27 +9,16 @@ import ( type Module struct { name string - providers []providerEntry + registers []func(c *Container) error decorators []decoratorEntry - bindings []bindingEntry submodules []*Module } -type providerEntry struct { - register func(c *Container) error -} - type decoratorEntry struct { key string decorator func(ctx context.Context, r Resolver, instance any) (any, error) } -type bindingEntry struct { - interfaceKey string - implKey string - opts []ProviderOption -} - func NewModule(name string) *Module { return &Module{ name: name, @@ -40,30 +29,41 @@ func (m *Module) Name() string { return m.name } -func (m *Module) Provide(provider any, opts ...ProviderOption) *Module { - m.providers = append( - m.providers, providerEntry{ - register: func(c *Container) error { - return provideAny(c, provider, opts...) - }, - }, - ) +func (m *Module) Include(submodule *Module) *Module { + m.submodules = append(m.submodules, submodule) return m } -func (m *Module) ProvideValue(value any, opts ...ProviderOption) *Module { - m.providers = append( - m.providers, providerEntry{ - register: func(c *Container) error { - return provideValueAny(c, value, opts...) - }, - }, - ) +func ModuleRegister[T any](m *Module, spec Spec[T]) *Module { + m.registers = append(m.registers, func(c *Container) error { + return Register(c, spec) + }) return m } -func (m *Module) Include(submodule *Module) *Module { - m.submodules = append(m.submodules, submodule) +func ModuleReplace[T any](m *Module, spec Spec[T]) *Module { + m.registers = append(m.registers, func(c *Container) error { + return Replace(c, spec) + }) + return m +} + +func ModuleDecorate[T any](m *Module, decorator Decorator[T]) *Module { + key := reflect.TypeKey[T]() + + m.decorators = append( + m.decorators, decoratorEntry{ + key: key, + decorator: func(ctx context.Context, r Resolver, instance any) (any, error) { + typed, ok := instance.(T) + if !ok { + var zero T + return zero, errDecoratorTypeMismatch(reflect.TypeName[T]()) + } + return decorator(ctx, r, typed) + }, + }, + ) return m } @@ -74,23 +74,17 @@ func (m *Module) apply(c *Container) error { } } - for _, p := range m.providers { - if err := p.register(c); err != nil { - return err - } - } - - for _, b := range m.bindings { - if err := applyBinding(c, b); err != nil { + for _, register := range m.registers { + if err := register(c); err != nil { return err } } for _, d := range m.decorators { + entry := d c.internal.AddDecorator( - d.key, func(ctx context.Context, r container.Resolver, instance any) (any, error) { - resolver := &resolverAdapter{container: c} - return d.decorator(ctx, resolver, instance) + entry.key, func(ctx context.Context, _ container.Resolver, instance any) (any, error) { + return entry.decorator(ctx, c.resolver, instance) }, ) } @@ -98,69 +92,6 @@ func (m *Module) apply(c *Container) error { return nil } -func applyBinding(c *Container, b bindingEntry) error { - cfg := &providerConfig{} - for _, opt := range b.opts { - opt(cfg) - } - - key := b.interfaceKey - if cfg.name != "" { - key = cfg.name + "#" + b.interfaceKey - } - - wrappedProvider := func(ctx context.Context, r container.Resolver) (any, error) { - return r.Resolve(ctx, b.implKey) - } - - if err := c.internal.Register(key, wrappedProvider, []string{b.implKey}); err != nil { - return err - } - - for _, hook := range cfg.onStart { - c.internal.AddOnStart(key, hook) - } - for _, hook := range cfg.onStop { - c.internal.AddOnStop(key, hook) - } - - return nil -} - -func provideAny(c *Container, provider any, opts ...ProviderOption) error { - switch p := provider.(type) { - case func(context.Context, Resolver) (any, error): - return Provide(c, p, opts...) - default: - return errModuleInvalidProvider(provider) - } -} - -func provideValueAny(c *Container, value any, opts ...ProviderOption) error { - cfg := &providerConfig{} - for _, opt := range opts { - opt(cfg) - } - - key := reflect.TypeKeyFromValue(value) - if cfg.name != "" { - key = reflect.TypeKeyNamedFromValue(value, cfg.name) - } - - if err := c.internal.RegisterValue(key, value); err != nil { - return err - } - - for _, hook := range cfg.onStart { - c.internal.AddOnStart(key, hook) - } - for _, hook := range cfg.onStop { - c.internal.AddOnStop(key, hook) - } - - return nil -} - func (c *Container) Apply(modules ...*Module) error { for _, m := range modules { if err := m.apply(c); err != nil { @@ -177,66 +108,3 @@ func errModuleApplyFailed(moduleName string, cause error) *Error { cause, ) } - -func errModuleInvalidProvider(_ any) *Error { - return newError( - ErrCodeModuleInvalidProvider, - "invalid provider type in module", - nil, - ) -} - -func ModuleProvide[T any](m *Module, provider Provider[T], opts ...ProviderOption) *Module { - m.providers = append( - m.providers, providerEntry{ - register: func(c *Container) error { - return Provide(c, provider, opts...) - }, - }, - ) - return m -} - -func ModuleProvideValue[T any](m *Module, value T, opts ...ProviderOption) *Module { - m.providers = append( - m.providers, providerEntry{ - register: func(c *Container) error { - return ProvideValue(c, value, opts...) - }, - }, - ) - return m -} - -func ModuleBind[I, T any](m *Module, opts ...ProviderOption) *Module { - interfaceKey := reflect.TypeKey[I]() - implKey := reflect.TypeKey[T]() - - m.bindings = append( - m.bindings, bindingEntry{ - interfaceKey: interfaceKey, - implKey: implKey, - opts: opts, - }, - ) - return m -} - -func ModuleDecorate[T any](m *Module, decorator Decorator[T]) *Module { - key := reflect.TypeKey[T]() - - m.decorators = append( - m.decorators, decoratorEntry{ - key: key, - decorator: func(ctx context.Context, r Resolver, instance any) (any, error) { - typed, ok := instance.(T) - if !ok { - var zero T - return zero, errDecoratorTypeMismatch(reflect.TypeName[T]()) - } - return decorator(ctx, r, typed) - }, - }, - ) - return m -} diff --git a/module_test.go b/module_test.go index 5d14503..77b27d0 100644 --- a/module_test.go +++ b/module_test.go @@ -32,14 +32,16 @@ func TestModuleBasic(t *testing.T) { } } -func TestModuleProvide(t *testing.T) { +func TestModuleRegister(t *testing.T) { t.Parallel() c := needle.New() module := needle.NewModule("config") - needle.ModuleProvide(module, func(ctx context.Context, r needle.Resolver) (*Config, error) { - return &Config{Port: 9000, Host: "module.local"}, nil + needle.ModuleRegister(module, needle.Spec[*Config]{ + Provider: func(ctx context.Context, r needle.Resolver) (*Config, error) { + return &Config{Port: 9000, Host: "module.local"}, nil + }, }) err := c.Apply(module) @@ -57,14 +59,14 @@ func TestModuleProvide(t *testing.T) { } } -func TestModuleProvideValue(t *testing.T) { +func TestModuleRegisterValue(t *testing.T) { t.Parallel() c := needle.New() config := &Config{Port: 7000} module := needle.NewModule("values") - needle.ModuleProvideValue(module, config) + needle.ModuleRegister(module, needle.SpecValue(config)) err := c.Apply(module) if err != nil { @@ -87,12 +89,14 @@ func TestModuleInclude(t *testing.T) { c := needle.New() configModule := needle.NewModule("config") - needle.ModuleProvideValue(configModule, &Config{Port: 5000}) + needle.ModuleRegister(configModule, needle.SpecValue(&Config{Port: 5000})) dbModule := needle.NewModule("db") - needle.ModuleProvide(dbModule, func(ctx context.Context, r needle.Resolver) (*Database, error) { - cfg := needle.MustInvoke[*Config](c) - return &Database{Config: cfg, Name: "testdb"}, nil + needle.ModuleRegister(dbModule, needle.Spec[*Database]{ + Provider: func(ctx context.Context, r needle.Resolver) (*Database, error) { + cfg := needle.MustInvoke[*Config](c) + return &Database{Config: cfg, Name: "testdb"}, nil + }, }) appModule := needle.NewModule("app"). @@ -114,18 +118,20 @@ func TestModuleInclude(t *testing.T) { } } -func TestModuleBind(t *testing.T) { +func TestModuleBinding(t *testing.T) { t.Parallel() c := needle.New() module := needle.NewModule("repos") - needle.ModuleProvideValue(module, &Database{Name: "postgres"}) - needle.ModuleProvide(module, func(ctx context.Context, r needle.Resolver) (*PostgresUserRepo, error) { - db := needle.MustInvoke[*Database](c) - return &PostgresUserRepo{DB: db}, nil + needle.ModuleRegister(module, needle.SpecValue(&Database{Name: "postgres"})) + needle.ModuleRegister(module, needle.Spec[*PostgresUserRepo]{ + Provider: func(ctx context.Context, r needle.Resolver) (*PostgresUserRepo, error) { + db := needle.MustInvoke[*Database](c) + return &PostgresUserRepo{DB: db}, nil + }, }) - needle.ModuleBind[UserRepository, *PostgresUserRepo](module) + needle.ModuleRegister(module, needle.SpecFromBinding[UserRepository, *PostgresUserRepo]()) err := c.Apply(module) if err != nil { @@ -149,8 +155,10 @@ func TestModuleDecorate(t *testing.T) { c := needle.New() module := needle.NewModule("logging") - needle.ModuleProvide(module, func(ctx context.Context, r needle.Resolver) (*Logger, error) { - return &Logger{Prefix: "app"}, nil + needle.ModuleRegister(module, needle.Spec[*Logger]{ + Provider: func(ctx context.Context, r needle.Resolver) (*Logger, error) { + return &Logger{Prefix: "app"}, nil + }, }) needle.ModuleDecorate(module, func(ctx context.Context, r needle.Resolver, base *Logger) (*Logger, error) { base.Prefix = "[" + base.Prefix + "]" @@ -172,25 +180,27 @@ func TestModuleDecorate(t *testing.T) { } } -func TestBind(t *testing.T) { +func TestSpecFromBinding(t *testing.T) { t.Parallel() c := needle.New() - err := needle.ProvideValue(c, &Database{Name: "main"}) + err := needle.Register(c, needle.SpecValue(&Database{Name: "main"})) if err != nil { - t.Fatalf("ProvideValue failed: %v", err) + t.Fatalf("Register failed: %v", err) } - err = needle.Provide(c, func(ctx context.Context, r needle.Resolver) (*PostgresUserRepo, error) { - db := needle.MustInvoke[*Database](c) - return &PostgresUserRepo{DB: db}, nil + err = needle.Register(c, needle.Spec[*PostgresUserRepo]{ + Provider: func(ctx context.Context, r needle.Resolver) (*PostgresUserRepo, error) { + db := needle.MustInvoke[*Database](c) + return &PostgresUserRepo{DB: db}, nil + }, }) if err != nil { - t.Fatalf("Provide failed: %v", err) + t.Fatalf("Register failed: %v", err) } - err = needle.Bind[UserRepository, *PostgresUserRepo](c) + err = needle.Register(c, needle.SpecFromBinding[UserRepository, *PostgresUserRepo]()) if err != nil { t.Fatalf("Bind failed: %v", err) } @@ -205,27 +215,29 @@ func TestBind(t *testing.T) { } } -func TestBindNamed(t *testing.T) { +func TestSpecFromBindingNamed(t *testing.T) { t.Parallel() c := needle.New() - err := needle.ProvideValue(c, &Database{Name: "named-db"}) + err := needle.Register(c, needle.SpecValue(&Database{Name: "named-db"})) if err != nil { - t.Fatalf("ProvideValue failed: %v", err) + t.Fatalf("Register failed: %v", err) } - err = needle.Provide(c, func(ctx context.Context, r needle.Resolver) (*PostgresUserRepo, error) { - db := needle.MustInvoke[*Database](c) - return &PostgresUserRepo{DB: db}, nil + err = needle.Register(c, needle.Spec[*PostgresUserRepo]{ + Provider: func(ctx context.Context, r needle.Resolver) (*PostgresUserRepo, error) { + db := needle.MustInvoke[*Database](c) + return &PostgresUserRepo{DB: db}, nil + }, }) if err != nil { - t.Fatalf("Provide failed: %v", err) + t.Fatalf("Register failed: %v", err) } - err = needle.BindNamed[UserRepository, *PostgresUserRepo](c, "users") + err = needle.Register(c, needle.SpecFromBinding[UserRepository, *PostgresUserRepo]().WithName("users")) if err != nil { - t.Fatalf("BindNamed failed: %v", err) + t.Fatalf("Register binding failed: %v", err) } repo, err := needle.InvokeNamed[UserRepository](c, "users") @@ -243,11 +255,13 @@ func TestDecorate(t *testing.T) { c := needle.New() - err := needle.Provide(c, func(ctx context.Context, r needle.Resolver) (*Logger, error) { - return &Logger{Prefix: "base"}, nil + err := needle.Register(c, needle.Spec[*Logger]{ + Provider: func(ctx context.Context, r needle.Resolver) (*Logger, error) { + return &Logger{Prefix: "base"}, nil + }, }) if err != nil { - t.Fatalf("Provide failed: %v", err) + t.Fatalf("Register failed: %v", err) } needle.Decorate(c, func(ctx context.Context, r needle.Resolver, base *Logger) (*Logger, error) { @@ -270,11 +284,13 @@ func TestDecorateChain(t *testing.T) { c := needle.New() - err := needle.Provide(c, func(ctx context.Context, r needle.Resolver) (*Logger, error) { - return &Logger{Prefix: "core"}, nil + err := needle.Register(c, needle.Spec[*Logger]{ + Provider: func(ctx context.Context, r needle.Resolver) (*Logger, error) { + return &Logger{Prefix: "core"}, nil + }, }) if err != nil { - t.Fatalf("Provide failed: %v", err) + t.Fatalf("Register failed: %v", err) } needle.Decorate(c, func(ctx context.Context, r needle.Resolver, base *Logger) (*Logger, error) { @@ -302,11 +318,14 @@ func TestDecorateNamed(t *testing.T) { c := needle.New() - err := needle.ProvideNamed(c, "app", func(ctx context.Context, r needle.Resolver) (*Logger, error) { - return &Logger{Prefix: "app"}, nil + err := needle.Register(c, needle.Spec[*Logger]{ + Name: "app", + Provider: func(ctx context.Context, r needle.Resolver) (*Logger, error) { + return &Logger{Prefix: "app"}, nil + }, }) if err != nil { - t.Fatalf("ProvideNamed failed: %v", err) + t.Fatalf("Register failed: %v", err) } needle.DecorateNamed(c, "app", func(ctx context.Context, r needle.Resolver, base *Logger) (*Logger, error) { @@ -330,12 +349,14 @@ func TestMultipleModules(t *testing.T) { c := needle.New() configModule := needle.NewModule("config") - needle.ModuleProvideValue(configModule, &Config{Port: 8080}) + needle.ModuleRegister(configModule, needle.SpecValue(&Config{Port: 8080})) dbModule := needle.NewModule("db") - needle.ModuleProvide(dbModule, func(ctx context.Context, r needle.Resolver) (*Database, error) { - cfg := needle.MustInvoke[*Config](c) - return &Database{Config: cfg, Name: "app-db"}, nil + needle.ModuleRegister(dbModule, needle.Spec[*Database]{ + Provider: func(ctx context.Context, r needle.Resolver) (*Database, error) { + cfg := needle.MustInvoke[*Config](c) + return &Database{Config: cfg, Name: "app-db"}, nil + }, }) err := c.Apply(configModule, dbModule) diff --git a/needle_test.go b/needle_test.go index ced1d17..387d1de 100644 --- a/needle_test.go +++ b/needle_test.go @@ -44,18 +44,18 @@ func TestNewWithLogger(t *testing.T) { } } -func TestProvideAndInvoke(t *testing.T) { +func TestRegisterAndInvoke(t *testing.T) { t.Parallel() c := needle.New() - err := needle.Provide( - c, func(ctx context.Context, r needle.Resolver) (*Config, error) { + err := needle.Register(c, needle.Spec[*Config]{ + Provider: func(ctx context.Context, r needle.Resolver) (*Config, error) { return &Config{Port: 8080, Host: "localhost"}, nil }, - ) + }) if err != nil { - t.Fatalf("Provide failed: %v", err) + t.Fatalf("Register failed: %v", err) } cfg, err := needle.Invoke[*Config](c) @@ -71,15 +71,15 @@ func TestProvideAndInvoke(t *testing.T) { } } -func TestProvideValue(t *testing.T) { +func TestRegisterValue(t *testing.T) { t.Parallel() c := needle.New() config := &Config{Port: 3000, Host: "0.0.0.0"} - err := needle.ProvideValue(c, config) + err := needle.Register(c, needle.SpecValue(config)) if err != nil { - t.Fatalf("ProvideValue failed: %v", err) + t.Fatalf("Register value failed: %v", err) } cfg, err := needle.Invoke[*Config](c) @@ -97,30 +97,30 @@ func TestDependencyChain(t *testing.T) { c := needle.New() - err := needle.ProvideValue(c, &Config{Port: 5432, Host: "db.local"}) + err := needle.Register(c, needle.SpecValue(&Config{Port: 5432, Host: "db.local"})) if err != nil { - t.Fatalf("ProvideValue for Config failed: %v", err) + t.Fatalf("Register Config value failed: %v", err) } - err = needle.Provide( - c, func(ctx context.Context, r needle.Resolver) (*Database, error) { + err = needle.Register(c, needle.Spec[*Database]{ + Provider: func(ctx context.Context, r needle.Resolver) (*Database, error) { cfg := needle.MustInvoke[*Config](c) return &Database{Config: cfg, Name: "testdb"}, nil }, - ) + }) if err != nil { - t.Fatalf("Provide for Database failed: %v", err) + t.Fatalf("Register Database failed: %v", err) } - err = needle.Provide( - c, func(ctx context.Context, r needle.Resolver) (*Server, error) { + err = needle.Register(c, needle.Spec[*Server]{ + Provider: func(ctx context.Context, r needle.Resolver) (*Server, error) { db := needle.MustInvoke[*Database](c) cfg := needle.MustInvoke[*Config](c) return &Server{DB: db, Config: cfg}, nil }, - ) + }) if err != nil { - t.Fatalf("Provide for Server failed: %v", err) + t.Fatalf("Register Server failed: %v", err) } server, err := needle.Invoke[*Server](c) @@ -144,22 +144,24 @@ func TestNamedServices(t *testing.T) { c := needle.New() - err := needle.ProvideNamed( - c, "primary", func(ctx context.Context, r needle.Resolver) (*Database, error) { + err := needle.Register(c, needle.Spec[*Database]{ + Name: "primary", + Provider: func(ctx context.Context, r needle.Resolver) (*Database, error) { return &Database{Name: "primary"}, nil }, - ) + }) if err != nil { - t.Fatalf("ProvideNamed for primary failed: %v", err) + t.Fatalf("Register primary failed: %v", err) } - err = needle.ProvideNamed( - c, "replica", func(ctx context.Context, r needle.Resolver) (*Database, error) { + err = needle.Register(c, needle.Spec[*Database]{ + Name: "replica", + Provider: func(ctx context.Context, r needle.Resolver) (*Database, error) { return &Database{Name: "replica"}, nil }, - ) + }) if err != nil { - t.Fatalf("ProvideNamed for replica failed: %v", err) + t.Fatalf("Register replica failed: %v", err) } primary, err := needle.InvokeNamed[*Database](c, "primary") @@ -185,9 +187,9 @@ func TestMustInvoke(t *testing.T) { c := needle.New() - err := needle.ProvideValue(c, &Config{Port: 8080}) + err := needle.Register(c, needle.SpecValue(&Config{Port: 8080})) if err != nil { - t.Fatalf("ProvideValue failed: %v", err) + t.Fatalf("Register failed: %v", err) } cfg := needle.MustInvoke[*Config](c) @@ -220,9 +222,9 @@ func TestTryInvoke(t *testing.T) { t.Error("TryInvoke should return false for missing service") } - err := needle.ProvideValue(c, &Config{Port: 8080}) + err := needle.Register(c, needle.SpecValue(&Config{Port: 8080})) if err != nil { - t.Fatalf("ProvideValue failed: %v", err) + t.Fatalf("Register failed: %v", err) } cfg, ok := needle.TryInvoke[*Config](c) @@ -243,9 +245,9 @@ func TestHas(t *testing.T) { t.Error("Has should return false for missing service") } - err := needle.ProvideValue(c, &Config{}) + err := needle.Register(c, needle.SpecValue(&Config{})) if err != nil { - t.Fatalf("ProvideValue failed: %v", err) + t.Fatalf("Register failed: %v", err) } if !needle.Has[*Config](c) { @@ -262,9 +264,9 @@ func TestHasNamed(t *testing.T) { t.Error("HasNamed should return false for missing service") } - err := needle.ProvideNamedValue(c, "myconfig", &Config{}) + err := needle.Register(c, needle.SpecValue(&Config{}).WithName("myconfig")) if err != nil { - t.Fatalf("ProvideNamedValue failed: %v", err) + t.Fatalf("Register failed: %v", err) } if !needle.HasNamed[*Config](c, "myconfig") { @@ -278,13 +280,13 @@ func TestProviderError(t *testing.T) { c := needle.New() expectedErr := errors.New("provider error") - err := needle.Provide( - c, func(ctx context.Context, r needle.Resolver) (*Config, error) { + err := needle.Register(c, needle.Spec[*Config]{ + Provider: func(ctx context.Context, r needle.Resolver) (*Config, error) { return nil, expectedErr }, - ) + }) if err != nil { - t.Fatalf("Provide failed: %v", err) + t.Fatalf("Register failed: %v", err) } _, err = needle.Invoke[*Config](c) @@ -298,9 +300,9 @@ func TestContainerValidate(t *testing.T) { c := needle.New() - err := needle.ProvideValue(c, &Config{}) + err := needle.Register(c, needle.SpecValue(&Config{})) if err != nil { - t.Fatalf("ProvideValue failed: %v", err) + t.Fatalf("Register failed: %v", err) } err = c.Validate() @@ -318,8 +320,8 @@ func TestContainerSize(t *testing.T) { t.Error("empty container should have size 0") } - _ = needle.ProvideValue(c, &Config{}) - _ = needle.ProvideValue(c, &Database{}) + _ = needle.Register(c, needle.SpecValue(&Config{})) + _ = needle.Register(c, needle.SpecValue(&Database{})) if c.Size() != 2 { t.Errorf("expected size 2, got %d", c.Size()) @@ -331,8 +333,8 @@ func TestContainerKeys(t *testing.T) { c := needle.New() - _ = needle.ProvideValue(c, &Config{}) - _ = needle.ProvideValue(c, &Database{}) + _ = needle.Register(c, needle.SpecValue(&Config{})) + _ = needle.Register(c, needle.SpecValue(&Database{})) keys := c.Keys() if len(keys) != 2 { @@ -345,13 +347,13 @@ func TestInvokeWithContext(t *testing.T) { c := needle.New() - err := needle.Provide( - c, func(ctx context.Context, r needle.Resolver) (*Config, error) { + err := needle.Register(c, needle.Spec[*Config]{ + Provider: func(ctx context.Context, r needle.Resolver) (*Config, error) { return &Config{Port: 8080}, nil }, - ) + }) if err != nil { - t.Fatalf("Provide failed: %v", err) + t.Fatalf("Register failed: %v", err) } ctx := context.Background() @@ -365,9 +367,9 @@ func TestInvokeWithContext(t *testing.T) { } } -func BenchmarkProvideAndInvoke(b *testing.B) { +func BenchmarkRegisterAndInvoke(b *testing.B) { c := needle.New() - _ = needle.ProvideValue(c, &Config{Port: 8080}) + _ = needle.Register(c, needle.SpecValue(&Config{Port: 8080})) b.ReportAllocs() for b.Loop() { @@ -377,7 +379,7 @@ func BenchmarkProvideAndInvoke(b *testing.B) { func BenchmarkMustInvoke(b *testing.B) { c := needle.New() - _ = needle.ProvideValue(c, &Config{Port: 8080}) + _ = needle.Register(c, needle.SpecValue(&Config{Port: 8080})) b.ReportAllocs() for b.Loop() { @@ -389,7 +391,7 @@ func TestOptionalPresent(t *testing.T) { t.Parallel() c := needle.New() - _ = needle.ProvideValue(c, &Config{Port: 8080, Host: "localhost"}) + _ = needle.Register(c, needle.SpecValue(&Config{Port: 8080, Host: "localhost"})) opt, err := needle.InvokeOptional[*Config](c) if err != nil { @@ -452,7 +454,7 @@ func TestOptionalOrElse(t *testing.T) { t.Errorf("expected port 3000, got %d", result.Port) } - _ = needle.ProvideValue(c, &Config{Port: 8080}) + _ = needle.Register(c, needle.SpecValue(&Config{Port: 8080})) opt2, err := needle.InvokeOptional[*Config](c) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -486,7 +488,7 @@ func TestOptionalOrElseFunc(t *testing.T) { t.Errorf("expected func to be called once, got %d", callCount) } - _ = needle.ProvideValue(c, &Config{Port: 8080}) + _ = needle.Register(c, needle.SpecValue(&Config{Port: 8080})) opt2, err := needle.InvokeOptional[*Config](c) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -508,7 +510,7 @@ func TestOptionalNamed(t *testing.T) { t.Parallel() c := needle.New() - _ = needle.ProvideNamedValue(c, "primary", &Config{Port: 5432}) + _ = needle.Register(c, needle.SpecValue(&Config{Port: 5432}).WithName("primary")) opt, err := needle.InvokeOptionalNamed[*Config](c, "primary") if err != nil { @@ -543,14 +545,16 @@ func TestOptionalInProvider(t *testing.T) { Cache *Cache } - _ = needle.Provide(c, func(ctx context.Context, r needle.Resolver) (*Service, error) { - cacheOpt, err := needle.InvokeOptional[*Cache](c) - if err != nil { - return nil, err - } - return &Service{ - Cache: cacheOpt.OrElse(nil), - }, nil + _ = needle.Register(c, needle.Spec[*Service]{ + Provider: func(ctx context.Context, r needle.Resolver) (*Service, error) { + cacheOpt, err := needle.InvokeOptional[*Cache](c) + if err != nil { + return nil, err + } + return &Service{ + Cache: cacheOpt.OrElse(nil), + }, nil + }, }) svc := needle.MustInvoke[*Service](c) @@ -572,15 +576,17 @@ func TestOptionalInProviderWithValue(t *testing.T) { Cache *Cache } - _ = needle.ProvideValue(c, &Cache{Enabled: true}) - _ = needle.Provide(c, func(ctx context.Context, r needle.Resolver) (*Service, error) { - cacheOpt, err := needle.InvokeOptional[*Cache](c) - if err != nil { - return nil, err - } - return &Service{ - Cache: cacheOpt.OrElse(nil), - }, nil + _ = needle.Register(c, needle.SpecValue(&Cache{Enabled: true})) + _ = needle.Register(c, needle.Spec[*Service]{ + Provider: func(ctx context.Context, r needle.Resolver) (*Service, error) { + cacheOpt, err := needle.InvokeOptional[*Cache](c) + if err != nil { + return nil, err + } + return &Service{ + Cache: cacheOpt.OrElse(nil), + }, nil + }, }) svc := needle.MustInvoke[*Service](c) @@ -597,8 +603,10 @@ func TestOptionalResolutionError(t *testing.T) { c := needle.New() - _ = needle.Provide(c, func(_ context.Context, _ needle.Resolver) (*Config, error) { - return nil, errors.New("provider broken") + _ = needle.Register(c, needle.Spec[*Config]{ + Provider: func(_ context.Context, _ needle.Resolver) (*Config, error) { + return nil, errors.New("provider broken") + }, }) opt, err := needle.InvokeOptional[*Config](c) diff --git a/observability_test.go b/observability_test.go index 89f504c..7d5a7fd 100644 --- a/observability_test.go +++ b/observability_test.go @@ -40,9 +40,9 @@ func TestHealthCheckHealthyService(t *testing.T) { c := needle.New() ctx := context.Background() - err := needle.ProvideValue(c, &HealthyService{}) + err := needle.Register(c, needle.SpecValue(&HealthyService{})) if err != nil { - t.Fatalf("ProvideValue failed: %v", err) + t.Fatalf("Register failed: %v", err) } _ = c.Start(ctx) @@ -68,9 +68,9 @@ func TestHealthCheckUnhealthyService(t *testing.T) { c := needle.New() ctx := context.Background() - err := needle.ProvideValue(c, &UnhealthyService{}) + err := needle.Register(c, needle.SpecValue(&UnhealthyService{})) if err != nil { - t.Fatalf("ProvideValue failed: %v", err) + t.Fatalf("Register failed: %v", err) } _ = c.Start(ctx) @@ -100,9 +100,9 @@ func TestReadinessCheckReadyService(t *testing.T) { c := needle.New() ctx := context.Background() - err := needle.ProvideValue(c, &ReadyService{}) + err := needle.Register(c, needle.SpecValue(&ReadyService{})) if err != nil { - t.Fatalf("ProvideValue failed: %v", err) + t.Fatalf("Register failed: %v", err) } _ = c.Start(ctx) @@ -119,9 +119,9 @@ func TestReadinessCheckNotReadyService(t *testing.T) { c := needle.New() ctx := context.Background() - err := needle.ProvideValue(c, &NotReadyService{}) + err := needle.Register(c, needle.SpecValue(&NotReadyService{})) if err != nil { - t.Fatalf("ProvideValue failed: %v", err) + t.Fatalf("Register failed: %v", err) } _ = c.Start(ctx) @@ -138,9 +138,9 @@ func TestHealthCheckNoHealthCheckers(t *testing.T) { c := needle.New() ctx := context.Background() - err := needle.ProvideValue(c, &Config{Port: 8080}) + err := needle.Register(c, needle.SpecValue(&Config{Port: 8080})) if err != nil { - t.Fatalf("ProvideValue failed: %v", err) + t.Fatalf("Register failed: %v", err) } _ = c.Start(ctx) @@ -171,9 +171,9 @@ func TestResolveObserver(t *testing.T) { }), ) - err := needle.ProvideValue(c, &Config{Port: 8080}) + err := needle.Register(c, needle.SpecValue(&Config{Port: 8080})) if err != nil { - t.Fatalf("ProvideValue failed: %v", err) + t.Fatalf("Register failed: %v", err) } _, err = needle.Invoke[*Config](c) @@ -225,8 +225,8 @@ func TestProvideObserver(t *testing.T) { }), ) - _ = needle.ProvideValue(c, &Config{Port: 8080}) - _ = needle.ProvideValue(c, &Database{Name: "test"}) + _ = needle.Register(c, needle.SpecValue(&Config{Port: 8080})) + _ = needle.Register(c, needle.SpecValue(&Database{Name: "test"})) if callCount.Load() != 2 { t.Errorf("expected 2 provide hook calls, got %d", callCount.Load()) @@ -250,7 +250,7 @@ func TestStartObserver(t *testing.T) { }), ) - _ = needle.ProvideValue(c, &Config{Port: 8080}) + _ = needle.Register(c, needle.SpecValue(&Config{Port: 8080})) ctx := context.Background() err := c.Start(ctx) @@ -274,11 +274,14 @@ func TestStopObserver(t *testing.T) { }), ) - _ = needle.Provide(c, func(ctx context.Context, r needle.Resolver) (*Server, error) { - return &Server{}, nil - }, needle.WithOnStop(func(ctx context.Context) error { - return nil - })) + _ = needle.Register(c, needle.Spec[*Server]{ + Provider: func(ctx context.Context, r needle.Resolver) (*Server, error) { + return &Server{}, nil + }, + OnStop: func(ctx context.Context) error { + return nil + }, + }) ctx := context.Background() _ = c.Start(ctx) @@ -302,7 +305,7 @@ func TestHealthReportLatency(t *testing.T) { c := needle.New() ctx := context.Background() - _ = needle.ProvideValue(c, &SlowHealthService{}) + _ = needle.Register(c, needle.SpecValue(&SlowHealthService{})) _ = c.Start(ctx) reports := c.Health(ctx) @@ -330,7 +333,7 @@ func TestMultipleObservers(t *testing.T) { }), ) - _ = needle.ProvideValue(c, &Config{Port: 8080}) + _ = needle.Register(c, needle.SpecValue(&Config{Port: 8080})) _, _ = needle.Invoke[*Config](c) if count1.Load() != 1 || count2.Load() != 1 { diff --git a/provider.go b/provider.go deleted file mode 100644 index 3cc0b92..0000000 --- a/provider.go +++ /dev/null @@ -1,158 +0,0 @@ -package needle - -import ( - "context" - - "github.com/danpasecinic/needle/internal/container" - "github.com/danpasecinic/needle/internal/reflect" - "github.com/danpasecinic/needle/internal/scope" -) - -type Provider[T any] func(ctx context.Context, r Resolver) (T, error) - -type ProviderOption func(*providerConfig) - -type providerConfig struct { - name string - dependencies []string - onStart []container.Hook - onStop []container.Hook - scope scope.Scope - poolSize int - lazy bool -} - -func Provide[T any](c *Container, provider Provider[T], opts ...ProviderOption) error { - if len(opts) == 0 { - key := reflect.TypeKey[T]() - resolver := c.resolver - wrappedProvider := func(ctx context.Context, r container.Resolver) (any, error) { - return provider(ctx, resolver) - } - return c.internal.Register(key, wrappedProvider, nil) - } - - cfg := &providerConfig{} - for _, opt := range opts { - opt(cfg) - } - - key := reflect.TypeKey[T]() - if cfg.name != "" { - key = reflect.TypeKeyNamed[T](cfg.name) - } - - resolver := c.resolver - wrappedProvider := func(ctx context.Context, r container.Resolver) (any, error) { - return provider(ctx, resolver) - } - - if err := c.internal.Register(key, wrappedProvider, cfg.dependencies); err != nil { - return err - } - - for _, hook := range cfg.onStart { - c.internal.AddOnStart(key, hook) - } - for _, hook := range cfg.onStop { - c.internal.AddOnStop(key, hook) - } - - if cfg.scope != scope.Singleton { - c.internal.SetScope(key, cfg.scope) - } - if cfg.poolSize > 0 { - c.internal.SetPoolSize(key, cfg.poolSize) - } - if cfg.lazy { - c.internal.SetLazy(key, true) - } - - return nil -} - -func ProvideValue[T any](c *Container, value T, opts ...ProviderOption) error { - cfg := &providerConfig{} - for _, opt := range opts { - opt(cfg) - } - - key := reflect.TypeKey[T]() - if cfg.name != "" { - key = reflect.TypeKeyNamed[T](cfg.name) - } - - if err := c.internal.RegisterValue(key, value); err != nil { - return err - } - - for _, hook := range cfg.onStart { - c.internal.AddOnStart(key, hook) - } - for _, hook := range cfg.onStop { - c.internal.AddOnStop(key, hook) - } - - return nil -} - -func ProvideNamed[T any](c *Container, name string, provider Provider[T], opts ...ProviderOption) error { - if len(opts) == 0 { - key := reflect.TypeKeyNamed[T](name) - resolver := c.resolver - wrappedProvider := func(ctx context.Context, r container.Resolver) (any, error) { - return provider(ctx, resolver) - } - return c.internal.Register(key, wrappedProvider, nil) - } - opts = append(opts, WithName(name)) - return Provide(c, provider, opts...) -} - -func ProvideNamedValue[T any](c *Container, name string, value T, opts ...ProviderOption) error { - opts = append(opts, WithName(name)) - return ProvideValue(c, value, opts...) -} - -func WithName(name string) ProviderOption { - return func(cfg *providerConfig) { - cfg.name = name - } -} - -func WithDependencies(deps ...string) ProviderOption { - return func(cfg *providerConfig) { - cfg.dependencies = deps - } -} - -func WithOnStart(hook Hook) ProviderOption { - return func(cfg *providerConfig) { - cfg.onStart = append(cfg.onStart, container.Hook(hook)) - } -} - -func WithOnStop(hook Hook) ProviderOption { - return func(cfg *providerConfig) { - cfg.onStop = append(cfg.onStop, container.Hook(hook)) - } -} - -func WithScope(s Scope) ProviderOption { - return func(cfg *providerConfig) { - cfg.scope = s - } -} - -func WithPoolSize(size int) ProviderOption { - return func(cfg *providerConfig) { - cfg.scope = scope.Pooled - cfg.poolSize = size - } -} - -func WithLazy() ProviderOption { - return func(cfg *providerConfig) { - cfg.lazy = true - } -} diff --git a/replace.go b/replace.go deleted file mode 100644 index 0af4102..0000000 --- a/replace.go +++ /dev/null @@ -1,123 +0,0 @@ -package needle - -import ( - "context" - - "github.com/danpasecinic/needle/internal/container" - "github.com/danpasecinic/needle/internal/reflect" -) - -func Replace[T any](c *Container, provider Provider[T], opts ...ProviderOption) error { - cfg := &providerConfig{} - for _, opt := range opts { - opt(cfg) - } - - key := reflect.TypeKey[T]() - if cfg.name != "" { - key = reflect.TypeKeyNamed[T](cfg.name) - } - - resolver := c.resolver - wrappedProvider := func(ctx context.Context, r container.Resolver) (any, error) { - return provider(ctx, resolver) - } - - if err := c.internal.Replace(key, wrappedProvider, cfg.dependencies); err != nil { - return err - } - - for _, hook := range cfg.onStart { - c.internal.AddOnStart(key, hook) - } - for _, hook := range cfg.onStop { - c.internal.AddOnStop(key, hook) - } - - if cfg.scope != 0 { - c.internal.SetScope(key, cfg.scope) - } - if cfg.poolSize > 0 { - c.internal.SetPoolSize(key, cfg.poolSize) - } - if cfg.lazy { - c.internal.SetLazy(key, true) - } - - return nil -} - -func ReplaceValue[T any](c *Container, value T, opts ...ProviderOption) error { - cfg := &providerConfig{} - for _, opt := range opts { - opt(cfg) - } - - key := reflect.TypeKey[T]() - if cfg.name != "" { - key = reflect.TypeKeyNamed[T](cfg.name) - } - - if err := c.internal.ReplaceValue(key, value); err != nil { - return err - } - - for _, hook := range cfg.onStart { - c.internal.AddOnStart(key, hook) - } - for _, hook := range cfg.onStop { - c.internal.AddOnStop(key, hook) - } - - return nil -} - -func ReplaceNamed[T any](c *Container, name string, provider Provider[T], opts ...ProviderOption) error { - opts = append(opts, WithName(name)) - return Replace(c, provider, opts...) -} - -func ReplaceNamedValue[T any](c *Container, name string, value T, opts ...ProviderOption) error { - opts = append(opts, WithName(name)) - return ReplaceValue(c, value, opts...) -} - -func MustReplace[T any](c *Container, provider Provider[T], opts ...ProviderOption) { - if err := Replace(c, provider, opts...); err != nil { - panic(err) - } -} - -func MustReplaceValue[T any](c *Container, value T, opts ...ProviderOption) { - if err := ReplaceValue(c, value, opts...); err != nil { - panic(err) - } -} - -func ReplaceFunc[T any](c *Container, constructor any, opts ...ProviderOption) error { - provider, depOpts, err := buildFuncProvider[T](c, constructor) - if err != nil { - return err - } - - opts = append(depOpts, opts...) - return Replace(c, provider, opts...) -} - -func MustReplaceFunc[T any](c *Container, constructor any, opts ...ProviderOption) { - if err := ReplaceFunc[T](c, constructor, opts...); err != nil { - panic(err) - } -} - -func ReplaceStruct[T any](c *Container, opts ...ProviderOption) error { - provider, depOpts := buildStructProvider[T](c) - opts = append(depOpts, opts...) - return Replace(c, provider, opts...) -} - -func MustReplaceStruct[T any](c *Container, opts ...ProviderOption) { - if err := ReplaceStruct[T](c, opts...); err != nil { - panic(err) - } -} diff --git a/replace_test.go b/replace_test.go index 84780d3..556bec2 100644 --- a/replace_test.go +++ b/replace_test.go @@ -20,7 +20,7 @@ func TestReplace(t *testing.T) { "replaces existing provider", func(t *testing.T) { c := needle.New() - _ = needle.ProvideValue(c, &ReplaceConfig{Value: "original"}) + _ = needle.Register(c, needle.SpecValue(&ReplaceConfig{Value: "original"})) cfg, err := needle.Invoke[*ReplaceConfig](c) if err != nil { @@ -30,7 +30,7 @@ func TestReplace(t *testing.T) { t.Errorf("expected 'original', got '%s'", cfg.Value) } - _ = needle.ReplaceValue(c, &ReplaceConfig{Value: "replaced"}) + _ = needle.Replace(c, needle.SpecValue(&ReplaceConfig{Value: "replaced"})) cfg, err = needle.Invoke[*ReplaceConfig](c) if err != nil { @@ -46,27 +46,27 @@ func TestReplace(t *testing.T) { "replaces provider with dependencies", func(t *testing.T) { c := needle.New() - _ = needle.ProvideValue(c, &ReplaceConfig{Value: "v1"}) - _ = needle.Provide( - c, func(ctx context.Context, r needle.Resolver) (*ReplaceService, error) { + _ = needle.Register(c, needle.SpecValue(&ReplaceConfig{Value: "v1"})) + _ = needle.Register(c, needle.Spec[*ReplaceService]{ + Provider: func(ctx context.Context, r needle.Resolver) (*ReplaceService, error) { cfg := needle.MustInvoke[*ReplaceConfig](c) return &ReplaceService{Config: cfg}, nil }, - ) + }) svc := needle.MustInvoke[*ReplaceService](c) if svc.Config.Value != "v1" { t.Errorf("expected 'v1', got '%s'", svc.Config.Value) } - _ = needle.ReplaceValue(c, &ReplaceConfig{Value: "v2"}) + _ = needle.Replace(c, needle.SpecValue(&ReplaceConfig{Value: "v2"})) - _ = needle.Replace( - c, func(ctx context.Context, r needle.Resolver) (*ReplaceService, error) { + _ = needle.Replace(c, needle.Spec[*ReplaceService]{ + Provider: func(ctx context.Context, r needle.Resolver) (*ReplaceService, error) { cfg := needle.MustInvoke[*ReplaceConfig](c) return &ReplaceService{Config: cfg}, nil }, - ) + }) svc = needle.MustInvoke[*ReplaceService](c) if svc.Config.Value != "v2" { @@ -79,7 +79,7 @@ func TestReplace(t *testing.T) { "replace non-existent service creates it", func(t *testing.T) { c := needle.New() - _ = needle.ReplaceValue(c, &ReplaceConfig{Value: "new"}) + _ = needle.Replace(c, needle.SpecValue(&ReplaceConfig{Value: "new"})) cfg, err := needle.Invoke[*ReplaceConfig](c) if err != nil { @@ -97,7 +97,7 @@ func TestReplaceNamed(t *testing.T) { "replaces named provider", func(t *testing.T) { c := needle.New() - _ = needle.ProvideNamedValue(c, "primary", &ReplaceConfig{Value: "orig"}) + _ = needle.Register(c, needle.SpecValue(&ReplaceConfig{Value: "orig"}).WithName("primary")) cfg, err := needle.InvokeNamed[*ReplaceConfig](c, "primary") if err != nil { @@ -107,7 +107,7 @@ func TestReplaceNamed(t *testing.T) { t.Errorf("expected 'orig', got '%s'", cfg.Value) } - _ = needle.ReplaceNamedValue(c, "primary", &ReplaceConfig{Value: "new"}) + _ = needle.Replace(c, needle.SpecValue(&ReplaceConfig{Value: "new"}).WithName("primary")) cfg, err = needle.InvokeNamed[*ReplaceConfig](c, "primary") if err != nil { @@ -125,9 +125,9 @@ func TestMustReplace(t *testing.T) { "does not panic on valid replace", func(t *testing.T) { c := needle.New() - _ = needle.ProvideValue(c, &ReplaceConfig{Value: "original"}) + _ = needle.Register(c, needle.SpecValue(&ReplaceConfig{Value: "original"})) - needle.MustReplaceValue(c, &ReplaceConfig{Value: "replaced"}) + needle.MustReplace(c, needle.SpecValue(&ReplaceConfig{Value: "replaced"})) cfg := needle.MustInvoke[*ReplaceConfig](c) if cfg.Value != "replaced" { @@ -142,14 +142,14 @@ func TestReplaceWithOptions(t *testing.T) { "replaces with scope option", func(t *testing.T) { c := needle.New() - _ = needle.ProvideValue(c, &ReplaceConfig{Value: "singleton"}) + _ = needle.Register(c, needle.SpecValue(&ReplaceConfig{Value: "singleton"})) - _ = needle.Replace( - c, func(ctx context.Context, r needle.Resolver) (*ReplaceConfig, error) { + _ = needle.Replace(c, needle.Spec[*ReplaceConfig]{ + Provider: func(ctx context.Context, r needle.Resolver) (*ReplaceConfig, error) { return &ReplaceConfig{Value: "transient"}, nil }, - needle.WithScope(needle.Transient), - ) + Scope: needle.Transient, + }) cfg1 := needle.MustInvoke[*ReplaceConfig](c) cfg2 := needle.MustInvoke[*ReplaceConfig](c) @@ -165,21 +165,21 @@ func NewReplaceService(cfg *ReplaceConfig) *ReplaceService { return &ReplaceService{Config: cfg} } -func TestReplaceFunc(t *testing.T) { +func TestReplaceConstructor(t *testing.T) { t.Run( "replaces with auto-wired constructor", func(t *testing.T) { c := needle.New() - _ = needle.ProvideValue(c, &ReplaceConfig{Value: "v1"}) - _ = needle.ProvideFunc[*ReplaceService](c, NewReplaceService) + _ = needle.Register(c, needle.SpecValue(&ReplaceConfig{Value: "v1"})) + _ = needle.Register(c, needle.SpecFromConstructor[*ReplaceService](NewReplaceService)) svc := needle.MustInvoke[*ReplaceService](c) if svc.Config.Value != "v1" { t.Errorf("expected 'v1', got '%s'", svc.Config.Value) } - _ = needle.ReplaceValue(c, &ReplaceConfig{Value: "v2"}) - _ = needle.ReplaceFunc[*ReplaceService](c, NewReplaceService) + _ = needle.Replace(c, needle.SpecValue(&ReplaceConfig{Value: "v2"})) + _ = needle.Replace(c, needle.SpecFromConstructor[*ReplaceService](NewReplaceService)) svc = needle.MustInvoke[*ReplaceService](c) if svc.Config.Value != "v2" { @@ -198,16 +198,16 @@ func TestReplaceStruct(t *testing.T) { "replaces with struct injection", func(t *testing.T) { c := needle.New() - _ = needle.ProvideValue(c, &ReplaceConfig{Value: "original"}) - _ = needle.ProvideStruct[*ReplaceStructService](c) + _ = needle.Register(c, needle.SpecValue(&ReplaceConfig{Value: "original"})) + _ = needle.Register(c, needle.SpecFromStruct[*ReplaceStructService]()) svc := needle.MustInvoke[*ReplaceStructService](c) if svc.Config.Value != "original" { t.Errorf("expected 'original', got '%s'", svc.Config.Value) } - _ = needle.ReplaceValue(c, &ReplaceConfig{Value: "replaced"}) - _ = needle.ReplaceStruct[*ReplaceStructService](c) + _ = needle.Replace(c, needle.SpecValue(&ReplaceConfig{Value: "replaced"})) + _ = needle.Replace(c, needle.SpecFromStruct[*ReplaceStructService]()) svc = needle.MustInvoke[*ReplaceStructService](c) if svc.Config.Value != "replaced" { diff --git a/scope_test.go b/scope_test.go index b7e6250..37b143a 100644 --- a/scope_test.go +++ b/scope_test.go @@ -13,12 +13,12 @@ func TestScope_Singleton(t *testing.T) { var callCount atomic.Int32 - _ = Provide( - c, func(ctx context.Context, r Resolver) (*testCounter, error) { + _ = Register(c, Spec[*testCounter]{ + Provider: func(ctx context.Context, r Resolver) (*testCounter, error) { callCount.Add(1) return &testCounter{id: int(callCount.Load())}, nil }, - ) + }) first, _ := Invoke[*testCounter](c) second, _ := Invoke[*testCounter](c) @@ -40,12 +40,13 @@ func TestScope_Transient(t *testing.T) { var callCount atomic.Int32 - _ = Provide( - c, func(ctx context.Context, r Resolver) (*testCounter, error) { + _ = Register(c, Spec[*testCounter]{ + Provider: func(ctx context.Context, r Resolver) (*testCounter, error) { callCount.Add(1) return &testCounter{id: int(callCount.Load())}, nil - }, WithScope(Transient), - ) + }, + Scope: Transient, + }) ctx := context.Background() @@ -73,12 +74,13 @@ func TestScope_Request(t *testing.T) { var callCount atomic.Int32 - _ = Provide( - c, func(ctx context.Context, r Resolver) (*testCounter, error) { + _ = Register(c, Spec[*testCounter]{ + Provider: func(ctx context.Context, r Resolver) (*testCounter, error) { callCount.Add(1) return &testCounter{id: int(callCount.Load())}, nil - }, WithScope(Request), - ) + }, + Scope: Request, + }) ctx1 := WithRequestScope(context.Background()) ctx2 := WithRequestScope(context.Background()) @@ -111,11 +113,12 @@ func TestScope_Request_NoScope(t *testing.T) { c := New() - _ = Provide( - c, func(ctx context.Context, r Resolver) (*testCounter, error) { + _ = Register(c, Spec[*testCounter]{ + Provider: func(ctx context.Context, r Resolver) (*testCounter, error) { return &testCounter{id: 1}, nil - }, WithScope(Request), - ) + }, + Scope: Request, + }) ctx := context.Background() @@ -132,12 +135,14 @@ func TestScope_Pooled(t *testing.T) { var callCount atomic.Int32 - _ = Provide( - c, func(ctx context.Context, r Resolver) (*testCounter, error) { + _ = Register(c, Spec[*testCounter]{ + Provider: func(ctx context.Context, r Resolver) (*testCounter, error) { callCount.Add(1) return &testCounter{id: int(callCount.Load())}, nil - }, WithPoolSize(2), - ) + }, + Scope: Pooled, + PoolSize: 2, + }) ctx := context.Background() @@ -174,12 +179,14 @@ func TestScope_Pooled_Overflow(t *testing.T) { var callCount atomic.Int32 - _ = Provide( - c, func(ctx context.Context, r Resolver) (*testCounter, error) { + _ = Register(c, Spec[*testCounter]{ + Provider: func(ctx context.Context, r Resolver) (*testCounter, error) { callCount.Add(1) return &testCounter{id: int(callCount.Load())}, nil - }, WithPoolSize(1), - ) + }, + Scope: Pooled, + PoolSize: 1, + }) ctx := context.Background() diff --git a/spec.go b/spec.go new file mode 100644 index 0000000..f8c1242 --- /dev/null +++ b/spec.go @@ -0,0 +1,220 @@ +package needle + +import ( + "context" + "errors" + "fmt" + + "github.com/danpasecinic/needle/internal/container" + "github.com/danpasecinic/needle/internal/reflect" +) + +type Provider[T any] func(ctx context.Context, r Resolver) (T, error) + +type Spec[T any] struct { + Name string + Provider Provider[T] + Dependencies []string + Scope Scope + OnStart Hook + OnStop Hook + PoolSize int + Lazy bool + + value T + hasValue bool +} + +func SpecValue[T any](v T) Spec[T] { + return Spec[T]{value: v, hasValue: true} +} + +func (s Spec[T]) WithName(name string) Spec[T] { + s.Name = name + return s +} + +func (s Spec[T]) WithDependencies(deps ...string) Spec[T] { + s.Dependencies = deps + return s +} + +func (s Spec[T]) WithScope(sc Scope) Spec[T] { + s.Scope = sc + return s +} + +func (s Spec[T]) WithPoolSize(n int) Spec[T] { + s.Scope = Pooled + s.PoolSize = n + return s +} + +func (s Spec[T]) WithLazy() Spec[T] { + s.Lazy = true + return s +} + +func (s Spec[T]) WithOnStart(hook Hook) Spec[T] { + s.OnStart = hook + return s +} + +func (s Spec[T]) WithOnStop(hook Hook) Spec[T] { + s.OnStop = hook + return s +} + +func Compose(hooks ...Hook) Hook { + switch len(hooks) { + case 0: + return nil + case 1: + return hooks[0] + } + return func(ctx context.Context) error { + for _, h := range hooks { + if h == nil { + continue + } + if err := h(ctx); err != nil { + return err + } + } + return nil + } +} + +func Register[T any](c *Container, spec Spec[T]) error { + entry, err := buildEntry(c, spec) + if err != nil { + return err + } + return c.internal.Register(entry) +} + +func MustRegister[T any](c *Container, spec Spec[T]) { + if err := Register(c, spec); err != nil { + panic(err) + } +} + +func Replace[T any](c *Container, spec Spec[T]) error { + entry, err := buildEntry(c, spec) + if err != nil { + return err + } + return c.internal.Replace(entry) +} + +func MustReplace[T any](c *Container, spec Spec[T]) { + if err := Replace(c, spec); err != nil { + panic(err) + } +} + +var ( + errSpecBothProviderAndValue = errors.New("spec has both Provider and Value set; exactly one must be provided") + errSpecNeitherProviderNorValue = errors.New("spec has neither Provider nor Value set; exactly one must be provided") + errSpecPoolSizeWithNonPooledScope = errors.New("spec has PoolSize > 0 but Scope is not Pooled; use WithPoolSize or set Scope = Pooled") + errSpecValueWithLazy = errors.New("spec has a Value and Lazy = true; values are eagerly bound and cannot be lazy") + errSpecValueWithNonSingletonScope = errors.New("spec has a Value with non-singleton Scope; values are inherently singleton") +) + +func buildEntry[T any](c *Container, spec Spec[T]) (*container.ServiceEntry, error) { + if spec.Provider != nil && spec.hasValue { + return nil, errSpecBothProviderAndValue + } + if spec.Provider == nil && !spec.hasValue { + return nil, errSpecNeitherProviderNorValue + } + if spec.hasValue { + if spec.Lazy { + return nil, errSpecValueWithLazy + } + if spec.Scope != Singleton { + return nil, errSpecValueWithNonSingletonScope + } + if spec.PoolSize > 0 { + return nil, errSpecPoolSizeWithNonPooledScope + } + } + if spec.PoolSize > 0 && spec.Scope != Pooled { + return nil, errSpecPoolSizeWithNonPooledScope + } + + key := reflect.TypeKey[T]() + if spec.Name != "" { + key = reflect.TypeKeyNamed[T](spec.Name) + } + + cfg := container.EntryConfig{ + Key: key, + Dependencies: spec.Dependencies, + Scope: spec.Scope, + PoolSize: spec.PoolSize, + Lazy: spec.Lazy, + OnStart: container.Hook(spec.OnStart), + OnStop: container.Hook(spec.OnStop), + } + + if spec.hasValue { + cfg.Value = spec.value + cfg.HasValue = true + } else { + provider := spec.Provider + resolver := c.resolver + cfg.Provider = func(ctx context.Context, _ container.Resolver) (any, error) { + return provider(ctx, resolver) + } + } + + return container.NewServiceEntry(cfg), nil +} + +// SpecFromConstructor builds a spec whose Provider auto-resolves the constructor's +// parameters from the resolver and calls the constructor. The constructor must +// return T (and optionally an error as the second return value). Panics if the +// constructor signature is invalid – a programmer error caught at startup. +func SpecFromConstructor[T any](constructor any) Spec[T] { + provider, deps, err := buildFuncProvider[T](constructor) + if err != nil { + panic(fmt.Errorf("needle: SpecFromConstructor: %w", err)) + } + return Spec[T]{ + Provider: provider, + Dependencies: deps, + } +} + +// SpecFromStruct builds a spec whose Provider populates a T (or *T) struct from +// fields tagged `needle:"..."`. Optional fields are silently skipped if absent. +func SpecFromStruct[T any]() Spec[T] { + provider, deps := buildStructProvider[T]() + return Spec[T]{ + Provider: provider, + Dependencies: deps, + } +} + +// SpecFromBinding builds a spec for interface I that resolves implementation T +// from the container and returns it as I. Use to wire an interface to a concrete +// type already registered under T's key. +func SpecFromBinding[I, T any]() Spec[I] { + implKey := reflect.TypeKey[T]() + return Spec[I]{ + Provider: func(ctx context.Context, r Resolver) (I, error) { + var zero I + instance, err := r.Resolve(ctx, implKey) + if err != nil { + return zero, err + } + typed, ok := instance.(I) + if !ok { + return zero, fmt.Errorf("binding type mismatch: %s does not implement %s", reflect.TypeName[T](), reflect.TypeName[I]()) + } + return typed, nil + }, + Dependencies: []string{implKey}, + } +}