Skip to content
This repository was archived by the owner on Mar 11, 2025. It is now read-only.

token-2022: Introduce PackedSizeOf and relax BaseState trait #6332

Merged
merged 1 commit into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 25 additions & 22 deletions token/program-2022/src/extension/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use {
transfer_fee::{TransferFeeAmount, TransferFeeConfig},
transfer_hook::{TransferHook, TransferHookAccount},
},
state::{Account, Mint, Multisig},
state::{Account, Mint, Multisig, PackedSizeOf},
},
bytemuck::{Pod, Zeroable},
num_enum::{IntoPrimitive, TryFromPrimitive},
Expand Down Expand Up @@ -298,7 +298,7 @@ fn type_and_tlv_indices<S: BaseState>(
if rest_input.is_empty() {
Ok(None)
} else {
let account_type_index = BASE_ACCOUNT_LENGTH.saturating_sub(S::LEN);
let account_type_index = BASE_ACCOUNT_LENGTH.saturating_sub(S::SIZE_OF);
// check padding is all zeroes
let tlv_start_index = account_type_index.saturating_add(size_of::<AccountType>());
if rest_input.len() <= tlv_start_index {
Expand Down Expand Up @@ -427,7 +427,7 @@ pub trait BaseStateWithExtensions<S: BaseState> {
fn try_get_account_len(&self) -> Result<usize, ProgramError> {
let tlv_info = get_tlv_data_info(self.get_tlv_data())?;
if tlv_info.extension_types.is_empty() {
Ok(S::LEN)
Ok(S::SIZE_OF)
} else {
let total_len = tlv_info
.used_len
Expand Down Expand Up @@ -468,13 +468,13 @@ pub struct StateWithExtensionsOwned<S: BaseState> {
/// Raw TLV data, deserialized on demand
tlv_data: Vec<u8>,
}
impl<S: BaseState> StateWithExtensionsOwned<S> {
impl<S: BaseState + Pack> StateWithExtensionsOwned<S> {
/// Unpack base state, leaving the extension data as a slice
///
/// Fails if the base state is not initialized.
pub fn unpack(mut input: Vec<u8>) -> Result<Self, ProgramError> {
check_min_len_and_not_multisig(&input, S::LEN)?;
let mut rest = input.split_off(S::LEN);
check_min_len_and_not_multisig(&input, S::SIZE_OF)?;
let mut rest = input.split_off(S::SIZE_OF);
let base = S::unpack(&input)?;
if let Some((account_type_index, tlv_start_index)) = type_and_tlv_indices::<S>(&rest)? {
// type_and_tlv_indices() checks that returned indexes are within range
Expand All @@ -501,19 +501,19 @@ impl<S: BaseState> BaseStateWithExtensions<S> for StateWithExtensionsOwned<S> {
/// Encapsulates immutable base state data (mint or account) with possible
/// extensions
#[derive(Debug, PartialEq)]
pub struct StateWithExtensions<'data, S: BaseState> {
pub struct StateWithExtensions<'data, S: BaseState + Pack> {
/// Unpacked base data
pub base: S,
/// Slice of data containing all TLV data, deserialized on demand
tlv_data: &'data [u8],
}
impl<'data, S: BaseState> StateWithExtensions<'data, S> {
impl<'data, S: BaseState + Pack> StateWithExtensions<'data, S> {
/// Unpack base state, leaving the extension data as a slice
///
/// Fails if the base state is not initialized.
pub fn unpack(input: &'data [u8]) -> Result<Self, ProgramError> {
check_min_len_and_not_multisig(input, S::LEN)?;
let (base_data, rest) = input.split_at(S::LEN);
check_min_len_and_not_multisig(input, S::SIZE_OF)?;
let (base_data, rest) = input.split_at(S::SIZE_OF);
let base = S::unpack(base_data)?;
if let Some((account_type_index, tlv_start_index)) = type_and_tlv_indices::<S>(rest)? {
// type_and_tlv_indices() checks that returned indexes are within range
Expand All @@ -532,7 +532,7 @@ impl<'data, S: BaseState> StateWithExtensions<'data, S> {
}
}
}
impl<'a, S: BaseState> BaseStateWithExtensions<S> for StateWithExtensions<'a, S> {
impl<'a, S: BaseState + Pack> BaseStateWithExtensions<S> for StateWithExtensions<'a, S> {
fn get_tlv_data(&self) -> &[u8] {
self.tlv_data
}
Expand Down Expand Up @@ -784,13 +784,13 @@ pub struct StateWithExtensionsMut<'data, S: BaseState> {
/// Slice of data containing all TLV data, deserialized on demand
tlv_data: &'data mut [u8],
}
impl<'data, S: BaseState> StateWithExtensionsMut<'data, S> {
impl<'data, S: BaseState + Pack> StateWithExtensionsMut<'data, S> {
/// Unpack base state, leaving the extension data as a mutable slice
///
/// Fails if the base state is not initialized.
pub fn unpack(input: &'data mut [u8]) -> Result<Self, ProgramError> {
check_min_len_and_not_multisig(input, S::LEN)?;
let (base_data, rest) = input.split_at_mut(S::LEN);
check_min_len_and_not_multisig(input, S::SIZE_OF)?;
let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
let base = S::unpack(base_data)?;
let (account_type, tlv_data) = unpack_type_and_tlv_data::<S>(rest)?;
Ok(Self {
Expand All @@ -806,8 +806,8 @@ impl<'data, S: BaseState> StateWithExtensionsMut<'data, S> {
///
/// Fails if the base state has already been initialized.
pub fn unpack_uninitialized(input: &'data mut [u8]) -> Result<Self, ProgramError> {
check_min_len_and_not_multisig(input, S::LEN)?;
let (base_data, rest) = input.split_at_mut(S::LEN);
check_min_len_and_not_multisig(input, S::SIZE_OF)?;
let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
let base = S::unpack_unchecked(base_data)?;
if base.is_initialized() {
return Err(TokenError::AlreadyInUse.into());
Expand Down Expand Up @@ -892,8 +892,8 @@ fn unpack_uninitialized_type_and_tlv_data<S: BaseState>(
/// This method assumes that the `base_data` has already been packed with data
/// of the desired type.
pub fn set_account_type<S: BaseState>(input: &mut [u8]) -> Result<(), ProgramError> {
check_min_len_and_not_multisig(input, S::LEN)?;
let (base_data, rest) = input.split_at_mut(S::LEN);
check_min_len_and_not_multisig(input, S::SIZE_OF)?;
let (base_data, rest) = input.split_at_mut(S::SIZE_OF);
if S::ACCOUNT_TYPE == AccountType::Account && !is_initialized_account(base_data)? {
return Err(ProgramError::InvalidAccountData);
}
Expand Down Expand Up @@ -1113,7 +1113,7 @@ impl ExtensionType {
extension_types: &[Self],
) -> Result<usize, ProgramError> {
if extension_types.is_empty() {
Ok(S::LEN)
Ok(S::SIZE_OF)
} else {
let extension_size = Self::try_get_total_tlv_len(extension_types)?;
let total_len = extension_size.saturating_add(BASE_ACCOUNT_AND_TYPE_LENGTH);
Expand Down Expand Up @@ -1216,7 +1216,7 @@ impl ExtensionType {
}

/// Trait for base states, specifying the associated enum
pub trait BaseState: Pack + IsInitialized {
pub trait BaseState: PackedSizeOf + IsInitialized {
/// Associated extension type enum, checked at the start of TLV entries
const ACCOUNT_TYPE: AccountType;
}
Expand Down Expand Up @@ -1287,7 +1287,7 @@ impl Extension for AccountPaddingTest {
/// NOTE: Since this function deals with fixed-size extensions, it does not
/// handle _decreasing_ the size of an account's data buffer, like the function
/// `alloc_and_serialize_variable_len_extension` does.
pub fn alloc_and_serialize<S: BaseState, V: Default + Extension + Pod>(
pub fn alloc_and_serialize<S: BaseState + Pack, V: Default + Extension + Pod>(
account_info: &AccountInfo,
new_extension: &V,
overwrite: bool,
Expand Down Expand Up @@ -1324,7 +1324,10 @@ pub fn alloc_and_serialize<S: BaseState, V: Default + Extension + Pod>(
///
/// NOTE: Unlike the `reallocate` instruction, this function will reduce the
/// size of an account if it has too many bytes allocated for the given value.
pub fn alloc_and_serialize_variable_len_extension<S: BaseState, V: Extension + VariableLenPack>(
pub fn alloc_and_serialize_variable_len_extension<
S: BaseState + Pack,
V: Extension + VariableLenPack,
>(
account_info: &AccountInfo,
new_extension: &V,
overwrite: bool,
Expand Down
17 changes: 17 additions & 0 deletions token/program-2022/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ use {
},
};

/// Simplified version of the `Pack` trait which only gives the size of the
/// packed struct. Useful when a function doesn't need a type to implement all
/// of `Pack`, but a size is still needed.
pub trait PackedSizeOf {
/// The packed size of the struct
const SIZE_OF: usize;
}

/// Mint data.
#[repr(C)]
#[derive(Clone, Copy, Debug, Default, PartialEq)]
Expand Down Expand Up @@ -86,6 +94,9 @@ impl Pack for Mint {
pack_coption_key(freeze_authority, freeze_authority_dst);
}
}
impl PackedSizeOf for Mint {
const SIZE_OF: usize = Self::LEN;
}

/// Account data.
#[repr(C)]
Expand Down Expand Up @@ -184,6 +195,9 @@ impl Pack for Account {
pack_coption_key(close_authority, close_authority_dst);
}
}
impl PackedSizeOf for Account {
const SIZE_OF: usize = Self::LEN;
}

/// Account state.
#[repr(u8)]
Expand Down Expand Up @@ -254,6 +268,9 @@ impl Pack for Multisig {
}
}
}
impl PackedSizeOf for Multisig {
const SIZE_OF: usize = Self::LEN;
}

// Helpers
pub(crate) fn pack_coption_key(src: &COption<Pubkey>, dst: &mut [u8; 36]) {
Expand Down