Instead of having half state in RedSASL and half in RedSASLAuth move everything in RedSASLAuth. This also reduces memory usage when we are using SASL but we finish the authentication step. Signed-off-by: Frediano Ziglio <fziglio@xxxxxxxxxx> --- server/red-stream.c | 138 +++++++++++++++++++++++++++------------------------- 1 file changed, 72 insertions(+), 66 deletions(-) diff --git a/server/red-stream.c b/server/red-stream.c index 0e372743a..898253c2e 100644 --- a/server/red-stream.c +++ b/server/red-stream.c @@ -67,13 +67,6 @@ typedef struct RedSASL { unsigned int encodedOffset; SpiceBuffer inbuffer; - - char *mechlist; - char *mechname; - - /* temporary data during authentication */ - unsigned int len; - char *data; } RedSASL; #endif @@ -351,13 +344,8 @@ void red_stream_free(RedStream *s) #if HAVE_SASL if (s->priv->sasl.conn) { s->priv->sasl.runSSF = s->priv->sasl.wantSSF = 0; - s->priv->sasl.len = 0; s->priv->sasl.encodedLength = s->priv->sasl.encodedOffset = 0; s->priv->sasl.encoded = NULL; - g_free(s->priv->sasl.mechlist); - g_free(s->priv->sasl.mechname); - s->priv->sasl.mechlist = NULL; - g_free(s->priv->sasl.data); sasl_dispose(&s->priv->sasl.conn); s->priv->sasl.conn = NULL; } @@ -735,17 +723,37 @@ static int auth_sasl_check_ssf(RedSASL *sasl, int *runSSF) typedef struct RedSASLAuth { RedStream *stream; + // list of mech allowed, allocated and freed by SASL + char *mechlist; + // mech received + char *mechname; + uint32_t len; + char *data; + // callback to call if success RedSaslResult result_cb; void *result_opaque; + // saved Async callback, we need to call if failed as + // we need to chain it in order to use a different opaque data AsyncReadError saved_error; } RedSASLAuth; +static void red_sasl_async_deinit(RedSASLAuth *opaque) +{ + g_free(opaque->data); + opaque->data = NULL; + g_free(opaque->mechname); + opaque->mechname = NULL; + g_free(opaque->mechlist); + opaque->mechlist = NULL; + opaque->stream->priv->async_read.error = opaque->saved_error; +} + // handle SASL termination, either success or error // NOTE: After this function is called usually there should be a // return or the function should exit static void red_sasl_async_result(RedSASLAuth *auth, RedSaslError err) { - auth->stream->priv->async_read.error = auth->saved_error; + red_sasl_async_deinit(auth); auth->result_cb(auth->result_opaque, err); g_free(auth); } @@ -753,7 +761,7 @@ static void red_sasl_async_result(RedSASLAuth *auth, RedSaslError err) static void red_sasl_error(void *opaque, int err) { RedSASLAuth *auth = opaque; - auth->stream->priv->async_read.error = auth->saved_error; + red_sasl_async_deinit(auth); if (auth->saved_error) { auth->saved_error(auth->result_opaque, err); } @@ -796,32 +804,33 @@ static void red_sasl_handle_auth_steplen(void *opaque); static void red_sasl_handle_auth_step(void *opaque) { - RedStream *stream = ((RedSASLAuth *)opaque)->stream; + RedSASLAuth *auth = opaque; + RedStream *stream = auth->stream; const char *serverout; unsigned int serveroutlen; int err; char *clientdata = NULL; RedSASL *sasl = &stream->priv->sasl; - uint32_t datalen = sasl->len; + uint32_t datalen = auth->len; /* NB, distinction of NULL vs "" is *critical* in SASL */ if (datalen) { - clientdata = sasl->data; + clientdata = auth->data; clientdata[datalen - 1] = '\0'; /* Wire includes '\0', but make sure */ datalen--; /* Don't count NULL byte when passing to _start() */ } - if (sasl->mechname != NULL) { + if (auth->mechname != NULL) { spice_debug("Start SASL auth with mechanism %s. Data %p (%d bytes)", - sasl->mechname, clientdata, datalen); + auth->mechname, clientdata, datalen); err = sasl_server_start(sasl->conn, - sasl->mechname, + auth->mechname, clientdata, datalen, &serverout, &serveroutlen); - g_free(sasl->mechname); - sasl->mechname = NULL; + g_free(auth->mechname); + auth->mechname = NULL; } else { spice_debug("Step using SASL Data %p (%d bytes)", clientdata, datalen); err = sasl_server_step(sasl->conn, @@ -834,13 +843,13 @@ static void red_sasl_handle_auth_step(void *opaque) err != SASL_CONTINUE) { spice_warning("sasl step failed %d (%s)", err, sasl_errdetail(sasl->conn)); - return red_sasl_async_result(opaque, RED_SASL_ERROR_GENERIC); + return red_sasl_async_result(auth, RED_SASL_ERROR_GENERIC); } if (serveroutlen > SASL_DATA_MAX_LEN) { spice_warning("sasl step reply data too long %d", serveroutlen); - return red_sasl_async_result(opaque, RED_SASL_ERROR_INVALID_DATA); + return red_sasl_async_result(auth, RED_SASL_ERROR_INVALID_DATA); } spice_debug("SASL return data %d bytes, %p", serveroutlen, serverout); @@ -859,8 +868,8 @@ static void red_sasl_handle_auth_step(void *opaque) if (err == SASL_CONTINUE) { spice_debug("%s", "Authentication must continue"); /* Wait for step length */ - red_stream_async_read(stream, (uint8_t *)&sasl->len, sizeof(uint32_t), - red_sasl_handle_auth_steplen, opaque); + red_stream_async_read(stream, (uint8_t *)&auth->len, sizeof(uint32_t), + red_sasl_handle_auth_steplen, auth); return; } else { int ssf; @@ -879,7 +888,7 @@ static void red_sasl_handle_auth_step(void *opaque) sasl->runSSF = ssf; red_stream_disable_writev(stream); /* make sure writev isn't called directly anymore */ - return red_sasl_async_result(opaque, RED_SASL_ERROR_OK); + return red_sasl_async_result(auth, RED_SASL_ERROR_OK); } authreject: @@ -887,73 +896,71 @@ authreject: red_stream_write_u32(stream, sizeof("Authentication failed")); red_stream_write_all(stream, "Authentication failed", sizeof("Authentication failed")); - red_sasl_async_result(opaque, RED_SASL_ERROR_AUTH_FAILED); + red_sasl_async_result(auth, RED_SASL_ERROR_AUTH_FAILED); } static void red_sasl_handle_auth_steplen(void *opaque) { - RedStream *stream = ((RedSASLAuth *)opaque)->stream; - RedSASL *sasl = &stream->priv->sasl; + RedSASLAuth *auth = opaque; - sasl->len = GUINT32_FROM_LE(sasl->len); - spice_debug("Got steplen %d", sasl->len); - if (sasl->len > SASL_DATA_MAX_LEN) { - spice_warning("Too much SASL data %d", sasl->len); - return red_sasl_async_result(opaque, RED_SASL_ERROR_INVALID_DATA); + auth->len = GUINT32_FROM_LE(auth->len); + uint32_t len = auth->len; + spice_debug("Got steplen %d", len); + if (len > SASL_DATA_MAX_LEN) { + spice_warning("Too much SASL data %d", len); + return red_sasl_async_result(auth, RED_SASL_ERROR_INVALID_DATA); } - if (sasl->len == 0) { - return red_sasl_handle_auth_step(opaque); + if (len == 0) { + return red_sasl_handle_auth_step(auth); } - sasl->data = g_realloc(sasl->data, sasl->len); - red_stream_async_read(stream, (uint8_t *)sasl->data, sasl->len, - red_sasl_handle_auth_step, opaque); + auth->data = g_realloc(auth->data, len); + red_stream_async_read(auth->stream, (uint8_t *)auth->data, len, + red_sasl_handle_auth_step, auth); } static void red_sasl_handle_auth_mechname(void *opaque) { - RedStream *stream = ((RedSASLAuth *)opaque)->stream; - RedSASL *sasl = &stream->priv->sasl; + RedSASLAuth *auth = opaque; - sasl->mechname[sasl->len] = '\0'; + auth->mechname[auth->len] = '\0'; spice_debug("Got client mechname '%s' check against '%s'", - sasl->mechname, sasl->mechlist); + auth->mechname, auth->mechlist); char quoted_mechname[SASL_MAX_MECHNAME_LEN + 4]; - sprintf(quoted_mechname, ",%s,", sasl->mechname); + sprintf(quoted_mechname, ",%s,", auth->mechname); - if (strchr(sasl->mechname, ',') || strstr(sasl->mechlist, quoted_mechname) == NULL) { - return red_sasl_async_result(opaque, RED_SASL_ERROR_INVALID_DATA); + if (strchr(auth->mechname, ',') || strstr(auth->mechlist, quoted_mechname) == NULL) { + return red_sasl_async_result(auth, RED_SASL_ERROR_INVALID_DATA); } - spice_debug("Validated mechname '%s'", sasl->mechname); + spice_debug("Validated mechname '%s'", auth->mechname); - red_stream_async_read(stream, (uint8_t *)&sasl->len, sizeof(uint32_t), - red_sasl_handle_auth_steplen, opaque); + red_stream_async_read(auth->stream, (uint8_t *)&auth->len, sizeof(uint32_t), + red_sasl_handle_auth_steplen, auth); } static void red_sasl_handle_auth_mechlen(void *opaque) { - RedStream *stream = ((RedSASLAuth *)opaque)->stream; - RedSASL *sasl = &stream->priv->sasl; + RedSASLAuth *auth = opaque; - sasl->len = GUINT32_FROM_LE(sasl->len); - if (sasl->len < 1 || sasl->len > SASL_MAX_MECHNAME_LEN) { - spice_warning("Got bad client mechname len %d", sasl->len); - return red_sasl_async_result(opaque, RED_SASL_ERROR_INVALID_DATA); + uint32_t len = GUINT32_FROM_LE(auth->len); + if (len < 1 || len > SASL_MAX_MECHNAME_LEN) { + spice_warning("Got bad client mechname len %d", auth->len); + return red_sasl_async_result(auth, RED_SASL_ERROR_INVALID_DATA); } - sasl->mechname = g_malloc(sasl->len + 1); + auth->mechname = g_malloc(len + 1); spice_debug("Wait for client mechname"); - red_stream_async_read(stream, (uint8_t *)sasl->mechname, sasl->len, - red_sasl_handle_auth_mechname, opaque); + red_stream_async_read(auth->stream, (uint8_t *)auth->mechname, auth->len, + red_sasl_handle_auth_mechname, auth); } -bool red_sasl_start_auth(RedStream *stream, RedSaslResult result_cb, void *opaque) +bool red_sasl_start_auth(RedStream *stream, RedSaslResult result_cb, void *result_opaque) { const char *mechlist = NULL; sasl_security_properties_t secprops; @@ -1047,11 +1054,9 @@ bool red_sasl_start_auth(RedStream *stream, RedSaslResult result_cb, void *opaqu spice_debug("Available mechanisms for client: '%s'", mechlist); - sasl->mechlist = g_strdup(mechlist); - mechlistlen = strlen(mechlist); if (!red_stream_write_u32(stream, mechlistlen) - || !red_stream_write_all(stream, sasl->mechlist, mechlistlen)) { + || !red_stream_write_all(stream, mechlist, mechlistlen)) { spice_warning("SASL mechanisms write error"); goto error; } @@ -1059,12 +1064,13 @@ bool red_sasl_start_auth(RedStream *stream, RedSaslResult result_cb, void *opaqu auth = g_new0(RedSASLAuth, 1); auth->stream = stream; auth->result_cb = result_cb; - auth->result_opaque = opaque; + auth->result_opaque = result_opaque; auth->saved_error = stream->priv->async_read.error; - stream->priv->async_read.error = red_sasl_error; + auth->mechlist = g_strdup(mechlist); spice_debug("Wait for client mechname length"); - red_stream_async_read(stream, (uint8_t *)&sasl->len, sizeof(uint32_t), + red_stream_set_async_error_handler(stream, red_sasl_error); + red_stream_async_read(stream, (uint8_t *)&auth->len, sizeof(uint32_t), red_sasl_handle_auth_mechlen, auth); return true; -- 2.14.3 _______________________________________________ Spice-devel mailing list Spice-devel@xxxxxxxxxxxxxxxxxxxxx https://lists.freedesktop.org/mailman/listinfo/spice-devel