1 /*
2 * Copyright (C) 2023 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "berberis/backend/x86_64/insn_folding.h"
18
19 #include <cstdint>
20 #include <tuple>
21
22 #include "berberis/backend/common/machine_ir.h"
23 #include "berberis/backend/x86_64/machine_ir.h"
24
25 #include "berberis/backend/code_emitter.h" // for CodeEmitter::Condition
26 #include "berberis/base/algorithm.h"
27 #include "berberis/base/bit_util.h"
28 #include "berberis/base/logging.h"
29
30 namespace berberis::x86_64 {
31
MapDefRegs(const MachineInsn * insn)32 void DefMap::MapDefRegs(const MachineInsn* insn) {
33 for (int op = 0; op < insn->NumRegOperands(); ++op) {
34 MachineReg reg = insn->RegAt(op);
35 if (insn->RegKindAt(op).RegClass()->IsSubsetOf(&x86_64::kFLAGS)) {
36 if (flags_reg_ == kInvalidMachineReg) {
37 flags_reg_ = reg;
38 }
39 // Some optimizations assume flags is the same virtual register everywhere.
40 CHECK(reg == flags_reg_);
41 }
42 if (insn->RegKindAt(op).IsDef()) {
43 Set(reg, insn);
44 }
45 }
46 }
47
ProcessInsn(const MachineInsn * insn)48 void DefMap::ProcessInsn(const MachineInsn* insn) {
49 MapDefRegs(insn);
50 ++index_;
51 }
52
Initialize()53 void DefMap::Initialize() {
54 std::fill(def_map_.begin(), def_map_.end(), std::pair(nullptr, 0));
55 flags_reg_ = kInvalidMachineReg;
56 index_ = 0;
57 }
58
IsRegImm(MachineReg reg,uint64_t * imm) const59 bool InsnFolding::IsRegImm(MachineReg reg, uint64_t* imm) const {
60 auto [general_insn, _] = def_map_.Get(reg);
61 if (!general_insn) {
62 return false;
63 }
64 const auto* insn = AsMachineInsnX86_64(general_insn);
65 if (insn->opcode() == kMachineOpMovqRegImm) {
66 *imm = insn->imm();
67 return true;
68 } else if (insn->opcode() == kMachineOpMovlRegImm) {
69 // Take into account zero-extension by MOVL.
70 *imm = static_cast<uint64_t>(static_cast<uint32_t>(insn->imm()));
71 return true;
72 }
73 return false;
74 }
75
NewImmInsnFromRegInsn(const MachineInsn * insn,int32_t imm32)76 MachineInsn* InsnFolding::NewImmInsnFromRegInsn(const MachineInsn* insn, int32_t imm32) {
77 MachineInsn* folded_insn;
78 switch (insn->opcode()) {
79 case kMachineOpAddqRegReg:
80 folded_insn = machine_ir_->NewInsn<AddqRegImm>(insn->RegAt(0), imm32, insn->RegAt(2));
81 break;
82 case kMachineOpSubqRegReg:
83 folded_insn = machine_ir_->NewInsn<SubqRegImm>(insn->RegAt(0), imm32, insn->RegAt(2));
84 break;
85 case kMachineOpCmpqRegReg:
86 folded_insn = machine_ir_->NewInsn<CmpqRegImm>(insn->RegAt(0), imm32, insn->RegAt(2));
87 break;
88 case kMachineOpOrqRegReg:
89 folded_insn = machine_ir_->NewInsn<OrqRegImm>(insn->RegAt(0), imm32, insn->RegAt(2));
90 break;
91 case kMachineOpXorqRegReg:
92 folded_insn = machine_ir_->NewInsn<XorqRegImm>(insn->RegAt(0), imm32, insn->RegAt(2));
93 break;
94 case kMachineOpAndqRegReg:
95 folded_insn = machine_ir_->NewInsn<AndqRegImm>(insn->RegAt(0), imm32, insn->RegAt(2));
96 break;
97 case kMachineOpTestqRegReg:
98 folded_insn = machine_ir_->NewInsn<TestqRegImm>(insn->RegAt(0), imm32, insn->RegAt(2));
99 break;
100 case kMachineOpMovlRegReg:
101 folded_insn = machine_ir_->NewInsn<MovlRegImm>(insn->RegAt(0), imm32);
102 break;
103 case kMachineOpAddlRegReg:
104 folded_insn = machine_ir_->NewInsn<AddlRegImm>(insn->RegAt(0), imm32, insn->RegAt(2));
105 break;
106 case kMachineOpSublRegReg:
107 folded_insn = machine_ir_->NewInsn<SublRegImm>(insn->RegAt(0), imm32, insn->RegAt(2));
108 break;
109 case kMachineOpCmplRegReg:
110 folded_insn = machine_ir_->NewInsn<CmplRegImm>(insn->RegAt(0), imm32, insn->RegAt(2));
111 break;
112 case kMachineOpOrlRegReg:
113 folded_insn = machine_ir_->NewInsn<OrlRegImm>(insn->RegAt(0), imm32, insn->RegAt(2));
114 break;
115 case kMachineOpXorlRegReg:
116 folded_insn = machine_ir_->NewInsn<XorlRegImm>(insn->RegAt(0), imm32, insn->RegAt(2));
117 break;
118 case kMachineOpAndlRegReg:
119 folded_insn = machine_ir_->NewInsn<AndlRegImm>(insn->RegAt(0), imm32, insn->RegAt(2));
120 break;
121 case kMachineOpTestlRegReg:
122 folded_insn = machine_ir_->NewInsn<TestlRegImm>(insn->RegAt(0), imm32, insn->RegAt(2));
123 break;
124 case kMachineOpMovlMemBaseDispReg:
125 folded_insn = machine_ir_->NewInsn<MovlMemBaseDispImm>(
126 insn->RegAt(0), AsMachineInsnX86_64(insn)->disp(), imm32);
127 break;
128 case kMachineOpMovqMemBaseDispReg:
129 folded_insn = machine_ir_->NewInsn<MovqMemBaseDispImm>(
130 insn->RegAt(0), AsMachineInsnX86_64(insn)->disp(), imm32);
131 break;
132 default:
133 LOG_ALWAYS_FATAL("unexpected opcode");
134 }
135 // Inherit the additional attributes.
136 folded_insn->set_recovery_bb(insn->recovery_bb());
137 folded_insn->set_recovery_pc(insn->recovery_pc());
138 return folded_insn;
139 }
140
IsWritingSameFlagsValue(const MachineInsn * write_flags_insn) const141 bool InsnFolding::IsWritingSameFlagsValue(const MachineInsn* write_flags_insn) const {
142 CHECK(write_flags_insn && write_flags_insn->opcode() == kMachineOpPseudoWriteFlags);
143 MachineReg src_reg = write_flags_insn->RegAt(0);
144 auto [def_insn, def_insn_pos] = def_map_.Get(src_reg);
145 // Warning: We are assuming that all flags writes in IR happen to the same virtual register.
146 while (true) {
147 if (!def_insn) {
148 return false;
149 }
150
151 int opcode = def_insn->opcode();
152 if (opcode == kMachineOpPseudoCopy) {
153 src_reg = def_insn->RegAt(1);
154 std::tie(def_insn, def_insn_pos) = def_map_.Get(src_reg, def_insn_pos);
155 continue;
156 } else if (opcode == kMachineOpPseudoReadFlags) {
157 break;
158 }
159 return false;
160 }
161
162 // Instruction is PseudoReadFlags.
163 if (write_flags_insn->RegAt(1) != def_insn->RegAt(1)) {
164 return false;
165 }
166 auto [flag_def_insn, _] = def_map_.Get(write_flags_insn->RegAt(1), def_insn_pos);
167 return flag_def_insn != nullptr;
168 }
169
170 template <bool is_input_64bit>
TryFoldImmediateInput(const MachineInsn * insn)171 std::tuple<bool, MachineInsn*> InsnFolding::TryFoldImmediateInput(const MachineInsn* insn) {
172 auto src = insn->RegAt(1);
173 uint64_t imm64;
174 if (!IsRegImm(src, &imm64)) {
175 return {false, nullptr};
176 }
177
178 // MovqRegReg is the only instruction that can encode full 64-bit immediate.
179 if (insn->opcode() == kMachineOpMovqRegReg) {
180 return {true, machine_ir_->NewInsn<MovqRegImm>(insn->RegAt(0), imm64)};
181 }
182
183 int64_t signed_imm = bit_cast<int64_t>(imm64);
184 int32_t signed_imm32 = static_cast<int32_t>(signed_imm);
185 if (!is_input_64bit) {
186 // Use the lower half of the register as the immediate operand.
187 return {true, NewImmInsnFromRegInsn(insn, signed_imm32)};
188 }
189
190 // Except for MOVQ x86 doesn't allow to encode 64-bit immediates. That said,
191 // we can encode 32-bit immediates that are sign-extended by hardware to
192 // 64-bit during instruction execution.
193 if (signed_imm == static_cast<int64_t>(signed_imm32)) {
194 return {true, NewImmInsnFromRegInsn(insn, signed_imm32)};
195 }
196
197 return {false, nullptr};
198 }
199
TryFoldRedundantMovl(const MachineInsn * insn)200 std::tuple<bool, MachineInsn*> InsnFolding::TryFoldRedundantMovl(const MachineInsn* insn) {
201 CHECK_EQ(insn->opcode(), kMachineOpMovlRegReg);
202 auto src = insn->RegAt(1);
203 auto [def_insn, _] = def_map_.Get(src);
204
205 if (!def_insn) {
206 return {false, nullptr};
207 }
208
209 // If the definition of src clears its upper half, then we can replace MOVL with PseudoCopy.
210 switch (def_insn->opcode()) {
211 case kMachineOpMovlRegReg:
212 case kMachineOpAndlRegReg:
213 case kMachineOpXorlRegReg:
214 case kMachineOpOrlRegReg:
215 case kMachineOpSublRegReg:
216 case kMachineOpAddlRegReg:
217 return {true, machine_ir_->NewInsn<PseudoCopy>(insn->RegAt(0), src, 4)};
218 default:
219 return {false, nullptr};
220 }
221 }
222
TryFoldInsn(const MachineInsn * insn)223 std::tuple<bool, MachineInsn*> InsnFolding::TryFoldInsn(const MachineInsn* insn) {
224 switch (insn->opcode()) {
225 case kMachineOpMovqMemBaseDispReg:
226 case kMachineOpMovqRegReg:
227 case kMachineOpAndqRegReg:
228 case kMachineOpTestqRegReg:
229 case kMachineOpXorqRegReg:
230 case kMachineOpOrqRegReg:
231 case kMachineOpSubqRegReg:
232 case kMachineOpCmpqRegReg:
233 case kMachineOpAddqRegReg:
234 return TryFoldImmediateInput<true>(insn);
235 case kMachineOpMovlRegReg: {
236 auto [is_folded, folded_insn] = TryFoldImmediateInput<false>(insn);
237 if (is_folded) {
238 return {is_folded, folded_insn};
239 }
240
241 return TryFoldRedundantMovl(insn);
242 }
243 case kMachineOpMovlMemBaseDispReg:
244 case kMachineOpAndlRegReg:
245 case kMachineOpTestlRegReg:
246 case kMachineOpXorlRegReg:
247 case kMachineOpOrlRegReg:
248 case kMachineOpSublRegReg:
249 case kMachineOpCmplRegReg:
250 case kMachineOpAddlRegReg:
251 return TryFoldImmediateInput<false>(insn);
252 case kMachineOpPseudoWriteFlags: {
253 if (IsWritingSameFlagsValue(insn)) {
254 return {true, nullptr};
255 }
256 break;
257 }
258 default:
259 return {false, nullptr};
260 }
261 return {false, nullptr};
262 }
263
FoldInsns(MachineIR * machine_ir)264 void FoldInsns(MachineIR* machine_ir) {
265 DefMap def_map(machine_ir->NumVReg(), machine_ir->arena());
266 for (auto* bb : machine_ir->bb_list()) {
267 def_map.Initialize();
268 InsnFolding insn_folding(def_map, machine_ir);
269 MachineInsnList& insn_list = bb->insn_list();
270
271 for (auto insn_it = insn_list.begin(); insn_it != insn_list.end();) {
272 auto [is_folded, new_insn] = insn_folding.TryFoldInsn(*insn_it);
273
274 if (is_folded) {
275 insn_it = insn_list.erase(insn_it);
276 if (new_insn) {
277 insn_list.insert(insn_it, new_insn);
278 def_map.ProcessInsn(new_insn);
279 }
280 } else {
281 def_map.ProcessInsn(*insn_it);
282 ++insn_it;
283 }
284 }
285 }
286 }
287
288 // TODO(b/179708579): Maybe combine with FoldInsns.
FoldWriteFlags(MachineIR * machine_ir)289 void FoldWriteFlags(MachineIR* machine_ir) {
290 for (auto* bb : machine_ir->bb_list()) {
291 CHECK(!bb->insn_list().empty());
292 auto insn_it = std::prev(bb->insn_list().end());
293 if ((*insn_it)->opcode() != kMachineOpPseudoCondBranch) {
294 continue;
295 }
296
297 auto* branch = static_cast<PseudoCondBranch*>(*insn_it);
298 const auto* write_flags = *(--insn_it);
299 if (write_flags->opcode() != kMachineOpPseudoWriteFlags) {
300 continue;
301 }
302 // There is only one flags register, so CondBranch must read flags from WriteFlags.
303 MachineReg flags = write_flags->RegAt(1);
304 CHECK_EQ(flags.reg(), branch->RegAt(0).reg());
305
306 const auto& live_out = bb->live_out();
307 if (Contains(live_out, flags)) {
308 // Flags are living-out. Cannot remove.
309 // TODO(b/179708579): This shouldn't happen. Consider conversion to an assert.
310 continue;
311 }
312
313 using Cond = CodeEmitter::Condition;
314 Cond new_cond = Cond::kInvalidCondition;
315 PseudoWriteFlags::Flags flags_mask;
316
317 switch (branch->cond()) {
318 // Verify that the flags are within the bottom 16 bits, so we can use Testw.
319 static_assert(sizeof(PseudoWriteFlags::Flags) == 2);
320 case Cond::kZero:
321 new_cond = Cond::kNotZero;
322 flags_mask = PseudoWriteFlags::Flags::kZero;
323 break;
324 case Cond::kNotZero:
325 new_cond = Cond::kZero;
326 flags_mask = PseudoWriteFlags::Flags::kZero;
327 break;
328 case Cond::kCarry:
329 new_cond = Cond::kNotZero;
330 flags_mask = PseudoWriteFlags::Flags::kCarry;
331 break;
332 case Cond::kNotCarry:
333 new_cond = Cond::kZero;
334 flags_mask = PseudoWriteFlags::Flags::kCarry;
335 break;
336 case Cond::kNegative:
337 new_cond = Cond::kNotZero;
338 flags_mask = PseudoWriteFlags::Flags::kNegative;
339 break;
340 case Cond::kNotSign:
341 new_cond = Cond::kZero;
342 flags_mask = PseudoWriteFlags::Flags::kNegative;
343 break;
344 case Cond::kOverflow:
345 new_cond = Cond::kNotZero;
346 flags_mask = PseudoWriteFlags::Flags::kOverflow;
347 break;
348 case Cond::kNoOverflow:
349 new_cond = Cond::kZero;
350 flags_mask = PseudoWriteFlags::Flags::kOverflow;
351 break;
352 default:
353 continue;
354 }
355
356 MachineReg flags_src = write_flags->RegAt(0);
357 MachineInsn* new_write_flags =
358 machine_ir->NewInsn<x86_64::TestwRegImm>(flags_src, flags_mask, flags);
359 insn_it = bb->insn_list().erase(insn_it);
360 bb->insn_list().insert(insn_it, new_write_flags);
361 branch->set_cond(new_cond);
362 }
363 }
364
365 } // namespace berberis::x86_64
366