| // Copyright 2025 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| //! Secure gRPC server implementation with custom TLS configuration and client certificate verification. |
| //! |
| //! This module provides functionality to run a secure gRPC server with mutual TLS authentication, |
| //! certificate revocation checking, and custom client certificate validation using MTLS. |
| |
| use crate::grpc::telemetry_server::BmcTelemetryService; |
| use crate::grpc::third_party_voyager::machine_telemetry_server::MachineTelemetryServer; |
| use rustls::{ |
| client::danger::HandshakeSignatureValid, pki_types::UnixTime, |
| server::danger::ClientCertVerifier, DigitallySignedStruct, DistinguishedName, Error, |
| RootCertStore, SignatureScheme, |
| }; |
| use rustls_pki_types::{CertificateDer, PrivateKeyDer}; |
| use std::fs::File; |
| use std::pin::Pin; |
| use std::sync::Arc; |
| use std::task::{Context, Poll}; |
| use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; |
| use x509_parser::prelude::*; |
| |
| /// Represents connection information for MTLS. |
| #[derive(Clone)] |
| struct MtlsConnectInfo { |
| #[allow(dead_code)] |
| remote_addr: std::net::SocketAddr, |
| } |
| |
| /// A wrapper around a TLS stream that includes connection information. |
| struct TlsStreamWithAddr { |
| stream: tokio_rustls::server::TlsStream<tokio::net::TcpStream>, |
| connect_info: MtlsConnectInfo, |
| } |
| |
| // Implement AsyncRead, AsyncWrite, and tonic::transport::server::Connected traits for TlsStreamWithAddr |
| impl AsyncRead for TlsStreamWithAddr { |
| fn poll_read( |
| mut self: Pin<&mut Self>, |
| cx: &mut Context<'_>, |
| buf: &mut ReadBuf<'_>, |
| ) -> Poll<std::io::Result<()>> { |
| Pin::new(&mut self.stream).poll_read(cx, buf) |
| } |
| } |
| |
| impl AsyncWrite for TlsStreamWithAddr { |
| fn poll_write( |
| mut self: Pin<&mut Self>, |
| cx: &mut Context<'_>, |
| buf: &[u8], |
| ) -> Poll<std::io::Result<usize>> { |
| Pin::new(&mut self.stream).poll_write(cx, buf) |
| } |
| |
| fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> { |
| Pin::new(&mut self.stream).poll_flush(cx) |
| } |
| |
| fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> { |
| Pin::new(&mut self.stream).poll_shutdown(cx) |
| } |
| } |
| |
| impl tonic::transport::server::Connected for TlsStreamWithAddr { |
| type ConnectInfo = MtlsConnectInfo; |
| |
| fn connect_info(&self) -> Self::ConnectInfo { |
| self.connect_info.clone() |
| } |
| } |
| |
| /// Custom client certificate verifier that includes CRL checking and MTLS validation. |
| #[derive(Debug)] |
| struct CustomClientCertVerifier { |
| inner: Arc<dyn ClientCertVerifier>, |
| crls: Option<Vec<u8>>, |
| } |
| |
| impl CustomClientCertVerifier { |
| /// Creates a new CustomClientCertVerifier. |
| /// |
| /// SAFETY: unwrap in this function is allowed, should only be called in server init |
| /// |
| /// # Arguments |
| /// |
| /// * `root_cert_store` - The root certificate store for verifying client certificates. |
| /// * `crls_dir` - Optional directory containing Certificate Revocation Lists. |
| fn new(root_cert_store: RootCertStore, crls_dir: Option<&str>) -> Self { |
| let client_verifier = |
| rustls::server::WebPkiClientVerifier::builder(Arc::new(root_cert_store)) |
| .build() |
| .unwrap(); |
| |
| let crls = match crls_dir { |
| Some(dir) => { |
| let mut all_crls = Vec::new(); |
| match std::fs::read_dir(dir) { |
| Ok(entries) => { |
| for entry in entries.flatten() { |
| if let Ok(file_type) = entry.file_type() { |
| if file_type.is_file() { |
| match std::fs::read(entry.path()) { |
| Ok(data) => all_crls.extend(data), |
| Err(e) => println!( |
| "WARNING! Failed to read CRL file {:?}: {}", |
| entry.path(), |
| e |
| ), |
| } |
| } |
| } |
| } |
| if all_crls.is_empty() { |
| println!("WARNING! No CRL files were successfully read from {}", dir); |
| None |
| } else { |
| Some(all_crls) |
| } |
| } |
| Err(e) => { |
| println!("WARNING! Failed to read CRL directory {}: {}", dir, e); |
| None |
| } |
| } |
| } |
| None => None, |
| }; |
| |
| CustomClientCertVerifier { |
| inner: client_verifier, |
| crls, |
| } |
| } |
| |
| /// Checks if a certificate is revoked using the loaded CRLs. |
| /// |
| /// # Arguments |
| /// |
| /// * `certificate` - The X509 certificate to check. |
| /// |
| /// # Returns |
| /// |
| /// `true` if the certificate is revoked, `false` otherwise. |
| fn check_crl(&self, certificate: &X509Certificate) -> bool { |
| if let Some(data) = &self.crls { |
| // Determine if the data is DER or PEM format |
| // SAFETY: starts_with has out-of-range check |
| let der_data: Vec<u8> = if data.starts_with(&[0x30, 0x82]) { |
| // Data is likely in DER format |
| data.clone() |
| } else { |
| // Data is likely in PEM format, try to parse it |
| match parse_x509_pem(data) { |
| Ok((_, pem)) => pem.contents.clone(), |
| Err(_) => { |
| println!("Could not decode the PEM file"); |
| return false; // If parsing fails, return false |
| } |
| } |
| }; |
| |
| match parse_x509_crl(&der_data) { |
| Ok((_, crl)) => { |
| // Check if the certificate's serial number is in the list of revoked certificates |
| crl.iter_revoked_certificates().any(|rc| { |
| println!("revoked serial {:?}", rc.raw_serial()); |
| println!( |
| "certificate serial {:?}, is revoked {:?}", |
| certificate.tbs_certificate.raw_serial(), |
| rc.raw_serial() == certificate.tbs_certificate.raw_serial() |
| ); |
| rc.raw_serial() == certificate.tbs_certificate.raw_serial() |
| }) |
| } |
| Err(_) => { |
| println!("Could not decode DER data"); |
| false // If parsing fails, return false |
| } |
| } |
| } else { |
| false // If no CRL data is present, return false |
| } |
| } |
| } |
| |
| /// Verifies a client certificate using MTLS. |
| /// |
| /// # Arguments |
| /// |
| /// * `peer_dns_names` - DNS names from the client certificate. |
| /// * `peer_uri_names` - URI names from the client certificate. |
| /// * `root_subject_str` - Root subject string for validation. |
| /// |
| /// # Returns |
| /// |
| /// A result indicating whether the client certificate is valid. |
| fn mtls_verify_client_cert( |
| _peer_dns_names: Vec<String>, |
| _peer_uri_names: Vec<String>, |
| _root_subject_str: &str, |
| ) -> Result<rustls::server::danger::ClientCertVerified, rustls::Error> { |
| // TODO: do authentication based on client cert's URI, DNS, SANs |
| println!("MTLS client validation passed"); |
| Ok(rustls::server::danger::ClientCertVerified::assertion()) |
| } |
| |
| // Implement ClientCertVerifier for CustomClientCertVerifier |
| impl ClientCertVerifier for CustomClientCertVerifier { |
| fn root_hint_subjects(&self) -> &[DistinguishedName] { |
| self.inner.root_hint_subjects() |
| } |
| |
| fn verify_client_cert( |
| &self, |
| end_entity: &CertificateDer<'_>, |
| intermediates: &[CertificateDer<'_>], |
| now: UnixTime, |
| ) -> Result<rustls::server::danger::ClientCertVerified, rustls::Error> { |
| // Verify the client certificate chain |
| self.inner |
| .verify_client_cert(end_entity, intermediates, now)?; |
| |
| let root_subject_str = "".to_string(); |
| let subject = self.inner.root_hint_subjects().to_owned(); |
| for name in &subject { |
| match parse_distinguished_name(name) { |
| // TODO: set the name_string to root_subject_str |
| Ok(name_string) => println!("Distinguished Name: {}", name_string), |
| Err(e) => eprintln!("Failed to parse distinguished name: {}", e), |
| } |
| } |
| |
| // Parse the DER-encoded certificate |
| let (_rem, cert) = match X509Certificate::from_der(end_entity) { |
| Ok((rem, cert)) => { |
| if !rem.is_empty() { |
| return Err(rustls::Error::General( |
| "Certificate parser did not consume all input".to_string(), |
| )); |
| } |
| (rem, cert) |
| } |
| Err(e) => { |
| return Err(rustls::Error::General(format!( |
| "Error parsing certificate: {}", |
| e |
| ))) |
| } |
| }; |
| |
| // Check if the certificate is revoked |
| if self.check_crl(&cert) { |
| return Err(rustls::Error::General( |
| "Client certificate revoked".to_string(), |
| )); |
| } |
| |
| let sans = extract_sans(&cert); |
| match sans { |
| Ok((dns_names, uris)) => { |
| println!("DNS Names: {:?}", dns_names); |
| println!("URIs: {:?}", uris); |
| mtls_verify_client_cert(dns_names, uris, &root_subject_str) |
| } |
| Err(e) => { |
| println!("Error extracting SANs, but continuing: {}", e); |
| Err(rustls::Error::General( |
| "ValidatePeer failed to extract SANs".to_string(), |
| )) |
| } |
| } |
| } |
| |
| fn verify_tls12_signature( |
| &self, |
| message: &[u8], |
| cert: &CertificateDer<'_>, |
| dss: &DigitallySignedStruct, |
| ) -> Result<HandshakeSignatureValid, Error> { |
| self.inner.verify_tls12_signature(message, cert, dss) |
| } |
| |
| fn verify_tls13_signature( |
| &self, |
| message: &[u8], |
| cert: &CertificateDer<'_>, |
| dss: &DigitallySignedStruct, |
| ) -> Result<HandshakeSignatureValid, Error> { |
| self.inner.verify_tls13_signature(message, cert, dss) |
| } |
| |
| fn supported_verify_schemes(&self) -> Vec<SignatureScheme> { |
| self.inner.supported_verify_schemes() |
| } |
| } |
| |
| /// Extracts Subject Alternative Names (SANs) from an X509 certificate. |
| /// |
| /// # Arguments |
| /// |
| /// * `cert` - The X509 certificate to extract SANs from. |
| /// |
| /// # Returns |
| /// |
| /// A tuple containing vectors of DNS names and URIs. |
| fn extract_sans( |
| cert: &X509Certificate, |
| ) -> Result<(Vec<String>, Vec<String>), Box<dyn std::error::Error>> { |
| // Extract SANs |
| let sans = cert.tbs_certificate.subject_alternative_name()?; |
| let sans = sans.ok_or("No SAN extension found")?; |
| let dns_names = sans |
| .value |
| .general_names |
| .iter() |
| .filter_map(|name| { |
| if let GeneralName::DNSName(dns) = name { |
| Some(dns.to_string()) |
| } else { |
| None |
| } |
| }) |
| .collect(); |
| |
| let uris = sans |
| .value |
| .general_names |
| .iter() |
| .filter_map(|name| { |
| if let GeneralName::URI(uri) = name { |
| Some(uri.to_string()) |
| } else { |
| None |
| } |
| }) |
| .collect(); |
| |
| Ok((dns_names, uris)) |
| } |
| |
| /// Parses a DistinguishedName into a string representation. |
| /// |
| /// # Arguments |
| /// |
| /// * `dn` - The DistinguishedName to parse. |
| /// |
| /// # Returns |
| /// |
| /// A string representation of the DistinguishedName. |
| fn parse_distinguished_name(dn: &DistinguishedName) -> Result<String, Box<dyn std::error::Error>> { |
| // Attempt to parse the DistinguishedName as an X509Name |
| let parsed_name = X509Name::from_der(dn.as_ref())?; |
| Ok(parsed_name.1.to_string()) |
| } |
| |
| /// Loads certificates from a PEM file. |
| /// |
| /// # Arguments |
| /// |
| /// * `file_name` - The name of the file containing the certificates. |
| /// |
| /// # Returns |
| /// |
| /// A vector of loaded certificates. |
| pub fn load_cert(file_name: &str) -> std::io::Result<Vec<CertificateDer<'static>>> { |
| let file = File::open(file_name)?; |
| let mut reader = std::io::BufReader::new(file); |
| rustls_pemfile::certs(&mut reader).collect() |
| } |
| |
| /// Loads a private key from a PEM file. |
| /// |
| /// SAFETY: unwrap in this function is allowed |
| /// |
| /// # Arguments |
| /// |
| /// * `file_name` - The name of the file containing the private key. |
| /// |
| /// # Returns |
| /// |
| /// The loaded private key. |
| fn load_key(file_name: &str) -> std::io::Result<PrivateKeyDer<'static>> { |
| let file = File::open(file_name)?; |
| let mut reader = std::io::BufReader::new(file); |
| rustls_pemfile::private_key(&mut reader).map(|key| key.unwrap()) |
| } |
| |
| /// Runs the secure gRPC server. |
| /// |
| /// SAFETY: unwrap in this function is allowed |
| /// |
| /// # Arguments |
| /// |
| /// * `port` - The port number to listen on. |
| /// * `key` - Path to the server's private key file. |
| /// * `cert` - Path to the server's certificate file. |
| /// * `cacert` - Path to the CA certificate file for verifying client certificates. |
| /// * `policy` - Path to the MTLS policy file. |
| /// * `crls` - Optional path to the directory containing CRLs. |
| /// * `grpc` - The gRPC service to serve. |
| /// |
| /// # Returns |
| /// |
| /// A result indicating whether the server ran successfully. |
| pub async fn run_secure_server( |
| port: u16, |
| key: &str, |
| cert: &str, |
| cacert: &str, |
| _policy: &str, |
| crls: Option<&str>, |
| grpc: MachineTelemetryServer<BmcTelemetryService>, |
| ) -> Result<(), Box<dyn std::error::Error>> { |
| // Load server's private key and certificate |
| let cert = load_cert(cert)?; |
| let key = load_key(key)?; |
| |
| // Load root CA certificate to verify clients |
| let client_ca_cert: Vec<CertificateDer<'static>> = load_cert(cacert)?; |
| let mut root_cert_store = rustls::RootCertStore::empty(); |
| for ca_cert in client_ca_cert { |
| let _ = root_cert_store.add(ca_cert); |
| } |
| |
| let client_verifier = Arc::new(CustomClientCertVerifier::new(root_cert_store, crls)); |
| let mut tls_config = rustls::ServerConfig::builder() |
| .with_client_cert_verifier(client_verifier) |
| .with_single_cert(cert, key) |
| .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?; |
| tls_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec(), b"http/1.0".to_vec()]; |
| |
| let tls_config = Arc::new(tls_config); |
| |
| let acceptor = tokio_rustls::TlsAcceptor::from(tls_config); |
| |
| let addr = format!("[::]:{port}"); |
| let addr: std::net::SocketAddr = addr.parse().unwrap(); |
| let listener = tokio::net::TcpListener::bind(addr).await?; |
| println!("Server listening on {}", addr); |
| |
| let incoming_tls_stream = futures_util::stream::unfold(listener, |listener| async { |
| let (socket, remote_addr) = match listener.accept().await { |
| Ok(conn) => conn, |
| Err(_) => return None, |
| }; |
| match acceptor.accept(socket).await { |
| Ok(tls_stream) => { |
| let connect_info = MtlsConnectInfo { remote_addr }; |
| Some(( |
| Ok(TlsStreamWithAddr { |
| stream: tls_stream, |
| connect_info, |
| }), |
| listener, |
| )) |
| } |
| Err(_) => Some(( |
| Err(std::io::Error::new(std::io::ErrorKind::Other, "TLS Error")), |
| listener, |
| )), |
| } |
| }); |
| |
| tonic::transport::Server::builder() |
| .add_service(grpc) |
| .serve_with_incoming(incoming_tls_stream) |
| .await?; |
| |
| Ok(()) |
| } |