1 /**
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 
17 #define LOG_TAG "InputChannelTest"
18 
19 #include "../includes/common.h"
20 
21 #include <android-base/stringprintf.h>
22 #include <input/InputTransport.h>
23 
24 using namespace android;
25 using android::base::StringPrintf;
26 
memoryAsHexString(const void * const address,size_t numBytes)27 static std::string memoryAsHexString(const void* const address, size_t numBytes) {
28     std::string str;
29     for (size_t i = 0; i < numBytes; i++) {
30         str += StringPrintf("%02X ", static_cast<const uint8_t* const>(address)[i]);
31     }
32     return str;
33 }
34 
35 /**
36  * There could be non-zero bytes in-between InputMessage fields. Force-initialize the entire
37  * memory to zero, then only copy the valid bytes on a per-field basis.
38  * Input: message msg
39  * Output: cleaned message outMsg
40  */
sanitizeMessage(const InputMessage & msg,InputMessage * outMsg)41 static void sanitizeMessage(const InputMessage& msg, InputMessage* outMsg) {
42     memset(outMsg, 0, sizeof(*outMsg));
43 
44     // Write the header
45     outMsg->header.type = msg.header.type;
46     outMsg->header.seq = msg.header.seq;
47 
48     // Write the body
49     switch(msg.header.type) {
50         case InputMessage::Type::KEY: {
51             // int32_t eventId
52             outMsg->body.key.eventId = msg.body.key.eventId;
53             // nsecs_t eventTime
54             outMsg->body.key.eventTime = msg.body.key.eventTime;
55             // int32_t deviceId
56             outMsg->body.key.deviceId = msg.body.key.deviceId;
57             // int32_t source
58             outMsg->body.key.source = msg.body.key.source;
59             // int32_t displayId
60             outMsg->body.key.displayId = msg.body.key.displayId;
61             // std::array<uint8_t, 32> hmac
62             outMsg->body.key.hmac = msg.body.key.hmac;
63             // int32_t action
64             outMsg->body.key.action = msg.body.key.action;
65             // int32_t flags
66             outMsg->body.key.flags = msg.body.key.flags;
67             // int32_t keyCode
68             outMsg->body.key.keyCode = msg.body.key.keyCode;
69             // int32_t scanCode
70             outMsg->body.key.scanCode = msg.body.key.scanCode;
71             // int32_t metaState
72             outMsg->body.key.metaState = msg.body.key.metaState;
73             // int32_t repeatCount
74             outMsg->body.key.repeatCount = msg.body.key.repeatCount;
75             // nsecs_t downTime
76             outMsg->body.key.downTime = msg.body.key.downTime;
77             break;
78         }
79         case InputMessage::Type::MOTION: {
80             // int32_t eventId
81             outMsg->body.motion.eventId = msg.body.key.eventId;
82             // uint32_t pointerCount
83             outMsg->body.motion.pointerCount = msg.body.motion.pointerCount;
84             // nsecs_t eventTime
85             outMsg->body.motion.eventTime = msg.body.motion.eventTime;
86             // int32_t deviceId
87             outMsg->body.motion.deviceId = msg.body.motion.deviceId;
88             // int32_t source
89             outMsg->body.motion.source = msg.body.motion.source;
90             // int32_t displayId
91             outMsg->body.motion.displayId = msg.body.motion.displayId;
92             // std::array<uint8_t, 32> hmac
93             outMsg->body.motion.hmac = msg.body.motion.hmac;
94             // int32_t action
95             outMsg->body.motion.action = msg.body.motion.action;
96             // int32_t actionButton
97             outMsg->body.motion.actionButton = msg.body.motion.actionButton;
98             // int32_t flags
99             outMsg->body.motion.flags = msg.body.motion.flags;
100             // int32_t metaState
101             outMsg->body.motion.metaState = msg.body.motion.metaState;
102             // int32_t buttonState
103             outMsg->body.motion.buttonState = msg.body.motion.buttonState;
104             // MotionClassification classification
105             outMsg->body.motion.classification = msg.body.motion.classification;
106             // int32_t edgeFlags
107             outMsg->body.motion.edgeFlags = msg.body.motion.edgeFlags;
108             // nsecs_t downTime
109             outMsg->body.motion.downTime = msg.body.motion.downTime;
110             // float dsdx
111             outMsg->body.motion.dsdx = msg.body.motion.dsdx;
112             // float dtdx
113             outMsg->body.motion.dtdx = msg.body.motion.dtdx;
114             // float dtdy
115             outMsg->body.motion.dtdy = msg.body.motion.dtdy;
116             // float dsdy
117             outMsg->body.motion.dsdy = msg.body.motion.dsdy;
118             // float tx
119             outMsg->body.motion.tx = msg.body.motion.tx;
120             // float ty
121             outMsg->body.motion.ty = msg.body.motion.ty;
122             // float xPrecision
123             outMsg->body.motion.xPrecision = msg.body.motion.xPrecision;
124             // float yPrecision
125             outMsg->body.motion.yPrecision = msg.body.motion.yPrecision;
126             // float xCursorPosition
127             outMsg->body.motion.xCursorPosition = msg.body.motion.xCursorPosition;
128             // float yCursorPosition
129             outMsg->body.motion.yCursorPosition = msg.body.motion.yCursorPosition;
130             // float dsdxDisplay
131             outMsg->body.motion.dsdxRaw = msg.body.motion.dsdxRaw;
132             // float dtdxDisplay
133             outMsg->body.motion.dtdxRaw = msg.body.motion.dtdxRaw;
134             // float dtdyDisplay
135             outMsg->body.motion.dtdyRaw = msg.body.motion.dtdyRaw;
136             // float dsdyDisplay
137             outMsg->body.motion.dsdyRaw = msg.body.motion.dsdyRaw;
138             // float txDisplay
139             outMsg->body.motion.txRaw = msg.body.motion.txRaw;
140             // float tyDisplay
141             outMsg->body.motion.tyRaw = msg.body.motion.tyRaw;
142             //struct Pointer pointers[MAX_POINTERS]
143             for (size_t i = 0; i < msg.body.motion.pointerCount; i++) {
144                 // PointerProperties properties
145                 outMsg->body.motion.pointers[i].properties.id =
146                         msg.body.motion.pointers[i].properties.id;
147                 outMsg->body.motion.pointers[i].properties.toolType =
148                         msg.body.motion.pointers[i].properties.toolType;
149                 // PointerCoords coords
150                 outMsg->body.motion.pointers[i].coords.bits =
151                         msg.body.motion.pointers[i].coords.bits;
152                 const uint32_t count = BitSet64::count(msg.body.motion.pointers[i].coords.bits);
153                 memcpy(&outMsg->body.motion.pointers[i].coords.values[0],
154                         &msg.body.motion.pointers[i].coords.values[0],
155                         count * sizeof(msg.body.motion.pointers[i].coords.values[0]));
156                 outMsg->body.motion.pointers[i].coords.isResampled =
157                         msg.body.motion.pointers[i].coords.isResampled;
158             }
159             break;
160         }
161         case InputMessage::Type::FINISHED: {
162             outMsg->body.finished.handled = msg.body.finished.handled;
163             outMsg->body.finished.consumeTime = msg.body.finished.consumeTime;
164             break;
165         }
166         case InputMessage::Type::FOCUS: {
167             outMsg->body.focus.eventId = msg.body.focus.eventId;
168             outMsg->body.focus.hasFocus = msg.body.focus.hasFocus;
169             break;
170         }
171         case InputMessage::Type::CAPTURE: {
172             outMsg->body.capture.eventId = msg.body.capture.eventId;
173             outMsg->body.capture.pointerCaptureEnabled = msg.body.capture.pointerCaptureEnabled;
174             break;
175         }
176         case InputMessage::Type::DRAG: {
177             outMsg->body.capture.eventId = msg.body.capture.eventId;
178             outMsg->body.drag.isExiting = msg.body.drag.isExiting;
179             outMsg->body.drag.x = msg.body.drag.x;
180             outMsg->body.drag.y = msg.body.drag.y;
181             break;
182         }
183         case InputMessage::Type::TIMELINE: {
184             outMsg->body.timeline.eventId = msg.body.timeline.eventId;
185             outMsg->body.timeline.graphicsTimeline = msg.body.timeline.graphicsTimeline;
186             break;
187         }
188         case InputMessage::Type::TOUCH_MODE: {
189             outMsg->body.touchMode.eventId = msg.body.timeline.eventId;
190             outMsg->body.touchMode.isInTouchMode = msg.body.touchMode.isInTouchMode;
191         }
192     }
193 }
194 
makeMessageValid(InputMessage & msg)195 static void makeMessageValid(InputMessage& msg) {
196     InputMessage::Type type = msg.header.type;
197     if (type == InputMessage::Type::MOTION) {
198         // Message is considered invalid if it has more than MAX_POINTERS pointers.
199         msg.body.motion.pointerCount = MAX_POINTERS;
200     }
201     if (type == InputMessage::Type::TIMELINE) {
202         // Message is considered invalid if presentTime <= gpuCompletedTime
203         msg.body.timeline.graphicsTimeline[GraphicsTimeline::GPU_COMPLETED_TIME] = 10;
204         msg.body.timeline.graphicsTimeline[GraphicsTimeline::PRESENT_TIME] = 20;
205     }
206 }
207 
208 /**
209  * Return false if vulnerability is found for a given message type
210  */
checkMessage(InputChannel & server,InputChannel & client,InputMessage::Type type)211 static bool checkMessage(InputChannel& server, InputChannel& client, InputMessage::Type type) {
212     InputMessage serverMsg;
213     // Set all potentially uninitialized bytes to 1, for easier comparison
214 
215     memset(&serverMsg, 1, sizeof(serverMsg));
216     serverMsg.header.type = type;
217     makeMessageValid(serverMsg);
218     status_t result = server.sendMessage(&serverMsg);
219     if (result != OK) {
220         ALOGE("Could not send message to the input channel");
221         return false;
222     }
223 
224     InputMessage clientMsg;
225     result = client.receiveMessage(&clientMsg);
226     if (result != OK) {
227         ALOGE("Could not receive message from the input channel");
228         return false;
229     }
230     if (serverMsg.header.type != clientMsg.header.type) {
231         ALOGE("Types do not match");
232         return false;
233     }
234 
235     InputMessage sanitizedClientMsg;
236     sanitizeMessage(clientMsg, &sanitizedClientMsg);
237     if (memcmp(&clientMsg, &sanitizedClientMsg, clientMsg.size()) != 0) {
238         ALOGE("Client received un-sanitized message");
239         ALOGE("Received message: %s", memoryAsHexString(&clientMsg, clientMsg.size()).c_str());
240         ALOGE("Expected message: %s",
241                 memoryAsHexString(&sanitizedClientMsg, clientMsg.size()).c_str());
242         return false;
243     }
244 
245     return true;
246 }
247 
248 /**
249  * Create an unsanitized message
250  * Send
251  * Receive
252  * Compare the received message to a sanitized expected message
253  * Do this for all message types
254  */
main()255 int main() {
256     std::unique_ptr<InputChannel> server, client;
257 
258     status_t result = InputChannel::openInputChannelPair("channel name", server, client);
259     if (result != OK) {
260         ALOGE("Could not open input channel pair");
261         return 0;
262     }
263 
264     InputMessage::Type types[] = {
265             InputMessage::Type::KEY,      InputMessage::Type::MOTION,
266             InputMessage::Type::FINISHED, InputMessage::Type::FOCUS,
267             InputMessage::Type::CAPTURE,  InputMessage::Type::DRAG,
268             InputMessage::Type::TIMELINE, InputMessage::Type::TOUCH_MODE,
269     };
270     for (InputMessage::Type type : types) {
271         bool success = checkMessage(*server, *client, type);
272         if (!success) {
273             ALOGE("Check message failed for type %i", type);
274             return EXIT_VULNERABLE;
275         }
276     }
277 
278     return 0;
279 }
280