@@ -76,6 +76,9 @@ bool WebSocketClient::analyzeRequest() {
76
76
socket_client->print (F (" Sec-WebSocket-Key: " ));
77
77
socket_client->print (key);
78
78
socket_client->print (CRLF);
79
+ socket_client->print (F (" Sec-WebSocket-Protocol: " ));
80
+ socket_client->print (protocol);
81
+ socket_client->print (CRLF);
79
82
socket_client->print (F (" Sec-WebSocket-Version: 13\r\n " ));
80
83
socket_client->print (CRLF);
81
84
@@ -131,7 +134,7 @@ bool WebSocketClient::analyzeRequest() {
131
134
}
132
135
133
136
134
- String WebSocketClient::handleStream () {
137
+ bool WebSocketClient::handleStream (String& data, uint8_t *opcode ) {
135
138
uint8_t msgtype;
136
139
uint8_t bite;
137
140
unsigned int length;
@@ -140,93 +143,96 @@ String WebSocketClient::handleStream() {
140
143
unsigned int i;
141
144
bool hasMask = false ;
142
145
143
- // String to hold bytes sent by server to client
144
- String socketString;
146
+ if (!socket_client->connected () || !socket_client->available ())
147
+ {
148
+ return false ;
149
+ }
145
150
146
- if (socket_client->connected () && socket_client->available ()) {
151
+ msgtype = timedRead ();
152
+ if (!socket_client->connected ()) {
153
+ return false ;
154
+ }
147
155
148
- msgtype = timedRead ();
149
- if (!socket_client->connected ()) {
150
- return socketString;
151
- }
156
+ length = timedRead ();
157
+
158
+ if (length & WS_MASK) {
159
+ hasMask = true ;
160
+ length = length & ~WS_MASK;
161
+ }
152
162
153
- length = timedRead ();
154
163
155
- if (length > 127 ) {
156
- hasMask = true ;
157
- length = length & 127 ;
158
- }
164
+ if (!socket_client->connected ()) {
165
+ return false ;
166
+ }
159
167
168
+ index = 6 ;
160
169
170
+ if (length == WS_SIZE16) {
171
+ length = timedRead () << 8 ;
161
172
if (!socket_client->connected ()) {
162
- return socketString ;
173
+ return false ;
163
174
}
164
-
165
- index = 6 ;
166
-
167
- if (length == 126 ) {
168
- length = timedRead () << 8 ;
169
- if (!socket_client->connected ()) {
170
- return socketString;
171
- }
172
175
173
- length |= timedRead ();
174
- if (!socket_client->connected ()) {
175
- return socketString ;
176
- }
176
+ length |= timedRead ();
177
+ if (!socket_client->connected ()) {
178
+ return false ;
179
+ }
177
180
178
- } else if (length == 127 ) {
181
+ } else if (length == WS_SIZE64 ) {
179
182
#ifdef DEBUGGING
180
- Serial.println (F (" No support for over 16 bit sized messages" ));
183
+ Serial.println (F (" No support for over 16 bit sized messages" ));
181
184
#endif
182
- while (1 ) {
183
- // halt, can't handle this case
184
- }
185
- }
185
+ return false ;
186
+ }
186
187
187
- if (hasMask) {
188
- // get the mask
189
- mask[0 ] = timedRead ();
190
- if (!socket_client->connected ()) {
191
- return socketString ;
192
- }
188
+ if (hasMask) {
189
+ // get the mask
190
+ mask[0 ] = timedRead ();
191
+ if (!socket_client->connected ()) {
192
+ return false ;
193
+ }
193
194
194
- mask[1 ] = timedRead ();
195
- if (!socket_client->connected ()) {
195
+ mask[1 ] = timedRead ();
196
+ if (!socket_client->connected ()) {
196
197
197
- return socketString ;
198
- }
198
+ return false ;
199
+ }
199
200
200
- mask[2 ] = timedRead ();
201
- if (!socket_client->connected ()) {
202
- return socketString ;
203
- }
201
+ mask[2 ] = timedRead ();
202
+ if (!socket_client->connected ()) {
203
+ return false ;
204
+ }
204
205
205
- mask[3 ] = timedRead ();
206
+ mask[3 ] = timedRead ();
207
+ if (!socket_client->connected ()) {
208
+ return false ;
209
+ }
210
+ }
211
+
212
+ data = " " ;
213
+
214
+ if (opcode != NULL )
215
+ {
216
+ *opcode = msgtype & ~WS_FIN;
217
+ }
218
+
219
+ if (hasMask) {
220
+ for (i=0 ; i<length; ++i) {
221
+ data += (char ) (timedRead () ^ mask[i % 4 ]);
206
222
if (!socket_client->connected ()) {
207
- return socketString ;
223
+ return false ;
208
224
}
209
225
}
210
-
211
- if (hasMask) {
212
- for (i=0 ; i<length; ++i) {
213
- socketString += (char ) (timedRead () ^ mask[i % 4 ]);
214
- if (!socket_client->connected ()) {
215
- return socketString;
216
- }
226
+ } else {
227
+ for (i=0 ; i<length; ++i) {
228
+ data += (char ) timedRead ();
229
+ if (!socket_client->connected ()) {
230
+ return false ;
217
231
}
218
- } else {
219
- for (i=0 ; i<length; ++i) {
220
- socketString += (char ) timedRead ();
221
- if (!socket_client->connected ()) {
222
- return socketString;
223
- }
224
- }
225
- }
226
-
232
+ }
227
233
}
228
-
229
- return socketString ;
234
+
235
+ return true ;
230
236
}
231
237
232
238
void WebSocketClient::disconnectStream () {
@@ -242,31 +248,27 @@ void WebSocketClient::disconnectStream() {
242
248
socket_client->stop ();
243
249
}
244
250
245
- String WebSocketClient::getData () {
246
- String data;
251
+ bool WebSocketClient::getData (String& data, uint8_t *opcode) {
252
+ return handleStream (data, opcode);
253
+ }
247
254
248
- data = handleStream ();
249
-
250
- return data;
251
- }
252
-
253
- void WebSocketClient::sendData (const char *str) {
255
+ void WebSocketClient::sendData (const char *str, uint8_t opcode) {
254
256
#ifdef DEBUGGING
255
257
Serial.print (F (" Sending data: " ));
256
258
Serial.println (str);
257
259
#endif
258
260
if (socket_client->connected ()) {
259
- sendEncodedData (str);
261
+ sendEncodedData (str, opcode );
260
262
}
261
263
}
262
264
263
- void WebSocketClient::sendData (String str) {
265
+ void WebSocketClient::sendData (String str, uint8_t opcode ) {
264
266
#ifdef DEBUGGING
265
267
Serial.print (F (" Sending data: " ));
266
268
Serial.println (str);
267
269
#endif
268
270
if (socket_client->connected ()) {
269
- sendEncodedData (str);
271
+ sendEncodedData (str, opcode );
270
272
}
271
273
}
272
274
@@ -278,31 +280,42 @@ int WebSocketClient::timedRead() {
278
280
return socket_client->read ();
279
281
}
280
282
281
- void WebSocketClient::sendEncodedData (char *str) {
283
+ void WebSocketClient::sendEncodedData (char *str, uint8_t opcode) {
284
+ uint8_t mask[4 ];
282
285
int size = strlen (str);
283
286
284
- // string type
285
- socket_client->write (0x81 );
287
+ // Opcode; final fragment
288
+ socket_client->write (opcode | WS_FIN );
286
289
287
290
// NOTE: no support for > 16-bit sized messages
288
291
if (size > 125 ) {
289
- socket_client->write (126 );
292
+ socket_client->write (WS_SIZE16 | WS_MASK );
290
293
socket_client->write ((uint8_t ) (size >> 8 ));
291
- socket_client->write ((uint8_t ) (size && 0xFF ));
294
+ socket_client->write ((uint8_t ) (size & 0xFF ));
292
295
} else {
293
- socket_client->write ((uint8_t ) size);
296
+ socket_client->write ((uint8_t ) size | WS_MASK );
294
297
}
295
298
299
+ mask[0 ] = random (0 , 256 );
300
+ mask[1 ] = random (0 , 256 );
301
+ mask[2 ] = random (0 , 256 );
302
+ mask[3 ] = random (0 , 256 );
303
+
304
+ socket_client->write (mask[0 ]);
305
+ socket_client->write (mask[1 ]);
306
+ socket_client->write (mask[2 ]);
307
+ socket_client->write (mask[3 ]);
308
+
296
309
for (int i=0 ; i<size; ++i) {
297
- socket_client->write (str[i]);
310
+ socket_client->write (str[i] ^ mask[i % 4 ] );
298
311
}
299
312
}
300
313
301
- void WebSocketClient::sendEncodedData (String str) {
314
+ void WebSocketClient::sendEncodedData (String str, uint8_t opcode ) {
302
315
int size = str.length () + 1 ;
303
316
char cstr[size];
304
317
305
318
str.toCharArray (cstr, size);
306
319
307
- sendEncodedData (cstr);
320
+ sendEncodedData (cstr, opcode );
308
321
}
0 commit comments