1use std::str::FromStr;
2use std::sync::Arc;
3use std::thread::available_parallelism;
4use std::time::Duration;
5
6use log::{debug, error, info, trace};
7use nethsm_sdk_rs::ureq::{Agent, AgentBuilder};
8use rustls::client::{
9 ClientConfig,
10 danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier},
11};
12use rustls::crypto::{CryptoProvider, ring as tls_provider};
13use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
14use rustls::{DigitallySignedStruct, SignatureScheme};
15use serde::{Deserialize, Serialize};
16use sha2::{Digest, Sha256};
17
18use crate::Error;
19#[cfg(doc)]
20use crate::NetHsm;
21
22#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
24pub struct CertFingerprint(
25 #[serde(
26 deserialize_with = "hex::serde::deserialize",
27 serialize_with = "hex::serde::serialize"
28 )]
29 Vec<u8>,
30);
31
32impl FromStr for CertFingerprint {
33 type Err = Error;
34 fn from_str(s: &str) -> Result<Self, Self::Err> {
35 Ok(Self(s.as_bytes().to_vec()))
36 }
37}
38
39impl From<Vec<u8>> for CertFingerprint {
40 fn from(value: Vec<u8>) -> Self {
41 Self(value)
42 }
43}
44
45#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
48pub struct HostCertificateFingerprints {
49 sha256: Option<Vec<CertFingerprint>>,
51}
52
53#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
55pub enum ConnectionSecurity {
56 Unsafe,
58 Native,
60 Fingerprints(HostCertificateFingerprints),
62}
63
64impl FromStr for ConnectionSecurity {
65 type Err = Error;
66
67 fn from_str(s: &str) -> Result<Self, Self::Err> {
95 match s {
96 "unsafe" | "Unsafe" => Ok(Self::Unsafe),
97 "native" | "Native" => Ok(Self::Native),
98 _ => {
99 let sha256_fingerprints: Vec<Vec<u8>> = s
100 .split(',')
101 .filter_map(|checksum| {
102 checksum
103 .strip_prefix("sha256:")
104 .filter(|x| x.len() == 64 && x.chars().all(|x| x.is_ascii_hexdigit()))
105 .map(|checksum| checksum.as_bytes().to_vec())
106 })
107 .collect();
108 if sha256_fingerprints.is_empty() {
109 Err(Error::Default(
110 "No valid TLS certificate fingerprints detected.".to_string(),
111 ))
112 } else {
113 Ok(Self::Fingerprints(HostCertificateFingerprints {
114 sha256: Some(
115 sha256_fingerprints
116 .iter()
117 .map(|checksum| checksum.clone().into())
118 .collect(),
119 ),
120 }))
121 }
122 }
123 }
124 }
125}
126
127#[derive(Debug)]
134pub struct DangerIgnoreVerifier(pub CryptoProvider);
135
136impl ServerCertVerifier for DangerIgnoreVerifier {
137 fn verify_server_cert(
138 &self,
139 _end_entity: &CertificateDer<'_>,
140 _intermediates: &[CertificateDer<'_>],
141 _server_name: &ServerName<'_>,
142 _ocsp_response: &[u8],
143 _now: UnixTime,
144 ) -> Result<ServerCertVerified, rustls::Error> {
145 Ok(ServerCertVerified::assertion())
147 }
148
149 fn verify_tls12_signature(
150 &self,
151 message: &[u8],
152 cert: &CertificateDer<'_>,
153 dss: &DigitallySignedStruct,
154 ) -> Result<HandshakeSignatureValid, rustls::Error> {
155 rustls::crypto::verify_tls12_signature(
156 message,
157 cert,
158 dss,
159 &self.0.signature_verification_algorithms,
160 )
161 }
162
163 fn verify_tls13_signature(
164 &self,
165 message: &[u8],
166 cert: &CertificateDer<'_>,
167 dss: &DigitallySignedStruct,
168 ) -> Result<HandshakeSignatureValid, rustls::Error> {
169 rustls::crypto::verify_tls13_signature(
170 message,
171 cert,
172 dss,
173 &self.0.signature_verification_algorithms,
174 )
175 }
176
177 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
178 self.0.signature_verification_algorithms.supported_schemes()
179 }
180}
181
182#[derive(Debug)]
190pub struct FingerprintVerifier {
191 pub fingerprints: HostCertificateFingerprints,
192 pub provider: CryptoProvider,
193}
194
195impl ServerCertVerifier for FingerprintVerifier {
196 fn verify_server_cert(
197 &self,
198 end_entity: &CertificateDer<'_>,
199 _intermediates: &[CertificateDer<'_>],
200 _server_name: &ServerName<'_>,
201 _ocsp_response: &[u8],
202 _now: UnixTime,
203 ) -> Result<ServerCertVerified, rustls::Error> {
204 if let Some(sha256_fingerprints) = self.fingerprints.sha256.as_ref() {
205 let mut hasher = Sha256::new();
206 hasher.update(end_entity.as_ref());
207 let result = hasher.finalize();
208 for fingerprint in sha256_fingerprints.iter() {
209 if fingerprint.0 == result[..] {
210 trace!("Certificate fingerprint matches");
211 return Ok(ServerCertVerified::assertion());
212 }
213 }
214 } else {
215 return Err(rustls::Error::General(
216 "Could not verify certificate fingerprint as no fingerprints were provided to match against".to_string(),
217 ));
218 }
219 Err(rustls::Error::General(
220 "Could not verify certificate fingerprint".to_string(),
221 ))
222 }
223
224 fn verify_tls12_signature(
225 &self,
226 message: &[u8],
227 cert: &CertificateDer<'_>,
228 dss: &DigitallySignedStruct,
229 ) -> Result<HandshakeSignatureValid, rustls::Error> {
230 rustls::crypto::verify_tls12_signature(
231 message,
232 cert,
233 dss,
234 &self.provider.signature_verification_algorithms,
235 )
236 }
237
238 fn verify_tls13_signature(
239 &self,
240 message: &[u8],
241 cert: &CertificateDer<'_>,
242 dss: &DigitallySignedStruct,
243 ) -> Result<HandshakeSignatureValid, rustls::Error> {
244 rustls::crypto::verify_tls13_signature(
245 message,
246 cert,
247 dss,
248 &self.provider.signature_verification_algorithms,
249 )
250 }
251
252 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
253 self.provider
254 .signature_verification_algorithms
255 .supported_schemes()
256 }
257}
258
259pub(crate) fn create_agent(
276 tls_security: ConnectionSecurity,
277 max_idle_connections: Option<usize>,
278 timeout_seconds: Option<u64>,
279) -> Result<Agent, Error> {
280 let tls_conf = {
281 let tls_conf = ClientConfig::builder_with_provider(Arc::new(CryptoProvider {
282 cipher_suites: tls_provider::ALL_CIPHER_SUITES.into(),
283 ..tls_provider::default_provider()
284 }))
285 .with_protocol_versions(rustls::DEFAULT_VERSIONS)?;
286
287 match tls_security {
288 ConnectionSecurity::Unsafe => {
289 let dangerous = tls_conf.dangerous();
290 dangerous
291 .with_custom_certificate_verifier(Arc::new(DangerIgnoreVerifier(
292 tls_provider::default_provider(),
293 )))
294 .with_no_client_auth()
295 }
296 ConnectionSecurity::Native => {
297 let native_certs = rustls_native_certs::load_native_certs();
298 if !native_certs.errors.is_empty() {
299 return Err(Error::CertLoading(native_certs.errors));
300 }
301 let native_certs = native_certs.certs;
302
303 let roots = {
304 let mut roots = rustls::RootCertStore::empty();
305 let (added, failed) = roots.add_parsable_certificates(native_certs);
306 debug!("Added {added} certificates and failed to parse {failed} certificates");
307 if added == 0 {
308 error!("Added no native certificates");
309 return Err(Error::NoSystemCertsAdded { failed });
310 }
311 roots
312 };
313
314 tls_conf.with_root_certificates(roots).with_no_client_auth()
315 }
316 ConnectionSecurity::Fingerprints(fingerprints) => {
317 let dangerous = tls_conf.dangerous();
318 dangerous
319 .with_custom_certificate_verifier(Arc::new(FingerprintVerifier {
320 fingerprints,
321 provider: tls_provider::default_provider(),
322 }))
323 .with_no_client_auth()
324 }
325 }
326 };
327
328 let max_idle_connections = max_idle_connections
329 .or_else(|| available_parallelism().ok().map(Into::into))
330 .unwrap_or(100);
331 let timeout_seconds = timeout_seconds.unwrap_or(10);
332 info!(
333 "NetHSM connection configured with \"max_idle_connection\" {} and \"timeout_seconds\" {}.",
334 max_idle_connections, timeout_seconds
335 );
336
337 Ok(AgentBuilder::new()
338 .tls_config(Arc::new(tls_conf))
339 .max_idle_connections(max_idle_connections)
340 .max_idle_connections_per_host(max_idle_connections)
341 .timeout_connect(Duration::from_secs(timeout_seconds))
342 .build())
343}
344
345#[cfg(test)]
346mod tests {
347 use rstest::rstest;
348 use testresult::TestResult;
349
350 use super::*;
351
352 #[rstest]
353 #[case("native", Some(ConnectionSecurity::Native))]
354 #[case("unsafe", Some(ConnectionSecurity::Unsafe))]
355 #[case("sha256:324f7bd1530c55cf6812ca6865445de21dfc74cf7a3bb5fae7585e849e3553b7", Some(ConnectionSecurity::Fingerprints(HostCertificateFingerprints { sha256: Some(vec![CertFingerprint::from_str("324f7bd1530c55cf6812ca6865445de21dfc74cf7a3bb5fae7585e849e3553b7")?]) })))]
356 #[case(
357 "324f7bd1530c55cf6812ca6865445de21dfc74cf7a3bb5fae7585e849e3553b7",
358 None
359 )]
360 #[case(
361 "sha256:324f7bd1530c55cf6812ca6865445de21dfc74cf7a3bb5fae7585e849e",
362 None
363 )]
364 #[case(
365 "sha256:324f7bd1530c55cf6812ca6865445de21dfc74cf7a3bb5fae7585e849e3553b73553b7",
366 None
367 )]
368 fn connection_security_fromstr(
369 #[case] input: &str,
370 #[case] expected: Option<ConnectionSecurity>,
371 ) -> TestResult {
372 if let Some(expected) = expected {
373 assert_eq!(ConnectionSecurity::from_str(input)?, expected);
374 } else {
375 assert!(ConnectionSecurity::from_str(input).is_err());
376 }
377 Ok(())
378 }
379}