diff --git a/e2e/forwarder/service.go b/e2e/forwarder/service.go index c308fc86e..d9d4e5dd2 100644 --- a/e2e/forwarder/service.go +++ b/e2e/forwarder/service.go @@ -181,6 +181,16 @@ func (s *Service) WithDenyDomainExclude(domains ...string) *Service { return s } +func (s *Service) WithReadLimit(limit string) *Service { + s.Environment["FORWARDER_READ_LIMIT"] = limit + return s +} + +func (s *Service) WithWriteLimit(limit string) *Service { + s.Environment["FORWARDER_WRITE_LIMIT"] = limit + return s +} + func (s *Service) Service() *compose.Service { return (*compose.Service)(s) } diff --git a/e2e/setups.go b/e2e/setups.go index 03bb2c93b..c7b1eb250 100644 --- a/e2e/setups.go +++ b/e2e/setups.go @@ -40,6 +40,7 @@ func AllSetups() []setup.Setup { SetupFlagInsecure(l) SetupFlagMITM(l) SetupFlagDenyDomain(l) + SetupFlagRateLimit(l) SetupSC2450(l) return l.Build() @@ -360,6 +361,33 @@ func SetupFlagDenyDomain(l *setupList) { ) } +func SetupFlagRateLimit(l *setupList) { + l.Add( + setup.Setup{ + Name: "flag-read-limit", + Compose: compose.NewBuilder(). + AddService( + forwarder.HttpbinService()). + AddService( + forwarder.ProxyService(). + WithReadLimit("1M")). + MustBuild(), + Run: "^TestFlagReadLimit$", + }, + setup.Setup{ + Name: "flag-write-limit", + Compose: compose.NewBuilder(). + AddService( + forwarder.HttpbinService()). + AddService( + forwarder.ProxyService(). + WithWriteLimit("1M")). + MustBuild(), + Run: "^TestFlagWriteLimit$", + }, + ) +} + func SetupSC2450(l *setupList) { l.Add(setup.Setup{ Name: "sc-2450", diff --git a/e2e/tests/flag_test.go b/e2e/tests/flag_test.go index 53ec7c929..0582830cf 100644 --- a/e2e/tests/flag_test.go +++ b/e2e/tests/flag_test.go @@ -9,9 +9,12 @@ package tests import ( + "fmt" "net" "net/http" + "sync" "testing" + "time" ) func TestFlagProxyLocalhost(t *testing.T) { @@ -94,3 +97,50 @@ func TestFlagDenyDomain(t *testing.T) { newClient(t, "https://www.google.com").GET("/").ExpectStatus(http.StatusForbidden) newClient(t, httpbin).GET("/status/200").ExpectStatus(http.StatusOK) } + +func TestFlagWriteLimit(t *testing.T) { + size := 10 * 1024 * 1024 // 10MiB + // It streams 10MiB, the write limit is 1MiB/s, but minimum burst is 4MiB, so it should take approximately 6 seconds. + expectedTime := 6 * time.Second + + ts := time.Now() + c := newClient(t, httpbin) + c.GET(fmt.Sprintf("/stream-bytes/%d", size)).ExpectStatus(http.StatusOK).ExpectBodySize(size) + if elapsed := time.Since(ts); elapsed.Round(time.Second) != expectedTime { + t.Fatalf("Expected request to take approximately %s, took %s", expectedTime, elapsed) + } +} + +func TestFlagReadLimit(t *testing.T) { + var ( + workers = 4 + requests = 500 + // GET /status/200 requests with 140 `X-Test-Read-Limit` headers takes around 4KB. + // The read limit is 1MiB/s, but minimum burst is 4MiB, + // 4 * 500 * 4000 = 8000000 bytes, so it should take approximately 4 seconds. + expectedTime = 4 * time.Second + + wg sync.WaitGroup + ) + + ts := time.Now() + for i := 0; i < workers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + c := newClient(t, httpbin) + for j := 0; j < requests; j++ { + c.GET("/status/200", func(req *http.Request) { + for i := 0; i < 140; i++ { + req.Header.Add(fmt.Sprintf("X-Test-Read-Limit-%v", i), "test") + } + }).ExpectStatus(http.StatusOK) + } + }() + } + wg.Wait() + + if elapsed := time.Since(ts); elapsed.Round(time.Second) != expectedTime { + t.Fatalf("Expected request to take approximately %s, took %s", expectedTime, elapsed) + } +}