1 /*
2  * Copyright (C) 2015 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ART_DEX2OAT_LINKER_RELATIVE_PATCHER_TEST_H_
18 #define ART_DEX2OAT_LINKER_RELATIVE_PATCHER_TEST_H_
19 
20 #include <gtest/gtest.h>
21 
22 #include "arch/instruction_set.h"
23 #include "arch/instruction_set_features.h"
24 #include "base/array_ref.h"
25 #include "base/globals.h"
26 #include "base/macros.h"
27 #include "dex/method_reference.h"
28 #include "dex/string_reference.h"
29 #include "driver/compiled_method-inl.h"
30 #include "driver/compiled_method_storage.h"
31 #include "linker/relative_patcher.h"
32 #include "oat/oat_quick_method_header.h"
33 #include "stream/vector_output_stream.h"
34 
35 namespace art {
36 namespace linker {
37 
38 // Base class providing infrastructure for architecture-specific tests.
39 class RelativePatcherTest : public testing::Test {
40  protected:
RelativePatcherTest(InstructionSet instruction_set,const std::string & variant)41   RelativePatcherTest(InstructionSet instruction_set, const std::string& variant)
42       : storage_(/*swap_fd=*/ -1),
43         instruction_set_(instruction_set),
44         instruction_set_features_(nullptr),
45         method_offset_map_(),
46         patcher_(nullptr),
47         bss_begin_(0u),
48         compiled_method_refs_(),
49         compiled_methods_(),
50         patched_code_(),
51         output_(),
52         out_(nullptr) {
53     std::string error_msg;
54     instruction_set_features_ =
55         InstructionSetFeatures::FromVariant(instruction_set, variant, &error_msg);
56     CHECK(instruction_set_features_ != nullptr) << error_msg;
57 
58     patched_code_.reserve(16 * KB);
59   }
60 
SetUp()61   void SetUp() override {
62     Reset();
63   }
64 
TearDown()65   void TearDown() override {
66     thunk_provider_.Reset();
67     compiled_methods_.clear();
68     patcher_.reset();
69     bss_begin_ = 0u;
70     string_index_to_offset_map_.clear();
71     method_index_to_offset_map_.clear();
72     compiled_method_refs_.clear();
73     compiled_methods_.clear();
74     patched_code_.clear();
75     output_.clear();
76     out_.reset();
77   }
78 
79   // Reset the helper to start another test. Creating and tearing down the Runtime is expensive,
80   // so we merge related tests together.
Reset()81   virtual void Reset() {
82     thunk_provider_.Reset();
83     method_offset_map_.map.clear();
84     patcher_ = RelativePatcher::Create(instruction_set_,
85                                        instruction_set_features_.get(),
86                                        &thunk_provider_,
87                                        &method_offset_map_);
88     bss_begin_ = 0u;
89     string_index_to_offset_map_.clear();
90     method_index_to_offset_map_.clear();
91     compiled_method_refs_.clear();
92     compiled_methods_.clear();
93     patched_code_.clear();
94     output_.clear();
95     out_.reset(new VectorOutputStream("test output stream", &output_));
96   }
97 
MethodRef(uint32_t method_idx)98   MethodReference MethodRef(uint32_t method_idx) {
99     CHECK_NE(method_idx, 0u);
100     return MethodReference(nullptr, method_idx);
101   }
102 
103   void AddCompiledMethod(
104       MethodReference method_ref,
105       const ArrayRef<const uint8_t>& code,
106       const ArrayRef<const LinkerPatch>& patches = ArrayRef<const LinkerPatch>()) {
107     compiled_method_refs_.push_back(method_ref);
108     compiled_methods_.emplace_back(new CompiledMethod(
109         &storage_,
110         instruction_set_,
111         code,
112         /* vmap_table */ ArrayRef<const uint8_t>(),
113         /* cfi_info */ ArrayRef<const uint8_t>(),
114         patches));
115   }
116 
CodeAlignmentSize(uint32_t header_offset_to_align)117   uint32_t CodeAlignmentSize(uint32_t header_offset_to_align) {
118     // We want to align the code rather than the preheader.
119     uint32_t unaligned_code_offset = header_offset_to_align + sizeof(OatQuickMethodHeader);
120     uint32_t aligned_code_offset =
121         CompiledMethod::AlignCode(unaligned_code_offset, instruction_set_);
122     return aligned_code_offset - unaligned_code_offset;
123   }
124 
Link()125   void Link() {
126     // Reserve space.
127     static_assert(kTrampolineOffset == 0u, "Unexpected trampoline offset.");
128     uint32_t offset = kTrampolineSize;
129     size_t idx = 0u;
130     for (auto& compiled_method : compiled_methods_) {
131       offset = patcher_->ReserveSpace(offset, compiled_method.get(), compiled_method_refs_[idx]);
132 
133       uint32_t alignment_size = CodeAlignmentSize(offset);
134       offset += alignment_size;
135 
136       offset += sizeof(OatQuickMethodHeader);
137       uint32_t quick_code_offset = offset + compiled_method->GetEntryPointAdjustment();
138       const auto code = compiled_method->GetQuickCode();
139       offset += code.size();
140 
141       method_offset_map_.map.Put(compiled_method_refs_[idx], quick_code_offset);
142       ++idx;
143     }
144     offset = patcher_->ReserveSpaceEnd(offset);
145     uint32_t output_size = offset;
146     output_.reserve(output_size);
147 
148     // Write data.
149     DCHECK(output_.empty());
150     uint8_t fake_trampoline[kTrampolineSize];
151     memset(fake_trampoline, 0, sizeof(fake_trampoline));
152     out_->WriteFully(fake_trampoline, kTrampolineSize);
153     offset = kTrampolineSize;
154     static const uint8_t kPadding[] = {
155         0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u
156     };
157     uint8_t fake_header[sizeof(OatQuickMethodHeader)];
158     memset(fake_header, 0, sizeof(fake_header));
159     for (auto& compiled_method : compiled_methods_) {
160       offset = patcher_->WriteThunks(out_.get(), offset);
161 
162       uint32_t alignment_size = CodeAlignmentSize(offset);
163       CHECK_LE(alignment_size, sizeof(kPadding));
164       out_->WriteFully(kPadding, alignment_size);
165       offset += alignment_size;
166 
167       out_->WriteFully(fake_header, sizeof(OatQuickMethodHeader));
168       offset += sizeof(OatQuickMethodHeader);
169       ArrayRef<const uint8_t> code = compiled_method->GetQuickCode();
170       if (!compiled_method->GetPatches().empty()) {
171         patched_code_.assign(code.begin(), code.end());
172         code = ArrayRef<const uint8_t>(patched_code_);
173         for (const LinkerPatch& patch : compiled_method->GetPatches()) {
174           if (patch.GetType() == LinkerPatch::Type::kCallRelative) {
175             auto result = method_offset_map_.FindMethodOffset(patch.TargetMethod());
176             uint32_t target_offset =
177                 result.first ? result.second
178                              : kTrampolineOffset + compiled_method->GetEntryPointAdjustment();
179             patcher_->PatchCall(&patched_code_,
180                                 patch.LiteralOffset(),
181                                 offset + patch.LiteralOffset(),
182                                 target_offset);
183           } else if (patch.GetType() == LinkerPatch::Type::kStringBssEntry) {
184             uint32_t target_offset =
185                 bss_begin_ +
186                 string_index_to_offset_map_.Get(patch.TargetString().StringIndex().index_);
187             patcher_->PatchPcRelativeReference(&patched_code_,
188                                                patch,
189                                                offset + patch.LiteralOffset(),
190                                                target_offset);
191           } else if (patch.GetType() == LinkerPatch::Type::kMethodBssEntry) {
192             uint32_t target_offset =
193                 bss_begin_ + method_index_to_offset_map_.Get(patch.TargetMethod().index);
194             patcher_->PatchPcRelativeReference(&patched_code_,
195                                                patch,
196                                                offset + patch.LiteralOffset(),
197                                                target_offset);
198           } else if (patch.GetType() == LinkerPatch::Type::kStringRelative) {
199             uint32_t target_offset =
200                 string_index_to_offset_map_.Get(patch.TargetString().StringIndex().index_);
201             patcher_->PatchPcRelativeReference(&patched_code_,
202                                                patch,
203                                                offset + patch.LiteralOffset(),
204                                                target_offset);
205           } else if (patch.GetType() == LinkerPatch::Type::kCallEntrypoint) {
206             patcher_->PatchEntrypointCall(&patched_code_,
207                                           patch,
208                                           offset + patch.LiteralOffset());
209           } else if (patch.GetType() == LinkerPatch::Type::kBakerReadBarrierBranch) {
210             patcher_->PatchBakerReadBarrierBranch(&patched_code_,
211                                                   patch,
212                                                   offset + patch.LiteralOffset());
213           } else {
214             LOG(FATAL) << "Bad patch type. " << patch.GetType();
215             UNREACHABLE();
216           }
217         }
218       }
219       out_->WriteFully(&code[0], code.size());
220       offset += code.size();
221     }
222     offset = patcher_->WriteThunks(out_.get(), offset);
223     CHECK_EQ(offset, output_size);
224     CHECK_EQ(output_.size(), output_size);
225   }
226 
CheckLinkedMethod(MethodReference method_ref,const ArrayRef<const uint8_t> & expected_code)227   bool CheckLinkedMethod(MethodReference method_ref, const ArrayRef<const uint8_t>& expected_code) {
228     // Check that the original code size must match linked_code.size().
229     size_t idx = 0u;
230     for (auto ref : compiled_method_refs_) {
231       if (ref == method_ref) {
232         break;
233       }
234       ++idx;
235     }
236     CHECK_NE(idx, compiled_method_refs_.size());
237     CHECK_EQ(compiled_methods_[idx]->GetQuickCode().size(), expected_code.size());
238 
239     auto result = method_offset_map_.FindMethodOffset(method_ref);
240     CHECK(result.first);  // Must have been linked.
241     size_t offset = result.second - compiled_methods_[idx]->GetEntryPointAdjustment();
242     CHECK_LT(offset, output_.size());
243     CHECK_LE(offset + expected_code.size(), output_.size());
244     ArrayRef<const uint8_t> linked_code(&output_[offset], expected_code.size());
245     if (linked_code == expected_code) {
246       return true;
247     }
248     // Log failure info.
249     DumpDiff(expected_code, linked_code);
250     return false;
251   }
252 
DumpDiff(const ArrayRef<const uint8_t> & expected_code,const ArrayRef<const uint8_t> & linked_code)253   void DumpDiff(const ArrayRef<const uint8_t>& expected_code,
254                 const ArrayRef<const uint8_t>& linked_code) {
255     std::ostringstream expected_hex;
256     std::ostringstream linked_hex;
257     std::ostringstream diff_indicator;
258     static const char digits[] = "0123456789abcdef";
259     bool found_diff = false;
260     for (size_t i = 0; i != expected_code.size(); ++i) {
261       expected_hex << " " << digits[expected_code[i] >> 4] << digits[expected_code[i] & 0xf];
262       linked_hex << " " << digits[linked_code[i] >> 4] << digits[linked_code[i] & 0xf];
263       if (!found_diff) {
264         found_diff = (expected_code[i] != linked_code[i]);
265         diff_indicator << (found_diff ? " ^^" : "   ");
266       }
267     }
268     CHECK(found_diff);
269     std::string expected_hex_str = expected_hex.str();
270     std::string linked_hex_str = linked_hex.str();
271     std::string diff_indicator_str = diff_indicator.str();
272     if (diff_indicator_str.length() > 60) {
273       CHECK_EQ(diff_indicator_str.length() % 3u, 0u);
274       size_t remove = diff_indicator_str.length() / 3 - 5;
275       std::ostringstream oss;
276       oss << "[stripped " << remove << "]";
277       std::string replacement = oss.str();
278       expected_hex_str.replace(0u, remove * 3u, replacement);
279       linked_hex_str.replace(0u, remove * 3u, replacement);
280       diff_indicator_str.replace(0u, remove * 3u, replacement);
281     }
282     LOG(ERROR) << "diff expected_code linked_code";
283     LOG(ERROR) << "<" << expected_hex_str;
284     LOG(ERROR) << ">" << linked_hex_str;
285     LOG(ERROR) << " " << diff_indicator_str;
286   }
287 
288   class ThunkProvider : public RelativePatcherThunkProvider {
289    public:
ThunkProvider()290     ThunkProvider() {}
291 
SetThunkCode(const LinkerPatch & patch,ArrayRef<const uint8_t> code,const std::string & debug_name)292     void SetThunkCode(const LinkerPatch& patch,
293                       ArrayRef<const uint8_t> code,
294                       const std::string& debug_name) {
295       thunk_map_.emplace(ThunkKey(patch), ThunkValue(code, debug_name));
296     }
297 
GetThunkCode(const LinkerPatch & patch,ArrayRef<const uint8_t> * code,std::string * debug_name)298     void GetThunkCode(const LinkerPatch& patch,
299                       /*out*/ ArrayRef<const uint8_t>* code,
300                       /*out*/ std::string* debug_name) override {
301       auto it = thunk_map_.find(ThunkKey(patch));
302       CHECK(it != thunk_map_.end());
303       const ThunkValue& value = it->second;
304       CHECK(code != nullptr);
305       *code = value.GetCode();
306       CHECK(debug_name != nullptr);
307       *debug_name = value.GetDebugName();
308     }
309 
Reset()310     void Reset() {
311       thunk_map_.clear();
312     }
313 
314    private:
315     class ThunkKey {
316      public:
ThunkKey(const LinkerPatch & patch)317       explicit ThunkKey(const LinkerPatch& patch)
318           : type_(patch.GetType()),
319             custom_value1_(CustomValue1(patch)),
320             custom_value2_(CustomValue2(patch)) {
321         CHECK(patch.GetType() == LinkerPatch::Type::kCallEntrypoint ||
322               patch.GetType() == LinkerPatch::Type::kBakerReadBarrierBranch ||
323               patch.GetType() == LinkerPatch::Type::kCallRelative);
324       }
325 
326       bool operator<(const ThunkKey& other) const {
327         if (custom_value1_ != other.custom_value1_) {
328           return custom_value1_ < other.custom_value1_;
329         }
330         if (custom_value2_ != other.custom_value2_) {
331           return custom_value2_ < other.custom_value2_;
332         }
333         return type_ < other.type_;
334       }
335 
336      private:
CustomValue1(const LinkerPatch & patch)337       static uint32_t CustomValue1(const LinkerPatch& patch) {
338         switch (patch.GetType()) {
339           case LinkerPatch::Type::kCallEntrypoint:
340             return patch.EntrypointOffset();
341           case LinkerPatch::Type::kBakerReadBarrierBranch:
342             return patch.GetBakerCustomValue1();
343           default:
344             return 0;
345         }
346       }
347 
CustomValue2(const LinkerPatch & patch)348       static uint32_t CustomValue2(const LinkerPatch& patch) {
349         switch (patch.GetType()) {
350           case LinkerPatch::Type::kBakerReadBarrierBranch:
351             return patch.GetBakerCustomValue2();
352           default:
353             return 0;
354         }
355       }
356 
357       const LinkerPatch::Type type_;
358       const uint32_t custom_value1_;
359       const uint32_t custom_value2_;
360     };
361 
362     class ThunkValue {
363      public:
ThunkValue(ArrayRef<const uint8_t> code,const std::string & debug_name)364       ThunkValue(ArrayRef<const uint8_t> code, const std::string& debug_name)
365           : code_(code.begin(), code.end()), debug_name_(debug_name) {}
GetCode()366       ArrayRef<const uint8_t> GetCode() const { return ArrayRef<const uint8_t>(code_); }
GetDebugName()367       const std::string& GetDebugName() const { return debug_name_; }
368 
369      private:
370       const std::vector<uint8_t> code_;
371       const std::string debug_name_;
372     };
373 
374     std::map<ThunkKey, ThunkValue> thunk_map_;
375   };
376 
377   // Map method reference to assinged offset.
378   // Wrap the map in a class implementing RelativePatcherTargetProvider.
379   class MethodOffsetMap final : public RelativePatcherTargetProvider {
380    public:
FindMethodOffset(MethodReference ref)381     std::pair<bool, uint32_t> FindMethodOffset(MethodReference ref) override {
382       auto it = map.find(ref);
383       if (it == map.end()) {
384         return std::pair<bool, uint32_t>(false, 0u);
385       } else {
386         return std::pair<bool, uint32_t>(true, it->second);
387       }
388     }
389     SafeMap<MethodReference, uint32_t> map;
390   };
391 
392   static const uint32_t kTrampolineSize = 4u;
393   static const uint32_t kTrampolineOffset = 0u;
394 
395   CompiledMethodStorage storage_;
396   InstructionSet instruction_set_;
397   std::unique_ptr<const InstructionSetFeatures> instruction_set_features_;
398 
399   ThunkProvider thunk_provider_;
400   MethodOffsetMap method_offset_map_;
401   std::unique_ptr<RelativePatcher> patcher_;
402   uint32_t bss_begin_;
403   SafeMap<uint32_t, uint32_t> string_index_to_offset_map_;
404   SafeMap<uint32_t, uint32_t> method_index_to_offset_map_;
405   std::vector<MethodReference> compiled_method_refs_;
406   std::vector<std::unique_ptr<CompiledMethod>> compiled_methods_;
407   std::vector<uint8_t> patched_code_;
408   std::vector<uint8_t> output_;
409   std::unique_ptr<VectorOutputStream> out_;
410 };
411 
412 }  // namespace linker
413 }  // namespace art
414 
415 #endif  // ART_DEX2OAT_LINKER_RELATIVE_PATCHER_TEST_H_
416