From e9028a1f7193a40bbe5193417bf0f2da977ba984 Mon Sep 17 00:00:00 2001 From: Michael Welles Date: Thu, 4 Jun 2026 13:06:57 -0400 Subject: [PATCH 01/24] feat(typed): generic, type-safe client and query builder Add a generic typed layer over modusgraph.Client: typed.Client[T] with CRUD and iterators; a fluent Query[T] builder (filters, ordering, paging, edge traversal, IterNodes); MultiQuery for N homogeneous blocks in one round-trip; functional options; a filter DSL (typed/filter); and ordered result merging (typed/search). A small no-op-by-default Tracer seam (typed.SetTracer) lets a host plug in tracing without the typed package depending on any telemetry library. Self-contained: builds and tests against the current client with no other changes. --- typed/client.go | 87 +++ typed/client_test.go | 209 ++++++ typed/filter/filter.go | 118 +++ typed/filter/filter_test.go | 118 +++ typed/filter/fulltext.go | 21 + typed/filter/fulltext_test.go | 41 ++ typed/multi_query.go | 191 +++++ typed/multi_query_test.go | 127 ++++ typed/option.go | 17 + typed/option_test.go | 37 + typed/query.go | 565 ++++++++++++++ typed/query_test.go | 1294 +++++++++++++++++++++++++++++++++ typed/search/merge.go | 27 + typed/search/merge_test.go | 86 +++ typed/tracing.go | 58 ++ typed/tracing_test.go | 47 ++ 16 files changed, 3043 insertions(+) create mode 100644 typed/client.go create mode 100644 typed/client_test.go create mode 100644 typed/filter/filter.go create mode 100644 typed/filter/filter_test.go create mode 100644 typed/filter/fulltext.go create mode 100644 typed/filter/fulltext_test.go create mode 100644 typed/multi_query.go create mode 100644 typed/multi_query_test.go create mode 100644 typed/option.go create mode 100644 typed/option_test.go create mode 100644 typed/query.go create mode 100644 typed/query_test.go create mode 100644 typed/search/merge.go create mode 100644 typed/search/merge_test.go create mode 100644 typed/tracing.go create mode 100644 typed/tracing_test.go diff --git a/typed/client.go b/typed/client.go new file mode 100644 index 0000000..c540f89 --- /dev/null +++ b/typed/client.go @@ -0,0 +1,87 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +// Package typed binds a Go type to the otherwise any-typed modusgraph.Client, +// providing generic, type-safe CRUD and query operations without per-entity +// code generation. It is the handwritten substrate that modusgraph-gen's +// generated clients compose over. +package typed + +import ( + "context" + "iter" + + "github.com/matthewmcneely/modusgraph" +) + +// Client provides type-safe CRUD and query operations over records of type T. +// T is the schema struct (for example schema.Actor); modusgraph reflects over +// the struct's dgraph/json tags, so T needs no constraint. +type Client[T any] struct { + conn modusgraph.Client +} + +// NewClient binds a Client[T] to conn. +func NewClient[T any](conn modusgraph.Client) *Client[T] { + return &Client[T]{conn: conn} +} + +// Get loads the T with the given UID. +func (c *Client[T]) Get(ctx context.Context, uid string) (rec *T, err error) { + ctx, span := tracer.StartSpan(ctx, "get", entityName[T]()) + defer func() { span.End(err) }() + var out T + if err = c.conn.Get(ctx, &out, uid); err != nil { + return nil, err + } + return &out, nil +} + +// Add inserts a new T. modusgraph writes the assigned UID back into rec. +func (c *Client[T]) Add(ctx context.Context, rec *T) (err error) { + ctx, span := tracer.StartSpan(ctx, "add", entityName[T]()) + defer func() { span.End(err) }() + return c.conn.Insert(ctx, rec) +} + +// Update modifies an existing T (must have its UID set). +func (c *Client[T]) Update(ctx context.Context, rec *T) (err error) { + ctx, span := tracer.StartSpan(ctx, "update", entityName[T]()) + defer func() { span.End(err) }() + return c.conn.Update(ctx, rec) +} + +// Upsert inserts or updates rec, matching against predicates. With no +// predicates, the first field tagged dgraph:"upsert" is used. +func (c *Client[T]) Upsert(ctx context.Context, rec *T, predicates ...string) (err error) { + ctx, span := tracer.StartSpan(ctx, "upsert", entityName[T]()) + defer func() { span.End(err) }() + return c.conn.Upsert(ctx, rec, predicates...) +} + +// Delete removes the T with the given UID. +func (c *Client[T]) Delete(ctx context.Context, uid string) (err error) { + ctx, span := tracer.StartSpan(ctx, "delete", entityName[T]()) + defer func() { span.End(err) }() + return c.conn.Delete(ctx, []string{uid}) +} + +// Query returns a typed query builder for T. conn and ctx are carried so the +// builder can run a WhereEdge pre-pass (see Query.WhereEdge) if one is needed. +func (c *Client[T]) Query(ctx context.Context) *Query[T] { + var z T + return &Query[T]{q: c.conn.Query(ctx, &z), conn: c.conn, ctx: ctx} +} + +// defaultPageSize is the page size IterNodes uses to page through results. +const defaultPageSize = 50 + +// Iter returns an iterator over every T, paging transparently so large result +// sets are not materialized at once. It yields each record in turn; on error +// it yields a final (nil, err) and stops. All pages execute against one +// read-only transaction, so the iteration reads a single consistent snapshot. +func (c *Client[T]) Iter(ctx context.Context) iter.Seq2[*T, error] { + return c.Query(ctx).IterNodes() +} diff --git a/typed/client_test.go b/typed/client_test.go new file mode 100644 index 0000000..6fa2b1d --- /dev/null +++ b/typed/client_test.go @@ -0,0 +1,209 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package typed_test + +import ( + "context" + "testing" + + "github.com/matthewmcneely/modusgraph" + "github.com/matthewmcneely/modusgraph/typed" +) + +// widget is a minimal schema struct used to exercise the typed package. +type widget struct { + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` + Name string `json:"name,omitempty" dgraph:"index=exact"` + Qty int `json:"qty,omitempty" dgraph:"index=int"` +} + +// owner and pet exercise Query.WhereEdge: owner has an outbound "pets" edge to +// pet, and pet's Name carries an index so eq(name, ...) resolves inside an edge +// filter. The pair is the typed-package analogue of the Person/Dog example in +// docs/specs/2026-05-21-query-edge-filter-design.md. +type owner struct { + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` + Name string `json:"name,omitempty" dgraph:"index=exact"` + Pets []*pet `json:"pets,omitempty"` +} + +type pet struct { + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` + Name string `json:"name,omitempty" dgraph:"index=exact"` +} + +// newConn builds a local file-backed modusgraph client for a test. +func newConn(t *testing.T) modusgraph.Client { + t.Helper() + conn, err := modusgraph.NewClient("file://"+t.TempDir(), modusgraph.WithAutoSchema(true)) + if err != nil { + t.Fatalf("modusgraph.NewClient: %v", err) + } + t.Cleanup(conn.Close) + return conn +} + +func TestClient_AddPopulatesUIDAndGetReadsBack(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + w := &widget{Name: "sprocket", Qty: 3} + if err := c.Add(ctx, w); err != nil { + t.Fatalf("Add: %v", err) + } + if w.UID == "" { + t.Fatal("Add did not populate UID on the passed struct") + } + + got, err := c.Get(ctx, w.UID) + if err != nil { + t.Fatalf("Get: %v", err) + } + if got.Name != "sprocket" || got.Qty != 3 { + t.Fatalf("Get returned %+v, want Name=sprocket Qty=3", got) + } +} + +func TestClient_Update(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + w := &widget{Name: "gear", Qty: 1} + if err := c.Add(ctx, w); err != nil { + t.Fatalf("Add: %v", err) + } + w.Qty = 99 + if err := c.Update(ctx, w); err != nil { + t.Fatalf("Update: %v", err) + } + + got, err := c.Get(ctx, w.UID) + if err != nil { + t.Fatalf("Get: %v", err) + } + if got.Qty != 99 { + t.Fatalf("Update did not persist; Qty = %d, want 99", got.Qty) + } +} + +func TestClient_Delete(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + w := &widget{Name: "bolt"} + if err := c.Add(ctx, w); err != nil { + t.Fatalf("Add: %v", err) + } + if err := c.Delete(ctx, w.UID); err != nil { + t.Fatalf("Delete: %v", err) + } + if _, err := c.Get(ctx, w.UID); err == nil { + t.Fatal("Get after Delete returned no error; expected not-found") + } +} + +func TestClient_IterPagesThroughAllRecords(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + // 125 is deliberately larger than the package's 50-record page size, so + // a correct Iter must fetch more than one page. + const n = 125 + for i := range n { + if err := c.Add(ctx, &widget{Name: "w", Qty: i}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + + seen := 0 + for w, err := range c.Iter(ctx) { + if err != nil { + t.Fatalf("Iter yielded error: %v", err) + } + if w == nil { + t.Fatal("Iter yielded a nil widget") + } + seen++ + } + if seen != n { + t.Fatalf("Iter yielded %d records, want %d", seen, n) + } +} + +// gadget is a dedicated upsert struct. It must not be the shared widget, because +// widget is used in tests that insert many records with duplicate Name values; +// adding a "upsert" directive to widget.Name would cause those inserts to +// collide and break unrelated tests. +type gadget struct { + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` + Label string `json:"label,omitempty" dgraph:"index=exact upsert"` + Stock int `json:"stock,omitempty" dgraph:"index=int"` +} + +func TestClient_Upsert(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[gadget](newConn(t)) + + // First call — creates the record. + g := &gadget{Label: "sprocket", Stock: 10} + if err := c.Upsert(ctx, g, "label"); err != nil { + t.Fatalf("Upsert (create): %v", err) + } + if g.UID == "" { + t.Fatal("Upsert (create) did not populate UID") + } + + // Second call — same Label value, different Stock. Must UPDATE, not insert. + g2 := &gadget{Label: "sprocket", Stock: 99} + if err := c.Upsert(ctx, g2, "label"); err != nil { + t.Fatalf("Upsert (update): %v", err) + } + + // Exactly one record must exist and it must carry the updated Stock. + nodes, err := c.Query(ctx).Nodes() + if err != nil { + t.Fatalf("Query after Upsert: %v", err) + } + if len(nodes) != 1 { + t.Fatalf("got %d gadgets after two upserts on the same label, want 1", len(nodes)) + } + if nodes[0].Stock != 99 { + t.Fatalf("upserted gadget Stock = %d, want 99", nodes[0].Stock) + } +} + +func TestClient_IterStopsOnConsumerBreak(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + const n = 125 + for i := range n { + if err := c.Add(ctx, &widget{Name: "w", Qty: i}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + + seen := 0 + for w, err := range c.Iter(ctx) { + if err != nil { + t.Fatalf("Iter yielded error: %v", err) + } + if w == nil { + t.Fatal("Iter yielded a nil widget") + } + seen++ + if seen == 10 { + break + } + } + if seen != 10 { + t.Fatalf("Iter yielded %d records after break at 10, want 10", seen) + } +} diff --git a/typed/filter/filter.go b/typed/filter/filter.go new file mode 100644 index 0000000..d67f118 --- /dev/null +++ b/typed/filter/filter.go @@ -0,0 +1,118 @@ +// Package filter provides typed values and a parameterised expression builder +// for composing dgraph @filter clauses on generated Query types. +// +// Generated By methods accept []UUID or []String and feed them into +// Builder.EqGroupUUID / Builder.EqGroupString. Consumers can also build +// custom expressions directly with Builder for cases the generator does not +// cover (multi-predicate joins, non-equality operators, domain defaults). +package filter + +import ( + "fmt" + "strings" +) + +// UUID is one UUID-valued filter term, optionally negated. A leading "!" in +// the parsed source negates the term ("!abc" becomes {Negated: true, Value: "abc"}). +type UUID struct { + Negated bool + Value string +} + +// String is one string-valued filter term, optionally negated. +type String struct { + Negated bool + Value string +} + +// ParseUUID parses "value" or "!value" into a UUID. +func ParseUUID(s string) UUID { + neg, v := parseNegation(s) + return UUID{Negated: neg, Value: v} +} + +// ParseString parses "value" or "!value" into a String. +func ParseString(s string) String { + neg, v := parseNegation(s) + return String{Negated: neg, Value: v} +} + +func parseNegation(s string) (bool, string) { + if strings.HasPrefix(s, "!") { + return true, s[1:] + } + return false, s +} + +// term is one predicate-agnostic value used by Builder. +type term struct { + value string + negated bool +} + +// Builder composes parameterised DQL @filter expressions. Terms within an +// EqGroup join with OR; groups join with AND. Required terms become their own +// single-term group. The output is the (expression, positional params) pair +// that typed.Query[T].Filter consumes. +type Builder struct { + groups []string + params []any +} + +func (b *Builder) param(v any) string { + b.params = append(b.params, v) + return fmt.Sprintf("$%d", len(b.params)) +} + +// EqGroupUUID adds an OR-group of eq(predicate, value) terms for one +// UUID-typed predicate. An empty terms slice is a no-op. +func (b *Builder) EqGroupUUID(predicate string, terms []UUID) { + if len(terms) == 0 { + return + } + tg := make([]term, 0, len(terms)) + for _, t := range terms { + tg = append(tg, term{value: t.Value, negated: t.Negated}) + } + b.addEqGroup(predicate, tg) +} + +// EqGroupString adds an OR-group of eq(predicate, value) terms for one +// string-typed predicate. +func (b *Builder) EqGroupString(predicate string, terms []String) { + if len(terms) == 0 { + return + } + tg := make([]term, 0, len(terms)) + for _, t := range terms { + tg = append(tg, term{value: t.Value, negated: t.Negated}) + } + b.addEqGroup(predicate, tg) +} + +func (b *Builder) addEqGroup(predicate string, terms []term) { + parts := make([]string, 0, len(terms)) + for _, t := range terms { + eq := fmt.Sprintf("eq(%s, %s)", predicate, b.param(t.value)) + if t.negated { + eq = "NOT " + eq + } + parts = append(parts, eq) + } + b.groups = append(b.groups, "("+strings.Join(parts, " OR ")+")") +} + +// RequiredEq adds a single mandatory eq(predicate, value) term (its own group). +func (b *Builder) RequiredEq(predicate, value string) { + b.groups = append(b.groups, fmt.Sprintf("eq(%s, %s)", predicate, b.param(value))) +} + +// Build returns the combined DQL filter expression and its parameters. When +// no groups were added it returns ("", nil) — callers should skip the +// .Filter() call entirely in that case. +func (b *Builder) Build() (string, []any) { + if len(b.groups) == 0 { + return "", nil + } + return strings.Join(b.groups, " AND "), b.params +} diff --git a/typed/filter/filter_test.go b/typed/filter/filter_test.go new file mode 100644 index 0000000..864a554 --- /dev/null +++ b/typed/filter/filter_test.go @@ -0,0 +1,118 @@ +package filter_test + +import ( + "strings" + "testing" + + "github.com/matthewmcneely/modusgraph/typed/filter" +) + +func TestParseUUID(t *testing.T) { + tests := []struct { + name string + in string + want filter.UUID + }{ + {"plain", "abc", filter.UUID{Value: "abc"}}, + {"negated", "!abc", filter.UUID{Negated: true, Value: "abc"}}, + {"empty", "", filter.UUID{}}, + {"just bang", "!", filter.UUID{Negated: true}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := filter.ParseUUID(tt.in) + if got != tt.want { + t.Errorf("ParseUUID(%q) = %+v, want %+v", tt.in, got, tt.want) + } + }) + } +} + +func TestParseString(t *testing.T) { + got := filter.ParseString("!hello") + want := filter.String{Negated: true, Value: "hello"} + if got != want { + t.Errorf("ParseString = %+v, want %+v", got, want) + } +} + +func TestBuilder_Empty(t *testing.T) { + var b filter.Builder + expr, params := b.Build() + if expr != "" || params != nil { + t.Errorf("empty Build = (%q, %v), want (\"\", nil)", expr, params) + } +} + +func TestBuilder_EqGroupUUID_SingleTerm(t *testing.T) { + var b filter.Builder + b.EqGroupUUID("id", []filter.UUID{{Value: "u1"}}) + expr, params := b.Build() + want := "(eq(id, $1))" + if expr != want { + t.Errorf("expr = %q, want %q", expr, want) + } + if len(params) != 1 || params[0] != "u1" { + t.Errorf("params = %v, want [u1]", params) + } +} + +func TestBuilder_EqGroupUUID_MultipleTermsJoinWithOR(t *testing.T) { + var b filter.Builder + b.EqGroupUUID("id", []filter.UUID{{Value: "u1"}, {Value: "u2"}, {Negated: true, Value: "u3"}}) + expr, params := b.Build() + want := "(eq(id, $1) OR eq(id, $2) OR NOT eq(id, $3))" + if expr != want { + t.Errorf("expr = %q, want %q", expr, want) + } + if len(params) != 3 { + t.Errorf("len(params) = %d, want 3", len(params)) + } +} + +func TestBuilder_EqGroupString_NoTermsIsNoop(t *testing.T) { + var b filter.Builder + b.EqGroupString("name", nil) + expr, _ := b.Build() + if expr != "" { + t.Errorf("empty EqGroupString should be no-op, got expr=%q", expr) + } +} + +func TestBuilder_MultipleGroupsJoinWithAND(t *testing.T) { + var b filter.Builder + b.EqGroupUUID("id", []filter.UUID{{Value: "u1"}}) + b.EqGroupString("name", []filter.String{{Value: "Alice"}}) + expr, params := b.Build() + want := "(eq(id, $1)) AND (eq(name, $2))" + if expr != want { + t.Errorf("expr = %q, want %q", expr, want) + } + if len(params) != 2 || params[0] != "u1" || params[1] != "Alice" { + t.Errorf("params = %v, want [u1 Alice]", params) + } +} + +func TestBuilder_RequiredEqIsOwnGroup(t *testing.T) { + var b filter.Builder + b.RequiredEq("archiveStatus", "Active") + b.EqGroupUUID("id", []filter.UUID{{Value: "u1"}}) + expr, params := b.Build() + want := "eq(archiveStatus, $1) AND (eq(id, $2))" + if expr != want { + t.Errorf("expr = %q, want %q", expr, want) + } + if len(params) != 2 { + t.Errorf("len(params) = %d, want 2", len(params)) + } +} + +func TestBuilder_PositionalParamsAreSequential(t *testing.T) { + var b filter.Builder + b.EqGroupUUID("id", []filter.UUID{{Value: "a"}, {Value: "b"}}) + b.EqGroupString("name", []filter.String{{Value: "c"}}) + expr, _ := b.Build() + if !strings.Contains(expr, "$1") || !strings.Contains(expr, "$2") || !strings.Contains(expr, "$3") { + t.Errorf("expected $1, $2, $3 in expr; got %q", expr) + } +} diff --git a/typed/filter/fulltext.go b/typed/filter/fulltext.go new file mode 100644 index 0000000..a025ef0 --- /dev/null +++ b/typed/filter/fulltext.go @@ -0,0 +1,21 @@ +package filter + +import "fmt" + +// AnyOfText adds a fulltext OR-match group: anyoftext(predicate, term). +// An empty term is a no-op. +func (b *Builder) AnyOfText(predicate, term string) { + if term == "" { + return + } + b.groups = append(b.groups, fmt.Sprintf("anyoftext(%s, %s)", predicate, b.param(term))) +} + +// AllOfText adds a fulltext AND-match group: alloftext(predicate, term). +// An empty term is a no-op. +func (b *Builder) AllOfText(predicate, term string) { + if term == "" { + return + } + b.groups = append(b.groups, fmt.Sprintf("alloftext(%s, %s)", predicate, b.param(term))) +} diff --git a/typed/filter/fulltext_test.go b/typed/filter/fulltext_test.go new file mode 100644 index 0000000..1d71e0b --- /dev/null +++ b/typed/filter/fulltext_test.go @@ -0,0 +1,41 @@ +package filter_test + +import ( + "strings" + "testing" + + "github.com/matthewmcneely/modusgraph/typed/filter" +) + +func TestAnyOfTextEmitsFilterAndBindsParam(t *testing.T) { + b := &filter.Builder{} + b.AnyOfText("resourceName", "honda civic") + expr, params := b.Build() + if !strings.Contains(expr, "anyoftext(resourceName, $1)") { + t.Fatalf("expected anyoftext(resourceName, $1) in expr, got %q", expr) + } + if len(params) != 1 || params[0] != "honda civic" { + t.Fatalf("expected params [\"honda civic\"], got %v", params) + } +} + +func TestAllOfTextEmitsFilterAndBindsParam(t *testing.T) { + b := &filter.Builder{} + b.AllOfText("description", "engine block") + expr, params := b.Build() + if !strings.Contains(expr, "alloftext(description, $1)") { + t.Fatalf("expected alloftext(description, $1) in expr, got %q", expr) + } + if len(params) != 1 || params[0] != "engine block" { + t.Fatalf("expected params [\"engine block\"], got %v", params) + } +} + +func TestAnyOfTextEmptyTermIsNoop(t *testing.T) { + b := &filter.Builder{} + b.AnyOfText("resourceName", "") + expr, params := b.Build() + if expr != "" || params != nil { + t.Fatalf("expected empty expr/params for empty term, got %q / %v", expr, params) + } +} diff --git a/typed/multi_query.go b/typed/multi_query.go new file mode 100644 index 0000000..98409c6 --- /dev/null +++ b/typed/multi_query.go @@ -0,0 +1,191 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package typed + +import ( + "context" + "encoding/json" + "fmt" + "reflect" + "strings" + + dg "github.com/dolan-in/dgman/v2" + "github.com/matthewmcneely/modusgraph" +) + +// MultiQuery batches N homogeneous-type Query[T] blocks into a single +// Dgraph multi-block request. All blocks return rows of the same T; the +// per-block result is keyed by the block name supplied at Add. +// +// Dgraph executes the blocks concurrently on the server side; the entire +// batch costs one gRPC round-trip. +type MultiQuery[T any] struct { + conn modusgraph.Client + names []string + blocks map[string]*Query[T] +} + +// NewMultiQuery constructs a MultiQuery bound to conn. +func NewMultiQuery[T any](conn modusgraph.Client) *MultiQuery[T] { + return &MultiQuery[T]{ + conn: conn, + blocks: make(map[string]*Query[T]), + } +} + +// Add registers a named block. Names must be unique within one MultiQuery. +// Panics on duplicate name — the call site is a programming error, not a +// runtime condition. +func (mq *MultiQuery[T]) Add(name string, q *Query[T]) *MultiQuery[T] { + if _, exists := mq.blocks[name]; exists { + panic(fmt.Sprintf("multi_query: duplicate block name %q", name)) + } + mq.names = append(mq.names, name) + mq.blocks[name] = q + return mq +} + +// BlockNames returns the registered block names in insertion order. +func (mq *MultiQuery[T]) BlockNames() []string { + out := make([]string, len(mq.names)) + copy(out, mq.names) + return out +} + +// Execute runs every registered block in a single Dgraph round-trip and +// returns the per-block results, keyed by the block name supplied at Add. +// A block that matched no rows appears as an empty (non-nil) slice in the +// result map; the key is always present. +// +// Execute rejects blocks that carry WhereEdge constraints — those require a +// runtime pre-pass that cannot be folded into the multi-block batch. Run such +// queries individually with Query.Nodes. +// +// Dgraph keys response JSON by predicate name (e.g. resourceName), but Go +// structs typically use their json tag (e.g. name). Execute remaps the keys +// per T's tags before decoding so a schema that uses `dgraph:"predicate=..."` +// with a divergent `json:"..."` decodes correctly — matching the behavior of +// dgman's QueryBlock.Scan path. +func (mq *MultiQuery[T]) Execute(ctx context.Context) (map[string][]T, error) { + if len(mq.names) == 0 { + return map[string][]T{}, nil + } + + rawBlocks := make([]*dg.Query, 0, len(mq.names)) + for _, name := range mq.names { + block := mq.blocks[name] + if len(block.edges) != 0 { + return nil, fmt.Errorf("multi_query: block %q carries WhereEdge constraints; MultiQuery cannot batch edge-filtered blocks", name) + } + // Name the underlying dgman query so blocks do not collide on the + // default "data" name and so the response JSON keys are predictable. + block.q.Name(name) + rawBlocks = append(rawBlocks, block.q) + } + + dql := dg.NewQueryBlock(rawBlocks...).String() + raw, err := mq.conn.QueryRaw(ctx, dql, nil) + if err != nil { + return nil, fmt.Errorf("multi_query: dgraph: %w", err) + } + + var perBlockRaw map[string]json.RawMessage + if err := json.Unmarshal(raw, &perBlockRaw); err != nil { + return nil, fmt.Errorf("multi_query: decoding response: %w", err) + } + + var zero T + predMap := buildPredicateToJSONMap(reflect.TypeOf(zero)) + + out := make(map[string][]T, len(mq.names)) + for _, name := range mq.names { + body, ok := perBlockRaw[name] + if !ok { + out[name] = []T{} + continue + } + if len(predMap) > 0 { + remapped, err := remapArrayKeys(body, predMap) + if err == nil { + body = remapped + } + } + var rows []T + if err := json.Unmarshal(body, &rows); err != nil { + return nil, fmt.Errorf("multi_query: decoding block %q: %w", name, err) + } + if rows == nil { + rows = []T{} + } + out[name] = rows + } + return out, nil +} + +// buildPredicateToJSONMap returns a map from dgraph predicate name → JSON tag +// name for fields on T where the two differ. Mirrors dgman's unexported helper +// of the same name; we need our own because the multi-block response from +// QueryRaw bypasses dgman's scan path. +func buildPredicateToJSONMap(t reflect.Type) map[string]string { + for t != nil && t.Kind() == reflect.Ptr { + t = t.Elem() + } + if t == nil || t.Kind() != reflect.Struct { + return nil + } + result := make(map[string]string) + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + jsonTag := field.Tag.Get("json") + if jsonTag == "" || jsonTag == "-" { + continue + } + jsonName := strings.Split(jsonTag, ",")[0] + if jsonName == "" { + continue + } + dgraphTag := field.Tag.Get("dgraph") + if dgraphTag == "" { + continue + } + var predName string + for _, part := range strings.Fields(dgraphTag) { + if strings.HasPrefix(part, "predicate=") { + predName = strings.TrimPrefix(part, "predicate=") + break + } + } + if predName == "" || predName == jsonName { + continue + } + if predName == "uid" || predName == "dgraph.type" { + continue + } + result[predName] = jsonName + } + return result +} + +// remapArrayKeys rewrites top-level keys in each object of a JSON array using +// the predicate → JSON-tag map. Nested objects are left untouched (search +// callers iterate scalar predicates of the root type; edge fields are +// hydrated lazily, not in the multi-block response). +func remapArrayKeys(data json.RawMessage, predMap map[string]string) (json.RawMessage, error) { + var rows []map[string]json.RawMessage + if err := json.Unmarshal(data, &rows); err != nil { + return data, err + } + for i, row := range rows { + for k, v := range row { + if newK, ok := predMap[k]; ok && newK != k { + delete(row, k) + row[newK] = v + } + } + rows[i] = row + } + return json.Marshal(rows) +} diff --git a/typed/multi_query_test.go b/typed/multi_query_test.go new file mode 100644 index 0000000..98f1ae4 --- /dev/null +++ b/typed/multi_query_test.go @@ -0,0 +1,127 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package typed_test + +import ( + "context" + "testing" + + "github.com/matthewmcneely/modusgraph/typed" +) + +func TestMultiQueryAddAccumulatesBlocks(t *testing.T) { + conn := newConn(t) + mq := typed.NewMultiQuery[widget](conn) + q1 := typed.NewClient[widget](conn).Query(context.Background()) + q2 := typed.NewClient[widget](conn).Query(context.Background()) + mq.Add("byName", q1) + mq.Add("byQty", q2) + got := mq.BlockNames() + if len(got) != 2 || got[0] != "byName" || got[1] != "byQty" { + t.Fatalf("BlockNames = %v, want [byName, byQty]", got) + } +} + +func TestMultiQueryAddRejectsDuplicateName(t *testing.T) { + conn := newConn(t) + mq := typed.NewMultiQuery[widget](conn) + q := typed.NewClient[widget](conn).Query(context.Background()) + mq.Add("byName", q) + defer func() { + if r := recover(); r == nil { + t.Fatal("expected panic on duplicate block name") + } + }() + mq.Add("byName", q) +} + +func TestMultiQueryExecuteReturnsPerBlockResults(t *testing.T) { + ctx := context.Background() + conn := newConn(t) + c := typed.NewClient[widget](conn) + + for _, w := range []*widget{ + {Name: "sprocket", Qty: 1}, + {Name: "gear", Qty: 5}, + {Name: "bolt", Qty: 10}, + } { + if err := c.Add(ctx, w); err != nil { + t.Fatalf("Add %s: %v", w.Name, err) + } + } + + mq := typed.NewMultiQuery[widget](conn) + mq.Add("all", c.Query(ctx)) + mq.Add("filtered", c.Query(ctx).Filter("eq(name, $1)", "gear")) + + results, err := mq.Execute(ctx) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if got := len(results["all"]); got != 3 { + t.Fatalf("results[all] has %d rows, want 3", got) + } + if got := len(results["filtered"]); got != 1 { + t.Fatalf("results[filtered] has %d rows, want 1", got) + } + if results["filtered"][0].Name != "gear" { + t.Fatalf("results[filtered][0].Name = %q, want gear", results["filtered"][0].Name) + } +} + +func TestMultiQueryExecuteEmptyReturnsEmptyMap(t *testing.T) { + conn := newConn(t) + mq := typed.NewMultiQuery[widget](conn) + results, err := mq.Execute(context.Background()) + if err != nil { + t.Fatalf("Execute on empty MultiQuery: %v", err) + } + if len(results) != 0 { + t.Fatalf("expected empty map, got %v", results) + } +} + +// renamed exercises the predicate-vs-json-tag remap. Dgraph returns the +// "thingName" key (the predicate name) but the struct's JSON tag is +// "name"; MultiQuery.Execute must remap before unmarshaling so Name +// populates. +type renamed struct { + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` + Name string `json:"name,omitempty" dgraph:"predicate=thingName index=hash,fulltext"` + Qty int `json:"qty,omitempty" dgraph:"index=int"` +} + +func TestMultiQueryExecuteRemapsPredicateKeys(t *testing.T) { + ctx := context.Background() + conn := newConn(t) + c := typed.NewClient[renamed](conn) + + for _, w := range []*renamed{ + {Name: "alpha", Qty: 1}, + {Name: "beta", Qty: 2}, + } { + if err := c.Add(ctx, w); err != nil { + t.Fatalf("Add %s: %v", w.Name, err) + } + } + + mq := typed.NewMultiQuery[renamed](conn) + mq.Add("all", c.Query(ctx)) + results, err := mq.Execute(ctx) + if err != nil { + t.Fatalf("Execute: %v", err) + } + rows := results["all"] + if len(rows) != 2 { + t.Fatalf("rows = %d, want 2", len(rows)) + } + for _, r := range rows { + if r.Name == "" { + t.Fatalf("Name not populated; multi-block response was not remapped from predicate key: %+v", r) + } + } +} diff --git a/typed/option.go b/typed/option.go new file mode 100644 index 0000000..d944483 --- /dev/null +++ b/typed/option.go @@ -0,0 +1,17 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package typed + +// Option configures a *T. Generated With constructors return an Option; +// generated New/Wrap constructors apply them via Apply. +type Option[T any] func(*T) + +// Apply applies opts to target in declaration order. +func Apply[T any](target *T, opts ...Option[T]) { + for _, opt := range opts { + opt(target) + } +} diff --git a/typed/option_test.go b/typed/option_test.go new file mode 100644 index 0000000..7c1f378 --- /dev/null +++ b/typed/option_test.go @@ -0,0 +1,37 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package typed_test + +import ( + "strings" + "testing" + + "github.com/matthewmcneely/modusgraph/typed" +) + +func TestApply_RunsOptionsInOrder(t *testing.T) { + type rec struct{ trail []string } + r := &rec{} + + typed.Apply(r, + func(x *rec) { x.trail = append(x.trail, "a") }, + func(x *rec) { x.trail = append(x.trail, "b") }, + func(x *rec) { x.trail = append(x.trail, "c") }, + ) + + if got := strings.Join(r.trail, ""); got != "abc" { + t.Fatalf("Apply ran options as %q, want %q", got, "abc") + } +} + +func TestApply_NoOptionsIsNoop(t *testing.T) { + type rec struct{ n int } + r := &rec{n: 7} + typed.Apply(r) + if r.n != 7 { + t.Fatalf("Apply with no options mutated target: n = %d, want 7", r.n) + } +} diff --git a/typed/query.go b/typed/query.go new file mode 100644 index 0000000..e4b2199 --- /dev/null +++ b/typed/query.go @@ -0,0 +1,565 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package typed + +import ( + "context" + "fmt" + "iter" + "strconv" + "strings" + + dg "github.com/dolan-in/dgman/v2" + "github.com/matthewmcneely/modusgraph" +) + +// Query is a fluent, type-safe query builder over records of type T. Builder +// methods return *Query[T] for chaining, except As, Var, and GroupBy, which +// change the result shape and transition to *RawQuery; terminal methods +// (Nodes, First, IterNodes) execute the query and decode typed results. +// +// A Query is single-use. Builder methods mutate the underlying query in place +// and return the same *Query, so a Query value should be built as one chain +// and handed to a single terminal. It is not safe to save a Query to a +// variable and branch it into independent queries: every branch shares — and +// keeps mutating — the same underlying query. +// +// Repeated builder calls do not all behave the same way. Limit, Offset, After, +// Cascade, Name, RootFunc, and Vars overwrite: the last call wins. Filter, +// OrderAsc, OrderDesc, and WhereEdge accumulate: each call adds to the query. +// Accumulated Filter fragments AND together (see CombinedFilter, OrGroup). +// +// Limit and Offset additionally record the bounds that IterNodes pages +// within — a Limit caps the rows it streams, an Offset is its start. +type Query[T any] struct { + q *dg.Query + conn modusgraph.Client // runs the WhereEdge pre-pass; set by Client.Query + ctx context.Context // carried for the WhereEdge pre-pass query + limit int // caller-set row cap; 0 = unbounded + offset int // caller-set starting offset; 0 = none + edges []edgeFilter // accumulated WhereEdge constraints; empty = none + filters []filterFrag // accumulated @filter fragments, ANDed; empty = none +} + +// edgeFilter is one accumulated WhereEdge constraint: a dgraph @filter +// expression scoped to an outbound edge predicate of T. +type edgeFilter struct { + predicate string + filter string + params []any +} + +// filterFrag is one accumulated @filter fragment. Fragments join with AND. +type filterFrag struct { + expr string + params []any +} + +// NewDetachedQuery returns a Query[T] with no connection, used only to +// accumulate a filter expression: its By/Filter calls record fragments +// that CombinedFilter reads back. It must not be executed (it has no terminal +// path) and exists as the capture target behind the generated Or and +// WhereBy combinators. +func NewDetachedQuery[T any]() *Query[T] { + return &Query[T]{} +} + +// Filter adds a dgraph @filter expression. params bind to placeholders. +// Repeated calls accumulate: every fragment ANDs together. +func (qb *Query[T]) Filter(filter string, params ...any) *Query[T] { + qb.addFilter(filter, params) + return qb +} + +// addFilter accumulates one @filter fragment. Fragments AND together: the +// effective filter is every fragment joined with AND, each fragment's $N +// placeholders shifted to stay bound to its own params. dgman's own Filter is +// last-write-wins, so the full combined expression is re-pushed on every call. +// A detached query (nil q — used to capture a sub-scope's filter for OrGroup or +// WhereBy) accumulates with no dgman query to push to; CombinedFilter +// reads the fragments back. +func (qb *Query[T]) addFilter(expr string, params []any) { + if expr == "" { + return + } + qb.filters = append(qb.filters, filterFrag{expr: expr, params: params}) + if qb.q != nil { + combined, cp := combineAnd(qb.filters) + qb.q.Filter(combined, cp...) + } +} + +// combineAnd joins fragments with AND, renumbering each fragment's ordinal +// placeholders against the concatenated params slice. +func combineAnd(frags []filterFrag) (string, []any) { + parts := make([]string, 0, len(frags)) + var params []any + for _, f := range frags { + if f.expr == "" { + continue + } + parts = append(parts, shiftPlaceholders(f.expr, len(params))) + params = append(params, f.params...) + } + if len(parts) == 0 { + return "", nil + } + return strings.Join(parts, " AND "), params +} + +// CombinedFilter returns the AND-combined accumulated @filter expression and +// its params, or ("", nil) when no filter was set. It is the substrate behind +// the generated Or and WhereBy combinators: they run a sub-scope's +// By/Filter calls against a detached query, then fold the captured +// expression into a parent OR group or edge constraint. +func (qb *Query[T]) CombinedFilter() (string, []any) { + return combineAnd(qb.filters) +} + +// OrGroup adds one @filter group that ORs the combined filter of each sub. +// Each sub is a detached Query[T] whose By/Filter calls have been +// accumulated; their combined (AND) expressions are parenthesized, joined with +// OR, and the whole OR group ANDs with the receiver's other filters. Subs with +// an empty filter are skipped; an all-empty OrGroup is a no-op. It is the +// substrate behind the generated Query.Or combinator. +func (qb *Query[T]) OrGroup(subs ...*Query[T]) *Query[T] { + parts := make([]string, 0, len(subs)) + var params []any + for _, s := range subs { + e, p := s.CombinedFilter() + if e == "" { + continue + } + parts = append(parts, "("+shiftPlaceholders(e, len(params))+")") + params = append(params, p...) + } + if len(parts) == 0 { + return qb + } + qb.addFilter("("+strings.Join(parts, " OR ")+")", params) + return qb +} + +// OrderAsc orders results ascending by clause. +func (qb *Query[T]) OrderAsc(clause string) *Query[T] { + qb.q.OrderAsc(clause) + return qb +} + +// OrderDesc orders results descending by clause. +func (qb *Query[T]) OrderDesc(clause string) *Query[T] { + qb.q.OrderDesc(clause) + return qb +} + +// Limit caps the number of results. dgman names this First; it is renamed +// here so it does not collide with the First terminal. +func (qb *Query[T]) Limit(n int) *Query[T] { + qb.limit = n + qb.q.First(n) + return qb +} + +// Offset skips the first n results. +func (qb *Query[T]) Offset(n int) *Query[T] { + qb.offset = n + qb.q.Offset(n) + return qb +} + +// After returns results with UID greater than uid (cursor pagination). +func (qb *Query[T]) After(uid string) *Query[T] { + qb.q.After(uid) + return qb +} + +// Cascade drops nodes missing any of the given predicates (all, if none given). +func (qb *Query[T]) Cascade(predicates ...string) *Query[T] { + qb.q.Cascade(predicates...) + return qb +} + +// RootFunc overrides the query root function. dgman's default root function +// is type(); RootFunc replaces it with an expression such as +// eq(name, "Alice") or has(email). Repeated calls overwrite. +func (qb *Query[T]) RootFunc(rootFunc string) *Query[T] { + qb.q.RootFunc(rootFunc) + return qb +} + +// Name sets the query block name. It defaults to "data"; dgman uses the name +// to both generate and decode the query, so a renamed block still decodes +// into []T. Repeated calls overwrite. +func (qb *Query[T]) Name(queryName string) *Query[T] { + qb.q.Name(queryName) + return qb +} + +// Vars supplies GraphQL variables for a parameterized query: funcDef is the +// query function definition (for example "getByName($n: string)") and vars +// binds each variable. The query then executes via dgraph's QueryWithVars +// path. Repeated calls overwrite. +func (qb *Query[T]) Vars(funcDef string, vars map[string]string) *Query[T] { + qb.q.Vars(funcDef, vars) + return qb +} + +// WhereEdge constrains results to records that have at least one `predicate` +// edge whose target node satisfies the dgraph @filter expression. params bind +// to $N placeholders within filter, exactly as Filter binds them. +// +// Where Filter constrains T's own scalar predicates, WhereEdge constrains a +// neighbouring node reached over an edge. dgraph's root @filter cannot express +// that, so a query carrying WhereEdge constraints executes in two steps: a +// pre-pass resolves the UIDs of roots that satisfy every constraint, then the +// main query runs against uid(...) — keeping ordering, pagination, and result +// projection on the normal path. See +// docs/specs/2026-05-21-query-edge-filter-design.md. +// +// WhereEdge accumulates: multiple calls AND together (a record must satisfy +// every edge constraint). It is the substrate behind the generated +// Query.Where methods. +func (qb *Query[T]) WhereEdge(predicate, filter string, params ...any) *Query[T] { + qb.edges = append(qb.edges, edgeFilter{predicate: predicate, filter: filter, params: params}) + return qb +} + +// WhereAnyOfText adds an @filter(anyoftext(predicate, $1)) clause. It +// accumulates and ANDs with other filters like Filter. +func (qb *Query[T]) WhereAnyOfText(predicate, term string) *Query[T] { + qb.addFilter(fmt.Sprintf("anyoftext(%s, $1)", predicate), []any{term}) + return qb +} + +// WhereAllOfText adds an @filter(alloftext(predicate, $1)) clause. It +// accumulates and ANDs with other filters like Filter. +func (qb *Query[T]) WhereAllOfText(predicate, term string) *Query[T] { + qb.addFilter(fmt.Sprintf("alloftext(%s, $1)", predicate), []any{term}) + return qb +} + +// As names the query block as a dgraph query variable. dgraph requires such a +// variable be consumed by another block, which a single-block typed query +// cannot do, so As transitions out of the typed query: it returns a *RawQuery, +// which exposes no node terminal. +func (qb *Query[T]) As(varName string) *RawQuery { + qb.q.As(varName) + return &RawQuery{q: qb.q} +} + +// Var marks the query block as a dgraph var block. A var block computes query +// variables and returns no data of its own, so Var transitions out of the +// typed query: it returns a *RawQuery, which exposes no node terminal. +func (qb *Query[T]) Var() *RawQuery { + qb.q.Var() + return &RawQuery{q: qb.q} +} + +// GroupBy adds an @groupby(predicate) aggregation. A grouped query returns +// aggregation groups rather than a slice of T, so GroupBy transitions out of +// the typed query: it returns a *RawQuery, which exposes no node terminal. +func (qb *Query[T]) GroupBy(predicate string) *RawQuery { + qb.q.GroupBy(predicate) + return &RawQuery{q: qb.q} +} + +// Nodes executes the query and returns all matching records. +func (qb *Query[T]) Nodes() (out []T, err error) { + _, span := tracer.StartSpan(qb.ctx, "query", entityName[T]()) + defer func() { span.End(err) }() + matched, err := qb.resolveRoots() + if err != nil { + return nil, err + } + if !matched { + return nil, nil + } + if err = qb.q.Nodes(&out); err != nil { + return nil, err + } + return out, nil +} + +// First executes the query with an implicit Limit(1) and returns the first +// record, or (nil, nil) if the query matched no rows. +func (qb *Query[T]) First() (rec *T, err error) { + _, span := tracer.StartSpan(qb.ctx, "query", entityName[T]()) + defer func() { span.End(err) }() + matched, err := qb.resolveRoots() + if err != nil { + return nil, err + } + if !matched { + return nil, nil + } + var out []T + if err = qb.q.First(1).Nodes(&out); err != nil { + return nil, err + } + if len(out) == 0 { + return nil, nil + } + return &out[0], nil +} + +// IterNodes executes the query and returns an iterator over matching records, +// paging transparently so a large result set is never materialized at once. +// +// IterNodes is a terminal operation: it drives Offset/Limit internally as it +// pages and leaves the builder spent — do not call another terminal on the +// same Query afterward. A Limit set on the query caps the total number of +// rows streamed; an Offset is the starting point. +// +// All pages execute against one read-only transaction, so the iteration reads +// a single consistent snapshot: a concurrent writer cannot make it skip or +// repeat rows. A WhereEdge pre-pass, when present, runs once before paging +// begins, in its own transaction. On error it yields a final (nil, err) and +// stops. +func (qb *Query[T]) IterNodes() iter.Seq2[*T, error] { + return func(yield func(*T, error) bool) { + _, span := tracer.StartSpan(qb.ctx, "query", entityName[T]()) + var ferr error + defer func() { span.End(ferr) }() + matched, err := qb.resolveRoots() + if err != nil { + ferr = err + yield(nil, err) + return + } + if !matched { + return // edge constraints present, but no root matched + } + remaining := qb.limit // 0 = unbounded + for off := qb.offset; ; off += defaultPageSize { + size := defaultPageSize + if remaining > 0 && remaining < size { + size = remaining // shrink the last page so it can't overshoot the cap + } + var page []T + if err := qb.q.Offset(off).First(size).Nodes(&page); err != nil { + ferr = err + yield(nil, err) + return + } + for i := range page { + if !yield(&page[i], nil) { + return // consumer broke out + } + } + if remaining > 0 { + if remaining -= len(page); remaining <= 0 { + return // hit the caller's Limit + } + } + if len(page) < size { + return // result set exhausted + } + } + } +} + +// Raw returns the underlying dgman query for operations Query does not wrap +// (for example the raw-selection Query method). Raw does not carry WhereEdge +// constraints — those are resolved only when a terminal runs. +func (qb *Query[T]) Raw() *dg.Query { + return qb.q +} + +// UID roots the query at a specific node UID. Results still decode into []T. +func (qb *Query[T]) UID(uid string) *Query[T] { + qb.q.UID(uid) + return qb +} + +// All sets the edge-traversal depth for this query, overriding the client's +// default maxEdgeTraversal. Use a small depth to stay under Dgraph's 4MB gRPC +// limit on highly-connected entities. +func (qb *Query[T]) All(depth int) *Query[T] { + qb.q.All(depth) + return qb +} + +// NodesAndCount executes the query and returns the matching records together +// with the total count (useful for pagination totals). Like Nodes, it runs the +// WhereEdge pre-pass first when edge constraints are present. +func (qb *Query[T]) NodesAndCount() ([]T, int, error) { + matched, err := qb.resolveRoots() + if err != nil { + return nil, 0, err + } + if !matched { + return nil, 0, nil + } + var out []T + count, err := qb.q.NodesAndCount(&out) + if err != nil { + return nil, 0, err + } + return out, count, nil +} + +// String renders the generated DQL without executing it. WhereEdge constraints +// are not reflected — they are resolved only when a terminal runs. +func (qb *Query[T]) String() string { + return qb.q.String() +} + +// FormatBlock renders the query as a single DQL block named name, without +// executing it. The returned text is suitable for inclusion inside a wrapping +// "{ ... }" multi-block request — it does not include outer braces. +// +// FormatBlock is the substrate behind MultiQuery; external callers can use it +// to compose typed queries into larger hand-written DQL requests. +// +// Filter parameters are inlined at Filter-call time (dgman renders $N +// placeholders into the filter string immediately), so the returned block +// carries no unresolved variables. WhereEdge constraints are not formatted — +// they require a runtime pre-pass and would produce no useful output here. +func (qb *Query[T]) FormatBlock(name string) (string, error) { + if len(qb.edges) != 0 { + return "", fmt.Errorf("typed: FormatBlock cannot render a Query carrying WhereEdge constraints") + } + qb.q.Name(name) + wrapped := dg.NewQueryBlock(qb.q).String() + // QueryBlock.String() wraps the block in "{\n ... }" — strip the wrapper so + // the caller can compose blocks inside their own braces. + inner := strings.TrimPrefix(wrapped, "{\n") + inner = strings.TrimSuffix(inner, "}") + return inner, nil +} + +// RawQuery is a query whose result is not a slice of T — produced by the +// shape-changing builders Query.As, Query.Var, and Query.GroupBy. A RawQuery +// deliberately exposes no typed node terminal: its result must be decoded by +// the caller through the underlying dgman query, obtained via Raw. +type RawQuery struct { + q *dg.Query +} + +// Raw returns the underlying dgman query, for the caller to execute and decode. +func (r *RawQuery) Raw() *dg.Query { + return r.q +} + +// String returns the generated DQL. +func (r *RawQuery) String() string { + return r.q.String() +} + +// As names the block as a dgraph query variable. See Query.As. +func (r *RawQuery) As(varName string) *RawQuery { + r.q.As(varName) + return r +} + +// Var marks the block as a dgraph var block. See Query.Var. +func (r *RawQuery) Var() *RawQuery { + r.q.Var() + return r +} + +// GroupBy adds an @groupby(predicate) aggregation. See Query.GroupBy. +func (r *RawQuery) GroupBy(predicate string) *RawQuery { + r.q.GroupBy(predicate) + return r +} + +// resolveRoots runs the WhereEdge pre-pass when the query carries edge +// constraints, rewriting the main query's root function to the matching UIDs. +// It returns matched=false when constraints are present but no root satisfied +// them — callers then return an empty result without running the main query. +// With no edge constraints it is a no-op returning matched=true. +func (qb *Query[T]) resolveRoots() (matched bool, err error) { + if len(qb.edges) == 0 { + return true, nil + } + uids, err := qb.matchedUIDs() + if err != nil { + return false, err + } + if len(uids) == 0 { + return false, nil + } + qb.q.RootFunc("uid(" + strings.Join(uids, ", ") + ")") + return true, nil +} + +// matchedUIDs runs the pre-pass: an @cascade query over T that keeps only +// nodes whose every WhereEdge predicate has a target matching its filter, and +// returns those nodes' UIDs. +func (qb *Query[T]) matchedUIDs() ([]string, error) { + var z T + pre := qb.conn.Query(qb.ctx, &z) + body, params := qb.edgeMatchBody() + pre.Cascade().Query(body, params...) + + var rows []struct { + UID string `json:"uid"` + } + if err := pre.Nodes(&rows); err != nil { + return nil, err + } + uids := make([]string, len(rows)) + for i := range rows { + uids[i] = rows[i].UID + } + return uids, nil +} + +// edgeMatchBody renders the selection set for the pre-pass: uid plus one +// aliased, filtered block per WhereEdge constraint. The caller adds a bare +// @cascade, which then drops any node with an empty block — so a survivor +// satisfies every constraint. Blocks are aliased mg_e0, mg_e1, ... so two +// constraints on the same predicate do not collide as duplicate fields. Each +// fragment's $N placeholders are shifted to stay bound to its own params once +// every fragment's params are concatenated into one slice. +func (qb *Query[T]) edgeMatchBody() (body string, params []any) { + var b strings.Builder + b.WriteString("{\n\tuid\n") + for i, e := range qb.edges { + b.WriteString("\tmg_e") + b.WriteString(strconv.Itoa(i)) + b.WriteString(" : ") + b.WriteString(e.predicate) + b.WriteString(" @filter(") + b.WriteString(shiftPlaceholders(e.filter, len(params))) + b.WriteString(") { uid }\n") + params = append(params, e.params...) + } + b.WriteString("}") + return b.String(), params +} + +// shiftPlaceholders rewrites dgman ordinal placeholders ($1, $2, ...) in expr, +// adding delta to each index. WhereEdge filters are written independently, each +// numbering its params from $1; concatenating them into one pre-pass body +// needs every fragment renumbered against the combined params slice. A '$' not +// followed by a digit is left as-is, matching dgman's parseQueryWithParams. +func shiftPlaceholders(expr string, delta int) string { + if delta == 0 || !strings.ContainsRune(expr, '$') { + return expr + } + var b strings.Builder + for i := 0; i < len(expr); i++ { + if expr[i] != '$' { + b.WriteByte(expr[i]) + continue + } + j := i + 1 + for j < len(expr) && expr[j] >= '0' && expr[j] <= '9' { + j++ + } + if j == i+1 { // '$' not followed by digits — leave verbatim + b.WriteByte('$') + continue + } + n, _ := strconv.Atoi(expr[i+1 : j]) + b.WriteByte('$') + b.WriteString(strconv.Itoa(n + delta)) + i = j - 1 + } + return b.String() +} diff --git a/typed/query_test.go b/typed/query_test.go new file mode 100644 index 0000000..588bf6b --- /dev/null +++ b/typed/query_test.go @@ -0,0 +1,1294 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package typed_test + +import ( + "context" + "strings" + "testing" + + dg "github.com/dolan-in/dgman/v2" + "github.com/go-logr/logr/funcr" + "github.com/matthewmcneely/modusgraph" + "github.com/matthewmcneely/modusgraph/typed" +) + +// newCountingConn builds a file-backed modusgraph client exactly like newConn, +// but wires in a logr.Logger that counts dgman query executions. dgman logs +// every executed query at verbosity 3 with the message "execute query"; the +// returned *int is incremented once per such log line. +// +// dgman's logger is process-global, and modusgraph allows only one live +// file-backed engine per process (see modusgraph.ErrSingletonOnly). Each call +// uses a fresh t.TempDir() URI for data isolation. Tests that use +// newCountingConn must NOT call t.Parallel(): a second live client would hit +// the engine singleton, and parallel tests would also corrupt the shared +// query count. +func newCountingConn(t *testing.T, count *int) modusgraph.Client { + t.Helper() + logger := funcr.New(func(_, args string) { + // funcr renders the message into args as `"msg"="execute query"`. + // Match that exact pair so unrelated dgman/pool log lines (which log + // other messages, e.g. "executeQuery" for query blocks) are ignored. + if strings.Contains(args, `"msg"="execute query"`) { + *count++ + } + }, funcr.Options{Verbosity: 3}) + conn, err := modusgraph.NewClient("file://"+t.TempDir(), + modusgraph.WithAutoSchema(true), modusgraph.WithLogger(logger)) + if err != nil { + t.Fatalf("modusgraph.NewClient: %v", err) + } + t.Cleanup(conn.Close) + return conn +} + +func TestQuery_NodesReturnsAll(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + for _, n := range []string{"a", "b", "c"} { + if err := c.Add(ctx, &widget{Name: n}); err != nil { + t.Fatalf("Add %s: %v", n, err) + } + } + + got, err := c.Query(ctx).Nodes() + if err != nil { + t.Fatalf("Nodes: %v", err) + } + if len(got) != 3 { + t.Fatalf("Nodes returned %d records, want 3", len(got)) + } +} + +func TestQuery_LimitCapsResults(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + for i := range 5 { + if err := c.Add(ctx, &widget{Name: "w", Qty: i}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + + got, err := c.Query(ctx).Limit(2).Nodes() + if err != nil { + t.Fatalf("Nodes: %v", err) + } + if len(got) != 2 { + t.Fatalf("Limit(2) returned %d records, want 2", len(got)) + } +} + +func TestQuery_FirstReturnsAMatch(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + if err := c.Add(ctx, &widget{Name: "only", Qty: 7}); err != nil { + t.Fatalf("Add: %v", err) + } + + got, err := c.Query(ctx).First() + if err != nil { + t.Fatalf("First: %v", err) + } + if got == nil || got.Name != "only" { + t.Fatalf("First returned %+v, want Name=only", got) + } +} + +func TestQuery_FirstNoMatchReturnsNilNil(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + got, err := c.Query(ctx).First() + if err != nil { + t.Fatalf("First on empty: unexpected error %v", err) + } + if got != nil { + t.Fatalf("First on empty returned %+v, want nil", got) + } +} + +func TestQuery_BuilderChainCompilesAndRuns(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + if err := c.Add(ctx, &widget{Name: "x", Qty: 1}); err != nil { + t.Fatalf("Add: %v", err) + } + + // Every builder method must return *Query[widget] so the chain stays typed. + _, err := c.Query(ctx). + OrderAsc("qty"). + Offset(0). + Limit(10). + Cascade(). + Nodes() + if err != nil { + t.Fatalf("builder chain Nodes: %v", err) + } +} + +func TestQuery_RawExposesUnderlyingBuilder(t *testing.T) { + c := typed.NewClient[widget](newConn(t)) + if c.Query(context.Background()).Raw() == nil { + t.Fatal("Raw() returned nil; expected the underlying *dg.Query") + } +} + +func TestQuery_Filter(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + // Insert three widgets with distinct names. + for _, name := range []string{"alpha", "beta", "gamma"} { + if err := c.Add(ctx, &widget{Name: name}); err != nil { + t.Fatalf("Add %s: %v", name, err) + } + } + + // Filter to exactly those whose name equals "beta" (index=exact allows eq()). + got, err := c.Query(ctx).Filter(`eq(name, "beta")`).Nodes() + if err != nil { + t.Fatalf("Filter Nodes: %v", err) + } + if len(got) != 1 { + t.Fatalf("Filter returned %d records, want 1", len(got)) + } + if got[0].Name != "beta" { + t.Fatalf("Filter returned Name=%q, want beta", got[0].Name) + } +} + +func TestQuery_FilterAccumulatesWithAnd(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + // Three widgets; only "beta"/9 satisfies BOTH name=="beta" and qty>=5. + for _, w := range []widget{ + {Name: "alpha", Qty: 9}, + {Name: "beta", Qty: 9}, + {Name: "beta", Qty: 1}, + } { + if err := c.Add(ctx, &w); err != nil { + t.Fatalf("Add %+v: %v", w, err) + } + } + + // Two Filter calls must AND together, not overwrite. With last-write-wins + // only ge(qty, 5) survives and this returns the two qty>=5 rows instead of + // the single AND match. + got, err := c.Query(ctx). + Filter(`eq(name, "beta")`). + Filter(`ge(qty, "5")`). + Nodes() + if err != nil { + t.Fatalf("Filter Nodes: %v", err) + } + if len(got) != 1 { + t.Fatalf("two ANDed Filters returned %d records, want 1 (name==beta AND qty>=5)", len(got)) + } + if got[0].Name != "beta" || got[0].Qty != 9 { + t.Fatalf("got %+v, want Name=beta Qty=9", got[0]) + } +} + +func TestQuery_CombinedFilterShiftsPlaceholders(t *testing.T) { + q := typed.NewDetachedQuery[widget]() + if expr, params := q.CombinedFilter(); expr != "" || params != nil { + t.Fatalf("empty CombinedFilter = (%q, %v), want (\"\", nil)", expr, params) + } + q.Filter("eq(name, $1)", "a") + q.Filter("eq(qty, $1)", 7) + expr, params := q.CombinedFilter() + const want = "eq(name, $1) AND eq(qty, $2)" + if expr != want { + t.Fatalf("CombinedFilter expr = %q, want %q", expr, want) + } + if len(params) != 2 || params[0] != "a" || params[1] != 7 { + t.Fatalf("CombinedFilter params = %v, want [a 7]", params) + } +} + +func TestQuery_OrGroup(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + for _, w := range []widget{ + {Name: "alpha", Qty: 9}, + {Name: "beta", Qty: 9}, + {Name: "gamma", Qty: 1}, + } { + if err := c.Add(ctx, &w); err != nil { + t.Fatalf("Add %+v: %v", w, err) + } + } + + // name == "alpha" OR name == "gamma": two of three rows. + got, err := c.Query(ctx).OrGroup( + typed.NewDetachedQuery[widget]().Filter(`eq(name, "alpha")`), + typed.NewDetachedQuery[widget]().Filter(`eq(name, "gamma")`), + ).Nodes() + if err != nil { + t.Fatalf("OrGroup Nodes: %v", err) + } + if len(got) != 2 { + t.Fatalf("OrGroup(alpha, gamma) returned %d rows, want 2", len(got)) + } + + // AND-of-OR: qty>=5 AND (name==alpha OR name==gamma) → only alpha/9. + got, err = c.Query(ctx). + Filter(`ge(qty, "5")`). + OrGroup( + typed.NewDetachedQuery[widget]().Filter(`eq(name, "alpha")`), + typed.NewDetachedQuery[widget]().Filter(`eq(name, "gamma")`), + ).Nodes() + if err != nil { + t.Fatalf("AND-of-OR Nodes: %v", err) + } + if len(got) != 1 || got[0].Name != "alpha" { + t.Fatalf("qty>=5 AND (alpha OR gamma) returned %+v, want [alpha/9]", got) + } +} + +func TestQuery_OrderAscDesc(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + // Insert widgets with distinct Qty values in non-sorted order so a + // stable natural ordering cannot hide a missing sort. + qtys := []int{30, 10, 50, 20, 40} + for i, q := range qtys { + if err := c.Add(ctx, &widget{Name: "w", Qty: q}); err != nil { + t.Fatalf("Add widget[%d]: %v", i, err) + } + } + + // Ascending. + asc, err := c.Query(ctx).OrderAsc("qty").Nodes() + if err != nil { + t.Fatalf("OrderAsc Nodes: %v", err) + } + if len(asc) != len(qtys) { + t.Fatalf("OrderAsc returned %d records, want %d", len(asc), len(qtys)) + } + for i := range len(asc) - 1 { + if asc[i].Qty > asc[i+1].Qty { + t.Fatalf("OrderAsc: asc[%d].Qty=%d > asc[%d].Qty=%d; not ascending", + i, asc[i].Qty, i+1, asc[i+1].Qty) + } + } + + // Descending. + desc, err := c.Query(ctx).OrderDesc("qty").Nodes() + if err != nil { + t.Fatalf("OrderDesc Nodes: %v", err) + } + if len(desc) != len(qtys) { + t.Fatalf("OrderDesc returned %d records, want %d", len(desc), len(qtys)) + } + for i := range len(desc) - 1 { + if desc[i].Qty < desc[i+1].Qty { + t.Fatalf("OrderDesc: desc[%d].Qty=%d < desc[%d].Qty=%d; not descending", + i, desc[i].Qty, i+1, desc[i+1].Qty) + } + } +} + +func TestQuery_OffsetSkipsResults(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + // Five widgets with distinct, deliberately unsorted Qty values. + qtys := []int{40, 10, 50, 20, 30} + for i, q := range qtys { + if err := c.Add(ctx, &widget{Name: "w", Qty: q}); err != nil { + t.Fatalf("Add widget[%d]: %v", i, err) + } + } + + // Ordering ascending by qty gives 10,20,30,40,50; Offset(2) drops the + // first two, so 3 rows remain and the first is the 3rd-smallest (30). + got, err := c.Query(ctx).OrderAsc("qty").Offset(2).Nodes() + if err != nil { + t.Fatalf("Offset Nodes: %v", err) + } + if len(got) != 3 { + t.Fatalf("OrderAsc.Offset(2) returned %d records, want 3", len(got)) + } + if got[0].Qty != 30 { + t.Fatalf("first row after Offset(2) has Qty=%d, want 30 (3rd-smallest)", got[0].Qty) + } +} + +func TestQuery_AfterCursor(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + for i := range 5 { + if err := c.Add(ctx, &widget{Name: "w", Qty: i}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + + // First pass: grab all rows so we can pick a non-last cursor UID. + all, err := c.Query(ctx).Nodes() + if err != nil { + t.Fatalf("first Nodes: %v", err) + } + if len(all) < 3 { + t.Fatalf("expected at least 3 widgets, got %d", len(all)) + } + cursor := all[1].UID // a non-last row + + // After(cursor) uses default UID ordering to skip past the cursor node. + got, err := c.Query(ctx).After(cursor).Nodes() + if err != nil { + t.Fatalf("After Nodes: %v", err) + } + if len(got) == 0 { + t.Fatal("After(cursor) returned no rows; expected the rows past the cursor") + } + for _, w := range got { + if w.UID <= cursor { + t.Fatalf("After(%s) returned UID %s, which is not strictly greater than the cursor", + cursor, w.UID) + } + } +} + +func TestQuery_CascadeDropsIncompleteNodes(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + // Widgets with Qty > 0 carry a qty predicate. Widgets with Qty left 0 + // have it omitted entirely (json tag is omitempty), so they have no qty + // predicate at all. + withQty := []int{5, 9, 13} + for _, q := range withQty { + if err := c.Add(ctx, &widget{Name: "has-qty", Qty: q}); err != nil { + t.Fatalf("Add qty=%d: %v", q, err) + } + } + for i := range 4 { + if err := c.Add(ctx, &widget{Name: "no-qty"}); err != nil { + t.Fatalf("Add no-qty[%d]: %v", i, err) + } + } + + // @cascade(qty) drops any node that lacks the qty predicate. + got, err := c.Query(ctx).Cascade("qty").Nodes() + if err != nil { + t.Fatalf("Cascade Nodes: %v", err) + } + if len(got) != len(withQty) { + t.Fatalf("Cascade(qty) returned %d records, want %d (only the qty-bearing widgets)", + len(got), len(withQty)) + } + for _, w := range got { + if w.Qty == 0 { + t.Fatalf("Cascade(qty) returned a widget with Qty=0 (no qty predicate): %+v", w) + } + } +} + +func TestQuery_FilterOrderLimitOffsetCombined(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + // A known set: five "keep" widgets plus a "drop" widget the filter excludes. + for _, q := range []int{50, 20, 40, 10, 30} { + if err := c.Add(ctx, &widget{Name: "keep", Qty: q}); err != nil { + t.Fatalf("Add keep qty=%d: %v", q, err) + } + } + if err := c.Add(ctx, &widget{Name: "drop", Qty: 99}); err != nil { + t.Fatalf("Add drop: %v", err) + } + + // Filter to name=keep -> qtys {10,20,30,40,50}; OrderAsc -> sorted; + // Offset(1) drops 10; Limit(2) keeps {20,30}. + got, err := c.Query(ctx). + Filter(`eq(name, "keep")`). + OrderAsc("qty"). + Offset(1). + Limit(2). + Nodes() + if err != nil { + t.Fatalf("combined chain Nodes: %v", err) + } + if len(got) != 2 { + t.Fatalf("combined chain returned %d records, want 2", len(got)) + } + if got[0].Qty != 20 || got[1].Qty != 30 { + t.Fatalf("combined chain window = [%d, %d], want [20, 30]", got[0].Qty, got[1].Qty) + } +} + +func TestQuery_FirstOnMultipleRows(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + for _, q := range []int{30, 10, 20} { + if err := c.Add(ctx, &widget{Name: "w", Qty: q}); err != nil { + t.Fatalf("Add qty=%d: %v", q, err) + } + } + + // First on an ascending-by-qty query yields exactly the smallest row. + got, err := c.Query(ctx).OrderAsc("qty").First() + if err != nil { + t.Fatalf("First: %v", err) + } + if got == nil { + t.Fatal("First returned nil on a non-empty result set") + } + if got.Qty != 10 { + t.Fatalf("First on OrderAsc(qty) returned Qty=%d, want 10 (smallest)", got.Qty) + } +} + +func TestQuery_NodesEmptyResult(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) // fresh client, no inserts + + got, err := c.Query(ctx).Nodes() + if err != nil { + t.Fatalf("Nodes on empty client: unexpected error %v", err) + } + if len(got) != 0 { + t.Fatalf("Nodes on empty client returned %d records, want 0", len(got)) + } +} + +func TestQuery_OrderAccumulates(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + // OrderAsc and OrderDesc accumulate: both clauses must survive on the + // same query. dgman renders them as "orderasc:"/"orderdesc:" in the + // generated query string. + q := c.Query(ctx).OrderAsc("name").OrderDesc("qty") + s := q.Raw().String() + if !strings.Contains(s, "orderasc: name") { + t.Fatalf("query string missing ascending name order; got:\n%s", s) + } + if !strings.Contains(s, "orderdesc: qty") { + t.Fatalf("query string missing descending qty order; got:\n%s", s) + } +} + +func TestQuery_CascadeOverwrites(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + // Cascade overwrites: the second call wins, the first predicate is gone. + // dgman renders predicates as @cascade(pred1,pred2,...) with no spaces. + q := c.Query(ctx).Cascade("name").Cascade("qty") + s := q.Raw().String() + if !strings.Contains(s, "@cascade(qty)") { + t.Fatalf("second Cascade(qty) not rendered in query string; got:\n%s", s) + } + if strings.Contains(s, "@cascade(name)") { + t.Fatalf("first Cascade(name) still present after overwrite; got:\n%s", s) + } +} + +func TestQuery_TerminalRunsTwice(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + for _, n := range []string{"a", "b", "c"} { + if err := c.Add(ctx, &widget{Name: n}); err != nil { + t.Fatalf("Add %s: %v", n, err) + } + } + + // A terminal is re-runnable: calling Nodes twice on the same builder + // succeeds both times and yields equal-length results. + q := c.Query(ctx) + first, err := q.Nodes() + if err != nil { + t.Fatalf("first Nodes: %v", err) + } + second, err := q.Nodes() + if err != nil { + t.Fatalf("second Nodes: %v", err) + } + if len(first) != len(second) { + t.Fatalf("Nodes run twice returned %d then %d records; want equal lengths", + len(first), len(second)) + } +} + +func TestQuery_BuilderAliasesAndAccumulates(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + // (i) Filter accumulates: after two Filter calls both survive, ANDed. + q := c.Query(ctx) + q.Filter(`eq(name, "alpha")`) + q.Filter(`eq(name, "beta")`) + s := q.Raw().String() + if !strings.Contains(s, `eq(name, "alpha")`) { + t.Fatalf("Filter A dropped; want both fragments present in:\n%s", s) + } + if !strings.Contains(s, `eq(name, "beta")`) { + t.Fatalf("Filter B dropped; want both fragments present in:\n%s", s) + } + if !strings.Contains(s, " AND ") { + t.Fatalf("accumulated filters not ANDed; got:\n%s", s) + } + + // (ii) The builder aliases: a saved reference and further mutation observe + // the same underlying query. ref and q point at the same *Query, so a + // mutation through one is visible through the other. This documents the + // single-use footgun: you cannot branch a saved builder. + ref := q + if ref != q { + t.Fatal("builder reference is not identical to the original *Query") + } + q.OrderAsc("name") + if ref.Raw().String() != q.Raw().String() { + t.Fatal("mutating q did not affect ref; builder is expected to alias a shared query") + } + if !strings.Contains(ref.Raw().String(), "orderasc: name") { + t.Fatalf("order applied via q not visible through ref; got:\n%s", ref.Raw().String()) + } +} + +func TestQuery_RawRoundTrips(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + if err := c.Add(ctx, &widget{Name: "raw-target", Qty: 7}); err != nil { + t.Fatalf("Add: %v", err) + } + + // Take the raw *dg.Query, apply a dgman-only builder method directly, + // then execute via the raw query's own Nodes(&dst). + var raw *dg.Query = c.Query(ctx).Raw() + raw.OrderAsc("qty") + + var dst []widget + if err := raw.Nodes(&dst); err != nil { + t.Fatalf("raw query Nodes: %v", err) + } + if len(dst) != 1 { + t.Fatalf("raw query returned %d records, want 1", len(dst)) + } + if dst[0].Name != "raw-target" || dst[0].Qty != 7 { + t.Fatalf("raw query returned %+v, want Name=raw-target Qty=7", dst[0]) + } +} + +func TestQuery_SingleQueryPerTerminal(t *testing.T) { + // Uses the global dgman logger; must not run in parallel. + ctx := context.Background() + // queriesExecuted is incremented by newCountingConn's logger each time + // dgman runs a query, so it reflects real database round-trips. + var queriesExecuted int + c := typed.NewClient[widget](newCountingConn(t, &queriesExecuted)) + + for i := range 2 { + if err := c.Add(ctx, &widget{Name: "w", Qty: i}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + + // Building the chain runs no queries: builder methods only mutate the AST. + before := queriesExecuted + q := c.Query(ctx).Filter(`eq(name, "w")`).OrderAsc("qty").Limit(10) + if queriesExecuted != before { + t.Fatalf("builder methods executed %d queries, want 0", queriesExecuted-before) + } + + // The Nodes terminal runs exactly one query. + if _, err := q.Nodes(); err != nil { + t.Fatalf("Nodes: %v", err) + } + if got := queriesExecuted - before; got != 1 { + t.Fatalf("Nodes executed %d queries, want exactly 1", got) + } + + // A fresh builder's First terminal also runs exactly one query. + before = queriesExecuted + if _, err := c.Query(ctx).First(); err != nil { + t.Fatalf("First: %v", err) + } + if got := queriesExecuted - before; got != 1 { + t.Fatalf("First executed %d queries, want exactly 1", got) + } +} + +func TestIterNodes_StreamsAll(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + const n = 125 // > defaultPageSize (50): forces multiple pages + for i := range n { + if err := c.Add(ctx, &widget{Name: "w", Qty: i}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + seen := 0 + for w, err := range c.Query(ctx).IterNodes() { + if err != nil { + t.Fatalf("IterNodes yielded error: %v", err) + } + if w == nil { + t.Fatal("IterNodes yielded a nil widget") + } + seen++ + } + if seen != n { + t.Fatalf("IterNodes streamed %d records, want %d", seen, n) + } +} + +func TestIterNodes_StopsOnConsumerBreak(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + const n = 125 + for i := range n { + if err := c.Add(ctx, &widget{Name: "w", Qty: i}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + seen := 0 + for _, err := range c.Query(ctx).IterNodes() { + if err != nil { + t.Fatalf("IterNodes yielded error: %v", err) + } + seen++ + if seen == 10 { + break + } + } + if seen != 10 { + t.Fatalf("IterNodes yielded %d records after break at 10, want 10", seen) + } +} + +func TestIterNodes_EmptyResult(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + seen := 0 + for _, err := range c.Query(ctx).IterNodes() { + if err != nil { + t.Fatalf("IterNodes over empty set yielded error: %v", err) + } + seen++ + } + if seen != 0 { + t.Fatalf("IterNodes over empty set yielded %d records, want 0", seen) + } +} + +func TestIterNodes_RespectsLimit(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + const n = 100 + for i := range n { + if err := c.Add(ctx, &widget{Name: "w", Qty: i}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + seen := 0 + for _, err := range c.Query(ctx).Limit(30).IterNodes() { + if err != nil { + t.Fatalf("IterNodes yielded error: %v", err) + } + seen++ + } + if seen != 30 { + t.Fatalf("Limit(30).IterNodes() streamed %d records, want 30", seen) + } +} + +func TestIterNodes_LimitExceedsResultSet(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + const n = 30 + for i := range n { + if err := c.Add(ctx, &widget{Name: "w", Qty: i}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + seen := 0 + for _, err := range c.Query(ctx).Limit(500).IterNodes() { + if err != nil { + t.Fatalf("IterNodes yielded error: %v", err) + } + seen++ + } + if seen != n { + t.Fatalf("Limit(500).IterNodes() over %d records streamed %d, want %d", n, seen, n) + } +} + +func TestIterNodes_RespectsOffset(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + // Qty values start at 1 (not 0) so omitempty never suppresses the field, + // keeping OrderAsc("qty") a true total order over all records. + const n = 10 + for i := range n { + if err := c.Add(ctx, &widget{Name: "w", Qty: i + 1}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + var got []int + for w, err := range c.Query(ctx).OrderAsc("qty").Offset(3).IterNodes() { + if err != nil { + t.Fatalf("IterNodes yielded error: %v", err) + } + got = append(got, w.Qty) + } + if len(got) != 7 { + t.Fatalf("Offset(3).IterNodes() streamed %d records, want 7", len(got)) + } + for i, q := range got { + if q != i+4 { // Qty=1..10; offset 3 skips 1,2,3 → starts at 4 + t.Fatalf("Offset(3).IterNodes()[%d] Qty = %d, want %d", i, q, i+4) + } + } +} + +func TestIterNodes_RespectsOffsetAndLimit(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + // Qty values start at 1 so omitempty never suppresses the field and + // OrderAsc("qty") is a strict total order across all 200 records. + const n = 200 + for i := range n { + if err := c.Add(ctx, &widget{Name: "w", Qty: i + 1}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + var got []int + for w, err := range c.Query(ctx).OrderAsc("qty").Offset(60).Limit(120).IterNodes() { + if err != nil { + t.Fatalf("IterNodes yielded error: %v", err) + } + got = append(got, w.Qty) + } + if len(got) != 120 { + t.Fatalf("Offset(60).Limit(120).IterNodes() streamed %d records, want 120", len(got)) + } + for i, q := range got { + if q != i+61 { // Qty=1..200; offset 60 skips 1..60 → starts at 61 + t.Fatalf("result[%d] Qty = %d, want %d", i, q, i+61) + } + } +} + +func TestIterNodes_OneQueryPerPage(t *testing.T) { + ctx := context.Background() + var queriesExecuted int + c := typed.NewClient[widget](newCountingConn(t, &queriesExecuted)) + const n = 125 // ceil(125/50) = 3 page queries + for i := range n { + if err := c.Add(ctx, &widget{Name: "w", Qty: i}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + // Obtaining the iterator runs no query — IterNodes is lazy. + seq := c.Query(ctx).IterNodes() + if queriesExecuted != 0 { + t.Fatalf("building the IterNodes iterator executed %d queries, want 0", queriesExecuted) + } + seen := 0 + for _, err := range seq { + if err != nil { + t.Fatalf("IterNodes yielded error: %v", err) + } + seen++ + } + if seen != n { + t.Fatalf("IterNodes streamed %d records, want %d", seen, n) + } + if queriesExecuted != 3 { + t.Fatalf("IterNodes over %d records ran %d queries, want 3", n, queriesExecuted) + } +} + +func TestIterNodes_YieldsErrorAndStops(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + if err := c.Add(ctx, &widget{Name: "w", Qty: 1}); err != nil { + t.Fatalf("Add: %v", err) + } + // A syntactically invalid @filter (unbalanced parenthesis) makes the page + // query fail at execution; IterNodes must yield one (nil, err) and stop. + gotErr := false + for w, err := range c.Query(ctx).Filter(`eq(name, "w"`).IterNodes() { + if err != nil { + gotErr = true + if w != nil { + t.Fatalf("error yield carried a non-nil widget: %+v", w) + } + break + } + t.Fatal("IterNodes over a malformed query yielded a record before erroring") + } + if !gotErr { + t.Fatal("IterNodes over a malformed query did not yield an error") + } +} + +func TestQuery_LimitOffsetStillDriveNodes(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + // Qty values start at 1 so omitempty never suppresses the field and + // OrderAsc("qty") is a strict total order across all records. + const n = 10 + for i := range n { + if err := c.Add(ctx, &widget{Name: "w", Qty: i + 1}); err != nil { + t.Fatalf("Add %d: %v", i, err) + } + } + // Regression: Limit/Offset now also set Query struct fields; confirm they + // still drive the Nodes terminal. + got, err := c.Query(ctx).OrderAsc("qty").Offset(2).Limit(3).Nodes() + if err != nil { + t.Fatalf("Nodes: %v", err) + } + if len(got) != 3 { + t.Fatalf("Offset(2).Limit(3).Nodes() returned %d records, want 3", len(got)) + } + for i, w := range got { + if w.Qty != i+3 { // Qty=1..10; offset 2 skips 1,2 → starts at 3 + t.Fatalf("Nodes()[%d] Qty = %d, want %d", i, w.Qty, i+3) + } + } +} + +func TestQuery_RootFuncOverridesRoot(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + for _, n := range []string{"a", "b", "c"} { + if err := c.Add(ctx, &widget{Name: n}); err != nil { + t.Fatalf("Add %s: %v", n, err) + } + } + // RootFunc replaces the default type(widget) root with an eq() lookup; + // the query still decodes into []widget. + got, err := c.Query(ctx).RootFunc(`eq(name, "b")`).Nodes() + if err != nil { + t.Fatalf("Nodes: %v", err) + } + if len(got) != 1 { + t.Fatalf(`RootFunc(eq(name,"b")).Nodes() returned %d records, want 1`, len(got)) + } + if got[0].Name != "b" { + t.Fatalf("RootFunc lookup returned %q, want \"b\"", got[0].Name) + } +} + +func TestQuery_RootFuncRendersAndOverwrites(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + // RootFunc renders into the (func: ...) position and overwrites: the + // second call wins. + q := c.Query(ctx).RootFunc(`eq(name, "x")`).RootFunc(`eq(name, "y")`) + s := q.Raw().String() + if !strings.Contains(s, `func: eq(name, "y")`) { + t.Fatalf("second RootFunc not rendered; got:\n%s", s) + } + if strings.Contains(s, `eq(name, "x")`) { + t.Fatalf("first RootFunc still present after overwrite; got:\n%s", s) + } +} + +func TestQuery_NameDecodesAfterRename(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + for _, n := range []string{"a", "b", "c"} { + if err := c.Add(ctx, &widget{Name: n}); err != nil { + t.Fatalf("Add %s: %v", n, err) + } + } + // Name renames the query block. dgman uses the name symmetrically to + // generate and decode, so a renamed block still decodes into []widget. + got, err := c.Query(ctx).Name("widgets").Nodes() + if err != nil { + t.Fatalf("Nodes: %v", err) + } + if len(got) != 3 { + t.Fatalf(`Name("widgets").Nodes() returned %d records, want 3`, len(got)) + } +} + +func TestQuery_NameRendersAndOverwrites(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + // Name renders as the block name and overwrites: the second call wins. + q := c.Query(ctx).Name("first").Name("second") + s := q.Raw().String() + if !strings.Contains(s, "second(func:") { + t.Fatalf("second Name not rendered as block name; got:\n%s", s) + } + if strings.Contains(s, "first(func:") { + t.Fatalf("first Name still present after overwrite; got:\n%s", s) + } +} + +func TestQuery_AsRendersAndOverwrites(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + // As transitions to *RawQuery, prefixes the block with " as ", + // and overwrites: the second call wins. + q := c.Query(ctx).As("first").As("second") + if q == nil { + t.Fatal("As() returned nil *RawQuery") + } + s := q.String() + if !strings.Contains(s, "second as ") { + t.Fatalf("second As not rendered; got:\n%s", s) + } + if strings.Contains(s, "first as ") { + t.Fatalf("first As still present after overwrite; got:\n%s", s) + } +} + +func TestQuery_VarsRendersQueryPrefix(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + // Vars renders a "query " prefix on the generated DQL. + q := c.Query(ctx).Vars("getByName($n: string)", map[string]string{"$n": "b"}) + s := q.Raw().String() + if !strings.Contains(s, "query getByName($n: string)") { + t.Fatalf("Vars did not render the query-definition prefix; got:\n%s", s) + } +} + +func TestQuery_VarsParameterizedQuery(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + for _, n := range []string{"a", "b", "c"} { + if err := c.Add(ctx, &widget{Name: n}); err != nil { + t.Fatalf("Add %s: %v", n, err) + } + } + // Vars supplies a GraphQL variable bound into the root function; the + // query executes via dgraph's QueryWithVars path. + got, err := c.Query(ctx). + Vars("getByName($n: string)", map[string]string{"$n": "b"}). + RootFunc("eq(name, $n)"). + Nodes() + if err != nil { + t.Fatalf("Vars query Nodes: %v", err) + } + if len(got) != 1 || got[0].Name != "b" { + t.Fatalf(`Vars parameterized query returned %+v, want one widget named "b"`, got) + } +} + +func TestQuery_VarReturnsRawQuery(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + // Var transitions to *RawQuery and emits a var block: dgman renders the + // block name as "var". + rq := c.Query(ctx).Var() + if rq == nil { + t.Fatal("Var() returned nil *RawQuery") + } + s := rq.String() + if !strings.Contains(s, "var(func:") { + t.Fatalf("Var() did not render a var block; got:\n%s", s) + } +} + +func TestQuery_GroupByReturnsRawQuery(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + // GroupBy transitions to *RawQuery and emits an @groupby clause. + rq := c.Query(ctx).GroupBy("name") + if rq == nil { + t.Fatal("GroupBy() returned nil *RawQuery") + } + s := rq.String() + if !strings.Contains(s, "@groupby(name)") { + t.Fatalf(`GroupBy("name") did not render an @groupby clause; got:\n%s`, s) + } +} + +func TestRawQuery_RawExposesUnderlyingQuery(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + rq := c.Query(ctx).Var() + // Raw returns the underlying *dg.Query; String mirrors Raw().String(). + var raw *dg.Query = rq.Raw() + if raw == nil { + t.Fatal("RawQuery.Raw() returned nil") + } + if rq.String() != raw.String() { + t.Fatalf("RawQuery.String() and Raw().String() differ:\n%s\n---\n%s", + rq.String(), raw.String()) + } +} + +func TestRawQuery_GroupByThenVarChains(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + // RawQuery re-exposes Var and GroupBy so the canonical .GroupBy(...).Var() + // composition still chains; both clauses survive. + s := c.Query(ctx).GroupBy("name").Var().String() + if !strings.Contains(s, "@groupby(name)") { + t.Fatalf("@groupby clause missing after GroupBy().Var(); got:\n%s", s) + } + if !strings.Contains(s, "var(func:") { + t.Fatalf("var block missing after GroupBy().Var(); got:\n%s", s) + } +} + +func TestRawQuery_CarriesEarlierBuilders(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + // Builders applied on *Query[T] before the GroupBy transition survive + // into the *RawQuery — the two share one underlying *dg.Query. + s := c.Query(ctx).Filter(`eq(name, "z")`).GroupBy("name").String() + if !strings.Contains(s, `eq(name, "z")`) { + t.Fatalf("Filter set before GroupBy did not survive the transition; got:\n%s", s) + } + if !strings.Contains(s, "@groupby(name)") { + t.Fatalf("@groupby clause missing; got:\n%s", s) + } +} + +// seedOwners inserts owner/pet pairs over conn for the WhereEdge tests. Each +// map entry is one owner owning one pet of the given name; the pet is inserted +// first so the owner's edge links an already-persisted node. It returns an +// owner client bound to conn. +func seedOwners(ctx context.Context, t *testing.T, conn modusgraph.Client, ownerToPet map[string]string) *typed.Client[owner] { + t.Helper() + pets := typed.NewClient[pet](conn) + owners := typed.NewClient[owner](conn) + for ownerName, petName := range ownerToPet { + p := &pet{Name: petName} + if err := pets.Add(ctx, p); err != nil { + t.Fatalf("Add pet %q: %v", petName, err) + } + if err := owners.Add(ctx, &owner{Name: ownerName, Pets: []*pet{p}}); err != nil { + t.Fatalf("Add owner %q: %v", ownerName, err) + } + } + return owners +} + +func TestQuery_WhereEdgeFiltersByEdgeTarget(t *testing.T) { + ctx := context.Background() + owners := seedOwners(ctx, t, newConn(t), map[string]string{ + "Alice": "Fido", + "Bob": "Rex", + "Carol": "Fido", + }) + + // WhereEdge constrains owners by a scalar of the pet reached over the + // "pets" edge — something a root Filter cannot express. + got, err := owners.Query(ctx).WhereEdge("pets", `eq(name, "Fido")`).Nodes() + if err != nil { + t.Fatalf("WhereEdge Nodes: %v", err) + } + if len(got) != 2 { + t.Fatalf("WhereEdge(pets, name=Fido) returned %d owners, want 2 (Alice, Carol)", len(got)) + } + for _, o := range got { + if o.Name != "Alice" && o.Name != "Carol" { + t.Fatalf("WhereEdge returned %q, want only Fido owners (Alice, Carol)", o.Name) + } + } +} + +func TestQuery_WhereEdgeNoMatchReturnsEmpty(t *testing.T) { + ctx := context.Background() + owners := seedOwners(ctx, t, newConn(t), map[string]string{"Alice": "Fido", "Bob": "Rex"}) + + // No pet is named Nemo: the pre-pass matches zero roots, so Nodes returns + // an empty result — not an error — and never runs the main query. + got, err := owners.Query(ctx).WhereEdge("pets", `eq(name, "Nemo")`).Nodes() + if err != nil { + t.Fatalf("WhereEdge Nodes: unexpected error %v", err) + } + if len(got) != 0 { + t.Fatalf("WhereEdge for an unowned pet name returned %d owners, want 0", len(got)) + } +} + +func TestQuery_WhereEdgeBindsParams(t *testing.T) { + ctx := context.Background() + owners := seedOwners(ctx, t, newConn(t), map[string]string{"Alice": "Fido", "Bob": "Rex"}) + + // The $1 placeholder in a WhereEdge filter binds exactly as it does for Filter. + got, err := owners.Query(ctx).WhereEdge("pets", "eq(name, $1)", "Rex").Nodes() + if err != nil { + t.Fatalf("WhereEdge Nodes: %v", err) + } + if len(got) != 1 || got[0].Name != "Bob" { + t.Fatalf("WhereEdge(pets, name=$1, Rex) returned %+v, want [Bob]", got) + } +} + +func TestQuery_WhereEdgeCombinesWithFilter(t *testing.T) { + ctx := context.Background() + // Alice and Carol both own a Fido; a root Filter on name narrows to Alice. + owners := seedOwners(ctx, t, newConn(t), map[string]string{ + "Alice": "Fido", + "Bob": "Rex", + "Carol": "Fido", + }) + + got, err := owners.Query(ctx). + Filter(`eq(name, "Alice")`). + WhereEdge("pets", `eq(name, "Fido")`). + Nodes() + if err != nil { + t.Fatalf("Filter+WhereEdge Nodes: %v", err) + } + if len(got) != 1 || got[0].Name != "Alice" { + t.Fatalf("Filter(name=Alice)+WhereEdge(pets,name=Fido) returned %+v, want [Alice]", got) + } +} + +func TestQuery_WhereEdgeMultipleConstraintsAnd(t *testing.T) { + ctx := context.Background() + conn := newConn(t) + pets := typed.NewClient[pet](conn) + owners := typed.NewClient[owner](conn) + + // Alice owns both Fido and Rex; Bob owns only Fido. + fido, rex := &pet{Name: "Fido"}, &pet{Name: "Rex"} + for _, p := range []*pet{fido, rex} { + if err := pets.Add(ctx, p); err != nil { + t.Fatalf("Add pet %q: %v", p.Name, err) + } + } + if err := owners.Add(ctx, &owner{Name: "Alice", Pets: []*pet{fido, rex}}); err != nil { + t.Fatalf("Add Alice: %v", err) + } + if err := owners.Add(ctx, &owner{Name: "Bob", Pets: []*pet{fido}}); err != nil { + t.Fatalf("Add Bob: %v", err) + } + + // Two WhereEdge calls AND together: only an owner of BOTH pets survives. + got, err := owners.Query(ctx). + WhereEdge("pets", `eq(name, "Fido")`). + WhereEdge("pets", `eq(name, "Rex")`). + Nodes() + if err != nil { + t.Fatalf("two-WhereEdge Nodes: %v", err) + } + if len(got) != 1 || got[0].Name != "Alice" { + t.Fatalf("WhereEdge(Fido) AND WhereEdge(Rex) returned %+v, want [Alice]", got) + } +} + +func TestQuery_WhereEdgeFirst(t *testing.T) { + ctx := context.Background() + owners := seedOwners(ctx, t, newConn(t), map[string]string{"Alice": "Fido", "Bob": "Rex"}) + + // First runs the pre-pass too: it returns the Rex owner, never a Fido one. + got, err := owners.Query(ctx).WhereEdge("pets", `eq(name, "Rex")`).First() + if err != nil { + t.Fatalf("WhereEdge First: %v", err) + } + if got == nil || got.Name != "Bob" { + t.Fatalf("WhereEdge(pets,name=Rex).First() = %+v, want Bob", got) + } + + // First with an edge constraint nothing satisfies is (nil, nil). + none, err := owners.Query(ctx).WhereEdge("pets", `eq(name, "Nemo")`).First() + if err != nil { + t.Fatalf("WhereEdge First no-match: unexpected error %v", err) + } + if none != nil { + t.Fatalf("WhereEdge First with no match = %+v, want nil", none) + } +} + +func TestQuery_WhereEdgeIterNodes(t *testing.T) { + ctx := context.Background() + owners := seedOwners(ctx, t, newConn(t), map[string]string{ + "Alice": "Fido", + "Bob": "Rex", + "Carol": "Fido", + }) + + seen := 0 + for o, err := range owners.Query(ctx).WhereEdge("pets", `eq(name, "Fido")`).IterNodes() { + if err != nil { + t.Fatalf("WhereEdge IterNodes yielded error: %v", err) + } + if o.Name != "Alice" && o.Name != "Carol" { + t.Fatalf("WhereEdge IterNodes yielded %q, want a Fido owner", o.Name) + } + seen++ + } + if seen != 2 { + t.Fatalf("WhereEdge IterNodes streamed %d owners, want 2", seen) + } +} + +func TestQuery_UIDRootsAtNode(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + w := &widget{Name: "sprocket", Qty: 3} + if err := c.Add(ctx, w); err != nil { + t.Fatalf("Add: %v", err) + } + + got, err := c.Query(ctx).UID(w.UID).Nodes() + if err != nil { + t.Fatalf("Nodes: %v", err) + } + if len(got) != 1 || got[0].Name != "sprocket" { + t.Fatalf("UID query returned %+v, want one widget named sprocket", got) + } +} + +func TestQuery_NodesAndCountReturnsTotal(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + for i := 0; i < 3; i++ { + if err := c.Add(ctx, &widget{Name: "w", Qty: i}); err != nil { + t.Fatalf("Add: %v", err) + } + } + + nodes, count, err := c.Query(ctx).NodesAndCount() + if err != nil { + t.Fatalf("NodesAndCount: %v", err) + } + if count != 3 || len(nodes) != 3 { + t.Fatalf("got count=%d len=%d, want 3 and 3", count, len(nodes)) + } +} + +func TestQuery_AllSetsTraversalDepth(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + if err := c.Add(ctx, &widget{Name: "deep", Qty: 1}); err != nil { + t.Fatalf("Add: %v", err) + } + + // All(1) overrides the default traversal depth for this query; the call + // must chain and the query must still execute and decode. + got, err := c.Query(ctx).All(1).Nodes() + if err != nil { + t.Fatalf("Nodes with All(1): %v", err) + } + if len(got) != 1 { + t.Fatalf("got %d widgets, want 1", len(got)) + } +} + +func TestQuery_StringRendersDQL(t *testing.T) { + ctx := context.Background() + c := typed.NewClient[widget](newConn(t)) + + dql := c.Query(ctx).Filter("eq(name, $1)", "sprocket").String() + if !strings.Contains(dql, "widget") { + t.Fatalf("String() = %q, want it to mention the widget type", dql) + } +} diff --git a/typed/search/merge.go b/typed/search/merge.go new file mode 100644 index 0000000..2546274 --- /dev/null +++ b/typed/search/merge.go @@ -0,0 +1,27 @@ +// Package search provides helpers for assembling fulltext / ranked search +// results across multiple typed query blocks. +package search + +// MergeByID concatenates inputs into a single slice while preserving +// first-seen order and dropping any subsequent occurrence of an ID already +// emitted. The id function extracts a comparable identifier from each row. +// +// MergeByID is intended for use after typed.MultiQuery.Execute, when +// consumers want a single ranked slice from N per-field result sets: +// inputs[0] takes priority, inputs[1] fills in next, etc. A nil result +// indicates no rows survived (the inputs were all empty). +func MergeByID[T any](id func(T) string, inputs ...[]T) []T { + seen := make(map[string]struct{}) + var out []T + for _, in := range inputs { + for _, row := range in { + k := id(row) + if _, dup := seen[k]; dup { + continue + } + seen[k] = struct{}{} + out = append(out, row) + } + } + return out +} diff --git a/typed/search/merge_test.go b/typed/search/merge_test.go new file mode 100644 index 0000000..e4e8583 --- /dev/null +++ b/typed/search/merge_test.go @@ -0,0 +1,86 @@ +package search_test + +import ( + "reflect" + "testing" + + "github.com/matthewmcneely/modusgraph/typed/search" +) + +type rec struct { + ID string + Tag string +} + +func id(r rec) string { return r.ID } + +func TestMergeByID(t *testing.T) { + cases := []struct { + name string + inputs [][]rec + want []rec + }{ + { + name: "empty inputs returns nil", + inputs: nil, + want: nil, + }, + { + name: "single empty slice returns nil", + inputs: [][]rec{{}}, + want: nil, + }, + { + name: "single slice returns it as-is", + inputs: [][]rec{{ + {ID: "a", Tag: "1"}, + {ID: "b", Tag: "1"}, + }}, + want: []rec{ + {ID: "a", Tag: "1"}, + {ID: "b", Tag: "1"}, + }, + }, + { + name: "two slices merge in priority order", + inputs: [][]rec{ + {{ID: "a", Tag: "name"}}, + {{ID: "b", Tag: "desc"}}, + }, + want: []rec{ + {ID: "a", Tag: "name"}, + {ID: "b", Tag: "desc"}, + }, + }, + { + name: "duplicate ID keeps first-seen entry", + inputs: [][]rec{ + {{ID: "a", Tag: "name"}}, + {{ID: "a", Tag: "desc"}, {ID: "b", Tag: "desc"}}, + }, + want: []rec{ + {ID: "a", Tag: "name"}, + {ID: "b", Tag: "desc"}, + }, + }, + { + name: "intra-slice duplicates dedup too", + inputs: [][]rec{ + {{ID: "a", Tag: "1"}, {ID: "a", Tag: "2"}, {ID: "b", Tag: "1"}}, + }, + want: []rec{ + {ID: "a", Tag: "1"}, + {ID: "b", Tag: "1"}, + }, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + got := search.MergeByID(id, c.inputs...) + if !reflect.DeepEqual(got, c.want) { + t.Fatalf("got %v, want %v", got, c.want) + } + }) + } +} diff --git a/typed/tracing.go b/typed/tracing.go new file mode 100644 index 0000000..8a456ed --- /dev/null +++ b/typed/tracing.go @@ -0,0 +1,58 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package typed + +import ( + "context" + "reflect" +) + +// Span is a tracing span for a single database operation. End is called once, +// with the operation's final error (nil on success). +type Span interface { + End(err error) +} + +// Tracer starts a Span around a typed-layer database operation. The typed +// client calls the installed Tracer for every DB call; the default is a no-op, +// so the typed package itself carries no tracing dependency. Install a real +// tracer — for example github.com/mlwelles/modusgraph-telemetry's OpenTelemetry +// tracer — with SetTracer. +type Tracer interface { + // StartSpan begins a span for operation op (for example "get") on the named + // collection, returning a context carrying the span and the Span itself. + StartSpan(ctx context.Context, op, collection string) (context.Context, Span) +} + +type noopSpan struct{} + +func (noopSpan) End(error) {} + +type noopTracer struct{} + +func (noopTracer) StartSpan(ctx context.Context, _, _ string) (context.Context, Span) { + return ctx, noopSpan{} +} + +// tracer is the process-wide tracer the typed package uses. It is a no-op until +// a host installs one via SetTracer. +var tracer Tracer = noopTracer{} + +// SetTracer installs the process-wide tracer for typed-layer DB spans. Passing +// nil restores the no-op tracer. Install once during startup; it is not safe to +// call concurrently with active queries. +func SetTracer(t Tracer) { + if t == nil { + t = noopTracer{} + } + tracer = t +} + +// entityName returns the unqualified Go type name of T (for example "Resource"), +// used as the db.collection.name span attribute. +func entityName[T any]() string { + return reflect.TypeFor[T]().Name() +} diff --git a/typed/tracing_test.go b/typed/tracing_test.go new file mode 100644 index 0000000..d9aab78 --- /dev/null +++ b/typed/tracing_test.go @@ -0,0 +1,47 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package typed + +import ( + "context" + "testing" +) + +func TestSetTracer_InstallsAndResets(t *testing.T) { + t.Cleanup(func() { SetTracer(nil) }) + + rec := &recordingTracer{} + SetTracer(rec) + + _, span := tracer.StartSpan(context.Background(), "get", "Widget") + span.End(nil) + + if rec.op != "get" || rec.collection != "Widget" { + t.Fatalf("installed tracer not invoked: %+v", rec) + } + if !rec.ended { + t.Fatal("span.End was not called") + } + + // nil restores the no-op tracer, which must not panic. + SetTracer(nil) + _, span = tracer.StartSpan(context.Background(), "x", "Y") + span.End(nil) +} + +type recordingTracer struct { + op, collection string + ended bool +} + +func (r *recordingTracer) StartSpan(ctx context.Context, op, collection string) (context.Context, Span) { + r.op, r.collection = op, collection + return ctx, &recordingSpan{r} +} + +type recordingSpan struct{ r *recordingTracer } + +func (s *recordingSpan) End(error) { s.r.ended = true } From b477c28daec118d7eb2da55c3ed194891931381a Mon Sep 17 00:00:00 2001 From: Michael Welles Date: Thu, 4 Jun 2026 13:11:17 -0400 Subject: [PATCH 02/24] feat: aborted-transaction retry policy, runner, and client integration Add RetryPolicy / DefaultRetryPolicy and a runner that re-executes a function on aborted Dgraph transactions with exponential backoff (retry.go), exposed on the client via a WithRetry method. --- client.go | 3 + retry.go | 96 ++++++++++++++++++++ retry_internal_test.go | 68 ++++++++++++++ retry_test.go | 197 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 364 insertions(+) create mode 100644 retry.go create mode 100644 retry_internal_test.go create mode 100644 retry_test.go diff --git a/client.go b/client.go index be9813b..e4bb263 100644 --- a/client.go +++ b/client.go @@ -87,6 +87,9 @@ type Client interface { // DgraphClient returns a gRPC Dgraph client from the connection pool and a cleanup function. // The cleanup function must be called when finished with the client to return it to the pool. DgraphClient() (*dgo.Dgraph, func(), error) + + // WithRetry executes fn, retrying on aborted transactions per policy. + WithRetry(ctx context.Context, policy RetryPolicy, fn func() error) error } const ( diff --git a/retry.go b/retry.go new file mode 100644 index 0000000..9b49fda --- /dev/null +++ b/retry.go @@ -0,0 +1,96 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package modusgraph + +import ( + "context" + "errors" + "math/rand/v2" + "time" + + "github.com/dgraph-io/dgo/v250" +) + +// RetryPolicy controls how WithRetry handles aborted transactions. +// Modeled after dgraph4j's RetryPolicy: exponential backoff with jitter. +type RetryPolicy struct { + // MaxRetries is the maximum number of retry attempts after the initial try. + MaxRetries int + + // BaseDelay is the initial delay before the first retry. + // Subsequent delays grow exponentially: BaseDelay * 2^attempt. + BaseDelay time.Duration + + // MaxDelay caps the backoff duration. No single delay exceeds this. + MaxDelay time.Duration + + // Jitter adds randomness to each delay to prevent thundering herd. + // Expressed as a fraction of the computed delay (e.g. 0.1 = 10%). + Jitter float64 +} + +// DefaultRetryPolicy mirrors dgraph4j's defaults: +// 5 retries, 100ms base delay, 5s max delay, 10% jitter. +var DefaultRetryPolicy = RetryPolicy{ + MaxRetries: 10, + BaseDelay: 100 * time.Millisecond, + MaxDelay: 5 * time.Second, + Jitter: 0.1, +} + +// delay computes the backoff duration for a given attempt (0-indexed). +// Formula: min(BaseDelay * 2^attempt, MaxDelay) + random(0, delay * Jitter) +func (p RetryPolicy) delay(attempt int) time.Duration { + d := p.BaseDelay * time.Duration(1< p.MaxDelay { + d = p.MaxDelay + } + if p.Jitter > 0 { + d += time.Duration(float64(d) * p.Jitter * rand.Float64()) + } + return d +} + +// WithRetry executes fn, retrying on aborted transactions according to policy. +// +// This is an opt-in mechanism modeled after dgraph4j's client.withRetry(). +// The caller wraps their mutation logic in fn; WithRetry handles creating +// fresh attempts with exponential backoff when Dgraph returns a transaction +// abort due to concurrent conflicts. +// +// fn is called at least once. On each aborted-transaction error, WithRetry +// waits according to the policy's backoff schedule and calls fn again, up to +// policy.MaxRetries additional times. Non-abort errors are returned immediately. +// +// The context is checked between retries; if cancelled during a backoff sleep, +// the context error is returned. +// +// Usage: +// +// err := client.WithRetry(ctx, modusgraph.DefaultRetryPolicy, func() error { +// return client.Insert(ctx, &entity) +// }) +func (c client) WithRetry(ctx context.Context, policy RetryPolicy, fn func() error) error { + for attempt := range policy.MaxRetries + 1 { + err := fn() + if err == nil { + return nil + } + if !errors.Is(err, dgo.ErrAborted) || attempt >= policy.MaxRetries { + return err + } + d := policy.delay(attempt) + c.logger.V(1).Info("Transaction aborted, retrying", + "attempt", attempt+1, "maxRetries", policy.MaxRetries, "delay", d) + select { + case <-time.After(d): + case <-ctx.Done(): + return ctx.Err() + } + } + // Unreachable: the loop runs MaxRetries+1 times and returns on every path. + panic("unreachable") +} diff --git a/retry_internal_test.go b/retry_internal_test.go new file mode 100644 index 0000000..ce6bd2b --- /dev/null +++ b/retry_internal_test.go @@ -0,0 +1,68 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package modusgraph + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestRetryPolicyDelayExponentialGrowth(t *testing.T) { + p := RetryPolicy{ + BaseDelay: 100 * time.Millisecond, + MaxDelay: 10 * time.Second, + Jitter: 0, + } + + assert.Equal(t, 100*time.Millisecond, p.delay(0)) + assert.Equal(t, 200*time.Millisecond, p.delay(1)) + assert.Equal(t, 400*time.Millisecond, p.delay(2)) + assert.Equal(t, 800*time.Millisecond, p.delay(3)) + assert.Equal(t, 1600*time.Millisecond, p.delay(4)) +} + +func TestRetryPolicyDelayMaxCap(t *testing.T) { + p := RetryPolicy{ + BaseDelay: 1 * time.Second, + MaxDelay: 3 * time.Second, + Jitter: 0, + } + + assert.Equal(t, 1*time.Second, p.delay(0)) + assert.Equal(t, 2*time.Second, p.delay(1)) + assert.Equal(t, 3*time.Second, p.delay(2)) + assert.Equal(t, 3*time.Second, p.delay(3)) + assert.Equal(t, 3*time.Second, p.delay(10)) +} + +func TestRetryPolicyDelayWithJitter(t *testing.T) { + p := RetryPolicy{ + BaseDelay: 100 * time.Millisecond, + MaxDelay: 10 * time.Second, + Jitter: 0.5, + } + + for range 100 { + d := p.delay(0) + assert.GreaterOrEqual(t, d, 100*time.Millisecond, "delay should be at least base") + assert.LessOrEqual(t, d, 150*time.Millisecond, "delay should not exceed base + 50% jitter") + } +} + +func TestRetryPolicyDelayZeroJitter(t *testing.T) { + p := RetryPolicy{ + BaseDelay: 100 * time.Millisecond, + MaxDelay: 10 * time.Second, + Jitter: 0, + } + + for range 10 { + assert.Equal(t, 100*time.Millisecond, p.delay(0)) + assert.Equal(t, 200*time.Millisecond, p.delay(1)) + } +} diff --git a/retry_test.go b/retry_test.go new file mode 100644 index 0000000..4cb0d86 --- /dev/null +++ b/retry_test.go @@ -0,0 +1,197 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package modusgraph_test + +import ( + "context" + "fmt" + "os" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/matthewmcneely/modusgraph" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// RetryEntity is a test struct with a unique index to provoke transaction conflicts. +type RetryEntity struct { + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` + Name string `json:"name,omitempty" dgraph:"index=term,exact upsert"` + Value int `json:"value,omitempty"` +} + +// TestConcurrentInsertsWithRetry verifies that WithRetry handles aborted +// transactions from concurrent inserts. Without WithRetry, concurrent inserts +// on the same predicate index would fail with dgo.ErrAborted. +func TestConcurrentInsertsWithRetry(t *testing.T) { + testCases := []struct { + name string + uri string + skip bool + }{ + { + name: "FileURI", + uri: "file://" + GetTempDir(t), + }, + { + name: "DgraphURI", + uri: "dgraph://" + os.Getenv("MODUSGRAPH_TEST_ADDR"), + skip: os.Getenv("MODUSGRAPH_TEST_ADDR") == "", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if tc.skip { + t.Skipf("Skipping %s: MODUSGRAPH_TEST_ADDR not set", tc.name) + return + } + + client, cleanup := CreateTestClient(t, tc.uri) + defer cleanup() + + ctx := context.Background() + const numWorkers = 8 + const entitiesPerWorker = 10 + + var succeeded atomic.Int64 + var wg sync.WaitGroup + + for w := range numWorkers { + wg.Add(1) + go func() { + defer wg.Done() + for i := range entitiesPerWorker { + entity := &RetryEntity{ + Name: fmt.Sprintf("entity-%d-%d", w, i), + Value: w*entitiesPerWorker + i, + } + err := client.WithRetry(ctx, modusgraph.DefaultRetryPolicy, func() error { + return client.Insert(ctx, entity) + }) + if err != nil { + t.Errorf("worker %d entity %d: %v", w, i, err) + return + } + succeeded.Add(1) + } + }() + } + wg.Wait() + + total := int64(numWorkers * entitiesPerWorker) + require.Equal(t, total, succeeded.Load(), + "all concurrent inserts should succeed with retry") + }) + } +} + +// TestWithRetryContextCancellation verifies that WithRetry respects context +// cancellation during backoff sleeps. +func TestWithRetryContextCancellation(t *testing.T) { + uri := "file://" + GetTempDir(t) + client, cleanup := CreateTestClient(t, uri) + defer cleanup() + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + // Use a policy with a long delay so the context expires during backoff. + slowPolicy := modusgraph.RetryPolicy{ + MaxRetries: 10, + BaseDelay: 1 * time.Second, + MaxDelay: 5 * time.Second, + Jitter: 0, + } + + callCount := 0 + err := client.WithRetry(ctx, slowPolicy, func() error { + callCount++ + // Always return an error that looks like an abort to trigger retry. + // We simulate this by inserting a duplicate to get a UniqueError, + // but that won't be retried. Instead, use a real insert to a fresh + // entity so the first call succeeds. + // Actually, to test the cancellation path we need the fn to always + // fail with an aborted error. Since we can't easily manufacture + // dgo.ErrAborted, test that context cancellation returns ctx.Err() + // by having fn block until context is done. + <-ctx.Done() + return ctx.Err() + }) + + assert.ErrorIs(t, err, context.DeadlineExceeded) + assert.Equal(t, 1, callCount, "fn should be called once before context expires") +} + +// TestRetryPolicyDelay verifies the exponential backoff calculation. +func TestRetryPolicyDelay(t *testing.T) { + // Use the public struct fields to verify delay behavior indirectly + // by checking that DefaultRetryPolicy has the expected values. + p := modusgraph.DefaultRetryPolicy + assert.Equal(t, 10, p.MaxRetries) + assert.Equal(t, 100*time.Millisecond, p.BaseDelay) + assert.Equal(t, 5*time.Second, p.MaxDelay) + assert.InDelta(t, 0.1, p.Jitter, 0.001) +} + +// TestWithRetryNonAbortError verifies that non-abort errors are returned +// immediately without any retry. +func TestWithRetryNonAbortError(t *testing.T) { + uri := "file://" + GetTempDir(t) + client, cleanup := CreateTestClient(t, uri) + defer cleanup() + + callCount := 0 + expectedErr := fmt.Errorf("not an abort error") + + err := client.WithRetry(context.Background(), modusgraph.DefaultRetryPolicy, func() error { + callCount++ + return expectedErr + }) + + assert.ErrorIs(t, err, expectedErr) + assert.Equal(t, 1, callCount, "non-abort errors should not trigger retry") +} + +// TestWithRetrySucceedsFirstTry verifies that WithRetry returns nil +// when fn succeeds on the first call. +func TestWithRetrySucceedsFirstTry(t *testing.T) { + uri := "file://" + GetTempDir(t) + client, cleanup := CreateTestClient(t, uri) + defer cleanup() + + callCount := 0 + err := client.WithRetry(context.Background(), modusgraph.DefaultRetryPolicy, func() error { + callCount++ + return nil + }) + + assert.NoError(t, err) + assert.Equal(t, 1, callCount) +} + +// TestWithRetryMaxRetriesZero verifies that MaxRetries=0 calls fn exactly once +// and returns any error without retrying. +func TestWithRetryMaxRetriesZero(t *testing.T) { + uri := "file://" + GetTempDir(t) + client, cleanup := CreateTestClient(t, uri) + defer cleanup() + + policy := modusgraph.RetryPolicy{MaxRetries: 0} + callCount := 0 + + err := client.WithRetry(context.Background(), policy, func() error { + callCount++ + return fmt.Errorf("always fails") + }) + + assert.Error(t, err) + assert.Equal(t, 1, callCount, "MaxRetries=0 should call fn exactly once") +} From b74cec750218c51c12e8282a86ebdbce1d8ebc19 Mon Sep 17 00:00:00 2001 From: Michael Welles Date: Thu, 4 Jun 2026 13:48:08 -0400 Subject: [PATCH 03/24] feat: recognize generated schema types via SchemaTypeName + UnwrapSchema MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add the Schema interface (SchemaTypeName), the UnwrapSchema reflection helper, and the DgraphMapper interface (record.go). The client unwraps schema-defining values at the mutation and query boundary so generated wrapper types route to their backing schema struct. Plain structs do not implement Schema and are unaffected — UnwrapSchema is identity for them. --- client.go | 9 ++++ record.go | 58 ++++++++++++++++++++++++ record_test.go | 117 +++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 184 insertions(+) create mode 100644 record.go create mode 100644 record_test.go diff --git a/client.go b/client.go index be9813b..14e38ee 100644 --- a/client.go +++ b/client.go @@ -486,6 +486,7 @@ func (c client) validateStruct(ctx context.Context, obj any) error { // Insert implements inserting an object or slice of objects in the database. // Passed object must be a pointer to a struct with appropriate dgraph tags. func (c client) Insert(ctx context.Context, obj any) error { + obj = UnwrapSchema(obj) // Validate struct before insertion if err := c.validateStruct(ctx, obj); err != nil { return err @@ -503,6 +504,7 @@ func (c client) Insert(ctx context.Context, obj any) error { // // Deprecated: InsertRaw is now identical to Insert. Use Insert instead. func (c client) InsertRaw(ctx context.Context, obj any) error { + obj = UnwrapSchema(obj) // Validate struct before insertion if err := c.validateStruct(ctx, obj); err != nil { return err @@ -518,6 +520,7 @@ func (c client) InsertRaw(ctx context.Context, obj any) error { // to be used for upserting. If none are specified, the first predicate with the `upsert` tag // will be used. func (c client) Upsert(ctx context.Context, obj any, predicates ...string) error { + obj = UnwrapSchema(obj) // Validate struct before upsert if err := c.validateStruct(ctx, obj); err != nil { return err @@ -531,6 +534,7 @@ func (c client) Upsert(ctx context.Context, obj any, predicates ...string) error // Update implements updating an existing object in the database. // Passed object must be a pointer to a struct. func (c client) Update(ctx context.Context, obj any) error { + obj = UnwrapSchema(obj) // Validate struct before update if err := c.validateStruct(ctx, obj); err != nil { return err @@ -557,6 +561,7 @@ func (c client) Delete(ctx context.Context, uids []string) error { // Get implements retrieving a single object by its UID. // Passed object must be a pointer to a struct. func (c client) Get(ctx context.Context, obj any, uid string) error { + obj = UnwrapSchema(obj) err := checkPointer(obj) if err != nil { return err @@ -575,6 +580,7 @@ func (c client) Get(ctx context.Context, obj any, uid string) error { // Returns a *dg.Query that can be further refined with filters, pagination, etc. // The returned query will be limited to the maximum number of edges specified in the options. func (c client) Query(ctx context.Context, model any) *dg.Query { + model = UnwrapSchema(model) client, err := c.pool.get() if err != nil { return nil @@ -590,6 +596,9 @@ func (c client) Query(ctx context.Context, model any) *dg.Query { // If any object contains SimString fields tagged `dgraph:"embedding"`, the // corresponding shadow float32vector predicates (__vec) are also registered. func (c client) UpdateSchema(ctx context.Context, obj ...any) error { + for i := range obj { + obj[i] = UnwrapSchema(obj[i]) + } dgClient, err := c.pool.get() if err != nil { c.logger.Error(err, "Failed to get client from pool") diff --git a/record.go b/record.go new file mode 100644 index 0000000..015c587 --- /dev/null +++ b/record.go @@ -0,0 +1,58 @@ +package modusgraph + +import "reflect" + +// Schema identifies a value as a record of a generated schema-defining type. +// modusgraph-gen-emitted schema structs implement this via a generated +// SchemaTypeName() method that returns the canonical entity name +// (e.g. "Studio"). The interface is intentionally minimal — a single method +// returning a useful piece of metadata. +// +// Plain user structs (not emitted by modusgraph-gen) do not implement Schema +// and are unaffected by the modusgraph.Client routing it enables; they pass +// through to the existing reflection-based dgman pipeline exactly as before. +type Schema interface { + SchemaTypeName() string +} + +// UnwrapSchema returns the schema-defining record contained in obj. If obj +// is nil, it is returned as-is. If obj is already a Schema, it is returned +// as-is. If obj exposes an Unwrap() method whose return value satisfies +// Schema, that return is substituted. Otherwise obj is returned unchanged. +// +// This is the bridge between modusgraph-gen-emitted wrapper types and the +// rest of modusgraph.Client. It is purely additive: types that don't +// implement Schema and don't have an Unwrap() method (i.e. existing +// modusgraph users' plain structs) pass through untouched. +// +// Note on errors.Unwrap overlap: Go's errors package uses Unwrap() error +// as the standard "give me the wrapped thing" method. UnwrapSchema's +// secondary check (the returned value must itself implement Schema) means +// an error wrapper is not mistaken for a modusgraph wrapper — the +// reflection probe finds Unwrap(), calls it, gets an error, fails the +// Schema check, and returns the original obj. +func UnwrapSchema(obj any) any { + if obj == nil { + return obj + } + if _, ok := obj.(Schema); ok { + return obj + } + v := reflect.ValueOf(obj) + if !v.IsValid() { + return obj + } + m := v.MethodByName("Unwrap") + if !m.IsValid() { + return obj + } + mt := m.Type() + if mt.NumIn() != 0 || mt.NumOut() != 1 { + return obj + } + inner := m.Call(nil)[0].Interface() + if _, ok := inner.(Schema); ok { + return inner + } + return obj +} diff --git a/record_test.go b/record_test.go new file mode 100644 index 0000000..1f6ef72 --- /dev/null +++ b/record_test.go @@ -0,0 +1,117 @@ +package modusgraph + +import ( + "errors" + "testing" +) + +type fakeRecord struct{ name string } + +func (f *fakeRecord) SchemaTypeName() string { return f.name } + +type fakeWrapper struct{ inner *fakeRecord } + +func (w *fakeWrapper) Unwrap() *fakeRecord { return w.inner } + +type fakeNonSchema struct{ X string } + +func TestUnwrapSchema_PassthroughForPlainStruct(t *testing.T) { + in := &fakeNonSchema{X: "hi"} + out := UnwrapSchema(in) + if out != any(in) { + t.Fatalf("expected passthrough, got %T", out) + } +} + +func TestUnwrapSchema_PassthroughForSchemaStruct(t *testing.T) { + in := &fakeRecord{name: "Studio"} + out := UnwrapSchema(in) + if out != any(in) { + t.Fatalf("expected passthrough for direct Schema, got %T", out) + } +} + +func TestUnwrapSchema_UnwrapsWrapper(t *testing.T) { + inner := &fakeRecord{name: "Studio"} + w := &fakeWrapper{inner: inner} + out := UnwrapSchema(w) + if out != any(inner) { + t.Fatalf("expected unwrapped inner, got %T (%v)", out, out) + } +} + +func TestUnwrapSchema_IgnoresErrorsUnwrap(t *testing.T) { + // errors.New("x") has no Unwrap; wrap one to get something with Unwrap() error. + inner := errors.New("inner") + outer := &wrappedErr{err: inner} + out := UnwrapSchema(outer) + if out != any(outer) { + t.Fatalf("expected passthrough for error wrapper, got %T", out) + } +} + +type wrappedErr struct{ err error } + +func (w *wrappedErr) Error() string { return w.err.Error() } +func (w *wrappedErr) Unwrap() error { return w.err } + +func TestUnwrapSchema_NilInput(t *testing.T) { + if out := UnwrapSchema(nil); out != nil { + t.Fatalf("expected nil for nil input, got %v", out) + } +} + +// recordingClient is the minimal surface needed to verify that wrappers +// passed to the Client interface get unwrapped before reaching internal +// reflection. It records whatever it received and returns nil. Each method +// applies obj = UnwrapSchema(obj) at the top, mirroring the patch landing +// in this task. +type recordingClient struct { + seen []any +} + +func (c *recordingClient) capture(obj any) any { + obj = UnwrapSchema(obj) + c.seen = append(c.seen, obj) + return obj +} + +func TestUnwrapSchema_CaptureForwardsInner(t *testing.T) { + inner := &fakeRecord{name: "Studio"} + w := &fakeWrapper{inner: inner} + c := &recordingClient{} + got := c.capture(w) + if got != any(inner) { + t.Fatalf("expected inner record, got %T (%v)", got, got) + } + if len(c.seen) != 1 || c.seen[0] != any(inner) { + t.Fatalf("expected recording to hold inner record, got %v", c.seen) + } +} + +func TestUnwrapSchema_CapturePassthroughForPlain(t *testing.T) { + plain := &fakeNonSchema{X: "y"} + c := &recordingClient{} + got := c.capture(plain) + if got != any(plain) { + t.Fatalf("expected plain struct passthrough, got %T", got) + } +} + +func TestUnwrapSchema_VariadicUnwrapsEachElement(t *testing.T) { + innerA := &fakeRecord{name: "Studio"} + innerB := &fakeRecord{name: "Film"} + templates := []any{ + &fakeWrapper{inner: innerA}, + innerB, // already a Schema; passthrough + } + for i, obj := range templates { + templates[i] = UnwrapSchema(obj) + } + if templates[0] != any(innerA) { + t.Fatalf("template[0]: expected innerA, got %T", templates[0]) + } + if templates[1] != any(innerB) { + t.Fatalf("template[1]: expected innerB (passthrough), got %T", templates[1]) + } +} From a4997cf7a0ae6b8b2c6d75d09f068c13bd5879c0 Mon Sep 17 00:00:00 2001 From: Michael Welles Date: Thu, 4 Jun 2026 16:00:04 -0400 Subject: [PATCH 04/24] ci: drop redundant Dgraph standalone from -short unit job The unit-test job runs `go test -short`, which skips every test that needs a live Dgraph. Standing up a dgraph/standalone container (and setting MODUSGRAPH_TEST_ADDR) therefore adds setup the job never uses. Remove both; the integration and load suites keep their own dedicated jobs. --- .github/workflows/ci-go-unit-tests.yaml | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/.github/workflows/ci-go-unit-tests.yaml b/.github/workflows/ci-go-unit-tests.yaml index 3623538..31878ab 100644 --- a/.github/workflows/ci-go-unit-tests.yaml +++ b/.github/workflows/ci-go-unit-tests.yaml @@ -39,22 +39,5 @@ jobs: go-version: 1.25.0 cache-dependency-path: go.sum - - name: Set up Dgraph - if: matrix.os == 'linux' - run: | - docker run -d --name dgraph-standalone -p 9080:9080 -p 8080:8080 dgraph/standalone:latest - echo "Waiting for Dgraph to be ready..." - for i in {1..30}; do - if curl -s http://localhost:8080/health > /dev/null; then - echo "Dgraph is ready!" - break - fi - echo "Attempt $i: Dgraph not ready, waiting..." - sleep 2 - done - sleep 5 - - name: Run Unit Tests - env: - MODUSGRAPH_TEST_ADDR: ${{ matrix.os == 'linux' && 'localhost:9080' || '' }} run: go test -short -race -v . From 616065df67aa02200717d6ad71e3f4f75b641720 Mon Sep 17 00:00:00 2001 From: Michael Welles Date: Thu, 4 Jun 2026 16:00:09 -0400 Subject: [PATCH 05/24] chore: ignore IDE dirs, query binary, benchmark output, worktrees Add common local artifacts to .gitignore: editor config (.idea/, .vscode/), the built ./query binary, load_test benchmark JSON, and git worktrees. --- .gitignore | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/.gitignore b/.gitignore index a63304e..b75ce0d 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,16 @@ go.work.sum .env cpu_profile.prof + +# IDE config +.idea/ +.vscode/ + +# Built query binary +/query + +# Benchmark result files +load_test/*.json + +# git worktrees +.worktrees/ From eee3d018f8f11445e93abbe2e2804529a541993d Mon Sep 17 00:00:00 2001 From: Michael Welles Date: Thu, 4 Jun 2026 16:06:19 -0400 Subject: [PATCH 06/24] feat: WithGRPCDialOption for custom gRPC dial settings MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds WithGRPCDialOption(opt grpc.DialOption), a general escape hatch for gRPC dial settings the dedicated options do not cover — TLS transport credentials, interceptors, keepalive, and so on — on remote (dgraph://) connections. The existing WithMaxRecvMsgSize is folded into the same dial-option assembly, so the two compose cleanly, and the client dedup key now counts the custom dial options so differently-configured clients are not merged. No change for embedded (file://) URIs. --- client.go | 33 ++++++++++++++++++++++++++++----- dial_options_test.go | 30 ++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 5 deletions(-) create mode 100644 dial_options_test.go diff --git a/client.go b/client.go index be9813b..7834db2 100644 --- a/client.go +++ b/client.go @@ -124,6 +124,7 @@ type clientOptions struct { maxEdgeTraversal int cacheSizeMB int maxRecvMsgSize int + grpcDialOptions []grpc.DialOption namespace string logger logr.Logger validator StructValidator @@ -189,6 +190,18 @@ func WithMaxRecvMsgSize(size int) ClientOpt { } } +// WithGRPCDialOption appends a custom grpc.DialOption applied when opening a +// remote (dgraph://) connection. It is the general escape hatch for gRPC dial +// settings the dedicated options do not cover — TLS transport credentials, +// interceptors, keepalive parameters, and so on. May be supplied multiple +// times; the options are applied in the order given, after any option implied +// by WithMaxRecvMsgSize. Ignored for embedded (file://) URIs. +func WithGRPCDialOption(opt grpc.DialOption) ClientOpt { + return func(o *clientOptions) { + o.grpcDialOptions = append(o.grpcDialOptions, opt) + } +} + // WithValidator sets a validator instance for struct validation. // The validator will be used to validate structs before insert, upsert, and update operations. // If no validator is provided, validation will be skipped. @@ -279,16 +292,26 @@ func NewClient(uri string, opts ...ClientOpt) (Client, error) { client.logger.V(2).Info("Opening new Dgraph connection", "uri", uri) return dgo.Open(uri) } + // Assemble any custom gRPC dial options. maxRecvMsgSize is folded + // into the same mechanism as WithGRPCDialOption so the two compose. + var dialOpts []grpc.DialOption if options.maxRecvMsgSize > 0 { + dialOpts = append(dialOpts, + grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(options.maxRecvMsgSize))) + } + dialOpts = append(dialOpts, options.grpcDialOptions...) + if len(dialOpts) > 0 { endpoint, dgoOpts, err := parseDgraphURI(uri) if err != nil { return nil, err } - dgoOpts = append(dgoOpts, dgo.WithGrpcOption( - grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(options.maxRecvMsgSize)))) + for _, opt := range dialOpts { + dgoOpts = append(dgoOpts, dgo.WithGrpcOption(opt)) + } factory = func() (*dgo.Dgraph, error) { client.logger.V(2).Info("Opening new Dgraph connection", - "uri", uri, "maxRecvMsgSize", options.maxRecvMsgSize) + "uri", uri, "maxRecvMsgSize", options.maxRecvMsgSize, + "grpcDialOptions", len(options.grpcDialOptions)) return dgo.NewClient(endpoint, dgoOpts...) } } @@ -430,9 +453,9 @@ func (c client) key() string { if c.options.embeddingProvider != nil { embeddingKey = fmt.Sprintf("%p", c.options.embeddingProvider) } - return fmt.Sprintf("%s:%t:%d:%d:%d:%d:%s:%s:%s", c.uri, c.options.autoSchema, c.options.poolSize, + return fmt.Sprintf("%s:%t:%d:%d:%d:%d:%s:%s:%s:%d", c.uri, c.options.autoSchema, c.options.poolSize, c.options.maxEdgeTraversal, c.options.cacheSizeMB, c.options.maxRecvMsgSize, - c.options.namespace, validatorKey, embeddingKey) + c.options.namespace, validatorKey, embeddingKey, len(c.options.grpcDialOptions)) } // embeddingProvider implements the embeddingClient interface, exposing the diff --git a/dial_options_test.go b/dial_options_test.go new file mode 100644 index 0000000..c64e257 --- /dev/null +++ b/dial_options_test.go @@ -0,0 +1,30 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package modusgraph + +import ( + "testing" + + "google.golang.org/grpc" +) + +func TestWithGRPCDialOptionAppends(t *testing.T) { + var o clientOptions + WithGRPCDialOption(grpc.WithUserAgent("a"))(&o) + WithGRPCDialOption(grpc.WithUserAgent("b"))(&o) + if got := len(o.grpcDialOptions); got != 2 { + t.Fatalf("expected 2 dial options, got %d", got) + } +} + +func TestKeyDistinguishesGRPCDialOptions(t *testing.T) { + base := client{uri: "dgraph://localhost:9080"} + withOpt := client{uri: "dgraph://localhost:9080"} + WithGRPCDialOption(grpc.WithUserAgent("x"))(&withOpt.options) + if base.key() == withOpt.key() { + t.Fatal("client.key() must differ when grpcDialOptions differ, else clients dedup incorrectly") + } +} From bd16559b97c3206c9f4897318ed994b6b9c2cb7c Mon Sep 17 00:00:00 2001 From: Michael Welles Date: Thu, 4 Jun 2026 16:11:23 -0400 Subject: [PATCH 07/24] feat: AlterSchema, dropPredicate, and embedded DropAttr MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds raw schema-DDL primitives that complement UpdateSchema's object-template inference: - Client.AlterSchema(ctx, schema) applies a raw DQL schema string directly, giving full control over predicate types, indexes, and directives — useful for migrations that declare predicates no Go type models yet. - Engine.dropPredicate deletes a single predicate (and its data) from the embedded engine via posting.DeletePredicate. - embedded_client.go routes an Alter carrying DropAttr to dropPredicate, so the embedded path matches a remote Dgraph cluster's DropAttr behavior. TestDropPredicateEmbedded exercises the full declare/insert/drop cycle against the embedded engine. --- client.go | 19 +++++++++++ embedded_client.go | 6 ++++ engine.go | 19 +++++++++++ schema_ddl_test.go | 84 ++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 128 insertions(+) create mode 100644 schema_ddl_test.go diff --git a/client.go b/client.go index be9813b..2189a3a 100644 --- a/client.go +++ b/client.go @@ -69,6 +69,12 @@ type Client interface { // Pass one or more objects that will be used as templates for the schema. UpdateSchema(context.Context, ...any) error + // AlterSchema applies a raw Dgraph Schema Definition Language string directly, + // bypassing the object-template inference of UpdateSchema. Use it when you need + // full control over predicate types, indexes, and directives — for example, + // schema migrations that declare predicates no Go type models yet. + AlterSchema(ctx context.Context, schema string) error + // GetSchema retrieves the current schema definition from the database. // Returns a string containing the full schema in Dgraph Schema Definition Language. GetSchema(context.Context) (string, error) @@ -585,6 +591,19 @@ func (c client) Query(ctx context.Context, model any) *dg.Query { return txn.Get(model).All(c.options.maxEdgeTraversal) } +// AlterSchema applies a raw DQL schema string directly via Dgraph Alter, +// without the object-template inference performed by UpdateSchema. +func (c client) AlterSchema(ctx context.Context, schema string) error { + dgClient, err := c.pool.get() + if err != nil { + c.logger.Error(err, "Failed to get client from pool") + return err + } + defer c.pool.put(dgClient) + + return dgClient.Alter(ctx, &api.Operation{Schema: schema}) +} + // UpdateSchema implements updating the Dgraph schema. Pass one or more // objects that will be used to generate the schema. // If any object contains SimString fields tagged `dgraph:"embedding"`, the diff --git a/embedded_client.go b/embedded_client.go index 329f4bf..fbe993d 100644 --- a/embedded_client.go +++ b/embedded_client.go @@ -146,6 +146,12 @@ func (c *embeddedDgraphClient) Alter( } return &api.Payload{}, nil } + if in.DropAttr != "" { + if err := c.engine.dropPredicate(ctx, c.ns, in.DropAttr); err != nil { + return nil, err + } + return &api.Payload{}, nil + } if in.Schema != "" { if err := c.engine.alterSchema(ctx, c.ns, in.Schema); err != nil { return nil, err diff --git a/engine.go b/engine.go index d9d236d..34e7f41 100644 --- a/engine.go +++ b/engine.go @@ -271,6 +271,25 @@ func (engine *Engine) dropData(ctx context.Context, ns *Namespace) error { return nil } +// dropPredicate deletes a single predicate (and its data) from the embedded +// engine — the in-process equivalent of a gRPC Alter with DropAttr set. +func (engine *Engine) dropPredicate(ctx context.Context, ns *Namespace, pred string) error { + engine.mutex.Lock() + defer engine.mutex.Unlock() + + if !engine.isOpen.Load() { + return ErrClosedEngine + } + + startTs, err := engine.z.nextTs() + if err != nil { + return err + } + + nsAttr := x.NamespaceAttr(ns.ID(), pred) + return posting.DeletePredicate(ctx, nsAttr, startTs) +} + func (engine *Engine) alterSchema(ctx context.Context, ns *Namespace, sch string) error { engine.mutex.Lock() defer engine.mutex.Unlock() diff --git a/schema_ddl_test.go b/schema_ddl_test.go new file mode 100644 index 0000000..8315dd2 --- /dev/null +++ b/schema_ddl_test.go @@ -0,0 +1,84 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package modusgraph_test + +import ( + "context" + "os" + "testing" + + "github.com/dgraph-io/dgo/v250/protos/api" + "github.com/stretchr/testify/require" +) + +// TestDropPredicateEmbedded exercises the schema-DDL surface end-to-end: +// Client.AlterSchema declares a raw predicate, and a gRPC Alter with DropAttr +// routes through embedded_client.go's DropAttr arm into engine.dropPredicate. +func TestDropPredicateEmbedded(t *testing.T) { + testCases := []struct { + name string + uri string + skip bool + }{ + { + name: "DropPredicateWithFileURI", + uri: "file://" + GetTempDir(t), + }, + { + name: "DropPredicateWithDgraphURI", + uri: "dgraph://" + os.Getenv("MODUSGRAPH_TEST_ADDR"), + skip: os.Getenv("MODUSGRAPH_TEST_ADDR") == "", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if tc.skip { + t.Skipf("Skipping %s: MODUSGRAPH_TEST_ADDR not set", tc.name) + return + } + + client, cleanup := CreateTestClient(t, tc.uri) + defer cleanup() + + ctx := context.Background() + + // Declare an indexed string predicate and insert a node carrying it. + err := client.AlterSchema(ctx, "dropme: string @index(exact) .") + require.NoError(t, err, "AlterSchema should succeed") + + dg, dgCleanup, err := client.DgraphClient() + require.NoError(t, err, "DgraphClient should succeed") + defer dgCleanup() + + _, err = dg.NewTxn().Mutate(ctx, &api.Mutation{ + SetJson: []byte(`[{"dropme":"hello"}]`), + CommitNow: true, + }) + require.NoError(t, err, "mutate should succeed") + + // Confirm the predicate is present before the drop. + raw, err := client.QueryRaw(ctx, `{ q(func: has(dropme)) { c: count(uid) } }`, nil) + require.NoError(t, err, "count query should succeed") + require.Contains(t, string(raw), `"c":1`, "predicate present before drop") + + // Drop the predicate via the public path; this exercises + // embedded_client.go's DropAttr arm + engine.dropPredicate. + err = dg.Alter(ctx, &api.Operation{DropAttr: "dropme"}) + require.NoError(t, err, "DropAttr should succeed") + + // Confirm the data is gone (no nodes have the predicate). + raw, err = client.QueryRaw(ctx, `{ q(func: has(dropme)) { c: count(uid) } }`, nil) + require.NoError(t, err, "count query should succeed after drop") + require.Contains(t, string(raw), `"c":0`, "predicate values gone after drop") + + // Confirm the schema entry is gone. + raw, err = client.QueryRaw(ctx, `schema(pred: [dropme]) { type }`, nil) + require.NoError(t, err, "schema query should succeed after drop") + require.NotContains(t, string(raw), "dropme", "predicate schema entry gone after drop") + }) + } +} From 501b3ef55db7cb7c7b98b628293b66ece37cb492 Mon Sep 17 00:00:00 2001 From: Michael Welles Date: Thu, 4 Jun 2026 16:14:08 -0400 Subject: [PATCH 08/24] feat: SelfValidator for private-field validation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds SelfValidator, an opt-in seam that lets a type drive its own validation. When a value passed to Insert, Upsert, or Update implements SelfValidator, the client calls ValidateWith instead of handing the value straight to the configured StructValidator. validateStruct now routes each element through a new validateOne helper that detects SelfValidator (on the value or its address) and otherwise falls back to StructCtx exactly as before — behavior is unchanged for ordinary structs. This is the runtime seam generated entities use to validate unexported fields: the generated ValidateWith builds a mirror struct with exported fields the go-playground validator can read by reflection. --- client.go | 30 +++++++++++++++++-- self_validator_test.go | 65 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 93 insertions(+), 2 deletions(-) create mode 100644 self_validator_test.go diff --git a/client.go b/client.go index be9813b..5edeaf1 100644 --- a/client.go +++ b/client.go @@ -109,6 +109,16 @@ type StructValidator interface { StructCtx(ctx context.Context, s interface{}) error } +// SelfValidator lets a type drive its own validation. When a value passed to +// Insert, Upsert, or Update implements SelfValidator, the client calls +// ValidateWith instead of handing the value straight to the configured +// StructValidator. This is the seam generated entities use to validate private +// fields: the generated ValidateWith builds a mirror struct with exported +// fields the underlying go-playground validator can read by reflection. +type SelfValidator interface { + ValidateWith(ctx context.Context, v StructValidator) error +} + // clientOptions holds configuration options for the client. // // autoSchema: whether to automatically manage the schema. @@ -472,17 +482,33 @@ func (c client) validateStruct(ctx context.Context, obj any) error { } elem = elem.Elem() } - if err := c.options.validator.StructCtx(ctx, elem.Interface()); err != nil { + if err := c.validateOne(ctx, elem); err != nil { return err } } } else { - return c.options.validator.StructCtx(ctx, obj) + return c.validateOne(ctx, val) } return nil } +// validateOne validates a single struct value. If the value (or its address) +// implements SelfValidator, validation is delegated to ValidateWith so the type +// can validate fields the configured StructValidator cannot reach directly — +// for example unexported fields exposed through a generated mirror struct. +// Otherwise the value is validated by the configured StructValidator as usual. +func (c client) validateOne(ctx context.Context, val reflect.Value) error { + iface := val.Interface() + if val.CanAddr() { + iface = val.Addr().Interface() + } + if sv, ok := iface.(SelfValidator); ok { + return sv.ValidateWith(ctx, c.options.validator) + } + return c.options.validator.StructCtx(ctx, iface) +} + // Insert implements inserting an object or slice of objects in the database. // Passed object must be a pointer to a struct with appropriate dgraph tags. func (c client) Insert(ctx context.Context, obj any) error { diff --git a/self_validator_test.go b/self_validator_test.go new file mode 100644 index 0000000..b7620ad --- /dev/null +++ b/self_validator_test.go @@ -0,0 +1,65 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package modusgraph + +import ( + "context" + "errors" + "testing" +) + +// recordingValidator counts StructCtx calls so tests can assert which path ran. +type recordingValidator struct{ calls int } + +func (r *recordingValidator) StructCtx(_ context.Context, _ interface{}) error { + r.calls++ + return nil +} + +var errSelfValidated = errors.New("self-validated") + +type selfValidatingEntity struct{ Name string } + +func (s *selfValidatingEntity) ValidateWith(_ context.Context, _ StructValidator) error { + return errSelfValidated +} + +type plainEntity struct{ Name string } + +func TestValidateRoutesToSelfValidator(t *testing.T) { + rv := &recordingValidator{} + c := client{options: clientOptions{validator: rv}} + + err := c.validateStruct(context.Background(), &selfValidatingEntity{Name: "x"}) + if !errors.Is(err, errSelfValidated) { + t.Fatalf("expected the SelfValidator path, got %v", err) + } + if rv.calls != 0 { + t.Fatalf("StructCtx must not run for a SelfValidator, got %d calls", rv.calls) + } +} + +func TestValidateFallsBackToStructCtx(t *testing.T) { + rv := &recordingValidator{} + c := client{options: clientOptions{validator: rv}} + + if err := c.validateStruct(context.Background(), &plainEntity{Name: "x"}); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if rv.calls != 1 { + t.Fatalf("expected StructCtx to run once, got %d", rv.calls) + } +} + +func TestValidateSelfValidatorInSlice(t *testing.T) { + rv := &recordingValidator{} + c := client{options: clientOptions{validator: rv}} + + err := c.validateStruct(context.Background(), []*selfValidatingEntity{{Name: "a"}}) + if !errors.Is(err, errSelfValidated) { + t.Fatalf("expected the SelfValidator path for slice elements, got %v", err) + } +} From 504231c1689251b327eafcc5f1d4382970d25202 Mon Sep 17 00:00:00 2001 From: Michael Welles Date: Thu, 4 Jun 2026 16:14:08 -0400 Subject: [PATCH 09/24] feat: SelfValidator for custom and cross-field validation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds SelfValidator, an opt-in seam that lets a type drive its own validation. When a value passed to Insert, Upsert, or Update implements SelfValidator, the client calls ValidateWith instead of handing the value straight to the configured StructValidator. This covers validation that struct tags cannot express on their own: cross-field rules (one field constrained by another), conditional rules, checks on computed or setter-derived values, and broader business rules. ValidateWith receives the configured StructValidator, so an implementation can still run ordinary tag-based checks and layer custom logic on top. validateStruct routes each element through a new validateOne helper that detects SelfValidator (on the value or its address) and otherwise falls back to StructCtx exactly as before — behavior is unchanged for ordinary structs. --- client.go | 34 +++++++++++++++- self_validator_test.go | 91 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 123 insertions(+), 2 deletions(-) create mode 100644 self_validator_test.go diff --git a/client.go b/client.go index be9813b..7542ceb 100644 --- a/client.go +++ b/client.go @@ -109,6 +109,20 @@ type StructValidator interface { StructCtx(ctx context.Context, s interface{}) error } +// SelfValidator lets a type drive its own validation. When a value passed to +// Insert, Upsert, or Update implements SelfValidator, the client calls +// ValidateWith instead of handing the value straight to the configured +// StructValidator. +// +// This is the seam for validation that struct tags cannot express on their own: +// cross-field rules (one field constrained by another), conditional rules, +// checks on computed or setter-derived values, and broader business rules. +// ValidateWith receives the configured StructValidator, so an implementation can +// still run ordinary tag-based validation and then layer custom logic on top. +type SelfValidator interface { + ValidateWith(ctx context.Context, v StructValidator) error +} + // clientOptions holds configuration options for the client. // // autoSchema: whether to automatically manage the schema. @@ -472,17 +486,33 @@ func (c client) validateStruct(ctx context.Context, obj any) error { } elem = elem.Elem() } - if err := c.options.validator.StructCtx(ctx, elem.Interface()); err != nil { + if err := c.validateOne(ctx, elem); err != nil { return err } } } else { - return c.options.validator.StructCtx(ctx, obj) + return c.validateOne(ctx, val) } return nil } +// validateOne validates a single struct value. If the value (or its address) +// implements SelfValidator, validation is delegated to ValidateWith so the type +// can apply custom rules — cross-field, conditional, computed-value, or other +// logic beyond struct tags. Otherwise the value is validated by the configured +// StructValidator as usual. +func (c client) validateOne(ctx context.Context, val reflect.Value) error { + iface := val.Interface() + if val.CanAddr() { + iface = val.Addr().Interface() + } + if sv, ok := iface.(SelfValidator); ok { + return sv.ValidateWith(ctx, c.options.validator) + } + return c.options.validator.StructCtx(ctx, iface) +} + // Insert implements inserting an object or slice of objects in the database. // Passed object must be a pointer to a struct with appropriate dgraph tags. func (c client) Insert(ctx context.Context, obj any) error { diff --git a/self_validator_test.go b/self_validator_test.go new file mode 100644 index 0000000..bf5b5b4 --- /dev/null +++ b/self_validator_test.go @@ -0,0 +1,91 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package modusgraph + +import ( + "context" + "errors" + "fmt" + "testing" +) + +// recordingValidator counts StructCtx calls so tests can assert which path ran. +type recordingValidator struct{ calls int } + +func (r *recordingValidator) StructCtx(_ context.Context, _ interface{}) error { + r.calls++ + return nil +} + +var errSelfValidated = errors.New("self-validated") + +type selfValidatingEntity struct{ Name string } + +func (s *selfValidatingEntity) ValidateWith(_ context.Context, _ StructValidator) error { + return errSelfValidated +} + +type plainEntity struct{ Name string } + +func TestValidateRoutesToSelfValidator(t *testing.T) { + rv := &recordingValidator{} + c := client{options: clientOptions{validator: rv}} + + err := c.validateStruct(context.Background(), &selfValidatingEntity{Name: "x"}) + if !errors.Is(err, errSelfValidated) { + t.Fatalf("expected the SelfValidator path, got %v", err) + } + if rv.calls != 0 { + t.Fatalf("StructCtx must not run for a SelfValidator, got %d calls", rv.calls) + } +} + +func TestValidateFallsBackToStructCtx(t *testing.T) { + rv := &recordingValidator{} + c := client{options: clientOptions{validator: rv}} + + if err := c.validateStruct(context.Background(), &plainEntity{Name: "x"}); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if rv.calls != 1 { + t.Fatalf("expected StructCtx to run once, got %d", rv.calls) + } +} + +func TestValidateSelfValidatorInSlice(t *testing.T) { + rv := &recordingValidator{} + c := client{options: clientOptions{validator: rv}} + + err := c.validateStruct(context.Background(), []*selfValidatingEntity{{Name: "a"}}) + if !errors.Is(err, errSelfValidated) { + t.Fatalf("expected the SelfValidator path for slice elements, got %v", err) + } +} + +// dateRange validates a relationship between two fields — a cross-field rule +// that struct tags alone cannot express. +type dateRange struct { + Start int + End int +} + +func (d *dateRange) ValidateWith(_ context.Context, _ StructValidator) error { + if d.End < d.Start { + return fmt.Errorf("End (%d) must be >= Start (%d)", d.End, d.Start) + } + return nil +} + +func TestSelfValidatorCustomCrossFieldRule(t *testing.T) { + c := client{options: clientOptions{validator: &recordingValidator{}}} + + if err := c.validateStruct(context.Background(), &dateRange{Start: 1, End: 5}); err != nil { + t.Fatalf("a valid range should pass the cross-field rule: %v", err) + } + if err := c.validateStruct(context.Background(), &dateRange{Start: 5, End: 1}); err == nil { + t.Fatal("End < Start must fail the custom cross-field rule") + } +} From 0ae8003114a6eacfaa4e5aa73407767c29629898 Mon Sep 17 00:00:00 2001 From: Michael Welles Date: Mon, 8 Jun 2026 01:01:09 -0400 Subject: [PATCH 10/24] feat: add Client.LoadOrStore (insert-if-absent) --- client.go | 35 +++++++++++++++++++++++++++++++++ consume_test.go | 52 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 87 insertions(+) create mode 100644 consume_test.go diff --git a/client.go b/client.go index c78aff3..267a95c 100644 --- a/client.go +++ b/client.go @@ -46,6 +46,11 @@ type Client interface { // will be used. Upsert(context.Context, any, ...string) error + // LoadOrStore stores the object only if no node matches the upsert + // predicates, returning loaded=true when an existing node already matched + // (the object is then populated from it). Insert-if-absent. + LoadOrStore(ctx context.Context, obj any, predicates ...string) (loaded bool, err error) + // Update modifies an existing object in the database. // The object must be a pointer to a struct and must have a UID field set. Update(context.Context, any) error @@ -589,6 +594,36 @@ func (c client) Upsert(ctx context.Context, obj any, predicates ...string) error }) } +// LoadOrStore stores obj only if no node already matches the upsert predicates, +// reporting whether one already existed (loaded == true). Built on dgman +// MutateOrGet, which returns the UIDs of newly created nodes only: an empty +// result means an existing node matched, and obj is populated with its fields. +// With no predicates, the first field tagged dgraph:"upsert" is used. +func (c client) LoadOrStore(ctx context.Context, obj any, predicates ...string) (loaded bool, err error) { + obj = UnwrapSchema(obj) + if err := c.validateStruct(ctx, obj); err != nil { + return false, err + } + + dgClient, err := c.pool.get() + if err != nil { + c.logger.Error(err, "Failed to get client from pool") + return false, err + } + defer c.pool.put(dgClient) + + tx := dg.NewTxnContext(ctx, dgClient).SetCommitNow() + uids, err := tx.MutateOrGet(obj, predicates...) + if err != nil { + if uniqueErr := parseUniqueError(err); uniqueErr != nil { + return false, uniqueErr + } + return false, err + } + // MutateOrGet returns created UIDs only; empty => an existing node matched. + return len(uids) == 0, nil +} + // Update implements updating an existing object in the database. // Passed object must be a pointer to a struct. func (c client) Update(ctx context.Context, obj any) error { diff --git a/consume_test.go b/consume_test.go new file mode 100644 index 0000000..97862e2 --- /dev/null +++ b/consume_test.go @@ -0,0 +1,52 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package modusgraph_test + +import ( + "context" + "testing" + + "github.com/matthewmcneely/modusgraph" +) + +type consumeJTI struct { + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` + JTI string `json:"jti,omitempty" dgraph:"index=hash upsert unique"` +} + +func newConsumeClient(t *testing.T) modusgraph.Client { + t.Helper() + conn, err := modusgraph.NewClient("file://"+t.TempDir(), modusgraph.WithAutoSchema(true)) + if err != nil { + t.Fatalf("NewClient: %v", err) + } + t.Cleanup(conn.Close) + return conn +} + +func TestLoadOrStore(t *testing.T) { + conn := newConsumeClient(t) + ctx := context.Background() + + first := &consumeJTI{JTI: "abc"} + loaded, err := conn.LoadOrStore(ctx, first, "jti") + if err != nil { + t.Fatalf("first LoadOrStore: %v", err) + } + if loaded { + t.Fatal("first store: want loaded=false (newly created)") + } + + second := &consumeJTI{JTI: "abc"} + loaded, err = conn.LoadOrStore(ctx, second, "jti") + if err != nil { + t.Fatalf("second LoadOrStore: %v", err) + } + if !loaded { + t.Fatal("second store: want loaded=true (already existed)") + } +} From d7811e1cc0686463ce78f55258460f90575cb703 Mon Sep 17 00:00:00 2001 From: Michael Welles Date: Mon, 8 Jun 2026 01:04:29 -0400 Subject: [PATCH 11/24] feat: add Client.LoadAndDelete (atomic read-and-consume) --- client.go | 136 ++++++++++++++++++++++++++++++++++++++++++++++++ consume_test.go | 37 +++++++++++++ 2 files changed, 173 insertions(+) diff --git a/client.go b/client.go index 267a95c..51b52e2 100644 --- a/client.go +++ b/client.go @@ -51,6 +51,11 @@ type Client interface { // (the object is then populated from it). Insert-if-absent. LoadOrStore(ctx context.Context, obj any, predicates ...string) (loaded bool, err error) + // LoadAndDelete atomically reads the node whose key predicate equals key + // into obj and deletes it, returning loaded=false when none matched. + // Read-and-consume; concurrent callers elect one winner. + LoadAndDelete(ctx context.Context, obj any, key any, predicates ...string) (loaded bool, err error) + // Update modifies an existing object in the database. // The object must be a pointer to a struct and must have a UID field set. Update(context.Context, any) error @@ -624,6 +629,137 @@ func (c client) LoadOrStore(ctx context.Context, obj any, predicates ...string) return len(uids) == 0, nil } +// firstUpsertPredicate returns the Dgraph predicate name of the first field +// tagged dgraph:"...upsert...". The predicate defaults to the json tag name +// unless an explicit predicate= token is present. It returns "" if no upsert +// field exists. +func firstUpsertPredicate(obj any) string { + v := reflect.ValueOf(obj) + for v.Kind() == reflect.Ptr { + v = v.Elem() + } + t := v.Type() + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + dgTag := f.Tag.Get("dgraph") + if !strings.Contains(dgTag, "upsert") { + continue + } + // Explicit predicate= wins. + for _, directive := range strings.Fields(dgTag) { + if strings.HasPrefix(directive, "predicate=") { + return strings.TrimPrefix(directive, "predicate=") + } + } + // Otherwise fall back to the json tag name. + if jsonTag := f.Tag.Get("json"); jsonTag != "" { + return strings.Split(jsonTag, ",")[0] + } + return f.Name + } + return "" +} + +// uidOf reflects out the UID field of a dgraph struct pointer. +func uidOf(obj any) string { + v := reflect.ValueOf(obj) + for v.Kind() == reflect.Ptr { + v = v.Elem() + } + f := v.FieldByName("UID") + if f.IsValid() && f.Kind() == reflect.String { + return f.String() + } + return "" +} + +// LoadAndDelete atomically reads the node whose key predicate equals key into +// obj and deletes it, returning loaded=false (and leaving obj zero) when no +// node matched. The read and delete share one transaction with no CommitNow, +// so two concurrent callers conflict on commit: exactly one wins (loaded=true), +// the loser aborts and retries into not-found (loaded=false). This reproduces +// PostgreSQL's DELETE … RETURNING. With no predicates, the first dgraph:"upsert" +// field is used. +func (c client) LoadAndDelete(ctx context.Context, obj any, key any, predicates ...string) (loaded bool, err error) { + obj = UnwrapSchema(obj) + if err := checkPointer(obj); err != nil { + return false, err + } + + pred := "" + if len(predicates) > 0 { + pred = predicates[0] + } else { + pred = firstUpsertPredicate(obj) + } + if pred == "" { + return false, fmt.Errorf("LoadAndDelete: no key predicate (pass one or tag a field dgraph:\"upsert\")") + } + + dgClient, err := c.pool.get() + if err != nil { + c.logger.Error(err, "Failed to get client from pool") + return false, err + } + defer c.pool.put(dgClient) + + // Bounded retry: Dgraph aborts the loser of a commit conflict; the retry + // reads the node already gone and reports not-found. + const maxAttempts = 10 + for attempt := 0; ; attempt++ { + tx := dg.NewTxnContext(ctx, dgClient) + getErr := tx.Get(obj). + Filter("eq("+pred+", $1)", key). + All(c.options.maxEdgeTraversal). + Node() + if getErr != nil { + _ = tx.Discard() + // dgman returns ErrNodeNotFound when nothing matches. + if errors.Is(getErr, dg.ErrNodeNotFound) { + return false, nil + } + return false, getErr + } + + uid := uidOf(obj) + if uid == "" { + _ = tx.Discard() + return false, nil + } + + if delErr := tx.DeleteNode(uid); delErr != nil { + _ = tx.Discard() + return false, delErr + } + + if cErr := tx.Commit(); cErr != nil { + _ = tx.Discard() + if isAbortedErr(cErr) { + // Lost the race or a concurrent change; retry — the winner has + // already deleted the node, so the retry reads not-found. + if attempt < maxAttempts { + continue + } + } + return false, cErr + } + return true, nil + } +} + +// isAbortedErr reports whether err is a Dgraph transaction-conflict abort, +// matching both dgo's ErrAborted sentinel and the underlying message in case a +// wrapped or stringified form reaches us. +func isAbortedErr(err error) bool { + if err == nil { + return false + } + if errors.Is(err, dgo.ErrAborted) { + return true + } + return strings.Contains(strings.ToLower(err.Error()), "aborted") +} + // Update implements updating an existing object in the database. // Passed object must be a pointer to a struct. func (c client) Update(ctx context.Context, obj any) error { diff --git a/consume_test.go b/consume_test.go index 97862e2..479d0de 100644 --- a/consume_test.go +++ b/consume_test.go @@ -50,3 +50,40 @@ func TestLoadOrStore(t *testing.T) { t.Fatal("second store: want loaded=true (already existed)") } } + +type consumeState struct { + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` + State string `json:"state,omitempty" dgraph:"index=hash upsert"` + Secret string `json:"secret,omitempty"` +} + +func TestLoadAndDelete(t *testing.T) { + conn := newConsumeClient(t) + ctx := context.Background() + + if err := conn.Insert(ctx, &consumeState{State: "s1", Secret: "shh"}); err != nil { + t.Fatalf("Insert: %v", err) + } + + var got consumeState + loaded, err := conn.LoadAndDelete(ctx, &got, "s1", "state") + if err != nil { + t.Fatalf("LoadAndDelete: %v", err) + } + if !loaded { + t.Fatal("first consume: want loaded=true") + } + if got.Secret != "shh" { + t.Fatalf("want prior secret %q, got %q", "shh", got.Secret) + } + + var again consumeState + loaded, err = conn.LoadAndDelete(ctx, &again, "s1", "state") + if err != nil { + t.Fatalf("second LoadAndDelete: %v", err) + } + if loaded { + t.Fatal("second consume: want loaded=false (already consumed)") + } +} From cb5053803446244c82a2c52953e4e1ce2048e63b Mon Sep 17 00:00:00 2001 From: Michael Welles Date: Mon, 8 Jun 2026 01:05:31 -0400 Subject: [PATCH 12/24] feat: add typed Client[T].LoadOrStore --- typed/client.go | 14 +++++++++++ typed/consume_test.go | 54 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+) create mode 100644 typed/consume_test.go diff --git a/typed/client.go b/typed/client.go index c540f89..f6425b1 100644 --- a/typed/client.go +++ b/typed/client.go @@ -61,6 +61,20 @@ func (c *Client[T]) Upsert(ctx context.Context, rec *T, predicates ...string) (e return c.conn.Upsert(ctx, rec, predicates...) } +// LoadOrStore stores rec only if no node matches the upsert predicates, +// returning the resulting record and loaded=true when one already existed. +// Insert-if-absent (compare sync.Map.LoadOrStore). With no predicates, the +// first field tagged dgraph:"upsert" is used. +func (c *Client[T]) LoadOrStore(ctx context.Context, rec *T, predicates ...string) (out *T, loaded bool, err error) { + ctx, span := tracer.StartSpan(ctx, "loadOrStore", entityName[T]()) + defer func() { span.End(err) }() + loaded, err = c.conn.LoadOrStore(ctx, rec, predicates...) + if err != nil { + return nil, false, err + } + return rec, loaded, nil +} + // Delete removes the T with the given UID. func (c *Client[T]) Delete(ctx context.Context, uid string) (err error) { ctx, span := tracer.StartSpan(ctx, "delete", entityName[T]()) diff --git a/typed/consume_test.go b/typed/consume_test.go new file mode 100644 index 0000000..6504228 --- /dev/null +++ b/typed/consume_test.go @@ -0,0 +1,54 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package typed_test + +import ( + "context" + "testing" + + "github.com/matthewmcneely/modusgraph" + "github.com/matthewmcneely/modusgraph/typed" +) + +type jti struct { + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` + JTI string `json:"jti,omitempty" dgraph:"index=hash upsert unique"` +} + +func newTypedConn(t *testing.T) modusgraph.Client { + t.Helper() + conn, err := modusgraph.NewClient("file://"+t.TempDir(), modusgraph.WithAutoSchema(true)) + if err != nil { + t.Fatalf("NewClient: %v", err) + } + t.Cleanup(conn.Close) + return conn +} + +func TestTypedLoadOrStore(t *testing.T) { + c := typed.NewClient[jti](newTypedConn(t)) + ctx := context.Background() + + rec, loaded, err := c.LoadOrStore(ctx, &jti{JTI: "abc"}, "jti") + if err != nil { + t.Fatalf("first: %v", err) + } + if loaded { + t.Fatal("first: want loaded=false") + } + if rec.UID == "" { + t.Fatal("first: want a UID assigned") + } + + _, loaded, err = c.LoadOrStore(ctx, &jti{JTI: "abc"}, "jti") + if err != nil { + t.Fatalf("second: %v", err) + } + if !loaded { + t.Fatal("second: want loaded=true") + } +} From 56f09fd6196a77f48d4e8887b0e69741c5a4e95f Mon Sep 17 00:00:00 2001 From: Michael Welles Date: Mon, 8 Jun 2026 01:07:09 -0400 Subject: [PATCH 13/24] feat: add typed Client[T].LoadAndDelete --- typed/client.go | 15 +++++++++++++++ typed/consume_test.go | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/typed/client.go b/typed/client.go index f6425b1..e278712 100644 --- a/typed/client.go +++ b/typed/client.go @@ -75,6 +75,21 @@ func (c *Client[T]) LoadOrStore(ctx context.Context, rec *T, predicates ...strin return rec, loaded, nil } +// LoadAndDelete atomically reads the T whose key predicate equals key and +// deletes it, returning (nil, false, nil) when none matched. Read-and-consume +// (compare sync.Map.LoadAndDelete). With no predicates, the first field tagged +// dgraph:"upsert" is used. +func (c *Client[T]) LoadAndDelete(ctx context.Context, key any, predicates ...string) (rec *T, loaded bool, err error) { + ctx, span := tracer.StartSpan(ctx, "loadAndDelete", entityName[T]()) + defer func() { span.End(err) }() + var out T + loaded, err = c.conn.LoadAndDelete(ctx, &out, key, predicates...) + if err != nil || !loaded { + return nil, loaded, err + } + return &out, true, nil +} + // Delete removes the T with the given UID. func (c *Client[T]) Delete(ctx context.Context, uid string) (err error) { ctx, span := tracer.StartSpan(ctx, "delete", entityName[T]()) diff --git a/typed/consume_test.go b/typed/consume_test.go index 6504228..ae28116 100644 --- a/typed/consume_test.go +++ b/typed/consume_test.go @@ -52,3 +52,38 @@ func TestTypedLoadOrStore(t *testing.T) { t.Fatal("second: want loaded=true") } } + +type state struct { + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` + State string `json:"state,omitempty" dgraph:"index=hash upsert"` + Secret string `json:"secret,omitempty"` +} + +func TestTypedLoadAndDelete(t *testing.T) { + c := typed.NewClient[state](newTypedConn(t)) + ctx := context.Background() + + if err := c.Add(ctx, &state{State: "s1", Secret: "shh"}); err != nil { + t.Fatalf("Add: %v", err) + } + + rec, loaded, err := c.LoadAndDelete(ctx, "s1", "state") + if err != nil { + t.Fatalf("LoadAndDelete: %v", err) + } + if !loaded { + t.Fatal("first: want loaded=true") + } + if rec.Secret != "shh" { + t.Fatalf("want secret %q, got %q", "shh", rec.Secret) + } + + rec, loaded, err = c.LoadAndDelete(ctx, "s1", "state") + if err != nil { + t.Fatalf("second: %v", err) + } + if loaded || rec != nil { + t.Fatalf("second: want (nil, false), got (%v, %v)", rec, loaded) + } +} From 4ad3512d01f7202b860d7ffeb0c156feade8b791 Mon Sep 17 00:00:00 2001 From: Michael Welles Date: Mon, 8 Jun 2026 01:14:34 -0400 Subject: [PATCH 14/24] test: assert LoadAndDelete elects a single winner under contention Add TestLoadAndDeleteSingleWinner and serialize LoadAndDelete's read-then-delete critical section with a per-client mutex so exactly one in-process caller consumes a node. The embedded engine's commit path does no optimistic-concurrency conflict check, so the shared read-write transaction alone cannot abort losers; the lock guarantees single-winner semantics regardless of backend. --- client.go | 27 ++++++++++++++++++++++++--- typed/consume_test.go | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 3 deletions(-) diff --git a/client.go b/client.go index 51b52e2..40b18e9 100644 --- a/client.go +++ b/client.go @@ -303,9 +303,10 @@ func NewClient(uri string, opts ...ClientOpt) (Client, error) { } client := client{ - uri: uri, - options: options, - logger: options.logger, + uri: uri, + options: options, + logger: options.logger, + consumeMu: &sync.Mutex{}, } clientMapLock.Lock() @@ -471,6 +472,15 @@ type client struct { options clientOptions pool *clientPool logger logr.Logger + // consumeMu serializes LoadAndDelete's read-then-delete critical section so + // exactly one in-process caller consumes a given node. The client value is + // copied (value receivers, cached by value in clientMap), so the mutex is a + // pointer shared across every copy that shares this client's connection. + // Against a real Dgraph cluster the shared read-write transaction would also + // abort losers on commit conflict; this lock additionally guarantees + // single-winner semantics against the embedded engine, whose commit path + // performs no optimistic-concurrency conflict check. + consumeMu *sync.Mutex } func (c client) key() string { @@ -703,6 +713,17 @@ func (c client) LoadAndDelete(ctx context.Context, obj any, key any, predicates } defer c.pool.put(dgClient) + // Serialize the read-then-delete critical section across in-process callers. + // The shared read-write transaction already elects one winner against a real + // Dgraph cluster (the loser aborts on commit), but the embedded engine does + // no commit-time conflict check, so without this lock concurrent callers + // would each read the node and each report loaded=true. The lock makes + // read-and-consume atomic regardless of backend. + if c.consumeMu != nil { + c.consumeMu.Lock() + defer c.consumeMu.Unlock() + } + // Bounded retry: Dgraph aborts the loser of a commit conflict; the retry // reads the node already gone and reports not-found. const maxAttempts = 10 diff --git a/typed/consume_test.go b/typed/consume_test.go index ae28116..16eca08 100644 --- a/typed/consume_test.go +++ b/typed/consume_test.go @@ -7,6 +7,7 @@ package typed_test import ( "context" + "sync" "testing" "github.com/matthewmcneely/modusgraph" @@ -87,3 +88,38 @@ func TestTypedLoadAndDelete(t *testing.T) { t.Fatalf("second: want (nil, false), got (%v, %v)", rec, loaded) } } + +func TestLoadAndDeleteSingleWinner(t *testing.T) { + c := typed.NewClient[state](newTypedConn(t)) + ctx := context.Background() + if err := c.Add(ctx, &state{State: "race", Secret: "one"}); err != nil { + t.Fatalf("Add: %v", err) + } + + const racers = 8 + var wg sync.WaitGroup + wins := make([]bool, racers) + wg.Add(racers) + for i := 0; i < racers; i++ { + go func(i int) { + defer wg.Done() + _, loaded, err := c.LoadAndDelete(ctx, "race", "state") + if err != nil { + t.Errorf("racer %d: %v", i, err) + return + } + wins[i] = loaded + }(i) + } + wg.Wait() + + won := 0 + for _, w := range wins { + if w { + won++ + } + } + if won != 1 { + t.Fatalf("want exactly one winner, got %d", won) + } +} From 917634f3f1464d34fb7c3d8ca385850d7bb59501 Mon Sep 17 00:00:00 2001 From: Michael Welles Date: Wed, 17 Jun 2026 17:15:11 -0400 Subject: [PATCH 15/24] fix(typed): correct filter precedence, preserve roots under WhereEdge MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses review feedback on the typed query builder: - combineAnd now parenthesizes each accumulated @filter fragment. Without this, a fragment containing OR ANDed with another fragment rendered as "a OR b AND c", which dgraph parses as "a OR (b AND c)" — silently widening results. dgman exposes only a string Filter(), so the builder must compose the expression itself; this makes that composition correct. - WhereEdge's pre-pass no longer discards a caller-set UID/RootFunc root. When a custom root is present, the matched UIDs are intersected via a uid() @filter instead of overwriting the root. - MultiQuery.Add rejects the same *Query[T] under two names; Execute names the underlying query in place, so reuse would corrupt block composition. - NodesAndCount now opens a tracing span, matching the other terminals. Tests: add a precedence regression, a WhereEdge+UID intersection test, and a duplicate-Query guard test; strengthen the filter-sequencing test to assert the exact expression; make the IterNodes laziness check delta-based. Docs: add package doc.go with a before/after narrative, runnable examples for the client/query builder/MultiQuery, and a verified filter example. --- typed/client.go | 4 - typed/doc.go | 73 ++++++++++++++++ typed/example_test.go | 162 +++++++++++++++++++++++++++++++++++ typed/filter/example_test.go | 30 +++++++ typed/filter/filter_test.go | 20 ++++- typed/multi_query.go | 21 +++-- typed/multi_query_test.go | 15 ++++ typed/query.go | 37 ++++++-- typed/query_test.go | 66 ++++++++++++-- 9 files changed, 400 insertions(+), 28 deletions(-) create mode 100644 typed/doc.go create mode 100644 typed/example_test.go create mode 100644 typed/filter/example_test.go diff --git a/typed/client.go b/typed/client.go index c540f89..2c5e796 100644 --- a/typed/client.go +++ b/typed/client.go @@ -3,10 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ -// Package typed binds a Go type to the otherwise any-typed modusgraph.Client, -// providing generic, type-safe CRUD and query operations without per-entity -// code generation. It is the handwritten substrate that modusgraph-gen's -// generated clients compose over. package typed import ( diff --git a/typed/doc.go b/typed/doc.go new file mode 100644 index 0000000..1596c86 --- /dev/null +++ b/typed/doc.go @@ -0,0 +1,73 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +// Package typed binds a Go type to the otherwise any-typed modusgraph.Client, +// giving you generic, type-safe CRUD and a fluent query builder with no +// per-entity code generation. It is the handwritten substrate that +// modusgraph-gen's generated clients compose over, and it is useful on its own +// wherever you want compile-time types over modusgraph. +// +// # Why +// +// modusgraph.Client is value-oriented: its methods take and return any, and a +// query is assembled by hand from dgman primitives and decoded into a slice you +// declare at the call site. That works, but every call site repeats the same +// shape — declare the destination, build the query, decode, check the type. +// Package typed lifts that shape into the type system once. +// +// Without the typed layer, a "first matching record" lookup carries the type on +// every line and decodes by hand: +// +// var out []Person +// q := client.Query(ctx, &Person{}). +// Filter("eq(name, $1)", "Alice"). +// First(1) +// if err := q.Nodes(&out); err != nil { +// return nil, err +// } +// var person *Person +// if len(out) > 0 { +// person = &out[0] +// } +// +// With it, the type is declared once — when the client is constructed — and the +// terminal returns exactly what you asked for: +// +// people := typed.NewClient[Person](client) +// person, err := people.Query(ctx). +// Filter("eq(name, $1)", "Alice"). +// First() +// // person is *Person; nil when nothing matched. +// +// # The query builder +// +// Query[T] is a fluent builder. Builder methods (Filter, OrderAsc, Limit, +// WhereEdge, and the rest) return *Query[T] for chaining; terminals (Nodes, +// First, NodesAndCount, IterNodes) execute and decode typed results. The +// builder delegates the actual querying, parameter binding, and injection-safe +// $N substitution to dgman — it adds the type binding and the fragment +// composition dgman does not provide: +// +// - Accumulated Filter fragments AND together, each fragment parenthesized so +// a fragment containing OR keeps its precedence. +// - OrGroup ORs several sub-scopes into one parenthesized group. +// - WhereEdge constrains T by a predicate of a neighbouring node reached over +// an edge, resolved by a pre-pass and intersected with any root you set. +// - IterNodes streams arbitrarily large result sets one page at a time over a +// single read-only snapshot. +// +// # Composing larger requests +// +// MultiQuery batches N same-type blocks into one Dgraph round-trip, keyed by +// block name. The filter subpackage builds parameterised @filter expressions +// (the substrate behind generated By and Or combinators), and the search +// subpackage merges ordered result sets by ID. +// +// # Tracing +// +// Every terminal opens a span through a process-wide tracer that is a no-op by +// default. Install one with SetTracer to emit spans without the typed package +// depending on any telemetry library. +package typed diff --git a/typed/example_test.go b/typed/example_test.go new file mode 100644 index 0000000..c17d3d1 --- /dev/null +++ b/typed/example_test.go @@ -0,0 +1,162 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package typed_test + +import ( + "context" + "fmt" + + "github.com/matthewmcneely/modusgraph" + "github.com/matthewmcneely/modusgraph/typed" +) + +// Person is the schema struct the examples bind a typed client to. modusgraph +// reflects over the dgraph/json tags, so the type needs no special interface. +type Person struct { + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` + Name string `json:"name,omitempty" dgraph:"index=exact"` + Age int `json:"age,omitempty" dgraph:"index=int"` + Friends []*Person `json:"friends,omitempty"` +} + +// ExampleClient shows the core lift package typed provides: declare the type +// once at construction, then Add, Get, and Query in terms of *Person rather +// than any. +func ExampleClient() { + conn, _ := modusgraph.NewClient("dgraph://localhost:9080") + defer conn.Close() + + people := typed.NewClient[Person](conn) + ctx := context.Background() + + alice := &Person{Name: "Alice", Age: 30} + if err := people.Add(ctx, alice); err != nil { // Add writes the new UID back into alice. + panic(err) + } + + got, err := people.Get(ctx, alice.UID) // got is *Person, not any. + if err != nil { + panic(err) + } + fmt.Println(got.Name) +} + +// ExampleClient_query builds a filtered, ordered, paged query. The terminal +// returns []Person directly — no destination slice to declare, no decode step. +func ExampleClient_query() { + conn, _ := modusgraph.NewClient("dgraph://localhost:9080") + defer conn.Close() + people := typed.NewClient[Person](conn) + ctx := context.Background() + + adults, err := people.Query(ctx). + Filter("ge(age, $1)", 18). + OrderAsc("name"). + Limit(50). + Nodes() + if err != nil { + panic(err) + } + fmt.Println(len(adults)) +} + +// ExampleQuery_First returns a single record or nil, replacing the +// declare-slice-then-index-element-zero idiom of the untyped client. +func ExampleQuery_First() { + conn, _ := modusgraph.NewClient("dgraph://localhost:9080") + defer conn.Close() + people := typed.NewClient[Person](conn) + ctx := context.Background() + + person, err := people.Query(ctx). + Filter("eq(name, $1)", "Alice"). + First() + if err != nil { + panic(err) + } + if person == nil { + fmt.Println("not found") + return + } + fmt.Println(person.Name) +} + +// ExampleQuery_OrGroup ANDs a scalar filter with an OR of two sub-scopes: +// age >= 18 AND (name == "Alice" OR name == "Bob"). Each sub-scope is a +// detached Query whose filter is captured, not executed. +func ExampleQuery_OrGroup() { + conn, _ := modusgraph.NewClient("dgraph://localhost:9080") + defer conn.Close() + people := typed.NewClient[Person](conn) + ctx := context.Background() + + got, err := people.Query(ctx). + Filter("ge(age, $1)", 18). + OrGroup( + typed.NewDetachedQuery[Person]().Filter(`eq(name, "Alice")`), + typed.NewDetachedQuery[Person]().Filter(`eq(name, "Bob")`), + ). + Nodes() + if err != nil { + panic(err) + } + fmt.Println(len(got)) +} + +// ExampleQuery_WhereEdge constrains people by a scalar of a neighbour reached +// over the "friends" edge — something a root filter cannot express. The builder +// resolves it with a pre-pass and intersects with any root you set. +func ExampleQuery_WhereEdge() { + conn, _ := modusgraph.NewClient("dgraph://localhost:9080") + defer conn.Close() + people := typed.NewClient[Person](conn) + ctx := context.Background() + + // Everyone who has a friend named "Alice". + got, err := people.Query(ctx). + WhereEdge("friends", `eq(name, $1)`, "Alice"). + Nodes() + if err != nil { + panic(err) + } + fmt.Println(len(got)) +} + +// ExampleQuery_IterNodes streams a large result set one page at a time over a +// single consistent snapshot, so the whole set is never held in memory at once. +func ExampleClient_iter() { + conn, _ := modusgraph.NewClient("dgraph://localhost:9080") + defer conn.Close() + people := typed.NewClient[Person](conn) + ctx := context.Background() + + for person, err := range people.Query(ctx).OrderAsc("name").IterNodes() { + if err != nil { + panic(err) + } + fmt.Println(person.Name) + } +} + +// ExampleMultiQuery batches several same-type queries into one Dgraph +// round-trip, keyed by block name. +func ExampleMultiQuery() { + conn, _ := modusgraph.NewClient("dgraph://localhost:9080") + defer conn.Close() + people := typed.NewClient[Person](conn) + ctx := context.Background() + + mq := typed.NewMultiQuery[Person](conn). + Add("adults", people.Query(ctx).Filter("ge(age, $1)", 18)). + Add("named_alice", people.Query(ctx).Filter(`eq(name, "Alice")`)) + + results, err := mq.Execute(ctx) // one round-trip + if err != nil { + panic(err) + } + fmt.Println(len(results["adults"]), len(results["named_alice"])) +} diff --git a/typed/filter/example_test.go b/typed/filter/example_test.go new file mode 100644 index 0000000..6c602f9 --- /dev/null +++ b/typed/filter/example_test.go @@ -0,0 +1,30 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package filter_test + +import ( + "fmt" + + "github.com/matthewmcneely/modusgraph/typed/filter" +) + +// ExampleBuilder composes a parameterised @filter expression. Terms within a +// group join with OR; groups join with AND; required terms form their own +// group. Build returns the expression and the positional params that +// typed.Query[T].Filter consumes — the values never get interpolated into the +// string, so the expression is safe against injection. +func ExampleBuilder() { + var b filter.Builder + b.RequiredEq("archiveStatus", "Active") + b.EqGroupString("name", []filter.String{{Value: "Alice"}, {Value: "Bob"}}) + + expr, params := b.Build() + fmt.Println(expr) + fmt.Println(params...) + // Output: + // eq(archiveStatus, $1) AND (eq(name, $2) OR eq(name, $3)) + // Active Alice Bob +} diff --git a/typed/filter/filter_test.go b/typed/filter/filter_test.go index 864a554..10fdf2b 100644 --- a/typed/filter/filter_test.go +++ b/typed/filter/filter_test.go @@ -1,7 +1,6 @@ package filter_test import ( - "strings" "testing" "github.com/matthewmcneely/modusgraph/typed/filter" @@ -111,8 +110,21 @@ func TestBuilder_PositionalParamsAreSequential(t *testing.T) { var b filter.Builder b.EqGroupUUID("id", []filter.UUID{{Value: "a"}, {Value: "b"}}) b.EqGroupString("name", []filter.String{{Value: "c"}}) - expr, _ := b.Build() - if !strings.Contains(expr, "$1") || !strings.Contains(expr, "$2") || !strings.Contains(expr, "$3") { - t.Errorf("expected $1, $2, $3 in expr; got %q", expr) + expr, params := b.Build() + // Assert the exact expression: placeholders must be numbered $1..$N in + // emission order and bound to the matching params. A substring check would + // pass even if the numbering were scrambled (e.g. "$3 ... $1 ... $2"). + const want = "(eq(id, $1) OR eq(id, $2)) AND (eq(name, $3))" + if expr != want { + t.Errorf("expr = %q, want %q", expr, want) + } + wantParams := []any{"a", "b", "c"} + if len(params) != len(wantParams) { + t.Fatalf("params = %v, want %v", params, wantParams) + } + for i, p := range wantParams { + if params[i] != p { + t.Errorf("param[%d] = %v, want %v", i, params[i], p) + } } } diff --git a/typed/multi_query.go b/typed/multi_query.go index 98409c6..c088521 100644 --- a/typed/multi_query.go +++ b/typed/multi_query.go @@ -36,13 +36,20 @@ func NewMultiQuery[T any](conn modusgraph.Client) *MultiQuery[T] { } } -// Add registers a named block. Names must be unique within one MultiQuery. -// Panics on duplicate name — the call site is a programming error, not a -// runtime condition. +// Add registers a named block. Names must be unique within one MultiQuery, and +// each *Query[T] may be added only once: Execute names the block's underlying +// dgman query in place, so registering the same Query pointer under two names +// would make both blocks render with whichever name was applied last. Both +// conditions are programming errors and panic rather than fail at runtime. func (mq *MultiQuery[T]) Add(name string, q *Query[T]) *MultiQuery[T] { if _, exists := mq.blocks[name]; exists { panic(fmt.Sprintf("multi_query: duplicate block name %q", name)) } + for existingName, existing := range mq.blocks { + if existing == q { + panic(fmt.Sprintf("multi_query: Query already added as %q; build a separate Query per block", existingName)) + } + } mq.names = append(mq.names, name) mq.blocks[name] = q return mq @@ -130,7 +137,7 @@ func (mq *MultiQuery[T]) Execute(ctx context.Context) (map[string][]T, error) { // of the same name; we need our own because the multi-block response from // QueryRaw bypasses dgman's scan path. func buildPredicateToJSONMap(t reflect.Type) map[string]string { - for t != nil && t.Kind() == reflect.Ptr { + for t != nil && t.Kind() == reflect.Pointer { t = t.Elem() } if t == nil || t.Kind() != reflect.Struct { @@ -152,9 +159,9 @@ func buildPredicateToJSONMap(t reflect.Type) map[string]string { continue } var predName string - for _, part := range strings.Fields(dgraphTag) { - if strings.HasPrefix(part, "predicate=") { - predName = strings.TrimPrefix(part, "predicate=") + for part := range strings.FieldsSeq(dgraphTag) { + if p, ok := strings.CutPrefix(part, "predicate="); ok { + predName = p break } } diff --git a/typed/multi_query_test.go b/typed/multi_query_test.go index 98f1ae4..2f1561c 100644 --- a/typed/multi_query_test.go +++ b/typed/multi_query_test.go @@ -38,6 +38,21 @@ func TestMultiQueryAddRejectsDuplicateName(t *testing.T) { mq.Add("byName", q) } +func TestMultiQueryAddRejectsSameQueryTwice(t *testing.T) { + conn := newConn(t) + mq := typed.NewMultiQuery[widget](conn) + q := typed.NewClient[widget](conn).Query(context.Background()) + mq.Add("first", q) + defer func() { + if r := recover(); r == nil { + t.Fatal("expected panic when the same Query is added under two names") + } + }() + // Execute names the block's underlying query in place, so reusing one Query + // pointer would corrupt block composition; Add must reject it up front. + mq.Add("second", q) +} + func TestMultiQueryExecuteReturnsPerBlockResults(t *testing.T) { ctx := context.Background() conn := newConn(t) diff --git a/typed/query.go b/typed/query.go index e4b2199..78157f8 100644 --- a/typed/query.go +++ b/typed/query.go @@ -42,6 +42,11 @@ type Query[T any] struct { offset int // caller-set starting offset; 0 = none edges []edgeFilter // accumulated WhereEdge constraints; empty = none filters []filterFrag // accumulated @filter fragments, ANDed; empty = none + + // customRoot records that the caller narrowed the root with UID or + // RootFunc. The WhereEdge pre-pass then intersects with that root instead + // of overwriting it (see resolveRoots). + customRoot bool } // edgeFilter is one accumulated WhereEdge constraint: a dgraph @filter @@ -93,7 +98,10 @@ func (qb *Query[T]) addFilter(expr string, params []any) { } // combineAnd joins fragments with AND, renumbering each fragment's ordinal -// placeholders against the concatenated params slice. +// placeholders against the concatenated params slice. Each fragment is wrapped +// in its own parentheses so a fragment that itself contains OR keeps its +// intended precedence: without the parens, "a OR b" ANDed with "c" would parse +// as "a OR (b AND c)" because dgraph binds AND tighter than OR. func combineAnd(frags []filterFrag) (string, []any) { parts := make([]string, 0, len(frags)) var params []any @@ -101,7 +109,7 @@ func combineAnd(frags []filterFrag) (string, []any) { if f.expr == "" { continue } - parts = append(parts, shiftPlaceholders(f.expr, len(params))) + parts = append(parts, "("+shiftPlaceholders(f.expr, len(params))+")") params = append(params, f.params...) } if len(parts) == 0 { @@ -186,6 +194,7 @@ func (qb *Query[T]) Cascade(predicates ...string) *Query[T] { // is type(); RootFunc replaces it with an expression such as // eq(name, "Alice") or has(email). Repeated calls overwrite. func (qb *Query[T]) RootFunc(rootFunc string) *Query[T] { + qb.customRoot = true qb.q.RootFunc(rootFunc) return qb } @@ -370,6 +379,7 @@ func (qb *Query[T]) Raw() *dg.Query { // UID roots the query at a specific node UID. Results still decode into []T. func (qb *Query[T]) UID(uid string) *Query[T] { + qb.customRoot = true qb.q.UID(uid) return qb } @@ -385,7 +395,9 @@ func (qb *Query[T]) All(depth int) *Query[T] { // NodesAndCount executes the query and returns the matching records together // with the total count (useful for pagination totals). Like Nodes, it runs the // WhereEdge pre-pass first when edge constraints are present. -func (qb *Query[T]) NodesAndCount() ([]T, int, error) { +func (qb *Query[T]) NodesAndCount() (out []T, count int, err error) { + _, span := tracer.StartSpan(qb.ctx, "query", entityName[T]()) + defer func() { span.End(err) }() matched, err := qb.resolveRoots() if err != nil { return nil, 0, err @@ -393,8 +405,7 @@ func (qb *Query[T]) NodesAndCount() ([]T, int, error) { if !matched { return nil, 0, nil } - var out []T - count, err := qb.q.NodesAndCount(&out) + count, err = qb.q.NodesAndCount(&out) if err != nil { return nil, 0, err } @@ -468,10 +479,17 @@ func (r *RawQuery) GroupBy(predicate string) *RawQuery { } // resolveRoots runs the WhereEdge pre-pass when the query carries edge -// constraints, rewriting the main query's root function to the matching UIDs. +// constraints, narrowing the main query to the UIDs whose edges matched. // It returns matched=false when constraints are present but no root satisfied // them — callers then return an empty result without running the main query. // With no edge constraints it is a no-op returning matched=true. +// +// When the caller has not narrowed the root, the matched UIDs become the root +// function directly (the efficient path: the main query scans only those +// nodes). When the caller already narrowed the root with UID or RootFunc, the +// matched UIDs are added as a uid() @filter instead, so the result is the +// intersection of the caller's root and the edge constraints rather than +// silently discarding the caller's narrowing. func (qb *Query[T]) resolveRoots() (matched bool, err error) { if len(qb.edges) == 0 { return true, nil @@ -483,7 +501,12 @@ func (qb *Query[T]) resolveRoots() (matched bool, err error) { if len(uids) == 0 { return false, nil } - qb.q.RootFunc("uid(" + strings.Join(uids, ", ") + ")") + uidExpr := "uid(" + strings.Join(uids, ", ") + ")" + if qb.customRoot { + qb.addFilter(uidExpr, nil) + } else { + qb.q.RootFunc(uidExpr) + } return true, nil } diff --git a/typed/query_test.go b/typed/query_test.go index 588bf6b..4496756 100644 --- a/typed/query_test.go +++ b/typed/query_test.go @@ -202,7 +202,7 @@ func TestQuery_CombinedFilterShiftsPlaceholders(t *testing.T) { q.Filter("eq(name, $1)", "a") q.Filter("eq(qty, $1)", 7) expr, params := q.CombinedFilter() - const want = "eq(name, $1) AND eq(qty, $2)" + const want = "(eq(name, $1)) AND (eq(qty, $2))" if expr != want { t.Fatalf("CombinedFilter expr = %q, want %q", expr, want) } @@ -211,6 +211,22 @@ func TestQuery_CombinedFilterShiftsPlaceholders(t *testing.T) { } } +// TestQuery_CombinedFilterParenthesizesFragments pins the precedence guarantee: +// a fragment that contains OR must stay grouped when it is ANDed with another +// fragment. Without per-fragment parentheses the expression would render as +// "a OR b AND c", which dgraph parses as "a OR (b AND c)" — silently widening +// the result set. +func TestQuery_CombinedFilterParenthesizesFragments(t *testing.T) { + q := typed.NewDetachedQuery[widget]() + q.Filter(`eq(name, "alpha") OR eq(name, "beta")`) + q.Filter(`ge(qty, "5")`) + expr, _ := q.CombinedFilter() + const want = `(eq(name, "alpha") OR eq(name, "beta")) AND (ge(qty, "5"))` + if expr != want { + t.Fatalf("CombinedFilter precedence: expr = %q, want %q", expr, want) + } +} + func TestQuery_OrGroup(t *testing.T) { ctx := context.Background() c := typed.NewClient[widget](newConn(t)) @@ -791,10 +807,13 @@ func TestIterNodes_OneQueryPerPage(t *testing.T) { t.Fatalf("Add %d: %v", i, err) } } - // Obtaining the iterator runs no query — IterNodes is lazy. + // Obtaining the iterator runs no query — IterNodes is lazy. Measure the + // delta around the build, not the absolute count, so the assertion holds + // regardless of how many queries the seeding above happened to run. + before := queriesExecuted seq := c.Query(ctx).IterNodes() - if queriesExecuted != 0 { - t.Fatalf("building the IterNodes iterator executed %d queries, want 0", queriesExecuted) + if delta := queriesExecuted - before; delta != 0 { + t.Fatalf("building the IterNodes iterator executed %d queries, want 0", delta) } seen := 0 for _, err := range seq { @@ -806,8 +825,8 @@ func TestIterNodes_OneQueryPerPage(t *testing.T) { if seen != n { t.Fatalf("IterNodes streamed %d records, want %d", seen, n) } - if queriesExecuted != 3 { - t.Fatalf("IterNodes over %d records ran %d queries, want 3", n, queriesExecuted) + if delta := queriesExecuted - before; delta != 3 { // ceil(125/50) = 3 pages + t.Fatalf("IterNodes over %d records ran %d queries, want 3", n, delta) } } @@ -1149,6 +1168,41 @@ func TestQuery_WhereEdgeCombinesWithFilter(t *testing.T) { } } +func TestQuery_WhereEdgePreservesUIDRoot(t *testing.T) { + ctx := context.Background() + conn := newConn(t) + pets := typed.NewClient[pet](conn) + owners := typed.NewClient[owner](conn) + + fido := &pet{Name: "Fido"} + if err := pets.Add(ctx, fido); err != nil { + t.Fatalf("Add pet: %v", err) + } + alice := &owner{Name: "Alice", Pets: []*pet{fido}} + carol := &owner{Name: "Carol", Pets: []*pet{fido}} + for _, o := range []*owner{alice, carol} { + if err := owners.Add(ctx, o); err != nil { + t.Fatalf("Add owner %q: %v", o.Name, err) + } + } + + // Both Alice and Carol own Fido, so the WhereEdge pre-pass matches both. + // Rooting the query at Alice's UID must survive that pre-pass: the result + // is the intersection (just Alice), not every Fido owner. Before the fix, + // resolveRoots overwrote the UID root with uid(Alice, Carol) and returned + // both owners. + got, err := owners.Query(ctx). + UID(alice.UID). + WhereEdge("pets", `eq(name, "Fido")`). + Nodes() + if err != nil { + t.Fatalf("UID+WhereEdge Nodes: %v", err) + } + if len(got) != 1 || got[0].Name != "Alice" { + t.Fatalf("UID(Alice)+WhereEdge(pets,name=Fido) returned %+v, want [Alice]", got) + } +} + func TestQuery_WhereEdgeMultipleConstraintsAnd(t *testing.T) { ctx := context.Background() conn := newConn(t) From febcd5cf12d8fab5d50a753d9b73cdd28d43ec02 Mon Sep 17 00:00:00 2001 From: Michael Welles Date: Wed, 17 Jun 2026 17:23:28 -0400 Subject: [PATCH 16/24] fix(retry): keep jitter under MaxDelay; never skip fn on negative MaxRetries Addresses review feedback on the retry policy: - delay() now adds jitter before the final MaxDelay clamp, so the documented invariant "no single delay exceeds MaxDelay" holds. Previously the cap was applied before jitter, letting the delay reach MaxDelay*(1+Jitter). The exponential is also capped before the shift to avoid overflow at large attempt counts. - WithRetry clamps a negative MaxRetries to zero so fn always runs at least once, as documented, instead of skipping the loop and hitting the unreachable panic. Tests: - TestRetryPolicyDelayJitterNeverExceedsMaxCap asserts the cap invariant across attempts with large jitter. - TestWithRetryNegativeMaxRetries asserts fn runs once with MaxRetries=-1. - TestWithRetryContextCancellation now returns dgo.ErrAborted from fn so it actually enters the backoff sleep and exercises the ctx.Done() path it claims to cover (previously it returned ctx.Err() and bailed immediately). Docs: add runnable examples for WithRetry and a custom RetryPolicy. --- retry.go | 34 ++++++++++++++++++++------- retry_example_test.go | 53 ++++++++++++++++++++++++++++++++++++++++++ retry_internal_test.go | 20 ++++++++++++++++ retry_test.go | 39 ++++++++++++++++++++++--------- 4 files changed, 126 insertions(+), 20 deletions(-) create mode 100644 retry_example_test.go diff --git a/retry.go b/retry.go index 9b49fda..691d4ad 100644 --- a/retry.go +++ b/retry.go @@ -41,15 +41,25 @@ var DefaultRetryPolicy = RetryPolicy{ Jitter: 0.1, } -// delay computes the backoff duration for a given attempt (0-indexed). -// Formula: min(BaseDelay * 2^attempt, MaxDelay) + random(0, delay * Jitter) +// delay computes the backoff duration for a given attempt (0-indexed): +// the exponential BaseDelay*2^attempt, plus up to Jitter of itself, clamped +// so the result never exceeds MaxDelay. Clamping last keeps the documented +// invariant that no single delay exceeds MaxDelay — adding jitter after the +// cap would let the delay overshoot it. func (p RetryPolicy) delay(attempt int) time.Duration { - d := p.BaseDelay * time.Duration(1< p.MaxDelay { - d = p.MaxDelay + // Cap the exponential before jitter so a large attempt cannot overflow the + // shift. exp <= 0 means the shift overflowed; treat that as the cap too. + d := p.MaxDelay + if attempt < 63 { + if exp := p.BaseDelay << uint(attempt); exp > 0 && exp < p.MaxDelay { + d = exp + } } if p.Jitter > 0 { d += time.Duration(float64(d) * p.Jitter * rand.Float64()) + if d > p.MaxDelay { + d = p.MaxDelay + } } return d } @@ -74,23 +84,29 @@ func (p RetryPolicy) delay(attempt int) time.Duration { // return client.Insert(ctx, &entity) // }) func (c client) WithRetry(ctx context.Context, policy RetryPolicy, fn func() error) error { - for attempt := range policy.MaxRetries + 1 { + // A negative MaxRetries would make the loop run zero times and never call + // fn; clamp to zero so fn always runs at least once, as documented. + maxRetries := policy.MaxRetries + if maxRetries < 0 { + maxRetries = 0 + } + for attempt := range maxRetries + 1 { err := fn() if err == nil { return nil } - if !errors.Is(err, dgo.ErrAborted) || attempt >= policy.MaxRetries { + if !errors.Is(err, dgo.ErrAborted) || attempt >= maxRetries { return err } d := policy.delay(attempt) c.logger.V(1).Info("Transaction aborted, retrying", - "attempt", attempt+1, "maxRetries", policy.MaxRetries, "delay", d) + "attempt", attempt+1, "maxRetries", maxRetries, "delay", d) select { case <-time.After(d): case <-ctx.Done(): return ctx.Err() } } - // Unreachable: the loop runs MaxRetries+1 times and returns on every path. + // Unreachable: the loop runs at least once and returns on every path. panic("unreachable") } diff --git a/retry_example_test.go b/retry_example_test.go new file mode 100644 index 0000000..c95e9ef --- /dev/null +++ b/retry_example_test.go @@ -0,0 +1,53 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package modusgraph_test + +import ( + "context" + "time" + + "github.com/matthewmcneely/modusgraph" +) + +// ExampleClient_WithRetry wraps a mutation so an aborted transaction — the +// error Dgraph returns when concurrent writers conflict on an indexed +// predicate — is retried with exponential backoff instead of surfacing to the +// caller. Non-abort errors return immediately; the context bounds the total +// wait. +func ExampleClient_withRetry() { + client, _ := modusgraph.NewClient("dgraph://localhost:9080") + defer client.Close() + + ctx := context.Background() + entity := &RetryEntity{Name: "alice", Value: 1} + + err := client.WithRetry(ctx, modusgraph.DefaultRetryPolicy, func() error { + return client.Insert(ctx, entity) + }) + if err != nil { + panic(err) + } +} + +// ExampleRetryPolicy shows a custom backoff schedule: three retries, 50ms base +// delay doubling each attempt, capped at 2s, with 20% jitter to spread +// concurrent retriers apart. +func ExampleRetryPolicy() { + client, _ := modusgraph.NewClient("dgraph://localhost:9080") + defer client.Close() + + policy := modusgraph.RetryPolicy{ + MaxRetries: 3, + BaseDelay: 50 * time.Millisecond, + MaxDelay: 2 * time.Second, + Jitter: 0.2, + } + + ctx := context.Background() + _ = client.WithRetry(ctx, policy, func() error { + return client.Insert(ctx, &RetryEntity{Name: "bob", Value: 2}) + }) +} diff --git a/retry_internal_test.go b/retry_internal_test.go index ce6bd2b..4eadccf 100644 --- a/retry_internal_test.go +++ b/retry_internal_test.go @@ -54,6 +54,26 @@ func TestRetryPolicyDelayWithJitter(t *testing.T) { } } +// TestRetryPolicyDelayJitterNeverExceedsMaxCap pins the documented invariant: +// even when the exponential delay sits at or above MaxDelay and jitter is +// large, no single delay exceeds MaxDelay. Jitter is added before the final +// clamp, so the result stays within the cap. +func TestRetryPolicyDelayJitterNeverExceedsMaxCap(t *testing.T) { + p := RetryPolicy{ + BaseDelay: 8 * time.Second, + MaxDelay: 10 * time.Second, + Jitter: 0.5, // up to +50% would overshoot MaxDelay without the clamp + } + for attempt := range 6 { + for range 100 { + d := p.delay(attempt) + assert.LessOrEqual(t, d, p.MaxDelay, + "attempt %d: delay %v exceeded MaxDelay %v", attempt, d, p.MaxDelay) + assert.Positive(t, d, "attempt %d: delay should be positive", attempt) + } + } +} + func TestRetryPolicyDelayZeroJitter(t *testing.T) { p := RetryPolicy{ BaseDelay: 100 * time.Millisecond, diff --git a/retry_test.go b/retry_test.go index 4cb0d86..a2a5e63 100644 --- a/retry_test.go +++ b/retry_test.go @@ -14,6 +14,7 @@ import ( "testing" "time" + "github.com/dgraph-io/dgo/v250" "github.com/matthewmcneely/modusgraph" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -114,20 +115,15 @@ func TestWithRetryContextCancellation(t *testing.T) { callCount := 0 err := client.WithRetry(ctx, slowPolicy, func() error { callCount++ - // Always return an error that looks like an abort to trigger retry. - // We simulate this by inserting a duplicate to get a UniqueError, - // but that won't be retried. Instead, use a real insert to a fresh - // entity so the first call succeeds. - // Actually, to test the cancellation path we need the fn to always - // fail with an aborted error. Since we can't easily manufacture - // dgo.ErrAborted, test that context cancellation returns ctx.Err() - // by having fn block until context is done. - <-ctx.Done() - return ctx.Err() + // Return a real abort so WithRetry enters the retry path and sleeps for + // the 1s backoff. The 50ms context deadline fires during that sleep, so + // WithRetry must return from the ctx.Done() branch of the backoff select + // — the path this test exists to cover. + return dgo.ErrAborted }) assert.ErrorIs(t, err, context.DeadlineExceeded) - assert.Equal(t, 1, callCount, "fn should be called once before context expires") + assert.Equal(t, 1, callCount, "fn runs once, then the context expires during backoff") } // TestRetryPolicyDelay verifies the exponential backoff calculation. @@ -195,3 +191,24 @@ func TestWithRetryMaxRetriesZero(t *testing.T) { assert.Error(t, err) assert.Equal(t, 1, callCount, "MaxRetries=0 should call fn exactly once") } + +// TestWithRetryNegativeMaxRetries verifies that a negative MaxRetries still +// calls fn exactly once and returns its error, rather than skipping the loop +// and panicking. +func TestWithRetryNegativeMaxRetries(t *testing.T) { + uri := "file://" + GetTempDir(t) + client, cleanup := CreateTestClient(t, uri) + defer cleanup() + + policy := modusgraph.RetryPolicy{MaxRetries: -1} + callCount := 0 + expectedErr := fmt.Errorf("boom") + + err := client.WithRetry(context.Background(), policy, func() error { + callCount++ + return expectedErr + }) + + assert.ErrorIs(t, err, expectedErr) + assert.Equal(t, 1, callCount, "negative MaxRetries should still call fn once") +} From 9bc428948c0f0b6ec02e8535b8cfc22b5db2b786 Mon Sep 17 00:00:00 2001 From: Michael Welles Date: Wed, 17 Jun 2026 17:29:04 -0400 Subject: [PATCH 17/24] fix(schema): guard nil wrappers and pointer-receiver Unwrap in UnwrapSchema MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses review feedback on generated-schema-type routing: - UnwrapSchema no longer panics on a typed nil pointer: it returns the value untouched instead of invoking Unwrap on a nil receiver that may dereference. - UnwrapSchema now finds a pointer-receiver Unwrap on a wrapper passed by value, by looking the method up on an addressable copy (a value's method set excludes pointer-receiver methods). Tests: - Add unit tests for the typed-nil-pointer and value-wrapper cases. - Add an integration test (TestClientUnwrapsWrapperThroughRealMutation) that inserts a wrapper through the real client and reads it back, so a mutation method dropping its UnwrapSchema call is caught — the prior tests exercised UnwrapSchema in isolation through a local mock and would not catch that. Docs: add a runnable ExampleSchema showing the wrapper/Unwrap pattern. --- record.go | 13 ++++++++++ record_example_test.go | 49 ++++++++++++++++++++++++++++++++++++ record_integration_test.go | 51 ++++++++++++++++++++++++++++++++++++++ record_test.go | 22 ++++++++++++++++ 4 files changed, 135 insertions(+) create mode 100644 record_example_test.go create mode 100644 record_integration_test.go diff --git a/record.go b/record.go index 015c587..bfcc762 100644 --- a/record.go +++ b/record.go @@ -42,7 +42,20 @@ func UnwrapSchema(obj any) any { if !v.IsValid() { return obj } + // A typed nil pointer has a valid method set, but invoking Unwrap on a nil + // receiver would panic if the method dereferences it. Leave it untouched. + if v.Kind() == reflect.Pointer && v.IsNil() { + return obj + } m := v.MethodByName("Unwrap") + if !m.IsValid() && v.Kind() != reflect.Pointer { + // Unwrap may be declared with a pointer receiver while obj was passed by + // value; a value's method set excludes pointer-receiver methods, so look + // it up on an addressable copy. + pv := reflect.New(v.Type()) + pv.Elem().Set(v) + m = pv.MethodByName("Unwrap") + } if !m.IsValid() { return obj } diff --git a/record_example_test.go b/record_example_test.go new file mode 100644 index 0000000..534a844 --- /dev/null +++ b/record_example_test.go @@ -0,0 +1,49 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package modusgraph_test + +import ( + "context" + "fmt" + + mg "github.com/matthewmcneely/modusgraph" +) + +// Actor is a schema-defining record. Implementing mg.Schema (a single +// SchemaTypeName method) marks it as a generated schema type; code generators +// such as modusgraph-gen emit this method. +type Actor struct { + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` + Name string `json:"name,omitempty" dgraph:"index=exact"` +} + +func (a *Actor) SchemaTypeName() string { return "Actor" } + +// ActorBuilder is a wrapper around Actor — the shape a generated fluent builder +// or domain wrapper takes. Exposing Unwrap lets the modusgraph client route the +// wrapper to its backing record, so the wrapper can be passed straight to +// Insert/Update/Get without the caller reaching for the inner value. +type ActorBuilder struct{ actor *Actor } + +func (b *ActorBuilder) Unwrap() *Actor { return b.actor } + +// ExampleSchema shows the wrapper pattern: the client unwraps an ActorBuilder +// to its Actor before persisting, so generated wrapper types work transparently +// while plain structs are unaffected. +func ExampleSchema() { + client, _ := mg.NewClient("dgraph://localhost:9080") + defer client.Close() + + ctx := context.Background() + builder := &ActorBuilder{actor: &Actor{Name: "Sigourney Weaver"}} + + // Insert the wrapper; the client unwraps it to the Actor record. + if err := client.Insert(ctx, builder); err != nil { + panic(err) + } + fmt.Println(builder.actor.Name) +} diff --git a/record_integration_test.go b/record_integration_test.go new file mode 100644 index 0000000..41e9a13 --- /dev/null +++ b/record_integration_test.go @@ -0,0 +1,51 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package modusgraph_test + +import ( + "context" + "testing" + + mg "github.com/matthewmcneely/modusgraph" + "github.com/stretchr/testify/require" +) + +// studioRecord is a schema-defining record (implements mg.Schema). studioWrapper +// wraps it and exposes Unwrap, exactly as a modusgraph-gen wrapper would. +type studioRecord struct { + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` + Name string `json:"name,omitempty" dgraph:"index=exact"` +} + +func (s *studioRecord) SchemaTypeName() string { return "studioRecord" } + +type studioWrapper struct{ inner *studioRecord } + +func (w *studioWrapper) Unwrap() *studioRecord { return w.inner } + +// TestClientUnwrapsWrapperThroughRealMutation exercises the real client path, +// not UnwrapSchema in isolation: it inserts a wrapper and reads it back. If a +// mutation method stopped calling UnwrapSchema, the wrapper (which has no usable +// dgraph fields of its own) would not persist Name and the inner UID would stay +// empty — so this test fails on that regression. +func TestClientUnwrapsWrapperThroughRealMutation(t *testing.T) { + client, err := mg.NewClient("file://"+GetTempDir(t), mg.WithAutoSchema(true)) + require.NoError(t, err) + defer client.Close() + + ctx := context.Background() + inner := &studioRecord{Name: "Acme"} + wrapper := &studioWrapper{inner: inner} + + require.NoError(t, client.Insert(ctx, wrapper)) + require.NotEmpty(t, inner.UID, + "Insert did not route the wrapper to its inner record") + + var got studioRecord + require.NoError(t, client.Get(ctx, &got, inner.UID)) + require.Equal(t, "Acme", got.Name) +} diff --git a/record_test.go b/record_test.go index 1f6ef72..caf5083 100644 --- a/record_test.go +++ b/record_test.go @@ -61,6 +61,28 @@ func TestUnwrapSchema_NilInput(t *testing.T) { } } +func TestUnwrapSchema_TypedNilPointerDoesNotPanic(t *testing.T) { + // fakeWrapper.Unwrap dereferences its receiver, so invoking it on a typed + // nil pointer would panic. UnwrapSchema must return the value untouched. + var w *fakeWrapper + out := UnwrapSchema(w) + if out != any(w) { + t.Fatalf("expected typed nil pointer passthrough, got %T (%v)", out, out) + } +} + +func TestUnwrapSchema_PointerReceiverUnwrapOnValue(t *testing.T) { + // fakeWrapper.Unwrap has a pointer receiver. Passing the wrapper by value + // must still unwrap: a value's method set excludes pointer-receiver methods, + // so UnwrapSchema looks Unwrap up on an addressable copy. + inner := &fakeRecord{name: "Studio"} + w := fakeWrapper{inner: inner} + out := UnwrapSchema(w) + if out != any(inner) { + t.Fatalf("expected unwrapped inner from value wrapper, got %T (%v)", out, out) + } +} + // recordingClient is the minimal surface needed to verify that wrappers // passed to the Client interface get unwrapped before reaching internal // reflection. It records whatever it received and returns nil. Each method From e51a05901fc5cdfb53d83ef7379207fc67832f9c Mon Sep 17 00:00:00 2001 From: Michael Welles Date: Wed, 17 Jun 2026 17:36:07 -0400 Subject: [PATCH 18/24] fix(client): key dedup cache on dial-option identity, not just count MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses review feedback on WithGRPCDialOption: - client.key() keyed custom gRPC dial options by count only, so two clients to the same URI with one different option each produced the same key and were merged in the dedup cache. The key now includes each option's runtime identity (grpc.DialOption is opaque and not comparable), so differently configured clients are never merged. The cache errs toward keeping clients apart rather than merging connections configured differently. - Dial options apply only to remote (dgraph://) connections, so they now contribute to the key only for remote URIs — consistent with the documented "ignored for file://" behavior. Tests: - TestKeyDistinguishesGRPCDialOptions now also asserts two clients with the same option count but different options get different keys. - TestKeyIgnoresGRPCDialOptionsForEmbedded asserts dial options do not affect the key for embedded (file://) clients. Docs: add a runnable ExampleWithGRPCDialOption (TLS credentials + keepalive). --- client.go | 29 +++++++++++++++++++++++++++-- dial_options_example_test.go | 35 +++++++++++++++++++++++++++++++++++ dial_options_test.go | 23 +++++++++++++++++++++++ 3 files changed, 85 insertions(+), 2 deletions(-) create mode 100644 dial_options_example_test.go diff --git a/client.go b/client.go index 7834db2..e7ba130 100644 --- a/client.go +++ b/client.go @@ -453,9 +453,34 @@ func (c client) key() string { if c.options.embeddingProvider != nil { embeddingKey = fmt.Sprintf("%p", c.options.embeddingProvider) } - return fmt.Sprintf("%s:%t:%d:%d:%d:%d:%s:%s:%s:%d", c.uri, c.options.autoSchema, c.options.poolSize, + // Custom gRPC dial options only apply to remote (dgraph://) connections; + // they are ignored for embedded (file://) URIs, so they only contribute to + // the dedup key for remote clients — matching that documented behavior. + dialKey := "0" + if strings.HasPrefix(c.uri, dgraphURIPrefix) { + dialKey = dialOptionsKey(c.options.grpcDialOptions) + } + return fmt.Sprintf("%s:%t:%d:%d:%d:%d:%s:%s:%s:%s", c.uri, c.options.autoSchema, c.options.poolSize, c.options.maxEdgeTraversal, c.options.cacheSizeMB, c.options.maxRecvMsgSize, - c.options.namespace, validatorKey, embeddingKey, len(c.options.grpcDialOptions)) + c.options.namespace, validatorKey, embeddingKey, dialKey) +} + +// dialOptionsKey identifies a set of custom gRPC dial options for the client +// dedup cache. grpc.DialOption values are opaque and not comparable, so the key +// uses each option's runtime identity rather than just the count: two clients +// configured with different options get different keys and are never merged. +// Two clients built from separately-constructed but equivalent options also +// differ, which is safe — the cache errs toward keeping them apart rather than +// merging connections that were configured differently. +func dialOptionsKey(opts []grpc.DialOption) string { + if len(opts) == 0 { + return "0" + } + parts := make([]string, len(opts)) + for i, opt := range opts { + parts[i] = fmt.Sprintf("%p", opt) + } + return strings.Join(parts, ",") } // embeddingProvider implements the embeddingClient interface, exposing the diff --git a/dial_options_example_test.go b/dial_options_example_test.go new file mode 100644 index 0000000..ea82ab7 --- /dev/null +++ b/dial_options_example_test.go @@ -0,0 +1,35 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package modusgraph_test + +import ( + "time" + + mg "github.com/matthewmcneely/modusgraph" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/keepalive" +) + +// ExampleWithGRPCDialOption configures gRPC dial settings the dedicated options +// do not cover — here, transport credentials and keepalive parameters — when +// opening a remote dgraph:// connection. Each WithGRPCDialOption adds one +// grpc.DialOption; they compose with WithMaxRecvMsgSize. The options are ignored +// for embedded (file://) URIs. +func ExampleWithGRPCDialOption() { + client, err := mg.NewClient( + "dgraph://localhost:9080", + mg.WithGRPCDialOption(grpc.WithTransportCredentials(insecure.NewCredentials())), + mg.WithGRPCDialOption(grpc.WithKeepaliveParams(keepalive.ClientParameters{ + Time: 30 * time.Second, + Timeout: 10 * time.Second, + })), + ) + if err != nil { + panic(err) + } + defer client.Close() +} diff --git a/dial_options_test.go b/dial_options_test.go index c64e257..b7e0b65 100644 --- a/dial_options_test.go +++ b/dial_options_test.go @@ -21,10 +21,33 @@ func TestWithGRPCDialOptionAppends(t *testing.T) { } func TestKeyDistinguishesGRPCDialOptions(t *testing.T) { + // A client with no dial options must differ from one with a dial option. base := client{uri: "dgraph://localhost:9080"} withOpt := client{uri: "dgraph://localhost:9080"} WithGRPCDialOption(grpc.WithUserAgent("x"))(&withOpt.options) if base.key() == withOpt.key() { t.Fatal("client.key() must differ when grpcDialOptions differ, else clients dedup incorrectly") } + + // Two clients with the SAME number of dial options but DIFFERENT options + // must also differ. A count-only key would collide here and merge clients + // that were configured differently. + a := client{uri: "dgraph://localhost:9080"} + b := client{uri: "dgraph://localhost:9080"} + WithGRPCDialOption(grpc.WithUserAgent("x"))(&a.options) + WithGRPCDialOption(grpc.WithUserAgent("y"))(&b.options) + if a.key() == b.key() { + t.Fatal("client.key() must differ when dial options differ at the same count") + } +} + +func TestKeyIgnoresGRPCDialOptionsForEmbedded(t *testing.T) { + // Dial options are ignored for embedded (file://) URIs, so they must not + // affect the dedup key there. + plain := client{uri: "file:///tmp/db"} + withOpt := client{uri: "file:///tmp/db"} + WithGRPCDialOption(grpc.WithUserAgent("x"))(&withOpt.options) + if plain.key() != withOpt.key() { + t.Fatal("dial options must not affect the cache key for embedded (file://) clients") + } } From 7478c26d5c385ebd769afedff5b31e45713cfac8 Mon Sep 17 00:00:00 2001 From: Michael Welles Date: Wed, 17 Jun 2026 17:38:37 -0400 Subject: [PATCH 19/24] docs(schema): add runnable ExampleClient_alterSchema Document the raw schema-DDL path with a runnable example showing predicate types, indexes, and directives applied directly via AlterSchema. --- schema_ddl_example_test.go | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 schema_ddl_example_test.go diff --git a/schema_ddl_example_test.go b/schema_ddl_example_test.go new file mode 100644 index 0000000..2556cac --- /dev/null +++ b/schema_ddl_example_test.go @@ -0,0 +1,31 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package modusgraph_test + +import ( + "context" + + mg "github.com/matthewmcneely/modusgraph" +) + +// ExampleClient_AlterSchema applies a raw DQL schema string directly, giving +// full control over predicate types, indexes, and directives. This complements +// UpdateSchema (which infers the schema from Go struct tags) and is useful for +// migrations that declare predicates no Go type models yet. +func ExampleClient_alterSchema() { + client, _ := mg.NewClient("dgraph://localhost:9080") + defer client.Close() + + ctx := context.Background() + schema := ` + name: string @index(exact) . + email: string @index(hash) @upsert . + age: int @index(int) . + ` + if err := client.AlterSchema(ctx, schema); err != nil { + panic(err) + } +} From 81770b57c8033e2e8b576dc034a79a4fb6bd81f9 Mon Sep 17 00:00:00 2001 From: Michael Welles Date: Wed, 17 Jun 2026 17:38:37 -0400 Subject: [PATCH 20/24] test(validation): assert SelfValidator slice path skips StructCtx; add example - TestValidateSelfValidatorInSlice now asserts the configured StructValidator is never called for a SelfValidator slice element, matching the scalar test; a regression invoking both paths would otherwise pass silently. - Add a runnable ExampleSelfValidator showing a cross-field rule that layers on the configured StructValidator. --- self_validator_example_test.go | 51 ++++++++++++++++++++++++++++++++++ self_validator_test.go | 5 ++++ 2 files changed, 56 insertions(+) create mode 100644 self_validator_example_test.go diff --git a/self_validator_example_test.go b/self_validator_example_test.go new file mode 100644 index 0000000..1f5f918 --- /dev/null +++ b/self_validator_example_test.go @@ -0,0 +1,51 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package modusgraph_test + +import ( + "context" + "fmt" + + mg "github.com/matthewmcneely/modusgraph" +) + +// Event carries a cross-field rule that struct tags cannot express: End must +// not precede Start. Implementing SelfValidator lets the type enforce that rule +// itself; the client calls ValidateWith on Insert/Upsert/Update. ValidateWith +// also receives the configured StructValidator, so a type can run ordinary +// tag-based validation first and then layer custom logic on top. +type Event struct { + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` + Name string `json:"name,omitempty"` + Start int `json:"start,omitempty"` + End int `json:"end,omitempty"` +} + +func (e *Event) ValidateWith(ctx context.Context, v mg.StructValidator) error { + // Run any tag-based validation the client was configured with. + if v != nil { + if err := v.StructCtx(ctx, e); err != nil { + return err + } + } + // Then the cross-field rule. + if e.End < e.Start { + return fmt.Errorf("event %q: End (%d) must be >= Start (%d)", e.Name, e.End, e.Start) + } + return nil +} + +// ExampleSelfValidator inserts an Event; the client routes it through +// ValidateWith, so the cross-field rule runs before the write. +func ExampleSelfValidator() { + client, _ := mg.NewClient("dgraph://localhost:9080") + defer client.Close() + + ctx := context.Background() + err := client.Insert(ctx, &Event{Name: "launch", Start: 10, End: 5}) + fmt.Println(err != nil) // the rule rejects End < Start +} diff --git a/self_validator_test.go b/self_validator_test.go index bf5b5b4..d6ce8dd 100644 --- a/self_validator_test.go +++ b/self_validator_test.go @@ -63,6 +63,11 @@ func TestValidateSelfValidatorInSlice(t *testing.T) { if !errors.Is(err, errSelfValidated) { t.Fatalf("expected the SelfValidator path for slice elements, got %v", err) } + // As in the scalar case, the SelfValidator path must not also invoke the + // configured StructValidator for slice elements. + if rv.calls != 0 { + t.Fatalf("StructCtx must not run for a SelfValidator slice element, got %d calls", rv.calls) + } } // dateRange validates a relationship between two fields — a cross-field rule From 4a4f6f1e9a5e68ab159528adcb5cfa3dc9e94f94 Mon Sep 17 00:00:00 2001 From: Michael Welles Date: Wed, 17 Jun 2026 17:49:56 -0400 Subject: [PATCH 21/24] fix(consume): gate LoadAndDelete lock to embedded, guard nil upsert reflection Addresses review feedback specific to LoadOrStore/LoadAndDelete: - LoadAndDelete took a process-wide mutex unconditionally, serializing consumers across all keys and backends. The mutex is only needed for the embedded engine (no commit-time conflict check); a remote Dgraph cluster aborts the loser of a commit conflict and the bounded retry already elects a single winner. The lock is now taken only for embedded clients (c.engine != nil), removing head-of-line blocking for remote consumers. - firstUpsertPredicate dereferenced reflect.Value without guarding a nil pointer or non-struct input, panicking on Type()/NumField(). It now returns "" for an invalid or non-struct value. Tests: - TestLoadOrStore / TestTypedLoadOrStore now assert the passed object is hydrated with the existing record (its UID) on the loaded=true path, so a regression returning loaded=true without populating fields is caught. Docs: add runnable ExampleClient_loadOrStore and ExampleClient_loadAndDelete. This branch also merges the updated feature branches, bringing in the review fixes for the typed query builder, retry policy, schema routing, gRPC dial options, and self-validation. --- client.go | 15 ++++++++--- consume_example_test.go | 59 +++++++++++++++++++++++++++++++++++++++++ consume_test.go | 9 +++++++ typed/consume_test.go | 7 ++++- 4 files changed, 86 insertions(+), 4 deletions(-) create mode 100644 consume_example_test.go diff --git a/client.go b/client.go index a7154cd..28d69ad 100644 --- a/client.go +++ b/client.go @@ -680,6 +680,12 @@ func firstUpsertPredicate(obj any) string { for v.Kind() == reflect.Ptr { v = v.Elem() } + // A nil pointer (or nil interface) dereferences to an invalid Value, and a + // non-struct has no fields; either way there is no upsert predicate, and + // calling Type()/NumField() on them would panic. + if !v.IsValid() || v.Kind() != reflect.Struct { + return "" + } t := v.Type() for i := 0; i < t.NumField(); i++ { f := t.Field(i) @@ -749,9 +755,12 @@ func (c client) LoadAndDelete(ctx context.Context, obj any, key any, predicates // The shared read-write transaction already elects one winner against a real // Dgraph cluster (the loser aborts on commit), but the embedded engine does // no commit-time conflict check, so without this lock concurrent callers - // would each read the node and each report loaded=true. The lock makes - // read-and-consume atomic regardless of backend. - if c.consumeMu != nil { + // would each read the node and each report loaded=true. A real Dgraph + // cluster aborts the loser of a commit conflict and the bounded retry below + // already elects a single winner, so the lock is needed — and taken — only + // for the embedded engine. Taking it for remote clusters would needlessly + // serialize consumers operating on unrelated keys. + if c.engine != nil && c.consumeMu != nil { c.consumeMu.Lock() defer c.consumeMu.Unlock() } diff --git a/consume_example_test.go b/consume_example_test.go new file mode 100644 index 0000000..13a8eac --- /dev/null +++ b/consume_example_test.go @@ -0,0 +1,59 @@ +/* + * SPDX-FileCopyrightText: © 2017-2026 Istari Digital, Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package modusgraph_test + +import ( + "context" + "fmt" + + mg "github.com/matthewmcneely/modusgraph" +) + +// Token is keyed by a unique jti predicate. The upsert+unique tags let +// LoadOrStore atomically insert-if-absent on that key. +type Token struct { + UID string `json:"uid,omitempty"` + DType []string `json:"dgraph.type,omitempty"` + JTI string `json:"jti,omitempty" dgraph:"index=hash upsert unique"` +} + +// ExampleClient_loadOrStore atomically inserts a node if no node with the same +// key exists, or reports that one already did. loaded is false when this call +// created the node and true when an existing node was found; on the loaded=true +// path the passed object is hydrated with the existing record. +// +// This is the building block for "claim a one-time token": the first caller +// stores and proceeds, every later caller sees loaded=true and is rejected. +func ExampleClient_loadOrStore() { + client, _ := mg.NewClient("dgraph://localhost:9080") + defer client.Close() + + ctx := context.Background() + loaded, err := client.LoadOrStore(ctx, &Token{JTI: "abc123"}, "jti") + if err != nil { + panic(err) + } + fmt.Println(loaded) // false the first time, true thereafter +} + +// ExampleClient_loadAndDelete atomically reads a node and deletes it, electing a +// single winner under concurrency: exactly one caller gets loaded=true with the +// record hydrated, the rest get loaded=false. Use it to consume a one-shot +// value — a nonce, a pending job, a single-use code. +func ExampleClient_loadAndDelete() { + client, _ := mg.NewClient("dgraph://localhost:9080") + defer client.Close() + + ctx := context.Background() + var got Token + loaded, err := client.LoadAndDelete(ctx, &got, "abc123", "jti") + if err != nil { + panic(err) + } + if loaded { + fmt.Println("consumed", got.JTI) + } +} diff --git a/consume_test.go b/consume_test.go index 479d0de..6fb8519 100644 --- a/consume_test.go +++ b/consume_test.go @@ -49,6 +49,15 @@ func TestLoadOrStore(t *testing.T) { if !loaded { t.Fatal("second store: want loaded=true (already existed)") } + // On the loaded=true path, LoadOrStore must hydrate the passed object with + // the existing record. Assert it carries the existing node's UID, so a + // regression that returns loaded=true without populating fields is caught. + if second.UID == "" { + t.Fatal("second store: loaded record was not hydrated (UID empty)") + } + if second.UID != first.UID { + t.Fatalf("second store: want existing node UID %q, got %q", first.UID, second.UID) + } } type consumeState struct { diff --git a/typed/consume_test.go b/typed/consume_test.go index 16eca08..c80471b 100644 --- a/typed/consume_test.go +++ b/typed/consume_test.go @@ -45,13 +45,18 @@ func TestTypedLoadOrStore(t *testing.T) { t.Fatal("first: want a UID assigned") } - _, loaded, err = c.LoadOrStore(ctx, &jti{JTI: "abc"}, "jti") + rec2, loaded, err := c.LoadOrStore(ctx, &jti{JTI: "abc"}, "jti") if err != nil { t.Fatalf("second: %v", err) } if !loaded { t.Fatal("second: want loaded=true") } + // The loaded=true path must return the existing record, not the freshly + // passed (un-stored) one — assert it carries the original node's UID. + if rec2 == nil || rec2.UID != rec.UID { + t.Fatalf("second: want existing record with UID %q, got %+v", rec.UID, rec2) + } } type state struct { From d79b02e16ec94f3dd34817740377a7bcb054d625 Mon Sep 17 00:00:00 2001 From: Michael Welles Date: Wed, 17 Jun 2026 18:09:33 -0400 Subject: [PATCH 22/24] style(typed): satisfy golangci-lint lll and prealloc Wrap a >120-char error string in MultiQuery.Execute and a long test helper signature, and pre-allocate two test result slices. Resolves the Trunk golangci-lint findings; no behavior change. --- typed/multi_query.go | 4 +++- typed/query_test.go | 8 +++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/typed/multi_query.go b/typed/multi_query.go index c088521..5a546e2 100644 --- a/typed/multi_query.go +++ b/typed/multi_query.go @@ -85,7 +85,9 @@ func (mq *MultiQuery[T]) Execute(ctx context.Context) (map[string][]T, error) { for _, name := range mq.names { block := mq.blocks[name] if len(block.edges) != 0 { - return nil, fmt.Errorf("multi_query: block %q carries WhereEdge constraints; MultiQuery cannot batch edge-filtered blocks", name) + return nil, fmt.Errorf( + "multi_query: block %q carries WhereEdge constraints; "+ + "MultiQuery cannot batch edge-filtered blocks", name) } // Name the underlying dgman query so blocks do not collide on the // default "data" name and so the response JSON keys are predictable. diff --git a/typed/query_test.go b/typed/query_test.go index 4496756..6a52540 100644 --- a/typed/query_test.go +++ b/typed/query_test.go @@ -752,7 +752,7 @@ func TestIterNodes_RespectsOffset(t *testing.T) { t.Fatalf("Add %d: %v", i, err) } } - var got []int + got := make([]int, 0, n-3) // offset 3 of n records for w, err := range c.Query(ctx).OrderAsc("qty").Offset(3).IterNodes() { if err != nil { t.Fatalf("IterNodes yielded error: %v", err) @@ -780,7 +780,7 @@ func TestIterNodes_RespectsOffsetAndLimit(t *testing.T) { t.Fatalf("Add %d: %v", i, err) } } - var got []int + got := make([]int, 0, 120) // Limit(120) for w, err := range c.Query(ctx).OrderAsc("qty").Offset(60).Limit(120).IterNodes() { if err != nil { t.Fatalf("IterNodes yielded error: %v", err) @@ -1078,7 +1078,9 @@ func TestRawQuery_CarriesEarlierBuilders(t *testing.T) { // map entry is one owner owning one pet of the given name; the pet is inserted // first so the owner's edge links an already-persisted node. It returns an // owner client bound to conn. -func seedOwners(ctx context.Context, t *testing.T, conn modusgraph.Client, ownerToPet map[string]string) *typed.Client[owner] { +func seedOwners( + ctx context.Context, t *testing.T, conn modusgraph.Client, ownerToPet map[string]string, +) *typed.Client[owner] { t.Helper() pets := typed.NewClient[pet](conn) owners := typed.NewClient[owner](conn) From bbf99b955c751730688d9ce7a5936738ee714225 Mon Sep 17 00:00:00 2001 From: Michael Welles Date: Wed, 17 Jun 2026 18:10:40 -0400 Subject: [PATCH 23/24] style(retry): clear gosec G115 and G404 in delay() MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Drop the int->uint shift conversion (Go allows a signed shift count, and attempt is bounded to [0,63)), clearing gosec G115. Annotate the jitter RNG with nolint:gosec — backoff jitter is not security-sensitive, so math/rand/v2 is appropriate. No behavior change. --- retry.go | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/retry.go b/retry.go index 691d4ad..113c7fd 100644 --- a/retry.go +++ b/retry.go @@ -48,15 +48,18 @@ var DefaultRetryPolicy = RetryPolicy{ // cap would let the delay overshoot it. func (p RetryPolicy) delay(attempt int) time.Duration { // Cap the exponential before jitter so a large attempt cannot overflow the - // shift. exp <= 0 means the shift overflowed; treat that as the cap too. + // shift. attempt comes from a range loop (>= 0) and is bounded below 63; + // exp <= 0 means the shift overflowed anyway, which we treat as the cap. d := p.MaxDelay - if attempt < 63 { - if exp := p.BaseDelay << uint(attempt); exp > 0 && exp < p.MaxDelay { + if attempt >= 0 && attempt < 63 { + if exp := p.BaseDelay << attempt; exp > 0 && exp < p.MaxDelay { d = exp } } if p.Jitter > 0 { - d += time.Duration(float64(d) * p.Jitter * rand.Float64()) + // Backoff jitter spreads retriers apart; it does not need a + // cryptographic RNG, so math/rand/v2 is appropriate here. + d += time.Duration(float64(d) * p.Jitter * rand.Float64()) //nolint:gosec // G404: jitter is not security-sensitive if d > p.MaxDelay { d = p.MaxDelay } From 7e97f3a2a6d1aba2a786f20837d7559225f9f77b Mon Sep 17 00:00:00 2001 From: Michael Welles Date: Wed, 17 Jun 2026 19:42:36 -0400 Subject: [PATCH 24/24] ci: pin trivy to an existing release (0.59.1 does not exist) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The Trunk config pinned trivy@0.59.1, but that release does not exist on github.com/aquasecurity/trivy — Trunk's templated download (.../v0.59.1/trivy_0.59.1_Linux-64bit.tar.gz) returns HTTP 404, failing the Trunk Code Quality check on any PR whose diff trivy scans (e.g. workflow or broad changes) while reporting no actual lint issues. Bump to trivy@0.69.3, a real release with the expected Linux-64bit asset. The plugin (v1.6.7) downloads trivy via a version-templated URL, so no other change is needed. --- .trunk/trunk.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.trunk/trunk.yaml b/.trunk/trunk.yaml index 8f22d46..445cc8d 100644 --- a/.trunk/trunk.yaml +++ b/.trunk/trunk.yaml @@ -23,7 +23,7 @@ runtimes: # This is the section where you manage your linters. (https://docs.trunk.io/check/configuration) lint: enabled: - - trivy@0.59.1 + - trivy@0.69.3 - taplo@0.9.3 - actionlint@1.7.7 - checkov@3.2.365