在我先前的文章中,我通过 Java 实现了使用多个线程交替打印 ABC。刚好最近在学习 Rust,于是就来尝试一下如何它来解决这个简单的问题。

类似地,我通过原子变量和互斥锁两种实现方式来完成这个任务。Rust 的实现与 Java 相比在细节上有着更多需要注意的地方(即非常多坑点),但最后写完后还是让我不得不佩服 Rust 的设计理念。强大且规则严格的编译器虽然在写代码的时候有些痛苦,但比起黑盒的 GC 的确是更加让人安心,是一门真正能让代码苦手也能写出高质量代码的语言。

完整代码

下面直接贴出全部的代码,后续其它小节将对每个实现进行详细的分析。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Condvar, Mutex};
use std::thread;

const MAX_LEN: usize = 12; // 最大输出次数

/// 通过 `AtomicUsize` 实现交替打印目标字符串的能力
fn by_atomic(raw_data: Vec<String>) {
let data = Arc::new(raw_data);
let len = data.len();
let status = Arc::new(AtomicUsize::new(0));
let mut handle_vec = Vec::new();

for i in 0..len {
let status_clone = status.clone();
let data_clone = data.clone();
handle_vec.push(thread::spawn(move || {
loop {
while status_clone.load(Ordering::Relaxed) % len != i {
std::hint::spin_loop(); // 自旋
}
// 进入临界区
let s = status_clone.load(Ordering::Relaxed);
println!("Val {0:?} Index {1:?}", data_clone[s % len], s);
// 离开临界区
status_clone.fetch_add(1, Ordering::Relaxed);

// 必须将最大循环次数减去线程数,才能保证最终打印的数量正确
// 这是因为我们通过 `fetch_add` 自增操作来控制 `status`
if s >= MAX_LEN - len {
break;
}
}
}));
}

for handle in handle_vec {
handle.join().unwrap();
}
}

fn by_mutex(raw_data: Vec<String>) {
let data = Arc::new(raw_data);
let len = data.len();
let status = Arc::new((Mutex::new(0), Condvar::new()));
let mut handle_vec = Vec::new();

for i in 0..len {
let data_clone = data.clone();
let status_tup = status.clone();
handle_vec.push(thread::spawn(move || {
let (mutex_clone, condvar_clone) = &*status_tup;
loop {
let mut s = mutex_clone.lock().unwrap();
while *s % len != i {
s = condvar_clone.wait(s).unwrap(); // 阻塞
}
// 进入临界区
if *s < MAX_LEN {
// 需要判断 `status` 是否超出 `MAX_LEN`
// 因为 `status` 在最终依次销毁线程时一定会超出最大值,此时无需打印
// 每存在一个线程,`status` 最终值就将增加1
println!("Val {0:?} Index {1:?}", data_clone[*s % len], *s);
}
*s += 1;
// 离开临界区, 唤醒其他线程
condvar_clone.notify_all();

if *s >= MAX_LEN {
break;
}
}
}));
}

for handle in handle_vec {
handle.join().unwrap();
}
println!("Last status: {0:?}", status.0.lock().unwrap()); // 打印 MAX_LEN + len - 1
}

fn main() {
let input = vec!["A".to_string(), "B".to_string(), "C".to_string()];
by_atomic(input.clone());
println!("------------------------------------------");
by_mutex(input.clone());
}

标准输出如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
Val "A" Index 0
Val "B" Index 1
Val "C" Index 2
Val "A" Index 3
Val "B" Index 4
Val "C" Index 5
Val "A" Index 6
Val "B" Index 7
Val "C" Index 8
Val "A" Index 9
Val "B" Index 10
Val "C" Index 11
------------------------------------------
Val "A" Index 0
Val "B" Index 1
Val "C" Index 2
Val "A" Index 3
Val "B" Index 4
Val "C" Index 5
Val "A" Index 6
Val "B" Index 7
Val "C" Index 8
Val "A" Index 9
Val "B" Index 10
Val "C" Index 11
Last status: 14

原子操作实现

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
fn by_atomic(raw_data: Vec<String>) {
let data = Arc::new(raw_data);
let len = data.len();
let status = Arc::new(AtomicUsize::new(0));
let mut handle_vec = Vec::new();

for i in 0..len {
let status_clone = status.clone();
let data_clone = data.clone();
handle_vec.push(thread::spawn(move || {
loop {
while status_clone.load(Ordering::Relaxed) % len != i {
std::hint::spin_loop(); // 自旋
}
// 进入临界区
let s = status_clone.load(Ordering::Relaxed);
println!("Val {0:?} Index {1:?}", data_clone[s % len], s);
// 离开临界区
status_clone.fetch_add(1, Ordering::Relaxed);

// 必须将最大循环次数减去线程数,才能保证最终打印的数量正确
// 这是因为我们通过 `fetch_add` 自增操作来控制 `status`
if s >= MAX_LEN - len {
break;
}
}
}));
}

for handle in handle_vec {
handle.join().unwrap();
}
}

首先我们通过原子化引用计数智能指针 Arc 来共享输入数据和原子状态变量 status。同时需要保存每个线程的句柄到 handle_vec 中,以便在最后等待所有线程结束。

1
2
3
4
let data = Arc::new(raw_data);
let len = data.len();
let status = Arc::new(AtomicUsize::new(0));
let mut handle_vec = Vec::new();

我们需要创建 len 个线程,len 是字符串数组的长度,每个线程都负责打印一个字符。对每一个线程,我们都需要赋予其输入数据和状态变量的所有权。

1
2
3
4
5
6
7
for i in 0..len {
let status_clone = status.clone();
let data_clone = data.clone();
handle_vec.push(thread::spawn(move || {
...
}));
}

对于每一个线程,我们需要一个无限循环来确保线程持续地打印字符。通过原子变量 status_clone 和自旋锁 std::hint::spin_loop() 来保证输出顺序,当前不应输出的线程将一直处于自旋状态直到 status 的值满足条件。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
loop {
while status_clone.load(Ordering::Relaxed) % len != i {
std::hint::spin_loop(); // 自旋
}
// 进入临界区
let s = status_clone.load(Ordering::Relaxed);
println!("Val {0:?} Index {1:?}", data_clone[s % len], s);
// 离开临界区
status_clone.fetch_add(1, Ordering::Relaxed);

// 必须将最大循环次数减去线程数,才能保证最终打印的数量正确
// 这是因为我们通过 `fetch_add` 自增操作来控制 `status`
// 在最后一轮循环中,`status` 将会超出最大值,此时无需打印直接退出即可
if s >= MAX_LEN - len {
break;
}
}

有一个非常容易忽略的点是,如果将

1
2
3
let s = status_clone.load(Ordering::Relaxed);
println!("Val {0:?} Index {1:?}", data_clone[s % len], s);
status_clone.fetch_add(1, Ordering::Relaxed);

改为下述两行代码,尽管status_clone的值在两种情况下都是相同的,但下述代码将无法保证打印的顺序执行

1
2
let s = status_clone.fetch_add(1, Ordering::SeqCst);
println!("Val {0:?} Index {1:?}", data_clone[s % len], s);

这是因为 status 的值一旦修改,就表示脱离了临界区,此时无法保证该线程能按预期顺序立即打印字母,即 status 可以看作一个互斥锁,脱离 spin_loop() 自旋表示获取锁,状态 改变 表示释放锁。

互斥锁实现

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
fn by_mutex(raw_data: Vec<String>) {
let data = Arc::new(raw_data);
let len = data.len();
let status = Arc::new((Mutex::new(0), Condvar::new()));
let mut handle_vec = Vec::new();

for i in 0..len {
let data_clone = data.clone();
let status_tup = status.clone();
handle_vec.push(thread::spawn(move || {
let (mutex_clone, condvar_clone) = &*status_tup;
loop {
let mut s = mutex_clone.lock().unwrap();
while *s % len != i {
s = condvar_clone.wait(s).unwrap(); // 阻塞
}
// 进入临界区
if *s < MAX_LEN {
// 需要判断 `status` 是否超出 `MAX_LEN`
// 因为 `status` 在最终依次销毁线程时一定会超出最大值,此时无需打印
// 每存在一个线程,`status` 最终值就将增加1
println!("Val {0:?} Index {1:?}", data_clone[*s % len], *s);
}
*s += 1;
// 离开临界区, 唤醒其他线程
condvar_clone.notify_all();

if *s >= MAX_LEN {
break;
}
}
}));
}

for handle in handle_vec {
handle.join().unwrap();
}
println!("Last status: {0:?}", status.0.lock().unwrap()); // 打印 MAX_LEN + len - 1
}

在互斥锁实现中,我使用了 MutexCondvar 来实现线程间的同步。Mutex 用于保护共享数据的访问,而 Condvar 用于在条件不满足时阻塞线程。

1
2
3
4
5
6
7
8
9
10
11
12
let data = Arc::new(raw_data);
let len = data.len();
let status = Arc::new((Mutex::new(0), Condvar::new()));
let mut handle_vec = Vec::new();

for i in 0..len {
let data_clone = data.clone();
let status_tup = status.clone();
handle_vec.push(thread::spawn(move || {
...
}));
}

对于互斥锁和状态变量元组,应当通过解引用再取地址的方式来获取其值。否则编译器无法推断元组内数据的类型。

1
let (mutex_clone, condvar_clone) = &*status_tup;

类似地,我们可以通过无限循环字符的持续打印。互斥锁的使用和原子变量非常类似,不同的是在满足阻塞条件时,不使用自旋而是使用基于 Condvar 的阻塞方式。没有获取锁的线程将一直阻塞直到其他线程通过 notify_all() 唤醒。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
loop {
let mut s = mutex_clone.lock().unwrap();
while *s % len != i {
s = condvar_clone.wait(s).unwrap(); // 阻塞
}
// 进入临界区
if *s < MAX_LEN {
// 需要判断 `status` 是否超出 `MAX_LEN`
// 因为 `status` 在最终依次销毁线程时一定会超出最大值,此时无需打印
// 每存在一个线程,`status` 最终值就将增加1(相对 MAX_LEN)
println!("Val {0:?} Index {1:?}", data_clone[*s % len], *s);
}
*s += 1;
// 离开临界区, 唤醒其他线程
condvar_clone.notify_all();

if *s >= MAX_LEN {
break;
}
}

和原子操作实现的不同点在于,我们需要在打印完字符后,首先 判断 status 是否超出最大值 MAX_LEN。因为原子操作中进入临界区的条件基于原子变量数值的大小,而在互斥锁实现中进入临界区则是由 Condvar 来控制的。

可以在函数最后打印 Mutex 的值来看看最终的 status 是多少。