diff --git a/pkg/schema/container.go b/pkg/schema/container.go index 6342f3e..7d9d933 100644 --- a/pkg/schema/container.go +++ b/pkg/schema/container.go @@ -108,6 +108,16 @@ func getMandatoryChildren(e *yang.Entry) []*sdcpb.MandatoryChild { result = append(result, c) } } + // include mandatory children from augments + for _, v := range e.Augmented { + if v.Mandatory == yang.TSTrue { + c := &sdcpb.MandatoryChild{ + Name: v.Name, + IsState: isState(v), + } + result = append(result, c) + } + } return result } @@ -157,6 +167,16 @@ func getChoiceInfo(e *yang.Entry) *sdcpb.ChoiceInfo { processChoice(de, ci) } + // also consider choices introduced via augments + for _, de := range e.Augmented { + if !de.IsChoice() { + continue + } + if ci == nil { + ci = &sdcpb.ChoiceInfo{Choice: map[string]*sdcpb.ChoiceInfoChoice{}} + } + processChoice(de, ci) + } return ci } diff --git a/pkg/schema/expand.go b/pkg/schema/expand.go index 7a9a09e..8527277 100644 --- a/pkg/schema/expand.go +++ b/pkg/schema/expand.go @@ -41,7 +41,7 @@ func (sc *Schema) ExpandPath(p *sdcpb.Path, dt sdcpb.DataType) ([]*sdcpb.Path, e for _, k := range strings.Fields(e.Key) { keys[k] = struct{}{} } - for _, c := range e.Dir { + for _, c := range getChildren(e) { // skip keys if _, ok := keys[c.Name]; ok { continue @@ -109,7 +109,7 @@ func (sc *Schema) getPathElems(e *yang.Entry, dt sdcpb.DataType) [][]*sdcpb.Path kmap[k] = struct{}{} } - for _, c := range e.Dir { + for _, c := range getChildren(e) { if _, ok := kmap[c.Name]; ok { continue } @@ -126,7 +126,7 @@ func (sc *Schema) getPathElems(e *yang.Entry, dt sdcpb.DataType) [][]*sdcpb.Path case e.IsContainer(): log.Debugf("got container: %s", e.Name) containerPE := &sdcpb.PathElem{Name: e.Name, Key: make(map[string]string)} - for _, c := range e.Dir { + for _, c := range getChildren(e) { log.Debugf("container parent adding child: %s", c.Name) childrenPE := sc.getPathElems(c, dt) diff --git a/pkg/schema/object.go b/pkg/schema/object.go index df34983..4a5bf32 100644 --- a/pkg/schema/object.go +++ b/pkg/schema/object.go @@ -182,7 +182,7 @@ func getEntry(e *yang.Entry, pe []string) (*yang.Entry, error) { } return e, nil default: - if e.Dir == nil { + if e.Dir == nil && len(e.Augmented) == 0 { return nil, errors.New("not found") } for _, ee := range getChildren(e) { @@ -261,6 +261,22 @@ func (sc *Schema) buildPath(pe []string, p *sdcpb.Path, e *yang.Entry) error { return nil } + // helper to find a direct child by name from Dir or Augmented + childByName := func(parent *yang.Entry, name string) *yang.Entry { + if parent == nil { + return nil + } + if ce, ok := parent.Dir[name]; ok { + return ce + } + for _, ae := range parent.Augmented { + if ae.Name == name { + return ae + } + } + return nil + } + switch { case e.IsList(): if cpe.GetKey() == nil { @@ -281,7 +297,7 @@ func (sc *Schema) buildPath(pe []string, p *sdcpb.Path, e *yang.Entry) error { return nil } nxt := pe[count] - if ee, ok := e.Dir[nxt]; ok { + if ee := childByName(e, nxt); ee != nil { return sc.buildPath(pe[count:], p, ee) } // find choices/cases @@ -316,20 +332,20 @@ func (sc *Schema) buildPath(pe []string, p *sdcpb.Path, e *yang.Entry) error { return fmt.Errorf("case %s - unknown element %s", e.Name, pe[0]) case e.IsContainer(): // implicit case: child with same name which is a choice - if ee, ok := e.Dir[pe[0]]; ee != nil && ok { + if ee := childByName(e, pe[0]); ee != nil { if ee.IsChoice() { return sc.buildPath(pe[1:], p, ee) } } p.Elem = append(p.Elem, cpe) - if ee, ok := e.Dir[pe[0]]; ok { + if ee := childByName(e, pe[0]); ee != nil { return sc.buildPath(pe, p, ee) } if lpe == 1 { return nil } - if ee, ok := e.Dir[pe[1]]; ok { + if ee := childByName(e, pe[1]); ee != nil { return sc.buildPath(pe[1:], p, ee) } // find choice/case @@ -359,7 +375,7 @@ func (sc *Schema) buildPath(pe []string, p *sdcpb.Path, e *yang.Entry) error { func getChildren(e *yang.Entry) []*yang.Entry { switch { case e.IsChoice(), e.IsCase(), e.IsContainer(), e.IsList(): - rs := make([]*yang.Entry, 0, len(e.Dir)) + rs := make([]*yang.Entry, 0, len(e.Dir)+len(e.Augmented)) for _, ee := range e.Dir { if ee.IsChoice() || ee.IsCase() { rs = append(rs, getChildren(ee)...) @@ -367,6 +383,14 @@ func getChildren(e *yang.Entry) []*yang.Entry { } rs = append(rs, ee) } + // add augmented children as well + for _, ee := range e.Augmented { + if ee.IsChoice() || ee.IsCase() { + rs = append(rs, getChildren(ee)...) + continue + } + rs = append(rs, ee) + } //sort.Slice(rs, sortFn(rs)) return rs // case e.IsCase(): @@ -562,21 +586,41 @@ func (sc *Schema) findChoiceCase(e *yang.Entry, pe []string) (*yang.Entry, error if len(pe) == 0 { return e, nil } - for _, ee := range e.Dir { - if !ee.IsChoice() { - continue + // pe is expected to contain at least the current element name at index 0 and + // the sought child element name at index 1. + // + // This is used from container/list resolution paths where the next element may + // live under a choice/case. Case nodes do not exist in the data tree, so we + // must search through cases to find the actual schema node. + if len(pe) < 2 { + return nil, fmt.Errorf("unknown element %s", pe[0]) + } + + // scan choice nodes in both direct and augmented children + choices := make([]*yang.Entry, 0) + for _, child := range e.Dir { + if child != nil && child.IsChoice() { + choices = append(choices, child) } - if eee, ok := ee.Dir[pe[1]]; ok && !eee.IsCase() { - return eee, nil + } + for _, child := range e.Augmented { + if child != nil && child.IsChoice() { + choices = append(choices, child) } - // assume there was a case obj, - // search one step deeper - for _, eee := range ee.Dir { - if !eee.IsCase() { + } + + for _, choice := range choices { + // implicit case: choice directly contains the data node + if direct, ok := choice.Dir[pe[1]]; ok && direct != nil && !direct.IsCase() { + return direct, nil + } + // explicit cases: the data node is under a case + for _, cc := range choice.Dir { + if cc == nil || !cc.IsCase() { continue } - if eeee, ok := eee.Dir[pe[1]]; ok { - return eeee, nil + if target, ok := cc.Dir[pe[1]]; ok && target != nil { + return target, nil } } } diff --git a/pkg/schema/references.go b/pkg/schema/references.go index a3f646a..3a1e6d7 100644 --- a/pkg/schema/references.go +++ b/pkg/schema/references.go @@ -70,6 +70,21 @@ func (sc *Schema) buildReferences(e *yang.Entry) error { return err } } + // also recurse into augmented children + for _, ce := range e.Augmented { + if ce.IsCase() || ce.IsChoice() { + for _, cce := range ce.Dir { + err := sc.buildReferences(cce) + if err != nil { + return err + } + } + } + err := sc.buildReferences(ce) + if err != nil { + return err + } + } return nil } diff --git a/pkg/schema/schema.go b/pkg/schema/schema.go index 9ab9de0..dbeea39 100644 --- a/pkg/schema/schema.go +++ b/pkg/schema/schema.go @@ -141,14 +141,28 @@ func (s *Schema) Walk(e *yang.Entry, fn func(ec *yang.Entry) error) error { return err } } + // also walk augments merged into this entry + for _, ee := range e.Augmented { + err = s.Walk(ee, fn) + if err != nil { + return err + } + } return nil } err = fn(e) if err != nil { return err } - for _, e := range e.Dir { - err = s.Walk(e, fn) + for _, ce := range e.Dir { + err = s.Walk(ce, fn) + if err != nil { + return err + } + } + // include augmented children at every level + for _, ce := range e.Augmented { + err = s.Walk(ce, fn) if err != nil { return err } diff --git a/pkg/store/persiststore/persiststore.go b/pkg/store/persiststore/persiststore.go index 2bad4c7..2d83fec 100644 --- a/pkg/store/persiststore/persiststore.go +++ b/pkg/store/persiststore/persiststore.go @@ -733,6 +733,8 @@ func getModules(txn *badger.Txn, sc store.SchemaKey) ([]string, error) { func (s *persistStore) getSchema(_ context.Context, req *sdcpb.GetSchemaRequest, sck store.SchemaKey) (*sdcpb.GetSchemaResponse, error) { pes := utils.ToStrings(req.GetPath(), false, true) + log.Debugf("[persiststore][getSchema] raw path elems=%v", pes) + cKey := cacheKey{ SchemaKey: sck, Path: strings.Join(pes, "/"), @@ -750,92 +752,104 @@ func (s *persistStore) getSchema(_ context.Context, req *sdcpb.GetSchemaRequest, return rsp, nil } } + var err error sce := new(sdcpb.SchemaElem) + var modules []string - // key all i.e "root" - if lpes := len(pes); lpes == 0 || (lpes == 1 && pes[0] == "") { - err = s.db.View(func(txn *badger.Txn) error { + origin := req.GetPath().GetOrigin() + + // Root schema + if len(pes) == 0 || (len(pes) == 1 && pes[0] == "") { + var sce sdcpb.SchemaElem + err := s.db.View(func(txn *badger.Txn) error { k := buildEntryKey(sck, []string{schema.RootName}) item, err := txn.Get(k) if err != nil { return err } - if item == nil { - return ErrKeyNotFound - } v, err := item.ValueCopy(nil) if err != nil { return err } - err = proto.Unmarshal(v, sce) - if err != nil { - return err - } - return nil + return proto.Unmarshal(v, &sce) }) if err != nil { return nil, err } - return &sdcpb.GetSchemaResponse{Schema: sce}, nil + return &sdcpb.GetSchemaResponse{Schema: &sce}, nil } - moduleName := "" - if index := strings.Index(pes[0], ":"); index > 0 { - moduleName = pes[0][:index] - pes[0] = pes[0][index+1:] + + // Parse per-element prefixes and build unprefixed names + pp := parsePathElems(pes) + names := make([]string, 0, len(pp)) + for _, pe := range pp { + names = append(names, pe.name) } - var modules []string - // path has module prefix - if moduleName != "" { - modules = []string{moduleName} - } else { - // path does not have module prefix - modules, err = s.getModules(sck) - if err != nil { - return nil, err + + // Apply gNMI origin as module hint on first element (if present) + if origin != "" && len(pp) > 0 && pp[0].module == "" { + pp[0].module = origin + } + + log.Debugf("[persiststore][getSchema] parsed path elems=%+v", pp) + + // Validate first element's module prefix if present + // Note: non-first element prefixes can refer to augmented modules not in root list + allModules, err := s.getModules(sck) + if err != nil { + return nil, err + } + if len(pp) > 0 && pp[0].module != "" { + modSet := make(map[string]struct{}, len(allModules)) + for _, m := range allModules { + modSet[m] = struct{}{} + } + if _, ok := modSet[pp[0].module]; !ok { + return nil, status.Errorf(codes.InvalidArgument, "unknown module prefix %q", pp[0].module) + } + if pp[0].name == "" { + return nil, status.Errorf(codes.InvalidArgument, "empty identifier after prefix %q", pp[0].module) } + } + + // Decide candidate modules: use first element's module if present, else try all + if len(pp) > 0 && pp[0].module != "" { + modules = []string{pp[0].module} + } else { + // Prefer modules in a stable, deprioritized order + modules = append(modules, allModules...) sort.Slice(modules, func(i, j int) bool { return utils.SortModulesAB(modules[i], modules[j], config.DeprioritizedModules) }) } - npe := make([]string, 1+len(pes)) - copy(npe[1:], pes) err = s.db.View(func(txn *badger.Txn) error { for _, module := range modules { - var k []byte - if npe[1] == module { // query module name - k = buildEntryKey(sck, npe[1:]) - } else { - npe[0] = module - k = buildEntryKey(sck, npe) - } + keyPath := make([]string, 0, 1+len(names)) + keyPath = append(keyPath, module) + keyPath = append(keyPath, names...) + k := buildEntryKey(sck, keyPath) + item, err := txn.Get(k) - if err != nil { - continue - } - if item == nil { + if err != nil || item == nil { continue } v, err := item.ValueCopy(nil) if err != nil { return err } - err = proto.Unmarshal(v, sce) - if err != nil { + if err := proto.Unmarshal(v, sce); err != nil { return err } return nil } - return fmt.Errorf("%s: %w", req.GetPath(), ErrKeyNotFound) + return fmt.Errorf("schema path not found: %s", req.GetPath()) }) if err != nil { return nil, err } - rsp := &sdcpb.GetSchemaResponse{Schema: sce} - if s.cache != nil { - s.cache.Set(cKey, rsp, ttlcache.DefaultTTL) - } + return &sdcpb.GetSchemaResponse{Schema: sce}, nil } diff --git a/pkg/store/persiststore/persiststore_test.go b/pkg/store/persiststore/persiststore_test.go new file mode 100644 index 0000000..cdc217f --- /dev/null +++ b/pkg/store/persiststore/persiststore_test.go @@ -0,0 +1,246 @@ +// Copyright 2024 Nokia +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package persiststore + +import ( + "context" + "testing" + + "github.com/dgraph-io/badger/v4" + "github.com/sdcio/schema-server/pkg/schema" + "github.com/sdcio/schema-server/pkg/store" + sdcpb "github.com/sdcio/sdc-protos/sdcpb" + "google.golang.org/protobuf/proto" +) + +// +// ---------- helpers ---------- +// + +func newTestStore(t *testing.T) *persistStore { + t.Helper() + + dir := t.TempDir() + db, err := badger.Open(badger.DefaultOptions(dir).WithLogger(nil)) + if err != nil { + t.Fatalf("failed to open badger: %v", err) + } + + t.Cleanup(func() { _ = db.Close() }) + + return &persistStore{db: db} +} + +func testSchemaKey() store.SchemaKey { + return store.SchemaKey{ + Name: "M", + Vendor: "V", + Version: "1", + } +} + +func insertSchemaMeta(t *testing.T, ps *persistStore, sk store.SchemaKey) { + t.Helper() + + key := buildSchemaKey(sk) + err := ps.db.Update(func(txn *badger.Txn) error { + return txn.Set(key, []byte(`{}`)) + }) + if err != nil { + t.Fatalf("failed inserting schema meta: %v", err) + } +} + +func insertRootEntry(t *testing.T, ps *persistStore, sk store.SchemaKey, modules []string) { + t.Helper() + + // Create a root container listing available modules as children + root := &sdcpb.SchemaElem{ + Schema: &sdcpb.SchemaElem_Container{Container: &sdcpb.ContainerSchema{ + Name: schema.RootName, + Children: modules, + }}, + } + b, err := proto.Marshal(root) + if err != nil { + t.Fatalf("marshal root: %v", err) + } + key := buildEntryKey(sk, []string{schema.RootName}) + if err := ps.db.Update(func(txn *badger.Txn) error { return txn.Set(key, b) }); err != nil { + t.Fatalf("insert root failed: %v", err) + } +} + +func insertEntry(t *testing.T, ps *persistStore, sk store.SchemaKey, keyPath []string, se *sdcpb.SchemaElem) { + t.Helper() + b, err := proto.Marshal(se) + if err != nil { + t.Fatalf("marshal entry: %v", err) + } + key := buildEntryKey(sk, keyPath) + if err := ps.db.Update(func(txn *badger.Txn) error { return txn.Set(key, b) }); err != nil { + t.Fatalf("insert entry failed: %v", err) + } +} + +// +// ---------- helper function tests ---------- +// + +func TestSchemaKeyString(t *testing.T) { + sk := store.SchemaKey{Name: "n", Vendor: "v", Version: "1"} + if got := schemaKeyString(sk); got != "n@v@1" { + t.Fatalf("unexpected schemaKeyString: %q", got) + } +} + +func TestStripPrefix(t *testing.T) { + cases := map[string]string{ + "a": "a", + "m:a": "a", + "foo:bar": "bar", + } + + for in, exp := range cases { + if got := stripPrefix(in); got != exp { + t.Fatalf("stripPrefix(%q)=%q, want %q", in, got, exp) + } + } +} + +func TestHasPrefix(t *testing.T) { + if !hasPrefix("m:a") { + t.Fatalf("expected prefix") + } + if hasPrefix("a") { + t.Fatalf("unexpected prefix") + } +} + +// +// ---------- HasSchema tests ---------- +// + +func TestHasSchema(t *testing.T) { + ps := newTestStore(t) + sk := testSchemaKey() + + if ps.HasSchema(sk) { + t.Fatalf("schema should not exist") + } + + insertSchemaMeta(t, ps, sk) + + if !ps.HasSchema(sk) { + t.Fatalf("schema should exist") + } +} + +// +// ---------- GetSchema tests (negative + strict) ---------- +// + +func TestGetSchema_UnknownSchema(t *testing.T) { + ps := newTestStore(t) + + _, err := ps.GetSchema(context.Background(), &sdcpb.GetSchemaRequest{ + Schema: &sdcpb.Schema{ + Name: "M", + Vendor: "V", + Version: "1", + }, + }) + if err == nil { + t.Fatalf("expected error for unknown schema") + } +} + +func TestGetSchema_StrictPrefixRejected(t *testing.T) { + // Replace strict-prefix behavior with validation of unknown module hints + ps := newTestStore(t) + sk := testSchemaKey() + insertSchemaMeta(t, ps, sk) + // Insert root with one known module + insertRootEntry(t, ps, sk, []string{"known"}) + + _, err := ps.GetSchema(context.Background(), &sdcpb.GetSchemaRequest{ + Schema: &sdcpb.Schema{Name: sk.Name, Vendor: sk.Vendor, Version: sk.Version}, + Path: &sdcpb.Path{Elem: []*sdcpb.PathElem{{Name: "unknown:foo"}}}, + }) + if err == nil { + t.Fatalf("expected error for unknown module prefix") + } +} + +func TestGetSchema_RootSchemaMissingEntry(t *testing.T) { + ps := newTestStore(t) + sk := testSchemaKey() + insertSchemaMeta(t, ps, sk) + + _, err := ps.GetSchema(context.Background(), &sdcpb.GetSchemaRequest{ + Schema: &sdcpb.Schema{ + Name: sk.Name, + Vendor: sk.Vendor, + Version: sk.Version, + }, + }) + if err == nil { + t.Fatalf("expected error due to missing root entry") + } +} + +func TestGetSchema_ModuleLessPathResolves(t *testing.T) { + ps := newTestStore(t) + sk := testSchemaKey() + insertSchemaMeta(t, ps, sk) + // Set up root with module list and a concrete entry under that module + insertRootEntry(t, ps, sk, []string{"ietf-nss"}) + // Create container entry for ietf-nss:network-instances + se := &sdcpb.SchemaElem{Schema: &sdcpb.SchemaElem_Container{Container: &sdcpb.ContainerSchema{Name: "network-instances"}}} + insertEntry(t, ps, sk, []string{"ietf-nss", "network-instances"}, se) + + rsp, err := ps.GetSchema(context.Background(), &sdcpb.GetSchemaRequest{ + Schema: &sdcpb.Schema{Name: sk.Name, Vendor: sk.Vendor, Version: sk.Version}, + Path: &sdcpb.Path{Elem: []*sdcpb.PathElem{{Name: "network-instances"}}}, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + got := rsp.GetSchema().GetContainer().GetName() + if got != "network-instances" { + t.Fatalf("unexpected schema name: %q", got) + } +} + +func TestGetSchema_ModulePrefixedPathResolves(t *testing.T) { + ps := newTestStore(t) + sk := testSchemaKey() + insertSchemaMeta(t, ps, sk) + insertRootEntry(t, ps, sk, []string{"ietf-nss"}) + se := &sdcpb.SchemaElem{Schema: &sdcpb.SchemaElem_Container{Container: &sdcpb.ContainerSchema{Name: "network-instances"}}} + insertEntry(t, ps, sk, []string{"ietf-nss", "network-instances"}, se) + + rsp, err := ps.GetSchema(context.Background(), &sdcpb.GetSchemaRequest{ + Schema: &sdcpb.Schema{Name: sk.Name, Vendor: sk.Vendor, Version: sk.Version}, + Path: &sdcpb.Path{Elem: []*sdcpb.PathElem{{Name: "ietf-nss:network-instances"}}}, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + got := rsp.GetSchema().GetContainer().GetName() + if got != "network-instances" { + t.Fatalf("unexpected schema name: %q", got) + } +} diff --git a/pkg/store/persiststore/yang_helpers.go b/pkg/store/persiststore/yang_helpers.go new file mode 100644 index 0000000..d54f120 --- /dev/null +++ b/pkg/store/persiststore/yang_helpers.go @@ -0,0 +1,95 @@ +package persiststore + +import ( + "strings" + + sdcpb "github.com/sdcio/sdc-protos/sdcpb" +) + +// pathElem represents a parsed path element with an optional module hint. +// +// The module is derived from a prefix in the form ":" when present. +// The name field always contains the unprefixed element name. +type pathElem struct { + name string // unprefixed name + module string // optional module hint +} + +// parsePathElems parses gNMI-like path elements that may optionally carry a module +// prefix in the form ":". +// +// It returns a slice of pathElem values where: +// - name is always the unprefixed element name +// - module is set only when a prefix was present +// +// Note: only the first ':' is treated as the prefix separator. + +func parsePathElems(pes []string) []pathElem { + out := make([]pathElem, 0, len(pes)) + for _, pe := range pes { + if i := strings.IndexByte(pe, ':'); i > 0 { + out = append(out, pathElem{ + module: pe[:i], + name: pe[i+1:], + }) + } else { + out = append(out, pathElem{ + name: pe, + }) + } + } + return out +} + +// uniqueStrings returns the input slice with duplicates removed while preserving +// the original order of first occurrence. + +func uniqueStrings(in []string) []string { + m := make(map[string]struct{}) + out := make([]string, 0, len(in)) + for _, s := range in { + if _, ok := m[s]; !ok { + m[s] = struct{}{} + out = append(out, s) + } + } + return out +} + +// schemaElemModuleName extracts the module name from a SchemaElem, regardless of +// whether the element is a container, leaf, or leaf-list. +// +// If the schema element is nil or of an unknown/unsupported oneof type, an empty +// string is returned. + +func schemaElemModuleName(sce *sdcpb.SchemaElem) string { + switch s := sce.Schema.(type) { + case *sdcpb.SchemaElem_Container: + return s.Container.ModuleName + case *sdcpb.SchemaElem_Leaflist: + return s.Leaflist.ModuleName + case *sdcpb.SchemaElem_Field: + return s.Field.ModuleName + default: + return "" + } +} + +// hasPrefix reports whether the provided path element contains a module prefix. +// A prefix is identified by the presence of ':' anywhere in the string. + +func hasPrefix(pe string) bool { + return strings.Contains(pe, ":") +} + +// stripPrefix removes the module prefix from a path element of the form +// ":". +// +// If no ':' is present, the input is returned unchanged. + +func stripPrefix(pe string) string { + if i := strings.IndexByte(pe, ':'); i != -1 { + return pe[i+1:] + } + return pe +}