Browse Source
- AIProvider interface (Chat + ChatStream) - OpenAI-compatible adapter (covers OpenAI, Qwen, Zhipu, DeepSeek) - Anthropic adapter (Claude models) - Factory pattern for provider creation - Unified types (ChatRequest, ChatResponse, StreamChunk)master
5 changed files with 436 additions and 0 deletions
@ -0,0 +1,216 @@ |
|||||
|
package provider |
||||
|
|
||||
|
import ( |
||||
|
"context" |
||||
|
"fmt" |
||||
|
"strings" |
||||
|
|
||||
|
"github.com/anthropics/anthropic-sdk-go" |
||||
|
"github.com/anthropics/anthropic-sdk-go/option" |
||||
|
) |
||||
|
|
||||
|
// AnthropicProvider implements AIProvider for Anthropic Claude models.
|
||||
|
type AnthropicProvider struct { |
||||
|
client anthropic.Client |
||||
|
name string |
||||
|
} |
||||
|
|
||||
|
// NewAnthropicProvider creates a new Anthropic provider.
|
||||
|
// baseUrl can be empty to use the default Anthropic endpoint.
|
||||
|
func NewAnthropicProvider(baseUrl, apiKey string) *AnthropicProvider { |
||||
|
opts := []option.RequestOption{ |
||||
|
option.WithAPIKey(apiKey), |
||||
|
} |
||||
|
if baseUrl != "" { |
||||
|
opts = append(opts, option.WithBaseURL(baseUrl)) |
||||
|
} |
||||
|
client := anthropic.NewClient(opts...) |
||||
|
return &AnthropicProvider{ |
||||
|
client: client, |
||||
|
name: "anthropic", |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// Name returns the provider name.
|
||||
|
func (p *AnthropicProvider) Name() string { |
||||
|
return p.name |
||||
|
} |
||||
|
|
||||
|
// Chat sends a synchronous message request to Anthropic.
|
||||
|
func (p *AnthropicProvider) Chat(ctx context.Context, req *ChatRequest) (*ChatResponse, error) { |
||||
|
systemPrompt, messages := convertToAnthropicMessages(req.Messages) |
||||
|
|
||||
|
maxTokens := int64(req.MaxTokens) |
||||
|
if maxTokens <= 0 { |
||||
|
maxTokens = 4096 |
||||
|
} |
||||
|
|
||||
|
params := anthropic.MessageNewParams{ |
||||
|
Model: anthropic.Model(req.Model), |
||||
|
Messages: messages, |
||||
|
MaxTokens: maxTokens, |
||||
|
} |
||||
|
|
||||
|
if req.Temperature > 0 { |
||||
|
params.Temperature = anthropic.Float(req.Temperature) |
||||
|
} |
||||
|
|
||||
|
if len(systemPrompt) > 0 { |
||||
|
params.System = systemPrompt |
||||
|
} |
||||
|
|
||||
|
resp, err := p.client.Messages.New(ctx, params) |
||||
|
if err != nil { |
||||
|
return nil, fmt.Errorf("anthropic message creation failed: %w", err) |
||||
|
} |
||||
|
|
||||
|
// Extract text content from response blocks
|
||||
|
content := extractTextContent(resp.Content) |
||||
|
|
||||
|
return &ChatResponse{ |
||||
|
Content: content, |
||||
|
Model: string(resp.Model), |
||||
|
InputTokens: int(resp.Usage.InputTokens), |
||||
|
OutputTokens: int(resp.Usage.OutputTokens), |
||||
|
FinishReason: string(resp.StopReason), |
||||
|
}, nil |
||||
|
} |
||||
|
|
||||
|
// ChatStream sends a streaming message request to Anthropic. It returns a
|
||||
|
// channel that delivers StreamChunk values. The channel is closed when
|
||||
|
// the stream ends.
|
||||
|
func (p *AnthropicProvider) ChatStream(ctx context.Context, req *ChatRequest) (<-chan *StreamChunk, error) { |
||||
|
systemPrompt, messages := convertToAnthropicMessages(req.Messages) |
||||
|
|
||||
|
maxTokens := int64(req.MaxTokens) |
||||
|
if maxTokens <= 0 { |
||||
|
maxTokens = 4096 |
||||
|
} |
||||
|
|
||||
|
params := anthropic.MessageNewParams{ |
||||
|
Model: anthropic.Model(req.Model), |
||||
|
Messages: messages, |
||||
|
MaxTokens: maxTokens, |
||||
|
} |
||||
|
|
||||
|
if req.Temperature > 0 { |
||||
|
params.Temperature = anthropic.Float(req.Temperature) |
||||
|
} |
||||
|
|
||||
|
if len(systemPrompt) > 0 { |
||||
|
params.System = systemPrompt |
||||
|
} |
||||
|
|
||||
|
stream := p.client.Messages.NewStreaming(ctx, params) |
||||
|
|
||||
|
ch := make(chan *StreamChunk, 64) |
||||
|
|
||||
|
go func() { |
||||
|
defer close(ch) |
||||
|
defer stream.Close() |
||||
|
|
||||
|
for stream.Next() { |
||||
|
event := stream.Current() |
||||
|
|
||||
|
switch event.Type { |
||||
|
case "content_block_delta": |
||||
|
// Text delta — the main content streaming event
|
||||
|
delta := event.Delta |
||||
|
if delta.Text != "" { |
||||
|
select { |
||||
|
case ch <- &StreamChunk{Content: delta.Text}: |
||||
|
case <-ctx.Done(): |
||||
|
return |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
case "message_delta": |
||||
|
// Message delta carries the stop reason and final usage
|
||||
|
chunk := &StreamChunk{} |
||||
|
if event.Delta.StopReason != "" { |
||||
|
chunk.FinishReason = string(event.Delta.StopReason) |
||||
|
} |
||||
|
if event.Usage.OutputTokens > 0 { |
||||
|
chunk.InputTokens = int(event.Usage.InputTokens) |
||||
|
chunk.OutputTokens = int(event.Usage.OutputTokens) |
||||
|
} |
||||
|
select { |
||||
|
case ch <- chunk: |
||||
|
case <-ctx.Done(): |
||||
|
return |
||||
|
} |
||||
|
|
||||
|
case "message_stop": |
||||
|
// Stream is done
|
||||
|
select { |
||||
|
case ch <- &StreamChunk{Done: true}: |
||||
|
case <-ctx.Done(): |
||||
|
} |
||||
|
return |
||||
|
|
||||
|
// content_block_start, content_block_stop, message_start: no action needed
|
||||
|
default: |
||||
|
continue |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// Check for stream errors
|
||||
|
if err := stream.Err(); err != nil { |
||||
|
select { |
||||
|
case ch <- &StreamChunk{ |
||||
|
Content: fmt.Sprintf("[stream error: %v]", err), |
||||
|
Done: true, |
||||
|
FinishReason: "error", |
||||
|
}: |
||||
|
case <-ctx.Done(): |
||||
|
} |
||||
|
return |
||||
|
} |
||||
|
|
||||
|
// If we exit the loop without message_stop, still signal done
|
||||
|
select { |
||||
|
case ch <- &StreamChunk{Done: true}: |
||||
|
case <-ctx.Done(): |
||||
|
} |
||||
|
}() |
||||
|
|
||||
|
return ch, nil |
||||
|
} |
||||
|
|
||||
|
// convertToAnthropicMessages separates system messages and converts the rest
|
||||
|
// to Anthropic MessageParam format. Anthropic does not support a "system" role
|
||||
|
// in messages; instead, system prompts are passed as a separate field.
|
||||
|
func convertToAnthropicMessages(messages []ChatMessage) ([]anthropic.TextBlockParam, []anthropic.MessageParam) { |
||||
|
var systemBlocks []anthropic.TextBlockParam |
||||
|
var result []anthropic.MessageParam |
||||
|
|
||||
|
for _, msg := range messages { |
||||
|
switch msg.Role { |
||||
|
case "system": |
||||
|
systemBlocks = append(systemBlocks, anthropic.TextBlockParam{ |
||||
|
Text: msg.Content, |
||||
|
}) |
||||
|
case "user": |
||||
|
result = append(result, anthropic.NewUserMessage( |
||||
|
anthropic.NewTextBlock(msg.Content), |
||||
|
)) |
||||
|
case "assistant": |
||||
|
result = append(result, anthropic.NewAssistantMessage( |
||||
|
anthropic.NewTextBlock(msg.Content), |
||||
|
)) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
return systemBlocks, result |
||||
|
} |
||||
|
|
||||
|
// extractTextContent concatenates all text blocks from an Anthropic response.
|
||||
|
func extractTextContent(blocks []anthropic.ContentBlockUnion) string { |
||||
|
var parts []string |
||||
|
for _, block := range blocks { |
||||
|
if block.Type == "text" { |
||||
|
parts = append(parts, block.Text) |
||||
|
} |
||||
|
} |
||||
|
return strings.Join(parts, "") |
||||
|
} |
||||
@ -0,0 +1,18 @@ |
|||||
|
package provider |
||||
|
|
||||
|
import "fmt" |
||||
|
|
||||
|
// NewProvider creates a provider based on SDK type.
|
||||
|
// Supported sdkType values:
|
||||
|
// - "openai_compat": OpenAI-compatible APIs (OpenAI, Qwen, Zhipu, DeepSeek, etc.)
|
||||
|
// - "anthropic": Anthropic Claude models
|
||||
|
func NewProvider(sdkType, baseUrl, apiKey string) (AIProvider, error) { |
||||
|
switch sdkType { |
||||
|
case "openai_compat": |
||||
|
return NewOpenAIProvider(baseUrl, apiKey), nil |
||||
|
case "anthropic": |
||||
|
return NewAnthropicProvider(baseUrl, apiKey), nil |
||||
|
default: |
||||
|
return nil, fmt.Errorf("unsupported sdk_type: %s", sdkType) |
||||
|
} |
||||
|
} |
||||
@ -0,0 +1,155 @@ |
|||||
|
package provider |
||||
|
|
||||
|
import ( |
||||
|
"context" |
||||
|
"errors" |
||||
|
"fmt" |
||||
|
"io" |
||||
|
|
||||
|
openai "github.com/sashabaranov/go-openai" |
||||
|
) |
||||
|
|
||||
|
// OpenAIProvider implements AIProvider for OpenAI-compatible APIs
|
||||
|
// (OpenAI, Qwen, Zhipu, DeepSeek, etc.)
|
||||
|
type OpenAIProvider struct { |
||||
|
client *openai.Client |
||||
|
name string |
||||
|
} |
||||
|
|
||||
|
// NewOpenAIProvider creates a new OpenAI-compatible provider.
|
||||
|
// baseUrl can be empty to use the default OpenAI endpoint.
|
||||
|
func NewOpenAIProvider(baseUrl, apiKey string) *OpenAIProvider { |
||||
|
config := openai.DefaultConfig(apiKey) |
||||
|
if baseUrl != "" { |
||||
|
config.BaseURL = baseUrl |
||||
|
} |
||||
|
return &OpenAIProvider{ |
||||
|
client: openai.NewClientWithConfig(config), |
||||
|
name: "openai_compat", |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// Name returns the provider name.
|
||||
|
func (p *OpenAIProvider) Name() string { |
||||
|
return p.name |
||||
|
} |
||||
|
|
||||
|
// Chat sends a synchronous chat completion request.
|
||||
|
func (p *OpenAIProvider) Chat(ctx context.Context, req *ChatRequest) (*ChatResponse, error) { |
||||
|
messages := convertToOpenAIMessages(req.Messages) |
||||
|
|
||||
|
openaiReq := openai.ChatCompletionRequest{ |
||||
|
Model: req.Model, |
||||
|
Messages: messages, |
||||
|
MaxTokens: req.MaxTokens, |
||||
|
Temperature: float32(req.Temperature), |
||||
|
} |
||||
|
|
||||
|
resp, err := p.client.CreateChatCompletion(ctx, openaiReq) |
||||
|
if err != nil { |
||||
|
return nil, fmt.Errorf("openai chat completion failed: %w", err) |
||||
|
} |
||||
|
|
||||
|
if len(resp.Choices) == 0 { |
||||
|
return nil, fmt.Errorf("openai chat completion returned no choices") |
||||
|
} |
||||
|
|
||||
|
return &ChatResponse{ |
||||
|
Content: resp.Choices[0].Message.Content, |
||||
|
Model: resp.Model, |
||||
|
InputTokens: resp.Usage.PromptTokens, |
||||
|
OutputTokens: resp.Usage.CompletionTokens, |
||||
|
FinishReason: string(resp.Choices[0].FinishReason), |
||||
|
}, nil |
||||
|
} |
||||
|
|
||||
|
// ChatStream sends a streaming chat completion request. It returns a channel
|
||||
|
// that delivers StreamChunk values. The channel is closed when the stream
|
||||
|
// ends or an error occurs. The final chunk has Done=true and may include
|
||||
|
// token usage if the API provides it.
|
||||
|
func (p *OpenAIProvider) ChatStream(ctx context.Context, req *ChatRequest) (<-chan *StreamChunk, error) { |
||||
|
messages := convertToOpenAIMessages(req.Messages) |
||||
|
|
||||
|
openaiReq := openai.ChatCompletionRequest{ |
||||
|
Model: req.Model, |
||||
|
Messages: messages, |
||||
|
MaxTokens: req.MaxTokens, |
||||
|
Temperature: float32(req.Temperature), |
||||
|
Stream: true, |
||||
|
StreamOptions: &openai.StreamOptions{ |
||||
|
IncludeUsage: true, |
||||
|
}, |
||||
|
} |
||||
|
|
||||
|
stream, err := p.client.CreateChatCompletionStream(ctx, openaiReq) |
||||
|
if err != nil { |
||||
|
return nil, fmt.Errorf("openai stream creation failed: %w", err) |
||||
|
} |
||||
|
|
||||
|
ch := make(chan *StreamChunk, 64) |
||||
|
|
||||
|
go func() { |
||||
|
defer close(ch) |
||||
|
defer stream.Close() |
||||
|
|
||||
|
for { |
||||
|
response, err := stream.Recv() |
||||
|
if errors.Is(err, io.EOF) { |
||||
|
// Stream finished normally
|
||||
|
select { |
||||
|
case ch <- &StreamChunk{Done: true}: |
||||
|
case <-ctx.Done(): |
||||
|
} |
||||
|
return |
||||
|
} |
||||
|
if err != nil { |
||||
|
// Send error indication as the final chunk
|
||||
|
select { |
||||
|
case ch <- &StreamChunk{ |
||||
|
Content: fmt.Sprintf("[stream error: %v]", err), |
||||
|
Done: true, |
||||
|
FinishReason: "error", |
||||
|
}: |
||||
|
case <-ctx.Done(): |
||||
|
} |
||||
|
return |
||||
|
} |
||||
|
|
||||
|
chunk := &StreamChunk{} |
||||
|
|
||||
|
// Extract content delta from choices
|
||||
|
if len(response.Choices) > 0 { |
||||
|
chunk.Content = response.Choices[0].Delta.Content |
||||
|
if response.Choices[0].FinishReason != "" { |
||||
|
chunk.FinishReason = string(response.Choices[0].FinishReason) |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// Extract usage from the final usage chunk (when StreamOptions.IncludeUsage is true)
|
||||
|
if response.Usage != nil { |
||||
|
chunk.InputTokens = response.Usage.PromptTokens |
||||
|
chunk.OutputTokens = response.Usage.CompletionTokens |
||||
|
} |
||||
|
|
||||
|
select { |
||||
|
case ch <- chunk: |
||||
|
case <-ctx.Done(): |
||||
|
return |
||||
|
} |
||||
|
} |
||||
|
}() |
||||
|
|
||||
|
return ch, nil |
||||
|
} |
||||
|
|
||||
|
// convertToOpenAIMessages converts our unified ChatMessage slice to OpenAI format.
|
||||
|
func convertToOpenAIMessages(messages []ChatMessage) []openai.ChatCompletionMessage { |
||||
|
result := make([]openai.ChatCompletionMessage, 0, len(messages)) |
||||
|
for _, msg := range messages { |
||||
|
result = append(result, openai.ChatCompletionMessage{ |
||||
|
Role: msg.Role, |
||||
|
Content: msg.Content, |
||||
|
}) |
||||
|
} |
||||
|
return result |
||||
|
} |
||||
@ -0,0 +1,13 @@ |
|||||
|
package provider |
||||
|
|
||||
|
import "context" |
||||
|
|
||||
|
// AIProvider defines the interface for AI model providers
|
||||
|
type AIProvider interface { |
||||
|
// Chat sends a synchronous chat request
|
||||
|
Chat(ctx context.Context, req *ChatRequest) (*ChatResponse, error) |
||||
|
// ChatStream sends a streaming chat request, returning chunks via channel
|
||||
|
ChatStream(ctx context.Context, req *ChatRequest) (<-chan *StreamChunk, error) |
||||
|
// Name returns the provider name
|
||||
|
Name() string |
||||
|
} |
||||
@ -0,0 +1,34 @@ |
|||||
|
package provider |
||||
|
|
||||
|
// ChatMessage represents a single message in a conversation
|
||||
|
type ChatMessage struct { |
||||
|
Role string `json:"role"` // user, assistant, system
|
||||
|
Content string `json:"content"` |
||||
|
} |
||||
|
|
||||
|
// ChatRequest is the unified request format
|
||||
|
type ChatRequest struct { |
||||
|
Model string `json:"model"` |
||||
|
Messages []ChatMessage `json:"messages"` |
||||
|
MaxTokens int `json:"max_tokens,omitempty"` |
||||
|
Temperature float64 `json:"temperature,omitempty"` |
||||
|
Stream bool `json:"stream"` |
||||
|
} |
||||
|
|
||||
|
// ChatResponse is the unified non-streaming response
|
||||
|
type ChatResponse struct { |
||||
|
Content string `json:"content"` |
||||
|
Model string `json:"model"` |
||||
|
InputTokens int `json:"input_tokens"` |
||||
|
OutputTokens int `json:"output_tokens"` |
||||
|
FinishReason string `json:"finish_reason"` |
||||
|
} |
||||
|
|
||||
|
// StreamChunk represents a single SSE chunk
|
||||
|
type StreamChunk struct { |
||||
|
Content string `json:"content,omitempty"` |
||||
|
FinishReason string `json:"finish_reason,omitempty"` |
||||
|
InputTokens int `json:"input_tokens,omitempty"` |
||||
|
OutputTokens int `json:"output_tokens,omitempty"` |
||||
|
Done bool `json:"done"` |
||||
|
} |
||||
Loading…
Reference in new issue