Skip to content

Commit 0c60de0

Browse files
stackinspectorsdroege
authored andcommitted
Add a simple WebSocketStream::send method to replace Sink trait usage
And also bump MSRV to 1.64. Fixes #142
1 parent ce58323 commit 0c60de0

File tree

8 files changed

+125
-30
lines changed

8 files changed

+125
-30
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ jobs:
9797
strategy:
9898
matrix:
9999
rust:
100-
- 1.63.0
100+
- 1.64.0
101101

102102
steps:
103103
- name: Checkout sources

Cargo.lock.msrv

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@ version = "0.28.0"
1212
edition = "2018"
1313
readme = "README.md"
1414
include = ["examples/**/*", "src/**/*", "LICENSE", "README.md", "CHANGELOG.md"]
15-
rust-version = "1.63"
15+
rust-version = "1.64"
1616

1717
[features]
18-
default = ["handshake"]
18+
default = ["handshake", "futures-03-sink"]
19+
futures-03-sink = ["futures-util"]
1920
handshake = ["tungstenite/handshake"]
2021
async-std-runtime = ["async-std", "handshake"]
2122
tokio-runtime = ["tokio", "handshake"]
@@ -37,10 +38,17 @@ features = ["async-std-runtime", "tokio-runtime", "gio-runtime", "async-tls", "a
3738

3839
[dependencies]
3940
log = "0.4"
40-
futures-util = { version = "0.3", default-features = false, features = ["sink", "std"] }
41+
futures-core = { version = "0.3", default-features = false }
42+
atomic-waker = { version = "1.1", default-features = false }
4143
futures-io = { version = "0.3", default-features = false, features = ["std"] }
4244
pin-project-lite = "0.2"
4345

46+
[dependencies.futures-util]
47+
optional = true
48+
version = "0.3"
49+
default-features = false
50+
features = ["sink"]
51+
4452
[dependencies.tungstenite]
4553
version = "0.24"
4654
default-features = false
@@ -141,7 +149,7 @@ required-features = ["async-std-runtime"]
141149

142150
[[example]]
143151
name = "autobahn-server"
144-
required-features = ["async-std-runtime"]
152+
required-features = ["async-std-runtime", "futures-03-sink"]
145153

146154
[[example]]
147155
name = "server"
@@ -153,7 +161,7 @@ required-features = ["async-std-runtime"]
153161

154162
[[example]]
155163
name = "server-headers"
156-
required-features = ["async-std-runtime", "handshake"]
164+
required-features = ["async-std-runtime", "handshake", "futures-util"]
157165

158166
[[example]]
159167
name = "interval-server"
@@ -173,4 +181,4 @@ required-features = ["tokio-runtime"]
173181

174182
[[example]]
175183
name = "server-custom-accept"
176-
required-features = ["tokio-runtime"]
184+
required-features = ["tokio-runtime", "futures-util"]

examples/autobahn-client.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ async fn run_test(case: u32) -> Result<()> {
3232
while let Some(msg) = ws_stream.next().await {
3333
let msg = msg?;
3434
if msg.is_text() || msg.is_binary() {
35+
// for Sink of futures 0.3, see autobahn-server example
3536
ws_stream.send(msg).await?;
3637
}
3738
}

examples/autobahn-server.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ async fn handle_connection(peer: SocketAddr, stream: TcpStream) -> Result<()> {
2323
while let Some(msg) = ws_stream.next().await {
2424
let msg = msg?;
2525
if msg.is_text() || msg.is_binary() {
26-
ws_stream.send(msg).await?;
26+
// here we explicitly using futures 0.3's Sink implementation for send message
27+
// for WebSocketStream::send, see autobahn-client example
28+
futures::SinkExt::send(&mut ws_stream, msg).await?;
2729
}
2830
}
2931

examples/server-headers.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ use async_tungstenite::{
2424
use url::Url;
2525
#[macro_use]
2626
extern crate log;
27-
use futures_util::{SinkExt, StreamExt};
27+
use futures_util::StreamExt;
2828

2929
#[async_std::main]
3030
async fn main() {

src/compat.rs

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
use log::*;
33
use std::io::{Read, Write};
44
use std::pin::Pin;
5-
use std::task::{Context, Poll};
5+
use std::task::{Context, Poll, Wake, Waker};
66

7+
use atomic_waker::AtomicWaker;
78
use futures_io::{AsyncRead, AsyncWrite};
8-
use futures_util::task;
99
use std::sync::Arc;
1010
use tungstenite::Error as WsError;
1111

@@ -49,18 +49,20 @@ pub(crate) struct AllowStd<S> {
4949
// read waker slot for this, but any would do.
5050
//
5151
// Don't ever use this from multiple tasks at the same time!
52+
#[cfg(feature = "handshake")]
5253
pub(crate) trait SetWaker {
53-
fn set_waker(&self, waker: &task::Waker);
54+
fn set_waker(&self, waker: &Waker);
5455
}
5556

57+
#[cfg(feature = "handshake")]
5658
impl<S> SetWaker for AllowStd<S> {
57-
fn set_waker(&self, waker: &task::Waker) {
59+
fn set_waker(&self, waker: &Waker) {
5860
self.set_waker(ContextWaker::Read, waker);
5961
}
6062
}
6163

6264
impl<S> AllowStd<S> {
63-
pub(crate) fn new(inner: S, waker: &task::Waker) -> Self {
65+
pub(crate) fn new(inner: S, waker: &Waker) -> Self {
6466
let res = Self {
6567
inner,
6668
write_waker_proxy: Default::default(),
@@ -83,7 +85,7 @@ impl<S> AllowStd<S> {
8385
//
8486
// Write: this is only supposde to be called by write operations, i.e. the Sink impl on the
8587
// WebSocketStream.
86-
pub(crate) fn set_waker(&self, kind: ContextWaker, waker: &task::Waker) {
88+
pub(crate) fn set_waker(&self, kind: ContextWaker, waker: &Waker) {
8789
match kind {
8890
ContextWaker::Read => {
8991
self.write_waker_proxy.read_waker.register(waker);
@@ -103,11 +105,11 @@ impl<S> AllowStd<S> {
103105
// reads and writes, and the same for writes.
104106
#[derive(Debug, Default)]
105107
struct WakerProxy {
106-
read_waker: task::AtomicWaker,
107-
write_waker: task::AtomicWaker,
108+
read_waker: AtomicWaker,
109+
write_waker: AtomicWaker,
108110
}
109111

110-
impl std::task::Wake for WakerProxy {
112+
impl Wake for WakerProxy {
111113
fn wake(self: Arc<Self>) {
112114
self.wake_by_ref()
113115
}
@@ -129,10 +131,10 @@ where
129131
#[cfg(feature = "verbose-logging")]
130132
trace!("{}:{} AllowStd.with_context", file!(), line!());
131133
let waker = match kind {
132-
ContextWaker::Read => task::Waker::from(self.read_waker_proxy.clone()),
133-
ContextWaker::Write => task::Waker::from(self.write_waker_proxy.clone()),
134+
ContextWaker::Read => Waker::from(self.read_waker_proxy.clone()),
135+
ContextWaker::Write => Waker::from(self.write_waker_proxy.clone()),
134136
};
135-
let mut context = task::Context::from_waker(&waker);
137+
let mut context = Context::from_waker(&waker);
136138
f(&mut context, Pin::new(&mut self.inner))
137139
}
138140

src/lib.rs

Lines changed: 89 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -58,17 +58,16 @@ mod handshake;
5858
))]
5959
pub mod stream;
6060

61-
use std::io::{Read, Write};
61+
use std::{
62+
io::{Read, Write},
63+
pin::Pin,
64+
task::{ready, Context, Poll},
65+
};
6266

6367
use compat::{cvt, AllowStd, ContextWaker};
68+
use futures_core::stream::{FusedStream, Stream};
6469
use futures_io::{AsyncRead, AsyncWrite};
65-
use futures_util::{
66-
sink::{Sink, SinkExt},
67-
stream::{FusedStream, Stream},
68-
};
6970
use log::*;
70-
use std::pin::Pin;
71-
use std::task::{Context, Poll};
7271

7372
#[cfg(feature = "handshake")]
7473
use tungstenite::{
@@ -227,6 +226,7 @@ where
227226
#[derive(Debug)]
228227
pub struct WebSocketStream<S> {
229228
inner: WebSocket<AllowStd<S>>,
229+
#[cfg(feature = "futures-03-sink")]
230230
closing: bool,
231231
ended: bool,
232232
/// Tungstenite is probably ready to receive more data.
@@ -269,6 +269,7 @@ impl<S> WebSocketStream<S> {
269269
pub(crate) fn new(ws: WebSocket<AllowStd<S>>) -> Self {
270270
Self {
271271
inner: ws,
272+
#[cfg(feature = "futures-03-sink")]
272273
closing: false,
273274
ended: false,
274275
ready: true,
@@ -337,7 +338,7 @@ where
337338
return Poll::Ready(None);
338339
}
339340

340-
match futures_util::ready!(self.with_context(Some((ContextWaker::Read, cx)), |s| {
341+
match ready!(self.with_context(Some((ContextWaker::Read, cx)), |s| {
341342
#[cfg(feature = "verbose-logging")]
342343
trace!(
343344
"{}:{} Stream.with_context poll_next -> read()",
@@ -368,7 +369,8 @@ where
368369
}
369370
}
370371

371-
impl<T> Sink<Message> for WebSocketStream<T>
372+
#[cfg(feature = "futures-03-sink")]
373+
impl<T> futures_util::Sink<Message> for WebSocketStream<T>
372374
where
373375
T: AsyncRead + AsyncWrite + Unpin,
374376
{
@@ -446,6 +448,84 @@ where
446448
}
447449
}
448450

451+
impl<S> WebSocketStream<S> {
452+
/// Simple send method to replace `futures_sink::Sink` (till v0.3).
453+
pub async fn send(&mut self, msg: Message) -> Result<(), WsError>
454+
where
455+
S: AsyncRead + AsyncWrite + Unpin,
456+
{
457+
Send::new(self, msg).await
458+
}
459+
}
460+
461+
struct Send<'a, S> {
462+
ws: &'a mut WebSocketStream<S>,
463+
msg: Option<Message>,
464+
}
465+
466+
impl<'a, S> Send<'a, S>
467+
where
468+
S: AsyncRead + AsyncWrite + Unpin,
469+
{
470+
fn new(ws: &'a mut WebSocketStream<S>, msg: Message) -> Self {
471+
Self { ws, msg: Some(msg) }
472+
}
473+
}
474+
475+
impl<S> std::future::Future for Send<'_, S>
476+
where
477+
S: AsyncRead + AsyncWrite + Unpin,
478+
{
479+
type Output = Result<(), WsError>;
480+
481+
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
482+
if self.msg.is_some() {
483+
if !self.ws.ready {
484+
// Currently blocked so try to flush the blockage away
485+
let polled = self
486+
.ws
487+
.with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
488+
.map(|r| {
489+
self.ws.ready = true;
490+
r
491+
});
492+
ready!(polled)?
493+
}
494+
495+
let msg = self.msg.take().expect("unreachable");
496+
match self.ws.with_context(None, |s| s.write(msg)) {
497+
Ok(_) => Ok(()),
498+
Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
499+
// the message was accepted and queued so not an error
500+
//
501+
// set to false here for cancellation safety of *this* Future
502+
self.ws.ready = false;
503+
Ok(())
504+
}
505+
Err(e) => {
506+
debug!("websocket start_send error: {}", e);
507+
Err(e)
508+
}
509+
}?;
510+
}
511+
512+
let polled = self
513+
.ws
514+
.with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
515+
.map(|r| {
516+
self.ws.ready = true;
517+
match r {
518+
// WebSocket connection has just been closed. Flushing completed, not an error.
519+
Err(WsError::ConnectionClosed) => Ok(()),
520+
other => other,
521+
}
522+
});
523+
ready!(polled)?;
524+
525+
Poll::Ready(Ok(()))
526+
}
527+
}
528+
449529
#[cfg(any(
450530
feature = "async-tls",
451531
feature = "async-std-runtime",

0 commit comments

Comments
 (0)