diff --git a/ENVIRONMENT.md b/ENVIRONMENT.md index 0dad97ed..cac8bd5f 100644 --- a/ENVIRONMENT.md +++ b/ENVIRONMENT.md @@ -16,6 +16,16 @@ If 1, always prints the Homeserver container logs even on success. When used wit This allows you to override the base image used for a particular named homeserver. For example, `COMPLEMENT_BASE_IMAGE_HS1=complement-dendrite:latest` would use `complement-dendrite:latest` for the `hs1` homeserver in blueprints, but not any other homeserver (e.g `hs2`). This matching is case-insensitive. This allows Complement to test how different homeserver implementations work with each other. - Type: `map[string]string` +#### `COMPLEMENT_CONTAINER_CPU_CORES` +The number of CPU cores available for the container to use (can be fractional like 0.5). This is passed to Docker as the `--cpus`/`NanoCPUs` argument. If 0, no limit is set and the container can use all available host CPUs. This is useful to mimic a resource-constrained environment, like a CI environment. +- Type: `float64` +- Default: 0 + +#### `COMPLEMENT_CONTAINER_MEMORY` +The maximum amount of memory the container can use (ex. "1GB"). Valid units are "B", (decimal: "KB", "MB", "GB, "TB, "PB"), (binary: "KiB", "MiB", "GiB", "TiB", "PiB") or no units (bytes) (case-insensitive). We also support "K", "M", "G" as per Docker's CLI. The number of bytes is passed to Docker as the `--memory`/`Memory` argument. If 0, no limit is set and the container can use all available host memory. This is useful to mimic a resource-constrained environment, like a CI environment. +- Type: `int64` +- Default: 0 + #### `COMPLEMENT_DEBUG` If 1, prints out more verbose logging such as HTTP request/response bodies. - Type: `bool` diff --git a/README.md b/README.md index 6a0b2cf5..4e3797f9 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ To solve this, you will need to configure your firewall to allow such requests. If you are using [ufw](https://code.launchpad.net/ufw), this can be done with: ```sh -sudo ufw allow in on br-+ +sudo ufw allow in on br-+ comment "(from Matrix Complement testing) Allow traffic from custom Docker networks to the host machine (host.docker.internal)" ``` ### Running using Podman @@ -96,6 +96,7 @@ If you're looking to run against a custom Dockerfile, it must meet the following - The Dockerfile must `EXPOSE 8008` and `EXPOSE 8448` for client and federation traffic respectively. - The homeserver should run and listen on these ports. +- The homeserver should listen on plain HTTP for client traffic and HTTPS for federation traffic. See [Complement PKI](#Complement-PKI) below. - The homeserver should become healthy within `COMPLEMENT_SPAWN_HS_TIMEOUT_SECS` if a `HEALTHCHECK` is specified in the Dockerfile. - The homeserver needs to `200 OK` requests to `GET /_matrix/client/versions`. - The homeserver needs to manage its own storage within the image. diff --git a/cmd/gendoc/main.go b/cmd/gendoc/main.go index eaa6938a..a55b5c30 100644 --- a/cmd/gendoc/main.go +++ b/cmd/gendoc/main.go @@ -1,3 +1,5 @@ +// Usage: `go run ./cmd/gendoc --config config/config.go > ENVIRONMENT.md` + package main import ( diff --git a/cmd/homerunner/README.md b/cmd/homerunner/README.md index f0f19136..1635b3b6 100644 --- a/cmd/homerunner/README.md +++ b/cmd/homerunner/README.md @@ -28,7 +28,7 @@ HOMERUNNER_KEEP_BLUEPRINTS='name-of-blueprint' ./homerunner ``` This is neccessary to stop Homerunner from cleaning up the image. Then perform a single POST request: ``` -curl -XPOST -d '{"blueprint_name":"name-of-blueprint"}' +curl -XPOST -d '{"blueprint_name":"name-of-blueprint"}' http://localhost:54321/create { "homeservers":{ "hs1":{ diff --git a/cmd/homerunner/routes.go b/cmd/homerunner/routes.go index 616524fe..38b6e44b 100644 --- a/cmd/homerunner/routes.go +++ b/cmd/homerunner/routes.go @@ -10,8 +10,8 @@ import ( func Routes(rt *Runtime, cfg *Config) http.Handler { mux := mux.NewRouter() - mux.Path("/create").Methods("POST").HandlerFunc( - util.WithCORSOptions(util.MakeJSONAPI(util.NewJSONRequestHandler( + mux.Path("/create").Methods("POST", "OPTIONS").HandlerFunc( + withCORS(util.MakeJSONAPI(util.NewJSONRequestHandler( func(req *http.Request) util.JSONResponse { rc := ReqCreate{} if err := json.NewDecoder(req.Body).Decode(&rc); err != nil { @@ -21,8 +21,8 @@ func Routes(rt *Runtime, cfg *Config) http.Handler { }, ))), ) - mux.Path("/destroy").Methods("POST").HandlerFunc( - util.WithCORSOptions(util.MakeJSONAPI(util.NewJSONRequestHandler( + mux.Path("/destroy").Methods("POST", "OPTIONS").HandlerFunc( + withCORS(util.MakeJSONAPI(util.NewJSONRequestHandler( func(req *http.Request) util.JSONResponse { rc := ReqDestroy{} if err := json.NewDecoder(req.Body).Decode(&rc); err != nil { @@ -32,10 +32,18 @@ func Routes(rt *Runtime, cfg *Config) http.Handler { }, ))), ) - mux.Path("/health").Methods("GET").HandlerFunc( - func(res http.ResponseWriter, req *http.Request) { + mux.Path("/health").Methods("GET", "OPTIONS").HandlerFunc( + withCORS(func(res http.ResponseWriter, req *http.Request) { res.WriteHeader(200) - }, + }), ) return mux } + +// withCORS intercepts all requests and adds CORS headers. +func withCORS(handler http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, req *http.Request) { + util.SetCORSHeaders(w) + handler(w, req) + } +} diff --git a/config/config.go b/config/config.go index 5260a62e..9394b37a 100644 --- a/config/config.go +++ b/config/config.go @@ -11,6 +11,7 @@ import ( "math/big" "os" "regexp" + "sort" "strconv" "strings" "time" @@ -52,6 +53,23 @@ type Complement struct { // starting the container. Responsiveness is detected by `HEALTHCHECK` being healthy *and* // the `/versions` endpoint returning 200 OK. SpawnHSTimeout time.Duration + // Name: COMPLEMENT_CONTAINER_CPU_CORES + // Default: 0 + // Description: The number of CPU cores available for the container to use (can be + // fractional like 0.5). This is passed to Docker as the `--cpus`/`NanoCPUs` argument. + // If 0, no limit is set and the container can use all available host CPUs. This is + // useful to mimic a resource-constrained environment, like a CI environment. + ContainerCPUCores float64 + // Name: COMPLEMENT_CONTAINER_MEMORY + // Default: 0 + // Description: The maximum amount of memory the container can use (ex. "1GB"). Valid + // units are "B", (decimal: "KB", "MB", "GB, "TB, "PB"), (binary: "KiB", "MiB", "GiB", + // "TiB", "PiB") or no units (bytes) (case-insensitive). We also support "K", "M", "G" + // as per Docker's CLI. The number of bytes is passed to Docker as the + // `--memory`/`Memory` argument. If 0, no limit is set and the container can use all + // available host memory. This is useful to mimic a resource-constrained environment, + // like a CI environment. + ContainerMemoryBytes int64 // Name: COMPLEMENT_KEEP_BLUEPRINTS // Description: A list of space separated blueprint names to not clean up after running. For example, // `one_to_one_room alice` would not delete the homeserver images for the blueprints `alice` and @@ -145,8 +163,13 @@ func NewConfigFromEnvVars(pkgNamespace, baseImageURI string) *Complement { // each iteration had a 50ms sleep between tries so the timeout is 50 * iteration ms cfg.SpawnHSTimeout = time.Duration(50*parseEnvWithDefault("COMPLEMENT_VERSION_CHECK_ITERATIONS", 100)) * time.Millisecond } + cfg.ContainerCPUCores = parseEnvAsFloatWithDefault("COMPLEMENT_CONTAINER_CPU_CORES", 0) + parsedMemoryBytes, err := parseByteSizeString(os.Getenv("COMPLEMENT_CONTAINER_MEMORY")) + if err != nil { + panic("COMPLEMENT_CONTAINER_MEMORY parse error: " + err.Error()) + } + cfg.ContainerMemoryBytes = parsedMemoryBytes cfg.KeepBlueprints = strings.Split(os.Getenv("COMPLEMENT_KEEP_BLUEPRINTS"), " ") - var err error hostMounts := os.Getenv("COMPLEMENT_HOST_MOUNTS") if hostMounts != "" { cfg.HostMounts, err = newHostMounts(strings.Split(hostMounts, ";")) @@ -214,17 +237,132 @@ func (c *Complement) CAPrivateKeyBytes() ([]byte, error) { return caKey.Bytes(), err } -func parseEnvWithDefault(key string, def int) int { - s := os.Getenv(key) - if s != "" { - i, err := strconv.Atoi(s) - if err != nil { - // Don't bother trying to report it - return def +func parseEnvWithDefault(key string, defaultValue int) int { + inputString := os.Getenv(key) + if inputString == "" { + return defaultValue + } + + parsedNumber, err := strconv.Atoi(inputString) + if err != nil { + panic(key + " parse error: " + err.Error()) + } + return parsedNumber +} + +func parseEnvAsFloatWithDefault(key string, defaultValue float64) float64 { + inputString := os.Getenv(key) + if inputString == "" { + return defaultValue + } + + parsedNumber, err := strconv.ParseFloat(inputString, 64) + if err != nil { + panic(key + " parse error: " + err.Error()) + } + return parsedNumber +} + +// parseByteSizeString parses a byte size string (case insensitive) like "512MB" +// or "2GB" into bytes. If the string is empty, 0 is returned. Returns an error if the +// string does not match one of the valid units or is an invalid integer. +// +// Valid units are "B", (decimal: "KB", "MB", "GB, "TB, "PB"), (binary: "KiB", "MiB", +// "GiB", "TiB", "PiB") or no units (bytes). We also support "K", "M", "G" as per +// Docker's CLI. +func parseByteSizeString(inputString string) (int64, error) { + // Strip spaces and normalize to lowercase + normalizedString := strings.TrimSpace(strings.ToLower(inputString)) + if normalizedString == "" { + return 0, nil + } + unitToByteMultiplierMap := map[string]int64{ + // No unit (bytes) + "": 1, + "b": 1, + "kb": intPow(10, 3), + "mb": intPow(10, 6), + "gb": intPow(10, 9), + "tb": intPow(10, 12), + "kib": 1024, + "mib": intPow(1024, 2), + "gib": intPow(1024, 3), + "tib": intPow(1024, 4), + // These are also supported to match Docker's CLI + "k": 1024, + "m": intPow(1024, 2), + "g": intPow(1024, 3), + } + availableUnitsSorted := make([]string, 0, len(unitToByteMultiplierMap)) + for unit := range unitToByteMultiplierMap { + availableUnitsSorted = append(availableUnitsSorted, unit) + } + // Sort units by length descending so that longer units are matched first + // (e.g "mib" before "b") + sort.Slice(availableUnitsSorted, func(i, j int) bool { + return len(availableUnitsSorted[i]) > len(availableUnitsSorted[j]) + }) + + // Find the number part of the string and the unit used + numberPart := "" + byteUnit := "" + byteMultiplier := int64(0) + for _, unit := range availableUnitsSorted { + if strings.HasSuffix(normalizedString, unit) { + byteUnit = unit + // Handle the case where there is a space between the number and the unit (e.g "512 MB") + numberPart = strings.TrimSpace(normalizedString[:len(normalizedString)-len(unit)]) + byteMultiplier = unitToByteMultiplierMap[unit] + break } - return i } - return def + + // Failed to find a valid unit + if byteUnit == "" { + return 0, fmt.Errorf("parseByteSizeString: invalid byte unit used in string: %s (supported units: %s)", + inputString, + strings.Join(availableUnitsSorted, ", "), + ) + } + // Assert to sanity check our logic above is sound + if byteMultiplier == 0 { + panic(fmt.Sprintf( + "parseByteSizeString: byteMultiplier is unexpectedly 0 for unit: %s. "+ + "This is probably a problem with the function itself.", byteUnit, + )) + } + + // Parse the number part as an int64 + parsedNumber, err := strconv.ParseInt(strings.TrimSpace(numberPart), 10, 64) + if err != nil { + return 0, fmt.Errorf("parseByteSizeString: failed to parse number part of string: %s (%w)", + numberPart, + err, + ) + } + + // Calculate the total bytes + totalBytes := parsedNumber * byteMultiplier + return totalBytes, nil +} + +// intPow calculates n to the mth power. Since the result is an int, it is assumed that m is a positive power +// +// via https://stackoverflow.com/questions/64108933/how-to-use-math-pow-with-integers-in-go/66429580#66429580 +func intPow(n, m int64) int64 { + if m == 0 { + return 1 + } + + if m == 1 { + return n + } + + result := n + for i := int64(2); i <= m; i++ { + result *= n + } + return result } func newHostMounts(mounts []string) ([]HostMount, error) { diff --git a/internal/docker/builder.go b/internal/docker/builder.go index fe11902d..8b0edc99 100644 --- a/internal/docker/builder.go +++ b/internal/docker/builder.go @@ -32,13 +32,6 @@ import ( "github.com/matrix-org/complement/internal/instruction" ) -var ( - // HostnameRunningDocker is the hostname of the docker daemon from the perspective of Complement. - HostnameRunningDocker = "localhost" - // HostnameRunningComplement is the hostname of Complement from the perspective of a Homeserver. - HostnameRunningComplement = "host.docker.internal" -) - const complementLabel = "complement_context" type Builder struct { diff --git a/internal/docker/deployer.go b/internal/docker/deployer.go index 0a8a511a..6255a945 100644 --- a/internal/docker/deployer.go +++ b/internal/docker/deployer.go @@ -348,7 +348,7 @@ func deployImage( // interact with a complement-controlled test server. // Note: this feature of docker landed in Docker 20.10, // see https://github.com/moby/moby/pull/40007 - extraHosts = []string{"host.docker.internal:host-gateway"} + extraHosts = []string{fmt.Sprintf("%s:host-gateway", cfg.HostnameRunningComplement)} } for _, m := range cfg.HostMounts { @@ -399,6 +399,18 @@ func deployImage( PublishAllPorts: true, ExtraHosts: extraHosts, Mounts: mounts, + // https://docs.docker.com/engine/containers/resource_constraints/ + Resources: container.Resources{ + // Constrain the the number of CPU cores this container can use + // + // The number of CPU cores in 1e9 increments + // + // `NanoCPUs` is the option that is "Applicable to all platforms" instead of + // `CPUPeriod`/`CPUQuota` (Unix only) or `CPUCount`/`CPUPercent` (Windows only). + NanoCPUs: int64(cfg.ContainerCPUCores * 1e9), + // Constrain the maximum memory the container can use + Memory: cfg.ContainerMemoryBytes, + }, }, &network.NetworkingConfig{ EndpointsConfig: map[string]*network.EndpointSettings{ networkName: { @@ -415,7 +427,20 @@ func deployImage( containerID := body.ID if cfg.DebugLoggingEnabled { - log.Printf("%s: Created container '%s' using image '%s' on network '%s'", contextStr, containerID, imageID, networkName) + constraintStrings := []string{} + if cfg.ContainerCPUCores > 0 { + constraintStrings = append(constraintStrings, fmt.Sprintf("%.1f CPU cores", cfg.ContainerCPUCores)) + } + if cfg.ContainerMemoryBytes > 0 { + // TODO: It would be nice to pretty print this in MB/GB etc. + constraintStrings = append(constraintStrings, fmt.Sprintf("%d bytes of memory", cfg.ContainerMemoryBytes)) + } + constrainedResourcesDisplayString := "" + if len(constraintStrings) > 0 { + constrainedResourcesDisplayString = fmt.Sprintf("(%s)", strings.Join(constraintStrings, ", ")) + } + + log.Printf("%s: Created container '%s' using image '%s' on network '%s' %s", contextStr, containerID, imageID, networkName, constrainedResourcesDisplayString) } stubDeployment := &HomeserverDeployment{ ContainerID: containerID, diff --git a/tests/csapi/upload_keys_test.go b/tests/csapi/upload_keys_test.go index fcf5f5ec..7308e3b4 100644 --- a/tests/csapi/upload_keys_test.go +++ b/tests/csapi/upload_keys_test.go @@ -180,11 +180,11 @@ func TestKeyClaimOrdering(t *testing.T) { deployment := complement.Deploy(t, 1) defer deployment.Destroy(t) alice := deployment.Register(t, "hs1", helpers.RegistrationOpts{}) - _, oneTimeKeys := alice.MustGenerateOneTimeKeys(t, 2) + deviceKeys, oneTimeKeys := alice.MustGenerateOneTimeKeys(t, 2) // first upload key 1, sleep a bit, then upload key 0. otk1 := map[string]interface{}{"signed_curve25519:1": oneTimeKeys["signed_curve25519:1"]} - alice.MustUploadKeys(t, nil, otk1) + alice.MustUploadKeys(t, deviceKeys, otk1) // Ensure that there is a difference in timestamp between the two upload requests. time.Sleep(1 * time.Second) diff --git a/tests/msc4140/delayed_event_test.go b/tests/msc4140/delayed_event_test.go index 8def039e..3e59cf20 100644 --- a/tests/msc4140/delayed_event_test.go +++ b/tests/msc4140/delayed_event_test.go @@ -2,6 +2,8 @@ package tests import ( "fmt" + "io" + "math" "net/http" "net/url" "testing" @@ -50,7 +52,7 @@ func TestDelayedEvents(t *testing.T) { user2.MustJoinRoom(t, roomID, nil) t.Run("delayed events are empty on startup", func(t *testing.T) { - matchDelayedEvents(t, user, 0) + matchDelayedEvents(t, user, delayedEventsNumberEqual(0)) }) t.Run("delayed event lookups are authenticated", func(t *testing.T) { @@ -100,14 +102,14 @@ func TestDelayedEvents(t *testing.T) { } countExpected = 0 - matchDelayedEvents(t, user, numEvents) + matchDelayedEvents(t, user, delayedEventsNumberEqual(numEvents)) t.Run("cannot get delayed events of another user", func(t *testing.T) { - matchDelayedEvents(t, user2, 0) + matchDelayedEvents(t, user2, delayedEventsNumberEqual(0)) }) time.Sleep(1 * time.Second) - matchDelayedEvents(t, user, 0) + matchDelayedEvents(t, user, delayedEventsNumberEqual(0)) queryParams := url.Values{} queryParams.Set("dir", "f") queryParams.Set("from", token) @@ -149,7 +151,7 @@ func TestDelayedEvents(t *testing.T) { getDelayQueryParam("900"), ) - matchDelayedEvents(t, user, 1) + matchDelayedEvents(t, user, delayedEventsNumberEqual(1)) res = getDelayedEvents(t, user) must.MatchResponse(t, res, match.HTTPResponse{ @@ -172,7 +174,7 @@ func TestDelayedEvents(t *testing.T) { }) time.Sleep(1 * time.Second) - matchDelayedEvents(t, user, 0) + matchDelayedEvents(t, user, delayedEventsNumberEqual(0)) res = user.MustDo(t, "GET", getPathForState(roomID, eventType, stateKey)) must.MatchResponse(t, res, match.HTTPResponse{ JSON: []match.JSON{ @@ -244,7 +246,7 @@ func TestDelayedEvents(t *testing.T) { delayID := client.GetJSONFieldStr(t, client.ParseJSON(t, res), "delay_id") time.Sleep(1 * time.Second) - matchDelayedEvents(t, user, 1) + matchDelayedEvents(t, user, delayedEventsNumberEqual(1)) res = user.Do(t, "GET", getPathForState(roomID, eventType, stateKey)) must.MatchResponse(t, res, match.HTTPResponse{ StatusCode: 404, @@ -256,7 +258,7 @@ func TestDelayedEvents(t *testing.T) { getPathForUpdateDelayedEvent(delayID, DelayedEventActionCancel), client.WithJSONBody(t, map[string]interface{}{}), ) - matchDelayedEvents(t, user, 0) + matchDelayedEvents(t, user, delayedEventsNumberEqual(0)) time.Sleep(1 * time.Second) res = user.Do(t, "GET", getPathForState(roomID, eventType, stateKey)) @@ -286,7 +288,7 @@ func TestDelayedEvents(t *testing.T) { delayID := client.GetJSONFieldStr(t, client.ParseJSON(t, res), "delay_id") time.Sleep(1 * time.Second) - matchDelayedEvents(t, user, 1) + matchDelayedEvents(t, user, delayedEventsNumberEqual(1)) res = user.Do(t, "GET", getPathForState(roomID, eventType, stateKey)) must.MatchResponse(t, res, match.HTTPResponse{ StatusCode: 404, @@ -298,7 +300,7 @@ func TestDelayedEvents(t *testing.T) { getPathForUpdateDelayedEvent(delayID, DelayedEventActionSend), client.WithJSONBody(t, map[string]interface{}{}), ) - matchDelayedEvents(t, user, 0) + matchDelayedEvents(t, user, delayedEventsNumberEqual(0)) res = user.Do(t, "GET", getPathForState(roomID, eventType, stateKey)) must.MatchResponse(t, res, match.HTTPResponse{ JSON: []match.JSON{ @@ -328,7 +330,7 @@ func TestDelayedEvents(t *testing.T) { delayID := client.GetJSONFieldStr(t, client.ParseJSON(t, res), "delay_id") time.Sleep(1 * time.Second) - matchDelayedEvents(t, user, 1) + matchDelayedEvents(t, user, delayedEventsNumberEqual(1)) res = user.Do(t, "GET", getPathForState(roomID, eventType, stateKey)) must.MatchResponse(t, res, match.HTTPResponse{ StatusCode: 404, @@ -342,14 +344,14 @@ func TestDelayedEvents(t *testing.T) { ) time.Sleep(1 * time.Second) - matchDelayedEvents(t, user, 1) + matchDelayedEvents(t, user, delayedEventsNumberEqual(1)) res = user.Do(t, "GET", getPathForState(roomID, eventType, stateKey)) must.MatchResponse(t, res, match.HTTPResponse{ StatusCode: 404, }) time.Sleep(1 * time.Second) - matchDelayedEvents(t, user, 0) + matchDelayedEvents(t, user, delayedEventsNumberEqual(0)) res = user.MustDo(t, "GET", getPathForState(roomID, eventType, stateKey)) must.MatchResponse(t, res, match.HTTPResponse{ JSON: []match.JSON{ @@ -376,7 +378,7 @@ func TestDelayedEvents(t *testing.T) { }), getDelayQueryParam("900"), ) - matchDelayedEvents(t, user, 1) + matchDelayedEvents(t, user, delayedEventsNumberEqual(1)) user.MustDo( t, @@ -386,7 +388,7 @@ func TestDelayedEvents(t *testing.T) { setterKey: "manual", }), ) - matchDelayedEvents(t, user, 1) + matchDelayedEvents(t, user, delayedEventsNumberEqual(1)) time.Sleep(1 * time.Second) res = user.MustDo(t, "GET", getPathForState(roomID, eventType, stateKey)) @@ -415,7 +417,7 @@ func TestDelayedEvents(t *testing.T) { }), getDelayQueryParam("900"), ) - matchDelayedEvents(t, user, 1) + matchDelayedEvents(t, user, delayedEventsNumberEqual(1)) setterExpected := "manual" user2.MustDo( @@ -426,7 +428,7 @@ func TestDelayedEvents(t *testing.T) { setterKey: setterExpected, }), ) - matchDelayedEvents(t, user, 0) + matchDelayedEvents(t, user, delayedEventsNumberEqual(0)) time.Sleep(1 * time.Second) res = user.MustDo(t, "GET", getPathForState(roomID, eventType, stateKey)) @@ -446,6 +448,10 @@ func TestDelayedEvents(t *testing.T) { stateKey1 := "1" stateKey2 := "2" + numberOfDelayedEvents := 0 + + // Send an initial delayed event that will be ready to send as soon as the server + // comes back up. user.MustDo( t, "PUT", @@ -453,24 +459,78 @@ func TestDelayedEvents(t *testing.T) { client.WithJSONBody(t, map[string]interface{}{}), getDelayQueryParam("900"), ) - user.MustDo( - t, - "PUT", - getPathForState(roomID, eventType, stateKey2), - client.WithJSONBody(t, map[string]interface{}{}), - getDelayQueryParam("9900"), - ) - matchDelayedEvents(t, user, 2) + numberOfDelayedEvents++ + + // Previously, this was naively using a single delayed event with a 10 second delay. + // But because we're stopping and starting servers here, it could take up to + // `deployment.GetConfig().SpawnHSTimeout` (defaults to 30 seconds) for the server + // to start up again so by the time the server is back up, the delayed event may + // have already been sent invalidating our assertions below (which expect some + // delayed events to still be pending and then see one of them be sent after the + // server is back up). + // + // We could account for this by setting the delayed event delay to be longer than + // `deployment.GetConfig().SpawnHSTimeout` but that would make the test suite take + // longer to run in all cases even for homeservers that are quick to restart because + // we have to wait for that large delay. + // + // We instead account for this by scheduling many delayed events at short intervals + // (we chose 10 seconds because that's what the test naively chose before). Then + // whenever the servers comes back, we can just check until it decrements by 1. + // + // We add 1 to the number of intervals to ensure that we have at least one interval + // to check against no matter how things are configured. + numberOf10SecondIntervals := int(math.Ceil(deployment.GetConfig().SpawnHSTimeout.Seconds()/10)) + 1 + for i := 0; i < numberOf10SecondIntervals; i++ { + // +1 as we want to start at 10 seconds and so we don't end up with -100ms delay + // on the first one. + delay := time.Duration(i+1)*10*time.Second - 100*time.Millisecond + + user.MustDo( + t, + "PUT", + // Avoid clashing state keys as that would cancel previous delayed events on the + // same key (start at 2). + getPathForState(roomID, eventType, fmt.Sprintf("%d", i+2)), + client.WithJSONBody(t, map[string]interface{}{}), + getDelayQueryParam(fmt.Sprintf("%d", delay.Milliseconds())), + ) + numberOfDelayedEvents++ + } + // We expect all of the delayed events to be scheduled and not sent yet. + matchDelayedEvents(t, user, delayedEventsNumberEqual(numberOfDelayedEvents)) + // Restart the server and wait until it's back up. deployment.StopServer(t, hsName) + // Wait one second which will cause the first delayed event to be ready to be sent + // when the server is back up. time.Sleep(1 * time.Second) deployment.StartServer(t, hsName) - matchDelayedEvents(t, user, 1) + delayedEventResponse := matchDelayedEvents(t, user, + // We should still see some delayed events left after the restart. + delayedEventsNumberGreaterThan(0), + // We should see at-least one less than we had before the restart (the first + // delayed event should have been sent). Other delayed events may have been sent + // by the time the server actually came back up. + delayedEventsNumberLessThan(numberOfDelayedEvents-1), + ) + // Capture whatever number of delayed events are remaining after the server restart. + remainingDelayedEventCount := countDelayedEvents(t, delayedEventResponse) + // Sanity check that the room state was updated correctly with the delayed events + // that were sent. user.MustDo(t, "GET", getPathForState(roomID, eventType, stateKey1)) - time.Sleep(9 * time.Second) - matchDelayedEvents(t, user, 0) + // Wait until we see another delayed event being sent (ensure things resumed and are continuing). + time.Sleep(10 * time.Second) + matchDelayedEvents(t, user, + delayedEventsNumberLessThan(remainingDelayedEventCount), + ) + // Sanity check that the other delayed events also updated the room state correctly. + // + // FIXME: Ideally, we'd check specifically for the last one that was sent but it + // will be a bit of a juggle and fiddly to get this right so for now we just check + // one. user.MustDo(t, "GET", getPathForState(roomID, eventType, stateKey2)) }) } @@ -502,25 +562,93 @@ func getDelayedEvents(t *testing.T, user *client.CSAPI) *http.Response { return user.MustDo(t, "GET", getPathForDelayedEvents()) } -// Checks if the number of delayed events match the given number. This will +// countDelayedEvents counts the number of delayed events in the response. Assumes the +// response is well-formed. +func countDelayedEventsInternal(res *http.Response) (int, error) { + body, err := io.ReadAll(res.Body) + if err != nil { + return 0, fmt.Errorf("countDelayedEventsInternal: Failed to read response body: %s", err) + } + + parsedBody := gjson.ParseBytes(body) + return len(parsedBody.Get("delayed_events").Array()), nil +} + +func countDelayedEvents(t *testing.T, res *http.Response) int { + t.Helper() + count, err := countDelayedEventsInternal(res) + if err != nil { + t.Fatalf("countDelayedEvents: %s", err) + } + return count +} + +type delayedEventsCheckOpt func(res *http.Response) error + +// delayedEventsNumberEqual returns a check option that checks if the number of delayed events +// is equal to the given number. +func delayedEventsNumberEqual(wantNumber int) delayedEventsCheckOpt { + return func(res *http.Response) error { + _, err := should.MatchResponse(res, match.HTTPResponse{ + StatusCode: 200, + JSON: []match.JSON{ + match.JSONKeyArrayOfSize("delayed_events", wantNumber), + }, + }) + if err == nil { + return nil + } + return fmt.Errorf("delayedEventsNumberEqual(%d): %s", wantNumber, err) + } +} + +// delayedEventsNumberLessThan returns a check option that checks if the number of delayed events +// is greater than the given number. +func delayedEventsNumberGreaterThan(target int) delayedEventsCheckOpt { + return func(res *http.Response) error { + count, err := countDelayedEventsInternal(res) + if err != nil { + return fmt.Errorf("delayedEventsNumberGreaterThan(%d): %s", target, err) + } + if count > target { + return nil + } + return fmt.Errorf("delayedEventsNumberGreaterThan(%d): got %d", target, count) + } +} + +// delayedEventsNumberLessThan returns a check option that checks if the number of delayed events +// is less than the given number. +func delayedEventsNumberLessThan(target int) delayedEventsCheckOpt { + return func(res *http.Response) error { + count, err := countDelayedEventsInternal(res) + if err != nil { + return fmt.Errorf("delayedEventsNumberLessThan(%d): %s", target, err) + } + if count < target { + return nil + } + return fmt.Errorf("delayedEventsNumberLessThan(%d): got %d", target, count) + } +} + +// matchDelayedEvents will run the given checks on the delayed events response. This will // retry to handle replication lag. -func matchDelayedEvents(t *testing.T, user *client.CSAPI, wantNumber int) { +func matchDelayedEvents(t *testing.T, user *client.CSAPI, checks ...delayedEventsCheckOpt) *http.Response { t.Helper() // We need to retry this as replication can sometimes lag. - user.MustDo(t, "GET", getPathForDelayedEvents(), + return user.MustDo(t, "GET", getPathForDelayedEvents(), client.WithRetryUntil( 500*time.Millisecond, func(res *http.Response) bool { - _, err := should.MatchResponse(res, match.HTTPResponse{ - StatusCode: 200, - JSON: []match.JSON{ - match.JSONKeyArrayOfSize("delayed_events", wantNumber), - }, - }) - if err != nil { - t.Log(err) - return false + for _, check := range checks { + err := check(res) + + if err != nil { + t.Log(err) + return false + } } return true }, @@ -543,5 +671,5 @@ func cleanupDelayedEvents(t *testing.T, user *client.CSAPI) { ) } - matchDelayedEvents(t, user, 0) + matchDelayedEvents(t, user, delayedEventsNumberEqual(0)) }