1 /**
2  * This fuzz target performs a lz4 streaming round-trip test
3  * (compress & decompress), compares the result with the original, and calls
4  * abort() on corruption.
5  */
6 
7 #include <stddef.h>
8 #include <stdint.h>
9 #include <stdlib.h>
10 #include <string.h>
11 
12 #include "fuzz_helpers.h"
13 #define LZ4_STATIC_LINKING_ONLY
14 #include "lz4.h"
15 #define LZ4_HC_STATIC_LINKING_ONLY
16 #include "lz4hc.h"
17 
18 typedef struct {
19   char const* buf;
20   size_t size;
21   size_t pos;
22 } const_cursor_t;
23 
24 typedef struct {
25   char* buf;
26   size_t size;
27   size_t pos;
28 } cursor_t;
29 
30 typedef struct {
31   LZ4_stream_t* cstream;
32   LZ4_streamHC_t* cstreamHC;
33   LZ4_streamDecode_t* dstream;
34   const_cursor_t data;
35   cursor_t compressed;
36   cursor_t roundTrip;
37   uint32_t seed;
38   int level;
39 } state_t;
40 
cursor_create(size_t size)41 cursor_t cursor_create(size_t size)
42 {
43   cursor_t cursor;
44   cursor.buf = (char*)malloc(size);
45   cursor.size = size;
46   cursor.pos = 0;
47   FUZZ_ASSERT(cursor.buf);
48   return cursor;
49 }
50 
51 typedef void (*round_trip_t)(state_t* state);
52 
cursor_free(cursor_t cursor)53 void cursor_free(cursor_t cursor)
54 {
55     free(cursor.buf);
56 }
57 
state_create(char const * data,size_t size,uint32_t seed)58 state_t state_create(char const* data, size_t size, uint32_t seed)
59 {
60     state_t state;
61 
62     state.seed = seed;
63 
64     state.data.buf = (char const*)data;
65     state.data.size = size;
66     state.data.pos = 0;
67 
68     /* Extra margin because we are streaming. */
69     state.compressed = cursor_create(1024 + 2 * LZ4_compressBound(size));
70     state.roundTrip = cursor_create(size);
71 
72     state.cstream = LZ4_createStream();
73     FUZZ_ASSERT(state.cstream);
74     state.cstreamHC = LZ4_createStreamHC();
75     FUZZ_ASSERT(state.cstream);
76     state.dstream = LZ4_createStreamDecode();
77     FUZZ_ASSERT(state.dstream);
78 
79     return state;
80 }
81 
state_free(state_t state)82 void state_free(state_t state)
83 {
84     cursor_free(state.compressed);
85     cursor_free(state.roundTrip);
86     LZ4_freeStream(state.cstream);
87     LZ4_freeStreamHC(state.cstreamHC);
88     LZ4_freeStreamDecode(state.dstream);
89 }
90 
state_reset(state_t * state,uint32_t seed)91 static void state_reset(state_t* state, uint32_t seed)
92 {
93     state->level = FUZZ_rand32(&seed, LZ4HC_CLEVEL_MIN, LZ4HC_CLEVEL_MAX);
94     LZ4_resetStream_fast(state->cstream);
95     LZ4_resetStreamHC_fast(state->cstreamHC, state->level);
96     LZ4_setStreamDecode(state->dstream, NULL, 0);
97     state->data.pos = 0;
98     state->compressed.pos = 0;
99     state->roundTrip.pos = 0;
100     state->seed = seed;
101 }
102 
state_decompress(state_t * state,char const * src,int srcSize)103 static void state_decompress(state_t* state, char const* src, int srcSize)
104 {
105     char* dst = state->roundTrip.buf + state->roundTrip.pos;
106     int const dstCapacity = state->roundTrip.size - state->roundTrip.pos;
107     int const dSize = LZ4_decompress_safe_continue(state->dstream, src, dst,
108                                                    srcSize, dstCapacity);
109     FUZZ_ASSERT(dSize >= 0);
110     state->roundTrip.pos += dSize;
111 }
112 
state_checkRoundTrip(state_t const * state)113 static void state_checkRoundTrip(state_t const* state)
114 {
115     char const* data = state->data.buf;
116     size_t const size = state->data.size;
117     FUZZ_ASSERT_MSG(size == state->roundTrip.pos, "Incorrect size!");
118     FUZZ_ASSERT_MSG(!memcmp(data, state->roundTrip.buf, size), "Corruption!");
119 }
120 
121 /**
122  * Picks a dictionary size and trims the dictionary off of the data.
123  * We copy the dictionary to the roundTrip so our validation passes.
124  */
state_trimDict(state_t * state)125 static size_t state_trimDict(state_t* state)
126 {
127     /* 64 KB is the max dict size, allow slightly beyond that to test trim. */
128     uint32_t maxDictSize = MIN(70 * 1024, state->data.size);
129     size_t const dictSize = FUZZ_rand32(&state->seed, 0, maxDictSize);
130     DEBUGLOG(2, "dictSize = %zu", dictSize);
131     FUZZ_ASSERT(state->data.pos == 0);
132     FUZZ_ASSERT(state->roundTrip.pos == 0);
133     memcpy(state->roundTrip.buf, state->data.buf, dictSize);
134     state->data.pos += dictSize;
135     state->roundTrip.pos += dictSize;
136     return dictSize;
137 }
138 
state_prefixRoundTrip(state_t * state)139 static void state_prefixRoundTrip(state_t* state)
140 {
141     while (state->data.pos != state->data.size) {
142         char const* src = state->data.buf + state->data.pos;
143         char* dst = state->compressed.buf + state->compressed.pos;
144         int const srcRemaining = state->data.size - state->data.pos;
145         int const srcSize = FUZZ_rand32(&state->seed, 0, srcRemaining);
146         int const dstCapacity = state->compressed.size - state->compressed.pos;
147         int const cSize = LZ4_compress_fast_continue(state->cstream, src, dst,
148                                                      srcSize, dstCapacity, 0);
149         FUZZ_ASSERT(cSize > 0);
150         state->data.pos += srcSize;
151         state->compressed.pos += cSize;
152         state_decompress(state, dst, cSize);
153     }
154 }
155 
state_extDictRoundTrip(state_t * state)156 static void state_extDictRoundTrip(state_t* state)
157 {
158     int i = 0;
159     cursor_t data2 = cursor_create(state->data.size);
160     memcpy(data2.buf, state->data.buf, state->data.size);
161     while (state->data.pos != state->data.size) {
162         char const* data = (i++ & 1) ? state->data.buf : data2.buf;
163         char const* src = data + state->data.pos;
164         char* dst = state->compressed.buf + state->compressed.pos;
165         int const srcRemaining = state->data.size - state->data.pos;
166         int const srcSize = FUZZ_rand32(&state->seed, 0, srcRemaining);
167         int const dstCapacity = state->compressed.size - state->compressed.pos;
168         int const cSize = LZ4_compress_fast_continue(state->cstream, src, dst,
169                                                      srcSize, dstCapacity, 0);
170         FUZZ_ASSERT(cSize > 0);
171         state->data.pos += srcSize;
172         state->compressed.pos += cSize;
173         state_decompress(state, dst, cSize);
174     }
175     cursor_free(data2);
176 }
177 
state_randomRoundTrip(state_t * state,round_trip_t rt0,round_trip_t rt1)178 static void state_randomRoundTrip(state_t* state, round_trip_t rt0,
179                                   round_trip_t rt1)
180 {
181     if (FUZZ_rand32(&state->seed, 0, 1)) {
182       rt0(state);
183     } else {
184       rt1(state);
185     }
186 }
187 
state_loadDictRoundTrip(state_t * state)188 static void state_loadDictRoundTrip(state_t* state)
189 {
190     char const* dict = state->data.buf;
191     size_t const dictSize = state_trimDict(state);
192     LZ4_loadDict(state->cstream, dict, dictSize);
193     LZ4_setStreamDecode(state->dstream, dict, dictSize);
194     state_randomRoundTrip(state, state_prefixRoundTrip, state_extDictRoundTrip);
195 }
196 
state_attachDictRoundTrip(state_t * state)197 static void state_attachDictRoundTrip(state_t* state)
198 {
199     char const* dict = state->data.buf;
200     size_t const dictSize = state_trimDict(state);
201     LZ4_stream_t* dictStream = LZ4_createStream();
202     LZ4_loadDict(dictStream, dict, dictSize);
203     LZ4_attach_dictionary(state->cstream, dictStream);
204     LZ4_setStreamDecode(state->dstream, dict, dictSize);
205     state_randomRoundTrip(state, state_prefixRoundTrip, state_extDictRoundTrip);
206     LZ4_freeStream(dictStream);
207 }
208 
state_prefixHCRoundTrip(state_t * state)209 static void state_prefixHCRoundTrip(state_t* state)
210 {
211     while (state->data.pos != state->data.size) {
212         char const* src = state->data.buf + state->data.pos;
213         char* dst = state->compressed.buf + state->compressed.pos;
214         int const srcRemaining = state->data.size - state->data.pos;
215         int const srcSize = FUZZ_rand32(&state->seed, 0, srcRemaining);
216         int const dstCapacity = state->compressed.size - state->compressed.pos;
217         int const cSize = LZ4_compress_HC_continue(state->cstreamHC, src, dst,
218                                                    srcSize, dstCapacity);
219         FUZZ_ASSERT(cSize > 0);
220         state->data.pos += srcSize;
221         state->compressed.pos += cSize;
222         state_decompress(state, dst, cSize);
223     }
224 }
225 
state_extDictHCRoundTrip(state_t * state)226 static void state_extDictHCRoundTrip(state_t* state)
227 {
228     int i = 0;
229     cursor_t data2 = cursor_create(state->data.size);
230     DEBUGLOG(2, "extDictHC");
231     memcpy(data2.buf, state->data.buf, state->data.size);
232     while (state->data.pos != state->data.size) {
233         char const* data = (i++ & 1) ? state->data.buf : data2.buf;
234         char const* src = data + state->data.pos;
235         char* dst = state->compressed.buf + state->compressed.pos;
236         int const srcRemaining = state->data.size - state->data.pos;
237         int const srcSize = FUZZ_rand32(&state->seed, 0, srcRemaining);
238         int const dstCapacity = state->compressed.size - state->compressed.pos;
239         int const cSize = LZ4_compress_HC_continue(state->cstreamHC, src, dst,
240                                                    srcSize, dstCapacity);
241         FUZZ_ASSERT(cSize > 0);
242         DEBUGLOG(2, "srcSize = %d", srcSize);
243         state->data.pos += srcSize;
244         state->compressed.pos += cSize;
245         state_decompress(state, dst, cSize);
246     }
247     cursor_free(data2);
248 }
249 
state_loadDictHCRoundTrip(state_t * state)250 static void state_loadDictHCRoundTrip(state_t* state)
251 {
252     char const* dict = state->data.buf;
253     size_t const dictSize = state_trimDict(state);
254     LZ4_loadDictHC(state->cstreamHC, dict, dictSize);
255     LZ4_setStreamDecode(state->dstream, dict, dictSize);
256     state_randomRoundTrip(state, state_prefixHCRoundTrip,
257                           state_extDictHCRoundTrip);
258 }
259 
state_attachDictHCRoundTrip(state_t * state)260 static void state_attachDictHCRoundTrip(state_t* state)
261 {
262     char const* dict = state->data.buf;
263     size_t const dictSize = state_trimDict(state);
264     LZ4_streamHC_t* dictStream = LZ4_createStreamHC();
265     LZ4_setCompressionLevel(dictStream, state->level);
266     LZ4_loadDictHC(dictStream, dict, dictSize);
267     LZ4_attach_HC_dictionary(state->cstreamHC, dictStream);
268     LZ4_setStreamDecode(state->dstream, dict, dictSize);
269     state_randomRoundTrip(state, state_prefixHCRoundTrip,
270                           state_extDictHCRoundTrip);
271     LZ4_freeStreamHC(dictStream);
272 }
273 
274 round_trip_t roundTrips[] = {
275   &state_prefixRoundTrip,
276   &state_extDictRoundTrip,
277   &state_loadDictRoundTrip,
278   &state_attachDictRoundTrip,
279   &state_prefixHCRoundTrip,
280   &state_extDictHCRoundTrip,
281   &state_loadDictHCRoundTrip,
282   &state_attachDictHCRoundTrip,
283 };
284 
LLVMFuzzerTestOneInput(const uint8_t * data,size_t size)285 int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size)
286 {
287     uint32_t seed = FUZZ_seed(&data, &size);
288     state_t state = state_create((char const*)data, size, seed);
289     const int n = sizeof(roundTrips) / sizeof(round_trip_t);
290     int i;
291 
292     for (i = 0; i < n; ++i) {
293         DEBUGLOG(2, "Round trip %d", i);
294         state_reset(&state, seed);
295         roundTrips[i](&state);
296         state_checkRoundTrip(&state);
297     }
298 
299     state_free(state);
300 
301     return 0;
302 }
303