You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
216 lines
5.2 KiB
216 lines
5.2 KiB
package provider
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/anthropics/anthropic-sdk-go"
|
|
"github.com/anthropics/anthropic-sdk-go/option"
|
|
)
|
|
|
|
// AnthropicProvider implements AIProvider for Anthropic Claude models.
|
|
type AnthropicProvider struct {
|
|
client anthropic.Client
|
|
name string
|
|
}
|
|
|
|
// NewAnthropicProvider creates a new Anthropic provider.
|
|
// baseUrl can be empty to use the default Anthropic endpoint.
|
|
func NewAnthropicProvider(baseUrl, apiKey string) *AnthropicProvider {
|
|
opts := []option.RequestOption{
|
|
option.WithAPIKey(apiKey),
|
|
}
|
|
if baseUrl != "" {
|
|
opts = append(opts, option.WithBaseURL(baseUrl))
|
|
}
|
|
client := anthropic.NewClient(opts...)
|
|
return &AnthropicProvider{
|
|
client: client,
|
|
name: "anthropic",
|
|
}
|
|
}
|
|
|
|
// Name returns the provider name.
|
|
func (p *AnthropicProvider) Name() string {
|
|
return p.name
|
|
}
|
|
|
|
// Chat sends a synchronous message request to Anthropic.
|
|
func (p *AnthropicProvider) Chat(ctx context.Context, req *ChatRequest) (*ChatResponse, error) {
|
|
systemPrompt, messages := convertToAnthropicMessages(req.Messages)
|
|
|
|
maxTokens := int64(req.MaxTokens)
|
|
if maxTokens <= 0 {
|
|
maxTokens = 4096
|
|
}
|
|
|
|
params := anthropic.MessageNewParams{
|
|
Model: anthropic.Model(req.Model),
|
|
Messages: messages,
|
|
MaxTokens: maxTokens,
|
|
}
|
|
|
|
if req.Temperature > 0 {
|
|
params.Temperature = anthropic.Float(req.Temperature)
|
|
}
|
|
|
|
if len(systemPrompt) > 0 {
|
|
params.System = systemPrompt
|
|
}
|
|
|
|
resp, err := p.client.Messages.New(ctx, params)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("anthropic message creation failed: %w", err)
|
|
}
|
|
|
|
// Extract text content from response blocks
|
|
content := extractTextContent(resp.Content)
|
|
|
|
return &ChatResponse{
|
|
Content: content,
|
|
Model: string(resp.Model),
|
|
InputTokens: int(resp.Usage.InputTokens),
|
|
OutputTokens: int(resp.Usage.OutputTokens),
|
|
FinishReason: string(resp.StopReason),
|
|
}, nil
|
|
}
|
|
|
|
// ChatStream sends a streaming message request to Anthropic. It returns a
|
|
// channel that delivers StreamChunk values. The channel is closed when
|
|
// the stream ends.
|
|
func (p *AnthropicProvider) ChatStream(ctx context.Context, req *ChatRequest) (<-chan *StreamChunk, error) {
|
|
systemPrompt, messages := convertToAnthropicMessages(req.Messages)
|
|
|
|
maxTokens := int64(req.MaxTokens)
|
|
if maxTokens <= 0 {
|
|
maxTokens = 4096
|
|
}
|
|
|
|
params := anthropic.MessageNewParams{
|
|
Model: anthropic.Model(req.Model),
|
|
Messages: messages,
|
|
MaxTokens: maxTokens,
|
|
}
|
|
|
|
if req.Temperature > 0 {
|
|
params.Temperature = anthropic.Float(req.Temperature)
|
|
}
|
|
|
|
if len(systemPrompt) > 0 {
|
|
params.System = systemPrompt
|
|
}
|
|
|
|
stream := p.client.Messages.NewStreaming(ctx, params)
|
|
|
|
ch := make(chan *StreamChunk, 64)
|
|
|
|
go func() {
|
|
defer close(ch)
|
|
defer stream.Close()
|
|
|
|
for stream.Next() {
|
|
event := stream.Current()
|
|
|
|
switch event.Type {
|
|
case "content_block_delta":
|
|
// Text delta — the main content streaming event
|
|
delta := event.Delta
|
|
if delta.Text != "" {
|
|
select {
|
|
case ch <- &StreamChunk{Content: delta.Text}:
|
|
case <-ctx.Done():
|
|
return
|
|
}
|
|
}
|
|
|
|
case "message_delta":
|
|
// Message delta carries the stop reason and final usage
|
|
chunk := &StreamChunk{}
|
|
if event.Delta.StopReason != "" {
|
|
chunk.FinishReason = string(event.Delta.StopReason)
|
|
}
|
|
if event.Usage.OutputTokens > 0 {
|
|
chunk.InputTokens = int(event.Usage.InputTokens)
|
|
chunk.OutputTokens = int(event.Usage.OutputTokens)
|
|
}
|
|
select {
|
|
case ch <- chunk:
|
|
case <-ctx.Done():
|
|
return
|
|
}
|
|
|
|
case "message_stop":
|
|
// Stream is done
|
|
select {
|
|
case ch <- &StreamChunk{Done: true}:
|
|
case <-ctx.Done():
|
|
}
|
|
return
|
|
|
|
// content_block_start, content_block_stop, message_start: no action needed
|
|
default:
|
|
continue
|
|
}
|
|
}
|
|
|
|
// Check for stream errors
|
|
if err := stream.Err(); err != nil {
|
|
select {
|
|
case ch <- &StreamChunk{
|
|
Content: fmt.Sprintf("[stream error: %v]", err),
|
|
Done: true,
|
|
FinishReason: "error",
|
|
}:
|
|
case <-ctx.Done():
|
|
}
|
|
return
|
|
}
|
|
|
|
// If we exit the loop without message_stop, still signal done
|
|
select {
|
|
case ch <- &StreamChunk{Done: true}:
|
|
case <-ctx.Done():
|
|
}
|
|
}()
|
|
|
|
return ch, nil
|
|
}
|
|
|
|
// convertToAnthropicMessages separates system messages and converts the rest
|
|
// to Anthropic MessageParam format. Anthropic does not support a "system" role
|
|
// in messages; instead, system prompts are passed as a separate field.
|
|
func convertToAnthropicMessages(messages []ChatMessage) ([]anthropic.TextBlockParam, []anthropic.MessageParam) {
|
|
var systemBlocks []anthropic.TextBlockParam
|
|
var result []anthropic.MessageParam
|
|
|
|
for _, msg := range messages {
|
|
switch msg.Role {
|
|
case "system":
|
|
systemBlocks = append(systemBlocks, anthropic.TextBlockParam{
|
|
Text: msg.Content,
|
|
})
|
|
case "user":
|
|
result = append(result, anthropic.NewUserMessage(
|
|
anthropic.NewTextBlock(msg.Content),
|
|
))
|
|
case "assistant":
|
|
result = append(result, anthropic.NewAssistantMessage(
|
|
anthropic.NewTextBlock(msg.Content),
|
|
))
|
|
}
|
|
}
|
|
|
|
return systemBlocks, result
|
|
}
|
|
|
|
// extractTextContent concatenates all text blocks from an Anthropic response.
|
|
func extractTextContent(blocks []anthropic.ContentBlockUnion) string {
|
|
var parts []string
|
|
for _, block := range blocks {
|
|
if block.Type == "text" {
|
|
parts = append(parts, block.Text)
|
|
}
|
|
}
|
|
return strings.Join(parts, "")
|
|
}
|
|
|