diff --git a/cmd/baker/main.go b/cmd/baker/main.go index 05a9044..75e6381 100644 --- a/cmd/baker/main.go +++ b/cmd/baker/main.go @@ -13,6 +13,7 @@ import ( "ella.to/baker/driver" "ella.to/baker/internal/acme" "ella.to/baker/internal/httpclient" + "ella.to/baker/rule" ) var Version = "master" @@ -49,7 +50,15 @@ https://ella.to/baker docker := driver.NewDocker(dockerGetter) - handler := baker.NewServer(baker.WithBufferSize(bufferSize), baker.WithPingDuration(pingDuration)) + handler := baker.NewServer( + baker.WithBufferSize(bufferSize), + baker.WithPingDuration(pingDuration), + baker.WithRules( + rule.RegisterAppendPath(), + rule.RegisterReplacePath(), + rule.RegisterRateLimiter(), + ), + ) handler.RegisterDriver(docker.RegisterDriver) if acmeEnable { diff --git a/data.go b/data.go index 5d686e5..75f2ae9 100644 --- a/data.go +++ b/data.go @@ -3,6 +3,7 @@ package baker import ( "encoding/json" "net/netip" + "strings" ) type Container struct { @@ -17,6 +18,15 @@ type Endpoint struct { Rules []Rule `json:"rules"` } +func (e *Endpoint) getHashKey() string { + var sb strings.Builder + + sb.WriteString(e.Domain) + sb.WriteString(e.Path) + + return sb.String() +} + type Rule struct { Type string `json:"type"` Args json.RawMessage `json:"args"` diff --git a/internal/collection/map.go b/internal/collection/map.go new file mode 100644 index 0000000..c590cf2 --- /dev/null +++ b/internal/collection/map.go @@ -0,0 +1,54 @@ +package collection + +import "sync" + +type Map[T any] struct { + rw sync.RWMutex + collection map[string]T +} + +func (m *Map[T]) Put(key string, val T) { + m.rw.Lock() + defer m.rw.Unlock() + + m.collection[key] = val +} + +func (m *Map[T]) Get(key string) (val T, ok bool) { + m.rw.RLock() + defer m.rw.RUnlock() + + val, ok = m.collection[key] + return +} + +func (s *Map[T]) Len() int { + s.rw.RLock() + defer s.rw.RUnlock() + + return len(s.collection) +} + +func (m *Map[T]) Delete(key string) { + m.rw.Lock() + defer m.rw.Unlock() + + delete(m.collection, key) +} + +func (m *Map[T]) GetAndUpdate(key string, update func(old T, found bool) T) (val T) { + m.rw.Lock() + defer m.rw.Unlock() + + val, found := m.collection[key] + val = update(val, found) + m.collection[key] = val + + return +} + +func NewMap[T any]() *Map[T] { + return &Map[T]{ + collection: make(map[string]T), + } +} diff --git a/server.go b/server.go index b14f320..aea18e5 100644 --- a/server.go +++ b/server.go @@ -14,6 +14,7 @@ import ( "ella.to/baker/internal/httpclient" "ella.to/baker/internal/trie" "ella.to/baker/rule" + "github.com/alinz/baker.go/pkg/collection" ) type containerInfo struct { @@ -24,13 +25,14 @@ type containerInfo struct { } type Server struct { - bufferSize int - pingDuration time.Duration - containersMap map[string]*containerInfo // containerID -> containerInfo - domainsMap map[string]*trie.Node[*Service] // domain -> path -> containers - rules map[string]rule.BuilderFunc - runner *ActionRunner - close chan struct{} + bufferSize int + pingDuration time.Duration + containersMap map[string]*containerInfo // containerID -> containerInfo + domainsMap map[string]*trie.Node[*Service] // domain -> path -> containers + rules map[string]rule.BuilderFunc + middlewareCacheMap *collection.Map[rule.Middleware] + runner *ActionRunner + close chan struct{} } var _ http.Handler = (*Server)(nil) @@ -102,6 +104,16 @@ func (s *Server) getMiddlewares(endpoint *Endpoint) ([]rule.Middleware, error) { return nil, fmt.Errorf("failed to parse args for rule %s: %w", r.Type, err) } + if middleware.IsCachable() { + middleware = s.middlewareCacheMap.GetAndUpdate(endpoint.getHashKey(), func(old rule.Middleware, found bool) rule.Middleware { + if found { + return old.UpdateMiddelware(middleware) + } + + return middleware.UpdateMiddelware(nil) + }) + } + middlewares = append(middlewares, middleware) } @@ -249,6 +261,7 @@ func (s *Server) removeContainer(container *Container) { service.Containers = append(service.Containers[:i], service.Containers[i+1:]...) if len(service.Containers) == 0 { paths.Del([]rune(containerInfo.path)) + s.middlewareCacheMap.Delete(service.Endpoint.getHashKey()) } else { paths.Put([]rune(containerInfo.path), service) } @@ -283,38 +296,58 @@ func (s *Server) getContainer(domain, path string) (container *Container, endpoi } type serverOpt interface { - configureServer(*Server) + configureServer(*Server) error } -type serverOptFunc func(*Server) +type serverOptFunc func(*Server) error -func (f serverOptFunc) configureServer(s *Server) { - f(s) +func (f serverOptFunc) configureServer(s *Server) error { + return f(s) } func WithBufferSize(size int) serverOptFunc { - return func(s *Server) { + return func(s *Server) error { s.bufferSize = size + return nil } } func WithPingDuration(d time.Duration) serverOptFunc { - return func(s *Server) { + return func(s *Server) error { s.pingDuration = d + return nil + } +} + +func WithRules(rules ...rule.RegisterFunc) serverOptFunc { + return func(s *Server) error { + s.rules = make(map[string]rule.BuilderFunc) + + for _, r := range rules { + if err := r(s.rules); err != nil { + return err + } + } + + return nil } } func NewServer(opts ...serverOpt) *Server { s := &Server{ - bufferSize: 100, - pingDuration: 10 * time.Second, - containersMap: make(map[string]*containerInfo), - domainsMap: make(map[string]*trie.Node[*Service]), - close: make(chan struct{}), + bufferSize: 100, + pingDuration: 10 * time.Second, + containersMap: make(map[string]*containerInfo), + domainsMap: make(map[string]*trie.Node[*Service]), + middlewareCacheMap: collection.NewMap[rule.Middleware](), + close: make(chan struct{}), } for _, opt := range opts { - opt.configureServer(s) + if err := opt.configureServer(s); err != nil { + slog.Error("failed to configure server", "error", err) + return nil + } } s.runner = NewActionRunner( diff --git a/server_test.go b/server_test.go index e8bb983..db3aceb 100644 --- a/server_test.go +++ b/server_test.go @@ -12,6 +12,7 @@ import ( "time" "ella.to/baker" + "ella.to/baker/rule" ) var count int @@ -51,7 +52,14 @@ func createDummyContainer(t *testing.T, config *baker.Config) *baker.Container { } func createBakerServer(t *testing.T) (*baker.Server, string) { - handler := baker.NewServer(baker.WithPingDuration(2 * time.Second)) + handler := baker.NewServer( + baker.WithPingDuration(2*time.Second), + baker.WithRules( + rule.RegisterAppendPath(), + rule.RegisterReplacePath(), + rule.RegisterRateLimiter(), + ), + ) server := httptest.NewServer(handler) t.Cleanup(func() { handler.Close() @@ -101,3 +109,70 @@ func TestServer(t *testing.T) { resp.Body.Close() } + +func TestRateLimiter(t *testing.T) { + t.Skip("skipping test for now") + + slog.SetLogLoggerLevel(slog.LevelDebug) + + container1 := createDummyContainer(t, &baker.Config{ + Endpoints: []baker.Endpoint{ + { + Domain: "example.com", + Path: "/ella/a", + Rules: []baker.Rule{ + { + Type: "RateLimiter", + Args: json.RawMessage(`{"request_limit":2,"window_duration":"3s"}`), + }, + }, + }, + }, + }) + + server, url := createBakerServer(t) + + var driver baker.Driver + + server.RegisterDriver(func(d baker.Driver) { + driver = d + }) + + driver.Add(container1) + + // Wait for the server to process the container + time.Sleep(4 * time.Second) + + for range 2 { + if err := makeCall(url, "/ella/a", "example.com"); err != nil { + t.Fatal(err) + } + } + + if err := makeCall(url, "/ella/a", "example.com"); err == nil { + t.Fatal("expected error, got nil") + } + + fmt.Println("waiting for rate limiter to reset") +} + +func makeCall(url, path, host string) error { + req, err := http.NewRequest(http.MethodGet, url+path, nil) + if err != nil { + return err + } + + req.Host = host + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + + if resp.StatusCode != http.StatusOK { + return err + } + + resp.Body.Close() + return nil +}