diff --git a/beacon_node/http_api/src/lib.rs b/beacon_node/http_api/src/lib.rs index 07e20b44377..f101e35ed9a 100644 --- a/beacon_node/http_api/src/lib.rs +++ b/beacon_node/http_api/src/lib.rs @@ -68,6 +68,7 @@ use slog::{crit, debug, error, info, warn, Logger}; use slot_clock::SlotClock; use ssz::Encode; pub use state_id::StateId; +use std::collections::HashSet; use std::future::Future; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::path::PathBuf; @@ -85,13 +86,14 @@ use tokio_stream::{ wrappers::{errors::BroadcastStreamRecvError, BroadcastStream}, StreamExt, }; +use types::AttestationData; use types::{ - fork_versioned_response::EmptyMetadata, Attestation, AttestationData, AttestationShufflingId, - AttesterSlashing, BeaconStateError, ChainSpec, Checkpoint, CommitteeCache, ConfigAndPreset, - Epoch, EthSpec, ForkName, ForkVersionedResponse, Hash256, ProposerPreparationData, - ProposerSlashing, RelativeEpoch, SignedAggregateAndProof, SignedBlindedBeaconBlock, - SignedBlsToExecutionChange, SignedContributionAndProof, SignedValidatorRegistrationData, - SignedVoluntaryExit, Slot, SyncCommitteeMessage, SyncContributionData, + fork_versioned_response::EmptyMetadata, Attestation, AttestationShufflingId, AttesterSlashing, + BeaconStateError, ChainSpec, Checkpoint, CommitteeCache, ConfigAndPreset, Epoch, EthSpec, + ForkName, ForkVersionedResponse, Hash256, ProposerPreparationData, ProposerSlashing, + RelativeEpoch, SignedAggregateAndProof, SignedBlindedBeaconBlock, SignedBlsToExecutionChange, + SignedContributionAndProof, SignedValidatorRegistrationData, SignedVoluntaryExit, Slot, + SyncCommitteeMessage, SyncContributionData, }; use validator::pubkey_to_validator_index; use version::{ @@ -2032,11 +2034,11 @@ pub fn serve( chain: Arc>, query: api_types::AttestationPoolQuery| { task_spawner.blocking_response_task(Priority::P1, move || { - let query_filter = |data: &AttestationData| { + let query_filter = |data: &AttestationData, committee_indices: HashSet| { query.slot.is_none_or(|slot| slot == data.slot) && query .committee_index - .is_none_or(|index| index == data.index) + .is_none_or(|index| committee_indices.contains(&index)) }; let mut attestations = chain.op_pool.get_filtered_attestations(query_filter); @@ -2045,7 +2047,9 @@ pub fn serve( .naive_aggregation_pool .read() .iter() - .filter(|&att| query_filter(att.data())) + .filter(|&att| { + query_filter(att.data(), att.get_committee_indices_map()) + }) .cloned(), ); // Use the current slot to find the fork version, and convert all messages to the diff --git a/beacon_node/http_api/tests/tests.rs b/beacon_node/http_api/tests/tests.rs index b573302e8ba..a5aeb30e1ab 100644 --- a/beacon_node/http_api/tests/tests.rs +++ b/beacon_node/http_api/tests/tests.rs @@ -28,6 +28,7 @@ use http_api::{ use lighthouse_network::{types::SyncState, Enr, EnrExt, PeerId}; use logging::test_logger; use network::NetworkReceivers; +use operation_pool::attestation_storage::CheckpointKey; use proto_array::ExecutionStatus; use sensitive_url::SensitiveUrl; use slot_clock::SlotClock; @@ -2119,7 +2120,7 @@ impl ApiTester { self } - pub async fn test_get_beacon_pool_attestations(self) -> Self { + pub async fn test_get_beacon_pool_attestations(self) { let result = self .client .get_beacon_pool_attestations_v1(None, None) @@ -2138,9 +2139,80 @@ impl ApiTester { .await .unwrap() .data; + assert_eq!(result, expected); - self + let result_committee_index_filtered = self + .client + .get_beacon_pool_attestations_v1(None, Some(0)) + .await + .unwrap() + .data; + + let expected_committee_index_filtered = expected + .clone() + .into_iter() + .filter(|att| att.get_committee_indices_map().contains(&0)) + .collect::>(); + + assert_eq!( + result_committee_index_filtered, + expected_committee_index_filtered + ); + + let result_committee_index_filtered = self + .client + .get_beacon_pool_attestations_v1(None, Some(1)) + .await + .unwrap() + .data; + + let expected_committee_index_filtered = expected + .clone() + .into_iter() + .filter(|att| att.get_committee_indices_map().contains(&1)) + .collect::>(); + + assert_eq!( + result_committee_index_filtered, + expected_committee_index_filtered + ); + + let fork_name = self + .harness + .chain + .spec + .fork_name_at_slot::(self.harness.chain.slot().unwrap()); + + // aggregate electra attestations + if fork_name.electra_enabled() { + // Take and drop the lock in a block to avoid clippy complaining + // about taking locks across await points + { + let mut all_attestations = self.chain.op_pool.attestations.write(); + let (prev_epoch_key, curr_epoch_key) = + CheckpointKey::keys_for_state(&self.harness.get_current_state()); + all_attestations.aggregate_across_committees(prev_epoch_key); + all_attestations.aggregate_across_committees(curr_epoch_key); + } + let result_committee_index_filtered = self + .client + .get_beacon_pool_attestations_v2(None, Some(0)) + .await + .unwrap() + .data; + let mut expected = self.chain.op_pool.get_all_attestations(); + expected.extend(self.chain.naive_aggregation_pool.read().iter().cloned()); + let expected_committee_index_filtered = expected + .clone() + .into_iter() + .filter(|att| att.get_committee_indices_map().contains(&0)) + .collect::>(); + assert_eq!( + result_committee_index_filtered, + expected_committee_index_filtered + ); + } } pub async fn test_post_beacon_pool_attester_slashings_valid_v1(mut self) -> Self { @@ -6463,10 +6535,30 @@ async fn beacon_get_blocks() { } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn beacon_get_pools() { +async fn test_beacon_pool_attestations_electra() { + let mut config = ApiTesterConfig::default(); + config.spec.altair_fork_epoch = Some(Epoch::new(0)); + config.spec.bellatrix_fork_epoch = Some(Epoch::new(0)); + config.spec.capella_fork_epoch = Some(Epoch::new(0)); + config.spec.deneb_fork_epoch = Some(Epoch::new(0)); + config.spec.electra_fork_epoch = Some(Epoch::new(0)); + ApiTester::new_from_config(config) + .await + .test_get_beacon_pool_attestations() + .await; +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn test_beacon_pool_attestations_base() { ApiTester::new() .await .test_get_beacon_pool_attestations() + .await; +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn beacon_get_pools() { + ApiTester::new() .await .test_get_beacon_pool_attester_slashings() .await diff --git a/beacon_node/operation_pool/src/attestation_storage.rs b/beacon_node/operation_pool/src/attestation_storage.rs index 49ef5c279c5..67c24b9c7a1 100644 --- a/beacon_node/operation_pool/src/attestation_storage.rs +++ b/beacon_node/operation_pool/src/attestation_storage.rs @@ -1,6 +1,6 @@ use crate::AttestationStats; use itertools::Itertools; -use std::collections::{BTreeMap, HashMap}; +use std::collections::{BTreeMap, HashMap, HashSet}; use types::{ attestation::{AttestationBase, AttestationElectra}, superstruct, AggregateSignature, Attestation, AttestationData, BeaconState, BitList, BitVector, @@ -119,6 +119,18 @@ impl CompactAttestationRef<'_, E> { } } + pub fn get_committee_indices_map(&self) -> HashSet { + match self.indexed { + CompactIndexedAttestation::Base(_) => HashSet::from([self.data.index]), + CompactIndexedAttestation::Electra(indexed_att) => indexed_att + .committee_bits + .iter() + .enumerate() + .filter_map(|(index, bit)| if bit { Some(index as u64) } else { None }) + .collect(), + } + } + pub fn clone_as_attestation(&self) -> Attestation { match self.indexed { CompactIndexedAttestation::Base(indexed_att) => Attestation::Base(AttestationBase { @@ -268,7 +280,11 @@ impl CompactIndexedAttestationElectra { } pub fn committee_index(&self) -> Option { - self.get_committee_indices().first().copied() + self.committee_bits + .iter() + .enumerate() + .find(|&(_, bit)| bit) + .map(|(index, _)| index as u64) } pub fn get_committee_indices(&self) -> Vec { diff --git a/beacon_node/operation_pool/src/lib.rs b/beacon_node/operation_pool/src/lib.rs index 584a5f9f323..ec8c6640b1d 100644 --- a/beacon_node/operation_pool/src/lib.rs +++ b/beacon_node/operation_pool/src/lib.rs @@ -1,5 +1,5 @@ mod attestation; -mod attestation_storage; +pub mod attestation_storage; mod attester_slashing; mod bls_to_execution_changes; mod max_cover; @@ -47,7 +47,7 @@ type SyncContributions = RwLock { /// Map from attestation ID (see below) to vectors of attestations. - attestations: RwLock>, + pub attestations: RwLock>, /// Map from sync aggregate ID to the best `SyncCommitteeContribution`s seen for that ID. sync_contributions: SyncContributions, /// Set of attester slashings, and the fork version they were verified against. @@ -673,12 +673,12 @@ impl OperationPool { /// This method may return objects that are invalid for block inclusion. pub fn get_filtered_attestations(&self, filter: F) -> Vec> where - F: Fn(&AttestationData) -> bool, + F: Fn(&AttestationData, HashSet) -> bool, { self.attestations .read() .iter() - .filter(|att| filter(&att.attestation_data())) + .filter(|att| filter(&att.attestation_data(), att.get_committee_indices_map())) .map(|att| att.clone_as_attestation()) .collect() } diff --git a/consensus/types/src/attestation.rs b/consensus/types/src/attestation.rs index 1485842edbd..08953770637 100644 --- a/consensus/types/src/attestation.rs +++ b/consensus/types/src/attestation.rs @@ -5,6 +5,7 @@ use derivative::Derivative; use serde::{Deserialize, Serialize}; use ssz_derive::{Decode, Encode}; use ssz_types::BitVector; +use std::collections::HashSet; use std::hash::{Hash, Hasher}; use superstruct::superstruct; use test_random_derive::TestRandom; @@ -209,6 +210,13 @@ impl Attestation { } } + pub fn get_committee_indices_map(&self) -> HashSet { + match self { + Attestation::Base(att) => HashSet::from([att.data.index]), + Attestation::Electra(att) => att.get_committee_indices().into_iter().collect(), + } + } + pub fn is_aggregation_bits_zero(&self) -> bool { match self { Attestation::Base(att) => att.aggregation_bits.is_zero(), @@ -292,7 +300,11 @@ impl AttestationRef<'_, E> { impl AttestationElectra { pub fn committee_index(&self) -> Option { - self.get_committee_indices().first().cloned() + self.committee_bits + .iter() + .enumerate() + .find(|&(_, bit)| bit) + .map(|(index, _)| index as u64) } pub fn get_aggregation_bits(&self) -> Vec {