blob: 08abea4d3db53ce6dce51f7bb4562946ea761807 [file] [log] [blame]
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Secure gRPC server implementation with custom TLS configuration and client certificate verification.
//!
//! This module provides functionality to run a secure gRPC server with mutual TLS authentication,
//! certificate revocation checking, and custom client certificate validation using MTLS.
use crate::grpc::telemetry_server::BmcTelemetryService;
use crate::grpc::third_party_voyager::machine_telemetry_server::MachineTelemetryServer;
use rustls::{
client::danger::HandshakeSignatureValid, pki_types::UnixTime,
server::danger::ClientCertVerifier, DigitallySignedStruct, DistinguishedName, Error,
RootCertStore, SignatureScheme,
};
use rustls_pki_types::{CertificateDer, PrivateKeyDer};
use std::fs::File;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use x509_parser::prelude::*;
/// Represents connection information for MTLS.
#[derive(Clone)]
struct MtlsConnectInfo {
#[allow(dead_code)]
remote_addr: std::net::SocketAddr,
}
/// A wrapper around a TLS stream that includes connection information.
struct TlsStreamWithAddr {
stream: tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
connect_info: MtlsConnectInfo,
}
// Implement AsyncRead, AsyncWrite, and tonic::transport::server::Connected traits for TlsStreamWithAddr
impl AsyncRead for TlsStreamWithAddr {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.stream).poll_read(cx, buf)
}
}
impl AsyncWrite for TlsStreamWithAddr {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
Pin::new(&mut self.stream).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.stream).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.stream).poll_shutdown(cx)
}
}
impl tonic::transport::server::Connected for TlsStreamWithAddr {
type ConnectInfo = MtlsConnectInfo;
fn connect_info(&self) -> Self::ConnectInfo {
self.connect_info.clone()
}
}
/// Custom client certificate verifier that includes CRL checking and MTLS validation.
#[derive(Debug)]
struct CustomClientCertVerifier {
inner: Arc<dyn ClientCertVerifier>,
crls: Option<Vec<u8>>,
}
impl CustomClientCertVerifier {
/// Creates a new CustomClientCertVerifier.
///
/// SAFETY: unwrap in this function is allowed, should only be called in server init
///
/// # Arguments
///
/// * `root_cert_store` - The root certificate store for verifying client certificates.
/// * `crls_dir` - Optional directory containing Certificate Revocation Lists.
fn new(root_cert_store: RootCertStore, crls_dir: Option<&str>) -> Self {
let client_verifier =
rustls::server::WebPkiClientVerifier::builder(Arc::new(root_cert_store))
.build()
.unwrap();
let crls = match crls_dir {
Some(dir) => {
let mut all_crls = Vec::new();
match std::fs::read_dir(dir) {
Ok(entries) => {
for entry in entries.flatten() {
if let Ok(file_type) = entry.file_type() {
if file_type.is_file() {
match std::fs::read(entry.path()) {
Ok(data) => all_crls.extend(data),
Err(e) => println!(
"WARNING! Failed to read CRL file {:?}: {}",
entry.path(),
e
),
}
}
}
}
if all_crls.is_empty() {
println!("WARNING! No CRL files were successfully read from {}", dir);
None
} else {
Some(all_crls)
}
}
Err(e) => {
println!("WARNING! Failed to read CRL directory {}: {}", dir, e);
None
}
}
}
None => None,
};
CustomClientCertVerifier {
inner: client_verifier,
crls,
}
}
/// Checks if a certificate is revoked using the loaded CRLs.
///
/// # Arguments
///
/// * `certificate` - The X509 certificate to check.
///
/// # Returns
///
/// `true` if the certificate is revoked, `false` otherwise.
fn check_crl(&self, certificate: &X509Certificate) -> bool {
if let Some(data) = &self.crls {
// Determine if the data is DER or PEM format
// SAFETY: starts_with has out-of-range check
let der_data: Vec<u8> = if data.starts_with(&[0x30, 0x82]) {
// Data is likely in DER format
data.clone()
} else {
// Data is likely in PEM format, try to parse it
match parse_x509_pem(data) {
Ok((_, pem)) => pem.contents.clone(),
Err(_) => {
println!("Could not decode the PEM file");
return false; // If parsing fails, return false
}
}
};
match parse_x509_crl(&der_data) {
Ok((_, crl)) => {
// Check if the certificate's serial number is in the list of revoked certificates
crl.iter_revoked_certificates().any(|rc| {
println!("revoked serial {:?}", rc.raw_serial());
println!(
"certificate serial {:?}, is revoked {:?}",
certificate.tbs_certificate.raw_serial(),
rc.raw_serial() == certificate.tbs_certificate.raw_serial()
);
rc.raw_serial() == certificate.tbs_certificate.raw_serial()
})
}
Err(_) => {
println!("Could not decode DER data");
false // If parsing fails, return false
}
}
} else {
false // If no CRL data is present, return false
}
}
}
/// Verifies a client certificate using MTLS.
///
/// # Arguments
///
/// * `peer_dns_names` - DNS names from the client certificate.
/// * `peer_uri_names` - URI names from the client certificate.
/// * `root_subject_str` - Root subject string for validation.
///
/// # Returns
///
/// A result indicating whether the client certificate is valid.
fn mtls_verify_client_cert(
_peer_dns_names: Vec<String>,
_peer_uri_names: Vec<String>,
_root_subject_str: &str,
) -> Result<rustls::server::danger::ClientCertVerified, rustls::Error> {
// TODO: do authentication based on client cert's URI, DNS, SANs
println!("MTLS client validation passed");
Ok(rustls::server::danger::ClientCertVerified::assertion())
}
// Implement ClientCertVerifier for CustomClientCertVerifier
impl ClientCertVerifier for CustomClientCertVerifier {
fn root_hint_subjects(&self) -> &[DistinguishedName] {
self.inner.root_hint_subjects()
}
fn verify_client_cert(
&self,
end_entity: &CertificateDer<'_>,
intermediates: &[CertificateDer<'_>],
now: UnixTime,
) -> Result<rustls::server::danger::ClientCertVerified, rustls::Error> {
// Verify the client certificate chain
self.inner
.verify_client_cert(end_entity, intermediates, now)?;
let root_subject_str = "".to_string();
let subject = self.inner.root_hint_subjects().to_owned();
for name in &subject {
match parse_distinguished_name(name) {
// TODO: set the name_string to root_subject_str
Ok(name_string) => println!("Distinguished Name: {}", name_string),
Err(e) => eprintln!("Failed to parse distinguished name: {}", e),
}
}
// Parse the DER-encoded certificate
let (_rem, cert) = match X509Certificate::from_der(end_entity) {
Ok((rem, cert)) => {
if !rem.is_empty() {
return Err(rustls::Error::General(
"Certificate parser did not consume all input".to_string(),
));
}
(rem, cert)
}
Err(e) => {
return Err(rustls::Error::General(format!(
"Error parsing certificate: {}",
e
)))
}
};
// Check if the certificate is revoked
if self.check_crl(&cert) {
return Err(rustls::Error::General(
"Client certificate revoked".to_string(),
));
}
let sans = extract_sans(&cert);
match sans {
Ok((dns_names, uris)) => {
println!("DNS Names: {:?}", dns_names);
println!("URIs: {:?}", uris);
mtls_verify_client_cert(dns_names, uris, &root_subject_str)
}
Err(e) => {
println!("Error extracting SANs, but continuing: {}", e);
Err(rustls::Error::General(
"ValidatePeer failed to extract SANs".to_string(),
))
}
}
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, Error> {
self.inner.verify_tls12_signature(message, cert, dss)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, Error> {
self.inner.verify_tls13_signature(message, cert, dss)
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
self.inner.supported_verify_schemes()
}
}
/// Extracts Subject Alternative Names (SANs) from an X509 certificate.
///
/// # Arguments
///
/// * `cert` - The X509 certificate to extract SANs from.
///
/// # Returns
///
/// A tuple containing vectors of DNS names and URIs.
fn extract_sans(
cert: &X509Certificate,
) -> Result<(Vec<String>, Vec<String>), Box<dyn std::error::Error>> {
// Extract SANs
let sans = cert.tbs_certificate.subject_alternative_name()?;
let sans = sans.ok_or("No SAN extension found")?;
let dns_names = sans
.value
.general_names
.iter()
.filter_map(|name| {
if let GeneralName::DNSName(dns) = name {
Some(dns.to_string())
} else {
None
}
})
.collect();
let uris = sans
.value
.general_names
.iter()
.filter_map(|name| {
if let GeneralName::URI(uri) = name {
Some(uri.to_string())
} else {
None
}
})
.collect();
Ok((dns_names, uris))
}
/// Parses a DistinguishedName into a string representation.
///
/// # Arguments
///
/// * `dn` - The DistinguishedName to parse.
///
/// # Returns
///
/// A string representation of the DistinguishedName.
fn parse_distinguished_name(dn: &DistinguishedName) -> Result<String, Box<dyn std::error::Error>> {
// Attempt to parse the DistinguishedName as an X509Name
let parsed_name = X509Name::from_der(dn.as_ref())?;
Ok(parsed_name.1.to_string())
}
/// Loads certificates from a PEM file.
///
/// # Arguments
///
/// * `file_name` - The name of the file containing the certificates.
///
/// # Returns
///
/// A vector of loaded certificates.
pub fn load_cert(file_name: &str) -> std::io::Result<Vec<CertificateDer<'static>>> {
let file = File::open(file_name)?;
let mut reader = std::io::BufReader::new(file);
rustls_pemfile::certs(&mut reader).collect()
}
/// Loads a private key from a PEM file.
///
/// SAFETY: unwrap in this function is allowed
///
/// # Arguments
///
/// * `file_name` - The name of the file containing the private key.
///
/// # Returns
///
/// The loaded private key.
fn load_key(file_name: &str) -> std::io::Result<PrivateKeyDer<'static>> {
let file = File::open(file_name)?;
let mut reader = std::io::BufReader::new(file);
rustls_pemfile::private_key(&mut reader).map(|key| key.unwrap())
}
/// Runs the secure gRPC server.
///
/// SAFETY: unwrap in this function is allowed
///
/// # Arguments
///
/// * `port` - The port number to listen on.
/// * `key` - Path to the server's private key file.
/// * `cert` - Path to the server's certificate file.
/// * `cacert` - Path to the CA certificate file for verifying client certificates.
/// * `policy` - Path to the MTLS policy file.
/// * `crls` - Optional path to the directory containing CRLs.
/// * `grpc` - The gRPC service to serve.
///
/// # Returns
///
/// A result indicating whether the server ran successfully.
pub async fn run_secure_server(
port: u16,
key: &str,
cert: &str,
cacert: &str,
_policy: &str,
crls: Option<&str>,
grpc: MachineTelemetryServer<BmcTelemetryService>,
) -> Result<(), Box<dyn std::error::Error>> {
// Load server's private key and certificate
let cert = load_cert(cert)?;
let key = load_key(key)?;
// Load root CA certificate to verify clients
let client_ca_cert: Vec<CertificateDer<'static>> = load_cert(cacert)?;
let mut root_cert_store = rustls::RootCertStore::empty();
for ca_cert in client_ca_cert {
let _ = root_cert_store.add(ca_cert);
}
let client_verifier = Arc::new(CustomClientCertVerifier::new(root_cert_store, crls));
let mut tls_config = rustls::ServerConfig::builder()
.with_client_cert_verifier(client_verifier)
.with_single_cert(cert, key)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?;
tls_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec(), b"http/1.0".to_vec()];
let tls_config = Arc::new(tls_config);
let acceptor = tokio_rustls::TlsAcceptor::from(tls_config);
let addr = format!("[::]:{port}");
let addr: std::net::SocketAddr = addr.parse().unwrap();
let listener = tokio::net::TcpListener::bind(addr).await?;
println!("Server listening on {}", addr);
let incoming_tls_stream = futures_util::stream::unfold(listener, |listener| async {
let (socket, remote_addr) = match listener.accept().await {
Ok(conn) => conn,
Err(_) => return None,
};
match acceptor.accept(socket).await {
Ok(tls_stream) => {
let connect_info = MtlsConnectInfo { remote_addr };
Some((
Ok(TlsStreamWithAddr {
stream: tls_stream,
connect_info,
}),
listener,
))
}
Err(_) => Some((
Err(std::io::Error::new(std::io::ErrorKind::Other, "TLS Error")),
listener,
)),
}
});
tonic::transport::Server::builder()
.add_service(grpc)
.serve_with_incoming(incoming_tls_stream)
.await?;
Ok(())
}