Skip to content

Tags refactor #1250

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 30 commits into from
May 6, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
1b8554d
Get rid of TagRelationships
tomchop May 2, 2025
0165061
Fix some more tests
tomchop May 3, 2025
5a1d7e5
Bugfix
tomchop May 3, 2025
5034072
Update tagging event message publishing
tomchop May 3, 2025
6b6f946
Bugfix
tomchop May 3, 2025
4fd5f34
Migration script
tomchop May 3, 2025
6a6773b
Fix bug where tags were not getting cleared
tomchop May 3, 2025
ddef35d
Don't merge attributes when updating documents
tomchop May 3, 2025
66c1844
Add comment
tomchop May 3, 2025
13e95d1
Fix recursion loop
tomchop May 3, 2025
677e817
Remove tag graph query from neighbors
tomchop May 3, 2025
ad4cf51
[breaking] Move from objs to arrays for tags
tomchop May 4, 2025
ee0853c
Add analytics test go GH action
tomchop May 4, 2025
59e1705
Migration script
tomchop May 4, 2025
4501729
Fix more tests
tomchop May 4, 2025
9e901b5
Can't run these for some reason
tomchop May 4, 2025
c5af34a
Better error reporting
tomchop May 5, 2025
14e5347
Change the way indexes and view params are generated
tomchop May 5, 2025
e5a1816
Include all fields in indexing
tomchop May 5, 2025
af54554
Create generic view of all objects
tomchop May 5, 2025
fbdde54
Manually create collections at start
tomchop May 5, 2025
62f435d
Add some missing tables
tomchop May 5, 2025
afbca4f
Change type checking
tomchop May 6, 2025
d442bf8
Change from strict / clear in the schema
tomchop May 6, 2025
bd974d9
Account for removed tags
tomchop May 6, 2025
15400ff
Change test for typing
tomchop May 6, 2025
b4afeb5
Merge branch 'main' into tags
tomchop May 6, 2025
e72c702
Fix linting
tomchop May 6, 2025
d2adf77
Don't assume tags are being clear when we're extending tags
tomchop May 6, 2025
d2be77c
We just added an event to deletion
tomchop May 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 24 additions & 212 deletions core/database_arango.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
Relationship,
RelationshipTypes,
RoleRelationship,
TagRelationship,
)


Expand Down Expand Up @@ -422,24 +421,21 @@ def save(
Returns:
The created Yeti object.
"""
exclude = ["tags"] + self._exclude_overwrite
exclude = self._exclude_overwrite
doc_dict = self.model_dump(exclude_unset=True, exclude=exclude)
if doc_dict.get("id") is not None:
exclude = ["tags", "acls"] + self._exclude_overwrite
exclude = ["acls"] + self._exclude_overwrite
result = self._update(self.model_dump_json(exclude=exclude))
event_type = message.EventType.update
else:
exclude = ["tags", "acls", "id"] + self._exclude_overwrite
exclude = ["acls", "id"] + self._exclude_overwrite
result = self._insert(self.model_dump_json(exclude=exclude))
event_type = message.EventType.new
if not result:
exclude = exclude_overwrite + self._exclude_overwrite
result = self._update(self.model_dump_json(exclude=exclude))
event_type = message.EventType.update
yeti_object = self.__class__(**result)
# TODO: Override this if we decide to implement YetiTagModel
if hasattr(self, "tags"):
yeti_object.get_tags()
if self._collection_name not in ("auditlog", "timeline"):
try:
event = message.ObjectEvent(type=event_type, yeti_object=yeti_object)
Expand Down Expand Up @@ -522,191 +518,6 @@ def find(cls: Type[TYetiObject], **kwargs) -> TYetiObject | None:
document["__id"] = document.pop("_key")
return cls.load(document)

def tag(
self: TYetiObject,
tags: List[str],
strict: bool = False,
normalized: bool = True,
expiration: datetime.timedelta | None = None,
) -> TYetiObject:
"""Connects object to tag graph."""
# Import at runtime to avoid circular dependency.
from core.schemas import tag

if self.id is None:
raise RuntimeError(
"Cannot tag unsaved object, make sure to save() it first."
)

if not isinstance(tags, (list, set, tuple)):
raise ValueError("Tags must be of type list, set or tuple.")

tags = list({t.strip() for t in tags if t.strip()})
if strict:
self.clear_tags()

extra_tags = set()
for provided_tag_name in tags:
tag_name = tag.normalize_name(provided_tag_name)
if not tag_name:
raise RuntimeError(
f"Cannot tag object with empty tag: '{provided_tag_name}' -> '{tag_name}'"
)
replacements, _ = tag.Tag.filter({"in__replaces": [tag_name]}, count=1)
new_tag: Optional[tag.Tag] = None

if replacements:
new_tag = replacements[0]
# Attempt to find actual tag
else:
new_tag = tag.Tag.find(name=tag_name)
# Create tag
if not new_tag:
new_tag = tag.Tag(name=tag_name).save()

expiration = expiration or new_tag.default_expiration
tag_link = self.link_to_tag(new_tag.name, expiration=expiration)
self._tags[new_tag.name] = tag_link

extra_tags |= set(new_tag.produces)

extra_tags -= set(tags)
if extra_tags:
self.tag(list(extra_tags))

return self

def link_to_tag(
self, tag_name: str, expiration: datetime.timedelta
) -> "TagRelationship":
"""Links a YetiObject to a Tag object.

Args:
tag_name: The name of the tag to link to.
"""
# Import at runtime to avoid circular dependency.
from core.schemas.graph import TagRelationship
from core.schemas.tag import Tag

graph = self._db.graph("tags")

tags = self.get_tags()

for tag_relationship, tag in tags:
if tag.name != tag_name:
continue
tag_relationship.last_seen = datetime.datetime.now(datetime.timezone.utc)
tag_relationship.fresh = True
edge = json.loads(tag_relationship.model_dump_json())
edge["_id"] = tag_relationship.id
graph.update_edge(edge)
if self._collection_name not in ("auditlog", "timeline"):
try:
event = message.TagEvent(
type=message.EventType.update,
tagged_object=self,
tag_object=tag,
)
producer.publish_event(event)
except Exception:
logging.exception("Error while publishing event")
return tag_relationship

# Relationship doesn't exist, check if tag is already in the db
tag_obj = Tag.find(name=tag_name)
if not tag_obj:
tag_obj = Tag(name=tag_name).save()
tag_obj.count += 1
tag_obj.save()

tag_relationship = TagRelationship(
source=self.extended_id,
target=tag_obj.extended_id,
last_seen=datetime.datetime.now(datetime.timezone.utc),
expires=datetime.datetime.now(datetime.timezone.utc) + expiration,
fresh=True,
)

job = graph.edge_collection("tagged").link(
self.extended_id,
tag_obj.extended_id,
data=json.loads(tag_relationship.model_dump_json()),
return_new=True,
)
while job.status() != "done":
time.sleep(ASYNC_JOB_WAIT_TIME)
result = job.result()["new"]
result["__id"] = result.pop("_key")
if self._collection_name not in ("auditlog", "timeline"):
try:
event = message.TagEvent(
type=message.EventType.new, tagged_object=self, tag_object=tag_obj
)
producer.publish_event(event)
except Exception:
logging.exception("Error while publishing event")
return TagRelationship.load(result)

def expire_tag(self, tag_name: str) -> "TagRelationship":
"""Expires a tag on an Observable.

Args:
tag_name: The name of the tag to expire.
"""
# Avoid circular dependency
graph = self._db.graph("tags")

tags = self.get_tags()

for tag_relationship, tag in tags:
if tag.name != tag_name:
continue
tag_relationship.fresh = False
edge = json.loads(tag_relationship.model_dump_json())
edge["_id"] = tag_relationship.id
graph.update_edge(edge)
return tag_relationship

raise ValueError(
f"Tag '{tag_name}' not found on observable '{self.extended_id}'"
)

def clear_tags(self):
"""Clears all tags on an Observable."""
# Avoid circular dependency
graph = self._db.graph("tags")

self.get_tags()
job = graph.edge_collection("tagged").edges(self.extended_id)
while job.status() != "done":
time.sleep(ASYNC_JOB_WAIT_TIME)
results = job.result()
for edge in results["edges"]:
if self._collection_name not in ("auditlog", "timeline"):
try:
job = self._db.collection("tagged").get(edge["_id"])
while job.status() != "done":
time.sleep(ASYNC_JOB_WAIT_TIME)
tag_relationship = job.result()
tag_collection, tag_id = tag_relationship["target"].split("/")
job = self._db.collection(tag_collection).get(tag_id)
while job.status() != "done":
time.sleep(ASYNC_JOB_WAIT_TIME)
tag_obj = job.result()
event = message.TagEvent(
type=message.EventType.delete,
tagged_object=self,
tag_object=tag_obj,
)
producer.publish_event(event)
except Exception:
logging.exception("Error while publishing event")
job = graph.edge_collection("tagged").delete(edge["_id"])
while job.status() != "done":
time.sleep(ASYNC_JOB_WAIT_TIME)

self._tags = {}

def link_to(
self, target, relationship_type: str, description: str
) -> "Relationship":
Expand Down Expand Up @@ -1044,9 +855,7 @@ def _build_edges(self, arango_edges) -> List["RelationshipTypes"]:
edge["__id"] = edge.pop("_key")
edge["source"] = edge.pop("_from")
edge["target"] = edge.pop("_to")
if "tagged" in edge["_id"]:
relationships.append(graph.TagRelationship.load(edge))
elif "acls" in edge["_id"]:
if "acls" in edge["_id"]:
relationships.append(graph.RoleRelationship.load(edge))
else:
relationships.append(graph.Relationship.load(edge))
Expand Down Expand Up @@ -1184,16 +993,22 @@ def filter(
elif key in {"labels", "relevant_tags"}:
conditions.append(f"o.@arg{i}_key IN @arg{i}_value")
aql_args[f"arg{i}_key"] = key
elif key in ("created", "expires"):
elif key in ("created", "modified", "tag.expires"):
# Value is a string, we're checking the first character.
operator = value[0]
if operator not in ["<", ">"]:
operator = "="
else:
aql_args[f"arg{i}_value"] = value[1:]
filter_conditions.append(
f"DATE_TIMESTAMP(o.{key}) {operator}= DATE_TIMESTAMP(@arg{i}_value)"
)
sorts.append(f"o.{key}")
if key == "tag.expires":
filter_conditions.append(
f"VALUES(o.tags)[* RETURN DATE_TIMESTAMP(CURRENT.expires)] ANY {operator} DATE_TIMESTAMP(@arg{i}_value)"
)
else:
filter_conditions.append(
f"DATE_TIMESTAMP(o.{key}) {operator}= DATE_TIMESTAMP(@arg{i}_value)"
)
sorts.append(f"o.{key}")
elif key in ("name", "value"):
if using_view and not using_regex:
aql_args[f"arg{i}_value"] = f"%{value}%"
Expand Down Expand Up @@ -1262,7 +1077,7 @@ def filter(
tag_filter_query = ""
if tag_filter:
tag_filter_query = (
" FILTER COUNT(INTERSECTION(ATTRIBUTES(MERGE(tags)), @tag_names)) > 0"
" FILTER COUNT(INTERSECTION(ATTRIBUTES(o.tags), @tag_names)) > 0"
)
aql_args["tag_names"] = tag_filter

Expand Down Expand Up @@ -1435,22 +1250,19 @@ def _get_collection(cls):

def tagged_observables_export(cls, args):
aql = """
WITH tags

FOR o in observables
FILTER (o.type IN @acts_on OR @acts_on == [])
LET tags = MERGE(
FOR v, e in 1..1 OUTBOUND o tagged
FILTER v.name NOT IN @ignore
FILTER (e.fresh OR NOT @fresh)
RETURN {[v.name]: MERGE(e, {id: e._id})}
FILTER o.tags != {}
LET tagnames = (
FOR t IN VALUES(o.tags)
FILTER t.name NOT IN @ignore
FILTER (t.fresh OR NOT @fresh)
RETURN t.name
)
FILTER tags != {}
LET tagnames = ATTRIBUTES(tags)

FILTER COUNT(tagnames) > 0
FILTER COUNT(INTERSECTION(tagnames, @include)) > 0 OR @include == []
FILTER COUNT(INTERSECTION(tagnames, @exclude)) == 0
RETURN MERGE(o, {tags: tags})
RETURN o
"""
documents = db.aql.execute(aql, bind_vars=args, count=True, full_count=True)
results = []
Expand Down
1 change: 0 additions & 1 deletion core/events/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def yeti_object_discriminator(v):
Annotated["tag.Tag", PydanticTag("tag")],
Annotated["template.Template", PydanticTag("template")],
Annotated["graph.Relationship", PydanticTag("relationship")],
Annotated["graph.TagRelationship", PydanticTag("tag_relationship")],
Annotated["rbac.Group", PydanticTag("rbacgroup")],
],
Field(discriminator=Discriminator(yeti_object_discriminator)),
Expand Down
35 changes: 1 addition & 34 deletions core/schemas/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,39 +54,6 @@ def load(cls, object: dict):
return cls(**object)


class TagRelationship(BaseModel, database_arango.ArangoYetiConnector):
model_config = ConfigDict(str_strip_whitespace=True)
_exclude_overwrite: list[str] = list()
_collection_name: ClassVar[str] = "tagged"
_root_type: Literal["tag_relationship"] = "tag_relationship"
_type_filter: None = None
__id: str | None = None

source: str
target: str
last_seen: datetime.datetime
expires: datetime.datetime | None = None
fresh: bool

def __init__(self, **data):
super().__init__(**data)
self.__id = data.get("__id", None)

@computed_field(return_type=Literal["tag_relationship"])
@property
def root_type(self):
return self._root_type

@computed_field(return_type=str)
@property
def id(self):
return self.__id

@classmethod
def load(cls, object: dict):
return cls(**object)


class RoleRelationship(BaseModel, database_arango.ArangoYetiConnector):
model_config = ConfigDict(str_strip_whitespace=True)
_exclude_overwrite: list[str] = list()
Expand Down Expand Up @@ -141,4 +108,4 @@ def has_permissions(
return False


RelationshipTypes = Relationship | TagRelationship | RoleRelationship
RelationshipTypes = Relationship | RoleRelationship
Loading
Loading