Skip to content

Commit a184f92

Browse files
tamirddjc
authored andcommitted
Improve type safety, extract identical code
Avoid fragility of tracking objects and their FDs separately.
1 parent 58afd15 commit a184f92

File tree

1 file changed

+73
-52
lines changed

1 file changed

+73
-52
lines changed

src/unix_term.rs

Lines changed: 73 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
use std::env;
22
use std::fmt::Display;
33
use std::fs;
4-
use std::io;
5-
use std::io::{BufRead, BufReader};
4+
use std::io::{self, BufRead, BufReader};
65
use std::mem;
7-
use std::os::unix::io::AsRawFd;
6+
use std::os::fd::{AsRawFd, RawFd};
87
use std::str;
98

109
#[cfg(not(target_os = "macos"))]
@@ -18,7 +17,7 @@ pub(crate) use crate::common_term::*;
1817
pub(crate) const DEFAULT_WIDTH: u16 = 80;
1918

2019
#[inline]
21-
pub(crate) fn is_a_terminal(out: &Term) -> bool {
20+
pub(crate) fn is_a_terminal(out: &impl AsRawFd) -> bool {
2221
unsafe { libc::isatty(out.as_raw_fd()) != 0 }
2322
}
2423

@@ -66,41 +65,73 @@ pub(crate) fn terminal_size(out: &Term) -> Option<(u16, u16)> {
6665
}
6766
}
6867

69-
pub(crate) fn read_secure() -> io::Result<String> {
70-
let mut f_tty;
71-
let fd = unsafe {
72-
if libc::isatty(libc::STDIN_FILENO) == 1 {
73-
f_tty = None;
74-
libc::STDIN_FILENO
75-
} else {
76-
let f = fs::OpenOptions::new()
77-
.read(true)
78-
.write(true)
79-
.open("/dev/tty")?;
80-
let fd = f.as_raw_fd();
81-
f_tty = Some(BufReader::new(f));
82-
fd
68+
enum Input<T> {
69+
Stdin(io::Stdin),
70+
File(T),
71+
}
72+
73+
fn unbuffered_input() -> io::Result<Input<fs::File>> {
74+
let stdin = io::stdin();
75+
if is_a_terminal(&stdin) {
76+
Ok(Input::Stdin(stdin))
77+
} else {
78+
let f = fs::OpenOptions::new()
79+
.read(true)
80+
.write(true)
81+
.open("/dev/tty")?;
82+
Ok(Input::File(f))
83+
}
84+
}
85+
86+
fn buffered_input() -> io::Result<Input<BufReader<fs::File>>> {
87+
Ok(match unbuffered_input()? {
88+
Input::Stdin(s) => Input::Stdin(s),
89+
Input::File(f) => Input::File(BufReader::new(f)),
90+
})
91+
}
92+
93+
// NB: this is not a full BufRead implementation because io::Stdin does not implement BufRead.
94+
impl<T: BufRead> Input<T> {
95+
fn read_line(&mut self, buf: &mut String) -> io::Result<usize> {
96+
match self {
97+
Self::Stdin(s) => s.read_line(buf),
98+
Self::File(f) => f.read_line(buf),
8399
}
84-
};
100+
}
101+
}
102+
103+
impl AsRawFd for Input<fs::File> {
104+
fn as_raw_fd(&self) -> RawFd {
105+
match self {
106+
Self::Stdin(s) => s.as_raw_fd(),
107+
Self::File(f) => f.as_raw_fd(),
108+
}
109+
}
110+
}
111+
112+
impl AsRawFd for Input<BufReader<fs::File>> {
113+
fn as_raw_fd(&self) -> RawFd {
114+
match self {
115+
Self::Stdin(s) => s.as_raw_fd(),
116+
Self::File(f) => f.get_ref().as_raw_fd(),
117+
}
118+
}
119+
}
120+
121+
pub(crate) fn read_secure() -> io::Result<String> {
122+
let mut input = buffered_input()?;
85123

86124
let mut termios = mem::MaybeUninit::uninit();
87-
c_result(|| unsafe { libc::tcgetattr(fd, termios.as_mut_ptr()) })?;
125+
c_result(|| unsafe { libc::tcgetattr(input.as_raw_fd(), termios.as_mut_ptr()) })?;
88126
let mut termios = unsafe { termios.assume_init() };
89127
let original = termios;
90128
termios.c_lflag &= !libc::ECHO;
91-
c_result(|| unsafe { libc::tcsetattr(fd, libc::TCSAFLUSH, &termios) })?;
129+
c_result(|| unsafe { libc::tcsetattr(input.as_raw_fd(), libc::TCSAFLUSH, &termios) })?;
92130
let mut rv = String::new();
93131

94-
let read_rv = if let Some(f) = &mut f_tty {
95-
f.read_line(&mut rv)
96-
} else {
97-
io::stdin().read_line(&mut rv)
98-
};
132+
let read_rv = input.read_line(&mut rv);
99133

100-
c_result(|| unsafe { libc::tcsetattr(fd, libc::TCSAFLUSH, &original) })?;
101-
102-
// Ensure the fd is only closed after everything has been restored.
103-
drop(f_tty);
134+
c_result(|| unsafe { libc::tcsetattr(input.as_raw_fd(), libc::TCSAFLUSH, &original) })?;
104135

105136
read_rv.map(|_| {
106137
let len = rv.trim_end_matches(&['\r', '\n'][..]).len();
@@ -109,7 +140,7 @@ pub(crate) fn read_secure() -> io::Result<String> {
109140
})
110141
}
111142

112-
fn poll_fd(fd: i32, timeout: i32) -> io::Result<bool> {
143+
fn poll_fd(fd: RawFd, timeout: i32) -> io::Result<bool> {
113144
let mut pollfd = libc::pollfd {
114145
fd,
115146
events: libc::POLLIN,
@@ -124,7 +155,7 @@ fn poll_fd(fd: i32, timeout: i32) -> io::Result<bool> {
124155
}
125156

126157
#[cfg(target_os = "macos")]
127-
fn select_fd(fd: i32, timeout: i32) -> io::Result<bool> {
158+
fn select_fd(fd: RawFd, timeout: i32) -> io::Result<bool> {
128159
unsafe {
129160
let mut read_fd_set: libc::fd_set = mem::zeroed();
130161

@@ -156,7 +187,7 @@ fn select_fd(fd: i32, timeout: i32) -> io::Result<bool> {
156187
}
157188
}
158189

159-
fn select_or_poll_term_fd(fd: i32, timeout: i32) -> io::Result<bool> {
190+
fn select_or_poll_term_fd(fd: RawFd, timeout: i32) -> io::Result<bool> {
160191
// There is a bug on macos that ttys cannot be polled, only select()
161192
// works. However given how problematic select is in general, we
162193
// normally want to use poll there too.
@@ -169,7 +200,7 @@ fn select_or_poll_term_fd(fd: i32, timeout: i32) -> io::Result<bool> {
169200
poll_fd(fd, timeout)
170201
}
171202

172-
fn read_single_char(fd: i32) -> io::Result<Option<char>> {
203+
fn read_single_char(fd: RawFd) -> io::Result<Option<char>> {
173204
// timeout of zero means that it will not block
174205
let is_ready = select_or_poll_term_fd(fd, 0)?;
175206

@@ -188,7 +219,7 @@ fn read_single_char(fd: i32) -> io::Result<Option<char>> {
188219
// Similar to libc::read. Read count bytes into slice buf from descriptor fd.
189220
// If successful, return the number of bytes read.
190221
// Will return an error if nothing was read, i.e when called at end of file.
191-
fn read_bytes(fd: i32, buf: &mut [u8], count: u8) -> io::Result<u8> {
222+
fn read_bytes(fd: RawFd, buf: &mut [u8], count: u8) -> io::Result<u8> {
192223
let read = unsafe { libc::read(fd, buf.as_mut_ptr() as *mut _, count as usize) };
193224
if read < 0 {
194225
Err(io::Error::last_os_error())
@@ -207,7 +238,7 @@ fn read_bytes(fd: i32, buf: &mut [u8], count: u8) -> io::Result<u8> {
207238
}
208239
}
209240

210-
fn read_single_key_impl(fd: i32) -> Result<Key, io::Error> {
241+
fn read_single_key_impl(fd: RawFd) -> Result<Key, io::Error> {
211242
loop {
212243
match read_single_char(fd)? {
213244
Some('\x1b') => {
@@ -301,27 +332,17 @@ fn read_single_key_impl(fd: i32) -> Result<Key, io::Error> {
301332
}
302333

303334
pub(crate) fn read_single_key(ctrlc_key: bool) -> io::Result<Key> {
304-
let tty_f;
305-
let fd = unsafe {
306-
if libc::isatty(libc::STDIN_FILENO) == 1 {
307-
libc::STDIN_FILENO
308-
} else {
309-
tty_f = fs::OpenOptions::new()
310-
.read(true)
311-
.write(true)
312-
.open("/dev/tty")?;
313-
tty_f.as_raw_fd()
314-
}
315-
};
335+
let input = unbuffered_input()?;
336+
316337
let mut termios = core::mem::MaybeUninit::uninit();
317-
c_result(|| unsafe { libc::tcgetattr(fd, termios.as_mut_ptr()) })?;
338+
c_result(|| unsafe { libc::tcgetattr(input.as_raw_fd(), termios.as_mut_ptr()) })?;
318339
let mut termios = unsafe { termios.assume_init() };
319340
let original = termios;
320341
unsafe { libc::cfmakeraw(&mut termios) };
321342
termios.c_oflag = original.c_oflag;
322-
c_result(|| unsafe { libc::tcsetattr(fd, libc::TCSADRAIN, &termios) })?;
323-
let rv: io::Result<Key> = read_single_key_impl(fd);
324-
c_result(|| unsafe { libc::tcsetattr(fd, libc::TCSADRAIN, &original) })?;
343+
c_result(|| unsafe { libc::tcsetattr(input.as_raw_fd(), libc::TCSADRAIN, &termios) })?;
344+
let rv: io::Result<Key> = read_single_key_impl(input.as_raw_fd());
345+
c_result(|| unsafe { libc::tcsetattr(input.as_raw_fd(), libc::TCSADRAIN, &original) })?;
325346

326347
// if the user hit ^C we want to signal SIGINT to ourselves.
327348
if let Err(ref err) = rv {

0 commit comments

Comments
 (0)