nethsm/
tls.rs

1use std::sync::Arc;
2use std::thread::available_parallelism;
3use std::time::Duration;
4use std::{fmt::Display, str::FromStr};
5
6use log::{debug, error, info, trace};
7use nethsm_sdk_rs::ureq::{Agent, AgentBuilder};
8use rustls::client::{
9    ClientConfig,
10    danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier},
11};
12use rustls::crypto::{CryptoProvider, ring as tls_provider};
13use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
14use rustls::{DigitallySignedStruct, SignatureScheme};
15use serde::{Deserialize, Serialize};
16use sha2::{Digest, Sha256};
17
18use crate::Error;
19#[cfg(doc)]
20use crate::NetHsm;
21
22/// The default maximum idle TLS connections for a [`NetHsm`].
23pub const DEFAULT_MAX_IDLE_CONNECTIONS: usize = 100;
24
25/// The default timeout in seconds for a TLS connections for a [`NetHsm`].
26pub const DEFAULT_TIMEOUT_SECONDS: u64 = 10;
27
28/// The fingerprint of a TLS certificate (as hex)
29#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
30pub struct CertFingerprint(
31    #[serde(
32        deserialize_with = "hex::serde::deserialize",
33        serialize_with = "hex::serde::serialize"
34    )]
35    Vec<u8>,
36);
37
38impl Display for CertFingerprint {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        for byte in self.0.iter() {
41            write!(f, "{byte:02x?}")?
42        }
43        Ok(())
44    }
45}
46
47impl FromStr for CertFingerprint {
48    type Err = Error;
49    fn from_str(s: &str) -> Result<Self, Self::Err> {
50        Ok(Self(s.as_bytes().to_vec()))
51    }
52}
53
54impl From<Vec<u8>> for CertFingerprint {
55    fn from(value: Vec<u8>) -> Self {
56        Self(value)
57    }
58}
59
60/// Certificate fingerprints to use for matching against a host's TLS
61/// certificate
62#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
63pub struct HostCertificateFingerprints {
64    /// An optional list of SHA-256 checksums
65    sha256: Option<Vec<CertFingerprint>>,
66}
67
68impl Display for HostCertificateFingerprints {
69    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70        write!(
71            f,
72            "{}",
73            if let Some(fingerprints) = self.sha256.as_ref() {
74                if fingerprints.is_empty() {
75                    "n/a".to_string()
76                } else {
77                    fingerprints
78                        .iter()
79                        .map(|fingerprint| format!("sha256:{fingerprint}"))
80                        .collect::<Vec<String>>()
81                        .join("\n")
82                }
83            } else {
84                "n/a".to_string()
85            }
86        )
87    }
88}
89
90/// The security model chosen for a [`crate::NetHsm`]'s TLS connection
91#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
92pub enum ConnectionSecurity {
93    /// Always trust the TLS certificate associated with a host
94    Unsafe,
95    /// Use the native trust store to evaluate the trust of a host
96    Native,
97    /// Use a list of checksums (fingerprints) to verify a host's TLS certificate
98    Fingerprints(HostCertificateFingerprints),
99}
100
101impl Display for ConnectionSecurity {
102    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103        match self {
104            Self::Unsafe => write!(f, "unsafe"),
105            Self::Native => write!(f, "native"),
106            Self::Fingerprints(fingerprints) => write!(f, "{fingerprints}"),
107        }
108    }
109}
110
111impl FromStr for ConnectionSecurity {
112    type Err = Error;
113
114    /// Create a ConnectionSecurity from string
115    ///
116    /// Valid inputs are either "Unsafe" (or "unsafe"), "Native" (or "native") or "sha256:checksum"
117    /// where "checksum" denotes 64 ASCII hexadecimal chars.
118    ///
119    /// # Errors
120    ///
121    /// Returns an [`Error`] if the input is neither "Unsafe" nor "Native" and also no valid
122    /// certificate fingerprint can be derived from the input.
123    ///
124    /// # Examples
125    ///
126    /// ```
127    /// use std::str::FromStr;
128    ///
129    /// use nethsm::ConnectionSecurity;
130    ///
131    /// assert!(ConnectionSecurity::from_str("unsafe").is_ok());
132    /// assert!(ConnectionSecurity::from_str("native").is_ok());
133    /// assert!(
134    ///     ConnectionSecurity::from_str(
135    ///         "sha256:324f7bd1530c55cf6812ca6865445de21dfc74cf7a3bb5fae7585e849e3553b7"
136    ///     )
137    ///     .is_ok()
138    /// );
139    /// assert!(ConnectionSecurity::from_str("something").is_err());
140    /// ```
141    fn from_str(s: &str) -> Result<Self, Self::Err> {
142        match s {
143            "unsafe" | "Unsafe" => Ok(Self::Unsafe),
144            "native" | "Native" => Ok(Self::Native),
145            _ => {
146                let sha256_fingerprints: Vec<Vec<u8>> = s
147                    .split(',')
148                    .filter_map(|checksum| {
149                        checksum
150                            .strip_prefix("sha256:")
151                            .filter(|x| x.len() == 64 && x.chars().all(|x| x.is_ascii_hexdigit()))
152                            .map(|checksum| checksum.as_bytes().to_vec())
153                    })
154                    .collect();
155                if sha256_fingerprints.is_empty() {
156                    Err(Error::Default(
157                        "No valid TLS certificate fingerprints detected.".to_string(),
158                    ))
159                } else {
160                    Ok(Self::Fingerprints(HostCertificateFingerprints {
161                        sha256: Some(
162                            sha256_fingerprints
163                                .iter()
164                                .map(|checksum| checksum.clone().into())
165                                .collect(),
166                        ),
167                    }))
168                }
169            }
170        }
171    }
172}
173
174/// A verifier for server certificates that always accepts them
175///
176/// This verifier is used when choosing [`ConnectionSecurity::Unsafe`]. It is **unsafe** and should
177/// not be used unless for initial setup scenarios of a NetHSM! Instead use [`FingerprintVerifier`]
178/// (selected by [`ConnectionSecurity::Fingerprints`]) or better yet rely on
179/// [`ConnectionSecurity::Native`].
180#[derive(Debug)]
181pub struct DangerIgnoreVerifier(pub CryptoProvider);
182
183impl ServerCertVerifier for DangerIgnoreVerifier {
184    fn verify_server_cert(
185        &self,
186        _end_entity: &CertificateDer<'_>,
187        _intermediates: &[CertificateDer<'_>],
188        _server_name: &ServerName<'_>,
189        _ocsp_response: &[u8],
190        _now: UnixTime,
191    ) -> Result<ServerCertVerified, rustls::Error> {
192        // always accept the certificate
193        Ok(ServerCertVerified::assertion())
194    }
195
196    fn verify_tls12_signature(
197        &self,
198        message: &[u8],
199        cert: &CertificateDer<'_>,
200        dss: &DigitallySignedStruct,
201    ) -> Result<HandshakeSignatureValid, rustls::Error> {
202        rustls::crypto::verify_tls12_signature(
203            message,
204            cert,
205            dss,
206            &self.0.signature_verification_algorithms,
207        )
208    }
209
210    fn verify_tls13_signature(
211        &self,
212        message: &[u8],
213        cert: &CertificateDer<'_>,
214        dss: &DigitallySignedStruct,
215    ) -> Result<HandshakeSignatureValid, rustls::Error> {
216        rustls::crypto::verify_tls13_signature(
217            message,
218            cert,
219            dss,
220            &self.0.signature_verification_algorithms,
221        )
222    }
223
224    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
225        self.0.signature_verification_algorithms.supported_schemes()
226    }
227}
228
229/// A verifier for server certificates that verifies them based on fingerprints
230///
231/// This verifier is selected when using [`ConnectionSecurity::Fingerprints`] and relies on
232/// [`HostCertificateFingerprints`] to be able to match a host certificate fingerprint against a
233/// predefined list of fingerprints. It should be preferred over the use of [`DangerIgnoreVerifier`]
234/// (selected by [`ConnectionSecurity::Unsafe`]), but ideally a setup should make use of
235/// [`ConnectionSecurity::Native`] instead!
236#[derive(Debug)]
237pub struct FingerprintVerifier {
238    pub fingerprints: HostCertificateFingerprints,
239    pub provider: CryptoProvider,
240}
241
242impl ServerCertVerifier for FingerprintVerifier {
243    fn verify_server_cert(
244        &self,
245        end_entity: &CertificateDer<'_>,
246        _intermediates: &[CertificateDer<'_>],
247        _server_name: &ServerName<'_>,
248        _ocsp_response: &[u8],
249        _now: UnixTime,
250    ) -> Result<ServerCertVerified, rustls::Error> {
251        if let Some(sha256_fingerprints) = self.fingerprints.sha256.as_ref() {
252            let mut hasher = Sha256::new();
253            hasher.update(end_entity.as_ref());
254            let result = hasher.finalize();
255            for fingerprint in sha256_fingerprints.iter() {
256                if fingerprint.0 == result[..] {
257                    trace!("Certificate fingerprint matches");
258                    return Ok(ServerCertVerified::assertion());
259                }
260            }
261        } else {
262            return Err(rustls::Error::General(
263                "Could not verify certificate fingerprint as no fingerprints were provided to match against".to_string(),
264            ));
265        }
266        Err(rustls::Error::General(
267            "Could not verify certificate fingerprint".to_string(),
268        ))
269    }
270
271    fn verify_tls12_signature(
272        &self,
273        message: &[u8],
274        cert: &CertificateDer<'_>,
275        dss: &DigitallySignedStruct,
276    ) -> Result<HandshakeSignatureValid, rustls::Error> {
277        rustls::crypto::verify_tls12_signature(
278            message,
279            cert,
280            dss,
281            &self.provider.signature_verification_algorithms,
282        )
283    }
284
285    fn verify_tls13_signature(
286        &self,
287        message: &[u8],
288        cert: &CertificateDer<'_>,
289        dss: &DigitallySignedStruct,
290    ) -> Result<HandshakeSignatureValid, rustls::Error> {
291        rustls::crypto::verify_tls13_signature(
292            message,
293            cert,
294            dss,
295            &self.provider.signature_verification_algorithms,
296        )
297    }
298
299    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
300        self.provider
301            .signature_verification_algorithms
302            .supported_schemes()
303    }
304}
305
306/// Creates an [`Agent`] for the use in a [`NetHsm`] connection.
307///
308/// Takes a [`ConnectionSecurity`] to define the TLS security model for the connection.
309/// Allows setting the maximum idle connections per host using the optional
310/// `max_idle_connections` (defaults to [`available_parallelism`] and falls back to
311/// [`DEFAULT_MAX_IDLE_CONNECTIONS`] if unavailable).
312/// Also allows setting the timeout in seconds for a successful socket connection
313/// using the optional `timeout_seconds` (defaults to [`DEFAULT_TIMEOUT_SECONDS`]).
314///
315/// # Errors
316///
317/// Returns an error if
318///
319/// - the TLS client configuration can not be created,
320/// - [`ConnectionSecurity::Native`] is provided as `tls_security`, but no certification authority
321///   certificates are available on the system.
322pub(crate) fn create_agent(
323    tls_security: ConnectionSecurity,
324    max_idle_connections: Option<usize>,
325    timeout_seconds: Option<u64>,
326) -> Result<Agent, Error> {
327    let tls_conf = {
328        let tls_conf = ClientConfig::builder_with_provider(Arc::new(CryptoProvider {
329            cipher_suites: tls_provider::ALL_CIPHER_SUITES.into(),
330            ..tls_provider::default_provider()
331        }))
332        .with_protocol_versions(rustls::DEFAULT_VERSIONS)?;
333
334        match tls_security {
335            ConnectionSecurity::Unsafe => {
336                let dangerous = tls_conf.dangerous();
337                dangerous
338                    .with_custom_certificate_verifier(Arc::new(DangerIgnoreVerifier(
339                        tls_provider::default_provider(),
340                    )))
341                    .with_no_client_auth()
342            }
343            ConnectionSecurity::Native => {
344                let native_certs = rustls_native_certs::load_native_certs();
345                if !native_certs.errors.is_empty() {
346                    return Err(Error::CertLoading(native_certs.errors));
347                }
348                let native_certs = native_certs.certs;
349
350                let roots = {
351                    let mut roots = rustls::RootCertStore::empty();
352                    let (added, failed) = roots.add_parsable_certificates(native_certs);
353                    debug!("Added {added} certificates and failed to parse {failed} certificates");
354                    if added == 0 {
355                        error!("Added no native certificates");
356                        return Err(Error::NoSystemCertsAdded { failed });
357                    }
358                    roots
359                };
360
361                tls_conf.with_root_certificates(roots).with_no_client_auth()
362            }
363            ConnectionSecurity::Fingerprints(fingerprints) => {
364                let dangerous = tls_conf.dangerous();
365                dangerous
366                    .with_custom_certificate_verifier(Arc::new(FingerprintVerifier {
367                        fingerprints,
368                        provider: tls_provider::default_provider(),
369                    }))
370                    .with_no_client_auth()
371            }
372        }
373    };
374
375    let max_idle_connections = max_idle_connections
376        .or_else(|| available_parallelism().ok().map(Into::into))
377        .unwrap_or(DEFAULT_MAX_IDLE_CONNECTIONS);
378    let timeout_seconds = timeout_seconds.unwrap_or(DEFAULT_TIMEOUT_SECONDS);
379    info!(
380        "NetHSM connection configured with \"max_idle_connection\" {max_idle_connections} and \"timeout_seconds\" {timeout_seconds}."
381    );
382
383    Ok(AgentBuilder::new()
384        .tls_config(Arc::new(tls_conf))
385        .max_idle_connections(max_idle_connections)
386        .max_idle_connections_per_host(max_idle_connections)
387        .timeout_connect(Duration::from_secs(timeout_seconds))
388        .build())
389}
390
391#[cfg(test)]
392mod tests {
393    use rstest::rstest;
394    use testresult::TestResult;
395
396    use super::*;
397
398    #[test]
399    fn certfingerprint_display() -> TestResult {
400        // the hash digest of "foo"
401        let digest = vec![
402            181, 187, 157, 128, 20, 160, 249, 177, 214, 30, 33, 231, 150, 215, 141, 204, 223, 19,
403            82, 242, 60, 211, 40, 18, 244, 133, 11, 135, 138, 228, 148, 76,
404        ];
405        let cert_fingerprint = CertFingerprint::from(digest);
406
407        assert_eq!(
408            cert_fingerprint.to_string(),
409            "b5bb9d8014a0f9b1d61e21e796d78dccdf1352f23cd32812f4850b878ae4944c"
410        );
411
412        Ok(())
413    }
414
415    #[rstest]
416    #[case(HostCertificateFingerprints { sha256: Some(vec![CertFingerprint::from(vec![
417            181, 187, 157, 128, 20, 160, 249, 177, 214, 30, 33, 231, 150, 215, 141, 204, 223, 19,
418            82, 242, 60, 211, 40, 18, 244, 133, 11, 135, 138, 228, 148, 76,
419        ])]) }, "sha256:b5bb9d8014a0f9b1d61e21e796d78dccdf1352f23cd32812f4850b878ae4944c")]
420    #[case(HostCertificateFingerprints { sha256: Some(Vec::new()) }, "n/a")]
421    #[case(HostCertificateFingerprints { sha256: None }, "n/a")]
422    fn hostcertfingerprints_display(
423        #[case] fingerprints: HostCertificateFingerprints,
424        #[case] expected: &str,
425    ) -> TestResult {
426        assert_eq!(fingerprints.to_string(), expected);
427        Ok(())
428    }
429
430    #[rstest]
431    #[case(ConnectionSecurity::Native, "native")]
432    #[case(ConnectionSecurity::Unsafe, "unsafe")]
433    #[case(ConnectionSecurity::Fingerprints(HostCertificateFingerprints { sha256: Some(vec![CertFingerprint::from(vec![
434            181, 187, 157, 128, 20, 160, 249, 177, 214, 30, 33, 231, 150, 215, 141, 204, 223, 19,
435            82, 242, 60, 211, 40, 18, 244, 133, 11, 135, 138, 228, 148, 76,
436        ])]) }), "sha256:b5bb9d8014a0f9b1d61e21e796d78dccdf1352f23cd32812f4850b878ae4944c")]
437    fn connectionsecurity_display(
438        #[case] connection_security: ConnectionSecurity,
439        #[case] expected: &str,
440    ) -> TestResult {
441        assert_eq!(connection_security.to_string(), expected);
442        Ok(())
443    }
444
445    #[rstest]
446    #[case("native", Some(ConnectionSecurity::Native))]
447    #[case("unsafe", Some(ConnectionSecurity::Unsafe))]
448    #[case("sha256:324f7bd1530c55cf6812ca6865445de21dfc74cf7a3bb5fae7585e849e3553b7", Some(ConnectionSecurity::Fingerprints(HostCertificateFingerprints { sha256: Some(vec![CertFingerprint::from_str("324f7bd1530c55cf6812ca6865445de21dfc74cf7a3bb5fae7585e849e3553b7")?]) })))]
449    #[case(
450        "324f7bd1530c55cf6812ca6865445de21dfc74cf7a3bb5fae7585e849e3553b7",
451        None
452    )]
453    #[case(
454        "sha256:324f7bd1530c55cf6812ca6865445de21dfc74cf7a3bb5fae7585e849e",
455        None
456    )]
457    #[case(
458        "sha256:324f7bd1530c55cf6812ca6865445de21dfc74cf7a3bb5fae7585e849e3553b73553b7",
459        None
460    )]
461    fn connection_security_fromstr(
462        #[case] input: &str,
463        #[case] expected: Option<ConnectionSecurity>,
464    ) -> TestResult {
465        if let Some(expected) = expected {
466            assert_eq!(ConnectionSecurity::from_str(input)?, expected);
467        } else {
468            assert!(ConnectionSecurity::from_str(input).is_err());
469        }
470        Ok(())
471    }
472}