Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions aligned/redis/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ async def to_lazy_polars(self) -> pl.LazyFrame:
[
(
pl.concat_str(
[pl.lit(request.location.identifier), pl.lit(":")]
+ [pl.col(col) for col in sorted(request.entity_names)]
[pl.lit(request.location.identifier)]
+ [pl.col(col) for col in sorted(request.entity_names)],
separator=":"
)
).alias(redis_combine_id),
pl.col(list(request.entity_names)),
Expand All @@ -67,10 +68,23 @@ async def to_lazy_polars(self) -> pl.LazyFrame:

features = list(feature.name for feature in needed_features)

import snappy

async with redis.pipeline(transaction=False) as pipe:
for entity in entities[redis_combine_id]:
pipe.hmget(entity, keys=features)
result = await pipe.execute()

result_bytes = await pipe.execute()

result = []
for row in result_bytes:
result.append(
[
snappy.uncompress(val).decode()
if val else None
for val in row
]
)

reqs: pl.DataFrame = pl.concat(
[entities, pl.DataFrame(result, schema=features, orient="row")],
Expand Down
22 changes: 13 additions & 9 deletions aligned/redis/tests/test_redis_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from aligned.schemas.date_formatter import DateFormatter
from aligned.schemas.feature import Feature, FeatureLocation, FeatureType
from aligned.sources.redis import RedisConfig, RedisSource
import snappy


@pytest.fixture
Expand All @@ -30,7 +31,8 @@ def retrieval_request() -> RetrievalRequest:

@pytest.mark.asyncio
async def test_factual_redis_job(mocker, retrieval_request) -> None: # type: ignore[no-untyped-def]
values = ["20", "44"]
raw_values = ["20", "44"]
values = [[snappy.compress(val)] for val in raw_values]

redis_mock = mocker.patch.object(Pipeline, "execute", return_value=values)

Expand All @@ -47,13 +49,14 @@ async def test_factual_redis_job(mocker, retrieval_request) -> None: # type: ig

result = await job.to_pandas()
redis_mock.assert_called_once()
x_result = [int(value) for value in values] + [0, 0]
x_result = [int(v) for v in raw_values] + [0, 0]
assert np.all(result["x"].fillna(0).values == x_result), f"Got {result}" # type: ignore


@pytest.mark.asyncio
async def test_factual_redis_job_int_as_str(mocker, retrieval_request) -> None: # type: ignore[no-untyped-def]
values = ["20", "44"]
raw_values = ["20", "44"]
values = [[snappy.compress(val)] for val in raw_values]

redis_mock = mocker.patch.object(Pipeline, "execute", return_value=values)

Expand All @@ -71,13 +74,13 @@ async def test_factual_redis_job_int_as_str(mocker, retrieval_request) -> None:

result = await job.to_pandas()
redis_mock.assert_called_once()
x_result = [int(value) for value in values] + [0, 0]
x_result = [int(v) for v in raw_values] + [0, 0]
assert np.all(result["x"].fillna(0).values == x_result) # type: ignore


@pytest.mark.asyncio
async def test_nan_entities_job(mocker, retrieval_request) -> None: # type: ignore[no-untyped-def]
values = ["20", "44"]
values = [snappy.compress(val) for val in ["20", "44"]]

redis_mock = mocker.patch.object(Pipeline, "execute", return_value=values)

Expand All @@ -98,7 +101,7 @@ async def test_nan_entities_job(mocker, retrieval_request) -> None: # type: ign

@pytest.mark.asyncio
async def test_no_entities_job(mocker, retrieval_request) -> None: # type: ignore[no-untyped-def]
values = ["20", "44"]
values = [snappy.compress(val) for val in ["20", "44"]]

redis_mock = mocker.patch.object(Pipeline, "execute", return_value=values)

Expand Down Expand Up @@ -130,7 +133,8 @@ async def test_factual_redis_job_int_entity(mocker) -> None: # type: ignore[no-
event_timestamp=None,
)

values = ["20", "44", "55"]
raw_values = ["20", "44", "55"]
values = [[snappy.compress(val)] for val in raw_values]

redis_mock = mocker.patch.object(Pipeline, "execute", return_value=values)

Expand All @@ -147,15 +151,15 @@ async def test_factual_redis_job_int_entity(mocker) -> None: # type: ignore[no-

result = await job.to_pandas()
redis_mock.assert_called_once()
x_result = [int(value) for value in values] + [0]
x_result = [int(v) for v in raw_values] + [0]
assert np.all(result["x"].fillna(0).values == x_result) # type: ignore


@pytest.mark.asyncio
async def test_write_job(mocker, retrieval_request: RetrievalRequest) -> None: # type: ignore[no-untyped-def]
import fakeredis.aioredis

redis = fakeredis.aioredis.FakeRedis(decode_responses=True)
redis = fakeredis.aioredis.FakeRedis(decode_responses=False)

_ = mocker.patch.object(RedisConfig, "redis", return_value=redis)

Expand Down
83 changes: 69 additions & 14 deletions aligned/sources/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@ class DatabricksConnectionConfig:
azure_client_secret: ConfigValue | None = None
azure_tenant_id: ConfigValue | None = None

spark_config: dict[str, str] | None = None

def __init__(
self,
host: str | ConfigValue,
Expand All @@ -200,13 +202,15 @@ def __init__(
azure_client_id: ConfigValue | None = None,
azure_client_secret: ConfigValue | None = None,
azure_tenant_id: ConfigValue | None = None,
spark_config: dict[str, str] | None = None,
) -> None:
self.host = LiteralValue.from_value(host)
self.cluster_id = LiteralValue.from_value(cluster_id) if cluster_id else None
self.token = LiteralValue.from_value(token) if token else None
self.azure_client_secret = azure_client_secret
self.azure_client_id = azure_client_id
self.azure_tenant_id = azure_tenant_id
self.spark_config = spark_config

def storage_provider(self) -> pl.CredentialProvider | None:
if self.azure_client_id and self.azure_client_secret and self.azure_tenant_id:
Expand All @@ -221,11 +225,24 @@ def storage_provider(self) -> pl.CredentialProvider | None:
return pl.CredentialProviderAzure(credential=creds)
return None

def with_spark_config(self, spark_config: dict[str, str]) -> DatabricksConnectionConfig:
merged = dict(self.spark_config or {})
merged.update(spark_config)
return DatabricksConnectionConfig(
host=self.host,
cluster_id=self.cluster_id,
token=self.token,
azure_client_id=self.azure_client_id,
azure_client_secret=self.azure_client_secret,
azure_tenant_id=self.azure_tenant_id,
spark_config=merged,
)

def with_auth(
self, token: str | ConfigValue, host: str | ConfigValue
) -> DatabricksConnectionConfig:
return DatabricksConnectionConfig(
cluster_id=self.cluster_id, token=token, host=host
cluster_id=self.cluster_id, token=token, host=host, spark_config=self.spark_config
)

@staticmethod
Expand Down Expand Up @@ -257,7 +274,7 @@ def with_cluster_id(
def catalog(self, catalog: str | ConfigValue) -> UnityCatalog:
return UnityCatalog(self, LiteralValue.from_value(catalog))

def connection(self) -> SparkSession:
def connection(self, spark_config: dict[str, str] | None = None) -> SparkSession:
from pyspark.errors import PySparkException

cluster_id = self.cluster_id
Expand Down Expand Up @@ -285,6 +302,13 @@ def connection(self) -> SparkSession:
if self.token:
builder = builder.token(self.token.read())

effective_config = dict(self.spark_config or {})
if spark_config:
effective_config.update(spark_config)

for key, value in effective_config.items():
builder = builder.config(key, value)

if cluster_id_value == "serverless":
spark = builder.getOrCreate()
try:
Expand Down Expand Up @@ -551,9 +575,22 @@ class UCFeatureTableSource(
):
config: DatabricksConnectionConfig
table: UnityCatalogTableConfig
spark_config: dict[str, str] | None = None

type_name = "uc_feature_table"

def _configured_connection(self) -> SparkSession:
return self.config.connection(spark_config=self.spark_config)

def with_spark_config(self, spark_config: dict[str, str]) -> UCFeatureTableSource:
merged = dict(self.spark_config or {})
merged.update(spark_config)
return UCFeatureTableSource(
config=self.config,
table=self.table,
spark_config=merged,
)

def job_group_key(self) -> str:
return "uc_feature_table"

Expand Down Expand Up @@ -661,7 +698,7 @@ async def freshness(self, feature: Feature) -> datetime | None:
.freshness()
)
"""
spark = self.config.connection()
spark = self._configured_connection()
return (
spark.sql(
f"SELECT MAX({feature.name}) as {feature.name} FROM {self.table.identifier()}"
Expand All @@ -684,15 +721,15 @@ async def overwrite(
) -> None:
client = databricks_fe.FeatureEngineeringClient()

conn = self.config.connection()
conn = self._configured_connection()
df = conn.createDataFrame(await job.unique_entities().to_pandas())

client.create_table(
name=self.table.identifier(), primary_keys=list(request.entity_names), df=df
)

def with_config(self, config: DatabricksConnectionConfig) -> UCFeatureTableSource:
return UCFeatureTableSource(config, self.table)
return UCFeatureTableSource(config, self.table, spark_config=self.spark_config)


def features_to_read(
Expand Down Expand Up @@ -801,6 +838,7 @@ class UnityCatalogTableAllJob(RetrievalJob, DatabricksSource):
_limit: int | None
where: Expression | None = field(default=None)
renamer: Renamer | None = field(default=None)
spark_config: dict[str, str] | None = field(default=None)

@property
def request_result(self) -> RequestResult:
Expand Down Expand Up @@ -830,7 +868,7 @@ async def to_pandas(self) -> pd.DataFrame:
return (await self.to_spark()).toPandas()

async def to_spark(self, session: SparkSession | None = None) -> SparkFrame:
con = session or self.config.connection()
con = session or self.config.connection(spark_config=self.spark_config)
spark_df = con.read.table(self.table.identifier())

if self.request.features_to_include:
Expand Down Expand Up @@ -914,9 +952,24 @@ class UCTableSource(CodableBatchDataSource, WritableFeatureSource, DatabricksSou
table: UnityCatalogTableConfig
should_overwrite_schema: bool = False
renamer: Renamer | None = None
spark_config: dict[str, str] | None = None

type_name = "uc_table"

def _configured_connection(self) -> SparkSession:
return self.config.connection(spark_config=self.spark_config)

def with_spark_config(self, spark_config: dict[str, str]) -> UCTableSource:
merged = dict(self.spark_config or {})
merged.update(spark_config)
return UCTableSource(
config=self.config,
table=self.table,
should_overwrite_schema=self.should_overwrite_schema,
renamer=self.renamer,
spark_config=merged,
)

def job_group_key(self) -> str:
return f"uc_table-{self.table.identifier()}"

Expand All @@ -926,19 +979,21 @@ def overwrite_schema(self, should_overwrite_schema: bool = True) -> UCTableSourc
table=self.table,
should_overwrite_schema=should_overwrite_schema,
renamer=self.renamer,
spark_config=self.spark_config,
)

def with_renames(self, renames: dict[str, str] | Renamer | None) -> UCTableSource:
if isinstance(renames, dict):
renames = Renamer.noop(renames)

return UCTableSource(
self.config, self.table, self.should_overwrite_schema, renames
self.config, self.table, self.should_overwrite_schema, renames, self.spark_config
)

def all_data(self, request: RetrievalRequest, limit: int | None) -> RetrievalJob:
return UnityCatalogTableAllJob(
self.config, self.table, request, limit, renamer=self.renamer
self.config, self.table, request, limit, renamer=self.renamer,
spark_config=self.spark_config,
)

def all_between_dates(
Expand Down Expand Up @@ -995,7 +1050,7 @@ async def schema(self) -> dict[str, FeatureType]:
Returns:
dict[str, FeatureType]: A dictionary containing the column name and the feature type
"""
spark = self.config.connection()
spark = self._configured_connection()
schema = spark.table(self.table.identifier()).schema

aligned_schema: dict[str, FeatureType] = {}
Expand All @@ -1011,7 +1066,7 @@ async def freshness(self, feature: Feature) -> datetime | None:
.freshness()
)
"""
spark = self.config.connection()
spark = self._configured_connection()
return (
spark.sql(
f"SELECT MAX({feature.name}) as {feature.name} FROM {self.table.identifier()}"
Expand All @@ -1025,7 +1080,7 @@ async def insert(self, job: RetrievalJob, request: RetrievalRequest) -> None:

expected_schema = request.spark_schema()

conn = self.config.connection()
conn = self._configured_connection()
spark_df = await job.to_spark(conn)

df = spark_df.select(
Expand All @@ -1046,7 +1101,7 @@ async def upsert(self, job: RetrievalJob, request: RetrievalRequest) -> None:

expected_schema = request.spark_schema()

conn = self.config.connection()
conn = self._configured_connection()
spark_df = await job.to_spark(conn)

df = spark_df.select(
Expand Down Expand Up @@ -1088,7 +1143,7 @@ async def overwrite(

expected_schema = request.spark_schema()

conn = self.config.connection()
conn = self._configured_connection()
spark_df = await job.to_spark(conn)

df = spark_df.select(
Expand All @@ -1103,7 +1158,7 @@ async def overwrite(
).saveAsTable(self.table.identifier())

def with_config(self, config: DatabricksConnectionConfig) -> UCTableSource:
return UCTableSource(config, self.table)
return UCTableSource(config, self.table, spark_config=self.spark_config)

async def feature_view_code(self, view_name: str) -> str:
from aligned.sources.renamer import snake_to_pascal
Expand Down
2 changes: 1 addition & 1 deletion aligned/sources/iceberg.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

@dataclass
class IcebergCatalog(Codable):
name: ConfigValue = field(default=ConfigValue.from_value("default"))
name: ConfigValue = field(default_factory=lambda: ConfigValue.from_value("default"))
config: dict[str, ConfigValue] = field(
default_factory=lambda: {
"type": ConfigValue.from_value("in-memory"),
Expand Down
6 changes: 4 additions & 2 deletions aligned/sources/lancedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ async def insert(self, job: "RetrievalJob", request: RetrievalRequest) -> None:
if df.is_empty():
return

arrow_table = df.to_arrow()
columns = list(request.all_returned_columns)
arrow_table = df.select(columns).to_arrow()
await table.add(arrow_table)

async def overwrite(self, job: "RetrievalJob", request: RetrievalRequest) -> None:
Expand Down Expand Up @@ -193,7 +194,8 @@ def first_embedding(features: set[Feature]) -> Feature | None:
polars_df = polars_df.select(pl.exclude("_distance"))
if df_cols > 1:
logger.info(f"Stacking {polars_df.columns} and {item.keys()}")
polars_df = polars_df.select(pl.exclude(org_columns)).hstack(
cols_to_exclude = [c for c in org_columns if c in polars_df.columns]
polars_df = polars_df.select(pl.exclude(cols_to_exclude)).hstack(
pl.DataFrame([item] * polars_df.height)
.select(org_columns)
.select(pl.exclude(embedding.name))
Expand Down
Loading
Loading