context BUGFIX removing applied augments and deviations when removing module

- deviations and augments defined in submodules were not removed
- deviations nor augments were not removed in ly_ctx_remove_module()
diff --git a/src/tree_schema.c b/src/tree_schema.c
index 351111f..b96e661 100644
--- a/src/tree_schema.c
+++ b/src/tree_schema.c
@@ -3513,90 +3513,115 @@
 
 #endif
 
-void
-lys_sub_module_remove_devs_augs(struct lys_module *module)
+static void
+remove_dev(struct lys_deviation *dev, const struct lys_module *module)
 {
-    uint32_t i = 0, j;
-    struct lys_node *last, *elem;
+    uint32_t idx = 0, j;
     const struct lys_module *mod;
     struct lys_module *target_mod;
     const char *ptr;
 
-    /* remove applied deviations */
-    for (i = 0; i < module->deviation_size; ++i) {
-        if (module->deviation[i].orig_node) {
-            target_mod = lys_node_module(module->deviation[i].orig_node);
-        } else {
-            target_mod = (struct lys_module *)lys_get_import_module(module, NULL, 0, module->deviation[i].target_name + 1,
-                                                                    strcspn(module->deviation[i].target_name, ":") - 1);
-            target_mod = (struct lys_module *)lys_implemented_module(target_mod);
+    if (dev->orig_node) {
+        target_mod = lys_node_module(dev->orig_node);
+    } else {
+        target_mod = (struct lys_module *)lys_get_import_module(module, NULL, 0, dev->target_name + 1,
+                                                                strcspn(dev->target_name, ":") - 1);
+        target_mod = (struct lys_module *)lys_implemented_module(target_mod);
+    }
+    lys_switch_deviation(dev, module);
+
+    /* clear the deviation flag if possible */
+    while ((mod = ly_ctx_get_module_iter(module->ctx, &idx))) {
+        if ((mod == module) || (mod == target_mod)) {
+            continue;
         }
-        lys_switch_deviation(&module->deviation[i], module);
 
-        /* clear the deviation flag if possible */
-        while ((mod = ly_ctx_get_module_iter(module->ctx, &i))) {
-            if ((mod == module) || (mod == target_mod)) {
-                continue;
-            }
-
-            for (j = 0; j < mod->deviation_size; ++j) {
-                ptr = strstr(mod->deviation[j].target_name, target_mod->name);
-                if (ptr && (ptr[strlen(target_mod->name)] == ':')) {
-                    /* some other module deviation targets the inspected module, flag remains */
-                    break;
-                }
-            }
-
-            if (j < mod->deviation_size) {
+        for (j = 0; j < mod->deviation_size; ++j) {
+            ptr = strstr(mod->deviation[j].target_name, target_mod->name);
+            if (ptr && (ptr[strlen(target_mod->name)] == ':')) {
+                /* some other module deviation targets the inspected module, flag remains */
                 break;
             }
         }
 
-        if (!mod) {
-            target_mod->deviated = 0;
+        if (j < mod->deviation_size) {
+            break;
         }
     }
 
+    if (!mod) {
+        target_mod->deviated = 0;
+    }
+}
+
+static void
+remove_aug(struct lys_node_augment *augment)
+{
+    struct lys_node *last, *elem;
+
+    if (!augment->target) {
+        /* skip not resolved augments */
+        return;
+    }
+
+    elem = augment->child;
+    if (elem) {
+        LY_TREE_FOR(elem, last) {
+            if (!last->next || (last->next->parent != (struct lys_node *)augment)) {
+                break;
+            }
+        }
+        /* elem is first augment child, last is the last child */
+
+        /* parent child ptr */
+        if (augment->target->child == elem) {
+            augment->target->child = last->next;
+        }
+
+        /* parent child next ptr */
+        if (elem->prev->next) {
+            elem->prev->next = last->next;
+        }
+
+        /* parent child prev ptr */
+        if (last->next) {
+            last->next->prev = elem->prev;
+        } else if (augment->target->child) {
+            augment->target->child->prev = elem->prev;
+        }
+
+        /* update augment children themselves */
+        elem->prev = last;
+        last->next = NULL;
+    }
+
+    /* needs to be NULL for lys_augment_free() to free the children */
+    augment->target = NULL;
+}
+
+void
+lys_sub_module_remove_devs_augs(struct lys_module *module)
+{
+    uint8_t u, v;
+
+    /* remove applied deviations */
+    for (u = 0; u < module->deviation_size; ++u) {
+        remove_dev(&module->deviation[u], module);
+    }
     /* remove applied augments */
-    for (i = 0; i < module->augment_size; ++i) {
-        if (!module->augment[i].target) {
-            /* skip not resolved augments */
-            continue;
+    for (u = 0; u < module->augment_size; ++u) {
+        remove_aug(&module->augment[u]);
+    }
+
+    /* remove deviation and augments defined in submodules */
+    for (v = 0; v < module->inc_size; ++v) {
+        for (u = 0; u < module->inc[v].submodule->deviation_size; ++u) {
+            remove_dev(&module->inc[v].submodule->deviation[u], module);
         }
 
-        elem = module->augment[i].child;
-        if (elem) {
-            LY_TREE_FOR(elem, last) {
-                if (!last->next || (last->next->parent != (struct lys_node *)&module->augment[i])) {
-                    break;
-                }
-            }
-            /* elem is first augment child, last is the last child */
-
-            /* parent child ptr */
-            if (module->augment[i].target->child == elem) {
-                module->augment[i].target->child = last->next;
-            }
-
-            /* parent child next ptr */
-            if (elem->prev->next) {
-                elem->prev->next = last->next;
-            }
-
-            /* parent child prev ptr */
-            if (last->next) {
-                last->next->prev = elem->prev;
-            } else if (module->augment[i].target->child) {
-                module->augment[i].target->child->prev = elem->prev;
-            }
-
-            /* update augment children themselves */
-            elem->prev = last;
-            last->next = NULL;
+        for (u = 0; u < module->inc[v].submodule->augment_size; ++u) {
+            remove_aug(&module->inc[v].submodule->augment[u]);
         }
-
-        /* needs to be NULL for lys_augment_free() to free the children */
-        module->augment[i].target = NULL;
     }
 }