Skip to content

Commit

Permalink
Add support for unix domain sockets
Browse files Browse the repository at this point in the history
  • Loading branch information
argerus committed Nov 14, 2024
1 parent fd49ba8 commit 1838a30
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 26 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ tokio = { version = "1.20", features = ["full"] }
tokio-stream = { version = "0.1.8" }
tonic = { version = "0.11.0", default-features = false }
tonic-build = { version = "0.11.0", default-features = false }
tower = { version = "0.4" }
clap = { version = "4.2", features = [
"std",
"env",
Expand Down
5 changes: 5 additions & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ struct Args {
#[clap(long, display_order = 4, default_value_t = 55555)]
port: u64,

/// Unix socket path of databroker
#[clap(long = "unix-socket", display_order = 5)]
unix_socket_path: Option<String>,

/// Seconds to run (skip) before measuring the latency.
#[clap(long, display_order = 5, value_name = "SECONDS")]
skip_seconds: Option<u64>,
Expand Down Expand Up @@ -136,6 +140,7 @@ async fn main() -> Result<()> {
let measurement_config = MeasurementConfig {
host: args.host,
port: args.port,
unix_socket_path: args.unix_socket_path,
duration: args.duration,
interval: 0,
skip_seconds: args.skip_seconds,
Expand Down
68 changes: 42 additions & 26 deletions src/measure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,18 @@ use indicatif::{ProgressBar, ProgressStyle};
use log::error;
use std::collections::HashMap;
use std::fmt;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::{
sync::atomic::Ordering,
time::{Duration, SystemTime},
};
use tokio::net::UnixStream;
use tokio::sync::{mpsc::Sender, RwLock};
use tokio::task;
use tokio::{select, task::JoinSet, time::Instant};
use tonic::transport::Endpoint;
use tonic::transport::Channel;
use tower::service_fn;

#[derive(Clone, PartialEq)]
pub enum Api {
Expand All @@ -54,6 +57,7 @@ pub struct Provider {
pub struct MeasurementConfig {
pub host: String,
pub port: u64,
pub unix_socket_path: Option<String>,
pub duration: Option<u64>,
pub interval: u16,
pub skip_seconds: Option<u64>,
Expand Down Expand Up @@ -89,29 +93,31 @@ impl fmt::Display for Api {
}
}

async fn setup_subscriber(
endpoint: &Endpoint,
async fn create_subscriber(
channel: Channel,
signals: Vec<Signal>,
api: &Api,
initial_values_sender: Sender<HashMap<String, DataValue>>,
) -> Result<Subscriber> {
let subscriber_channel = endpoint.connect().await.with_context(|| {
let host = endpoint.uri().host().unwrap_or("unknown host");
let port = endpoint
.uri()
.port()
.map_or("unknown port".to_string(), |p| p.to_string());
format!("Failed to connect to server {}:{}", host, port)
})?;

let subscriber =
subscriber::Subscriber::new(subscriber_channel, signals, api, initial_values_sender)
.await?;
subscriber::Subscriber::new(channel, signals, api, initial_values_sender).await?;

Ok(subscriber)
}

fn create_databroker_endpoint(host: String, port: u64) -> Result<Endpoint> {
async fn create_unix_socket_channel(path: impl AsRef<Path>) -> Result<Channel> {
let path_buf = PathBuf::from(path.as_ref());
tonic::transport::Endpoint::try_from("http://[::]:50051")?
.connect_with_connector(service_fn(move |_| {
let path = path_buf.clone();
// Connect to a unix socket
UnixStream::connect(path)
}))
.await
.with_context(|| format!("Failed to connect to server {}", path.as_ref().display()))
}

async fn create_tcp_channel(host: String, port: u64) -> Result<Channel> {
let databroker_address = format!("{}:{}", host, port);

let endpoint = tonic::transport::Channel::from_shared(databroker_address.clone())
Expand All @@ -124,10 +130,6 @@ fn create_databroker_endpoint(host: String, port: u64) -> Result<Endpoint> {
.keep_alive_timeout(Duration::from_secs(1))
.timeout(Duration::from_secs(1));

Ok(endpoint)
}

async fn create_provider(endpoint: &Endpoint, api: &Api) -> Result<Provider> {
let channel = endpoint.connect().await.with_context(|| {
let host = endpoint.uri().host().unwrap_or("unknown host");
let port = endpoint
Expand All @@ -137,6 +139,10 @@ async fn create_provider(endpoint: &Endpoint, api: &Api) -> Result<Provider> {
format!("Failed to connect to server {}:{}", host, port)
})?;

Ok(channel)
}

fn create_provider(channel: Channel, api: &Api) -> Result<Provider> {
if *api == Api::KuksaValV2 {
let provider =
kuksa_val_v2::Provider::new(channel).with_context(|| "Failed to setup provider")?;
Expand All @@ -163,11 +169,19 @@ pub async fn perform_measurement(
config_groups: Vec<Group>,
shutdown_handler: ShutdownHandler,
) -> Result<()> {
let provider_endpoint =
create_databroker_endpoint(measurement_config.host.clone(), measurement_config.port)?;
let provider_channel = match measurement_config.unix_socket_path {
Some(ref path) => create_unix_socket_channel(path).await?,
None => {
create_tcp_channel(measurement_config.host.clone(), measurement_config.port).await?
}
};

let subscriber_endpoint =
create_databroker_endpoint(measurement_config.host.clone(), measurement_config.port)?;
let subscriber_channel = match measurement_config.unix_socket_path {
Some(ref path) => create_unix_socket_channel(path).await?,
None => {
create_tcp_channel(measurement_config.host.clone(), measurement_config.port).await?
}
};

// Create references to be used among tokio::tasks
let shutdown_handler_ref = Arc::new(RwLock::new(shutdown_handler));
Expand All @@ -176,8 +190,9 @@ pub async fn perform_measurement(
let mut tasks: JoinSet<Result<MeasurementResult>> = JoinSet::new();

for group in config_groups.clone() {
let provider_channel = provider_channel.clone();
// Initialize provider
let mut provider = create_provider(&provider_endpoint, &measurement_config.api).await?;
let mut provider = create_provider(provider_channel, &measurement_config.api)?;

// Validate metadata signals
let signals = provider
Expand All @@ -191,8 +206,9 @@ pub async fn perform_measurement(
let (initial_values_sender, mut initial_values_reciever) =
tokio::sync::mpsc::channel::<HashMap<String, DataValue>>(10);

let subscriber = setup_subscriber(
&subscriber_endpoint,
let subscriber_channel = subscriber_channel.clone();
let subscriber = create_subscriber(
subscriber_channel,
signals,
&measurement_config.api,
initial_values_sender,
Expand Down

0 comments on commit 1838a30

Please sign in to comment.