diff --git a/handlers.go b/handlers.go index 8545781..993a139 100644 --- a/handlers.go +++ b/handlers.go @@ -1,6 +1,7 @@ package bot import ( + "errors" "regexp" "strings" @@ -45,64 +46,103 @@ func (h handler) match(update *models.Update) bool { return h.matchFunc(update) } - var data string - var entities []models.MessageEntity + data, entities, err := getDataFromUpdate(update, h.handlerType) + if err != nil { + return false + } + + switch h.matchType { + case MatchTypeExact: + return h.matchExact(data) + case MatchTypePrefix: + return h.matchPrefix(data) + case MatchTypeContains: + return h.matchContains(data) + case MatchTypeCommand: + return h.matchCommand(data, entities) + case MatchTypeCommandStartOnly: + return h.matchCommandStartOnly(data, entities) + case matchTypeRegexp: + return h.matchRegexp(data) + default: + return false + } +} - switch h.handlerType { +func getDataFromUpdate(update *models.Update, handlerType HandlerType) (data string, entities []models.MessageEntity, err error) { + switch handlerType { case HandlerTypeMessageText: if update.Message == nil { - return false + return "", nil, errors.New("message is nil") } data = update.Message.Text entities = update.Message.Entities case HandlerTypeCallbackQueryData: if update.CallbackQuery == nil { - return false + return "", nil, errors.New("callback query is nil") } data = update.CallbackQuery.Data case HandlerTypeCallbackQueryGameShortName: if update.CallbackQuery == nil { - return false + return "", nil, errors.New("callback query is nil") } data = update.CallbackQuery.GameShortName case HandlerTypePhotoCaption: if update.Message == nil { - return false + return "", nil, errors.New("message is nil") } data = update.Message.Caption entities = update.Message.CaptionEntities } + return +} - if h.matchType == MatchTypeExact { - return data == h.pattern - } - if h.matchType == MatchTypePrefix { - return strings.HasPrefix(data, h.pattern) - } - if h.matchType == MatchTypeContains { - return strings.Contains(data, h.pattern) +func (h handler) matchExact(data string) bool { + return data == h.pattern +} + +func (h handler) matchPrefix(data string) bool { + return strings.HasPrefix(data, h.pattern) +} + +func (h handler) matchContains(data string) bool { + return strings.Contains(data, h.pattern) +} + +func (h handler) matchRegexp(data string) bool { + return h.re.Match([]byte(data)) +} + +func extractCommand(data string, entity models.MessageEntity) string { + // Checking the correctness of boundaries to avoid panic + if entity.Offset < 0 || entity.Length <= 1 || entity.Offset+entity.Length > len(data) { + return "" } - if h.matchType == MatchTypeCommand { - for _, e := range entities { - if e.Type == models.MessageEntityTypeBotCommand { - if data[e.Offset+1:e.Offset+e.Length] == h.pattern { - return true - } + // Skipping the "/" character at the beginning of the command + return data[entity.Offset+1 : entity.Offset+entity.Length] +} + +func (h handler) matchCommand(data string, entities []models.MessageEntity) bool { + for _, e := range entities { + if e.Type == models.MessageEntityTypeBotCommand { + command := extractCommand(data, e) + if command == h.pattern { + return true } } } - if h.matchType == MatchTypeCommandStartOnly { - for _, e := range entities { - if e.Type == models.MessageEntityTypeBotCommand { - if e.Offset == 0 && data[e.Offset+1:e.Offset+e.Length] == h.pattern { - return true - } + return false +} + +func (h handler) matchCommandStartOnly(data string, entities []models.MessageEntity) bool { + for _, e := range entities { + if e.Type == models.MessageEntityTypeBotCommand && e.Offset == 0 { + command := extractCommand(data, e) + if command == h.pattern { + return true } } } - if h.matchType == matchTypeRegexp { - return h.re.Match([]byte(data)) - } return false } diff --git a/handlers_test.go b/handlers_test.go index 4a48df3..6a19d90 100644 --- a/handlers_test.go +++ b/handlers_test.go @@ -1,6 +1,7 @@ package bot import ( + "reflect" "regexp" "testing" @@ -330,3 +331,614 @@ func Test_match_command_start(t *testing.T) { } }) } + +func Test_match_NilUpdateMessageIsFalse(t *testing.T) { + b := &Bot{} + id := b.RegisterHandler(HandlerTypeMessageText, "foo", MatchTypeCommand, nil) + h := findHandler(b, id) + + u := models.Update{ + ID: 42, + Message: nil, + } + + res := h.match(&u) + if res { + t.Error("want 'false', but got 'true'") + } +} + +func Test_getDataFromUpdate(t *testing.T) { + tests := []struct { + name string + update *models.Update + handlerType HandlerType + wantData string + wantEntities []models.MessageEntity + wantError bool + }{ + { + name: "HandlerTypeMessageText - valid message, is ok", + update: &models.Update{ + Message: &models.Message{ + Text: "Hello, world!", + Entities: []models.MessageEntity{ + {Type: models.MessageEntityTypeBold, Offset: 0, Length: 5}, + }, + }, + }, + handlerType: HandlerTypeMessageText, + wantData: "Hello, world!", + wantEntities: []models.MessageEntity{ + {Type: models.MessageEntityTypeBold, Offset: 0, Length: 5}, + }, + wantError: false, + }, + { + name: "HandlerTypeMessageText - nil message, is error", + update: &models.Update{ + Message: nil, + }, + handlerType: HandlerTypeMessageText, + wantData: "", + wantEntities: nil, + wantError: true, + }, + { + name: "HandlerTypeMessageText - empty message, is ok", + update: &models.Update{ + Message: &models.Message{ + Text: "", + Entities: nil, + }, + }, + handlerType: HandlerTypeMessageText, + wantData: "", + wantEntities: nil, + wantError: false, + }, + { + name: "HandlerTypeCallbackQueryData - valid callback query, is ok", + update: &models.Update{ + CallbackQuery: &models.CallbackQuery{ + Data: "callback_data", + }, + }, + handlerType: HandlerTypeCallbackQueryData, + wantData: "callback_data", + wantEntities: nil, + wantError: false, + }, + { + name: "HandlerTypeCallbackQueryData - nil callback query, is error", + update: &models.Update{ + CallbackQuery: nil, + }, + handlerType: HandlerTypeCallbackQueryData, + wantData: "", + wantEntities: nil, + wantError: true, + }, + { + name: "HandlerTypeCallbackQueryData - empty data, is ok", + update: &models.Update{ + CallbackQuery: &models.CallbackQuery{ + Data: "", + }, + }, + handlerType: HandlerTypeCallbackQueryData, + wantData: "", + wantEntities: nil, + wantError: false, + }, + { + name: "HandlerTypeCallbackQueryGameShortName - valid game short name, is ok", + update: &models.Update{ + CallbackQuery: &models.CallbackQuery{ + GameShortName: "snake_game", + }, + }, + handlerType: HandlerTypeCallbackQueryGameShortName, + wantData: "snake_game", + wantEntities: nil, + wantError: false, + }, + { + name: "HandlerTypeCallbackQueryGameShortName - nil callback query, is error", + update: &models.Update{ + CallbackQuery: nil, + }, + handlerType: HandlerTypeCallbackQueryGameShortName, + wantData: "", + wantEntities: nil, + wantError: true, + }, + { + name: "HandlerTypePhotoCaption - valid photo caption, is ok", + update: &models.Update{ + Message: &models.Message{ + Caption: "Photo caption", + CaptionEntities: []models.MessageEntity{ + {Type: models.MessageEntityTypeItalic, Offset: 6, Length: 7}, + }, + }, + }, + handlerType: HandlerTypePhotoCaption, + wantData: "Photo caption", + wantEntities: []models.MessageEntity{ + {Type: models.MessageEntityTypeItalic, Offset: 6, Length: 7}, + }, + wantError: false, + }, + { + name: "HandlerTypePhotoCaption - nil message, is error", + update: &models.Update{ + Message: nil, + }, + handlerType: HandlerTypePhotoCaption, + wantData: "", + wantEntities: nil, + wantError: true, + }, + { + name: "HandlerTypePhotoCaption - empty caption, is ok", + update: &models.Update{ + Message: &models.Message{ + Caption: "", + CaptionEntities: nil, + }, + }, + handlerType: HandlerTypePhotoCaption, + wantData: "", + wantEntities: nil, + wantError: false, + }, + { + name: "Invalid handler type returns empty values, is ok", + update: &models.Update{ + Message: &models.Message{ + Text: "some text", + }, + }, + handlerType: HandlerType(999), // Unknown type + wantData: "", + wantEntities: nil, + wantError: false, // Func just returns empty values for unknown types + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, entities, err := getDataFromUpdate(tt.update, tt.handlerType) + + if tt.wantError && err == nil { + t.Errorf("want error but got none") + } + if !tt.wantError && err != nil { + t.Errorf("unexpected error: %v", err) + } + + if data != tt.wantData { + t.Errorf("want data %q, got %q", tt.wantData, data) + } + + if !reflect.DeepEqual(entities, tt.wantEntities) { + t.Errorf("entities mismatch:\nwant: %+v\ngot: %+v", tt.wantEntities, entities) + } + }) + } +} + +func TestHandler_matchExact(t *testing.T) { + tests := []struct { + name string + pattern string + data string + wantMatch bool + }{ + { + name: "exact match", + pattern: "test", + data: "test", + wantMatch: true, + }, + { + name: "no match", + pattern: "test", + data: "testing", + wantMatch: false, + }, + { + name: "empty pattern", + pattern: "", + data: "", + wantMatch: true, + }, + { + name: "empty data", + pattern: "test", + data: "", + wantMatch: false, + }, + { + name: "case sensitive", + pattern: "Test", + data: "test", + wantMatch: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := handler{pattern: tt.pattern} + result := h.matchExact(tt.data) + if result != tt.wantMatch { + t.Errorf("matchExact() = %v, want %v", result, tt.wantMatch) + } + }) + } +} + +func TestHandler_matchPrefix(t *testing.T) { + tests := []struct { + name string + pattern string + data string + wantMatch bool + }{ + { + name: "has prefix", + pattern: "test", + data: "testing", + wantMatch: true, + }, + { + name: "exact match", + pattern: "test", + data: "test", + wantMatch: true, + }, + { + name: "no prefix", + pattern: "test", + data: "notest", + wantMatch: false, + }, + { + name: "empty pattern", + pattern: "", + data: "anything", + wantMatch: true, + }, + { + name: "empty data", + pattern: "test", + data: "", + wantMatch: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := handler{pattern: tt.pattern} + result := h.matchPrefix(tt.data) + if result != tt.wantMatch { + t.Errorf("matchPrefix() = %v, want %v", result, tt.wantMatch) + } + }) + } +} + +func TestHandler_matchContains(t *testing.T) { + tests := []struct { + name string + pattern string + data string + wantMatch bool + }{ + { + name: "contains", + pattern: "test", + data: "atestb", + wantMatch: true, + }, + { + name: "exact match", + pattern: "test", + data: "test", + wantMatch: true, + }, + { + name: "no contains", + pattern: "test", + data: "nothing", + wantMatch: false, + }, + { + name: "empty pattern", + pattern: "", + data: "anything", + wantMatch: true, + }, + { + name: "empty data", + pattern: "test", + data: "", + wantMatch: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := handler{pattern: tt.pattern} + result := h.matchContains(tt.data) + if result != tt.wantMatch { + t.Errorf("matchContains() = %v, want %v", result, tt.wantMatch) + } + }) + } +} + +func TestHandler_matchRegexp(t *testing.T) { + tests := []struct { + name string + regexp string + data string + wantMatch bool + }{ + { + name: "simple match", + regexp: "^test$", + data: "test", + wantMatch: true, + }, + { + name: "no match", + regexp: "^test$", + data: "testing", + wantMatch: false, + }, + { + name: "partial match", + regexp: "test", + data: "testing", + wantMatch: true, + }, + { + name: "digit pattern", + regexp: "\\d+", + data: "123abc", + wantMatch: true, + }, + { + name: "no digits", + regexp: "\\d+", + data: "abc", + wantMatch: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + re := regexp.MustCompile(tt.regexp) + h := handler{re: re} + result := h.matchRegexp(tt.data) + if result != tt.wantMatch { + t.Errorf("matchRegexp() = %v, want %v", result, tt.wantMatch) + } + }) + } +} + +func TestExtractCommand(t *testing.T) { + tests := []struct { + name string + data string + entity models.MessageEntity + want string + }{ + { + name: "valid command", + data: "/start arg1", + entity: models.MessageEntity{ + Type: models.MessageEntityTypeBotCommand, + Offset: 0, + Length: 6, + }, + want: "start", + }, + { + name: "command in middle", + data: "text /help more", + entity: models.MessageEntity{ + Type: models.MessageEntityTypeBotCommand, + Offset: 5, + Length: 5, + }, + want: "help", + }, + { + name: "invalid offset negative", + data: "/start", + entity: models.MessageEntity{ + Type: models.MessageEntityTypeBotCommand, + Offset: -1, + Length: 6, + }, + want: "", + }, + { + name: "invalid length too long", + data: "/start", + entity: models.MessageEntity{ + Type: models.MessageEntityTypeBotCommand, + Offset: 0, + Length: 100, + }, + want: "", + }, + { + name: "length too small", + data: "/", + entity: models.MessageEntity{ + Type: models.MessageEntityTypeBotCommand, + Offset: 0, + Length: 1, + }, + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := extractCommand(tt.data, tt.entity) + if result != tt.want { + t.Errorf("extractCommand() = %q, want %q", result, tt.want) + } + }) + } +} + +func TestHandler_matchCommand(t *testing.T) { + tests := []struct { + name string + pattern string + data string + entities []models.MessageEntity + wantMatch bool + }{ + { + name: "command match anywhere", + pattern: "start", + data: "text /start arg1", + entities: []models.MessageEntity{ + { + Type: models.MessageEntityTypeBotCommand, + Offset: 5, + Length: 6, + }, + }, + wantMatch: true, + }, + { + name: "no command match", + pattern: "help", + data: "text /start arg1", + entities: []models.MessageEntity{ + { + Type: models.MessageEntityTypeBotCommand, + Offset: 5, + Length: 6, + }, + }, + wantMatch: false, + }, + { + name: "no entities", + pattern: "start", + data: "/start", + entities: []models.MessageEntity{}, + wantMatch: false, + }, + { + name: "multiple commands, one matches", + pattern: "help", + data: "/start /help /stop", + entities: []models.MessageEntity{ + { + Type: models.MessageEntityTypeBotCommand, + Offset: 0, + Length: 6, + }, + { + Type: models.MessageEntityTypeBotCommand, + Offset: 7, + Length: 5, + }, + { + Type: models.MessageEntityTypeBotCommand, + Offset: 13, + Length: 5, + }, + }, + wantMatch: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := handler{pattern: tt.pattern} + result := h.matchCommand(tt.data, tt.entities) + if result != tt.wantMatch { + t.Errorf("matchCommand() = %v, want %v", result, tt.wantMatch) + } + }) + } +} + +func TestHandler_matchCommandStartOnly(t *testing.T) { + tests := []struct { + name string + pattern string + data string + entities []models.MessageEntity + wantMatch bool + }{ + { + name: "command at start", + pattern: "start", + data: "/start arg1", + entities: []models.MessageEntity{ + { + Type: models.MessageEntityTypeBotCommand, + Offset: 0, + Length: 6, + }, + }, + wantMatch: true, + }, + { + name: "command not at start", + pattern: "start", + data: "text /start arg1", + entities: []models.MessageEntity{ + { + Type: models.MessageEntityTypeBotCommand, + Offset: 5, + Length: 6, + }, + }, + wantMatch: false, + }, + { + name: "wrong command at start", + pattern: "help", + data: "/start arg1", + entities: []models.MessageEntity{ + { + Type: models.MessageEntityTypeBotCommand, + Offset: 0, + Length: 6, + }, + }, + wantMatch: false, + }, + { + name: "no entities", + pattern: "start", + data: "/start", + entities: []models.MessageEntity{}, + wantMatch: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := handler{pattern: tt.pattern} + result := h.matchCommandStartOnly(tt.data, tt.entities) + if result != tt.wantMatch { + t.Errorf("matchCommandStartOnly() = %v, want %v", result, tt.wantMatch) + } + }) + } +}