diff --git a/src/script/descriptor.cpp b/src/script/descriptor.cpp index 95291ef317..ac872c8cd2 100644 --- a/src/script/descriptor.cpp +++ b/src/script/descriptor.cpp @@ -160,8 +160,12 @@ public: virtual ~PubkeyProvider() = default; - /** Derive a public key. If key==nullptr, only info is desired. */ - virtual bool GetPubKey(int pos, const SigningProvider& arg, CPubKey* key, KeyOriginInfo& info) const = 0; + /** Derive a public key. + * read_cache is the cache to read keys from (if not nullptr) + * write_cache is the cache to write keys to (if not nullptr) + * Caches are not exclusive but this is not tested. Currently we use them exclusively + */ + virtual bool GetPubKey(int pos, const SigningProvider& arg, CPubKey& key, KeyOriginInfo& info, const DescriptorCache* read_cache = nullptr, DescriptorCache* write_cache = nullptr) const = 0; /** Whether this represent multiple public keys at different positions. */ virtual bool IsRange() const = 0; @@ -191,9 +195,9 @@ class OriginPubkeyProvider final : public PubkeyProvider public: OriginPubkeyProvider(uint32_t exp_index, KeyOriginInfo info, std::unique_ptr provider) : PubkeyProvider(exp_index), m_origin(std::move(info)), m_provider(std::move(provider)) {} - bool GetPubKey(int pos, const SigningProvider& arg, CPubKey* key, KeyOriginInfo& info) const override + bool GetPubKey(int pos, const SigningProvider& arg, CPubKey& key, KeyOriginInfo& info, const DescriptorCache* read_cache = nullptr, DescriptorCache* write_cache = nullptr) const override { - if (!m_provider->GetPubKey(pos, arg, key, info)) return false; + if (!m_provider->GetPubKey(pos, arg, key, info, read_cache, write_cache)) return false; std::copy(std::begin(m_origin.fingerprint), std::end(m_origin.fingerprint), info.fingerprint); info.path.insert(info.path.begin(), m_origin.path.begin(), m_origin.path.end()); return true; @@ -221,9 +225,9 @@ class ConstPubkeyProvider final : public PubkeyProvider public: ConstPubkeyProvider(uint32_t exp_index, const CPubKey& pubkey) : PubkeyProvider(exp_index), m_pubkey(pubkey) {} - bool GetPubKey(int pos, const SigningProvider& arg, CPubKey* key, KeyOriginInfo& info) const override + bool GetPubKey(int pos, const SigningProvider& arg, CPubKey& key, KeyOriginInfo& info, const DescriptorCache* read_cache = nullptr, DescriptorCache* write_cache = nullptr) const override { - if (key) *key = m_pubkey; + key = m_pubkey; info.path.clear(); CKeyID keyid = m_pubkey.GetID(); std::copy(keyid.begin(), keyid.begin() + sizeof(info.fingerprint), info.fingerprint); @@ -271,6 +275,16 @@ class BIP32PubkeyProvider final : public PubkeyProvider return true; } + // Derives the last xprv + bool GetDerivedExtKey(const SigningProvider& arg, CExtKey& xprv) const + { + if (!GetExtKey(arg, xprv)) return false; + for (auto entry : m_path) { + xprv.Derive(xprv, entry); + } + return true; + } + bool IsHardened() const { if (m_derive == DeriveType::HARDENED) return true; @@ -284,29 +298,47 @@ public: BIP32PubkeyProvider(uint32_t exp_index, const CExtPubKey& extkey, KeyPath path, DeriveType derive) : PubkeyProvider(exp_index), m_root_extkey(extkey), m_path(std::move(path)), m_derive(derive) {} bool IsRange() const override { return m_derive != DeriveType::NO; } size_t GetSize() const override { return 33; } - bool GetPubKey(int pos, const SigningProvider& arg, CPubKey* key, KeyOriginInfo& info) const override + bool GetPubKey(int pos, const SigningProvider& arg, CPubKey& key_out, KeyOriginInfo& final_info_out, const DescriptorCache* read_cache = nullptr, DescriptorCache* write_cache = nullptr) const override { - if (key) { - if (IsHardened()) { - CKey priv_key; - if (!GetPrivKey(pos, arg, priv_key)) return false; - *key = priv_key.GetPubKey(); - } else { - // TODO: optimize by caching - CExtPubKey extkey = m_root_extkey; - for (auto entry : m_path) { - extkey.Derive(extkey, entry); - } - if (m_derive == DeriveType::UNHARDENED) extkey.Derive(extkey, pos); - assert(m_derive != DeriveType::HARDENED); - *key = extkey.pubkey; - } - } + // Info of parent of the to be derived pubkey + KeyOriginInfo parent_info; CKeyID keyid = m_root_extkey.pubkey.GetID(); - std::copy(keyid.begin(), keyid.begin() + sizeof(info.fingerprint), info.fingerprint); - info.path = m_path; - if (m_derive == DeriveType::UNHARDENED) info.path.push_back((uint32_t)pos); - if (m_derive == DeriveType::HARDENED) info.path.push_back(((uint32_t)pos) | 0x80000000L); + std::copy(keyid.begin(), keyid.begin() + sizeof(parent_info.fingerprint), parent_info.fingerprint); + parent_info.path = m_path; + + // Info of the derived key itself which is copied out upon successful completion + KeyOriginInfo final_info_out_tmp = parent_info; + if (m_derive == DeriveType::UNHARDENED) final_info_out_tmp.path.push_back((uint32_t)pos); + if (m_derive == DeriveType::HARDENED) final_info_out_tmp.path.push_back(((uint32_t)pos) | 0x80000000L); + + // Derive keys or fetch them from cache + CExtPubKey final_extkey = m_root_extkey; + bool der = true; + if (read_cache) { + if (!read_cache->GetCachedDerivedExtPubKey(m_expr_index, pos, final_extkey)) return false; + } else if (IsHardened()) { + CExtKey xprv; + if (!GetDerivedExtKey(arg, xprv)) return false; + if (m_derive == DeriveType::UNHARDENED) der = xprv.Derive(xprv, pos); + if (m_derive == DeriveType::HARDENED) der = xprv.Derive(xprv, pos | 0x80000000UL); + final_extkey = xprv.Neuter(); + } else { + for (auto entry : m_path) { + der = final_extkey.Derive(final_extkey, entry); + assert(der); + } + if (m_derive == DeriveType::UNHARDENED) der = final_extkey.Derive(final_extkey, pos); + assert(m_derive != DeriveType::HARDENED); + } + assert(der); + + final_info_out = final_info_out_tmp; + key_out = final_extkey.pubkey; + + if (write_cache) { + write_cache->CacheDerivedExtPubKey(m_expr_index, pos, final_extkey); + } + return true; } std::string ToString() const override @@ -332,10 +364,7 @@ public: bool GetPrivKey(int pos, const SigningProvider& arg, CKey& key) const override { CExtKey extkey; - if (!GetExtKey(arg, extkey)) return false; - for (auto entry : m_path) { - extkey.Derive(extkey, entry); - } + if (!GetDerivedExtKey(arg, extkey)) return false; if (m_derive == DeriveType::UNHARDENED) extkey.Derive(extkey, pos); if (m_derive == DeriveType::HARDENED) extkey.Derive(extkey, pos | 0x80000000UL); key = extkey.key; @@ -434,7 +463,7 @@ public: return ret; } - bool ExpandHelper(int pos, const SigningProvider& arg, Span* cache_read, std::vector& output_scripts, FlatSigningProvider& out, std::vector* cache_write) const + bool ExpandHelper(int pos, const SigningProvider& arg, const DescriptorCache* read_cache, std::vector& output_scripts, FlatSigningProvider& out, DescriptorCache* write_cache) const { std::vector> entries; entries.reserve(m_pubkey_args.size()); @@ -442,27 +471,12 @@ public: // Construct temporary data in `entries` and `subscripts`, to avoid producing output in case of failure. for (const auto& p : m_pubkey_args) { entries.emplace_back(); - // If we have a cache, we don't need GetPubKey to compute the public key. - // Pass in nullptr to signify only origin info is desired. - if (!p->GetPubKey(pos, arg, cache_read ? nullptr : &entries.back().first, entries.back().second)) return false; - if (cache_read) { - // Cached expanded public key exists, use it. - if (cache_read->size() == 0) return false; - bool compressed = ((*cache_read)[0] == 0x02 || (*cache_read)[0] == 0x03) && cache_read->size() >= 33; - bool uncompressed = ((*cache_read)[0] == 0x04) && cache_read->size() >= 65; - if (!(compressed || uncompressed)) return false; - CPubKey pubkey(cache_read->begin(), cache_read->begin() + (compressed ? 33 : 65)); - entries.back().first = pubkey; - *cache_read = cache_read->subspan(compressed ? 33 : 65); - } - if (cache_write) { - cache_write->insert(cache_write->end(), entries.back().first.begin(), entries.back().first.end()); - } + if (!p->GetPubKey(pos, arg, entries.back().first, entries.back().second, read_cache, write_cache)) return false; } std::vector subscripts; if (m_subdescriptor_arg) { FlatSigningProvider subprovider; - if (!m_subdescriptor_arg->ExpandHelper(pos, arg, cache_read, subscripts, subprovider, cache_write)) return false; + if (!m_subdescriptor_arg->ExpandHelper(pos, arg, read_cache, subscripts, subprovider, write_cache)) return false; out = Merge(out, subprovider); } @@ -486,15 +500,14 @@ public: return true; } - bool Expand(int pos, const SigningProvider& provider, std::vector& output_scripts, FlatSigningProvider& out, std::vector* cache = nullptr) const final + bool Expand(int pos, const SigningProvider& provider, std::vector& output_scripts, FlatSigningProvider& out, DescriptorCache* write_cache = nullptr) const final { - return ExpandHelper(pos, provider, nullptr, output_scripts, out, cache); + return ExpandHelper(pos, provider, nullptr, output_scripts, out, write_cache); } - bool ExpandFromCache(int pos, const std::vector& cache, std::vector& output_scripts, FlatSigningProvider& out) const final + bool ExpandFromCache(int pos, const DescriptorCache& read_cache, std::vector& output_scripts, FlatSigningProvider& out) const final { - Span span = MakeSpan(cache); - return ExpandHelper(pos, DUMMY_SIGNING_PROVIDER, &span, output_scripts, out, nullptr) && span.size() == 0; + return ExpandHelper(pos, DUMMY_SIGNING_PROVIDER, &read_cache, output_scripts, out, nullptr); } void ExpandPrivate(int pos, const SigningProvider& provider, FlatSigningProvider& out) const final diff --git a/src/script/descriptor.h b/src/script/descriptor.h index 5c686d68c1..34cd5760de 100644 --- a/src/script/descriptor.h +++ b/src/script/descriptor.h @@ -96,18 +96,18 @@ struct Descriptor { * @param[in] provider The provider to query for private keys in case of hardened derivation. * @param[out] output_scripts The expanded scriptPubKeys. * @param[out] out Scripts and public keys necessary for solving the expanded scriptPubKeys (may be equal to `provider`). - * @param[out] cache Cache data necessary to evaluate the descriptor at this point without access to private keys. + * @param[out] write_cache Cache data necessary to evaluate the descriptor at this point without access to private keys. */ - virtual bool Expand(int pos, const SigningProvider& provider, std::vector& output_scripts, FlatSigningProvider& out, std::vector* cache = nullptr) const = 0; + virtual bool Expand(int pos, const SigningProvider& provider, std::vector& output_scripts, FlatSigningProvider& out, DescriptorCache* write_cache = nullptr) const = 0; /** Expand a descriptor at a specified position using cached expansion data. * * @param[in] pos The position at which to expand the descriptor. If IsRange() is false, this is ignored. - * @param[in] cache Cached expansion data. + * @param[in] read_cache Cached expansion data. * @param[out] output_scripts The expanded scriptPubKeys. * @param[out] out Scripts and public keys necessary for solving the expanded scriptPubKeys (may be equal to `provider`). */ - virtual bool ExpandFromCache(int pos, const std::vector& cache, std::vector& output_scripts, FlatSigningProvider& out) const = 0; + virtual bool ExpandFromCache(int pos, const DescriptorCache& read_cache, std::vector& output_scripts, FlatSigningProvider& out) const = 0; /** Expand the private key for a descriptor at a specified position, if possible. * diff --git a/src/test/descriptor_tests.cpp b/src/test/descriptor_tests.cpp index e99aee724e..f20a6bcd1d 100644 --- a/src/test/descriptor_tests.cpp +++ b/src/test/descriptor_tests.cpp @@ -135,14 +135,14 @@ void DoCheck(const std::string& prv, const std::string& pub, int flags, const st // Evaluate the descriptor selected by `t` in poisition `i`. FlatSigningProvider script_provider, script_provider_cached; std::vector spks, spks_cached; - std::vector cache; - BOOST_CHECK((t ? parse_priv : parse_pub)->Expand(i, key_provider, spks, script_provider, &cache)); + DescriptorCache desc_cache; + BOOST_CHECK((t ? parse_priv : parse_pub)->Expand(i, key_provider, spks, script_provider, &desc_cache)); // Compare the output with the expected result. BOOST_CHECK_EQUAL(spks.size(), ref.size()); // Try to expand again using cached data, and compare. - BOOST_CHECK(parse_pub->ExpandFromCache(i, cache, spks_cached, script_provider_cached)); + BOOST_CHECK(parse_pub->ExpandFromCache(i, desc_cache, spks_cached, script_provider_cached)); BOOST_CHECK(spks == spks_cached); BOOST_CHECK(script_provider.pubkeys == script_provider_cached.pubkeys); BOOST_CHECK(script_provider.scripts == script_provider_cached.scripts);