Skip to content

refactor: replace more GAT-based async trait with RPITIT #12271

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions src/batch/src/exchange_source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use std::fmt::Debug;
use std::future::Future;

use risingwave_common::array::DataChunk;
use risingwave_common::error::Result;

use crate::execution::grpc_exchange::GrpcExchangeSource;
use crate::execution::local_exchange::LocalExchangeSource;
Expand All @@ -24,11 +25,7 @@ use crate::task::TaskId;

/// Each `ExchangeSource` maps to one task, it takes the execution result from task chunk by chunk.
pub trait ExchangeSource: Send + Debug {
type TakeDataFuture<'a>: Future<Output = risingwave_common::error::Result<Option<DataChunk>>>
+ 'a
where
Self: 'a;
fn take_data(&mut self) -> Self::TakeDataFuture<'_>;
fn take_data(&mut self) -> impl Future<Output = Result<Option<DataChunk>>> + '_;

/// Get upstream task id.
fn get_task_id(&self) -> TaskId;
Expand All @@ -42,9 +39,7 @@ pub enum ExchangeSourceImpl {
}

impl ExchangeSourceImpl {
pub(crate) async fn take_data(
&mut self,
) -> risingwave_common::error::Result<Option<DataChunk>> {
pub(crate) async fn take_data(&mut self) -> Result<Option<DataChunk>> {
match self {
ExchangeSourceImpl::Grpc(grpc) => grpc.take_data().await,
ExchangeSourceImpl::Local(local) => local.take_data().await,
Expand Down
35 changes: 15 additions & 20 deletions src/batch/src/execution/grpc_exchange.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
// limitations under the License.

use std::fmt::{Debug, Formatter};
use std::future::Future;

use futures::StreamExt;
use risingwave_common::array::DataChunk;
Expand Down Expand Up @@ -73,26 +72,22 @@ impl Debug for GrpcExchangeSource {
}

impl ExchangeSource for GrpcExchangeSource {
type TakeDataFuture<'a> = impl Future<Output = Result<Option<DataChunk>>> + 'a;

fn take_data(&mut self) -> Self::TakeDataFuture<'_> {
async {
let res = match self.stream.next().await {
None => {
return Ok(None);
}
Some(r) => r,
};
let task_data = res?;
let data = DataChunk::from_protobuf(task_data.get_record_batch()?)?.compact();
trace!(
"Receiver taskOutput = {:?}, data = {:?}",
self.task_output_id,
data
);
async fn take_data(&mut self) -> Result<Option<DataChunk>> {
let res = match self.stream.next().await {
None => {
return Ok(None);
}
Some(r) => r,
};
let task_data = res?;
let data = DataChunk::from_protobuf(task_data.get_record_batch()?)?.compact();
trace!(
"Receiver taskOutput = {:?}, data = {:?}",
self.task_output_id,
data
);

Ok(Some(data))
}
Ok(Some(data))
}

fn get_task_id(&self) -> TaskId {
Expand Down
31 changes: 13 additions & 18 deletions src/batch/src/execution/local_exchange.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
// limitations under the License.

use std::fmt::{Debug, Formatter};
use std::future::Future;

use risingwave_common::array::DataChunk;
use risingwave_common::error::Result;
Expand Down Expand Up @@ -52,23 +51,19 @@ impl Debug for LocalExchangeSource {
}

impl ExchangeSource for LocalExchangeSource {
type TakeDataFuture<'a> = impl Future<Output = Result<Option<DataChunk>>> + 'a;

fn take_data(&mut self) -> Self::TakeDataFuture<'_> {
async {
let ret = self.task_output.direct_take_data().await?;
if let Some(data) = ret {
let data = data.compact();
trace!(
"Receiver task: {:?}, source task output: {:?}, data: {:?}",
self.task_id,
self.task_output.id(),
data
);
Ok(Some(data))
} else {
Ok(None)
}
async fn take_data(&mut self) -> Result<Option<DataChunk>> {
let ret = self.task_output.direct_take_data().await?;
if let Some(data) = ret {
let data = data.compact();
trace!(
"Receiver task: {:?}, source task output: {:?}, data: {:?}",
self.task_id,
self.task_output.id(),
data
);
Ok(Some(data))
} else {
Ok(None)
}
}

Expand Down
15 changes: 5 additions & 10 deletions src/batch/src/executor/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
// limitations under the License.

use std::collections::VecDeque;
use std::future::Future;

use assert_matches::assert_matches;
use futures::StreamExt;
Expand Down Expand Up @@ -246,15 +245,11 @@ impl FakeExchangeSource {
}

impl ExchangeSource for FakeExchangeSource {
type TakeDataFuture<'a> = impl Future<Output = Result<Option<DataChunk>>> + 'a;

fn take_data(&mut self) -> Self::TakeDataFuture<'_> {
async {
if let Some(chunk) = self.chunks.pop() {
Ok(chunk)
} else {
Ok(None)
}
async fn take_data(&mut self) -> Result<Option<DataChunk>> {
if let Some(chunk) = self.chunks.pop() {
Ok(chunk)
} else {
Ok(None)
}
}

Expand Down
1 change: 1 addition & 0 deletions src/batch/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#![feature(result_option_inspect)]
#![feature(assert_matches)]
#![feature(lazy_cell)]
#![feature(return_position_impl_trait_in_trait)]

mod error;
pub mod exchange_source;
Expand Down
58 changes: 23 additions & 35 deletions src/connector/src/source/external.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,13 +200,9 @@ impl MySqlOffset {
}

pub trait ExternalTableReader {
type CdcOffsetFuture<'a>: Future<Output = ConnectorResult<CdcOffset>> + Send + 'a
where
Self: 'a;

fn get_normalized_table_name(&self, table_name: &SchemaTableName) -> String;

fn current_cdc_offset(&self) -> Self::CdcOffsetFuture<'_>;
fn current_cdc_offset(&self) -> impl Future<Output = ConnectorResult<CdcOffset>> + Send + '_;

fn parse_binlog_offset(&self, offset: &str) -> ConnectorResult<CdcOffset>;

Expand Down Expand Up @@ -248,32 +244,28 @@ pub struct ExternalTableConfig {
}

impl ExternalTableReader for MySqlExternalTableReader {
type CdcOffsetFuture<'a> = impl Future<Output = ConnectorResult<CdcOffset>> + 'a;

fn get_normalized_table_name(&self, table_name: &SchemaTableName) -> String {
format!("`{}`", table_name.table_name)
}

fn current_cdc_offset(&self) -> Self::CdcOffsetFuture<'_> {
async move {
let mut conn = self
.pool
.get_conn()
.await
.map_err(|e| ConnectorError::Connection(anyhow!(e)))?;

let sql = "SHOW MASTER STATUS".to_string();
let mut rs = conn.query::<mysql_async::Row, _>(sql).await?;
let row = rs
.iter_mut()
.exactly_one()
.map_err(|e| ConnectorError::Internal(anyhow!("read binlog error: {}", e)))?;

Ok(CdcOffset::MySql(MySqlOffset {
filename: row.take("File").unwrap(),
position: row.take("Position").unwrap(),
}))
}
async fn current_cdc_offset(&self) -> ConnectorResult<CdcOffset> {
let mut conn = self
.pool
.get_conn()
.await
.map_err(|e| ConnectorError::Connection(anyhow!(e)))?;

let sql = "SHOW MASTER STATUS".to_string();
let mut rs = conn.query::<mysql_async::Row, _>(sql).await?;
let row = rs
.iter_mut()
.exactly_one()
.map_err(|e| ConnectorError::Internal(anyhow!("read binlog error: {}", e)))?;

Ok(CdcOffset::MySql(MySqlOffset {
filename: row.take("File").unwrap(),
position: row.take("Position").unwrap(),
}))
}

fn parse_binlog_offset(&self, offset: &str) -> ConnectorResult<CdcOffset> {
Expand Down Expand Up @@ -478,21 +470,17 @@ impl MySqlExternalTableReader {
}

impl ExternalTableReader for ExternalTableReaderImpl {
type CdcOffsetFuture<'a> = impl Future<Output = ConnectorResult<CdcOffset>> + 'a;

fn get_normalized_table_name(&self, table_name: &SchemaTableName) -> String {
match self {
ExternalTableReaderImpl::MySql(mysql) => mysql.get_normalized_table_name(table_name),
ExternalTableReaderImpl::Mock(mock) => mock.get_normalized_table_name(table_name),
}
}

fn current_cdc_offset(&self) -> Self::CdcOffsetFuture<'_> {
async move {
match self {
ExternalTableReaderImpl::MySql(mysql) => mysql.current_cdc_offset().await,
ExternalTableReaderImpl::Mock(mock) => mock.current_cdc_offset().await,
}
async fn current_cdc_offset(&self) -> ConnectorResult<CdcOffset> {
match self {
ExternalTableReaderImpl::MySql(mysql) => mysql.current_cdc_offset().await,
ExternalTableReaderImpl::Mock(mock) => mock.current_cdc_offset().await,
}
}

Expand Down
24 changes: 10 additions & 14 deletions src/connector/src/source/mock_external_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::future::Future;
use std::sync::atomic::AtomicUsize;

use futures::stream::BoxStream;
Expand Down Expand Up @@ -91,24 +90,21 @@ impl MockExternalTableReader {
}

impl ExternalTableReader for MockExternalTableReader {
type CdcOffsetFuture<'a> = impl Future<Output = ConnectorResult<CdcOffset>> + 'a;

fn get_normalized_table_name(&self, _table_name: &SchemaTableName) -> String {
"`mock_table`".to_string()
}

fn current_cdc_offset(&self) -> Self::CdcOffsetFuture<'_> {
async fn current_cdc_offset(&self) -> ConnectorResult<CdcOffset> {
static IDX: AtomicUsize = AtomicUsize::new(0);
async move {
let idx = IDX.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
if idx < self.binlog_watermarks.len() {
Ok(CdcOffset::MySql(self.binlog_watermarks[idx].clone()))
} else {
Ok(CdcOffset::MySql(MySqlOffset {
filename: "1.binlog".to_string(),
position: u64::MAX,
}))
}

let idx = IDX.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
if idx < self.binlog_watermarks.len() {
Ok(CdcOffset::MySql(self.binlog_watermarks[idx].clone()))
} else {
Ok(CdcOffset::MySql(MySqlOffset {
filename: "1.binlog".to_string(),
position: u64::MAX,
}))
}
}

Expand Down
Loading