package bus import ( "sync" "sync/atomic" "testing" "time" ) func TestNewBus(t *testing.T) { b := New() if b == nil { t.Fatal("New() returned nil") } } func TestPublishSubscribe(t *testing.T) { b := New() defer b.Close() var received int32 var wg sync.WaitGroup wg.Add(1) sub, err := b.Subscribe("test", func(msg Message) { atomic.AddInt32(&received, 1) wg.Done() }) if err != nil { t.Fatalf("Subscribe failed: %v", err) } defer sub.Unsubscribe() err = b.Publish("test", Message{ ID: "msg-1", Type: MsgTypeSystem, From: "test", }) if err != nil { t.Fatalf("Publish failed: %v", err) } wg.Wait() if atomic.LoadInt32(&received) != 1 { t.Errorf("expected 1 message, got %d", received) } } func TestPublishNoSubscribers(t *testing.T) { b := New() defer b.Close() err := b.Publish("nonexistent", Message{ID: "msg-1"}) if err != nil { t.Fatalf("Publish to nonexistent topic should not error: %v", err) } } func TestMultipleSubscribers(t *testing.T) { b := New() defer b.Close() var received int32 var wg sync.WaitGroup wg.Add(3) for i := 0; i < 3; i++ { sub, err := b.Subscribe("multi", func(msg Message) { atomic.AddInt32(&received, 1) wg.Done() }) if err != nil { t.Fatalf("Subscribe %d failed: %v", i, err) } defer sub.Unsubscribe() } err := b.Publish("multi", Message{ID: "msg-1"}) if err != nil { t.Fatalf("Publish failed: %v", err) } wg.Wait() if n := atomic.LoadInt32(&received); n != 3 { t.Errorf("expected 3 messages, got %d", n) } } func TestUnsubscribe(t *testing.T) { b := New() defer b.Close() var received int32 sub, err := b.Subscribe("test", func(msg Message) { atomic.AddInt32(&received, 1) }) if err != nil { t.Fatalf("Subscribe failed: %v", err) } // Publish before unsubscribe b.Publish("test", Message{ID: "msg-1"}) time.Sleep(50 * time.Millisecond) sub.Unsubscribe() // Publish after unsubscribe b.Publish("test", Message{ID: "msg-2"}) time.Sleep(50 * time.Millisecond) if n := atomic.LoadInt32(&received); n != 1 { t.Errorf("expected 1 message after unsubscribe, got %d", n) } } func TestSubscribeAfterClose(t *testing.T) { b := New() b.Close() _, err := b.Subscribe("test", func(msg Message) {}) if err == nil { t.Error("expected error subscribing to closed bus") } } func TestPublishAfterClose(t *testing.T) { b := New() b.Close() err := b.Publish("test", Message{ID: "msg-1"}) if err == nil { t.Error("expected error publishing to closed bus") } } func TestSubscriptionID(t *testing.T) { b := New() defer b.Close() sub1, _ := b.Subscribe("a", func(msg Message) {}) defer sub1.Unsubscribe() sub2, _ := b.Subscribe("b", func(msg Message) {}) defer sub2.Unsubscribe() if sub1.ID() == sub2.ID() { t.Error("subscription IDs should be unique") } if sub1.Topic() != "a" || sub2.Topic() != "b" { t.Error("topic mismatch") } } func TestConcurrentPublish(t *testing.T) { b := New() defer b.Close() var received int32 var wg sync.WaitGroup wg.Add(100) sub, err := b.Subscribe("concurrent", func(msg Message) { atomic.AddInt32(&received, 1) wg.Done() }) if err != nil { t.Fatalf("Subscribe failed: %v", err) } defer sub.Unsubscribe() for i := 0; i < 100; i++ { go func(i int) { b.Publish("concurrent", Message{ ID: time.Now().String(), Type: MsgTypeSystem, }) }(i) } wg.Wait() if n := atomic.LoadInt32(&received); n != 100 { t.Errorf("expected 100 messages, got %d", n) } } func TestDifferentTopics(t *testing.T) { b := New() defer b.Close() var topics []string var mu sync.Mutex sub1, _ := b.Subscribe("topic-a", func(msg Message) { mu.Lock() topics = append(topics, "a") mu.Unlock() }) defer sub1.Unsubscribe() sub2, _ := b.Subscribe("topic-b", func(msg Message) { mu.Lock() topics = append(topics, "b") mu.Unlock() }) defer sub2.Unsubscribe() b.Publish("topic-a", Message{ID: "msg-1"}) time.Sleep(50 * time.Millisecond) if len(topics) != 1 || topics[0] != "a" { t.Errorf("expected only topic-a to receive message, got %v", topics) } } func TestCloseIdempotent(t *testing.T) { b := New() err1 := b.Close() err2 := b.Close() if err1 != nil { t.Fatalf("first Close failed: %v", err1) } if err2 != nil { t.Fatalf("second Close should be idempotent: %v", err2) } } func TestMessageTypeString(t *testing.T) { tests := []struct { mt MessageType want string }{ {MsgTypeSystem, "system"}, {MsgTypeTaskRequest, "task_request"}, {MsgTypeTaskResponse, "task_response"}, {MsgTypeToolCall, "tool_call"}, {MsgTypeToolResult, "tool_result"}, {MsgTypeObservation, "observation"}, {MsgTypeError, "error"}, {MsgTypeLog, "log"}, {MessageType(99), "unknown"}, } for _, tt := range tests { if got := tt.mt.String(); got != tt.want { t.Errorf("MessageType(%d).String() = %q, want %q", tt.mt, got, tt.want) } } }