1 /*
2  * Copyright (C) 2022 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <MtpDataPacket.h>
18 #include <MtpDevHandle.h>
19 #include <MtpPacketFuzzerUtils.h>
20 #include <functional>
21 #include <fuzzer/FuzzedDataProvider.h>
22 #include <utils/String16.h>
23 
24 using namespace android;
25 
26 class MtpDataPacketFuzzer : MtpPacketFuzzerUtils {
27   public:
MtpDataPacketFuzzer(const uint8_t * data,size_t size)28     MtpDataPacketFuzzer(const uint8_t* data, size_t size) : mFdp(data, size) {
29         mUsbDevFsUrb = (struct usbdevfs_urb*)malloc(sizeof(struct usbdevfs_urb) +
30                                                    sizeof(struct usbdevfs_iso_packet_desc));
31     };
~MtpDataPacketFuzzer()32     ~MtpDataPacketFuzzer() { free(mUsbDevFsUrb); };
33     void process();
34 
35   private:
36     FuzzedDataProvider mFdp;
37 };
38 
process()39 void MtpDataPacketFuzzer::process() {
40     MtpDataPacket mtpDataPacket;
41     while (mFdp.remaining_bytes() > 0) {
42         auto mtpDataAPI = mFdp.PickValueInArray<const std::function<void()>>({
43                 [&]() { mtpDataPacket.allocate(mFdp.ConsumeIntegralInRange(kMinSize, kMaxSize)); },
44                 [&]() { mtpDataPacket.reset(); },
45                 [&]() {
46                     mtpDataPacket.setOperationCode(mFdp.ConsumeIntegralInRange(kMinSize, kMaxSize));
47                 },
48                 [&]() {
49                     mtpDataPacket.setTransactionID(mFdp.ConsumeIntegralInRange(kMinSize, kMaxSize));
50                 },
51                 [&]() {
52                     Int8List* result = mtpDataPacket.getAInt8();
53                     delete result;
54                 },
55                 [&]() {
56                     Int16List* result = mtpDataPacket.getAInt16();
57                     delete result;
58                 },
59                 [&]() {
60                     Int32List* result = mtpDataPacket.getAInt32();
61                     delete result;
62                 },
63                 [&]() {
64                     Int64List* result = mtpDataPacket.getAInt64();
65                     delete result;
66                 },
67                 [&]() {
68                     UInt8List* result = mtpDataPacket.getAUInt8();
69                     delete result;
70                 },
71                 [&]() {
72                     UInt16List* result = mtpDataPacket.getAUInt16();
73                     delete result;
74                 },
75                 [&]() {
76                     UInt32List* result = mtpDataPacket.getAUInt32();
77                     delete result;
78                 },
79                 [&]() {
80                     UInt64List* result = mtpDataPacket.getAUInt64();
81                     delete result;
82                 },
83                 [&]() {
84                     if (mFdp.ConsumeBool()) {
85                         std::vector<uint8_t> initData =
86                                 mFdp.ConsumeBytes<uint8_t>(mFdp.ConsumeIntegral<uint8_t>());
87                         mtpDataPacket.putAUInt8(initData.data(), initData.size());
88                     } else {
89                         mtpDataPacket.putAUInt8(nullptr, 0);
90                     }
91                 },
92                 [&]() {
93                     if (mFdp.ConsumeBool()) {
94                         size_t size = mFdp.ConsumeIntegralInRange<size_t>(kMinSize, kMaxSize);
95                         uint16_t arr[size];
96                         for (size_t idx = 0; idx < size; ++idx) {
97                             arr[idx] = mFdp.ConsumeIntegral<uint16_t>();
98                         }
99                         mtpDataPacket.putAUInt16(arr, size);
100                     } else {
101                         mtpDataPacket.putAUInt16(nullptr, 0);
102                     }
103                 },
104                 [&]() {
105                     if (mFdp.ConsumeBool()) {
106                         size_t size = mFdp.ConsumeIntegralInRange<size_t>(kMinSize, kMaxSize);
107                         uint32_t arr[size];
108                         for (size_t idx = 0; idx < size; ++idx) {
109                             arr[idx] = mFdp.ConsumeIntegral<uint32_t>();
110                         }
111                         mtpDataPacket.putAUInt32(arr, size);
112                     } else {
113                         mtpDataPacket.putAUInt32(nullptr, 0);
114                     }
115                 },
116                 [&]() {
117                     if (mFdp.ConsumeBool()) {
118                         size_t size = mFdp.ConsumeIntegralInRange<size_t>(kMinSize, kMaxSize);
119                         uint64_t arr[size];
120                         for (size_t idx = 0; idx < size; ++idx) {
121                             arr[idx] = mFdp.ConsumeIntegral<uint64_t>();
122                         }
123                         mtpDataPacket.putAUInt64(arr, size);
124                     } else {
125                         mtpDataPacket.putAUInt64(nullptr, 0);
126                     }
127                 },
128                 [&]() {
129                     if (mFdp.ConsumeBool()) {
130                         size_t size = mFdp.ConsumeIntegralInRange<size_t>(kMinSize, kMaxSize);
131                         int64_t arr[size];
132                         for (size_t idx = 0; idx < size; ++idx) {
133                             arr[idx] = mFdp.ConsumeIntegral<int64_t>();
134                         }
135                         mtpDataPacket.putAInt64(arr, size);
136                     } else {
137                         mtpDataPacket.putAInt64(nullptr, 0);
138                     }
139                 },
140                 [&]() {
141                     if (mFdp.ConsumeBool()) {
142                         std::vector<uint16_t> arr;
143                         size_t size = mFdp.ConsumeIntegralInRange<size_t>(kMinSize, kMaxSize);
144                         for (size_t idx = 0; idx < size; ++idx) {
145                             arr.push_back(mFdp.ConsumeIntegral<uint16_t>());
146                         }
147                         mtpDataPacket.putAUInt16(&arr);
148                     } else {
149                         mtpDataPacket.putAUInt16(nullptr);
150                     }
151                 },
152                 [&]() {
153                     if (mFdp.ConsumeBool()) {
154                         std::vector<uint32_t> arr;
155                         size_t size = mFdp.ConsumeIntegralInRange<size_t>(kMinSize, kMaxSize);
156                         for (size_t idx = 0; idx < size; ++idx) {
157                             arr.push_back(mFdp.ConsumeIntegral<uint32_t>());
158                         }
159                         mtpDataPacket.putAUInt32(&arr);
160                     } else {
161                         mtpDataPacket.putAUInt32(nullptr);
162                     }
163                 },
164 
165                 [&]() {
166                     if (mFdp.ConsumeBool()) {
167                         size_t size = mFdp.ConsumeIntegralInRange<size_t>(kMinSize, kMaxSize);
168                         int32_t arr[size];
169                         for (size_t idx = 0; idx < size; ++idx) {
170                             arr[idx] = mFdp.ConsumeIntegral<int32_t>();
171                         }
172                         mtpDataPacket.putAInt32(arr, size);
173                     } else {
174                         mtpDataPacket.putAInt32(nullptr, 0);
175                     }
176                 },
177                 [&]() {
178                     if (mFdp.ConsumeBool()) {
179                         mtpDataPacket.putString(
180                                 (mFdp.ConsumeRandomLengthString(kMaxLength)).c_str());
181                     } else {
182                         mtpDataPacket.putString(static_cast<char*>(nullptr));
183                     }
184                 },
185                 [&]() {
186                     android::MtpStringBuffer sBuffer(
187                             (mFdp.ConsumeRandomLengthString(kMaxLength)).c_str());
188                     if (mFdp.ConsumeBool()) {
189                         mtpDataPacket.getString(sBuffer);
190                     } else {
191                         mtpDataPacket.putString(sBuffer);
192                     }
193                 },
194                 [&]() {
195                     MtpDevHandle handle;
196                     handle.start(mFdp.ConsumeBool());
197                     std::string text = mFdp.ConsumeRandomLengthString(kMaxLength);
198                     char* data = const_cast<char*>(text.c_str());
199                     handle.read(static_cast<void*>(data), text.length());
200                     if (mFdp.ConsumeBool()) {
201                         mtpDataPacket.read(&handle);
202                     } else if (mFdp.ConsumeBool()) {
203                         mtpDataPacket.write(&handle);
204                     } else {
205                         std::string textData = mFdp.ConsumeRandomLengthString(kMaxLength);
206                         char* Data = const_cast<char*>(textData.c_str());
207                         mtpDataPacket.writeData(&handle, static_cast<void*>(Data),
208                                                 textData.length());
209                     }
210                     handle.close();
211                 },
212                 [&]() {
213                     if (mFdp.ConsumeBool()) {
214                         std::string str = mFdp.ConsumeRandomLengthString(kMaxLength);
215                         android::String16 s(str.c_str());
216                         char16_t* data = const_cast<char16_t*>(s.c_str());
217                         mtpDataPacket.putString(reinterpret_cast<uint16_t*>(data));
218                     } else {
219                         mtpDataPacket.putString(static_cast<uint16_t*>(nullptr));
220                     }
221                 },
222                 [&]() {
223                     if (mFdp.ConsumeBool()) {
224                         std::vector<int8_t> data = mFdp.ConsumeBytes<int8_t>(
225                                 mFdp.ConsumeIntegralInRange<size_t>(kMinSize, kMaxSize));
226                         mtpDataPacket.putAInt8(data.data(), data.size());
227                     } else {
228                         mtpDataPacket.putAInt8(nullptr, 0);
229                     }
230                 },
231                 [&]() {
232                     if (mFdp.ConsumeBool()) {
233                         std::vector<uint8_t> data = mFdp.ConsumeBytes<uint8_t>(
234                                 mFdp.ConsumeIntegralInRange<size_t>(kMinSize, kMaxSize));
235                         mtpDataPacket.putAUInt8(data.data(), data.size());
236                     } else {
237                         mtpDataPacket.putAUInt8(nullptr, 0);
238                     }
239                 },
240                 [&]() {
241                     fillFilePath(&mFdp);
242                     int32_t fd = memfd_create(mPath.c_str(), MFD_ALLOW_SEALING);
243                     fillUsbRequest(fd, &mFdp);
244                     mUsbRequest.dev = usb_device_new(mPath.c_str(), fd);
245                     std::vector<int8_t> data = mFdp.ConsumeBytes<int8_t>(
246                             mFdp.ConsumeIntegralInRange<size_t>(kMinSize, kMaxSize));
247                     mtpDataPacket.readData(&mUsbRequest, data.data(), data.size());
248                     usb_device_close(mUsbRequest.dev);
249                 },
250                 [&]() {
251                     fillFilePath(&mFdp);
252                     int32_t fd = memfd_create(mPath.c_str(), MFD_ALLOW_SEALING);
253                     fillUsbRequest(fd, &mFdp);
254                     mUsbRequest.dev = usb_device_new(mPath.c_str(), fd);
255                     mtpDataPacket.write(
256                             &mUsbRequest,
257                             mFdp.PickValueInArray<UrbPacketDivisionMode>(kUrbPacketDivisionModes),
258                             fd, mFdp.ConsumeIntegralInRange(kMinSize, kMaxSize));
259                     usb_device_close(mUsbRequest.dev);
260                 },
261                 [&]() {
262                     fillFilePath(&mFdp);
263                     int32_t fd = memfd_create(mPath.c_str(), MFD_ALLOW_SEALING);
264                     fillUsbRequest(fd, &mFdp);
265                     mUsbRequest.dev = usb_device_new(mPath.c_str(), fd);
266                     mtpDataPacket.read(&mUsbRequest);
267                     usb_device_close(mUsbRequest.dev);
268                 },
269                 [&]() {
270                     fillFilePath(&mFdp);
271                     int32_t fd = memfd_create(mPath.c_str(), MFD_ALLOW_SEALING);
272                     fillUsbRequest(fd, &mFdp);
273                     mUsbRequest.dev = usb_device_new(mPath.c_str(), fd);
274                     mtpDataPacket.write(&mUsbRequest, mFdp.PickValueInArray<UrbPacketDivisionMode>(
275                                                              kUrbPacketDivisionModes));
276                     usb_device_close(mUsbRequest.dev);
277                 },
278                 [&]() {
279                     fillFilePath(&mFdp);
280                     int32_t fd = memfd_create(mPath.c_str(), MFD_ALLOW_SEALING);
281                     fillUsbRequest(fd, &mFdp);
282                     mUsbRequest.dev = usb_device_new(mPath.c_str(), fd);
283                     mtpDataPacket.readDataHeader(&mUsbRequest);
284                     usb_device_close(mUsbRequest.dev);
285                 },
286                 [&]() {
287                     fillFilePath(&mFdp);
288                     int32_t fd = memfd_create(mPath.c_str(), MFD_ALLOW_SEALING);
289                     fillUsbRequest(fd, &mFdp);
290                     mUsbRequest.dev = usb_device_new(mPath.c_str(), fd);
291                     mtpDataPacket.readDataAsync(&mUsbRequest);
292                     usb_device_close(mUsbRequest.dev);
293                 },
294                 [&]() {
295                     fillFilePath(&mFdp);
296                     int32_t fd = memfd_create(mPath.c_str(), MFD_ALLOW_SEALING);
297                     fillUsbRequest(fd, &mFdp);
298                     mUsbRequest.dev = usb_device_new(mPath.c_str(), fd);
299                     mtpDataPacket.readDataWait(mUsbRequest.dev);
300                     usb_device_close(mUsbRequest.dev);
301                 },
302                 [&]() {
303                     if (mFdp.ConsumeBool()) {
304                         std::vector<int16_t> data;
305                         for (size_t idx = 0;
306                              idx < mFdp.ConsumeIntegralInRange<size_t>(kMinSize, kMaxSize); ++idx) {
307                             data.push_back(mFdp.ConsumeIntegral<int16_t>());
308                         }
309                         mtpDataPacket.putAInt16(data.data(), data.size());
310                     } else {
311                         mtpDataPacket.putAInt16(nullptr, 0);
312                     }
313                 },
314                 [&]() {
315                     int32_t arr[4];
316                     for (size_t idx = 0; idx < 4; ++idx) {
317                         arr[idx] = mFdp.ConsumeIntegral<int32_t>();
318                     }
319                     mtpDataPacket.putInt128(arr);
320                 },
321                 [&]() { mtpDataPacket.putInt64(mFdp.ConsumeIntegral<int64_t>()); },
322                 [&]() {
323                     int16_t out;
324                     mtpDataPacket.getInt16(out);
325                 },
326                 [&]() {
327                     int32_t out;
328                     mtpDataPacket.getInt32(out);
329                 },
330                 [&]() {
331                     int8_t out;
332                     mtpDataPacket.getInt8(out);
333                 },
334                 [&]() {
335                     uint32_t arr[4];
336                     for (size_t idx = 0; idx < 4; ++idx) {
337                         arr[idx] = mFdp.ConsumeIntegral<uint32_t>();
338                     }
339                     if (mFdp.ConsumeBool()) {
340                         mtpDataPacket.putUInt128(arr);
341                     } else {
342                         mtpDataPacket.getUInt128(arr);
343                     }
344                 },
345                 [&]() { mtpDataPacket.putUInt64(mFdp.ConsumeIntegral<uint64_t>()); },
346                 [&]() {
347                     uint64_t out;
348                     mtpDataPacket.getUInt64(out);
349                 },
350                 [&]() { mtpDataPacket.putInt128(mFdp.ConsumeIntegral<int64_t>()); },
351                 [&]() { mtpDataPacket.putUInt128(mFdp.ConsumeIntegral<uint64_t>()); },
352                 [&]() {
353                     int32_t length;
354                     void* data = mtpDataPacket.getData(&length);
355                     free(data);
356                 },
357         });
358         mtpDataAPI();
359     }
360 }
361 
LLVMFuzzerTestOneInput(const uint8_t * data,size_t size)362 extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
363     MtpDataPacketFuzzer mtpDataPacketFuzzer(data, size);
364     mtpDataPacketFuzzer.process();
365     return 0;
366 }
367