From: Sebastian Andrzej Siewior <sebastian@xxxxxxxxxxxxx> The "zstd@xxxxxxxxxxxxx" compression algorithm enables ZSTD based compression as defined in RFC8478. The compression is delayed until the server sends the SSH_MSG_USERAUTH_SUCCESS which is the same time as with the "zlib@xxxxxxxxxxx" method. Signed-off-by: Sebastian Andrzej Siewior <sebastian@xxxxxxxxxxxxx> --- cipher.c | 30 +++++- configure.ac | 8 ++ kex.c | 5 + kex.h | 3 + myproposal.h | 2 +- packet.c | 272 +++++++++++++++++++++++++++++++++++++++++++++------ readconf.c | 8 +- servconf.c | 14 +-- ssh.c | 4 +- 9 files changed, 300 insertions(+), 46 deletions(-) diff --git a/cipher.c b/cipher.c index 02aea4089ff91..1634bb4019c86 100644 --- a/cipher.c +++ b/cipher.c @@ -48,6 +48,7 @@ #include "sshbuf.h" #include "ssherr.h" #include "digest.h" +#include "kex.h" #include "openbsd-compat/openssl-compat.h" @@ -142,12 +143,33 @@ cipher_alg_list(char sep, int auth_only) const char * compression_alg_list(int compression) { -#ifdef WITH_ZLIB - return compression ? "zlib@xxxxxxxxxxx,zlib,none" : - "none,zlib@xxxxxxxxxxx,zlib"; +#ifdef HAVE_LIBZSTD +#define COMP_ZSTD_WITH "zstd@xxxxxxxxxxxxx," +#define COMP_ZSTD_NONE ",zstd@xxxxxxxxxxxxx" #else - return "none"; +#define COMP_ZSTD_WITH "" +#define COMP_ZSTD_NONE "" #endif + +#ifdef WITH_ZLIB +#define COMP_ZLIB_C_WITH "zlib@xxxxxxxxxxx,zlib," +#define COMP_ZLIB_S_WITH "zlib@xxxxxxxxxxx," + +#define COMP_ZLIB_C_NONE ",zlib@xxxxxxxxxxx,zlib" +#else +#define COMP_ZLIB_C_WITH "" +#define COMP_ZLIB_S_WITH "" +#define COMP_ZLIB_C_NONE "" +#endif + switch (compression) { + case COMP_ZLIB: return COMP_ZLIB_C_WITH "none"; + case COMP_DELAYED: return COMP_ZLIB_S_WITH "none"; + case COMP_ZSTD: return COMP_ZSTD_WITH "none"; + case COMP_ALL_C: return COMP_ZSTD_WITH COMP_ZLIB_C_WITH "none"; + case COMP_ALL_S: return COMP_ZSTD_WITH COMP_ZLIB_S_WITH "none"; + default: + case 0: return "none" COMP_ZSTD_NONE COMP_ZLIB_C_NONE; + } } u_int diff --git a/configure.ac b/configure.ac index 22fee70f604a2..91ef386788be3 100644 --- a/configure.ac +++ b/configure.ac @@ -1498,6 +1498,14 @@ See http://www.gzip.org/zlib/ for details.]) LIBS="$saved_LIBS" fi +AC_ARG_WITH([libzstd], AS_HELP_STRING([--with-libzstd], [Build with libzstd.])) +AS_IF([test "x$with_libzstd" = "xyes"], + [ + PKG_CHECK_MODULES([LIBZSTD], [libzstd >= 1.4.0], [AC_DEFINE([HAVE_LIBZSTD], [1], [Use LIBZSTD])]) + LIBS="$LIBS ${LIBZSTD_LIBS}" + CFLAGS="$CFLAGS ${LIBZSTD_CFLAGS}" + ]) + dnl UnixWare 2.x AC_CHECK_FUNC([strcasecmp], [], [ AC_CHECK_LIB([resolv], [strcasecmp], [LIBS="$LIBS -lresolv"]) ] diff --git a/kex.c b/kex.c index 7731ca9004fc8..d71fd777a3123 100644 --- a/kex.c +++ b/kex.c @@ -826,6 +826,11 @@ choose_comp(struct sshcomp *comp, char *client, char *server) comp->type = COMP_ZLIB; } else #endif /* WITH_ZLIB */ +#ifdef HAVE_LIBZSTD + if (strcmp(name, "zstd@xxxxxxxxxxxxx") == 0) { + comp->type = COMP_ZSTD; + } else +#endif /* HAVE_LIBZSTD */ if (strcmp(name, "none") == 0) { comp->type = COMP_NONE; } else { diff --git a/kex.h b/kex.h index c35329501871a..159cfc794bd67 100644 --- a/kex.h +++ b/kex.h @@ -68,6 +68,9 @@ /* pre-auth compression (COMP_ZLIB) is only supported in the client */ #define COMP_ZLIB 1 #define COMP_DELAYED 2 +#define COMP_ZSTD 3 +#define COMP_ALL_C 4 +#define COMP_ALL_S 5 #define CURVE25519_SIZE 32 diff --git a/myproposal.h b/myproposal.h index ee6e9f7415261..a015190b35d9f 100644 --- a/myproposal.h +++ b/myproposal.h @@ -88,7 +88,7 @@ "rsa-sha2-512," \ "rsa-sha2-256" -#define KEX_DEFAULT_COMP "none,zlib@xxxxxxxxxxx" +#define KEX_DEFAULT_COMP "none,zstd@xxxxxxxxxxxxx,zlib@xxxxxxxxxxx" #define KEX_DEFAULT_LANG "" #define KEX_CLIENT \ diff --git a/packet.c b/packet.c index 3f64d2d32854a..a39b8d7fbd963 100644 --- a/packet.c +++ b/packet.c @@ -79,6 +79,9 @@ #ifdef WITH_ZLIB #include <zlib.h> #endif +#ifdef HAVE_LIBZSTD +#include <zstd.h> +#endif #include "xmalloc.h" #include "compat.h" @@ -156,6 +159,14 @@ struct session_state { /* Incoming/outgoing compression dictionaries */ z_stream compression_in_stream; z_stream compression_out_stream; +#endif +#ifdef HAVE_LIBZSTD + ZSTD_DCtx *compression_zstd_in_stream; + ZSTD_CCtx *compression_zstd_out_stream; + u_int64_t compress_zstd_in_raw; + u_int64_t compress_zstd_in_comp; + u_int64_t compress_zstd_out_raw; + u_int64_t compress_zstd_out_comp; #endif int compression_in_started; int compression_out_started; @@ -604,11 +615,11 @@ ssh_packet_close_internal(struct ssh *ssh, int do_close) state->newkeys[mode] = NULL; ssh_clear_newkeys(ssh, mode); /* next keys */ } -#ifdef WITH_ZLIB /* compression state is in shared mem, so we can only release it once */ if (do_close && state->compression_buffer) { sshbuf_free(state->compression_buffer); - if (state->compression_out_started) { +#ifdef WITH_ZLIB + if (state->compression_out_started == COMP_ZLIB) { z_streamp stream = &state->compression_out_stream; debug("compress outgoing: " "raw data %llu, compressed %llu, factor %.2f", @@ -619,7 +630,7 @@ ssh_packet_close_internal(struct ssh *ssh, int do_close) if (state->compression_out_failures == 0) deflateEnd(stream); } - if (state->compression_in_started) { + if (state->compression_in_started == COMP_ZLIB) { z_streamp stream = &state->compression_in_stream; debug("compress incoming: " "raw data %llu, compressed %llu, factor %.2f", @@ -630,8 +641,28 @@ ssh_packet_close_internal(struct ssh *ssh, int do_close) if (state->compression_in_failures == 0) inflateEnd(stream); } +#endif /* WITH_ZLIB */ +#ifdef HAVE_LIBZSTD + if (state->compression_out_started == COMP_ZSTD) { + debug("compress outgoing: " + "raw data %llu, compressed %llu, factor %.2f", + (unsigned long long)state->compress_zstd_out_raw, + (unsigned long long)state->compress_zstd_out_comp, + state->compress_zstd_out_raw == 0 ? 0.0 : + (double) state->compress_zstd_out_comp / + state->compress_zstd_out_raw); + } + if (state->compression_in_started == COMP_ZSTD) { + debug("compress incoming: " + "raw data %llu, compressed %llu, factor %.2f", + (unsigned long long)state->compress_zstd_in_raw, + (unsigned long long)state->compress_zstd_in_comp, + state->compress_zstd_in_raw == 0 ? 0.0 : + (double) state->compress_zstd_in_comp / + state->compress_zstd_in_raw); + } +#endif /* HAVE_LIBZSTD */ } -#endif /* WITH_ZLIB */ cipher_free(state->send_context); cipher_free(state->receive_context); state->send_context = state->receive_context = NULL; @@ -696,11 +727,11 @@ start_compression_out(struct ssh *ssh, int level) if (level < 1 || level > 9) return SSH_ERR_INVALID_ARGUMENT; debug("Enabling compression at level %d.", level); - if (ssh->state->compression_out_started == 1) + if (ssh->state->compression_out_started == COMP_ZLIB) deflateEnd(&ssh->state->compression_out_stream); switch (deflateInit(&ssh->state->compression_out_stream, level)) { case Z_OK: - ssh->state->compression_out_started = 1; + ssh->state->compression_out_started = COMP_ZLIB; break; case Z_MEM_ERROR: return SSH_ERR_ALLOC_FAIL; @@ -713,11 +744,11 @@ start_compression_out(struct ssh *ssh, int level) static int start_compression_in(struct ssh *ssh) { - if (ssh->state->compression_in_started == 1) + if (ssh->state->compression_in_started == COMP_ZLIB) inflateEnd(&ssh->state->compression_in_stream); switch (inflateInit(&ssh->state->compression_in_stream)) { case Z_OK: - ssh->state->compression_in_started = 1; + ssh->state->compression_in_started = COMP_ZLIB; break; case Z_MEM_ERROR: return SSH_ERR_ALLOC_FAIL; @@ -734,7 +765,7 @@ compress_buffer(struct ssh *ssh, struct sshbuf *in, struct sshbuf *out) u_char buf[4096]; int r, status; - if (ssh->state->compression_out_started != 1) + if (ssh->state->compression_out_started != COMP_ZLIB) return SSH_ERR_INTERNAL_ERROR; /* This case is not handled below. */ @@ -780,7 +811,7 @@ uncompress_buffer(struct ssh *ssh, struct sshbuf *in, struct sshbuf *out) u_char buf[4096]; int r, status; - if (ssh->state->compression_in_started != 1) + if (ssh->state->compression_in_started != COMP_ZLIB) return SSH_ERR_INTERNAL_ERROR; if ((ssh->state->compression_in_stream.next_in = @@ -848,6 +879,143 @@ uncompress_buffer(struct ssh *ssh, struct sshbuf *in, struct sshbuf *out) } #endif /* WITH_ZLIB */ +#ifdef HAVE_LIBZSTD +static int +start_compression_zstd_out(struct ssh *ssh) +{ + debug("Enabling ZSTD compression."); + if (ssh->state->compression_out_started == COMP_ZSTD) + ZSTD_CCtx_reset(ssh->state->compression_zstd_out_stream, ZSTD_reset_session_only); + if (!ssh->state->compression_zstd_out_stream) + ssh->state->compression_zstd_out_stream = ZSTD_createCCtx(); + if (!ssh->state->compression_zstd_out_stream) + return SSH_ERR_ALLOC_FAIL; + ssh->state->compression_out_started = COMP_ZSTD; + return 0; +} + +static int +start_compression_zstd_in(struct ssh *ssh) +{ + if (ssh->state->compression_in_started == COMP_ZSTD) + ZSTD_DCtx_reset(ssh->state->compression_zstd_in_stream, ZSTD_reset_session_only); + if (!ssh->state->compression_zstd_in_stream) + ssh->state->compression_zstd_in_stream = ZSTD_createDCtx(); + if (!ssh->state->compression_zstd_in_stream) + return SSH_ERR_ALLOC_FAIL; + + ssh->state->compression_in_started = COMP_ZSTD; + return 0; +} + +static int +compress_buffer_zstd(struct ssh *ssh, struct sshbuf *in, struct sshbuf *out) +{ + u_char buf[4096]; + ZSTD_inBuffer in_buff; + ZSTD_outBuffer out_buff; + int r, comp; + + if (ssh->state->compression_out_started != COMP_ZSTD) + return SSH_ERR_INTERNAL_ERROR; + + if (sshbuf_len(in) == 0) + return 0; + + in_buff.src = sshbuf_mutable_ptr(in); + if (!in_buff.src) + return SSH_ERR_INTERNAL_ERROR; + in_buff.size = sshbuf_len(in); + in_buff.pos = 0; + + ssh->state->compress_zstd_out_raw += in_buff.size; + out_buff.dst = buf; + out_buff.size = sizeof(buf); + + /* + * Consume input and immediatelly flush compressed data. It will loop + * multiple times if the output does not fit into the buffer + */ + do { + out_buff.pos = 0; + + comp = ZSTD_compressStream2(ssh->state->compression_zstd_out_stream, + &out_buff, &in_buff, ZSTD_e_flush); + if (ZSTD_isError(comp)) + return SSH_ERR_ALLOC_FAIL; + /* Append compressed data to output_buffer. */ + r = sshbuf_put(out, buf, out_buff.pos); + if (r != 0) + return r; + ssh->state->compress_zstd_out_comp += out_buff.pos; + } while (comp > 0); + return 0; +} + +static int uncompress_buffer_zstd(struct ssh *ssh, struct sshbuf *in, + struct sshbuf *out) +{ + u_char buf[4096]; + ZSTD_inBuffer in_buff; + ZSTD_outBuffer out_buff; + int r, decomp; + + if (ssh->state->compression_in_started != COMP_ZSTD) + return SSH_ERR_INTERNAL_ERROR; + + in_buff.src = sshbuf_mutable_ptr(in); + if (in_buff.src == NULL) + return SSH_ERR_INTERNAL_ERROR; + in_buff.size = sshbuf_len(in); + in_buff.pos = 0; + ssh->state->compress_zstd_in_comp += in_buff.size; + for (;;) { + /* Set up fixed-size output buffer. */ + out_buff.dst = buf; + out_buff.size = sizeof(buf); + out_buff.pos = 0; + + decomp = ZSTD_decompressStream(ssh->state->compression_zstd_in_stream, + &out_buff, &in_buff); + if (ZSTD_isError(decomp)) + return SSH_ERR_INVALID_FORMAT; + + r = sshbuf_put(out, buf, out_buff.pos); + if (r != 0) + return r; + ssh->state->compress_zstd_in_raw += out_buff.pos; + if (in_buff.size == in_buff.pos && + out_buff.pos < sizeof(buf)) + return 0; + } +} +#else /* HAVE_LIBZSTD */ + +static int +start_compression_zstd_out(struct ssh *ssh) +{ + return SSH_ERR_INTERNAL_ERROR; +} + +static int +start_compression_zstd_in(struct ssh *ssh) +{ + return SSH_ERR_INTERNAL_ERROR; +} + +static int +compress_buffer_zstd(struct ssh *ssh, struct sshbuf *in, struct sshbuf *out) +{ + return SSH_ERR_INTERNAL_ERROR; +} + +static int +uncompress_buffer_zstd(struct ssh *ssh, struct sshbuf *in, struct sshbuf *out) +{ + return SSH_ERR_INTERNAL_ERROR; +} +#endif /* HAVE_LIBZSTD */ + void ssh_clear_newkeys(struct ssh *ssh, int mode) { @@ -924,18 +1092,29 @@ ssh_set_newkeys(struct ssh *ssh, int mode) explicit_bzero(enc->key, enc->key_len); explicit_bzero(mac->key, mac->key_len); */ if ((comp->type == COMP_ZLIB || - (comp->type == COMP_DELAYED && + ((comp->type == COMP_DELAYED || comp->type == COMP_ZSTD) && state->after_authentication)) && comp->enabled == 0) { if ((r = ssh_packet_init_compression(ssh)) < 0) return r; - if (mode == MODE_OUT) { - if ((r = start_compression_out(ssh, 6)) != 0) - return r; + if (comp->type == COMP_ZSTD) { + if (mode == MODE_OUT) { + if ((r = start_compression_zstd_out(ssh)) != 0) + return r; + } else { + if ((r = start_compression_zstd_in(ssh)) != 0) + return r; + } + comp->enabled = COMP_ZSTD; } else { - if ((r = start_compression_in(ssh)) != 0) - return r; + if (mode == MODE_OUT) { + if ((r = start_compression_out(ssh, 6)) != 0) + return r; + } else { + if ((r = start_compression_in(ssh)) != 0) + return r; + } + comp->enabled = COMP_ZLIB; } - comp->enabled = 1; } /* * The 2^(blocksize*2) limit is too expensive for 3DES, @@ -1022,6 +1201,7 @@ ssh_packet_enable_delayed_compress(struct ssh *ssh) struct session_state *state = ssh->state; struct sshcomp *comp = NULL; int r, mode; + int type = 0; /* * Remember that we are past the authentication step, so rekeying @@ -1033,17 +1213,33 @@ ssh_packet_enable_delayed_compress(struct ssh *ssh) if (state->newkeys[mode] == NULL) continue; comp = &state->newkeys[mode]->comp; - if (comp && !comp->enabled && comp->type == COMP_DELAYED) { - if ((r = ssh_packet_init_compression(ssh)) != 0) + if (comp && !comp->enabled && comp->type) + type = comp->type; + if (type == COMP_DELAYED || type == COMP_ZSTD) { + if ((r = ssh_packet_init_compression(ssh)) != 0) { return r; - if (mode == MODE_OUT) { - if ((r = start_compression_out(ssh, 6)) != 0) - return r; - } else { - if ((r = start_compression_in(ssh)) != 0) - return r; } - comp->enabled = 1; + if (type == COMP_DELAYED) { + if (mode == MODE_OUT) { + if ((r = start_compression_out(ssh, 6)) != 0) + return r; + } else { + if ((r = start_compression_in(ssh)) != 0) + return r; + } + comp->enabled = COMP_ZLIB; + } else if (type == COMP_ZSTD) { + if (mode == MODE_OUT) { + if ((r = start_compression_zstd_out(ssh)) != 0) + return r; + } else { + if ((r = start_compression_zstd_in(ssh)) != 0) + return r; + } + comp->enabled = COMP_ZSTD; + } else { + return SSH_ERR_INTERNAL_ERROR; + } } } return 0; @@ -1104,9 +1300,15 @@ ssh_packet_send2_wrapped(struct ssh *ssh) if ((r = sshbuf_consume(state->outgoing_packet, 5)) != 0) goto out; sshbuf_reset(state->compression_buffer); - if ((r = compress_buffer(ssh, state->outgoing_packet, - state->compression_buffer)) != 0) - goto out; + if (comp->enabled == COMP_ZSTD) { + if ((r = compress_buffer_zstd(ssh, state->outgoing_packet, + state->compression_buffer)) != 0) + goto out; + } else { + if ((r = compress_buffer(ssh, state->outgoing_packet, + state->compression_buffer)) != 0) + goto out; + } sshbuf_reset(state->outgoing_packet); if ((r = sshbuf_put(state->outgoing_packet, "\0\0\0\0\0", 5)) != 0 || @@ -1657,9 +1859,15 @@ ssh_packet_read_poll2(struct ssh *ssh, u_char *typep, u_int32_t *seqnr_p) sshbuf_len(state->incoming_packet))); if (comp && comp->enabled) { sshbuf_reset(state->compression_buffer); - if ((r = uncompress_buffer(ssh, state->incoming_packet, - state->compression_buffer)) != 0) - goto out; + if (comp->enabled == COMP_ZSTD) { + if ((r = uncompress_buffer_zstd(ssh, state->incoming_packet, + state->compression_buffer)) != 0) + goto out; + } else { + if ((r = uncompress_buffer(ssh, state->incoming_packet, + state->compression_buffer)) != 0) + goto out; + } sshbuf_reset(state->incoming_packet); if ((r = sshbuf_putb(state->incoming_packet, state->compression_buffer)) != 0) diff --git a/readconf.c b/readconf.c index cf79498848f6d..f05aab2316c8a 100644 --- a/readconf.c +++ b/readconf.c @@ -899,8 +899,14 @@ static const struct multistate multistate_pubkey_auth[] = { { NULL, -1 } }; static const struct multistate multistate_compression[] = { +#if defined(WITH_ZLIB) || defined(HAVE_LIBZSTD) + { "yes", COMP_ALL_C }, +#endif #ifdef WITH_ZLIB - { "yes", COMP_ZLIB }, + { "zlib", COMP_ZLIB }, +#endif +#ifdef HAVE_LIBZSTD + { "zstd", COMP_ZSTD }, #endif { "no", COMP_NONE }, { NULL, -1 } diff --git a/servconf.c b/servconf.c index 2e039da8b95e8..a82ef128c79f7 100644 --- a/servconf.c +++ b/servconf.c @@ -375,11 +375,7 @@ fill_default_server_options(ServerOptions *options) options->permit_user_env_allowlist = NULL; } if (options->compression == -1) -#ifdef WITH_ZLIB - options->compression = COMP_DELAYED; -#else - options->compression = COMP_NONE; -#endif + options->compression = COMP_ALL_S; if (options->rekey_limit == -1) options->rekey_limit = 0; @@ -1303,9 +1299,15 @@ static const struct multistate multistate_permitrootlogin[] = { { NULL, -1 } }; static const struct multistate multistate_compression[] = { +#if defined(WITH_ZLIB) || defined(HAVE_LIBZSTD) + { "yes", COMP_ALL_S }, +#endif #ifdef WITH_ZLIB - { "yes", COMP_DELAYED }, { "delayed", COMP_DELAYED }, + { "zlib", COMP_DELAYED }, +#endif +#ifdef HAVE_LIBZSTD + { "zstd", COMP_ZSTD }, #endif { "no", COMP_NONE }, { NULL, -1 } diff --git a/ssh.c b/ssh.c index 918389bccba25..ae67808a36215 100644 --- a/ssh.c +++ b/ssh.c @@ -1011,8 +1011,8 @@ main(int ac, char **av) break; case 'C': -#ifdef WITH_ZLIB - options.compression = 1; +#if defined(HAVE_LIBZSTD) || defined(WITH_ZLIB) + options.compression = COMP_ALL_C; #else error("Compression not supported, disabling."); #endif -- 2.39.2 _______________________________________________ openssh-unix-dev mailing list openssh-unix-dev@xxxxxxxxxxx https://lists.mindrot.org/mailman/listinfo/openssh-unix-dev