Browse Source
- AIProvider interface (Chat + ChatStream) - OpenAI-compatible adapter (covers OpenAI, Qwen, Zhipu, DeepSeek) - Anthropic adapter (Claude models) - Factory pattern for provider creation - Unified types (ChatRequest, ChatResponse, StreamChunk)master
5 changed files with 436 additions and 0 deletions
@ -0,0 +1,216 @@ |
|||
package provider |
|||
|
|||
import ( |
|||
"context" |
|||
"fmt" |
|||
"strings" |
|||
|
|||
"github.com/anthropics/anthropic-sdk-go" |
|||
"github.com/anthropics/anthropic-sdk-go/option" |
|||
) |
|||
|
|||
// AnthropicProvider implements AIProvider for Anthropic Claude models.
|
|||
type AnthropicProvider struct { |
|||
client anthropic.Client |
|||
name string |
|||
} |
|||
|
|||
// NewAnthropicProvider creates a new Anthropic provider.
|
|||
// baseUrl can be empty to use the default Anthropic endpoint.
|
|||
func NewAnthropicProvider(baseUrl, apiKey string) *AnthropicProvider { |
|||
opts := []option.RequestOption{ |
|||
option.WithAPIKey(apiKey), |
|||
} |
|||
if baseUrl != "" { |
|||
opts = append(opts, option.WithBaseURL(baseUrl)) |
|||
} |
|||
client := anthropic.NewClient(opts...) |
|||
return &AnthropicProvider{ |
|||
client: client, |
|||
name: "anthropic", |
|||
} |
|||
} |
|||
|
|||
// Name returns the provider name.
|
|||
func (p *AnthropicProvider) Name() string { |
|||
return p.name |
|||
} |
|||
|
|||
// Chat sends a synchronous message request to Anthropic.
|
|||
func (p *AnthropicProvider) Chat(ctx context.Context, req *ChatRequest) (*ChatResponse, error) { |
|||
systemPrompt, messages := convertToAnthropicMessages(req.Messages) |
|||
|
|||
maxTokens := int64(req.MaxTokens) |
|||
if maxTokens <= 0 { |
|||
maxTokens = 4096 |
|||
} |
|||
|
|||
params := anthropic.MessageNewParams{ |
|||
Model: anthropic.Model(req.Model), |
|||
Messages: messages, |
|||
MaxTokens: maxTokens, |
|||
} |
|||
|
|||
if req.Temperature > 0 { |
|||
params.Temperature = anthropic.Float(req.Temperature) |
|||
} |
|||
|
|||
if len(systemPrompt) > 0 { |
|||
params.System = systemPrompt |
|||
} |
|||
|
|||
resp, err := p.client.Messages.New(ctx, params) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("anthropic message creation failed: %w", err) |
|||
} |
|||
|
|||
// Extract text content from response blocks
|
|||
content := extractTextContent(resp.Content) |
|||
|
|||
return &ChatResponse{ |
|||
Content: content, |
|||
Model: string(resp.Model), |
|||
InputTokens: int(resp.Usage.InputTokens), |
|||
OutputTokens: int(resp.Usage.OutputTokens), |
|||
FinishReason: string(resp.StopReason), |
|||
}, nil |
|||
} |
|||
|
|||
// ChatStream sends a streaming message request to Anthropic. It returns a
|
|||
// channel that delivers StreamChunk values. The channel is closed when
|
|||
// the stream ends.
|
|||
func (p *AnthropicProvider) ChatStream(ctx context.Context, req *ChatRequest) (<-chan *StreamChunk, error) { |
|||
systemPrompt, messages := convertToAnthropicMessages(req.Messages) |
|||
|
|||
maxTokens := int64(req.MaxTokens) |
|||
if maxTokens <= 0 { |
|||
maxTokens = 4096 |
|||
} |
|||
|
|||
params := anthropic.MessageNewParams{ |
|||
Model: anthropic.Model(req.Model), |
|||
Messages: messages, |
|||
MaxTokens: maxTokens, |
|||
} |
|||
|
|||
if req.Temperature > 0 { |
|||
params.Temperature = anthropic.Float(req.Temperature) |
|||
} |
|||
|
|||
if len(systemPrompt) > 0 { |
|||
params.System = systemPrompt |
|||
} |
|||
|
|||
stream := p.client.Messages.NewStreaming(ctx, params) |
|||
|
|||
ch := make(chan *StreamChunk, 64) |
|||
|
|||
go func() { |
|||
defer close(ch) |
|||
defer stream.Close() |
|||
|
|||
for stream.Next() { |
|||
event := stream.Current() |
|||
|
|||
switch event.Type { |
|||
case "content_block_delta": |
|||
// Text delta — the main content streaming event
|
|||
delta := event.Delta |
|||
if delta.Text != "" { |
|||
select { |
|||
case ch <- &StreamChunk{Content: delta.Text}: |
|||
case <-ctx.Done(): |
|||
return |
|||
} |
|||
} |
|||
|
|||
case "message_delta": |
|||
// Message delta carries the stop reason and final usage
|
|||
chunk := &StreamChunk{} |
|||
if event.Delta.StopReason != "" { |
|||
chunk.FinishReason = string(event.Delta.StopReason) |
|||
} |
|||
if event.Usage.OutputTokens > 0 { |
|||
chunk.InputTokens = int(event.Usage.InputTokens) |
|||
chunk.OutputTokens = int(event.Usage.OutputTokens) |
|||
} |
|||
select { |
|||
case ch <- chunk: |
|||
case <-ctx.Done(): |
|||
return |
|||
} |
|||
|
|||
case "message_stop": |
|||
// Stream is done
|
|||
select { |
|||
case ch <- &StreamChunk{Done: true}: |
|||
case <-ctx.Done(): |
|||
} |
|||
return |
|||
|
|||
// content_block_start, content_block_stop, message_start: no action needed
|
|||
default: |
|||
continue |
|||
} |
|||
} |
|||
|
|||
// Check for stream errors
|
|||
if err := stream.Err(); err != nil { |
|||
select { |
|||
case ch <- &StreamChunk{ |
|||
Content: fmt.Sprintf("[stream error: %v]", err), |
|||
Done: true, |
|||
FinishReason: "error", |
|||
}: |
|||
case <-ctx.Done(): |
|||
} |
|||
return |
|||
} |
|||
|
|||
// If we exit the loop without message_stop, still signal done
|
|||
select { |
|||
case ch <- &StreamChunk{Done: true}: |
|||
case <-ctx.Done(): |
|||
} |
|||
}() |
|||
|
|||
return ch, nil |
|||
} |
|||
|
|||
// convertToAnthropicMessages separates system messages and converts the rest
|
|||
// to Anthropic MessageParam format. Anthropic does not support a "system" role
|
|||
// in messages; instead, system prompts are passed as a separate field.
|
|||
func convertToAnthropicMessages(messages []ChatMessage) ([]anthropic.TextBlockParam, []anthropic.MessageParam) { |
|||
var systemBlocks []anthropic.TextBlockParam |
|||
var result []anthropic.MessageParam |
|||
|
|||
for _, msg := range messages { |
|||
switch msg.Role { |
|||
case "system": |
|||
systemBlocks = append(systemBlocks, anthropic.TextBlockParam{ |
|||
Text: msg.Content, |
|||
}) |
|||
case "user": |
|||
result = append(result, anthropic.NewUserMessage( |
|||
anthropic.NewTextBlock(msg.Content), |
|||
)) |
|||
case "assistant": |
|||
result = append(result, anthropic.NewAssistantMessage( |
|||
anthropic.NewTextBlock(msg.Content), |
|||
)) |
|||
} |
|||
} |
|||
|
|||
return systemBlocks, result |
|||
} |
|||
|
|||
// extractTextContent concatenates all text blocks from an Anthropic response.
|
|||
func extractTextContent(blocks []anthropic.ContentBlockUnion) string { |
|||
var parts []string |
|||
for _, block := range blocks { |
|||
if block.Type == "text" { |
|||
parts = append(parts, block.Text) |
|||
} |
|||
} |
|||
return strings.Join(parts, "") |
|||
} |
|||
@ -0,0 +1,18 @@ |
|||
package provider |
|||
|
|||
import "fmt" |
|||
|
|||
// NewProvider creates a provider based on SDK type.
|
|||
// Supported sdkType values:
|
|||
// - "openai_compat": OpenAI-compatible APIs (OpenAI, Qwen, Zhipu, DeepSeek, etc.)
|
|||
// - "anthropic": Anthropic Claude models
|
|||
func NewProvider(sdkType, baseUrl, apiKey string) (AIProvider, error) { |
|||
switch sdkType { |
|||
case "openai_compat": |
|||
return NewOpenAIProvider(baseUrl, apiKey), nil |
|||
case "anthropic": |
|||
return NewAnthropicProvider(baseUrl, apiKey), nil |
|||
default: |
|||
return nil, fmt.Errorf("unsupported sdk_type: %s", sdkType) |
|||
} |
|||
} |
|||
@ -0,0 +1,155 @@ |
|||
package provider |
|||
|
|||
import ( |
|||
"context" |
|||
"errors" |
|||
"fmt" |
|||
"io" |
|||
|
|||
openai "github.com/sashabaranov/go-openai" |
|||
) |
|||
|
|||
// OpenAIProvider implements AIProvider for OpenAI-compatible APIs
|
|||
// (OpenAI, Qwen, Zhipu, DeepSeek, etc.)
|
|||
type OpenAIProvider struct { |
|||
client *openai.Client |
|||
name string |
|||
} |
|||
|
|||
// NewOpenAIProvider creates a new OpenAI-compatible provider.
|
|||
// baseUrl can be empty to use the default OpenAI endpoint.
|
|||
func NewOpenAIProvider(baseUrl, apiKey string) *OpenAIProvider { |
|||
config := openai.DefaultConfig(apiKey) |
|||
if baseUrl != "" { |
|||
config.BaseURL = baseUrl |
|||
} |
|||
return &OpenAIProvider{ |
|||
client: openai.NewClientWithConfig(config), |
|||
name: "openai_compat", |
|||
} |
|||
} |
|||
|
|||
// Name returns the provider name.
|
|||
func (p *OpenAIProvider) Name() string { |
|||
return p.name |
|||
} |
|||
|
|||
// Chat sends a synchronous chat completion request.
|
|||
func (p *OpenAIProvider) Chat(ctx context.Context, req *ChatRequest) (*ChatResponse, error) { |
|||
messages := convertToOpenAIMessages(req.Messages) |
|||
|
|||
openaiReq := openai.ChatCompletionRequest{ |
|||
Model: req.Model, |
|||
Messages: messages, |
|||
MaxTokens: req.MaxTokens, |
|||
Temperature: float32(req.Temperature), |
|||
} |
|||
|
|||
resp, err := p.client.CreateChatCompletion(ctx, openaiReq) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("openai chat completion failed: %w", err) |
|||
} |
|||
|
|||
if len(resp.Choices) == 0 { |
|||
return nil, fmt.Errorf("openai chat completion returned no choices") |
|||
} |
|||
|
|||
return &ChatResponse{ |
|||
Content: resp.Choices[0].Message.Content, |
|||
Model: resp.Model, |
|||
InputTokens: resp.Usage.PromptTokens, |
|||
OutputTokens: resp.Usage.CompletionTokens, |
|||
FinishReason: string(resp.Choices[0].FinishReason), |
|||
}, nil |
|||
} |
|||
|
|||
// ChatStream sends a streaming chat completion request. It returns a channel
|
|||
// that delivers StreamChunk values. The channel is closed when the stream
|
|||
// ends or an error occurs. The final chunk has Done=true and may include
|
|||
// token usage if the API provides it.
|
|||
func (p *OpenAIProvider) ChatStream(ctx context.Context, req *ChatRequest) (<-chan *StreamChunk, error) { |
|||
messages := convertToOpenAIMessages(req.Messages) |
|||
|
|||
openaiReq := openai.ChatCompletionRequest{ |
|||
Model: req.Model, |
|||
Messages: messages, |
|||
MaxTokens: req.MaxTokens, |
|||
Temperature: float32(req.Temperature), |
|||
Stream: true, |
|||
StreamOptions: &openai.StreamOptions{ |
|||
IncludeUsage: true, |
|||
}, |
|||
} |
|||
|
|||
stream, err := p.client.CreateChatCompletionStream(ctx, openaiReq) |
|||
if err != nil { |
|||
return nil, fmt.Errorf("openai stream creation failed: %w", err) |
|||
} |
|||
|
|||
ch := make(chan *StreamChunk, 64) |
|||
|
|||
go func() { |
|||
defer close(ch) |
|||
defer stream.Close() |
|||
|
|||
for { |
|||
response, err := stream.Recv() |
|||
if errors.Is(err, io.EOF) { |
|||
// Stream finished normally
|
|||
select { |
|||
case ch <- &StreamChunk{Done: true}: |
|||
case <-ctx.Done(): |
|||
} |
|||
return |
|||
} |
|||
if err != nil { |
|||
// Send error indication as the final chunk
|
|||
select { |
|||
case ch <- &StreamChunk{ |
|||
Content: fmt.Sprintf("[stream error: %v]", err), |
|||
Done: true, |
|||
FinishReason: "error", |
|||
}: |
|||
case <-ctx.Done(): |
|||
} |
|||
return |
|||
} |
|||
|
|||
chunk := &StreamChunk{} |
|||
|
|||
// Extract content delta from choices
|
|||
if len(response.Choices) > 0 { |
|||
chunk.Content = response.Choices[0].Delta.Content |
|||
if response.Choices[0].FinishReason != "" { |
|||
chunk.FinishReason = string(response.Choices[0].FinishReason) |
|||
} |
|||
} |
|||
|
|||
// Extract usage from the final usage chunk (when StreamOptions.IncludeUsage is true)
|
|||
if response.Usage != nil { |
|||
chunk.InputTokens = response.Usage.PromptTokens |
|||
chunk.OutputTokens = response.Usage.CompletionTokens |
|||
} |
|||
|
|||
select { |
|||
case ch <- chunk: |
|||
case <-ctx.Done(): |
|||
return |
|||
} |
|||
} |
|||
}() |
|||
|
|||
return ch, nil |
|||
} |
|||
|
|||
// convertToOpenAIMessages converts our unified ChatMessage slice to OpenAI format.
|
|||
func convertToOpenAIMessages(messages []ChatMessage) []openai.ChatCompletionMessage { |
|||
result := make([]openai.ChatCompletionMessage, 0, len(messages)) |
|||
for _, msg := range messages { |
|||
result = append(result, openai.ChatCompletionMessage{ |
|||
Role: msg.Role, |
|||
Content: msg.Content, |
|||
}) |
|||
} |
|||
return result |
|||
} |
|||
@ -0,0 +1,13 @@ |
|||
package provider |
|||
|
|||
import "context" |
|||
|
|||
// AIProvider defines the interface for AI model providers
|
|||
type AIProvider interface { |
|||
// Chat sends a synchronous chat request
|
|||
Chat(ctx context.Context, req *ChatRequest) (*ChatResponse, error) |
|||
// ChatStream sends a streaming chat request, returning chunks via channel
|
|||
ChatStream(ctx context.Context, req *ChatRequest) (<-chan *StreamChunk, error) |
|||
// Name returns the provider name
|
|||
Name() string |
|||
} |
|||
@ -0,0 +1,34 @@ |
|||
package provider |
|||
|
|||
// ChatMessage represents a single message in a conversation
|
|||
type ChatMessage struct { |
|||
Role string `json:"role"` // user, assistant, system
|
|||
Content string `json:"content"` |
|||
} |
|||
|
|||
// ChatRequest is the unified request format
|
|||
type ChatRequest struct { |
|||
Model string `json:"model"` |
|||
Messages []ChatMessage `json:"messages"` |
|||
MaxTokens int `json:"max_tokens,omitempty"` |
|||
Temperature float64 `json:"temperature,omitempty"` |
|||
Stream bool `json:"stream"` |
|||
} |
|||
|
|||
// ChatResponse is the unified non-streaming response
|
|||
type ChatResponse struct { |
|||
Content string `json:"content"` |
|||
Model string `json:"model"` |
|||
InputTokens int `json:"input_tokens"` |
|||
OutputTokens int `json:"output_tokens"` |
|||
FinishReason string `json:"finish_reason"` |
|||
} |
|||
|
|||
// StreamChunk represents a single SSE chunk
|
|||
type StreamChunk struct { |
|||
Content string `json:"content,omitempty"` |
|||
FinishReason string `json:"finish_reason,omitempty"` |
|||
InputTokens int `json:"input_tokens,omitempty"` |
|||
OutputTokens int `json:"output_tokens,omitempty"` |
|||
Done bool `json:"done"` |
|||
} |
|||
Loading…
Reference in new issue