Skip to content

Commit f8b6468

Browse files
Manicbenemasab
andauthored
Support setting principal and SASL extensions in oauth_cb, handle failures (confluentinc#1402)
* Support setting principal and SASL extensions in oauth_cb and handle token failures * removed global variables Co-authored-by: Emanuele Sabellico <[email protected]>
1 parent 9ea3aae commit f8b6468

File tree

2 files changed

+164
-38
lines changed

2 files changed

+164
-38
lines changed

src/confluent_kafka/src/confluent_kafka.c

+97-6
Original file line numberDiff line numberDiff line change
@@ -1522,13 +1522,73 @@ static void log_cb (const rd_kafka_t *rk, int level,
15221522
CallState_resume(cs);
15231523
}
15241524

1525+
/**
1526+
* @brief Translate Python \p key and \p value to C types and set on
1527+
* provided \p extensions char* array at the provided index.
1528+
*
1529+
* @returns 1 on success or 0 if an exception was raised.
1530+
*/
1531+
static int py_extensions_to_c (char **extensions, Py_ssize_t idx,
1532+
PyObject *key, PyObject *value) {
1533+
PyObject *ks, *ks8, *vo8 = NULL;
1534+
const char *k;
1535+
const char *v;
1536+
Py_ssize_t ksize = 0;
1537+
Py_ssize_t vsize = 0;
1538+
1539+
if (!(ks = cfl_PyObject_Unistr(key))) {
1540+
PyErr_SetString(PyExc_TypeError,
1541+
"expected extension key to be unicode "
1542+
"string");
1543+
return 0;
1544+
}
1545+
1546+
k = cfl_PyUnistr_AsUTF8(ks, &ks8);
1547+
ksize = (Py_ssize_t)strlen(k);
1548+
1549+
if (cfl_PyUnistr(_Check(value))) {
1550+
/* Unicode string, translate to utf-8. */
1551+
v = cfl_PyUnistr_AsUTF8(value, &vo8);
1552+
if (!v) {
1553+
Py_DECREF(ks);
1554+
Py_XDECREF(ks8);
1555+
return 0;
1556+
}
1557+
vsize = (Py_ssize_t)strlen(v);
1558+
} else {
1559+
PyErr_Format(PyExc_TypeError,
1560+
"expected extension value to be "
1561+
"unicode string, not %s",
1562+
((PyTypeObject *)PyObject_Type(value))->
1563+
tp_name);
1564+
Py_DECREF(ks);
1565+
Py_XDECREF(ks8);
1566+
return 0;
1567+
}
1568+
1569+
extensions[idx] = (char*)malloc(ksize);
1570+
strcpy(extensions[idx], k);
1571+
extensions[idx + 1] = (char*)malloc(vsize);
1572+
strcpy(extensions[idx + 1], v);
1573+
1574+
Py_DECREF(ks);
1575+
Py_XDECREF(ks8);
1576+
Py_XDECREF(vo8);
1577+
1578+
return 1;
1579+
}
1580+
15251581
static void oauth_cb (rd_kafka_t *rk, const char *oauthbearer_config,
15261582
void *opaque) {
15271583
Handle *h = opaque;
15281584
PyObject *eo, *result;
15291585
CallState *cs;
15301586
const char *token;
15311587
double expiry;
1588+
const char *principal = "";
1589+
PyObject *extensions = NULL;
1590+
char **rd_extensions = NULL;
1591+
Py_ssize_t rd_extensions_size = 0;
15321592
char err_msg[2048];
15331593
rd_kafka_resp_err_t err_code;
15341594

@@ -1539,26 +1599,57 @@ static void oauth_cb (rd_kafka_t *rk, const char *oauthbearer_config,
15391599
Py_DECREF(eo);
15401600

15411601
if (!result) {
1542-
goto err;
1602+
goto fail;
15431603
}
1544-
if (!PyArg_ParseTuple(result, "sd", &token, &expiry)) {
1604+
if (!PyArg_ParseTuple(result, "sd|sO!", &token, &expiry, &principal, &PyDict_Type, &extensions)) {
15451605
Py_DECREF(result);
1546-
PyErr_Format(PyExc_TypeError,
1606+
PyErr_SetString(PyExc_TypeError,
15471607
"expect returned value from oauth_cb "
15481608
"to be (token_str, expiry_time) tuple");
15491609
goto err;
15501610
}
1611+
1612+
if (extensions) {
1613+
int len = (int)PyDict_Size(extensions);
1614+
rd_extensions = (char **)malloc(2 * len * sizeof(char *));
1615+
Py_ssize_t pos = 0;
1616+
PyObject *ko, *vo;
1617+
while (PyDict_Next(extensions, &pos, &ko, &vo)) {
1618+
if (!py_extensions_to_c(rd_extensions, rd_extensions_size, ko, vo)) {
1619+
Py_DECREF(result);
1620+
free(rd_extensions);
1621+
goto err;
1622+
}
1623+
rd_extensions_size = rd_extensions_size + 2;
1624+
}
1625+
}
1626+
15511627
err_code = rd_kafka_oauthbearer_set_token(h->rk, token,
15521628
(int64_t)(expiry * 1000),
1553-
"", NULL, 0, err_msg,
1629+
principal, (const char **)rd_extensions, rd_extensions_size, err_msg,
15541630
sizeof(err_msg));
15551631
Py_DECREF(result);
1556-
if (err_code) {
1632+
if (rd_extensions) {
1633+
for(int i = 0; i < rd_extensions_size; i++) {
1634+
free(rd_extensions[i]);
1635+
}
1636+
free(rd_extensions);
1637+
}
1638+
1639+
if (err_code != RD_KAFKA_RESP_ERR_NO_ERROR) {
15571640
PyErr_Format(PyExc_ValueError, "%s", err_msg);
1558-
goto err;
1641+
goto fail;
15591642
}
15601643
goto done;
15611644

1645+
fail:
1646+
err_code = rd_kafka_oauthbearer_set_token_failure(h->rk, "OAuth callback raised exception");
1647+
if (err_code != RD_KAFKA_RESP_ERR_NO_ERROR) {
1648+
PyErr_SetString(PyExc_ValueError, "Failed to set token failure");
1649+
goto err;
1650+
}
1651+
PyErr_Clear();
1652+
goto done;
15621653
err:
15631654
CallState_crash(cs);
15641655
rd_kafka_yield(h->rk);

tests/test_misc.py

+67-32
Original file line numberDiff line numberDiff line change
@@ -24,49 +24,41 @@ def test_version():
2424
assert confluent_kafka.version()[0] == confluent_kafka.__version__
2525

2626

27-
# global variable for error_cb call back function
28-
seen_error_cb = False
29-
30-
3127
def test_error_cb():
3228
""" Tests error_cb. """
29+
seen_error_cb = False
3330

3431
def error_cb(error_msg):
35-
global seen_error_cb
32+
nonlocal seen_error_cb
3633
seen_error_cb = True
3734
acceptable_error_codes = (confluent_kafka.KafkaError._TRANSPORT, confluent_kafka.KafkaError._ALL_BROKERS_DOWN)
3835
assert error_msg.code() in acceptable_error_codes
3936

4037
conf = {'bootstrap.servers': 'localhost:65531', # Purposely cause connection refused error
4138
'group.id': 'test',
42-
'socket.timeout.ms': '100',
4339
'session.timeout.ms': 1000, # Avoid close() blocking too long
4440
'error_cb': error_cb
4541
}
4642

4743
kc = confluent_kafka.Consumer(**conf)
4844
kc.subscribe(["test"])
4945
while not seen_error_cb:
50-
kc.poll(timeout=1)
46+
kc.poll(timeout=0.1)
5147

5248
kc.close()
5349

5450

55-
# global variable for stats_cb call back function
56-
seen_stats_cb = False
57-
58-
5951
def test_stats_cb():
6052
""" Tests stats_cb. """
53+
seen_stats_cb = False
6154

6255
def stats_cb(stats_json_str):
63-
global seen_stats_cb
56+
nonlocal seen_stats_cb
6457
seen_stats_cb = True
6558
stats_json = json.loads(stats_json_str)
6659
assert len(stats_json['name']) > 0
6760

6861
conf = {'group.id': 'test',
69-
'socket.timeout.ms': '100',
7062
'session.timeout.ms': 1000, # Avoid close() blocking too long
7163
'statistics.interval.ms': 200,
7264
'stats_cb': stats_cb
@@ -76,22 +68,20 @@ def stats_cb(stats_json_str):
7668

7769
kc.subscribe(["test"])
7870
while not seen_stats_cb:
79-
kc.poll(timeout=1)
71+
kc.poll(timeout=0.1)
8072
kc.close()
8173

8274

83-
seen_stats_cb_check_no_brokers = False
84-
85-
8675
def test_conf_none():
8776
""" Issue #133
8877
Test that None can be passed for NULL by setting bootstrap.servers
8978
to None. If None would be converted to a string then a broker would
9079
show up in statistics. Verify that it doesnt. """
80+
seen_stats_cb_check_no_brokers = False
9181

9282
def stats_cb_check_no_brokers(stats_json_str):
9383
""" Make sure no brokers are reported in stats """
94-
global seen_stats_cb_check_no_brokers
84+
nonlocal seen_stats_cb_check_no_brokers
9585
stats = json.loads(stats_json_str)
9686
assert len(stats['brokers']) == 0, "expected no brokers in stats: %s" % stats_json_str
9787
seen_stats_cb_check_no_brokers = True
@@ -101,9 +91,8 @@ def stats_cb_check_no_brokers(stats_json_str):
10191
'stats_cb': stats_cb_check_no_brokers}
10292

10393
p = confluent_kafka.Producer(conf)
104-
p.poll(timeout=1)
94+
p.poll(timeout=0.1)
10595

106-
global seen_stats_cb_check_no_brokers
10796
assert seen_stats_cb_check_no_brokers
10897

10998

@@ -130,23 +119,19 @@ def test_throttle_event_types():
130119
assert str(throttle_event) == "broker/0 throttled for 10000 ms"
131120

132121

133-
# global variable for oauth_cb call back function
134-
seen_oauth_cb = False
135-
136-
137122
def test_oauth_cb():
138123
""" Tests oauth_cb. """
124+
seen_oauth_cb = False
139125

140126
def oauth_cb(oauth_config):
141-
global seen_oauth_cb
127+
nonlocal seen_oauth_cb
142128
seen_oauth_cb = True
143129
assert oauth_config == 'oauth_cb'
144130
return 'token', time.time() + 300.0
145131

146132
conf = {'group.id': 'test',
147133
'security.protocol': 'sasl_plaintext',
148134
'sasl.mechanisms': 'OAUTHBEARER',
149-
'socket.timeout.ms': '100',
150135
'session.timeout.ms': 1000, # Avoid close() blocking too long
151136
'sasl.oauthbearer.config': 'oauth_cb',
152137
'oauth_cb': oauth_cb
@@ -155,7 +140,59 @@ def oauth_cb(oauth_config):
155140
kc = confluent_kafka.Consumer(**conf)
156141

157142
while not seen_oauth_cb:
158-
kc.poll(timeout=1)
143+
kc.poll(timeout=0.1)
144+
kc.close()
145+
146+
147+
def test_oauth_cb_principal_sasl_extensions():
148+
""" Tests oauth_cb. """
149+
seen_oauth_cb = False
150+
151+
def oauth_cb(oauth_config):
152+
nonlocal seen_oauth_cb
153+
seen_oauth_cb = True
154+
assert oauth_config == 'oauth_cb'
155+
return 'token', time.time() + 300.0, oauth_config, {"extone": "extoneval", "exttwo": "exttwoval"}
156+
157+
conf = {'group.id': 'test',
158+
'security.protocol': 'sasl_plaintext',
159+
'sasl.mechanisms': 'OAUTHBEARER',
160+
'session.timeout.ms': 100, # Avoid close() blocking too long
161+
'sasl.oauthbearer.config': 'oauth_cb',
162+
'oauth_cb': oauth_cb
163+
}
164+
165+
kc = confluent_kafka.Consumer(**conf)
166+
167+
while not seen_oauth_cb:
168+
kc.poll(timeout=0.1)
169+
kc.close()
170+
171+
172+
def test_oauth_cb_failure():
173+
""" Tests oauth_cb. """
174+
oauth_cb_count = 0
175+
176+
def oauth_cb(oauth_config):
177+
nonlocal oauth_cb_count
178+
oauth_cb_count += 1
179+
assert oauth_config == 'oauth_cb'
180+
if oauth_cb_count == 2:
181+
return 'token', time.time() + 100.0, oauth_config, {"extthree": "extthreeval"}
182+
raise Exception
183+
184+
conf = {'group.id': 'test',
185+
'security.protocol': 'sasl_plaintext',
186+
'sasl.mechanisms': 'OAUTHBEARER',
187+
'session.timeout.ms': 1000, # Avoid close() blocking too long
188+
'sasl.oauthbearer.config': 'oauth_cb',
189+
'oauth_cb': oauth_cb
190+
}
191+
192+
kc = confluent_kafka.Consumer(**conf)
193+
194+
while oauth_cb_count < 2:
195+
kc.poll(timeout=0.1)
159196
kc.close()
160197

161198

@@ -194,11 +231,9 @@ def test_unordered_dict(init_func):
194231
client.poll(0)
195232

196233

197-
# global variable for on_delivery call back function
198-
seen_delivery_cb = False
199-
200-
201234
def test_topic_config_update():
235+
seen_delivery_cb = False
236+
202237
# *NOTE* default.topic.config has been deprecated.
203238
# This example remains to ensure backward-compatibility until its removal.
204239
confs = [{"message.timeout.ms": 600000, "default.topic.config": {"message.timeout.ms": 1000}},
@@ -207,7 +242,7 @@ def test_topic_config_update():
207242

208243
def on_delivery(err, msg):
209244
# Since there is no broker, produced messages should time out.
210-
global seen_delivery_cb
245+
nonlocal seen_delivery_cb
211246
seen_delivery_cb = True
212247
assert err.code() == confluent_kafka.KafkaError._MSG_TIMED_OUT
213248

0 commit comments

Comments
 (0)