From 15b6371584f67ca9cd1b0cf342f81996ceccd36b Mon Sep 17 00:00:00 2001
From: Alexander Tumin <iamtakingiteasy@eientei.org>
Date: Thu, 1 Sep 2022 11:14:44 +0300
Subject: [PATCH] Fix empty structure serialization for optional fields Closes
 #1, Closes #3

---
 v1/apollows/proto.go      | 52 ++++++++++++++++++++++++++++++++++-----
 v1/apollows/proto_test.go | 52 +++++++++++++++++++++++++++++++++++++++
 v1/server_websocket.go    |  8 +++---
 3 files changed, 103 insertions(+), 9 deletions(-)

diff --git a/v1/apollows/proto.go b/v1/apollows/proto.go
index 41cd36e..224a262 100644
--- a/v1/apollows/proto.go
+++ b/v1/apollows/proto.go
@@ -265,15 +265,35 @@ type PayloadOperation struct {
 	OperationName string                 `json:"operationName"`
 }
 
-// PayloadData provides server-side response for previously started operation
-type PayloadData struct {
-	Data Data `json:"data,omitempty"`
+// PayloadDataRaw provides server-side response for previously started operation
+type PayloadDataRaw struct {
+	Data Data `json:"data"`
 
 	// see https://github.com/graph-gophers/graphql-go#custom-errors
 	// for adding custom error attributes
 	Errors []error `json:"errors,omitempty"`
 }
 
+// PayloadData type-alias for serialization
+type PayloadData PayloadDataRaw
+
+// MarshalJSON serializes PayloadData to JSON, excluding empty data
+func (payload PayloadData) MarshalJSON() (bs []byte, err error) {
+	var data *Data
+
+	if payload.Data.Value != nil {
+		data = &payload.Data
+	}
+
+	return json.Marshal(struct {
+		Data *Data `json:"data,omitempty"`
+		PayloadDataRaw
+	}{
+		PayloadDataRaw: PayloadDataRaw(payload),
+		Data:           data,
+	})
+}
+
 // PayloadErrorLocation error location in originating request
 type PayloadErrorLocation struct {
 	Line   int `json:"line"`
@@ -294,9 +314,29 @@ type PayloadDataResponse struct {
 	Errors []PayloadError         `json:"errors,omitempty"`
 }
 
-// Message encapsulates every message within apollows protocol in both directions
-type Message struct {
+// MessageRaw encapsulates every message within apollows protocol in both directions
+type MessageRaw struct {
 	ID      string    `json:"id,omitempty"`
 	Type    Operation `json:"type"`
-	Payload Data      `json:"payload,omitempty"`
+	Payload Data      `json:"payload"`
+}
+
+// Message type-alias for (de-)serialization
+type Message MessageRaw
+
+// MarshalJSON serializes Message to JSON, excluding empty id or payload from serialized fields.
+func (message Message) MarshalJSON() (bs []byte, err error) {
+	var payload *Data
+
+	if message.Payload.Value != nil {
+		payload = &message.Payload
+	}
+
+	return json.Marshal(struct {
+		Payload *Data `json:"payload,omitempty"`
+		MessageRaw
+	}{
+		MessageRaw: MessageRaw(message),
+		Payload:    payload,
+	})
 }
diff --git a/v1/apollows/proto_test.go b/v1/apollows/proto_test.go
index efef6d2..5fd8f71 100644
--- a/v1/apollows/proto_test.go
+++ b/v1/apollows/proto_test.go
@@ -182,3 +182,55 @@ func TestWrapError(t *testing.T) {
 
 	assert.True(t, errors.Is(err, io.EOF))
 }
+
+func TestMessageMarshal(t *testing.T) {
+	data := Message{
+		ID:   "123",
+		Type: OperationConnectionAck,
+		Payload: Data{
+			Value: "foo",
+		},
+	}
+
+	bs, err := json.Marshal(data)
+
+	assert.NoError(t, err)
+	assert.JSONEq(t, `{"id":"123","type":"connection_ack","payload":"foo"}`, string(bs))
+}
+
+func TestMessageMarshalEmpty(t *testing.T) {
+	data := Message{
+		ID:      "123",
+		Type:    OperationConnectionAck,
+		Payload: Data{},
+	}
+
+	bs, err := json.Marshal(data)
+
+	assert.NoError(t, err)
+	assert.JSONEq(t, `{"id":"123","type":"connection_ack"}`, string(bs))
+}
+
+func TestPayloadDataMarshal(t *testing.T) {
+	data := PayloadData{
+		Data: Data{
+			Value: "foo",
+		},
+	}
+
+	bs, err := json.Marshal(data)
+
+	assert.NoError(t, err)
+	assert.JSONEq(t, `{"data":"foo"}`, string(bs))
+}
+
+func TestPayloadDataMarshalEmpty(t *testing.T) {
+	data := PayloadData{
+		Data: Data{},
+	}
+
+	bs, err := json.Marshal(data)
+
+	assert.NoError(t, err)
+	assert.JSONEq(t, `{}`, string(bs))
+}
diff --git a/v1/server_websocket.go b/v1/server_websocket.go
index 29b2ea8..1265a4e 100644
--- a/v1/server_websocket.go
+++ b/v1/server_websocket.go
@@ -205,9 +205,11 @@ func (req *websocketRequest) writeWebsocketMessage(ctx mutable.Context, t apollo
 func (req *websocketRequest) readWebsocketInit(msg *apollows.Message) (err error) {
 	init := make(apollows.PayloadInit)
 
-	err = json.Unmarshal(msg.Payload.RawMessage, &init)
-	if err != nil {
-		return
+	if len(msg.Payload.RawMessage) > 0 {
+		err = json.Unmarshal(msg.Payload.RawMessage, &init)
+		if err != nil {
+			return
+		}
 	}
 
 	err = req.server.callbacks.OnConnect(req.ctx, init)
-- 
GitLab