1 /*
2  * Copyright (C) 2016 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ThreadCapture.h"
18 
19 #include <fcntl.h>
20 #include <pthread.h>
21 #include <sys/syscall.h>
22 #include <sys/types.h>
23 
24 #include <algorithm>
25 #include <functional>
26 #include <memory>
27 #include <thread>
28 
29 #include <gtest/gtest.h>
30 
31 #include "Allocator.h"
32 #include "ScopedDisableMalloc.h"
33 #include "ScopedPipe.h"
34 
35 #include <android-base/threads.h>
36 
37 using namespace std::chrono_literals;
38 
39 namespace android {
40 
41 class ThreadListTest : public ::testing::TestWithParam<int> {
42  public:
ThreadListTest()43   ThreadListTest() : stop_(false) {}
44 
~ThreadListTest()45   ~ThreadListTest() {
46     // pthread_join may return before the entry in /proc/pid/task/ is gone,
47     // loop until ListThreads only finds the main thread so the next test
48     // doesn't fail.
49     WaitForThreads();
50   }
51 
TearDown()52   virtual void TearDown() { ASSERT_TRUE(heap.empty()); }
53 
54  protected:
55   template <class Function>
StartThreads(unsigned int threads,Function && func)56   void StartThreads(unsigned int threads, Function&& func) {
57     threads_.reserve(threads);
58     tids_.reserve(threads);
59     for (unsigned int i = 0; i < threads; i++) {
60       threads_.emplace_back([&, threads, this]() {
61         {
62           std::lock_guard<std::mutex> lk(m_);
63           tids_.push_back(gettid());
64           if (tids_.size() == threads) {
65             cv_start_.notify_one();
66           }
67         }
68 
69         func();
70 
71         {
72           std::unique_lock<std::mutex> lk(m_);
73           cv_stop_.wait(lk, [&] { return stop_; });
74         }
75       });
76     }
77 
78     {
79       std::unique_lock<std::mutex> lk(m_);
80       cv_start_.wait(lk, [&] { return tids_.size() == threads; });
81     }
82   }
83 
StopThreads()84   void StopThreads() {
85     {
86       std::lock_guard<std::mutex> lk(m_);
87       stop_ = true;
88     }
89     cv_stop_.notify_all();
90 
91     for (auto i = threads_.begin(); i != threads_.end(); i++) {
92       i->join();
93     }
94     threads_.clear();
95     tids_.clear();
96   }
97 
tids()98   std::vector<pid_t>& tids() { return tids_; }
99 
100   Heap heap;
101 
102  private:
WaitForThreads()103   void WaitForThreads() {
104     auto tids = TidList{heap};
105     ThreadCapture thread_capture{getpid(), heap};
106 
107     for (unsigned int i = 0; i < 100; i++) {
108       EXPECT_TRUE(thread_capture.ListThreads(tids));
109       if (tids.size() == 1) {
110         break;
111       }
112       std::this_thread::sleep_for(10ms);
113     }
114     EXPECT_EQ(1U, tids.size());
115   }
116 
117   std::mutex m_;
118   std::condition_variable cv_start_;
119   std::condition_variable cv_stop_;
120   bool stop_;
121   std::vector<pid_t> tids_;
122 
123   std::vector<std::thread> threads_;
124 };
125 
TEST_F(ThreadListTest,list_one)126 TEST_F(ThreadListTest, list_one) {
127   ScopedDisableMallocTimeout disable_malloc;
128 
129   ThreadCapture thread_capture(getpid(), heap);
130 
131   auto expected_tids = allocator::vector<pid_t>(1, getpid(), heap);
132   auto list_tids = allocator::vector<pid_t>(heap);
133 
134   ASSERT_TRUE(thread_capture.ListThreads(list_tids));
135 
136   ASSERT_EQ(expected_tids, list_tids);
137 
138   if (!HasFailure()) {
139     ASSERT_FALSE(disable_malloc.timed_out());
140   }
141 }
142 
TEST_P(ThreadListTest,list_some)143 TEST_P(ThreadListTest, list_some) {
144   const unsigned int threads = GetParam() - 1;
145 
146   StartThreads(threads, []() {});
147   std::vector<pid_t> expected_tids = tids();
148   expected_tids.push_back(getpid());
149 
150   auto list_tids = allocator::vector<pid_t>(heap);
151 
152   {
153     ScopedDisableMallocTimeout disable_malloc;
154 
155     ThreadCapture thread_capture(getpid(), heap);
156 
157     ASSERT_TRUE(thread_capture.ListThreads(list_tids));
158 
159     if (!HasFailure()) {
160       ASSERT_FALSE(disable_malloc.timed_out());
161     }
162   }
163 
164   StopThreads();
165 
166   std::sort(list_tids.begin(), list_tids.end());
167   std::sort(expected_tids.begin(), expected_tids.end());
168 
169   ASSERT_EQ(expected_tids.size(), list_tids.size());
170   EXPECT_TRUE(std::equal(expected_tids.begin(), expected_tids.end(), list_tids.begin()));
171 }
172 
173 INSTANTIATE_TEST_CASE_P(ThreadListTest, ThreadListTest, ::testing::Values(1, 2, 10, 1024));
174 
175 class ThreadCaptureTest : public ThreadListTest {
176  public:
ThreadCaptureTest()177   ThreadCaptureTest() {}
~ThreadCaptureTest()178   ~ThreadCaptureTest() {}
Fork(std::function<void ()> && child_init,std::function<void ()> && child_cleanup,std::function<void (pid_t)> && parent)179   void Fork(std::function<void()>&& child_init, std::function<void()>&& child_cleanup,
180             std::function<void(pid_t)>&& parent) {
181     ScopedPipe start_pipe;
182     ScopedPipe stop_pipe;
183 
184     int pid = fork();
185 
186     if (pid == 0) {
187       // child
188       child_init();
189       EXPECT_EQ(1, TEMP_FAILURE_RETRY(write(start_pipe.Sender(), "+", 1))) << strerror(errno);
190       char buf;
191       EXPECT_EQ(1, TEMP_FAILURE_RETRY(read(stop_pipe.Receiver(), &buf, 1))) << strerror(errno);
192       child_cleanup();
193       _exit(0);
194     } else {
195       // parent
196       ASSERT_GT(pid, 0);
197       char buf;
198       ASSERT_EQ(1, TEMP_FAILURE_RETRY(read(start_pipe.Receiver(), &buf, 1))) << strerror(errno);
199 
200       parent(pid);
201 
202       ASSERT_EQ(1, TEMP_FAILURE_RETRY(write(stop_pipe.Sender(), "+", 1))) << strerror(errno);
203       siginfo_t info{};
204       ASSERT_EQ(0, TEMP_FAILURE_RETRY(waitid(P_PID, pid, &info, WEXITED))) << strerror(errno);
205     }
206   }
207 };
208 
TEST_P(ThreadCaptureTest,capture_some)209 TEST_P(ThreadCaptureTest, capture_some) {
210   const unsigned int threads = GetParam();
211 
212   Fork(
213       [&]() {
214         // child init
215         StartThreads(threads - 1, []() {});
216       },
217       [&]() {
218         // child cleanup
219         StopThreads();
220       },
221       [&](pid_t child) {
222         // parent
223         ASSERT_GT(child, 0);
224 
225         {
226           ScopedDisableMallocTimeout disable_malloc;
227 
228           ThreadCapture thread_capture(child, heap);
229           auto list_tids = allocator::vector<pid_t>(heap);
230 
231           ASSERT_TRUE(thread_capture.ListThreads(list_tids));
232           ASSERT_EQ(threads, list_tids.size());
233 
234           ASSERT_TRUE(thread_capture.CaptureThreads());
235 
236           auto thread_info = allocator::vector<ThreadInfo>(heap);
237           ASSERT_TRUE(thread_capture.CapturedThreadInfo(thread_info));
238           ASSERT_EQ(threads, thread_info.size());
239           ASSERT_TRUE(thread_capture.ReleaseThreads());
240 
241           if (!HasFailure()) {
242             ASSERT_FALSE(disable_malloc.timed_out());
243           }
244         }
245       });
246 }
247 
248 INSTANTIATE_TEST_CASE_P(ThreadCaptureTest, ThreadCaptureTest, ::testing::Values(1, 2, 10, 1024));
249 
TEST_F(ThreadCaptureTest,capture_kill)250 TEST_F(ThreadCaptureTest, capture_kill) {
251   int ret = fork();
252 
253   if (ret == 0) {
254     // child
255     sleep(10);
256   } else {
257     // parent
258     ASSERT_GT(ret, 0);
259 
260     {
261       ScopedDisableMallocTimeout disable_malloc;
262 
263       ThreadCapture thread_capture(ret, heap);
264       thread_capture.InjectTestFunc([&](pid_t tid) {
265         tgkill(ret, tid, SIGKILL);
266         usleep(10000);
267       });
268       auto list_tids = allocator::vector<pid_t>(heap);
269 
270       ASSERT_TRUE(thread_capture.ListThreads(list_tids));
271       ASSERT_EQ(1U, list_tids.size());
272 
273       ASSERT_FALSE(thread_capture.CaptureThreads());
274 
275       if (!HasFailure()) {
276         ASSERT_FALSE(disable_malloc.timed_out());
277       }
278     }
279   }
280 }
281 
TEST_F(ThreadCaptureTest,capture_signal)282 TEST_F(ThreadCaptureTest, capture_signal) {
283   const int sig = SIGUSR1;
284 
285   ScopedPipe pipe;
286 
287   // For signal handler
288   static ScopedPipe* g_pipe;
289 
290   Fork(
291       [&]() {
292         // child init
293         pipe.CloseReceiver();
294 
295         g_pipe = &pipe;
296 
297         struct sigaction act {};
298         act.sa_handler = [](int) {
299           char buf = '+';
300           write(g_pipe->Sender(), &buf, 1);
301           g_pipe->CloseSender();
302         };
303         sigaction(sig, &act, NULL);
304         sigset_t set;
305         sigemptyset(&set);
306         sigaddset(&set, sig);
307         pthread_sigmask(SIG_UNBLOCK, &set, NULL);
308       },
309       [&]() {
310         // child cleanup
311         g_pipe = nullptr;
312         pipe.Close();
313       },
314       [&](pid_t child) {
315         // parent
316         ASSERT_GT(child, 0);
317         pipe.CloseSender();
318 
319         {
320           ScopedDisableMallocTimeout disable_malloc;
321 
322           ThreadCapture thread_capture(child, heap);
323           thread_capture.InjectTestFunc([&](pid_t tid) {
324             tgkill(child, tid, sig);
325             usleep(10000);
326           });
327           auto list_tids = allocator::vector<pid_t>(heap);
328 
329           ASSERT_TRUE(thread_capture.ListThreads(list_tids));
330           ASSERT_EQ(1U, list_tids.size());
331 
332           ASSERT_TRUE(thread_capture.CaptureThreads());
333 
334           auto thread_info = allocator::vector<ThreadInfo>(heap);
335           ASSERT_TRUE(thread_capture.CapturedThreadInfo(thread_info));
336           ASSERT_EQ(1U, thread_info.size());
337           ASSERT_TRUE(thread_capture.ReleaseThreads());
338 
339           usleep(100000);
340           char buf;
341           ASSERT_EQ(1, TEMP_FAILURE_RETRY(read(pipe.Receiver(), &buf, 1)));
342           ASSERT_EQ(buf, '+');
343 
344           if (!HasFailure()) {
345             ASSERT_FALSE(disable_malloc.timed_out());
346           }
347         }
348       });
349 }
350 
351 }  // namespace android
352