From 146c270cbbf533b9a139cc6fcdf6f6b3228798aa Mon Sep 17 00:00:00 2001 From: Ionut Mihalcea Date: Tue, 17 Dec 2019 16:27:08 +0000 Subject: [PATCH] Improve usage of unwrap and expect This commit looks at improving our usage of unwrap and expect and the errors they could return instead. Calls that could legitimately panic during execution were changed to return an error. The rest were commented to specify why they shouldn't fail in a normal execution environment. Signed-off-by: Ionut Mihalcea --- src/abstraction/transient.rs | 16 ++++----- src/lib.rs | 68 +++++++++++++----------------------- src/response_code.rs | 11 +++++- src/utils.rs | 64 +++++++++++++++++++-------------- 4 files changed, 81 insertions(+), 78 deletions(-) diff --git a/src/abstraction/transient.rs b/src/abstraction/transient.rs index 31ca21d1..584685b3 100644 --- a/src/abstraction/transient.rs +++ b/src/abstraction/transient.rs @@ -35,7 +35,7 @@ impl TransientObjectContext { if root_key_auth_size > 32 { return Err(Error::local_error(ErrorKind::WrongParamSize)); } - if root_key_size < 1024 { + if root_key_size < 1024 || root_key_size > 4096 { return Err(Error::local_error(ErrorKind::WrongParamSize)); } let mut context = Context::new(tcti)?; @@ -50,7 +50,7 @@ impl TransientObjectContext { let root_key_handle = context.create_primary_key( ESYS_TR_RH_OWNER, - &get_rsa_public(true, true, false, root_key_size.try_into().unwrap()), + &get_rsa_public(true, true, false, root_key_size.try_into().unwrap()), // should not fail on supported targets, given the checks above &root_key_auth, &[], &[], @@ -83,7 +83,7 @@ impl TransientObjectContext { if auth_size > 32 { return Err(Error::local_error(ErrorKind::WrongParamSize)); } - if key_size < 1024 { + if key_size < 1024 || key_size > 4096 { return Err(Error::local_error(ErrorKind::WrongParamSize)); } let key_auth = if auth_size > 0 { @@ -95,7 +95,7 @@ impl TransientObjectContext { self.set_session_attrs()?; let (key_priv, key_pub) = self.context.create_key( self.root_key_handle, - &get_rsa_public(false, false, true, key_size.try_into().unwrap()), + &get_rsa_public(false, false, true, key_size.try_into().unwrap()), // should not fail on valid targets, given the checks above &key_auth, &[], &[], @@ -122,7 +122,7 @@ impl TransientObjectContext { let pk = TPMU_PUBLIC_ID { rsa: TPM2B_PUBLIC_KEY_RSA { - size: public_key.len().try_into().unwrap(), + size: public_key.len().try_into().unwrap(), // should not fail on valid targets, given the checks above buffer: pk_buffer, }, }; @@ -131,7 +131,7 @@ impl TransientObjectContext { false, false, true, - u16::try_from(public_key.len()).unwrap() * 8u16, + u16::try_from(public_key.len()).unwrap() * 8u16, // should not fail on valid targets, given the checks above ); public.publicArea.unique = pk; @@ -160,7 +160,7 @@ impl TransientObjectContext { let key = match PublicIdUnion::from_public(&key_pub_id) { PublicIdUnion::Rsa(pub_key) => { let mut key = pub_key.buffer.to_vec(); - key.truncate(pub_key.size.try_into().unwrap()); + key.truncate(pub_key.size.try_into().unwrap()); // should not fail on supported targets key } _ => unimplemented!(), @@ -257,7 +257,7 @@ mod tests { let mut ctx = TransientObjectContext::new(Tcti::Mssim, 2048, 32, &[]).unwrap(); for _ in 0..4 { let (key, auth) = ctx.create_rsa_signing_key(2048, 16).unwrap(); - let mut signature = ctx.sign(key.clone(), &auth, &HASH).unwrap(); + let signature = ctx.sign(key.clone(), &auth, &HASH).unwrap(); let pub_key = ctx.read_public_key(key.clone()).unwrap(); let pub_key = ctx.load_external_rsa_public_key(&pub_key).unwrap(); ctx.verify_signature(pub_key, &HASH, signature).unwrap(); diff --git a/src/lib.rs b/src/lib.rs index 79db97e3..bd0f3219 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -54,10 +54,13 @@ use utils::{Signature, TpmaSession, TpmsContext}; #[macro_use] macro_rules! wrap_buffer { ($buf:expr, $buf_type:ty, $buf_size:expr) => {{ + if $buf.len() > $buf_size { + return Err(Error::local_error(ErrorKind::WrongParamSize)); + } let mut buffer = [0u8; $buf_size]; buffer[..$buf.len()].clone_from_slice(&$buf[..$buf.len()]); let mut buf_struct: $buf_type = Default::default(); - buf_struct.size = $buf.len().try_into().unwrap(); + buf_struct.size = $buf.len().try_into().unwrap(); // should not fail since the length is checked above buf_struct.buffer = buffer; buf_struct }}; @@ -115,7 +118,7 @@ impl Context { let ret = unsafe { tss2_esys::Esys_Initialize( &mut esys_context, - tcti_context.as_mut().unwrap().as_mut_ptr(), + tcti_context.as_mut().unwrap().as_mut_ptr(), // will not panic as per how tcti_context is initialised null_mut(), ) }; @@ -162,10 +165,6 @@ impl Context { symmetric: TPMT_SYM_DEF, auth_hash: TPMI_ALG_HASH, ) -> Result { - if nonce.len() > 64 { - return Err(Error::local_error(ErrorKind::WrongParamSize)); - } - let nonce_caller = wrap_buffer!(nonce, TPM2B_NONCE, 64); let mut sess = ESYS_TR_NONE; @@ -218,30 +217,26 @@ impl Context { outside_info: &[u8], creation_pcrs: &[TPMS_PCR_SELECTION], ) -> Result { - if auth_value.len() > 64 - || initial_data.len() > 256 - || outside_info.len() > 64 - || creation_pcrs.len() > 16 - { - return Err(Error::local_error(ErrorKind::WrongParamSize)); - } - let sensitive_create = TPM2B_SENSITIVE_CREATE { size: std::mem::size_of::() .try_into() - .unwrap(), + .unwrap(), // will not fail on targets of at least 16 bits sensitive: TPMS_SENSITIVE_CREATE { userAuth: wrap_buffer!(auth_value, TPM2B_AUTH, 64), data: wrap_buffer!(initial_data, TPM2B_SENSITIVE_DATA, 256), }, }; - let outside_info = wrap_buffer!(outside_info, TPM2B_DATA, 64); + + if creation_pcrs.len() > 16 { + return Err(Error::local_error(ErrorKind::WrongParamSize)); + } + let mut creation_pcrs_buffer = [Default::default(); 16]; creation_pcrs_buffer[..creation_pcrs.len()] .clone_from_slice(&creation_pcrs[..creation_pcrs.len()]); let creation_pcrs = TPML_PCR_SELECTION { - count: creation_pcrs.len().try_into().unwrap(), + count: creation_pcrs.len().try_into().unwrap(), // will not fail given the len checks above pcrSelections: creation_pcrs_buffer, }; @@ -297,18 +292,10 @@ impl Context { outside_info: &[u8], creation_pcrs: &[TPMS_PCR_SELECTION], ) -> Result<(TPM2B_PRIVATE, TPM2B_PUBLIC)> { - if auth_value.len() > 64 - || initial_data.len() > 256 - || outside_info.len() > 64 - || creation_pcrs.len() > 16 - { - return Err(Error::local_error(ErrorKind::WrongParamSize)); - } - let sensitive_create = TPM2B_SENSITIVE_CREATE { size: std::mem::size_of::() .try_into() - .unwrap(), + .unwrap(), // will not fail on targets of at least 16 bits sensitive: TPMS_SENSITIVE_CREATE { userAuth: wrap_buffer!(auth_value, TPM2B_AUTH, 64), data: wrap_buffer!(initial_data, TPM2B_SENSITIVE_DATA, 256), @@ -317,11 +304,14 @@ impl Context { let outside_info = wrap_buffer!(outside_info, TPM2B_DATA, 64); + if creation_pcrs.len() > 16 { + return Err(Error::local_error(ErrorKind::WrongParamSize)); + } let mut creation_pcrs_buffer = [Default::default(); 16]; creation_pcrs_buffer[..creation_pcrs.len()] .clone_from_slice(&creation_pcrs[..creation_pcrs.len()]); let creation_pcrs = TPML_PCR_SELECTION { - count: creation_pcrs.len().try_into().unwrap(), + count: creation_pcrs.len().try_into().unwrap(), // will not fail given the len checks above pcrSelections: creation_pcrs_buffer, }; @@ -403,9 +393,6 @@ impl Context { scheme: TPMT_SIG_SCHEME, validation: &TPMT_TK_HASHCHECK, ) -> Result { - if digest.len() > 64 { - return Err(Error::local_error(ErrorKind::WrongParamSize)); - } let mut signature = null_mut(); let digest = wrap_buffer!(digest, TPM2B_DIGEST, 64); let ret = unsafe { @@ -438,9 +425,6 @@ impl Context { digest: &[u8], signature: &TPMT_SIGNATURE, ) -> Result { - if digest.len() > 64 { - return Err(Error::local_error(ErrorKind::WrongParamSize)); - } let mut validation = null_mut(); let digest = wrap_buffer!(digest, TPM2B_DIGEST, 64); let ret = unsafe { @@ -577,7 +561,7 @@ impl Context { let ret = Error::from_tss_rc(ret); if ret.is_success() { let context = unsafe { MBox::::from_raw(context) }; - Ok((*context).into()) + Ok((*context).try_into()?) } else { error!("Error in saving context: {}.", ret); Err(ret) @@ -612,7 +596,9 @@ impl Context { self.sessions.0, self.sessions.1, self.sessions.2, - num_bytes.try_into().unwrap(), + num_bytes + .try_into() + .or_else(|_| Err(Error::local_error(ErrorKind::WrongParamSize)))?, &mut buffer, ) }; @@ -621,7 +607,7 @@ impl Context { if ret.is_success() { let buffer = unsafe { MBox::from_raw(buffer) }; let mut random = buffer.buffer.to_vec(); - random.truncate(buffer.size.try_into().unwrap()); + random.truncate(buffer.size.try_into().unwrap()); // should not panic given the TryInto above Ok(random) } else { error!("Error in flushing context: {}.", ret); @@ -630,10 +616,6 @@ impl Context { } pub fn set_handle_auth(&mut self, handle: ESYS_TR, auth_value: &[u8]) -> Result<()> { - if auth_value.len() > 64 { - return Err(Error::local_error(ErrorKind::WrongParamSize)); - } - let auth = wrap_buffer!(auth_value, TPM2B_AUTH, 64); let ret = unsafe { Esys_TR_SetAuth(self.mut_context(), handle, &auth) }; let ret = Error::from_tss_rc(ret); @@ -657,7 +639,7 @@ impl Context { } fn mut_context(&mut self) -> *mut ESYS_CONTEXT { - self.esys_context.as_mut().unwrap().as_mut_ptr() + self.esys_context.as_mut().unwrap().as_mut_ptr() // will only fail if called from Drop after .take() } } @@ -673,8 +655,8 @@ impl Drop for Context { } }); - let esys_context = self.esys_context.take().unwrap(); - let tcti_context = self.tcti_context.take().unwrap(); + let esys_context = self.esys_context.take().unwrap(); // should not fail based on how the context is initialised/used + let tcti_context = self.tcti_context.take().unwrap(); // should not fail based on how the context is initialised/used // Close the TCTI context. unsafe { diff --git a/src/response_code.rs b/src/response_code.rs index d8658c13..fe3df75b 100644 --- a/src/response_code.rs +++ b/src/response_code.rs @@ -328,7 +328,7 @@ impl std::fmt::Display for Tss2ResponseCode { if kind.is_none() { return write!(f, "response code not recognized"); } - match self.kind().unwrap() { + match self.kind().unwrap() { // should not panic, given the check above Tss2ResponseCodeKind::Success => write!(f, "success"), Tss2ResponseCodeKind::TpmVendorSpecific => write!(f, "vendor specific error: {}", self.error_number()), // Format Zero @@ -475,6 +475,8 @@ impl std::fmt::Display for Error { #[derive(Copy, Clone, PartialEq, Debug)] pub enum WrapperErrorKind { WrongParamSize, + ParamsMissing, + InconsistentParams, } impl std::fmt::Display for WrapperErrorKind { @@ -483,6 +485,13 @@ impl std::fmt::Display for WrapperErrorKind { WrapperErrorKind::WrongParamSize => { write!(f, "parameter provided is of the wrong size") } + WrapperErrorKind::ParamsMissing => { + write!(f, "some of the required parameters were not provided") + } + WrapperErrorKind::InconsistentParams => write!( + f, + "the provided parameters have inconsistent values or variants" + ), } } } diff --git a/src/utils.rs b/src/utils.rs index 964e2e28..445c1dfb 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -71,7 +71,7 @@ impl Tpm2BPublicBuilder { self } - pub fn build(mut self) -> TPM2B_PUBLIC { + pub fn build(mut self) -> Result { match self.type_ { Some(TPM2_ALG_RSA) => { // RSA key @@ -80,9 +80,9 @@ impl Tpm2BPublicBuilder { if let Some(PublicParmsUnion::RsaDetail(parms)) = self.parameters { parameters = TPMU_PUBLIC_PARMS { rsaDetail: parms }; } else if self.parameters.is_none() { - panic!("No key parameters provided"); + return Err(Error::local_error(WrapperErrorKind::ParamsMissing)); } else { - panic!("Wrong parameter type provided"); + return Err(Error::local_error(WrapperErrorKind::InconsistentParams)); } if let Some(PublicIdUnion::Rsa(rsa_unique)) = self.unique { @@ -90,7 +90,7 @@ impl Tpm2BPublicBuilder { } else if self.unique.is_none() { unique = Default::default(); } else { - panic!("Wrong unique type provided"); + return Err(Error::local_error(WrapperErrorKind::InconsistentParams)); } if self.object_attributes.sign_encrypt() && self.object_attributes.decrypt() { @@ -111,19 +111,21 @@ impl Tpm2BPublicBuilder { parameters.rsaDetail.symmetric.algorithm = TPM2_ALG_NULL; } - TPM2B_PUBLIC { + Ok(TPM2B_PUBLIC { size: std::mem::size_of::() .try_into() - .expect("Failed to convert usize to u16"), + .expect("Failed to convert usize to u16"), // should not fail on valid targets publicArea: TPMT_PUBLIC { - type_: self.type_.expect("Object type not provided"), + type_: self + .type_ + .ok_or_else(|| Error::local_error(WrapperErrorKind::ParamsMissing))?, nameAlg: self.name_alg, objectAttributes: self.object_attributes.0, authPolicy: self.auth_policy, parameters, unique, }, - } + }) } _ => unimplemented!(), } @@ -177,16 +179,16 @@ impl TpmsRsaParmsBuilder { self } - pub fn build(self) -> TPMS_RSA_PARMS { - TPMS_RSA_PARMS { + pub fn build(self) -> Result { + Ok(TPMS_RSA_PARMS { symmetric: self.symmetric, scheme: self .scheme - .expect("Scheme was not provided") + .ok_or_else(|| Error::local_error(WrapperErrorKind::ParamsMissing))? .get_rsa_scheme(), keyBits: self.key_bits, exponent: self.exponent, - } + }) } } @@ -225,7 +227,7 @@ impl TpmtSymDefBuilder { self } - pub fn build_object(self) -> TPMT_SYM_DEF_OBJECT { + pub fn build_object(self) -> Result { let key_bits; let mode; match self.algorithm { @@ -263,11 +265,13 @@ impl TpmtSymDefBuilder { _ => unimplemented!(), } - TPMT_SYM_DEF_OBJECT { - algorithm: self.algorithm.expect("No algorithm provided"), + Ok(TPMT_SYM_DEF_OBJECT { + algorithm: self + .algorithm + .ok_or_else(|| Error::local_error(WrapperErrorKind::ParamsMissing))?, keyBits: key_bits, mode, - } + }) } pub fn aes_256_cfb() -> TPMT_SYM_DEF { @@ -469,7 +473,7 @@ impl TryFrom for TPMT_SIGNATURE { rsassa: TPMS_SIGNATURE_RSA { hash: hash_alg, sig: TPM2B_PUBLIC_KEY_RSA { - size: len.try_into().expect("Failed to convert length to u16"), // Should never panic + size: len.try_into().expect("Failed to convert length to u16"), // Should never panic as per the check above buffer, }, }, @@ -517,18 +521,24 @@ pub struct TpmsContext { context_blob: Vec, } -impl From for TpmsContext { - fn from(tss2_context: TPMS_CONTEXT) -> Self { +impl TryFrom for TpmsContext { + type Error = Error; + + fn try_from(tss2_context: TPMS_CONTEXT) -> Result { let mut context = TpmsContext { sequence: tss2_context.sequence, saved_handle: tss2_context.savedHandle, hierarchy: tss2_context.hierarchy, context_blob: tss2_context.contextBlob.buffer.to_vec(), }; - context - .context_blob - .truncate(tss2_context.contextBlob.size.try_into().unwrap()); - context + context.context_blob.truncate( + tss2_context + .contextBlob + .size + .try_into() + .or_else(|_| Err(Error::local_error(WrapperErrorKind::WrongParamSize)))?, + ); + Ok(context) } } @@ -538,7 +548,7 @@ impl TryFrom for TPMS_CONTEXT { fn try_from(context: TpmsContext) -> Result { let buffer_size = context.context_blob.len(); if buffer_size > 5188 { - return Err(Error::from_tss_rc(TPM2_RC_SIZE)); + return Err(Error::local_error(WrapperErrorKind::WrongParamSize)); } let mut buffer = [0u8; 5188]; for (i, val) in context.context_blob.into_iter().enumerate() { @@ -549,7 +559,7 @@ impl TryFrom for TPMS_CONTEXT { savedHandle: context.saved_handle, hierarchy: context.hierarchy, contextBlob: TPM2B_CONTEXT_DATA { - size: buffer_size.try_into().unwrap(), + size: buffer_size.try_into().unwrap(), // should not panic given the check above buffer, }, }) @@ -563,7 +573,8 @@ pub fn get_rsa_public(restricted: bool, decrypt: bool, sign: bool, key_bits: u16 .with_symmetric(symmetric) .with_key_bits(key_bits) .with_scheme(scheme) - .build(); + .build() + .unwrap(); // should not fail as we control the params let mut object_attributes = ObjectAttributes(0); object_attributes.set_fixed_tpm(true); object_attributes.set_fixed_parent(true); @@ -579,4 +590,5 @@ pub fn get_rsa_public(restricted: bool, decrypt: bool, sign: bool, key_bits: u16 .with_object_attributes(object_attributes) .with_parms(PublicParmsUnion::RsaDetail(rsa_parms)) .build() + .unwrap() // should not fail as we control the params }