diff --git a/cli/test/test_config_generator.py b/cli/test/test_config_generator.py index b3e3ab628..30890087f 100644 --- a/cli/test/test_config_generator.py +++ b/cli/test/test_config_generator.py @@ -11,15 +11,11 @@ def cleanup_env(monkeypatch): monkeypatch.undo() -def test_validate_and_render_happy_path(monkeypatch): - monkeypatch.setenv("PLANO_CONFIG_FILE", "fake_plano_config.yaml") - monkeypatch.setenv("PLANO_CONFIG_SCHEMA_FILE", "fake_plano_config_schema.yaml") - monkeypatch.setenv("ENVOY_CONFIG_TEMPLATE_FILE", "./envoy.template.yaml") - monkeypatch.setenv("PLANO_CONFIG_FILE_RENDERED", "fake_plano_config_rendered.yaml") - monkeypatch.setenv("ENVOY_CONFIG_FILE_RENDERED", "fake_envoy.yaml") - monkeypatch.setenv("TEMPLATE_ROOT", "../") - - plano_config = """ +@pytest.mark.parametrize( + "plano_config", + [ + # Case 1: LLM provider config + """ version: v0.1.0 listeners: @@ -49,40 +45,9 @@ def test_validate_and_render_happy_path(monkeypatch): tracing: random_sampling: 100 -""" - plano_config_schema = "" - with open("../config/plano_config_schema.yaml", "r") as file: - plano_config_schema = file.read() - - m_open = mock.mock_open() - # Provide enough file handles for all open() calls in validate_and_render_schema - m_open.side_effect = [ - # Removed empty read - was causing validation failures - mock.mock_open(read_data=plano_config).return_value, # PLANO_CONFIG_FILE - mock.mock_open( - read_data=plano_config_schema - ).return_value, # PLANO_CONFIG_SCHEMA_FILE - mock.mock_open(read_data=plano_config).return_value, # PLANO_CONFIG_FILE - mock.mock_open( - read_data=plano_config_schema - ).return_value, # PLANO_CONFIG_SCHEMA_FILE - mock.mock_open().return_value, # ENVOY_CONFIG_FILE_RENDERED (write) - mock.mock_open().return_value, # PLANO_CONFIG_FILE_RENDERED (write) - ] - with mock.patch("builtins.open", m_open): - with mock.patch("planoai.config_generator.Environment"): - validate_and_render_schema() - - -def test_validate_and_render_happy_path_agent_config(monkeypatch): - monkeypatch.setenv("PLANO_CONFIG_FILE", "fake_plano_config.yaml") - monkeypatch.setenv("PLANO_CONFIG_SCHEMA_FILE", "fake_plano_config_schema.yaml") - monkeypatch.setenv("ENVOY_CONFIG_TEMPLATE_FILE", "./envoy.template.yaml") - monkeypatch.setenv("PLANO_CONFIG_FILE_RENDERED", "fake_plano_config_rendered.yaml") - monkeypatch.setenv("ENVOY_CONFIG_FILE_RENDERED", "fake_envoy.yaml") - monkeypatch.setenv("TEMPLATE_ROOT", "../") - - plano_config = """ +""", + # Case 2: Agent config + """ version: v0.3.0 agents: @@ -122,25 +87,30 @@ def test_validate_and_render_happy_path_agent_config(monkeypatch): model_providers: - access_key: ${OPENAI_API_KEY} model: openai/gpt-4o -""" +""", + ], + ids=["llm_provider_config", "agent_config"], +) +def test_validate_and_render_happy_path(monkeypatch, plano_config): + monkeypatch.setenv("PLANO_CONFIG_FILE", "fake_plano_config.yaml") + monkeypatch.setenv("PLANO_CONFIG_SCHEMA_FILE", "fake_plano_config_schema.yaml") + monkeypatch.setenv("ENVOY_CONFIG_TEMPLATE_FILE", "./envoy.template.yaml") + monkeypatch.setenv("PLANO_CONFIG_FILE_RENDERED", "fake_plano_config_rendered.yaml") + monkeypatch.setenv("ENVOY_CONFIG_FILE_RENDERED", "fake_envoy.yaml") + monkeypatch.setenv("TEMPLATE_ROOT", "../") + plano_config_schema = "" with open("../config/plano_config_schema.yaml", "r") as file: plano_config_schema = file.read() m_open = mock.mock_open() - # Provide enough file handles for all open() calls in validate_and_render_schema m_open.side_effect = [ - # Removed empty read - was causing validation failures - mock.mock_open(read_data=plano_config).return_value, # PLANO_CONFIG_FILE - mock.mock_open( - read_data=plano_config_schema - ).return_value, # PLANO_CONFIG_SCHEMA_FILE - mock.mock_open(read_data=plano_config).return_value, # PLANO_CONFIG_FILE - mock.mock_open( - read_data=plano_config_schema - ).return_value, # PLANO_CONFIG_SCHEMA_FILE - mock.mock_open().return_value, # ENVOY_CONFIG_FILE_RENDERED (write) - mock.mock_open().return_value, # PLANO_CONFIG_FILE_RENDERED (write) + mock.mock_open(read_data=plano_config).return_value, + mock.mock_open(read_data=plano_config_schema).return_value, + mock.mock_open(read_data=plano_config).return_value, + mock.mock_open(read_data=plano_config_schema).return_value, + mock.mock_open().return_value, + mock.mock_open().return_value, ] with mock.patch("builtins.open", m_open): with mock.patch("planoai.config_generator.Environment"): @@ -344,124 +314,103 @@ def test_validate_and_render_schema_tests(monkeypatch, plano_config_test_case): validate_and_render_schema() -def test_convert_legacy_llm_providers(): - from planoai.utils import convert_legacy_listeners - - listeners = { - "ingress_traffic": { - "address": "0.0.0.0", - "port": 10000, - "timeout": "30s", - }, - "egress_traffic": { - "address": "0.0.0.0", - "port": 12000, - "timeout": "30s", - }, - } - llm_providers = [ - { - "model": "openai/gpt-4o", - "access_key": "test_key", - } - ] - - updated_providers, llm_gateway, prompt_gateway = convert_legacy_listeners( - listeners, llm_providers - ) - assert isinstance(updated_providers, list) - assert llm_gateway is not None - assert prompt_gateway is not None - print(json.dumps(updated_providers)) - assert updated_providers == [ - { - "name": "egress_traffic", - "type": "model_listener", - "port": 12000, - "address": "0.0.0.0", - "timeout": "30s", - "model_providers": [{"model": "openai/gpt-4o", "access_key": "test_key"}], - }, - { - "name": "ingress_traffic", - "type": "prompt_listener", - "port": 10000, - "address": "0.0.0.0", - "timeout": "30s", - }, - ] - - assert llm_gateway == { - "address": "0.0.0.0", - "model_providers": [ +@pytest.mark.parametrize( + "listeners,expected_providers,expected_llm_gateway,expected_prompt_gateway", + [ + # Case 1: With prompt gateway (ingress + egress) + ( + { + "ingress_traffic": { + "address": "0.0.0.0", + "port": 10000, + "timeout": "30s", + }, + "egress_traffic": { + "address": "0.0.0.0", + "port": 12000, + "timeout": "30s", + }, + }, + [ + { + "name": "egress_traffic", + "type": "model_listener", + "port": 12000, + "address": "0.0.0.0", + "timeout": "30s", + "model_providers": [ + {"model": "openai/gpt-4o", "access_key": "test_key"} + ], + }, + { + "name": "ingress_traffic", + "type": "prompt_listener", + "port": 10000, + "address": "0.0.0.0", + "timeout": "30s", + }, + ], + { + "address": "0.0.0.0", + "model_providers": [ + {"access_key": "test_key", "model": "openai/gpt-4o"} + ], + "name": "egress_traffic", + "type": "model_listener", + "port": 12000, + "timeout": "30s", + }, + { + "address": "0.0.0.0", + "name": "ingress_traffic", + "port": 10000, + "timeout": "30s", + "type": "prompt_listener", + }, + ), + # Case 2: Without prompt gateway (egress only) + ( + {"egress_traffic": {"address": "0.0.0.0", "port": 12000, "timeout": "30s"}}, + [ + { + "address": "0.0.0.0", + "model_providers": [ + {"access_key": "test_key", "model": "openai/gpt-4o"} + ], + "name": "egress_traffic", + "port": 12000, + "timeout": "30s", + "type": "model_listener", + } + ], { - "access_key": "test_key", - "model": "openai/gpt-4o", + "address": "0.0.0.0", + "model_providers": [ + {"access_key": "test_key", "model": "openai/gpt-4o"} + ], + "name": "egress_traffic", + "type": "model_listener", + "port": 12000, + "timeout": "30s", }, - ], - "name": "egress_traffic", - "type": "model_listener", - "port": 12000, - "timeout": "30s", - } - - assert prompt_gateway == { - "address": "0.0.0.0", - "name": "ingress_traffic", - "port": 10000, - "timeout": "30s", - "type": "prompt_listener", - } - - -def test_convert_legacy_llm_providers_no_prompt_gateway(): + None, + ), + ], + ids=["with_prompt_gateway", "without_prompt_gateway"], +) +def test_convert_legacy_llm_providers( + listeners, expected_providers, expected_llm_gateway, expected_prompt_gateway +): from planoai.utils import convert_legacy_listeners - listeners = { - "egress_traffic": { - "address": "0.0.0.0", - "port": 12000, - "timeout": "30s", - } - } - llm_providers = [ - { - "model": "openai/gpt-4o", - "access_key": "test_key", - } - ] - + llm_providers = [{"model": "openai/gpt-4o", "access_key": "test_key"}] updated_providers, llm_gateway, prompt_gateway = convert_legacy_listeners( listeners, llm_providers ) assert isinstance(updated_providers, list) assert llm_gateway is not None assert prompt_gateway is not None - assert updated_providers == [ - { - "address": "0.0.0.0", - "model_providers": [ - { - "access_key": "test_key", - "model": "openai/gpt-4o", - }, - ], - "name": "egress_traffic", - "port": 12000, - "timeout": "30s", - "type": "model_listener", - } - ] - assert llm_gateway == { - "address": "0.0.0.0", - "model_providers": [ - { - "access_key": "test_key", - "model": "openai/gpt-4o", - }, - ], - "name": "egress_traffic", - "type": "model_listener", - "port": 12000, - "timeout": "30s", - } + assert updated_providers == expected_providers + assert llm_gateway == expected_llm_gateway + if expected_prompt_gateway is not None: + assert prompt_gateway == expected_prompt_gateway diff --git a/cli/test/test_version_check.py b/cli/test/test_version_check.py index a00fba46e..5f3fa44fa 100644 --- a/cli/test/test_version_check.py +++ b/cli/test/test_version_check.py @@ -52,16 +52,17 @@ def test_current_is_newer(self): assert status["is_outdated"] is False assert status["message"] is None - def test_major_version_outdated(self): - status = check_version_status("0.4.1", "1.0.0") - assert status["is_outdated"] is True - - def test_minor_version_outdated(self): - status = check_version_status("0.4.1", "0.5.0") - assert status["is_outdated"] is True - - def test_patch_version_outdated(self): - status = check_version_status("0.4.1", "0.4.2") + @pytest.mark.parametrize( + "current,latest", + [ + ("0.4.1", "1.0.0"), # major + ("0.4.1", "0.5.0"), # minor + ("0.4.1", "0.4.2"), # patch + ], + ids=["major", "minor", "patch"], + ) + def test_version_outdated(self, current, latest): + status = check_version_status(current, latest) assert status["is_outdated"] is True def test_latest_is_none(self): @@ -153,11 +154,3 @@ def test_up_to_date_version(self): status = check_version_status(current_version, latest) assert status["is_outdated"] is False - - def test_skip_version_check_env_var(self, monkeypatch): - """Test that PLANO_SKIP_VERSION_CHECK skips the check.""" - monkeypatch.setenv("PLANO_SKIP_VERSION_CHECK", "1") - - import os - - assert os.environ.get("PLANO_SKIP_VERSION_CHECK") == "1" diff --git a/crates/brightstaff/src/handlers/function_calling.rs b/crates/brightstaff/src/handlers/function_calling.rs index 0d6c6d3c8..a2f4fd691 100644 --- a/crates/brightstaff/src/handlers/function_calling.rs +++ b/crates/brightstaff/src/handlers/function_calling.rs @@ -1481,12 +1481,6 @@ mod tests { assert!(config.format_prompt.contains(r#"{\"tool_calls\": [{"#)); } - #[test] - fn test_arch_agent_config_default() { - let config = ArchAgentConfig::default(); - assert_eq!(config.generation_params.temperature, 0.01); // Different from ArchFunctionConfig - } - #[test] fn test_fix_json_string_valid() { let handler = ArchFunctionHandler::new( diff --git a/crates/brightstaff/src/router/orchestrator_model_v1.rs b/crates/brightstaff/src/router/orchestrator_model_v1.rs index c6d3d56d1..9835de3f2 100644 --- a/crates/brightstaff/src/router/orchestrator_model_v1.rs +++ b/crates/brightstaff/src/router/orchestrator_model_v1.rs @@ -415,6 +415,17 @@ mod tests { use super::*; use pretty_assertions::assert_eq; + fn default_orchestrations() -> HashMap> { + serde_json::from_str( + r#"{"gpt-4o": [{"name": "Image generation", "description": "generating image"}]}"#, + ) + .unwrap() + } + + fn default_conversation() -> Vec { + serde_json::from_str(r#"[{"role": "user", "content": "hi"},{"role": "assistant", "content": "Hello! How can I assist you today?"},{"role": "user", "content": "given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"}]"#).unwrap() + } + #[test] fn test_spaced_json_formatter() { // Test basic object @@ -509,41 +520,12 @@ Return your answer strictly in JSON as follows: {{"route": ["route_name_1", "route_name_2", "..."]}} If no routes are needed, return an empty list for `route`. "#; - let orchestrations_str = r#" - { - "gpt-4o": [ - {"name": "Image generation", "description": "generating image"} - ] - } - "#; - let agent_orchestrations = serde_json::from_str::< - HashMap>, - >(orchestrations_str) - .unwrap(); - let orchestration_model = "test-model".to_string(); let orchestrator = OrchestratorModelV1::new( - agent_orchestrations, - orchestration_model.clone(), + default_orchestrations(), + "test-model".to_string(), usize::MAX, ); - - let conversation_str = r#" - [ - { - "role": "user", - "content": "hi" - }, - { - "role": "assistant", - "content": "Hello! How can I assist you today?" - }, - { - "role": "user", - "content": "given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson" - } - ] - "#; - let conversation: Vec = serde_json::from_str(conversation_str).unwrap(); + let conversation = default_conversation(); let req = orchestrator.generate_request(&conversation, &None); @@ -591,31 +573,9 @@ Return your answer strictly in JSON as follows: If no routes are needed, return an empty list for `route`. "#; // Empty orchestrations map - not used when usage_preferences are provided - let agent_orchestrations: HashMap> = HashMap::new(); - let orchestration_model = "test-model".to_string(); - let orchestrator = OrchestratorModelV1::new( - agent_orchestrations, - orchestration_model.clone(), - usize::MAX, - ); - - let conversation_str = r#" - [ - { - "role": "user", - "content": "hi" - }, - { - "role": "assistant", - "content": "Hello! How can I assist you today?" - }, - { - "role": "user", - "content": "given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson" - } - ] - "#; - let conversation: Vec = serde_json::from_str(conversation_str).unwrap(); + let orchestrator = + OrchestratorModelV1::new(HashMap::new(), "test-model".to_string(), usize::MAX); + let conversation = default_conversation(); let usage_preferences = Some(vec![AgentUsagePreference { model: "claude/claude-3-7-sonnet".to_string(), @@ -662,38 +622,9 @@ Return your answer strictly in JSON as follows: If no routes are needed, return an empty list for `route`. "#; - let orchestrations_str = r#" - { - "gpt-4o": [ - {"name": "Image generation", "description": "generating image"} - ] - } - "#; - let agent_orchestrations = serde_json::from_str::< - HashMap>, - >(orchestrations_str) - .unwrap(); - let orchestration_model = "test-model".to_string(); - let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model, 235); - - let conversation_str = r#" - [ - { - "role": "user", - "content": "hi" - }, - { - "role": "assistant", - "content": "Hello! How can I assist you today?" - }, - { - "role": "user", - "content": "given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson" - } - ] - "#; - - let conversation: Vec = serde_json::from_str(conversation_str).unwrap(); + let orchestrator = + OrchestratorModelV1::new(default_orchestrations(), "test-model".to_string(), 235); + let conversation = default_conversation(); let req = orchestrator.generate_request(&conversation, &None); @@ -733,20 +664,8 @@ Return your answer strictly in JSON as follows: If no routes are needed, return an empty list for `route`. "#; - let orchestrations_str = r#" - { - "gpt-4o": [ - {"name": "Image generation", "description": "generating image"} - ] - } - "#; - let agent_orchestrations = serde_json::from_str::< - HashMap>, - >(orchestrations_str) - .unwrap(); - - let orchestration_model = "test-model".to_string(); - let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model, 200); + let orchestrator = + OrchestratorModelV1::new(default_orchestrations(), "test-model".to_string(), 200); let conversation_str = r#" [ @@ -813,19 +732,8 @@ Return your answer strictly in JSON as follows: If no routes are needed, return an empty list for `route`. "#; - let orchestrations_str = r#" - { - "gpt-4o": [ - {"name": "Image generation", "description": "generating image"} - ] - } - "#; - let agent_orchestrations = serde_json::from_str::< - HashMap>, - >(orchestrations_str) - .unwrap(); - let orchestration_model = "test-model".to_string(); - let orchestrator = OrchestratorModelV1::new(agent_orchestrations, orchestration_model, 230); + let orchestrator = + OrchestratorModelV1::new(default_orchestrations(), "test-model".to_string(), 230); let conversation_str = r#" [ @@ -899,21 +807,9 @@ Return your answer strictly in JSON as follows: {{"route": ["route_name_1", "route_name_2", "..."]}} If no routes are needed, return an empty list for `route`. "#; - let orchestrations_str = r#" - { - "gpt-4o": [ - {"name": "Image generation", "description": "generating image"} - ] - } - "#; - let agent_orchestrations = serde_json::from_str::< - HashMap>, - >(orchestrations_str) - .unwrap(); - let orchestration_model = "test-model".to_string(); let orchestrator = OrchestratorModelV1::new( - agent_orchestrations, - orchestration_model.clone(), + default_orchestrations(), + "test-model".to_string(), usize::MAX, ); @@ -991,21 +887,9 @@ Return your answer strictly in JSON as follows: {{"route": ["route_name_1", "route_name_2", "..."]}} If no routes are needed, return an empty list for `route`. "#; - let orchestrations_str = r#" - { - "gpt-4o": [ - {"name": "Image generation", "description": "generating image"} - ] - } - "#; - let agent_orchestrations = serde_json::from_str::< - HashMap>, - >(orchestrations_str) - .unwrap(); - let orchestration_model = "test-model".to_string(); let orchestrator = OrchestratorModelV1::new( - agent_orchestrations, - orchestration_model.clone(), + default_orchestrations(), + "test-model".to_string(), usize::MAX, ); diff --git a/crates/brightstaff/src/router/router_model_v1.rs b/crates/brightstaff/src/router/router_model_v1.rs index 430b4f8e3..07173668b 100644 --- a/crates/brightstaff/src/router/router_model_v1.rs +++ b/crates/brightstaff/src/router/router_model_v1.rs @@ -299,6 +299,17 @@ mod tests { use super::*; use pretty_assertions::assert_eq; + fn default_routes() -> HashMap> { + serde_json::from_str( + r#"{"gpt-4o": [{"name": "Image generation", "description": "generating image"}]}"#, + ) + .unwrap() + } + + fn default_conversation() -> Vec { + serde_json::from_str(r#"[{"role": "user", "content": "hi"},{"role": "assistant", "content": "Hello! How can I assist you today?"},{"role": "user", "content": "given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"}]"#).unwrap() + } + #[test] fn test_system_prompt_format() { let expected_prompt = r#" @@ -320,35 +331,8 @@ Your task is to decide which route is best suit with user intent on the conversa Based on your analysis, provide your response in the following JSON formats if you decide to match any route: {"route": "route_name"} "#; - let routes_str = r#" - { - "gpt-4o": [ - {"name": "Image generation", "description": "generating image"} - ] - } - "#; - let llm_routes = - serde_json::from_str::>>(routes_str).unwrap(); - let routing_model = "test-model".to_string(); - let router = RouterModelV1::new(llm_routes, routing_model, usize::MAX); - - let conversation_str = r#" - [ - { - "role": "user", - "content": "hi" - }, - { - "role": "assistant", - "content": "Hello! How can I assist you today?" - }, - { - "role": "user", - "content": "given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson" - } - ] - "#; - let conversation: Vec = serde_json::from_str(conversation_str).unwrap(); + let router = RouterModelV1::new(default_routes(), "test-model".to_string(), usize::MAX); + let conversation = default_conversation(); let req = router.generate_request(&conversation, &None); @@ -378,35 +362,8 @@ Your task is to decide which route is best suit with user intent on the conversa Based on your analysis, provide your response in the following JSON formats if you decide to match any route: {"route": "route_name"} "#; - let routes_str = r#" - { - "gpt-4o": [ - {"name": "Image generation", "description": "generating image"} - ] - } - "#; - let llm_routes = - serde_json::from_str::>>(routes_str).unwrap(); - let routing_model = "test-model".to_string(); - let router = RouterModelV1::new(llm_routes, routing_model, usize::MAX); - - let conversation_str = r#" - [ - { - "role": "user", - "content": "hi" - }, - { - "role": "assistant", - "content": "Hello! How can I assist you today?" - }, - { - "role": "user", - "content": "given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson" - } - ] - "#; - let conversation: Vec = serde_json::from_str(conversation_str).unwrap(); + let router = RouterModelV1::new(default_routes(), "test-model".to_string(), usize::MAX); + let conversation = default_conversation(); let usage_preferences = Some(vec![ModelUsagePreference { model: "claude/claude-3-7-sonnet".to_string(), @@ -444,36 +401,8 @@ Based on your analysis, provide your response in the following JSON formats if y {"route": "route_name"} "#; - let routes_str = r#" - { - "gpt-4o": [ - {"name": "Image generation", "description": "generating image"} - ] - } - "#; - let llm_routes = - serde_json::from_str::>>(routes_str).unwrap(); - let routing_model = "test-model".to_string(); - let router = RouterModelV1::new(llm_routes, routing_model, 235); - - let conversation_str = r#" - [ - { - "role": "user", - "content": "hi" - }, - { - "role": "assistant", - "content": "Hello! How can I assist you today?" - }, - { - "role": "user", - "content": "given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson" - } - ] - "#; - - let conversation: Vec = serde_json::from_str(conversation_str).unwrap(); + let router = RouterModelV1::new(default_routes(), "test-model".to_string(), 235); + let conversation = default_conversation(); let req = router.generate_request(&conversation, &None); @@ -504,18 +433,7 @@ Based on your analysis, provide your response in the following JSON formats if y {"route": "route_name"} "#; - let routes_str = r#" - { - "gpt-4o": [ - {"name": "Image generation", "description": "generating image"} - ] - } - "#; - let llm_routes = - serde_json::from_str::>>(routes_str).unwrap(); - - let routing_model = "test-model".to_string(); - let router = RouterModelV1::new(llm_routes, routing_model, 200); + let router = RouterModelV1::new(default_routes(), "test-model".to_string(), 200); let conversation_str = r#" [ @@ -565,17 +483,7 @@ Based on your analysis, provide your response in the following JSON formats if y {"route": "route_name"} "#; - let routes_str = r#" - { - "gpt-4o": [ - {"name": "Image generation", "description": "generating image"} - ] - } - "#; - let llm_routes = - serde_json::from_str::>>(routes_str).unwrap(); - let routing_model = "test-model".to_string(); - let router = RouterModelV1::new(llm_routes, routing_model, 230); + let router = RouterModelV1::new(default_routes(), "test-model".to_string(), 230); let conversation_str = r#" [ @@ -632,17 +540,7 @@ Your task is to decide which route is best suit with user intent on the conversa Based on your analysis, provide your response in the following JSON formats if you decide to match any route: {"route": "route_name"} "#; - let routes_str = r#" - { - "gpt-4o": [ - {"name": "Image generation", "description": "generating image"} - ] - } - "#; - let llm_routes = - serde_json::from_str::>>(routes_str).unwrap(); - let routing_model = "test-model".to_string(); - let router = RouterModelV1::new(llm_routes, routing_model, usize::MAX); + let router = RouterModelV1::new(default_routes(), "test-model".to_string(), usize::MAX); let conversation_str = r#" [ @@ -701,17 +599,7 @@ Your task is to decide which route is best suit with user intent on the conversa Based on your analysis, provide your response in the following JSON formats if you decide to match any route: {"route": "route_name"} "#; - let routes_str = r#" - { - "gpt-4o": [ - {"name": "Image generation", "description": "generating image"} - ] - } - "#; - let llm_routes = - serde_json::from_str::>>(routes_str).unwrap(); - let routing_model = "test-model".to_string(); - let router = RouterModelV1::new(llm_routes, routing_model, usize::MAX); + let router = RouterModelV1::new(default_routes(), "test-model".to_string(), usize::MAX); let conversation_str = r#" [ @@ -777,17 +665,7 @@ Based on your analysis, provide your response in the following JSON formats if y #[test] fn test_parse_response() { - let routes_str = r#" - { - "gpt-4o": [ - {"name": "Image generation", "description": "generating image"} - ] - } - "#; - let llm_routes = - serde_json::from_str::>>(routes_str).unwrap(); - - let router = RouterModelV1::new(llm_routes, "test-model".to_string(), 2000); + let router = RouterModelV1::new(default_routes(), "test-model".to_string(), 2000); // Case 1: Valid JSON with non-empty route let input = r#"{"route": "Image generation"}"#; diff --git a/crates/brightstaff/src/signals/analyzer.rs b/crates/brightstaff/src/signals/analyzer.rs index 5ee3c7d9a..a70be4a30 100644 --- a/crates/brightstaff/src/signals/analyzer.rs +++ b/crates/brightstaff/src/signals/analyzer.rs @@ -1959,145 +1959,109 @@ mod tests { // ======================================================================== #[test] - fn test_char_ngram_similarity_exact_match() { - let msg = NormalizedMessage::from_text("thank you very much"); - let similarity = msg.char_ngram_similarity("thank you very much"); - assert!( - similarity > 0.95, - "Exact match should have very high similarity" - ); - } - - #[test] - fn test_char_ngram_similarity_typo() { - let msg = NormalizedMessage::from_text("thank you very much"); - // Common typo: "thnks" instead of "thanks" - let similarity = msg.char_ngram_similarity("thnks you very much"); - assert!( - similarity > 0.50, - "Should handle single-character typo with decent similarity: {}", - similarity - ); - } - - #[test] - fn test_char_ngram_similarity_small_edit() { - let msg = NormalizedMessage::from_text("this doesn't work"); - let similarity = msg.char_ngram_similarity("this doesnt work"); - assert!( - similarity > 0.70, - "Should handle punctuation removal gracefully: {}", - similarity - ); - } - - #[test] - fn test_char_ngram_similarity_word_insertion() { - let msg = NormalizedMessage::from_text("i don't understand"); - let similarity = msg.char_ngram_similarity("i really don't understand"); - assert!( - similarity > 0.40, - "Should be robust to word insertions: {}", - similarity - ); - } - - #[test] - fn test_token_cosine_similarity_exact_match() { - let msg = NormalizedMessage::from_text("this is not helpful"); - let similarity = msg.token_cosine_similarity("this is not helpful"); - assert!( - (similarity - 1.0).abs() < 0.01, - "Exact match should have cosine similarity of 1.0" - ); - } - - #[test] - fn test_token_cosine_similarity_word_order() { - let msg = NormalizedMessage::from_text("not helpful at all"); - let similarity = msg.token_cosine_similarity("helpful not at all"); - assert!( - similarity > 0.95, - "Should be robust to word order changes: {}", - similarity - ); - } - - #[test] - fn test_token_cosine_similarity_frequency() { - let msg = NormalizedMessage::from_text("help help help please"); - let similarity = msg.token_cosine_similarity("help please"); - assert!( - similarity > 0.7 && similarity < 1.0, - "Should account for frequency differences: {}", - similarity - ); - } - - #[test] - fn test_token_cosine_similarity_long_message_with_context() { - let msg = NormalizedMessage::from_text( - "I've been trying to set up my account for the past hour \ - and the verification email never arrived. I checked my spam folder \ - and still nothing. This is really frustrating and not helpful at all.", - ); - let similarity = msg.token_cosine_similarity("not helpful"); - assert!( - similarity > 0.15 && similarity < 0.7, - "Should detect pattern in long message with lower but non-zero similarity: {}", - similarity - ); - } - - #[test] - fn test_layered_matching_exact_hit() { - let msg = NormalizedMessage::from_text("thank you so much"); - assert!( - msg.layered_contains_phrase("thank you", 0.50, 0.60), - "Should match exact phrase in Layer 0" - ); - } - - #[test] - fn test_layered_matching_typo_hit() { - // Test that shows layered matching is more robust than exact matching alone - let msg = NormalizedMessage::from_text("it doesnt work for me"); - - // "doesnt work" should match "doesn't work" via character ngrams (high overlap) - assert!( - msg.layered_contains_phrase("doesn't work", 0.50, 0.60), - "Should match 'doesnt work' to 'doesn't work' via character ngrams" - ); - } - - #[test] - fn test_layered_matching_word_order_hit() { - let msg = NormalizedMessage::from_text("helpful not very"); - assert!( - msg.layered_contains_phrase("not helpful", 0.50, 0.60), - "Should match reordered words via token cosine in Layer 2" - ); + fn test_char_ngram_similarity() { + let cases = [ + ( + "thank you very much", + "thank you very much", + 0.95, + "exact match", + ), + ("thank you very much", "thnks you very much", 0.50, "typo"), + ("this doesn't work", "this doesnt work", 0.70, "small edit"), + ( + "i don't understand", + "i really don't understand", + 0.40, + "word insertion", + ), + ]; + for (msg_text, pattern, threshold, label) in cases { + let msg = NormalizedMessage::from_text(msg_text); + let similarity = msg.char_ngram_similarity(pattern); + assert!( + similarity > threshold, + "{}: expected > {}, got {}", + label, + threshold, + similarity + ); + } } #[test] - fn test_layered_matching_long_message_with_pattern() { - let msg = NormalizedMessage::from_text( - "I've tried everything and followed all the instructions \ - but this is not helpful at all and I'm getting frustrated", - ); - assert!( - msg.layered_contains_phrase("not helpful", 0.50, 0.60), - "Should detect pattern buried in long message" - ); + fn test_token_cosine_similarity() { + let cases: Vec<(&str, &str, f64, f64, &str)> = vec![ + ( + "this is not helpful", + "this is not helpful", + 0.99, + 1.01, + "exact match", + ), + ( + "not helpful at all", + "helpful not at all", + 0.95, + 2.0, + "word order", + ), + ( + "help help help please", + "help please", + 0.7, + 1.0, + "frequency", + ), + ( + "I've been trying to set up my account for the past hour \ + and the verification email never arrived. I checked my spam folder \ + and still nothing. This is really frustrating and not helpful at all.", + "not helpful", + 0.15, + 0.7, + "long message with context", + ), + ]; + for (msg_text, pattern, min, max, label) in cases { + let msg = NormalizedMessage::from_text(msg_text); + let similarity = msg.token_cosine_similarity(pattern); + assert!( + similarity > min && similarity < max, + "{}: expected ({}, {}), got {}", + label, + min, + max, + similarity + ); + } } #[test] - fn test_layered_matching_no_match() { - let msg = NormalizedMessage::from_text("everything is working perfectly"); - assert!( - !msg.layered_contains_phrase("not helpful", 0.50, 0.60), - "Should not match completely different content" - ); + fn test_layered_matching() { + let cases = [ + ("thank you so much", "thank you", true, "exact hit"), + ("it doesnt work for me", "doesn't work", true, "typo hit"), + ("helpful not very", "not helpful", true, "word order hit"), + ( + "I've tried everything and followed all the instructions \ + but this is not helpful at all and I'm getting frustrated", + "not helpful", + true, + "long message with pattern", + ), + ( + "everything is working perfectly", + "not helpful", + false, + "no match", + ), + ]; + for (msg_text, pattern, expected, label) in cases { + let msg = NormalizedMessage::from_text(msg_text); + let result = msg.layered_contains_phrase(pattern, 0.50, 0.60); + assert_eq!(result, expected, "{}: expected {}", label, expected); + } } #[test] @@ -2139,7 +2103,6 @@ mod tests { #[test] fn test_turn_count_efficient() { - let start = Instant::now(); let analyzer = TextBasedSignalAnalyzer::new(); let messages = vec![ create_message(Role::User, "Hello"), @@ -2154,12 +2117,10 @@ mod tests { assert!(!signal.is_concerning); assert!(!signal.is_excessive); assert!(signal.efficiency_score > 0.9); - println!("test_turn_count_efficient took: {:?}", start.elapsed()); } #[test] fn test_turn_count_excessive() { - let start = Instant::now(); let analyzer = TextBasedSignalAnalyzer::new(); let mut messages = Vec::new(); for i in 0..15 { @@ -2178,12 +2139,10 @@ mod tests { assert!(signal.is_concerning); assert!(signal.is_excessive); assert!(signal.efficiency_score < 0.5); - println!("test_turn_count_excessive took: {:?}", start.elapsed()); } #[test] fn test_follow_up_detection() { - let start = Instant::now(); let analyzer = TextBasedSignalAnalyzer::new(); let messages = vec![ create_message(Role::User, "Show me restaurants"), @@ -2196,12 +2155,10 @@ mod tests { let signal = analyzer.analyze_follow_up(&normalized_messages); assert_eq!(signal.repair_count, 1); assert!(signal.repair_ratio > 0.0); - println!("test_follow_up_detection took: {:?}", start.elapsed()); } #[test] fn test_frustration_detection() { - let start = Instant::now(); let analyzer = TextBasedSignalAnalyzer::new(); let messages = vec![ create_message(Role::User, "THIS IS RIDICULOUS!!!"), @@ -2214,12 +2171,10 @@ mod tests { assert!(signal.has_frustration); assert!(signal.frustration_count >= 2); assert!(signal.severity > 0); - println!("test_frustration_detection took: {:?}", start.elapsed()); } #[test] fn test_positive_feedback_detection() { - let start = Instant::now(); let analyzer = TextBasedSignalAnalyzer::new(); let messages = vec![ create_message(Role::User, "Can you help me?"), @@ -2232,15 +2187,10 @@ mod tests { assert!(signal.has_positive_feedback); assert!(signal.positive_count >= 1); assert!(signal.confidence > 0.5); - println!( - "test_positive_feedback_detection took: {:?}", - start.elapsed() - ); } #[test] fn test_escalation_detection() { - let start = Instant::now(); let analyzer = TextBasedSignalAnalyzer::new(); let messages = vec![ create_message(Role::User, "This isn't working"), @@ -2252,12 +2202,10 @@ mod tests { let signal = analyzer.analyze_escalation(&normalized_messages); assert!(signal.escalation_requested); assert_eq!(signal.escalation_count, 1); - println!("test_escalation_detection took: {:?}", start.elapsed()); } #[test] fn test_repetition_detection() { - let start = Instant::now(); let analyzer = TextBasedSignalAnalyzer::new(); let messages = vec![ create_message(Role::User, "What's the weather?"), @@ -2273,22 +2221,13 @@ mod tests { let normalized_messages = preprocess_messages(&messages); let signal = analyzer.analyze_repetition(&normalized_messages); - for rep in &signal.repetitions { - println!( - " - Messages {:?}, similarity: {:.3}, type: {:?}", - rep.message_indices, rep.similarity, rep.repetition_type - ); - } - assert!(signal.repetition_count > 0, "Should detect the subtle repetition between 'I can help you with the weather information' \ and 'Sure, I can help you with the forecast'"); - println!("test_repetition_detection took: {:?}", start.elapsed()); } #[test] fn test_full_analysis_excellent() { - let start = Instant::now(); let analyzer = TextBasedSignalAnalyzer::new(); let messages = vec![ create_message(Role::User, "I need to book a flight"), @@ -2305,12 +2244,10 @@ mod tests { )); assert!(report.positive_feedback.has_positive_feedback); assert!(!report.frustration.has_frustration); - println!("test_full_analysis_excellent took: {:?}", start.elapsed()); } #[test] fn test_full_analysis_poor() { - let start = Instant::now(); let analyzer = TextBasedSignalAnalyzer::new(); let messages = vec![ create_message(Role::User, "Help me"), @@ -2329,86 +2266,64 @@ mod tests { )); assert!(report.frustration.has_frustration); assert!(report.escalation.escalation_requested); - println!("test_full_analysis_poor took: {:?}", start.elapsed()); } #[test] - fn test_fuzzy_matching_gratitude() { - let start = Instant::now(); + fn test_fuzzy_matching() { let analyzer = TextBasedSignalAnalyzer::new(); + + // Gratitude with typo let messages = vec![ create_message(Role::User, "Can you help me?"), create_message(Role::Assistant, "Sure!"), create_message(Role::User, "thnaks! that's exactly what i needed."), ]; - - let normalized_messages = preprocess_messages(&messages); - let signal = analyzer.analyze_positive_feedback(&normalized_messages); - assert!(signal.has_positive_feedback); + let normalized = preprocess_messages(&messages); + let signal = analyzer.analyze_positive_feedback(&normalized); + assert!( + signal.has_positive_feedback, + "fuzzy gratitude should be detected" + ); assert!(signal.positive_count >= 1); - println!("test_fuzzy_matching_gratitude took: {:?}", start.elapsed()); - } - #[test] - fn test_fuzzy_matching_escalation() { - let start = Instant::now(); - let analyzer = TextBasedSignalAnalyzer::new(); + // Escalation with typo let messages = vec![ create_message(Role::User, "This isn't working"), create_message(Role::Assistant, "Let me help"), create_message(Role::User, "i need to speek to a human agnet"), ]; - - let normalized_messages = preprocess_messages(&messages); - let signal = analyzer.analyze_escalation(&normalized_messages); - assert!(signal.escalation_requested); + let normalized = preprocess_messages(&messages); + let signal = analyzer.analyze_escalation(&normalized); + assert!( + signal.escalation_requested, + "fuzzy escalation should be detected" + ); assert_eq!(signal.escalation_count, 1); - println!("test_fuzzy_matching_escalation took: {:?}", start.elapsed()); - } - #[test] - fn test_fuzzy_matching_repair() { - let start = Instant::now(); - let analyzer = TextBasedSignalAnalyzer::new(); + // Repair with typo let messages = vec![ create_message(Role::User, "Show me restaurants"), create_message(Role::Assistant, "Here are some options"), create_message(Role::User, "no i ment Italian restaurants"), create_message(Role::Assistant, "Here are Italian restaurants"), ]; + let normalized = preprocess_messages(&messages); + let signal = analyzer.analyze_follow_up(&normalized); + assert!(signal.repair_count >= 1, "fuzzy repair should be detected"); - let normalized_messages = preprocess_messages(&messages); - let signal = analyzer.analyze_follow_up(&normalized_messages); - assert!(signal.repair_count >= 1); - println!("test_fuzzy_matching_repair took: {:?}", start.elapsed()); - } - - #[test] - fn test_fuzzy_matching_complaint() { - let start = Instant::now(); - let analyzer = TextBasedSignalAnalyzer::new(); - // Use a complaint that should match - "doesnt work" is close enough to "doesn't work" + // Complaint with typo let messages = vec![ - create_message(Role::User, "this doesnt work at all"), // Common typo: missing apostrophe + create_message(Role::User, "this doesnt work at all"), create_message(Role::Assistant, "I apologize"), ]; - - let normalized_messages = preprocess_messages(&messages); - let signal = analyzer.analyze_frustration(&normalized_messages); - - // The layered matching should catch this via character ngrams or token cosine - // "doesnt work" has high character-level similarity to "doesn't work" - assert!( - signal.has_frustration, - "Should detect frustration from complaint pattern" - ); + let normalized = preprocess_messages(&messages); + let signal = analyzer.analyze_frustration(&normalized); + assert!(signal.has_frustration, "fuzzy complaint should be detected"); assert!(signal.frustration_count >= 1); - println!("test_fuzzy_matching_complaint took: {:?}", start.elapsed()); } #[test] fn test_exact_match_priority() { - let start = Instant::now(); let analyzer = TextBasedSignalAnalyzer::new(); let messages = vec![create_message(Role::User, "thank you so much")]; @@ -2418,7 +2333,6 @@ mod tests { // Should detect exact match, not fuzzy assert!(signal.indicators[0].snippet.contains("thank you")); assert!(!signal.indicators[0].snippet.contains("fuzzy")); - println!("test_exact_match_priority took: {:?}", start.elapsed()); } // ======================================================================== @@ -2426,31 +2340,54 @@ mod tests { // ======================================================================== #[test] - fn test_hello_not_profanity() { + fn test_false_positive_guards() { let analyzer = TextBasedSignalAnalyzer::new(); - let messages = vec![create_message(Role::User, "hello there")]; - let normalized_messages = preprocess_messages(&messages); - let signal = analyzer.analyze_frustration(&normalized_messages); + // "hello" should not trigger profanity + let messages = vec![create_message(Role::User, "hello there")]; + let normalized = preprocess_messages(&messages); + let signal = analyzer.analyze_frustration(&normalized); assert!( !signal.has_frustration, - "\"hello\" should not trigger profanity detection" + "\"hello\" should not trigger profanity" ); - } - #[test] - fn test_prepare_not_escalation() { - let analyzer = TextBasedSignalAnalyzer::new(); + // "prepare" should not trigger escalation let messages = vec![create_message( Role::User, "Can you help me prepare for the meeting?", )]; - - let normalized_messages = preprocess_messages(&messages); - let signal = analyzer.analyze_escalation(&normalized_messages); + let normalized = preprocess_messages(&messages); + let signal = analyzer.analyze_escalation(&normalized); assert!( !signal.escalation_requested, - "\"prepare\" should not trigger escalation (rep pattern removed)" + "\"prepare\" should not trigger escalation" + ); + + // "absolute" should not trigger 'bs' match + let messages = vec![create_message(Role::User, "That's absolute nonsense")]; + let normalized = preprocess_messages(&messages); + let signal = analyzer.analyze_frustration(&normalized); + let has_bs_match = signal + .indicators + .iter() + .any(|ind| ind.snippet.contains("bs")); + assert!( + !has_bs_match, + "\"absolute\" should not trigger 'bs' profanity match" + ); + + // Stopwords-only overlap should not be rephrase + let messages = vec![ + create_message(Role::User, "Help me with X"), + create_message(Role::Assistant, "Sure"), + create_message(Role::User, "Help me with Y"), + ]; + let normalized = preprocess_messages(&messages); + let signal = analyzer.analyze_follow_up(&normalized); + assert_eq!( + signal.repair_count, 0, + "Messages with only stopword overlap should not be rephrases" ); } @@ -2485,42 +2422,6 @@ mod tests { ); } - #[test] - fn test_absolute_not_profanity() { - let analyzer = TextBasedSignalAnalyzer::new(); - let messages = vec![create_message(Role::User, "That's absolute nonsense")]; - - let normalized_messages = preprocess_messages(&messages); - let signal = analyzer.analyze_frustration(&normalized_messages); - // Should match on "nonsense" logic, not on "bs" substring - let has_bs_match = signal - .indicators - .iter() - .any(|ind| ind.snippet.contains("bs")); - assert!( - !has_bs_match, - "\"absolute\" should not trigger 'bs' profanity match" - ); - } - - #[test] - fn test_stopwords_not_rephrase() { - let analyzer = TextBasedSignalAnalyzer::new(); - let messages = vec![ - create_message(Role::User, "Help me with X"), - create_message(Role::Assistant, "Sure"), - create_message(Role::User, "Help me with Y"), - ]; - - let normalized_messages = preprocess_messages(&messages); - let signal = analyzer.analyze_follow_up(&normalized_messages); - // Should not detect as rephrase since only stopwords overlap - assert_eq!( - signal.repair_count, 0, - "Messages with only stopword overlap should not be rephrases" - ); - } - #[test] fn test_frustrated_user_with_legitimate_repair() { let start = Instant::now(); @@ -2794,23 +2695,44 @@ mod tests { // false negative tests #[test] - fn test_dissatisfaction_polite_not_working_for_me() { + fn test_dissatisfaction_and_low_mood() { let analyzer = TextBasedSignalAnalyzer::new(); - let messages = vec![ - create_message(Role::User, "Thanks, but this still isn't working for me."), // Polite dissatisfaction, e.g., I appreciate it, but this isn't what I was looking for. - create_message(Role::Assistant, "Sorry—what error do you see?"), + + // Cases that should trigger frustration + let frustration_cases = [ + ( + "Thanks, but this still isn't working for me.", + "polite not working", + ), + ( + "I'm running into the same issue again.", + "same problem again", + ), + ("This feels incomplete.", "incomplete"), + ( + "This is overwhelming and I'm not sure what to do.", + "overwhelming", + ), + ( + "I'm exhausted trying to get this working.", + "exhausted trying", + ), ]; - let normalized = preprocess_messages(&messages); - let signal = analyzer.analyze_frustration(&normalized); - assert!( - signal.has_frustration, - "Polite dissatisfaction should be detected" - ); - } + for (msg, label) in frustration_cases { + let messages = vec![ + create_message(Role::User, msg), + create_message(Role::Assistant, "Sorry about that."), + ]; + let normalized = preprocess_messages(&messages); + let signal = analyzer.analyze_frustration(&normalized); + assert!( + signal.has_frustration, + "{}: should detect frustration", + label + ); + } - #[test] - fn test_dissatisfaction_giving_up_without_escalation() { - let analyzer = TextBasedSignalAnalyzer::new(); + // Case that should trigger escalation (giving up) let messages = vec![create_message( Role::User, "Never mind, I'll figure it out myself.", @@ -2819,61 +2741,7 @@ mod tests { let signal = analyzer.analyze_escalation(&normalized); assert!( signal.escalation_requested, - "Giving up should count as escalation/quit intent" - ); - } - - #[test] - fn test_dissatisfaction_same_problem_again() { - let analyzer = TextBasedSignalAnalyzer::new(); - let messages = vec![create_message( - Role::User, - "I'm running into the same issue again.", - )]; - let normalized = preprocess_messages(&messages); - let signal = analyzer.analyze_frustration(&normalized); - assert!( - signal.has_frustration, - "'same issue again' should be detected" - ); - } - - #[test] - fn test_unsatisfied_incomplete() { - let analyzer = TextBasedSignalAnalyzer::new(); - let messages = vec![create_message(Role::User, "This feels incomplete.")]; - let normalized = preprocess_messages(&messages); - let signal = analyzer.analyze_frustration(&normalized); - assert!( - signal.has_frustration, - "Should detect 'incomplete' dissatisfaction" - ); - } - - #[test] - fn test_low_mood_overwhelming() { - let analyzer = TextBasedSignalAnalyzer::new(); - let messages = vec![create_message( - Role::User, - "This is overwhelming and I'm not sure what to do.", - )]; - let normalized = preprocess_messages(&messages); - let signal = analyzer.analyze_frustration(&normalized); - assert!(signal.has_frustration, "Should detect overwhelmed language"); - } - - #[test] - fn test_low_mood_exhausted_trying() { - let analyzer = TextBasedSignalAnalyzer::new(); - let messages = vec![create_message( - Role::User, - "I'm exhausted trying to get this working.", - )]; - let normalized = preprocess_messages(&messages); - let signal = analyzer.analyze_frustration(&normalized); - assert!( - signal.has_frustration, - "Should detect exhaustion/struggle language" + "giving up should count as escalation" ); } diff --git a/crates/brightstaff/src/state/memory.rs b/crates/brightstaff/src/state/memory.rs index be4d82324..92b6b318d 100644 --- a/crates/brightstaff/src/state/memory.rs +++ b/crates/brightstaff/src/state/memory.rs @@ -85,11 +85,30 @@ impl StateStorage for MemoryConversationalStorage { #[cfg(test)] mod tests { use super::*; + use crate::state::generate_storage_tests; use hermesllm::apis::openai_responses::{ InputContent, InputItem, InputMessage, MessageContent, MessageRole, }; - fn create_test_state(response_id: &str, num_messages: usize) -> OpenAIConversationState { + fn create_test_state(response_id: &str) -> OpenAIConversationState { + OpenAIConversationState { + response_id: response_id.to_string(), + input_items: vec![InputItem::Message(InputMessage { + role: MessageRole::User, + content: MessageContent::Items(vec![InputContent::InputText { + text: "Test message".to_string(), + }]), + })], + created_at: 1234567890, + model: "claude-3".to_string(), + provider: "anthropic".to_string(), + } + } + + fn create_test_state_with_messages( + response_id: &str, + num_messages: usize, + ) -> OpenAIConversationState { let mut input_items = Vec::new(); for i in 0..num_messages { input_items.push(InputItem::Message(InputMessage { @@ -113,209 +132,8 @@ mod tests { } } - #[tokio::test] - async fn test_put_and_get_success() { - let storage = MemoryConversationalStorage::new(); - let state: OpenAIConversationState = create_test_state("resp_001", 3); - - // Store - storage.put(state.clone()).await.unwrap(); - - // Retrieve - let retrieved = storage.get("resp_001").await.unwrap(); - assert_eq!(retrieved.response_id, state.response_id); - assert_eq!(retrieved.model, state.model); - assert_eq!(retrieved.provider, state.provider); - assert_eq!(retrieved.input_items.len(), 3); - assert_eq!(retrieved.created_at, state.created_at); - } - - #[tokio::test] - async fn test_put_overwrites_existing() { - let storage = MemoryConversationalStorage::new(); - - // First state - let state1 = create_test_state("resp_002", 2); - storage.put(state1).await.unwrap(); - - // Overwrite with new state - let state2 = OpenAIConversationState { - response_id: "resp_002".to_string(), - input_items: vec![], - created_at: 9999999999, - model: "gpt-4".to_string(), - provider: "openai".to_string(), - }; - storage.put(state2.clone()).await.unwrap(); - - // Should retrieve the new state - let retrieved = storage.get("resp_002").await.unwrap(); - assert_eq!(retrieved.model, "gpt-4"); - assert_eq!(retrieved.provider, "openai"); - assert_eq!(retrieved.input_items.len(), 0); - assert_eq!(retrieved.created_at, 9999999999); - } - - #[tokio::test] - async fn test_get_not_found() { - let storage = MemoryConversationalStorage::new(); - - let result = storage.get("nonexistent").await; - assert!(result.is_err()); - - match result.unwrap_err() { - StateStorageError::NotFound(id) => { - assert_eq!(id, "nonexistent"); - } - _ => panic!("Expected NotFound error"), - } - } - - #[tokio::test] - async fn test_exists_returns_false_for_nonexistent() { - let storage = MemoryConversationalStorage::new(); - assert!(!storage.exists("resp_003").await.unwrap()); - } - - #[tokio::test] - async fn test_exists_returns_true_after_put() { - let storage = MemoryConversationalStorage::new(); - let state = create_test_state("resp_004", 1); - - assert!(!storage.exists("resp_004").await.unwrap()); - storage.put(state).await.unwrap(); - assert!(storage.exists("resp_004").await.unwrap()); - } - - #[tokio::test] - async fn test_delete_success() { - let storage = MemoryConversationalStorage::new(); - let state = create_test_state("resp_005", 2); - - storage.put(state).await.unwrap(); - assert!(storage.exists("resp_005").await.unwrap()); - - // Delete - storage.delete("resp_005").await.unwrap(); - - // Should no longer exist - assert!(!storage.exists("resp_005").await.unwrap()); - assert!(storage.get("resp_005").await.is_err()); - } - - #[tokio::test] - async fn test_delete_not_found() { - let storage = MemoryConversationalStorage::new(); - - let result = storage.delete("nonexistent").await; - assert!(result.is_err()); - - match result.unwrap_err() { - StateStorageError::NotFound(id) => { - assert_eq!(id, "nonexistent"); - } - _ => panic!("Expected NotFound error"), - } - } - - #[tokio::test] - async fn test_merge_combines_inputs() { - let storage = MemoryConversationalStorage::new(); - - // Create a previous state with 2 messages - let prev_state = create_test_state("resp_006", 2); - - // Create current input with 1 message - let current_input = vec![InputItem::Message(InputMessage { - role: MessageRole::User, - content: MessageContent::Items(vec![InputContent::InputText { - text: "New message".to_string(), - }]), - })]; - - // Merge - let merged = storage.merge(&prev_state, current_input); - - // Should have 3 messages total (2 from prev + 1 current) - assert_eq!(merged.len(), 3); - } - - #[tokio::test] - async fn test_merge_preserves_order() { - let storage = MemoryConversationalStorage::new(); - - // Previous state has messages 0 and 1 - let prev_state = create_test_state("resp_007", 2); - - // Current input has message 2 - let current_input = vec![InputItem::Message(InputMessage { - role: MessageRole::User, - content: MessageContent::Items(vec![InputContent::InputText { - text: "Message 2".to_string(), - }]), - })]; - - let merged = storage.merge(&prev_state, current_input); - - // Verify order: prev messages first, then current - let InputItem::Message(msg) = &merged[0] else { - panic!("Expected Message") - }; - match &msg.content { - MessageContent::Items(items) => match &items[0] { - InputContent::InputText { text } => assert_eq!(text, "Message 0"), - _ => panic!("Expected InputText"), - }, - _ => panic!("Expected MessageContent::Items"), - } - - let InputItem::Message(msg) = &merged[2] else { - panic!("Expected Message") - }; - match &msg.content { - MessageContent::Items(items) => match &items[0] { - InputContent::InputText { text } => assert_eq!(text, "Message 2"), - _ => panic!("Expected InputText"), - }, - _ => panic!("Expected MessageContent::Items"), - } - } - - #[tokio::test] - async fn test_merge_with_empty_current_input() { - let storage = MemoryConversationalStorage::new(); - let prev_state = create_test_state("resp_008", 3); - - let merged = storage.merge(&prev_state, vec![]); - - // Should just have the previous state's items - assert_eq!(merged.len(), 3); - } - - #[tokio::test] - async fn test_merge_with_empty_previous_state() { - let storage = MemoryConversationalStorage::new(); - - let prev_state = OpenAIConversationState { - response_id: "resp_009".to_string(), - input_items: vec![], - created_at: 1234567890, - model: "gpt-4".to_string(), - provider: "openai".to_string(), - }; - - let current_input = vec![InputItem::Message(InputMessage { - role: MessageRole::User, - content: MessageContent::Items(vec![InputContent::InputText { - text: "Only message".to_string(), - }]), - })]; - - let merged = storage.merge(&prev_state, current_input); - - // Should just have the current input - assert_eq!(merged.len(), 1); - } + // Generate the standard CRUD tests via macro + generate_storage_tests!(MemoryConversationalStorage::new()); #[tokio::test] async fn test_concurrent_access() { @@ -327,7 +145,7 @@ mod tests { for i in 0..10 { let storage_clone = storage.clone(); let handle = tokio::spawn(async move { - let state = create_test_state(&format!("resp_{}", i), i % 3); + let state = create_test_state_with_messages(&format!("resp_{}", i), i % 3); storage_clone.put(state).await.unwrap(); }); handles.push(handle); @@ -347,7 +165,7 @@ mod tests { #[tokio::test] async fn test_multiple_operations_on_same_id() { let storage = MemoryConversationalStorage::new(); - let state = create_test_state("resp_010", 1); + let state = create_test_state_with_messages("resp_010", 1); // Put storage.put(state.clone()).await.unwrap(); @@ -360,7 +178,7 @@ mod tests { assert!(storage.exists("resp_010").await.unwrap()); // Put again (overwrite) - let new_state = create_test_state("resp_010", 5); + let new_state = create_test_state_with_messages("resp_010", 5); storage.put(new_state).await.unwrap(); // Get updated @@ -373,266 +191,4 @@ mod tests { // Should not exist assert!(!storage.exists("resp_010").await.unwrap()); } - - #[tokio::test] - async fn test_merge_with_tool_call_flow() { - // This test simulates a realistic tool call conversation flow: - // 1. User sends message: "What's the weather?" - // 2. Model responds with function call (converted to assistant message) - // 3. User sends function call output in next request with previous_response_id - // The merge should combine: user message + assistant function call + function output - - let storage = MemoryConversationalStorage::new(); - - // Step 1: Previous state contains the initial exchange - // - User message: "What's the weather in SF?" - // - Assistant message (converted from FunctionCall): "Called function: get_weather..." - let prev_state = OpenAIConversationState { - response_id: "resp_tool_001".to_string(), - input_items: vec![ - // Original user message - InputItem::Message(InputMessage { - role: MessageRole::User, - content: MessageContent::Items(vec![InputContent::InputText { - text: "What's the weather in San Francisco?".to_string(), - }]), - }), - // Assistant's function call (converted from OutputItem::FunctionCall) - InputItem::Message(InputMessage { - role: MessageRole::Assistant, - content: MessageContent::Items(vec![InputContent::InputText { - text: "Called function: get_weather with arguments: {\"location\":\"San Francisco, CA\"}".to_string(), - }]), - }), - ], - created_at: 1234567890, - model: "claude-3".to_string(), - provider: "anthropic".to_string(), - }; - - // Step 2: Current request includes function call output - let current_input = vec![InputItem::Message(InputMessage { - role: MessageRole::User, - content: MessageContent::Items(vec![InputContent::InputText { - text: "Function result: {\"temperature\": 72, \"condition\": \"sunny\"}" - .to_string(), - }]), - })]; - - // Step 3: Merge should combine all conversation history - let merged = storage.merge(&prev_state, current_input); - - // Should have 3 items: user question + assistant function call + function output - assert_eq!(merged.len(), 3); - - // Verify the order and content - let InputItem::Message(msg1) = &merged[0] else { - panic!("Expected Message") - }; - assert!(matches!(msg1.role, MessageRole::User)); - match &msg1.content { - MessageContent::Items(items) => match &items[0] { - InputContent::InputText { text } => { - assert!(text.contains("weather in San Francisco")); - } - _ => panic!("Expected InputText"), - }, - _ => panic!("Expected MessageContent::Items"), - } - - let InputItem::Message(msg2) = &merged[1] else { - panic!("Expected Message") - }; - assert!(matches!(msg2.role, MessageRole::Assistant)); - match &msg2.content { - MessageContent::Items(items) => match &items[0] { - InputContent::InputText { text } => { - assert!(text.contains("get_weather")); - } - _ => panic!("Expected InputText"), - }, - _ => panic!("Expected MessageContent::Items"), - } - - let InputItem::Message(msg3) = &merged[2] else { - panic!("Expected Message") - }; - assert!(matches!(msg3.role, MessageRole::User)); - match &msg3.content { - MessageContent::Items(items) => match &items[0] { - InputContent::InputText { text } => { - assert!(text.contains("Function result")); - assert!(text.contains("temperature")); - } - _ => panic!("Expected InputText"), - }, - _ => panic!("Expected MessageContent::Items"), - } - } - - #[tokio::test] - async fn test_merge_with_multiple_tool_calls() { - // Test a more complex scenario with multiple tool calls - let storage = MemoryConversationalStorage::new(); - - // Previous state has: user message + 2 function calls from assistant - let prev_state = OpenAIConversationState { - response_id: "resp_tool_002".to_string(), - input_items: vec![ - InputItem::Message(InputMessage { - role: MessageRole::User, - content: MessageContent::Items(vec![InputContent::InputText { - text: "What's the weather and time in SF?".to_string(), - }]), - }), - InputItem::Message(InputMessage { - role: MessageRole::Assistant, - content: MessageContent::Items(vec![InputContent::InputText { - text: "Called function: get_weather with arguments: {\"location\":\"SF\"}".to_string(), - }]), - }), - InputItem::Message(InputMessage { - role: MessageRole::Assistant, - content: MessageContent::Items(vec![InputContent::InputText { - text: "Called function: get_time with arguments: {\"timezone\":\"America/Los_Angeles\"}".to_string(), - }]), - }), - ], - created_at: 1234567890, - model: "gpt-4".to_string(), - provider: "openai".to_string(), - }; - - // Current input: function outputs for both calls - let current_input = vec![ - InputItem::Message(InputMessage { - role: MessageRole::User, - content: MessageContent::Items(vec![InputContent::InputText { - text: "Weather result: {\"temp\": 68}".to_string(), - }]), - }), - InputItem::Message(InputMessage { - role: MessageRole::User, - content: MessageContent::Items(vec![InputContent::InputText { - text: "Time result: {\"time\": \"14:30\"}".to_string(), - }]), - }), - ]; - - let merged = storage.merge(&prev_state, current_input); - - // Should have 5 items total: 1 user + 2 assistant calls + 2 function outputs - assert_eq!(merged.len(), 5); - - // Verify first item is original user message - let InputItem::Message(first) = &merged[0] else { - panic!("Expected Message") - }; - assert!(matches!(first.role, MessageRole::User)); - - // Verify last two are function outputs - let InputItem::Message(second_last) = &merged[3] else { - panic!("Expected Message") - }; - assert!(matches!(second_last.role, MessageRole::User)); - match &second_last.content { - MessageContent::Items(items) => match &items[0] { - InputContent::InputText { text } => assert!(text.contains("Weather result")), - _ => panic!("Expected InputText"), - }, - _ => panic!("Expected MessageContent::Items"), - } - - let InputItem::Message(last) = &merged[4] else { - panic!("Expected Message") - }; - assert!(matches!(last.role, MessageRole::User)); - match &last.content { - MessageContent::Items(items) => match &items[0] { - InputContent::InputText { text } => assert!(text.contains("Time result")), - _ => panic!("Expected InputText"), - }, - _ => panic!("Expected MessageContent::Items"), - } - } - - #[tokio::test] - async fn test_merge_preserves_conversation_context_for_multi_turn() { - // Simulate a multi-turn conversation with tool calls - let storage = MemoryConversationalStorage::new(); - - // Previous state: full conversation history up to this point - let prev_state = OpenAIConversationState { - response_id: "resp_tool_003".to_string(), - input_items: vec![ - // Turn 1: User asks about weather - InputItem::Message(InputMessage { - role: MessageRole::User, - content: MessageContent::Items(vec![InputContent::InputText { - text: "What's the weather?".to_string(), - }]), - }), - // Turn 1: Assistant calls get_weather - InputItem::Message(InputMessage { - role: MessageRole::Assistant, - content: MessageContent::Items(vec![InputContent::InputText { - text: "Called function: get_weather".to_string(), - }]), - }), - // Turn 2: User provides function output - InputItem::Message(InputMessage { - role: MessageRole::User, - content: MessageContent::Items(vec![InputContent::InputText { - text: "Weather: sunny, 72°F".to_string(), - }]), - }), - // Turn 2: Assistant responds with text - InputItem::Message(InputMessage { - role: MessageRole::Assistant, - content: MessageContent::Items(vec![InputContent::InputText { - text: "It's sunny and 72°F in San Francisco today!".to_string(), - }]), - }), - ], - created_at: 1234567890, - model: "claude-3".to_string(), - provider: "anthropic".to_string(), - }; - - // Turn 3: User asks follow-up question - let current_input = vec![InputItem::Message(InputMessage { - role: MessageRole::User, - content: MessageContent::Items(vec![InputContent::InputText { - text: "Should I bring an umbrella?".to_string(), - }]), - })]; - - let merged = storage.merge(&prev_state, current_input); - - // Should have all 5 messages in order - assert_eq!(merged.len(), 5); - - // Verify the entire conversation flow is preserved - let InputItem::Message(first) = &merged[0] else { - panic!("Expected Message") - }; - match &first.content { - MessageContent::Items(items) => match &items[0] { - InputContent::InputText { text } => assert!(text.contains("What's the weather")), - _ => panic!("Expected InputText"), - }, - _ => panic!("Expected MessageContent::Items"), - } - - let InputItem::Message(last) = &merged[4] else { - panic!("Expected Message") - }; - match &last.content { - MessageContent::Items(items) => match &items[0] { - InputContent::InputText { text } => assert!(text.contains("umbrella")), - _ => panic!("Expected InputText"), - }, - _ => panic!("Expected MessageContent::Items"), - } - } } diff --git a/crates/brightstaff/src/state/mod.rs b/crates/brightstaff/src/state/mod.rs index 3d59f359a..57402bb93 100644 --- a/crates/brightstaff/src/state/mod.rs +++ b/crates/brightstaff/src/state/mod.rs @@ -148,13 +148,128 @@ pub async fn retrieve_and_combine_input( Ok(combined_input) } +#[cfg(test)] +macro_rules! generate_storage_tests { + ($create_storage:expr) => { + #[tokio::test] + async fn test_put_and_get_success() { + let storage = $create_storage; + let state = create_test_state("resp_001"); + storage.put(state.clone()).await.unwrap(); + let retrieved = storage.get("resp_001").await.unwrap(); + assert_eq!(retrieved.response_id, "resp_001"); + assert_eq!(retrieved.model, state.model); + assert_eq!(retrieved.provider, state.provider); + assert_eq!(retrieved.input_items.len(), state.input_items.len()); + assert_eq!(retrieved.created_at, state.created_at); + } + + #[tokio::test] + async fn test_put_overwrites_existing() { + let storage = $create_storage; + let state1 = create_test_state("resp_002"); + storage.put(state1).await.unwrap(); + let state2 = OpenAIConversationState { + response_id: "resp_002".to_string(), + input_items: vec![], + created_at: 9999999999, + model: "gpt-4".to_string(), + provider: "openai".to_string(), + }; + storage.put(state2).await.unwrap(); + let retrieved = storage.get("resp_002").await.unwrap(); + assert_eq!(retrieved.model, "gpt-4"); + assert_eq!(retrieved.provider, "openai"); + assert_eq!(retrieved.input_items.len(), 0); + assert_eq!(retrieved.created_at, 9999999999); + } + + #[tokio::test] + async fn test_get_not_found() { + let storage = $create_storage; + let result = storage.get("nonexistent").await; + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + StateStorageError::NotFound(_) + )); + } + + #[tokio::test] + async fn test_exists_returns_false_for_nonexistent() { + let storage = $create_storage; + assert!(!storage.exists("nonexistent").await.unwrap()); + } + + #[tokio::test] + async fn test_exists_returns_true_after_put() { + let storage = $create_storage; + let state = create_test_state("resp_004"); + assert!(!storage.exists("resp_004").await.unwrap()); + storage.put(state).await.unwrap(); + assert!(storage.exists("resp_004").await.unwrap()); + } + + #[tokio::test] + async fn test_delete_success() { + let storage = $create_storage; + let state = create_test_state("resp_005"); + storage.put(state).await.unwrap(); + assert!(storage.exists("resp_005").await.unwrap()); + storage.delete("resp_005").await.unwrap(); + assert!(!storage.exists("resp_005").await.unwrap()); + assert!(storage.get("resp_005").await.is_err()); + } + + #[tokio::test] + async fn test_delete_not_found() { + let storage = $create_storage; + let result = storage.delete("nonexistent").await; + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + StateStorageError::NotFound(_) + )); + } + }; +} + +#[cfg(test)] +pub(crate) use generate_storage_tests; + #[cfg(test)] mod tests { use super::extract_input_items; + use super::memory::MemoryConversationalStorage; + use super::{OpenAIConversationState, StateStorage}; use hermesllm::apis::openai_responses::{ InputContent, InputItem, InputMessage, InputParam, MessageContent, MessageRole, }; + fn create_test_state(response_id: &str, num_messages: usize) -> OpenAIConversationState { + let mut input_items = Vec::new(); + for i in 0..num_messages { + input_items.push(InputItem::Message(InputMessage { + role: if i % 2 == 0 { + MessageRole::User + } else { + MessageRole::Assistant + }, + content: MessageContent::Items(vec![InputContent::InputText { + text: format!("Message {}", i), + }]), + })); + } + + OpenAIConversationState { + response_id: response_id.to_string(), + input_items, + created_at: 1234567890, + model: "claude-3".to_string(), + provider: "anthropic".to_string(), + } + } + #[test] fn test_extract_input_items_converts_text_to_user_message_item() { let extracted = extract_input_items(&InputParam::Text("hello world".to_string())); @@ -244,4 +359,320 @@ mod tests { }; assert_eq!(second_text, "second"); } + + // === Merge tests (testing the default trait method) === + + #[tokio::test] + async fn test_merge_combines_inputs() { + let storage = MemoryConversationalStorage::new(); + let prev_state = create_test_state("resp_006", 2); + + let current_input = vec![InputItem::Message(InputMessage { + role: MessageRole::User, + content: MessageContent::Items(vec![InputContent::InputText { + text: "New message".to_string(), + }]), + })]; + + let merged = storage.merge(&prev_state, current_input); + assert_eq!(merged.len(), 3); + } + + #[tokio::test] + async fn test_merge_preserves_order() { + let storage = MemoryConversationalStorage::new(); + let prev_state = create_test_state("resp_007", 2); + + let current_input = vec![InputItem::Message(InputMessage { + role: MessageRole::User, + content: MessageContent::Items(vec![InputContent::InputText { + text: "Message 2".to_string(), + }]), + })]; + + let merged = storage.merge(&prev_state, current_input); + + let InputItem::Message(msg) = &merged[0] else { + panic!("Expected Message") + }; + match &msg.content { + MessageContent::Items(items) => match &items[0] { + InputContent::InputText { text } => assert_eq!(text, "Message 0"), + _ => panic!("Expected InputText"), + }, + _ => panic!("Expected MessageContent::Items"), + } + + let InputItem::Message(msg) = &merged[2] else { + panic!("Expected Message") + }; + match &msg.content { + MessageContent::Items(items) => match &items[0] { + InputContent::InputText { text } => assert_eq!(text, "Message 2"), + _ => panic!("Expected InputText"), + }, + _ => panic!("Expected MessageContent::Items"), + } + } + + #[tokio::test] + async fn test_merge_with_empty_current_input() { + let storage = MemoryConversationalStorage::new(); + let prev_state = create_test_state("resp_008", 3); + + let merged = storage.merge(&prev_state, vec![]); + assert_eq!(merged.len(), 3); + } + + #[tokio::test] + async fn test_merge_with_empty_previous_state() { + let storage = MemoryConversationalStorage::new(); + + let prev_state = OpenAIConversationState { + response_id: "resp_009".to_string(), + input_items: vec![], + created_at: 1234567890, + model: "gpt-4".to_string(), + provider: "openai".to_string(), + }; + + let current_input = vec![InputItem::Message(InputMessage { + role: MessageRole::User, + content: MessageContent::Items(vec![InputContent::InputText { + text: "Only message".to_string(), + }]), + })]; + + let merged = storage.merge(&prev_state, current_input); + assert_eq!(merged.len(), 1); + } + + #[tokio::test] + async fn test_merge_with_tool_call_flow() { + let storage = MemoryConversationalStorage::new(); + + let prev_state = OpenAIConversationState { + response_id: "resp_tool_001".to_string(), + input_items: vec![ + InputItem::Message(InputMessage { + role: MessageRole::User, + content: MessageContent::Items(vec![InputContent::InputText { + text: "What's the weather in San Francisco?".to_string(), + }]), + }), + InputItem::Message(InputMessage { + role: MessageRole::Assistant, + content: MessageContent::Items(vec![InputContent::InputText { + text: "Called function: get_weather with arguments: {\"location\":\"San Francisco, CA\"}".to_string(), + }]), + }), + ], + created_at: 1234567890, + model: "claude-3".to_string(), + provider: "anthropic".to_string(), + }; + + let current_input = vec![InputItem::Message(InputMessage { + role: MessageRole::User, + content: MessageContent::Items(vec![InputContent::InputText { + text: "Function result: {\"temperature\": 72, \"condition\": \"sunny\"}" + .to_string(), + }]), + })]; + + let merged = storage.merge(&prev_state, current_input); + assert_eq!(merged.len(), 3); + + let InputItem::Message(msg1) = &merged[0] else { + panic!("Expected Message") + }; + assert!(matches!(msg1.role, MessageRole::User)); + match &msg1.content { + MessageContent::Items(items) => match &items[0] { + InputContent::InputText { text } => { + assert!(text.contains("weather in San Francisco")); + } + _ => panic!("Expected InputText"), + }, + _ => panic!("Expected MessageContent::Items"), + } + + let InputItem::Message(msg2) = &merged[1] else { + panic!("Expected Message") + }; + assert!(matches!(msg2.role, MessageRole::Assistant)); + match &msg2.content { + MessageContent::Items(items) => match &items[0] { + InputContent::InputText { text } => { + assert!(text.contains("get_weather")); + } + _ => panic!("Expected InputText"), + }, + _ => panic!("Expected MessageContent::Items"), + } + + let InputItem::Message(msg3) = &merged[2] else { + panic!("Expected Message") + }; + assert!(matches!(msg3.role, MessageRole::User)); + match &msg3.content { + MessageContent::Items(items) => match &items[0] { + InputContent::InputText { text } => { + assert!(text.contains("Function result")); + assert!(text.contains("temperature")); + } + _ => panic!("Expected InputText"), + }, + _ => panic!("Expected MessageContent::Items"), + } + } + + #[tokio::test] + async fn test_merge_with_multiple_tool_calls() { + let storage = MemoryConversationalStorage::new(); + + let prev_state = OpenAIConversationState { + response_id: "resp_tool_002".to_string(), + input_items: vec![ + InputItem::Message(InputMessage { + role: MessageRole::User, + content: MessageContent::Items(vec![InputContent::InputText { + text: "What's the weather and time in SF?".to_string(), + }]), + }), + InputItem::Message(InputMessage { + role: MessageRole::Assistant, + content: MessageContent::Items(vec![InputContent::InputText { + text: "Called function: get_weather with arguments: {\"location\":\"SF\"}".to_string(), + }]), + }), + InputItem::Message(InputMessage { + role: MessageRole::Assistant, + content: MessageContent::Items(vec![InputContent::InputText { + text: "Called function: get_time with arguments: {\"timezone\":\"America/Los_Angeles\"}".to_string(), + }]), + }), + ], + created_at: 1234567890, + model: "gpt-4".to_string(), + provider: "openai".to_string(), + }; + + let current_input = vec![ + InputItem::Message(InputMessage { + role: MessageRole::User, + content: MessageContent::Items(vec![InputContent::InputText { + text: "Weather result: {\"temp\": 68}".to_string(), + }]), + }), + InputItem::Message(InputMessage { + role: MessageRole::User, + content: MessageContent::Items(vec![InputContent::InputText { + text: "Time result: {\"time\": \"14:30\"}".to_string(), + }]), + }), + ]; + + let merged = storage.merge(&prev_state, current_input); + assert_eq!(merged.len(), 5); + + let InputItem::Message(first) = &merged[0] else { + panic!("Expected Message") + }; + assert!(matches!(first.role, MessageRole::User)); + + let InputItem::Message(second_last) = &merged[3] else { + panic!("Expected Message") + }; + assert!(matches!(second_last.role, MessageRole::User)); + match &second_last.content { + MessageContent::Items(items) => match &items[0] { + InputContent::InputText { text } => assert!(text.contains("Weather result")), + _ => panic!("Expected InputText"), + }, + _ => panic!("Expected MessageContent::Items"), + } + + let InputItem::Message(last) = &merged[4] else { + panic!("Expected Message") + }; + assert!(matches!(last.role, MessageRole::User)); + match &last.content { + MessageContent::Items(items) => match &items[0] { + InputContent::InputText { text } => assert!(text.contains("Time result")), + _ => panic!("Expected InputText"), + }, + _ => panic!("Expected MessageContent::Items"), + } + } + + #[tokio::test] + async fn test_merge_preserves_conversation_context_for_multi_turn() { + let storage = MemoryConversationalStorage::new(); + + let prev_state = OpenAIConversationState { + response_id: "resp_tool_003".to_string(), + input_items: vec![ + InputItem::Message(InputMessage { + role: MessageRole::User, + content: MessageContent::Items(vec![InputContent::InputText { + text: "What's the weather?".to_string(), + }]), + }), + InputItem::Message(InputMessage { + role: MessageRole::Assistant, + content: MessageContent::Items(vec![InputContent::InputText { + text: "Called function: get_weather".to_string(), + }]), + }), + InputItem::Message(InputMessage { + role: MessageRole::User, + content: MessageContent::Items(vec![InputContent::InputText { + text: "Weather: sunny, 72\u{00b0}F".to_string(), + }]), + }), + InputItem::Message(InputMessage { + role: MessageRole::Assistant, + content: MessageContent::Items(vec![InputContent::InputText { + text: "It's sunny and 72\u{00b0}F in San Francisco today!".to_string(), + }]), + }), + ], + created_at: 1234567890, + model: "claude-3".to_string(), + provider: "anthropic".to_string(), + }; + + let current_input = vec![InputItem::Message(InputMessage { + role: MessageRole::User, + content: MessageContent::Items(vec![InputContent::InputText { + text: "Should I bring an umbrella?".to_string(), + }]), + })]; + + let merged = storage.merge(&prev_state, current_input); + assert_eq!(merged.len(), 5); + + let InputItem::Message(first) = &merged[0] else { + panic!("Expected Message") + }; + match &first.content { + MessageContent::Items(items) => match &items[0] { + InputContent::InputText { text } => assert!(text.contains("What's the weather")), + _ => panic!("Expected InputText"), + }, + _ => panic!("Expected MessageContent::Items"), + } + + let InputItem::Message(last) = &merged[4] else { + panic!("Expected Message") + }; + match &last.content { + MessageContent::Items(items) => match &items[0] { + InputContent::InputText { text } => assert!(text.contains("umbrella")), + _ => panic!("Expected InputText"), + }, + _ => panic!("Expected MessageContent::Items"), + } + } } diff --git a/crates/brightstaff/src/state/postgresql.rs b/crates/brightstaff/src/state/postgresql.rs index fe27580e6..38f85510a 100644 --- a/crates/brightstaff/src/state/postgresql.rs +++ b/crates/brightstaff/src/state/postgresql.rs @@ -229,6 +229,7 @@ Run that SQL file against your database before using this storage backend. #[cfg(test)] mod tests { use super::*; + use crate::state::generate_storage_tests; use hermesllm::apis::openai_responses::{ InputContent, InputItem, InputMessage, MessageContent, MessageRole, }; @@ -267,140 +268,13 @@ mod tests { } } - #[tokio::test] - async fn test_supabase_put_and_get_success() { - let Some(storage) = get_test_storage().await else { - return; - }; - - let state = create_test_state("test_resp_001"); - storage.put(state.clone()).await.unwrap(); - - let retrieved = storage.get("test_resp_001").await.unwrap(); - assert_eq!(retrieved.response_id, "test_resp_001"); - assert_eq!(retrieved.input_items.len(), 1); - assert_eq!(retrieved.model, "gpt-4"); - assert_eq!(retrieved.provider, "openai"); - - // Cleanup - let _ = storage.delete("test_resp_001").await; - } - - #[tokio::test] - async fn test_supabase_put_overwrites_existing() { - let Some(storage) = get_test_storage().await else { - return; - }; - - let state1 = create_test_state("test_resp_002"); - storage.put(state1).await.unwrap(); - - let mut state2 = create_test_state("test_resp_002"); - state2.model = "gpt-4-turbo".to_string(); - state2.input_items.push(InputItem::Message(InputMessage { - role: MessageRole::Assistant, - content: MessageContent::Items(vec![InputContent::InputText { - text: "Response".to_string(), - }]), - })); - storage.put(state2).await.unwrap(); - - let retrieved = storage.get("test_resp_002").await.unwrap(); - assert_eq!(retrieved.model, "gpt-4-turbo"); - assert_eq!(retrieved.input_items.len(), 2); - - // Cleanup - let _ = storage.delete("test_resp_002").await; - } - - #[tokio::test] - async fn test_supabase_get_not_found() { - let Some(storage) = get_test_storage().await else { - return; - }; - - let result = storage.get("nonexistent_id").await; - assert!(result.is_err()); - assert!(matches!( - result.unwrap_err(), - StateStorageError::NotFound(_) - )); - } - - #[tokio::test] - async fn test_supabase_exists_returns_false() { - let Some(storage) = get_test_storage().await else { - return; - }; - - let exists = storage.exists("nonexistent_id").await.unwrap(); - assert!(!exists); - } - - #[tokio::test] - async fn test_supabase_exists_returns_true_after_put() { + // Generate the standard CRUD tests via macro + generate_storage_tests!({ let Some(storage) = get_test_storage().await else { return; }; - - let state = create_test_state("test_resp_003"); - storage.put(state).await.unwrap(); - - let exists = storage.exists("test_resp_003").await.unwrap(); - assert!(exists); - - // Cleanup - let _ = storage.delete("test_resp_003").await; - } - - #[tokio::test] - async fn test_supabase_delete_success() { - let Some(storage) = get_test_storage().await else { - return; - }; - - let state = create_test_state("test_resp_004"); - storage.put(state).await.unwrap(); - - storage.delete("test_resp_004").await.unwrap(); - - let exists = storage.exists("test_resp_004").await.unwrap(); - assert!(!exists); - } - - #[tokio::test] - async fn test_supabase_delete_not_found() { - let Some(storage) = get_test_storage().await else { - return; - }; - - let result = storage.delete("nonexistent_id").await; - assert!(result.is_err()); - assert!(matches!( - result.unwrap_err(), - StateStorageError::NotFound(_) - )); - } - - #[tokio::test] - async fn test_supabase_merge_works() { - let Some(storage) = get_test_storage().await else { - return; - }; - - let prev_state = create_test_state("test_resp_005"); - let current_input = vec![InputItem::Message(InputMessage { - role: MessageRole::User, - content: MessageContent::Items(vec![InputContent::InputText { - text: "New message".to_string(), - }]), - })]; - - let merged = storage.merge(&prev_state, current_input); - - // Should have 2 messages (1 from prev + 1 current) - assert_eq!(merged.len(), 2); - } + storage + }); #[tokio::test] async fn test_supabase_table_verification() { @@ -428,7 +302,7 @@ mod tests { let state = create_test_state("manual_test_verification"); storage.put(state).await.unwrap(); - println!("✅ Data written to Supabase!"); + println!("Data written to Supabase!"); println!("Check your Supabase dashboard:"); println!( " SELECT * FROM conversation_states WHERE response_id = 'manual_test_verification';" diff --git a/crates/hermesllm/src/providers/streaming_response.rs b/crates/hermesllm/src/providers/streaming_response.rs index 66ccc7354..032b2508d 100644 --- a/crates/hermesllm/src/providers/streaming_response.rs +++ b/crates/hermesllm/src/providers/streaming_response.rs @@ -472,6 +472,37 @@ mod tests { use crate::clients::endpoints::SupportedAPIsFromClient; use serde_json::json; + /// Helper to build an SseEvent from optional JSON data and optional event type. + fn make_sse_event(data: Option, event: Option<&str>) -> SseEvent { + match data { + Some(ref json_val) => SseEvent { + data: Some(json_val.to_string()), + event: event.map(|s| s.to_string()), + raw_line: format!("data: {}", json_val), + sse_transformed_lines: format!("data: {}", json_val), + provider_stream_response: None, + }, + None => SseEvent { + data: None, + event: event.map(|s| s.to_string()), + raw_line: event.map(|e| format!("event: {}", e)).unwrap_or_default(), + sse_transformed_lines: event.map(|e| format!("event: {}", e)).unwrap_or_default(), + provider_stream_response: None, + }, + } + } + + /// Helper to build a standard OpenAI content chunk with the given content text. + fn openai_content_chunk(content: &str) -> serde_json::Value { + json!({ + "id": "chatcmpl-123", + "object": "chat.completion.chunk", + "created": 1234567890, + "model": "gpt-4", + "choices": [{"index": 0, "delta": {"content": content}, "finish_reason": null}] + }) + } + #[test] fn test_sse_event_parsing() { // Test valid SSE data line @@ -1099,14 +1130,7 @@ mod tests { } }); - // Create SSE event with this data - let sse_event = SseEvent { - data: Some(openai_stream_chunk.to_string()), - event: None, - raw_line: format!("data: {}", openai_stream_chunk), - sse_transformed_lines: format!("data: {}", openai_stream_chunk), - provider_stream_response: None, - }; + let sse_event = make_sse_event(Some(openai_stream_chunk), None); let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages); let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); @@ -1143,26 +1167,7 @@ mod tests { use crate::apis::openai::OpenAIApi; // Create an OpenAI stream response with content (which becomes content_block_delta in Anthropic) - let openai_stream_chunk = json!({ - "id": "chatcmpl-123", - "object": "chat.completion.chunk", - "created": 1234567890, - "model": "gpt-4", - "choices": [{ - "index": 0, - "delta": {"content": "Hello"}, - "finish_reason": null - }] - }); - - // Create SSE event with this data - let sse_event = SseEvent { - data: Some(openai_stream_chunk.to_string()), - event: None, - raw_line: format!("data: {}", openai_stream_chunk), - sse_transformed_lines: format!("data: {}", openai_stream_chunk), - provider_stream_response: None, - }; + let sse_event = make_sse_event(Some(openai_content_chunk("Hello")), None); let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages); let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); @@ -1198,13 +1203,7 @@ mod tests { use crate::apis::openai::OpenAIApi; // Create an Anthropic event-only SSE line (no data) - let sse_event = SseEvent { - data: None, - event: Some("message_start".to_string()), - raw_line: "event: message_start".to_string(), - sse_transformed_lines: "event: message_start".to_string(), - provider_stream_response: None, - }; + let sse_event = make_sse_event(None, Some("message_start")); let client_api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions); let upstream_api = SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages); @@ -1245,13 +1244,7 @@ mod tests { } }); - let sse_event = SseEvent { - data: Some(anthropic_event.to_string()), - event: None, - raw_line: format!("data: {}", anthropic_event), - sse_transformed_lines: format!("data: {}", anthropic_event), - provider_stream_response: None, - }; + let sse_event = make_sse_event(Some(anthropic_event), None); let client_api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions); let upstream_api = SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages); @@ -1279,26 +1272,9 @@ mod tests { use crate::apis::openai::OpenAIApi; // Create an OpenAI stream response - let openai_stream_chunk = json!({ - "id": "chatcmpl-123", - "object": "chat.completion.chunk", - "created": 1234567890, - "model": "gpt-4", - "choices": [{ - "index": 0, - "delta": {"content": "Hello"}, - "finish_reason": null - }] - }); - - let original_data = openai_stream_chunk.to_string(); - let sse_event = SseEvent { - data: Some(original_data.clone()), - event: None, - raw_line: format!("data: {}", original_data), - sse_transformed_lines: format!("data: {}\n\n", original_data), - provider_stream_response: None, - }; + let mut sse_event = make_sse_event(Some(openai_content_chunk("Hello")), None); + // This test requires trailing \n\n in sse_transformed_lines + sse_event.sse_transformed_lines = format!("data: {}\n\n", openai_content_chunk("Hello")); let client_api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions); let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); @@ -1324,25 +1300,7 @@ mod tests { use crate::apis::openai::OpenAIApi; // Create an OpenAI stream response - let openai_stream_chunk = json!({ - "id": "chatcmpl-123", - "object": "chat.completion.chunk", - "created": 1234567890, - "model": "gpt-4", - "choices": [{ - "index": 0, - "delta": {"content": "Test"}, - "finish_reason": null - }] - }); - - let sse_event = SseEvent { - data: Some(openai_stream_chunk.to_string()), - event: None, - raw_line: format!("data: {}", openai_stream_chunk), - sse_transformed_lines: format!("data: {}", openai_stream_chunk), - provider_stream_response: None, - }; + let sse_event = make_sse_event(Some(openai_content_chunk("Test")), None); let client_api = SupportedAPIsFromClient::AnthropicMessagesAPI(AnthropicApi::Messages); let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions);