diff --git a/packages/snaps-rpc-methods/jest.config.js b/packages/snaps-rpc-methods/jest.config.js index 7ac6aa730a..9d6aa634ac 100644 --- a/packages/snaps-rpc-methods/jest.config.js +++ b/packages/snaps-rpc-methods/jest.config.js @@ -10,10 +10,10 @@ module.exports = deepmerge(baseConfig, { ], coverageThreshold: { global: { - branches: 96.68, - functions: 99.2, - lines: 99.06, - statements: 98.78, + branches: 97.28, + functions: 98.84, + lines: 99.14, + statements: 98.81, }, }, }); diff --git a/packages/snaps-rpc-methods/src/permissions.test.ts b/packages/snaps-rpc-methods/src/permissions.test.ts index 87b4120ad4..013f5587a3 100644 --- a/packages/snaps-rpc-methods/src/permissions.test.ts +++ b/packages/snaps-rpc-methods/src/permissions.test.ts @@ -1,3 +1,5 @@ +import { Messenger } from '@metamask/messenger'; + import { buildSnapEndowmentSpecifications, buildSnapRestrictedMethodSpecifications, @@ -193,7 +195,11 @@ describe('buildSnapEndowmentSpecifications', () => { describe('buildSnapRestrictedMethodSpecifications', () => { it('returns the expected object', () => { - const specifications = buildSnapRestrictedMethodSpecifications([], {}); + const specifications = buildSnapRestrictedMethodSpecifications( + [], + {}, + new Messenger({ namespace: 'SnapsRestrictedMethods' }), + ); expect(specifications).toMatchInlineSnapshot(` { "snap_dialog": { diff --git a/packages/snaps-rpc-methods/src/permissions.ts b/packages/snaps-rpc-methods/src/permissions.ts index 6c0e6d43dc..bed80afdec 100644 --- a/packages/snaps-rpc-methods/src/permissions.ts +++ b/packages/snaps-rpc-methods/src/permissions.ts @@ -1,6 +1,8 @@ -import type { - PermissionConstraint, - PermissionSpecificationConstraint, +import { selectHooks } from '@metamask/json-rpc-engine/v2'; +import { + createRestrictedMethodMessenger, + type PermissionConstraint, + type PermissionSpecificationConstraint, } from '@metamask/permission-controller'; import type { SnapPermissions } from '@metamask/snaps-utils'; import { hasProperty } from '@metamask/utils'; @@ -9,11 +11,14 @@ import { endowmentCaveatMappers, endowmentPermissionBuilders, } from './endowments'; +import type { + RestrictedMethodActions, + RestrictedMethodMessenger, +} from './restricted'; import { caveatMappers, restrictedMethodPermissionBuilders, } from './restricted'; -import { selectHooks } from './utils'; /** * Map initial permissions as defined in a Snap manifest to something that can @@ -64,18 +69,28 @@ export const buildSnapEndowmentSpecifications = ( export const buildSnapRestrictedMethodSpecifications = ( excludedPermissions: string[], hooks: Record, + messenger: RestrictedMethodMessenger, ) => Object.values(restrictedMethodPermissionBuilders).reduce< Record - >((specifications, { targetName, specificationBuilder, methodHooks }) => { - if (!excludedPermissions.includes(targetName)) { - specifications[targetName] = specificationBuilder({ - // @ts-expect-error The selectHooks type is wonky - methodHooks: selectHooks( - hooks, - methodHooks, - ) as Pick, - }); - } - return specifications; - }, {}); + >( + ( + specifications, + { targetName, specificationBuilder, methodHooks, actionNames }, + ) => { + if (!excludedPermissions.includes(targetName)) { + specifications[targetName] = specificationBuilder({ + methodHooks: selectHooks(hooks, methodHooks), + messenger: createRestrictedMethodMessenger({ + namespace: targetName, + rootMessenger: messenger, + actionNames: actionNames as readonly [ + RestrictedMethodActions['type'], + ], + }), + }); + } + return specifications; + }, + {}, + ); diff --git a/packages/snaps-rpc-methods/src/restricted/dialog.test.tsx b/packages/snaps-rpc-methods/src/restricted/dialog.test.tsx index 2b295af243..bb4ca52c2f 100644 --- a/packages/snaps-rpc-methods/src/restricted/dialog.test.tsx +++ b/packages/snaps-rpc-methods/src/restricted/dialog.test.tsx @@ -1,9 +1,12 @@ +import type { MockAnyNamespace } from '@metamask/messenger'; +import { MOCK_ANY_NAMESPACE, Messenger } from '@metamask/messenger'; import { PermissionType, SubjectType } from '@metamask/permission-controller'; import { rpcErrors } from '@metamask/rpc-errors'; +import type { SnapId } from '@metamask/snaps-sdk'; import { DialogType, NodeType } from '@metamask/snaps-sdk'; import { Box, Text } from '@metamask/snaps-sdk/jsx'; -import type { DialogMethodHooks } from './dialog'; +import type { DialogMessengerActions } from './dialog'; import { DIALOG_APPROVAL_TYPES, dialogBuilder, @@ -15,24 +18,19 @@ describe('builder', () => { expect(dialogBuilder).toMatchObject({ targetName: 'snap_dialog', specificationBuilder: expect.any(Function), - methodHooks: { - requestUserApproval: true, - createInterface: true, - getInterface: true, - setInterfaceDisplayed: true, - }, + actionNames: [ + 'ApprovalController:addRequest', + 'SnapInterfaceController:createInterface', + 'SnapInterfaceController:getInterface', + 'SnapInterfaceController:setInterfaceDisplayed', + ], }); }); it('builder outputs expected specification', () => { expect( dialogBuilder.specificationBuilder({ - methodHooks: { - requestUserApproval: jest.fn(), - createInterface: jest.fn(), - getInterface: jest.fn(), - setInterfaceDisplayed: jest.fn(), - }, + messenger: new Messenger({ namespace: 'Dialog' }), }), ).toStrictEqual({ permissionType: PermissionType.RestrictedMethod, @@ -45,21 +43,43 @@ describe('builder', () => { }); describe('implementation', () => { - const getMockDialogHooks = () => - ({ - requestUserApproval: jest.fn(), - createInterface: jest.fn().mockReturnValue('bar'), - getInterface: jest.fn().mockReturnValue({ + const getMessenger = () => { + const messenger = new Messenger({ + namespace: MOCK_ANY_NAMESPACE, + }); + + messenger.registerActionHandler( + 'ApprovalController:addRequest', + async () => null, + ); + + messenger.registerActionHandler( + 'SnapInterfaceController:createInterface', + () => 'bar', + ); + + messenger.registerActionHandler( + 'SnapInterfaceController:getInterface', + () => ({ content: { type: NodeType.Text as const, value: 'foo' }, state: {}, - snapId: 'foo', + snapId: 'foo' as SnapId, }), - setInterfaceDisplayed: jest.fn(), - }) as DialogMethodHooks; + ); + + messenger.registerActionHandler( + 'SnapInterfaceController:setInterfaceDisplayed', + () => null, + ); + + jest.spyOn(messenger, 'call'); + + return messenger; + }; it('accepts string dialog types', async () => { - const hooks = getMockDialogHooks(); - const implementation = getDialogImplementation(hooks); + const messenger = getMessenger(); + const implementation = getDialogImplementation({ messenger }); await implementation({ context: { origin: 'foo' }, method: 'snap_dialog', @@ -75,21 +95,25 @@ describe('implementation', () => { }, }); - expect(hooks.requestUserApproval).toHaveBeenCalledTimes(1); - expect(hooks.requestUserApproval).toHaveBeenCalledWith({ - id: undefined, - origin: 'foo', - type: DIALOG_APPROVAL_TYPES[DialogType.Alert], - requestData: { - id: 'bar', - placeholder: undefined, + expect(messenger.call).toHaveBeenCalledTimes(3); + expect(messenger.call).toHaveBeenCalledWith( + 'ApprovalController:addRequest', + { + id: undefined, + origin: 'foo', + type: DIALOG_APPROVAL_TYPES[DialogType.Alert], + requestData: { + id: 'bar', + placeholder: undefined, + }, }, - }); + true, + ); }); it('accepts no dialog type with an interface ID', async () => { - const hooks = getMockDialogHooks(); - const implementation = getDialogImplementation(hooks); + const messenger = getMessenger(); + const implementation = getDialogImplementation({ messenger }); await implementation({ context: { origin: 'foo' }, method: 'snap_dialog', @@ -98,21 +122,25 @@ describe('implementation', () => { }, }); - expect(hooks.requestUserApproval).toHaveBeenCalledTimes(1); - expect(hooks.requestUserApproval).toHaveBeenCalledWith({ - id: 'bar', - origin: 'foo', - type: DIALOG_APPROVAL_TYPES.default, - requestData: { + expect(messenger.call).toHaveBeenCalledTimes(3); + expect(messenger.call).toHaveBeenCalledWith( + 'ApprovalController:addRequest', + { id: 'bar', - placeholder: undefined, + origin: 'foo', + type: DIALOG_APPROVAL_TYPES.default, + requestData: { + id: 'bar', + placeholder: undefined, + }, }, - }); + true, + ); }); it('accepts no dialog type with content', async () => { - const hooks = getMockDialogHooks(); - const implementation = getDialogImplementation(hooks); + const messenger = getMessenger(); + const implementation = getDialogImplementation({ messenger }); const content = ( @@ -128,32 +156,30 @@ describe('implementation', () => { }, }); - expect(hooks.createInterface).toHaveBeenCalledWith('foo', content); - expect(hooks.requestUserApproval).toHaveBeenCalledTimes(1); - expect(hooks.requestUserApproval).toHaveBeenCalledWith({ - id: 'bar', - origin: 'foo', - type: DIALOG_APPROVAL_TYPES.default, - requestData: { + expect(messenger.call).toHaveBeenCalledTimes(3); + expect(messenger.call).toHaveBeenCalledWith( + 'SnapInterfaceController:createInterface', + 'foo', + content, + ); + expect(messenger.call).toHaveBeenCalledWith( + 'ApprovalController:addRequest', + { id: 'bar', - placeholder: undefined, + origin: 'foo', + type: DIALOG_APPROVAL_TYPES.default, + requestData: { + id: 'bar', + placeholder: undefined, + }, }, - }); + true, + ); }); it('gets the interface data if an interface ID is passed', async () => { - const hooks = { - requestUserApproval: jest.fn(), - createInterface: jest.fn().mockReturnValue('bar'), - getInterface: jest.fn().mockReturnValue({ - content: { type: NodeType.Text as const, value: 'foo' }, - state: {}, - snapId: 'foo', - }), - setInterfaceDisplayed: jest.fn(), - }; - - const implementation = getDialogImplementation(hooks); + const messenger = getMessenger(); + const implementation = getDialogImplementation({ messenger }); await implementation({ context: { origin: 'foo' }, @@ -164,21 +190,25 @@ describe('implementation', () => { }, }); - expect(hooks.requestUserApproval).toHaveBeenCalledTimes(1); - expect(hooks.requestUserApproval).toHaveBeenCalledWith({ - id: undefined, - origin: 'foo', - type: DIALOG_APPROVAL_TYPES[DialogType.Alert], - requestData: { - id: 'bar', - placeholder: undefined, + expect(messenger.call).toHaveBeenCalledTimes(3); + expect(messenger.call).toHaveBeenCalledWith( + 'ApprovalController:addRequest', + { + id: undefined, + origin: 'foo', + type: DIALOG_APPROVAL_TYPES[DialogType.Alert], + requestData: { + id: 'bar', + placeholder: undefined, + }, }, - }); + true, + ); }); it('creates a new interface if some content is passed', async () => { - const hooks = getMockDialogHooks(); - const implementation = getDialogImplementation(hooks); + const messenger = getMessenger(); + const implementation = getDialogImplementation({ messenger }); const content = { type: NodeType.Panel as const, @@ -197,22 +227,30 @@ describe('implementation', () => { }, }); - expect(hooks.createInterface).toHaveBeenCalledWith('foo', content); - expect(hooks.requestUserApproval).toHaveBeenCalledTimes(1); - expect(hooks.requestUserApproval).toHaveBeenCalledWith({ - id: undefined, - origin: 'foo', - type: DIALOG_APPROVAL_TYPES[DialogType.Alert], - requestData: { - id: 'bar', - placeholder: undefined, + expect(messenger.call).toHaveBeenCalledTimes(3); + expect(messenger.call).toHaveBeenCalledWith( + 'SnapInterfaceController:createInterface', + 'foo', + content, + ); + expect(messenger.call).toHaveBeenCalledWith( + 'ApprovalController:addRequest', + { + id: undefined, + origin: 'foo', + type: DIALOG_APPROVAL_TYPES[DialogType.Alert], + requestData: { + id: 'bar', + placeholder: undefined, + }, }, - }); + true, + ); }); it('creates a new interface if a JSX element is passed', async () => { - const hooks = getMockDialogHooks(); - const implementation = getDialogImplementation(hooks); + const messenger = getMessenger(); + const implementation = getDialogImplementation({ messenger }); const content = ( @@ -229,22 +267,30 @@ describe('implementation', () => { }, }); - expect(hooks.createInterface).toHaveBeenCalledWith('foo', content); - expect(hooks.requestUserApproval).toHaveBeenCalledTimes(1); - expect(hooks.requestUserApproval).toHaveBeenCalledWith({ - id: undefined, - origin: 'foo', - type: DIALOG_APPROVAL_TYPES[DialogType.Alert], - requestData: { - id: 'bar', - placeholder: undefined, + expect(messenger.call).toHaveBeenCalledTimes(3); + expect(messenger.call).toHaveBeenCalledWith( + 'SnapInterfaceController:createInterface', + 'foo', + content, + ); + expect(messenger.call).toHaveBeenCalledWith( + 'ApprovalController:addRequest', + { + id: undefined, + origin: 'foo', + type: DIALOG_APPROVAL_TYPES[DialogType.Alert], + requestData: { + id: 'bar', + placeholder: undefined, + }, }, - }); + true, + ); }); it('sets an interface as displayed', async () => { - const hooks = getMockDialogHooks(); - const implementation = getDialogImplementation(hooks); + const messenger = getMessenger(); + const implementation = getDialogImplementation({ messenger }); await implementation({ context: { origin: 'foo' }, @@ -255,13 +301,17 @@ describe('implementation', () => { }, }); - expect(hooks.setInterfaceDisplayed).toHaveBeenCalledTimes(1); - expect(hooks.setInterfaceDisplayed).toHaveBeenCalledWith('foo', 'bar'); + expect(messenger.call).toHaveBeenCalledTimes(3); + expect(messenger.call).toHaveBeenCalledWith( + 'SnapInterfaceController:setInterfaceDisplayed', + 'foo', + 'bar', + ); }); it('sets an interface as displayed if content is passed without an ID', async () => { - const hooks = getMockDialogHooks(); - const implementation = getDialogImplementation(hooks); + const messenger = getMessenger(); + const implementation = getDialogImplementation({ messenger }); await implementation({ context: { origin: 'foo' }, @@ -272,21 +322,27 @@ describe('implementation', () => { }, }); - expect(hooks.setInterfaceDisplayed).toHaveBeenCalledTimes(1); - expect(hooks.setInterfaceDisplayed).toHaveBeenCalledWith('foo', 'bar'); + expect(messenger.call).toHaveBeenCalledTimes(3); + expect(messenger.call).toHaveBeenCalledWith( + 'SnapInterfaceController:setInterfaceDisplayed', + 'foo', + 'bar', + ); }); it('throws if the requested interface does not exist.', async () => { - const hooks = { - requestUserApproval: jest.fn(), - createInterface: jest.fn(), - getInterface: jest.fn().mockImplementation((_snapId, id) => { + const messenger = new Messenger({ namespace: MOCK_ANY_NAMESPACE }); + + messenger.registerActionHandler( + 'SnapInterfaceController:getInterface', + (_origin: string, id: string) => { throw new Error(`Interface with id '${id}' not found.`); - }), - setInterfaceDisplayed: jest.fn(), - }; + }, + ); + + const spy = jest.spyOn(messenger, 'call'); - const implementation = getDialogImplementation(hooks); + const implementation = getDialogImplementation({ messenger }); await expect( implementation({ @@ -303,14 +359,18 @@ describe('implementation', () => { ), ); - expect(hooks.getInterface).toHaveBeenCalledTimes(1); - expect(hooks.getInterface).toHaveBeenCalledWith('foo', 'bar'); + expect(spy).toHaveBeenCalledTimes(1); + expect(spy).toHaveBeenCalledWith( + 'SnapInterfaceController:getInterface', + 'foo', + 'bar', + ); }); describe('alerts', () => { it('handles alerts', async () => { - const hooks = getMockDialogHooks(); - const implementation = getDialogImplementation(hooks); + const messenger = getMessenger(); + const implementation = getDialogImplementation({ messenger }); await implementation({ context: { origin: 'foo' }, method: 'snap_dialog', @@ -326,21 +386,25 @@ describe('implementation', () => { }, }); - expect(hooks.requestUserApproval).toHaveBeenCalledTimes(1); - expect(hooks.requestUserApproval).toHaveBeenCalledWith({ - id: undefined, - origin: 'foo', - type: DIALOG_APPROVAL_TYPES[DialogType.Alert], - requestData: { - id: 'bar', - placeholder: undefined, + expect(messenger.call).toHaveBeenCalledTimes(3); + expect(messenger.call).toHaveBeenCalledWith( + 'ApprovalController:addRequest', + { + id: undefined, + origin: 'foo', + type: DIALOG_APPROVAL_TYPES[DialogType.Alert], + requestData: { + id: 'bar', + placeholder: undefined, + }, }, - }); + true, + ); }); it('handles JSX alerts', async () => { - const hooks = getMockDialogHooks(); - const implementation = getDialogImplementation(hooks); + const messenger = getMessenger(); + const implementation = getDialogImplementation({ messenger }); await implementation({ context: { origin: 'foo' }, method: 'snap_dialog', @@ -354,23 +418,27 @@ describe('implementation', () => { }, }); - expect(hooks.requestUserApproval).toHaveBeenCalledTimes(1); - expect(hooks.requestUserApproval).toHaveBeenCalledWith({ - id: undefined, - origin: 'foo', - type: DIALOG_APPROVAL_TYPES[DialogType.Alert], - requestData: { - id: 'bar', - placeholder: undefined, + expect(messenger.call).toHaveBeenCalledTimes(3); + expect(messenger.call).toHaveBeenCalledWith( + 'ApprovalController:addRequest', + { + id: undefined, + origin: 'foo', + type: DIALOG_APPROVAL_TYPES[DialogType.Alert], + requestData: { + id: 'bar', + placeholder: undefined, + }, }, - }); + true, + ); }); }); describe('confirmations', () => { it('handles confirmations', async () => { - const hooks = getMockDialogHooks(); - const implementation = getDialogImplementation(hooks); + const messenger = getMessenger(); + const implementation = getDialogImplementation({ messenger }); await implementation({ context: { origin: 'foo' }, method: 'snap_dialog', @@ -386,21 +454,25 @@ describe('implementation', () => { }, }); - expect(hooks.requestUserApproval).toHaveBeenCalledTimes(1); - expect(hooks.requestUserApproval).toHaveBeenCalledWith({ - id: undefined, - origin: 'foo', - type: DIALOG_APPROVAL_TYPES[DialogType.Confirmation], - requestData: { - id: 'bar', - placeholder: undefined, + expect(messenger.call).toHaveBeenCalledTimes(3); + expect(messenger.call).toHaveBeenCalledWith( + 'ApprovalController:addRequest', + { + id: undefined, + origin: 'foo', + type: DIALOG_APPROVAL_TYPES[DialogType.Confirmation], + requestData: { + id: 'bar', + placeholder: undefined, + }, }, - }); + true, + ); }); it('handles JSX confirmations', async () => { - const hooks = getMockDialogHooks(); - const implementation = getDialogImplementation(hooks); + const messenger = getMessenger(); + const implementation = getDialogImplementation({ messenger }); await implementation({ context: { origin: 'foo' }, method: 'snap_dialog', @@ -414,21 +486,25 @@ describe('implementation', () => { }, }); - expect(hooks.requestUserApproval).toHaveBeenCalledTimes(1); - expect(hooks.requestUserApproval).toHaveBeenCalledWith({ - id: undefined, - origin: 'foo', - type: DIALOG_APPROVAL_TYPES[DialogType.Confirmation], - requestData: { - id: 'bar', - placeholder: undefined, + expect(messenger.call).toHaveBeenCalledTimes(3); + expect(messenger.call).toHaveBeenCalledWith( + 'ApprovalController:addRequest', + { + id: undefined, + origin: 'foo', + type: DIALOG_APPROVAL_TYPES[DialogType.Confirmation], + requestData: { + id: 'bar', + placeholder: undefined, + }, }, - }); + true, + ); }); it('handles confirmations using an ID', async () => { - const hooks = getMockDialogHooks(); - const implementation = getDialogImplementation(hooks); + const messenger = getMessenger(); + const implementation = getDialogImplementation({ messenger }); await implementation({ context: { origin: 'foo' }, method: 'snap_dialog', @@ -438,23 +514,27 @@ describe('implementation', () => { }, }); - expect(hooks.requestUserApproval).toHaveBeenCalledTimes(1); - expect(hooks.requestUserApproval).toHaveBeenCalledWith({ - id: undefined, - origin: 'foo', - type: DIALOG_APPROVAL_TYPES[DialogType.Confirmation], - requestData: { - id: 'baz', - placeholder: undefined, + expect(messenger.call).toHaveBeenCalledTimes(3); + expect(messenger.call).toHaveBeenCalledWith( + 'ApprovalController:addRequest', + { + id: undefined, + origin: 'foo', + type: DIALOG_APPROVAL_TYPES[DialogType.Confirmation], + requestData: { + id: 'baz', + placeholder: undefined, + }, }, - }); + true, + ); }); }); describe('prompts', () => { it('handles prompts', async () => { - const hooks = getMockDialogHooks(); - const implementation = getDialogImplementation(hooks); + const messenger = getMessenger(); + const implementation = getDialogImplementation({ messenger }); await implementation({ context: { origin: 'foo' }, method: 'snap_dialog', @@ -471,21 +551,25 @@ describe('implementation', () => { }, }); - expect(hooks.requestUserApproval).toHaveBeenCalledTimes(1); - expect(hooks.requestUserApproval).toHaveBeenCalledWith({ - id: undefined, - origin: 'foo', - type: DIALOG_APPROVAL_TYPES[DialogType.Prompt], - requestData: { - id: 'bar', - placeholder: 'foobar', + expect(messenger.call).toHaveBeenCalledTimes(3); + expect(messenger.call).toHaveBeenCalledWith( + 'ApprovalController:addRequest', + { + id: undefined, + origin: 'foo', + type: DIALOG_APPROVAL_TYPES[DialogType.Prompt], + requestData: { + id: 'bar', + placeholder: 'foobar', + }, }, - }); + true, + ); }); it('handles JSX prompts', async () => { - const hooks = getMockDialogHooks(); - const implementation = getDialogImplementation(hooks); + const messenger = getMessenger(); + const implementation = getDialogImplementation({ messenger }); await implementation({ context: { origin: 'foo' }, method: 'snap_dialog', @@ -500,21 +584,25 @@ describe('implementation', () => { }, }); - expect(hooks.requestUserApproval).toHaveBeenCalledTimes(1); - expect(hooks.requestUserApproval).toHaveBeenCalledWith({ - id: undefined, - origin: 'foo', - type: DIALOG_APPROVAL_TYPES[DialogType.Prompt], - requestData: { - id: 'bar', - placeholder: 'foobar', + expect(messenger.call).toHaveBeenCalledTimes(3); + expect(messenger.call).toHaveBeenCalledWith( + 'ApprovalController:addRequest', + { + id: undefined, + origin: 'foo', + type: DIALOG_APPROVAL_TYPES[DialogType.Prompt], + requestData: { + id: 'bar', + placeholder: 'foobar', + }, }, - }); + true, + ); }); it('handles prompts using an ID', async () => { - const hooks = getMockDialogHooks(); - const implementation = getDialogImplementation(hooks); + const messenger = getMessenger(); + const implementation = getDialogImplementation({ messenger }); await implementation({ context: { origin: 'foo' }, method: 'snap_dialog', @@ -525,16 +613,20 @@ describe('implementation', () => { }, }); - expect(hooks.requestUserApproval).toHaveBeenCalledTimes(1); - expect(hooks.requestUserApproval).toHaveBeenCalledWith({ - id: undefined, - origin: 'foo', - type: DIALOG_APPROVAL_TYPES[DialogType.Prompt], - requestData: { - id: 'baz', - placeholder: 'foobar', + expect(messenger.call).toHaveBeenCalledTimes(3); + expect(messenger.call).toHaveBeenCalledWith( + 'ApprovalController:addRequest', + { + id: undefined, + origin: 'foo', + type: DIALOG_APPROVAL_TYPES[DialogType.Prompt], + requestData: { + id: 'baz', + placeholder: 'foobar', + }, }, - }); + true, + ); }); }); @@ -542,8 +634,8 @@ describe('implementation', () => { it.each([undefined, null, false, 2])( 'rejects invalid parameter object', async (value) => { - const hooks = getMockDialogHooks(); - const implementation = getDialogImplementation(hooks); + const messenger = getMessenger(); + const implementation = getDialogImplementation({ messenger }); await expect( implementation({ @@ -558,8 +650,8 @@ describe('implementation', () => { ); it('rejects empty parameter object', async () => { - const hooks = getMockDialogHooks(); - const implementation = getDialogImplementation(hooks); + const messenger = getMessenger(); + const implementation = getDialogImplementation({ messenger }); await expect( implementation({ @@ -575,8 +667,8 @@ describe('implementation', () => { it.each([{ type: false }, { type: '' }, { type: 'foo' }])( 'rejects invalid types', async (value) => { - const hooks = getMockDialogHooks(); - const implementation = getDialogImplementation(hooks); + const messenger = getMessenger(); + const implementation = getDialogImplementation({ messenger }); await expect( implementation({ @@ -599,8 +691,8 @@ describe('implementation', () => { { type: DialogType.Alert, content: 2 }, { type: DialogType.Alert, content: [] }, ])('rejects invalid fields', async (value) => { - const hooks = getMockDialogHooks(); - const implementation = getDialogImplementation(hooks); + const messenger = getMessenger(); + const implementation = getDialogImplementation({ messenger }); await expect( implementation({ @@ -616,8 +708,8 @@ describe('implementation', () => { it.each([true, 2, [], {}, new (class {})()])( 'rejects invalid placeholder contents', async (value: any) => { - const hooks = getMockDialogHooks(); - const implementation = getDialogImplementation(hooks); + const messenger = getMessenger(); + const implementation = getDialogImplementation({ messenger }); await expect( implementation({ @@ -642,8 +734,8 @@ describe('implementation', () => { ); it('rejects placeholders with invalid length', async () => { - const hooks = getMockDialogHooks(); - const implementation = getDialogImplementation(hooks); + const messenger = getMessenger(); + const implementation = getDialogImplementation({ messenger }); await expect( implementation({ @@ -669,8 +761,8 @@ describe('implementation', () => { it.each([DialogType.Alert, DialogType.Confirmation])( 'rejects placeholder field for alerts and confirmations', async (type) => { - const hooks = getMockDialogHooks(); - const implementation = getDialogImplementation(hooks); + const messenger = getMessenger(); + const implementation = getDialogImplementation({ messenger }); await expect( implementation({ context: { origin: 'foo' }, diff --git a/packages/snaps-rpc-methods/src/restricted/dialog.ts b/packages/snaps-rpc-methods/src/restricted/dialog.ts index b178bb55a9..b78322e72b 100644 --- a/packages/snaps-rpc-methods/src/restricted/dialog.ts +++ b/packages/snaps-rpc-methods/src/restricted/dialog.ts @@ -1,3 +1,4 @@ +import type { Messenger } from '@metamask/messenger'; import type { PermissionSpecificationBuilder, RestrictedMethodOptions, @@ -14,21 +15,21 @@ import { import type { DialogParams, Component, - InterfaceState, - SnapId, PromptDialog, - ComponentOrElement, - InterfaceContext, - ContentType, DialogResult, } from '@metamask/snaps-sdk'; import type { InferMatching } from '@metamask/snaps-utils'; import type { Infer } from '@metamask/superstruct'; import { create, object, optional, size, string } from '@metamask/superstruct'; -import type { Json, NonEmptyArray } from '@metamask/utils'; +import type { NonEmptyArray } from '@metamask/utils'; import { hasProperty, isObject, isPlainObject } from '@metamask/utils'; -import { type MethodHooksObject } from '../utils'; +import type { + ApprovalControllerAddRequestAction, + SnapInterfaceControllerCreateInterfaceAction, + SnapInterfaceControllerGetInterfaceAction, + SnapInterfaceControllerSetInterfaceDisplayedAction, +} from '../types'; const methodName = 'snap_dialog'; @@ -47,67 +48,15 @@ const PlaceholderStruct = optional(size(string(), 1, 40)); export type Placeholder = Infer; -type RequestUserApprovalOptions = { - id?: string; - origin: string; - type: string; - requestData: { - id: string; - placeholder?: string; - }; -}; - -type RequestUserApproval = ( - opts: RequestUserApprovalOptions, -) => Promise; - -type CreateInterface = ( - snapId: string, - content: ComponentOrElement, - context?: InterfaceContext, - contentType?: ContentType, -) => Promise; - -type GetInterface = ( - snapId: string, - id: string, -) => { content: ComponentOrElement; snapId: SnapId; state: InterfaceState }; - -export type DialogMethodHooks = { - /** - * @param opts - The `requestUserApproval` options. - * @param opts.id - The approval ID. If not provided, a new approval ID will be generated. - * @param opts.origin - The origin of the request. In this case, the Snap ID. - * @param opts.type - The type of the approval request. - * @param opts.requestData - The data of the approval request. - * @param opts.requestData.id - The ID of the interface. - * @param opts.requestData.placeholder - The placeholder of the `Prompt` dialog. - */ - requestUserApproval: RequestUserApproval; - - /** - * @param snapId - The Snap ID creating the interface. - * @param content - The content of the interface. - */ - createInterface: CreateInterface; - /** - * @param snapId - The SnapId requesting the interface. - * @param id - The interface ID. - */ - getInterface: GetInterface; - - /** - * Set the interface as displayed. - * - * @param snapId - The Snap ID requesting the interface. - * @param id - The interface ID. - */ - setInterfaceDisplayed: (snapId: string, id: string) => void; -}; +export type DialogMessengerActions = + | ApprovalControllerAddRequestAction + | SnapInterfaceControllerCreateInterfaceAction + | SnapInterfaceControllerGetInterfaceAction + | SnapInterfaceControllerSetInterfaceDisplayedAction; type DialogSpecificationBuilderOptions = { allowedCaveats?: Readonly> | null; - methodHooks: DialogMethodHooks; + messenger: Messenger; }; type DialogSpecification = ValidPermissionSpecification<{ @@ -127,8 +76,7 @@ type DialogSpecification = ValidPermissionSpecification<{ * @param options - The specification builder options. * @param options.allowedCaveats - The optional allowed caveats for the * permission. - * @param options.methodHooks - The RPC method hooks needed by the method - * implementation. + * @param options.messenger - The messenger. * @returns The specification for the `snap_dialog` permission. */ const specificationBuilder: PermissionSpecificationBuilder< @@ -137,24 +85,17 @@ const specificationBuilder: PermissionSpecificationBuilder< DialogSpecification > = ({ allowedCaveats = null, - methodHooks, + messenger, }: DialogSpecificationBuilderOptions) => { return { permissionType: PermissionType.RestrictedMethod, targetName: methodName, allowedCaveats, - methodImplementation: getDialogImplementation(methodHooks), + methodImplementation: getDialogImplementation({ messenger }), subjectTypes: [SubjectType.Snap], }; }; -const methodHooks: MethodHooksObject = { - requestUserApproval: true, - createInterface: true, - getInterface: true, - setInterfaceDisplayed: true, -}; - /* eslint-disable jsdoc/check-indentation */ /** * Display a [dialog](https://docs.metamask.io/snaps/features/custom-ui/dialogs/) @@ -200,7 +141,12 @@ const methodHooks: MethodHooksObject = { export const dialogBuilder = Object.freeze({ targetName: methodName, specificationBuilder, - methodHooks, + actionNames: [ + 'ApprovalController:addRequest', + 'SnapInterfaceController:createInterface', + 'SnapInterfaceController:getInterface', + 'SnapInterfaceController:setInterfaceDisplayed', + ], } as const); /* eslint-enable jsdoc/check-indentation */ @@ -301,22 +247,14 @@ export type DialogParameters = InferMatching< /** * Builds the method implementation for `snap_dialog`. * - * @param hooks - The RPC method hooks. - * @param hooks.requestUserApproval - A function that creates a new Approval in the ApprovalController. - * This function should return a Promise that resolves with the appropriate value when the user has approved or rejected the request. - * @param hooks.createInterface - A function that creates the interface in SnapInterfaceController. - * @param hooks.getInterface - A function that gets an interface from SnapInterfaceController. - * @param hooks.setInterfaceDisplayed - A function that sets the interface as - * displayed in SnapInterfaceController. + * @param options - The options. + * @param options.messenger - The messenger. * @returns The method implementation which return value depends on the dialog * type, valid return types are: string, boolean, null. */ export function getDialogImplementation({ - requestUserApproval, - createInterface, - getInterface, - setInterfaceDisplayed, -}: DialogMethodHooks) { + messenger, +}: DialogSpecificationBuilderOptions) { return async function dialogImplementation( args: RestrictedMethodOptions, ): Promise { @@ -346,33 +284,51 @@ export function getDialogImplementation({ ]; if (hasProperty(validatedParams, 'content')) { - const id = await createInterface( + const id = messenger.call( + 'SnapInterfaceController:createInterface', origin, validatedParams.content as Component, ); - setInterfaceDisplayed(origin, id); - - return requestUserApproval({ - id: approvalType === DIALOG_APPROVAL_TYPES.default ? id : undefined, + messenger.call( + 'SnapInterfaceController:setInterfaceDisplayed', origin, - type: approvalType, - requestData: { id, placeholder }, - }); + id, + ); + + return messenger.call( + 'ApprovalController:addRequest', + { + id: approvalType === DIALOG_APPROVAL_TYPES.default ? id : undefined, + origin, + type: approvalType, + requestData: { id, placeholder }, + }, + true, + ); } - validateInterface(origin, validatedParams.id, getInterface); - setInterfaceDisplayed(origin, validatedParams.id); + validateInterface(origin, validatedParams.id, messenger); - return requestUserApproval({ - id: - approvalType === DIALOG_APPROVAL_TYPES.default - ? validatedParams.id - : undefined, + messenger.call( + 'SnapInterfaceController:setInterfaceDisplayed', origin, - type: approvalType, - requestData: { id: validatedParams.id, placeholder }, - }); + validatedParams.id, + ); + + return messenger.call( + 'ApprovalController:addRequest', + { + id: + approvalType === DIALOG_APPROVAL_TYPES.default + ? validatedParams.id + : undefined, + origin, + type: approvalType, + requestData: { id: validatedParams.id, placeholder }, + }, + true, + ); }; } /** @@ -380,15 +336,15 @@ export function getDialogImplementation({ * * @param origin - The origin of the request. * @param id - The interface ID. - * @param getInterface - The function to get the interface. + * @param messenger - The messenger. */ function validateInterface( origin: string, id: string, - getInterface: GetInterface, + messenger: Messenger, ) { try { - getInterface(origin, id); + messenger.call('SnapInterfaceController:getInterface', origin, id); } catch (error) { throw rpcErrors.invalidParams({ message: `Invalid params: ${error.message}`, diff --git a/packages/snaps-rpc-methods/src/restricted/getBip32Entropy.test.ts b/packages/snaps-rpc-methods/src/restricted/getBip32Entropy.test.ts index d584035b1f..192d2e9d3b 100644 --- a/packages/snaps-rpc-methods/src/restricted/getBip32Entropy.test.ts +++ b/packages/snaps-rpc-methods/src/restricted/getBip32Entropy.test.ts @@ -1,3 +1,5 @@ +import type { MockAnyNamespace } from '@metamask/messenger'; +import { MOCK_ANY_NAMESPACE, Messenger } from '@metamask/messenger'; import { PermissionType, SubjectType } from '@metamask/permission-controller'; import { SnapCaveatType } from '@metamask/snaps-utils'; import { @@ -8,6 +10,7 @@ import { import { hmac } from '@noble/hashes/hmac'; import { sha512 } from '@noble/hashes/sha512'; +import type { GetBip32EntropyMessengerActions } from './getBip32Entropy'; import { getBip32EntropyBuilder, getBip32EntropyImplementation, @@ -15,14 +18,13 @@ import { describe('specificationBuilder', () => { const methodHooks = { - getMnemonic: jest.fn(), - getMnemonicSeed: jest.fn(), getUnlockPromise: jest.fn(), getClientCryptography: jest.fn(), }; const specification = getBip32EntropyBuilder.specificationBuilder({ methodHooks, + messenger: new Messenger({ namespace: 'GetBip32Entropy' }), }); it('outputs expected specification', () => { @@ -64,23 +66,41 @@ describe('specificationBuilder', () => { }); describe('getBip32EntropyImplementation', () => { + const getMessenger = () => { + const messenger = new Messenger< + MockAnyNamespace, + GetBip32EntropyMessengerActions + >({ + namespace: MOCK_ANY_NAMESPACE, + }); + + messenger.registerActionHandler( + 'KeyringController:withKeyring', + async (_selector, operation) => + operation({ + keyring: { + type: 'HD Key Tree', + mnemonic: TEST_SECRET_RECOVERY_PHRASE_BYTES, + seed: TEST_SECRET_RECOVERY_PHRASE_SEED_BYTES, + }, + }), + ); + + jest.spyOn(messenger, 'call'); + + return messenger; + }; + describe('getBip32Entropy', () => { it('derives the entropy from the path', async () => { const getUnlockPromise = jest.fn().mockResolvedValue(undefined); - const getMnemonic = jest - .fn() - .mockResolvedValue(TEST_SECRET_RECOVERY_PHRASE_BYTES); - const getMnemonicSeed = jest - .fn() - .mockResolvedValue(TEST_SECRET_RECOVERY_PHRASE_SEED_BYTES); const getClientCryptography = jest.fn().mockReturnValue({}); + const messenger = getMessenger(); expect( await getBip32EntropyImplementation({ - getUnlockPromise, - getMnemonic, - getMnemonicSeed, - getClientCryptography, + methodHooks: { getUnlockPromise, getClientCryptography }, + messenger, // @ts-expect-error Missing other required properties. })({ params: { path: ['m', "44'", "1'"], curve: 'secp256k1' }, @@ -102,20 +122,13 @@ describe('getBip32EntropyImplementation', () => { it('derives a BIP-44 path', async () => { const getUnlockPromise = jest.fn().mockResolvedValue(undefined); - const getMnemonic = jest - .fn() - .mockResolvedValue(TEST_SECRET_RECOVERY_PHRASE_BYTES); - const getMnemonicSeed = jest - .fn() - .mockResolvedValue(TEST_SECRET_RECOVERY_PHRASE_SEED_BYTES); const getClientCryptography = jest.fn().mockReturnValue({}); + const messenger = getMessenger(); expect( await getBip32EntropyImplementation({ - getUnlockPromise, - getMnemonic, - getMnemonicSeed, - getClientCryptography, + methodHooks: { getUnlockPromise, getClientCryptography }, + messenger, // @ts-expect-error Missing other required properties. })({ params: { @@ -140,21 +153,13 @@ describe('getBip32EntropyImplementation', () => { it('derives a path using ed25519', async () => { const getUnlockPromise = jest.fn().mockResolvedValue(undefined); - const getMnemonic = jest - .fn() - .mockResolvedValue(TEST_SECRET_RECOVERY_PHRASE_BYTES); - const getMnemonicSeed = jest - .fn() - .mockResolvedValue(TEST_SECRET_RECOVERY_PHRASE_SEED_BYTES); - const getClientCryptography = jest.fn().mockReturnValue({}); + const messenger = getMessenger(); expect( await getBip32EntropyImplementation({ - getUnlockPromise, - getMnemonic, - getMnemonicSeed, - getClientCryptography, + methodHooks: { getUnlockPromise, getClientCryptography }, + messenger, // @ts-expect-error Missing other required properties. })({ params: { @@ -179,21 +184,13 @@ describe('getBip32EntropyImplementation', () => { it('derives a path using ed25519Bip32', async () => { const getUnlockPromise = jest.fn().mockResolvedValue(undefined); - const getMnemonic = jest - .fn() - .mockResolvedValue(TEST_SECRET_RECOVERY_PHRASE_BYTES); - const getMnemonicSeed = jest - .fn() - .mockResolvedValue(TEST_SECRET_RECOVERY_PHRASE_SEED_BYTES); - const getClientCryptography = jest.fn().mockReturnValue({}); + const messenger = getMessenger(); expect( await getBip32EntropyImplementation({ - getUnlockPromise, - getMnemonic, - getMnemonicSeed, - getClientCryptography, + methodHooks: { getUnlockPromise, getClientCryptography }, + messenger, // @ts-expect-error Missing other required properties. })({ params: { @@ -217,22 +214,14 @@ describe('getBip32EntropyImplementation', () => { }); it('calls `getMnemonic` with a different entropy source', async () => { - const getMnemonic = jest - .fn() - .mockImplementation(() => TEST_SECRET_RECOVERY_PHRASE_BYTES); - const getMnemonicSeed = jest - .fn() - .mockResolvedValue(TEST_SECRET_RECOVERY_PHRASE_SEED_BYTES); - const getUnlockPromise = jest.fn(); const getClientCryptography = jest.fn().mockReturnValue({}); + const messenger = getMessenger(); expect( await getBip32EntropyImplementation({ - getUnlockPromise, - getMnemonic, - getMnemonicSeed, - getClientCryptography, + methodHooks: { getUnlockPromise, getClientCryptography }, + messenger, })({ method: 'snap_getBip32Entropy', context: { origin: MOCK_SNAP_ID }, @@ -256,27 +245,23 @@ describe('getBip32EntropyImplementation', () => { } `); - expect(getMnemonic).toHaveBeenCalledWith('source-id'); - expect(getMnemonicSeed).not.toHaveBeenCalled(); + expect(messenger.call).toHaveBeenCalledTimes(1); + expect(messenger.call).toHaveBeenCalledWith( + 'KeyringController:withKeyring', + { id: 'source-id' }, + expect.any(Function), + ); }); it('calls `getMnemonicSeed` with a different entropy source', async () => { - const getMnemonic = jest - .fn() - .mockImplementation(() => TEST_SECRET_RECOVERY_PHRASE_BYTES); - const getMnemonicSeed = jest - .fn() - .mockResolvedValue(TEST_SECRET_RECOVERY_PHRASE_SEED_BYTES); - const getUnlockPromise = jest.fn(); const getClientCryptography = jest.fn().mockReturnValue({}); + const messenger = getMessenger(); expect( await getBip32EntropyImplementation({ - getUnlockPromise, - getMnemonic, - getMnemonicSeed, - getClientCryptography, + methodHooks: { getUnlockPromise, getClientCryptography }, + messenger, })({ method: 'snap_getBip32Entropy', context: { origin: MOCK_SNAP_ID }, @@ -300,18 +285,16 @@ describe('getBip32EntropyImplementation', () => { } `); - expect(getMnemonicSeed).toHaveBeenCalledWith('source-id'); - expect(getMnemonic).not.toHaveBeenCalled(); + expect(messenger.call).toHaveBeenCalledTimes(1); + expect(messenger.call).toHaveBeenCalledWith( + 'KeyringController:withKeyring', + { id: 'source-id' }, + expect.any(Function), + ); }); it('uses custom client cryptography functions', async () => { const getUnlockPromise = jest.fn().mockResolvedValue(undefined); - const getMnemonic = jest - .fn() - .mockResolvedValue(TEST_SECRET_RECOVERY_PHRASE_BYTES); - const getMnemonicSeed = jest - .fn() - .mockResolvedValue(TEST_SECRET_RECOVERY_PHRASE_SEED_BYTES); const hmacSha512 = jest .fn() @@ -321,13 +304,12 @@ describe('getBip32EntropyImplementation', () => { const getClientCryptography = jest.fn().mockReturnValue({ hmacSha512, }); + const messenger = getMessenger(); expect( await getBip32EntropyImplementation({ - getUnlockPromise, - getMnemonic, - getMnemonicSeed, - getClientCryptography, + methodHooks: { getUnlockPromise, getClientCryptography }, + messenger, // @ts-expect-error Missing other required properties. })({ params: { path: ['m', "44'", "1'"], curve: 'secp256k1' }, @@ -348,5 +330,133 @@ describe('getBip32EntropyImplementation', () => { expect(hmacSha512).toHaveBeenCalledTimes(3); }); + + it('throws if invalid primary keyring is returned', async () => { + const getUnlockPromise = jest.fn().mockResolvedValue(undefined); + const getClientCryptography = jest.fn().mockReturnValue({}); + + const messenger = new Messenger< + MockAnyNamespace, + GetBip32EntropyMessengerActions + >({ namespace: MOCK_ANY_NAMESPACE }); + + messenger.registerActionHandler( + 'KeyringController:withKeyring', + async (_selector, operation) => + operation({ + keyring: { + type: 'HD Key Tree', + }, + }), + ); + + await expect( + getBip32EntropyImplementation({ + methodHooks: { getUnlockPromise, getClientCryptography }, + messenger, + // @ts-expect-error Missing other required properties. + })({ + params: { path: ['m', "44'", "1'"], curve: 'secp256k1' }, + }), + ).rejects.toThrow('Primary keyring mnemonic unavailable.'); + }); + + it('throws if invalid primary keyring is returned for ed25519Bip32', async () => { + const getUnlockPromise = jest.fn().mockResolvedValue(undefined); + const getClientCryptography = jest.fn().mockReturnValue({}); + + const messenger = new Messenger< + MockAnyNamespace, + GetBip32EntropyMessengerActions + >({ namespace: MOCK_ANY_NAMESPACE }); + + messenger.registerActionHandler( + 'KeyringController:withKeyring', + async (_selector, operation) => + operation({ + keyring: { + type: 'HD Key Tree', + }, + }), + ); + + await expect( + getBip32EntropyImplementation({ + methodHooks: { getUnlockPromise, getClientCryptography }, + messenger, + // @ts-expect-error Missing other required properties. + })({ + params: { path: ['m', "44'", "1'"], curve: 'ed25519Bip32' }, + }), + ).rejects.toThrow('Primary keyring mnemonic unavailable.'); + }); + + it('throws if invalid keyring is returned when selected using entropy source ID', async () => { + const getUnlockPromise = jest.fn().mockResolvedValue(undefined); + const getClientCryptography = jest.fn().mockReturnValue({}); + + const messenger = new Messenger< + MockAnyNamespace, + GetBip32EntropyMessengerActions + >({ namespace: MOCK_ANY_NAMESPACE }); + + messenger.registerActionHandler( + 'KeyringController:withKeyring', + async (_selector, operation) => + operation({ + keyring: { + type: 'HD Key Tree', + }, + }), + ); + + await expect( + getBip32EntropyImplementation({ + methodHooks: { getUnlockPromise, getClientCryptography }, + messenger, + // @ts-expect-error Missing other required properties. + })({ + params: { + path: ['m', "44'", "1'"], + curve: 'secp256k1', + source: 'foo', + }, + }), + ).rejects.toThrow('Entropy source with ID "foo" not found.'); + }); + + it('throws if invalid keyring is returned when selected using entropy source ID for ed25519Bip32', async () => { + const getUnlockPromise = jest.fn().mockResolvedValue(undefined); + const getClientCryptography = jest.fn().mockReturnValue({}); + + const messenger = new Messenger< + MockAnyNamespace, + GetBip32EntropyMessengerActions + >({ namespace: MOCK_ANY_NAMESPACE }); + + messenger.registerActionHandler( + 'KeyringController:withKeyring', + async (_selector, operation) => + operation({ + keyring: { + type: 'HD Key Tree', + }, + }), + ); + + await expect( + getBip32EntropyImplementation({ + methodHooks: { getUnlockPromise, getClientCryptography }, + messenger, + // @ts-expect-error Missing other required properties. + })({ + params: { + path: ['m', "44'", "1'"], + curve: 'ed25519Bip32', + source: 'foo', + }, + }), + ).rejects.toThrow('Entropy source with ID "foo" not found.'); + }); }); }); diff --git a/packages/snaps-rpc-methods/src/restricted/getBip32Entropy.ts b/packages/snaps-rpc-methods/src/restricted/getBip32Entropy.ts index ae16964e7b..a55ca235c1 100644 --- a/packages/snaps-rpc-methods/src/restricted/getBip32Entropy.ts +++ b/packages/snaps-rpc-methods/src/restricted/getBip32Entropy.ts @@ -1,4 +1,5 @@ import type { CryptographicFunctions } from '@metamask/key-tree'; +import type { Messenger } from '@metamask/messenger'; import type { PermissionSpecificationBuilder, PermissionValidatorConstraint, @@ -15,8 +16,11 @@ import { SnapCaveatType } from '@metamask/snaps-utils'; import type { NonEmptyArray } from '@metamask/utils'; import { assert } from '@metamask/utils'; +import type { KeyringControllerWithKeyringAction } from '../types'; import type { MethodHooksObject } from '../utils'; import { + getMnemonic, + getMnemonicSeed, getNodeFromMnemonic, getNodeFromSeed, getValueFromEntropySource, @@ -25,26 +29,6 @@ import { const targetName = 'snap_getBip32Entropy'; export type GetBip32EntropyMethodHooks = { - /** - * Get the mnemonic of the provided source. If no source is provided, the - * mnemonic of the primary keyring will be returned. - * - * @param source - The optional ID of the source to get the mnemonic of. - * @returns The mnemonic of the provided source, or the default source if no - * source is provided. - */ - getMnemonic: (source?: string | undefined) => Promise; - - /** - * Get the mnemonic seed of the provided source. If no source is provided, the - * mnemonic seed of the primary keyring will be returned. - * - * @param source - The optional ID of the source to get the mnemonic of. - * @returns The mnemonic seed of the provided source, or the default source if no - * source is provided. - */ - getMnemonicSeed: (source?: string | undefined) => Promise; - /** * Waits for the extension to be unlocked. * @@ -62,8 +46,12 @@ export type GetBip32EntropyMethodHooks = { getClientCryptography: () => CryptographicFunctions | undefined; }; +export type GetBip32EntropyMessengerActions = + KeyringControllerWithKeyringAction; + type GetBip32EntropySpecificationBuilderOptions = { methodHooks: GetBip32EntropyMethodHooks; + messenger: Messenger; }; type GetBip32EntropySpecification = ValidPermissionSpecification<{ @@ -80,6 +68,7 @@ type GetBip32EntropySpecification = ValidPermissionSpecification<{ * BIP-32 node. * * @param options - The specification builder options. + * @param options.messenger - The messenger. * @param options.methodHooks - The RPC method hooks needed by the method implementation. * @returns The specification for the `snap_getBip32Entropy` permission. */ @@ -87,12 +76,18 @@ const specificationBuilder: PermissionSpecificationBuilder< PermissionType.RestrictedMethod, GetBip32EntropySpecificationBuilderOptions, GetBip32EntropySpecification -> = ({ methodHooks }: GetBip32EntropySpecificationBuilderOptions) => { +> = ({ + methodHooks, + messenger, +}: GetBip32EntropySpecificationBuilderOptions) => { return { permissionType: PermissionType.RestrictedMethod, targetName, allowedCaveats: [SnapCaveatType.PermittedDerivationPaths], - methodImplementation: getBip32EntropyImplementation(methodHooks), + methodImplementation: getBip32EntropyImplementation({ + methodHooks, + messenger, + }), validator: ({ caveats }) => { if ( caveats?.length !== 1 || @@ -108,8 +103,6 @@ const specificationBuilder: PermissionSpecificationBuilder< }; const methodHooks: MethodHooksObject = { - getMnemonic: true, - getMnemonicSeed: true, getUnlockPromise: true, getClientCryptography: true, }; @@ -173,27 +166,26 @@ export const getBip32EntropyBuilder = Object.freeze({ targetName, specificationBuilder, methodHooks, + actionNames: ['KeyringController:withKeyring'], } as const); /** * Builds the method implementation for `snap_getBip32Entropy`. * - * @param hooks - The RPC method hooks. - * @param hooks.getMnemonic - A function to retrieve the Secret Recovery Phrase of the user. - * @param hooks.getMnemonicSeed - A function to retrieve the BIP-39 seed of the user. - * @param hooks.getUnlockPromise - A function that resolves once the MetaMask extension is unlocked + * @param options - The options. + * @param options.messenger - The messenger. + * @param options.methodHooks - The RPC method hooks. + * @param options.methodHooks.getUnlockPromise - A function that resolves once the MetaMask extension is unlocked * and prompts the user to unlock their MetaMask if it is locked. - * @param hooks.getClientCryptography - A function to retrieve the cryptographic + * @param options.methodHooks.getClientCryptography - A function to retrieve the cryptographic * functions to use for the client. * @returns The method implementation which returns a `JsonSLIP10Node`. * @throws If the params are invalid. */ export function getBip32EntropyImplementation({ - getMnemonic, - getMnemonicSeed, - getUnlockPromise, - getClientCryptography, -}: GetBip32EntropyMethodHooks) { + methodHooks: { getUnlockPromise, getClientCryptography }, + messenger, +}: GetBip32EntropySpecificationBuilderOptions) { return async function getBip32Entropy( args: RestrictedMethodOptions, ): Promise { @@ -205,7 +197,7 @@ export function getBip32EntropyImplementation({ // Using the seed is much faster, but we can only do it for these specific curves. if (params.curve === 'secp256k1' || params.curve === 'ed25519') { const seed = await getValueFromEntropySource( - getMnemonicSeed, + getMnemonicSeed.bind(null, messenger), params.source, ); @@ -220,7 +212,7 @@ export function getBip32EntropyImplementation({ } const secretRecoveryPhrase = await getValueFromEntropySource( - getMnemonic, + getMnemonic.bind(null, messenger), params.source, ); diff --git a/packages/snaps-rpc-methods/src/restricted/getBip32PublicKey.test.ts b/packages/snaps-rpc-methods/src/restricted/getBip32PublicKey.test.ts index fef25b3e89..20b217c59b 100644 --- a/packages/snaps-rpc-methods/src/restricted/getBip32PublicKey.test.ts +++ b/packages/snaps-rpc-methods/src/restricted/getBip32PublicKey.test.ts @@ -1,3 +1,5 @@ +import type { MockAnyNamespace } from '@metamask/messenger'; +import { MOCK_ANY_NAMESPACE, Messenger } from '@metamask/messenger'; import { PermissionType, SubjectType } from '@metamask/permission-controller'; import { SnapCaveatType } from '@metamask/snaps-utils'; import { @@ -8,6 +10,7 @@ import { import { hmac } from '@noble/hashes/hmac'; import { sha512 } from '@noble/hashes/sha512'; +import type { GetBip32PublicKeyMessengerActions } from './getBip32PublicKey'; import { getBip32PublicKeyBuilder, getBip32PublicKeyImplementation, @@ -15,14 +18,13 @@ import { describe('specificationBuilder', () => { const methodHooks = { - getMnemonic: jest.fn(), - getMnemonicSeed: jest.fn(), getUnlockPromise: jest.fn(), getClientCryptography: jest.fn(), }; const specification = getBip32PublicKeyBuilder.specificationBuilder({ methodHooks, + messenger: new Messenger({ namespace: 'GetBip32PublicKey' }), }); it('outputs expected specification', () => { @@ -64,23 +66,41 @@ describe('specificationBuilder', () => { }); describe('getBip32PublicKeyImplementation', () => { + const getMessenger = () => { + const messenger = new Messenger< + MockAnyNamespace, + GetBip32PublicKeyMessengerActions + >({ + namespace: MOCK_ANY_NAMESPACE, + }); + + messenger.registerActionHandler( + 'KeyringController:withKeyring', + async (_selector, operation) => + operation({ + keyring: { + type: 'HD Key Tree', + mnemonic: TEST_SECRET_RECOVERY_PHRASE_BYTES, + seed: TEST_SECRET_RECOVERY_PHRASE_SEED_BYTES, + }, + }), + ); + + jest.spyOn(messenger, 'call'); + + return messenger; + }; + describe('getBip32PublicKey', () => { it('derives the public key from the path', async () => { const getUnlockPromise = jest.fn().mockResolvedValue(undefined); - const getMnemonic = jest - .fn() - .mockResolvedValue(TEST_SECRET_RECOVERY_PHRASE_BYTES); - const getMnemonicSeed = jest - .fn() - .mockResolvedValue(TEST_SECRET_RECOVERY_PHRASE_SEED_BYTES); const getClientCryptography = jest.fn().mockReturnValue({}); + const messenger = getMessenger(); expect( await getBip32PublicKeyImplementation({ - getUnlockPromise, - getMnemonic, - getMnemonicSeed, - getClientCryptography, + methodHooks: { getUnlockPromise, getClientCryptography }, + messenger, // @ts-expect-error Missing other required properties. })({ params: { @@ -95,20 +115,13 @@ describe('getBip32PublicKeyImplementation', () => { it('derives the ed25519 public key from the path', async () => { const getUnlockPromise = jest.fn().mockResolvedValue(undefined); - const getMnemonic = jest - .fn() - .mockResolvedValue(TEST_SECRET_RECOVERY_PHRASE_BYTES); - const getMnemonicSeed = jest - .fn() - .mockResolvedValue(TEST_SECRET_RECOVERY_PHRASE_SEED_BYTES); const getClientCryptography = jest.fn().mockReturnValue({}); + const messenger = getMessenger(); expect( await getBip32PublicKeyImplementation({ - getUnlockPromise, - getMnemonic, - getMnemonicSeed, - getClientCryptography, + methodHooks: { getUnlockPromise, getClientCryptography }, + messenger, // @ts-expect-error Missing other required properties. })({ params: { @@ -123,20 +136,13 @@ describe('getBip32PublicKeyImplementation', () => { it('derives the ed25519Bip32 public key from the path', async () => { const getUnlockPromise = jest.fn().mockResolvedValue(undefined); - const getMnemonic = jest - .fn() - .mockResolvedValue(TEST_SECRET_RECOVERY_PHRASE_BYTES); - const getMnemonicSeed = jest - .fn() - .mockResolvedValue(TEST_SECRET_RECOVERY_PHRASE_SEED_BYTES); const getClientCryptography = jest.fn().mockReturnValue({}); + const messenger = getMessenger(); expect( await getBip32PublicKeyImplementation({ - getUnlockPromise, - getMnemonic, - getMnemonicSeed, - getClientCryptography, + methodHooks: { getUnlockPromise, getClientCryptography }, + messenger, // @ts-expect-error Missing other required properties. })({ params: { @@ -151,20 +157,13 @@ describe('getBip32PublicKeyImplementation', () => { it('derives the compressed public key from the path', async () => { const getUnlockPromise = jest.fn().mockResolvedValue(undefined); - const getMnemonic = jest - .fn() - .mockResolvedValue(TEST_SECRET_RECOVERY_PHRASE_BYTES); - const getMnemonicSeed = jest - .fn() - .mockResolvedValue(TEST_SECRET_RECOVERY_PHRASE_SEED_BYTES); const getClientCryptography = jest.fn().mockReturnValue({}); + const messenger = getMessenger(); expect( await getBip32PublicKeyImplementation({ - getUnlockPromise, - getMnemonic, - getMnemonicSeed, - getClientCryptography, + methodHooks: { getUnlockPromise, getClientCryptography }, + messenger, // @ts-expect-error Missing other required properties. })({ params: { @@ -178,23 +177,37 @@ describe('getBip32PublicKeyImplementation', () => { ); }); - it('calls `getMnemonic` with a different entropy source', async () => { - const getMnemonic = jest - .fn() - .mockImplementation(() => TEST_SECRET_RECOVERY_PHRASE_BYTES); - const getMnemonicSeed = jest - .fn() - .mockResolvedValue(TEST_SECRET_RECOVERY_PHRASE_SEED_BYTES); + it('derives the compressed public key from the path using ed25519Bip32', async () => { + const getUnlockPromise = jest.fn().mockResolvedValue(undefined); + const getClientCryptography = jest.fn().mockReturnValue({}); + const messenger = getMessenger(); + + expect( + await getBip32PublicKeyImplementation({ + methodHooks: { getUnlockPromise, getClientCryptography }, + messenger, + // @ts-expect-error Missing other required properties. + })({ + params: { + path: ['m', "44'", "1'", '1', '2', '3'], + curve: 'ed25519Bip32', + compressed: true, + }, + }), + ).toMatchInlineSnapshot( + `"0x03303da49ddfafc90587b7559eacdd5523028e75be81f2a9f158733fee1211a6"`, + ); + }); + it('calls `getMnemonic` with a different entropy source', async () => { const getUnlockPromise = jest.fn(); const getClientCryptography = jest.fn().mockReturnValue({}); + const messenger = getMessenger(); expect( await getBip32PublicKeyImplementation({ - getUnlockPromise, - getMnemonic, - getMnemonicSeed, - getClientCryptography, + methodHooks: { getUnlockPromise, getClientCryptography }, + messenger, })({ method: 'snap_getBip32PublicKey', context: { origin: MOCK_SNAP_ID }, @@ -208,27 +221,23 @@ describe('getBip32PublicKeyImplementation', () => { `"0x03303da49ddfafc90587b7559eacdd5523028e75be81f2a9f158733fee1211a6"`, ); - expect(getMnemonic).toHaveBeenCalledWith('source-id'); - expect(getMnemonicSeed).not.toHaveBeenCalled(); + expect(messenger.call).toHaveBeenCalledTimes(1); + expect(messenger.call).toHaveBeenCalledWith( + 'KeyringController:withKeyring', + { id: 'source-id' }, + expect.any(Function), + ); }); it('calls `getMnemonicSeed` with a different entropy source', async () => { - const getMnemonic = jest - .fn() - .mockImplementation(() => TEST_SECRET_RECOVERY_PHRASE_BYTES); - const getMnemonicSeed = jest - .fn() - .mockResolvedValue(TEST_SECRET_RECOVERY_PHRASE_SEED_BYTES); - const getUnlockPromise = jest.fn(); const getClientCryptography = jest.fn().mockReturnValue({}); + const messenger = getMessenger(); expect( await getBip32PublicKeyImplementation({ - getUnlockPromise, - getMnemonic, - getMnemonicSeed, - getClientCryptography, + methodHooks: { getUnlockPromise, getClientCryptography }, + messenger, })({ method: 'snap_getBip32PublicKey', context: { origin: MOCK_SNAP_ID }, @@ -242,18 +251,16 @@ describe('getBip32PublicKeyImplementation', () => { `"0x042de17487a660993177ce2a85bb73b6cd9ad436184d57bdf5a93f5db430bea914f7c31d378fe68f4723b297a04e49ef55fbf490605c4a3f9ca947a4af4f06526a"`, ); - expect(getMnemonicSeed).toHaveBeenCalledWith('source-id'); - expect(getMnemonic).not.toHaveBeenCalled(); + expect(messenger.call).toHaveBeenCalledTimes(1); + expect(messenger.call).toHaveBeenCalledWith( + 'KeyringController:withKeyring', + { id: 'source-id' }, + expect.any(Function), + ); }); it('uses custom client cryptography functions', async () => { const getUnlockPromise = jest.fn().mockResolvedValue(undefined); - const getMnemonic = jest - .fn() - .mockResolvedValue(TEST_SECRET_RECOVERY_PHRASE_BYTES); - const getMnemonicSeed = jest - .fn() - .mockResolvedValue(TEST_SECRET_RECOVERY_PHRASE_SEED_BYTES); const hmacSha512 = jest .fn() @@ -263,13 +270,12 @@ describe('getBip32PublicKeyImplementation', () => { const getClientCryptography = jest.fn().mockReturnValue({ hmacSha512, }); + const messenger = getMessenger(); expect( await getBip32PublicKeyImplementation({ - getUnlockPromise, - getMnemonic, - getMnemonicSeed, - getClientCryptography, + methodHooks: { getUnlockPromise, getClientCryptography }, + messenger, // @ts-expect-error Missing other required properties. })({ params: { diff --git a/packages/snaps-rpc-methods/src/restricted/getBip32PublicKey.ts b/packages/snaps-rpc-methods/src/restricted/getBip32PublicKey.ts index 45ab805657..530641859e 100644 --- a/packages/snaps-rpc-methods/src/restricted/getBip32PublicKey.ts +++ b/packages/snaps-rpc-methods/src/restricted/getBip32PublicKey.ts @@ -1,4 +1,5 @@ import type { CryptographicFunctions } from '@metamask/key-tree'; +import type { Messenger } from '@metamask/messenger'; import type { PermissionSpecificationBuilder, PermissionValidatorConstraint, @@ -21,36 +22,19 @@ import { boolean, object, optional, string } from '@metamask/superstruct'; import type { NonEmptyArray } from '@metamask/utils'; import { assertStruct } from '@metamask/utils'; +import type { KeyringControllerWithKeyringAction } from '../types'; import type { MethodHooksObject } from '../utils'; import { - getValueFromEntropySource, + getMnemonic, + getMnemonicSeed, getNodeFromMnemonic, getNodeFromSeed, + getValueFromEntropySource, } from '../utils'; const targetName = 'snap_getBip32PublicKey'; export type GetBip32PublicKeyMethodHooks = { - /** - * Get the mnemonic of the provided source. If no source is provided, the - * mnemonic of the primary keyring will be returned. - * - * @param source - The optional ID of the source to get the mnemonic of. - * @returns The mnemonic of the provided source, or the default source if no - * source is provided. - */ - getMnemonic: (source?: string | undefined) => Promise; - - /** - * Get the mnemonic seed of the provided source. If no source is provided, the - * mnemonic seed of the primary keyring will be returned. - * - * @param source - The optional ID of the source to get the mnemonic of. - * @returns The mnemonic seed of the provided source, or the default source if no - * source is provided. - */ - getMnemonicSeed: (source?: string | undefined) => Promise; - /** * Waits for the extension to be unlocked. * @@ -68,8 +52,12 @@ export type GetBip32PublicKeyMethodHooks = { getClientCryptography: () => CryptographicFunctions | undefined; }; +export type GetBip32PublicKeyMessengerActions = + KeyringControllerWithKeyringAction; + type GetBip32PublicKeySpecificationBuilderOptions = { methodHooks: GetBip32PublicKeyMethodHooks; + messenger: Messenger; }; type GetBip32PublicKeySpecification = ValidPermissionSpecification<{ @@ -95,6 +83,7 @@ export const Bip32PublicKeyArgsStruct = bip32entropy( * BIP-32 node. * * @param options - The specification builder options. + * @param options.messenger - The messenger. * @param options.methodHooks - The RPC method hooks needed by the method implementation. * @returns The specification for the `snap_getBip32PublicKey` permission. */ @@ -102,12 +91,18 @@ const specificationBuilder: PermissionSpecificationBuilder< PermissionType.RestrictedMethod, GetBip32PublicKeySpecificationBuilderOptions, GetBip32PublicKeySpecification -> = ({ methodHooks }: GetBip32PublicKeySpecificationBuilderOptions) => { +> = ({ + methodHooks, + messenger, +}: GetBip32PublicKeySpecificationBuilderOptions) => { return { permissionType: PermissionType.RestrictedMethod, targetName, allowedCaveats: [SnapCaveatType.PermittedDerivationPaths], - methodImplementation: getBip32PublicKeyImplementation(methodHooks), + methodImplementation: getBip32PublicKeyImplementation({ + methodHooks, + messenger, + }), validator: ({ caveats }) => { if ( caveats?.length !== 1 || @@ -123,8 +118,6 @@ const specificationBuilder: PermissionSpecificationBuilder< }; const methodHooks: MethodHooksObject = { - getMnemonic: true, - getMnemonicSeed: true, getUnlockPromise: true, getClientCryptography: true, }; @@ -169,27 +162,26 @@ export const getBip32PublicKeyBuilder = Object.freeze({ targetName, specificationBuilder, methodHooks, + actionNames: ['KeyringController:withKeyring'], } as const); /** * Builds the method implementation for `snap_getBip32PublicKey`. * - * @param hooks - The RPC method hooks. - * @param hooks.getMnemonic - A function to retrieve the Secret Recovery Phrase of the user. - * @param hooks.getMnemonicSeed - A function to retrieve the BIP-39 seed of the user. - * @param hooks.getUnlockPromise - A function that resolves once the MetaMask extension is unlocked + * @param options - The options. + * @param options.messenger - The messenger. + * @param options.methodHooks - The RPC method hooks. + * @param options.methodHooks.getUnlockPromise - A function that resolves once the MetaMask extension is unlocked * and prompts the user to unlock their MetaMask if it is locked. - * @param hooks.getClientCryptography - A function to retrieve the cryptographic + * @param options.methodHooks.getClientCryptography - A function to retrieve the cryptographic * functions to use for the client. * @returns The method implementation which returns a public key. * @throws If the params are invalid. */ export function getBip32PublicKeyImplementation({ - getMnemonic, - getMnemonicSeed, - getUnlockPromise, - getClientCryptography, -}: GetBip32PublicKeyMethodHooks) { + methodHooks: { getUnlockPromise, getClientCryptography }, + messenger, +}: GetBip32PublicKeySpecificationBuilderOptions) { return async function getBip32PublicKey( args: RestrictedMethodOptions, ): Promise { @@ -207,7 +199,7 @@ export function getBip32PublicKeyImplementation({ // Using the seed is much faster, but we can only do it for these specific curves. if (params.curve === 'secp256k1' || params.curve === 'ed25519') { const seed = await getValueFromEntropySource( - getMnemonicSeed, + getMnemonicSeed.bind(null, messenger), params.source, ); @@ -226,7 +218,7 @@ export function getBip32PublicKeyImplementation({ } const secretRecoveryPhrase = await getValueFromEntropySource( - getMnemonic, + getMnemonic.bind(null, messenger), params.source, ); diff --git a/packages/snaps-rpc-methods/src/restricted/getBip44Entropy.test.ts b/packages/snaps-rpc-methods/src/restricted/getBip44Entropy.test.ts index bfcd794948..b2022946a2 100644 --- a/packages/snaps-rpc-methods/src/restricted/getBip44Entropy.test.ts +++ b/packages/snaps-rpc-methods/src/restricted/getBip44Entropy.test.ts @@ -1,3 +1,5 @@ +import type { MockAnyNamespace } from '@metamask/messenger'; +import { MOCK_ANY_NAMESPACE, Messenger } from '@metamask/messenger'; import { SubjectType, PermissionType } from '@metamask/permission-controller'; import { SnapCaveatType } from '@metamask/snaps-utils'; import { @@ -7,6 +9,7 @@ import { import { hmac } from '@noble/hashes/hmac'; import { sha512 } from '@noble/hashes/sha512'; +import type { GetBip44EntropyMessengerActions } from './getBip44Entropy'; import { getBip44EntropyBuilder, getBip44EntropyImplementation, @@ -14,13 +17,13 @@ import { describe('specificationBuilder', () => { const methodHooks = { - getMnemonicSeed: jest.fn(), getUnlockPromise: jest.fn(), getClientCryptography: jest.fn(), }; const specification = getBip44EntropyBuilder.specificationBuilder({ methodHooks, + messenger: new Messenger({ namespace: 'GetBip44Entropy' }), }); it('outputs expected specification', () => { @@ -62,19 +65,40 @@ describe('specificationBuilder', () => { }); describe('getBip44EntropyImplementation', () => { + const getMessenger = () => { + const messenger = new Messenger< + MockAnyNamespace, + GetBip44EntropyMessengerActions + >({ + namespace: MOCK_ANY_NAMESPACE, + }); + + messenger.registerActionHandler( + 'KeyringController:withKeyring', + async (_selector, operation) => + operation({ + keyring: { + type: 'HD Key Tree', + seed: TEST_SECRET_RECOVERY_PHRASE_SEED_BYTES, + }, + }), + ); + + jest.spyOn(messenger, 'call'); + + return messenger; + }; + describe('getBip44Entropy', () => { it('derives the entropy from the path', async () => { const getUnlockPromise = jest.fn().mockResolvedValue(undefined); - const getMnemonicSeed = jest - .fn() - .mockResolvedValue(TEST_SECRET_RECOVERY_PHRASE_SEED_BYTES); const getClientCryptography = jest.fn().mockReturnValue({}); + const messenger = getMessenger(); expect( await getBip44EntropyImplementation({ - getUnlockPromise, - getMnemonicSeed, - getClientCryptography, + methodHooks: { getUnlockPromise, getClientCryptography }, + messenger, // @ts-expect-error Missing other required properties. })({ params: { coinType: 1 }, @@ -96,18 +120,14 @@ describe('getBip44EntropyImplementation', () => { }); it('calls `getMnemonic` with a different entropy source', async () => { - const getMnemonicSeed = jest - .fn() - .mockImplementation(() => TEST_SECRET_RECOVERY_PHRASE_SEED_BYTES); - const getUnlockPromise = jest.fn(); const getClientCryptography = jest.fn().mockReturnValue({}); + const messenger = getMessenger(); expect( await getBip44EntropyImplementation({ - getUnlockPromise, - getMnemonicSeed, - getClientCryptography, + methodHooks: { getUnlockPromise, getClientCryptography }, + messenger, })({ method: 'snap_getBip44Entropy', context: { origin: MOCK_SNAP_ID }, @@ -128,14 +148,16 @@ describe('getBip44EntropyImplementation', () => { } `); - expect(getMnemonicSeed).toHaveBeenCalledWith('source-id'); + expect(messenger.call).toHaveBeenCalledTimes(1); + expect(messenger.call).toHaveBeenCalledWith( + 'KeyringController:withKeyring', + { id: 'source-id' }, + expect.any(Function), + ); }); it('uses custom client cryptography functions', async () => { const getUnlockPromise = jest.fn().mockResolvedValue(undefined); - const getMnemonicSeed = jest - .fn() - .mockResolvedValue(TEST_SECRET_RECOVERY_PHRASE_SEED_BYTES); const hmacSha512 = jest .fn() @@ -145,12 +167,12 @@ describe('getBip44EntropyImplementation', () => { const getClientCryptography = jest.fn().mockReturnValue({ hmacSha512, }); + const messenger = getMessenger(); expect( await getBip44EntropyImplementation({ - getUnlockPromise, - getMnemonicSeed, - getClientCryptography, + methodHooks: { getUnlockPromise, getClientCryptography }, + messenger, // @ts-expect-error Missing other required properties. })({ params: { coinType: 1 }, diff --git a/packages/snaps-rpc-methods/src/restricted/getBip44Entropy.ts b/packages/snaps-rpc-methods/src/restricted/getBip44Entropy.ts index f7ec1beb43..e668cf811f 100644 --- a/packages/snaps-rpc-methods/src/restricted/getBip44Entropy.ts +++ b/packages/snaps-rpc-methods/src/restricted/getBip44Entropy.ts @@ -1,5 +1,6 @@ import type { CryptographicFunctions } from '@metamask/key-tree'; import { BIP44CoinTypeNode } from '@metamask/key-tree'; +import type { Messenger } from '@metamask/messenger'; import type { PermissionSpecificationBuilder, PermissionValidatorConstraint, @@ -15,22 +16,13 @@ import type { import { SnapCaveatType } from '@metamask/snaps-utils'; import type { NonEmptyArray } from '@metamask/utils'; +import type { KeyringControllerWithKeyringAction } from '../types'; import type { MethodHooksObject } from '../utils'; -import { getValueFromEntropySource } from '../utils'; +import { getMnemonicSeed, getValueFromEntropySource } from '../utils'; const targetName = 'snap_getBip44Entropy'; export type GetBip44EntropyMethodHooks = { - /** - * Get the mnemonic seed of the provided source. If no source is provided, the - * mnemonic seed of the primary keyring will be returned. - * - * @param source - The optional ID of the source to get the mnemonic of. - * @returns The mnemonic seed of the provided source, or the default source if no - * source is provided. - */ - getMnemonicSeed: (source?: string | undefined) => Promise; - /** * Waits for the extension to be unlocked. * @@ -48,8 +40,12 @@ export type GetBip44EntropyMethodHooks = { getClientCryptography: () => CryptographicFunctions | undefined; }; +export type GetBip44EntropyMessengerActions = + KeyringControllerWithKeyringAction; + type GetBip44EntropySpecificationBuilderOptions = { methodHooks: GetBip44EntropyMethodHooks; + messenger: Messenger; }; type GetBip44EntropySpecification = ValidPermissionSpecification<{ @@ -66,6 +62,7 @@ type GetBip44EntropySpecification = ValidPermissionSpecification<{ * BIP-32 coin type. * * @param options - The specification builder options. + * @param options.messenger - The messenger. * @param options.methodHooks - The RPC method hooks needed by the method * implementation. * @returns The specification for the `snap_getBip44Entropy` permission. @@ -74,12 +71,18 @@ const specificationBuilder: PermissionSpecificationBuilder< PermissionType.RestrictedMethod, GetBip44EntropySpecificationBuilderOptions, GetBip44EntropySpecification -> = ({ methodHooks }: GetBip44EntropySpecificationBuilderOptions) => { +> = ({ + methodHooks, + messenger, +}: GetBip44EntropySpecificationBuilderOptions) => { return { permissionType: PermissionType.RestrictedMethod, targetName, allowedCaveats: [SnapCaveatType.PermittedCoinTypes], - methodImplementation: getBip44EntropyImplementation(methodHooks), + methodImplementation: getBip44EntropyImplementation({ + methodHooks, + messenger, + }), validator: ({ caveats }) => { if ( caveats?.length !== 1 || @@ -95,7 +98,6 @@ const specificationBuilder: PermissionSpecificationBuilder< }; const methodHooks: MethodHooksObject = { - getMnemonicSeed: true, getUnlockPromise: true, getClientCryptography: true, }; @@ -157,27 +159,27 @@ export const getBip44EntropyBuilder = Object.freeze({ targetName, specificationBuilder, methodHooks, + actionNames: ['KeyringController:withKeyring'], } as const); /** * Builds the method implementation for `snap_getBip44Entropy`. * - * @param hooks - The RPC method hooks. - * @param hooks.getMnemonicSeed - A function to retrieve the BIP-39 seed - * of the user. - * @param hooks.getUnlockPromise - A function that resolves once the MetaMask + * @param options - The options. + * @param options.messenger - The messenger. + * @param options.methodHooks - The RPC method hooks. + * @param options.methodHooks.getUnlockPromise - A function that resolves once the MetaMask * extension is unlocked and prompts the user to unlock their MetaMask if it is * locked. - * @param hooks.getClientCryptography - A function to retrieve the cryptographic + * @param options.methodHooks.getClientCryptography - A function to retrieve the cryptographic * functions to use for the client. * @returns The method implementation which returns a `BIP44CoinTypeNode`. * @throws If the params are invalid. */ export function getBip44EntropyImplementation({ - getMnemonicSeed, - getUnlockPromise, - getClientCryptography, -}: GetBip44EntropyMethodHooks) { + methodHooks: { getUnlockPromise, getClientCryptography }, + messenger, +}: GetBip44EntropySpecificationBuilderOptions) { return async function getBip44Entropy( args: RestrictedMethodOptions, ): Promise { @@ -186,7 +188,7 @@ export function getBip44EntropyImplementation({ // `args.params` is validated by the decorator, so it's safe to assert here. const params = args.params as GetBip44EntropyParams; const seed = await getValueFromEntropySource( - getMnemonicSeed, + getMnemonicSeed.bind(null, messenger), params.source, ); diff --git a/packages/snaps-rpc-methods/src/restricted/getEntropy.test.ts b/packages/snaps-rpc-methods/src/restricted/getEntropy.test.ts index b666d5b98e..fa983809a1 100644 --- a/packages/snaps-rpc-methods/src/restricted/getEntropy.test.ts +++ b/packages/snaps-rpc-methods/src/restricted/getEntropy.test.ts @@ -1,3 +1,5 @@ +import type { MockAnyNamespace } from '@metamask/messenger'; +import { MOCK_ANY_NAMESPACE, Messenger } from '@metamask/messenger'; import { PermissionType, SubjectType } from '@metamask/permission-controller'; import { MOCK_SNAP_ID, @@ -6,6 +8,7 @@ import { import { hmac } from '@noble/hashes/hmac'; import { sha512 } from '@noble/hashes/sha512'; +import type { GetEntropyMessengerActions } from './getEntropy'; import { getEntropyBuilder } from './getEntropy'; describe('getEntropyBuilder', () => { @@ -14,22 +17,24 @@ describe('getEntropyBuilder', () => { targetName: 'snap_getEntropy', specificationBuilder: expect.any(Function), methodHooks: { - getMnemonicSeed: true, getUnlockPromise: true, getClientCryptography: true, }, + actionNames: ['KeyringController:withKeyring'], }); }); it('returns the expected specification', () => { const methodHooks = { - getMnemonicSeed: jest.fn(), getUnlockPromise: jest.fn(), getClientCryptography: jest.fn(), }; expect( - getEntropyBuilder.specificationBuilder({ methodHooks }), + getEntropyBuilder.specificationBuilder({ + methodHooks, + messenger: new Messenger({ namespace: 'GetEntropy' }), + }), ).toStrictEqual({ permissionType: PermissionType.RestrictedMethod, targetName: 'snap_getEntropy', @@ -41,22 +46,43 @@ describe('getEntropyBuilder', () => { }); describe('getEntropyImplementation', () => { - it('returns the expected result', async () => { - const getMnemonicSeed = jest - .fn() - .mockImplementation(() => TEST_SECRET_RECOVERY_PHRASE_SEED_BYTES); + const getMessenger = () => { + const messenger = new Messenger< + MockAnyNamespace, + GetEntropyMessengerActions + >({ + namespace: MOCK_ANY_NAMESPACE, + }); + + messenger.registerActionHandler( + 'KeyringController:withKeyring', + async (_selector, operation) => + operation({ + keyring: { + type: 'HD Key Tree', + seed: TEST_SECRET_RECOVERY_PHRASE_SEED_BYTES, + }, + }), + ); + + jest.spyOn(messenger, 'call'); + return messenger; + }; + + it('returns the expected result', async () => { const getUnlockPromise = jest.fn(); const getClientCryptography = jest.fn().mockReturnValue({}); const methodHooks = { - getMnemonicSeed, getUnlockPromise, getClientCryptography, }; + const messenger = getMessenger(); const implementation = getEntropyBuilder.specificationBuilder({ methodHooks, + messenger, }).methodImplementation; const result = await implementation({ @@ -76,21 +102,18 @@ describe('getEntropyImplementation', () => { }); it('calls `getMnemonic` with a different entropy source', async () => { - const getMnemonicSeed = jest - .fn() - .mockImplementation(() => TEST_SECRET_RECOVERY_PHRASE_SEED_BYTES); - const getUnlockPromise = jest.fn(); const getClientCryptography = jest.fn().mockReturnValue({}); const methodHooks = { - getMnemonicSeed, getUnlockPromise, getClientCryptography, }; + const messenger = getMessenger(); const implementation = getEntropyBuilder.specificationBuilder({ methodHooks, + messenger, }).methodImplementation; const result = await implementation({ @@ -109,14 +132,15 @@ describe('getEntropyImplementation', () => { '0x6d8e92de419401c7da3cedd5f60ce5635b26059c2a4a8003877fec83653a4921', ); - expect(getMnemonicSeed).toHaveBeenCalledWith('source-id'); + expect(messenger.call).toHaveBeenCalledWith( + 'KeyringController:withKeyring', + { id: 'source-id' }, + expect.any(Function), + ); }); it('uses custom client cryptography functions', async () => { const getUnlockPromise = jest.fn().mockResolvedValue(undefined); - const getMnemonicSeed = jest - .fn() - .mockResolvedValue(TEST_SECRET_RECOVERY_PHRASE_SEED_BYTES); const hmacSha512 = jest .fn() @@ -128,13 +152,14 @@ describe('getEntropyImplementation', () => { }); const methodHooks = { - getMnemonicSeed, getUnlockPromise, getClientCryptography, }; + const messenger = getMessenger(); const implementation = getEntropyBuilder.specificationBuilder({ methodHooks, + messenger, }).methodImplementation; const result = await implementation({ diff --git a/packages/snaps-rpc-methods/src/restricted/getEntropy.ts b/packages/snaps-rpc-methods/src/restricted/getEntropy.ts index b493452548..7a874aad45 100644 --- a/packages/snaps-rpc-methods/src/restricted/getEntropy.ts +++ b/packages/snaps-rpc-methods/src/restricted/getEntropy.ts @@ -1,4 +1,5 @@ import type { CryptographicFunctions } from '@metamask/key-tree'; +import type { Messenger } from '@metamask/messenger'; import type { PermissionSpecificationBuilder, RestrictedMethodOptions, @@ -13,14 +14,22 @@ import { literal, object, optional, string } from '@metamask/superstruct'; import type { NonEmptyArray } from '@metamask/utils'; import { assertStruct } from '@metamask/utils'; +import type { KeyringControllerWithKeyringAction } from '../types'; import type { MethodHooksObject } from '../utils'; -import { getValueFromEntropySource, deriveEntropyFromSeed } from '../utils'; +import { + deriveEntropyFromSeed, + getMnemonicSeed, + getValueFromEntropySource, +} from '../utils'; const targetName = 'snap_getEntropy'; +export type GetEntropyMessengerActions = KeyringControllerWithKeyringAction; + type GetEntropySpecificationBuilderOptions = { allowedCaveats?: Readonly> | null; methodHooks: GetEntropyHooks; + messenger: Messenger; }; type GetEntropySpecification = ValidPermissionSpecification<{ @@ -51,18 +60,18 @@ const specificationBuilder: PermissionSpecificationBuilder< > = ({ allowedCaveats = null, methodHooks, + messenger, }: GetEntropySpecificationBuilderOptions) => { return { permissionType: PermissionType.RestrictedMethod, targetName, allowedCaveats, - methodImplementation: getEntropyImplementation(methodHooks), + methodImplementation: getEntropyImplementation({ methodHooks, messenger }), subjectTypes: [SubjectType.Snap], }; }; const methodHooks: MethodHooksObject = { - getMnemonicSeed: true, getUnlockPromise: true, getClientCryptography: true, }; @@ -106,19 +115,10 @@ export const getEntropyBuilder = Object.freeze({ targetName, specificationBuilder, methodHooks, + actionNames: ['KeyringController:withKeyring'], } as const); export type GetEntropyHooks = { - /** - * Get the mnemonic seed of the provided source. If no source is provided, the - * mnemonic seed of the primary keyring will be returned. - * - * @param source - The optional ID of the source to get the mnemonic of. - * @returns The mnemonic seed of the provided source, or the default source if no - * source is provided. - */ - getMnemonicSeed: (source?: string | undefined) => Promise; - /** * Waits for the extension to be unlocked. * @@ -141,20 +141,19 @@ export type GetEntropyHooks = { * is based on the reference implementation of * [SIP-6](https://metamask.github.io/SIPs/SIPS/sip-6). * - * @param hooks - The RPC method hooks. - * @param hooks.getMnemonicSeed - A function to retrieve the BIP-39 seed - * of the user. - * @param hooks.getUnlockPromise - The method to get a promise that resolves + * @param options - The options. + * @param options.messenger - The messenger. + * @param options.methodHooks - The RPC method hooks. + * @param options.methodHooks.getUnlockPromise - The method to get a promise that resolves * once the extension is unlocked. - * @param hooks.getClientCryptography - A function to retrieve the cryptographic + * @param options.methodHooks.getClientCryptography - A function to retrieve the cryptographic * functions to use for the client. * @returns The method implementation. */ function getEntropyImplementation({ - getMnemonicSeed, - getUnlockPromise, - getClientCryptography, -}: GetEntropyHooks) { + methodHooks: { getUnlockPromise, getClientCryptography }, + messenger, +}: GetEntropySpecificationBuilderOptions) { return async function getEntropy( options: RestrictedMethodOptions, ): Promise { @@ -171,8 +170,9 @@ function getEntropyImplementation({ ); await getUnlockPromise(true); + const seed = await getValueFromEntropySource( - getMnemonicSeed, + getMnemonicSeed.bind(null, messenger), params.source, ); diff --git a/packages/snaps-rpc-methods/src/restricted/index.ts b/packages/snaps-rpc-methods/src/restricted/index.ts index 93722da6d8..e76281f355 100644 --- a/packages/snaps-rpc-methods/src/restricted/index.ts +++ b/packages/snaps-rpc-methods/src/restricted/index.ts @@ -1,42 +1,89 @@ -import type { DialogMethodHooks } from './dialog'; +import type { Messenger } from '@metamask/messenger'; +import type { + PermissionSpecificationBuilder, + PermissionType, + RestrictedMethodSpecificationConstraint, +} from '@metamask/permission-controller'; + +import type { DialogMessengerActions } from './dialog'; import { dialogBuilder } from './dialog'; -import type { GetBip32EntropyMethodHooks } from './getBip32Entropy'; +import type { + GetBip32EntropyMessengerActions, + GetBip32EntropyMethodHooks, +} from './getBip32Entropy'; import { getBip32EntropyBuilder } from './getBip32Entropy'; -import type { GetBip32PublicKeyMethodHooks } from './getBip32PublicKey'; +import type { + GetBip32PublicKeyMessengerActions, + GetBip32PublicKeyMethodHooks, +} from './getBip32PublicKey'; import { getBip32PublicKeyBuilder } from './getBip32PublicKey'; -import type { GetBip44EntropyMethodHooks } from './getBip44Entropy'; +import type { + GetBip44EntropyMessengerActions, + GetBip44EntropyMethodHooks, +} from './getBip44Entropy'; import { getBip44EntropyBuilder } from './getBip44Entropy'; -import type { GetEntropyHooks } from './getEntropy'; +import type { GetEntropyHooks, GetEntropyMessengerActions } from './getEntropy'; import { getEntropyBuilder } from './getEntropy'; import type { GetLocaleMethodHooks } from './getLocale'; import { getLocaleBuilder } from './getLocale'; import type { GetPreferencesMethodHooks } from './getPreferences'; import { getPreferencesBuilder } from './getPreferences'; -import type { InvokeSnapMethodHooks } from './invokeSnap'; +import type { InvokeSnapMessengerActions } from './invokeSnap'; import { invokeSnapBuilder } from './invokeSnap'; import type { ManageAccountsMethodHooks } from './manageAccounts'; import { manageAccountsBuilder } from './manageAccounts'; -import type { ManageStateMethodHooks } from './manageState'; +import type { + ManageStateMessengerActions, + ManageStateMethodHooks, +} from './manageState'; import { manageStateBuilder } from './manageState'; -import type { NotifyMethodHooks } from './notify'; +import type { NotifyMessengerActions, NotifyMethodHooks } from './notify'; import { notifyBuilder } from './notify'; +import type { MethodHooksObject } from '../utils'; export { WALLET_SNAP_PERMISSION_KEY } from './invokeSnap'; export { getEncryptionEntropy } from './manageState'; -export type RestrictedMethodHooks = DialogMethodHooks & - GetBip32EntropyMethodHooks & +export type RestrictedMethodActions = + | DialogMessengerActions + | GetBip32EntropyMessengerActions + | GetBip32PublicKeyMessengerActions + | GetBip44EntropyMessengerActions + | GetEntropyMessengerActions + | InvokeSnapMessengerActions + | ManageStateMessengerActions + | NotifyMessengerActions; + +export type RestrictedMethodMessenger = Messenger< + string, + RestrictedMethodActions +>; + +export type RestrictedMethodHooks = GetBip32EntropyMethodHooks & GetBip32PublicKeyMethodHooks & GetBip44EntropyMethodHooks & GetEntropyHooks & - InvokeSnapMethodHooks & ManageStateMethodHooks & NotifyMethodHooks & ManageAccountsMethodHooks & GetLocaleMethodHooks & GetPreferencesMethodHooks; -export const restrictedMethodPermissionBuilders = { +type RestrictedMethodPermissionBuilder = { + targetName: string; + specificationBuilder: PermissionSpecificationBuilder< + PermissionType.RestrictedMethod, + any, + RestrictedMethodSpecificationConstraint + >; + actionNames?: readonly RestrictedMethodActions['type'][]; + methodHooks?: MethodHooksObject>; +}; + +export const restrictedMethodPermissionBuilders: Record< + string, + RestrictedMethodPermissionBuilder +> = { [dialogBuilder.targetName]: dialogBuilder, [getBip32EntropyBuilder.targetName]: getBip32EntropyBuilder, [getBip32PublicKeyBuilder.targetName]: getBip32PublicKeyBuilder, diff --git a/packages/snaps-rpc-methods/src/restricted/invokeSnap.test.ts b/packages/snaps-rpc-methods/src/restricted/invokeSnap.test.ts index 44968b375a..4fe27d89c6 100644 --- a/packages/snaps-rpc-methods/src/restricted/invokeSnap.test.ts +++ b/packages/snaps-rpc-methods/src/restricted/invokeSnap.test.ts @@ -1,4 +1,5 @@ -import { Messenger } from '@metamask/messenger'; +import type { MockAnyNamespace } from '@metamask/messenger'; +import { MOCK_ANY_NAMESPACE, Messenger } from '@metamask/messenger'; import type { PermissionsRequest } from '@metamask/permission-controller'; import { PermissionType } from '@metamask/permission-controller'; import { SnapCaveatType } from '@metamask/snaps-utils'; @@ -11,6 +12,7 @@ import { } from '@metamask/snaps-utils/test-utils'; import type { + InvokeSnapMessengerActions, SnapControllerInstallSnapsAction, SnapControllerGetPermittedSnapsAction, } from './invokeSnap'; @@ -26,18 +28,14 @@ describe('builder', () => { expect(invokeSnapBuilder).toMatchObject({ targetName: WALLET_SNAP_PERMISSION_KEY, specificationBuilder: expect.any(Function), - methodHooks: { - handleSnapRpcRequest: true, - }, + actionNames: ['SnapController:handleRequest'], }); }); it('builder outputs expected specification', () => { expect( invokeSnapBuilder.specificationBuilder({ - methodHooks: { - handleSnapRpcRequest: jest.fn(), - }, + messenger: new Messenger({ namespace: 'InvokeSnap' }), }), ).toStrictEqual({ permissionType: PermissionType.RestrictedMethod, @@ -54,9 +52,7 @@ describe('builder', () => { describe('specificationBuilder', () => { const specification = invokeSnapBuilder.specificationBuilder({ - methodHooks: { - handleSnapRpcRequest: jest.fn(), - }, + messenger: new Messenger({ namespace: 'InvokeSnap' }), }); describe('validator', () => { it('throws if the caveat is not a single "snapIds"', () => { @@ -84,14 +80,27 @@ describe('specificationBuilder', () => { }); describe('implementation', () => { - const getMockHooks = () => - ({ - getSnap: jest.fn(), - handleSnapRpcRequest: jest.fn(), - }) as any; - it('calls handleSnapRpcRequest', async () => { - const hooks = getMockHooks(); - const implementation = getInvokeSnapImplementation(hooks); + const getMessenger = () => { + const messenger = new Messenger< + MockAnyNamespace, + InvokeSnapMessengerActions + >({ + namespace: MOCK_ANY_NAMESPACE, + }); + + messenger.registerActionHandler( + 'SnapController:handleRequest', + async () => null, + ); + + jest.spyOn(messenger, 'call'); + + return messenger; + }; + + it('calls SnapController:handleRequest', async () => { + const messenger = getMessenger(); + const implementation = getInvokeSnapImplementation({ messenger }); await implementation({ context: { origin: MOCK_ORIGIN }, method: WALLET_SNAP_PERMISSION_KEY, @@ -101,15 +110,18 @@ describe('implementation', () => { }, }); - expect(hooks.handleSnapRpcRequest).toHaveBeenCalledWith({ - handler: 'onRpcRequest', - origin: MOCK_ORIGIN, - request: { - method: 'hello', - params: {}, + expect(messenger.call).toHaveBeenCalledWith( + 'SnapController:handleRequest', + { + handler: 'onRpcRequest', + origin: MOCK_ORIGIN, + request: { + method: 'hello', + params: {}, + }, + snapId: MOCK_SNAP_ID, }, - snapId: MOCK_SNAP_ID, - }); + ); }); }); diff --git a/packages/snaps-rpc-methods/src/restricted/invokeSnap.ts b/packages/snaps-rpc-methods/src/restricted/invokeSnap.ts index 177d4537c7..a4feac7208 100644 --- a/packages/snaps-rpc-methods/src/restricted/invokeSnap.ts +++ b/packages/snaps-rpc-methods/src/restricted/invokeSnap.ts @@ -1,3 +1,4 @@ +import type { Messenger } from '@metamask/messenger'; import type { PermissionSpecificationBuilder, RestrictedMethodOptions, @@ -13,11 +14,10 @@ import type { RequestSnapsParams, RequestSnapsResult, } from '@metamask/snaps-sdk'; -import type { SnapRpcHookArgs } from '@metamask/snaps-utils'; import { HandlerType, SnapCaveatType } from '@metamask/snaps-utils'; import type { Json, NonEmptyArray } from '@metamask/utils'; -import type { MethodHooksObject } from '../utils'; +import type { SnapControllerHandleRequestAction } from '../types'; export const WALLET_SNAP_PERMISSION_KEY = 'wallet_snap'; @@ -39,18 +39,11 @@ type AllowedActions = | SnapControllerInstallSnapsAction | SnapControllerGetPermittedSnapsAction; -export type InvokeSnapMethodHooks = { - handleSnapRpcRequest: ({ - snapId, - origin, - handler, - request, - }: SnapRpcHookArgs & { snapId: string }) => Promise; -}; +export type InvokeSnapMessengerActions = SnapControllerHandleRequestAction; type InvokeSnapSpecificationBuilderOptions = { allowedCaveats?: Readonly> | null; - methodHooks: InvokeSnapMethodHooks; + messenger: Messenger; }; type InvokeSnapSpecification = ValidPermissionSpecification<{ @@ -109,19 +102,21 @@ export const handleSnapInstall: PermissionSideEffect< * and install it if it's not available yet. * * @param options - The specification builder options. - * @param options.methodHooks - The RPC method hooks needed by the method implementation. + * @param options.messenger - The messenger. * @returns The specification for the `wallet_snap_*` permission. */ const specificationBuilder: PermissionSpecificationBuilder< PermissionType.RestrictedMethod, InvokeSnapSpecificationBuilderOptions, InvokeSnapSpecification -> = ({ methodHooks }: InvokeSnapSpecificationBuilderOptions) => { +> = ({ messenger }: InvokeSnapSpecificationBuilderOptions) => { return { permissionType: PermissionType.RestrictedMethod, targetName: WALLET_SNAP_PERMISSION_KEY, allowedCaveats: [SnapCaveatType.SnapIds], - methodImplementation: getInvokeSnapImplementation(methodHooks), + methodImplementation: getInvokeSnapImplementation({ + messenger, + }), validator: ({ caveats }) => { if (caveats?.length !== 1 || caveats[0].type !== SnapCaveatType.SnapIds) { throw rpcErrors.invalidParams({ @@ -135,10 +130,6 @@ const specificationBuilder: PermissionSpecificationBuilder< }; }; -const methodHooks: MethodHooksObject = { - handleSnapRpcRequest: true, -}; - /** * Calls the specified JSON-RPC API method of the specified Snap. The Snap * must be installed and the dapp must have permission to communicate with the @@ -164,20 +155,20 @@ const methodHooks: MethodHooksObject = { export const invokeSnapBuilder = Object.freeze({ targetName: WALLET_SNAP_PERMISSION_KEY, specificationBuilder, - methodHooks, + actionNames: ['SnapController:handleRequest'], } as const); /** * Builds the method implementation for `wallet_snap_*`. * - * @param hooks - The RPC method hooks. - * @param hooks.handleSnapRpcRequest - A function that sends an RPC request to a snap's RPC handler or throws if that fails. - * @returns The method implementation which returns the result of `handleSnapRpcRequest`. + * @param options - The options. + * @param options.messenger - The messenger. + * @returns The method implementation which returns the result of `SnapController:handleRequest`. * @throws If the params are invalid. */ export function getInvokeSnapImplementation({ - handleSnapRpcRequest, -}: InvokeSnapMethodHooks) { + messenger, +}: InvokeSnapSpecificationBuilderOptions) { return async function invokeSnap( options: RestrictedMethodOptions, ): Promise { @@ -187,7 +178,7 @@ export function getInvokeSnapImplementation({ const { origin } = context; - return (await handleSnapRpcRequest({ + return (await messenger.call('SnapController:handleRequest', { snapId, origin, request, diff --git a/packages/snaps-rpc-methods/src/restricted/manageState.test.ts b/packages/snaps-rpc-methods/src/restricted/manageState.test.ts index 6bf6de772e..64ffe91177 100644 --- a/packages/snaps-rpc-methods/src/restricted/manageState.test.ts +++ b/packages/snaps-rpc-methods/src/restricted/manageState.test.ts @@ -1,10 +1,15 @@ +import type { MockAnyNamespace } from '@metamask/messenger'; +import { MOCK_ANY_NAMESPACE, Messenger } from '@metamask/messenger'; import { PermissionType, SubjectType } from '@metamask/permission-controller'; import { ManageStateOperation } from '@metamask/snaps-sdk'; +import type { Snap } from '@metamask/snaps-utils'; import { MOCK_SNAP_ID, TEST_SECRET_RECOVERY_PHRASE_SEED_BYTES, } from '@metamask/snaps-utils/test-utils'; +import type { Json } from '@metamask/utils'; +import type { ManageStateMessengerActions } from './manageState'; import { getEncryptionEntropy, getManageStateImplementation, @@ -41,17 +46,14 @@ describe('snap_manageState', () => { describe('specification', () => { it('builds specification', () => { const methodHooks = { - clearSnapState: jest.fn(), - getSnapState: jest.fn(), - updateSnapState: jest.fn(), getUnlockPromise: jest.fn(), - getSnap: jest.fn(), }; expect( specificationBuilder({ allowedCaveats: null, methodHooks, + messenger: new Messenger({ namespace: 'ManageState' }), }), ).toStrictEqual({ allowedCaveats: null, @@ -64,6 +66,42 @@ describe('snap_manageState', () => { }); describe('getManageStateImplementation', () => { + const getMessenger = ({ + snap = { preinstalled: false } as Snap, + snapState = null as Record | null, + }: { + snap?: Snap | null; + snapState?: Record | null; + } = {}) => { + const messenger = new Messenger< + MockAnyNamespace, + ManageStateMessengerActions + >({ + namespace: MOCK_ANY_NAMESPACE, + }); + + messenger.registerActionHandler('SnapController:getSnap', () => snap); + + messenger.registerActionHandler( + 'SnapController:getSnapState', + async () => snapState, + ); + + messenger.registerActionHandler( + 'SnapController:clearSnapState', + () => undefined, + ); + + messenger.registerActionHandler( + 'SnapController:updateSnapState', + async () => undefined, + ); + + jest.spyOn(messenger, 'call'); + + return messenger; + }; + it('gets snap state', async () => { const mockSnapState = { some: { @@ -71,17 +109,11 @@ describe('snap_manageState', () => { }, }; - const clearSnapState = jest.fn().mockReturnValueOnce(true); - const getSnapState = jest.fn().mockReturnValueOnce(mockSnapState); - const updateSnapState = jest.fn().mockReturnValueOnce(true); - const getSnap = jest.fn().mockReturnValue({ preinstalled: false }); + const messenger = getMessenger({ snapState: mockSnapState }); const manageStateImplementation = getManageStateImplementation({ - clearSnapState, - getSnapState, - updateSnapState, - getUnlockPromise: jest.fn(), - getSnap, + methodHooks: { getUnlockPromise: jest.fn() }, + messenger, }); const result = await manageStateImplementation({ @@ -90,7 +122,11 @@ describe('snap_manageState', () => { params: { operation: ManageStateOperation.GetState }, }); - expect(getSnapState).toHaveBeenCalledWith(MOCK_SNAP_ID, true); + expect(messenger.call).toHaveBeenCalledWith( + 'SnapController:getSnapState', + MOCK_SNAP_ID, + true, + ); expect(result).toStrictEqual(mockSnapState); }); @@ -101,18 +137,12 @@ describe('snap_manageState', () => { }, }; - const clearSnapState = jest.fn().mockReturnValueOnce(true); - const getSnapState = jest.fn().mockReturnValueOnce(mockSnapState); - const updateSnapState = jest.fn().mockReturnValueOnce(true); const getUnlockPromise = jest.fn(); - const getSnap = jest.fn().mockReturnValue({ preinstalled: false }); + const messenger = getMessenger({ snapState: mockSnapState }); const manageStateImplementation = getManageStateImplementation({ - clearSnapState, - getSnapState, - updateSnapState, - getUnlockPromise, - getSnap, + methodHooks: { getUnlockPromise }, + messenger, }); const result = await manageStateImplementation({ @@ -121,23 +151,21 @@ describe('snap_manageState', () => { params: { operation: ManageStateOperation.GetState, encrypted: false }, }); - expect(getSnapState).toHaveBeenCalledWith(MOCK_SNAP_ID, false); + expect(messenger.call).toHaveBeenCalledWith( + 'SnapController:getSnapState', + MOCK_SNAP_ID, + false, + ); expect(getUnlockPromise).not.toHaveBeenCalled(); expect(result).toStrictEqual(mockSnapState); }); it('supports empty state', async () => { - const clearSnapState = jest.fn().mockReturnValueOnce(true); - const getSnapState = jest.fn().mockReturnValueOnce(null); - const updateSnapState = jest.fn().mockReturnValueOnce(true); - const getSnap = jest.fn().mockReturnValue({ preinstalled: false }); + const messenger = getMessenger({ snapState: null }); const manageStateImplementation = getManageStateImplementation({ - clearSnapState, - getSnapState, - updateSnapState, - getUnlockPromise: jest.fn(), - getSnap, + methodHooks: { getUnlockPromise: jest.fn() }, + messenger, }); const result = await manageStateImplementation({ @@ -146,23 +174,21 @@ describe('snap_manageState', () => { params: { operation: ManageStateOperation.GetState }, }); - expect(getSnapState).toHaveBeenCalledWith(MOCK_SNAP_ID, true); + expect(messenger.call).toHaveBeenCalledWith( + 'SnapController:getSnapState', + MOCK_SNAP_ID, + true, + ); expect(result).toBeNull(); }); it('clears snap state', async () => { - const clearSnapState = jest.fn().mockReturnValueOnce(true); - const getSnapState = jest.fn().mockReturnValueOnce(true); - const updateSnapState = jest.fn().mockReturnValueOnce(true); const getUnlockPromise = jest.fn(); - const getSnap = jest.fn().mockReturnValue({ preinstalled: false }); + const messenger = getMessenger(); const manageStateImplementation = getManageStateImplementation({ - clearSnapState, - getSnapState, - updateSnapState, - getUnlockPromise, - getSnap, + methodHooks: { getUnlockPromise }, + messenger, }); await manageStateImplementation({ @@ -171,23 +197,21 @@ describe('snap_manageState', () => { params: { operation: ManageStateOperation.ClearState }, }); - expect(clearSnapState).toHaveBeenCalledWith(MOCK_SNAP_ID, true); + expect(messenger.call).toHaveBeenCalledWith( + 'SnapController:clearSnapState', + MOCK_SNAP_ID, + true, + ); expect(getUnlockPromise).not.toHaveBeenCalled(); }); it('clears unencrypted snap state', async () => { - const clearSnapState = jest.fn().mockReturnValueOnce(true); - const getSnapState = jest.fn().mockReturnValueOnce(true); - const updateSnapState = jest.fn().mockReturnValueOnce(true); const getUnlockPromise = jest.fn(); - const getSnap = jest.fn().mockReturnValue({ preinstalled: false }); + const messenger = getMessenger(); const manageStateImplementation = getManageStateImplementation({ - clearSnapState, - getSnapState, - updateSnapState, - getUnlockPromise, - getSnap, + methodHooks: { getUnlockPromise }, + messenger, }); await manageStateImplementation({ @@ -199,7 +223,11 @@ describe('snap_manageState', () => { }, }); - expect(clearSnapState).toHaveBeenCalledWith(MOCK_SNAP_ID, false); + expect(messenger.call).toHaveBeenCalledWith( + 'SnapController:clearSnapState', + MOCK_SNAP_ID, + false, + ); expect(getUnlockPromise).not.toHaveBeenCalled(); }); @@ -210,17 +238,11 @@ describe('snap_manageState', () => { }, }; - const clearSnapState = jest.fn().mockReturnValueOnce(true); - const getSnapState = jest.fn().mockReturnValueOnce(true); - const updateSnapState = jest.fn().mockReturnValueOnce(true); - const getSnap = jest.fn().mockReturnValue({ preinstalled: false }); + const messenger = getMessenger(); const manageStateImplementation = getManageStateImplementation({ - clearSnapState, - getSnapState, - updateSnapState, - getUnlockPromise: jest.fn(), - getSnap, + methodHooks: { getUnlockPromise: jest.fn() }, + messenger, }); await manageStateImplementation({ @@ -232,7 +254,8 @@ describe('snap_manageState', () => { }, }); - expect(updateSnapState).toHaveBeenCalledWith( + expect(messenger.call).toHaveBeenCalledWith( + 'SnapController:updateSnapState', MOCK_SNAP_ID, mockSnapState, true, @@ -246,20 +269,12 @@ describe('snap_manageState', () => { }, }; - const clearSnapState = jest.fn().mockReturnValueOnce(true); - const getSnapState = jest - .fn() - .mockReturnValueOnce(JSON.stringify(mockSnapState)); - const updateSnapState = jest.fn().mockReturnValueOnce(true); const getUnlockPromise = jest.fn(); - const getSnap = jest.fn().mockReturnValue({ preinstalled: false }); + const messenger = getMessenger(); const manageStateImplementation = getManageStateImplementation({ - clearSnapState, - getSnapState, - updateSnapState, - getUnlockPromise, - getSnap, + methodHooks: { getUnlockPromise }, + messenger, }); await manageStateImplementation({ @@ -272,7 +287,8 @@ describe('snap_manageState', () => { }, }); - expect(updateSnapState).toHaveBeenCalledWith( + expect(messenger.call).toHaveBeenCalledWith( + 'SnapController:updateSnapState', MOCK_SNAP_ID, mockSnapState, false, @@ -287,17 +303,11 @@ describe('snap_manageState', () => { }, }; - const clearSnapState = jest.fn().mockReturnValueOnce(true); - const getSnapState = jest.fn().mockReturnValueOnce(true); - const updateSnapState = jest.fn().mockReturnValueOnce(true); - const getSnap = jest.fn().mockReturnValue({ preinstalled: false }); + const messenger = getMessenger(); const manageStateImplementation = getManageStateImplementation({ - clearSnapState, - getSnapState, - updateSnapState, - getUnlockPromise: jest.fn(), - getSnap, + methodHooks: { getUnlockPromise: jest.fn() }, + messenger, }); expect(async () => @@ -313,17 +323,11 @@ describe('snap_manageState', () => { }); it('throws an error on update if the new state is not plain object', async () => { - const clearSnapState = jest.fn().mockReturnValueOnce(true); - const getSnapState = jest.fn().mockReturnValueOnce(true); - const updateSnapState = jest.fn().mockReturnValueOnce(true); - const getSnap = jest.fn().mockReturnValue({ preinstalled: false }); + const messenger = getMessenger(); const manageStateImplementation = getManageStateImplementation({ - clearSnapState, - getSnapState, - updateSnapState, - getUnlockPromise: jest.fn(), - getSnap, + methodHooks: { getUnlockPromise: jest.fn() }, + messenger, }); const newState = (a: unknown) => { @@ -344,7 +348,8 @@ describe('snap_manageState', () => { 'Invalid snap_manageState "newState" parameter: The new state must be a plain object.', ); - expect(updateSnapState).not.toHaveBeenCalledWith( + expect(messenger.call).not.toHaveBeenCalledWith( + 'SnapController:updateSnapState', MOCK_SNAP_ID, newState, true, @@ -352,17 +357,11 @@ describe('snap_manageState', () => { }); it('throws an error on update if the new state is not valid json serializable object', async () => { - const clearSnapState = jest.fn().mockReturnValueOnce(true); - const getSnapState = jest.fn().mockReturnValueOnce(true); - const updateSnapState = jest.fn().mockReturnValueOnce(true); - const getSnap = jest.fn().mockReturnValue({ preinstalled: false }); + const messenger = getMessenger(); const manageStateImplementation = getManageStateImplementation({ - clearSnapState, - getSnapState, - updateSnapState, - getUnlockPromise: jest.fn(), - getSnap, + methodHooks: { getUnlockPromise: jest.fn() }, + messenger, }); const newState = { @@ -387,7 +386,8 @@ describe('snap_manageState', () => { 'Invalid snap_manageState "newState" parameter: The new state must be JSON serializable.', ); - expect(updateSnapState).not.toHaveBeenCalledWith( + expect(messenger.call).not.toHaveBeenCalledWith( + 'SnapController:updateSnapState', 'snap-origin', newState, true, @@ -395,17 +395,11 @@ describe('snap_manageState', () => { }); it('throws an error on update if the new state is too large', async () => { - const clearSnapState = jest.fn().mockReturnValueOnce(true); - const getSnapState = jest.fn().mockReturnValueOnce(true); - const updateSnapState = jest.fn().mockReturnValueOnce(true); - const getSnap = jest.fn().mockReturnValue({ preinstalled: false }); + const messenger = getMessenger(); const manageStateImplementation = getManageStateImplementation({ - clearSnapState, - getSnapState, - updateSnapState, - getUnlockPromise: jest.fn(), - getSnap, + methodHooks: { getUnlockPromise: jest.fn() }, + messenger, }); const newState = { @@ -425,7 +419,8 @@ describe('snap_manageState', () => { 'Invalid snap_manageState "newState" parameter: The new state must not exceed 64 MB in size.', ); - expect(updateSnapState).not.toHaveBeenCalledWith( + expect(messenger.call).not.toHaveBeenCalledWith( + 'SnapController:updateSnapState', 'snap-origin', newState, true, diff --git a/packages/snaps-rpc-methods/src/restricted/manageState.ts b/packages/snaps-rpc-methods/src/restricted/manageState.ts index bb0372c4ec..775c699fb1 100644 --- a/packages/snaps-rpc-methods/src/restricted/manageState.ts +++ b/packages/snaps-rpc-methods/src/restricted/manageState.ts @@ -1,4 +1,5 @@ import type { CryptographicFunctions } from '@metamask/key-tree'; +import type { Messenger } from '@metamask/messenger'; import type { PermissionSpecificationBuilder, RestrictedMethodOptions, @@ -8,14 +9,19 @@ import { PermissionType, SubjectType } from '@metamask/permission-controller'; import { rpcErrors } from '@metamask/rpc-errors'; import type { ManageStateParams, ManageStateResult } from '@metamask/snaps-sdk'; import { ManageStateOperation } from '@metamask/snaps-sdk'; -import type { Snap } from '@metamask/snaps-utils'; import { getJsonSizeUnsafe, STATE_ENCRYPTION_MAGIC_VALUE, } from '@metamask/snaps-utils'; -import type { Json, NonEmptyArray } from '@metamask/utils'; +import type { NonEmptyArray } from '@metamask/utils'; import { isObject, isValidJson } from '@metamask/utils'; +import type { + SnapControllerClearSnapStateAction, + SnapControllerGetSnapAction, + SnapControllerGetSnapStateAction, + SnapControllerUpdateSnapStateAction, +} from '../types'; import type { MethodHooksObject } from '../utils'; import { deriveEntropyFromSeed } from '../utils'; @@ -31,44 +37,18 @@ export type ManageStateMethodHooks = { * @returns A promise that resolves once the extension is unlocked. */ getUnlockPromise: (shouldShowUnlockRequest: boolean) => Promise; - - /** - * A function that clears the state of the requesting Snap. - */ - clearSnapState: (snapId: string, encrypted: boolean) => void; - - /** - * A function that gets the encrypted state of the requesting Snap. - * - * @returns The current state of the Snap. - */ - getSnapState: ( - snapId: string, - encrypted: boolean, - ) => Promise>; - - /** - * A function that updates the state of the requesting Snap. - * - * @param newState - The new state of the Snap. - */ - updateSnapState: ( - snapId: string, - newState: Record, - encrypted: boolean, - ) => Promise; - - /** - * Get Snap metadata. - * - * @param snapId - The ID of a Snap. - */ - getSnap: (snapId: string) => Snap | undefined; }; +export type ManageStateMessengerActions = + | SnapControllerClearSnapStateAction + | SnapControllerGetSnapAction + | SnapControllerGetSnapStateAction + | SnapControllerUpdateSnapStateAction; + type ManageStateSpecificationBuilderOptions = { allowedCaveats?: Readonly> | null; methodHooks: ManageStateMethodHooks; + messenger: Messenger; }; type ManageStateSpecification = ValidPermissionSpecification<{ @@ -85,6 +65,7 @@ type ManageStateSpecification = ValidPermissionSpecification<{ * * @param options - The specification builder options. * @param options.allowedCaveats - The optional allowed caveats for the permission. + * @param options.messenger - The messenger. * @param options.methodHooks - The RPC method hooks needed by the method implementation. * @returns The specification for the `snap_manageState` permission. */ @@ -95,22 +76,22 @@ export const specificationBuilder: PermissionSpecificationBuilder< > = ({ allowedCaveats = null, methodHooks, + messenger, }: ManageStateSpecificationBuilderOptions) => { return { permissionType: PermissionType.RestrictedMethod, targetName: methodName, allowedCaveats, - methodImplementation: getManageStateImplementation(methodHooks), + methodImplementation: getManageStateImplementation({ + methodHooks, + messenger, + }), subjectTypes: [SubjectType.Snap], }; }; const methodHooks: MethodHooksObject = { getUnlockPromise: true, - clearSnapState: true, - getSnapState: true, - updateSnapState: true, - getSnap: true, }; /** @@ -159,6 +140,12 @@ export const manageStateBuilder = Object.freeze({ targetName: methodName, specificationBuilder, methodHooks, + actionNames: [ + 'SnapController:clearSnapState', + 'SnapController:getSnap', + 'SnapController:getSnapState', + 'SnapController:updateSnapState', + ], } as const); export const STORAGE_SIZE_LIMIT = 64_000_000; // In bytes (64 MB) @@ -201,28 +188,20 @@ export async function getEncryptionEntropy({ /** * Builds the method implementation for `snap_manageState`. * - * @param hooks - The RPC method hooks. - * @param hooks.clearSnapState - A function that clears the state stored for a - * snap. - * @param hooks.getSnapState - A function that fetches the persisted decrypted - * state for a snap. - * @param hooks.updateSnapState - A function that updates the state stored for a - * snap. - * @param hooks.getUnlockPromise - A function that resolves once the MetaMask + * @param options - The options. + * @param options.messenger - The messenger. + * @param options.methodHooks - The RPC method hooks. + * @param options.methodHooks.getUnlockPromise - A function that resolves once the MetaMask * extension is unlocked and prompts the user to unlock their MetaMask if it is * locked. - * @param hooks.getSnap - The hook function to get Snap metadata. * @returns The method implementation which either returns `null` for a * successful state update/deletion or returns the decrypted state. * @throws If the params are invalid. */ export function getManageStateImplementation({ - getUnlockPromise, - clearSnapState, - getSnapState, - updateSnapState, - getSnap, -}: ManageStateMethodHooks) { + methodHooks: { getUnlockPromise }, + messenger, +}: ManageStateSpecificationBuilderOptions) { return async function manageState( options: RestrictedMethodOptions, ): Promise { @@ -233,7 +212,7 @@ export function getManageStateImplementation({ } = options; const validatedParams = getValidatedParams(params, method); - const snap = getSnap(origin); + const snap = messenger.call('SnapController:getSnap', origin); if ( !snap?.preinstalled && @@ -264,15 +243,24 @@ export function getManageStateImplementation({ switch (validatedParams.operation) { case ManageStateOperation.ClearState: - clearSnapState(origin, shouldEncrypt); + messenger.call('SnapController:clearSnapState', origin, shouldEncrypt); return null; case ManageStateOperation.GetState: { - return await getSnapState(origin, shouldEncrypt); + return await messenger.call( + 'SnapController:getSnapState', + origin, + shouldEncrypt, + ); } case ManageStateOperation.UpdateState: { - await updateSnapState(origin, validatedParams.newState, shouldEncrypt); + await messenger.call( + 'SnapController:updateSnapState', + origin, + validatedParams.newState, + shouldEncrypt, + ); return null; } diff --git a/packages/snaps-rpc-methods/src/restricted/notify.test.tsx b/packages/snaps-rpc-methods/src/restricted/notify.test.tsx index c78a28ad1a..a6f59c364f 100644 --- a/packages/snaps-rpc-methods/src/restricted/notify.test.tsx +++ b/packages/snaps-rpc-methods/src/restricted/notify.test.tsx @@ -1,7 +1,10 @@ +import type { MockAnyNamespace } from '@metamask/messenger'; +import { MOCK_ANY_NAMESPACE, Messenger } from '@metamask/messenger'; import { PermissionType, SubjectType } from '@metamask/permission-controller'; import { ContentType, NotificationType } from '@metamask/snaps-sdk'; import { Box, Text } from '@metamask/snaps-sdk/jsx'; +import type { NotifyMessengerActions } from './notify'; import { getImplementation, getValidatedParams, @@ -14,20 +17,39 @@ describe('snap_notify', () => { message: 'Some message', }; + const getMessenger = () => { + const messenger = new Messenger({ + namespace: MOCK_ANY_NAMESPACE, + }); + + messenger.registerActionHandler( + 'RateLimitController:call', + async () => true, + ); + + messenger.registerActionHandler( + 'SnapInterfaceController:createInterface', + () => 'foo', + ); + + messenger.registerActionHandler('SnapController:getSnap', () => null); + + jest.spyOn(messenger, 'call'); + + return messenger; + }; + describe('specification', () => { it('builds specification', () => { const methodHooks = { - showNativeNotification: jest.fn(), - showInAppNotification: jest.fn(), isOnPhishingList: jest.fn(), maybeUpdatePhishingList: jest.fn(), - createInterface: jest.fn(), - getSnap: jest.fn(), }; expect( specificationBuilder({ methodHooks, + messenger: new Messenger({ namespace: 'Notify' }), }), ).toStrictEqual({ allowedCaveats: null, @@ -41,20 +63,13 @@ describe('snap_notify', () => { describe('getImplementation', () => { it('shows inApp notification', async () => { - const showNativeNotification = jest.fn().mockResolvedValueOnce(true); - const showInAppNotification = jest.fn().mockResolvedValueOnce(true); const isOnPhishingList = jest.fn().mockResolvedValueOnce(false); - const getSnap = jest.fn(); const maybeUpdatePhishingList = jest.fn(); - const createInterface = jest.fn(); + const messenger = getMessenger(); const notificationImplementation = getImplementation({ - showNativeNotification, - showInAppNotification, - isOnPhishingList, - maybeUpdatePhishingList, - createInterface, - getSnap, + methodHooks: { isOnPhishingList, maybeUpdatePhishingList }, + messenger, }); await notificationImplementation({ @@ -68,27 +83,28 @@ describe('snap_notify', () => { }, }); - expect(showInAppNotification).toHaveBeenCalledWith('extension', { - type: NotificationType.InApp, - message: 'Some message', - }); + expect(messenger.call).toHaveBeenCalledWith( + 'RateLimitController:call', + 'extension', + 'showInAppNotification', + 'extension', + { + interfaceId: undefined, + message: 'Some message', + title: undefined, + footerLink: undefined, + }, + ); }); it('shows inApp notifications with a detailed view', async () => { - const showNativeNotification = jest.fn().mockResolvedValueOnce(true); - const showInAppNotification = jest.fn().mockResolvedValueOnce(true); const isOnPhishingList = jest.fn().mockResolvedValueOnce(false); const maybeUpdatePhishingList = jest.fn(); - const createInterface = jest.fn().mockResolvedValueOnce(1); - const getSnap = jest.fn(); + const messenger = getMessenger(); const notificationImplementation = getImplementation({ - showNativeNotification, - showInAppNotification, - isOnPhishingList, - maybeUpdatePhishingList, - createInterface, - getSnap, + methodHooks: { isOnPhishingList, maybeUpdatePhishingList }, + messenger, }); await notificationImplementation({ @@ -104,36 +120,36 @@ describe('snap_notify', () => { }, }); - expect(showInAppNotification).toHaveBeenCalledWith('extension', { - type: NotificationType.InApp, - message: 'Some message', - title: 'Detailed view title', - content: 1, - }); - - expect(createInterface).toHaveBeenCalledWith( + expect(messenger.call).toHaveBeenCalledWith( + 'SnapInterfaceController:createInterface', 'extension', Hello, undefined, ContentType.Notification, ); + + expect(messenger.call).toHaveBeenCalledWith( + 'RateLimitController:call', + 'extension', + 'showInAppNotification', + 'extension', + { + interfaceId: 'foo', + message: 'Some message', + title: 'Detailed view title', + footerLink: undefined, + }, + ); }); it('shows native notification', async () => { - const showNativeNotification = jest.fn().mockResolvedValueOnce(true); - const showInAppNotification = jest.fn().mockResolvedValueOnce(true); const isOnPhishingList = jest.fn().mockResolvedValueOnce(false); const maybeUpdatePhishingList = jest.fn(); - const createInterface = jest.fn(); - const getSnap = jest.fn(); + const messenger = getMessenger(); const notificationImplementation = getImplementation({ - showNativeNotification, - showInAppNotification, - isOnPhishingList, - maybeUpdatePhishingList, - createInterface, - getSnap, + methodHooks: { isOnPhishingList, maybeUpdatePhishingList }, + messenger, }); await notificationImplementation({ @@ -147,27 +163,23 @@ describe('snap_notify', () => { }, }); - expect(showNativeNotification).toHaveBeenCalledWith('extension', { - type: NotificationType.Native, - message: 'Some message', - }); + expect(messenger.call).toHaveBeenCalledWith( + 'RateLimitController:call', + 'extension', + 'showNativeNotification', + 'extension', + 'Some message', + ); }); it('accepts string notification types', async () => { - const showNativeNotification = jest.fn().mockResolvedValueOnce(true); - const showInAppNotification = jest.fn().mockResolvedValueOnce(true); const isOnPhishingList = jest.fn().mockResolvedValueOnce(false); const maybeUpdatePhishingList = jest.fn(); - const createInterface = jest.fn(); - const getSnap = jest.fn(); + const messenger = getMessenger(); const notificationImplementation = getImplementation({ - showNativeNotification, - showInAppNotification, - isOnPhishingList, - maybeUpdatePhishingList, - createInterface, - getSnap, + methodHooks: { isOnPhishingList, maybeUpdatePhishingList }, + messenger, }); await notificationImplementation({ @@ -181,27 +193,23 @@ describe('snap_notify', () => { }, }); - expect(showNativeNotification).toHaveBeenCalledWith('extension', { - type: NotificationType.Native, - message: 'Some message', - }); + expect(messenger.call).toHaveBeenCalledWith( + 'RateLimitController:call', + 'extension', + 'showNativeNotification', + 'extension', + 'Some message', + ); }); it('throws an error if the notification type is invalid', async () => { - const showNativeNotification = jest.fn().mockResolvedValueOnce(true); - const showInAppNotification = jest.fn().mockResolvedValueOnce(true); const isOnPhishingList = jest.fn().mockResolvedValueOnce(false); const maybeUpdatePhishingList = jest.fn(); - const createInterface = jest.fn(); - const getSnap = jest.fn(); + const messenger = getMessenger(); const notificationImplementation = getImplementation({ - showNativeNotification, - showInAppNotification, - isOnPhishingList, - maybeUpdatePhishingList, - createInterface, - getSnap, + methodHooks: { isOnPhishingList, maybeUpdatePhishingList }, + messenger, }); await expect( @@ -223,9 +231,9 @@ describe('snap_notify', () => { describe('getValidatedParams', () => { it('throws an error if the params is not an object', () => { const isOnPhishingList = jest.fn().mockResolvedValue(true); - expect(() => getValidatedParams([], isOnPhishingList, jest.fn())).toThrow( - 'Expected params to be a single object.', - ); + expect(() => + getValidatedParams([], isOnPhishingList, getMessenger()), + ).toThrow('Expected params to be a single object.'); }); it('throws an error if the type is missing from params object', () => { @@ -234,7 +242,7 @@ describe('snap_notify', () => { getValidatedParams( { type: undefined, message: 'Something happened.' }, isOnPhishingList, - jest.fn(), + getMessenger(), ), ).toThrow('Must specify a valid notification "type".'); }); @@ -245,7 +253,7 @@ describe('snap_notify', () => { getValidatedParams( { type: NotificationType.InApp, message: '' }, isOnPhishingList, - jest.fn(), + getMessenger(), ), ).toThrow( 'Must specify a non-empty string "message" less than 500 characters long.', @@ -258,7 +266,7 @@ describe('snap_notify', () => { getValidatedParams( { type: NotificationType.InApp, message: 123 }, isOnPhishingList, - jest.fn(), + getMessenger(), ), ).toThrow( 'Must specify a non-empty string "message" less than 500 characters long.', @@ -274,7 +282,7 @@ describe('snap_notify', () => { message: 'test'.repeat(20), }, isOnPhishingList, - jest.fn(), + getMessenger(), ), ).toThrow( 'Must specify a non-empty string "message" less than 50 characters long.', @@ -290,7 +298,7 @@ describe('snap_notify', () => { message: 'test'.repeat(150), }, isOnPhishingList, - jest.fn(), + getMessenger(), ), ).toThrow( 'Must specify a non-empty string "message" less than 500 characters long.', @@ -298,20 +306,13 @@ describe('snap_notify', () => { }); it('throws an error if a link in the `message` property is on the phishing list', async () => { - const showNativeNotification = jest.fn().mockResolvedValueOnce(true); - const showInAppNotification = jest.fn().mockResolvedValueOnce(true); const isOnPhishingList = jest.fn().mockResolvedValue(true); const maybeUpdatePhishingList = jest.fn(); - const createInterface = jest.fn(); - const getSnap = jest.fn(); + const messenger = getMessenger(); const notificationImplementation = getImplementation({ - showNativeNotification, - showInAppNotification, - isOnPhishingList, - maybeUpdatePhishingList, - createInterface, - getSnap, + methodHooks: { isOnPhishingList, maybeUpdatePhishingList }, + messenger, }); await expect( @@ -329,20 +330,13 @@ describe('snap_notify', () => { }); it('throws an error if a link in the `message` property is invalid', async () => { - const showNativeNotification = jest.fn().mockResolvedValueOnce(true); - const showInAppNotification = jest.fn().mockResolvedValueOnce(true); const isOnPhishingList = jest.fn().mockResolvedValue(true); const maybeUpdatePhishingList = jest.fn(); - const createInterface = jest.fn(); - const getSnap = jest.fn(); + const messenger = getMessenger(); const notificationImplementation = getImplementation({ - showNativeNotification, - showInAppNotification, - isOnPhishingList, - maybeUpdatePhishingList, - createInterface, - getSnap, + methodHooks: { isOnPhishingList, maybeUpdatePhishingList }, + messenger, }); await expect( @@ -362,20 +356,13 @@ describe('snap_notify', () => { }); it('throws an error if a link in the `footerLink` property is on the phishing list', async () => { - const showNativeNotification = jest.fn().mockResolvedValueOnce(true); - const showInAppNotification = jest.fn().mockResolvedValueOnce(true); const isOnPhishingList = jest.fn().mockResolvedValue(true); const maybeUpdatePhishingList = jest.fn(); - const createInterface = jest.fn(); - const getSnap = jest.fn(); + const messenger = getMessenger(); const notificationImplementation = getImplementation({ - showNativeNotification, - showInAppNotification, - isOnPhishingList, - maybeUpdatePhishingList, - createInterface, - getSnap, + methodHooks: { isOnPhishingList, maybeUpdatePhishingList }, + messenger, }); const content = ( @@ -402,20 +389,13 @@ describe('snap_notify', () => { }); it('throws an error if a link in the `footerLink` property is invalid', async () => { - const showNativeNotification = jest.fn().mockResolvedValueOnce(true); - const showInAppNotification = jest.fn().mockResolvedValueOnce(true); const isOnPhishingList = jest.fn().mockResolvedValue(true); const maybeUpdatePhishingList = jest.fn(); - const createInterface = jest.fn(); - const getSnap = jest.fn(); + const messenger = getMessenger(); const notificationImplementation = getImplementation({ - showNativeNotification, - showInAppNotification, - isOnPhishingList, - maybeUpdatePhishingList, - createInterface, - getSnap, + methodHooks: { isOnPhishingList, maybeUpdatePhishingList }, + messenger, }); const content = ( @@ -446,7 +426,7 @@ describe('snap_notify', () => { it('returns valid parameters', () => { const isNotOnPhishingList = jest.fn().mockResolvedValueOnce(false); expect( - getValidatedParams(validParams, isNotOnPhishingList, jest.fn()), + getValidatedParams(validParams, isNotOnPhishingList, getMessenger()), ).toStrictEqual(validParams); }); }); diff --git a/packages/snaps-rpc-methods/src/restricted/notify.ts b/packages/snaps-rpc-methods/src/restricted/notify.ts index bc6575ddac..97a849e940 100644 --- a/packages/snaps-rpc-methods/src/restricted/notify.ts +++ b/packages/snaps-rpc-methods/src/restricted/notify.ts @@ -1,3 +1,4 @@ +import type { Messenger } from '@metamask/messenger'; import type { PermissionSpecificationBuilder, RestrictedMethodOptions, @@ -8,7 +9,6 @@ import { rpcErrors } from '@metamask/rpc-errors'; import type { NotifyParams, NotifyResult, - InterfaceContext, ComponentOrElement, } from '@metamask/snaps-sdk'; import { @@ -24,11 +24,16 @@ import { validateLink, validateTextLinks, } from '@metamask/snaps-utils'; -import type { InferMatching, Snap } from '@metamask/snaps-utils'; +import type { InferMatching } from '@metamask/snaps-utils'; import { object, string, optional } from '@metamask/superstruct'; import type { NonEmptyArray } from '@metamask/utils'; -import { hasProperty, isObject } from '@metamask/utils'; +import { assertExhaustive, hasProperty, isObject } from '@metamask/utils'; +import type { + RateLimitControllerCallAction, + SnapControllerGetSnapAction, + SnapInterfaceControllerCreateInterfaceAction, +} from '../types'; import { type MethodHooksObject } from '../utils'; const methodName = 'snap_notify'; @@ -68,40 +73,20 @@ export type NotificationArgs = InferMatching< >; export type NotifyMethodHooks = { - /** - * @param snapId - The ID of the Snap that created the notification. - * @param args - The notification arguments. - */ - showNativeNotification: ( - snapId: string, - args: NotificationArgs, - ) => Promise; - - /** - * @param snapId - The ID of the Snap that created the notification. - * @param args - The notification arguments. - */ - showInAppNotification: ( - snapId: string, - args: NotificationArgs, - ) => Promise; - isOnPhishingList: (url: string) => boolean; maybeUpdatePhishingList: () => Promise; - - createInterface: ( - origin: string, - content: ComponentOrElement, - context?: InterfaceContext, - contentType?: ContentType, - ) => Promise; - getSnap: (snapId: string) => Snap | null; }; +export type NotifyMessengerActions = + | RateLimitControllerCallAction + | SnapControllerGetSnapAction + | SnapInterfaceControllerCreateInterfaceAction; + type SpecificationBuilderOptions = { allowedCaveats?: Readonly> | null; methodHooks: NotifyMethodHooks; + messenger: Messenger; }; type Specification = ValidPermissionSpecification<{ @@ -117,6 +102,7 @@ type Specification = ValidPermissionSpecification<{ * * @param options - The specification builder options. * @param options.allowedCaveats - The optional allowed caveats for the permission. + * @param options.messenger - The messenger. * @param options.methodHooks - The RPC method hooks needed by the method implementation. * @returns The specification for the `snap_notify` permission. */ @@ -124,23 +110,23 @@ export const specificationBuilder: PermissionSpecificationBuilder< PermissionType.RestrictedMethod, SpecificationBuilderOptions, Specification -> = ({ allowedCaveats = null, methodHooks }: SpecificationBuilderOptions) => { +> = ({ + allowedCaveats = null, + methodHooks, + messenger, +}: SpecificationBuilderOptions) => { return { permissionType: PermissionType.RestrictedMethod, targetName: methodName, allowedCaveats, - methodImplementation: getImplementation(methodHooks), + methodImplementation: getImplementation({ methodHooks, messenger }), subjectTypes: [SubjectType.Snap], }; }; const methodHooks: MethodHooksObject = { - showNativeNotification: true, - showInAppNotification: true, isOnPhishingList: true, maybeUpdatePhishingList: true, - createInterface: true, - getSnap: true, }; /** @@ -221,29 +207,28 @@ export const notifyBuilder = Object.freeze({ targetName: methodName, specificationBuilder, methodHooks, + actionNames: [ + 'RateLimitController:call', + 'SnapController:getSnap', + 'SnapInterfaceController:createInterface', + ], } as const); /** * Builds the method implementation for `snap_notify`. * - * @param hooks - The RPC method hooks. - * @param hooks.showNativeNotification - A function that shows a native browser notification. - * @param hooks.showInAppNotification - A function that shows a notification in the MetaMask UI. - * @param hooks.isOnPhishingList - A function that checks for links against the phishing list. - * @param hooks.maybeUpdatePhishingList - A function that updates the phishing list if needed. - * @param hooks.createInterface - A function that creates the interface in SnapInterfaceController. - * @param hooks.getSnap - A function that checks if a snap is installed. + * @param options - The options. + * @param options.messenger - The messenger. + * @param options.methodHooks - The RPC method hooks. + * @param options.methodHooks.isOnPhishingList - A function that checks for links against the phishing list. + * @param options.methodHooks.maybeUpdatePhishingList - A function that updates the phishing list if needed. * @returns The method implementation which returns `null` on success. * @throws If the params are invalid. */ export function getImplementation({ - showNativeNotification, - showInAppNotification, - isOnPhishingList, - maybeUpdatePhishingList, - createInterface, - getSnap, -}: NotifyMethodHooks) { + methodHooks: { isOnPhishingList, maybeUpdatePhishingList }, + messenger, +}: SpecificationBuilderOptions) { return async function implementation( args: RestrictedMethodOptions, ): Promise { @@ -257,11 +242,12 @@ export function getImplementation({ const validatedParams = getValidatedParams( params, isOnPhishingList, - getSnap, + messenger, ); if (hasProperty(validatedParams, 'content')) { - const id = await createInterface( + const id = messenger.call( + 'SnapInterfaceController:createInterface', origin, validatedParams.content as ComponentOrElement, undefined, @@ -272,13 +258,36 @@ export function getImplementation({ switch (validatedParams.type) { case NotificationType.Native: - return await showNativeNotification(origin, validatedParams); - case NotificationType.InApp: - return await showInAppNotification(origin, validatedParams); + return (await messenger.call( + 'RateLimitController:call', + origin, + 'showNativeNotification', + origin, + validatedParams.message, + )) as NotifyResult; + case NotificationType.InApp: { + const { content, message, title, footerLink } = + validatedParams as NotificationArgs & { + content?: string; + title?: string; + footerLink?: { href: string; text: string }; + }; + return (await messenger.call( + 'RateLimitController:call', + origin, + 'showInAppNotification', + origin, + { + interfaceId: content, + message, + title, + footerLink, + }, + )) as NotifyResult; + } + /* istanbul ignore next */ default: - throw rpcErrors.invalidParams({ - message: 'Must specify a valid notification "type".', - }); + return assertExhaustive(validatedParams.type as never); } }; } @@ -289,14 +298,14 @@ export function getImplementation({ * * @param params - The unvalidated params object from the method request. * @param isOnPhishingList - The function that checks for links against the phishing list. - * @param getSnap - A function that checks if a snap is installed. + * @param messenger - The messenger. * @returns The validated method parameter object. * @throws If the params are invalid. */ export function getValidatedParams( params: unknown, isOnPhishingList: NotifyMethodHooks['isOnPhishingList'], - getSnap: NotifyMethodHooks['getSnap'], + messenger: Messenger, ): NotifyParams { if (!isObject(params)) { throw rpcErrors.invalidParams({ @@ -345,6 +354,9 @@ export function getValidatedParams( 'type', ); + const getSnap = (snapId: string) => + messenger.call('SnapController:getSnap', snapId); + validateTextLinks(validatedParams.message, isOnPhishingList, getSnap); if (hasProperty(validatedParams, 'footerLink')) { diff --git a/packages/snaps-rpc-methods/src/types.ts b/packages/snaps-rpc-methods/src/types.ts index 7e99e5b4d3..09ad4d558d 100644 --- a/packages/snaps-rpc-methods/src/types.ts +++ b/packages/snaps-rpc-methods/src/types.ts @@ -2,6 +2,14 @@ import type { JsonRpcEngineEndCallback, JsonRpcEngineNextCallback, } from '@metamask/json-rpc-engine'; +import type { + ComponentOrElement, + ContentType, + InterfaceContext, + InterfaceState, + SnapId, +} from '@metamask/snaps-sdk'; +import type { Snap, SnapRpcHookArgs } from '@metamask/snaps-utils'; import type { Json, JsonRpcParams, @@ -48,3 +56,105 @@ export type PermittedHandlerExport< hookNames: HookNames; methodNames: string[]; }; + +export type HdKeyring = { + type: 'HD Key Tree'; + seed?: Uint8Array; + mnemonic?: Uint8Array; +}; + +export type KeyringControllerWithKeyringAction = { + type: 'KeyringController:withKeyring'; + handler: ( + selector: + | { + type: string; + index?: number; + } + | { id: string }, + operation: (args: { keyring: HdKeyring }) => Promise, + ) => Promise; +}; + +export type ApprovalControllerAddRequestAction = { + type: 'ApprovalController:addRequest'; + handler: ( + opts: { + id?: string; + origin: string; + type: string; + requestData?: Record; + requestState?: Record; + }, + shouldShowRequest: boolean, + ) => Promise; +}; + +export type SnapInterfaceControllerCreateInterfaceAction = { + type: 'SnapInterfaceController:createInterface'; + handler: ( + snapId: string, + content: ComponentOrElement, + context?: InterfaceContext, + contentType?: ContentType, + ) => string; +}; + +export type SnapInterfaceControllerGetInterfaceAction = { + type: 'SnapInterfaceController:getInterface'; + handler: ( + snapId: string, + id: string, + ) => { + content: ComponentOrElement; + snapId: SnapId; + state: InterfaceState; + context: InterfaceContext | null; + }; +}; + +export type SnapInterfaceControllerSetInterfaceDisplayedAction = { + type: 'SnapInterfaceController:setInterfaceDisplayed'; + handler: (snapId: string, id: string) => void; +}; + +export type SnapControllerHandleRequestAction = { + type: 'SnapController:handleRequest'; + handler: (args: SnapRpcHookArgs & { snapId: string }) => Promise; +}; + +export type SnapControllerGetSnapAction = { + type: 'SnapController:getSnap'; + handler: (snapId: string) => Snap | null; +}; + +export type SnapControllerClearSnapStateAction = { + type: 'SnapController:clearSnapState'; + handler: (snapId: string, encrypted: boolean) => void; +}; + +export type SnapControllerGetSnapStateAction = { + type: 'SnapController:getSnapState'; + handler: ( + snapId: string, + encrypted: boolean, + ) => Promise | null>; +}; + +export type SnapControllerUpdateSnapStateAction = { + type: 'SnapController:updateSnapState'; + handler: ( + snapId: string, + newState: Record, + encrypted: boolean, + ) => Promise; +}; + +export type RateLimitControllerCallAction = { + type: 'RateLimitController:call'; + handler: ( + origin: string, + type: string, + ...args: unknown[] + ) => Promise; +}; diff --git a/packages/snaps-rpc-methods/src/utils.ts b/packages/snaps-rpc-methods/src/utils.ts index 84b91a21e5..ca9abee802 100644 --- a/packages/snaps-rpc-methods/src/utils.ts +++ b/packages/snaps-rpc-methods/src/utils.ts @@ -6,6 +6,7 @@ import type { CryptographicFunctions, } from '@metamask/key-tree'; import { SLIP10Node } from '@metamask/key-tree'; +import type { Messenger } from '@metamask/messenger'; import { rpcErrors } from '@metamask/rpc-errors'; import type { MagicValue } from '@metamask/snaps-utils'; import { refine, string } from '@metamask/superstruct'; @@ -20,6 +21,7 @@ import { import { keccak_256 as keccak256 } from '@noble/hashes/sha3'; import { SnapEndowments } from './endowments'; +import type { KeyringControllerWithKeyringAction } from './types'; const HARDENED_VALUE = 0x80000000; @@ -381,3 +383,116 @@ export const UI_PERMISSIONS = [ SnapEndowments.TransactionInsight, SnapEndowments.SignatureInsight, ] as const; + +export const HD_KEYRING = 'HD Key Tree'; + +/** + * Get the mnemonic for a given entropy source. If no source is + * provided, the primary HD keyring's mnemonic will be returned. + * + * @param messenger - The messenger. + * @param source - The ID of the entropy source keyring. + * @returns The mnemonic. + */ +export async function getMnemonic( + messenger: Messenger, + source?: string | undefined, +): Promise { + if (!source) { + const mnemonic = (await messenger.call( + 'KeyringController:withKeyring', + { + type: HD_KEYRING, + index: 0, + }, + async ({ keyring }) => keyring.mnemonic, + )) as Uint8Array | null; + + if (!mnemonic) { + throw new Error('Primary keyring mnemonic unavailable.'); + } + + return mnemonic; + } + + try { + const keyringData = await messenger.call( + 'KeyringController:withKeyring', + { + id: source, + }, + async ({ keyring }) => ({ + type: keyring.type, + mnemonic: keyring.mnemonic, + }), + ); + + const { type, mnemonic } = keyringData as { + type: string; + mnemonic?: Uint8Array; + }; + + // The keyring isn't guaranteed to have a mnemonic (e.g., + // hardware wallets, which can't be used as entropy sources), + // so we throw an error if it doesn't. + assert(type === HD_KEYRING && mnemonic); + + return mnemonic; + } catch { + throw new Error(`Entropy source with ID "${source}" not found.`); + } +} + +/** + * Get the mnemonic seed for a given entropy source. If no source is + * provided, the primary HD keyring's mnemonic seed will be returned. + * + * @param messenger - The messenger. + * @param source - The ID of the entropy source keyring. + * @returns The mnemonic seed. + */ +export async function getMnemonicSeed( + messenger: Messenger, + source?: string | undefined, +): Promise { + if (!source) { + const seed = (await messenger.call( + 'KeyringController:withKeyring', + { + type: HD_KEYRING, + index: 0, + }, + async ({ keyring }) => keyring.seed, + )) as Uint8Array | null; + + if (!seed) { + throw new Error('Primary keyring mnemonic unavailable.'); + } + + return seed; + } + + try { + const keyringData = await messenger.call( + 'KeyringController:withKeyring', + { + id: source, + }, + async ({ keyring }) => ({ + type: keyring.type, + seed: keyring.seed, + }), + ); + + const { type, seed } = keyringData as { type: string; seed?: Uint8Array }; + + // The keyring isn't guaranteed to have a mnemonic (e.g., + // hardware wallets, which can't be used as entropy sources), + // so we throw an error if it doesn't. + assert(type === HD_KEYRING && seed); + + return seed; + } catch { + throw new Error(`Entropy source with ID "${source}" not found.`); + } +} diff --git a/packages/snaps-simulation/src/controllers.test.ts b/packages/snaps-simulation/src/controllers.test.ts index e2ef5c5616..44f87214b2 100644 --- a/packages/snaps-simulation/src/controllers.test.ts +++ b/packages/snaps-simulation/src/controllers.test.ts @@ -21,7 +21,6 @@ describe('getControllers', () => { .fn() .mockResolvedValue(mnemonicPhraseToBytes(DEFAULT_SRP)), getClientCryptography: jest.fn(), - getMnemonicSeed: jest.fn(), getSimulationState: jest.fn(), getSnap: jest.fn(), setCurrentChain: jest.fn(), @@ -46,7 +45,6 @@ describe('getControllers', () => { .fn() .mockResolvedValue(mnemonicPhraseToBytes(DEFAULT_SRP)), getClientCryptography: jest.fn(), - getMnemonicSeed: jest.fn(), getSimulationState: jest.fn(), getSnap: jest.fn(), setCurrentChain: jest.fn(), @@ -82,7 +80,6 @@ describe('getControllers', () => { .fn() .mockResolvedValue(mnemonicPhraseToBytes(DEFAULT_SRP)), getClientCryptography: jest.fn(), - getMnemonicSeed: jest.fn(), getSimulationState: jest.fn(), getSnap: jest.fn(), setCurrentChain: jest.fn(), diff --git a/packages/snaps-simulation/src/methods/hooks/get-mnemonic-seed.test.ts b/packages/snaps-simulation/src/methods/hooks/get-mnemonic-seed.test.ts deleted file mode 100644 index 691a655f65..0000000000 --- a/packages/snaps-simulation/src/methods/hooks/get-mnemonic-seed.test.ts +++ /dev/null @@ -1,49 +0,0 @@ -import { TEST_SECRET_RECOVERY_PHRASE_SEED_BYTES } from '@metamask/snaps-utils/test-utils'; - -import { getGetMnemonicSeedImplementation } from './get-mnemonic-seed'; -import { DEFAULT_ALTERNATIVE_SRP } from '../../constants'; - -describe('getGetMnemonicSeedImplementation', () => { - const alternativeSeedBytes = new Uint8Array([ - 94, 176, 11, 189, 220, 240, 105, 8, 72, 137, 168, 171, 145, 85, 86, 129, - 101, 245, 196, 83, 204, 184, 94, 112, 129, 26, 174, 214, 246, 218, 95, 193, - 154, 90, 196, 11, 56, 156, 211, 112, 208, 134, 32, 109, 236, 138, 166, 196, - 61, 174, 166, 105, 15, 32, 173, 61, 141, 72, 178, 210, 206, 158, 56, 228, - ]); - - it('returns the default mnemonic seed', async () => { - const getMnemonicSeed = getGetMnemonicSeedImplementation(); - expect(await getMnemonicSeed()).toStrictEqual( - TEST_SECRET_RECOVERY_PHRASE_SEED_BYTES, - ); - - expect(await getMnemonicSeed('default')).toStrictEqual( - TEST_SECRET_RECOVERY_PHRASE_SEED_BYTES, - ); - }); - - it('returns the seed of the provided default mnemonic phrase', async () => { - const getMnemonicSeed = getGetMnemonicSeedImplementation( - DEFAULT_ALTERNATIVE_SRP, - ); - expect(await getMnemonicSeed()).toStrictEqual(alternativeSeedBytes); - - expect(await getMnemonicSeed('default')).toStrictEqual( - alternativeSeedBytes, - ); - }); - - it('returns the alternative mnemonic seed', async () => { - const getMnemonicSeed = getGetMnemonicSeedImplementation(); - expect(await getMnemonicSeed('alternative')).toStrictEqual( - alternativeSeedBytes, - ); - }); - - it('throws an error for an unknown entropy source', async () => { - const getMnemonicSeed = getGetMnemonicSeedImplementation(); - await expect(getMnemonicSeed('unknown')).rejects.toThrow( - 'Entropy source with ID "unknown" not found.', - ); - }); -}); diff --git a/packages/snaps-simulation/src/methods/hooks/get-mnemonic-seed.ts b/packages/snaps-simulation/src/methods/hooks/get-mnemonic-seed.ts deleted file mode 100644 index be5453f030..0000000000 --- a/packages/snaps-simulation/src/methods/hooks/get-mnemonic-seed.ts +++ /dev/null @@ -1,29 +0,0 @@ -import { mnemonicToSeed } from '@metamask/key-tree'; - -import { DEFAULT_ALTERNATIVE_SRP, DEFAULT_SRP } from '../../constants'; - -/** - * Get the implementation of the `getMnemonicSeed` method. - * - * @param defaultSecretRecoveryPhrase - The default secret recovery phrase to - * use. - * @returns The implementation of the `getMnemonicSeed` method. - */ -export function getGetMnemonicSeedImplementation( - defaultSecretRecoveryPhrase: string = DEFAULT_SRP, -) { - return async (source?: string | undefined): Promise => { - if (!source) { - return mnemonicToSeed(defaultSecretRecoveryPhrase); - } - - switch (source) { - case 'default': - return mnemonicToSeed(defaultSecretRecoveryPhrase); - case 'alternative': - return mnemonicToSeed(DEFAULT_ALTERNATIVE_SRP); - default: - throw new Error(`Entropy source with ID "${source}" not found.`); - } - }; -} diff --git a/packages/snaps-simulation/src/methods/hooks/index.ts b/packages/snaps-simulation/src/methods/hooks/index.ts index 6668d05319..5cbb44772f 100644 --- a/packages/snaps-simulation/src/methods/hooks/index.ts +++ b/packages/snaps-simulation/src/methods/hooks/index.ts @@ -4,7 +4,6 @@ export * from './get-entropy-sources'; export * from './get-mnemonic'; export * from './get-preferences'; export * from './get-snap'; -export * from './interface'; export * from './notifications'; export * from './permitted'; export * from './request-user-approval'; diff --git a/packages/snaps-simulation/src/methods/hooks/interface.test.ts b/packages/snaps-simulation/src/methods/hooks/interface.test.ts deleted file mode 100644 index 5187df1f18..0000000000 --- a/packages/snaps-simulation/src/methods/hooks/interface.test.ts +++ /dev/null @@ -1,111 +0,0 @@ -import { SnapInterfaceController } from '@metamask/snaps-controllers'; -import { NodeType } from '@metamask/snaps-sdk'; -import { getJsxElementFromComponent } from '@metamask/snaps-utils'; -import { MOCK_SNAP_ID } from '@metamask/snaps-utils/test-utils'; - -import { - getCreateInterfaceImplementation, - getGetInterfaceImplementation, - getSetInterfaceDisplayedImplementation, -} from './interface'; -import { - getRestrictedSnapInterfaceControllerMessenger, - getRootControllerMessenger, -} from '../../test-utils'; - -describe('getCreateInterfaceImplementation', () => { - it('returns the implementation of the `createInterface` hook', async () => { - const controllerMessenger = getRootControllerMessenger(); - - const interfaceController = new SnapInterfaceController({ - messenger: - getRestrictedSnapInterfaceControllerMessenger(controllerMessenger), - }); - - jest.spyOn(controllerMessenger, 'call'); - - const fn = getCreateInterfaceImplementation(controllerMessenger); - - const content = { type: NodeType.Text as const, value: 'bar' }; - - const id = await fn(MOCK_SNAP_ID, content); - - const result = interfaceController.getInterface(MOCK_SNAP_ID, id); - - expect(controllerMessenger.call).toHaveBeenCalledWith( - 'SnapInterfaceController:createInterface', - MOCK_SNAP_ID, - content, - undefined, - undefined, - ); - - expect(result.content).toStrictEqual(getJsxElementFromComponent(content)); - }); -}); - -describe('getGetInterfaceImplementation', () => { - it('returns the implementation of the `getInterface` hook', async () => { - const controllerMessenger = getRootControllerMessenger(); - - const interfaceController = new SnapInterfaceController({ - messenger: - getRestrictedSnapInterfaceControllerMessenger(controllerMessenger), - }); - - jest.spyOn(controllerMessenger, 'call'); - - const fn = getGetInterfaceImplementation(controllerMessenger); - - const content = { type: NodeType.Text as const, value: 'bar' }; - - const id = interfaceController.createInterface(MOCK_SNAP_ID, content); - - const result = fn(MOCK_SNAP_ID, id); - - expect(controllerMessenger.call).toHaveBeenCalledWith( - 'SnapInterfaceController:getInterface', - MOCK_SNAP_ID, - id, - ); - expect(result).toStrictEqual({ - content: getJsxElementFromComponent(content), - state: {}, - snapId: MOCK_SNAP_ID, - context: null, - contentType: null, - displayed: false, - }); - }); -}); - -describe('getSetInterfaceDisplayedImplementation', () => { - it('returns the implementation of the `setInterfaceDisplayed` hook', () => { - const controllerMessenger = getRootControllerMessenger(); - - const interfaceController = new SnapInterfaceController({ - messenger: - getRestrictedSnapInterfaceControllerMessenger(controllerMessenger), - }); - - jest.spyOn(controllerMessenger, 'call'); - - const fn = getSetInterfaceDisplayedImplementation(controllerMessenger); - - const content = { type: NodeType.Text as const, value: 'bar' }; - - const id = interfaceController.createInterface(MOCK_SNAP_ID, content); - - fn(MOCK_SNAP_ID, id); - - expect(controllerMessenger.call).toHaveBeenCalledWith( - 'SnapInterfaceController:setInterfaceDisplayed', - MOCK_SNAP_ID, - id, - ); - - expect(interfaceController.getInterface(MOCK_SNAP_ID, id).displayed).toBe( - true, - ); - }); -}); diff --git a/packages/snaps-simulation/src/methods/hooks/interface.ts b/packages/snaps-simulation/src/methods/hooks/interface.ts deleted file mode 100644 index f79fcb4007..0000000000 --- a/packages/snaps-simulation/src/methods/hooks/interface.ts +++ /dev/null @@ -1,66 +0,0 @@ -import type { - Component, - ContentType, - InterfaceContext, - SnapId, -} from '@metamask/snaps-sdk'; - -import type { RootControllerMessenger } from '../../controllers'; - -/** - * Get the implementation of the `createInterface` hook. - * - * @param controllerMessenger - The controller messenger used to call actions. - * @returns The implementation of the `createInterface` hook. - */ -export function getCreateInterfaceImplementation( - controllerMessenger: RootControllerMessenger, -) { - return async ( - snapId: SnapId, - content: Component, - context?: InterfaceContext, - contentType?: ContentType, - ) => - controllerMessenger.call( - 'SnapInterfaceController:createInterface', - snapId, - content, - context, - contentType, - ); -} - -/** - * Get the implementation of the `getInterface` hook. - * - * @param controllerMessenger - The controller messenger used to call actions. - * @returns The implementation of the `getInterface` hook. - */ -export function getGetInterfaceImplementation( - controllerMessenger: RootControllerMessenger, -) { - return (snapId: SnapId, id: string) => - controllerMessenger.call( - 'SnapInterfaceController:getInterface', - snapId, - id, - ); -} - -/** - * Get the implementation of the `setInterfaceDisplayed` hook. - * - * @param controllerMessenger - The controller messenger used to call actions. - * @returns The implementation of the `setInterfaceDisplayed` hook. - */ -export function getSetInterfaceDisplayedImplementation( - controllerMessenger: RootControllerMessenger, -) { - return (snapId: SnapId, id: string) => - controllerMessenger.call( - 'SnapInterfaceController:setInterfaceDisplayed', - snapId, - id, - ); -} diff --git a/packages/snaps-simulation/src/methods/hooks/state.test.ts b/packages/snaps-simulation/src/methods/hooks/state.test.ts index b7a00f44cb..1347a761c7 100644 --- a/packages/snaps-simulation/src/methods/hooks/state.test.ts +++ b/packages/snaps-simulation/src/methods/hooks/state.test.ts @@ -57,7 +57,7 @@ describe('getUpdateSnapStateMethodImplementation', () => { expect(getState(true)(store.getState())).toBeNull(); - fn(MOCK_SNAP_ID, { foo: 'bar' }); + await fn(MOCK_SNAP_ID, { foo: 'bar' }); expect(getState(true)(store.getState())).toStrictEqual( JSON.stringify({ @@ -72,7 +72,7 @@ describe('getUpdateSnapStateMethodImplementation', () => { expect(getState(false)(store.getState())).toBeNull(); - fn(MOCK_SNAP_ID, { foo: 'bar' }, false); + await fn(MOCK_SNAP_ID, { foo: 'bar' }, false); expect(getState(false)(store.getState())).toStrictEqual( JSON.stringify({ diff --git a/packages/snaps-simulation/src/methods/hooks/state.ts b/packages/snaps-simulation/src/methods/hooks/state.ts index 60a8d726e4..30fec33ee3 100644 --- a/packages/snaps-simulation/src/methods/hooks/state.ts +++ b/packages/snaps-simulation/src/methods/hooks/state.ts @@ -65,7 +65,7 @@ function* updateSnapStateImplementation( export function getUpdateSnapStateMethodImplementation( runSaga: RunSagaFunction, ) { - return (...args: Parameters) => { + return async (...args: Parameters) => { runSaga(updateSnapStateImplementation, ...args).result(); }; } diff --git a/packages/snaps-simulation/src/methods/specifications.test.ts b/packages/snaps-simulation/src/methods/specifications.test.ts index 96e1742e3f..e5a7ca11f2 100644 --- a/packages/snaps-simulation/src/methods/specifications.test.ts +++ b/packages/snaps-simulation/src/methods/specifications.test.ts @@ -19,7 +19,6 @@ import { getMockOptions } from '../test-utils/options'; const MOCK_HOOKS = { getClientCryptography: jest.fn(), getMnemonic: jest.fn(), - getMnemonicSeed: jest.fn(), getIsLocked: jest.fn(), }; @@ -42,7 +41,6 @@ describe('getPermissionSpecifications', () => { expect( getPermissionSpecifications({ hooks: MOCK_HOOKS, - runSaga: jest.fn(), options: getMockOptions(), controllerMessenger: new Messenger({ namespace: MOCK_ANY_NAMESPACE, diff --git a/packages/snaps-simulation/src/methods/specifications.ts b/packages/snaps-simulation/src/methods/specifications.ts index 706962932e..9754ec4f2b 100644 --- a/packages/snaps-simulation/src/methods/specifications.ts +++ b/packages/snaps-simulation/src/methods/specifications.ts @@ -1,9 +1,11 @@ import { caip25EndowmentBuilder } from '@metamask/chain-agnostic-permission'; +import type { CryptographicFunctions } from '@metamask/key-tree'; import type { GenericPermissionController, PermissionSpecificationConstraint, PermissionSpecificationMap, } from '@metamask/permission-controller'; +import type { RestrictedMethodMessenger } from '@metamask/snaps-rpc-methods'; import { endowmentPermissionBuilders, buildSnapEndowmentSpecifications, @@ -16,35 +18,23 @@ import { EXCLUDED_SNAP_ENDOWMENTS, EXCLUDED_SNAP_PERMISSIONS, } from './constants'; -import { - getGetPreferencesMethodImplementation, - getClearSnapStateMethodImplementation, - getGetSnapStateMethodImplementation, - getUpdateSnapStateMethodImplementation, - getShowInAppNotificationImplementation, - getShowNativeNotificationImplementation, - getCreateInterfaceImplementation, - getGetInterfaceImplementation, - getRequestUserApprovalImplementation, - getSetInterfaceDisplayedImplementation, -} from './hooks'; +import { getGetPreferencesMethodImplementation } from './hooks'; import type { RootControllerMessenger } from '../controllers'; import type { SimulationOptions } from '../options'; -import type { RunSagaFunction } from '../store'; export type PermissionSpecificationsHooks = { /** - * A hook that returns the user's secret recovery phrase. + * Get the cryptographic functions to use for the client. This may return an + * empty object to fall back to the default cryptographic functions. * - * @returns The user's secret recovery phrase. + * @returns The cryptographic functions to use for the client. */ - getMnemonic: () => Promise; + getClientCryptography: () => CryptographicFunctions; }; export type GetPermissionSpecificationsOptions = { controllerMessenger: RootControllerMessenger; hooks: PermissionSpecificationsHooks; - runSaga: RunSagaFunction; options: SimulationOptions; }; @@ -75,44 +65,35 @@ export function asyncResolve(result?: Type) { * @param options - The options. * @param options.controllerMessenger - The controller messenger. * @param options.hooks - The hooks. - * @param options.runSaga - The function to run a saga outside the usual Redux - * flow. * @param options.options - The simulation options. * @returns The permission specifications for the Snap. */ export function getPermissionSpecifications({ controllerMessenger, hooks, - runSaga, options, }: GetPermissionSpecificationsOptions): PermissionSpecificationMap { return { [caip25EndowmentBuilder.targetName]: caip25EndowmentBuilder.specificationBuilder({}), ...buildSnapEndowmentSpecifications(EXCLUDED_SNAP_ENDOWMENTS), - ...buildSnapRestrictedMethodSpecifications(EXCLUDED_SNAP_PERMISSIONS, { - // Shared hooks. - ...hooks, + ...buildSnapRestrictedMethodSpecifications( + EXCLUDED_SNAP_PERMISSIONS, + { + // Shared hooks. + ...hooks, - // Snaps-specific hooks. - clearSnapState: getClearSnapStateMethodImplementation(runSaga), - getPreferences: getGetPreferencesMethodImplementation(options), - getSnapState: getGetSnapStateMethodImplementation(runSaga), - getUnlockPromise: asyncResolve(true), + // Snaps-specific hooks. + getPreferences: getGetPreferencesMethodImplementation(options), + getUnlockPromise: asyncResolve(true), - // TODO: Allow the user to specify the result of this function. - isOnPhishingList: resolve(false), + // TODO: Allow the user to specify the result of this function. + isOnPhishingList: resolve(false), - maybeUpdatePhishingList: asyncResolve(), - requestUserApproval: getRequestUserApprovalImplementation(runSaga), - showInAppNotification: getShowInAppNotificationImplementation(runSaga), - showNativeNotification: getShowNativeNotificationImplementation(runSaga), - updateSnapState: getUpdateSnapStateMethodImplementation(runSaga), - createInterface: getCreateInterfaceImplementation(controllerMessenger), - getInterface: getGetInterfaceImplementation(controllerMessenger), - setInterfaceDisplayed: - getSetInterfaceDisplayedImplementation(controllerMessenger), - }), + maybeUpdatePhishingList: asyncResolve(), + }, + controllerMessenger as RestrictedMethodMessenger, + ), }; } diff --git a/packages/snaps-simulation/src/simulation.test.ts b/packages/snaps-simulation/src/simulation.test.ts index 6b174d33c7..9f56dc73a7 100644 --- a/packages/snaps-simulation/src/simulation.test.ts +++ b/packages/snaps-simulation/src/simulation.test.ts @@ -2,7 +2,7 @@ import { Caip25CaveatType, Caip25EndowmentPermissionName, } from '@metamask/chain-agnostic-permission'; -import { mnemonicPhraseToBytes } from '@metamask/key-tree'; +import { mnemonicPhraseToBytes, mnemonicToSeed } from '@metamask/key-tree'; import { PermissionDoesNotExistError } from '@metamask/permission-controller'; import { detectSnapLocation, @@ -20,7 +20,7 @@ import { MOCK_SNAP_ID, } from '@metamask/snaps-utils/test-utils'; -import { DEFAULT_SRP } from './constants'; +import { DEFAULT_ALTERNATIVE_SRP, DEFAULT_SRP } from './constants'; import { MOCK_CAVEAT } from './middleware/multichain/test-utils'; import { getMultichainHooks, @@ -946,4 +946,56 @@ describe('registerActions', () => { }, }); }); + + it('registers `KeyringController:withKeyring`', async () => { + registerActions(controllerMessenger, runSaga, options, MOCK_SNAP_ID); + + expect( + await controllerMessenger.call( + 'KeyringController:withKeyring', + { type: 'HD Key Tree' }, + ({ keyring }) => keyring, + ), + ).toStrictEqual({ + type: 'HD Key Tree', + mnemonic: mnemonicPhraseToBytes(DEFAULT_SRP), + seed: await mnemonicToSeed(DEFAULT_SRP), + }); + + expect( + await controllerMessenger.call( + 'KeyringController:withKeyring', + { id: 'alternative' }, + ({ keyring }) => keyring, + ), + ).toStrictEqual({ + type: 'HD Key Tree', + mnemonic: mnemonicPhraseToBytes(DEFAULT_ALTERNATIVE_SRP), + seed: await mnemonicToSeed(DEFAULT_ALTERNATIVE_SRP), + }); + }); + + it('registers `RateLimitController:call`', async () => { + registerActions(controllerMessenger, runSaga, options, MOCK_SNAP_ID); + + expect( + await controllerMessenger.call( + 'RateLimitController:call', + MOCK_SNAP_ID, + 'showNativeNotification', + MOCK_SNAP_ID, + { message: 'Hello world!' }, + ), + ).toBeNull(); + + expect( + await controllerMessenger.call( + 'RateLimitController:call', + MOCK_SNAP_ID, + 'showInAppNotification', + MOCK_SNAP_ID, + { message: 'Hello world!' }, + ), + ).toBeNull(); + }); }); diff --git a/packages/snaps-simulation/src/simulation.ts b/packages/snaps-simulation/src/simulation.ts index eb32289ebf..5ccf68104a 100644 --- a/packages/snaps-simulation/src/simulation.ts +++ b/packages/snaps-simulation/src/simulation.ts @@ -1,5 +1,6 @@ import { createEngineStream } from '@metamask/json-rpc-middleware-stream'; import type { CryptographicFunctions } from '@metamask/key-tree'; +import { mnemonicToSeed } from '@metamask/key-tree'; import type { ActionConstraint, EventConstraint, @@ -35,6 +36,7 @@ import type { } from '@metamask/snaps-sdk'; import type { FetchedSnapFiles, Snap } from '@metamask/snaps-utils'; import { logError } from '@metamask/snaps-utils'; +import { assertExhaustive, hasProperty } from '@metamask/utils'; import type { CaipAssetType, Hex, Json } from '@metamask/utils'; import type { Duplex } from 'readable-stream'; import { pipeline } from 'readable-stream'; @@ -60,8 +62,13 @@ import { getEndTraceImplementation, getStartTraceImplementation, getSetCurrentChainImplementation, + getClearSnapStateMethodImplementation, + getGetSnapStateMethodImplementation, + getUpdateSnapStateMethodImplementation, + getRequestUserApprovalImplementation, + getShowInAppNotificationImplementation, + getShowNativeNotificationImplementation, } from './methods/hooks'; -import { getGetMnemonicSeedImplementation } from './methods/hooks/get-mnemonic-seed'; import { createJsonRpcEngine } from './middleware'; import type { SimulationAccount, @@ -133,18 +140,9 @@ export type RestrictedMiddlewareHooks = { /** * A hook that returns the user's secret recovery phrase. * - * @param source - The entropy source to get the mnemonic from. * @returns The user's secret recovery phrase. */ - getMnemonic: (source?: string | undefined) => Promise; - - /** - * A hook that returns the seed derived from the user's secret recovery phrase. - * - * @param source - The entropy source to get the seed from. - * @returns The seed. - */ - getMnemonicSeed: (source?: string | undefined) => Promise; + getMnemonic: () => Promise; /** * A hook that returns whether the client is locked or not. @@ -566,9 +564,6 @@ export function getRestrictedHooks( ): RestrictedMiddlewareHooks { return { getMnemonic: getGetMnemonicImplementation(options.secretRecoveryPhrase), - getMnemonicSeed: getGetMnemonicSeedImplementation( - options.secretRecoveryPhrase, - ), getIsLocked: () => false, getClientCryptography: () => ({}), getSnap: getGetSnapImplementation(true), @@ -699,6 +694,12 @@ export function getMultichainHooks( }; } +/** + * Get the mock mnemonic for a given source ID. + * + * @param options - The simulation options. + * @returns The mnemonic. + */ /** * Register mocked action handlers. * @@ -811,4 +812,90 @@ export function registerActions( return { value }; }, ); + + controllerMessenger.registerActionHandler( + 'ApprovalController:addRequest', + // @ts-expect-error Types of property 'requestData' are incompatible. + getRequestUserApprovalImplementation(runSaga), + ); + + controllerMessenger.registerActionHandler( + 'SnapController:getSnap', + getGetSnapImplementation(true), + ); + + controllerMessenger.registerActionHandler( + 'SnapController:getSnapState', + getGetSnapStateMethodImplementation(runSaga), + ); + + controllerMessenger.registerActionHandler( + 'SnapController:updateSnapState', + getUpdateSnapStateMethodImplementation(runSaga), + ); + + controllerMessenger.registerActionHandler( + 'SnapController:clearSnapState', + getClearSnapStateMethodImplementation(runSaga), + ); + + const showNativeNotification = + getShowNativeNotificationImplementation(runSaga); + const showInAppNotification = getShowInAppNotificationImplementation(runSaga); + + controllerMessenger.registerActionHandler( + // @ts-expect-error - `RateLimitController` is not part of the simulation messenger types. + 'RateLimitController:call', + async ( + _origin: string, + type: 'showNativeNotification' | 'showInAppNotification', + ...args: unknown[] + ) => { + switch (type) { + case 'showNativeNotification': + return await showNativeNotification(args[0] as string, { + type: 'native', + message: args[1] as string, + }); + case 'showInAppNotification': + return await showInAppNotification( + args[0] as string, + args[1] as Parameters[1], + ); + /* istanbul ignore next */ + default: + return assertExhaustive(type); + } + }, + ); + + const getMnemonic = getGetMnemonicImplementation( + options.secretRecoveryPhrase, + ); + + controllerMessenger.registerActionHandler( + // @ts-expect-error - `KeyringController` is not part of the simulation messenger types. + 'KeyringController:withKeyring', + async ( + selector: { type: string; index?: number } | { id: string }, + operation: (args: { + keyring: { type: string; mnemonic: Uint8Array; seed: Uint8Array }; + }) => Promise, + ) => { + const source = hasProperty(selector, 'id') + ? (selector.id as string) + : undefined; + + const mnemonic = await getMnemonic(source); + const seed = await mnemonicToSeed(mnemonic); + + return await operation({ + keyring: { + type: 'HD Key Tree', + mnemonic, + seed, + }, + }); + }, + ); }