Skip to content

Commit d32a3d9

Browse files
authored
Share CLI connection logic to cli-tools (#576)
This will be used by xtask when using actions that require a connection like `applet-install`.
1 parent dfea057 commit d32a3d9

File tree

12 files changed

+137
-82
lines changed

12 files changed

+137
-82
lines changed

crates/cli-tools/CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
### Minor
66

7+
- Add `action::{ConnectionOptions,GlobalConnection}` for lazy platform connection
78
- Change the behavior of `fs::copy_if_changed()` to keep an original source
89

910
### Patch

crates/cli-tools/Cargo.lock

+14
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/cli-tools/Cargo.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ categories = ["command-line-utilities", "embedded", "wasm"]
1414
[dependencies]
1515
anyhow = { version = "1.0.86", default-features = false, features = ["std"] }
1616
cargo_metadata = { version = "0.18.1", default-features = false }
17-
clap = { version = "4.5.4", default-features = false, features = ["derive", "std"] }
17+
clap = { version = "4.5.4", default-features = false, features = ["derive", "env", "std"] }
18+
data-encoding = { version = "2.6.0", default-features = false, features = ["std"] }
19+
humantime = { version = "2.1.0", default-features = false }
1820
rusb = { version = "0.9.4", default-features = false }
1921
serde = { version = "1.0.202", default-features = false, features = ["derive"] }
2022
toml = { version = "0.8.13", default-features = false, features = ["display", "parse"] }

crates/cli-tools/src/action.rs

+87-2
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,101 @@
1515
use std::fmt::Display;
1616
use std::path::{Path, PathBuf};
1717
use std::process::Command;
18+
use std::time::Duration;
1819

1920
use anyhow::{bail, ensure, Result};
2021
use cargo_metadata::{Metadata, MetadataCommand};
2122
use clap::{ValueEnum, ValueHint};
22-
use rusb::UsbContext;
23+
use data_encoding::HEXLOWER_PERMISSIVE as HEX;
24+
use rusb::{GlobalContext, UsbContext};
2325
use wasefire_protocol::{self as service, applet, Api};
24-
use wasefire_protocol_usb::Connection;
26+
use wasefire_protocol_usb::{Candidate, Connection};
2527

2628
use crate::{cmd, fs};
2729

30+
/// Options to connect to a platform.
31+
#[derive(Clone, clap::Args)]
32+
pub struct ConnectionOptions {
33+
/// Serial of the platform to connect to.
34+
#[arg(long, env = "WASEFIRE_SERIAL")]
35+
serial: Option<String>,
36+
37+
/// Timeout to send or receive on the platform protocol.
38+
#[arg(long, default_value = "1s")]
39+
timeout: humantime::Duration,
40+
}
41+
42+
impl ConnectionOptions {
43+
/// Returns the timeout for platform connection.
44+
pub fn timeout(&self) -> Duration {
45+
*self.timeout
46+
}
47+
}
48+
49+
/// Reusable lazy connection to a platform.
50+
pub enum GlobalConnection {
51+
/// The connection is not yet configured and can be established.
52+
Invalid,
53+
/// The connection is configured but not yet established.
54+
Ready { options: ConnectionOptions },
55+
/// The connection is established.
56+
Connected { connection: Connection<GlobalContext> },
57+
}
58+
59+
impl GlobalConnection {
60+
/// Configures the connection (required to access it).
61+
pub fn configure(&mut self, options: ConnectionOptions) {
62+
match self {
63+
GlobalConnection::Invalid => *self = GlobalConnection::Ready { options },
64+
_ => panic!("connection already configured"),
65+
}
66+
}
67+
68+
/// Accesses the connection (establishing it if needed).
69+
pub fn get(&mut self) -> Result<&Connection<GlobalContext>> {
70+
if let GlobalConnection::Ready { options } = self {
71+
let connection = connect(options.timeout(), options.serial.as_deref())?;
72+
*self = GlobalConnection::Connected { connection };
73+
}
74+
match self {
75+
GlobalConnection::Connected { connection } => Ok(connection),
76+
_ => panic!("connection not yet configured"),
77+
}
78+
}
79+
}
80+
81+
fn connect(timeout: Duration, serial: Option<&str>) -> Result<Connection<GlobalContext>> {
82+
let context = GlobalContext::default();
83+
let mut candidates = wasefire_protocol_usb::list(&context)?;
84+
let candidate = match (serial, candidates.len()) {
85+
(None, 0) => bail!("no connected platforms"),
86+
(None, 1) => candidates.pop().unwrap(),
87+
(None, n) => {
88+
eprintln!("Choose one of the {n} connected platforms using its --serial option:");
89+
for candidate in candidates {
90+
eprintln!(" --serial={}", get_serial(&candidate, timeout)?);
91+
}
92+
bail!("more than one connected platform");
93+
}
94+
(Some(serial), _) => {
95+
match candidates
96+
.into_iter()
97+
.try_find(|x| anyhow::Ok(get_serial(x, timeout)? == serial))?
98+
{
99+
Some(x) => x,
100+
None => bail!("no connected platform with serial={serial}"),
101+
}
102+
}
103+
};
104+
Ok(candidate.connect(timeout)?)
105+
}
106+
107+
fn get_serial(candidate: &Candidate<GlobalContext>, timeout: Duration) -> Result<String> {
108+
let connection = candidate.clone().connect(timeout)?;
109+
let info = connection.call::<service::PlatformInfo>(())?;
110+
Ok(HEX.encode(info.get().serial))
111+
}
112+
28113
/// Parameters for an applet or platform RPC.
29114
#[derive(clap::Args)]
30115
pub struct Rpc {

crates/cli-tools/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
//! This library is also used for the internal maintenance CLI of Wasefire called xtask.
1818
1919
#![feature(path_add_extension)]
20+
#![feature(try_find)]
2021

2122
macro_rules! debug {
2223
($($x:tt)*) => {

crates/cli/CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,4 @@
3030

3131
## 0.1.0
3232

33-
<!-- Increment to skip CHANGELOG.md test: 2 -->
33+
<!-- Increment to skip CHANGELOG.md test: 3 -->

crates/cli/Cargo.lock

+2-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/cli/Cargo.toml

-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ anyhow = { version = "1.0.86", default-features = false }
2020
clap = { version = "4.5.4", default-features = false, features = ["default", "derive", "env"] }
2121
clap_complete = { version = "4.5.2", default-features = false }
2222
data-encoding = { version = "2.6.0", default-features = false, features = ["std"] }
23-
humantime = { version = "2.1.0", default-features = false }
2423
rusb = { version = "0.9.4", default-features = false }
2524
wasefire-cli-tools = { version = "0.1.1-git", path = "../cli-tools" }
2625
wasefire-protocol = { version = "0.1.1-git", path = "../protocol", features = ["host"] }

crates/cli/src/main.rs

+4-76
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
#![feature(try_find)]
16-
1715
use std::fs::File;
1816
use std::io::Write;
1917
use std::path::{Path, PathBuf};
@@ -27,29 +25,17 @@ use data_encoding::HEXLOWER_PERMISSIVE as HEX;
2725
use rusb::GlobalContext;
2826
use wasefire_cli_tools::{action, fs};
2927
use wasefire_protocol::{self as service, platform};
30-
use wasefire_protocol_usb::{Candidate, Connection};
3128

3229
#[derive(Parser)]
3330
#[command(version, about)]
3431
struct Flags {
3532
#[command(flatten)]
36-
options: Options,
33+
connection_options: action::ConnectionOptions,
3734

3835
#[command(subcommand)]
3936
action: Action,
4037
}
4138

42-
#[derive(clap::Args)]
43-
struct Options {
44-
/// Serial of the platform to connect to.
45-
#[arg(long, env = "WASEFIRE_SERIAL")]
46-
serial: Option<String>,
47-
48-
/// Timeout to send or receive on the platform protocol.
49-
#[arg(long, value_parser = humantime::parse_duration, default_value = "1s")]
50-
timeout: Duration,
51-
}
52-
5339
#[derive(clap::Subcommand)]
5440
enum Action {
5541
/// Lists the applets installed on a platform.
@@ -113,15 +99,15 @@ impl Completion {
11399

114100
fn main() -> Result<()> {
115101
let flags = Flags::parse();
116-
CONNECTION.lock().unwrap().set(flags.options.timeout, flags.options.serial);
102+
CONNECTION.lock().unwrap().configure(flags.connection_options.clone());
117103
let dir = std::env::current_dir()?;
118104
match flags.action {
119105
Action::AppletList => bail!("not implemented yet"),
120106
Action::AppletInstall => bail!("not implemented yet"),
121107
Action::AppletUpdate => bail!("not implemented yet"),
122108
Action::AppletUninstall => bail!("not implemented yet"),
123109
Action::AppletRpc(x) => x.run(CONNECTION.lock().unwrap().get()?),
124-
Action::PlatformList => platform_list(flags.options.timeout),
110+
Action::PlatformList => platform_list(flags.connection_options.timeout()),
125111
Action::PlatformUpdate => bail!("not implemented yet"),
126112
Action::PlatformReboot(x) => x.run(CONNECTION.lock().unwrap().get()?),
127113
Action::PlatformRpc(x) => x.run(CONNECTION.lock().unwrap().get()?),
@@ -132,33 +118,7 @@ fn main() -> Result<()> {
132118
}
133119
}
134120

135-
enum GlobalConnection {
136-
Invalid,
137-
Ready { timeout: Duration, serial: Option<String> },
138-
Connected { connection: Connection<GlobalContext> },
139-
}
140-
141-
impl GlobalConnection {
142-
fn set(&mut self, timeout: Duration, serial: Option<String>) {
143-
match self {
144-
GlobalConnection::Invalid => *self = GlobalConnection::Ready { timeout, serial },
145-
_ => unreachable!(),
146-
}
147-
}
148-
149-
fn get(&mut self) -> Result<&Connection<GlobalContext>> {
150-
if let GlobalConnection::Ready { timeout, serial } = self {
151-
*self =
152-
GlobalConnection::Connected { connection: connect(*timeout, serial.as_deref())? };
153-
}
154-
match self {
155-
GlobalConnection::Connected { connection } => Ok(connection),
156-
_ => unreachable!(),
157-
}
158-
}
159-
}
160-
161-
static CONNECTION: Mutex<GlobalConnection> = Mutex::new(GlobalConnection::Invalid);
121+
static CONNECTION: Mutex<action::GlobalConnection> = Mutex::new(action::GlobalConnection::Invalid);
162122

163123
fn platform_list(timeout: Duration) -> Result<()> {
164124
let context = GlobalContext::default();
@@ -174,35 +134,3 @@ fn platform_list(timeout: Duration) -> Result<()> {
174134
}
175135
Ok(())
176136
}
177-
178-
fn connect(timeout: Duration, serial: Option<&str>) -> Result<Connection<GlobalContext>> {
179-
let context = GlobalContext::default();
180-
let mut candidates = wasefire_protocol_usb::list(&context)?;
181-
let candidate = match (serial, candidates.len()) {
182-
(None, 0) => bail!("no connected platforms"),
183-
(None, 1) => candidates.pop().unwrap(),
184-
(None, n) => {
185-
eprintln!("Choose one of the {n} connected platforms using its --serial option:");
186-
for candidate in candidates {
187-
eprintln!(" --serial={}", get_serial(&candidate, timeout)?);
188-
}
189-
bail!("more than one connected platform");
190-
}
191-
(Some(serial), _) => {
192-
match candidates
193-
.into_iter()
194-
.try_find(|x| anyhow::Ok(get_serial(x, timeout)? == serial))?
195-
{
196-
Some(x) => x,
197-
None => bail!("no connected platform with serial={serial}"),
198-
}
199-
}
200-
};
201-
Ok(candidate.connect(timeout)?)
202-
}
203-
204-
fn get_serial(candidate: &Candidate<GlobalContext>, timeout: Duration) -> Result<String> {
205-
let connection = candidate.clone().connect(timeout)?;
206-
let info = connection.call::<service::PlatformInfo>(())?;
207-
Ok(HEX.encode(info.get().serial))
208-
}

crates/protocol/crates/schema/Cargo.lock

+14
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/runner-host/Cargo.lock

+2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/xtask/Cargo.lock

+8
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)