From: Sebastian Andrzej Siewior <sebastian@xxxxxxxxxxxxx> The "zstd@xxxxxxxxxxx" compression algorithm enables ZSTD based compression as defined in RFC8478. The compression is delayed until the server sends the SSH_MSG_USERAUTH_SUCCESS which is like the "zlib@xxxxxxxxxxx" method (except a different compression algorithm is used). Signed-off-by: Sebastian Andrzej Siewior <sebastian@xxxxxxxxxxxxx> --- Makefile.in | 2 +- cipher.c | 30 +++++- configure.ac | 7 ++ kex.c | 5 + kex.h | 3 + myproposal.h | 2 +- packet.c | 277 +++++++++++++++++++++++++++++++++++++++++++++------ readconf.c | 8 +- servconf.c | 14 +-- ssh.c | 4 +- 10 files changed, 305 insertions(+), 47 deletions(-) diff --git a/Makefile.in b/Makefile.in index 895beb9d0aae7..55836d17439d1 100644 --- a/Makefile.in +++ b/Makefile.in @@ -46,7 +46,7 @@ CFLAGS=@CFLAGS@ CFLAGS_NOPIE=@CFLAGS_NOPIE@ CPPFLAGS=-I. -I$(srcdir) @CPPFLAGS@ $(PATHS) @DEFS@ PICFLAG=@PICFLAG@ -LIBS=@LIBS@ +LIBS=@LIBS@ @LIBZSTD_LIBS@ K5LIBS=@K5LIBS@ GSSLIBS=@GSSLIBS@ SSHLIBS=@SSHLIBS@ diff --git a/cipher.c b/cipher.c index 820bc6ace3e8c..176dc664412e1 100644 --- a/cipher.c +++ b/cipher.c @@ -48,6 +48,7 @@ #include "sshbuf.h" #include "ssherr.h" #include "digest.h" +#include "kex.h" #include "openbsd-compat/openssl-compat.h" @@ -146,12 +147,33 @@ cipher_alg_list(char sep, int auth_only) const char * compression_alg_list(int compression) { -#ifdef WITH_ZLIB - return compression ? "zlib@xxxxxxxxxxx,zlib,none" : - "none,zlib@xxxxxxxxxxx,zlib"; +#ifdef HAVE_LIBZSTD +#define COMP_ZSTD_WITH "zstd@xxxxxxxxxxx," +#define COMP_ZSTD_NONE ",zstd@xxxxxxxxxxx" #else - return "none"; +#define COMP_ZSTD_WITH "" +#define COMP_ZSTD_NONE "" #endif + +#ifdef WITH_ZLIB +#define COMP_ZLIB_C_WITH "zlib@xxxxxxxxxxx,zlib," +#define COMP_ZLIB_S_WITH "zlib@xxxxxxxxxxx," + +#define COMP_ZLIB_C_NONE ",zlib@xxxxxxxxxxx,zlib" +#else +#define COMP_ZLIB_C_WITH "" +#define COMP_ZLIB_S_WITH "" +#define COMP_ZLIB_C_NONE "" +#endif + switch (compression) { + case COMP_ZLIB: return COMP_ZLIB_C_WITH "none"; + case COMP_DELAYED: return COMP_ZLIB_S_WITH "none"; + case COMP_ZSTD: return COMP_ZSTD_WITH "none"; + case COMP_ALL_C: return COMP_ZSTD_WITH COMP_ZLIB_C_WITH "none"; + case COMP_ALL_S: return COMP_ZSTD_WITH COMP_ZLIB_S_WITH "none"; + default: + case 0: return "none" COMP_ZSTD_NONE COMP_ZLIB_C_NONE; + } } u_int diff --git a/configure.ac b/configure.ac index 0fefa421b8365..99bf500fc0cad 100644 --- a/configure.ac +++ b/configure.ac @@ -1394,6 +1394,13 @@ See http://www.gzip.org/zlib/ for details.]) ) fi +AC_ARG_WITH([libzstd], AS_HELP_STRING([--with-libzstd], [Build with libzstd.])) +AS_IF([test "x$with_libzstd" = "xyes"], + [ + PKG_CHECK_MODULES([LIBZSTD], [libzstd >= 1.3.8], [AC_DEFINE([HAVE_LIBZSTD], [1], [Use LIBZSTD])]) + ]) +AC_SUBST([LIBZSTD_LIBS]) + dnl UnixWare 2.x AC_CHECK_FUNC([strcasecmp], [], [ AC_CHECK_LIB([resolv], [strcasecmp], [LIBS="$LIBS -lresolv"]) ] diff --git a/kex.c b/kex.c index ce85f043958ed..95560756ade61 100644 --- a/kex.c +++ b/kex.c @@ -805,6 +805,11 @@ choose_comp(struct sshcomp *comp, char *client, char *server) comp->type = COMP_ZLIB; } else #endif /* WITH_ZLIB */ +#ifdef HAVE_LIBZSTD + if (strcmp(name, "zstd@xxxxxxxxxxx") == 0) { + comp->type = COMP_ZSTD; + } else +#endif /* HAVE_LIBZSTD */ if (strcmp(name, "none") == 0) { comp->type = COMP_NONE; } else { diff --git a/kex.h b/kex.h index a5ae6ac050a78..5efe146d796c6 100644 --- a/kex.h +++ b/kex.h @@ -68,6 +68,9 @@ /* pre-auth compression (COMP_ZLIB) is only supported in the client */ #define COMP_ZLIB 1 #define COMP_DELAYED 2 +#define COMP_ZSTD 3 +#define COMP_ALL_C 4 +#define COMP_ALL_S 5 #define CURVE25519_SIZE 32 diff --git a/myproposal.h b/myproposal.h index 5312e60581ced..4840bef213584 100644 --- a/myproposal.h +++ b/myproposal.h @@ -89,7 +89,7 @@ "rsa-sha2-512," \ "rsa-sha2-256" -#define KEX_DEFAULT_COMP "none,zlib@xxxxxxxxxxx" +#define KEX_DEFAULT_COMP "none,zstd@xxxxxxxxxxx,zlib@xxxxxxxxxxx" #define KEX_DEFAULT_LANG "" #define KEX_CLIENT \ diff --git a/packet.c b/packet.c index 6d3e9172db6cd..cdae1401196eb 100644 --- a/packet.c +++ b/packet.c @@ -79,6 +79,9 @@ #ifdef WITH_ZLIB #include <zlib.h> #endif +#ifdef HAVE_LIBZSTD +#include <zstd.h> +#endif #include "xmalloc.h" #include "compat.h" @@ -156,6 +159,14 @@ struct session_state { /* Incoming/outgoing compression dictionaries */ z_stream compression_in_stream; z_stream compression_out_stream; +#endif +#ifdef HAVE_LIBZSTD + ZSTD_DCtx *compression_zstd_in_stream; + ZSTD_CCtx *compression_zstd_out_stream; + u_int64_t compress_zstd_in_raw; + u_int64_t compress_zstd_in_comp; + u_int64_t compress_zstd_out_raw; + u_int64_t compress_zstd_out_comp; #endif int compression_in_started; int compression_out_started; @@ -613,11 +624,11 @@ ssh_packet_close_internal(struct ssh *ssh, int do_close) state->newkeys[mode] = NULL; ssh_clear_newkeys(ssh, mode); /* next keys */ } -#ifdef WITH_ZLIB /* compression state is in shared mem, so we can only release it once */ if (do_close && state->compression_buffer) { sshbuf_free(state->compression_buffer); - if (state->compression_out_started) { +#ifdef WITH_ZLIB + if (state->compression_out_started == COMP_ZLIB) { z_streamp stream = &state->compression_out_stream; debug("compress outgoing: " "raw data %llu, compressed %llu, factor %.2f", @@ -628,7 +639,7 @@ ssh_packet_close_internal(struct ssh *ssh, int do_close) if (state->compression_out_failures == 0) deflateEnd(stream); } - if (state->compression_in_started) { + if (state->compression_in_started == COMP_ZLIB) { z_streamp stream = &state->compression_in_stream; debug("compress incoming: " "raw data %llu, compressed %llu, factor %.2f", @@ -639,8 +650,28 @@ ssh_packet_close_internal(struct ssh *ssh, int do_close) if (state->compression_in_failures == 0) inflateEnd(stream); } +#endif /* WITH_ZLIB */ +#ifdef HAVE_LIBZSTD + if (state->compression_out_started == COMP_ZSTD) { + debug("compress outgoing: " + "raw data %llu, compressed %llu, factor %.2f", + (unsigned long long)state->compress_zstd_out_raw, + (unsigned long long)state->compress_zstd_out_comp, + state->compress_zstd_out_raw == 0 ? 0.0 : + (double) state->compress_zstd_out_comp / + state->compress_zstd_out_raw); + } + if (state->compression_in_started == COMP_ZSTD) { + debug("compress incoming: " + "raw data %llu, compressed %llu, factor %.2f", + (unsigned long long)state->compress_zstd_in_raw, + (unsigned long long)state->compress_zstd_in_comp, + state->compress_zstd_in_raw == 0 ? 0.0 : + (double) state->compress_zstd_in_comp / + state->compress_zstd_in_raw); + } +#endif /* HAVE_LIBZSTD */ } -#endif /* WITH_ZLIB */ cipher_free(state->send_context); cipher_free(state->receive_context); state->send_context = state->receive_context = NULL; @@ -703,11 +734,11 @@ start_compression_out(struct ssh *ssh, int level) if (level < 1 || level > 9) return SSH_ERR_INVALID_ARGUMENT; debug("Enabling compression at level %d.", level); - if (ssh->state->compression_out_started == 1) + if (ssh->state->compression_out_started == COMP_ZLIB) deflateEnd(&ssh->state->compression_out_stream); switch (deflateInit(&ssh->state->compression_out_stream, level)) { case Z_OK: - ssh->state->compression_out_started = 1; + ssh->state->compression_out_started = COMP_ZLIB; break; case Z_MEM_ERROR: return SSH_ERR_ALLOC_FAIL; @@ -720,11 +751,11 @@ start_compression_out(struct ssh *ssh, int level) static int start_compression_in(struct ssh *ssh) { - if (ssh->state->compression_in_started == 1) + if (ssh->state->compression_in_started == COMP_ZLIB) inflateEnd(&ssh->state->compression_in_stream); switch (inflateInit(&ssh->state->compression_in_stream)) { case Z_OK: - ssh->state->compression_in_started = 1; + ssh->state->compression_in_started = COMP_ZLIB; break; case Z_MEM_ERROR: return SSH_ERR_ALLOC_FAIL; @@ -741,7 +772,7 @@ compress_buffer(struct ssh *ssh, struct sshbuf *in, struct sshbuf *out) u_char buf[4096]; int r, status; - if (ssh->state->compression_out_started != 1) + if (ssh->state->compression_out_started != COMP_ZLIB) return SSH_ERR_INTERNAL_ERROR; /* This case is not handled below. */ @@ -787,7 +818,7 @@ uncompress_buffer(struct ssh *ssh, struct sshbuf *in, struct sshbuf *out) u_char buf[4096]; int r, status; - if (ssh->state->compression_in_started != 1) + if (ssh->state->compression_in_started != COMP_ZLIB) return SSH_ERR_INTERNAL_ERROR; if ((ssh->state->compression_in_stream.next_in = @@ -855,6 +886,147 @@ uncompress_buffer(struct ssh *ssh, struct sshbuf *in, struct sshbuf *out) } #endif /* WITH_ZLIB */ +#ifdef HAVE_LIBZSTD +static int +start_compression_zstd_out(struct ssh *ssh) +{ + debug("Enabling ZSTD compression."); + if (ssh->state->compression_out_started == COMP_ZSTD) + ZSTD_CCtx_reset(ssh->state->compression_zstd_out_stream, ZSTD_reset_session_only); + if (!ssh->state->compression_zstd_out_stream) { + ssh->state->compression_zstd_out_stream = ZSTD_createCCtx(); + ZSTD_CCtx_setParameter(ssh->state->compression_zstd_out_stream, + ZSTD_c_compressionLevel, 18); + } + if (!ssh->state->compression_zstd_out_stream) + return SSH_ERR_ALLOC_FAIL; + ssh->state->compression_out_started = COMP_ZSTD; + return 0; +} + +static int +start_compression_zstd_in(struct ssh *ssh) +{ + if (ssh->state->compression_in_started == COMP_ZSTD) + ZSTD_DCtx_reset(ssh->state->compression_zstd_in_stream, ZSTD_reset_session_only); + if (!ssh->state->compression_zstd_in_stream) + ssh->state->compression_zstd_in_stream = ZSTD_createDCtx(); + if (!ssh->state->compression_zstd_in_stream) + return SSH_ERR_ALLOC_FAIL; + + ssh->state->compression_in_started = COMP_ZSTD; + return 0; +} + +static int +compress_buffer_zstd(struct ssh *ssh, struct sshbuf *in, struct sshbuf *out) +{ + u_char buf[4096]; + ZSTD_inBuffer in_buff; + ZSTD_outBuffer out_buff; + int r, comp; + + if (ssh->state->compression_out_started != COMP_ZSTD) + return SSH_ERR_INTERNAL_ERROR; + + /* This case is not handled below. */ + if (sshbuf_len(in) == 0) + return 0; + + /* Input is the contents of the input buffer. */ + in_buff.src = sshbuf_mutable_ptr(in); + if (!in_buff.src) + return SSH_ERR_INTERNAL_ERROR; + in_buff.size = sshbuf_len(in); + in_buff.pos = 0; + + ssh->state->compress_zstd_out_raw += in_buff.size; + out_buff.dst = buf; + out_buff.size = sizeof(buf); + + /* Loop compressing until deflate() returns with avail_out != 0. */ + do { + /* Set up fixed-size output buffer. */ + out_buff.pos = 0; + + /* Compress as much data into the buffer as possible. */ + comp = ZSTD_compressStream2(ssh->state->compression_zstd_out_stream, + &out_buff, &in_buff, ZSTD_e_flush); + if (ZSTD_isError(comp)) + return SSH_ERR_ALLOC_FAIL; + /* Append compressed data to output_buffer. */ + r = sshbuf_put(out, buf, out_buff.pos); + if (r != 0) + return r; + ssh->state->compress_zstd_out_comp += out_buff.pos; + } while (comp > 0); + return 0; +} + +static int uncompress_buffer_zstd(struct ssh *ssh, struct sshbuf *in, + struct sshbuf *out) +{ + u_char buf[4096]; + ZSTD_inBuffer in_buff; + ZSTD_outBuffer out_buff; + int r, decomp; + + if (ssh->state->compression_in_started != COMP_ZSTD) + return SSH_ERR_INTERNAL_ERROR; + + in_buff.src = sshbuf_mutable_ptr(in); + if (in_buff.src == NULL) + return SSH_ERR_INTERNAL_ERROR; + in_buff.size = sshbuf_len(in); + in_buff.pos = 0; + ssh->state->compress_zstd_in_comp += in_buff.size; + for (;;) { + /* Set up fixed-size output buffer. */ + out_buff.dst = buf; + out_buff.size = sizeof(buf); + out_buff.pos = 0; + + decomp = ZSTD_decompressStream(ssh->state->compression_zstd_in_stream, + &out_buff, &in_buff); + if (ZSTD_isError(decomp)) + return SSH_ERR_INVALID_FORMAT; + + r = sshbuf_put(out, buf, out_buff.pos); + if (r != 0) + return r; + ssh->state->compress_zstd_in_raw += out_buff.pos; + if (in_buff.size == in_buff.pos && + out_buff.pos < sizeof(buf)) + return 0; + } +} +#else /* HAVE_LIBZSTD */ + +static int +start_compression_zstd_out(struct ssh *ssh) +{ + return SSH_ERR_INTERNAL_ERROR; +} + +static int +start_compression_zstd_in(struct ssh *ssh) +{ + return SSH_ERR_INTERNAL_ERROR; +} + +static int +compress_buffer_zstd(struct ssh *ssh, struct sshbuf *in, struct sshbuf *out) +{ + return SSH_ERR_INTERNAL_ERROR; +} + +static int +uncompress_buffer_zstd(struct ssh *ssh, struct sshbuf *in, struct sshbuf *out) +{ + return SSH_ERR_INTERNAL_ERROR; +} +#endif /* HAVE_LIBZSTD */ + void ssh_clear_newkeys(struct ssh *ssh, int mode) { @@ -931,18 +1103,30 @@ ssh_set_newkeys(struct ssh *ssh, int mode) explicit_bzero(enc->key, enc->key_len); explicit_bzero(mac->key, mac->key_len); */ if ((comp->type == COMP_ZLIB || - (comp->type == COMP_DELAYED && + ((comp->type == COMP_DELAYED || + comp->type == COMP_ZSTD) && state->after_authentication)) && comp->enabled == 0) { if ((r = ssh_packet_init_compression(ssh)) < 0) return r; - if (mode == MODE_OUT) { - if ((r = start_compression_out(ssh, 6)) != 0) - return r; + if (comp->type == COMP_ZSTD) { + if (mode == MODE_OUT) { + if ((r = start_compression_zstd_out(ssh)) != 0) + return r; + } else { + if ((r = start_compression_zstd_in(ssh)) != 0) + return r; + } + comp->enabled = COMP_ZSTD; } else { - if ((r = start_compression_in(ssh)) != 0) - return r; + if (mode == MODE_OUT) { + if ((r = start_compression_out(ssh, 6)) != 0) + return r; + } else { + if ((r = start_compression_in(ssh)) != 0) + return r; + } + comp->enabled = COMP_ZLIB; } - comp->enabled = 1; } /* * The 2^(blocksize*2) limit is too expensive for 3DES, @@ -1020,6 +1204,7 @@ ssh_packet_enable_delayed_compress(struct ssh *ssh) struct session_state *state = ssh->state; struct sshcomp *comp = NULL; int r, mode; + int type = 0; /* * Remember that we are past the authentication step, so rekeying @@ -1031,17 +1216,33 @@ ssh_packet_enable_delayed_compress(struct ssh *ssh) if (state->newkeys[mode] == NULL) continue; comp = &state->newkeys[mode]->comp; - if (comp && !comp->enabled && comp->type == COMP_DELAYED) { - if ((r = ssh_packet_init_compression(ssh)) != 0) + if (comp && !comp->enabled && comp->type) + type = comp->type; + if (type == COMP_DELAYED || type == COMP_ZSTD) { + if ((r = ssh_packet_init_compression(ssh)) != 0) { return r; - if (mode == MODE_OUT) { - if ((r = start_compression_out(ssh, 6)) != 0) - return r; - } else { - if ((r = start_compression_in(ssh)) != 0) - return r; } - comp->enabled = 1; + if (type == COMP_DELAYED) { + if (mode == MODE_OUT) { + if ((r = start_compression_out(ssh, 6)) != 0) + return r; + } else { + if ((r = start_compression_in(ssh)) != 0) + return r; + } + comp->enabled = COMP_ZLIB; + } else if (type == COMP_ZSTD) { + if (mode == MODE_OUT) { + if ((r = start_compression_zstd_out(ssh)) != 0) + return r; + } else { + if ((r = start_compression_zstd_in(ssh)) != 0) + return r; + } + comp->enabled = COMP_ZSTD; + } else { + return SSH_ERR_INTERNAL_ERROR; + } } } return 0; @@ -1102,9 +1303,15 @@ ssh_packet_send2_wrapped(struct ssh *ssh) if ((r = sshbuf_consume(state->outgoing_packet, 5)) != 0) goto out; sshbuf_reset(state->compression_buffer); - if ((r = compress_buffer(ssh, state->outgoing_packet, - state->compression_buffer)) != 0) - goto out; + if (comp->enabled == COMP_ZSTD) { + if ((r = compress_buffer_zstd(ssh, state->outgoing_packet, + state->compression_buffer)) != 0) + goto out; + } else { + if ((r = compress_buffer(ssh, state->outgoing_packet, + state->compression_buffer)) != 0) + goto out; + } sshbuf_reset(state->outgoing_packet); if ((r = sshbuf_put(state->outgoing_packet, "\0\0\0\0\0", 5)) != 0 || @@ -1663,9 +1870,15 @@ ssh_packet_read_poll2(struct ssh *ssh, u_char *typep, u_int32_t *seqnr_p) sshbuf_len(state->incoming_packet))); if (comp && comp->enabled) { sshbuf_reset(state->compression_buffer); - if ((r = uncompress_buffer(ssh, state->incoming_packet, - state->compression_buffer)) != 0) - goto out; + if (comp->enabled == COMP_ZSTD) { + if ((r = uncompress_buffer_zstd(ssh, state->incoming_packet, + state->compression_buffer)) != 0) + goto out; + } else { + if ((r = uncompress_buffer(ssh, state->incoming_packet, + state->compression_buffer)) != 0) + goto out; + } sshbuf_reset(state->incoming_packet); if ((r = sshbuf_putb(state->incoming_packet, state->compression_buffer)) != 0) diff --git a/readconf.c b/readconf.c index f3cac6b3a89d2..66c0e08ac15b2 100644 --- a/readconf.c +++ b/readconf.c @@ -838,8 +838,14 @@ static const struct multistate multistate_canonicalizehostname[] = { { NULL, -1 } }; static const struct multistate multistate_compression[] = { +#if defined(WITH_ZLIB) || defined(HAVE_LIBZSTD) + { "yes", COMP_ALL_C }, +#endif #ifdef WITH_ZLIB - { "yes", COMP_ZLIB }, + { "zlib", COMP_ZLIB }, +#endif +#ifdef HAVE_LIBZSTD + { "zstd", COMP_ZSTD }, #endif { "no", COMP_NONE }, { NULL, -1 } diff --git a/servconf.c b/servconf.c index 31699254e6c41..08308362a2da3 100644 --- a/servconf.c +++ b/servconf.c @@ -392,11 +392,7 @@ fill_default_server_options(ServerOptions *options) options->permit_user_env_whitelist = NULL; } if (options->compression == -1) -#ifdef WITH_ZLIB - options->compression = COMP_DELAYED; -#else - options->compression = COMP_NONE; -#endif + options->compression = COMP_ALL_S; if (options->rekey_limit == -1) options->rekey_limit = 0; @@ -1219,9 +1215,15 @@ static const struct multistate multistate_permitrootlogin[] = { { NULL, -1 } }; static const struct multistate multistate_compression[] = { +#if defined(WITH_ZLIB) || defined(HAVE_LIBZSTD) + { "yes", COMP_ALL_S }, +#endif #ifdef WITH_ZLIB - { "yes", COMP_DELAYED }, { "delayed", COMP_DELAYED }, + { "zlib", COMP_DELAYED }, +#endif +#ifdef HAVE_LIBZSTD + { "zstd", COMP_ZSTD }, #endif { "no", COMP_NONE }, { NULL, -1 } diff --git a/ssh.c b/ssh.c index 15aee569e55d9..7ef3d10bb6d54 100644 --- a/ssh.c +++ b/ssh.c @@ -975,8 +975,8 @@ main(int ac, char **av) break; case 'C': -#ifdef WITH_ZLIB - options.compression = 1; +#if defined(HAVE_LIBZSTD) || defined(WITH_ZLIB) + options.compression = COMP_ALL_C; #else error("Compression not supported, disabling."); #endif -- 2.26.0.rc2 _______________________________________________ openssh-unix-dev mailing list openssh-unix-dev@xxxxxxxxxxx https://lists.mindrot.org/mailman/listinfo/openssh-unix-dev