Skip to content

Commit 6c651d8

Browse files
authored
tls: bug fix and enhancement (#91)
1 parent 270aa3a commit 6c651d8

File tree

8 files changed

+73
-41
lines changed

8 files changed

+73
-41
lines changed

Cargo.lock

+5-5
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/hstreamdb/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[package]
22
edition = "2021"
33
name = "hstreamdb"
4-
version = "0.2.0"
4+
version = "0.2.1"
55

66
license = "BSD-3-Clause"
77
description = "Rust client library for HStreamDB"

src/hstreamdb/src/channel_provider.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@ use std::iter::FromIterator;
55
use hstreamdb_pb::h_stream_api_client::HStreamApiClient;
66
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
77
use tokio::sync::oneshot;
8-
use tonic::transport::{Channel, ClientTlsConfig, Endpoint};
8+
use tonic::transport::{Channel, Endpoint};
99

1010
use crate::client::get_available_node_addrs;
1111
use crate::common;
12+
use crate::tls::ClientTlsConfig;
1213

1314
#[derive(Debug)]
1415
pub(crate) struct Request(

src/hstreamdb/src/client.rs

+29-33
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@ use hstreamdb_pb::{
55
GetSubscriptionRequest, ListConsumersRequest, ListStreamsRequest, ListSubscriptionsRequest,
66
LookupSubscriptionRequest, NodeState,
77
};
8-
use tonic::transport::{Channel, ClientTlsConfig};
8+
use tonic::transport::{Channel, Endpoint};
99
use tonic::Request;
1010
use url::Url;
1111

1212
use crate::appender::Appender;
1313
use crate::channel_provider::{new_channel_provider, ChannelProviderSettings, Channels};
1414
use crate::common::Error::PBUnwrapError;
1515
use crate::producer::{FlushCallback, FlushSettings, Producer};
16+
use crate::tls::ClientTlsConfig;
1617
use crate::{common, flow_controller, format_url, producer};
1718

1819
pub struct Client {
@@ -29,41 +30,29 @@ impl Client {
2930
where
3031
Destination: std::convert::Into<String>,
3132
{
32-
const HSTREAM_PREFIX: &str = "hstream";
33-
let server_url = server_url.into();
34-
let (url_scheme, url) = {
35-
let url = {
36-
let mut url = Url::parse(&server_url)?;
37-
if url.port().is_none() {
38-
url.set_port(Some(6570))
39-
.map_err(|()| common::Error::SetPortError(server_url.to_string()))?;
40-
}
41-
url
42-
};
43-
44-
if url.scheme() == HSTREAM_PREFIX {
45-
let url_scheme = if channel_provider_settings.client_tls_config.is_none() {
46-
"http"
47-
} else {
48-
"https"
49-
};
50-
let server_url = &server_url[7..];
51-
(
52-
url_scheme.to_string(),
53-
Url::parse(format!("{url_scheme}{server_url}").as_str())?,
54-
)
55-
} else if url.scheme() == "hstreams" {
56-
let url_scheme = "https";
57-
(
58-
url_scheme.to_string(),
59-
Url::parse(format!("{url_scheme}{server_url}").as_str())?,
60-
)
61-
} else {
62-
(url.scheme().to_string(), url)
33+
let server_url: String = server_url.into();
34+
Url::parse(&server_url)?;
35+
let server_url = set_scheme(&server_url).ok_or(common::Error::InvalidUrl(server_url))?;
36+
let (url_scheme, server_url) = {
37+
let mut server_url = Url::parse(&server_url)?;
38+
let port = server_url.port();
39+
if port.is_none() {
40+
server_url
41+
.set_port(Some(6570))
42+
.map_err(|()| common::Error::InvalidUrl(server_url.to_string()))?;
6343
}
44+
(server_url.scheme().to_string(), server_url)
6445
};
65-
let mut hstream_api_client = HStreamApiClient::connect(String::from(url)).await?;
46+
47+
log::debug!("client init connect: scheme = {url_scheme}, url = {server_url}");
6648
let tls_config = channel_provider_settings.client_tls_config.clone();
49+
let mut hstream_api_client = HStreamApiClient::new({
50+
let mut endpoint = Endpoint::new(server_url.to_string())?;
51+
if let Some(tls_config) = tls_config.clone() {
52+
endpoint = endpoint.tls_config(tls_config)?;
53+
}
54+
endpoint.connect().await?
55+
});
6756
let channels = new_channel_provider(
6857
&url_scheme,
6958
&mut hstream_api_client,
@@ -78,6 +67,13 @@ impl Client {
7867
}
7968
}
8069

70+
fn set_scheme(url: &str) -> Option<String> {
71+
Some(
72+
url.replace("hstream://", "http://")
73+
.replace("hstreams://", "https://"),
74+
)
75+
}
76+
8177
impl Client {
8278
async fn new_channel_provider(
8379
&self,

src/hstreamdb/src/common.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ pub enum Error {
9595
#[error(transparent)]
9696
AppenderSendError(producer::SendError),
9797
#[error("the URL {0} is cannot-be-a-base, or does not have a host, or has the file scheme")]
98-
SetPortError(String),
98+
InvalidUrl(String),
9999
}
100100

101101
#[derive(Debug, thiserror::Error)]

src/hstreamdb/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ pub mod consumer;
88
mod flow_controller;
99
pub mod producer;
1010
pub mod reader;
11+
pub mod tls;
1112
pub mod utils;
1213

1314
pub use channel_provider::ChannelProviderSettings;

src/hstreamdb/src/tls.rs

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
pub use tonic::transport::channel::ClientTlsConfig;
2+
pub use tonic::transport::{Certificate, Identity};

src/hstreamdb/tests/tls_test.rs

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// use std::{env, fs};
2+
3+
// use hstreamdb::tls::{Certificate, ClientTlsConfig, Identity};
4+
// use hstreamdb::{ChannelProviderSettings, Client};
5+
6+
// async fn _test_tls_impl() {
7+
// env::set_var("RUST_LOG", "DEBUG");
8+
// env_logger::init();
9+
10+
// let server_url: &str = todo!();
11+
// let tls_dir: &str = todo!();
12+
13+
// let ca_certificate =
14+
// Certificate::from_pem(fs::read(format!("{tls_dir}/root_ca.crt")).unwrap());
15+
// let cert = fs::read(format!("{tls_dir}/client.crt")).unwrap();
16+
// let key = fs::read(format!("{tls_dir}/client.key")).unwrap();
17+
18+
// let client = Client::new(
19+
// server_url,
20+
// ChannelProviderSettings::builder()
21+
// .set_tls_config(
22+
// ClientTlsConfig::new()
23+
// .ca_certificate(ca_certificate)
24+
// .identity(Identity::from_pem(cert, key)),
25+
// )
26+
// .build(),
27+
// )
28+
// .await
29+
// .unwrap();
30+
31+
// log::info!("{:?}", client.list_streams().await.unwrap());
32+
// }

0 commit comments

Comments
 (0)