Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
274 changes: 200 additions & 74 deletions src/sharding/router/shard-router.service.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { Injectable, Logger } from '@nestjs/common';
import { Injectable, Logger, OnModuleDestroy } from '@nestjs/common';
import { createHash } from 'crypto';
import {
ShardConfig,
Expand All @@ -8,6 +8,67 @@ import {
} from '../interfaces/shard.interface';
import { ShardConfigService } from '../shard-config.service';

type RangeBucket = { min: number; max: number; shardId: string };

interface RoutingSnapshot {
ring: ShardNode[];
rangeBuckets: RangeBucket[];
shardsById: Map<string, ShardConfig>;
}

class AsyncReadWriteLock {
private activeReaders = 0;
private activeWriter = false;
private waitingWriters: Array<() => void> = [];

async write<T>(operation: () => T | Promise<T>): Promise<T> {
await this.acquireWrite();
try {
return await operation();
} finally {
this.releaseWrite();
}
}

read<T>(operation: () => T): T {
this.activeReaders++;
try {
return operation();
} finally {
this.releaseRead();
}
}

private async acquireWrite(): Promise<void> {
if (!this.activeWriter && this.activeReaders === 0) {
this.activeWriter = true;
return;
}

await new Promise<void>((resolve) => this.waitingWriters.push(resolve));
this.activeWriter = true;
}

private releaseRead(): void {
this.activeReaders--;
this.drainWriters();
}

private releaseWrite(): void {
this.activeWriter = false;
this.drainWriters();
}

private drainWriters(): void {
if (this.activeWriter || this.activeReaders > 0 || this.waitingWriters.length === 0) {
return;
}

const nextWriter = this.waitingWriters.shift();
nextWriter?.();
}
}

/**
* ShardRouter
*
Expand All @@ -22,21 +83,31 @@ import { ShardConfigService } from '../shard-config.service';
* ensure an even key distribution even with a small shard count.
*/
@Injectable()
export class ShardRouter {
export class ShardRouter implements OnModuleDestroy {
private readonly logger = new Logger(ShardRouter.name);
private readonly VIRTUAL_NODES_PER_SHARD = 150;
private readonly MAX_UINT32 = 0xffffffff;
private readonly routingLock = new AsyncReadWriteLock();
private unsubscribeConfigUpdates?: () => void;

/** Sorted list of virtual-node → shard mappings for consistent hashing */
private ring: ShardNode[] = [];

/** Range buckets: [min, max) → shardId */
private rangeBuckets: Array<{ min: number; max: number; shardId: string }> = [];
/** Current immutable routing view used by route() calls */
private routingSnapshot: RoutingSnapshot = {
ring: [],
rangeBuckets: [],
shardsById: new Map(),
};

constructor(private readonly shardConfigService: ShardConfigService) {
this.unsubscribeConfigUpdates = this.shardConfigService.onConfigUpdated(() =>
this.reloadConfig(),
);
this.rebuildRing();
}

onModuleDestroy(): void {
this.unsubscribeConfigUpdates?.();
}

// ---------------------------------------------------------------------------
// Public routing API
// ---------------------------------------------------------------------------
Expand All @@ -54,48 +125,51 @@ export class ShardRouter {
): ShardRoutingResult {
const start = Date.now();

let shard: ShardConfig;

switch (strategy) {
case ShardStrategy.TENANT_BASED:
shard = this.routeByTenant(key);
break;
case ShardStrategy.RANGE_BASED:
shard = this.routeByRange(key);
break;
case ShardStrategy.HASH_BASED:
default:
shard = this.routeByHash(key);
break;
}
return this.routingLock.read(() => {
const snapshot = this.routingSnapshot;
let shard: ShardConfig;

switch (strategy) {
case ShardStrategy.TENANT_BASED:
shard = this.routeByTenant(snapshot, key);
break;
case ShardStrategy.RANGE_BASED:
shard = this.routeByRange(snapshot, key);
break;
case ShardStrategy.HASH_BASED:
default:
shard = this.routeByHash(snapshot, key);
break;
}

const isReplica = false;
if (forRead && shard.readReplicas?.length) {
// Pick a replica using weighted random selection
const replica = this.pickWeightedReplica(shard);
if (replica) {
// Return a synthetic ShardConfig representing the replica
const replicaShard: ShardConfig = {
...shard,
id: replica.id,
host: replica.host,
port: replica.port,
};
return {
shard: replicaShard,
isReplica: true,
routingKey: key,
resolutionTimeMs: Date.now() - start,
};
const isReplica = false;
if (forRead && shard.readReplicas?.length) {
// Pick a replica using weighted random selection
const replica = this.pickWeightedReplica(shard);
if (replica) {
// Return a synthetic ShardConfig representing the replica
const replicaShard: ShardConfig = {
...shard,
id: replica.id,
host: replica.host,
port: replica.port,
};
return {
shard: replicaShard,
isReplica: true,
routingKey: key,
resolutionTimeMs: Date.now() - start,
};
}
}
}

return {
shard,
isReplica,
routingKey: key,
resolutionTimeMs: Date.now() - start,
};
return {
shard,
isReplica,
routingKey: key,
resolutionTimeMs: Date.now() - start,
};
});
}

/**
Expand All @@ -104,15 +178,63 @@ export class ShardRouter {
*/
rebuildRing(): void {
const activeShards = this.shardConfigService.getActiveShards();
if (activeShards.length === 0) {
const nextSnapshot = this.buildRoutingSnapshot(activeShards, this.routingSnapshot.rangeBuckets);
this.routingSnapshot = nextSnapshot;

if (nextSnapshot.ring.length === 0) {
this.logger.warn('No active shards available — consistent-hash ring is empty');
this.ring = [];
return;
}

this.logger.log(
`Consistent-hash ring rebuilt with ${nextSnapshot.ring.length} virtual nodes ` +
`across ${activeShards.length} active shard(s)`,
);
}

/**
* Reload shard configuration and atomically publish a new routing snapshot.
*/
async reloadConfig(): Promise<void> {
await this.routingLock.write(async () => {
this.shardConfigService.reloadConfig();
const activeShards = this.shardConfigService.getActiveShards();
const nextSnapshot = this.buildRoutingSnapshot(
activeShards,
this.routingSnapshot.rangeBuckets,
);
this.routingSnapshot = nextSnapshot;

if (nextSnapshot.ring.length === 0) {
this.logger.warn('Shard config reload produced an empty consistent-hash ring');
return;
}

this.logger.log(
`Shard config reloaded; ring now has ${nextSnapshot.ring.length} virtual nodes ` +
`across ${activeShards.length} active shard(s)`,
);
});
}

private buildRoutingSnapshot(
activeShards: ShardConfig[],
rangeBuckets: RangeBucket[],
): RoutingSnapshot {
if (activeShards.length === 0) {
return {
ring: [],
rangeBuckets: [...rangeBuckets],
shardsById: new Map(),
};
}

const nodes: ShardNode[] = [];
const shardsById = new Map<string, ShardConfig>();

for (const shard of activeShards) {
shardsById.set(shard.id, shard);

// Scale virtual-node count by weight (100 = default)
const vnodeCount = Math.round((this.VIRTUAL_NODES_PER_SHARD * shard.weight) / 100);

Expand All @@ -124,72 +246,76 @@ export class ShardRouter {

// Sort ascending by virtual-node position
nodes.sort((a, b) => a.virtualNode - b.virtualNode);
this.ring = nodes;

this.logger.log(
`Consistent-hash ring rebuilt with ${this.ring.length} virtual nodes ` +
`across ${activeShards.length} active shard(s)`,
);
return {
ring: nodes,
rangeBuckets: [...rangeBuckets],
shardsById,
};
}

/**
* Configure range buckets for RANGE_BASED routing.
* @param buckets Ordered, non-overlapping range definitions
*/
setRangeBuckets(buckets: Array<{ min: number; max: number; shardId: string }>): void {
this.rangeBuckets = [...buckets].sort((a, b) => a.min - b.min);
this.logger.log(`Range buckets configured: ${JSON.stringify(this.rangeBuckets)}`);
setRangeBuckets(buckets: RangeBucket[]): void {
const rangeBuckets = [...buckets].sort((a, b) => a.min - b.min);
this.routingSnapshot = {
...this.routingSnapshot,
rangeBuckets,
};
this.logger.log(`Range buckets configured: ${JSON.stringify(rangeBuckets)}`);
}

// ---------------------------------------------------------------------------
// Strategy implementations
// ---------------------------------------------------------------------------

private routeByHash(key: string): ShardConfig {
if (this.ring.length === 0) {
private routeByHash(snapshot: RoutingSnapshot, key: string): ShardConfig {
if (snapshot.ring.length === 0) {
throw new Error('ShardRouter: consistent-hash ring is empty — no active shards');
}

const keyHash = this.hash32(key);
const idx = this.findRingPosition(keyHash);
const shardId = this.ring[idx].shardId;
const idx = this.findRingPosition(snapshot.ring, keyHash);
const shardId = snapshot.ring[idx].shardId;

const shard = this.shardConfigService.getShardById(shardId);
const shard = snapshot.shardsById.get(shardId);
if (!shard) {
throw new Error(`ShardRouter: shard "${shardId}" not found in configuration`);
}
return shard;
}

private routeByTenant(tenantKey: string): ShardConfig {
private routeByTenant(snapshot: RoutingSnapshot, tenantKey: string): ShardConfig {
// Tenant keys are expected in the form "tenant:<tenantId>:<entityKey>" or just a tenantId.
// We normalise by stripping the prefix and hashing the tenant segment only
// so that all data for a given tenant always lands on the same shard.
const tenantId = tenantKey.replace(/^tenant:/, '').split(':')[0];
return this.routeByHash(`tenant:${tenantId}`);
return this.routeByHash(snapshot, `tenant:${tenantId}`);
}

private routeByRange(key: string): ShardConfig {
private routeByRange(snapshot: RoutingSnapshot, key: string): ShardConfig {
const numeric = parseInt(key, 10);
if (isNaN(numeric)) {
this.logger.warn(`RANGE_BASED routing: non-numeric key "${key}" — falling back to hash`);
return this.routeByHash(key);
return this.routeByHash(snapshot, key);
}

if (this.rangeBuckets.length === 0) {
if (snapshot.rangeBuckets.length === 0) {
this.logger.warn('RANGE_BASED routing: no range buckets configured — falling back to hash');
return this.routeByHash(key);
return this.routeByHash(snapshot, key);
}

const bucket = this.rangeBuckets.find((b) => numeric >= b.min && numeric < b.max);
const bucket = snapshot.rangeBuckets.find((b) => numeric >= b.min && numeric < b.max);
if (!bucket) {
this.logger.warn(
`RANGE_BASED routing: key ${numeric} falls outside all buckets — falling back to hash`,
);
return this.routeByHash(key);
return this.routeByHash(snapshot, key);
}

const shard = this.shardConfigService.getShardById(bucket.shardId);
const shard = snapshot.shardsById.get(bucket.shardId);
if (!shard) {
throw new Error(`ShardRouter: range bucket points to unknown shard "${bucket.shardId}"`);
}
Expand All @@ -204,20 +330,20 @@ export class ShardRouter {
* Binary search for the first virtual-node at or after `hash`.
* Wraps around to index 0 when hash exceeds the last virtual-node.
*/
private findRingPosition(hash: number): number {
private findRingPosition(ring: ShardNode[], hash: number): number {
let lo = 0;
let hi = this.ring.length - 1;
let hi = ring.length - 1;

while (lo <= hi) {
const mid = (lo + hi) >>> 1;
if (this.ring[mid].virtualNode < hash) {
if (ring[mid].virtualNode < hash) {
lo = mid + 1;
} else {
hi = mid - 1;
}
}

return lo % this.ring.length; // wrap around
return lo % ring.length; // wrap around
}

/** Deterministic 32-bit FNV-1a-style hash via Node's crypto module */
Expand Down
Loading
Loading