Skip to content

Commit a68707a

Browse files
authored
[ISSUES #190] Add invoker extension and enhance extension design (#195)
* add Invoker extension * refactor extension mod * remove Extension constraint * cargo fmt * add license header * cargo fmt * add error handing * fix compile error * add load invoker method
1 parent 6943c9f commit a68707a

File tree

7 files changed

+506
-69
lines changed

7 files changed

+506
-69
lines changed
Lines changed: 303 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,303 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
use crate::{
19+
extension::{
20+
invoker_extension::proxy::InvokerProxy, Extension, ExtensionFactories, ExtensionMetaInfo,
21+
LoadExtensionPromise,
22+
},
23+
params::extension_param::{ExtensionName, ExtensionType},
24+
url::UrlParam,
25+
StdError, Url,
26+
};
27+
use async_trait::async_trait;
28+
use bytes::Bytes;
29+
use futures_core::Stream;
30+
use std::{collections::HashMap, future::Future, marker::PhantomData, pin::Pin};
31+
use thiserror::Error;
32+
33+
#[async_trait]
34+
pub trait Invoker {
35+
async fn invoke(
36+
&self,
37+
invocation: GrpcInvocation,
38+
) -> Result<Pin<Box<dyn Stream<Item = Bytes> + Send + 'static>>, StdError>;
39+
40+
async fn url(&self) -> Result<Url, StdError>;
41+
}
42+
43+
pub enum CallType {
44+
Unary,
45+
ClientStream,
46+
ServerStream,
47+
BiStream,
48+
}
49+
50+
pub struct GrpcInvocation {
51+
service_name: String,
52+
method_name: String,
53+
arguments: Vec<Argument>,
54+
attachments: HashMap<String, String>,
55+
call_type: CallType,
56+
}
57+
58+
pub struct Argument {
59+
name: String,
60+
value: Box<dyn Stream<Item = Box<dyn Serializable + Send + 'static>> + Send + 'static>,
61+
}
62+
63+
pub trait Serializable {
64+
fn serialize(&self, serialization_type: String) -> Result<Bytes, StdError>;
65+
}
66+
67+
pub trait Deserializable {
68+
fn deserialize(&self, bytes: Bytes, deserialization_type: String) -> Result<Self, StdError>
69+
where
70+
Self: Sized;
71+
}
72+
73+
pub mod proxy {
74+
use crate::{
75+
extension::invoker_extension::{GrpcInvocation, Invoker},
76+
StdError, Url,
77+
};
78+
use async_trait::async_trait;
79+
use bytes::Bytes;
80+
use futures_core::Stream;
81+
use std::pin::Pin;
82+
use thiserror::Error;
83+
use tokio::sync::{mpsc::Sender, oneshot};
84+
use tracing::error;
85+
86+
pub(super) enum InvokerOpt {
87+
Invoke(
88+
GrpcInvocation,
89+
oneshot::Sender<Result<Pin<Box<dyn Stream<Item = Bytes> + Send + 'static>>, StdError>>,
90+
),
91+
Url(oneshot::Sender<Result<Url, StdError>>),
92+
}
93+
94+
#[derive(Clone)]
95+
pub struct InvokerProxy {
96+
tx: Sender<InvokerOpt>,
97+
}
98+
99+
#[async_trait]
100+
impl Invoker for InvokerProxy {
101+
async fn invoke(
102+
&self,
103+
invocation: GrpcInvocation,
104+
) -> Result<Pin<Box<dyn Stream<Item = Bytes> + Send + 'static>>, StdError> {
105+
let (tx, rx) = oneshot::channel();
106+
let ret = self.tx.send(InvokerOpt::Invoke(invocation, tx)).await;
107+
match ret {
108+
Ok(_) => {}
109+
Err(err) => {
110+
error!(
111+
"call invoke method failed by invoker proxy, error: {:?}",
112+
err
113+
);
114+
return Err(InvokerProxyError::new(
115+
"call invoke method failed by invoker proxy",
116+
)
117+
.into());
118+
}
119+
}
120+
let ret = rx.await?;
121+
ret
122+
}
123+
124+
async fn url(&self) -> Result<Url, StdError> {
125+
let (tx, rx) = oneshot::channel();
126+
let ret = self.tx.send(InvokerOpt::Url(tx)).await;
127+
match ret {
128+
Ok(_) => {}
129+
Err(err) => {
130+
error!("call url method failed by invoker proxy, error: {:?}", err);
131+
return Err(
132+
InvokerProxyError::new("call url method failed by invoker proxy").into(),
133+
);
134+
}
135+
}
136+
let ret = rx.await?;
137+
ret
138+
}
139+
}
140+
141+
impl From<Box<dyn Invoker + Send + 'static>> for InvokerProxy {
142+
fn from(invoker: Box<dyn Invoker + Send + 'static>) -> Self {
143+
let (tx, mut rx) = tokio::sync::mpsc::channel(64);
144+
tokio::spawn(async move {
145+
while let Some(opt) = rx.recv().await {
146+
match opt {
147+
InvokerOpt::Invoke(invocation, tx) => {
148+
let result = invoker.invoke(invocation).await;
149+
let callback_ret = tx.send(result);
150+
match callback_ret {
151+
Ok(_) => {}
152+
Err(Err(err)) => {
153+
error!("invoke method has been called, but callback to caller failed. {:?}", err);
154+
}
155+
_ => {}
156+
}
157+
}
158+
InvokerOpt::Url(tx) => {
159+
let ret = tx.send(invoker.url().await);
160+
match ret {
161+
Ok(_) => {}
162+
Err(err) => {
163+
error!("url method has been called, but callback to caller failed. {:?}", err);
164+
}
165+
}
166+
}
167+
}
168+
}
169+
});
170+
InvokerProxy { tx }
171+
}
172+
}
173+
174+
#[derive(Error, Debug)]
175+
#[error("invoker proxy error: {0}")]
176+
pub struct InvokerProxyError(String);
177+
178+
impl InvokerProxyError {
179+
pub fn new(msg: &str) -> Self {
180+
InvokerProxyError(msg.to_string())
181+
}
182+
}
183+
}
184+
185+
#[derive(Default)]
186+
pub(super) struct InvokerExtensionLoader {
187+
factories: HashMap<String, InvokerExtensionFactory>,
188+
}
189+
190+
impl InvokerExtensionLoader {
191+
pub fn register(&mut self, extension_name: String, factory: InvokerExtensionFactory) {
192+
self.factories.insert(extension_name, factory);
193+
}
194+
195+
pub fn remove(&mut self, extension_name: String) {
196+
self.factories.remove(&extension_name);
197+
}
198+
199+
pub fn load(&mut self, url: Url) -> Result<LoadExtensionPromise<InvokerProxy>, StdError> {
200+
let extension_name = url.query::<ExtensionName>();
201+
let Some(extension_name) = extension_name else {
202+
return Err(InvokerExtensionLoaderError::new(
203+
"load invoker extension failed, extension mustn't be empty",
204+
)
205+
.into());
206+
};
207+
let extension_name = extension_name.value();
208+
let factory = self.factories.get_mut(&extension_name);
209+
let Some(factory) = factory else {
210+
let err_msg = format!(
211+
"load {} invoker extension failed, can not found extension factory",
212+
extension_name
213+
);
214+
return Err(InvokerExtensionLoaderError(err_msg).into());
215+
};
216+
factory.create(url)
217+
}
218+
}
219+
220+
type InvokerExtensionConstructor = fn(
221+
Url,
222+
) -> Pin<
223+
Box<dyn Future<Output = Result<Box<dyn Invoker + Send + 'static>, StdError>> + Send + 'static>,
224+
>;
225+
pub(crate) struct InvokerExtensionFactory {
226+
constructor: InvokerExtensionConstructor,
227+
instances: HashMap<String, LoadExtensionPromise<InvokerProxy>>,
228+
}
229+
230+
impl InvokerExtensionFactory {
231+
pub fn new(constructor: InvokerExtensionConstructor) -> Self {
232+
Self {
233+
constructor,
234+
instances: HashMap::default(),
235+
}
236+
}
237+
}
238+
239+
impl InvokerExtensionFactory {
240+
pub fn create(&mut self, url: Url) -> Result<LoadExtensionPromise<InvokerProxy>, StdError> {
241+
let key = url.to_string();
242+
243+
match self.instances.get(&key) {
244+
Some(instance) => Ok(instance.clone()),
245+
None => {
246+
let constructor = self.constructor;
247+
let creator = move |url: Url| {
248+
let invoker_future = constructor(url);
249+
Box::pin(async move {
250+
let invoker = invoker_future.await?;
251+
Ok(InvokerProxy::from(invoker))
252+
})
253+
as Pin<
254+
Box<
255+
dyn Future<Output = Result<InvokerProxy, StdError>>
256+
+ Send
257+
+ 'static,
258+
>,
259+
>
260+
};
261+
262+
let promise: LoadExtensionPromise<InvokerProxy> =
263+
LoadExtensionPromise::new(Box::new(creator), url);
264+
self.instances.insert(key, promise.clone());
265+
Ok(promise)
266+
}
267+
}
268+
}
269+
}
270+
271+
pub struct InvokerExtension<T>(PhantomData<T>)
272+
where
273+
T: Invoker + Send + 'static;
274+
275+
impl<T> ExtensionMetaInfo for InvokerExtension<T>
276+
where
277+
T: Invoker + Send + 'static,
278+
T: Extension<Target = Box<dyn Invoker + Send + 'static>>,
279+
{
280+
fn name() -> String {
281+
T::name()
282+
}
283+
284+
fn extension_type() -> ExtensionType {
285+
ExtensionType::Invoker
286+
}
287+
288+
fn extension_factory() -> ExtensionFactories {
289+
ExtensionFactories::InvokerExtensionFactory(InvokerExtensionFactory::new(
290+
<T as Extension>::create,
291+
))
292+
}
293+
}
294+
295+
#[derive(Error, Debug)]
296+
#[error("{0}")]
297+
pub struct InvokerExtensionLoaderError(String);
298+
299+
impl InvokerExtensionLoaderError {
300+
pub fn new(msg: &str) -> Self {
301+
InvokerExtensionLoaderError(msg.to_string())
302+
}
303+
}

0 commit comments

Comments
 (0)