diff --git a/main.go b/main.go index afcb401..295aee5 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,7 @@ import ( "net/http" "os" "os/signal" + "strconv" "syscall" "time" ) @@ -19,6 +20,15 @@ func getEnv(key, defaultVal string) string { return defaultVal } +func getEnvInt(key string, defaultVal int) int { + if v := os.Getenv(key); v != "" { + if n, err := strconv.Atoi(v); err == nil { + return n + } + } + return defaultVal +} + // securityHeaders はセキュリティ関連のHTTPヘッダを付与するミドルウェア。 func securityHeaders(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -47,10 +57,13 @@ func main() { mux := http.NewServeMux() registerRoutes(mux, store) + rl := newRateLimiter(getEnvInt("RATE_LIMIT", 60), time.Minute) + rl.startCleanup() + addr := ":" + getEnv("PORT", "8080") srv := &http.Server{ Addr: addr, - Handler: securityHeaders(requestLogger(mux)), + Handler: securityHeaders(rateLimitMiddleware(rl)(requestLogger(mux))), ReadTimeout: 10 * time.Second, WriteTimeout: 30 * time.Second, IdleTimeout: 60 * time.Second, diff --git a/ratelimit.go b/ratelimit.go new file mode 100644 index 0000000..3d87f66 --- /dev/null +++ b/ratelimit.go @@ -0,0 +1,106 @@ +package main + +import ( + "log/slog" + "net" + "net/http" + "strconv" + "strings" + "sync" + "time" +) + +type rateLimiter struct { + mu sync.Mutex + visitors map[string]*visitor + limit int + window time.Duration +} + +type visitor struct { + count int + resetTime time.Time +} + +func newRateLimiter(limit int, window time.Duration) *rateLimiter { + return &rateLimiter{ + visitors: make(map[string]*visitor), + limit: limit, + window: window, + } +} + +func (rl *rateLimiter) startCleanup() { + go func() { + ticker := time.NewTicker(rl.window) + defer ticker.Stop() + for range ticker.C { + rl.mu.Lock() + now := time.Now() + for ip, v := range rl.visitors { + if now.After(v.resetTime) { + delete(rl.visitors, ip) + } + } + rl.mu.Unlock() + } + }() +} + +func (rl *rateLimiter) allow(ip string) (bool, int, time.Time) { + rl.mu.Lock() + defer rl.mu.Unlock() + + now := time.Now() + v, exists := rl.visitors[ip] + if !exists || now.After(v.resetTime) { + rl.visitors[ip] = &visitor{count: 1, resetTime: now.Add(rl.window)} + return true, rl.limit - 1, now.Add(rl.window) + } + v.count++ + remaining := rl.limit - v.count + if remaining < 0 { + remaining = 0 + } + return v.count <= rl.limit, remaining, v.resetTime +} + +func clientIP(r *http.Request) string { + if forwarded := r.Header.Get("X-Forwarded-For"); forwarded != "" { + parts := strings.SplitN(forwarded, ",", 2) + ip := strings.TrimSpace(parts[0]) + if ip != "" { + return ip + } + } + if realIP := r.Header.Get("X-Real-IP"); realIP != "" { + return strings.TrimSpace(realIP) + } + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + return r.RemoteAddr + } + return host +} + +func rateLimitMiddleware(rl *rateLimiter) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + next.ServeHTTP(w, r) + return + } + ip := clientIP(r) + allowed, remaining, resetTime := rl.allow(ip) + w.Header().Set("X-RateLimit-Limit", strconv.Itoa(rl.limit)) + w.Header().Set("X-RateLimit-Remaining", strconv.Itoa(remaining)) + w.Header().Set("X-RateLimit-Reset", strconv.FormatInt(resetTime.Unix(), 10)) + if !allowed { + slog.Warn("レート制限超過", "ip", ip) + http.Error(w, "リクエストが多すぎます。しばらくお待ちください。", http.StatusTooManyRequests) + return + } + next.ServeHTTP(w, r) + }) + } +} diff --git a/ratelimit_test.go b/ratelimit_test.go new file mode 100644 index 0000000..8998cc9 --- /dev/null +++ b/ratelimit_test.go @@ -0,0 +1,194 @@ +package main + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestRateLimiterAllow(t *testing.T) { + rl := newRateLimiter(3, time.Minute) + ip := "192.168.1.1" + + for i := 0; i < 3; i++ { + allowed, _, _ := rl.allow(ip) + if !allowed { + t.Fatalf("request %d should be allowed", i+1) + } + } + + allowed, _, _ := rl.allow(ip) + if allowed { + t.Fatal("4th request should be denied") + } +} + +func TestRateLimiterDifferentIPs(t *testing.T) { + rl := newRateLimiter(1, time.Minute) + + allowed1, _, _ := rl.allow("1.1.1.1") + allowed2, _, _ := rl.allow("2.2.2.2") + + if !allowed1 || !allowed2 { + t.Fatal("different IPs should have separate limits") + } +} + +func TestRateLimiterWindowReset(t *testing.T) { + rl := newRateLimiter(1, 50*time.Millisecond) + ip := "192.168.1.1" + + allowed, _, _ := rl.allow(ip) + if !allowed { + t.Fatal("first request should be allowed") + } + + allowed, _, _ = rl.allow(ip) + if allowed { + t.Fatal("second request should be denied") + } + + time.Sleep(60 * time.Millisecond) + + allowed, _, _ = rl.allow(ip) + if !allowed { + t.Fatal("request after window reset should be allowed") + } +} + +func TestRateLimiterRemaining(t *testing.T) { + rl := newRateLimiter(3, time.Minute) + ip := "192.168.1.1" + + _, remaining, _ := rl.allow(ip) + if remaining != 2 { + t.Fatalf("expected 2 remaining, got %d", remaining) + } + + _, remaining, _ = rl.allow(ip) + if remaining != 1 { + t.Fatalf("expected 1 remaining, got %d", remaining) + } + + _, remaining, _ = rl.allow(ip) + if remaining != 0 { + t.Fatalf("expected 0 remaining, got %d", remaining) + } + + _, remaining, _ = rl.allow(ip) + if remaining != 0 { + t.Fatalf("expected 0 remaining after exhaustion, got %d", remaining) + } +} + +func TestClientIP(t *testing.T) { + tests := []struct { + name string + remoteAddr string + forwarded string + realIP string + expected string + }{ + {"RemoteAddr", "192.168.1.1:12345", "", "", "192.168.1.1"}, + {"X-Forwarded-For", "10.0.0.1:12345", "203.0.113.50, 70.41.3.18", "", "203.0.113.50"}, + {"X-Real-IP", "10.0.0.1:12345", "", "203.0.113.99", "203.0.113.99"}, + {"X-Forwarded-For優先", "10.0.0.1:12345", "203.0.113.50", "203.0.113.99", "203.0.113.50"}, + {"ポートなしRemoteAddr", "192.168.1.1", "", "", "192.168.1.1"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := httptest.NewRequest("GET", "/", nil) + r.RemoteAddr = tt.remoteAddr + if tt.forwarded != "" { + r.Header.Set("X-Forwarded-For", tt.forwarded) + } + if tt.realIP != "" { + r.Header.Set("X-Real-IP", tt.realIP) + } + got := clientIP(r) + if got != tt.expected { + t.Errorf("expected %q, got %q", tt.expected, got) + } + }) + } +} + +func TestRateLimitMiddlewareGETBypass(t *testing.T) { + rl := newRateLimiter(1, time.Minute) + handler := rateLimitMiddleware(rl)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + for i := 0; i < 5; i++ { + req := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("GET request %d should bypass rate limit, got %d", i+1, w.Code) + } + } +} + +func TestRateLimitMiddlewarePOSTLimited(t *testing.T) { + rl := newRateLimiter(2, time.Minute) + handler := rateLimitMiddleware(rl)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + for i := 0; i < 2; i++ { + req := httptest.NewRequest("POST", "/lists", nil) + req.RemoteAddr = "192.168.1.1:12345" + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("POST request %d should be allowed, got %d", i+1, w.Code) + } + } + + req := httptest.NewRequest("POST", "/lists", nil) + req.RemoteAddr = "192.168.1.1:12345" + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + if w.Code != http.StatusTooManyRequests { + t.Fatalf("3rd POST should be rate limited, got %d", w.Code) + } +} + +func TestRateLimitHeaders(t *testing.T) { + rl := newRateLimiter(5, time.Minute) + handler := rateLimitMiddleware(rl)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("POST", "/lists", nil) + req.RemoteAddr = "192.168.1.1:12345" + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + + if w.Header().Get("X-RateLimit-Limit") != "5" { + t.Errorf("expected X-RateLimit-Limit=5, got %s", w.Header().Get("X-RateLimit-Limit")) + } + if w.Header().Get("X-RateLimit-Remaining") != "4" { + t.Errorf("expected X-RateLimit-Remaining=4, got %s", w.Header().Get("X-RateLimit-Remaining")) + } + if w.Header().Get("X-RateLimit-Reset") == "" { + t.Error("expected X-RateLimit-Reset to be set") + } +} + +func TestRateLimiterConcurrentAccess(t *testing.T) { + rl := newRateLimiter(100, time.Minute) + done := make(chan bool, 50) + + for i := 0; i < 50; i++ { + go func() { + rl.allow("192.168.1.1") + done <- true + }() + } + + for i := 0; i < 50; i++ { + <-done + } +}