nethsm/
tls.rs

1use std::str::FromStr;
2use std::sync::Arc;
3use std::thread::available_parallelism;
4use std::time::Duration;
5
6use log::{debug, error, info, trace};
7use nethsm_sdk_rs::ureq::{Agent, AgentBuilder};
8use rustls::client::{
9    ClientConfig,
10    danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier},
11};
12use rustls::crypto::{CryptoProvider, ring as tls_provider};
13use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
14use rustls::{DigitallySignedStruct, SignatureScheme};
15use serde::{Deserialize, Serialize};
16use sha2::{Digest, Sha256};
17
18use crate::Error;
19#[cfg(doc)]
20use crate::NetHsm;
21
22/// The fingerprint of a TLS certificate (as hex)
23#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
24pub struct CertFingerprint(
25    #[serde(
26        deserialize_with = "hex::serde::deserialize",
27        serialize_with = "hex::serde::serialize"
28    )]
29    Vec<u8>,
30);
31
32impl FromStr for CertFingerprint {
33    type Err = Error;
34    fn from_str(s: &str) -> Result<Self, Self::Err> {
35        Ok(Self(s.as_bytes().to_vec()))
36    }
37}
38
39impl From<Vec<u8>> for CertFingerprint {
40    fn from(value: Vec<u8>) -> Self {
41        Self(value)
42    }
43}
44
45/// Certificate fingerprints to use for matching against a host's TLS
46/// certificate
47#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
48pub struct HostCertificateFingerprints {
49    /// An optional list of SHA-256 checksums
50    sha256: Option<Vec<CertFingerprint>>,
51}
52
53/// The security model chosen for a [`crate::NetHsm`]'s TLS connection
54#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
55pub enum ConnectionSecurity {
56    /// Always trust the TLS certificate associated with a host
57    Unsafe,
58    /// Use the native trust store to evaluate the trust of a host
59    Native,
60    /// Use a list of checksums (fingerprints) to verify a host's TLS certificate
61    Fingerprints(HostCertificateFingerprints),
62}
63
64impl FromStr for ConnectionSecurity {
65    type Err = Error;
66
67    /// Create a ConnectionSecurity from string
68    ///
69    /// Valid inputs are either "Unsafe" (or "unsafe"), "Native" (or "native") or "sha256:checksum"
70    /// where "checksum" denotes 64 ASCII hexadecimal chars.
71    ///
72    /// # Errors
73    ///
74    /// Returns an [`Error`] if the input is neither "Unsafe" nor "Native" and also no valid
75    /// certificate fingerprint can be derived from the input.
76    ///
77    /// # Examples
78    ///
79    /// ```
80    /// use std::str::FromStr;
81    ///
82    /// use nethsm::ConnectionSecurity;
83    ///
84    /// assert!(ConnectionSecurity::from_str("unsafe").is_ok());
85    /// assert!(ConnectionSecurity::from_str("native").is_ok());
86    /// assert!(
87    ///     ConnectionSecurity::from_str(
88    ///         "sha256:324f7bd1530c55cf6812ca6865445de21dfc74cf7a3bb5fae7585e849e3553b7"
89    ///     )
90    ///     .is_ok()
91    /// );
92    /// assert!(ConnectionSecurity::from_str("something").is_err());
93    /// ```
94    fn from_str(s: &str) -> Result<Self, Self::Err> {
95        match s {
96            "unsafe" | "Unsafe" => Ok(Self::Unsafe),
97            "native" | "Native" => Ok(Self::Native),
98            _ => {
99                let sha256_fingerprints: Vec<Vec<u8>> = s
100                    .split(',')
101                    .filter_map(|checksum| {
102                        checksum
103                            .strip_prefix("sha256:")
104                            .filter(|x| x.len() == 64 && x.chars().all(|x| x.is_ascii_hexdigit()))
105                            .map(|checksum| checksum.as_bytes().to_vec())
106                    })
107                    .collect();
108                if sha256_fingerprints.is_empty() {
109                    Err(Error::Default(
110                        "No valid TLS certificate fingerprints detected.".to_string(),
111                    ))
112                } else {
113                    Ok(Self::Fingerprints(HostCertificateFingerprints {
114                        sha256: Some(
115                            sha256_fingerprints
116                                .iter()
117                                .map(|checksum| checksum.clone().into())
118                                .collect(),
119                        ),
120                    }))
121                }
122            }
123        }
124    }
125}
126
127/// A verifier for server certificates that always accepts them
128///
129/// This verifier is used when choosing [`ConnectionSecurity::Unsafe`]. It is **unsafe** and should
130/// not be used unless for initial setup scenarios of a NetHSM! Instead use [`FingerprintVerifier`]
131/// (selected by [`ConnectionSecurity::Fingerprints`]) or better yet rely on
132/// [`ConnectionSecurity::Native`].
133#[derive(Debug)]
134pub struct DangerIgnoreVerifier(pub CryptoProvider);
135
136impl ServerCertVerifier for DangerIgnoreVerifier {
137    fn verify_server_cert(
138        &self,
139        _end_entity: &CertificateDer<'_>,
140        _intermediates: &[CertificateDer<'_>],
141        _server_name: &ServerName<'_>,
142        _ocsp_response: &[u8],
143        _now: UnixTime,
144    ) -> Result<ServerCertVerified, rustls::Error> {
145        // always accept the certificate
146        Ok(ServerCertVerified::assertion())
147    }
148
149    fn verify_tls12_signature(
150        &self,
151        message: &[u8],
152        cert: &CertificateDer<'_>,
153        dss: &DigitallySignedStruct,
154    ) -> Result<HandshakeSignatureValid, rustls::Error> {
155        rustls::crypto::verify_tls12_signature(
156            message,
157            cert,
158            dss,
159            &self.0.signature_verification_algorithms,
160        )
161    }
162
163    fn verify_tls13_signature(
164        &self,
165        message: &[u8],
166        cert: &CertificateDer<'_>,
167        dss: &DigitallySignedStruct,
168    ) -> Result<HandshakeSignatureValid, rustls::Error> {
169        rustls::crypto::verify_tls13_signature(
170            message,
171            cert,
172            dss,
173            &self.0.signature_verification_algorithms,
174        )
175    }
176
177    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
178        self.0.signature_verification_algorithms.supported_schemes()
179    }
180}
181
182/// A verifier for server certificates that verifies them based on fingerprints
183///
184/// This verifier is selected when using [`ConnectionSecurity::Fingerprints`] and relies on
185/// [`HostCertificateFingerprints`] to be able to match a host certificate fingerprint against a
186/// predefined list of fingerprints. It should be preferred over the use of [`DangerIgnoreVerifier`]
187/// (selected by [`ConnectionSecurity::Unsafe`]), but ideally a setup should make use of
188/// [`ConnectionSecurity::Native`] instead!
189#[derive(Debug)]
190pub struct FingerprintVerifier {
191    pub fingerprints: HostCertificateFingerprints,
192    pub provider: CryptoProvider,
193}
194
195impl ServerCertVerifier for FingerprintVerifier {
196    fn verify_server_cert(
197        &self,
198        end_entity: &CertificateDer<'_>,
199        _intermediates: &[CertificateDer<'_>],
200        _server_name: &ServerName<'_>,
201        _ocsp_response: &[u8],
202        _now: UnixTime,
203    ) -> Result<ServerCertVerified, rustls::Error> {
204        if let Some(sha256_fingerprints) = self.fingerprints.sha256.as_ref() {
205            let mut hasher = Sha256::new();
206            hasher.update(end_entity.as_ref());
207            let result = hasher.finalize();
208            for fingerprint in sha256_fingerprints.iter() {
209                if fingerprint.0 == result[..] {
210                    trace!("Certificate fingerprint matches");
211                    return Ok(ServerCertVerified::assertion());
212                }
213            }
214        } else {
215            return Err(rustls::Error::General(
216                "Could not verify certificate fingerprint as no fingerprints were provided to match against".to_string(),
217            ));
218        }
219        Err(rustls::Error::General(
220            "Could not verify certificate fingerprint".to_string(),
221        ))
222    }
223
224    fn verify_tls12_signature(
225        &self,
226        message: &[u8],
227        cert: &CertificateDer<'_>,
228        dss: &DigitallySignedStruct,
229    ) -> Result<HandshakeSignatureValid, rustls::Error> {
230        rustls::crypto::verify_tls12_signature(
231            message,
232            cert,
233            dss,
234            &self.provider.signature_verification_algorithms,
235        )
236    }
237
238    fn verify_tls13_signature(
239        &self,
240        message: &[u8],
241        cert: &CertificateDer<'_>,
242        dss: &DigitallySignedStruct,
243    ) -> Result<HandshakeSignatureValid, rustls::Error> {
244        rustls::crypto::verify_tls13_signature(
245            message,
246            cert,
247            dss,
248            &self.provider.signature_verification_algorithms,
249        )
250    }
251
252    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
253        self.provider
254            .signature_verification_algorithms
255            .supported_schemes()
256    }
257}
258
259/// Creates an [`Agent`] for the use in a [`NetHsm`] connection.
260///
261/// Takes a [`ConnectionSecurity`] to define the TLS security model for the connection.
262/// Allows setting the maximum idle connections per host using the optional
263/// `max_idle_connections` (defaults to [`available_parallelism`] and falls back to `100` if
264/// unavailable).
265/// Also allows setting the timeout in seconds for a successful socket connection
266/// using the optional `timeout_seconds` (defaults to `10`).
267///
268/// # Errors
269///
270/// Returns an error if
271///
272/// - the TLS client configuration can not be created,
273/// - [`ConnectionSecurity::Native`] is provided as `tls_security`, but no certification authority
274///   certificates are available on the system.
275pub(crate) fn create_agent(
276    tls_security: ConnectionSecurity,
277    max_idle_connections: Option<usize>,
278    timeout_seconds: Option<u64>,
279) -> Result<Agent, Error> {
280    let tls_conf = {
281        let tls_conf = ClientConfig::builder_with_provider(Arc::new(CryptoProvider {
282            cipher_suites: tls_provider::ALL_CIPHER_SUITES.into(),
283            ..tls_provider::default_provider()
284        }))
285        .with_protocol_versions(rustls::DEFAULT_VERSIONS)?;
286
287        match tls_security {
288            ConnectionSecurity::Unsafe => {
289                let dangerous = tls_conf.dangerous();
290                dangerous
291                    .with_custom_certificate_verifier(Arc::new(DangerIgnoreVerifier(
292                        tls_provider::default_provider(),
293                    )))
294                    .with_no_client_auth()
295            }
296            ConnectionSecurity::Native => {
297                let native_certs = rustls_native_certs::load_native_certs();
298                if !native_certs.errors.is_empty() {
299                    return Err(Error::CertLoading(native_certs.errors));
300                }
301                let native_certs = native_certs.certs;
302
303                let roots = {
304                    let mut roots = rustls::RootCertStore::empty();
305                    let (added, failed) = roots.add_parsable_certificates(native_certs);
306                    debug!("Added {added} certificates and failed to parse {failed} certificates");
307                    if added == 0 {
308                        error!("Added no native certificates");
309                        return Err(Error::NoSystemCertsAdded { failed });
310                    }
311                    roots
312                };
313
314                tls_conf.with_root_certificates(roots).with_no_client_auth()
315            }
316            ConnectionSecurity::Fingerprints(fingerprints) => {
317                let dangerous = tls_conf.dangerous();
318                dangerous
319                    .with_custom_certificate_verifier(Arc::new(FingerprintVerifier {
320                        fingerprints,
321                        provider: tls_provider::default_provider(),
322                    }))
323                    .with_no_client_auth()
324            }
325        }
326    };
327
328    let max_idle_connections = max_idle_connections
329        .or_else(|| available_parallelism().ok().map(Into::into))
330        .unwrap_or(100);
331    let timeout_seconds = timeout_seconds.unwrap_or(10);
332    info!(
333        "NetHSM connection configured with \"max_idle_connection\" {} and \"timeout_seconds\" {}.",
334        max_idle_connections, timeout_seconds
335    );
336
337    Ok(AgentBuilder::new()
338        .tls_config(Arc::new(tls_conf))
339        .max_idle_connections(max_idle_connections)
340        .max_idle_connections_per_host(max_idle_connections)
341        .timeout_connect(Duration::from_secs(timeout_seconds))
342        .build())
343}
344
345#[cfg(test)]
346mod tests {
347    use rstest::rstest;
348    use testresult::TestResult;
349
350    use super::*;
351
352    #[rstest]
353    #[case("native", Some(ConnectionSecurity::Native))]
354    #[case("unsafe", Some(ConnectionSecurity::Unsafe))]
355    #[case("sha256:324f7bd1530c55cf6812ca6865445de21dfc74cf7a3bb5fae7585e849e3553b7", Some(ConnectionSecurity::Fingerprints(HostCertificateFingerprints { sha256: Some(vec![CertFingerprint::from_str("324f7bd1530c55cf6812ca6865445de21dfc74cf7a3bb5fae7585e849e3553b7")?]) })))]
356    #[case(
357        "324f7bd1530c55cf6812ca6865445de21dfc74cf7a3bb5fae7585e849e3553b7",
358        None
359    )]
360    #[case(
361        "sha256:324f7bd1530c55cf6812ca6865445de21dfc74cf7a3bb5fae7585e849e",
362        None
363    )]
364    #[case(
365        "sha256:324f7bd1530c55cf6812ca6865445de21dfc74cf7a3bb5fae7585e849e3553b73553b7",
366        None
367    )]
368    fn connection_security_fromstr(
369        #[case] input: &str,
370        #[case] expected: Option<ConnectionSecurity>,
371    ) -> TestResult {
372        if let Some(expected) = expected {
373            assert_eq!(ConnectionSecurity::from_str(input)?, expected);
374        } else {
375            assert!(ConnectionSecurity::from_str(input).is_err());
376        }
377        Ok(())
378    }
379}