summaryrefslogtreecommitdiffhomepage
path: root/internal
diff options
context:
space:
mode:
authoradamdottv <[email protected]>2025-05-05 14:23:29 -0500
committeradamdottv <[email protected]>2025-05-05 14:23:29 -0500
commit3cc08494a56b30bab8663935eb158906a68bed20 (patch)
tree5296f2ef113cf98f44be315330e4527f3775595c /internal
parentafcdabd09534fd97c09b128b4b62baa318b92f19 (diff)
downloadopencode-3cc08494a56b30bab8663935eb158906a68bed20.tar.gz
opencode-3cc08494a56b30bab8663935eb158906a68bed20.zip
fix: pubsub leak and shutdown seq
Diffstat (limited to 'internal')
-rw-r--r--internal/pubsub/broker.go37
-rw-r--r--internal/pubsub/broker_test.go145
2 files changed, 166 insertions, 16 deletions
diff --git a/internal/pubsub/broker.go b/internal/pubsub/broker.go
index 0de1be063..88a59f60a 100644
--- a/internal/pubsub/broker.go
+++ b/internal/pubsub/broker.go
@@ -64,22 +64,27 @@ func (b *Broker[T]) Subscribe(ctx context.Context) <-chan Event[T] {
b.subs[sub] = struct{}{}
b.subCount++
- go func() {
- <-ctx.Done()
-
- b.mu.Lock()
- defer b.mu.Unlock()
-
- select {
- case <-b.done:
- return
- default:
- }
-
- delete(b.subs, sub)
- close(sub)
- b.subCount--
- }()
+ // Only start a goroutine if the context can actually be canceled
+ if ctx.Done() != nil {
+ go func() {
+ <-ctx.Done()
+
+ b.mu.Lock()
+ defer b.mu.Unlock()
+
+ select {
+ case <-b.done:
+ return
+ default:
+ }
+
+ if _, exists := b.subs[sub]; exists {
+ delete(b.subs, sub)
+ close(sub)
+ b.subCount--
+ }
+ }()
+ }
return sub
}
diff --git a/internal/pubsub/broker_test.go b/internal/pubsub/broker_test.go
new file mode 100644
index 000000000..6fae4874f
--- /dev/null
+++ b/internal/pubsub/broker_test.go
@@ -0,0 +1,145 @@
+package pubsub
+
+import (
+ "context"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestBrokerSubscribe(t *testing.T) {
+ t.Parallel()
+
+ t.Run("with cancellable context", func(t *testing.T) {
+ t.Parallel()
+ broker := NewBroker[string]()
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ ch := broker.Subscribe(ctx)
+ assert.NotNil(t, ch)
+ assert.Equal(t, 1, broker.GetSubscriberCount())
+
+ // Cancel the context should remove the subscription
+ cancel()
+ time.Sleep(10 * time.Millisecond) // Give time for goroutine to process
+ assert.Equal(t, 0, broker.GetSubscriberCount())
+ })
+
+ t.Run("with background context", func(t *testing.T) {
+ t.Parallel()
+ broker := NewBroker[string]()
+
+ // Using context.Background() should not leak goroutines
+ ch := broker.Subscribe(context.Background())
+ assert.NotNil(t, ch)
+ assert.Equal(t, 1, broker.GetSubscriberCount())
+
+ // Shutdown should clean up all subscriptions
+ broker.Shutdown()
+ assert.Equal(t, 0, broker.GetSubscriberCount())
+ })
+}
+
+func TestBrokerPublish(t *testing.T) {
+ t.Parallel()
+ broker := NewBroker[string]()
+ ctx := t.Context()
+
+ ch := broker.Subscribe(ctx)
+
+ // Publish a message
+ broker.Publish(CreatedEvent, "test message")
+
+ // Verify message is received
+ select {
+ case event := <-ch:
+ assert.Equal(t, CreatedEvent, event.Type)
+ assert.Equal(t, "test message", event.Payload)
+ case <-time.After(100 * time.Millisecond):
+ t.Fatal("timeout waiting for message")
+ }
+}
+
+func TestBrokerShutdown(t *testing.T) {
+ t.Parallel()
+ broker := NewBroker[string]()
+
+ // Create multiple subscribers
+ ch1 := broker.Subscribe(context.Background())
+ ch2 := broker.Subscribe(context.Background())
+
+ assert.Equal(t, 2, broker.GetSubscriberCount())
+
+ // Shutdown should close all channels and clean up
+ broker.Shutdown()
+
+ // Verify channels are closed
+ _, ok1 := <-ch1
+ _, ok2 := <-ch2
+ assert.False(t, ok1, "channel 1 should be closed")
+ assert.False(t, ok2, "channel 2 should be closed")
+
+ // Verify subscriber count is reset
+ assert.Equal(t, 0, broker.GetSubscriberCount())
+}
+
+func TestBrokerConcurrency(t *testing.T) {
+ t.Parallel()
+ broker := NewBroker[int]()
+
+ // Create a large number of subscribers
+ const numSubscribers = 100
+ var wg sync.WaitGroup
+ wg.Add(numSubscribers)
+
+ // Create a channel to collect received events
+ receivedEvents := make(chan int, numSubscribers)
+
+ for i := range numSubscribers {
+ go func(id int) {
+ defer wg.Done()
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ ch := broker.Subscribe(ctx)
+
+ // Receive one message then cancel
+ select {
+ case event := <-ch:
+ receivedEvents <- event.Payload
+ case <-time.After(1 * time.Second):
+ t.Errorf("timeout waiting for message %d", id)
+ }
+ cancel()
+ }(i)
+ }
+
+ // Give subscribers time to set up
+ time.Sleep(10 * time.Millisecond)
+
+ // Publish messages to all subscribers
+ for i := range numSubscribers {
+ broker.Publish(CreatedEvent, i)
+ }
+
+ // Wait for all subscribers to finish
+ wg.Wait()
+ close(receivedEvents)
+
+ // Give time for cleanup goroutines to run
+ time.Sleep(10 * time.Millisecond)
+
+ // Verify all subscribers are cleaned up
+ assert.Equal(t, 0, broker.GetSubscriberCount())
+
+ // Verify we received the expected number of events
+ count := 0
+ for range receivedEvents {
+ count++
+ }
+ assert.Equal(t, numSubscribers, count)
+}
+