wg_noise: avoid handshake/keypair type confusion

So the last change broke consuming responses, as it may return an
invalid remote pointer. Thanks for the catch zx2c4. We just pass a flag
"lookup_keypair" which will lookup the keypair when we want (for cookie)
and will not when we don't (for consuming responses).

It would be possible to merge both noise_remote_index_lookup and
noise_keypair_lookup, but the result would probably need to return a
void * (for both keypair and remote) or a noise_index * which would need
to be cast to the relevant type somewhere. The trickiest thing here
would be for if_wg to "put" the result of the function, as it may be a
remote or a keypair (which store their refcount in different locations).
Perhaps it would return a noise_index * which could contain the refcount
for both keypair and remote. It all seems easier to leave them separate.
The only argument for combining them would be to reduce duplication of
(similar) functions.

Signed-off-by: Matt Dunwoodie <ncon@noconroy.net>
This commit is contained in:
Matt Dunwoodie 2021-04-20 15:35:58 +10:00
parent 123c24e6af
commit 87a62f1322
3 changed files with 16 additions and 7 deletions

View File

@ -1356,7 +1356,7 @@ wg_handshake(struct wg_softc *sc, struct wg_packet *pkt)
case WG_PKT_COOKIE:
cook = mtod(m, struct wg_pkt_cookie *);
if ((remote = noise_remote_index_lookup(sc->sc_local, cook->r_idx)) == NULL) {
if ((remote = noise_remote_index(sc->sc_local, cook->r_idx)) == NULL) {
DPRINTF(sc, "Unknown cookie index\n");
goto error;
}

View File

@ -127,6 +127,8 @@ struct noise_local {
static void noise_precompute_ss(struct noise_local *, struct noise_remote *);
static void noise_remote_index_insert(struct noise_local *, struct noise_remote *);
static struct noise_remote *
noise_remote_index_lookup(struct noise_local *, uint32_t, int);
static int noise_remote_index_remove(struct noise_local *, struct noise_remote *);
static void noise_remote_expire_current(struct noise_remote *);
@ -397,8 +399,8 @@ assign_id:
r->r_handshake_alive = 1;
}
struct noise_remote *
noise_remote_index_lookup(struct noise_local *l, uint32_t idx0)
static struct noise_remote *
noise_remote_index_lookup(struct noise_local *l, uint32_t idx0, int lookup_keypair)
{
struct epoch_tracker et;
struct noise_index *i;
@ -409,11 +411,13 @@ noise_remote_index_lookup(struct noise_local *l, uint32_t idx0)
NET_EPOCH_ENTER(et);
CK_LIST_FOREACH(i, &l->l_index_hash[idx], i_entry) {
if (i->i_local_index == idx0) {
if (i->i_is_keypair) {
if (!i->i_is_keypair) {
r = (struct noise_remote *) i;
} else if (lookup_keypair) {
kp = (struct noise_keypair *) i;
r = kp->kp_remote;
} else {
r = (struct noise_remote *) i;
break;
}
if (refcount_acquire_if_not_zero(&r->r_refcnt))
ret = r;
@ -424,6 +428,11 @@ noise_remote_index_lookup(struct noise_local *l, uint32_t idx0)
return (ret);
}
struct noise_remote *
noise_remote_index(struct noise_local *l, uint32_t idx) {
return noise_remote_index_lookup(l, idx, 1);
}
static int
noise_remote_index_remove(struct noise_local *l, struct noise_remote *r)
{
@ -1093,7 +1102,7 @@ noise_consume_response(struct noise_local *l, struct noise_remote **rp,
struct noise_remote *r = NULL;
int ret = EINVAL;
if ((r = noise_remote_index_lookup(l, r_idx)) == NULL)
if ((r = noise_remote_index_lookup(l, r_idx, 0)) == NULL)
return (ret);
rw_rlock(&l->l_identity_lock);

View File

@ -51,7 +51,7 @@ void noise_remote_disable(struct noise_remote *);
struct noise_remote *
noise_remote_lookup(struct noise_local *, const uint8_t[NOISE_PUBLIC_KEY_LEN]);
struct noise_remote *
noise_remote_index_lookup(struct noise_local *, uint32_t);
noise_remote_index(struct noise_local *, uint32_t);
struct noise_remote *
noise_remote_ref(struct noise_remote *);
void noise_remote_put(struct noise_remote *);