1 /*
2  * Copyright 2019 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     https://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "cppbor.h"
18 
19 #include <inttypes.h>
20 #include <openssl/sha.h>
21 #include <cstdint>
22 
23 #include "cppbor_parse.h"
24 
25 using std::string;
26 using std::vector;
27 
28 #ifndef __TRUSTY__
29 #include <android-base/logging.h>
30 #define LOG_TAG "CppBor"
31 #else
32 #define CHECK(x) (void)(x)
33 #endif
34 
35 namespace cppbor {
36 
37 namespace {
38 
39 template <typename T, typename Iterator, typename = std::enable_if<std::is_unsigned<T>::value>>
40 Iterator writeBigEndian(T value, Iterator pos) {
41     for (unsigned i = 0; i < sizeof(value); ++i) {
42         *pos++ = static_cast<uint8_t>(value >> (8 * (sizeof(value) - 1)));
43         value = static_cast<T>(value << 8);
44     }
45     return pos;
46 }
47 
48 template <typename T, typename = std::enable_if<std::is_unsigned<T>::value>>
writeBigEndian(T value,std::function<void (uint8_t)> & cb)49 void writeBigEndian(T value, std::function<void(uint8_t)>& cb) {
50     for (unsigned i = 0; i < sizeof(value); ++i) {
51         cb(static_cast<uint8_t>(value >> (8 * (sizeof(value) - 1))));
52         value = static_cast<T>(value << 8);
53     }
54 }
55 
cborAreAllElementsNonCompound(const Item * compoundItem)56 bool cborAreAllElementsNonCompound(const Item* compoundItem) {
57     if (compoundItem->type() == ARRAY) {
58         const Array* array = compoundItem->asArray();
59         for (size_t n = 0; n < array->size(); n++) {
60             const Item* entry = (*array)[n].get();
61             switch (entry->type()) {
62                 case ARRAY:
63                 case MAP:
64                     return false;
65                 default:
66                     break;
67             }
68         }
69     } else {
70         const Map* map = compoundItem->asMap();
71         for (auto& [keyEntry, valueEntry] : *map) {
72             switch (keyEntry->type()) {
73                 case ARRAY:
74                 case MAP:
75                     return false;
76                 default:
77                     break;
78             }
79             switch (valueEntry->type()) {
80                 case ARRAY:
81                 case MAP:
82                     return false;
83                 default:
84                     break;
85             }
86         }
87     }
88     return true;
89 }
90 
prettyPrintInternal(const Item * item,string & out,size_t indent,size_t maxBStrSize,const vector<string> & mapKeysToNotPrint)91 bool prettyPrintInternal(const Item* item, string& out, size_t indent, size_t maxBStrSize,
92                          const vector<string>& mapKeysToNotPrint) {
93     if (!item) {
94         out.append("<NULL>");
95         return false;
96     }
97 
98     char buf[80];
99 
100     string indentString(indent, ' ');
101 
102     size_t tagCount = item->semanticTagCount();
103     while (tagCount > 0) {
104         --tagCount;
105         snprintf(buf, sizeof(buf), "tag %" PRIu64 " ", item->semanticTag(tagCount));
106         out.append(buf);
107     }
108 
109     switch (item->type()) {
110         case SEMANTIC:
111             // Handled above.
112             break;
113 
114         case UINT:
115             snprintf(buf, sizeof(buf), "%" PRIu64, item->asUint()->unsignedValue());
116             out.append(buf);
117             break;
118 
119         case NINT:
120             snprintf(buf, sizeof(buf), "%" PRId64, item->asNint()->value());
121             out.append(buf);
122             break;
123 
124         case BSTR: {
125             const uint8_t* valueData;
126             size_t valueSize;
127             const Bstr* bstr = item->asBstr();
128             if (bstr != nullptr) {
129                 const vector<uint8_t>& value = bstr->value();
130                 valueData = value.data();
131                 valueSize = value.size();
132             } else {
133                 const ViewBstr* viewBstr = item->asViewBstr();
134                 assert(viewBstr != nullptr);
135 
136                 valueData = viewBstr->view().data();
137                 valueSize = viewBstr->view().size();
138             }
139 
140             if (valueSize > maxBStrSize) {
141                 unsigned char digest[SHA_DIGEST_LENGTH];
142                 SHA_CTX ctx;
143                 SHA1_Init(&ctx);
144                 SHA1_Update(&ctx, valueData, valueSize);
145                 SHA1_Final(digest, &ctx);
146                 char buf2[SHA_DIGEST_LENGTH * 2 + 1];
147                 for (size_t n = 0; n < SHA_DIGEST_LENGTH; n++) {
148                     snprintf(buf2 + n * 2, 3, "%02x", digest[n]);
149                 }
150                 snprintf(buf, sizeof(buf), "<bstr size=%zd sha1=%s>", valueSize, buf2);
151                 out.append(buf);
152             } else {
153                 out.append("{");
154                 for (size_t n = 0; n < valueSize; n++) {
155                     if (n > 0) {
156                         out.append(", ");
157                     }
158                     snprintf(buf, sizeof(buf), "0x%02x", valueData[n]);
159                     out.append(buf);
160                 }
161                 out.append("}");
162             }
163         } break;
164 
165         case TSTR:
166             out.append("'");
167             {
168                 // TODO: escape "'" characters
169                 if (item->asTstr() != nullptr) {
170                     out.append(item->asTstr()->value().c_str());
171                 } else {
172                     const ViewTstr* viewTstr = item->asViewTstr();
173                     assert(viewTstr != nullptr);
174                     out.append(viewTstr->view());
175                 }
176             }
177             out.append("'");
178             break;
179 
180         case ARRAY: {
181             const Array* array = item->asArray();
182             if (array->size() == 0) {
183                 out.append("[]");
184             } else if (cborAreAllElementsNonCompound(array)) {
185                 out.append("[");
186                 for (size_t n = 0; n < array->size(); n++) {
187                     if (!prettyPrintInternal((*array)[n].get(), out, indent + 2, maxBStrSize,
188                                              mapKeysToNotPrint)) {
189                         return false;
190                     }
191                     out.append(", ");
192                 }
193                 out.append("]");
194             } else {
195                 out.append("[\n" + indentString);
196                 for (size_t n = 0; n < array->size(); n++) {
197                     out.append("  ");
198                     if (!prettyPrintInternal((*array)[n].get(), out, indent + 2, maxBStrSize,
199                                              mapKeysToNotPrint)) {
200                         return false;
201                     }
202                     out.append(",\n" + indentString);
203                 }
204                 out.append("]");
205             }
206         } break;
207 
208         case MAP: {
209             const Map* map = item->asMap();
210 
211             if (map->size() == 0) {
212                 out.append("{}");
213             } else {
214                 out.append("{\n" + indentString);
215                 for (auto& [map_key, map_value] : *map) {
216                     out.append("  ");
217 
218                     if (!prettyPrintInternal(map_key.get(), out, indent + 2, maxBStrSize,
219                                              mapKeysToNotPrint)) {
220                         return false;
221                     }
222                     out.append(" : ");
223                     if (map_key->type() == TSTR &&
224                         std::find(mapKeysToNotPrint.begin(), mapKeysToNotPrint.end(),
225                                   map_key->asTstr()->value()) != mapKeysToNotPrint.end()) {
226                         out.append("<not printed>");
227                     } else {
228                         if (!prettyPrintInternal(map_value.get(), out, indent + 2, maxBStrSize,
229                                                  mapKeysToNotPrint)) {
230                             return false;
231                         }
232                     }
233                     out.append(",\n" + indentString);
234                 }
235                 out.append("}");
236             }
237         } break;
238 
239         case SIMPLE:
240             const Bool* asBool = item->asSimple()->asBool();
241             const Null* asNull = item->asSimple()->asNull();
242             if (asBool != nullptr) {
243                 out.append(asBool->value() ? "true" : "false");
244             } else if (asNull != nullptr) {
245                 out.append("null");
246             } else {
247 #ifndef __TRUSTY__
248                 LOG(ERROR) << "Only boolean/null is implemented for SIMPLE";
249 #endif  // __TRUSTY__
250                 return false;
251             }
252             break;
253     }
254 
255     return true;
256 }
257 
258 }  // namespace
259 
headerSize(uint64_t addlInfo)260 size_t headerSize(uint64_t addlInfo) {
261     if (addlInfo < ONE_BYTE_LENGTH) return 1;
262     if (addlInfo <= std::numeric_limits<uint8_t>::max()) return 2;
263     if (addlInfo <= std::numeric_limits<uint16_t>::max()) return 3;
264     if (addlInfo <= std::numeric_limits<uint32_t>::max()) return 5;
265     return 9;
266 }
267 
encodeHeader(MajorType type,uint64_t addlInfo,uint8_t * pos,const uint8_t * end)268 uint8_t* encodeHeader(MajorType type, uint64_t addlInfo, uint8_t* pos, const uint8_t* end) {
269     size_t sz = headerSize(addlInfo);
270     if (end - pos < static_cast<ssize_t>(sz)) return nullptr;
271     switch (sz) {
272         case 1:
273             *pos++ = type | static_cast<uint8_t>(addlInfo);
274             return pos;
275         case 2:
276             *pos++ = type | static_cast<MajorType>(ONE_BYTE_LENGTH);
277             *pos++ = static_cast<uint8_t>(addlInfo);
278             return pos;
279         case 3:
280             *pos++ = type | static_cast<MajorType>(TWO_BYTE_LENGTH);
281             return writeBigEndian(static_cast<uint16_t>(addlInfo), pos);
282         case 5:
283             *pos++ = type | static_cast<MajorType>(FOUR_BYTE_LENGTH);
284             return writeBigEndian(static_cast<uint32_t>(addlInfo), pos);
285         case 9:
286             *pos++ = type | static_cast<MajorType>(EIGHT_BYTE_LENGTH);
287             return writeBigEndian(addlInfo, pos);
288         default:
289             CHECK(false);  // Impossible to get here.
290             return nullptr;
291     }
292 }
293 
encodeHeader(MajorType type,uint64_t addlInfo,EncodeCallback encodeCallback)294 void encodeHeader(MajorType type, uint64_t addlInfo, EncodeCallback encodeCallback) {
295     size_t sz = headerSize(addlInfo);
296     switch (sz) {
297         case 1:
298             encodeCallback(type | static_cast<uint8_t>(addlInfo));
299             break;
300         case 2:
301             encodeCallback(type | static_cast<MajorType>(ONE_BYTE_LENGTH));
302             encodeCallback(static_cast<uint8_t>(addlInfo));
303             break;
304         case 3:
305             encodeCallback(type | static_cast<MajorType>(TWO_BYTE_LENGTH));
306             writeBigEndian(static_cast<uint16_t>(addlInfo), encodeCallback);
307             break;
308         case 5:
309             encodeCallback(type | static_cast<MajorType>(FOUR_BYTE_LENGTH));
310             writeBigEndian(static_cast<uint32_t>(addlInfo), encodeCallback);
311             break;
312         case 9:
313             encodeCallback(type | static_cast<MajorType>(EIGHT_BYTE_LENGTH));
314             writeBigEndian(addlInfo, encodeCallback);
315             break;
316         default:
317             CHECK(false);  // Impossible to get here.
318     }
319 }
320 
operator ==(const Item & other) const321 bool Item::operator==(const Item& other) const& {
322     if (type() != other.type()) return false;
323     switch (type()) {
324         case UINT:
325             return *asUint() == *(other.asUint());
326         case NINT:
327             return *asNint() == *(other.asNint());
328         case BSTR:
329             if (asBstr() != nullptr && other.asBstr() != nullptr) {
330                 return *asBstr() == *(other.asBstr());
331             }
332             if (asViewBstr() != nullptr && other.asViewBstr() != nullptr) {
333                 return *asViewBstr() == *(other.asViewBstr());
334             }
335             // Interesting corner case: comparing a Bstr and ViewBstr with
336             // identical contents. The function currently returns false for
337             // this case.
338             // TODO: if it should return true, this needs a deep comparison
339             return false;
340         case TSTR:
341             if (asTstr() != nullptr && other.asTstr() != nullptr) {
342                 return *asTstr() == *(other.asTstr());
343             }
344             if (asViewTstr() != nullptr && other.asViewTstr() != nullptr) {
345                 return *asViewTstr() == *(other.asViewTstr());
346             }
347             // Same corner case as Bstr
348             return false;
349         case ARRAY:
350             return *asArray() == *(other.asArray());
351         case MAP:
352             return *asMap() == *(other.asMap());
353         case SIMPLE:
354             return *asSimple() == *(other.asSimple());
355         case SEMANTIC:
356             return *asSemanticTag() == *(other.asSemanticTag());
357         default:
358             CHECK(false);  // Impossible to get here.
359             return false;
360     }
361 }
362 
Nint(int64_t v)363 Nint::Nint(int64_t v) : mValue(v) {
364     CHECK(v < 0);
365 }
366 
operator ==(const Simple & other) const367 bool Simple::operator==(const Simple& other) const& {
368     if (simpleType() != other.simpleType()) return false;
369 
370     switch (simpleType()) {
371         case BOOLEAN:
372             return *asBool() == *(other.asBool());
373         case NULL_T:
374             return true;
375         default:
376             CHECK(false);  // Impossible to get here.
377             return false;
378     }
379 }
380 
encode(uint8_t * pos,const uint8_t * end) const381 uint8_t* Bstr::encode(uint8_t* pos, const uint8_t* end) const {
382     pos = encodeHeader(mValue.size(), pos, end);
383     if (!pos || end - pos < static_cast<ptrdiff_t>(mValue.size())) return nullptr;
384     return std::copy(mValue.begin(), mValue.end(), pos);
385 }
386 
encodeValue(EncodeCallback encodeCallback) const387 void Bstr::encodeValue(EncodeCallback encodeCallback) const {
388     for (auto c : mValue) {
389         encodeCallback(c);
390     }
391 }
392 
encode(uint8_t * pos,const uint8_t * end) const393 uint8_t* ViewBstr::encode(uint8_t* pos, const uint8_t* end) const {
394     pos = encodeHeader(mView.size(), pos, end);
395     if (!pos || end - pos < static_cast<ptrdiff_t>(mView.size())) return nullptr;
396     return std::copy(mView.begin(), mView.end(), pos);
397 }
398 
encodeValue(EncodeCallback encodeCallback) const399 void ViewBstr::encodeValue(EncodeCallback encodeCallback) const {
400     for (auto c : mView) {
401         encodeCallback(static_cast<uint8_t>(c));
402     }
403 }
404 
encode(uint8_t * pos,const uint8_t * end) const405 uint8_t* Tstr::encode(uint8_t* pos, const uint8_t* end) const {
406     pos = encodeHeader(mValue.size(), pos, end);
407     if (!pos || end - pos < static_cast<ptrdiff_t>(mValue.size())) return nullptr;
408     return std::copy(mValue.begin(), mValue.end(), pos);
409 }
410 
encodeValue(EncodeCallback encodeCallback) const411 void Tstr::encodeValue(EncodeCallback encodeCallback) const {
412     for (auto c : mValue) {
413         encodeCallback(static_cast<uint8_t>(c));
414     }
415 }
416 
encode(uint8_t * pos,const uint8_t * end) const417 uint8_t* ViewTstr::encode(uint8_t* pos, const uint8_t* end) const {
418     pos = encodeHeader(mView.size(), pos, end);
419     if (!pos || end - pos < static_cast<ptrdiff_t>(mView.size())) return nullptr;
420     return std::copy(mView.begin(), mView.end(), pos);
421 }
422 
encodeValue(EncodeCallback encodeCallback) const423 void ViewTstr::encodeValue(EncodeCallback encodeCallback) const {
424     for (auto c : mView) {
425         encodeCallback(static_cast<uint8_t>(c));
426     }
427 }
428 
operator ==(const Array & other) const429 bool Array::operator==(const Array& other) const& {
430     return size() == other.size()
431            // Can't use vector::operator== because the contents are pointers.  std::equal lets us
432            // provide a predicate that does the dereferencing.
433            && std::equal(mEntries.begin(), mEntries.end(), other.mEntries.begin(),
434                          [](auto& a, auto& b) -> bool { return *a == *b; });
435 }
436 
encode(uint8_t * pos,const uint8_t * end) const437 uint8_t* Array::encode(uint8_t* pos, const uint8_t* end) const {
438     pos = encodeHeader(size(), pos, end);
439     if (!pos) return nullptr;
440     for (auto& entry : mEntries) {
441         pos = entry->encode(pos, end);
442         if (!pos) return nullptr;
443     }
444     return pos;
445 }
446 
encode(EncodeCallback encodeCallback) const447 void Array::encode(EncodeCallback encodeCallback) const {
448     encodeHeader(size(), encodeCallback);
449     for (auto& entry : mEntries) {
450         entry->encode(encodeCallback);
451     }
452 }
453 
clone() const454 std::unique_ptr<Item> Array::clone() const {
455     auto res = std::make_unique<Array>();
456     for (size_t i = 0; i < mEntries.size(); i++) {
457         res->add(mEntries[i]->clone());
458     }
459     return res;
460 }
461 
operator ==(const Map & other) const462 bool Map::operator==(const Map& other) const& {
463     return size() == other.size()
464            // Can't use vector::operator== because the contents are pairs of pointers.  std::equal
465            // lets us provide a predicate that does the dereferencing.
466            && std::equal(begin(), end(), other.begin(), [](auto& a, auto& b) {
467                   return *a.first == *b.first && *a.second == *b.second;
468               });
469 }
470 
encode(uint8_t * pos,const uint8_t * end) const471 uint8_t* Map::encode(uint8_t* pos, const uint8_t* end) const {
472     pos = encodeHeader(size(), pos, end);
473     if (!pos) return nullptr;
474     for (auto& entry : mEntries) {
475         pos = entry.first->encode(pos, end);
476         if (!pos) return nullptr;
477         pos = entry.second->encode(pos, end);
478         if (!pos) return nullptr;
479     }
480     return pos;
481 }
482 
encode(EncodeCallback encodeCallback) const483 void Map::encode(EncodeCallback encodeCallback) const {
484     encodeHeader(size(), encodeCallback);
485     for (auto& entry : mEntries) {
486         entry.first->encode(encodeCallback);
487         entry.second->encode(encodeCallback);
488     }
489 }
490 
keyLess(const Item * a,const Item * b)491 bool Map::keyLess(const Item* a, const Item* b) {
492     // CBOR map canonicalization rules are:
493 
494     // 1. If two keys have different lengths, the shorter one sorts earlier.
495     if (a->encodedSize() < b->encodedSize()) return true;
496     if (a->encodedSize() > b->encodedSize()) return false;
497 
498     // 2. If two keys have the same length, the one with the lower value in (byte-wise) lexical
499     // order sorts earlier.  This requires encoding both items.
500     auto encodedA = a->encode();
501     auto encodedB = b->encode();
502 
503     return std::lexicographical_compare(encodedA.begin(), encodedA.end(),  //
504                                         encodedB.begin(), encodedB.end());
505 }
506 
recursivelyCanonicalize(std::unique_ptr<Item> & item)507 void recursivelyCanonicalize(std::unique_ptr<Item>& item) {
508     switch (item->type()) {
509         case UINT:
510         case NINT:
511         case BSTR:
512         case TSTR:
513         case SIMPLE:
514             return;
515 
516         case ARRAY:
517             std::for_each(item->asArray()->begin(), item->asArray()->end(),
518                           recursivelyCanonicalize);
519             return;
520 
521         case MAP:
522             item->asMap()->canonicalize(true /* recurse */);
523             return;
524 
525         case SEMANTIC:
526             // This can't happen.  SemanticTags delegate their type() method to the contained Item's
527             // type.
528             assert(false);
529             return;
530     }
531 }
532 
canonicalize(bool recurse)533 Map& Map::canonicalize(bool recurse) & {
534     if (recurse) {
535         for (auto& entry : mEntries) {
536             recursivelyCanonicalize(entry.first);
537             recursivelyCanonicalize(entry.second);
538         }
539     }
540 
541     if (size() < 2 || mCanonicalized) {
542         // Trivially or already canonical; do nothing.
543         return *this;
544     }
545 
546     std::sort(begin(), end(),
547               [](auto& a, auto& b) { return keyLess(a.first.get(), b.first.get()); });
548     mCanonicalized = true;
549     return *this;
550 }
551 
clone() const552 std::unique_ptr<Item> Map::clone() const {
553     auto res = std::make_unique<Map>();
554     for (auto& [key, value] : *this) {
555         res->add(key->clone(), value->clone());
556     }
557     res->mCanonicalized = mCanonicalized;
558     return res;
559 }
560 
clone() const561 std::unique_ptr<Item> SemanticTag::clone() const {
562     return std::make_unique<SemanticTag>(mValue, mTaggedItem->clone());
563 }
564 
encode(uint8_t * pos,const uint8_t * end) const565 uint8_t* SemanticTag::encode(uint8_t* pos, const uint8_t* end) const {
566     // Can't use the encodeHeader() method that calls type() to get the major type, since that will
567     // return the tagged Item's type.
568     pos = ::cppbor::encodeHeader(kMajorType, mValue, pos, end);
569     if (!pos) return nullptr;
570     return mTaggedItem->encode(pos, end);
571 }
572 
encode(EncodeCallback encodeCallback) const573 void SemanticTag::encode(EncodeCallback encodeCallback) const {
574     // Can't use the encodeHeader() method that calls type() to get the major type, since that will
575     // return the tagged Item's type.
576     ::cppbor::encodeHeader(kMajorType, mValue, encodeCallback);
577     mTaggedItem->encode(encodeCallback);
578 }
579 
semanticTagCount() const580 size_t SemanticTag::semanticTagCount() const {
581     size_t levelCount = 1;  // Count this level.
582     const SemanticTag* cur = this;
583     while (cur->mTaggedItem && (cur = cur->mTaggedItem->asSemanticTag()) != nullptr) ++levelCount;
584     return levelCount;
585 }
586 
semanticTag(size_t nesting) const587 uint64_t SemanticTag::semanticTag(size_t nesting) const {
588     // Getting the value of a specific nested tag is a bit tricky, because we start with the outer
589     // tag and don't know how many are inside.  We count the number of nesting levels to find out
590     // how many there are in total, then to get the one we want we have to walk down levelCount -
591     // nesting steps.
592     size_t levelCount = semanticTagCount();
593     if (nesting >= levelCount) return 0;
594 
595     levelCount -= nesting;
596     const SemanticTag* cur = this;
597     while (--levelCount > 0) cur = cur->mTaggedItem->asSemanticTag();
598 
599     return cur->mValue;
600 }
601 
prettyPrint(const Item * item,size_t maxBStrSize,const vector<string> & mapKeysToNotPrint)602 string prettyPrint(const Item* item, size_t maxBStrSize, const vector<string>& mapKeysToNotPrint) {
603     string out;
604     prettyPrintInternal(item, out, 0, maxBStrSize, mapKeysToNotPrint);
605     return out;
606 }
prettyPrint(const vector<uint8_t> & encodedCbor,size_t maxBStrSize,const vector<string> & mapKeysToNotPrint)607 string prettyPrint(const vector<uint8_t>& encodedCbor, size_t maxBStrSize,
608                    const vector<string>& mapKeysToNotPrint) {
609     auto [item, _, message] = parse(encodedCbor);
610     if (item == nullptr) {
611 #ifndef __TRUSTY__
612         LOG(ERROR) << "Data to pretty print is not valid CBOR: " << message;
613 #endif  // __TRUSTY__
614         return "";
615     }
616 
617     return prettyPrint(item.get(), maxBStrSize, mapKeysToNotPrint);
618 }
619 
620 }  // namespace cppbor
621