Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net/http"
"os"
"os/signal"
"strconv"
"syscall"
"time"
)
Expand All @@ -19,6 +20,15 @@ func getEnv(key, defaultVal string) string {
return defaultVal
}

func getEnvInt(key string, defaultVal int) int {
if v := os.Getenv(key); v != "" {
if n, err := strconv.Atoi(v); err == nil {
return n
}
}
return defaultVal
}

// securityHeaders はセキュリティ関連のHTTPヘッダを付与するミドルウェア。
func securityHeaders(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -47,10 +57,13 @@ func main() {
mux := http.NewServeMux()
registerRoutes(mux, store)

rl := newRateLimiter(getEnvInt("RATE_LIMIT", 60), time.Minute)
rl.startCleanup()

addr := ":" + getEnv("PORT", "8080")
srv := &http.Server{
Addr: addr,
Handler: securityHeaders(requestLogger(mux)),
Handler: securityHeaders(rateLimitMiddleware(rl)(requestLogger(mux))),
ReadTimeout: 10 * time.Second,
WriteTimeout: 30 * time.Second,
IdleTimeout: 60 * time.Second,
Expand Down
106 changes: 106 additions & 0 deletions ratelimit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package main

import (
"log/slog"
"net"
"net/http"
"strconv"
"strings"
"sync"
"time"
)

type rateLimiter struct {
mu sync.Mutex
visitors map[string]*visitor
limit int
window time.Duration
}

type visitor struct {
count int
resetTime time.Time
}

func newRateLimiter(limit int, window time.Duration) *rateLimiter {
return &rateLimiter{
visitors: make(map[string]*visitor),
limit: limit,
window: window,
}
}

func (rl *rateLimiter) startCleanup() {
go func() {
ticker := time.NewTicker(rl.window)
defer ticker.Stop()
for range ticker.C {
rl.mu.Lock()
now := time.Now()
for ip, v := range rl.visitors {
if now.After(v.resetTime) {
delete(rl.visitors, ip)
}
}
rl.mu.Unlock()
}
}()
}

func (rl *rateLimiter) allow(ip string) (bool, int, time.Time) {
rl.mu.Lock()
defer rl.mu.Unlock()

now := time.Now()
v, exists := rl.visitors[ip]
if !exists || now.After(v.resetTime) {
rl.visitors[ip] = &visitor{count: 1, resetTime: now.Add(rl.window)}
return true, rl.limit - 1, now.Add(rl.window)
}
v.count++
remaining := rl.limit - v.count
if remaining < 0 {
remaining = 0
}
return v.count <= rl.limit, remaining, v.resetTime
}

func clientIP(r *http.Request) string {
if forwarded := r.Header.Get("X-Forwarded-For"); forwarded != "" {
parts := strings.SplitN(forwarded, ",", 2)
ip := strings.TrimSpace(parts[0])
if ip != "" {
return ip
}
}
if realIP := r.Header.Get("X-Real-IP"); realIP != "" {
return strings.TrimSpace(realIP)
}
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
}
return host
}

func rateLimitMiddleware(rl *rateLimiter) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodGet {
next.ServeHTTP(w, r)
return
}
ip := clientIP(r)
allowed, remaining, resetTime := rl.allow(ip)
w.Header().Set("X-RateLimit-Limit", strconv.Itoa(rl.limit))
w.Header().Set("X-RateLimit-Remaining", strconv.Itoa(remaining))
w.Header().Set("X-RateLimit-Reset", strconv.FormatInt(resetTime.Unix(), 10))
if !allowed {
slog.Warn("レート制限超過", "ip", ip)
http.Error(w, "リクエストが多すぎます。しばらくお待ちください。", http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
}
194 changes: 194 additions & 0 deletions ratelimit_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
package main

import (
"net/http"
"net/http/httptest"
"testing"
"time"
)

func TestRateLimiterAllow(t *testing.T) {
rl := newRateLimiter(3, time.Minute)
ip := "192.168.1.1"

for i := 0; i < 3; i++ {
allowed, _, _ := rl.allow(ip)
if !allowed {
t.Fatalf("request %d should be allowed", i+1)
}
}

allowed, _, _ := rl.allow(ip)
if allowed {
t.Fatal("4th request should be denied")
}
}

func TestRateLimiterDifferentIPs(t *testing.T) {
rl := newRateLimiter(1, time.Minute)

allowed1, _, _ := rl.allow("1.1.1.1")
allowed2, _, _ := rl.allow("2.2.2.2")

if !allowed1 || !allowed2 {
t.Fatal("different IPs should have separate limits")
}
}

func TestRateLimiterWindowReset(t *testing.T) {
rl := newRateLimiter(1, 50*time.Millisecond)
ip := "192.168.1.1"

allowed, _, _ := rl.allow(ip)
if !allowed {
t.Fatal("first request should be allowed")
}

allowed, _, _ = rl.allow(ip)
if allowed {
t.Fatal("second request should be denied")
}

time.Sleep(60 * time.Millisecond)

allowed, _, _ = rl.allow(ip)
if !allowed {
t.Fatal("request after window reset should be allowed")
}
}

func TestRateLimiterRemaining(t *testing.T) {
rl := newRateLimiter(3, time.Minute)
ip := "192.168.1.1"

_, remaining, _ := rl.allow(ip)
if remaining != 2 {
t.Fatalf("expected 2 remaining, got %d", remaining)
}

_, remaining, _ = rl.allow(ip)
if remaining != 1 {
t.Fatalf("expected 1 remaining, got %d", remaining)
}

_, remaining, _ = rl.allow(ip)
if remaining != 0 {
t.Fatalf("expected 0 remaining, got %d", remaining)
}

_, remaining, _ = rl.allow(ip)
if remaining != 0 {
t.Fatalf("expected 0 remaining after exhaustion, got %d", remaining)
}
}

func TestClientIP(t *testing.T) {
tests := []struct {
name string
remoteAddr string
forwarded string
realIP string
expected string
}{
{"RemoteAddr", "192.168.1.1:12345", "", "", "192.168.1.1"},
{"X-Forwarded-For", "10.0.0.1:12345", "203.0.113.50, 70.41.3.18", "", "203.0.113.50"},
{"X-Real-IP", "10.0.0.1:12345", "", "203.0.113.99", "203.0.113.99"},
{"X-Forwarded-For優先", "10.0.0.1:12345", "203.0.113.50", "203.0.113.99", "203.0.113.50"},
{"ポートなしRemoteAddr", "192.168.1.1", "", "", "192.168.1.1"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := httptest.NewRequest("GET", "/", nil)
r.RemoteAddr = tt.remoteAddr
if tt.forwarded != "" {
r.Header.Set("X-Forwarded-For", tt.forwarded)
}
if tt.realIP != "" {
r.Header.Set("X-Real-IP", tt.realIP)
}
got := clientIP(r)
if got != tt.expected {
t.Errorf("expected %q, got %q", tt.expected, got)
}
})
}
}

func TestRateLimitMiddlewareGETBypass(t *testing.T) {
rl := newRateLimiter(1, time.Minute)
handler := rateLimitMiddleware(rl)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))

for i := 0; i < 5; i++ {
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("GET request %d should bypass rate limit, got %d", i+1, w.Code)
}
}
}

func TestRateLimitMiddlewarePOSTLimited(t *testing.T) {
rl := newRateLimiter(2, time.Minute)
handler := rateLimitMiddleware(rl)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))

for i := 0; i < 2; i++ {
req := httptest.NewRequest("POST", "/lists", nil)
req.RemoteAddr = "192.168.1.1:12345"
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("POST request %d should be allowed, got %d", i+1, w.Code)
}
}

req := httptest.NewRequest("POST", "/lists", nil)
req.RemoteAddr = "192.168.1.1:12345"
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusTooManyRequests {
t.Fatalf("3rd POST should be rate limited, got %d", w.Code)
}
}

func TestRateLimitHeaders(t *testing.T) {
rl := newRateLimiter(5, time.Minute)
handler := rateLimitMiddleware(rl)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))

req := httptest.NewRequest("POST", "/lists", nil)
req.RemoteAddr = "192.168.1.1:12345"
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)

if w.Header().Get("X-RateLimit-Limit") != "5" {
t.Errorf("expected X-RateLimit-Limit=5, got %s", w.Header().Get("X-RateLimit-Limit"))
}
if w.Header().Get("X-RateLimit-Remaining") != "4" {
t.Errorf("expected X-RateLimit-Remaining=4, got %s", w.Header().Get("X-RateLimit-Remaining"))
}
if w.Header().Get("X-RateLimit-Reset") == "" {
t.Error("expected X-RateLimit-Reset to be set")
}
}

func TestRateLimiterConcurrentAccess(t *testing.T) {
rl := newRateLimiter(100, time.Minute)
done := make(chan bool, 50)

for i := 0; i < 50; i++ {
go func() {
rl.allow("192.168.1.1")
done <- true
}()
}

for i := 0; i < 50; i++ {
<-done
}
}
Loading