diff --git a/go.mod b/go.mod index 3959ca3..fc670f4 100644 --- a/go.mod +++ b/go.mod @@ -44,6 +44,7 @@ require ( github.com/google/uuid v1.6.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect + github.com/klauspost/reedsolomon v1.14.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect diff --git a/go.sum b/go.sum index 40d6a60..0e7736d 100644 --- a/go.sum +++ b/go.sum @@ -61,6 +61,8 @@ github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnr github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/klauspost/reedsolomon v1.14.0 h1:5YSZeclzSYg5nl349+GDG/agDtQ6MZiwUYXvVKN1Jx0= +github.com/klauspost/reedsolomon v1.14.0/go.mod h1:yjqqjgMTQkBUHSG97/rm4zipffCNbCiZcB3kTqr++sQ= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= diff --git a/server/gcas/db.go b/server/gcas/db.go index 9709978..c8d5565 100644 --- a/server/gcas/db.go +++ b/server/gcas/db.go @@ -19,7 +19,7 @@ const pragmaString = `PRAGMA journal_mode=WAL; //go:embed migrations/*.sql var migrations embed.FS -const dbVersion = 1 +const dbVersion = 2 func OpenDB(dbPath string) (*sql.DB, error) { return OpenDBWithVersion(dbPath, dbVersion) diff --git a/server/gcas/gcas.go b/server/gcas/gcas.go index 40227c7..1cd9805 100644 --- a/server/gcas/gcas.go +++ b/server/gcas/gcas.go @@ -1,5 +1,7 @@ package gcas +import "context" + // GCAS is a content-addressible storage service that combines multiple CAS nodes into a single CAS. // It uses erasure coding to provide efficient redundancy. // The erasure coding used is Reed-Solomon coding. @@ -8,6 +10,8 @@ type GCAS interface { AddNode(node NamedCAS) RemoveNode(nodeName string) ReplaceNode(node NamedCAS) // replaces the node with the same name with the new node + RunMaintenance(ctx context.Context) error + Repair(ctx context.Context) error } type NamedCAS interface { diff --git a/server/gcas/gcas_impl.go b/server/gcas/gcas_impl.go index 4a34b41..93ea8e6 100644 --- a/server/gcas/gcas_impl.go +++ b/server/gcas/gcas_impl.go @@ -2,28 +2,44 @@ package gcas import ( "context" + "crypto/sha256" "database/sql" "errors" + "fmt" + "log" "math/rand" "sync" + + "github.com/klauspost/reedsolomon" ) -// NewGCAS creates a new GCAS instance. -// db is the database connection to use for storing metadata +const defaultDataShards = 4 +const parityShards = 2 + +// NewGCAS creates a new GCAS instance with the default number of data shards. func NewGCAS(db *sql.DB) GCAS { + return NewGCASWithDataShards(db, defaultDataShards) +} + +// NewGCASWithDataShards creates a new GCAS instance with the given number of data shards per stripe. +// A stripe requires dataShards+2 distinct nodes to form. Fewer nodes disables stripe formation. +func NewGCASWithDataShards(db *sql.DB, dataShards int) GCAS { return &GcasImpl{ db: db, + dataShards: dataShards, nodes: make(map[string]CAS), shardedLocker: newShardedLocker(), } } type GcasImpl struct { - db *sql.DB + db *sql.DB + dataShards int // nodes connected to the cluster - nodesLock sync.RWMutex - nodes map[string]CAS - shardedLocker *shardedLocker + nodesLock sync.RWMutex + nodes map[string]CAS + shardedLocker *shardedLocker + maintenanceLock sync.Mutex // enforces that at most one maintenance runs at a time } // ReplaceNode implements [GCAS]. @@ -52,35 +68,23 @@ func (g *GcasImpl) Delete(ctx context.Context, hash Hash) error { g.shardedLocker.Lock(hash) defer g.shardedLocker.Unlock(hash) - // which node has the chunk? - // query chunks table in database - var nodeID string - err := g.db.QueryRowContext(ctx, "SELECT node_id FROM chunks WHERE hash = ?", hash[:]).Scan(&nodeID) + // soft delete from the database + // must check is_data = 1 to get the correct number of rows affected for response code + result, err := g.db.ExecContext(ctx, "UPDATE chunks SET is_data = 0 WHERE hash = ? AND is_data = 1", hash[:]) if err != nil { - if err == sql.ErrNoRows { - return HashNotFoundError{} - } return err } - // if the node is currently connected, call Delete on the node's CAS - g.nodesLock.RLock() - cas, ok := g.nodes[nodeID] - g.nodesLock.RUnlock() - - if ok { - err = cas.Delete(ctx, hash) - // if delete failed for any reason other than HashNotFoundError, propagate without touching the database - if err != nil && !errors.Is(err, HashNotFoundError{}) { - return err - } - } - - // delete from the database - _, err = g.db.ExecContext(ctx, "DELETE FROM chunks WHERE hash = ?", hash[:]) + // if no rows were updated, the chunk does not exist + numRowsAffected, err := result.RowsAffected() if err != nil { return err } + + if numRowsAffected == 0 { + return HashNotFoundError{} + } + return nil } @@ -131,7 +135,7 @@ func (g *GcasImpl) Get(ctx context.Context, hash Hash) ([]byte, error) { defer g.shardedLocker.RUnlock(hash) var nodeID string - err := g.db.QueryRowContext(ctx, "SELECT node_id FROM chunks WHERE hash = ?", hash[:]).Scan(&nodeID) + err := g.db.QueryRowContext(ctx, "SELECT node_id FROM chunks WHERE hash = ? AND is_data = 1", hash[:]).Scan(&nodeID) if err != nil { if err == sql.ErrNoRows { return nil, HashNotFoundError{} @@ -142,48 +146,181 @@ func (g *GcasImpl) Get(ctx context.Context, hash Hash) ([]byte, error) { g.nodesLock.RLock() cas, ok := g.nodes[nodeID] g.nodesLock.RUnlock() + + var primaryErr error if ok { - return cas.Get(ctx, hash) + data, err := cas.Get(ctx, hash) + if err == nil { + return data, nil + } + primaryErr = err + } else { + primaryErr = errors.New("node not connected") } - // if the chunk exists but the node is not connected, give a server error - return nil, errors.New("node not connected") + // attempt erasure coding recovery + data, err := g.ecRecover(ctx, hash) + if err == nil { + return data, nil + } + + return nil, primaryErr } -// List implements [CAS]. -func (g *GcasImpl) List(ctx context.Context) (<-chan Hash, error) { - visited := make(map[Hash]struct{}) - ch := make(chan Hash) - // the list of nodes might change while we are iterating over it. - // holding the lock while iterating could result in a deadlock if the channel is not drained. - // thus we copy the list of nodes first, accepting that the list might not be up to date. +// ecRecover attempts to reconstruct a data chunk from its erasure group. +// Returns an error if the chunk has no erasure group or recovery is not possible. +func (g *GcasImpl) ecRecover(ctx context.Context, hash Hash) ([]byte, error) { + type memberRow struct { + sliceIdx int + hashID Hash + size int + nodeID string + } + + // look up the erasure group for this chunk + var groupID int64 + var dataShards, pShards, shardSize int + err := g.db.QueryRowContext(ctx, ` + SELECT eg.id, eg.data_shards, eg.parity_shards, eg.shard_size + FROM erasure_group eg + JOIN erasure_group_member egm ON egm.erasure_group_id = eg.id + WHERE egm.hash_id = ?`, hash[:]).Scan(&groupID, &dataShards, &pShards, &shardSize) + if err != nil { + return nil, fmt.Errorf("not in any erasure group: %w", err) + } + + // load all members + rows, err := g.db.QueryContext(ctx, ` + SELECT egm.slice_idx, egm.hash_id, c.size, c.node_id + FROM erasure_group_member egm + JOIN chunks c ON c.hash = egm.hash_id + WHERE egm.erasure_group_id = ?`, groupID) + if err != nil { + return nil, err + } + defer rows.Close() + + members := make([]memberRow, 0, dataShards+pShards) + for rows.Next() { + var m memberRow + var hashBytes []byte + if err := rows.Scan(&m.sliceIdx, &hashBytes, &m.size, &m.nodeID); err != nil { + return nil, err + } + if len(hashBytes) != len(Hash{}) { + continue + } + copy(m.hashID[:], hashBytes) + members = append(members, m) + } + + // build the target slice index + targetSlice := -1 + for _, m := range members { + if m.hashID == hash { + targetSlice = m.sliceIdx + break + } + } + if targetSlice < 0 || targetSlice >= dataShards { + return nil, errors.New("target chunk is not a data shard") + } + + // fetch each shard; nil means unavailable + shards := make([][]byte, dataShards+pShards) g.nodesLock.RLock() - nodes := make([]CAS, 0, len(g.nodes)) - for _, node := range g.nodes { - nodes = append(nodes, node) + nodesCopy := make(map[string]CAS, len(g.nodes)) + for k, v := range g.nodes { + nodesCopy[k] = v } g.nodesLock.RUnlock() + present := 0 + for _, m := range members { + cas, ok := nodesCopy[m.nodeID] + if !ok { + continue + } + data, err := cas.Get(ctx, m.hashID) + if err != nil { + continue + } + // pad to shard_size so RS can operate on equal-length slices + padded := make([]byte, shardSize) + copy(padded, data) + shards[m.sliceIdx] = padded + present++ + } + + if present < dataShards { + return nil, fmt.Errorf("only %d/%d shards available for recovery", present, dataShards) + } + + enc, err := reedsolomon.New(dataShards, pShards) + if err != nil { + return nil, err + } + if err := enc.ReconstructData(shards); err != nil { + return nil, err + } + + // find original size of the target chunk + origSize := 0 + for _, m := range members { + if m.hashID == hash { + origSize = m.size + break + } + } + + result := shards[targetSlice][:origSize] + + // verify hash + if sha256.Sum256(result) != hash { + return nil, DataCorruptError{} + } + return result, nil +} + +// List implements [CAS]. +func (g *GcasImpl) List(ctx context.Context) (<-chan Hash, error) { + // select only data chunks (not parity) + rows, err := g.db.QueryContext(ctx, "SELECT hash FROM chunks WHERE is_data = 1") + if err != nil { + return nil, err + } + + ch := make(chan Hash) + go func() { defer close(ch) - for _, node := range nodes { - hashes, err := node.List(ctx) + defer rows.Close() + + for rows.Next() { + var h []byte + err := rows.Scan(&h) + + if len(h) != len(Hash{}) { + log.Printf("List: Invalid hash length %d (expected %d)", len(h), len(Hash{})) + continue + } + if err != nil { + log.Printf("List: Error scanning hash: %v", err) return } - for hash := range hashes { - if _, ok := visited[hash]; ok { - continue - } - visited[hash] = struct{}{} - select { - case ch <- hash: - case <-ctx.Done(): - return - } + + var hash Hash + copy(hash[:], h) + + select { + case ch <- hash: + case <-ctx.Done(): + return } } }() + return ch, nil } @@ -192,22 +329,27 @@ func (g *GcasImpl) Put(ctx context.Context, hash Hash, data []byte) error { g.shardedLocker.Lock(hash) defer g.shardedLocker.Unlock(hash) - // pick a random node to store the chunk - // note: golang internally randomizes the starting point of map iteration, - // however this is not guaranteed and not meant to be relied upon. - // check if the chunk already exists { var nodeID string - err := g.db.QueryRowContext(ctx, "SELECT node_id FROM chunks WHERE hash = ?", hash[:]).Scan(&nodeID) + err := g.db.QueryRowContext(ctx, "SELECT node_id FROM chunks WHERE hash = ? AND is_data = 1", hash[:]).Scan(&nodeID) if err != sql.ErrNoRows { if err != nil { return err } - - // if the chunk already exists, return HashExistsError return HashExistsError{} } + + // try to update is_data to 1 from 0 (from deleted) if the chunk exists + result, err := g.db.ExecContext(ctx, "UPDATE chunks SET is_data = 1 WHERE hash = ? AND is_data = 0", hash[:]) + if err != nil { + return err + } + + numRowsAffected, _ := result.RowsAffected() + if numRowsAffected != 0 { + return nil + } } type nodePair struct { @@ -234,7 +376,7 @@ func (g *GcasImpl) Put(ctx context.Context, hash Hash, data []byte) error { err := node.cas.Put(ctx, hash, data) - if err != nil { + if err != nil && !errors.Is(err, HashExistsError{}) { return err } @@ -242,4 +384,495 @@ func (g *GcasImpl) Put(ctx context.Context, hash Hash, data []byte) error { return err } +// RunGC runs the garbage collection process. +func (g *GcasImpl) RunGC(ctx context.Context) error { + // clean up erasure groups whose data chunks are all deleted + _, err := g.db.ExecContext(ctx, ` + DELETE FROM erasure_group + WHERE id NOT IN ( + SELECT DISTINCT egm.erasure_group_id + FROM erasure_group_member egm + JOIN chunks ON chunks.hash = egm.hash_id + WHERE chunks.is_data = 1 + )`) + if err != nil { + return err + } + + // remove members of deleted groups (sqlite has no cascade here) + _, err = g.db.ExecContext(ctx, ` + DELETE FROM erasure_group_member + WHERE erasure_group_id NOT IN (SELECT id FROM erasure_group)`) + if err != nil { + return err + } + + // remove all chunks that have been marked as deleted and are not used for parity + rows, err := g.db.QueryContext(ctx, "DELETE FROM chunks WHERE is_data = 0 AND NOT EXISTS (SELECT 1 FROM erasure_group_member WHERE hash_id = chunks.hash) RETURNING hash, node_id") + if err != nil { + return err + } + defer rows.Close() + + for rows.Next() { + var hash Hash + var nodeID string + if err := rows.Scan(hash[:], &nodeID); err != nil { + continue + } + + g.nodesLock.RLock() + cas, ok := g.nodes[nodeID] + g.nodesLock.RUnlock() + if !ok { + continue + } + if err := cas.Delete(ctx, hash); err != nil { + continue + } + } + return nil +} + +type chunkInfo struct { + hash Hash + size int + nodeID string +} + +// formStripes groups unstriped data chunks into erasure-coded stripes and stores parity. +// A stripe requires g.dataShards data chunks (each on a distinct node) plus 2 parity nodes. +func (g *GcasImpl) formStripes(ctx context.Context) error { + for { + // fetch unstriped data chunks ordered by hash for determinism + rows, err := g.db.QueryContext(ctx, ` + SELECT hash, size, node_id FROM chunks + WHERE is_data = 1 + AND hash NOT IN (SELECT hash_id FROM erasure_group_member) + ORDER BY hash`) + if err != nil { + return err + } + + // group by node_id, one chunk per node (no two data chunks on the same node in a stripe) + seen := make(map[string]bool) + var batch []chunkInfo + for rows.Next() { + var ci chunkInfo + var hashBytes []byte + if err := rows.Scan(&hashBytes, &ci.size, &ci.nodeID); err != nil { + rows.Close() + return err + } + if len(hashBytes) != len(Hash{}) { + continue + } + if seen[ci.nodeID] { + continue + } + seen[ci.nodeID] = true + copy(ci.hash[:], hashBytes) + batch = append(batch, ci) + if len(batch) == g.dataShards { + break + } + } + rows.Close() + + if len(batch) < g.dataShards { + return nil // not enough distinct-node chunks for a full stripe + } + + if err := g.encodeStripe(ctx, batch); err != nil { + log.Printf("formStripes: failed to encode stripe: %v", err) + return nil // non-fatal; try again next maintenance + } + } +} + +// encodeStripe computes parity for the given data chunks, stores parity on nodes, +// and records the erasure group in the DB. +func (g *GcasImpl) encodeStripe(ctx context.Context, dataChunks []chunkInfo) error { + k := len(dataChunks) + m := parityShards + + // snapshot nodes so we can pick parity destinations + g.nodesLock.RLock() + nodesCopy := make(map[string]CAS, len(g.nodes)) + for id, cas := range g.nodes { + nodesCopy[id] = cas + } + g.nodesLock.RUnlock() + + // read data for each chunk + shards := make([][]byte, k+m) + shardSize := 0 + for i, ci := range dataChunks { + cas, ok := nodesCopy[ci.nodeID] + if !ok { + return fmt.Errorf("node %s not connected", ci.nodeID) + } + data, err := cas.Get(ctx, ci.hash) + if err != nil { + return fmt.Errorf("read chunk %x: %w", ci.hash[:4], err) + } + shards[i] = data + if len(data) > shardSize { + shardSize = len(data) + } + } + + if shardSize == 0 { + shardSize = 1 // reedsolomon requires non-zero shard size + } + + // pad data shards to shardSize + for i := range dataChunks { + if len(shards[i]) < shardSize { + padded := make([]byte, shardSize) + copy(padded, shards[i]) + shards[i] = padded + } + } + // allocate parity shards + for i := k; i < k+m; i++ { + shards[i] = make([]byte, shardSize) + } + + enc, err := reedsolomon.New(k, m) + if err != nil { + return err + } + if err := enc.Encode(shards); err != nil { + return err + } + + // collect nodes already used by data chunks + usedNodes := make(map[string]bool, k) + for _, ci := range dataChunks { + usedNodes[ci.nodeID] = true + } + + // pick m distinct nodes not in usedNodes for parity + parityNodes := make([]struct { + id string + cas CAS + }, 0, m) + for id, cas := range nodesCopy { + if !usedNodes[id] { + parityNodes = append(parityNodes, struct { + id string + cas CAS + }{id, cas}) + usedNodes[id] = true + if len(parityNodes) == m { + break + } + } + } + if len(parityNodes) < m { + return fmt.Errorf("not enough distinct nodes for parity (%d available, need %d)", len(parityNodes), m) + } + + // store parity chunks + type parityRecord struct { + hash Hash + nodeID string + } + parityRecords := make([]parityRecord, m) + for i := 0; i < m; i++ { + ph := sha256.Sum256(shards[k+i]) + parityLock := ph + g.shardedLocker.Lock(parityLock) + + err := parityNodes[i].cas.Put(ctx, ph, shards[k+i]) + if err != nil && !errors.Is(err, HashExistsError{}) { + g.shardedLocker.Unlock(parityLock) + return fmt.Errorf("store parity shard %d: %w", i, err) + } + _, dbErr := g.db.ExecContext(ctx, + "INSERT OR IGNORE INTO chunks (hash, size, node_id, is_data) VALUES (?, ?, ?, 0)", + ph[:], shardSize, parityNodes[i].id) + g.shardedLocker.Unlock(parityLock) + if dbErr != nil { + return dbErr + } + parityRecords[i] = parityRecord{hash: ph, nodeID: parityNodes[i].id} + } + + // create erasure group record + result, err := g.db.ExecContext(ctx, + "INSERT INTO erasure_group (data_shards, parity_shards, shard_size) VALUES (?, ?, ?)", + k, m, shardSize) + if err != nil { + return err + } + groupID, err := result.LastInsertId() + if err != nil { + return err + } + + // insert members: data shards + for i, ci := range dataChunks { + _, err := g.db.ExecContext(ctx, + "INSERT INTO erasure_group_member (hash_id, erasure_group_id, slice_idx) VALUES (?, ?, ?)", + ci.hash[:], groupID, i) + if err != nil { + return err + } + } + // insert members: parity shards + for i, pr := range parityRecords { + _, err := g.db.ExecContext(ctx, + "INSERT INTO erasure_group_member (hash_id, erasure_group_id, slice_idx) VALUES (?, ?, ?)", + pr.hash[:], groupID, k+i) + if err != nil { + return err + } + } + + return nil +} + +// Repair implements [GCAS]. +// It scans all erasure groups for missing or corrupt shards and reconstructs them +// onto available nodes. +func (g *GcasImpl) Repair(ctx context.Context) error { + type groupRow struct { + id int64 + dataShards int + parityShards int + shardSize int + } + + groupRows, err := g.db.QueryContext(ctx, + "SELECT id, data_shards, parity_shards, shard_size FROM erasure_group") + if err != nil { + return err + } + + var groups []groupRow + for groupRows.Next() { + var gr groupRow + if err := groupRows.Scan(&gr.id, &gr.dataShards, &gr.parityShards, &gr.shardSize); err != nil { + groupRows.Close() + return err + } + groups = append(groups, gr) + } + groupRows.Close() + + for _, gr := range groups { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + if err := g.repairGroup(ctx, gr.id, gr.dataShards, gr.parityShards, gr.shardSize); err != nil { + log.Printf("Repair: group %d: %v", gr.id, err) + } + } + return nil +} + +func (g *GcasImpl) repairGroup(ctx context.Context, groupID int64, dataShards, pShards, shardSize int) error { + type memberInfo struct { + sliceIdx int + hash Hash + size int + nodeID string + } + + rows, err := g.db.QueryContext(ctx, ` + SELECT egm.slice_idx, egm.hash_id, c.size, c.node_id + FROM erasure_group_member egm + JOIN chunks c ON c.hash = egm.hash_id + WHERE egm.erasure_group_id = ? + ORDER BY egm.slice_idx`, groupID) + if err != nil { + return err + } + defer rows.Close() + + members := make([]memberInfo, 0, dataShards+pShards) + for rows.Next() { + var mi memberInfo + var hashBytes []byte + if err := rows.Scan(&mi.sliceIdx, &hashBytes, &mi.size, &mi.nodeID); err != nil { + return err + } + if len(hashBytes) != len(Hash{}) { + continue + } + copy(mi.hash[:], hashBytes) + members = append(members, mi) + } + rows.Close() + + g.nodesLock.RLock() + nodesCopy := make(map[string]CAS, len(g.nodes)) + for id, cas := range g.nodes { + nodesCopy[id] = cas + } + g.nodesLock.RUnlock() + + // try to read each shard + total := dataShards + pShards + shards := make([][]byte, total) + broken := make([]bool, total) + + for _, mi := range members { + cas, ok := nodesCopy[mi.nodeID] + if !ok { + broken[mi.sliceIdx] = true + continue + } + data, err := cas.Get(ctx, mi.hash) + if err != nil { + broken[mi.sliceIdx] = true + continue + } + padded := make([]byte, shardSize) + copy(padded, data) + shards[mi.sliceIdx] = padded + } + + // count broken shards + brokenCount := 0 + for _, b := range broken { + if b { + brokenCount++ + } + } + if brokenCount == 0 { + return nil // all good + } + + present := total - brokenCount + if present < dataShards { + return fmt.Errorf("unrecoverable: only %d/%d shards present", present, dataShards) + } + + // allocate nil shards so reedsolomon knows which to reconstruct + for i, b := range broken { + if b { + shards[i] = nil + } + } + + enc, err := reedsolomon.New(dataShards, pShards) + if err != nil { + return err + } + if err := enc.Reconstruct(shards); err != nil { + return err + } + + // build set of nodes currently used in this stripe + usedNodes := make(map[string]bool, total) + for _, mi := range members { + if !broken[mi.sliceIdx] { + usedNodes[mi.nodeID] = true + } + } + + // store recovered shards + for _, mi := range members { + if !broken[mi.sliceIdx] { + continue + } + + shard := shards[mi.sliceIdx] + origSize := mi.size + if mi.sliceIdx >= dataShards { + // parity shard: store full padded size + origSize = shardSize + } + reconstructed := shard[:origSize] + + // find a node not already in the stripe + var targetID string + var targetCAS CAS + for id, cas := range nodesCopy { + if !usedNodes[id] { + targetID = id + targetCAS = cas + break + } + } + if targetCAS == nil { + // fall back: try to re-use the original node if it's now connected + if cas, ok := nodesCopy[mi.nodeID]; ok { + targetID = mi.nodeID + targetCAS = cas + } + } + if targetCAS == nil { + log.Printf("Repair: no available node for shard %d of group %d", mi.sliceIdx, groupID) + continue + } + + g.shardedLocker.Lock(mi.hash) + err := targetCAS.Put(ctx, mi.hash, reconstructed) + if err != nil && !errors.Is(err, HashExistsError{}) { + g.shardedLocker.Unlock(mi.hash) + log.Printf("Repair: put shard %d: %v", mi.sliceIdx, err) + continue + } + _, dbErr := g.db.ExecContext(ctx, + "UPDATE chunks SET node_id = ? WHERE hash = ?", + targetID, mi.hash[:]) + g.shardedLocker.Unlock(mi.hash) + if dbErr != nil { + return dbErr + } + + usedNodes[targetID] = true + } + + return nil +} + +// RunMaintenance does a one-off maintenance cycle. +func (g *GcasImpl) RunMaintenance(ctx context.Context) error { + lock := g.maintenanceLock.TryLock() + if !lock { + return fmt.Errorf("maintenance already running") + } + defer g.maintenanceLock.Unlock() + + if err := g.RunGC(ctx); err != nil { + log.Printf("error while running gc: %v", err) + } + + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + if err := g.formStripes(ctx); err != nil { + log.Printf("error while forming stripes: %v", err) + } + + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + if err := g.Repair(ctx); err != nil { + log.Printf("error while repairing: %v", err) + } + + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + return nil +} + var _ GCAS = (*GcasImpl)(nil) diff --git a/server/gcas/gcas_test.go b/server/gcas/gcas_test.go index b1bb64e..e7d666b 100644 --- a/server/gcas/gcas_test.go +++ b/server/gcas/gcas_test.go @@ -331,9 +331,9 @@ func TestGCASList(t *testing.T) { } } -// TestGCASListDeduplication verifies that a hash present on multiple nodes is returned -// only once by GCAS.List. -func TestGCASListDeduplication(t *testing.T) { +// TestGCASInternalList verifies that GCAS uses its own database to look up hashes, +// and does not rely on the accuracy of the nodes' lists +func TestGCASInternalList(t *testing.T) { gcas, db, err := createTestGCAS(0) if err != nil { t.Fatal(err) @@ -345,8 +345,14 @@ func TestGCASListDeduplication(t *testing.T) { data := []byte("shared") hash := sha256.Sum256(data) - node0.DirectPut(hash, data) - node1.DirectPut(hash, data) + + // directly insert hash into nodes, should not be listed + if err := node0.Put(context.Background(), hash, data); err != nil { + t.Fatal(err) + } + if err := node1.Put(context.Background(), hash, data); err != nil { + t.Fatal(err) + } gcas.AddNode(node0) gcas.AddNode(node1) @@ -359,8 +365,25 @@ func TestGCASListDeduplication(t *testing.T) { for range ch { count++ } + if count != 0 { + t.Errorf("expected empty list, got %d elements", count) + } + + // now put hash through gcas and check that it's listed + if err := gcas.Put(context.Background(), hash, data); err != nil { + t.Fatal(err) + } + + ch, err = gcas.List(context.Background()) + if err != nil { + t.Fatal(err) + } + count = 0 + for range ch { + count++ + } if count != 1 { - t.Errorf("expected 1 deduplicated hash, got %d", count) + t.Errorf("expected 1 hash, got %d", count) } } @@ -544,39 +567,8 @@ func (d *deleteErrCAS) Delete(_ context.Context, _ Hash) error { return d.deleteErr } -// TestGCASDeleteNodeError verifies that when a connected node returns a non-HashNotFound -// error from Delete, GCAS propagates that error without modifying the database record. -func TestGCASDeleteNodeError(t *testing.T) { - gcas, db, err := createTestGCAS(0) - if err != nil { - t.Fatal(err) - } - defer db.Close() - - inner := NewMockCAS("node0") - sentinelErr := errors.New("network failure") - errCAS := &deleteErrCAS{mockCAS: inner, deleteErr: sentinelErr} - gcas.AddNode(errCAS) - - data := []byte("hello") - hash := sha256.Sum256(data) - if err = gcas.Put(context.Background(), hash, data); err != nil { - t.Fatal(err) - } - - if err = gcas.Delete(context.Background(), hash); !errors.Is(err, sentinelErr) { - t.Errorf("expected sentinel error, got %v", err) - } - - // The DB record must still exist because the delete was aborted. - _, err = gcas.Get(context.Background(), hash) - if errors.Is(err, HashNotFoundError{}) { - t.Error("DB record was removed despite node delete failure") - } -} - -// TestGCASDeleteExecError verifies that a DB failure on the DELETE statement is propagated. -// It uses a SQLite BEFORE DELETE trigger to make the ExecContext call fail after the +// TestGCASDeleteExecError verifies that a DB failure on the UPDATE statement is propagated. +// It uses a SQLite BEFORE UPDATE trigger to make the ExecContext call fail after the // initial SELECT succeeds. func TestGCASDeleteExecError(t *testing.T) { gcas, db, err := createTestGCAS(1) @@ -591,7 +583,7 @@ func TestGCASDeleteExecError(t *testing.T) { t.Fatal(err) } - _, err = db.Exec(`CREATE TRIGGER prevent_delete BEFORE DELETE ON chunks BEGIN SELECT RAISE(ABORT, 'delete prevented'); END`) + _, err = db.Exec(`CREATE TRIGGER prevent_delete BEFORE UPDATE ON chunks BEGIN SELECT RAISE(ABORT, 'delete prevented'); END`) if err != nil { t.Fatal(err) } @@ -648,15 +640,68 @@ func TestGCASPutDBError(t *testing.T) { } } +func TestGCASRunMaintenance(t *testing.T) { + gcas, db, err := createTestGCAS(1) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + // put first chunk + // this chunk will be deleted later + data1 := []byte("hello") + hash1 := sha256.Sum256(data1) + if err = gcas.Put(context.Background(), hash1, data1); err != nil { + t.Fatal(err) + } + + // put second chunk + data2 := []byte("world") + hash2 := sha256.Sum256(data2) + if err = gcas.Put(context.Background(), hash2, data2); err != nil { + t.Fatal(err) + } + + // delete first chunk + if err = gcas.Delete(context.Background(), hash1); err != nil { + t.Fatal(err) + } + + // run maintenance. it will garbage collect the first chunk + if err = gcas.RunMaintenance(context.Background()); err != nil { + t.Fatal(err) + } + + // try to get the first chunk. it should fail + _, err = gcas.Get(context.Background(), hash1) + if !errors.Is(err, HashNotFoundError{}) { + t.Errorf("expected HashNotFoundError after GC, got %v", err) + } + + // get the second chunk. it should not fail + dataRetreived, err := gcas.Get(context.Background(), hash2) + if err != nil { + t.Errorf("expected success after GC, got %v", err) + } + + if !bytes.Equal(dataRetreived, data2) { + t.Errorf("expected data %v after GC, got %v", data2, dataRetreived) + } +} + func createTestGCAS(numNodes int) (GCAS, *sql.DB, error) { - db, err := OpenDB(":memory:") - gcas := NewGCAS(db) + return createTestGCASWithDataShards(numNodes, defaultDataShards) +} +func createTestGCASWithDataShards(numNodes, dataShards int) (GCAS, *sql.DB, error) { + db, err := OpenDB(":memory:") if err != nil { return nil, nil, err } - nodes := make([]NamedCAS, numNodes) + gcas := NewGCASWithDataShards(db, dataShards) + + nodes := make([]*mockCAS, numNodes) for i := 0; i < numNodes; i++ { nodes[i] = NewMockCAS(fmt.Sprintf("node%d", i)) } @@ -667,3 +712,249 @@ func createTestGCAS(numNodes int) (GCAS, *sql.DB, error) { return gcas, db, nil } + +// testPutToNode directly places a chunk on a specific node, bypassing Put's random +// assignment. Used by EC tests to guarantee deterministic stripe layout. +func testPutToNode(t *testing.T, db *sql.DB, nodes map[string]*mockCAS, nodeID string, hash Hash, data []byte) { + t.Helper() + if err := nodes[nodeID].Put(context.Background(), hash, data); err != nil && !errors.Is(err, HashExistsError{}) { + t.Fatalf("testPutToNode %s: %v", nodeID, err) + } + if _, err := db.Exec("INSERT OR IGNORE INTO chunks (hash, size, node_id) VALUES (?, ?, ?)", hash[:], len(data), nodeID); err != nil { + t.Fatalf("testPutToNode DB insert: %v", err) + } +} + +// testSetupStripe places k chunks on nodes 0..k-1 deterministically and runs +// maintenance to form a stripe. Returns the k data hashes. +func testSetupStripe(t *testing.T, gcas GCAS, db *sql.DB, nodes map[string]*mockCAS, k int) []Hash { + t.Helper() + hashes := make([]Hash, k) + for i := 0; i < k; i++ { + data := []byte(fmt.Sprintf("stripe-data-%d", i)) + h := sha256.Sum256(data) + hashes[i] = h + testPutToNode(t, db, nodes, fmt.Sprintf("node%d", i), h, data) + } + if err := gcas.RunMaintenance(context.Background()); err != nil { + t.Fatalf("RunMaintenance: %v", err) + } + return hashes +} + +// TestGCASErasureCoding verifies that maintenance forms an erasure group when +// enough distinct-node chunks exist. +func TestGCASErasureCoding(t *testing.T) { + const k = 2 + gcas, db, nodes := createTestGCASWithNodes(t, k+parityShards, k) + defer db.Close() + + testSetupStripe(t, gcas, db, nodes, k) + + var groupCount int + if err := db.QueryRow("SELECT COUNT(*) FROM erasure_group").Scan(&groupCount); err != nil { + t.Fatal(err) + } + if groupCount != 1 { + t.Errorf("expected 1 erasure group, got %d", groupCount) + } + + var memberCount int + if err := db.QueryRow("SELECT COUNT(*) FROM erasure_group_member").Scan(&memberCount); err != nil { + t.Fatal(err) + } + if memberCount != k+parityShards { + t.Errorf("expected %d erasure group members, got %d", k+parityShards, memberCount) + } +} + +// TestGCASStripeNodeConstraint verifies that no two stripe members share a node. +func TestGCASStripeNodeConstraint(t *testing.T) { + const k = 2 + gcas, db, nodes := createTestGCASWithNodes(t, k+parityShards, k) + defer db.Close() + + testSetupStripe(t, gcas, db, nodes, k) + + rows, err := db.Query(` + SELECT c.node_id FROM erasure_group_member egm + JOIN chunks c ON c.hash = egm.hash_id + WHERE egm.erasure_group_id = 1`) + if err != nil { + t.Fatal(err) + } + defer rows.Close() + + seen := make(map[string]bool) + for rows.Next() { + var nodeID string + if err := rows.Scan(&nodeID); err != nil { + t.Fatal(err) + } + if seen[nodeID] { + t.Errorf("node %s appears more than once in the stripe", nodeID) + } + seen[nodeID] = true + } +} + +// TestGCASGetNodeFailure verifies that Get succeeds via EC recovery when a +// single node holding a data chunk is removed. +func TestGCASGetNodeFailure(t *testing.T) { + const k = 2 + gcas, db, nodes := createTestGCASWithNodes(t, k+parityShards, k) + defer db.Close() + + hashes := testSetupStripe(t, gcas, db, nodes, k) + + // hashes[0] is on node0 (deterministic placement) + gcas.RemoveNode("node0") + + // Get should recover via EC + data, err := gcas.Get(context.Background(), hashes[0]) + if err != nil { + t.Errorf("expected EC recovery to succeed, got: %v", err) + } + expected := []byte("stripe-data-0") + if string(data) != string(expected) { + t.Errorf("recovered data mismatch: got %q, want %q", data, expected) + } +} + +// TestGCASGetTwoNodeFailure verifies EC recovery with 2 nodes down (maximum +// tolerable for 2 parity shards). +func TestGCASGetTwoNodeFailure(t *testing.T) { + const k = 2 + gcas, db, nodes := createTestGCASWithNodes(t, k+parityShards, k) + defer db.Close() + + hashes := testSetupStripe(t, gcas, db, nodes, k) + + // Remove both data nodes (node0 and node1); only the 2 parity nodes survive + gcas.RemoveNode("node0") + gcas.RemoveNode("node1") + + // With k=2 data shards and 2 parity shards, losing 2 shards is still recoverable + data, err := gcas.Get(context.Background(), hashes[0]) + if err != nil { + t.Errorf("expected EC recovery with 2 node failures, got: %v", err) + } + expected := []byte("stripe-data-0") + if string(data) != string(expected) { + t.Errorf("recovered data mismatch: got %q, want %q", data, expected) + } +} + +// TestGCASGetUnrecoverableFailure verifies that Get fails when more nodes are +// down than the parity count allows. +func TestGCASGetUnrecoverableFailure(t *testing.T) { + const k = 2 + gcas, db, nodes := createTestGCASWithNodes(t, k+parityShards, k) + defer db.Close() + + hashes := testSetupStripe(t, gcas, db, nodes, k) + + // Remove all 4 stripe nodes — 0 shards survive, need k=2 for recovery + rows, err := db.Query(` + SELECT DISTINCT c.node_id FROM erasure_group_member egm + JOIN chunks c ON c.hash = egm.hash_id + WHERE egm.erasure_group_id = 1`) + if err != nil { + t.Fatal(err) + } + var toRemove []string + for rows.Next() { + var n string + rows.Scan(&n) + toRemove = append(toRemove, n) + } + rows.Close() + for _, n := range toRemove { + gcas.RemoveNode(n) + } + + _, err = gcas.Get(context.Background(), hashes[0]) + if err == nil { + t.Error("expected Get to fail with all nodes down, got nil") + } +} + +// TestGCASRepairAndGet removes a node, runs Repair, and verifies Get succeeds +// without the original node. +func TestGCASRepairAndGet(t *testing.T) { + const k = 2 + // k+parityShards nodes for the stripe + 1 spare so Repair has a node to place recovered shard + gcas, db, nodes := createTestGCASWithNodes(t, k+parityShards+1, k) + defer db.Close() + + hashes := testSetupStripe(t, gcas, db, nodes, k) + + // hashes[0] is on node0 (deterministic placement); remove it + gcas.RemoveNode("node0") + + // Before repair: Get should fail (primary node gone, EC recovery still works but + // after repair the shard is placed on the spare node and Get uses the direct path) + // Run repair to restore the shard to the spare node + if err := gcas.Repair(context.Background()); err != nil { + t.Fatalf("Repair: %v", err) + } + + // After repair, Get should succeed even without node0 + data, err := gcas.Get(context.Background(), hashes[0]) + if err != nil { + t.Errorf("expected Get to succeed after Repair, got: %v", err) + } + expected := []byte("stripe-data-0") + if string(data) != string(expected) { + t.Errorf("data mismatch after repair: got %q, want %q", data, expected) + } +} + +// TestGCASRepairCorruptData verifies that Repair restores a shard whose data +// has been corrupted on the node. +func TestGCASRepairCorruptData(t *testing.T) { + const k = 2 + // +1 spare node so Repair can place the recovered shard somewhere other than the corrupt node + gcas, db, nodes := createTestGCASWithNodes(t, k+parityShards+1, k) + defer db.Close() + + hashes := testSetupStripe(t, gcas, db, nodes, k) + + // Corrupt data for hashes[0] on node0 (deterministic placement) + nodes["node0"].CorruptData(hashes[0]) + + // Repair should reconstruct hashes[0] onto the spare node + if err := gcas.Repair(context.Background()); err != nil { + t.Fatalf("Repair: %v", err) + } + + // After repair, Get should return correct data + data, err := gcas.Get(context.Background(), hashes[0]) + if err != nil { + t.Errorf("expected Get to succeed after Repair, got: %v", err) + } + expected := []byte("stripe-data-0") + if string(data) != string(expected) { + t.Errorf("data mismatch after repair: got %q, want %q", data, expected) + } +} + +// createTestGCASWithNodes is like createTestGCASWithDataShards but returns the +// node map so tests can corrupt or inspect individual nodes. +func createTestGCASWithNodes(t *testing.T, numNodes, dataShards int) (GCAS, *sql.DB, map[string]*mockCAS) { + t.Helper() + db, err := OpenDB(":memory:") + if err != nil { + t.Fatal(err) + } + + gcas := NewGCASWithDataShards(db, dataShards) + nodeMap := make(map[string]*mockCAS, numNodes) + for i := 0; i < numNodes; i++ { + name := fmt.Sprintf("node%d", i) + node := NewMockCAS(name) + nodeMap[name] = node + gcas.AddNode(node) + } + return gcas, db, nodeMap +} diff --git a/server/gcas/migrations/2_erasure_coding.down.sql b/server/gcas/migrations/2_erasure_coding.down.sql new file mode 100644 index 0000000..2dd15db --- /dev/null +++ b/server/gcas/migrations/2_erasure_coding.down.sql @@ -0,0 +1,9 @@ +-- drop erasure coding tables +DROP TABLE erasure_group_member; +DROP TABLE erasure_group; + +-- remove chunks that are not data +DELETE FROM chunks WHERE is_data = 0; + +-- remove is_data column from chunks +ALTER TABLE chunks DROP COLUMN is_data; \ No newline at end of file diff --git a/server/gcas/migrations/2_erasure_coding.up.sql b/server/gcas/migrations/2_erasure_coding.up.sql new file mode 100644 index 0000000..6e41a73 --- /dev/null +++ b/server/gcas/migrations/2_erasure_coding.up.sql @@ -0,0 +1,20 @@ +-- add column to chunks to indicate if the chunk is part of the data +ALTER TABLE chunks +ADD COLUMN is_data BOOLEAN DEFAULT 1; + +-- table of erasure coding groups +CREATE TABLE erasure_group ( + id INTEGER PRIMARY KEY, + data_shards INTEGER NOT NULL, + parity_shards INTEGER NOT NULL DEFAULT 2, + shard_size INTEGER NOT NULL -- max chunk size in stripe (bytes), used for padding +); + +-- map data and parity chunks to their erasure group +-- slice_idx: 0..data_shards-1 = data chunks, data_shards..data_shards+parity_shards-1 = parity +CREATE TABLE erasure_group_member ( + hash_id BLOB(32) PRIMARY KEY, + erasure_group_id INTEGER NOT NULL, + slice_idx INTEGER NOT NULL, + FOREIGN KEY (erasure_group_id) REFERENCES erasure_group(id) +); \ No newline at end of file