Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions sqlmesh/core/model/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -1706,6 +1706,7 @@ def render(

def render_seed(self) -> t.Iterator[QueryOrDF]:
import numpy as np
import pandas as pd

self._ensure_hydrated()

Expand Down Expand Up @@ -1746,8 +1747,6 @@ def render_seed(self) -> t.Iterator[QueryOrDF]:

# convert all date/time types to native pandas timestamp
for column in [*date_columns, *datetime_columns]:
import pandas as pd

df[column] = pd.to_datetime(df[column], infer_datetime_format=True, errors="ignore") # type: ignore

# extract datetime.date from pandas timestamp for DATE columns
Expand All @@ -1763,7 +1762,7 @@ def render_seed(self) -> t.Iterator[QueryOrDF]:
)

for column in bool_columns:
df[column] = df[column].apply(lambda i: str_to_bool(str(i)))
df[column] = df[column].apply(lambda i: None if pd.isna(i) else str_to_bool(str(i)))

df.loc[:, string_columns] = df[string_columns].mask(
cond=lambda x: x.notna(), # type: ignore
Expand Down
25 changes: 25 additions & 0 deletions tests/dbt/test_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,6 +947,31 @@ def test_seed_single_whitespace_is_na(tmp_path):
assert df["col_b"].to_list() == [1, None]


def test_seed_boolean_nulls_are_preserved(tmp_path):
seed_csv = tmp_path / "seed.csv"
with open(seed_csv, "w", encoding="utf-8") as fd:
fd.write("id,test_ind\n")
fd.write("1,null\n")
fd.write("2,false\n")
fd.write("3,true\n")
fd.write("4,null\n")

seed = SeedConfig(
name="test_model",
package="foo",
path=Path(seed_csv),
column_types={"test_ind": "boolean"},
)

context = DbtContext()
context.project_name = "foo"
context.target = DuckDbConfig(name="target", schema="test")
sqlmesh_seed = seed.to_sqlmesh(context)

df = next(sqlmesh_seed.render_seed())
assert df["test_ind"].to_list() == [None, False, True, None]


def test_seed_partial_column_inference(tmp_path):
seed_csv = tmp_path / "seed.csv"
with open(seed_csv, "w", encoding="utf-8") as fd:
Expand Down
Loading