711 lines
20 KiB
Go
711 lines
20 KiB
Go
package agent
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
"sync"
|
|
|
|
"github.com/opencode-ai/opencode/internal/config"
|
|
"github.com/opencode-ai/opencode/internal/llm/models"
|
|
"github.com/opencode-ai/opencode/internal/llm/prompt"
|
|
"github.com/opencode-ai/opencode/internal/llm/provider"
|
|
"github.com/opencode-ai/opencode/internal/llm/tools"
|
|
"github.com/opencode-ai/opencode/internal/logging"
|
|
"github.com/opencode-ai/opencode/internal/message"
|
|
"github.com/opencode-ai/opencode/internal/permission"
|
|
"github.com/opencode-ai/opencode/internal/pubsub"
|
|
"github.com/opencode-ai/opencode/internal/session"
|
|
)
|
|
|
|
// Common errors
|
|
var (
|
|
ErrRequestCancelled = errors.New("request cancelled by user")
|
|
ErrSessionBusy = errors.New("session is currently processing another request")
|
|
)
|
|
|
|
type AgentEventType string
|
|
|
|
const (
|
|
AgentEventTypeError AgentEventType = "error"
|
|
AgentEventTypeResponse AgentEventType = "response"
|
|
AgentEventTypeSummarize AgentEventType = "summarize"
|
|
)
|
|
|
|
type AgentEvent struct {
|
|
Type AgentEventType
|
|
Message message.Message
|
|
Error error
|
|
|
|
// When summarizing
|
|
SessionID string
|
|
Progress string
|
|
Done bool
|
|
}
|
|
|
|
type Service interface {
|
|
pubsub.Suscriber[AgentEvent]
|
|
Model() models.Model
|
|
Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
|
|
Cancel(sessionID string)
|
|
IsSessionBusy(sessionID string) bool
|
|
IsBusy() bool
|
|
Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error)
|
|
Summarize(ctx context.Context, sessionID string) error
|
|
}
|
|
|
|
type agent struct {
|
|
*pubsub.Broker[AgentEvent]
|
|
sessions session.Service
|
|
messages message.Service
|
|
|
|
tools []tools.BaseTool
|
|
provider provider.Provider
|
|
|
|
titleProvider provider.Provider
|
|
summarizeProvider provider.Provider
|
|
|
|
activeRequests sync.Map
|
|
}
|
|
|
|
func NewAgent(
|
|
agentName config.AgentName,
|
|
sessions session.Service,
|
|
messages message.Service,
|
|
agentTools []tools.BaseTool,
|
|
) (Service, error) {
|
|
agentProvider, err := createAgentProvider(agentName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var titleProvider provider.Provider
|
|
// Only generate titles for the coder agent
|
|
if agentName == config.AgentCoder {
|
|
titleProvider, err = createAgentProvider(config.AgentTitle)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
var summarizeProvider provider.Provider
|
|
if agentName == config.AgentCoder {
|
|
summarizeProvider, err = createAgentProvider(config.AgentSummarizer)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
agent := &agent{
|
|
Broker: pubsub.NewBroker[AgentEvent](),
|
|
provider: agentProvider,
|
|
messages: messages,
|
|
sessions: sessions,
|
|
tools: agentTools,
|
|
titleProvider: titleProvider,
|
|
summarizeProvider: summarizeProvider,
|
|
activeRequests: sync.Map{},
|
|
}
|
|
|
|
return agent, nil
|
|
}
|
|
|
|
func (a *agent) Model() models.Model {
|
|
return a.provider.Model()
|
|
}
|
|
|
|
func (a *agent) Cancel(sessionID string) {
|
|
// Cancel regular requests
|
|
if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
|
|
if cancel, ok := cancelFunc.(context.CancelFunc); ok {
|
|
logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
|
|
cancel()
|
|
}
|
|
}
|
|
|
|
// Also check for summarize requests
|
|
if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
|
|
if cancel, ok := cancelFunc.(context.CancelFunc); ok {
|
|
logging.InfoPersist(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
|
|
cancel()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (a *agent) IsBusy() bool {
|
|
busy := false
|
|
a.activeRequests.Range(func(key, value interface{}) bool {
|
|
if cancelFunc, ok := value.(context.CancelFunc); ok {
|
|
if cancelFunc != nil {
|
|
busy = true
|
|
return false // Stop iterating
|
|
}
|
|
}
|
|
return true // Continue iterating
|
|
})
|
|
return busy
|
|
}
|
|
|
|
func (a *agent) IsSessionBusy(sessionID string) bool {
|
|
_, busy := a.activeRequests.Load(sessionID)
|
|
return busy
|
|
}
|
|
|
|
func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
|
|
if content == "" {
|
|
return nil
|
|
}
|
|
if a.titleProvider == nil {
|
|
return nil
|
|
}
|
|
session, err := a.sessions.Get(ctx, sessionID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
parts := []message.ContentPart{message.TextContent{Text: content}}
|
|
response, err := a.titleProvider.SendMessages(
|
|
ctx,
|
|
[]message.Message{
|
|
{
|
|
Role: message.User,
|
|
Parts: parts,
|
|
},
|
|
},
|
|
make([]tools.BaseTool, 0),
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
title := strings.TrimSpace(strings.ReplaceAll(response.Content, "\n", " "))
|
|
if title == "" {
|
|
return nil
|
|
}
|
|
|
|
session.Title = title
|
|
_, err = a.sessions.Save(ctx, session)
|
|
return err
|
|
}
|
|
|
|
func (a *agent) err(err error) AgentEvent {
|
|
return AgentEvent{
|
|
Type: AgentEventTypeError,
|
|
Error: err,
|
|
}
|
|
}
|
|
|
|
func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
|
|
if !a.provider.Model().SupportsAttachments && attachments != nil {
|
|
attachments = nil
|
|
}
|
|
events := make(chan AgentEvent)
|
|
if a.IsSessionBusy(sessionID) {
|
|
return nil, ErrSessionBusy
|
|
}
|
|
|
|
genCtx, cancel := context.WithCancel(ctx)
|
|
|
|
a.activeRequests.Store(sessionID, cancel)
|
|
go func() {
|
|
logging.Debug("Request started", "sessionID", sessionID)
|
|
defer logging.RecoverPanic("agent.Run", func() {
|
|
events <- a.err(fmt.Errorf("panic while running the agent"))
|
|
})
|
|
var attachmentParts []message.ContentPart
|
|
for _, attachment := range attachments {
|
|
attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
|
|
}
|
|
result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
|
|
if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
|
|
logging.ErrorPersist(result.Error.Error())
|
|
}
|
|
logging.Debug("Request completed", "sessionID", sessionID)
|
|
a.activeRequests.Delete(sessionID)
|
|
cancel()
|
|
a.Publish(pubsub.CreatedEvent, result)
|
|
events <- result
|
|
close(events)
|
|
}()
|
|
return events, nil
|
|
}
|
|
|
|
func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
|
|
// List existing messages; if none, start title generation asynchronously.
|
|
msgs, err := a.messages.List(ctx, sessionID)
|
|
if err != nil {
|
|
return a.err(fmt.Errorf("failed to list messages: %w", err))
|
|
}
|
|
if len(msgs) == 0 {
|
|
go func() {
|
|
defer logging.RecoverPanic("agent.Run", func() {
|
|
logging.ErrorPersist("panic while generating title")
|
|
})
|
|
titleErr := a.generateTitle(context.Background(), sessionID, content)
|
|
if titleErr != nil {
|
|
logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr))
|
|
}
|
|
}()
|
|
}
|
|
|
|
userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
|
|
if err != nil {
|
|
return a.err(fmt.Errorf("failed to create user message: %w", err))
|
|
}
|
|
// Append the new user message to the conversation history.
|
|
msgHistory := append(msgs, userMsg)
|
|
|
|
for {
|
|
// Check for cancellation before each iteration
|
|
select {
|
|
case <-ctx.Done():
|
|
return a.err(ctx.Err())
|
|
default:
|
|
// Continue processing
|
|
}
|
|
agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
|
|
if err != nil {
|
|
if errors.Is(err, context.Canceled) {
|
|
agentMessage.AddFinish(message.FinishReasonCanceled)
|
|
a.messages.Update(context.Background(), agentMessage)
|
|
return a.err(ErrRequestCancelled)
|
|
}
|
|
return a.err(fmt.Errorf("failed to process events: %w", err))
|
|
}
|
|
logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
|
|
if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
|
|
// We are not done, we need to respond with the tool response
|
|
msgHistory = append(msgHistory, agentMessage, *toolResults)
|
|
continue
|
|
}
|
|
return AgentEvent{
|
|
Type: AgentEventTypeResponse,
|
|
Message: agentMessage,
|
|
Done: true,
|
|
}
|
|
}
|
|
}
|
|
|
|
func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
|
|
parts := []message.ContentPart{message.TextContent{Text: content}}
|
|
parts = append(parts, attachmentParts...)
|
|
return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
|
|
Role: message.User,
|
|
Parts: parts,
|
|
})
|
|
}
|
|
|
|
func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
|
|
eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
|
|
|
|
assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
|
|
Role: message.Assistant,
|
|
Parts: []message.ContentPart{},
|
|
Model: a.provider.Model().ID,
|
|
})
|
|
if err != nil {
|
|
return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
|
|
}
|
|
|
|
// Add the session and message ID into the context if needed by tools.
|
|
ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
|
|
ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
|
|
|
|
// Process each event in the stream.
|
|
for event := range eventChan {
|
|
if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
|
|
a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
|
|
return assistantMsg, nil, processErr
|
|
}
|
|
if ctx.Err() != nil {
|
|
a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
|
|
return assistantMsg, nil, ctx.Err()
|
|
}
|
|
}
|
|
|
|
toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
|
|
toolCalls := assistantMsg.ToolCalls()
|
|
for i, toolCall := range toolCalls {
|
|
select {
|
|
case <-ctx.Done():
|
|
a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
|
|
// Make all future tool calls cancelled
|
|
for j := i; j < len(toolCalls); j++ {
|
|
toolResults[j] = message.ToolResult{
|
|
ToolCallID: toolCalls[j].ID,
|
|
Content: "Tool execution canceled by user",
|
|
IsError: true,
|
|
}
|
|
}
|
|
goto out
|
|
default:
|
|
// Continue processing
|
|
var tool tools.BaseTool
|
|
for _, availableTools := range a.tools {
|
|
if availableTools.Info().Name == toolCall.Name {
|
|
tool = availableTools
|
|
}
|
|
}
|
|
|
|
// Tool not found
|
|
if tool == nil {
|
|
toolResults[i] = message.ToolResult{
|
|
ToolCallID: toolCall.ID,
|
|
Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
|
|
IsError: true,
|
|
}
|
|
continue
|
|
}
|
|
toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
|
|
ID: toolCall.ID,
|
|
Name: toolCall.Name,
|
|
Input: toolCall.Input,
|
|
})
|
|
if toolErr != nil {
|
|
if errors.Is(toolErr, permission.ErrorPermissionDenied) {
|
|
toolResults[i] = message.ToolResult{
|
|
ToolCallID: toolCall.ID,
|
|
Content: "Permission denied",
|
|
IsError: true,
|
|
}
|
|
for j := i + 1; j < len(toolCalls); j++ {
|
|
toolResults[j] = message.ToolResult{
|
|
ToolCallID: toolCalls[j].ID,
|
|
Content: "Tool execution canceled by user",
|
|
IsError: true,
|
|
}
|
|
}
|
|
a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
|
|
break
|
|
}
|
|
}
|
|
toolResults[i] = message.ToolResult{
|
|
ToolCallID: toolCall.ID,
|
|
Content: toolResult.Content,
|
|
Metadata: toolResult.Metadata,
|
|
IsError: toolResult.IsError,
|
|
}
|
|
}
|
|
}
|
|
out:
|
|
if len(toolResults) == 0 {
|
|
return assistantMsg, nil, nil
|
|
}
|
|
parts := make([]message.ContentPart, 0)
|
|
for _, tr := range toolResults {
|
|
parts = append(parts, tr)
|
|
}
|
|
msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
|
|
Role: message.Tool,
|
|
Parts: parts,
|
|
})
|
|
if err != nil {
|
|
return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
|
|
}
|
|
|
|
return assistantMsg, &msg, err
|
|
}
|
|
|
|
func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
|
|
msg.AddFinish(finishReson)
|
|
_ = a.messages.Update(ctx, *msg)
|
|
}
|
|
|
|
func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
default:
|
|
// Continue processing.
|
|
}
|
|
|
|
switch event.Type {
|
|
case provider.EventThinkingDelta:
|
|
assistantMsg.AppendReasoningContent(event.Content)
|
|
return a.messages.Update(ctx, *assistantMsg)
|
|
case provider.EventContentDelta:
|
|
assistantMsg.AppendContent(event.Content)
|
|
return a.messages.Update(ctx, *assistantMsg)
|
|
case provider.EventToolUseStart:
|
|
assistantMsg.AddToolCall(*event.ToolCall)
|
|
return a.messages.Update(ctx, *assistantMsg)
|
|
// TODO: see how to handle this
|
|
// case provider.EventToolUseDelta:
|
|
// tm := time.Unix(assistantMsg.UpdatedAt, 0)
|
|
// assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
|
|
// if time.Since(tm) > 1000*time.Millisecond {
|
|
// err := a.messages.Update(ctx, *assistantMsg)
|
|
// assistantMsg.UpdatedAt = time.Now().Unix()
|
|
// return err
|
|
// }
|
|
case provider.EventToolUseStop:
|
|
assistantMsg.FinishToolCall(event.ToolCall.ID)
|
|
return a.messages.Update(ctx, *assistantMsg)
|
|
case provider.EventError:
|
|
if errors.Is(event.Error, context.Canceled) {
|
|
logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
|
|
return context.Canceled
|
|
}
|
|
logging.ErrorPersist(event.Error.Error())
|
|
return event.Error
|
|
case provider.EventComplete:
|
|
assistantMsg.SetToolCalls(event.Response.ToolCalls)
|
|
assistantMsg.AddFinish(event.Response.FinishReason)
|
|
if err := a.messages.Update(ctx, *assistantMsg); err != nil {
|
|
return fmt.Errorf("failed to update message: %w", err)
|
|
}
|
|
return a.TrackUsage(ctx, sessionID, a.provider.Model(), event.Response.Usage)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
|
|
sess, err := a.sessions.Get(ctx, sessionID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get session: %w", err)
|
|
}
|
|
|
|
cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
|
|
model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
|
|
model.CostPer1MIn/1e6*float64(usage.InputTokens) +
|
|
model.CostPer1MOut/1e6*float64(usage.OutputTokens)
|
|
|
|
sess.Cost += cost
|
|
sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
|
|
sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
|
|
|
|
_, err = a.sessions.Save(ctx, sess)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to save session: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (a *agent) Update(agentName config.AgentName, modelID models.ModelID) (models.Model, error) {
|
|
if a.IsBusy() {
|
|
return models.Model{}, fmt.Errorf("cannot change model while processing requests")
|
|
}
|
|
|
|
if err := config.UpdateAgentModel(agentName, modelID); err != nil {
|
|
return models.Model{}, fmt.Errorf("failed to update config: %w", err)
|
|
}
|
|
|
|
provider, err := createAgentProvider(agentName)
|
|
if err != nil {
|
|
return models.Model{}, fmt.Errorf("failed to create provider for model %s: %w", modelID, err)
|
|
}
|
|
|
|
a.provider = provider
|
|
|
|
return a.provider.Model(), nil
|
|
}
|
|
|
|
func (a *agent) Summarize(ctx context.Context, sessionID string) error {
|
|
if a.summarizeProvider == nil {
|
|
return fmt.Errorf("summarize provider not available")
|
|
}
|
|
|
|
// Check if session is busy
|
|
if a.IsSessionBusy(sessionID) {
|
|
return ErrSessionBusy
|
|
}
|
|
|
|
// Create a new context with cancellation
|
|
summarizeCtx, cancel := context.WithCancel(ctx)
|
|
|
|
// Store the cancel function in activeRequests to allow cancellation
|
|
a.activeRequests.Store(sessionID+"-summarize", cancel)
|
|
|
|
go func() {
|
|
defer a.activeRequests.Delete(sessionID + "-summarize")
|
|
defer cancel()
|
|
event := AgentEvent{
|
|
Type: AgentEventTypeSummarize,
|
|
Progress: "Starting summarization...",
|
|
}
|
|
|
|
a.Publish(pubsub.CreatedEvent, event)
|
|
// Get all messages from the session
|
|
msgs, err := a.messages.List(summarizeCtx, sessionID)
|
|
if err != nil {
|
|
event = AgentEvent{
|
|
Type: AgentEventTypeError,
|
|
Error: fmt.Errorf("failed to list messages: %w", err),
|
|
Done: true,
|
|
}
|
|
a.Publish(pubsub.CreatedEvent, event)
|
|
return
|
|
}
|
|
|
|
if len(msgs) == 0 {
|
|
event = AgentEvent{
|
|
Type: AgentEventTypeError,
|
|
Error: fmt.Errorf("no messages to summarize"),
|
|
Done: true,
|
|
}
|
|
a.Publish(pubsub.CreatedEvent, event)
|
|
return
|
|
}
|
|
|
|
event = AgentEvent{
|
|
Type: AgentEventTypeSummarize,
|
|
Progress: "Analyzing conversation...",
|
|
}
|
|
a.Publish(pubsub.CreatedEvent, event)
|
|
|
|
// Add a system message to guide the summarization
|
|
summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
|
|
|
|
// Create a new message with the summarize prompt
|
|
promptMsg := message.Message{
|
|
Role: message.User,
|
|
Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
|
|
}
|
|
|
|
// Append the prompt to the messages
|
|
msgsWithPrompt := append(msgs, promptMsg)
|
|
|
|
event = AgentEvent{
|
|
Type: AgentEventTypeSummarize,
|
|
Progress: "Generating summary...",
|
|
}
|
|
|
|
a.Publish(pubsub.CreatedEvent, event)
|
|
|
|
// Send the messages to the summarize provider
|
|
response, err := a.summarizeProvider.SendMessages(
|
|
summarizeCtx,
|
|
msgsWithPrompt,
|
|
make([]tools.BaseTool, 0),
|
|
)
|
|
if err != nil {
|
|
event = AgentEvent{
|
|
Type: AgentEventTypeError,
|
|
Error: fmt.Errorf("failed to summarize: %w", err),
|
|
Done: true,
|
|
}
|
|
a.Publish(pubsub.CreatedEvent, event)
|
|
return
|
|
}
|
|
|
|
summary := strings.TrimSpace(response.Content)
|
|
if summary == "" {
|
|
event = AgentEvent{
|
|
Type: AgentEventTypeError,
|
|
Error: fmt.Errorf("empty summary returned"),
|
|
Done: true,
|
|
}
|
|
a.Publish(pubsub.CreatedEvent, event)
|
|
return
|
|
}
|
|
event = AgentEvent{
|
|
Type: AgentEventTypeSummarize,
|
|
Progress: "Creating new session...",
|
|
}
|
|
|
|
a.Publish(pubsub.CreatedEvent, event)
|
|
oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
|
|
if err != nil {
|
|
event = AgentEvent{
|
|
Type: AgentEventTypeError,
|
|
Error: fmt.Errorf("failed to get session: %w", err),
|
|
Done: true,
|
|
}
|
|
|
|
a.Publish(pubsub.CreatedEvent, event)
|
|
return
|
|
}
|
|
// Create a new session with the summary
|
|
newSession, err := a.sessions.Create(summarizeCtx, oldSession.Title+" - Continuation")
|
|
if err != nil {
|
|
event = AgentEvent{
|
|
Type: AgentEventTypeError,
|
|
Error: fmt.Errorf("failed to create new session: %w", err),
|
|
Done: true,
|
|
}
|
|
a.Publish(pubsub.CreatedEvent, event)
|
|
return
|
|
}
|
|
|
|
// Create a message in the new session with the summary
|
|
_, err = a.messages.Create(summarizeCtx, newSession.ID, message.CreateMessageParams{
|
|
Role: message.Assistant,
|
|
Parts: []message.ContentPart{message.TextContent{Text: summary}},
|
|
Model: a.summarizeProvider.Model().ID,
|
|
})
|
|
if err != nil {
|
|
event = AgentEvent{
|
|
Type: AgentEventTypeError,
|
|
Error: fmt.Errorf("failed to create summary message: %w", err),
|
|
Done: true,
|
|
}
|
|
|
|
a.Publish(pubsub.CreatedEvent, event)
|
|
return
|
|
}
|
|
event = AgentEvent{
|
|
Type: AgentEventTypeSummarize,
|
|
SessionID: newSession.ID,
|
|
Progress: "Summary complete",
|
|
Done: true,
|
|
}
|
|
a.Publish(pubsub.CreatedEvent, event)
|
|
// Send final success event with the new session ID
|
|
}()
|
|
|
|
return nil
|
|
}
|
|
|
|
func createAgentProvider(agentName config.AgentName) (provider.Provider, error) {
|
|
cfg := config.Get()
|
|
agentConfig, ok := cfg.Agents[agentName]
|
|
if !ok {
|
|
return nil, fmt.Errorf("agent %s not found", agentName)
|
|
}
|
|
model, ok := models.SupportedModels[agentConfig.Model]
|
|
if !ok {
|
|
return nil, fmt.Errorf("model %s not supported", agentConfig.Model)
|
|
}
|
|
|
|
providerCfg, ok := cfg.Providers[model.Provider]
|
|
if !ok {
|
|
return nil, fmt.Errorf("provider %s not supported", model.Provider)
|
|
}
|
|
if providerCfg.Disabled {
|
|
return nil, fmt.Errorf("provider %s is not enabled", model.Provider)
|
|
}
|
|
maxTokens := model.DefaultMaxTokens
|
|
if agentConfig.MaxTokens > 0 {
|
|
maxTokens = agentConfig.MaxTokens
|
|
}
|
|
opts := []provider.ProviderClientOption{
|
|
provider.WithAPIKey(providerCfg.APIKey),
|
|
provider.WithModel(model),
|
|
provider.WithSystemMessage(prompt.GetAgentPrompt(agentName, model.Provider)),
|
|
provider.WithMaxTokens(maxTokens),
|
|
}
|
|
if model.Provider == models.ProviderOpenAI && model.CanReason {
|
|
opts = append(
|
|
opts,
|
|
provider.WithOpenAIOptions(
|
|
provider.WithReasoningEffort(agentConfig.ReasoningEffort),
|
|
),
|
|
)
|
|
} else if model.Provider == models.ProviderAnthropic && model.CanReason && agentName == config.AgentCoder {
|
|
opts = append(
|
|
opts,
|
|
provider.WithAnthropicOptions(
|
|
provider.WithAnthropicShouldThinkFn(provider.DefaultShouldThinkFn),
|
|
),
|
|
)
|
|
}
|
|
agentProvider, err := provider.NewProvider(
|
|
model.Provider,
|
|
opts...,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("could not create provider: %v", err)
|
|
}
|
|
|
|
return agentProvider, nil
|
|
}
|