diff --git a/doc.go b/doc.go index bd92fa7..854a99e 100644 --- a/doc.go +++ b/doc.go @@ -18,6 +18,9 @@ // ParseMessage copies the input, extracts delimiters from the MSH header, // merges any ADD continuation segments into their preceding segments per // HL7v2.5.1 Section 2.5.2, and splits the message into segments by \r. +// ADD segments that immediately follow MSH are left as standalone segments. +// ADD segments that follow MSA, DSC, PID, QRD, QRF, URD, or URS are invalid +// per the spec and cause ParseMessage to return ErrInvalidADDContinuation. // This is the only phase that allocates. All deeper access — fields, // repetitions, components, and subcomponents — is performed by scanning the // segment's raw bytes on each call. No parsed results are cached; the diff --git a/error.go b/error.go index 8414662..be75fc7 100644 --- a/error.go +++ b/error.go @@ -35,7 +35,8 @@ var ( ErrBatchStructure = errors.New("hl7: invalid batch structure") ErrMSHDelimiterField = errors.New("hl7: cannot modify MSH-1 or MSH-2 (delimiter fields)") ErrDelimiterMismatch = errors.New("hl7: delimiter mismatch between messages") - ErrCannotDeleteMSH = errors.New("hl7: cannot delete MSH segment") + ErrCannotDeleteMSH = errors.New("hl7: cannot delete MSH segment") + ErrInvalidADDContinuation = errors.New("hl7: ADD cannot continue segment") ) // ParseError provides detailed context about a parsing failure. diff --git a/message.go b/message.go index a855c39..6c7bddf 100644 --- a/message.go +++ b/message.go @@ -49,7 +49,10 @@ func ParseMessage(data []byte) (*Message, error) { // Make an owned copy, merging any ADD continuation segments into their // preceding segment per HL7v2.5.1 Section 2.5.2. - owned := mergeADD(data, delims.Field) + owned, err := mergeADD(data, delims.Field) + if err != nil { + return nil, err + } segments := splitSegments(owned, delims) if len(segments) == 0 { @@ -129,7 +132,7 @@ func (m *Message) Raw() []byte { // "ADD" type marker are stripped, but the field separator that follows is kept, // so each ADD field becomes the next additional field of the preceding segment. // -// ADD segments without a field separator (e.g., "ADD\\r") are not merged and +// ADD segments without a field separator (e.g., "ADD\r") are not merged and // remain as standalone segments. // // ADD segments that immediately follow MSH are also left as standalone segments. @@ -137,7 +140,10 @@ func (m *Message) Raw() []byte { // ADD visible after MSH allows Concatenate to correctly reassemble cross-message // continuations where page N+1 starts with MSH followed by an ADD that continues // the last segment of page N. -func mergeADD(data []byte, fieldSep byte) []byte { +// +// ADD segments that follow MSA, DSC, PID, QRD, QRF, URD, or URS are invalid +// per HL7v2.5.1 Section 2.5.2 and cause mergeADD to return ErrInvalidADDContinuation. +func mergeADD(data []byte, fieldSep byte) ([]byte, error) { pattern1 := [5]byte{'\r', 'A', 'D', 'D', fieldSep} pattern2 := [5]byte{'\n', 'A', 'D', 'D', fieldSep} @@ -145,11 +151,13 @@ func mergeADD(data []byte, fieldSep byte) []byte { if !bytes.Contains(data, pattern1[:]) && !bytes.Contains(data, pattern2[:]) { owned := make([]byte, len(data)) copy(owned, data) - return owned + return owned, nil } // Slow path: merge ADD segments, skipping merge when the preceding segment - // is MSH. segStart tracks where the current segment begins in owned. + // is MSH, and returning an error when the preceding segment is one of the + // types that cannot be continued (MSA, DSC, PID, QRD, QRF, URD, URS). + // segStart tracks where the current segment begins in owned. owned := make([]byte, 0, len(data)) segStart := 0 i := 0 @@ -161,6 +169,8 @@ func mergeADD(data []byte, fieldSep byte) []byte { owned = append(owned, '\r', '\n') segStart = len(owned) i += 2 // skip \r\n; ADD will be copied normally + } else if isAddErrorSeg(owned[segStart:], fieldSep) { + return nil, ErrInvalidADDContinuation } else { i += 5 // skip \r\nADD, keep as the field boundary } @@ -173,6 +183,8 @@ func mergeADD(data []byte, fieldSep byte) []byte { owned = append(owned, '\r') segStart = len(owned) i++ // skip \r; ADD will be copied normally + } else if isAddErrorSeg(owned[segStart:], fieldSep) { + return nil, ErrInvalidADDContinuation } else { i += 4 // skip \rADD, keep as the field boundary } @@ -185,6 +197,8 @@ func mergeADD(data []byte, fieldSep byte) []byte { owned = append(owned, '\n') segStart = len(owned) i++ // skip \n; ADD will be copied normally + } else if isAddErrorSeg(owned[segStart:], fieldSep) { + return nil, ErrInvalidADDContinuation } else { i += 4 // skip \nADD, keep as the field boundary } @@ -204,7 +218,23 @@ func mergeADD(data []byte, fieldSep byte) []byte { segStart = len(owned) } } - return owned + return owned, nil +} + +// isAddErrorSeg reports whether seg begins with a segment type that cannot be +// continued by an ADD segment per HL7v2.5.1 Section 2.5.2: MSA, DSC, PID, +// QRD, QRF, URD, URS. +// (MSH is handled separately by isMSHSeg — ADD after MSH is left standalone.) +func isAddErrorSeg(seg []byte, fieldSep byte) bool { + if len(seg) < 4 || seg[3] != fieldSep { + return false + } + a, b, c := seg[0], seg[1], seg[2] + return (a == 'M' && b == 'S' && c == 'A') || + (a == 'D' && b == 'S' && c == 'C') || + (a == 'P' && b == 'I' && c == 'D') || + (a == 'Q' && b == 'R' && (c == 'D' || c == 'F')) || + (a == 'U' && b == 'R' && (c == 'D' || c == 'S')) } // isMSHSeg reports whether seg is an MSH segment (starts with "MSH" + fieldSep). diff --git a/message_test.go b/message_test.go index a8b11c6..da4dfd9 100644 --- a/message_test.go +++ b/message_test.go @@ -15,6 +15,7 @@ package hl7 import ( + "errors" "os" "path/filepath" "sync" @@ -419,21 +420,31 @@ func TestParseMessageADD_NoADD(t *testing.T) { } func TestParseMessageADD_EmptyADD(t *testing.T) { + // ADD following PID is invalid per spec; ParseMessage must return an error. raw := []byte("MSH|^~\\&|S|F|R|RF|20240101||ADT^A01|1|P|2.5.1\rPID|1||123\rADD|\rOBX|1") - msg, err := ParseMessage(raw) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - // ADD merges into PID (empty content), OBX is separate. - if got := len(msg.Segments()); got != 3 { - t.Fatalf("len(Segments()) = %d, want 3", got) + _, err := ParseMessage(raw) + if err == nil { + t.Fatal("expected parse error for ADD after PID, got nil") } - if got := msg.Segments()[1].Type(); got != "PID" { - t.Errorf("segment[1] type = %q, want PID", got) + if !errors.Is(err, ErrInvalidADDContinuation) { + t.Errorf("error = %v, want ErrInvalidADDContinuation", err) } - if got := msg.Segments()[2].Type(); got != "OBX" { - t.Errorf("segment[2] type = %q, want OBX", got) +} + +func TestParseMessageADD_InvalidContinuation(t *testing.T) { + types := []string{"MSA", "DSC", "PID", "QRD", "QRF", "URD", "URS"} + for _, typ := range types { + t.Run(typ, func(t *testing.T) { + raw := []byte("MSH|^~\\&|S|F|R|F|20240101||ADT^A01|1|P|2.5.1\r" + + typ + "|1\rADD|extra") + _, err := ParseMessage(raw) + if err == nil { + t.Fatal("expected parse error, got nil") + } + if !errors.Is(err, ErrInvalidADDContinuation) { + t.Errorf("error = %v, want ErrInvalidADDContinuation", err) + } + }) } } @@ -753,7 +764,7 @@ func BenchmarkMergeADD(b *testing.B) { b.ResetTimer() b.ReportAllocs() for b.Loop() { - _ = mergeADD(data, '|') + _, _ = mergeADD(data, '|') } }) @@ -762,7 +773,7 @@ func BenchmarkMergeADD(b *testing.B) { b.ResetTimer() b.ReportAllocs() for b.Loop() { - _ = mergeADD(data, '|') + _, _ = mergeADD(data, '|') } }) @@ -775,7 +786,7 @@ func BenchmarkMergeADD(b *testing.B) { b.ResetTimer() b.ReportAllocs() for b.Loop() { - _ = mergeADD(data, '|') + _, _ = mergeADD(data, '|') } }) }