diff --git a/client.go b/client.go index 3a8b133..2bc6a23 100644 --- a/client.go +++ b/client.go @@ -409,6 +409,43 @@ func (c *Client) StreamCommandResponse(ctx context.Context, creds keys.Credentia return c.streamingPostDNClient(ctx, message.CommandResponse, value, creds.HostID, creds.Counter, creds.PrivateKey) } +func (c *Client) Reauthenticate(ctx context.Context, creds keys.Credentials) (*message.ReauthenticateResponse, error) { + value, err := json.Marshal(message.ReauthenticateRequest{}) + if err != nil { + return nil, fmt.Errorf("failed to marshal DNClient message: %s", err) + } + + resp, err := c.postDNClient(ctx, message.Reauthenticate, value, creds.HostID, creds.Counter, creds.PrivateKey) + if err != nil { + return nil, err + } + + resultWrapper := message.SignedResponseWrapper{} + err = json.Unmarshal(resp, &resultWrapper) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal signed response wrapper: %s", err) + } + + // Verify the signature + valid := false + for _, caPubkey := range creds.TrustedKeys { + if caPubkey.Verify(resultWrapper.Data.Message, resultWrapper.Data.Signature) { + valid = true + break + } + } + if !valid { + return nil, fmt.Errorf("failed to verify signed API result") + } + + var response message.ReauthenticateResponse + if err := json.Unmarshal(resultWrapper.Data.Message, &response); err != nil { + return nil, fmt.Errorf("failed to unmarshal DNClient response: %s", err) + } + + return &response, nil +} + // streamingPostDNClient wraps and signs the given dnclientRequestWrapper message, and makes a streaming API call. // On success, it returns a StreamController to interact with the request. On error, the error is returned. func (c *Client) streamingPostDNClient(ctx context.Context, reqType string, value []byte, hostID string, counter uint, privkey keys.PrivateKey) (*StreamController, error) { diff --git a/client_test.go b/client_test.go index b3faa67..9d6b58a 100644 --- a/client_test.go +++ b/client_test.go @@ -490,7 +490,7 @@ func TestDoUpdate_P256(t *testing.T) { config, pkey, creds, _, err := c.Enroll(ctx, testutil.NewTestLogger(), "foobar") require.NoError(t, err) - // convert privkey to private key + // convert private key to public key pubkey, err := keys.MarshalHostP256PublicKey(creds.PrivateKey.Public().Unwrap().(*ecdsa.PublicKey)) require.NoError(t, err) @@ -898,6 +898,106 @@ func TestStreamCommandResponse(t *testing.T) { assert.Equal(t, 0, ts.RequestsRemaining(), ts.ExpectedRequests()) } +func TestReauthenticate(t *testing.T) { + t.Parallel() + + useragent := "testClient" + ts := dnapitest.NewServer(useragent) + t.Cleanup(func() { ts.Close() }) + + ca, caPrivkey := dnapitest.NebulaCACert() + caPEM, err := ca.MarshalToPEM() + require.NoError(t, err) + + c := NewClient(useragent, ts.URL) + + code := "foobar" + ts.ExpectEnrollment(code, message.NetworkCurve25519, func(req message.EnrollRequest) []byte { + cfg, err := yaml.Marshal(m{ + // we need to send this or we'll get an error from the api client + "pki": m{"ca": string(caPEM)}, + // here we reflect values back to the client for test purposes + "test": m{"code": req.Code, "dhPubkey": req.NebulaPubkeyX25519}, + }) + if err != nil { + return jsonMarshal(message.EnrollResponse{ + Errors: message.APIErrors{{ + Code: "ERR_FAILED_TO_MARSHAL_YAML", + Message: "failed to marshal test response config", + }}, + }) + } + + return jsonMarshal(message.EnrollResponse{ + Data: message.EnrollResponseData{ + HostID: "foobar", + Counter: 1, + Config: cfg, + TrustedKeys: marshalCAPublicKey(ca.Details.Curve, ca.Details.PublicKey), + Organization: message.HostOrgMetadata{ + ID: "foobaz", + Name: "foobar's foo org", + }, + Network: message.HostNetworkMetadata{ + ID: "qux", + Name: "the best network", + Curve: message.NetworkCurve25519, + CIDR: "192.168.100.0/24", + }, + Host: message.HostHostMetadata{ + ID: "quux", + Name: "foo host", + IPAddress: "192.168.100.2", + }, + }, + }) + }) + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + config, pkey, creds, _, err := c.Enroll(ctx, testutil.NewTestLogger(), "foobar") + require.NoError(t, err) + + // make sure all credential values were set + assert.NotEmpty(t, creds.HostID) + assert.NotEmpty(t, creds.PrivateKey) + assert.NotEmpty(t, creds.TrustedKeys) + assert.NotEmpty(t, creds.Counter) + + // make sure we got a config back + assert.NotEmpty(t, config) + assert.NotEmpty(t, pkey) + + // This time sign the response with the correct CA key. + ts.ExpectDNClientRequest(message.Reauthenticate, http.StatusOK, func(r message.RequestWrapper) []byte { + newConfigResponse := message.ReauthenticateResponse{ + LoginURL: "https://auth.example.com/login?authcode=123", + } + rawRes := jsonMarshal(newConfigResponse) + + return jsonMarshal(message.SignedResponseWrapper{ + Data: message.SignedResponse{ + Version: 1, + Message: rawRes, + Signature: ed25519.Sign(caPrivkey, rawRes), + }, + }) + }) + + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + resp, err := c.Reauthenticate(ctx, *creds) + require.NoError(t, err) + assert.Empty(t, ts.Errors()) + assert.Equal(t, 0, ts.RequestsRemaining()) + + // make sure we got a login URL back + assert.NotEmpty(t, resp) + assert.NotEmpty(t, resp.LoginURL) + assert.Equal(t, "https://auth.example.com/login?authcode=123", resp.LoginURL) + +} + func jsonMarshal(v interface{}) []byte { b, err := json.Marshal(v) if err != nil { diff --git a/dnapitest/dnapitest.go b/dnapitest/dnapitest.go index 8076d1e..47364f0 100644 --- a/dnapitest/dnapitest.go +++ b/dnapitest/dnapitest.go @@ -294,6 +294,15 @@ func (s *Server) handlerDNClient(w http.ResponseWriter, r *http.Request) { return } + case message.Reauthenticate: + var reauth message.ReauthenticateRequest + err = json.Unmarshal(msg.Value, &reauth) + if err != nil { + s.errors = append(s.errors, fmt.Errorf("failed to unmarshal ReauthenticateRequest: %w", err)) + http.Error(w, "failed to unmarshal ReauthenticateRequest", http.StatusInternalServerError) + return + } + } if res.isStreamingRequest { diff --git a/message/message.go b/message/message.go index 0324db6..a490ab3 100644 --- a/message/message.go +++ b/message/message.go @@ -13,6 +13,7 @@ const ( DoUpdate = "DoUpdate" LongPollWait = "LongPollWait" CommandResponse = "CommandResponse" + Reauthenticate = "Reauthenticate" ) // EndpointV1 is the version 1 DNClient API endpoint. @@ -108,7 +109,7 @@ type CommandResponseRequest struct { Response any `json:"response"` } -// DNClientCommandResponseResponse is the response message associated with a CommandResponse call. +// CommandResponseResponse is the response message associated with a CommandResponse call. type CommandResponseResponse struct{} type ClientInfo struct { @@ -118,6 +119,16 @@ type ClientInfo struct { Architecture string `json:"architecture"` } +// ReauthenticateRequest is the request sent for a Reauthenticate request. +type ReauthenticateRequest struct { + // Add fields as needed +} + +// ReauthenticateResponse is the response message associated with a Reauthenticate request. +type ReauthenticateResponse struct { + LoginURL string `json:"loginURL"` +} + // EnrollEndpoint is the REST enrollment endpoint. const EnrollEndpoint = "/v2/enroll"