Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
- name: Run Clippy against main lib
run: cargo clippy --workspace --exclude tests
- name: Run tests
run: cargo test --verbose -- --test-threads=1
run: cargo test --verbose --
env:
RUST_LOG: info
- name: Rust Format
Expand Down
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ edition = '2018'
license = "MIT"

[dependencies]
libc = "0.2"
log = "0.4"
env_logger = "0.7"
futures = { version = "0.3"}
Expand Down
121 changes: 116 additions & 5 deletions src/bin/post_meetup.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
extern crate tokio;
use clap::{crate_authors, App as ClApp, Arg};
use clap::{crate_authors, App as ClApp, Arg, ArgGroup};
use futures::future;
use futures::StreamExt;
use log::*;
Expand All @@ -15,6 +15,65 @@ use post::find_service::{
};
use tonic::transport::Server;

use std::os::raw::{c_int, c_void};
use std::os::unix::io as unix_io;
use std::os::unix::io::FromRawFd;

#[derive(Debug)]
pub struct InvalidSocketDescriptor {
fd: unix_io::RawFd,
}

impl InvalidSocketDescriptor {
pub fn new(fd: unix_io::RawFd) -> InvalidSocketDescriptor {
InvalidSocketDescriptor { fd }
}
}

impl std::fmt::Display for InvalidSocketDescriptor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Invalid Socket File Descriptor: {}", self.fd)
}
}

impl std::error::Error for InvalidSocketDescriptor {}

unsafe fn getsocketopt<T>(fd: unix_io::RawFd, level: c_int, opt: c_int) -> std::io::Result<T> {
let mut val: T = std::mem::zeroed();
let mut val_size: libc::socklen_t = std::mem::size_of::<T>() as u32;
let ret = libc::getsockopt(
fd,
level,
opt,
&mut val as *mut T as *mut c_void,
&mut val_size as *mut libc::socklen_t,
);
if ret < 0 {
Err(std::io::Error::last_os_error())
} else {
Ok(val)
}
}

fn listener_from_raw(
fd: unix_io::RawFd,
) -> Result<std::net::TcpListener, Box<dyn std::error::Error>> {
let family: c_int = unsafe { getsocketopt(fd, libc::SOL_SOCKET, libc::SO_DOMAIN) }?;

if family != libc::AF_INET && family != libc::AF_INET6 {
eprint!("Unable to use socket of family: {}", family);
return Err(Box::new(InvalidSocketDescriptor::new(fd)));
}

let sock_type: c_int = unsafe { getsocketopt(fd, libc::SOL_SOCKET, libc::SO_TYPE) }?;

if sock_type != libc::SOCK_STREAM {
return Err(Box::new(InvalidSocketDescriptor::new(fd)));
}

Ok(unsafe { std::net::TcpListener::from_raw_fd(fd) })
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
env_logger::init();
Expand All @@ -27,10 +86,20 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.short("b")
.long("bind")
.takes_value(true)
.required(true)
.validator(socket_validator)
.help("IP and port to bind to in the form of <IP>:<Port>. Use 0.0.0.0 for any IP on the system."),
)
.arg(
Arg::with_name("fd")
.long("fd")
.takes_value(true)
.help("File descriptor of bound TCP socket ready for `listen` to be called on it")
)
.group(
ArgGroup::with_name("Socket description")
.args(&["bind","fd"])
.required(true)
)
.arg(
Arg::with_name("publisher-timeout")
.short("t")
Expand Down Expand Up @@ -62,7 +131,20 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.parse()
.unwrap(),
);
let bind_info = matches.value_of("bind").unwrap().parse().unwrap();
let listener = match matches.value_of("bind") {
Some(addr) => {
tokio::net::TcpListener::bind(
std::net::ToSocketAddrs::to_socket_addrs(addr)?
.next()
.expect("No Address found"),
)
.await?
}
None => tokio::net::TcpListener::from_std(listener_from_raw(
matches.value_of("fd").unwrap().parse::<c_int>().unwrap(),
)?)?,
};
let local_address = listener.local_addr()?;

let publisher_store = HashMapPublisherStore::new(RwLock::new(HashMap::new()));
let meetup_server_options = MeetupServerOptions {
Expand All @@ -74,9 +156,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

let server = Server::builder()
.add_service(FindMeServer::new(meetup_server))
.serve(bind_info);
.serve_with_incoming(listener);

info!("server listening at: {}", local_address);

info!("Started server {}", bind_info);
remove_expired_publishers(publisher_store, scan_interval);
server.await?;

Expand Down Expand Up @@ -128,3 +211,31 @@ pub fn socket_validator(v: String) -> Result<(), String> {
)),
}
}

#[cfg(test)]
mod tests {
use std::os::unix::io::AsRawFd;

fn init() {
let _ = env_logger::builder().is_test(true).try_init();
}

#[test]
fn test_bad_fd() {
let orig_listener = std::net::UdpSocket::bind(("127.0.0.1", 0))
.expect("Unable to make a udp socket bound to localhost");
let fd = orig_listener.as_raw_fd();

let _listener =
super::listener_from_raw(fd).expect_err("Sent invalid socket, recieved success");
}

#[test]
fn test_good_fd() {
let orig_listener = std::net::TcpListener::bind(("127.0.0.1", 0))
.expect("Unable to make a tcp socket bound to localhost");
let fd = orig_listener.as_raw_fd();

let _listener = super::listener_from_raw(fd).expect("Sent valid socket, recieved error");
}
}
40 changes: 34 additions & 6 deletions tests/common/find_service_setup.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,38 @@
use post::find_service::Client;
use std::os::unix::io::AsRawFd;
use std::os::unix::io::RawFd;

///Wraps an external find service process and provides easy access to its functions
pub struct FindService {
_proc: tokio::process::Child,
client: post::find_service::Client,
}

pub async fn retry_client(url: &'static str) -> post::find_service::Client {
fn unset_close_on_exec(fd: RawFd) -> std::io::Result<RawFd> {
use std::io;
let flags = unsafe { libc::fcntl(fd, libc::F_GETFD) };
if flags == -1 {
return Err(io::Error::last_os_error());
}
let was_set = flags & libc::FD_CLOEXEC != 0;
log::info!("State of Close on exec: {}", was_set);
let result = unsafe { libc::fcntl(fd, libc::F_SETFD, flags & !libc::FD_CLOEXEC) };
if result == -1 {
Err(io::Error::last_os_error())
} else {
Ok(fd)
}
}

pub async fn retry_client(url: String) -> post::find_service::Client {
let retries: i32 = 10;
let mut retry: i32 = 0;

loop {
if retry >= retries {
panic!("Retries exceeded");
}
if let Ok(mut client) = post::find_service::Client::from_url(url)
if let Ok(mut client) = post::find_service::Client::from_url(url.clone())
.unwrap()
.set_connect_timeout(std::time::Duration::from_secs(60))
.connect()
Expand All @@ -35,16 +53,26 @@ impl FindService {
pub async fn new() -> FindService {
log::info!("Starting new meetup service");
let path = "target/debug/post-meetup";
let url = "http://127.0.0.1:8080/";
let bind = "127.0.0.1:8080";

let listener = std::net::TcpListener::bind(("127.0.0.1", 0))
.expect("could not reserve address for find service");
let port = listener
.local_addr()
.expect("could not retrieve port from OS for find service")
.port();
let listener_fd =
unset_close_on_exec(listener.as_raw_fd()).expect("could not disable close on exec");

log::info!("meetup server starting on port {}", port);
let url = format!("http://127.0.0.1:{}/", port);

let _proc = tokio::process::Command::new(path)
.arg("-s")
.arg("5")
.arg("-t")
.arg("5")
.arg("--bind")
.arg(bind)
.arg("--fd")
.arg(format!("{}", listener_fd))
.kill_on_drop(true)
.spawn()
.expect("Failed to start meetup");
Expand Down
2 changes: 1 addition & 1 deletion tests/publisher_subscriber.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ async fn publisher_cleanup() {
let desc = post::PublisherDesc {
name: publisher_name.clone(),
host_name: "127.0.0.1".to_string(),
port: 5000,
port: 5001,
subscriber_expiration_interval: std::time::Duration::from_secs(2),
};

Expand Down