diff --git a/core/dbt/artifacts/resources/__init__.py b/core/dbt/artifacts/resources/__init__.py index a8aecfd9990..d55ac1c411e 100644 --- a/core/dbt/artifacts/resources/__init__.py +++ b/core/dbt/artifacts/resources/__init__.py @@ -6,6 +6,7 @@ from dbt.artifacts.resources.v1.components import ( ColumnInfo, CompiledResource, + ConceptArgs, Contract, DeferRelation, DependsOn, @@ -19,6 +20,12 @@ RefArgs, Time, ) +from dbt.artifacts.resources.v1.concept import ( + Concept, + ConceptColumn, + ConceptConfig, + ConceptJoin, +) from dbt.artifacts.resources.v1.config import ( Hook, NodeAndTestConfig, diff --git a/core/dbt/artifacts/resources/types.py b/core/dbt/artifacts/resources/types.py index 838104ea7b5..eeef566b992 100644 --- a/core/dbt/artifacts/resources/types.py +++ b/core/dbt/artifacts/resources/types.py @@ -30,6 +30,7 @@ class NodeType(StrEnum): Macro = "macro" Exposure = "exposure" Metric = "metric" + Concept = "concept" Group = "group" SavedQuery = "saved_query" SemanticModel = "semantic_model" diff --git a/core/dbt/artifacts/resources/v1/components.py b/core/dbt/artifacts/resources/v1/components.py index ec2c6cc828c..f8fe16d184b 100644 --- a/core/dbt/artifacts/resources/v1/components.py +++ b/core/dbt/artifacts/resources/v1/components.py @@ -69,6 +69,22 @@ def keyword_args(self) -> Dict[str, Optional[NodeVersion]]: return {} +@dataclass +class ConceptArgs(dbtClassMixin): + """Arguments for referencing a concept""" + + name: str + package: Optional[str] = None + columns: List[str] = field(default_factory=list) + + @property + def positional_args(self) -> List[str]: + if self.package: + return [self.package, self.name] + else: + return [self.name] + + @dataclass class ColumnInfo(AdditionalPropertiesMixin, ExtensibleDbtClassMixin): """Used in all ManifestNodes and SourceDefinition""" @@ -241,6 +257,7 @@ class CompiledResource(ParsedResource): refs: List[RefArgs] = field(default_factory=list) sources: List[List[str]] = field(default_factory=list) metrics: List[List[str]] = field(default_factory=list) + concepts: List[ConceptArgs] = field(default_factory=list) # For tracking concept dependencies depends_on: DependsOn = field(default_factory=DependsOn) compiled_path: Optional[str] = None compiled: bool = False diff --git a/core/dbt/artifacts/resources/v1/concept.py b/core/dbt/artifacts/resources/v1/concept.py new file mode 100644 index 00000000000..cbcac4a6b5a --- /dev/null +++ b/core/dbt/artifacts/resources/v1/concept.py @@ -0,0 +1,55 @@ +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Union + +from dbt.artifacts.resources.base import GraphResource +from dbt.artifacts.resources.v1.components import DependsOn +from dbt_common.dataclass_schema import dbtClassMixin + + +@dataclass +class ConceptJoin(dbtClassMixin): + """Represents a join relationship in a concept definition.""" + + name: str # name of the model or concept to join + base_key: str # column in base model for join + foreign_key: Optional[str] = None # column in joined model (defaults to primary_key) + alias: Optional[str] = None # alias for the joined table + columns: List[str] = field(default_factory=list) # columns to expose from join + join_type: str = "left" # type of join (left, inner, etc.) + + +@dataclass +class ConceptColumn(dbtClassMixin): + """Represents a column definition in a concept.""" + + name: str + description: Optional[str] = None + alias: Optional[str] = None # optional alias for the column + + +@dataclass +class ConceptConfig(dbtClassMixin): + """Configuration for a concept.""" + + enabled: bool = True + meta: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class Concept(GraphResource): + """A concept resource definition.""" + + name: str + base_model: str # reference to the base model + description: str = "" + primary_key: Union[str, List[str]] = "id" # primary key column(s) + columns: List[ConceptColumn] = field(default_factory=list) + joins: List[ConceptJoin] = field(default_factory=list) + config: ConceptConfig = field(default_factory=ConceptConfig) + meta: Dict[str, Any] = field(default_factory=dict) + tags: List[str] = field(default_factory=list) + depends_on: DependsOn = field(default_factory=DependsOn) + + +# Type alias for concept resource +ConceptResource = Concept diff --git a/core/dbt/context/providers.py b/core/dbt/context/providers.py index 4f96bc54640..68ce3a77561 100644 --- a/core/dbt/context/providers.py +++ b/core/dbt/context/providers.py @@ -1,6 +1,7 @@ import abc import os from copy import deepcopy +from dataclasses import dataclass from typing import ( TYPE_CHECKING, Any, @@ -799,6 +800,183 @@ def resolve(self, target_name: str, target_package: Optional[str] = None) -> Met return ResolvedMetricReference(target_metric, self.manifest) +# `cref` implementations. +@dataclass +class ConceptReference: + name: str + package: Optional[str] = None + columns: Optional[List[str]] = None + + def __post_init__(self): + if self.columns is None: + self.columns = [] + + +class BaseConceptResolver: + def __init__( + self, + db_wrapper: BaseDatabaseWrapper, + model: Resource, + config: RuntimeConfig, + manifest: Manifest, + ) -> None: + self.db_wrapper = db_wrapper + self.model = model + self.config = config + self.manifest = manifest + self.current_project = config.project_name + self.Relation = db_wrapper.Relation + + def __call__( + self, concept_name: str, columns: List[str], package: Optional[str] = None + ) -> str: + """Entry point for cref() calls from Jinja templates.""" + return self.resolve(concept_name, columns, package) + + @abc.abstractmethod + def resolve(self, concept_name: str, columns: List[str], package: Optional[str] = None) -> str: + """Abstract method to resolve concept references.""" + pass + + def _repack_args( + self, name: str, package: Optional[str], columns: Optional[List[str]] + ) -> ConceptReference: + return ConceptReference(name, package, columns) + + +class ParseConceptResolver(BaseConceptResolver): + def resolve(self, name: str, columns: List[str], package: Optional[str] = None) -> str: + from dbt.artifacts.resources import ConceptArgs + + # During parsing, we just track the dependency and return a placeholder + concept_args = ConceptArgs(name=name, package=package, columns=columns) + + # Only nodes that inherit from CompiledResource have the concepts attribute + if hasattr(self.model, "concepts"): + self.model.concepts.append(concept_args) + + # Return a placeholder that will be replaced during compilation + return f"/* cref placeholder for {name} */" + + +class RuntimeConceptResolver(BaseConceptResolver): + def resolve(self, concept_name: str, columns: List[str], package: Optional[str] = None) -> str: + # Resolve the concept from the manifest + target_concept = self.manifest.resolve_concept( + concept_name, + package, + self.current_project, + self.model.package_name, + ) + + if target_concept is None: + raise TargetNotFoundError( + node=self.model, + target_name=concept_name, + target_kind="concept", + target_package=package, + ) + + # Generate the SQL for the concept reference + return self._generate_concept_sql(target_concept, columns) + + def _generate_concept_sql(self, concept, requested_columns: List[str]) -> str: + """Generate the SQL subquery for a concept reference.""" + + # Validate that all requested columns are available in the concept + available_columns = self._get_available_columns(concept) + for col in requested_columns: + if col not in available_columns: + raise CompilationError( + f"Column '{col}' is not available in concept '{concept.name}'. " + f"Available columns: {', '.join(sorted(available_columns.keys()))}" + ) + + # Determine which joins are needed based on requested columns + required_joins = self._determine_required_joins(concept, requested_columns) + + # Build the SQL + sql_parts = [] + + # SELECT clause + select_columns = [] + for col in requested_columns: + column_info = available_columns[col] + if column_info["source"] == "base": + select_columns.append(f"base.{col}") + else: + alias = column_info["alias"] + select_columns.append(f"{alias}.{col}") + + sql_parts.append("SELECT") + sql_parts.append(" " + ",\n ".join(select_columns)) + + # FROM clause (base model) + base_ref = f"{{{{ref('{concept.base_model}')}}}}" + sql_parts.append(f"FROM {base_ref} AS base") + + # JOIN clauses + for join in required_joins: + join_sql = self._generate_join_sql(join, concept) + sql_parts.append(join_sql) + + return "(\n" + "\n".join(sql_parts) + "\n)" + + def _get_available_columns(self, concept) -> Dict[str, Dict[str, str]]: + """Get all available columns from the concept and its joins.""" + available = {} + + # Add base model columns + for col in concept.columns: + available[col.name] = {"source": "base", "alias": "base", "original_name": col.name} + + # Add columns from joins + for join in concept.joins: + alias = join.alias or join.name + for col in join.columns: + available[col.name] = {"source": "join", "alias": alias, "original_name": col.name} + + return available + + def _determine_required_joins(self, concept, requested_columns: List[str]) -> List: + """Determine which joins are needed for the requested columns.""" + available_columns = self._get_available_columns(concept) + needed_joins = set() + + for col in requested_columns: + column_info = available_columns[col] + if column_info["source"] == "join": + # Find the join that provides this column + for join in concept.joins: + alias = join.alias or join.name + if alias == column_info["alias"]: + needed_joins.add(id(join)) # Use id to ensure uniqueness + break + + # Return the actual join objects + required_joins = [] + for join in concept.joins: + if id(join) in needed_joins: + required_joins.append(join) + + return required_joins + + def _generate_join_sql(self, join, concept) -> str: + """Generate SQL for a single join to a model.""" + join_alias = join.alias or join.name + foreign_key = join.foreign_key or concept.primary_key + + # Handle model references - joins only support models, not other concepts + if join.name.startswith("ref("): + # Direct model reference (e.g., "ref('stg_customers')") + join_ref = join.name + else: + # Model name that needs ref() wrapping (e.g., "stg_customers") + join_ref = f"{{{{ref('{join.name}')}}}}" + + return f"LEFT JOIN {join_ref} AS {join_alias} ON base.{join.base_key} = {join_alias}.{foreign_key}" + + # `var` implementations. class ModelConfiguredVar(Var): def __init__( @@ -871,6 +1049,7 @@ class Provider(Protocol): ref: Type[BaseRefResolver] source: Type[BaseSourceResolver] metric: Type[BaseMetricResolver] + cref: Type[BaseConceptResolver] class ParseProvider(Provider): @@ -881,6 +1060,7 @@ class ParseProvider(Provider): ref = ParseRefResolver source = ParseSourceResolver metric = ParseMetricResolver + cref = ParseConceptResolver class GenerateNameProvider(Provider): @@ -891,6 +1071,7 @@ class GenerateNameProvider(Provider): ref = ParseRefResolver source = ParseSourceResolver metric = ParseMetricResolver + cref = ParseConceptResolver class RuntimeProvider(Provider): @@ -901,6 +1082,7 @@ class RuntimeProvider(Provider): ref = RuntimeRefResolver source = RuntimeSourceResolver metric = RuntimeMetricResolver + cref = RuntimeConceptResolver class RuntimeUnitTestProvider(Provider): @@ -911,6 +1093,7 @@ class RuntimeUnitTestProvider(Provider): ref = RuntimeUnitTestRefResolver source = RuntimeUnitTestSourceResolver metric = RuntimeMetricResolver + cref = RuntimeConceptResolver class OperationProvider(RuntimeProvider): @@ -1153,6 +1336,21 @@ def source(self) -> Callable: def metric(self) -> Callable: return self.provider.metric(self.db_wrapper, self.model, self.config, self.manifest) + @contextproperty() + def cref(self) -> Callable: + """The `cref()` function allows you to reference a concept and select + specific columns from it. A concept defines a base model and its + joinable features, allowing for dynamic SQL generation based on the + columns you request. + + Usage: + select * from {{ cref('orders', ['order_id', 'customer_name', 'total_amount']) }} + + This will generate a subquery that includes only the necessary joins + to provide the requested columns from the 'orders' concept. + """ + return self.provider.cref(self.db_wrapper, self.model, self.config, self.manifest) + @contextproperty("config") def ctx_config(self) -> Config: """The `config` variable exists to handle end-user configuration for diff --git a/core/dbt/contracts/files.py b/core/dbt/contracts/files.py index 15e951e026c..ca4b800086a 100644 --- a/core/dbt/contracts/files.py +++ b/core/dbt/contracts/files.py @@ -192,6 +192,7 @@ class SchemaSourceFile(BaseSourceFile): sources: List[str] = field(default_factory=list) exposures: List[str] = field(default_factory=list) metrics: List[str] = field(default_factory=list) + concepts: List[str] = field(default_factory=list) snapshots: List[str] = field(default_factory=list) # The following field will no longer be used. Leaving # here to avoid breaking existing projects. To be removed diff --git a/core/dbt/contracts/graph/manifest.py b/core/dbt/contracts/graph/manifest.py index e53ae1a48b1..d6338c0c337 100644 --- a/core/dbt/contracts/graph/manifest.py +++ b/core/dbt/contracts/graph/manifest.py @@ -55,6 +55,7 @@ ManifestNode, Metric, ModelNode, + ParsedConcept, SavedQuery, SeedNode, SemanticModel, @@ -686,6 +687,7 @@ class Disabled(Generic[D]): MaybeSavedQueryNode = Optional[Union[SavedQuery, Disabled[SavedQuery]]] +MaybeConceptNode = Optional[Union[ParsedConcept, Disabled[ParsedConcept]]] MaybeDocumentation = Optional[Documentation] @@ -878,6 +880,7 @@ class Manifest(MacroMethods, dbtClassMixin): docs: MutableMapping[str, Documentation] = field(default_factory=dict) exposures: MutableMapping[str, Exposure] = field(default_factory=dict) metrics: MutableMapping[str, Metric] = field(default_factory=dict) + concepts: MutableMapping[str, ParsedConcept] = field(default_factory=dict) groups: MutableMapping[str, Group] = field(default_factory=dict) selectors: MutableMapping[str, Any] = field(default_factory=dict) files: MutableMapping[str, AnySourceFile] = field(default_factory=dict) @@ -1416,6 +1419,34 @@ def resolve_metric( return Disabled(disabled[0]) return None + def resolve_concept( + self, + target_concept_name: str, + target_concept_package: Optional[str], + current_project: str, + node_package: str, + ) -> MaybeConceptNode: + disabled = None + + candidates = _packages_to_search(current_project, node_package, target_concept_package) + for pkg in candidates: + # Look for concept in the concepts dictionary + for concept_unique_id, concept_obj in self.concepts.items(): + if ( + concept_obj.name == target_concept_name + and concept_obj.package_name == pkg + and concept_obj.config.enabled + ): + return concept_obj + + # Check if it's disabled + if disabled is None: + disabled = self.disabled_lookup.find(f"{target_concept_name}", pkg) + + if disabled: + return Disabled(disabled[0]) + return None + def resolve_saved_query( self, target_saved_query_name: str, @@ -1646,6 +1677,11 @@ def add_group(self, source_file: SchemaSourceFile, group: Group): self.groups[group.unique_id] = group source_file.groups.append(group.unique_id) + def add_concept(self, source_file: SchemaSourceFile, concept: ParsedConcept): + _check_duplicates(concept, self.concepts) + self.concepts[concept.unique_id] = concept + source_file.concepts.append(concept.unique_id) + def add_disabled_nofile(self, node: GraphMemberNode): # There can be multiple disabled nodes for the same unique_id if node.unique_id in self.disabled: diff --git a/core/dbt/contracts/graph/nodes.py b/core/dbt/contracts/graph/nodes.py index 6ae96084f3a..2ff31451573 100644 --- a/core/dbt/contracts/graph/nodes.py +++ b/core/dbt/contracts/graph/nodes.py @@ -22,13 +22,9 @@ from dbt.adapters.base import ConstraintSupport from dbt.adapters.factory import get_adapter_constraint_support from dbt.artifacts.resources import Analysis as AnalysisResource -from dbt.artifacts.resources import ( - BaseResource, - ColumnInfo, - CompiledResource, - DependsOn, - Docs, -) +from dbt.artifacts.resources import BaseResource, ColumnInfo, CompiledResource +from dbt.artifacts.resources import Concept as ConceptResource +from dbt.artifacts.resources import DependsOn, Docs from dbt.artifacts.resources import Documentation as DocumentationResource from dbt.artifacts.resources import Exposure as ExposureResource from dbt.artifacts.resources import FileHash @@ -1536,6 +1532,58 @@ def to_logging_dict(self) -> Dict[str, Union[str, Dict[str, str]]]: } +# ==================================== +# Concept node +# ==================================== + + +@dataclass +class ParsedConcept(GraphNode, ConceptResource): + """A parsed concept that defines a base model and its joinable features.""" + + @property + def depends_on_nodes(self): + return self.depends_on.nodes + + @property + def search_name(self): + return self.name + + @classmethod + def resource_class(cls) -> Type[ConceptResource]: + return ConceptResource + + def same_description(self, old: "ParsedConcept") -> bool: + return self.description == old.description + + def same_base_model(self, old: "ParsedConcept") -> bool: + return self.base_model == old.base_model + + def same_primary_key(self, old: "ParsedConcept") -> bool: + return self.primary_key == old.primary_key + + def same_joins(self, old: "ParsedConcept") -> bool: + return self.joins == old.joins + + def same_columns(self, old: "ParsedConcept") -> bool: + return self.columns == old.columns + + def same_config(self, old: "ParsedConcept") -> bool: + return self.config == old.config + + def same_contents(self, other: Optional["ParsedConcept"]) -> bool: + if other is None: + return False + return ( + self.same_description(other) + and self.same_base_model(other) + and self.same_primary_key(other) + and self.same_joins(other) + and self.same_columns(other) + and self.same_config(other) + ) + + # ==================================== # SemanticModel node # ==================================== @@ -1752,6 +1800,7 @@ class ParsedSingularTestPatch(ParsedPatch): ResultNode, Exposure, Metric, + ParsedConcept, SavedQuery, SemanticModel, UnitTestDefinition, diff --git a/core/dbt/contracts/graph/unparsed.py b/core/dbt/contracts/graph/unparsed.py index 46c56f72482..58a2a0ff098 100644 --- a/core/dbt/contracts/graph/unparsed.py +++ b/core/dbt/contracts/graph/unparsed.py @@ -792,3 +792,49 @@ def validate(cls, data): if data.get("versions", None): if data["versions"].get("include") and data["versions"].get("exclude"): raise ValidationError("Unit tests can not both include and exclude versions.") + + +@dataclass +class UnparsedConceptJoin(dbtClassMixin): + """Represents an unparsed join relationship in a concept definition.""" + + name: str # name of the model or concept to join + base_key: str # column in base model for join + foreign_key: Optional[str] = None # column in joined model (defaults to primary_key) + alias: Optional[str] = None # alias for the joined table + columns: List[str] = field(default_factory=list) # columns to expose from join + join_type: str = "left" # type of join (left, inner, etc.) + + +@dataclass +class UnparsedConceptColumn(dbtClassMixin): + """Represents an unparsed column definition in a concept.""" + + name: str + description: Optional[str] = None + alias: Optional[str] = None # optional alias for the column + + +@dataclass +class UnparsedConcept(dbtClassMixin): + """Represents an unparsed concept definition.""" + + name: str + base_model: str # reference to the base model + description: str = "" + primary_key: Union[str, List[str]] = "id" # primary key column(s) + columns: List[Union[str, UnparsedConceptColumn]] = field(default_factory=list) + joins: List[UnparsedConceptJoin] = field(default_factory=list) + config: Dict[str, Any] = field(default_factory=dict) + meta: Dict[str, Any] = field(default_factory=dict) + tags: List[str] = field(default_factory=list) + + @classmethod + def validate(cls, data): + super(UnparsedConcept, cls).validate(data) + if "name" in data: + # name can only contain alphanumeric chars and underscores + if not (re.match(r"[\w-]+$", data["name"])): + raise ParsingError( + f"Invalid concept name '{data['name']}'. Names must contain only letters, numbers, and underscores." + ) diff --git a/core/dbt/parser/schema_yaml_readers.py b/core/dbt/parser/schema_yaml_readers.py index 050b2695fdf..d6c298c62fc 100644 --- a/core/dbt/parser/schema_yaml_readers.py +++ b/core/dbt/parser/schema_yaml_readers.py @@ -2,6 +2,9 @@ from typing import Any, Dict, List, Optional, Union from dbt.artifacts.resources import ( + ConceptColumn, + ConceptConfig, + ConceptJoin, ConversionTypeParams, CumulativeTypeParams, Dimension, @@ -34,8 +37,16 @@ generate_parse_semantic_models, ) from dbt.contracts.files import SchemaSourceFile -from dbt.contracts.graph.nodes import Exposure, Group, Metric, SavedQuery, SemanticModel +from dbt.contracts.graph.nodes import ( + Exposure, + Group, + Metric, + ParsedConcept, + SavedQuery, + SemanticModel, +) from dbt.contracts.graph.unparsed import ( + UnparsedConcept, UnparsedConversionTypeParams, UnparsedCumulativeTypeParams, UnparsedDimension, @@ -192,6 +203,84 @@ def parse(self) -> None: self.parse_exposure(unparsed) +class ConceptParser(YamlReader): + def __init__(self, schema_parser: SchemaParser, yaml: YamlBlock) -> None: + super().__init__(schema_parser, yaml, NodeType.Concept.pluralize()) + self.schema_parser = schema_parser + self.yaml = yaml + + def parse_concept(self, unparsed: UnparsedConcept) -> None: + package_name = self.project.project_name + unique_id = f"{NodeType.Concept}.{package_name}.{unparsed.name}" + path = self.yaml.path.relative_path + + fqn = self.schema_parser.get_fqn_prefix(path) + fqn.append(unparsed.name) + + # Convert unparsed columns to ConceptColumn objects + columns = [] + for col in unparsed.columns: + if isinstance(col, str): + columns.append(ConceptColumn(name=col)) + else: + columns.append( + ConceptColumn(name=col.name, description=col.description, alias=col.alias) + ) + + # Convert unparsed joins to ConceptJoin objects + joins = [] + for join in unparsed.joins: + joins.append( + ConceptJoin( + name=join.name, + base_key=join.base_key, + foreign_key=join.foreign_key, + alias=join.alias, + columns=join.columns, + join_type=join.join_type, + ) + ) + + config = ConceptConfig( + enabled=unparsed.config.get("enabled", True), meta=unparsed.config.get("meta", {}) + ) + + # Create the parsed concept + concept = ParsedConcept( + package_name=package_name, + path=path, + original_file_path=self.yaml.path.original_file_path, + unique_id=unique_id, + fqn=fqn, + resource_type=NodeType.Concept, + name=unparsed.name, + description=unparsed.description, + base_model=unparsed.base_model, + primary_key=unparsed.primary_key, + columns=columns, + joins=joins, + config=config, + meta=unparsed.meta, + tags=unparsed.tags, + ) + + # Add to manifest + from dbt.contracts.files import SchemaSourceFile + + if isinstance(self.yaml.file, SchemaSourceFile): + self.manifest.add_concept(self.yaml.file, concept) + + def parse(self) -> None: + for data in self.get_key_dicts(): + try: + UnparsedConcept.validate(data) + unparsed = UnparsedConcept.from_dict(data) + except (ValidationError, JSONValidationError) as exc: + raise YamlParseDictError(self.yaml.path, self.key, data, exc) + + self.parse_concept(unparsed) + + class MetricParser(YamlReader): def __init__(self, schema_parser: SchemaParser, yaml: YamlBlock) -> None: super().__init__(schema_parser, yaml, NodeType.Metric.pluralize()) diff --git a/core/dbt/parser/schemas.py b/core/dbt/parser/schemas.py index 71aa5f6bbb2..b20dff0c5b8 100644 --- a/core/dbt/parser/schemas.py +++ b/core/dbt/parser/schemas.py @@ -293,6 +293,13 @@ def parse_file(self, block: FileBlock, dct: Optional[Dict] = None) -> None: group_parser = GroupParser(self, yaml_block) group_parser.parse() + # ConceptParser.parse() + if "concepts" in dct: + from dbt.parser.schema_yaml_readers import ConceptParser + + concept_parser = ConceptParser(self, yaml_block) + concept_parser.parse() + if "semantic_models" in dct: from dbt.parser.schema_yaml_readers import SemanticModelParser diff --git a/tests/functional/concepts/__init__.py b/tests/functional/concepts/__init__.py new file mode 100644 index 00000000000..90c287543ad --- /dev/null +++ b/tests/functional/concepts/__init__.py @@ -0,0 +1 @@ +# Concepts functional tests diff --git a/tests/functional/concepts/fixtures.py b/tests/functional/concepts/fixtures.py new file mode 100644 index 00000000000..54178890f44 --- /dev/null +++ b/tests/functional/concepts/fixtures.py @@ -0,0 +1,151 @@ +"""Fixtures for concept functional tests.""" + +# Basic concept definition with joins +basic_concept_yml = """ +version: 2 + +concepts: + - name: orders + description: "Orders concept with customer data" + base_model: stg_orders + primary_key: order_id + columns: + - name: order_id + description: "Primary key for orders" + - name: customer_id + description: "Foreign key to customers" + - name: order_date + description: "Date when order was placed" + - name: status + description: "Order status" + joins: + - name: stg_customers + base_key: customer_id + foreign_key: id + alias: customer + columns: + - name: customer_name + description: "Customer name" + - name: email + description: "Customer email" +""" + +# Base staging models +stg_orders_sql = """ +select * from {{ ref('raw_orders') }} +""" + +stg_customers_sql = """ +select * from {{ ref('raw_customers') }} +""" + +# Model using cref +orders_report_sql = """ +select + order_id, + order_date, + customer_name +from {{ cref('orders', ['order_id', 'order_date', 'customer_name']) }} +where order_date >= '2023-01-01' +""" + +# Seed data +raw_orders_csv = """order_id,customer_id,order_date,status +1,1,2023-01-01,completed +2,2,2023-01-02,pending +3,1,2023-01-03,completed +4,3,2023-01-04,cancelled +""" + +raw_customers_csv = """id,customer_name,email +1,Alice,alice@example.com +2,Bob,bob@example.com +3,Charlie,charlie@example.com +""" + +# Concept with only base columns (no joins) +simple_concept_yml = """ +version: 2 + +concepts: + - name: simple_orders + description: "Simple orders concept with only base columns" + base_model: stg_orders + primary_key: order_id + columns: + - name: order_id + - name: customer_id + - name: order_date + - name: status +""" + +# Invalid concept with missing base_model +invalid_concept_yml = """ +version: 2 + +concepts: + - name: invalid_orders + description: "Invalid concept" + columns: + - name: order_id +""" + +# Concept with multiple joins +multi_join_concept_yml = """ +version: 2 + +concepts: + - name: enriched_orders + description: "Orders with customer and product data" + base_model: stg_orders + primary_key: order_id + columns: + - name: order_id + - name: customer_id + - name: order_date + - name: status + joins: + - name: stg_customers + base_key: customer_id + foreign_key: id + alias: customer + columns: + - name: customer_name + - name: email + - name: stg_products + base_key: product_id + foreign_key: id + alias: product + columns: + - name: product_name + - name: price +""" + +# Additional staging model for multi-join test +stg_products_sql = """ +select * from {{ ref('raw_products') }} +""" + +# Additional seed for multi-join test +raw_products_csv = """id,product_name,price +1,Widget,10.00 +2,Gadget,20.00 +3,Doohickey,15.00 +""" + +# Model using multi-join concept with partial columns +partial_join_model_sql = """ +select + order_id, + customer_name, + product_name +from {{ cref('enriched_orders', ['order_id', 'customer_name', 'product_name']) }} +""" + +# Model using only base columns (should generate no joins) +base_only_model_sql = """ +select + order_id, + order_date +from {{ cref('orders', ['order_id', 'order_date']) }} +""" diff --git a/tests/functional/concepts/test_concepts.py b/tests/functional/concepts/test_concepts.py new file mode 100644 index 00000000000..d0f16207d53 --- /dev/null +++ b/tests/functional/concepts/test_concepts.py @@ -0,0 +1,228 @@ +import pytest + +from dbt.cli.main import dbtRunner +from dbt.contracts.graph.manifest import Manifest +from dbt.exceptions import CompilationError, ParsingError +from dbt.tests.util import check_relations_equal, get_manifest, run_dbt +from tests.functional.concepts.fixtures import ( + base_only_model_sql, + basic_concept_yml, + invalid_concept_yml, + multi_join_concept_yml, + orders_report_sql, + partial_join_model_sql, + raw_customers_csv, + raw_orders_csv, + raw_products_csv, + simple_concept_yml, + stg_customers_sql, + stg_orders_sql, + stg_products_sql, +) + + +class TestBasicConcepts: + @pytest.fixture(scope="class") + def models(self): + return { + "concept_schema.yml": basic_concept_yml, + "stg_orders.sql": stg_orders_sql, + "stg_customers.sql": stg_customers_sql, + "orders_report.sql": orders_report_sql, + } + + @pytest.fixture(scope="class") + def seeds(self): + return { + "raw_orders.csv": raw_orders_csv, + "raw_customers.csv": raw_customers_csv, + } + + def test_parse_basic_concept(self, project): + """Test that a basic concept definition can be parsed.""" + runner = dbtRunner() + result = runner.invoke(["parse"]) + assert result.success + assert isinstance(result.result, Manifest) + + manifest = get_manifest(project.project_root) + + # Check that concept was parsed and stored in manifest + assert "concept.test.orders" in manifest.concepts + concept = manifest.concepts["concept.test.orders"] + + # Verify concept properties + assert concept.name == "orders" + assert concept.base_model == "stg_orders" + assert concept.primary_key == "order_id" + assert len(concept.columns) == 4 # order_id, customer_id, order_date, status + assert len(concept.joins) == 1 # stg_customers join + + # Verify join properties + join = concept.joins[0] + assert join.name == "stg_customers" + assert join.base_key == "customer_id" + assert join.foreign_key == "id" + assert join.alias == "customer" + assert len(join.columns) == 2 # customer_name, email + + def test_compile_cref_usage(self, project): + """Test that models using cref can be compiled.""" + runner = dbtRunner() + result = runner.invoke(["parse"]) + assert result.success + + # Compile the project + result = runner.invoke(["compile"]) + assert result.success + + manifest = get_manifest(project.project_root) + + # Check that the orders_report model was compiled + assert "model.test.orders_report" in manifest.nodes + compiled_node = manifest.nodes["model.test.orders_report"] + + # Verify that dependencies were tracked + expected_deps = {"model.test.stg_orders", "model.test.stg_customers"} + assert set(compiled_node.depends_on.nodes) == expected_deps + + def test_cref_sql_generation(self, project): + """Test that cref generates correct SQL.""" + runner = dbtRunner() + result = runner.invoke(["compile"]) + assert result.success + + manifest = get_manifest(project.project_root) + compiled_node = manifest.nodes["model.test.orders_report"] + + # The compiled SQL should contain JOIN logic + compiled_sql = compiled_node.compiled_code + + # Basic checks that the SQL was expanded + assert "SELECT" in compiled_sql.upper() + assert "FROM" in compiled_sql.upper() + assert "LEFT JOIN" in compiled_sql.upper() + + # Should reference the base and joined models + assert "stg_orders" in compiled_sql + assert "stg_customers" in compiled_sql + + +class TestSimpleConcepts: + @pytest.fixture(scope="class") + def models(self): + return { + "simple_concept_schema.yml": simple_concept_yml, + "stg_orders.sql": stg_orders_sql, + "base_only.sql": base_only_model_sql, + } + + @pytest.fixture(scope="class") + def seeds(self): + return { + "raw_orders.csv": raw_orders_csv, + } + + def test_concept_with_no_joins(self, project): + """Test concept that has no joins (only base columns).""" + runner = dbtRunner() + result = runner.invoke(["parse"]) + assert result.success + + manifest = get_manifest(project.project_root) + assert "concept.test.simple_orders" in manifest.concepts + + concept = manifest.concepts["concept.test.simple_orders"] + assert len(concept.joins) == 0 + assert len(concept.columns) == 4 + + def test_base_only_cref_compilation(self, project): + """Test that cref with only base columns compiles without joins.""" + runner = dbtRunner() + result = runner.invoke(["compile"]) + assert result.success + + manifest = get_manifest(project.project_root) + compiled_node = manifest.nodes["model.test.base_only"] + + # Should only depend on base model + assert compiled_node.depends_on.nodes == ["model.test.stg_orders"] + + # Compiled SQL should not contain JOIN + compiled_sql = compiled_node.compiled_code + assert "JOIN" not in compiled_sql.upper() + + +class TestConceptErrors: + @pytest.fixture(scope="class") + def models(self): + return { + "invalid_concept_schema.yml": invalid_concept_yml, + } + + def test_invalid_concept_parsing(self, project): + """Test that invalid concept definitions raise parsing errors.""" + runner = dbtRunner() + result = runner.invoke(["parse"]) + assert not result.success + # Should fail because base_model is missing + assert isinstance(result.exception, (ParsingError, Exception)) + + +class TestMultiJoinConcepts: + @pytest.fixture(scope="class") + def models(self): + return { + "multi_join_schema.yml": multi_join_concept_yml, + "stg_orders.sql": stg_orders_sql, + "stg_customers.sql": stg_customers_sql, + "stg_products.sql": stg_products_sql, + "partial_join.sql": partial_join_model_sql, + } + + @pytest.fixture(scope="class") + def seeds(self): + return { + "raw_orders.csv": raw_orders_csv, + "raw_customers.csv": raw_customers_csv, + "raw_products.csv": raw_products_csv, + } + + def test_multi_join_concept_parsing(self, project): + """Test parsing concept with multiple joins.""" + runner = dbtRunner() + result = runner.invoke(["parse"]) + assert result.success + + manifest = get_manifest(project.project_root) + concept = manifest.concepts["concept.test.enriched_orders"] + + assert len(concept.joins) == 2 + join_names = [join.name for join in concept.joins] + assert "stg_customers" in join_names + assert "stg_products" in join_names + + def test_partial_join_compilation(self, project): + """Test that only needed joins are included in compilation.""" + runner = dbtRunner() + result = runner.invoke(["compile"]) + assert result.success + + manifest = get_manifest(project.project_root) + compiled_node = manifest.nodes["model.test.partial_join"] + + # Should depend on base and both joined models + # (conservative dependency tracking) + expected_deps = { + "model.test.stg_orders", + "model.test.stg_customers", + "model.test.stg_products", + } + assert set(compiled_node.depends_on.nodes) == expected_deps + + # Compiled SQL should contain both joins since we requested + # columns from both (customer_name and product_name) + compiled_sql = compiled_node.compiled_code + assert "LEFT JOIN" in compiled_sql.upper() + assert "stg_customers" in compiled_sql + assert "stg_products" in compiled_sql diff --git a/tests/unit/contracts/graph/test_manifest.py b/tests/unit/contracts/graph/test_manifest.py index da835cd5801..38421f719cf 100644 --- a/tests/unit/contracts/graph/test_manifest.py +++ b/tests/unit/contracts/graph/test_manifest.py @@ -99,6 +99,7 @@ "time_spine", "batch", "freshness", + "concepts", } ) diff --git a/tests/unit/parser/test_concept_parser.py b/tests/unit/parser/test_concept_parser.py new file mode 100644 index 00000000000..ca256c4d8a7 --- /dev/null +++ b/tests/unit/parser/test_concept_parser.py @@ -0,0 +1,203 @@ +from pathlib import Path +from unittest.mock import Mock, patch + +import pytest + +from dbt.contracts.files import SchemaSourceFile +from dbt.contracts.graph.unparsed import ( + UnparsedConcept, + UnparsedConceptColumn, + UnparsedConceptJoin, +) +from dbt.exceptions import ParsingError +from dbt.parser.schema_yaml_readers import ConceptParser + + +class TestConceptParser: + @pytest.fixture + def mock_schema_parser(self): + """Mock schema parser for testing.""" + schema_parser = Mock() + schema_parser.manifest = Mock() + schema_parser.manifest.add_concept = Mock() + schema_parser.project = Mock() + schema_parser.project.project_name = "test_project" + schema_parser.get_fqn_prefix = Mock(return_value=["test_project"]) + return schema_parser + + @pytest.fixture + def mock_yaml_block(self): + """Mock YAML block for testing.""" + from dbt.contracts.files import SchemaSourceFile + + yaml_block = Mock() + yaml_block.path = Mock() + yaml_block.path.relative_path = "models/schema.yml" + # Mock the file to be a SchemaSourceFile instance + yaml_block.file = Mock(spec=SchemaSourceFile) + return yaml_block + + @pytest.fixture + def concept_parser(self, mock_schema_parser, mock_yaml_block): + """Create a ConceptParser instance for testing.""" + parser = ConceptParser(schema_parser=mock_schema_parser, yaml=mock_yaml_block) + return parser + + def test_parse_basic_concept(self, concept_parser, mock_schema_parser): + """Test parsing a basic concept definition.""" + # Create test concept data + concept_data = { + "name": "orders", + "description": "Orders concept", + "base_model": "stg_orders", + "primary_key": "order_id", + "columns": [ + {"name": "order_id", "description": "Primary key"}, + {"name": "customer_id", "description": "Foreign key"}, + ], + "joins": [ + { + "name": "stg_customers", + "base_key": "customer_id", + "foreign_key": "id", + "alias": "customer", + "columns": [{"name": "customer_name"}, {"name": "email"}], + } + ], + } + + # Create unparsed concept + unparsed = UnparsedConcept( + name=concept_data["name"], + description=concept_data["description"], + base_model=concept_data["base_model"], + primary_key=concept_data["primary_key"], + columns=[ + UnparsedConceptColumn(name=col["name"], description=col.get("description")) + for col in concept_data["columns"] + ], + joins=[ + UnparsedConceptJoin( + name=join["name"], + base_key=join["base_key"], + foreign_key=join["foreign_key"], + alias=join["alias"], + columns=[UnparsedConceptColumn(name=col["name"]) for col in join["columns"]], + ) + for join in concept_data["joins"] + ], + ) + + # Parse the concept + concept_parser.parse_concept(unparsed=unparsed) + + # The parse_concept method doesn't return the concept, it adds it to the manifest + # So we'll verify it was called correctly + mock_schema_parser.manifest.add_concept.assert_called_once() + + # Get the parsed concept from the call arguments + call_args = mock_schema_parser.manifest.add_concept.call_args[0] + parsed_concept = call_args[1] # Second argument is the concept + + # Verify the parsed concept + assert parsed_concept.name == "orders" + assert parsed_concept.description == "Orders concept" + assert parsed_concept.base_model == "stg_orders" + assert parsed_concept.primary_key == "order_id" + assert len(parsed_concept.columns) == 2 + assert len(parsed_concept.joins) == 1 + + # Verify the join + join = parsed_concept.joins[0] + assert join.name == "stg_customers" + assert join.base_key == "customer_id" + assert join.foreign_key == "id" + assert join.alias == "customer" + assert len(join.columns) == 2 + + def test_parse_concept_empty_base_model(self, concept_parser): + """Test that parsing works with empty base_model.""" + concept_data = { + "name": "invalid_concept", + "base_model": "", # Empty base model + "columns": [{"name": "id"}], + } + + unparsed = UnparsedConcept( + name=concept_data["name"], + base_model=concept_data["base_model"], + columns=[UnparsedConceptColumn(name="id")], + ) + + # This should parse successfully but with empty base_model + concept_parser.parse_concept(unparsed=unparsed) + + # Verify it was added to manifest + concept_parser.manifest.add_concept.assert_called_once() + + def test_parse_concept_with_no_joins(self, concept_parser, mock_schema_parser): + """Test parsing a concept with no joins.""" + concept_data = { + "name": "simple_orders", + "base_model": "stg_orders", + "primary_key": "order_id", + "columns": [{"name": "order_id"}, {"name": "status"}], + "joins": [], + } + + unparsed = UnparsedConcept( + name=concept_data["name"], + base_model=concept_data["base_model"], + primary_key=concept_data["primary_key"], + columns=[UnparsedConceptColumn(name=col["name"]) for col in concept_data["columns"]], + joins=[], + ) + + concept_parser.parse_concept(unparsed=unparsed) + + mock_schema_parser.manifest.add_concept.assert_called_once() + + # Get the parsed concept from the call arguments + call_args = mock_schema_parser.manifest.add_concept.call_args[0] + parsed_concept = call_args[1] # Second argument is the concept + + assert parsed_concept.name == "simple_orders" + assert len(parsed_concept.joins) == 0 + assert len(parsed_concept.columns) == 2 + + def test_parse_multiple_concepts(self, concept_parser, mock_schema_parser): + """Test parsing multiple concepts in one file.""" + concepts_data = [ + { + "name": "orders", + "base_model": "stg_orders", + "primary_key": "order_id", + "columns": [{"name": "order_id"}], + "joins": [], + }, + { + "name": "customers", + "base_model": "stg_customers", + "primary_key": "customer_id", + "columns": [{"name": "customer_id"}], + "joins": [], + }, + ] + + unparsed_concepts = [ + UnparsedConcept( + name=concept["name"], + base_model=concept["base_model"], + primary_key=concept["primary_key"], + columns=[UnparsedConceptColumn(name="order_id")], + joins=[], + ) + for concept in concepts_data + ] + + # Parse all concepts + for unparsed in unparsed_concepts: + concept_parser.parse_concept(unparsed=unparsed) + + # Should have called add_concept twice + assert mock_schema_parser.manifest.add_concept.call_count == 2 diff --git a/tests/unit/test_concept_implementation.py b/tests/unit/test_concept_implementation.py new file mode 100644 index 00000000000..13aef2fc479 --- /dev/null +++ b/tests/unit/test_concept_implementation.py @@ -0,0 +1,173 @@ +from unittest.mock import Mock + +import pytest + +from dbt.artifacts.resources.v1.concept import Concept, ConceptColumn, ConceptJoin +from dbt.context.providers import ParseConceptResolver, RuntimeConceptResolver +from dbt.contracts.graph.nodes import ParsedConcept +from dbt.contracts.graph.unparsed import ( + UnparsedConcept, + UnparsedConceptColumn, + UnparsedConceptJoin, +) + + +class TestConceptImplementation: + def test_concept_column_creation(self): + """Test that ConceptColumn can be created with basic attributes.""" + column = ConceptColumn(name="test_column", description="A test column") + assert column.name == "test_column" + assert column.description == "A test column" + + def test_concept_join_creation(self): + """Test that ConceptJoin can be created with join attributes.""" + join = ConceptJoin( + name="test_join", + base_key="id", + foreign_key="test_id", + alias="test_alias", + columns=[ConceptColumn(name="col1")], + ) + assert join.name == "test_join" + assert join.base_key == "id" + assert join.foreign_key == "test_id" + assert join.alias == "test_alias" + assert len(join.columns) == 1 + + def test_unparsed_concept_creation(self): + """Test that UnparsedConcept can be created.""" + unparsed = UnparsedConcept( + name="test_concept", base_model="base_table", primary_key="id", columns=[], joins=[] + ) + assert unparsed.name == "test_concept" + assert unparsed.base_model == "base_table" + assert unparsed.primary_key == "id" + + def test_concept_resolver_initialization(self): + """Test that concept resolvers can be initialized.""" + # Mock dependencies + mock_db_wrapper = Mock() + mock_model = Mock() + mock_config = Mock() + mock_manifest = Mock() + + # Add required attributes + mock_config.project_name = "test_project" + mock_db_wrapper.Relation = Mock() + + parse_resolver = ParseConceptResolver( + db_wrapper=mock_db_wrapper, + model=mock_model, + config=mock_config, + manifest=mock_manifest, + ) + + runtime_resolver = RuntimeConceptResolver( + db_wrapper=mock_db_wrapper, + model=mock_model, + config=mock_config, + manifest=mock_manifest, + ) + + assert parse_resolver.current_project == "test_project" + assert runtime_resolver.current_project == "test_project" + + def test_concept_available_columns_mapping(self): + """Test that RuntimeConceptResolver can map available columns.""" + # Mock dependencies + mock_db_wrapper = Mock() + mock_model = Mock() + mock_config = Mock() + mock_manifest = Mock() + + # Add required attributes + mock_config.project_name = "test_project" + mock_db_wrapper.Relation = Mock() + + resolver = RuntimeConceptResolver( + db_wrapper=mock_db_wrapper, + model=mock_model, + config=mock_config, + manifest=mock_manifest, + ) + + # Create a mock concept + concept = Mock() + concept.columns = [ConceptColumn(name="base_col1"), ConceptColumn(name="base_col2")] + concept.joins = [ + ConceptJoin( + name="join1", + base_key="id", + foreign_key="join_id", + alias="j1", + columns=[ConceptColumn(name="join_col1")], + ) + ] + + available_columns = resolver._get_available_columns(concept) + + # Should include base columns and join columns + assert "base_col1" in available_columns + assert "base_col2" in available_columns + assert "join_col1" in available_columns + + # Check column source mapping + assert available_columns["base_col1"]["source"] == "base" + assert available_columns["join_col1"]["source"] == "join" + + def test_determine_required_joins(self): + """Test that RuntimeConceptResolver can determine required joins.""" + # Mock dependencies + mock_db_wrapper = Mock() + mock_model = Mock() + mock_config = Mock() + mock_manifest = Mock() + + # Add required attributes + mock_config.project_name = "test_project" + mock_db_wrapper.Relation = Mock() + + resolver = RuntimeConceptResolver( + db_wrapper=mock_db_wrapper, + model=mock_model, + config=mock_config, + manifest=mock_manifest, + ) + + # Create a mock concept for testing + concept = Mock() + concept.columns = [ConceptColumn(name="base_col")] + concept.joins = [ + ConceptJoin( + name="join1", + alias="j1", + base_key="id", + foreign_key="join_id", + columns=[ConceptColumn(name="join_col")], + ), + ConceptJoin( + name="join2", + alias="j2", + base_key="id", + foreign_key="join_id", + columns=[ConceptColumn(name="other_join_col")], + ), + ] + + # Test with columns that require only one join + requested_columns = ["base_col", "join_col"] + required_joins = resolver._determine_required_joins(concept, requested_columns) + + # Should only include j1 join, not j2 + assert len(required_joins) == 1 + assert required_joins[0].alias == "j1" + + # Test with columns that require both joins + requested_columns = ["base_col", "join_col", "other_join_col"] + required_joins = resolver._determine_required_joins(concept, requested_columns) + + # Should include both joins + assert len(required_joins) == 2 + aliases = [join.alias for join in required_joins] + assert "j1" in aliases + assert "j2" in aliases diff --git a/tests/unit/test_node_types.py b/tests/unit/test_node_types.py index 87bbf51e3a1..df73b377574 100644 --- a/tests/unit/test_node_types.py +++ b/tests/unit/test_node_types.py @@ -21,6 +21,7 @@ NodeType.Unit: "unit_tests", NodeType.SavedQuery: "saved_queries", NodeType.Fixture: "fixtures", + NodeType.Concept: "concepts", }