1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "berberis/backend/x86_64/insn_folding.h"
18 
19 #include <cstdint>
20 #include <tuple>
21 
22 #include "berberis/backend/common/machine_ir.h"
23 #include "berberis/backend/x86_64/machine_ir.h"
24 
25 #include "berberis/backend/code_emitter.h"  // for CodeEmitter::Condition
26 #include "berberis/base/algorithm.h"
27 #include "berberis/base/bit_util.h"
28 #include "berberis/base/logging.h"
29 
30 namespace berberis::x86_64 {
31 
MapDefRegs(const MachineInsn * insn)32 void DefMap::MapDefRegs(const MachineInsn* insn) {
33   for (int op = 0; op < insn->NumRegOperands(); ++op) {
34     MachineReg reg = insn->RegAt(op);
35     if (insn->RegKindAt(op).RegClass()->IsSubsetOf(&x86_64::kFLAGS)) {
36       if (flags_reg_ == kInvalidMachineReg) {
37         flags_reg_ = reg;
38       }
39       // Some optimizations assume flags is the same virtual register everywhere.
40       CHECK(reg == flags_reg_);
41     }
42     if (insn->RegKindAt(op).IsDef()) {
43       Set(reg, insn);
44     }
45   }
46 }
47 
ProcessInsn(const MachineInsn * insn)48 void DefMap::ProcessInsn(const MachineInsn* insn) {
49   MapDefRegs(insn);
50   ++index_;
51 }
52 
Initialize()53 void DefMap::Initialize() {
54   std::fill(def_map_.begin(), def_map_.end(), std::pair(nullptr, 0));
55   flags_reg_ = kInvalidMachineReg;
56   index_ = 0;
57 }
58 
IsRegImm(MachineReg reg,uint64_t * imm) const59 bool InsnFolding::IsRegImm(MachineReg reg, uint64_t* imm) const {
60   auto [general_insn, _] = def_map_.Get(reg);
61   if (!general_insn) {
62     return false;
63   }
64   const auto* insn = AsMachineInsnX86_64(general_insn);
65   if (insn->opcode() == kMachineOpMovqRegImm) {
66     *imm = insn->imm();
67     return true;
68   } else if (insn->opcode() == kMachineOpMovlRegImm) {
69     // Take into account zero-extension by MOVL.
70     *imm = static_cast<uint64_t>(static_cast<uint32_t>(insn->imm()));
71     return true;
72   }
73   return false;
74 }
75 
NewImmInsnFromRegInsn(const MachineInsn * insn,int32_t imm32)76 MachineInsn* InsnFolding::NewImmInsnFromRegInsn(const MachineInsn* insn, int32_t imm32) {
77   MachineInsn* folded_insn;
78   switch (insn->opcode()) {
79     case kMachineOpAddqRegReg:
80       folded_insn = machine_ir_->NewInsn<AddqRegImm>(insn->RegAt(0), imm32, insn->RegAt(2));
81       break;
82     case kMachineOpSubqRegReg:
83       folded_insn = machine_ir_->NewInsn<SubqRegImm>(insn->RegAt(0), imm32, insn->RegAt(2));
84       break;
85     case kMachineOpCmpqRegReg:
86       folded_insn = machine_ir_->NewInsn<CmpqRegImm>(insn->RegAt(0), imm32, insn->RegAt(2));
87       break;
88     case kMachineOpOrqRegReg:
89       folded_insn = machine_ir_->NewInsn<OrqRegImm>(insn->RegAt(0), imm32, insn->RegAt(2));
90       break;
91     case kMachineOpXorqRegReg:
92       folded_insn = machine_ir_->NewInsn<XorqRegImm>(insn->RegAt(0), imm32, insn->RegAt(2));
93       break;
94     case kMachineOpAndqRegReg:
95       folded_insn = machine_ir_->NewInsn<AndqRegImm>(insn->RegAt(0), imm32, insn->RegAt(2));
96       break;
97     case kMachineOpTestqRegReg:
98       folded_insn = machine_ir_->NewInsn<TestqRegImm>(insn->RegAt(0), imm32, insn->RegAt(2));
99       break;
100     case kMachineOpMovlRegReg:
101       folded_insn = machine_ir_->NewInsn<MovlRegImm>(insn->RegAt(0), imm32);
102       break;
103     case kMachineOpAddlRegReg:
104       folded_insn = machine_ir_->NewInsn<AddlRegImm>(insn->RegAt(0), imm32, insn->RegAt(2));
105       break;
106     case kMachineOpSublRegReg:
107       folded_insn = machine_ir_->NewInsn<SublRegImm>(insn->RegAt(0), imm32, insn->RegAt(2));
108       break;
109     case kMachineOpCmplRegReg:
110       folded_insn = machine_ir_->NewInsn<CmplRegImm>(insn->RegAt(0), imm32, insn->RegAt(2));
111       break;
112     case kMachineOpOrlRegReg:
113       folded_insn = machine_ir_->NewInsn<OrlRegImm>(insn->RegAt(0), imm32, insn->RegAt(2));
114       break;
115     case kMachineOpXorlRegReg:
116       folded_insn = machine_ir_->NewInsn<XorlRegImm>(insn->RegAt(0), imm32, insn->RegAt(2));
117       break;
118     case kMachineOpAndlRegReg:
119       folded_insn = machine_ir_->NewInsn<AndlRegImm>(insn->RegAt(0), imm32, insn->RegAt(2));
120       break;
121     case kMachineOpTestlRegReg:
122       folded_insn = machine_ir_->NewInsn<TestlRegImm>(insn->RegAt(0), imm32, insn->RegAt(2));
123       break;
124     case kMachineOpMovlMemBaseDispReg:
125       folded_insn = machine_ir_->NewInsn<MovlMemBaseDispImm>(
126           insn->RegAt(0), AsMachineInsnX86_64(insn)->disp(), imm32);
127       break;
128     case kMachineOpMovqMemBaseDispReg:
129       folded_insn = machine_ir_->NewInsn<MovqMemBaseDispImm>(
130           insn->RegAt(0), AsMachineInsnX86_64(insn)->disp(), imm32);
131       break;
132     default:
133       LOG_ALWAYS_FATAL("unexpected opcode");
134   }
135   // Inherit the additional attributes.
136   folded_insn->set_recovery_bb(insn->recovery_bb());
137   folded_insn->set_recovery_pc(insn->recovery_pc());
138   return folded_insn;
139 }
140 
IsWritingSameFlagsValue(const MachineInsn * write_flags_insn) const141 bool InsnFolding::IsWritingSameFlagsValue(const MachineInsn* write_flags_insn) const {
142   CHECK(write_flags_insn && write_flags_insn->opcode() == kMachineOpPseudoWriteFlags);
143   MachineReg src_reg = write_flags_insn->RegAt(0);
144   auto [def_insn, def_insn_pos] = def_map_.Get(src_reg);
145   // Warning: We are assuming that all flags writes in IR happen to the same virtual register.
146   while (true) {
147     if (!def_insn) {
148       return false;
149     }
150 
151     int opcode = def_insn->opcode();
152     if (opcode == kMachineOpPseudoCopy) {
153       src_reg = def_insn->RegAt(1);
154       std::tie(def_insn, def_insn_pos) = def_map_.Get(src_reg, def_insn_pos);
155       continue;
156     } else if (opcode == kMachineOpPseudoReadFlags) {
157       break;
158     }
159     return false;
160   }
161 
162   // Instruction is PseudoReadFlags.
163   if (write_flags_insn->RegAt(1) != def_insn->RegAt(1)) {
164     return false;
165   }
166   auto [flag_def_insn, _] = def_map_.Get(write_flags_insn->RegAt(1), def_insn_pos);
167   return flag_def_insn != nullptr;
168 }
169 
170 template <bool is_input_64bit>
TryFoldImmediateInput(const MachineInsn * insn)171 std::tuple<bool, MachineInsn*> InsnFolding::TryFoldImmediateInput(const MachineInsn* insn) {
172   auto src = insn->RegAt(1);
173   uint64_t imm64;
174   if (!IsRegImm(src, &imm64)) {
175     return {false, nullptr};
176   }
177 
178   // MovqRegReg is the only instruction that can encode full 64-bit immediate.
179   if (insn->opcode() == kMachineOpMovqRegReg) {
180     return {true, machine_ir_->NewInsn<MovqRegImm>(insn->RegAt(0), imm64)};
181   }
182 
183   int64_t signed_imm = bit_cast<int64_t>(imm64);
184   int32_t signed_imm32 = static_cast<int32_t>(signed_imm);
185   if (!is_input_64bit) {
186     // Use the lower half of the register as the immediate operand.
187     return {true, NewImmInsnFromRegInsn(insn, signed_imm32)};
188   }
189 
190   // Except for MOVQ x86 doesn't allow to encode 64-bit immediates. That said,
191   // we can encode 32-bit immediates that are sign-extended by hardware to
192   // 64-bit during instruction execution.
193   if (signed_imm == static_cast<int64_t>(signed_imm32)) {
194     return {true, NewImmInsnFromRegInsn(insn, signed_imm32)};
195   }
196 
197   return {false, nullptr};
198 }
199 
TryFoldRedundantMovl(const MachineInsn * insn)200 std::tuple<bool, MachineInsn*> InsnFolding::TryFoldRedundantMovl(const MachineInsn* insn) {
201   CHECK_EQ(insn->opcode(), kMachineOpMovlRegReg);
202   auto src = insn->RegAt(1);
203   auto [def_insn, _] = def_map_.Get(src);
204 
205   if (!def_insn) {
206     return {false, nullptr};
207   }
208 
209   // If the definition of src clears its upper half, then we can replace MOVL with PseudoCopy.
210   switch (def_insn->opcode()) {
211     case kMachineOpMovlRegReg:
212     case kMachineOpAndlRegReg:
213     case kMachineOpXorlRegReg:
214     case kMachineOpOrlRegReg:
215     case kMachineOpSublRegReg:
216     case kMachineOpAddlRegReg:
217       return {true, machine_ir_->NewInsn<PseudoCopy>(insn->RegAt(0), src, 4)};
218     default:
219       return {false, nullptr};
220   }
221 }
222 
TryFoldInsn(const MachineInsn * insn)223 std::tuple<bool, MachineInsn*> InsnFolding::TryFoldInsn(const MachineInsn* insn) {
224   switch (insn->opcode()) {
225     case kMachineOpMovqMemBaseDispReg:
226     case kMachineOpMovqRegReg:
227     case kMachineOpAndqRegReg:
228     case kMachineOpTestqRegReg:
229     case kMachineOpXorqRegReg:
230     case kMachineOpOrqRegReg:
231     case kMachineOpSubqRegReg:
232     case kMachineOpCmpqRegReg:
233     case kMachineOpAddqRegReg:
234       return TryFoldImmediateInput<true>(insn);
235     case kMachineOpMovlRegReg: {
236       auto [is_folded, folded_insn] = TryFoldImmediateInput<false>(insn);
237       if (is_folded) {
238         return {is_folded, folded_insn};
239       }
240 
241       return TryFoldRedundantMovl(insn);
242     }
243     case kMachineOpMovlMemBaseDispReg:
244     case kMachineOpAndlRegReg:
245     case kMachineOpTestlRegReg:
246     case kMachineOpXorlRegReg:
247     case kMachineOpOrlRegReg:
248     case kMachineOpSublRegReg:
249     case kMachineOpCmplRegReg:
250     case kMachineOpAddlRegReg:
251       return TryFoldImmediateInput<false>(insn);
252     case kMachineOpPseudoWriteFlags: {
253       if (IsWritingSameFlagsValue(insn)) {
254         return {true, nullptr};
255       }
256       break;
257     }
258     default:
259       return {false, nullptr};
260   }
261   return {false, nullptr};
262 }
263 
FoldInsns(MachineIR * machine_ir)264 void FoldInsns(MachineIR* machine_ir) {
265   DefMap def_map(machine_ir->NumVReg(), machine_ir->arena());
266   for (auto* bb : machine_ir->bb_list()) {
267     def_map.Initialize();
268     InsnFolding insn_folding(def_map, machine_ir);
269     MachineInsnList& insn_list = bb->insn_list();
270 
271     for (auto insn_it = insn_list.begin(); insn_it != insn_list.end();) {
272       auto [is_folded, new_insn] = insn_folding.TryFoldInsn(*insn_it);
273 
274       if (is_folded) {
275         insn_it = insn_list.erase(insn_it);
276         if (new_insn) {
277           insn_list.insert(insn_it, new_insn);
278           def_map.ProcessInsn(new_insn);
279         }
280       } else {
281         def_map.ProcessInsn(*insn_it);
282         ++insn_it;
283       }
284     }
285   }
286 }
287 
288 // TODO(b/179708579): Maybe combine with FoldInsns.
FoldWriteFlags(MachineIR * machine_ir)289 void FoldWriteFlags(MachineIR* machine_ir) {
290   for (auto* bb : machine_ir->bb_list()) {
291     CHECK(!bb->insn_list().empty());
292     auto insn_it = std::prev(bb->insn_list().end());
293     if ((*insn_it)->opcode() != kMachineOpPseudoCondBranch) {
294       continue;
295     }
296 
297     auto* branch = static_cast<PseudoCondBranch*>(*insn_it);
298     const auto* write_flags = *(--insn_it);
299     if (write_flags->opcode() != kMachineOpPseudoWriteFlags) {
300       continue;
301     }
302     // There is only one flags register, so CondBranch must read flags from WriteFlags.
303     MachineReg flags = write_flags->RegAt(1);
304     CHECK_EQ(flags.reg(), branch->RegAt(0).reg());
305 
306     const auto& live_out = bb->live_out();
307     if (Contains(live_out, flags)) {
308       // Flags are living-out. Cannot remove.
309       // TODO(b/179708579): This shouldn't happen. Consider conversion to an assert.
310       continue;
311     }
312 
313     using Cond = CodeEmitter::Condition;
314     Cond new_cond = Cond::kInvalidCondition;
315     PseudoWriteFlags::Flags flags_mask;
316 
317     switch (branch->cond()) {
318       // Verify that the flags are within the bottom 16 bits, so we can use Testw.
319       static_assert(sizeof(PseudoWriteFlags::Flags) == 2);
320       case Cond::kZero:
321         new_cond = Cond::kNotZero;
322         flags_mask = PseudoWriteFlags::Flags::kZero;
323         break;
324       case Cond::kNotZero:
325         new_cond = Cond::kZero;
326         flags_mask = PseudoWriteFlags::Flags::kZero;
327         break;
328       case Cond::kCarry:
329         new_cond = Cond::kNotZero;
330         flags_mask = PseudoWriteFlags::Flags::kCarry;
331         break;
332       case Cond::kNotCarry:
333         new_cond = Cond::kZero;
334         flags_mask = PseudoWriteFlags::Flags::kCarry;
335         break;
336       case Cond::kNegative:
337         new_cond = Cond::kNotZero;
338         flags_mask = PseudoWriteFlags::Flags::kNegative;
339         break;
340       case Cond::kNotSign:
341         new_cond = Cond::kZero;
342         flags_mask = PseudoWriteFlags::Flags::kNegative;
343         break;
344       case Cond::kOverflow:
345         new_cond = Cond::kNotZero;
346         flags_mask = PseudoWriteFlags::Flags::kOverflow;
347         break;
348       case Cond::kNoOverflow:
349         new_cond = Cond::kZero;
350         flags_mask = PseudoWriteFlags::Flags::kOverflow;
351         break;
352       default:
353         continue;
354     }
355 
356     MachineReg flags_src = write_flags->RegAt(0);
357     MachineInsn* new_write_flags =
358         machine_ir->NewInsn<x86_64::TestwRegImm>(flags_src, flags_mask, flags);
359     insn_it = bb->insn_list().erase(insn_it);
360     bb->insn_list().insert(insn_it, new_write_flags);
361     branch->set_cond(new_cond);
362   }
363 }
364 
365 }  // namespace berberis::x86_64
366