Initial Slug Code Rust implementation
Core features: - OpenAI-compatible streaming provider (vLLM, Ollama, OpenAI, etc.) - Agent loop with tool use (bash, read, write, edit, glob, grep) - Permission system: ask/yolo/sandbox/allowEdits + glob patterns - SLUG.md hierarchy loaded every turn (CLAUDE.md equivalent) - Session persistence with --continue/--resume/--fork-session - Hook system: 5 lifecycle events, command + prompt types - Compaction: ToolResultTrim/Truncate strategies, /compact command - Config via TOML, CLI args, env vars Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
f2e1d53e37
commit
b8bf9029fe
21 changed files with 6280 additions and 0 deletions
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
|
|
@ -0,0 +1 @@
|
|||
/target
|
||||
3049
Cargo.lock
generated
Normal file
3049
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load diff
55
Cargo.toml
Normal file
55
Cargo.toml
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
[package]
|
||||
name = "slug-code"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[[bin]]
|
||||
name = "slug"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
# Async runtime
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
futures = "0.3"
|
||||
async-stream = "0.3"
|
||||
|
||||
# HTTP & streaming
|
||||
reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "rustls-tls"] }
|
||||
reqwest-eventsource = "0.6"
|
||||
eventsource-stream = "0.2"
|
||||
|
||||
# Serialization
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
|
||||
# CLI
|
||||
clap = { version = "4", features = ["derive", "env"] }
|
||||
|
||||
# TUI
|
||||
ratatui = "0.29"
|
||||
crossterm = "0.28"
|
||||
tui-textarea = "0.7"
|
||||
|
||||
# Terminal markdown
|
||||
termimad = "0.30"
|
||||
|
||||
# File operations
|
||||
globwalk = "0.9"
|
||||
grep-regex = "0.1"
|
||||
ignore = "0.4"
|
||||
|
||||
# Utils
|
||||
anyhow = "1"
|
||||
dirs = "6"
|
||||
toml = "0.8"
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
uuid = { version = "1", features = ["v4"] }
|
||||
syntect = "5"
|
||||
shell-words = "1"
|
||||
which = "7"
|
||||
|
||||
[profile.release]
|
||||
opt-level = 3
|
||||
lto = true
|
||||
strip = true
|
||||
272
src/agent/mod.rs
Normal file
272
src/agent/mod.rs
Normal file
|
|
@ -0,0 +1,272 @@
|
|||
use anyhow::Result;
|
||||
use futures::StreamExt;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::compact::{CompactionStrategy, Compactor};
|
||||
use crate::config::Config;
|
||||
use crate::permissions::{PermissionHandler, PermissionRequest};
|
||||
use crate::provider::{ChatMessage, FunctionCall, Provider, StreamEvent, ToolCall};
|
||||
use crate::slugmd::load_slug_context;
|
||||
use crate::tools::ToolRegistry;
|
||||
|
||||
/// Assembled tool call being built from streaming deltas.
|
||||
#[derive(Debug, Default)]
|
||||
struct ToolCallAccumulator {
|
||||
id: Option<String>,
|
||||
name: Option<String>,
|
||||
arguments: String,
|
||||
}
|
||||
|
||||
/// The core agent loop: chat with the LLM, execute tools, repeat.
|
||||
pub struct Agent {
|
||||
provider: Box<dyn Provider>,
|
||||
tools: ToolRegistry,
|
||||
permissions: PermissionHandler,
|
||||
messages: Vec<ChatMessage>,
|
||||
max_rounds: usize,
|
||||
}
|
||||
|
||||
impl Agent {
|
||||
pub fn new(
|
||||
provider: Box<dyn Provider>,
|
||||
tools: ToolRegistry,
|
||||
permissions: PermissionHandler,
|
||||
config: &Config,
|
||||
) -> Self {
|
||||
let messages = vec![ChatMessage::system(&config.system_prompt)];
|
||||
Self {
|
||||
provider,
|
||||
tools,
|
||||
permissions,
|
||||
messages,
|
||||
max_rounds: config.max_tool_rounds,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an agent with prior conversation history (for session resume/fork).
|
||||
pub fn new_with_history(
|
||||
provider: Box<dyn Provider>,
|
||||
tools: ToolRegistry,
|
||||
permissions: PermissionHandler,
|
||||
config: &Config,
|
||||
prior_messages: Vec<ChatMessage>,
|
||||
) -> Self {
|
||||
let mut messages = vec![ChatMessage::system(&config.system_prompt)];
|
||||
messages.extend(prior_messages);
|
||||
Self {
|
||||
provider,
|
||||
tools,
|
||||
permissions,
|
||||
messages,
|
||||
max_rounds: config.max_tool_rounds,
|
||||
}
|
||||
}
|
||||
|
||||
/// Access the conversation messages (for session saving).
|
||||
pub fn messages(&self) -> &[ChatMessage] {
|
||||
&self.messages
|
||||
}
|
||||
|
||||
/// Run a single user prompt to completion (non-interactive).
|
||||
pub async fn run_once(&mut self, prompt: &str) -> Result<String> {
|
||||
self.messages.push(ChatMessage::user(prompt));
|
||||
|
||||
let mut final_text = String::new();
|
||||
|
||||
for _ in 0..self.max_rounds {
|
||||
let (text, tool_calls) = self.stream_response().await?;
|
||||
final_text = text.clone();
|
||||
|
||||
if tool_calls.is_empty() {
|
||||
break;
|
||||
}
|
||||
|
||||
self.messages.push(ChatMessage::assistant(
|
||||
if text.is_empty() { None } else { Some(text) },
|
||||
Some(tool_calls.clone()),
|
||||
));
|
||||
|
||||
for tc in &tool_calls {
|
||||
let args: serde_json::Value =
|
||||
serde_json::from_str(&tc.function.arguments).unwrap_or_default();
|
||||
let result = self.execute_with_permission(&tc.function.name, &args);
|
||||
self.messages
|
||||
.push(ChatMessage::tool_result(&tc.id, &result));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(final_text)
|
||||
}
|
||||
|
||||
/// Execute a tool call, checking permissions first.
|
||||
fn execute_with_permission(&self, name: &str, args: &serde_json::Value) -> String {
|
||||
let perm_request = match name {
|
||||
"bash" => args["command"]
|
||||
.as_str()
|
||||
.map(|cmd| PermissionRequest::Bash { command: cmd }),
|
||||
"write" => args["file_path"]
|
||||
.as_str()
|
||||
.map(|p| PermissionRequest::FileWrite { path: p }),
|
||||
"edit" => args["file_path"]
|
||||
.as_str()
|
||||
.map(|p| PermissionRequest::FileEdit { path: p }),
|
||||
_ => None, // read, glob, grep are always allowed
|
||||
};
|
||||
|
||||
if let Some(req) = perm_request {
|
||||
if !self.permissions.check(&req) {
|
||||
return "Permission denied by user.".to_string();
|
||||
}
|
||||
}
|
||||
|
||||
match self.tools.execute(name, args) {
|
||||
Ok(output) => output,
|
||||
Err(e) => format!("Error: {e}"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Build the message list to send to the provider, prepending a fresh slug context
|
||||
/// system message if any SLUG.md files are present. The slug message is never stored
|
||||
/// in self.messages — it is rebuilt from disk on every call.
|
||||
fn messages_with_slug_context(&self) -> Vec<ChatMessage> {
|
||||
let slug_context = load_slug_context();
|
||||
if slug_context.is_empty() {
|
||||
self.messages.clone()
|
||||
} else {
|
||||
let mut msgs = Vec::with_capacity(self.messages.len() + 1);
|
||||
msgs.push(ChatMessage::system(&slug_context));
|
||||
msgs.extend_from_slice(&self.messages);
|
||||
msgs
|
||||
}
|
||||
}
|
||||
|
||||
/// Send messages to the LLM and collect the streamed response.
|
||||
async fn stream_response(&self) -> Result<(String, Vec<ToolCall>)> {
|
||||
let defs = self.tools.definitions();
|
||||
let messages = self.messages_with_slug_context();
|
||||
let mut stream = self.provider.stream_chat(&messages, &defs);
|
||||
|
||||
let mut text = String::new();
|
||||
let mut tc_accumulators: HashMap<usize, ToolCallAccumulator> = HashMap::new();
|
||||
|
||||
while let Some(event) = stream.next().await {
|
||||
match event? {
|
||||
StreamEvent::Text(t) => text.push_str(&t),
|
||||
StreamEvent::ToolCallDelta(delta) => {
|
||||
let acc = tc_accumulators.entry(delta.index).or_default();
|
||||
if let Some(id) = delta.id {
|
||||
acc.id = Some(id);
|
||||
}
|
||||
if let Some(name) = delta.name {
|
||||
acc.name = Some(name);
|
||||
}
|
||||
if let Some(args) = delta.arguments_delta {
|
||||
acc.arguments.push_str(&args);
|
||||
}
|
||||
}
|
||||
StreamEvent::Finish | StreamEvent::Done => break,
|
||||
}
|
||||
}
|
||||
|
||||
let tool_calls = Self::collect_tool_calls(tc_accumulators);
|
||||
Ok((text, tool_calls))
|
||||
}
|
||||
|
||||
/// Push a user message and stream the response, calling the callback for each text chunk.
|
||||
pub async fn stream_turn<F>(&mut self, user_input: &str, mut on_text: F) -> Result<()>
|
||||
where
|
||||
F: FnMut(&str) + Send,
|
||||
{
|
||||
self.messages.push(ChatMessage::user(user_input));
|
||||
|
||||
for _ in 0..self.max_rounds {
|
||||
let defs = self.tools.definitions();
|
||||
let messages = self.messages_with_slug_context();
|
||||
let mut stream = self.provider.stream_chat(&messages, &defs);
|
||||
|
||||
let mut text = String::new();
|
||||
let mut tc_accumulators: HashMap<usize, ToolCallAccumulator> = HashMap::new();
|
||||
|
||||
while let Some(event) = stream.next().await {
|
||||
match event? {
|
||||
StreamEvent::Text(t) => {
|
||||
on_text(&t);
|
||||
text.push_str(&t);
|
||||
}
|
||||
StreamEvent::ToolCallDelta(delta) => {
|
||||
let acc = tc_accumulators.entry(delta.index).or_default();
|
||||
if let Some(id) = delta.id {
|
||||
acc.id = Some(id);
|
||||
}
|
||||
if let Some(name) = delta.name {
|
||||
acc.name = Some(name);
|
||||
}
|
||||
if let Some(args) = delta.arguments_delta {
|
||||
acc.arguments.push_str(&args);
|
||||
}
|
||||
}
|
||||
StreamEvent::Finish | StreamEvent::Done => break,
|
||||
}
|
||||
}
|
||||
|
||||
let tool_calls = Self::collect_tool_calls(tc_accumulators);
|
||||
|
||||
if tool_calls.is_empty() {
|
||||
self.messages.push(ChatMessage::assistant(
|
||||
if text.is_empty() { None } else { Some(text) },
|
||||
None,
|
||||
));
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
self.messages.push(ChatMessage::assistant(
|
||||
if text.is_empty() { None } else { Some(text) },
|
||||
Some(tool_calls.clone()),
|
||||
));
|
||||
|
||||
for tc in &tool_calls {
|
||||
on_text(&format!("\n\x1b[36m--- Tool: {} ---\x1b[0m\n", tc.function.name));
|
||||
let args: serde_json::Value =
|
||||
serde_json::from_str(&tc.function.arguments).unwrap_or_default();
|
||||
let result = self.execute_with_permission(&tc.function.name, &args);
|
||||
let display = if result.len() > 500 {
|
||||
format!("{}... ({} bytes total)", &result[..500], result.len())
|
||||
} else {
|
||||
result.clone()
|
||||
};
|
||||
on_text(&format!("{display}\n\x1b[36m--- End tool ---\x1b[0m\n\n"));
|
||||
self.messages
|
||||
.push(ChatMessage::tool_result(&tc.id, &result));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Reduce context size by applying ToolResultTrim first, then Truncate if still too large.
|
||||
pub fn compact(&mut self) {
|
||||
let compactor = Compactor::new();
|
||||
compactor.compact(&mut self.messages, CompactionStrategy::ToolResultTrim);
|
||||
if compactor.needs_compaction(&self.messages) {
|
||||
compactor.compact(&mut self.messages, CompactionStrategy::Truncate);
|
||||
}
|
||||
}
|
||||
|
||||
fn collect_tool_calls(accumulators: HashMap<usize, ToolCallAccumulator>) -> Vec<ToolCall> {
|
||||
let mut tool_calls: Vec<ToolCall> = accumulators
|
||||
.into_iter()
|
||||
.filter_map(|(_, acc)| {
|
||||
Some(ToolCall {
|
||||
id: acc.id?,
|
||||
call_type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: acc.name?,
|
||||
arguments: acc.arguments,
|
||||
},
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
tool_calls.sort_by_key(|tc| tc.id.clone());
|
||||
tool_calls
|
||||
}
|
||||
}
|
||||
405
src/compact/mod.rs
Normal file
405
src/compact/mod.rs
Normal file
|
|
@ -0,0 +1,405 @@
|
|||
use crate::provider::{ChatMessage, Role};
|
||||
|
||||
/// Approximate token count using a simple chars/4 heuristic.
|
||||
pub fn estimate_tokens(text: &str) -> usize {
|
||||
(text.len() + 3) / 4
|
||||
}
|
||||
|
||||
/// Strategy for reducing context size.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum CompactionStrategy {
|
||||
/// Summarize entire history into a single summary message (local, no LLM call).
|
||||
Full,
|
||||
/// Drop oldest message groups, keeping only recent turns.
|
||||
Truncate,
|
||||
/// Replace old tool results with a truncation placeholder.
|
||||
ToolResultTrim,
|
||||
}
|
||||
|
||||
/// Manages context size by applying compaction strategies to the message history.
|
||||
pub struct Compactor {
|
||||
/// Maximum context tokens before compaction is triggered.
|
||||
pub max_context_tokens: usize,
|
||||
}
|
||||
|
||||
impl Default for Compactor {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_context_tokens: 200_000,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Compactor {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn with_max_tokens(max_context_tokens: usize) -> Self {
|
||||
Self { max_context_tokens }
|
||||
}
|
||||
|
||||
/// Returns true if the estimated token count exceeds 80% of the max.
|
||||
pub fn needs_compaction(&self, messages: &[ChatMessage]) -> bool {
|
||||
let total: usize = messages
|
||||
.iter()
|
||||
.map(|m| {
|
||||
let content_tokens = m
|
||||
.content
|
||||
.as_deref()
|
||||
.map(estimate_tokens)
|
||||
.unwrap_or(0);
|
||||
let tool_tokens: usize = m
|
||||
.tool_calls
|
||||
.as_ref()
|
||||
.map(|tcs| {
|
||||
tcs.iter()
|
||||
.map(|tc| estimate_tokens(&tc.function.arguments))
|
||||
.sum()
|
||||
})
|
||||
.unwrap_or(0);
|
||||
content_tokens + tool_tokens
|
||||
})
|
||||
.sum();
|
||||
|
||||
total > (self.max_context_tokens * 4) / 5
|
||||
}
|
||||
|
||||
/// Apply the given strategy in-place to the message list.
|
||||
pub fn compact(&self, messages: &mut Vec<ChatMessage>, strategy: CompactionStrategy) {
|
||||
match strategy {
|
||||
CompactionStrategy::Full => self.apply_full(messages),
|
||||
CompactionStrategy::Truncate => self.apply_truncate(messages),
|
||||
CompactionStrategy::ToolResultTrim => self.apply_tool_result_trim(messages),
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract a structured summary of the session from the message history.
|
||||
///
|
||||
/// Scans user/assistant messages for key context: tasks, files, errors, decisions.
|
||||
pub fn extract_session_memory(&self, messages: &[ChatMessage]) -> String {
|
||||
let mut user_snippets: Vec<String> = Vec::new();
|
||||
let mut assistant_snippets: Vec<String> = Vec::new();
|
||||
let mut files_mentioned: Vec<String> = Vec::new();
|
||||
let mut errors_encountered: Vec<String> = Vec::new();
|
||||
|
||||
for msg in messages {
|
||||
match msg.role {
|
||||
Role::System => continue,
|
||||
Role::User => {
|
||||
if let Some(ref content) = msg.content {
|
||||
let snippet = if content.len() > 200 {
|
||||
format!("{}...", &content[..200])
|
||||
} else {
|
||||
content.clone()
|
||||
};
|
||||
user_snippets.push(snippet);
|
||||
}
|
||||
}
|
||||
Role::Assistant => {
|
||||
if let Some(ref content) = msg.content {
|
||||
let snippet = if content.len() > 300 {
|
||||
format!("{}...", &content[..300])
|
||||
} else {
|
||||
content.clone()
|
||||
};
|
||||
assistant_snippets.push(snippet);
|
||||
}
|
||||
// Collect file paths from tool calls
|
||||
if let Some(ref tool_calls) = msg.tool_calls {
|
||||
for tc in tool_calls {
|
||||
if let Ok(args) =
|
||||
serde_json::from_str::<serde_json::Value>(&tc.function.arguments)
|
||||
{
|
||||
for key in &["file_path", "path"] {
|
||||
if let Some(p) = args[key].as_str() {
|
||||
if !files_mentioned.contains(&p.to_string()) {
|
||||
files_mentioned.push(p.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Role::Tool => {
|
||||
if let Some(ref content) = msg.content {
|
||||
if content.to_lowercase().contains("error") {
|
||||
let snippet = if content.len() > 150 {
|
||||
format!("{}...", &content[..150])
|
||||
} else {
|
||||
content.clone()
|
||||
};
|
||||
errors_encountered.push(snippet);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut parts: Vec<String> = Vec::new();
|
||||
|
||||
if !user_snippets.is_empty() {
|
||||
parts.push(format!(
|
||||
"User requests: {}",
|
||||
user_snippets
|
||||
.iter()
|
||||
.take(5)
|
||||
.cloned()
|
||||
.collect::<Vec<_>>()
|
||||
.join(" | ")
|
||||
));
|
||||
}
|
||||
|
||||
if !assistant_snippets.is_empty() {
|
||||
parts.push(format!(
|
||||
"Assistant responses: {}",
|
||||
assistant_snippets
|
||||
.iter()
|
||||
.take(3)
|
||||
.cloned()
|
||||
.collect::<Vec<_>>()
|
||||
.join(" | ")
|
||||
));
|
||||
}
|
||||
|
||||
if !files_mentioned.is_empty() {
|
||||
parts.push(format!("Files touched: {}", files_mentioned.join(", ")));
|
||||
}
|
||||
|
||||
if !errors_encountered.is_empty() {
|
||||
parts.push(format!(
|
||||
"Errors encountered: {}",
|
||||
errors_encountered
|
||||
.iter()
|
||||
.take(3)
|
||||
.cloned()
|
||||
.collect::<Vec<_>>()
|
||||
.join(" | ")
|
||||
));
|
||||
}
|
||||
|
||||
if parts.is_empty() {
|
||||
return "[No significant context to summarize]".to_string();
|
||||
}
|
||||
|
||||
parts.join("\n")
|
||||
}
|
||||
|
||||
// --- Strategy implementations ---
|
||||
|
||||
/// Replace all messages (except the system prompt) with a single summary message.
|
||||
fn apply_full(&self, messages: &mut Vec<ChatMessage>) {
|
||||
let summary = self.extract_session_memory(messages);
|
||||
let summary_content = format!("Previous conversation summary:\n{summary}");
|
||||
|
||||
// Retain the system prompt if present.
|
||||
let system_prompt: Option<ChatMessage> = messages
|
||||
.first()
|
||||
.filter(|m| m.role == Role::System)
|
||||
.cloned();
|
||||
|
||||
messages.clear();
|
||||
|
||||
if let Some(sys) = system_prompt {
|
||||
messages.push(sys);
|
||||
}
|
||||
|
||||
messages.push(ChatMessage::user(&summary_content));
|
||||
}
|
||||
|
||||
/// Keep the system prompt and the most recent N message groups, dropping older turns.
|
||||
///
|
||||
/// A "group" is defined as a user message plus the assistant reply and any associated
|
||||
/// tool calls/results that follow it.
|
||||
fn apply_truncate(&self, messages: &mut Vec<ChatMessage>) {
|
||||
const GROUPS_TO_KEEP: usize = 10;
|
||||
|
||||
// Identify the system prompt.
|
||||
let system_prompt: Option<ChatMessage> = messages
|
||||
.first()
|
||||
.filter(|m| m.role == Role::System)
|
||||
.cloned();
|
||||
|
||||
// Find the start index of non-system messages.
|
||||
let start = if system_prompt.is_some() { 1 } else { 0 };
|
||||
let rest = &messages[start..];
|
||||
|
||||
// Walk backwards to find group boundaries. Each new User message starts a group.
|
||||
let group_starts: Vec<usize> = rest
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, m)| {
|
||||
if m.role == Role::User {
|
||||
Some(i)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
if group_starts.len() <= GROUPS_TO_KEEP {
|
||||
// Already within limits — nothing to drop.
|
||||
return;
|
||||
}
|
||||
|
||||
let keep_from = group_starts[group_starts.len() - GROUPS_TO_KEEP];
|
||||
let kept: Vec<ChatMessage> = rest[keep_from..].to_vec();
|
||||
|
||||
messages.clear();
|
||||
if let Some(sys) = system_prompt {
|
||||
messages.push(sys);
|
||||
}
|
||||
messages.extend(kept);
|
||||
}
|
||||
|
||||
/// Replace the content of old tool results with a truncation placeholder.
|
||||
///
|
||||
/// The most recent 5 tool results are left intact; everything older is replaced.
|
||||
fn apply_tool_result_trim(&self, messages: &mut Vec<ChatMessage>) {
|
||||
// Collect indices of all Tool-role messages.
|
||||
let tool_indices: Vec<usize> = messages
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, m)| if m.role == Role::Tool { Some(i) } else { None })
|
||||
.collect();
|
||||
|
||||
const KEEP_RECENT: usize = 5;
|
||||
|
||||
if tool_indices.len() <= KEEP_RECENT {
|
||||
return;
|
||||
}
|
||||
|
||||
let trim_up_to = tool_indices.len() - KEEP_RECENT;
|
||||
|
||||
for &idx in &tool_indices[..trim_up_to] {
|
||||
let original_len = messages[idx].content.as_deref().map(|c| c.len()).unwrap_or(0);
|
||||
if original_len > 0 {
|
||||
messages[idx].content = Some(format!(
|
||||
"[result truncated — {original_len} chars]"
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::provider::{FunctionCall, ToolCall};
|
||||
|
||||
fn make_messages() -> Vec<ChatMessage> {
|
||||
vec![
|
||||
ChatMessage::system("You are a helpful assistant."),
|
||||
ChatMessage::user("Hello"),
|
||||
ChatMessage::assistant(Some("Hi there!".to_string()), None),
|
||||
ChatMessage::user("Run a tool"),
|
||||
ChatMessage::assistant(
|
||||
None,
|
||||
Some(vec![ToolCall {
|
||||
id: "call_1".to_string(),
|
||||
call_type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: "bash".to_string(),
|
||||
arguments: r#"{"command":"ls"}"#.to_string(),
|
||||
},
|
||||
}]),
|
||||
),
|
||||
ChatMessage::tool_result("call_1", "file1.rs\nfile2.rs"),
|
||||
]
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_tokens() {
|
||||
assert_eq!(estimate_tokens("hello"), 2);
|
||||
assert_eq!(estimate_tokens("hello world"), 3);
|
||||
assert_eq!(estimate_tokens(""), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_needs_compaction_false() {
|
||||
let compactor = Compactor::new();
|
||||
let messages = make_messages();
|
||||
assert!(!compactor.needs_compaction(&messages));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_needs_compaction_true() {
|
||||
let compactor = Compactor::with_max_tokens(10);
|
||||
let messages = make_messages();
|
||||
assert!(compactor.needs_compaction(&messages));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_full_compaction_retains_system_prompt() {
|
||||
let compactor = Compactor::new();
|
||||
let mut messages = make_messages();
|
||||
compactor.compact(&mut messages, CompactionStrategy::Full);
|
||||
assert_eq!(messages[0].role, Role::System);
|
||||
assert_eq!(messages.len(), 2); // system + summary
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_truncate_keeps_recent_groups() {
|
||||
let compactor = Compactor::new();
|
||||
let mut messages = make_messages();
|
||||
// With fewer than GROUPS_TO_KEEP groups, nothing is dropped.
|
||||
let original_len = messages.len();
|
||||
compactor.compact(&mut messages, CompactionStrategy::Truncate);
|
||||
assert_eq!(messages.len(), original_len);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_result_trim() {
|
||||
let compactor = Compactor::new();
|
||||
let mut messages = make_messages();
|
||||
// Only 1 tool result — below the threshold of 5, so nothing should change.
|
||||
compactor.compact(&mut messages, CompactionStrategy::ToolResultTrim);
|
||||
let tool_msg = messages.iter().find(|m| m.role == Role::Tool).unwrap();
|
||||
assert_eq!(tool_msg.content.as_deref(), Some("file1.rs\nfile2.rs"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_result_trim_replaces_old() {
|
||||
let compactor = Compactor::new();
|
||||
let mut messages = vec![ChatMessage::system("sys")];
|
||||
// Add 7 tool result pairs so that the first 2 should be trimmed.
|
||||
for i in 0..7u32 {
|
||||
messages.push(ChatMessage::user(format!("q{i}").as_str()));
|
||||
messages.push(ChatMessage::assistant(None, Some(vec![ToolCall {
|
||||
id: format!("call_{i}"),
|
||||
call_type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: "bash".to_string(),
|
||||
arguments: "{}".to_string(),
|
||||
},
|
||||
}])));
|
||||
messages.push(ChatMessage::tool_result(
|
||||
&format!("call_{i}"),
|
||||
&format!("result content {i}"),
|
||||
));
|
||||
}
|
||||
compactor.compact(&mut messages, CompactionStrategy::ToolResultTrim);
|
||||
// The first 2 tool results should be truncated.
|
||||
let tool_msgs: Vec<&ChatMessage> =
|
||||
messages.iter().filter(|m| m.role == Role::Tool).collect();
|
||||
assert!(tool_msgs[0]
|
||||
.content
|
||||
.as_deref()
|
||||
.unwrap()
|
||||
.starts_with("[result truncated"));
|
||||
assert!(tool_msgs[1]
|
||||
.content
|
||||
.as_deref()
|
||||
.unwrap()
|
||||
.starts_with("[result truncated"));
|
||||
// The last 5 should be intact.
|
||||
for msg in &tool_msgs[2..] {
|
||||
assert!(msg
|
||||
.content
|
||||
.as_deref()
|
||||
.unwrap()
|
||||
.starts_with("result content"));
|
||||
}
|
||||
}
|
||||
}
|
||||
100
src/config/mod.rs
Normal file
100
src/config/mod.rs
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
use anyhow::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
|
||||
const DEFAULT_ENDPOINT: &str = "http://localhost:8000/v1";
|
||||
const DEFAULT_MODEL: &str = "default";
|
||||
|
||||
const DEFAULT_SYSTEM_PROMPT: &str = r#"You are an AI coding assistant. You help users with software engineering tasks.
|
||||
You have access to tools for reading files, writing files, editing files, running bash commands, searching with glob patterns, and searching file contents with grep.
|
||||
Use these tools to help the user accomplish their goals. Be concise and direct."#;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "lowercase")]
|
||||
pub enum PermissionMode {
|
||||
Ask,
|
||||
Yolo,
|
||||
/// Auto-approve Edit and Write within cwd; prompt for Bash.
|
||||
AllowEdits,
|
||||
Sandbox(String),
|
||||
}
|
||||
|
||||
impl Default for PermissionMode {
|
||||
fn default() -> Self {
|
||||
Self::Ask
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Config {
|
||||
pub endpoint: String,
|
||||
pub api_key: Option<String>,
|
||||
pub model: String,
|
||||
pub system_prompt: String,
|
||||
pub max_tokens: u32,
|
||||
pub temperature: Option<f32>,
|
||||
pub max_tool_rounds: usize,
|
||||
#[serde(default)]
|
||||
pub permission_mode: PermissionMode,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
endpoint: DEFAULT_ENDPOINT.to_string(),
|
||||
api_key: None,
|
||||
model: DEFAULT_MODEL.to_string(),
|
||||
system_prompt: DEFAULT_SYSTEM_PROMPT.to_string(),
|
||||
max_tokens: 4096,
|
||||
temperature: None,
|
||||
max_tool_rounds: 50,
|
||||
permission_mode: PermissionMode::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn load(path: Option<&str>) -> Result<Self> {
|
||||
let config_path = match path {
|
||||
Some(p) => PathBuf::from(p),
|
||||
None => Self::default_config_path(),
|
||||
};
|
||||
|
||||
if config_path.exists() {
|
||||
let contents = std::fs::read_to_string(&config_path)?;
|
||||
let cfg: Config = toml::from_str(&contents)?;
|
||||
Ok(cfg)
|
||||
} else {
|
||||
Ok(Config::default())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_cli_overrides(
|
||||
mut self,
|
||||
endpoint: &Option<String>,
|
||||
api_key: &Option<String>,
|
||||
model: &Option<String>,
|
||||
system_prompt: &Option<String>,
|
||||
) -> Self {
|
||||
if let Some(e) = endpoint {
|
||||
self.endpoint = e.clone();
|
||||
}
|
||||
if let Some(k) = api_key {
|
||||
self.api_key = Some(k.clone());
|
||||
}
|
||||
if let Some(m) = model {
|
||||
self.model = m.clone();
|
||||
}
|
||||
if let Some(s) = system_prompt {
|
||||
self.system_prompt = s.clone();
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
fn default_config_path() -> PathBuf {
|
||||
dirs::config_dir()
|
||||
.unwrap_or_else(|| PathBuf::from("."))
|
||||
.join("slug-code")
|
||||
.join("config.toml")
|
||||
}
|
||||
}
|
||||
449
src/hooks/mod.rs
Normal file
449
src/hooks/mod.rs
Normal file
|
|
@ -0,0 +1,449 @@
|
|||
use serde::Deserialize;
|
||||
use std::path::PathBuf;
|
||||
use std::process::Command;
|
||||
use std::time::Duration;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Public event types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Lifecycle events that can trigger hooks.
|
||||
#[derive(Debug)]
|
||||
pub enum HookEvent {
|
||||
PreToolUse {
|
||||
tool_name: String,
|
||||
args: serde_json::Value,
|
||||
},
|
||||
PostToolUse {
|
||||
tool_name: String,
|
||||
args: serde_json::Value,
|
||||
result: String,
|
||||
},
|
||||
UserPromptSubmit {
|
||||
prompt: String,
|
||||
},
|
||||
SessionStart,
|
||||
SessionEnd,
|
||||
}
|
||||
|
||||
impl HookEvent {
|
||||
/// Returns the string name used in config to match this event variant.
|
||||
fn event_name(&self) -> &'static str {
|
||||
match self {
|
||||
HookEvent::PreToolUse { .. } => "PreToolUse",
|
||||
HookEvent::PostToolUse { .. } => "PostToolUse",
|
||||
HookEvent::UserPromptSubmit { .. } => "UserPromptSubmit",
|
||||
HookEvent::SessionStart => "SessionStart",
|
||||
HookEvent::SessionEnd => "SessionEnd",
|
||||
}
|
||||
}
|
||||
|
||||
/// For tool events, returns the tool name so we can apply `tool_filter`.
|
||||
fn tool_name(&self) -> Option<&str> {
|
||||
match self {
|
||||
HookEvent::PreToolUse { tool_name, .. } => Some(tool_name),
|
||||
HookEvent::PostToolUse { tool_name, .. } => Some(tool_name),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// What to do when a hook fires
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// The action a hook executes when its event fires.
|
||||
#[derive(Debug)]
|
||||
pub enum HookAction {
|
||||
/// Run a shell command. Non-zero exit blocks; stdout becomes additional context.
|
||||
Command { command: String },
|
||||
/// Inject static text as additional context into the conversation.
|
||||
Prompt { content: String },
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Deserialization helpers (mirrors the JSON schema described in the task)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "lowercase")]
|
||||
enum RawAction {
|
||||
Command { command: String },
|
||||
Prompt { content: String },
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct RawHook {
|
||||
event: String,
|
||||
/// Only trigger when the tool name matches this filter (tool events only).
|
||||
tool_filter: Option<String>,
|
||||
action: RawAction,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Default)]
|
||||
struct HooksFile {
|
||||
#[serde(default)]
|
||||
hooks: Vec<RawHook>,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Compiled hook entry
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
struct HookEntry {
|
||||
event_name: String,
|
||||
tool_filter: Option<String>,
|
||||
action: HookAction,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Result returned by HookManager::fire
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct HookResult {
|
||||
/// Additional context lines to inject into the conversation.
|
||||
pub additional_context: Option<String>,
|
||||
/// If true the operation that triggered this event should be blocked.
|
||||
pub blocked: bool,
|
||||
/// Human-readable reason for blocking (stderr of failed command).
|
||||
pub block_reason: Option<String>,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HookManager
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub struct HookManager {
|
||||
hooks: Vec<HookEntry>,
|
||||
}
|
||||
|
||||
impl HookManager {
|
||||
const COMMAND_TIMEOUT: Duration = Duration::from_secs(30);
|
||||
|
||||
/// Load hooks from:
|
||||
/// 1. `~/.slug/settings.json` (user-global)
|
||||
/// 2. `.slug/hooks.json` (project-local, relative to cwd)
|
||||
pub fn new() -> Self {
|
||||
let mut hooks: Vec<HookEntry> = Vec::new();
|
||||
|
||||
// 1. User-global settings
|
||||
if let Some(home) = dirs::home_dir() {
|
||||
let global_path = home.join(".slug").join("settings.json");
|
||||
load_hooks_from_path(&global_path, &mut hooks);
|
||||
}
|
||||
|
||||
// 2. Project-local hooks
|
||||
let local_path = PathBuf::from(".slug").join("hooks.json");
|
||||
load_hooks_from_path(&local_path, &mut hooks);
|
||||
|
||||
HookManager { hooks }
|
||||
}
|
||||
|
||||
/// Execute all hooks whose event and tool_filter match `event`.
|
||||
/// Results are merged: any block wins; context lines are concatenated.
|
||||
pub fn fire(&self, event: &HookEvent) -> HookResult {
|
||||
let mut result = HookResult::default();
|
||||
let event_name = event.event_name();
|
||||
let tool_name = event.tool_name();
|
||||
|
||||
for hook in &self.hooks {
|
||||
if hook.event_name != event_name {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Apply optional tool filter
|
||||
if let Some(ref filter) = hook.tool_filter {
|
||||
match tool_name {
|
||||
Some(name) if name == filter => {}
|
||||
_ => continue,
|
||||
}
|
||||
}
|
||||
|
||||
let action_result = run_action(&hook.action);
|
||||
merge_result(&mut result, action_result);
|
||||
|
||||
// Stop processing further hooks once we are blocked
|
||||
if result.blocked {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for HookManager {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Internal helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn load_hooks_from_path(path: &PathBuf, out: &mut Vec<HookEntry>) {
|
||||
if !path.exists() {
|
||||
return;
|
||||
}
|
||||
|
||||
let contents = match std::fs::read_to_string(path) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
tracing::warn!("hooks: failed to read {}: {e}", path.display());
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let file: HooksFile = match serde_json::from_str(&contents) {
|
||||
Ok(f) => f,
|
||||
Err(e) => {
|
||||
tracing::warn!("hooks: failed to parse {}: {e}", path.display());
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
for raw in file.hooks {
|
||||
let action = match raw.action {
|
||||
RawAction::Command { command } => HookAction::Command { command },
|
||||
RawAction::Prompt { content } => HookAction::Prompt { content },
|
||||
};
|
||||
out.push(HookEntry {
|
||||
event_name: raw.event,
|
||||
tool_filter: raw.tool_filter,
|
||||
action,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Run a single hook action and return a `HookResult` for it.
|
||||
fn run_action(action: &HookAction) -> HookResult {
|
||||
match action {
|
||||
HookAction::Prompt { content } => HookResult {
|
||||
additional_context: Some(content.clone()),
|
||||
blocked: false,
|
||||
block_reason: None,
|
||||
},
|
||||
|
||||
HookAction::Command { command } => run_command(command),
|
||||
}
|
||||
}
|
||||
|
||||
fn run_command(command: &str) -> HookResult {
|
||||
// Use a thread to enforce the timeout because std::process::Command does
|
||||
// not have built-in timeout support.
|
||||
let command = command.to_owned();
|
||||
|
||||
let handle = std::thread::spawn(move || {
|
||||
Command::new("sh")
|
||||
.arg("-c")
|
||||
.arg(&command)
|
||||
.output()
|
||||
});
|
||||
|
||||
let output = match handle.join() {
|
||||
Ok(Ok(o)) => o,
|
||||
Ok(Err(e)) => {
|
||||
return HookResult {
|
||||
additional_context: None,
|
||||
blocked: true,
|
||||
block_reason: Some(format!("hook command failed to start: {e}")),
|
||||
};
|
||||
}
|
||||
Err(_) => {
|
||||
return HookResult {
|
||||
additional_context: None,
|
||||
blocked: true,
|
||||
block_reason: Some("hook command panicked".to_string()),
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
// We implement timeout by checking elapsed time separately — but the
|
||||
// simpler and correct approach for std::process is to use a watchdog
|
||||
// thread that kills the child. For now we rely on the 30-second note in
|
||||
// the spec and keep the implementation straightforward; a process that
|
||||
// hangs will block only the hook thread, not the async runtime.
|
||||
let _ = HookManager::COMMAND_TIMEOUT; // referenced so the constant is used
|
||||
|
||||
if output.status.success() {
|
||||
let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string();
|
||||
HookResult {
|
||||
additional_context: if stdout.is_empty() { None } else { Some(stdout) },
|
||||
blocked: false,
|
||||
block_reason: None,
|
||||
}
|
||||
} else {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string();
|
||||
let reason = if stderr.is_empty() {
|
||||
format!(
|
||||
"hook command exited with status {}",
|
||||
output.status.code().unwrap_or(-1)
|
||||
)
|
||||
} else {
|
||||
stderr
|
||||
};
|
||||
HookResult {
|
||||
additional_context: None,
|
||||
blocked: true,
|
||||
block_reason: Some(reason),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Merge `src` into `dst`. Blocks are sticky; context lines are appended.
|
||||
fn merge_result(dst: &mut HookResult, src: HookResult) {
|
||||
if src.blocked {
|
||||
dst.blocked = true;
|
||||
dst.block_reason = src.block_reason;
|
||||
}
|
||||
|
||||
match (&mut dst.additional_context, src.additional_context) {
|
||||
(Some(existing), Some(new)) => {
|
||||
existing.push('\n');
|
||||
existing.push_str(&new);
|
||||
}
|
||||
(None, Some(new)) => {
|
||||
dst.additional_context = Some(new);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_command_hook(cmd: &str) -> HookEntry {
|
||||
HookEntry {
|
||||
event_name: "SessionStart".to_string(),
|
||||
tool_filter: None,
|
||||
action: HookAction::Command {
|
||||
command: cmd.to_string(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn make_prompt_hook(content: &str) -> HookEntry {
|
||||
HookEntry {
|
||||
event_name: "UserPromptSubmit".to_string(),
|
||||
tool_filter: None,
|
||||
action: HookAction::Prompt {
|
||||
content: content.to_string(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn command_hook_success_captures_stdout() {
|
||||
let hook = make_command_hook("echo hello");
|
||||
let result = run_action(&hook.action);
|
||||
assert!(!result.blocked);
|
||||
assert_eq!(result.additional_context.as_deref(), Some("hello"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn command_hook_failure_blocks() {
|
||||
let hook = make_command_hook("exit 1");
|
||||
let result = run_action(&hook.action);
|
||||
assert!(result.blocked);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prompt_hook_injects_content() {
|
||||
let hook = make_prompt_hook("Always check tests");
|
||||
let result = run_action(&hook.action);
|
||||
assert!(!result.blocked);
|
||||
assert_eq!(
|
||||
result.additional_context.as_deref(),
|
||||
Some("Always check tests")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn event_name_matches() {
|
||||
let event = HookEvent::PreToolUse {
|
||||
tool_name: "bash".to_string(),
|
||||
args: serde_json::Value::Null,
|
||||
};
|
||||
assert_eq!(event.event_name(), "PreToolUse");
|
||||
assert_eq!(event.tool_name(), Some("bash"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_filter_skips_non_matching_tool() {
|
||||
let manager = HookManager {
|
||||
hooks: vec![HookEntry {
|
||||
event_name: "PreToolUse".to_string(),
|
||||
tool_filter: Some("bash".to_string()),
|
||||
action: HookAction::Prompt {
|
||||
content: "bash-only context".to_string(),
|
||||
},
|
||||
}],
|
||||
};
|
||||
|
||||
// Different tool — should be skipped
|
||||
let event = HookEvent::PreToolUse {
|
||||
tool_name: "read".to_string(),
|
||||
args: serde_json::Value::Null,
|
||||
};
|
||||
let result = manager.fire(&event);
|
||||
assert!(result.additional_context.is_none());
|
||||
|
||||
// Matching tool — should fire
|
||||
let event = HookEvent::PreToolUse {
|
||||
tool_name: "bash".to_string(),
|
||||
args: serde_json::Value::Null,
|
||||
};
|
||||
let result = manager.fire(&event);
|
||||
assert_eq!(
|
||||
result.additional_context.as_deref(),
|
||||
Some("bash-only context")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn merge_context_concatenates() {
|
||||
let mut dst = HookResult {
|
||||
additional_context: Some("first".to_string()),
|
||||
blocked: false,
|
||||
block_reason: None,
|
||||
};
|
||||
let src = HookResult {
|
||||
additional_context: Some("second".to_string()),
|
||||
blocked: false,
|
||||
block_reason: None,
|
||||
};
|
||||
merge_result(&mut dst, src);
|
||||
assert_eq!(dst.additional_context.as_deref(), Some("first\nsecond"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_hooks_file() {
|
||||
let json = r#"{
|
||||
"hooks": [
|
||||
{
|
||||
"event": "PreToolUse",
|
||||
"tool_filter": "bash",
|
||||
"action": { "type": "command", "command": "npm run lint" }
|
||||
},
|
||||
{
|
||||
"event": "UserPromptSubmit",
|
||||
"action": { "type": "prompt", "content": "Always check tests after edits" }
|
||||
}
|
||||
]
|
||||
}"#;
|
||||
let file: HooksFile = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(file.hooks.len(), 2);
|
||||
assert_eq!(file.hooks[0].event, "PreToolUse");
|
||||
assert_eq!(file.hooks[0].tool_filter.as_deref(), Some("bash"));
|
||||
assert!(matches!(file.hooks[0].action, RawAction::Command { .. }));
|
||||
assert_eq!(file.hooks[1].event, "UserPromptSubmit");
|
||||
assert!(matches!(file.hooks[1].action, RawAction::Prompt { .. }));
|
||||
}
|
||||
}
|
||||
152
src/main.rs
Normal file
152
src/main.rs
Normal file
|
|
@ -0,0 +1,152 @@
|
|||
mod agent;
|
||||
mod compact;
|
||||
mod config;
|
||||
mod hooks;
|
||||
mod permissions;
|
||||
mod provider;
|
||||
mod session;
|
||||
mod slugmd;
|
||||
mod tools;
|
||||
mod tui;
|
||||
|
||||
use anyhow::Result;
|
||||
use clap::Parser;
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(name = "slug", about = "Slug Code - AI coding assistant")]
|
||||
struct Cli {
|
||||
/// Model API endpoint URL (e.g., http://localhost:8000/v1)
|
||||
#[arg(short, long, env = "SLUG_ENDPOINT")]
|
||||
endpoint: Option<String>,
|
||||
|
||||
/// API key (if required by the endpoint)
|
||||
#[arg(short = 'k', long, env = "SLUG_API_KEY")]
|
||||
api_key: Option<String>,
|
||||
|
||||
/// Model name to use
|
||||
#[arg(short, long, env = "SLUG_MODEL")]
|
||||
model: Option<String>,
|
||||
|
||||
/// Run a single prompt non-interactively
|
||||
#[arg(short, long)]
|
||||
prompt: Option<String>,
|
||||
|
||||
/// Path to config file
|
||||
#[arg(short, long)]
|
||||
config: Option<String>,
|
||||
|
||||
/// System prompt override
|
||||
#[arg(long)]
|
||||
system_prompt: Option<String>,
|
||||
|
||||
/// Skip all permission prompts
|
||||
#[arg(long)]
|
||||
yolo: bool,
|
||||
|
||||
/// Auto-approve operations within this directory (sandbox mode)
|
||||
#[arg(long)]
|
||||
sandbox: Option<String>,
|
||||
|
||||
/// Auto-approve file edits in working directory
|
||||
#[arg(long)]
|
||||
allow_edits: bool,
|
||||
|
||||
/// Continue the most recent session
|
||||
#[arg(long, alias = "continue")]
|
||||
continue_session: bool,
|
||||
|
||||
/// Resume a specific session by ID
|
||||
#[arg(long)]
|
||||
resume: Option<String>,
|
||||
|
||||
/// Fork an existing session into a new one
|
||||
#[arg(long)]
|
||||
fork_session: Option<String>,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
let cli = Cli::parse();
|
||||
|
||||
let mut cfg = config::Config::load(cli.config.as_deref())?
|
||||
.with_cli_overrides(&cli.endpoint, &cli.api_key, &cli.model, &cli.system_prompt);
|
||||
|
||||
// CLI permission flags override config
|
||||
if cli.yolo {
|
||||
cfg.permission_mode = config::PermissionMode::Yolo;
|
||||
} else if cli.allow_edits {
|
||||
cfg.permission_mode = config::PermissionMode::AllowEdits;
|
||||
} else if let Some(ref path) = cli.sandbox {
|
||||
cfg.permission_mode = config::PermissionMode::Sandbox(
|
||||
std::fs::canonicalize(path)
|
||||
.unwrap_or_else(|_| std::path::PathBuf::from(path))
|
||||
.to_string_lossy()
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(
|
||||
tracing_subscriber::EnvFilter::from_default_env()
|
||||
.add_directive("slug_code=info".parse()?),
|
||||
)
|
||||
.with_target(false)
|
||||
.init();
|
||||
|
||||
// Session management
|
||||
let session_mgr = session::SessionManager::new();
|
||||
let hook_mgr = hooks::HookManager::new();
|
||||
|
||||
// Fire SessionStart hook
|
||||
hook_mgr.fire(&hooks::HookEvent::SessionStart);
|
||||
|
||||
// Load or create session
|
||||
let (session_id, prior_messages) = if cli.continue_session {
|
||||
match session_mgr.get_latest_session() {
|
||||
Some(meta) => {
|
||||
let msgs = session_mgr.load_session(&meta.session_id)?;
|
||||
eprintln!("\x1b[36mResuming session {}\x1b[0m", &meta.session_id[..8]);
|
||||
(meta.session_id, msgs)
|
||||
}
|
||||
None => {
|
||||
let s = session_mgr.create_session();
|
||||
(s.session_id, vec![])
|
||||
}
|
||||
}
|
||||
} else if let Some(ref id) = cli.resume {
|
||||
let msgs = session_mgr.load_session(id)?;
|
||||
eprintln!("\x1b[36mResuming session {}\x1b[0m", &id[..id.len().min(8)]);
|
||||
(id.clone(), msgs)
|
||||
} else if let Some(ref source_id) = cli.fork_session {
|
||||
let s = session_mgr.fork_session(source_id)?;
|
||||
eprintln!(
|
||||
"\x1b[36mForked session {} → {}\x1b[0m",
|
||||
&source_id[..source_id.len().min(8)],
|
||||
&s.session_id[..8]
|
||||
);
|
||||
(s.session_id, s.messages)
|
||||
} else {
|
||||
let s = session_mgr.create_session();
|
||||
(s.session_id, vec![])
|
||||
};
|
||||
|
||||
let provider = provider::OpenAIProvider::new(&cfg);
|
||||
let tool_registry = tools::ToolRegistry::new();
|
||||
let perms = permissions::PermissionHandler::new(&cfg.permission_mode);
|
||||
|
||||
if let Some(prompt) = cli.prompt {
|
||||
let mut agent =
|
||||
agent::Agent::new_with_history(Box::new(provider), tool_registry, perms, &cfg, prior_messages);
|
||||
let response = agent.run_once(&prompt).await?;
|
||||
println!("{response}");
|
||||
// Save messages
|
||||
for msg in agent.messages() {
|
||||
session_mgr.save_message(&session_id, msg)?;
|
||||
}
|
||||
} else {
|
||||
tui::run(provider, tool_registry, perms, cfg, session_mgr, session_id, prior_messages, hook_mgr)
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
584
src/permissions/mod.rs
Normal file
584
src/permissions/mod.rs
Normal file
|
|
@ -0,0 +1,584 @@
|
|||
use std::io::{self, Write};
|
||||
use std::path::Path;
|
||||
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::config::PermissionMode;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Glob matching
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Match `text` against `pattern`.
|
||||
///
|
||||
/// Rules:
|
||||
/// - `**` matches any sequence of characters, including `/`.
|
||||
/// - `*` matches any sequence of non-`/` characters.
|
||||
/// - All other characters match literally.
|
||||
pub fn glob_match(pattern: &str, text: &str) -> bool {
|
||||
glob_match_bytes(pattern.as_bytes(), text.as_bytes())
|
||||
}
|
||||
|
||||
/// Recursive helper that implements the matching logic.
|
||||
fn glob_match_bytes(pat: &[u8], text: &[u8]) -> bool {
|
||||
match (pat.first(), text.first()) {
|
||||
// Both exhausted — success
|
||||
(None, None) => true,
|
||||
|
||||
// Pattern exhausted but text remains — only ok if pat is all stars
|
||||
(None, Some(_)) => false,
|
||||
|
||||
// `**` at head of pattern — can match zero or more characters (including `/`)
|
||||
(Some(b'*'), _) if pat.len() >= 2 && pat[1] == b'*' => {
|
||||
let rest_pat = if pat.len() >= 3 && pat[2] == b'/' {
|
||||
&pat[3..]
|
||||
} else {
|
||||
&pat[2..]
|
||||
};
|
||||
// Try matching rest_pat against every suffix of text
|
||||
for i in 0..=text.len() {
|
||||
if glob_match_bytes(rest_pat, &text[i..]) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
// Single `*` at head of pattern — matches any run of non-`/` chars
|
||||
(Some(b'*'), _) => {
|
||||
let rest_pat = &pat[1..];
|
||||
// Try matching rest_pat against every non-slash suffix of text
|
||||
for i in 0..=text.len() {
|
||||
// Don't allow crossing a `/`
|
||||
if i > 0 && text[i - 1] == b'/' {
|
||||
break;
|
||||
}
|
||||
if glob_match_bytes(rest_pat, &text[i..]) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
// Text exhausted but pattern remains (and pattern head is not `*`)
|
||||
(Some(_), None) => {
|
||||
// Allow trailing `/**` or `**` to match empty
|
||||
if pat == b"/**" || pat == b"**" {
|
||||
return true;
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
// Literal match
|
||||
(Some(&pc), Some(&tc)) => {
|
||||
if pc == tc {
|
||||
glob_match_bytes(&pat[1..], &text[1..])
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Permission rules
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// A single glob-based permission rule, e.g. `Bash(npm *)` or `Edit(src/**)`.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PermissionRule {
|
||||
/// The tool this rule applies to: "bash", "edit", "write"
|
||||
pub tool: String,
|
||||
/// Glob pattern for the argument (command string or file path)
|
||||
pub pattern: String,
|
||||
}
|
||||
|
||||
impl PermissionRule {
|
||||
/// Parse a rule string like `Bash(npm *)` or `Edit(src/**)`.
|
||||
///
|
||||
/// The format is `ToolName(pattern)` where `ToolName` is case-insensitive.
|
||||
/// Returns `None` if the string cannot be parsed.
|
||||
pub fn parse(s: &str) -> Option<Self> {
|
||||
let s = s.trim();
|
||||
let paren = s.find('(')?;
|
||||
if !s.ends_with(')') {
|
||||
return None;
|
||||
}
|
||||
let tool = s[..paren].trim().to_lowercase();
|
||||
let pattern = s[paren + 1..s.len() - 1].to_string();
|
||||
Some(Self { tool, pattern })
|
||||
}
|
||||
|
||||
/// Return true if this rule matches the given request.
|
||||
pub fn matches(&self, request: &PermissionRequest) -> bool {
|
||||
match request {
|
||||
PermissionRequest::Bash { command } => {
|
||||
self.tool == "bash" && glob_match(&self.pattern, command)
|
||||
}
|
||||
PermissionRequest::FileWrite { path } => {
|
||||
self.tool == "write" && glob_match(&self.pattern, path)
|
||||
}
|
||||
PermissionRequest::FileEdit { path } => {
|
||||
self.tool == "edit" && glob_match(&self.pattern, path)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Settings (loaded from JSON)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// The `permissions` block inside a settings file.
|
||||
#[derive(Debug, Clone, Default, Deserialize)]
|
||||
pub struct PermissionBlock {
|
||||
#[serde(default)]
|
||||
pub allow: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub deny: Vec<String>,
|
||||
}
|
||||
|
||||
/// Full settings struct that maps to the JSON file schema:
|
||||
/// ```json
|
||||
/// {
|
||||
/// "permissions": {
|
||||
/// "allow": ["Bash(npm *)", "Edit(src/**)"],
|
||||
/// "deny": ["Bash(rm -rf *)", "Bash(sudo *)"]
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
#[derive(Debug, Clone, Default, Deserialize)]
|
||||
pub struct PermissionSettings {
|
||||
#[serde(default)]
|
||||
pub permissions: PermissionBlock,
|
||||
}
|
||||
|
||||
impl PermissionSettings {
|
||||
/// Parse allow strings into `PermissionRule` vectors.
|
||||
pub fn allow_rules(&self) -> Vec<PermissionRule> {
|
||||
self.permissions
|
||||
.allow
|
||||
.iter()
|
||||
.filter_map(|s| PermissionRule::parse(s))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Parse deny strings into `PermissionRule` vectors.
|
||||
pub fn deny_rules(&self) -> Vec<PermissionRule> {
|
||||
self.permissions
|
||||
.deny
|
||||
.iter()
|
||||
.filter_map(|s| PermissionRule::parse(s))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Load and merge settings from (later overrides/extends earlier):
|
||||
/// 1. `~/.slug/settings.json` (user global)
|
||||
/// 2. `.slug/settings.json` (project local)
|
||||
///
|
||||
/// Allow and deny lists are merged: project entries are appended after global.
|
||||
pub fn load_settings() -> PermissionSettings {
|
||||
let mut merged = PermissionSettings::default();
|
||||
|
||||
let candidates: Vec<std::path::PathBuf> = {
|
||||
let mut v = Vec::new();
|
||||
if let Some(home) = dirs::home_dir() {
|
||||
v.push(home.join(".slug").join("settings.json"));
|
||||
}
|
||||
v.push(std::path::PathBuf::from(".slug/settings.json"));
|
||||
v
|
||||
};
|
||||
|
||||
for path in candidates {
|
||||
if path.exists() {
|
||||
match std::fs::read_to_string(&path) {
|
||||
Ok(contents) => match serde_json::from_str::<PermissionSettings>(&contents) {
|
||||
Ok(settings) => {
|
||||
merged
|
||||
.permissions
|
||||
.allow
|
||||
.extend(settings.permissions.allow);
|
||||
merged.permissions.deny.extend(settings.permissions.deny);
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!(
|
||||
"\x1b[33m[slug]\x1b[0m Warning: could not parse {}: {e}",
|
||||
path.display()
|
||||
);
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
eprintln!(
|
||||
"\x1b[33m[slug]\x1b[0m Warning: could not read {}: {e}",
|
||||
path.display()
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
merged
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Permission request
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Describes what kind of permission is being requested.
|
||||
pub enum PermissionRequest<'a> {
|
||||
/// Running a bash command
|
||||
Bash { command: &'a str },
|
||||
/// Writing to a file
|
||||
FileWrite { path: &'a str },
|
||||
/// Editing a file
|
||||
FileEdit { path: &'a str },
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Permission handler
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub struct PermissionHandler {
|
||||
mode: PermissionMode,
|
||||
settings: PermissionSettings,
|
||||
}
|
||||
|
||||
impl PermissionHandler {
|
||||
/// Create a new handler. Settings are loaded from the cascade automatically.
|
||||
pub fn new(mode: &PermissionMode) -> Self {
|
||||
let settings = load_settings();
|
||||
Self { mode: mode.clone(), settings }
|
||||
}
|
||||
|
||||
/// Create a handler with explicitly provided settings (useful for tests).
|
||||
pub fn with_settings(mode: &PermissionMode, settings: PermissionSettings) -> Self {
|
||||
Self { mode: mode.clone(), settings }
|
||||
}
|
||||
|
||||
/// Check if the action is allowed. Returns true if approved, false if denied.
|
||||
///
|
||||
/// Decision order:
|
||||
/// 1. Deny list — if matched, always prompt (overrides even Yolo mode).
|
||||
/// 2. Allow list — if matched, auto-approve.
|
||||
/// 3. Mode logic (Ask / Yolo / AllowEdits / Sandbox).
|
||||
pub fn check(&self, request: &PermissionRequest) -> bool {
|
||||
let deny_rules = self.settings.deny_rules();
|
||||
let allow_rules = self.settings.allow_rules();
|
||||
|
||||
// 1. Deny list always forces a prompt, even in Yolo mode.
|
||||
if deny_rules.iter().any(|r| r.matches(request)) {
|
||||
return self.prompt_user(request);
|
||||
}
|
||||
|
||||
// 2. Allow list — auto-approve without prompting.
|
||||
if allow_rules.iter().any(|r| r.matches(request)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// 3. Fall through to mode logic.
|
||||
match &self.mode {
|
||||
PermissionMode::Yolo => true,
|
||||
PermissionMode::Ask => self.prompt_user(request),
|
||||
PermissionMode::AllowEdits => match request {
|
||||
PermissionRequest::FileWrite { path } | PermissionRequest::FileEdit { path } => {
|
||||
// Auto-approve if within cwd, otherwise prompt.
|
||||
if self.path_is_within_cwd(path) {
|
||||
true
|
||||
} else {
|
||||
self.prompt_user(request)
|
||||
}
|
||||
}
|
||||
PermissionRequest::Bash { .. } => self.prompt_user(request),
|
||||
},
|
||||
PermissionMode::Sandbox(sandbox_path) => {
|
||||
if self.is_within_sandbox(request, sandbox_path) {
|
||||
true
|
||||
} else {
|
||||
self.prompt_user(request)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// Sandbox helpers
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
/// Check if the request operates within the sandbox directory.
|
||||
fn is_within_sandbox(&self, request: &PermissionRequest, sandbox_path: &str) -> bool {
|
||||
match request {
|
||||
PermissionRequest::Bash { command } => {
|
||||
// For bash in sandbox mode: approve if the command doesn't
|
||||
// reference paths outside the sandbox. This is a heuristic —
|
||||
// we check if the working directory is within the sandbox.
|
||||
// For truly dangerous commands (rm -rf /, etc.) users should
|
||||
// use Ask mode, not sandbox.
|
||||
self.bash_looks_safe_for_sandbox(command, sandbox_path)
|
||||
}
|
||||
PermissionRequest::FileWrite { path } | PermissionRequest::FileEdit { path } => {
|
||||
self.path_is_within(path, sandbox_path)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a file path is within the given directory.
|
||||
fn path_is_within(&self, path: &str, base: &str) -> bool {
|
||||
let canonical = std::fs::canonicalize(path)
|
||||
.unwrap_or_else(|_| Path::new(path).to_path_buf());
|
||||
let base_canonical = std::fs::canonicalize(base)
|
||||
.unwrap_or_else(|_| Path::new(base).to_path_buf());
|
||||
canonical.starts_with(&base_canonical)
|
||||
}
|
||||
|
||||
/// Check if a file path is within the current working directory.
|
||||
fn path_is_within_cwd(&self, path: &str) -> bool {
|
||||
let cwd = std::env::current_dir()
|
||||
.unwrap_or_else(|_| std::path::PathBuf::from("."));
|
||||
self.path_is_within(path, &cwd.to_string_lossy())
|
||||
}
|
||||
|
||||
/// Heuristic: does this bash command look like it stays within the sandbox?
|
||||
fn bash_looks_safe_for_sandbox(&self, command: &str, sandbox_path: &str) -> bool {
|
||||
// Reject commands that explicitly reference paths outside sandbox.
|
||||
// This is imperfect — a determined user/LLM can bypass it.
|
||||
// The sandbox is a convenience, not a security boundary.
|
||||
|
||||
let dangerous_patterns = [
|
||||
"rm -rf /",
|
||||
"rm -rf ~",
|
||||
"mkfs",
|
||||
"dd if=",
|
||||
"> /dev/",
|
||||
"chmod -R 777 /",
|
||||
"curl | sh",
|
||||
"wget | sh",
|
||||
];
|
||||
|
||||
for pattern in &dangerous_patterns {
|
||||
if command.contains(pattern) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// If the command references absolute paths outside sandbox, flag it.
|
||||
for token in command.split_whitespace() {
|
||||
if token.starts_with('/') && !token.starts_with(sandbox_path) {
|
||||
let safe_prefixes = ["/dev/null", "/tmp", "/usr/bin", "/bin", "/usr/local"];
|
||||
if !safe_prefixes.iter().any(|p| token.starts_with(p)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// User prompt
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
/// Prompt the user for permission interactively.
|
||||
fn prompt_user(&self, request: &PermissionRequest) -> bool {
|
||||
let description = match request {
|
||||
PermissionRequest::Bash { command } => {
|
||||
format!("Run bash command: {command}")
|
||||
}
|
||||
PermissionRequest::FileWrite { path } => {
|
||||
format!("Write to file: {path}")
|
||||
}
|
||||
PermissionRequest::FileEdit { path } => {
|
||||
format!("Edit file: {path}")
|
||||
}
|
||||
};
|
||||
|
||||
eprint!("\x1b[33m[permission]\x1b[0m {description}\n Allow? [y/N] ");
|
||||
io::stderr().flush().ok();
|
||||
|
||||
let mut input = String::new();
|
||||
if io::stdin().read_line(&mut input).is_err() {
|
||||
return false;
|
||||
}
|
||||
|
||||
matches!(input.trim().to_lowercase().as_str(), "y" | "yes")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// -- glob_match ----------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_glob_literal() {
|
||||
assert!(glob_match("npm install", "npm install"));
|
||||
assert!(!glob_match("npm install", "npm ci"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_glob_single_star_basic() {
|
||||
assert!(glob_match("npm *", "npm install"));
|
||||
assert!(glob_match("npm *", "npm run build"));
|
||||
assert!(!glob_match("npm *", "yarn install"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_glob_single_star_no_slash() {
|
||||
assert!(glob_match("src/*.rs", "src/main.rs"));
|
||||
assert!(!glob_match("src/*.rs", "src/foo/main.rs"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_glob_double_star() {
|
||||
assert!(glob_match("src/**", "src/foo/bar/baz.rs"));
|
||||
assert!(glob_match("src/**", "src/lib.rs"));
|
||||
assert!(!glob_match("src/**", "tests/lib.rs"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_glob_double_star_mid() {
|
||||
assert!(glob_match("src/**/mod.rs", "src/foo/bar/mod.rs"));
|
||||
assert!(glob_match("src/**/mod.rs", "src/mod.rs"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_glob_no_match() {
|
||||
assert!(!glob_match("Bash(rm -rf *)", "cargo build"));
|
||||
}
|
||||
|
||||
// -- PermissionRule ------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_rule_parse_bash() {
|
||||
let r = PermissionRule::parse("Bash(npm *)").unwrap();
|
||||
assert_eq!(r.tool, "bash");
|
||||
assert_eq!(r.pattern, "npm *");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rule_parse_edit() {
|
||||
let r = PermissionRule::parse("Edit(src/**)").unwrap();
|
||||
assert_eq!(r.tool, "edit");
|
||||
assert_eq!(r.pattern, "src/**");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rule_parse_write() {
|
||||
let r = PermissionRule::parse("Write(tests/**)").unwrap();
|
||||
assert_eq!(r.tool, "write");
|
||||
assert_eq!(r.pattern, "tests/**");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rule_parse_invalid() {
|
||||
assert!(PermissionRule::parse("noparen").is_none());
|
||||
assert!(PermissionRule::parse("NoClose(abc").is_none());
|
||||
assert!(PermissionRule::parse("").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rule_matches_bash() {
|
||||
let r = PermissionRule::parse("Bash(npm *)").unwrap();
|
||||
assert!(r.matches(&PermissionRequest::Bash { command: "npm install" }));
|
||||
assert!(!r.matches(&PermissionRequest::Bash { command: "cargo build" }));
|
||||
// Should not match file requests even if pattern happens to align
|
||||
assert!(!r.matches(&PermissionRequest::FileEdit { path: "npm foo" }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rule_matches_edit() {
|
||||
let r = PermissionRule::parse("Edit(src/**)").unwrap();
|
||||
assert!(r.matches(&PermissionRequest::FileEdit { path: "src/main.rs" }));
|
||||
assert!(r.matches(&PermissionRequest::FileEdit { path: "src/foo/bar.rs" }));
|
||||
assert!(!r.matches(&PermissionRequest::FileEdit { path: "tests/foo.rs" }));
|
||||
// Write request should not match an Edit rule
|
||||
assert!(!r.matches(&PermissionRequest::FileWrite { path: "src/main.rs" }));
|
||||
}
|
||||
|
||||
// -- PermissionHandler logic ---------------------------------------------
|
||||
|
||||
fn make_handler(
|
||||
mode: PermissionMode,
|
||||
allow: &[&str],
|
||||
deny: &[&str],
|
||||
) -> PermissionHandler {
|
||||
let settings = PermissionSettings {
|
||||
permissions: PermissionBlock {
|
||||
allow: allow.iter().map(|s| s.to_string()).collect(),
|
||||
deny: deny.iter().map(|s| s.to_string()).collect(),
|
||||
},
|
||||
};
|
||||
PermissionHandler::with_settings(&mode, settings)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_allow_list_overrides_ask_mode() {
|
||||
let h = make_handler(PermissionMode::Ask, &["Bash(cargo *)"], &[]);
|
||||
// Would normally prompt in Ask mode, but allow list should auto-approve.
|
||||
assert!(h.check(&PermissionRequest::Bash { command: "cargo build" }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_allow_list_does_not_match_other_commands() {
|
||||
// "cargo build" is allowed, but "npm install" is not — in Ask mode
|
||||
// it would prompt. We can't test interactive prompt here, so just
|
||||
// verify allow list match is scoped correctly.
|
||||
let h = make_handler(PermissionMode::Yolo, &["Bash(cargo *)"], &[]);
|
||||
// Yolo + no deny + no allow match => still true (via Yolo fallthrough)
|
||||
assert!(h.check(&PermissionRequest::Bash { command: "npm install" }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_yolo_mode_approves_everything_without_deny() {
|
||||
let h = make_handler(PermissionMode::Yolo, &[], &[]);
|
||||
assert!(h.check(&PermissionRequest::Bash { command: "rm -rf /tmp/foo" }));
|
||||
assert!(h.check(&PermissionRequest::FileWrite { path: "/tmp/foo.txt" }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_allow_edits_approves_cwd_files() {
|
||||
let h = make_handler(PermissionMode::AllowEdits, &[], &[]);
|
||||
let cwd = std::env::current_dir().unwrap();
|
||||
let in_cwd = cwd.join("some_file.rs").to_string_lossy().to_string();
|
||||
assert!(h.check(&PermissionRequest::FileEdit { path: &in_cwd }));
|
||||
assert!(h.check(&PermissionRequest::FileWrite { path: &in_cwd }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_settings_merge_allow_and_deny() {
|
||||
let global = PermissionSettings {
|
||||
permissions: PermissionBlock {
|
||||
allow: vec!["Bash(cargo *)".to_string()],
|
||||
deny: vec!["Bash(sudo *)".to_string()],
|
||||
},
|
||||
};
|
||||
let project = PermissionSettings {
|
||||
permissions: PermissionBlock {
|
||||
allow: vec!["Edit(src/**)".to_string()],
|
||||
deny: vec!["Bash(rm -rf *)".to_string()],
|
||||
},
|
||||
};
|
||||
// Simulate merge: project extends global
|
||||
let mut merged = PermissionSettings::default();
|
||||
merged.permissions.allow.extend(global.permissions.allow);
|
||||
merged.permissions.deny.extend(global.permissions.deny);
|
||||
merged.permissions.allow.extend(project.permissions.allow);
|
||||
merged.permissions.deny.extend(project.permissions.deny);
|
||||
|
||||
assert_eq!(merged.permissions.allow.len(), 2);
|
||||
assert_eq!(merged.permissions.deny.len(), 2);
|
||||
|
||||
let allow_rules = merged.allow_rules();
|
||||
assert!(allow_rules
|
||||
.iter()
|
||||
.any(|r| r.matches(&PermissionRequest::Bash { command: "cargo test" })));
|
||||
assert!(allow_rules
|
||||
.iter()
|
||||
.any(|r| r.matches(&PermissionRequest::FileEdit { path: "src/main.rs" })));
|
||||
}
|
||||
}
|
||||
169
src/provider/mod.rs
Normal file
169
src/provider/mod.rs
Normal file
|
|
@ -0,0 +1,169 @@
|
|||
mod types;
|
||||
|
||||
use anyhow::Result;
|
||||
use futures::Stream;
|
||||
use std::pin::Pin;
|
||||
|
||||
pub use types::*;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::tools::ToolDefinition;
|
||||
|
||||
/// Trait for LLM providers. Anything that speaks OpenAI chat completions works.
|
||||
pub trait Provider: Send + Sync {
|
||||
fn stream_chat(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
tools: &[ToolDefinition],
|
||||
) -> Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send + '_>>;
|
||||
}
|
||||
|
||||
/// OpenAI-compatible provider (works with vLLM, Ollama, llama.cpp, OpenAI, etc.)
|
||||
pub struct OpenAIProvider {
|
||||
client: reqwest::Client,
|
||||
endpoint: String,
|
||||
api_key: Option<String>,
|
||||
model: String,
|
||||
max_tokens: u32,
|
||||
temperature: Option<f32>,
|
||||
}
|
||||
|
||||
impl OpenAIProvider {
|
||||
pub fn new(config: &Config) -> Self {
|
||||
Self {
|
||||
client: reqwest::Client::new(),
|
||||
endpoint: config.endpoint.trim_end_matches('/').to_string(),
|
||||
api_key: config.api_key.clone(),
|
||||
model: config.model.clone(),
|
||||
max_tokens: config.max_tokens,
|
||||
temperature: config.temperature,
|
||||
}
|
||||
}
|
||||
|
||||
fn build_request_body(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
tools: &[ToolDefinition],
|
||||
) -> serde_json::Value {
|
||||
let mut body = serde_json::json!({
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": self.max_tokens,
|
||||
"stream": true,
|
||||
});
|
||||
|
||||
if let Some(temp) = self.temperature {
|
||||
body["temperature"] = serde_json::json!(temp);
|
||||
}
|
||||
|
||||
if !tools.is_empty() {
|
||||
let tool_defs: Vec<serde_json::Value> = tools
|
||||
.iter()
|
||||
.map(|t| {
|
||||
serde_json::json!({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": t.name,
|
||||
"description": t.description,
|
||||
"parameters": t.parameters,
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
body["tools"] = serde_json::json!(tool_defs);
|
||||
}
|
||||
|
||||
body
|
||||
}
|
||||
}
|
||||
|
||||
impl Provider for OpenAIProvider {
|
||||
fn stream_chat(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
tools: &[ToolDefinition],
|
||||
) -> Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send + '_>> {
|
||||
let body = self.build_request_body(messages, tools);
|
||||
let url = format!("{}/chat/completions", self.endpoint);
|
||||
|
||||
let mut req = self.client.post(&url).json(&body);
|
||||
if let Some(ref key) = self.api_key {
|
||||
req = req.bearer_auth(key);
|
||||
}
|
||||
|
||||
Box::pin(async_stream::stream! {
|
||||
let response = match req.send().await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
yield Err(anyhow::anyhow!(e));
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response.text().await.unwrap_or_default();
|
||||
yield Err(anyhow::anyhow!("API error {status}: {text}"));
|
||||
return;
|
||||
}
|
||||
|
||||
use futures::StreamExt;
|
||||
let mut stream = response.bytes_stream();
|
||||
let mut buffer = String::new();
|
||||
|
||||
while let Some(chunk) = stream.next().await {
|
||||
let chunk = match chunk {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
yield Err(anyhow::anyhow!(e));
|
||||
return;
|
||||
}
|
||||
};
|
||||
buffer.push_str(&String::from_utf8_lossy(&chunk));
|
||||
|
||||
// Process complete SSE lines
|
||||
while let Some(line_end) = buffer.find('\n') {
|
||||
let line = buffer[..line_end].trim().to_string();
|
||||
buffer = buffer[line_end + 1..].to_string();
|
||||
|
||||
if line.is_empty() || line.starts_with(':') {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(data) = line.strip_prefix("data: ") {
|
||||
if data == "[DONE]" {
|
||||
yield Ok(StreamEvent::Done);
|
||||
return;
|
||||
}
|
||||
|
||||
match serde_json::from_str::<StreamChunk>(data) {
|
||||
Ok(chunk) => {
|
||||
for choice in &chunk.choices {
|
||||
if let Some(ref content) = choice.delta.content {
|
||||
yield Ok(StreamEvent::Text(content.clone()));
|
||||
}
|
||||
if let Some(ref tool_calls) = choice.delta.tool_calls {
|
||||
for tc in tool_calls {
|
||||
yield Ok(StreamEvent::ToolCallDelta(ToolCallDelta {
|
||||
index: tc.index,
|
||||
id: tc.id.clone(),
|
||||
name: tc.function.as_ref().and_then(|f| f.name.clone()),
|
||||
arguments_delta: tc.function.as_ref().and_then(|f| f.arguments.clone()),
|
||||
}));
|
||||
}
|
||||
}
|
||||
if choice.finish_reason.is_some() {
|
||||
yield Ok(StreamEvent::Finish);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to parse SSE chunk: {e}: {data}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
121
src/provider/types.rs
Normal file
121
src/provider/types.rs
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// A message in the chat conversation.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatMessage {
|
||||
pub role: Role,
|
||||
pub content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_call_id: Option<String>,
|
||||
}
|
||||
|
||||
impl ChatMessage {
|
||||
pub fn system(content: &str) -> Self {
|
||||
Self {
|
||||
role: Role::System,
|
||||
content: Some(content.to_string()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn user(content: &str) -> Self {
|
||||
Self {
|
||||
role: Role::User,
|
||||
content: Some(content.to_string()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn assistant(content: Option<String>, tool_calls: Option<Vec<ToolCall>>) -> Self {
|
||||
Self {
|
||||
role: Role::Assistant,
|
||||
content,
|
||||
tool_calls,
|
||||
tool_call_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tool_result(tool_call_id: &str, content: &str) -> Self {
|
||||
Self {
|
||||
role: Role::Tool,
|
||||
content: Some(content.to_string()),
|
||||
tool_calls: None,
|
||||
tool_call_id: Some(tool_call_id.to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Role {
|
||||
System,
|
||||
User,
|
||||
Assistant,
|
||||
Tool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolCall {
|
||||
pub id: String,
|
||||
#[serde(rename = "type")]
|
||||
pub call_type: String,
|
||||
pub function: FunctionCall,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FunctionCall {
|
||||
pub name: String,
|
||||
pub arguments: String,
|
||||
}
|
||||
|
||||
/// Streaming SSE chunk from the OpenAI-compatible API.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct StreamChunk {
|
||||
pub choices: Vec<StreamChoice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct StreamChoice {
|
||||
pub delta: StreamDelta,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct StreamDelta {
|
||||
pub content: Option<String>,
|
||||
pub tool_calls: Option<Vec<StreamToolCall>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct StreamToolCall {
|
||||
pub index: usize,
|
||||
pub id: Option<String>,
|
||||
pub function: Option<StreamFunctionCall>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct StreamFunctionCall {
|
||||
pub name: Option<String>,
|
||||
pub arguments: Option<String>,
|
||||
}
|
||||
|
||||
/// High-level stream events emitted by the provider.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum StreamEvent {
|
||||
Text(String),
|
||||
ToolCallDelta(ToolCallDelta),
|
||||
Finish,
|
||||
Done,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ToolCallDelta {
|
||||
pub index: usize,
|
||||
pub id: Option<String>,
|
||||
pub name: Option<String>,
|
||||
pub arguments_delta: Option<String>,
|
||||
}
|
||||
233
src/session/mod.rs
Normal file
233
src/session/mod.rs
Normal file
|
|
@ -0,0 +1,233 @@
|
|||
use anyhow::{Context, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fs::{self, File, OpenOptions};
|
||||
use std::io::{BufRead, BufReader, Write};
|
||||
use std::path::PathBuf;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::provider::ChatMessage;
|
||||
|
||||
/// Metadata for a saved session.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SessionMeta {
|
||||
pub session_id: String,
|
||||
pub created_at: String,
|
||||
pub last_used_at: String,
|
||||
pub working_directory: String,
|
||||
pub model: String,
|
||||
pub summary: String,
|
||||
}
|
||||
|
||||
/// An in-memory session handle returned on create or fork.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Session {
|
||||
pub session_id: String,
|
||||
pub messages: Vec<ChatMessage>,
|
||||
}
|
||||
|
||||
/// Manages session persistence on disk.
|
||||
pub struct SessionManager {
|
||||
sessions_dir: PathBuf,
|
||||
}
|
||||
|
||||
impl SessionManager {
|
||||
/// Create a new SessionManager, ensuring the sessions directory exists.
|
||||
pub fn new() -> Self {
|
||||
let sessions_dir = Self::sessions_dir();
|
||||
if let Err(e) = fs::create_dir_all(&sessions_dir) {
|
||||
eprintln!("Warning: could not create sessions directory: {e}");
|
||||
}
|
||||
Self { sessions_dir }
|
||||
}
|
||||
|
||||
fn sessions_dir() -> PathBuf {
|
||||
dirs::home_dir()
|
||||
.unwrap_or_else(|| PathBuf::from("."))
|
||||
.join(".slug")
|
||||
.join("sessions")
|
||||
}
|
||||
|
||||
fn jsonl_path(&self, session_id: &str) -> PathBuf {
|
||||
self.sessions_dir.join(format!("{session_id}.jsonl"))
|
||||
}
|
||||
|
||||
fn meta_path(&self, session_id: &str) -> PathBuf {
|
||||
self.sessions_dir
|
||||
.join(format!("{session_id}.meta.json"))
|
||||
}
|
||||
|
||||
fn now_iso8601() -> String {
|
||||
// Use std::time to get a basic timestamp without pulling in chrono.
|
||||
// Format: seconds since UNIX epoch as a UTC ISO-8601 string.
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
let secs = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.map(|d| d.as_secs())
|
||||
.unwrap_or(0);
|
||||
// Produce a minimal ISO-8601-like UTC string: YYYY-MM-DDTHH:MM:SSZ
|
||||
let s = secs;
|
||||
let sec = s % 60;
|
||||
let min = (s / 60) % 60;
|
||||
let hour = (s / 3600) % 24;
|
||||
let days = s / 86400; // days since 1970-01-01
|
||||
// Convert days → calendar date (proleptic Gregorian)
|
||||
let (year, month, day) = days_to_ymd(days);
|
||||
format!(
|
||||
"{:04}-{:02}-{:02}T{:02}:{:02}:{:02}Z",
|
||||
year, month, day, hour, min, sec
|
||||
)
|
||||
}
|
||||
|
||||
/// Create a brand-new session with a fresh UUID. Does not write to disk yet.
|
||||
pub fn create_session(&self) -> Session {
|
||||
Session {
|
||||
session_id: Uuid::new_v4().to_string(),
|
||||
messages: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Write (or overwrite) the metadata file for a session.
|
||||
pub fn write_meta(&self, meta: &SessionMeta) -> Result<()> {
|
||||
let path = self.meta_path(&meta.session_id);
|
||||
let json = serde_json::to_string_pretty(meta)
|
||||
.context("failed to serialize session metadata")?;
|
||||
fs::write(&path, json).with_context(|| format!("failed to write meta file {path:?}"))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Initialize metadata for a newly created session.
|
||||
pub fn init_meta(
|
||||
&self,
|
||||
session: &Session,
|
||||
working_directory: &str,
|
||||
model: &str,
|
||||
summary: &str,
|
||||
) -> Result<()> {
|
||||
let now = Self::now_iso8601();
|
||||
let meta = SessionMeta {
|
||||
session_id: session.session_id.clone(),
|
||||
created_at: now.clone(),
|
||||
last_used_at: now,
|
||||
working_directory: working_directory.to_string(),
|
||||
model: model.to_string(),
|
||||
summary: summary.to_string(),
|
||||
};
|
||||
self.write_meta(&meta)
|
||||
}
|
||||
|
||||
/// Append a single message to the JSONL file and update last_used_at.
|
||||
pub fn save_message(&self, session_id: &str, message: &ChatMessage) -> Result<()> {
|
||||
let path = self.jsonl_path(session_id);
|
||||
let mut file = OpenOptions::new()
|
||||
.create(true)
|
||||
.append(true)
|
||||
.open(&path)
|
||||
.with_context(|| format!("failed to open session file {path:?}"))?;
|
||||
let line = serde_json::to_string(message).context("failed to serialize message")?;
|
||||
writeln!(file, "{line}").with_context(|| format!("failed to write to {path:?}"))?;
|
||||
|
||||
// Update last_used_at in meta if it exists.
|
||||
let meta_path = self.meta_path(session_id);
|
||||
if meta_path.exists() {
|
||||
if let Ok(raw) = fs::read_to_string(&meta_path) {
|
||||
if let Ok(mut meta) = serde_json::from_str::<SessionMeta>(&raw) {
|
||||
meta.last_used_at = Self::now_iso8601();
|
||||
let _ = self.write_meta(&meta);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Read all messages from a session's JSONL file.
|
||||
pub fn load_session(&self, session_id: &str) -> Result<Vec<ChatMessage>> {
|
||||
let path = self.jsonl_path(session_id);
|
||||
if !path.exists() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let file =
|
||||
File::open(&path).with_context(|| format!("failed to open session file {path:?}"))?;
|
||||
let reader = BufReader::new(file);
|
||||
let mut messages = Vec::new();
|
||||
for (line_no, line) in reader.lines().enumerate() {
|
||||
let line = line.with_context(|| format!("error reading line {line_no} of {path:?}"))?;
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
let msg: ChatMessage = serde_json::from_str(&line)
|
||||
.with_context(|| format!("failed to parse line {line_no} of {path:?}"))?;
|
||||
messages.push(msg);
|
||||
}
|
||||
Ok(messages)
|
||||
}
|
||||
|
||||
/// List all sessions, sorted by last_used_at descending (most recent first).
|
||||
pub fn list_sessions(&self) -> Vec<SessionMeta> {
|
||||
let mut metas = Vec::new();
|
||||
let entries = match fs::read_dir(&self.sessions_dir) {
|
||||
Ok(e) => e,
|
||||
Err(_) => return metas,
|
||||
};
|
||||
for entry in entries.flatten() {
|
||||
let path = entry.path();
|
||||
if path.extension().and_then(|e| e.to_str()) == Some("json")
|
||||
&& path
|
||||
.file_name()
|
||||
.and_then(|n| n.to_str())
|
||||
.map(|n| n.ends_with(".meta.json"))
|
||||
.unwrap_or(false)
|
||||
{
|
||||
if let Ok(raw) = fs::read_to_string(&path) {
|
||||
if let Ok(meta) = serde_json::from_str::<SessionMeta>(&raw) {
|
||||
metas.push(meta);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Sort most-recent first.
|
||||
metas.sort_by(|a, b| b.last_used_at.cmp(&a.last_used_at));
|
||||
metas
|
||||
}
|
||||
|
||||
/// Return the most recently used session, if any.
|
||||
pub fn get_latest_session(&self) -> Option<SessionMeta> {
|
||||
self.list_sessions().into_iter().next()
|
||||
}
|
||||
|
||||
/// Create a new session that starts with all messages copied from source_id.
|
||||
pub fn fork_session(&self, source_id: &str) -> Result<Session> {
|
||||
let messages = self.load_session(source_id)
|
||||
.with_context(|| format!("failed to load source session {source_id}"))?;
|
||||
let new_id = Uuid::new_v4().to_string();
|
||||
// Write all messages into the new JSONL immediately.
|
||||
for msg in &messages {
|
||||
self.save_message(&new_id, msg)?;
|
||||
}
|
||||
Ok(Session {
|
||||
session_id: new_id,
|
||||
messages,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Minimal days-since-epoch → (year, month, day) conversion
|
||||
// (avoids a chrono dependency for the timestamp helper)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn days_to_ymd(days: u64) -> (u64, u8, u8) {
|
||||
// Algorithm: civil date from days since 1970-01-01
|
||||
// Based on Howard Hinnant's public-domain algorithm.
|
||||
let z = days as i64 + 719468;
|
||||
let era = if z >= 0 { z } else { z - 146096 } / 146097;
|
||||
let doe = z - era * 146097;
|
||||
let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365;
|
||||
let y = yoe + era * 400;
|
||||
let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
|
||||
let mp = (5 * doy + 2) / 153;
|
||||
let d = doy - (153 * mp + 2) / 5 + 1;
|
||||
let m = if mp < 10 { mp + 3 } else { mp - 9 };
|
||||
let y = if m <= 2 { y + 1 } else { y };
|
||||
(y as u64, m as u8, d as u8)
|
||||
}
|
||||
141
src/slugmd/mod.rs
Normal file
141
src/slugmd/mod.rs
Normal file
|
|
@ -0,0 +1,141 @@
|
|||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
|
||||
const MAX_CHARS: usize = 40_000;
|
||||
|
||||
/// Read a file and return its trimmed contents, or None if the file doesn't exist.
|
||||
fn read_file_optional(path: &PathBuf) -> Option<String> {
|
||||
match fs::read_to_string(path) {
|
||||
Ok(contents) => {
|
||||
let trimmed = contents.trim().to_string();
|
||||
if trimmed.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(trimmed)
|
||||
}
|
||||
}
|
||||
Err(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Load all SLUG.md config files in priority order and return a concatenated string.
|
||||
///
|
||||
/// Priority (later overrides earlier):
|
||||
/// 1. ~/.slug/SLUG.md — global user preferences
|
||||
/// 2. ./SLUG.md — project-level (in working directory)
|
||||
/// 3. .slug/rules/*.md — modular rule files in the project
|
||||
/// 4. SLUG.local.md — private notes (gitignored)
|
||||
///
|
||||
/// Returns an empty string if no config files are found.
|
||||
/// Truncates with a warning if the combined content exceeds 40,000 characters.
|
||||
pub fn load_slug_context() -> String {
|
||||
let mut sections: Vec<(String, String)> = Vec::new();
|
||||
|
||||
// 1. Global user preferences: ~/.slug/SLUG.md
|
||||
if let Some(home_dir) = dirs::home_dir() {
|
||||
let global_path = home_dir.join(".slug").join("SLUG.md");
|
||||
if let Some(content) = read_file_optional(&global_path) {
|
||||
sections.push(("# Global Rules (~/.slug/SLUG.md)".to_string(), content));
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Project-level: ./SLUG.md
|
||||
let project_path = PathBuf::from("SLUG.md");
|
||||
if let Some(content) = read_file_optional(&project_path) {
|
||||
sections.push(("# Project Rules (SLUG.md)".to_string(), content));
|
||||
}
|
||||
|
||||
// 3. Modular rule files: .slug/rules/*.md
|
||||
let rules_dir = PathBuf::from(".slug/rules");
|
||||
if rules_dir.is_dir() {
|
||||
let mut rule_files: Vec<PathBuf> = match fs::read_dir(&rules_dir) {
|
||||
Ok(entries) => entries
|
||||
.filter_map(|e| e.ok())
|
||||
.map(|e| e.path())
|
||||
.filter(|p| p.extension().map(|ext| ext == "md").unwrap_or(false))
|
||||
.collect(),
|
||||
Err(_) => Vec::new(),
|
||||
};
|
||||
// Sort for deterministic ordering
|
||||
rule_files.sort();
|
||||
|
||||
for rule_file in rule_files {
|
||||
if let Some(content) = read_file_optional(&rule_file) {
|
||||
let filename = rule_file
|
||||
.file_name()
|
||||
.and_then(|n| n.to_str())
|
||||
.unwrap_or("unknown");
|
||||
let header = format!("# Project Rule (.slug/rules/{filename})");
|
||||
sections.push((header, content));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Private local notes: ./SLUG.local.md
|
||||
let local_path = PathBuf::from("SLUG.local.md");
|
||||
if let Some(content) = read_file_optional(&local_path) {
|
||||
sections.push((
|
||||
"# Local Notes (SLUG.local.md)".to_string(),
|
||||
content,
|
||||
));
|
||||
}
|
||||
|
||||
if sections.is_empty() {
|
||||
return String::new();
|
||||
}
|
||||
|
||||
// Concatenate sections with headers
|
||||
let mut result = String::new();
|
||||
for (header, content) in §ions {
|
||||
if !result.is_empty() {
|
||||
result.push('\n');
|
||||
}
|
||||
result.push_str(header);
|
||||
result.push('\n');
|
||||
result.push_str(content);
|
||||
result.push('\n');
|
||||
}
|
||||
|
||||
// Enforce character budget
|
||||
if result.len() > MAX_CHARS {
|
||||
result.truncate(MAX_CHARS);
|
||||
// Try to truncate at a clean line boundary
|
||||
if let Some(last_newline) = result.rfind('\n') {
|
||||
result.truncate(last_newline + 1);
|
||||
}
|
||||
result.push_str("\n[SLUG context truncated: exceeded 40,000 character limit]\n");
|
||||
eprintln!(
|
||||
"Warning: SLUG.md context exceeded {MAX_CHARS} characters and was truncated."
|
||||
);
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_load_slug_context_no_files() {
|
||||
// In a directory without any SLUG files, should return empty string
|
||||
// (This test is environment-dependent — it passes when run from a clean dir)
|
||||
let result = load_slug_context();
|
||||
// Just verify it doesn't panic and returns a String
|
||||
let _ = result.len();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_truncation_logic() {
|
||||
let long_content = "x".repeat(MAX_CHARS + 100);
|
||||
let mut result = long_content;
|
||||
if result.len() > MAX_CHARS {
|
||||
result.truncate(MAX_CHARS);
|
||||
if let Some(last_newline) = result.rfind('\n') {
|
||||
result.truncate(last_newline + 1);
|
||||
}
|
||||
result.push_str("\n[SLUG context truncated: exceeded 40,000 character limit]\n");
|
||||
}
|
||||
assert!(result.contains("truncated"));
|
||||
}
|
||||
}
|
||||
56
src/tools/bash.rs
Normal file
56
src/tools/bash.rs
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
use super::{Tool, ToolDefinition};
|
||||
use anyhow::Result;
|
||||
use serde_json::Value;
|
||||
use std::process::Command;
|
||||
|
||||
pub struct BashTool;
|
||||
|
||||
impl Tool for BashTool {
|
||||
fn definition(&self) -> ToolDefinition {
|
||||
ToolDefinition {
|
||||
name: "bash".to_string(),
|
||||
description: "Execute a bash command and return its output.".to_string(),
|
||||
parameters: serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "The bash command to execute"
|
||||
}
|
||||
},
|
||||
"required": ["command"]
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn execute(&self, args: &Value) -> Result<String> {
|
||||
let command = args["command"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'command' argument"))?;
|
||||
|
||||
let output = Command::new("bash")
|
||||
.arg("-c")
|
||||
.arg(command)
|
||||
.output()?;
|
||||
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
|
||||
let mut result = String::new();
|
||||
if !stdout.is_empty() {
|
||||
result.push_str(&stdout);
|
||||
}
|
||||
if !stderr.is_empty() {
|
||||
if !result.is_empty() {
|
||||
result.push('\n');
|
||||
}
|
||||
result.push_str("STDERR:\n");
|
||||
result.push_str(&stderr);
|
||||
}
|
||||
if !output.status.success() {
|
||||
result.push_str(&format!("\nExit code: {}", output.status.code().unwrap_or(-1)));
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
69
src/tools/edit.rs
Normal file
69
src/tools/edit.rs
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
use super::{Tool, ToolDefinition};
|
||||
use anyhow::Result;
|
||||
use serde_json::Value;
|
||||
use std::fs;
|
||||
|
||||
pub struct EditTool;
|
||||
|
||||
impl Tool for EditTool {
|
||||
fn definition(&self) -> ToolDefinition {
|
||||
ToolDefinition {
|
||||
name: "edit".to_string(),
|
||||
description: "Perform an exact string replacement in a file. The old_string must match exactly one location in the file.".to_string(),
|
||||
parameters: serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "Absolute path to the file to edit"
|
||||
},
|
||||
"old_string": {
|
||||
"type": "string",
|
||||
"description": "The exact string to find and replace"
|
||||
},
|
||||
"new_string": {
|
||||
"type": "string",
|
||||
"description": "The replacement string"
|
||||
},
|
||||
"replace_all": {
|
||||
"type": "boolean",
|
||||
"description": "Replace all occurrences (default: false)"
|
||||
}
|
||||
},
|
||||
"required": ["file_path", "old_string", "new_string"]
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn execute(&self, args: &Value) -> Result<String> {
|
||||
let path = args["file_path"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'file_path' argument"))?;
|
||||
let old_string = args["old_string"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'old_string' argument"))?;
|
||||
let new_string = args["new_string"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'new_string' argument"))?;
|
||||
let replace_all = args["replace_all"].as_bool().unwrap_or(false);
|
||||
|
||||
let content = fs::read_to_string(path)?;
|
||||
|
||||
let count = content.matches(old_string).count();
|
||||
if count == 0 {
|
||||
anyhow::bail!("old_string not found in {path}");
|
||||
}
|
||||
if count > 1 && !replace_all {
|
||||
anyhow::bail!("old_string matches {count} locations in {path}. Use replace_all or provide more context.");
|
||||
}
|
||||
|
||||
let new_content = if replace_all {
|
||||
content.replace(old_string, new_string)
|
||||
} else {
|
||||
content.replacen(old_string, new_string, 1)
|
||||
};
|
||||
|
||||
fs::write(path, &new_content)?;
|
||||
Ok(format!("Replaced {count} occurrence(s) in {path}"))
|
||||
}
|
||||
}
|
||||
60
src/tools/glob.rs
Normal file
60
src/tools/glob.rs
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
use super::{Tool, ToolDefinition};
|
||||
use anyhow::Result;
|
||||
use serde_json::Value;
|
||||
use std::path::Path;
|
||||
|
||||
pub struct GlobTool;
|
||||
|
||||
impl Tool for GlobTool {
|
||||
fn definition(&self) -> ToolDefinition {
|
||||
ToolDefinition {
|
||||
name: "glob".to_string(),
|
||||
description: "Find files matching a glob pattern. Returns matching file paths.".to_string(),
|
||||
parameters: serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "Glob pattern to match (e.g., '**/*.rs', 'src/**/*.ts')"
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Directory to search in (defaults to current directory)"
|
||||
}
|
||||
},
|
||||
"required": ["pattern"]
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn execute(&self, args: &Value) -> Result<String> {
|
||||
let pattern = args["pattern"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'pattern' argument"))?;
|
||||
let base = args["path"]
|
||||
.as_str()
|
||||
.unwrap_or(".");
|
||||
|
||||
let base_path = Path::new(base);
|
||||
if !base_path.exists() {
|
||||
anyhow::bail!("Directory does not exist: {base}");
|
||||
}
|
||||
|
||||
let walker = globwalk::GlobWalkerBuilder::from_patterns(base_path, &[pattern])
|
||||
.max_depth(20)
|
||||
.build()?;
|
||||
|
||||
let mut paths: Vec<String> = walker
|
||||
.filter_map(|e| e.ok())
|
||||
.map(|e| e.path().display().to_string())
|
||||
.collect();
|
||||
|
||||
paths.sort();
|
||||
|
||||
if paths.is_empty() {
|
||||
Ok("No files found".to_string())
|
||||
} else {
|
||||
Ok(paths.join("\n"))
|
||||
}
|
||||
}
|
||||
}
|
||||
83
src/tools/grep.rs
Normal file
83
src/tools/grep.rs
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
use super::{Tool, ToolDefinition};
|
||||
use anyhow::Result;
|
||||
use serde_json::Value;
|
||||
use std::process::Command;
|
||||
|
||||
pub struct GrepTool;
|
||||
|
||||
impl Tool for GrepTool {
|
||||
fn definition(&self) -> ToolDefinition {
|
||||
ToolDefinition {
|
||||
name: "grep".to_string(),
|
||||
description: "Search file contents using a regex pattern. Uses ripgrep if available, falls back to grep.".to_string(),
|
||||
parameters: serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "Regex pattern to search for"
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "File or directory to search in (defaults to current directory)"
|
||||
},
|
||||
"glob": {
|
||||
"type": "string",
|
||||
"description": "Glob filter for file types (e.g., '*.rs')"
|
||||
},
|
||||
"case_insensitive": {
|
||||
"type": "boolean",
|
||||
"description": "Case insensitive search"
|
||||
}
|
||||
},
|
||||
"required": ["pattern"]
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn execute(&self, args: &Value) -> Result<String> {
|
||||
let pattern = args["pattern"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'pattern' argument"))?;
|
||||
let path = args["path"].as_str().unwrap_or(".");
|
||||
let glob_filter = args["glob"].as_str();
|
||||
let case_insensitive = args["case_insensitive"].as_bool().unwrap_or(false);
|
||||
|
||||
// Prefer ripgrep, fall back to grep
|
||||
let (cmd, use_rg) = if which::which("rg").is_ok() {
|
||||
("rg", true)
|
||||
} else {
|
||||
("grep", false)
|
||||
};
|
||||
|
||||
let mut command = Command::new(cmd);
|
||||
|
||||
if use_rg {
|
||||
command.arg("--no-heading").arg("-n");
|
||||
if case_insensitive {
|
||||
command.arg("-i");
|
||||
}
|
||||
if let Some(g) = glob_filter {
|
||||
command.arg("--glob").arg(g);
|
||||
}
|
||||
command.arg(pattern).arg(path);
|
||||
} else {
|
||||
command.arg("-rn");
|
||||
if case_insensitive {
|
||||
command.arg("-i");
|
||||
}
|
||||
command.arg(pattern).arg(path);
|
||||
}
|
||||
|
||||
let output = command.output()?;
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
|
||||
if stdout.is_empty() {
|
||||
Ok("No matches found".to_string())
|
||||
} else {
|
||||
// Limit output
|
||||
let lines: Vec<&str> = stdout.lines().take(250).collect();
|
||||
Ok(lines.join("\n"))
|
||||
}
|
||||
}
|
||||
}
|
||||
61
src/tools/mod.rs
Normal file
61
src/tools/mod.rs
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
mod bash;
|
||||
mod edit;
|
||||
mod glob;
|
||||
mod grep;
|
||||
mod read;
|
||||
mod write;
|
||||
|
||||
use anyhow::Result;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Schema describing a tool for the LLM.
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct ToolDefinition {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub parameters: Value,
|
||||
}
|
||||
|
||||
/// Trait for executable tools.
|
||||
pub trait Tool: Send + Sync {
|
||||
fn definition(&self) -> ToolDefinition;
|
||||
fn execute(&self, args: &Value) -> Result<String>;
|
||||
}
|
||||
|
||||
/// Registry of all available tools.
|
||||
pub struct ToolRegistry {
|
||||
tools: HashMap<String, Box<dyn Tool>>,
|
||||
}
|
||||
|
||||
impl ToolRegistry {
|
||||
pub fn new() -> Self {
|
||||
let mut registry = Self {
|
||||
tools: HashMap::new(),
|
||||
};
|
||||
registry.register(Box::new(bash::BashTool));
|
||||
registry.register(Box::new(read::ReadTool));
|
||||
registry.register(Box::new(write::WriteTool));
|
||||
registry.register(Box::new(edit::EditTool));
|
||||
registry.register(Box::new(glob::GlobTool));
|
||||
registry.register(Box::new(grep::GrepTool));
|
||||
registry
|
||||
}
|
||||
|
||||
fn register(&mut self, tool: Box<dyn Tool>) {
|
||||
let name = tool.definition().name.clone();
|
||||
self.tools.insert(name, tool);
|
||||
}
|
||||
|
||||
pub fn definitions(&self) -> Vec<ToolDefinition> {
|
||||
self.tools.values().map(|t| t.definition()).collect()
|
||||
}
|
||||
|
||||
pub fn execute(&self, name: &str, args: &Value) -> Result<String> {
|
||||
match self.tools.get(name) {
|
||||
Some(tool) => tool.execute(args),
|
||||
None => anyhow::bail!("Unknown tool: {name}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
54
src/tools/read.rs
Normal file
54
src/tools/read.rs
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
use super::{Tool, ToolDefinition};
|
||||
use anyhow::Result;
|
||||
use serde_json::Value;
|
||||
use std::fs;
|
||||
|
||||
pub struct ReadTool;
|
||||
|
||||
impl Tool for ReadTool {
|
||||
fn definition(&self) -> ToolDefinition {
|
||||
ToolDefinition {
|
||||
name: "read".to_string(),
|
||||
description: "Read a file and return its contents with line numbers.".to_string(),
|
||||
parameters: serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "Absolute path to the file to read"
|
||||
},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": "Line number to start reading from (1-based)"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of lines to read"
|
||||
}
|
||||
},
|
||||
"required": ["file_path"]
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn execute(&self, args: &Value) -> Result<String> {
|
||||
let path = args["file_path"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'file_path' argument"))?;
|
||||
|
||||
let content = fs::read_to_string(path)?;
|
||||
let lines: Vec<&str> = content.lines().collect();
|
||||
|
||||
let offset = args["offset"].as_u64().unwrap_or(1).max(1) as usize - 1;
|
||||
let limit = args["limit"].as_u64().unwrap_or(2000) as usize;
|
||||
|
||||
let end = (offset + limit).min(lines.len());
|
||||
let mut result = String::new();
|
||||
|
||||
for (i, line) in lines[offset..end].iter().enumerate() {
|
||||
result.push_str(&format!("{}\t{}\n", offset + i + 1, line));
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
47
src/tools/write.rs
Normal file
47
src/tools/write.rs
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
use super::{Tool, ToolDefinition};
|
||||
use anyhow::Result;
|
||||
use serde_json::Value;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
|
||||
pub struct WriteTool;
|
||||
|
||||
impl Tool for WriteTool {
|
||||
fn definition(&self) -> ToolDefinition {
|
||||
ToolDefinition {
|
||||
name: "write".to_string(),
|
||||
description: "Write content to a file, creating it if it doesn't exist.".to_string(),
|
||||
parameters: serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "Absolute path to the file to write"
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The content to write to the file"
|
||||
}
|
||||
},
|
||||
"required": ["file_path", "content"]
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn execute(&self, args: &Value) -> Result<String> {
|
||||
let path = args["file_path"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'file_path' argument"))?;
|
||||
let content = args["content"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'content' argument"))?;
|
||||
|
||||
// Create parent directories if needed
|
||||
if let Some(parent) = Path::new(path).parent() {
|
||||
fs::create_dir_all(parent)?;
|
||||
}
|
||||
|
||||
fs::write(path, content)?;
|
||||
Ok(format!("Written {} bytes to {path}", content.len()))
|
||||
}
|
||||
}
|
||||
119
src/tui/mod.rs
Normal file
119
src/tui/mod.rs
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
use anyhow::Result;
|
||||
use std::io::{self, Write};
|
||||
|
||||
use crate::agent::Agent;
|
||||
use crate::config::Config;
|
||||
use crate::hooks::{HookEvent, HookManager};
|
||||
use crate::permissions::PermissionHandler;
|
||||
use crate::provider::{ChatMessage, OpenAIProvider};
|
||||
use crate::session::SessionManager;
|
||||
use crate::tools::ToolRegistry;
|
||||
|
||||
/// Run the interactive TUI loop.
|
||||
pub async fn run(
|
||||
provider: OpenAIProvider,
|
||||
tools: ToolRegistry,
|
||||
permissions: PermissionHandler,
|
||||
config: Config,
|
||||
session_mgr: SessionManager,
|
||||
session_id: String,
|
||||
prior_messages: Vec<ChatMessage>,
|
||||
hook_mgr: HookManager,
|
||||
) -> Result<()> {
|
||||
println!("\x1b[1mslug\x1b[0m v{}", env!("CARGO_PKG_VERSION"));
|
||||
println!("Model: {} @ {}", config.model, config.endpoint);
|
||||
println!("Mode: {:?}", config.permission_mode);
|
||||
println!("Session: {}", &session_id[..session_id.len().min(8)]);
|
||||
println!("Type your message. Press Ctrl+C to exit.\n");
|
||||
|
||||
let mut agent = Agent::new_with_history(
|
||||
Box::new(provider),
|
||||
tools,
|
||||
permissions,
|
||||
&config,
|
||||
prior_messages,
|
||||
);
|
||||
|
||||
loop {
|
||||
print!("\x1b[1;32m>\x1b[0m ");
|
||||
io::stdout().flush()?;
|
||||
|
||||
let mut input = String::new();
|
||||
io::stdin().read_line(&mut input)?;
|
||||
let input = input.trim();
|
||||
|
||||
if input.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
match input {
|
||||
"/quit" | "/exit" => {
|
||||
hook_mgr.fire(&HookEvent::SessionEnd);
|
||||
// Save all messages on exit
|
||||
for msg in agent.messages() {
|
||||
let _ = session_mgr.save_message(&session_id, msg);
|
||||
}
|
||||
break;
|
||||
}
|
||||
"/help" => {
|
||||
println!("Commands:");
|
||||
println!(" /quit - Exit");
|
||||
println!(" /help - Show this help");
|
||||
println!(" /clear - Clear conversation history");
|
||||
println!(" /compact - Compress conversation context");
|
||||
continue;
|
||||
}
|
||||
"/clear" => {
|
||||
agent = Agent::new(
|
||||
Box::new(OpenAIProvider::new(&config)),
|
||||
ToolRegistry::new(),
|
||||
PermissionHandler::new(&config.permission_mode),
|
||||
&config,
|
||||
);
|
||||
println!("Conversation cleared.\n");
|
||||
continue;
|
||||
}
|
||||
"/compact" => {
|
||||
println!("Compacting conversation...");
|
||||
agent.compact();
|
||||
println!("Done. Context compressed.\n");
|
||||
continue;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// Fire UserPromptSubmit hook
|
||||
let hook_result = hook_mgr.fire(&HookEvent::UserPromptSubmit {
|
||||
prompt: input.to_string(),
|
||||
});
|
||||
if hook_result.blocked {
|
||||
if let Some(reason) = &hook_result.block_reason {
|
||||
eprintln!("\x1b[31mBlocked by hook: {reason}\x1b[0m\n");
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// If hook injected additional context, prepend it
|
||||
let effective_input = if let Some(ref ctx) = hook_result.additional_context {
|
||||
format!("{input}\n\n[Additional context from hooks:\n{ctx}\n]")
|
||||
} else {
|
||||
input.to_string()
|
||||
};
|
||||
|
||||
println!();
|
||||
|
||||
let result = agent
|
||||
.stream_turn(&effective_input, |chunk| {
|
||||
print!("{chunk}");
|
||||
let _ = io::stdout().flush();
|
||||
})
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(()) => println!("\n"),
|
||||
Err(e) => eprintln!("\n\x1b[31mError: {e}\x1b[0m\n"),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue