1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "berberis/backend/x86_64/code_emit.h"
18 
19 #include <iterator>  // std::next
20 #include <utility>
21 
22 #include "berberis/assembler/x86_64.h"
23 #include "berberis/backend/code_emitter.h"
24 #include "berberis/backend/x86_64/machine_ir.h"
25 #include "berberis/base/arena_vector.h"
26 #include "berberis/base/logging.h"
27 #include "berberis/code_gen_lib/code_gen_lib.h"
28 #include "berberis/guest_state/guest_addr.h"
29 #include "berberis/runtime_primitives/host_code.h"  // AsHostCode
30 
31 namespace berberis {
32 
33 using Assembler = x86_64::Assembler;
34 
35 namespace x86_64 {
36 
37 namespace {
38 
EmitMovGRegGReg(CodeEmitter * as,MachineReg dst,MachineReg src,int)39 void EmitMovGRegGReg(CodeEmitter* as, MachineReg dst, MachineReg src, int /* size */) {
40   as->Movq(GetGReg(dst), GetGReg(src));
41 }
42 
EmitMovGRegXReg(CodeEmitter * as,MachineReg dst,MachineReg src,int)43 void EmitMovGRegXReg(CodeEmitter* as, MachineReg dst, MachineReg src, int /* size */) {
44   as->Movq(GetGReg(dst), GetXReg(src));
45 }
46 
EmitMovGRegMem(CodeEmitter * as,MachineReg dst,MachineReg src,int)47 void EmitMovGRegMem(CodeEmitter* as, MachineReg dst, MachineReg src, int /* size */) {
48   // TODO(b/207399902): Make this cast safe
49   int offset = static_cast<int>(src.GetSpilledRegIndex());
50   as->Movq(GetGReg(dst), {.base = Assembler::rsp, .disp = offset});
51 }
52 
EmitMovXRegGReg(CodeEmitter * as,MachineReg dst,MachineReg src,int)53 void EmitMovXRegGReg(CodeEmitter* as, MachineReg dst, MachineReg src, int /* size */) {
54   as->Movq(GetXReg(dst), GetGReg(src));
55 }
56 
EmitMovXRegXReg(CodeEmitter * as,MachineReg dst,MachineReg src,int)57 void EmitMovXRegXReg(CodeEmitter* as, MachineReg dst, MachineReg src, int /* size */) {
58   as->Pmov(GetXReg(dst), GetXReg(src));
59 }
60 
EmitMovXRegMem(CodeEmitter * as,MachineReg dst,MachineReg src,int size)61 void EmitMovXRegMem(CodeEmitter* as, MachineReg dst, MachineReg src, int size) {
62   // TODO(b/207399902): Make this cast safe
63   int offset = static_cast<int>(src.GetSpilledRegIndex());
64   if (size > 8) {
65     as->MovdquXRegMemBaseDisp(GetXReg(dst), Assembler::rsp, offset);
66   } else if (size > 4) {
67     as->MovsdXRegMemBaseDisp(GetXReg(dst), Assembler::rsp, offset);
68   } else {
69     as->Movss(GetXReg(dst), {.base = Assembler::rsp, .disp = offset});
70   }
71 }
72 
EmitMovMemGReg(CodeEmitter * as,MachineReg dst,MachineReg src,int)73 void EmitMovMemGReg(CodeEmitter* as, MachineReg dst, MachineReg src, int /* size */) {
74   // TODO(b/207399902): Make this cast safe
75   int offset = static_cast<int>(dst.GetSpilledRegIndex());
76   as->Movq({.base = Assembler::rsp, .disp = offset}, GetGReg(src));
77 }
78 
EmitMovMemXReg(CodeEmitter * as,MachineReg dst,MachineReg src,int size)79 void EmitMovMemXReg(CodeEmitter* as, MachineReg dst, MachineReg src, int size) {
80   // TODO(b/207399902): Make this cast safe
81   int offset = static_cast<int>(dst.GetSpilledRegIndex());
82   if (size > 8) {
83     as->MovdquMemBaseDispXReg(Assembler::rsp, offset, GetXReg(src));
84   } else if (size > 4) {
85     as->MovsdMemBaseDispXReg(Assembler::rsp, offset, GetXReg(src));
86   } else {
87     as->Movss({.base = Assembler::rsp, .disp = offset}, GetXReg(src));
88   }
89 }
90 
EmitMovMemMem(CodeEmitter * as,MachineReg dst,MachineReg src,int size)91 void EmitMovMemMem(CodeEmitter* as, MachineReg dst, MachineReg src, int size) {
92   // ATTENTION: memory to memory copy, very inefficient!
93   // TODO(b/207399902): Make this cast safe
94   int dst_offset = static_cast<int>(dst.GetSpilledRegIndex());
95   int src_offset = static_cast<int>(src.GetSpilledRegIndex());
96   for (int part = 0; part < size; part += 8) {
97     // offset BEFORE rsp decr!
98     as->Pushq({.base = Assembler::rsp, .disp = src_offset + part});
99     // offset AFTER rsp incr!
100     as->Popq({.base = Assembler::rsp, .disp = dst_offset + part});
101   }
102 }
103 
EmitCopy(CodeEmitter * as,MachineReg dst,MachineReg src,int size)104 void EmitCopy(CodeEmitter* as, MachineReg dst, MachineReg src, int size) {
105   if (dst.IsSpilledReg()) {
106     if (src.IsSpilledReg()) {
107       EmitMovMemMem(as, dst, src, size);
108     } else if (IsXReg(src)) {
109       EmitMovMemXReg(as, dst, src, size);
110     } else {
111       EmitMovMemGReg(as, dst, src, size);
112     }
113   } else if (IsXReg(dst)) {
114     if (src.IsSpilledReg()) {
115       EmitMovXRegMem(as, dst, src, size);
116     } else if (IsXReg(src)) {
117       EmitMovXRegXReg(as, dst, src, size);
118     } else {
119       EmitMovXRegGReg(as, dst, src, size);
120     }
121   } else {
122     if (src.IsSpilledReg()) {
123       EmitMovGRegMem(as, dst, src, size);
124     } else if (IsXReg(src)) {
125       EmitMovGRegXReg(as, dst, src, size);
126     } else {
127       EmitMovGRegGReg(as, dst, src, size);
128     }
129   }
130 }
131 
132 using RecoveryLabels = ArenaVector<std::pair<CodeEmitter::Label*, GuestAddr>>;
133 
EmitRecoveryLabels(CodeEmitter * as,const RecoveryLabels & labels)134 void EmitRecoveryLabels(CodeEmitter* as, const RecoveryLabels& labels) {
135   if (labels.empty()) {
136     return;
137   }
138 
139   auto* exit_label = as->MakeLabel();
140 
141   for (auto pair : labels) {
142     as->Bind(pair.first);
143     // EmitExitGeneratedCode is more efficient if receives target in rax.
144     as->Movq(as->rax, pair.second);
145     // Exit uses Jmp to full 64-bit address and is 14 bytes long, which is expensive.
146     // Thus we generate local relative jump to the common exit label here.
147     // It's up to 5 bytes, but likely 2-bytes since distance is expected to be short.
148     as->Jmp(*exit_label);
149   }
150 
151   as->Bind(exit_label);
152 
153   if (as->exit_label_for_testing()) {
154     as->Jmp(*as->exit_label_for_testing());
155     return;
156   }
157 
158   EmitExitGeneratedCode(as, as->rax);
159 }
160 
161 }  // namespace
162 
GetGReg(MachineReg r)163 Assembler::Register GetGReg(MachineReg r) {
164   static constexpr Assembler::Register kHardRegs[] = {Assembler::no_register,
165                                                       Assembler::r8,
166                                                       Assembler::r9,
167                                                       Assembler::r10,
168                                                       Assembler::r11,
169                                                       Assembler::rsi,
170                                                       Assembler::rdi,
171                                                       Assembler::rax,
172                                                       Assembler::rbx,
173                                                       Assembler::rcx,
174                                                       Assembler::rdx,
175                                                       Assembler::rbp,
176                                                       Assembler::rsp,
177                                                       Assembler::r12,
178                                                       Assembler::r13,
179                                                       Assembler::r14,
180                                                       Assembler::r15};
181   CHECK_LT(static_cast<unsigned>(r.reg()), std::size(kHardRegs));
182   return kHardRegs[r.reg()];
183 }
184 
GetXReg(MachineReg r)185 Assembler::XMMRegister GetXReg(MachineReg r) {
186   static constexpr Assembler::XMMRegister kHardRegs[] = {
187       Assembler::xmm0,
188       Assembler::xmm1,
189       Assembler::xmm2,
190       Assembler::xmm3,
191       Assembler::xmm4,
192       Assembler::xmm5,
193       Assembler::xmm6,
194       Assembler::xmm7,
195       Assembler::xmm8,
196       Assembler::xmm9,
197       Assembler::xmm10,
198       Assembler::xmm11,
199       Assembler::xmm12,
200       Assembler::xmm13,
201       Assembler::xmm14,
202       Assembler::xmm15,
203   };
204   CHECK_GE(r.reg(), kMachineRegXMM0.reg());
205   CHECK_LT(static_cast<unsigned>(r.reg() - kMachineRegXMM0.reg()), std::size(kHardRegs));
206   return kHardRegs[r.reg() - kMachineRegXMM0.reg()];
207 }
208 
ToScaleFactor(MachineMemOperandScale scale)209 Assembler::ScaleFactor ToScaleFactor(MachineMemOperandScale scale) {
210   switch (scale) {
211     case MachineMemOperandScale::kOne:
212       return Assembler::kTimesOne;
213     case MachineMemOperandScale::kTwo:
214       return Assembler::kTimesTwo;
215     case MachineMemOperandScale::kFour:
216       return Assembler::kTimesFour;
217     case MachineMemOperandScale::kEight:
218       return Assembler::kTimesEight;
219   }
220 }
221 
Emit(CodeEmitter * as) const222 void CallImm::Emit(CodeEmitter* as) const {
223   as->Call(AsHostCode(imm()));
224 }
225 
226 }  // namespace x86_64
227 
Emit(CodeEmitter * as) const228 void PseudoBranch::Emit(CodeEmitter* as) const {
229   const Assembler::Label* then_label = as->GetLabelAt(then_bb()->id());
230 
231   if (as->next_label() == then_label) {
232     // We do not need to emit any instruction as we fall through to
233     // the next basic block.
234     return;
235   }
236 
237   as->Jmp(*then_label);
238 }
239 
Emit(CodeEmitter * as) const240 void PseudoCondBranch::Emit(CodeEmitter* as) const {
241   const Assembler::Label* then_label = as->GetLabelAt(then_bb()->id());
242   const Assembler::Label* else_label = as->GetLabelAt(else_bb()->id());
243 
244   if (as->next_label() == else_label) {
245     // We do not need to emit JMP as our "else" arm falls through to
246     // the next basic block.
247     as->Jcc(cond_, *then_label);
248   } else if (as->next_label() == then_label) {
249     // Reverse the condition and emit Jcc to else_label().  We do not
250     // need to emit JMP as our original (that is, before reversing)
251     // "then" arm falls through to the next basic block.
252     as->Jcc(ToReverseCond(cond()), *else_label);
253   } else {
254     // Neither our "then" nor "else" arm falls through to the next
255     // basic block.  We need to emit both Jcc and Jmp.
256     as->Jcc(cond(), *then_label);
257     as->Jmp(*else_label);
258   }
259 }
260 
Emit(CodeEmitter * as) const261 void PseudoJump::Emit(CodeEmitter* as) const {
262   EmitFreeStackFrame(as, as->frame_size());
263 
264   if (as->exit_label_for_testing()) {
265     as->Movq(as->rax, target_);
266     as->Jmp(*as->exit_label_for_testing());
267     return;
268   }
269 
270   switch (kind_) {
271     case Kind::kJumpWithPendingSignalsCheck:
272       EmitDirectDispatch(as, target_, true);
273       break;
274     case Kind::kJumpWithoutPendingSignalsCheck:
275       EmitDirectDispatch(as, target_, false);
276       break;
277     case Kind::kSyscall:
278       EmitSyscall(as, target_);
279       break;
280     case Kind::kExitGeneratedCode:
281       as->Movq(as->rax, target_);
282       EmitExitGeneratedCode(as, as->rax);
283       break;
284   }
285 }
286 
Emit(CodeEmitter * as) const287 void PseudoIndirectJump::Emit(CodeEmitter* as) const {
288   EmitFreeStackFrame(as, as->frame_size());
289   if (as->exit_label_for_testing()) {
290     as->Movq(as->rax, x86_64::GetGReg(RegAt(0)));
291     as->Jmp(*as->exit_label_for_testing());
292     return;
293   }
294   EmitIndirectDispatch(as, x86_64::GetGReg(RegAt(0)));
295 }
296 
Emit(CodeEmitter * as) const297 void PseudoCopy::Emit(CodeEmitter* as) const {
298   MachineReg dst = RegAt(0);
299   MachineReg src = RegAt(1);
300   if (src == dst) {
301     return;
302   }
303   // Operands should have equal register classes!
304   CHECK_EQ(RegKindAt(0).RegClass(), RegKindAt(1).RegClass());
305   // TODO(b/232598137): Why get size by class then pick insn by size instead of pick insn by class?
306   int size = RegKindAt(0).RegClass()->RegSize();
307   x86_64::EmitCopy(as, dst, src, size);
308 }
309 
Emit(CodeEmitter * as) const310 void PseudoReadFlags::Emit(CodeEmitter* as) const {
311   as->Lahf();
312   if (with_overflow()) {
313     as->Setcc(CodeEmitter::Condition::kOverflow, as->rax);
314   } else {
315     // Still need to fill overflow with zero.
316     as->Movb(as->rax, int8_t{0});
317   }
318 }
319 
Emit(CodeEmitter * as) const320 void PseudoWriteFlags::Emit(CodeEmitter* as) const {
321   as->Addb(as->rax, int8_t{0x7f});
322   as->Sahf();
323 }
324 
Emit(CodeEmitter * as) const325 void MachineIR::Emit(CodeEmitter* as) const {
326   EmitAllocStackFrame(as, as->frame_size());
327   ArenaVector<std::pair<CodeEmitter::Label*, GuestAddr>> recovery_labels(arena());
328 
329   for (auto bb_it = bb_list().begin(); bb_it != bb_list().end(); ++bb_it) {
330     const MachineBasicBlock* bb = *bb_it;
331     as->Bind(as->GetLabelAt(bb->id()));
332 
333     // Let CodeEmitter know the label of the next basic block, if any.
334     // This label can be used e.g. used by PseudoBranch and
335     // PseudoCondBranch to avoid generating jumps to the next basic
336     // block.
337     auto next_bb_it = std::next(bb_it);
338     if (next_bb_it == bb_list().end()) {
339       as->set_next_label(nullptr);
340     } else {
341       as->set_next_label(as->GetLabelAt((*next_bb_it)->id()));
342     }
343 
344     for (const auto* insn : bb->insn_list()) {
345       if (insn->recovery_bb()) {
346         as->SetRecoveryPoint(as->GetLabelAt(insn->recovery_bb()->id()));
347       } else if (insn->recovery_pc() != kNullGuestAddr) {
348         auto* label = as->MakeLabel();
349         as->SetRecoveryPoint(label);
350         recovery_labels.push_back(std::make_pair(label, insn->recovery_pc()));
351       }
352       insn->Emit(as);
353     }
354   }
355 
356   x86_64::EmitRecoveryLabels(as, recovery_labels);
357 }
358 
359 }  // namespace berberis
360