2026-05-12 00:09:01 +08:00

470 lines
10 KiB
Go

package websocket
import (
"encoding/json"
"fmt"
"log"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/orca/orca/pkg/actor"
"github.com/orca/orca/pkg/bus"
"github.com/orca/orca/pkg/kernel"
)
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
},
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
type Server struct {
kernel *kernel.Kernel
port int
clients map[string]*Client
clientsMu sync.RWMutex
counter int
}
type Client struct {
ID string
Conn *websocket.Conn
Send chan []byte
Server *Server
}
type Message struct {
Type string `json:"type"`
Content string `json:"content,omitempty"`
Message string `json:"message,omitempty"`
Text string `json:"text,omitempty"`
Agent string `json:"agent,omitempty"`
Stats Stats `json:"stats,omitempty"`
Agents []AgentInfo `json:"agents,omitempty"`
}
type Stats struct {
Tools int `json:"tools"`
Skills int `json:"skills"`
Agents int `json:"agents"`
}
type AgentInfo struct {
ID string `json:"id"`
Status string `json:"status"`
}
func NewServer(k *kernel.Kernel, port int) *Server {
s := &Server{
kernel: k,
port: port,
clients: make(map[string]*Client),
}
if mb := k.Bus(); mb != nil {
mb.Subscribe("agent_events", func(msg bus.Message) {
s.broadcastAgentEvent(msg)
})
}
return s
}
func (s *Server) Start() error {
mux := http.NewServeMux()
// WebSocket endpoint
mux.HandleFunc("/ws", s.handleWebSocket)
// API endpoints - must be registered before static files
mux.HandleFunc("/api/stats", s.handleStats)
mux.HandleFunc("/api/agents", s.handleAgents)
mux.HandleFunc("/api/sessions", s.handleSessions)
mux.HandleFunc("/api/sessions/", s.handleSessionMessages)
// Static files - serve React build
webDir := filepath.Join("web", "dist")
if _, err := os.Stat(webDir); err == nil {
fs := http.FileServer(http.Dir(webDir))
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
// Skip API paths
if strings.HasPrefix(r.URL.Path, "/api/") || r.URL.Path == "/ws" {
w.WriteHeader(http.StatusNotFound)
return
}
path := filepath.Join(webDir, r.URL.Path)
_, err := os.Stat(path)
if os.IsNotExist(err) || r.URL.Path == "/" {
http.ServeFile(w, r, filepath.Join(webDir, "index.html"))
return
}
fs.ServeHTTP(w, r)
})
}
addr := fmt.Sprintf(":%d", s.port)
log.Printf("WebSocket server starting on http://localhost%s", addr)
return http.ListenAndServe(addr, mux)
}
func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Printf("WebSocket upgrade error: %v", err)
return
}
s.counter++
client := &Client{
ID: fmt.Sprintf("client-%d", s.counter),
Conn: conn,
Send: make(chan []byte, 256),
Server: s,
}
s.clientsMu.Lock()
s.clients[client.ID] = client
s.clientsMu.Unlock()
// Send initial stats and agents
s.broadcastStats()
s.broadcastAgents()
go client.writePump()
go client.readPump()
}
func (c *Client) readPump() {
defer func() {
c.Server.removeClient(c)
c.Conn.Close()
}()
c.Conn.SetReadLimit(512 * 1024)
c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second))
c.Conn.SetPongHandler(func(string) error {
c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second))
return nil
})
for {
_, message, err := c.Conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
log.Printf("WebSocket error: %v", err)
}
break
}
var msg Message
if err := json.Unmarshal(message, &msg); err != nil {
continue
}
if msg.Type == "chat" {
go c.handleChat(msg.Message)
}
}
}
func (c *Client) handleChat(userMessage string) {
writer := &wsWriter{client: c}
c.Server.kernel.SetStreamWriter(writer)
resp, err := c.Server.kernel.SendMessage("user", "llm", userMessage)
if err != nil {
c.sendJSON(Message{Type: "error", Content: err.Error()})
return
}
c.sendJSON(Message{Type: "complete", Content: resp})
}
func (c *Client) writePump() {
ticker := time.NewTicker(54 * time.Second)
defer func() {
ticker.Stop()
c.Conn.Close()
}()
for {
select {
case message, ok := <-c.Send:
c.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if !ok {
c.Conn.WriteMessage(websocket.CloseMessage, []byte{})
return
}
c.Conn.WriteMessage(websocket.TextMessage, message)
case <-ticker.C:
c.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if err := c.Conn.WriteMessage(websocket.PingMessage, nil); err != nil {
return
}
}
}
}
func (c *Client) sendJSON(v interface{}) {
data, err := json.Marshal(v)
if err != nil {
return
}
select {
case c.Send <- data:
default:
// Channel full, drop message
}
}
func (s *Server) removeClient(c *Client) {
s.clientsMu.Lock()
delete(s.clients, c.ID)
s.clientsMu.Unlock()
close(c.Send)
}
func (s *Server) broadcastJSON(v interface{}) {
data, err := json.Marshal(v)
if err != nil {
return
}
s.clientsMu.RLock()
defer s.clientsMu.RUnlock()
for _, client := range s.clients {
select {
case client.Send <- data:
default:
// Channel full
}
}
}
func (s *Server) broadcastAgentEvent(msg bus.Message) {
content, ok := msg.Content.(map[string]interface{})
if !ok {
return
}
eventType, _ := content["event"].(string)
switch eventType {
case "token":
text, _ := content["text"].(string)
agent, _ := content["agent"].(string)
if text == "" || agent == "" {
return
}
msg := Message{
Type: "agent_token",
Agent: agent,
Content: text,
}
data, _ := json.Marshal(msg)
s.broadcast(data)
default:
event := Message{
Type: eventType,
Agent: content["agent"].(string),
Message: "",
}
if task, ok := content["task"].(string); ok {
event.Message = task
}
if result, ok := content["result"].(string); ok && result != "" {
event.Content = result
}
data, err := json.Marshal(event)
if err != nil {
return
}
s.broadcast(data)
}
}
func (s *Server) broadcast(data []byte) {
s.clientsMu.RLock()
defer s.clientsMu.RUnlock()
for _, client := range s.clients {
select {
case client.Send <- data:
default:
}
}
}
func (s *Server) broadcastStats() {
stats := Stats{}
if tm := s.kernel.ToolManager(); tm != nil {
stats.Tools = tm.Count()
}
if sm := s.kernel.SkillManager(); sm != nil {
stats.Skills = len(sm.ListSkills())
}
if as := s.kernel.ActorSystem(); as != nil {
stats.Agents = as.AgentCount()
}
s.broadcastJSON(Message{Type: "stats", Stats: stats})
}
func (s *Server) broadcastAgents() {
var agents []AgentInfo
if as := s.kernel.ActorSystem(); as != nil {
for _, info := range as.AgentInfos() {
status := "idle"
if info.Status == actor.StatusProcessing {
status = "running"
}
agents = append(agents, AgentInfo{ID: info.ID, Status: status})
}
}
s.broadcastJSON(Message{Type: "agents", Agents: agents})
}
func (s *Server) handleStats(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
stats := Stats{}
if tm := s.kernel.ToolManager(); tm != nil {
stats.Tools = tm.Count()
}
if sm := s.kernel.SkillManager(); sm != nil {
stats.Skills = len(sm.ListSkills())
}
if as := s.kernel.ActorSystem(); as != nil {
stats.Agents = as.AgentCount()
}
json.NewEncoder(w).Encode(stats)
}
func (s *Server) handleAgents(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
var agents []AgentInfo
if as := s.kernel.ActorSystem(); as != nil {
for _, info := range as.AgentInfos() {
status := "idle"
if info.Status == actor.StatusProcessing {
status = "running"
}
agents = append(agents, AgentInfo{ID: info.ID, Status: status})
}
}
json.NewEncoder(w).Encode(agents)
}
type SessionInfo struct {
ID string `json:"id"`
MessageCount int `json:"message_count"`
CreatedAt string `json:"created_at"`
}
type SessionMessage struct {
Role string `json:"role"`
Content string `json:"content"`
Timestamp string `json:"timestamp"`
}
func (s *Server) handleSessions(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
sessionMgr := s.kernel.SessionManager()
if sessionMgr == nil {
json.NewEncoder(w).Encode([]SessionInfo{})
return
}
sessionIDs, err := sessionMgr.ListSessions()
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
var result []SessionInfo
for _, id := range sessionIDs {
session, err := sessionMgr.GetSession(id)
if err != nil {
continue
}
result = append(result, SessionInfo{
ID: session.ID,
MessageCount: len(session.Messages),
CreatedAt: session.CreatedAt.Format(time.RFC3339),
})
}
json.NewEncoder(w).Encode(result)
}
func (s *Server) handleSessionMessages(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
path := strings.TrimPrefix(r.URL.Path, "/api/sessions/")
if path == "" || path == "/api/sessions" {
http.Error(w, "session ID required", http.StatusBadRequest)
return
}
sessionMgr := s.kernel.SessionManager()
if sessionMgr == nil {
json.NewEncoder(w).Encode([]SessionMessage{})
return
}
session, err := sessionMgr.GetSession(path)
if err != nil {
http.Error(w, err.Error(), http.StatusNotFound)
return
}
var result []SessionMessage
for _, msg := range session.Messages {
result = append(result, SessionMessage{
Role: string(msg.Role),
Content: msg.Content,
Timestamp: msg.Timestamp.Format(time.RFC3339),
})
}
json.NewEncoder(w).Encode(result)
}
// wsWriter implements kernel.StreamWriter
type wsWriter struct {
client *Client
mu sync.Mutex
buf strings.Builder
}
func (w *wsWriter) Write(p []byte) (n int, err error) {
w.mu.Lock()
defer w.mu.Unlock()
w.buf.Write(p)
text := w.buf.String()
w.buf.Reset()
msg := Message{Type: "token", Text: text}
data, _ := json.Marshal(msg)
select {
case w.client.Send <- data:
default:
}
return len(p), nil
}
func (w *wsWriter) Flush() error {
return nil
}