1 /**
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "InputChannelTest"
18
19 #include "../includes/common.h"
20
21 #include <android-base/stringprintf.h>
22 #include <input/InputTransport.h>
23
24 using namespace android;
25 using android::base::StringPrintf;
26
memoryAsHexString(const void * const address,size_t numBytes)27 static std::string memoryAsHexString(const void* const address, size_t numBytes) {
28 std::string str;
29 for (size_t i = 0; i < numBytes; i++) {
30 str += StringPrintf("%02X ", static_cast<const uint8_t* const>(address)[i]);
31 }
32 return str;
33 }
34
35 /**
36 * There could be non-zero bytes in-between InputMessage fields. Force-initialize the entire
37 * memory to zero, then only copy the valid bytes on a per-field basis.
38 * Input: message msg
39 * Output: cleaned message outMsg
40 */
sanitizeMessage(const InputMessage & msg,InputMessage * outMsg)41 static void sanitizeMessage(const InputMessage& msg, InputMessage* outMsg) {
42 memset(outMsg, 0, sizeof(*outMsg));
43
44 // Write the header
45 outMsg->header.type = msg.header.type;
46 outMsg->header.seq = msg.header.seq;
47
48 // Write the body
49 switch(msg.header.type) {
50 case InputMessage::Type::KEY: {
51 // int32_t eventId
52 outMsg->body.key.eventId = msg.body.key.eventId;
53 // nsecs_t eventTime
54 outMsg->body.key.eventTime = msg.body.key.eventTime;
55 // int32_t deviceId
56 outMsg->body.key.deviceId = msg.body.key.deviceId;
57 // int32_t source
58 outMsg->body.key.source = msg.body.key.source;
59 // int32_t displayId
60 outMsg->body.key.displayId = msg.body.key.displayId;
61 // std::array<uint8_t, 32> hmac
62 outMsg->body.key.hmac = msg.body.key.hmac;
63 // int32_t action
64 outMsg->body.key.action = msg.body.key.action;
65 // int32_t flags
66 outMsg->body.key.flags = msg.body.key.flags;
67 // int32_t keyCode
68 outMsg->body.key.keyCode = msg.body.key.keyCode;
69 // int32_t scanCode
70 outMsg->body.key.scanCode = msg.body.key.scanCode;
71 // int32_t metaState
72 outMsg->body.key.metaState = msg.body.key.metaState;
73 // int32_t repeatCount
74 outMsg->body.key.repeatCount = msg.body.key.repeatCount;
75 // nsecs_t downTime
76 outMsg->body.key.downTime = msg.body.key.downTime;
77 break;
78 }
79 case InputMessage::Type::MOTION: {
80 // int32_t eventId
81 outMsg->body.motion.eventId = msg.body.key.eventId;
82 // uint32_t pointerCount
83 outMsg->body.motion.pointerCount = msg.body.motion.pointerCount;
84 // nsecs_t eventTime
85 outMsg->body.motion.eventTime = msg.body.motion.eventTime;
86 // int32_t deviceId
87 outMsg->body.motion.deviceId = msg.body.motion.deviceId;
88 // int32_t source
89 outMsg->body.motion.source = msg.body.motion.source;
90 // int32_t displayId
91 outMsg->body.motion.displayId = msg.body.motion.displayId;
92 // std::array<uint8_t, 32> hmac
93 outMsg->body.motion.hmac = msg.body.motion.hmac;
94 // int32_t action
95 outMsg->body.motion.action = msg.body.motion.action;
96 // int32_t actionButton
97 outMsg->body.motion.actionButton = msg.body.motion.actionButton;
98 // int32_t flags
99 outMsg->body.motion.flags = msg.body.motion.flags;
100 // int32_t metaState
101 outMsg->body.motion.metaState = msg.body.motion.metaState;
102 // int32_t buttonState
103 outMsg->body.motion.buttonState = msg.body.motion.buttonState;
104 // MotionClassification classification
105 outMsg->body.motion.classification = msg.body.motion.classification;
106 // int32_t edgeFlags
107 outMsg->body.motion.edgeFlags = msg.body.motion.edgeFlags;
108 // nsecs_t downTime
109 outMsg->body.motion.downTime = msg.body.motion.downTime;
110 // float dsdx
111 outMsg->body.motion.dsdx = msg.body.motion.dsdx;
112 // float dtdx
113 outMsg->body.motion.dtdx = msg.body.motion.dtdx;
114 // float dtdy
115 outMsg->body.motion.dtdy = msg.body.motion.dtdy;
116 // float dsdy
117 outMsg->body.motion.dsdy = msg.body.motion.dsdy;
118 // float tx
119 outMsg->body.motion.tx = msg.body.motion.tx;
120 // float ty
121 outMsg->body.motion.ty = msg.body.motion.ty;
122 // float xPrecision
123 outMsg->body.motion.xPrecision = msg.body.motion.xPrecision;
124 // float yPrecision
125 outMsg->body.motion.yPrecision = msg.body.motion.yPrecision;
126 // float xCursorPosition
127 outMsg->body.motion.xCursorPosition = msg.body.motion.xCursorPosition;
128 // float yCursorPosition
129 outMsg->body.motion.yCursorPosition = msg.body.motion.yCursorPosition;
130 // float dsdxDisplay
131 outMsg->body.motion.dsdxRaw = msg.body.motion.dsdxRaw;
132 // float dtdxDisplay
133 outMsg->body.motion.dtdxRaw = msg.body.motion.dtdxRaw;
134 // float dtdyDisplay
135 outMsg->body.motion.dtdyRaw = msg.body.motion.dtdyRaw;
136 // float dsdyDisplay
137 outMsg->body.motion.dsdyRaw = msg.body.motion.dsdyRaw;
138 // float txDisplay
139 outMsg->body.motion.txRaw = msg.body.motion.txRaw;
140 // float tyDisplay
141 outMsg->body.motion.tyRaw = msg.body.motion.tyRaw;
142 //struct Pointer pointers[MAX_POINTERS]
143 for (size_t i = 0; i < msg.body.motion.pointerCount; i++) {
144 // PointerProperties properties
145 outMsg->body.motion.pointers[i].properties.id =
146 msg.body.motion.pointers[i].properties.id;
147 outMsg->body.motion.pointers[i].properties.toolType =
148 msg.body.motion.pointers[i].properties.toolType;
149 // PointerCoords coords
150 outMsg->body.motion.pointers[i].coords.bits =
151 msg.body.motion.pointers[i].coords.bits;
152 const uint32_t count = BitSet64::count(msg.body.motion.pointers[i].coords.bits);
153 memcpy(&outMsg->body.motion.pointers[i].coords.values[0],
154 &msg.body.motion.pointers[i].coords.values[0],
155 count * sizeof(msg.body.motion.pointers[i].coords.values[0]));
156 outMsg->body.motion.pointers[i].coords.isResampled =
157 msg.body.motion.pointers[i].coords.isResampled;
158 }
159 break;
160 }
161 case InputMessage::Type::FINISHED: {
162 outMsg->body.finished.handled = msg.body.finished.handled;
163 outMsg->body.finished.consumeTime = msg.body.finished.consumeTime;
164 break;
165 }
166 case InputMessage::Type::FOCUS: {
167 outMsg->body.focus.eventId = msg.body.focus.eventId;
168 outMsg->body.focus.hasFocus = msg.body.focus.hasFocus;
169 break;
170 }
171 case InputMessage::Type::CAPTURE: {
172 outMsg->body.capture.eventId = msg.body.capture.eventId;
173 outMsg->body.capture.pointerCaptureEnabled = msg.body.capture.pointerCaptureEnabled;
174 break;
175 }
176 case InputMessage::Type::DRAG: {
177 outMsg->body.capture.eventId = msg.body.capture.eventId;
178 outMsg->body.drag.isExiting = msg.body.drag.isExiting;
179 outMsg->body.drag.x = msg.body.drag.x;
180 outMsg->body.drag.y = msg.body.drag.y;
181 break;
182 }
183 case InputMessage::Type::TIMELINE: {
184 outMsg->body.timeline.eventId = msg.body.timeline.eventId;
185 outMsg->body.timeline.graphicsTimeline = msg.body.timeline.graphicsTimeline;
186 break;
187 }
188 case InputMessage::Type::TOUCH_MODE: {
189 outMsg->body.touchMode.eventId = msg.body.timeline.eventId;
190 outMsg->body.touchMode.isInTouchMode = msg.body.touchMode.isInTouchMode;
191 }
192 }
193 }
194
makeMessageValid(InputMessage & msg)195 static void makeMessageValid(InputMessage& msg) {
196 InputMessage::Type type = msg.header.type;
197 if (type == InputMessage::Type::MOTION) {
198 // Message is considered invalid if it has more than MAX_POINTERS pointers.
199 msg.body.motion.pointerCount = MAX_POINTERS;
200 }
201 if (type == InputMessage::Type::TIMELINE) {
202 // Message is considered invalid if presentTime <= gpuCompletedTime
203 msg.body.timeline.graphicsTimeline[GraphicsTimeline::GPU_COMPLETED_TIME] = 10;
204 msg.body.timeline.graphicsTimeline[GraphicsTimeline::PRESENT_TIME] = 20;
205 }
206 }
207
208 /**
209 * Return false if vulnerability is found for a given message type
210 */
checkMessage(InputChannel & server,InputChannel & client,InputMessage::Type type)211 static bool checkMessage(InputChannel& server, InputChannel& client, InputMessage::Type type) {
212 InputMessage serverMsg;
213 // Set all potentially uninitialized bytes to 1, for easier comparison
214
215 memset(&serverMsg, 1, sizeof(serverMsg));
216 serverMsg.header.type = type;
217 makeMessageValid(serverMsg);
218 status_t result = server.sendMessage(&serverMsg);
219 if (result != OK) {
220 ALOGE("Could not send message to the input channel");
221 return false;
222 }
223
224 InputMessage clientMsg;
225 result = client.receiveMessage(&clientMsg);
226 if (result != OK) {
227 ALOGE("Could not receive message from the input channel");
228 return false;
229 }
230 if (serverMsg.header.type != clientMsg.header.type) {
231 ALOGE("Types do not match");
232 return false;
233 }
234
235 InputMessage sanitizedClientMsg;
236 sanitizeMessage(clientMsg, &sanitizedClientMsg);
237 if (memcmp(&clientMsg, &sanitizedClientMsg, clientMsg.size()) != 0) {
238 ALOGE("Client received un-sanitized message");
239 ALOGE("Received message: %s", memoryAsHexString(&clientMsg, clientMsg.size()).c_str());
240 ALOGE("Expected message: %s",
241 memoryAsHexString(&sanitizedClientMsg, clientMsg.size()).c_str());
242 return false;
243 }
244
245 return true;
246 }
247
248 /**
249 * Create an unsanitized message
250 * Send
251 * Receive
252 * Compare the received message to a sanitized expected message
253 * Do this for all message types
254 */
main()255 int main() {
256 std::unique_ptr<InputChannel> server, client;
257
258 status_t result = InputChannel::openInputChannelPair("channel name", server, client);
259 if (result != OK) {
260 ALOGE("Could not open input channel pair");
261 return 0;
262 }
263
264 InputMessage::Type types[] = {
265 InputMessage::Type::KEY, InputMessage::Type::MOTION,
266 InputMessage::Type::FINISHED, InputMessage::Type::FOCUS,
267 InputMessage::Type::CAPTURE, InputMessage::Type::DRAG,
268 InputMessage::Type::TIMELINE, InputMessage::Type::TOUCH_MODE,
269 };
270 for (InputMessage::Type type : types) {
271 bool success = checkMessage(*server, *client, type);
272 if (!success) {
273 ALOGE("Check message failed for type %i", type);
274 return EXIT_VULNERABLE;
275 }
276 }
277
278 return 0;
279 }
280