1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "calling_convention_riscv64.h"
18 
19 #include <android-base/logging.h>
20 
21 #include "arch/instruction_set.h"
22 #include "arch/riscv64/jni_frame_riscv64.h"
23 #include "utils/riscv64/managed_register_riscv64.h"
24 
25 namespace art HIDDEN {
26 namespace riscv64 {
27 
28 static constexpr ManagedRegister kXArgumentRegisters[] = {
29     Riscv64ManagedRegister::FromXRegister(A0),
30     Riscv64ManagedRegister::FromXRegister(A1),
31     Riscv64ManagedRegister::FromXRegister(A2),
32     Riscv64ManagedRegister::FromXRegister(A3),
33     Riscv64ManagedRegister::FromXRegister(A4),
34     Riscv64ManagedRegister::FromXRegister(A5),
35     Riscv64ManagedRegister::FromXRegister(A6),
36     Riscv64ManagedRegister::FromXRegister(A7),
37 };
38 static_assert(kMaxIntLikeArgumentRegisters == arraysize(kXArgumentRegisters));
39 
40 static const FRegister kFArgumentRegisters[] = {
41   FA0, FA1, FA2, FA3, FA4, FA5, FA6, FA7
42 };
43 static_assert(kMaxFloatOrDoubleArgumentRegisters == arraysize(kFArgumentRegisters));
44 
45 static constexpr ManagedRegister kCalleeSaveRegisters[] = {
46     // Core registers.
47     Riscv64ManagedRegister::FromXRegister(S0),
48     // ART thread register (TR = S1) is not saved on the stack.
49     Riscv64ManagedRegister::FromXRegister(S2),
50     Riscv64ManagedRegister::FromXRegister(S3),
51     Riscv64ManagedRegister::FromXRegister(S4),
52     Riscv64ManagedRegister::FromXRegister(S5),
53     Riscv64ManagedRegister::FromXRegister(S6),
54     Riscv64ManagedRegister::FromXRegister(S7),
55     Riscv64ManagedRegister::FromXRegister(S8),
56     Riscv64ManagedRegister::FromXRegister(S9),
57     Riscv64ManagedRegister::FromXRegister(S10),
58     Riscv64ManagedRegister::FromXRegister(S11),
59     Riscv64ManagedRegister::FromXRegister(RA),
60 
61     // Hard float registers.
62     Riscv64ManagedRegister::FromFRegister(FS0),
63     Riscv64ManagedRegister::FromFRegister(FS1),
64     Riscv64ManagedRegister::FromFRegister(FS2),
65     Riscv64ManagedRegister::FromFRegister(FS3),
66     Riscv64ManagedRegister::FromFRegister(FS4),
67     Riscv64ManagedRegister::FromFRegister(FS5),
68     Riscv64ManagedRegister::FromFRegister(FS6),
69     Riscv64ManagedRegister::FromFRegister(FS7),
70     Riscv64ManagedRegister::FromFRegister(FS8),
71     Riscv64ManagedRegister::FromFRegister(FS9),
72     Riscv64ManagedRegister::FromFRegister(FS10),
73     Riscv64ManagedRegister::FromFRegister(FS11),
74 };
75 
76 template <size_t size>
CalculateCoreCalleeSpillMask(const ManagedRegister (& callee_saves)[size])77 static constexpr uint32_t CalculateCoreCalleeSpillMask(
78     const ManagedRegister (&callee_saves)[size]) {
79   uint32_t result = 0u;
80   for (auto&& r : callee_saves) {
81     if (r.AsRiscv64().IsXRegister()) {
82       result |= (1u << r.AsRiscv64().AsXRegister());
83     }
84   }
85   return result;
86 }
87 
88 template <size_t size>
CalculateFpCalleeSpillMask(const ManagedRegister (& callee_saves)[size])89 static constexpr uint32_t CalculateFpCalleeSpillMask(const ManagedRegister (&callee_saves)[size]) {
90   uint32_t result = 0u;
91   for (auto&& r : callee_saves) {
92     if (r.AsRiscv64().IsFRegister()) {
93       result |= (1u << r.AsRiscv64().AsFRegister());
94     }
95   }
96   return result;
97 }
98 
99 static constexpr uint32_t kCoreCalleeSpillMask = CalculateCoreCalleeSpillMask(kCalleeSaveRegisters);
100 static constexpr uint32_t kFpCalleeSpillMask = CalculateFpCalleeSpillMask(kCalleeSaveRegisters);
101 
102 static constexpr ManagedRegister kNativeCalleeSaveRegisters[] = {
103     // Core registers.
104     Riscv64ManagedRegister::FromXRegister(S0),
105     Riscv64ManagedRegister::FromXRegister(S1),
106     Riscv64ManagedRegister::FromXRegister(S2),
107     Riscv64ManagedRegister::FromXRegister(S3),
108     Riscv64ManagedRegister::FromXRegister(S4),
109     Riscv64ManagedRegister::FromXRegister(S5),
110     Riscv64ManagedRegister::FromXRegister(S6),
111     Riscv64ManagedRegister::FromXRegister(S7),
112     Riscv64ManagedRegister::FromXRegister(S8),
113     Riscv64ManagedRegister::FromXRegister(S9),
114     Riscv64ManagedRegister::FromXRegister(S10),
115     Riscv64ManagedRegister::FromXRegister(S11),
116     Riscv64ManagedRegister::FromXRegister(RA),
117 
118     // Hard float registers.
119     Riscv64ManagedRegister::FromFRegister(FS0),
120     Riscv64ManagedRegister::FromFRegister(FS1),
121     Riscv64ManagedRegister::FromFRegister(FS2),
122     Riscv64ManagedRegister::FromFRegister(FS3),
123     Riscv64ManagedRegister::FromFRegister(FS4),
124     Riscv64ManagedRegister::FromFRegister(FS5),
125     Riscv64ManagedRegister::FromFRegister(FS6),
126     Riscv64ManagedRegister::FromFRegister(FS7),
127     Riscv64ManagedRegister::FromFRegister(FS8),
128     Riscv64ManagedRegister::FromFRegister(FS9),
129     Riscv64ManagedRegister::FromFRegister(FS10),
130     Riscv64ManagedRegister::FromFRegister(FS11),
131 };
132 
133 static constexpr uint32_t kNativeCoreCalleeSpillMask =
134     CalculateCoreCalleeSpillMask(kNativeCalleeSaveRegisters);
135 static constexpr uint32_t kNativeFpCalleeSpillMask =
136     CalculateFpCalleeSpillMask(kNativeCalleeSaveRegisters);
137 
ReturnRegisterForShorty(std::string_view shorty)138 static ManagedRegister ReturnRegisterForShorty(std::string_view shorty) {
139   if (shorty[0] == 'F' || shorty[0] == 'D') {
140     return Riscv64ManagedRegister::FromFRegister(FA0);
141   } else if (shorty[0] == 'V') {
142     return Riscv64ManagedRegister::NoRegister();
143   } else {
144     // All other return types use A0. Note that there is no managed type wide enough to use A1/FA1.
145     return Riscv64ManagedRegister::FromXRegister(A0);
146   }
147 }
148 
149 // Managed runtime calling convention
150 
ReturnRegister() const151 ManagedRegister Riscv64ManagedRuntimeCallingConvention::ReturnRegister() const {
152   return ReturnRegisterForShorty(GetShorty());
153 }
154 
MethodRegister()155 ManagedRegister Riscv64ManagedRuntimeCallingConvention::MethodRegister() {
156   return Riscv64ManagedRegister::FromXRegister(A0);
157 }
158 
ArgumentRegisterForMethodExitHook()159 ManagedRegister Riscv64ManagedRuntimeCallingConvention::ArgumentRegisterForMethodExitHook() {
160   DCHECK(!Riscv64ManagedRegister::FromXRegister(A4).Overlaps(ReturnRegister().AsRiscv64()));
161   return Riscv64ManagedRegister::FromXRegister(A4);
162 }
163 
IsCurrentParamInRegister()164 bool Riscv64ManagedRuntimeCallingConvention::IsCurrentParamInRegister() {
165   // Note: The managed ABI does not pass FP args in general purpose registers.
166   // This differs from the native ABI which does that after using all FP arg registers.
167   if (IsCurrentParamAFloatOrDouble()) {
168     return itr_float_and_doubles_ < kMaxFloatOrDoubleArgumentRegisters;
169   } else {
170     size_t non_fp_arg_number = itr_args_ - itr_float_and_doubles_;
171     return /* method */ 1u + non_fp_arg_number < kMaxIntLikeArgumentRegisters;
172   }
173 }
174 
IsCurrentParamOnStack()175 bool Riscv64ManagedRuntimeCallingConvention::IsCurrentParamOnStack() {
176   return !IsCurrentParamInRegister();
177 }
178 
CurrentParamRegister()179 ManagedRegister Riscv64ManagedRuntimeCallingConvention::CurrentParamRegister() {
180   DCHECK(IsCurrentParamInRegister());
181   if (IsCurrentParamAFloatOrDouble()) {
182     return Riscv64ManagedRegister::FromFRegister(kFArgumentRegisters[itr_float_and_doubles_]);
183   } else {
184     size_t non_fp_arg_number = itr_args_ - itr_float_and_doubles_;
185     return kXArgumentRegisters[/* method */ 1u + non_fp_arg_number];
186   }
187 }
188 
CurrentParamStackOffset()189 FrameOffset Riscv64ManagedRuntimeCallingConvention::CurrentParamStackOffset() {
190   return FrameOffset(displacement_.Int32Value() +  // displacement
191                      kFramePointerSize +  // Method ref
192                      (itr_slots_ * sizeof(uint32_t)));  // offset into in args
193 }
194 
195 // JNI calling convention
196 
Riscv64JniCallingConvention(bool is_static,bool is_synchronized,bool is_fast_native,bool is_critical_native,std::string_view shorty)197 Riscv64JniCallingConvention::Riscv64JniCallingConvention(bool is_static,
198                                                          bool is_synchronized,
199                                                          bool is_fast_native,
200                                                          bool is_critical_native,
201                                                          std::string_view shorty)
202     : JniCallingConvention(is_static,
203                            is_synchronized,
204                            is_fast_native,
205                            is_critical_native,
206                            shorty,
207                            kRiscv64PointerSize) {
208 }
209 
ReturnRegister() const210 ManagedRegister Riscv64JniCallingConvention::ReturnRegister() const {
211   return ReturnRegisterForShorty(GetShorty());
212 }
213 
IntReturnRegister() const214 ManagedRegister Riscv64JniCallingConvention::IntReturnRegister() const {
215   return Riscv64ManagedRegister::FromXRegister(A0);
216 }
217 
FrameSize() const218 size_t Riscv64JniCallingConvention::FrameSize() const {
219   if (is_critical_native_) {
220     CHECK(!SpillsMethod());
221     CHECK(!HasLocalReferenceSegmentState());
222     return 0u;  // There is no managed frame for @CriticalNative.
223   }
224 
225   // Method*, callee save area size, local reference segment state
226   DCHECK(SpillsMethod());
227   size_t method_ptr_size = static_cast<size_t>(kFramePointerSize);
228   size_t callee_save_area_size = CalleeSaveRegisters().size() * kFramePointerSize;
229   size_t total_size = method_ptr_size + callee_save_area_size;
230 
231   DCHECK(HasLocalReferenceSegmentState());
232   // Cookie is saved in one of the spilled registers.
233 
234   return RoundUp(total_size, kStackAlignment);
235 }
236 
OutFrameSize() const237 size_t Riscv64JniCallingConvention::OutFrameSize() const {
238   // Count param args, including JNIEnv* and jclass*.
239   size_t all_args = NumberOfExtraArgumentsForJni() + NumArgs();
240   size_t num_fp_args = NumFloatOrDoubleArgs();
241   DCHECK_GE(all_args, num_fp_args);
242   size_t num_non_fp_args = all_args - num_fp_args;
243   // The size of outgoing arguments.
244   size_t size = GetNativeOutArgsSize(num_fp_args, num_non_fp_args);
245 
246   // @CriticalNative can use tail call as all managed callee saves are preserved by AAPCS64.
247   static_assert((kCoreCalleeSpillMask & ~kNativeCoreCalleeSpillMask) == 0u);
248   static_assert((kFpCalleeSpillMask & ~kNativeFpCalleeSpillMask) == 0u);
249 
250   // For @CriticalNative, we can make a tail call if there are no stack args.
251   // Otherwise, add space for return PC.
252   // Note: Result does not neeed to be zero- or sign-extended.
253   DCHECK(!RequiresSmallResultTypeExtension());
254   if (is_critical_native_ && size != 0u) {
255     size += kFramePointerSize;  // We need to spill RA with the args.
256   }
257   size_t out_args_size = RoundUp(size, kNativeStackAlignment);
258   if (UNLIKELY(IsCriticalNative())) {
259     DCHECK_EQ(out_args_size, GetCriticalNativeStubFrameSize(GetShorty()));
260   }
261   return out_args_size;
262 }
263 
CalleeSaveRegisters() const264 ArrayRef<const ManagedRegister> Riscv64JniCallingConvention::CalleeSaveRegisters() const {
265   if (UNLIKELY(IsCriticalNative())) {
266     if (UseTailCall()) {
267       return ArrayRef<const ManagedRegister>();  // Do not spill anything.
268     } else {
269       // Spill RA with out args.
270       static_assert((kCoreCalleeSpillMask & (1 << RA)) != 0u);  // Contains RA.
271       constexpr size_t ra_index = POPCOUNT(kCoreCalleeSpillMask) - 1u;
272       static_assert(kCalleeSaveRegisters[ra_index].Equals(
273                         Riscv64ManagedRegister::FromXRegister(RA)));
274       return ArrayRef<const ManagedRegister>(kCalleeSaveRegisters).SubArray(
275           /*pos=*/ ra_index, /*length=*/ 1u);
276     }
277   } else {
278     return ArrayRef<const ManagedRegister>(kCalleeSaveRegisters);
279   }
280 }
281 
CalleeSaveScratchRegisters() const282 ArrayRef<const ManagedRegister> Riscv64JniCallingConvention::CalleeSaveScratchRegisters() const {
283   DCHECK(!IsCriticalNative());
284   // Use S3-S11 from managed callee saves. All these registers are also native callee saves.
285   constexpr size_t kStart = 2u;
286   constexpr size_t kLength = 9u;
287   static_assert(kCalleeSaveRegisters[kStart].Equals(Riscv64ManagedRegister::FromXRegister(S3)));
288   static_assert(kCalleeSaveRegisters[kStart + kLength - 1u].Equals(
289                     Riscv64ManagedRegister::FromXRegister(S11)));
290   static_assert((kCoreCalleeSpillMask & ~kNativeCoreCalleeSpillMask) == 0u);
291   return ArrayRef<const ManagedRegister>(kCalleeSaveRegisters).SubArray(kStart, kLength);
292 }
293 
ArgumentScratchRegisters() const294 ArrayRef<const ManagedRegister> Riscv64JniCallingConvention::ArgumentScratchRegisters() const {
295   DCHECK(!IsCriticalNative());
296   ArrayRef<const ManagedRegister> scratch_regs(kXArgumentRegisters);
297   // Exclude return register (A0) even if unused. Using the same scratch registers helps
298   // making more JNI stubs identical for better reuse, such as deduplicating them in oat files.
299   static_assert(kXArgumentRegisters[0].Equals(Riscv64ManagedRegister::FromXRegister(A0)));
300   scratch_regs = scratch_regs.SubArray(/*pos=*/ 1u);
301   DCHECK(std::none_of(scratch_regs.begin(),
302                       scratch_regs.end(),
303                       [return_reg = ReturnRegister().AsRiscv64()](ManagedRegister reg) {
304                         return return_reg.Overlaps(reg.AsRiscv64());
305                       }));
306   return scratch_regs;
307 }
308 
CoreSpillMask() const309 uint32_t Riscv64JniCallingConvention::CoreSpillMask() const {
310   return is_critical_native_ ? 0u : kCoreCalleeSpillMask;
311 }
312 
FpSpillMask() const313 uint32_t Riscv64JniCallingConvention::FpSpillMask() const {
314   return is_critical_native_ ? 0u : kFpCalleeSpillMask;
315 }
316 
CurrentParamSize() const317 size_t Riscv64JniCallingConvention::CurrentParamSize() const {
318   if (IsCurrentArgExtraForJni()) {
319     return static_cast<size_t>(frame_pointer_size_);  // JNIEnv or jobject/jclass
320   } else {
321     size_t arg_pos = GetIteratorPositionWithinShorty();
322     DCHECK_LT(arg_pos, NumArgs());
323     if (IsStatic()) {
324       ++arg_pos;  // 0th argument must skip return value at start of the shorty
325     } else if (arg_pos == 0) {
326       return static_cast<size_t>(kRiscv64PointerSize);  // this argument
327     }
328     // The riscv64 native calling convention specifies that integers narrower than XLEN (64)
329     // bits are "widened according to the sign of their type up to 32 bits, then sign-extended
330     // to XLEN bits." Thus, everything other than `float` (which has the high 32 bits undefined)
331     // is passed as 64 bits, whether in register, or on the stack.
332     return (GetShorty()[arg_pos] == 'F') ? 4u : static_cast<size_t>(kRiscv64PointerSize);
333   }
334 }
335 
IsCurrentParamInRegister()336 bool Riscv64JniCallingConvention::IsCurrentParamInRegister() {
337   // FP args use FPRs, then GPRs and only then the stack.
338   if (itr_float_and_doubles_ < kMaxFloatOrDoubleArgumentRegisters) {
339     if (IsCurrentParamAFloatOrDouble()) {
340       return true;
341     } else {
342       size_t num_non_fp_args = itr_args_ - itr_float_and_doubles_;
343       return num_non_fp_args < kMaxIntLikeArgumentRegisters;
344     }
345   } else {
346     return (itr_args_ < kMaxFloatOrDoubleArgumentRegisters + kMaxIntLikeArgumentRegisters);
347   }
348 }
349 
IsCurrentParamOnStack()350 bool Riscv64JniCallingConvention::IsCurrentParamOnStack() {
351   return !IsCurrentParamInRegister();
352 }
353 
CurrentParamRegister()354 ManagedRegister Riscv64JniCallingConvention::CurrentParamRegister() {
355   // FP args use FPRs, then GPRs and only then the stack.
356   CHECK(IsCurrentParamInRegister());
357   if (itr_float_and_doubles_ < kMaxFloatOrDoubleArgumentRegisters) {
358     if (IsCurrentParamAFloatOrDouble()) {
359       return Riscv64ManagedRegister::FromFRegister(kFArgumentRegisters[itr_float_and_doubles_]);
360     } else {
361       size_t num_non_fp_args = itr_args_ - itr_float_and_doubles_;
362       DCHECK_LT(num_non_fp_args, kMaxIntLikeArgumentRegisters);
363       return kXArgumentRegisters[num_non_fp_args];
364     }
365   } else {
366     // This argument is in a GPR, whether it's a FP arg or a non-FP arg.
367     DCHECK_LT(itr_args_, kMaxFloatOrDoubleArgumentRegisters + kMaxIntLikeArgumentRegisters);
368     return kXArgumentRegisters[itr_args_ - kMaxFloatOrDoubleArgumentRegisters];
369   }
370 }
371 
CurrentParamStackOffset()372 FrameOffset Riscv64JniCallingConvention::CurrentParamStackOffset() {
373   CHECK(IsCurrentParamOnStack());
374   // Account for FP arguments passed through FA0-FA7.
375   // All other args are passed through A0-A7 (even FP args) and the stack.
376   size_t num_gpr_and_stack_args =
377       itr_args_ - std::min<size_t>(kMaxFloatOrDoubleArgumentRegisters, itr_float_and_doubles_);
378   size_t args_on_stack =
379       num_gpr_and_stack_args - std::min(kMaxIntLikeArgumentRegisters, num_gpr_and_stack_args);
380   size_t offset = displacement_.Int32Value() - OutFrameSize() + (args_on_stack * kFramePointerSize);
381   CHECK_LT(offset, OutFrameSize());
382   return FrameOffset(offset);
383 }
384 
RequiresSmallResultTypeExtension() const385 bool Riscv64JniCallingConvention::RequiresSmallResultTypeExtension() const {
386   // RISC-V native calling convention requires values to be returned the way that the first
387   // argument would be passed. Arguments are zero-/sign-extended to 32 bits based on their
388   // type, then sign-extended to 64 bits. This is the same as in the ART mamaged ABI.
389   // (Not applicable to FP args which are returned in `FA0`. A `float` is NaN-boxed.)
390   return false;
391 }
392 
393 // T0 is neither managed callee-save, nor argument register. It is suitable for use as the
394 // locking argument for synchronized methods and hidden argument for @CriticalNative methods.
AssertT0IsNeitherCalleeSaveNorArgumentRegister()395 static void AssertT0IsNeitherCalleeSaveNorArgumentRegister() {
396   // TODO: Change to static_assert; std::none_of should be constexpr since C++20.
397   DCHECK(std::none_of(kCalleeSaveRegisters,
398                       kCalleeSaveRegisters + std::size(kCalleeSaveRegisters),
399                       [](ManagedRegister callee_save) constexpr {
400                         return callee_save.Equals(Riscv64ManagedRegister::FromXRegister(T0));
401                       }));
402   DCHECK(std::none_of(kXArgumentRegisters,
403                       kXArgumentRegisters + std::size(kXArgumentRegisters),
404                       [](ManagedRegister arg) { return arg.AsRiscv64().AsXRegister() == T0; }));
405 }
406 
LockingArgumentRegister() const407 ManagedRegister Riscv64JniCallingConvention::LockingArgumentRegister() const {
408   DCHECK(!IsFastNative());
409   DCHECK(!IsCriticalNative());
410   DCHECK(IsSynchronized());
411   AssertT0IsNeitherCalleeSaveNorArgumentRegister();
412   return Riscv64ManagedRegister::FromXRegister(T0);
413 }
414 
HiddenArgumentRegister() const415 ManagedRegister Riscv64JniCallingConvention::HiddenArgumentRegister() const {
416   DCHECK(IsCriticalNative());
417   AssertT0IsNeitherCalleeSaveNorArgumentRegister();
418   return Riscv64ManagedRegister::FromXRegister(T0);
419 }
420 
421 // Whether to use tail call (used only for @CriticalNative).
UseTailCall() const422 bool Riscv64JniCallingConvention::UseTailCall() const {
423   CHECK(IsCriticalNative());
424   return OutFrameSize() == 0u;
425 }
426 
427 }  // namespace riscv64
428 }  // namespace art
429