1use std::sync::Arc;
2use std::thread::available_parallelism;
3use std::time::Duration;
4use std::{fmt::Display, str::FromStr};
5
6use log::{debug, error, info, trace};
7use nethsm_sdk_rs::ureq::{Agent, AgentBuilder};
8use rustls::client::{
9 ClientConfig,
10 danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier},
11};
12use rustls::crypto::{CryptoProvider, ring as tls_provider};
13use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
14use rustls::{DigitallySignedStruct, SignatureScheme};
15use serde::{Deserialize, Serialize};
16use sha2::{Digest, Sha256};
17
18use crate::Error;
19#[cfg(doc)]
20use crate::NetHsm;
21
22pub const DEFAULT_MAX_IDLE_CONNECTIONS: usize = 100;
24
25pub const DEFAULT_TIMEOUT_SECONDS: u64 = 10;
27
28#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
30pub struct CertFingerprint(
31 #[serde(
32 deserialize_with = "hex::serde::deserialize",
33 serialize_with = "hex::serde::serialize"
34 )]
35 Vec<u8>,
36);
37
38impl Display for CertFingerprint {
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 for byte in self.0.iter() {
41 write!(f, "{byte:02x?}")?
42 }
43 Ok(())
44 }
45}
46
47impl FromStr for CertFingerprint {
48 type Err = Error;
49 fn from_str(s: &str) -> Result<Self, Self::Err> {
50 Ok(Self(s.as_bytes().to_vec()))
51 }
52}
53
54impl From<Vec<u8>> for CertFingerprint {
55 fn from(value: Vec<u8>) -> Self {
56 Self(value)
57 }
58}
59
60#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
63pub struct HostCertificateFingerprints {
64 sha256: Option<Vec<CertFingerprint>>,
66}
67
68impl Display for HostCertificateFingerprints {
69 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70 write!(
71 f,
72 "{}",
73 if let Some(fingerprints) = self.sha256.as_ref() {
74 if fingerprints.is_empty() {
75 "n/a".to_string()
76 } else {
77 fingerprints
78 .iter()
79 .map(|fingerprint| format!("sha256:{fingerprint}"))
80 .collect::<Vec<String>>()
81 .join("\n")
82 }
83 } else {
84 "n/a".to_string()
85 }
86 )
87 }
88}
89
90#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
92pub enum ConnectionSecurity {
93 Unsafe,
95 Native,
97 Fingerprints(HostCertificateFingerprints),
99}
100
101impl Display for ConnectionSecurity {
102 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103 match self {
104 Self::Unsafe => write!(f, "unsafe"),
105 Self::Native => write!(f, "native"),
106 Self::Fingerprints(fingerprints) => write!(f, "{fingerprints}"),
107 }
108 }
109}
110
111impl FromStr for ConnectionSecurity {
112 type Err = Error;
113
114 fn from_str(s: &str) -> Result<Self, Self::Err> {
142 match s {
143 "unsafe" | "Unsafe" => Ok(Self::Unsafe),
144 "native" | "Native" => Ok(Self::Native),
145 _ => {
146 let sha256_fingerprints: Vec<Vec<u8>> = s
147 .split(',')
148 .filter_map(|checksum| {
149 checksum
150 .strip_prefix("sha256:")
151 .filter(|x| x.len() == 64 && x.chars().all(|x| x.is_ascii_hexdigit()))
152 .map(|checksum| checksum.as_bytes().to_vec())
153 })
154 .collect();
155 if sha256_fingerprints.is_empty() {
156 Err(Error::Default(
157 "No valid TLS certificate fingerprints detected.".to_string(),
158 ))
159 } else {
160 Ok(Self::Fingerprints(HostCertificateFingerprints {
161 sha256: Some(
162 sha256_fingerprints
163 .iter()
164 .map(|checksum| checksum.clone().into())
165 .collect(),
166 ),
167 }))
168 }
169 }
170 }
171 }
172}
173
174#[derive(Debug)]
181pub struct DangerIgnoreVerifier(pub CryptoProvider);
182
183impl ServerCertVerifier for DangerIgnoreVerifier {
184 fn verify_server_cert(
185 &self,
186 _end_entity: &CertificateDer<'_>,
187 _intermediates: &[CertificateDer<'_>],
188 _server_name: &ServerName<'_>,
189 _ocsp_response: &[u8],
190 _now: UnixTime,
191 ) -> Result<ServerCertVerified, rustls::Error> {
192 Ok(ServerCertVerified::assertion())
194 }
195
196 fn verify_tls12_signature(
197 &self,
198 message: &[u8],
199 cert: &CertificateDer<'_>,
200 dss: &DigitallySignedStruct,
201 ) -> Result<HandshakeSignatureValid, rustls::Error> {
202 rustls::crypto::verify_tls12_signature(
203 message,
204 cert,
205 dss,
206 &self.0.signature_verification_algorithms,
207 )
208 }
209
210 fn verify_tls13_signature(
211 &self,
212 message: &[u8],
213 cert: &CertificateDer<'_>,
214 dss: &DigitallySignedStruct,
215 ) -> Result<HandshakeSignatureValid, rustls::Error> {
216 rustls::crypto::verify_tls13_signature(
217 message,
218 cert,
219 dss,
220 &self.0.signature_verification_algorithms,
221 )
222 }
223
224 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
225 self.0.signature_verification_algorithms.supported_schemes()
226 }
227}
228
229#[derive(Debug)]
237pub struct FingerprintVerifier {
238 pub fingerprints: HostCertificateFingerprints,
239 pub provider: CryptoProvider,
240}
241
242impl ServerCertVerifier for FingerprintVerifier {
243 fn verify_server_cert(
244 &self,
245 end_entity: &CertificateDer<'_>,
246 _intermediates: &[CertificateDer<'_>],
247 _server_name: &ServerName<'_>,
248 _ocsp_response: &[u8],
249 _now: UnixTime,
250 ) -> Result<ServerCertVerified, rustls::Error> {
251 if let Some(sha256_fingerprints) = self.fingerprints.sha256.as_ref() {
252 let mut hasher = Sha256::new();
253 hasher.update(end_entity.as_ref());
254 let result = hasher.finalize();
255 for fingerprint in sha256_fingerprints.iter() {
256 if fingerprint.0 == result[..] {
257 trace!("Certificate fingerprint matches");
258 return Ok(ServerCertVerified::assertion());
259 }
260 }
261 } else {
262 return Err(rustls::Error::General(
263 "Could not verify certificate fingerprint as no fingerprints were provided to match against".to_string(),
264 ));
265 }
266 Err(rustls::Error::General(
267 "Could not verify certificate fingerprint".to_string(),
268 ))
269 }
270
271 fn verify_tls12_signature(
272 &self,
273 message: &[u8],
274 cert: &CertificateDer<'_>,
275 dss: &DigitallySignedStruct,
276 ) -> Result<HandshakeSignatureValid, rustls::Error> {
277 rustls::crypto::verify_tls12_signature(
278 message,
279 cert,
280 dss,
281 &self.provider.signature_verification_algorithms,
282 )
283 }
284
285 fn verify_tls13_signature(
286 &self,
287 message: &[u8],
288 cert: &CertificateDer<'_>,
289 dss: &DigitallySignedStruct,
290 ) -> Result<HandshakeSignatureValid, rustls::Error> {
291 rustls::crypto::verify_tls13_signature(
292 message,
293 cert,
294 dss,
295 &self.provider.signature_verification_algorithms,
296 )
297 }
298
299 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
300 self.provider
301 .signature_verification_algorithms
302 .supported_schemes()
303 }
304}
305
306pub(crate) fn create_agent(
323 tls_security: ConnectionSecurity,
324 max_idle_connections: Option<usize>,
325 timeout_seconds: Option<u64>,
326) -> Result<Agent, Error> {
327 let tls_conf = {
328 let tls_conf = ClientConfig::builder_with_provider(Arc::new(CryptoProvider {
329 cipher_suites: tls_provider::ALL_CIPHER_SUITES.into(),
330 ..tls_provider::default_provider()
331 }))
332 .with_protocol_versions(rustls::DEFAULT_VERSIONS)?;
333
334 match tls_security {
335 ConnectionSecurity::Unsafe => {
336 let dangerous = tls_conf.dangerous();
337 dangerous
338 .with_custom_certificate_verifier(Arc::new(DangerIgnoreVerifier(
339 tls_provider::default_provider(),
340 )))
341 .with_no_client_auth()
342 }
343 ConnectionSecurity::Native => {
344 let native_certs = rustls_native_certs::load_native_certs();
345 if !native_certs.errors.is_empty() {
346 return Err(Error::CertLoading(native_certs.errors));
347 }
348 let native_certs = native_certs.certs;
349
350 let roots = {
351 let mut roots = rustls::RootCertStore::empty();
352 let (added, failed) = roots.add_parsable_certificates(native_certs);
353 debug!("Added {added} certificates and failed to parse {failed} certificates");
354 if added == 0 {
355 error!("Added no native certificates");
356 return Err(Error::NoSystemCertsAdded { failed });
357 }
358 roots
359 };
360
361 tls_conf.with_root_certificates(roots).with_no_client_auth()
362 }
363 ConnectionSecurity::Fingerprints(fingerprints) => {
364 let dangerous = tls_conf.dangerous();
365 dangerous
366 .with_custom_certificate_verifier(Arc::new(FingerprintVerifier {
367 fingerprints,
368 provider: tls_provider::default_provider(),
369 }))
370 .with_no_client_auth()
371 }
372 }
373 };
374
375 let max_idle_connections = max_idle_connections
376 .or_else(|| available_parallelism().ok().map(Into::into))
377 .unwrap_or(DEFAULT_MAX_IDLE_CONNECTIONS);
378 let timeout_seconds = timeout_seconds.unwrap_or(DEFAULT_TIMEOUT_SECONDS);
379 info!(
380 "NetHSM connection configured with \"max_idle_connection\" {max_idle_connections} and \"timeout_seconds\" {timeout_seconds}."
381 );
382
383 Ok(AgentBuilder::new()
384 .tls_config(Arc::new(tls_conf))
385 .max_idle_connections(max_idle_connections)
386 .max_idle_connections_per_host(max_idle_connections)
387 .timeout_connect(Duration::from_secs(timeout_seconds))
388 .build())
389}
390
391#[cfg(test)]
392mod tests {
393 use rstest::rstest;
394 use testresult::TestResult;
395
396 use super::*;
397
398 #[test]
399 fn certfingerprint_display() -> TestResult {
400 let digest = vec![
402 181, 187, 157, 128, 20, 160, 249, 177, 214, 30, 33, 231, 150, 215, 141, 204, 223, 19,
403 82, 242, 60, 211, 40, 18, 244, 133, 11, 135, 138, 228, 148, 76,
404 ];
405 let cert_fingerprint = CertFingerprint::from(digest);
406
407 assert_eq!(
408 cert_fingerprint.to_string(),
409 "b5bb9d8014a0f9b1d61e21e796d78dccdf1352f23cd32812f4850b878ae4944c"
410 );
411
412 Ok(())
413 }
414
415 #[rstest]
416 #[case(HostCertificateFingerprints { sha256: Some(vec![CertFingerprint::from(vec![
417 181, 187, 157, 128, 20, 160, 249, 177, 214, 30, 33, 231, 150, 215, 141, 204, 223, 19,
418 82, 242, 60, 211, 40, 18, 244, 133, 11, 135, 138, 228, 148, 76,
419 ])]) }, "sha256:b5bb9d8014a0f9b1d61e21e796d78dccdf1352f23cd32812f4850b878ae4944c")]
420 #[case(HostCertificateFingerprints { sha256: Some(Vec::new()) }, "n/a")]
421 #[case(HostCertificateFingerprints { sha256: None }, "n/a")]
422 fn hostcertfingerprints_display(
423 #[case] fingerprints: HostCertificateFingerprints,
424 #[case] expected: &str,
425 ) -> TestResult {
426 assert_eq!(fingerprints.to_string(), expected);
427 Ok(())
428 }
429
430 #[rstest]
431 #[case(ConnectionSecurity::Native, "native")]
432 #[case(ConnectionSecurity::Unsafe, "unsafe")]
433 #[case(ConnectionSecurity::Fingerprints(HostCertificateFingerprints { sha256: Some(vec![CertFingerprint::from(vec![
434 181, 187, 157, 128, 20, 160, 249, 177, 214, 30, 33, 231, 150, 215, 141, 204, 223, 19,
435 82, 242, 60, 211, 40, 18, 244, 133, 11, 135, 138, 228, 148, 76,
436 ])]) }), "sha256:b5bb9d8014a0f9b1d61e21e796d78dccdf1352f23cd32812f4850b878ae4944c")]
437 fn connectionsecurity_display(
438 #[case] connection_security: ConnectionSecurity,
439 #[case] expected: &str,
440 ) -> TestResult {
441 assert_eq!(connection_security.to_string(), expected);
442 Ok(())
443 }
444
445 #[rstest]
446 #[case("native", Some(ConnectionSecurity::Native))]
447 #[case("unsafe", Some(ConnectionSecurity::Unsafe))]
448 #[case("sha256:324f7bd1530c55cf6812ca6865445de21dfc74cf7a3bb5fae7585e849e3553b7", Some(ConnectionSecurity::Fingerprints(HostCertificateFingerprints { sha256: Some(vec![CertFingerprint::from_str("324f7bd1530c55cf6812ca6865445de21dfc74cf7a3bb5fae7585e849e3553b7")?]) })))]
449 #[case(
450 "324f7bd1530c55cf6812ca6865445de21dfc74cf7a3bb5fae7585e849e3553b7",
451 None
452 )]
453 #[case(
454 "sha256:324f7bd1530c55cf6812ca6865445de21dfc74cf7a3bb5fae7585e849e",
455 None
456 )]
457 #[case(
458 "sha256:324f7bd1530c55cf6812ca6865445de21dfc74cf7a3bb5fae7585e849e3553b73553b7",
459 None
460 )]
461 fn connection_security_fromstr(
462 #[case] input: &str,
463 #[case] expected: Option<ConnectionSecurity>,
464 ) -> TestResult {
465 if let Some(expected) = expected {
466 assert_eq!(ConnectionSecurity::from_str(input)?, expected);
467 } else {
468 assert!(ConnectionSecurity::from_str(input).is_err());
469 }
470 Ok(())
471 }
472}