470 lines
10 KiB
Go
470 lines
10 KiB
Go
package websocket
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
"github.com/orca/orca/pkg/actor"
|
|
"github.com/orca/orca/pkg/bus"
|
|
"github.com/orca/orca/pkg/kernel"
|
|
)
|
|
|
|
var upgrader = websocket.Upgrader{
|
|
CheckOrigin: func(r *http.Request) bool {
|
|
return true
|
|
},
|
|
ReadBufferSize: 1024,
|
|
WriteBufferSize: 1024,
|
|
}
|
|
|
|
type Server struct {
|
|
kernel *kernel.Kernel
|
|
port int
|
|
clients map[string]*Client
|
|
clientsMu sync.RWMutex
|
|
counter int
|
|
}
|
|
|
|
type Client struct {
|
|
ID string
|
|
Conn *websocket.Conn
|
|
Send chan []byte
|
|
Server *Server
|
|
}
|
|
|
|
type Message struct {
|
|
Type string `json:"type"`
|
|
Content string `json:"content,omitempty"`
|
|
Message string `json:"message,omitempty"`
|
|
Text string `json:"text,omitempty"`
|
|
Agent string `json:"agent,omitempty"`
|
|
Stats Stats `json:"stats,omitempty"`
|
|
Agents []AgentInfo `json:"agents,omitempty"`
|
|
}
|
|
|
|
type Stats struct {
|
|
Tools int `json:"tools"`
|
|
Skills int `json:"skills"`
|
|
Agents int `json:"agents"`
|
|
}
|
|
|
|
type AgentInfo struct {
|
|
ID string `json:"id"`
|
|
Status string `json:"status"`
|
|
}
|
|
|
|
func NewServer(k *kernel.Kernel, port int) *Server {
|
|
s := &Server{
|
|
kernel: k,
|
|
port: port,
|
|
clients: make(map[string]*Client),
|
|
}
|
|
|
|
if mb := k.Bus(); mb != nil {
|
|
mb.Subscribe("agent_events", func(msg bus.Message) {
|
|
s.broadcastAgentEvent(msg)
|
|
})
|
|
}
|
|
|
|
return s
|
|
}
|
|
|
|
func (s *Server) Start() error {
|
|
mux := http.NewServeMux()
|
|
|
|
// WebSocket endpoint
|
|
mux.HandleFunc("/ws", s.handleWebSocket)
|
|
|
|
// API endpoints - must be registered before static files
|
|
mux.HandleFunc("/api/stats", s.handleStats)
|
|
mux.HandleFunc("/api/agents", s.handleAgents)
|
|
mux.HandleFunc("/api/sessions", s.handleSessions)
|
|
mux.HandleFunc("/api/sessions/", s.handleSessionMessages)
|
|
|
|
// Static files - serve React build
|
|
webDir := filepath.Join("web", "dist")
|
|
if _, err := os.Stat(webDir); err == nil {
|
|
fs := http.FileServer(http.Dir(webDir))
|
|
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
|
// Skip API paths
|
|
if strings.HasPrefix(r.URL.Path, "/api/") || r.URL.Path == "/ws" {
|
|
w.WriteHeader(http.StatusNotFound)
|
|
return
|
|
}
|
|
path := filepath.Join(webDir, r.URL.Path)
|
|
_, err := os.Stat(path)
|
|
if os.IsNotExist(err) || r.URL.Path == "/" {
|
|
http.ServeFile(w, r, filepath.Join(webDir, "index.html"))
|
|
return
|
|
}
|
|
fs.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
addr := fmt.Sprintf(":%d", s.port)
|
|
log.Printf("WebSocket server starting on http://localhost%s", addr)
|
|
return http.ListenAndServe(addr, mux)
|
|
}
|
|
|
|
func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) {
|
|
conn, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
log.Printf("WebSocket upgrade error: %v", err)
|
|
return
|
|
}
|
|
|
|
s.counter++
|
|
client := &Client{
|
|
ID: fmt.Sprintf("client-%d", s.counter),
|
|
Conn: conn,
|
|
Send: make(chan []byte, 256),
|
|
Server: s,
|
|
}
|
|
|
|
s.clientsMu.Lock()
|
|
s.clients[client.ID] = client
|
|
s.clientsMu.Unlock()
|
|
|
|
// Send initial stats and agents
|
|
s.broadcastStats()
|
|
s.broadcastAgents()
|
|
|
|
go client.writePump()
|
|
go client.readPump()
|
|
}
|
|
|
|
func (c *Client) readPump() {
|
|
defer func() {
|
|
c.Server.removeClient(c)
|
|
c.Conn.Close()
|
|
}()
|
|
|
|
c.Conn.SetReadLimit(512 * 1024)
|
|
c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
|
c.Conn.SetPongHandler(func(string) error {
|
|
c.Conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
|
return nil
|
|
})
|
|
|
|
for {
|
|
_, message, err := c.Conn.ReadMessage()
|
|
if err != nil {
|
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
|
log.Printf("WebSocket error: %v", err)
|
|
}
|
|
break
|
|
}
|
|
|
|
var msg Message
|
|
if err := json.Unmarshal(message, &msg); err != nil {
|
|
continue
|
|
}
|
|
|
|
if msg.Type == "chat" {
|
|
go c.handleChat(msg.Message)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Client) handleChat(userMessage string) {
|
|
writer := &wsWriter{client: c}
|
|
c.Server.kernel.SetStreamWriter(writer)
|
|
|
|
resp, err := c.Server.kernel.SendMessage("user", "llm", userMessage)
|
|
if err != nil {
|
|
c.sendJSON(Message{Type: "error", Content: err.Error()})
|
|
return
|
|
}
|
|
|
|
c.sendJSON(Message{Type: "complete", Content: resp})
|
|
}
|
|
|
|
func (c *Client) writePump() {
|
|
ticker := time.NewTicker(54 * time.Second)
|
|
defer func() {
|
|
ticker.Stop()
|
|
c.Conn.Close()
|
|
}()
|
|
|
|
for {
|
|
select {
|
|
case message, ok := <-c.Send:
|
|
c.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
|
if !ok {
|
|
c.Conn.WriteMessage(websocket.CloseMessage, []byte{})
|
|
return
|
|
}
|
|
c.Conn.WriteMessage(websocket.TextMessage, message)
|
|
|
|
case <-ticker.C:
|
|
c.Conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
|
if err := c.Conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Client) sendJSON(v interface{}) {
|
|
data, err := json.Marshal(v)
|
|
if err != nil {
|
|
return
|
|
}
|
|
select {
|
|
case c.Send <- data:
|
|
default:
|
|
// Channel full, drop message
|
|
}
|
|
}
|
|
|
|
func (s *Server) removeClient(c *Client) {
|
|
s.clientsMu.Lock()
|
|
delete(s.clients, c.ID)
|
|
s.clientsMu.Unlock()
|
|
close(c.Send)
|
|
}
|
|
|
|
func (s *Server) broadcastJSON(v interface{}) {
|
|
data, err := json.Marshal(v)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
s.clientsMu.RLock()
|
|
defer s.clientsMu.RUnlock()
|
|
|
|
for _, client := range s.clients {
|
|
select {
|
|
case client.Send <- data:
|
|
default:
|
|
// Channel full
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Server) broadcastAgentEvent(msg bus.Message) {
|
|
content, ok := msg.Content.(map[string]interface{})
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
eventType, _ := content["event"].(string)
|
|
|
|
switch eventType {
|
|
case "token":
|
|
text, _ := content["text"].(string)
|
|
agent, _ := content["agent"].(string)
|
|
if text == "" || agent == "" {
|
|
return
|
|
}
|
|
msg := Message{
|
|
Type: "agent_token",
|
|
Agent: agent,
|
|
Content: text,
|
|
}
|
|
data, _ := json.Marshal(msg)
|
|
s.broadcast(data)
|
|
default:
|
|
event := Message{
|
|
Type: eventType,
|
|
Agent: content["agent"].(string),
|
|
Message: "",
|
|
}
|
|
if task, ok := content["task"].(string); ok {
|
|
event.Message = task
|
|
}
|
|
if result, ok := content["result"].(string); ok && result != "" {
|
|
event.Content = result
|
|
}
|
|
data, err := json.Marshal(event)
|
|
if err != nil {
|
|
return
|
|
}
|
|
s.broadcast(data)
|
|
}
|
|
}
|
|
|
|
func (s *Server) broadcast(data []byte) {
|
|
s.clientsMu.RLock()
|
|
defer s.clientsMu.RUnlock()
|
|
|
|
for _, client := range s.clients {
|
|
select {
|
|
case client.Send <- data:
|
|
default:
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Server) broadcastStats() {
|
|
stats := Stats{}
|
|
if tm := s.kernel.ToolManager(); tm != nil {
|
|
stats.Tools = tm.Count()
|
|
}
|
|
if sm := s.kernel.SkillManager(); sm != nil {
|
|
stats.Skills = len(sm.ListSkills())
|
|
}
|
|
if as := s.kernel.ActorSystem(); as != nil {
|
|
stats.Agents = as.AgentCount()
|
|
}
|
|
s.broadcastJSON(Message{Type: "stats", Stats: stats})
|
|
}
|
|
|
|
func (s *Server) broadcastAgents() {
|
|
var agents []AgentInfo
|
|
if as := s.kernel.ActorSystem(); as != nil {
|
|
for _, info := range as.AgentInfos() {
|
|
status := "idle"
|
|
if info.Status == actor.StatusProcessing {
|
|
status = "running"
|
|
}
|
|
agents = append(agents, AgentInfo{ID: info.ID, Status: status})
|
|
}
|
|
}
|
|
s.broadcastJSON(Message{Type: "agents", Agents: agents})
|
|
}
|
|
|
|
func (s *Server) handleStats(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
stats := Stats{}
|
|
if tm := s.kernel.ToolManager(); tm != nil {
|
|
stats.Tools = tm.Count()
|
|
}
|
|
if sm := s.kernel.SkillManager(); sm != nil {
|
|
stats.Skills = len(sm.ListSkills())
|
|
}
|
|
if as := s.kernel.ActorSystem(); as != nil {
|
|
stats.Agents = as.AgentCount()
|
|
}
|
|
json.NewEncoder(w).Encode(stats)
|
|
}
|
|
|
|
func (s *Server) handleAgents(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
var agents []AgentInfo
|
|
if as := s.kernel.ActorSystem(); as != nil {
|
|
for _, info := range as.AgentInfos() {
|
|
status := "idle"
|
|
if info.Status == actor.StatusProcessing {
|
|
status = "running"
|
|
}
|
|
agents = append(agents, AgentInfo{ID: info.ID, Status: status})
|
|
}
|
|
}
|
|
json.NewEncoder(w).Encode(agents)
|
|
}
|
|
|
|
type SessionInfo struct {
|
|
ID string `json:"id"`
|
|
MessageCount int `json:"message_count"`
|
|
CreatedAt string `json:"created_at"`
|
|
}
|
|
|
|
type SessionMessage struct {
|
|
Role string `json:"role"`
|
|
Content string `json:"content"`
|
|
Timestamp string `json:"timestamp"`
|
|
}
|
|
|
|
func (s *Server) handleSessions(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
|
|
sessionMgr := s.kernel.SessionManager()
|
|
if sessionMgr == nil {
|
|
json.NewEncoder(w).Encode([]SessionInfo{})
|
|
return
|
|
}
|
|
|
|
sessionIDs, err := sessionMgr.ListSessions()
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
var result []SessionInfo
|
|
for _, id := range sessionIDs {
|
|
session, err := sessionMgr.GetSession(id)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
result = append(result, SessionInfo{
|
|
ID: session.ID,
|
|
MessageCount: len(session.Messages),
|
|
CreatedAt: session.CreatedAt.Format(time.RFC3339),
|
|
})
|
|
}
|
|
|
|
json.NewEncoder(w).Encode(result)
|
|
}
|
|
|
|
func (s *Server) handleSessionMessages(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
|
|
path := strings.TrimPrefix(r.URL.Path, "/api/sessions/")
|
|
if path == "" || path == "/api/sessions" {
|
|
http.Error(w, "session ID required", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
sessionMgr := s.kernel.SessionManager()
|
|
if sessionMgr == nil {
|
|
json.NewEncoder(w).Encode([]SessionMessage{})
|
|
return
|
|
}
|
|
|
|
session, err := sessionMgr.GetSession(path)
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusNotFound)
|
|
return
|
|
}
|
|
|
|
var result []SessionMessage
|
|
for _, msg := range session.Messages {
|
|
result = append(result, SessionMessage{
|
|
Role: string(msg.Role),
|
|
Content: msg.Content,
|
|
Timestamp: msg.Timestamp.Format(time.RFC3339),
|
|
})
|
|
}
|
|
|
|
json.NewEncoder(w).Encode(result)
|
|
}
|
|
|
|
// wsWriter implements kernel.StreamWriter
|
|
type wsWriter struct {
|
|
client *Client
|
|
mu sync.Mutex
|
|
buf strings.Builder
|
|
}
|
|
|
|
func (w *wsWriter) Write(p []byte) (n int, err error) {
|
|
w.mu.Lock()
|
|
defer w.mu.Unlock()
|
|
|
|
w.buf.Write(p)
|
|
text := w.buf.String()
|
|
w.buf.Reset()
|
|
|
|
msg := Message{Type: "token", Text: text}
|
|
data, _ := json.Marshal(msg)
|
|
|
|
select {
|
|
case w.client.Send <- data:
|
|
default:
|
|
}
|
|
|
|
return len(p), nil
|
|
}
|
|
|
|
func (w *wsWriter) Flush() error {
|
|
return nil
|
|
}
|