Skip to content

Commit cd1482e

Browse files
committed
Fixed payload length bug; Added protocol header, ping/pong, data masking.
16-bit payload lengths were invalid due to a typo, && should be &. Added Sec-WebSocket-Protocol header. getData and sendData methods now take an opcode value. This is used to implement ping/pong. sendEncodedData now applies a mask to the data. WebSocket constants have been #defined.
1 parent 0b21afa commit cd1482e

File tree

2 files changed

+120
-92
lines changed

2 files changed

+120
-92
lines changed

WebSocketClient.cpp

+99-86
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ bool WebSocketClient::analyzeRequest() {
7676
socket_client->print(F("Sec-WebSocket-Key: "));
7777
socket_client->print(key);
7878
socket_client->print(CRLF);
79+
socket_client->print(F("Sec-WebSocket-Protocol: "));
80+
socket_client->print(protocol);
81+
socket_client->print(CRLF);
7982
socket_client->print(F("Sec-WebSocket-Version: 13\r\n"));
8083
socket_client->print(CRLF);
8184

@@ -131,7 +134,7 @@ bool WebSocketClient::analyzeRequest() {
131134
}
132135

133136

134-
String WebSocketClient::handleStream() {
137+
bool WebSocketClient::handleStream(String& data, uint8_t *opcode) {
135138
uint8_t msgtype;
136139
uint8_t bite;
137140
unsigned int length;
@@ -140,93 +143,96 @@ String WebSocketClient::handleStream() {
140143
unsigned int i;
141144
bool hasMask = false;
142145

143-
// String to hold bytes sent by server to client
144-
String socketString;
146+
if (!socket_client->connected() || !socket_client->available())
147+
{
148+
return false;
149+
}
145150

146-
if (socket_client->connected() && socket_client->available()) {
151+
msgtype = timedRead();
152+
if (!socket_client->connected()) {
153+
return false;
154+
}
147155

148-
msgtype = timedRead();
149-
if (!socket_client->connected()) {
150-
return socketString;
151-
}
156+
length = timedRead();
157+
158+
if (length & WS_MASK) {
159+
hasMask = true;
160+
length = length & ~WS_MASK;
161+
}
152162

153-
length = timedRead();
154163

155-
if (length > 127) {
156-
hasMask = true;
157-
length = length & 127;
158-
}
164+
if (!socket_client->connected()) {
165+
return false;
166+
}
159167

168+
index = 6;
160169

170+
if (length == WS_SIZE16) {
171+
length = timedRead() << 8;
161172
if (!socket_client->connected()) {
162-
return socketString;
173+
return false;
163174
}
164-
165-
index = 6;
166-
167-
if (length == 126) {
168-
length = timedRead() << 8;
169-
if (!socket_client->connected()) {
170-
return socketString;
171-
}
172175

173-
length |= timedRead();
174-
if (!socket_client->connected()) {
175-
return socketString;
176-
}
176+
length |= timedRead();
177+
if (!socket_client->connected()) {
178+
return false;
179+
}
177180

178-
} else if (length == 127) {
181+
} else if (length == WS_SIZE64) {
179182
#ifdef DEBUGGING
180-
Serial.println(F("No support for over 16 bit sized messages"));
183+
Serial.println(F("No support for over 16 bit sized messages"));
181184
#endif
182-
while(1) {
183-
// halt, can't handle this case
184-
}
185-
}
185+
return false;
186+
}
186187

187-
if (hasMask) {
188-
// get the mask
189-
mask[0] = timedRead();
190-
if (!socket_client->connected()) {
191-
return socketString;
192-
}
188+
if (hasMask) {
189+
// get the mask
190+
mask[0] = timedRead();
191+
if (!socket_client->connected()) {
192+
return false;
193+
}
193194

194-
mask[1] = timedRead();
195-
if (!socket_client->connected()) {
195+
mask[1] = timedRead();
196+
if (!socket_client->connected()) {
196197

197-
return socketString;
198-
}
198+
return false;
199+
}
199200

200-
mask[2] = timedRead();
201-
if (!socket_client->connected()) {
202-
return socketString;
203-
}
201+
mask[2] = timedRead();
202+
if (!socket_client->connected()) {
203+
return false;
204+
}
204205

205-
mask[3] = timedRead();
206+
mask[3] = timedRead();
207+
if (!socket_client->connected()) {
208+
return false;
209+
}
210+
}
211+
212+
data = "";
213+
214+
if (opcode != NULL)
215+
{
216+
*opcode = msgtype & ~WS_FIN;
217+
}
218+
219+
if (hasMask) {
220+
for (i=0; i<length; ++i) {
221+
data += (char) (timedRead() ^ mask[i % 4]);
206222
if (!socket_client->connected()) {
207-
return socketString;
223+
return false;
208224
}
209225
}
210-
211-
if (hasMask) {
212-
for (i=0; i<length; ++i) {
213-
socketString += (char) (timedRead() ^ mask[i % 4]);
214-
if (!socket_client->connected()) {
215-
return socketString;
216-
}
226+
} else {
227+
for (i=0; i<length; ++i) {
228+
data += (char) timedRead();
229+
if (!socket_client->connected()) {
230+
return false;
217231
}
218-
} else {
219-
for (i=0; i<length; ++i) {
220-
socketString += (char) timedRead();
221-
if (!socket_client->connected()) {
222-
return socketString;
223-
}
224-
}
225-
}
226-
232+
}
227233
}
228-
229-
return socketString;
234+
235+
return true;
230236
}
231237

232238
void WebSocketClient::disconnectStream() {
@@ -242,31 +248,27 @@ void WebSocketClient::disconnectStream() {
242248
socket_client->stop();
243249
}
244250

245-
String WebSocketClient::getData() {
246-
String data;
251+
bool WebSocketClient::getData(String& data, uint8_t *opcode) {
252+
return handleStream(data, opcode);
253+
}
247254

248-
data = handleStream();
249-
250-
return data;
251-
}
252-
253-
void WebSocketClient::sendData(const char *str) {
255+
void WebSocketClient::sendData(const char *str, uint8_t opcode) {
254256
#ifdef DEBUGGING
255257
Serial.print(F("Sending data: "));
256258
Serial.println(str);
257259
#endif
258260
if (socket_client->connected()) {
259-
sendEncodedData(str);
261+
sendEncodedData(str, opcode);
260262
}
261263
}
262264

263-
void WebSocketClient::sendData(String str) {
265+
void WebSocketClient::sendData(String str, uint8_t opcode) {
264266
#ifdef DEBUGGING
265267
Serial.print(F("Sending data: "));
266268
Serial.println(str);
267269
#endif
268270
if (socket_client->connected()) {
269-
sendEncodedData(str);
271+
sendEncodedData(str, opcode);
270272
}
271273
}
272274

@@ -278,31 +280,42 @@ int WebSocketClient::timedRead() {
278280
return socket_client->read();
279281
}
280282

281-
void WebSocketClient::sendEncodedData(char *str) {
283+
void WebSocketClient::sendEncodedData(char *str, uint8_t opcode) {
284+
uint8_t mask[4];
282285
int size = strlen(str);
283286

284-
// string type
285-
socket_client->write(0x81);
287+
// Opcode; final fragment
288+
socket_client->write(opcode | WS_FIN);
286289

287290
// NOTE: no support for > 16-bit sized messages
288291
if (size > 125) {
289-
socket_client->write(126);
292+
socket_client->write(WS_SIZE16 | WS_MASK);
290293
socket_client->write((uint8_t) (size >> 8));
291-
socket_client->write((uint8_t) (size && 0xFF));
294+
socket_client->write((uint8_t) (size & 0xFF));
292295
} else {
293-
socket_client->write((uint8_t) size);
296+
socket_client->write((uint8_t) size | WS_MASK);
294297
}
295298

299+
mask[0] = random(0, 256);
300+
mask[1] = random(0, 256);
301+
mask[2] = random(0, 256);
302+
mask[3] = random(0, 256);
303+
304+
socket_client->write(mask[0]);
305+
socket_client->write(mask[1]);
306+
socket_client->write(mask[2]);
307+
socket_client->write(mask[3]);
308+
296309
for (int i=0; i<size; ++i) {
297-
socket_client->write(str[i]);
310+
socket_client->write(str[i] ^ mask[i % 4]);
298311
}
299312
}
300313

301-
void WebSocketClient::sendEncodedData(String str) {
314+
void WebSocketClient::sendEncodedData(String str, uint8_t opcode) {
302315
int size = str.length() + 1;
303316
char cstr[size];
304317

305318
str.toCharArray(cstr, size);
306319

307-
sendEncodedData(cstr);
320+
sendEncodedData(cstr, opcode);
308321
}

WebSocketClient.h

+21-6
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,20 @@ Currently based off of "The Web Socket protocol" draft (v 75):
6969

7070
#define SIZE(array) (sizeof(array) / sizeof(*array))
7171

72+
// WebSocket protocol constants
73+
// First byte
74+
#define WS_FIN 0x80
75+
#define WS_OPCODE_TEXT 0x01
76+
#define WS_OPCODE_BINARY 0x02
77+
#define WS_OPCODE_CLOSE 0x08
78+
#define WS_OPCODE_PING 0x09
79+
#define WS_OPCODE_PONG 0x0a
80+
// Second byte
81+
#define WS_MASK 0x80
82+
#define WS_SIZE16 126
83+
#define WS_SIZE64 127
84+
85+
7286
class WebSocketClient {
7387
public:
7488

@@ -77,14 +91,15 @@ class WebSocketClient {
7791
bool handshake(Client &client);
7892

7993
// Get data off of the stream
80-
String getData();
94+
bool getData(String& data, uint8_t *opcode = NULL);
8195

8296
// Write data to the stream
83-
void sendData(const char *str);
84-
void sendData(String str);
97+
void sendData(const char *str, uint8_t opcode = WS_OPCODE_TEXT);
98+
void sendData(String str, uint8_t opcode = WS_OPCODE_TEXT);
8599

86100
char *path;
87101
char *host;
102+
char *protocol;
88103

89104
private:
90105
Client *socket_client;
@@ -96,15 +111,15 @@ class WebSocketClient {
96111
// websocket connection.
97112
bool analyzeRequest();
98113

99-
String handleStream();
114+
bool handleStream(String& data, uint8_t *opcode);
100115

101116
// Disconnect user gracefully.
102117
void disconnectStream();
103118

104119
int timedRead();
105120

106-
void sendEncodedData(char *str);
107-
void sendEncodedData(String str);
121+
void sendEncodedData(char *str, uint8_t opcode);
122+
void sendEncodedData(String str, uint8_t opcode);
108123
};
109124

110125

0 commit comments

Comments
 (0)