Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/rust-mcp-sdk/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ reqwest = { workspace = true, default-features = false, features = [
"cookies",
"multipart",
] }
tempfile = "3.23.0"
tracing-subscriber = { workspace = true, features = [
"env-filter",
"std",
Expand Down
169 changes: 168 additions & 1 deletion crates/rust-mcp-sdk/src/hyper_servers/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ impl HyperServerOptions {
}

pub fn streamable_http_endpoint(&self) -> &str {
self.custom_messages_endpoint
self.custom_streamable_http_endpoint
.as_deref()
.unwrap_or(DEFAULT_STREAMABLE_HTTP_ENDPOINT)
}
Expand Down Expand Up @@ -490,3 +490,170 @@ async fn shutdown_signal(handle: Handle, state: Arc<McpAppState>) {
// Trigger graceful shutdown with a timeout
handle.graceful_shutdown(Some(Duration::from_secs(GRACEFUL_SHUTDOWN_TMEOUT_SECS)));
}

#[cfg(test)]
mod tests {
use super::*;

use tempfile::NamedTempFile;

#[test]
fn test_server_options_base_url_custom() {
let options = HyperServerOptions {
host: String::from("127.0.0.1"),
port: 8081,
enable_ssl: true,
..Default::default()
};
assert_eq!(options.base_url(), "https://127.0.0.1:8081");
}

#[test]
fn test_server_options_streamable_http_custom() {
let options = HyperServerOptions {
custom_streamable_http_endpoint: Some(String::from("/abcd/mcp")),
host: String::from("127.0.0.1"),
port: 8081,
enable_ssl: true,
..Default::default()
};
assert_eq!(
options.streamable_http_url(),
"https://127.0.0.1:8081/abcd/mcp"
);
assert_eq!(options.streamable_http_endpoint(), "/abcd/mcp");
}

#[test]
fn test_server_options_sse_custom() {
let options = HyperServerOptions {
custom_sse_endpoint: Some(String::from("/abcd/sse")),
host: String::from("127.0.0.1"),
port: 8081,
enable_ssl: true,
..Default::default()
};
assert_eq!(options.sse_url(), "https://127.0.0.1:8081/abcd/sse");
assert_eq!(options.sse_endpoint(), "/abcd/sse");
}

#[test]
fn test_server_options_sse_messages_custom() {
let options = HyperServerOptions {
custom_messages_endpoint: Some(String::from("/abcd/messages")),
..Default::default()
};
assert_eq!(
options.sse_message_url(),
"http://127.0.0.1:8080/abcd/messages"
);
assert_eq!(options.sse_messages_endpoint(), "/abcd/messages");
}

#[test]
fn test_server_options_needs_dns_protection() {
let options = HyperServerOptions::default();

// should be false by default
assert!(!options.needs_dns_protection());

// should still be false unless allowed_hosts or allowed_origins are also provided
let options = HyperServerOptions {
dns_rebinding_protection: true,
..Default::default()
};
assert!(!options.needs_dns_protection());

// should be true when dns_rebinding_protection is true and allowed_hosts is provided
let options = HyperServerOptions {
dns_rebinding_protection: true,
allowed_hosts: Some(vec![String::from("127.0.0.1")]),
..Default::default()
};
assert!(options.needs_dns_protection());

// should be true when dns_rebinding_protection is true and allowed_origins is provided
let options = HyperServerOptions {
dns_rebinding_protection: true,
allowed_origins: Some(vec![String::from("http://127.0.0.1:8080")]),
..Default::default()
};
assert!(options.needs_dns_protection());
}

#[test]
fn test_server_options_validate() {
let options = HyperServerOptions::default();
assert!(options.validate().is_ok());

// with ssl enabled but no cert or key provided, validate should fail
let options = HyperServerOptions {
enable_ssl: true,
..Default::default()
};
assert!(options.validate().is_err());

// with ssl enabled and invalid cert/key paths, validate should fail
let options = HyperServerOptions {
enable_ssl: true,
ssl_cert_path: Some(String::from("/invalid/path/to/cert.pem")),
ssl_key_path: Some(String::from("/invalid/path/to/key.pem")),
..Default::default()
};
assert!(options.validate().is_err());

// with ssl enabled and valid cert/key paths, validate should succeed
let cert_file =
NamedTempFile::with_suffix(".pem").expect("Expected to create test cert file");
let ssl_cert_path = cert_file
.path()
.to_str()
.expect("Expected to get cert path")
.to_string();
let key_file =
NamedTempFile::with_suffix(".pem").expect("Expected to create test key file");
Comment on lines +607 to +614
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just as a note, these files should be dropped when they go out of scope

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @jsudano , you are a pro!

the second adds a dev dependency in order to properly test SSL validation. I broke it into two commits in case you'd prefer not to add any dependencies, in which case I'll remove the second.

tempfile is great, no problem to have it in the dev dependency πŸ‘ , thanks for including the various test scenarios.

I plan to have a release in a couple of days including this fix.

let ssl_key_path = key_file
.path()
.to_str()
.expect("Expected to get key path")
.to_string();

let options = HyperServerOptions {
enable_ssl: true,
ssl_cert_path: Some(ssl_cert_path),
ssl_key_path: Some(ssl_key_path),
..Default::default()
};
assert!(options.validate().is_ok());
}

#[tokio::test]
async fn test_server_options_resolve_server_address() {
let options = HyperServerOptions::default();
assert!(options.resolve_server_address().await.is_ok());

// valid host should still work
let options = HyperServerOptions {
host: String::from("8.6.7.5"),
port: 309,
..Default::default()
};
assert!(options.resolve_server_address().await.is_ok());

// valid host (prepended with http://) should still work
let options = HyperServerOptions {
host: String::from("http://8.6.7.5"),
port: 309,
..Default::default()
};
assert!(options.resolve_server_address().await.is_ok());

// invalid host should raise an error
let options = HyperServerOptions {
host: String::from("invalid-host"),
port: 309,
..Default::default()
};
assert!(options.resolve_server_address().await.is_err());
}
}