mod types; use anyhow::Result; use futures::Stream; use std::pin::Pin; pub use types::*; use crate::config::Config; use crate::tools::ToolDefinition; /// Trait for LLM providers. Anything that speaks OpenAI chat completions works. pub trait Provider: Send + Sync { fn stream_chat( &self, messages: &[ChatMessage], tools: &[ToolDefinition], ) -> Pin> + Send + '_>>; } /// OpenAI-compatible provider (works with vLLM, Ollama, llama.cpp, OpenAI, etc.) pub struct OpenAIProvider { client: reqwest::Client, endpoint: String, api_key: Option, model: String, max_tokens: u32, temperature: Option, } impl OpenAIProvider { pub fn new(config: &Config) -> Self { Self { client: reqwest::Client::new(), endpoint: config.endpoint.trim_end_matches('/').to_string(), api_key: config.api_key.clone(), model: config.model.clone(), max_tokens: config.max_tokens, temperature: config.temperature, } } fn build_request_body( &self, messages: &[ChatMessage], tools: &[ToolDefinition], ) -> serde_json::Value { let mut body = serde_json::json!({ "model": self.model, "messages": messages, "max_tokens": self.max_tokens, "stream": true, }); if let Some(temp) = self.temperature { body["temperature"] = serde_json::json!(temp); } if !tools.is_empty() { let tool_defs: Vec = tools .iter() .map(|t| { serde_json::json!({ "type": "function", "function": { "name": t.name, "description": t.description, "parameters": t.parameters, } }) }) .collect(); body["tools"] = serde_json::json!(tool_defs); } body } } impl Provider for OpenAIProvider { fn stream_chat( &self, messages: &[ChatMessage], tools: &[ToolDefinition], ) -> Pin> + Send + '_>> { let body = self.build_request_body(messages, tools); let url = format!("{}/chat/completions", self.endpoint); let mut req = self.client.post(&url).json(&body); if let Some(ref key) = self.api_key { req = req.bearer_auth(key); } Box::pin(async_stream::stream! { let response = match req.send().await { Ok(r) => r, Err(e) => { yield Err(anyhow::anyhow!(e)); return; } }; if !response.status().is_success() { let status = response.status(); let text = response.text().await.unwrap_or_default(); yield Err(anyhow::anyhow!("API error {status}: {text}")); return; } use futures::StreamExt; let mut stream = response.bytes_stream(); let mut buffer = String::new(); while let Some(chunk) = stream.next().await { let chunk = match chunk { Ok(c) => c, Err(e) => { yield Err(anyhow::anyhow!(e)); return; } }; buffer.push_str(&String::from_utf8_lossy(&chunk)); // Process complete SSE lines while let Some(line_end) = buffer.find('\n') { let line = buffer[..line_end].trim().to_string(); buffer = buffer[line_end + 1..].to_string(); if line.is_empty() || line.starts_with(':') { continue; } if let Some(data) = line.strip_prefix("data: ") { if data == "[DONE]" { yield Ok(StreamEvent::Done); return; } match serde_json::from_str::(data) { Ok(chunk) => { for choice in &chunk.choices { if let Some(ref content) = choice.delta.content { yield Ok(StreamEvent::Text(content.clone())); } if let Some(ref tool_calls) = choice.delta.tool_calls { for tc in tool_calls { yield Ok(StreamEvent::ToolCallDelta(ToolCallDelta { index: tc.index, id: tc.id.clone(), name: tc.function.as_ref().and_then(|f| f.name.clone()), arguments_delta: tc.function.as_ref().and_then(|f| f.arguments.clone()), })); } } if choice.finish_reason.is_some() { yield Ok(StreamEvent::Finish); } } } Err(e) => { tracing::warn!("Failed to parse SSE chunk: {e}: {data}"); } } } } } }) } }