diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index cc971695..81e77296 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -11,94 +11,40 @@ env: UV_VERSION: 0.5.21 jobs: - unittests_py_less_3_11: - name: Unit Tests Python=${{ matrix.python-version }} Pyspark=${{ matrix.pyspark-version }} Tensorflow=${{ matrix.tensorflow-version }} + unittests: + name: Unit Tests Python=${{ matrix.python-version }} Pyspark=${{ matrix.pyspark-version }} Keras=${{ matrix.keras-version }} runs-on: [ ubuntu-latest ] strategy: matrix: - # We match the last 2 Databricks LTS Runtime versions for pyspark - # and 3 Tensorflow versions within our package range that are not compatible with Python 3.11 - python-version: ["3.9", "3.10"] + python-version: ["3.10", "3.11", "3.12"] pyspark-version: ["3.4.1", "3.5.0"] - tensorflow-version: ["2.9.1", "2.10.1", "2.11.1"] + keras-version: ["3.3.0", "3.7.0", "3.10.0", "3.12.0"] steps: - name: Setup Local Repo - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Setup Python - uses: actions/setup-python@v3 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install uv run: | pip install --upgrade pip pip install "uv==$UV_VERSION" - - name: Fix Python Pyspark & Tensorflow Versions + - name: Install project and pin matrix versions run: | uv venv --python ${{ matrix.python-version }} - uv add pyspark==${{ matrix.pyspark-version }} - uv add tensorflow==${{ matrix.tensorflow-version }} + uv pip install -e ".[tensorflow]" + uv pip install pyspark==${{ matrix.pyspark-version }} keras==${{ matrix.keras-version }} - name: Run tests run: uv run -p ${{ matrix.python-version }} python -m pytest -n auto . - unittests_py_3_11: - name: Unit Tests Python=3.11 Pyspark=${{ matrix.pyspark-version }} Tensorflow=${{ matrix.tensorflow-version }} - runs-on: [ ubuntu-latest ] - strategy: - matrix: - # Only certain versions of pyspark and tensorflow are compatible with Python 3.11 - pyspark-version: ["3.4.1", "3.5.0"] - tensorflow-version: ["2.12.1", "2.13.1", "2.14.1", "2.15.1", "2.16.2", "2.17.1", "2.18.0"] - steps: - - name: Setup Local Repo - uses: actions/checkout@v3 - - name: Setup Python - uses: actions/setup-python@v3 - with: - python-version: "3.11" - - name: Install uv - run: | - pip install --upgrade pip - pip install "uv==$UV_VERSION" - - name: Fix Python Pyspark & Tensorflow Versions - run: | - uv venv --python 3.11 - uv add pyspark==${{ matrix.pyspark-version }} - uv add "tensorflow==${{ matrix.tensorflow-version }}; python_version >='3.9' and python_version <'3.12'" - - name: Run tests - run: uv run -p 3.11 python -m pytest -n auto . - unittests_py_3_12: - name: Unit Tests Python=3.12 Pyspark=${{ matrix.pyspark-version }} Tensorflow=${{ matrix.tensorflow-version }} - runs-on: [ ubuntu-latest ] - strategy: - matrix: - # Only certain versions of pyspark and tensorflow are compatible with Python 3.12 - pyspark-version: [ "3.4.1", "3.5.0" ] - tensorflow-version: [ "2.16.2", "2.17.1", "2.18.0" ] - steps: - - name: Setup Local Repo - uses: actions/checkout@v3 - - name: Setup Python - uses: actions/setup-python@v3 - with: - python-version: "3.12" - - name: Install uv - run: | - pip install --upgrade pip - pip install "uv==$UV_VERSION" - - name: Fix Python Pyspark & Tensorflow Versions - run: | - uv venv --python 3.12 - uv add pyspark==${{ matrix.pyspark-version }} - uv add "tensorflow==${{ matrix.tensorflow-version }}; python_version >='3.9' and python_version <='3.12'" - - name: Run tests - run: uv run -p 3.12 python -m pytest -n auto . formatting: name: Formatting Checks runs-on: [ ubuntu-latest ] steps: - name: Setup Local Repo - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Setup Python - uses: actions/setup-python@v3 + uses: actions/setup-python@v5 with: python-version: "3.10" - name: Install uv @@ -112,9 +58,9 @@ jobs: runs-on: [ ubuntu-latest ] steps: - name: Setup Local Repo - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Setup Python - uses: actions/setup-python@v3 + uses: actions/setup-python@v5 with: python-version: "3.10" - name: Install uv @@ -128,9 +74,9 @@ jobs: runs-on: [ ubuntu-latest ] steps: - name: Setup Local Repo - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Setup Python - uses: actions/setup-python@v3 + uses: actions/setup-python@v5 with: python-version: "3.10" - name: Install Pre-commit @@ -145,9 +91,9 @@ jobs: runs-on: [ ubuntu-latest ] steps: - name: Setup Local Repo - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Setup Python - uses: actions/setup-python@v3 + uses: actions/setup-python@v5 with: python-version: "3.10" - name: Install uv diff --git a/README.md b/README.md index 761d453c..c212053b 100644 --- a/README.md +++ b/README.md @@ -2,11 +2,11 @@ [![CI](https://github.com/ExpediaGroup/kamae/actions/workflows/ci.yaml/badge.svg)](https://github.com/ExpediaGroup/kamae/actions/workflows/ci.yaml) ![PyPI - Version](https://img.shields.io/pypi/v/kamae) -Kamae bridges the gap between offline data processing and online model serving. Build preprocessing pipelines in [Spark](https://spark.apache.org/) for big data workloads, then export them as [Keras](https://keras.io/) models for low-latency inference. +Kamae bridges the gap between offline data processing and online model serving. Build preprocessing pipelines in [Spark](https://spark.apache.org/) for big data workloads, then export them as [Keras 3](https://keras.io/) models for low-latency inference. **Multi-backend support** allows numeric operations to run on TensorFlow, JAX, or PyTorch backends, while string and datetime operations require TensorFlow. ## Why Kamae? -Training and serving often happen on different platforms. Spark for batch processing at scale, TensorFlow for low-latency inference. Manually reimplementing preprocessing logic in both places creates: +Training and serving often happen on different platforms. Spark for batch processing at scale, Keras for low-latency inference. Manually reimplementing preprocessing logic in both places creates: - **Training/serving skew**: Subtle bugs from inconsistent implementations - **Development overhead**: Writing and maintaining duplicate code - **Deployment friction**: Changes require updates in multiple systems @@ -19,8 +19,6 @@ Kamae solves this by generating the inference model directly from your Spark pip pip install kamae ``` -**Platform notes**: Kamae supports `tensorflow>=2.9.1,<2.19.0`. For Mac ARM with `tensorflow<2.13.0`, install `tensorflow-macos` manually. TensorFlow no longer supports Mac x86_64 from version 2.18.0 onwards. - ## Quick Start ```python @@ -62,7 +60,26 @@ keras_model.save("./preprocessing_model.keras") **Direct Keras Layers**: Import and compose Keras layers directly for non-tabular data or custom workflows. Browse available layers in the [transformation table](#supported-preprocessing-layers) below. -For Scikit-learn support (experimental, unmaintained), see [sklearn examples](examples/sklearn). +**Backend Selection**: Set `KERAS_BACKEND` environment variable before importing keras: +```python +import os +os.environ['KERAS_BACKEND'] = 'tensorflow' # or 'jax' or 'torch' +``` + +**Multi-backend layers** (numeric operations) work on all backends. **TensorFlow-only layers** (string/datetime operations) require TensorFlow backend. See the [Backend column](#supported-preprocessing-layers) in the transformation table below, or use the discovery API: + +```python +import kamae +# Get layers/transformers compatible with current backend +layers = kamae.get_compatible_layers() +transformers = kamae.get_compatible_transformers() + +# Get layers/transformers compatible with specific backend +jax_layers = kamae.get_compatible_layers('jax') +torch_transformers = kamae.get_compatible_transformers('torch') +``` + +**Note:** TensorFlow is a required dependency for Kamae, as the package includes TensorFlow-only layers. JAX and PyTorch backends provide an alternative execution path for numeric operations only. ## Documentation @@ -72,83 +89,82 @@ For Scikit-learn support (experimental, unmaintained), see [sklearn examples](ex - **[Shape parity](docs/achieving_shape_parity.md)**: Ensuring consistent shapes between Spark and Keras - **[Testing inference](docs/testing_inference.md)**: Validate model outputs with TensorFlow Serving - **[Adding transformers](docs/adding_transformer.md)**: Contributing new transformations +- **[Keras 3 Migration](docs/keras3_migration.md)**: Migrating to Keras 3 multi-backend (Kamae >3.0.0) ## Supported Preprocessing Layers -| Transformation | Description | Keras Layer | Spark Transformer | Scikit-learn Transformer | -|:-------------------------------:|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:------------------------------------------------------------------:|:-------------------------------------------------------------------------:|:-----------------------------------------------------------:| -| AbsoluteValue | Applies the `abs(x)` transform. | [Link](src/kamae/tensorflow/layers/absolute_value.py) | [Link](src/kamae/spark/transformers/absolute_value.py) | Not yet implemented | -| ArrayConcatenate | Assembles multiple features into a single array. | [Link](src/kamae/tensorflow/layers/array_concatenate.py) | [Link](src/kamae/spark/transformers/array_concatenate.py) | [Link](src/kamae/sklearn/transformers/array_concatenate.py) | -| ArrayCrop | Crops or pads a feature array to a consistent size. | [Link](src/kamae/tensorflow/layers/array_crop.py) | [Link](src/kamae/spark/transformers/array_crop.py) | Not yet implemented | -| ArrayReduceMax | Reduces the last dimension of a tensor by taking the maximum. | [Link](src/kamae/tensorflow/layers/array_reduce_max.py) | [Link](src/kamae/spark/transformers/array_reduce_max.py) | Not yet implemented | -| ArraySplit | Splits a feature array into multiple features. | [Link](src/kamae/tensorflow/layers/array_split.py) | [Link](src/kamae/spark/transformers/array_split.py) | [Link](src/kamae/sklearn/transformers/array_split.py) | -| ArraySubtractMinimum | Subtracts the minimum element in an array from therest to compute a timestamp difference. Ignores padded values. | [Link](src/kamae/tensorflow/layers/array_subtract_minimum.py) | [Link](src/kamae/spark/transformers/array_subtract_minimum.py) | Not yet implemented | -| BearingAngle | Compute the bearing angle (https://en.wikipedia.org/wiki/Bearing_(navigation)) between two pairs of lat/long. | [Link](src/kamae/tensorflow/layers/bearing_angle.py) | [Link](src/kamae/spark/transformers/bearing_angle.py) | Not yet implemented | -| Bin | Bins a numerical column into string categorical bins. Users can specify the bin values, labels and a default label. | [Link](src/kamae/tensorflow/layers/bin.py) | [Link](src/kamae/spark/transformers/bin.py) | Not yet implemented | -| BloomEncode | Hash encodes a string feature multiple times to create an array of indices. Useful for compressing input dimensions for embeddings. Paper: https://arxiv.org/pdf/1706.03993.pdf | [Link](src/kamae/tensorflow/layers/bloom_encode.py) | [Link](src/kamae/spark/transformers/bloom_encode.py) | Not yet implemented | -| Bucketize | Buckets a numerical column into integer bins. | [Link](src/kamae/tensorflow/layers/bucketize.py) | [Link](src/kamae/spark/transformers/bucketize.py) | Not yet implemented | -| ConditionalStandardScale | Normalises by the mean and standard deviation, with ability to: apply a mask on another column, not scale the zeros, and apply a non standard scaling function. | [Link](src/kamae/tensorflow/layers/conditional_standard_scale.py) | [Link](src/kamae/spark/estimators/conditional_standard_scale.py) | Not yet implemented | -| CosineSimilarity | Computes the cosine similarity between two array features. | [Link](src/kamae/tensorflow/layers/cosine_similarity.py) | [Link](src/kamae/spark/transformers/cosine_similarity.py) | Not yet implemented | -| CurrentDate | Returns the current date for use in other transformers. | [Link](src/kamae/tensorflow/layers/current_date.py) | [Link](src/kamae/spark/transformers/current_date.py) | Not yet implemented | -| CurrentDateTime | Returns the current date time in the format yyyy-MM-dd HH:mm:ss.SSS for use in other transformers. | [Link](src/kamae/tensorflow/layers/current_date_time.py) | [Link](src/kamae/spark/transformers/current_date_time.py) | Not yet implemented | -| CurrentUnixTimestamp | Returns the current unix timestamp in either seconds or milliseconds for use in other transformers. | [Link](src/kamae/tensorflow/layers/current_unix_timestamp.py) | [Link](src/kamae/spark/transformers/current_unix_timestamp.py) | Not yet implemented | -| DateAdd | Adds a static or dynamic number of days to a date feature. NOTE: Destroys any time component of the datetime if present. | [Link](src/kamae/tensorflow/layers/date_add.py) | [Link](src/kamae/spark/transformers/date_add.py) | Not yet implemented | -| DateDiff | Computes the number of days between two date features. | [Link](src/kamae/tensorflow/layers/date_diff.py) | [Link](src/kamae/spark/transformers/date_diff.py) | Not yet implemented | -| DateParse | Parses a string date of format YYYY-MM-DD to extract a given date part. E.g. day of year. | [Link](src/kamae/tensorflow/layers/date_parse.py) | [Link](src/kamae/spark/transformers/date_parse.py) | Not yet implemented | -| DateTimeToUnixTimestamp | Converts a UTC datetime string to unix timestamp. | [Link](src/kamae/tensorflow/layers/date_time_to_unix_timestamp.py) | [Link](src/kamae/spark/transformers/date_time_to_unix_timestamp.py) | Not yet implemented | -| Divide | Divides a single feature by a constant or divides multiple features against each other. | [Link](src/kamae/tensorflow/layers/divide.py) | [Link](src/kamae/spark/transformers/divide.py) | Not yet implemented | -| Exp | Applies the exp(x) operation to the feature. | [Link](src/kamae/tensorflow/layers/exp.py) | [Link](src/kamae/spark/transformers/exp.py) | Not yet implemented | -| Exponent | Applies the x^exponent to a single feature or x^y for multiple features. | [Link](src/kamae/tensorflow/layers/exponent.py) | [Link](src/kamae/spark/transformers/exponent.py) | Not yet implemented | -| HashIndex | Transforms strings to indices via a hash table of predeterminded size. | [Link](src/kamae/tensorflow/layers/hash_index.py) | [Link](src/kamae/spark/transformers/hash_index.py) | Not yet implemented | -| HaversineDistance | Computes the [haversine distance](https://en.wikipedia.org/wiki/Haversine_formula) between latitude and longitude pairs. | [Link](src/kamae/tensorflow/layers/haversine_distance.py) | [Link](src/kamae/spark/transformers/haversine_distance.py) | Not yet implemented | -| Identity | Applies the identity operation, leaving the input the same. | [Link](src/kamae/tensorflow/layers/identity.py) | [Link](src/kamae/spark/transformers/identity.py) | [Link](src/kamae/sklearn/transformers/identity.py) | -| IfStatement | Computes a simple if statement on a set of columns/tensors and/or constants. | [Link](src/kamae/tensorflow/layers/if_statement.py) | [Link](src/kamae/spark/transformers/if_statement.py) | Not yet implemented | -| Impute | Performs imputation of either mean or median value of the data over a specified mask. | [Link](src/kamae/tensorflow/layers/impute.py) | [Link](src/kamae/spark/transformers/impute.py) | Not yet implemented | -| LambdaFunction | Transforms an input (or multiple inputs) to an output (or multiple outputs) with a user provided tensorflow function. | [Link](src/kamae/tensorflow/layers/lambda_function.py) | [Link](src/kamae/spark/transformers/lambda_function.py) | Not yet implemented | -| ListMax | Computes the listwise max of a feature, optionally calculated only on the top items based on another given feature. | [Link](src/kamae/tensorflow/layers/list_max.py) | [Link](src/kamae/spark/transformers/list_max.py) | Not yet implemented | -| ListMean | Computes the listwise mean of a feature, optionally calculated only on the top items based on another given feature. | [Link](src/kamae/tensorflow/layers/list_mean.py) | [Link](src/kamae/spark/transformers/list_mean.py) | Not yet implemented | -| ListMedian | Computes the listwise median of a feature, optionally calculated only on the top items based on another given feature. | [Link](src/kamae/tensorflow/layers/list_median.py) | [Link](src/kamae/spark/transformers/list_median.py) | Not yet implemented | -| ListMin | Computes the listwise min of a feature, optionally calculated only on the top items based on another given feature. | [Link](src/kamae/tensorflow/layers/list_min.py) | [Link](src/kamae/spark/transformers/list_min.py) | Not yet implemented | -| ListRank | Computes the listwise rank (ordering) of a feature. | [Link](src/kamae/tensorflow/layers/list_rank.py) | [Link](src/kamae/spark/transformers/list_rank.py) | Not yet implemented | -| ListStdDev | Computes the listwise standard deviation of a feature, optionally calculated only on the top items based on another given feature. | [Link](src/kamae/tensorflow/layers/list_std_dev.py) | [Link](src/kamae/spark/transformers/list_std_dev.py) | Not yet implemented | -| Log | Applies the natural logarithm `log(alpha + x)` transform . | [Link](src/kamae/tensorflow/layers/log.py) | [Link](src/kamae/spark/transformers/log.py) | [Link](src/kamae/sklearn/transformers/log.py) | -| LogicalAnd | Performs an and(x, y) operation on multiple boolean features. | [Link](src/kamae/tensorflow/layers/logical_and.py) | [Link](src/kamae/spark/transformers/logical_and.py) | Not yet implemented | -| LogicalNot | Performs a not(x) operation on a single boolean feature. | [Link](src/kamae/tensorflow/layers/logical_not.py) | [Link](src/kamae/spark/transformers/logical_not.py) | Not yet implemented | -| LogicalOr | Performs an or(x, y) operation on multiple boolean features. | [Link](src/kamae/tensorflow/layers/logical_or.py) | [Link](src/kamae/spark/transformers/logical_or.py) | Not yet implemented | -| Max | Computes the maximum of a feature with a constant or multiple other features. | [Link](src/kamae/tensorflow/layers/max.py) | [Link](src/kamae/spark/transformers/max.py) | Not yet implemented | -| Mean | Computes the mean of a feature with a constant or multiple other features. | [Link](src/kamae/tensorflow/layers/mean.py) | [Link](src/kamae/spark/transformers/mean.py) | Not yet implemented | -| Min | Computes the minimum of a feature with a constant or multiple other features. | [Link](src/kamae/tensorflow/layers/min.py) | [Link](src/kamae/spark/transformers/min.py) | Not yet implemented | -| MinHashIndex | Creates an integer bit array from a set of strings using the [MinHash algorithm](https://en.wikipedia.org/wiki/MinHash). | [Link](src/kamae/tensorflow/layers/min_hash_index.py) | [Link](src/kamae/spark/transformers/min_hash_index.py) | Not yet implemented | -| MinMaxScale | Scales the input feature by the min/max resulting in a feature in [0, 1]. | [Link](src/kamae/tensorflow/layers/min_max_scale.py) | [Link](src/kamae/spark/transformers/min_max_scale.py) | Not yet implemented | -| Modulo | Computes the modulo of a feature with the mod divisor being a constant or another feature. | [Link](src/kamae/tensorflow/layers/modulo.py) | [Link](src/kamae/spark/transformers/modulo.py) | Not yet implemented | -| Multiply | Multiplies a single feature by a constant or multiples multiple features together. | [Link](src/kamae/tensorflow/layers/multiply.py) | [Link](src/kamae/spark/transformers/multiply.py) | Not yet implemented | -| NumericalIfStatement | Performs a simple if else statement witha given operator. Value to check, result if true or false can be constants or features. | [Link](src/kamae/tensorflow/layers/numerical_if_statement.py) | [Link](src/kamae/spark/transformers/numerical_if_statement.py) | Not yet implemented | -| OneHotEncode | Transforms a string to a one-hot array. | [Link](src/kamae/tensorflow/layers/one_hot_encode.py) | [Link](src/kamae/spark/estimators/one_hot_encode.py) | Not yet implemented | -| OrdinalArrayEncode | Encodes strings in an array according to the order in which they appear. Only for 2D tensors. | [Link](src/kamae/tensorflow/layers/ordinal_array_encoder.py) | [Link](src/kamae/spark/estimators/ordinal_array_encoder.py) | Not yet implemented | -| PairwiseCosineSimilarity | Computes the cosine similarity between an embedding and a list of candidate embeddings. | [Link](src/kamae/tensorflow/layers/pairwise_cosine_similarity.py) | [Link](src/kamae/spark/transformers/pairwise_cosine_similarity.py) | Not yet implemented | -| Round | Rounds a floating feature to the nearest integer using `ceil`, `floor` or a standard `round` op. | [Link](src/kamae/tensorflow/layers/round.py) | [Link](src/kamae/spark/transformers/round.py) | Not yet implemented | -| RoundToDecimal | Rounds a floating feature to the nearest decimal precision. | [Link](src/kamae/tensorflow/layers/round_to_decimal.py) | [Link](src/kamae/spark/transformers/round_to_decimal.py) | Not yet implemented | -| SharedOneHotEncode | Transforms a string to a one-hot array, using labels across multiple inputs to determine the one-hot size. | [Link](src/kamae/tensorflow/layers/one_hot_encode.py) | [Link](src/kamae/spark/estimators/shared_one_hot_encode.py) | Not yet implemented | -| SharedStringIndex | Transforms strings to indices via a vocabulary lookup, sharing the vocabulary across multiple inputs. | [Link](src/kamae/tensorflow/layers/string_index.py) | [Link](src/kamae/spark/estimators/shared_string_index.py) | Not yet implemented | -| SingleFeatureArrayStandardScale | Normalises by the mean and standard deviation calculated over all elements of all inputs, with ability to mask a specified value. | [Link](src/kamae/tensorflow/layers/standard_scale.py) | [Link](src/kamae/spark/estimators/single_feature_array_standard_scale.py) | Not yet implemented | -| StandardScale | Normalises by the mean and standard deviation, with ability to mask a specified value. | [Link](src/kamae/tensorflow/layers/standard_scale.py) | [Link](src/kamae/spark/estimators/standard_scale.py) | [Link](src/kamae/sklearn/estimators/standard_scale.py) | -| StringAffix | Prefixes and suffixes a string with provided constants. | [Link](src/kamae/tensorflow/layers/string_affix.py) | [Link](src/kamae/spark/transformers/string_affix.py) | Not yet implemented | -| StringArrayConstant | Inserts provided string array constant into a column. | [Link](src/kamae/tensorflow/layers/string_array_constant.py) | [Link](src/kamae/spark/transformers/string_array_constant.py) | Not yet implemented | -| StringCase | Applies an upper or lower casing operation to the feature. | [Link](src/kamae/tensorflow/layers/string_case.py) | [Link](src/kamae/spark/transformers/string_case.py) | Not yet implemented | -| StringConcatenate | Joins string columns using the provided separator. | [Link](src/kamae/tensorflow/layers/string_concatenate.py) | [Link](src/kamae/spark/transformers/string_concatenate.py) | Not yet implemented | -| StringContains | Checks for the existence of a constant or tensor-element substring within a feature. | [Link](src/kamae/tensorflow/layers/string_contains.py) | [Link](src/kamae/spark/transformers/string_contains.py) | Not yet implemented | -| StringContainsList | Checks for the existence of any string from a list of string constants within a feature. | [Link](src/kamae/tensorflow/layers/string_contains_list.py) | [Link](src/kamae/spark/transformers/string_contains_list.py) | Not yet implemented | -| StringEqualsIfStatement | Performs a simple if else statement on string equality. Value to check, result if true or false can be constants or features. | [Link](src/kamae/tensorflow/layers/string_equals_if_statement.py) | [Link](src/kamae/spark/transformers/string_equals_if_statement.py) | Not yet implemented | -| StringIndex | Transforms strings to indices via a vocabulary lookup | [Link](src/kamae/tensorflow/layers/string_index.py) | [Link](src/kamae/spark/estimators/string_index.py) | Not yet implemented | -| StringListToString | Concatenates a list of strings to a single string with a given delimiter. | [Link](src/kamae/tensorflow/layers/string_list_to_string.py) | [Link](src/kamae/spark/transformers/string_list_to_string.py) | Not yet implemented | -| StringMap | Maps a list of string values to a list of other string values with a standard CASE WHEN statement. Can provide a default value for ELSE. | [Link](src/kamae/tensorflow/layers/string_map.py) | [Link](src/kamae/spark/transformers/string_map.py) | Not yet implemented | -| StringIsInList | Checks if the feature is equal to at least one of the strings provided. | [Link](src/kamae/tensorflow/layers/string_isin_list.py) | [Link](src/kamae/spark/transformers/string_isin_list.py) | Not yet implemented | -| StringReplace | Performs a regex replace operation on a feature with constant params or between multiple features | [Link](src/kamae/tensorflow/layers/string_replace.py) | [Link](src/kamae/spark/transformers/string_replace.py) | Not yet implemented | -| StringToStringList | Splits a string by a separator, returning a list of parametrised length (with a default value for missing inputs). | [Link](src/kamae/tensorflow/layers/string_to_string_list.py) | [Link](src/kamae/spark/transformers/string_to_string_list.py) | Not yet implemented | -| SubStringDelimAtIndex | Splits a string column using the provided delimiter, and returns the value at the index given. If the index is out of bounds, returns a given default value | [Link](src/kamae/tensorflow/layers/sub_string_delim_at_index.py) | [Link](src/kamae/spark/transformers/sub_string_delim_at_index.py) | Not yet implemented | -| Subtract | Subtracts a constant from a single feature or subtracts multiple features from each other. | [Link](src/kamae/tensorflow/layers/subtract.py) | [Link](src/kamae/spark/transformers/subtract.py) | Not yet implemented | -| Sum | Adds a constant to a single feature or sums multiple features together. | [Link](src/kamae/tensorflow/layers/sum.py) | [Link](src/kamae/spark/transformers/sum.py) | Not yet implemented | -| UnixTimestampToDateTime | Converts a unix timestamp to a UTC datetime string. | [Link](src/kamae/tensorflow/layers/unix_timestamp_to_date_time.py) | [Link](src/kamae/spark/transformers/unix_timestamp_to_date_time.py) | Not yet implemented | +| Transformation | Description | Keras Layer | Backend | Spark Transformer | +|:-------------------------------:|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:------------------------------------------------------------------:|:----------------:|:-------------------------------------------------------------------------:| +| AbsoluteValue | Applies the `abs(x)` transform. | [Link](src/kamae/keras/core/layers/absolute_value.py) | Multi-backend | [Link](src/kamae/spark/transformers/absolute_value.py) | +| ArrayConcatenate | Assembles multiple features into a single array. | [Link](src/kamae/keras/core/layers/array_concatenate.py) | Multi-backend | [Link](src/kamae/spark/transformers/array_concatenate.py) | +| ArrayCrop | Crops or pads a feature array to a consistent size. | [Link](src/kamae/keras/core/layers/array_crop.py) | Multi-backend | [Link](src/kamae/spark/transformers/array_crop.py) | +| ArraySplit | Splits a feature array into multiple features. | [Link](src/kamae/keras/core/layers/array_split.py) | Multi-backend | [Link](src/kamae/spark/transformers/array_split.py) | +| ArraySubtractMinimum | Subtracts the minimum element in an array from therest to compute a timestamp difference. Ignores padded values. | [Link](src/kamae/keras/core/layers/array_subtract_minimum.py) | Multi-backend | [Link](src/kamae/spark/transformers/array_subtract_minimum.py) | +| BearingAngle | Compute the bearing angle (https://en.wikipedia.org/wiki/Bearing_(navigation)) between two pairs of lat/long. | [Link](src/kamae/keras/core/layers/bearing_angle.py) | Multi-backend | [Link](src/kamae/spark/transformers/bearing_angle.py) | +| Bin | Bins a numerical column into string categorical bins. Users can specify the bin values, labels and a default label. | [Link](src/kamae/keras/core/layers/bin.py) | Multi-backend | [Link](src/kamae/spark/transformers/bin.py) | +| BloomEncode | Hash encodes a string feature multiple times to create an array of indices. Useful for compressing input dimensions for embeddings. Paper: https://arxiv.org/pdf/1706.03993.pdf | [Link](src/kamae/keras/tensorflow/layers/bloom_encode.py) | TensorFlow-only | [Link](src/kamae/spark/transformers/bloom_encode.py) | +| Bucketize | Buckets a numerical column into integer bins. | [Link](src/kamae/keras/tensorflow/layers/bucketize.py) | TensorFlow-only | [Link](src/kamae/spark/transformers/bucketize.py) | +| ConditionalStandardScale | Normalises by the mean and standard deviation, with ability to: apply a mask on another column, not scale the zeros, and apply a non standard scaling function. | [Link](src/kamae/keras/core/layers/conditional_standard_scale.py) | Multi-backend | [Link](src/kamae/spark/estimators/conditional_standard_scale.py) | +| CosineSimilarity | Computes the cosine similarity between two array features. | [Link](src/kamae/keras/core/layers/cosine_similarity.py) | Multi-backend | [Link](src/kamae/spark/transformers/cosine_similarity.py) | +| CurrentDate | Returns the current date for use in other transformers. | [Link](src/kamae/keras/tensorflow/layers/current_date.py) | TensorFlow-only | [Link](src/kamae/spark/transformers/current_date.py) | +| CurrentDateTime | Returns the current date time in the format yyyy-MM-dd HH:mm:ss.SSS for use in other transformers. | [Link](src/kamae/keras/tensorflow/layers/current_date_time.py) | TensorFlow-only | [Link](src/kamae/spark/transformers/current_date_time.py) | +| CurrentUnixTimestamp | Returns the current unix timestamp in either seconds or milliseconds for use in other transformers. | [Link](src/kamae/keras/tensorflow/layers/current_unix_timestamp.py) | TensorFlow-only | [Link](src/kamae/spark/transformers/current_unix_timestamp.py) | +| DateAdd | Adds a static or dynamic number of days to a date feature. NOTE: Destroys any time component of the datetime if present. | [Link](src/kamae/keras/tensorflow/layers/date_add.py) | TensorFlow-only | [Link](src/kamae/spark/transformers/date_add.py) | +| DateDiff | Computes the number of days between two date features. | [Link](src/kamae/keras/tensorflow/layers/date_diff.py) | TensorFlow-only | [Link](src/kamae/spark/transformers/date_diff.py) | +| DateParse | Parses a string date of format YYYY-MM-DD to extract a given date part. E.g. day of year. | [Link](src/kamae/keras/tensorflow/layers/date_parse.py) | TensorFlow-only | [Link](src/kamae/spark/transformers/date_parse.py) | +| DateTimeToUnixTimestamp | Converts a UTC datetime string to unix timestamp. | [Link](src/kamae/keras/tensorflow/layers/date_time_to_unix_timestamp.py) | TensorFlow-only | [Link](src/kamae/spark/transformers/date_time_to_unix_timestamp.py) | +| Divide | Divides a single feature by a constant or divides multiple features against each other. | [Link](src/kamae/keras/core/layers/divide.py) | Multi-backend | [Link](src/kamae/spark/transformers/divide.py) | +| Exp | Applies the exp(x) operation to the feature. | [Link](src/kamae/keras/core/layers/exp.py) | Multi-backend | [Link](src/kamae/spark/transformers/exp.py) | +| Exponent | Applies the x^exponent to a single feature or x^y for multiple features. | [Link](src/kamae/keras/core/layers/exponent.py) | Multi-backend | [Link](src/kamae/spark/transformers/exponent.py) | +| HashIndex | Transforms strings to indices via a hash table of predeterminded size. | [Link](src/kamae/keras/tensorflow/layers/hash_index.py) | TensorFlow-only | [Link](src/kamae/spark/transformers/hash_index.py) | +| HaversineDistance | Computes the [haversine distance](https://en.wikipedia.org/wiki/Haversine_formula) between latitude and longitude pairs. | [Link](src/kamae/keras/core/layers/haversine_distance.py) | Multi-backend | [Link](src/kamae/spark/transformers/haversine_distance.py) | +| Identity | Applies the identity operation, leaving the input the same. | [Link](src/kamae/keras/core/layers/identity.py) | Multi-backend | [Link](src/kamae/spark/transformers/identity.py) | +| IfStatement | Computes a simple if statement on a set of columns/tensors and/or constants. | [Link](src/kamae/keras/tensorflow/layers/if_statement.py) | TensorFlow-only | [Link](src/kamae/spark/transformers/if_statement.py) | +| Impute | Performs imputation of either mean or median value of the data over a specified mask. | [Link](src/kamae/keras/core/layers/impute.py) | Multi-backend | [Link](src/kamae/spark/transformers/impute.py) | +| LambdaFunction | Transforms an input (or multiple inputs) to an output (or multiple outputs) with a user provided tensorflow function. | [Link](src/kamae/keras/tensorflow/layers/lambda_function.py) | TensorFlow-only | [Link](src/kamae/spark/transformers/lambda_function.py) | +| ListMax | Computes the listwise max of a feature, optionally calculated only on the top items based on another given feature. | [Link](src/kamae/keras/tensorflow/layers/list_max.py) | TensorFlow-only | [Link](src/kamae/spark/transformers/list_max.py) | +| ListMean | Computes the listwise mean of a feature, optionally calculated only on the top items based on another given feature. | [Link](src/kamae/keras/tensorflow/layers/list_mean.py) | TensorFlow-only | [Link](src/kamae/spark/transformers/list_mean.py) | +| ListMedian | Computes the listwise median of a feature, optionally calculated only on the top items based on another given feature. | [Link](src/kamae/keras/tensorflow/layers/list_median.py) | TensorFlow-only | [Link](src/kamae/spark/transformers/list_median.py) | +| ListMin | Computes the listwise min of a feature, optionally calculated only on the top items based on another given feature. | [Link](src/kamae/keras/tensorflow/layers/list_min.py) | TensorFlow-only | [Link](src/kamae/spark/transformers/list_min.py) | +| ListRank | Computes the listwise rank (ordering) of a feature. | [Link](src/kamae/keras/tensorflow/layers/list_rank.py) | TensorFlow-only | [Link](src/kamae/spark/transformers/list_rank.py) | +| ListStdDev | Computes the listwise standard deviation of a feature, optionally calculated only on the top items based on another given feature. | [Link](src/kamae/keras/tensorflow/layers/list_std_dev.py) | TensorFlow-only | [Link](src/kamae/spark/transformers/list_std_dev.py) | +| Log | Applies the natural logarithm `log(alpha + x)` transform . | [Link](src/kamae/keras/core/layers/log.py) | Multi-backend | [Link](src/kamae/spark/transformers/log.py) | +| LogicalAnd | Performs an and(x, y) operation on multiple boolean features. | [Link](src/kamae/keras/core/layers/logical_and.py) | Multi-backend | [Link](src/kamae/spark/transformers/logical_and.py) | +| LogicalNot | Performs a not(x) operation on a single boolean feature. | [Link](src/kamae/keras/core/layers/logical_not.py) | Multi-backend | [Link](src/kamae/spark/transformers/logical_not.py) | +| LogicalOr | Performs an or(x, y) operation on multiple boolean features. | [Link](src/kamae/keras/core/layers/logical_or.py) | Multi-backend | [Link](src/kamae/spark/transformers/logical_or.py) | +| Max | Computes the maximum of a feature with a constant or multiple other features. | [Link](src/kamae/keras/core/layers/max.py) | Multi-backend | [Link](src/kamae/spark/transformers/max.py) | +| Mean | Computes the mean of a feature with a constant or multiple other features. | [Link](src/kamae/keras/core/layers/mean.py) | Multi-backend | [Link](src/kamae/spark/transformers/mean.py) | +| Min | Computes the minimum of a feature with a constant or multiple other features. | [Link](src/kamae/keras/core/layers/min.py) | Multi-backend | [Link](src/kamae/spark/transformers/min.py) | +| MinHashIndex | Creates an integer bit array from a set of strings using the [MinHash algorithm](https://en.wikipedia.org/wiki/MinHash). | [Link](src/kamae/keras/tensorflow/layers/min_hash_index.py) | TensorFlow-only | [Link](src/kamae/spark/transformers/min_hash_index.py) | +| MinMaxScale | Scales the input feature by the min/max resulting in a feature in [0, 1]. | [Link](src/kamae/keras/core/layers/min_max_scale.py) | Multi-backend | [Link](src/kamae/spark/transformers/min_max_scale.py) | +| Modulo | Computes the modulo of a feature with the mod divisor being a constant or another feature. | [Link](src/kamae/keras/core/layers/modulo.py) | Multi-backend | [Link](src/kamae/spark/transformers/modulo.py) | +| Multiply | Multiplies a single feature by a constant or multiples multiple features together. | [Link](src/kamae/keras/core/layers/multiply.py) | Multi-backend | [Link](src/kamae/spark/transformers/multiply.py) | +| NumericalIfStatement | Performs a simple if else statement witha given operator. Value to check, result if true or false can be constants or features. | [Link](src/kamae/keras/core/layers/numerical_if_statement.py) | Multi-backend | [Link](src/kamae/spark/transformers/numerical_if_statement.py) | +| OneHotEncode | Transforms a string to a one-hot array. | [Link](src/kamae/keras/tensorflow/layers/one_hot_encode.py) | TensorFlow-only | [Link](src/kamae/spark/estimators/one_hot_encode.py) | +| OrdinalArrayEncode | Encodes strings in an array according to the order in which they appear. Only for 2D tensors. | [Link](src/kamae/keras/tensorflow/layers/ordinal_array_encoder.py) | TensorFlow-only | [Link](src/kamae/spark/estimators/ordinal_array_encoder.py) | +| Round | Rounds a floating feature to the nearest integer using `ceil`, `floor` or a standard `round` op. | [Link](src/kamae/keras/core/layers/round.py) | Multi-backend | [Link](src/kamae/spark/transformers/round.py) | +| RoundToDecimal | Rounds a floating feature to the nearest decimal precision. | [Link](src/kamae/keras/core/layers/round_to_decimal.py) | Multi-backend | [Link](src/kamae/spark/transformers/round_to_decimal.py) | +| SharedOneHotEncode | Transforms a string to a one-hot array, using labels across multiple inputs to determine the one-hot size. | [Link](src/kamae/keras/tensorflow/layers/one_hot_encode.py) | TensorFlow-only | [Link](src/kamae/spark/estimators/shared_one_hot_encode.py) | +| SharedStringIndex | Transforms strings to indices via a vocabulary lookup, sharing the vocabulary across multiple inputs. | [Link](src/kamae/keras/tensorflow/layers/string_index.py) | TensorFlow-only | [Link](src/kamae/spark/estimators/shared_string_index.py) | +| SingleFeatureArrayStandardScale | Normalises by the mean and standard deviation calculated over all elements of all inputs, with ability to mask a specified value. | [Link](src/kamae/keras/tensorflow/layers/standard_scale.py) | TensorFlow-only | [Link](src/kamae/spark/estimators/single_feature_array_standard_scale.py) | +| StandardScale | Normalises by the mean and standard deviation, with ability to mask a specified value. | [Link](src/kamae/keras/core/layers/standard_scale.py) | Multi-backend | [Link](src/kamae/spark/estimators/standard_scale.py) | +| StringAffix | Prefixes and suffixes a string with provided constants. | [Link](src/kamae/keras/tensorflow/layers/string_affix.py) | TensorFlow-only | [Link](src/kamae/spark/transformers/string_affix.py) | +| StringArrayConstant | Inserts provided string array constant into a column. | [Link](src/kamae/keras/tensorflow/layers/string_array_constant.py) | TensorFlow-only | [Link](src/kamae/spark/transformers/string_array_constant.py) | +| StringCase | Applies an upper or lower casing operation to the feature. | [Link](src/kamae/keras/tensorflow/layers/string_case.py) | TensorFlow-only | [Link](src/kamae/spark/transformers/string_case.py) | +| StringConcatenate | Joins string columns using the provided separator. | [Link](src/kamae/keras/tensorflow/layers/string_concatenate.py) | TensorFlow-only | [Link](src/kamae/spark/transformers/string_concatenate.py) | +| StringContains | Checks for the existence of a constant or tensor-element substring within a feature. | [Link](src/kamae/keras/tensorflow/layers/string_contains.py) | TensorFlow-only | [Link](src/kamae/spark/transformers/string_contains.py) | +| StringContainsList | Checks for the existence of any string from a list of string constants within a feature. | [Link](src/kamae/keras/tensorflow/layers/string_contains_list.py) | TensorFlow-only | [Link](src/kamae/spark/transformers/string_contains_list.py) | +| StringEqualsIfStatement | Performs a simple if else statement on string equality. Value to check, result if true or false can be constants or features. | [Link](src/kamae/keras/tensorflow/layers/string_equals_if_statement.py) | TensorFlow-only | [Link](src/kamae/spark/transformers/string_equals_if_statement.py) | +| StringIndex | Transforms strings to indices via a vocabulary lookup | [Link](src/kamae/keras/tensorflow/layers/string_index.py) | TensorFlow-only | [Link](src/kamae/spark/estimators/string_index.py) | +| StringListToString | Concatenates a list of strings to a single string with a given delimiter. | [Link](src/kamae/keras/tensorflow/layers/string_list_to_string.py) | TensorFlow-only | [Link](src/kamae/spark/transformers/string_list_to_string.py) | +| StringMap | Maps a list of string values to a list of other string values with a standard CASE WHEN statement. Can provide a default value for ELSE. | [Link](src/kamae/keras/tensorflow/layers/string_map.py) | TensorFlow-only | [Link](src/kamae/spark/transformers/string_map.py) | +| StringIsInList | Checks if the feature is equal to at least one of the strings provided. | [Link](src/kamae/keras/tensorflow/layers/string_isin_list.py) | TensorFlow-only | [Link](src/kamae/spark/transformers/string_isin_list.py) | +| StringReplace | Performs a regex replace operation on a feature with constant params or between multiple features | [Link](src/kamae/keras/tensorflow/layers/string_replace.py) | TensorFlow-only | [Link](src/kamae/spark/transformers/string_replace.py) | +| StringToStringList | Splits a string by a separator, returning a list of parametrised length (with a default value for missing inputs). | [Link](src/kamae/keras/tensorflow/layers/string_to_string_list.py) | TensorFlow-only | [Link](src/kamae/spark/transformers/string_to_string_list.py) | +| SubStringDelimAtIndex | Splits a string column using the provided delimiter, and returns the value at the index given. If the index is out of bounds, returns a given default value | [Link](src/kamae/keras/tensorflow/layers/sub_string_delim_at_index.py) | TensorFlow-only | [Link](src/kamae/spark/transformers/sub_string_delim_at_index.py) | +| Subtract | Subtracts a constant from a single feature or subtracts multiple features from each other. | [Link](src/kamae/keras/core/layers/subtract.py) | Multi-backend | [Link](src/kamae/spark/transformers/subtract.py) | +| Sum | Adds a constant to a single feature or sums multiple features together. | [Link](src/kamae/keras/core/layers/sum.py) | Multi-backend | [Link](src/kamae/spark/transformers/sum.py) | +| UnixTimestampToDateTime | Converts a unix timestamp to a UTC datetime string. | [Link](src/kamae/keras/tensorflow/layers/unix_timestamp_to_date_time.py) | TensorFlow-only | [Link](src/kamae/spark/transformers/unix_timestamp_to_date_time.py) | ## Development diff --git a/docs/adding_transformer.md b/docs/adding_transformer.md index b495ad85..b02fbba0 100644 --- a/docs/adding_transformer.md +++ b/docs/adding_transformer.md @@ -1,4 +1,4 @@ -# Contributing a Keras layer and Spark/Scikit-learn transformer +# Contributing a Keras layer and Spark transformer Follow this guide to contribute a new transformer to the project. @@ -6,8 +6,6 @@ Follow this guide to contribute a new transformer to the project. In order to contribute a new transformer, you will need to implement a Spark Transformer, a corresponding Keras layer, and a Spark Estimator if your transformer needs a fit method. We also require unit tests for all new classes, in particular parity tests ensuring your Spark Transformer and Keras layer produce the same output. -You may wish to also implement a Scikit-learn transformer, however we deem the scikit-learn usage pattern to be experimental for now and so this is not required. - ## Naming In order to avoid name clashes and to keep consistency, we have a naming convention for all new classes. @@ -15,40 +13,46 @@ If an operation is called `` then: - `Estimator` = Spark estimator (if applicable) - `Transformer` = Spark transformer -- `Layer` = Tensorflow/Keras layer +- `Layer` = Keras layer - `Params` = Spark params class We just keep the verb stem. E.g string indexing is StringIndexTransformer, not StringIndexerTransformer. -The name of the file should then be `.py`. E.g. `src/kame/spark/transformers/string_index.py` and `src/kame/tensorflow/layers/string_index.py`. +The name of the file should then be `.py`. E.g. `src/kamae/spark/transformers/string_index.py` and `src/kamae/keras/core/layers/string_index.py` (for multi-backend layers) or `src/kamae/keras/tensorflow/layers/string_index.py` (for TensorFlow-only layers). Finally, if you need to create an estimator, then the estimator and its corresponding transformer should be in different files. E.g. `src/kame/spark/transformers/string_index.py` and `src/kame/spark/estimators/string_index.py`. ## Keras layer -Your Keras layer should extend [BaseLayer](../src/kamae/tensorflow/layers/base.py) and implement the `_call` method. Furthermore, you will need to define the `compatible_dtypes` property which should return a list of compatible dtypes for the layer (or `None` if the layer is compatible with all dtypes). +Your Keras layer should extend [BaseLayer](../src/kamae/keras/core/base.py) and implement the `_call` method. Furthermore, you will need to define the `compatible_dtypes` property which should return a list of compatible dtype strings (or `None` if the layer is compatible with all dtypes). You should ensure your layer is serializable by implementing the `get_config` method. -You also need to add the decorator `@tf.keras.utils.register_keras_serializable(package=kamae.__name__)` to the class. +You also need to add the decorator `@keras.saving.register_keras_serializable(package=kamae.__name__)` to the class. + +**Note:** Multi-backend layers should be placed in `src/kamae/keras/core/layers/` and use only Keras 3 operations. TensorFlow-only layers (those requiring TensorFlow-specific operations) should be placed in `src/kamae/keras/tensorflow/layers/` and can import TensorFlow for backend-specific functionality. ### Example ```python from typing import List, Optional -import tensorflow as tf +import keras import kamae -from .base import BaseLayer +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.base import BaseLayer -@tf.keras.utils.register_keras_serializable(package=kamae.__name__) +@keras.saving.register_keras_serializable(package=kamae.__name__) class MyLayer(BaseLayer): + supported_backends = ALL_BACKENDS + jit_compatible = True + def __init__(self, name, input_dtype, output_dtype, my_param, **kwargs): # Ensure that the name, input_dtype, and output_dtype are passed to the super constructor super().__init__(name=name, input_dtype=input_dtype, output_dtype=output_dtype, **kwargs) self.my_param = my_param @property - def compatible_dtypes(self) -> Optional[List[tf.DType]]: - return [tf.float32, tf.float64] + def compatible_dtypes(self) -> Optional[List[str]]: + return ["float32", "float64"] def _call(self, inputs): # do something with inputs @@ -62,18 +66,19 @@ class MyLayer(BaseLayer): ### Checklist -- [ ] I have implemented a Keras layer that extends [BaseLayer](../src/kamae/tensorflow/layers/base.py) +- [ ] I have implemented a Keras layer that extends [BaseLayer](../src/kamae/keras/core/base.py) +- [ ] I have defined `supported_backends` and `jit_compatible` class attributes on my layer. - [ ] I have implemented the `_call` method of my Keras layer. -- [ ] I have defined the `compatible_dtypes` property of my Keras layer. -- [ ] I have added the decorator `@tf.keras.utils.register_keras_serializable(package=kamae.__name__)` to my Keras layer. +- [ ] I have defined the `compatible_dtypes` property of my Keras layer, returning a list of dtype strings (e.g., `["float32", "float64"]`) or `None`. +- [ ] I have added the decorator `@keras.saving.register_keras_serializable(package=kamae.__name__)` to my Keras layer. - [ ] I have ensured that my layer takes a `name`, `input_dtype`, and `output_dtype` as arguments to the constructor and that this is passed to the super constructor. - [ ] My Keras layer is serializable. I have implemented the `get_config` method and added the decorator seen above to the class. - [ ] I have unit tests of my implementation. -- [ ] I have a specific test of layer serialisation added [here](../../tests/tensorflow/test_layer_serialisation.py). +- [ ] I have a specific test of layer serialisation added [here](../../tests/kamae/keras/test_layer_serialisation.py). ## Spark Transformer/Estimator Your Spark Transformer should extend [BaseTransformer](../src/kamae/spark/transformers/base.py). -In this it should implement the `get_tf_layer` method, which should return an instance of your Keras layer. +In this it should implement the `get_keras_layer` method, which should return an instance of your Keras layer. If your transformer needs a fit method, you should also implement a Spark Estimator (which extends [BaseEstimator](../src/kamae/spark/estimators/base.py)) whose fit method returns an instance of your transformer. Spark has a peculiar way of building constructors, in that the `__init__` calls a `setParams` method, which sets the parameters of the transformer. @@ -112,11 +117,13 @@ Note that the methods are named `_fit` and `_transform`. `fit` and `transform` w ```python from typing import List, Optional +import keras from pyspark import keyword_only from pyspark.ml.param import Param, Params, TypeConverters from pyspark.sql import DataFrame from pyspark.sql.types import DataType, StringType, BinaryType +from kamae.keras.core.backend import ALL_BACKENDS from kamae.spark.params import SingleInputSingleOutputParams from kamae.spark.transformers import BaseTransformer from kamae.spark.estimators import BaseEstimator @@ -144,6 +151,8 @@ class MyEstimator( SingleInputSingleOutputParams, MyCustomParams ): + supported_backends = ALL_BACKENDS + jit_compatible = True @keyword_only def __init__( @@ -180,6 +189,8 @@ class MyTransformer( SingleInputSingleOutputParams, MyCustomParams ): + supported_backends = ALL_BACKENDS + jit_compatible = True @keyword_only def __init__( @@ -199,13 +210,13 @@ class MyTransformer( def compatible_dtypes(self) -> Optional[List[DataType]]: return [StringType(), BinaryType()] - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> keras.layers.Layer: # Ensure that the layer has the layer name, input dtype, and output dtype # as arguments `name`, `input_dtype`, and `output_dtype` respectively. return MyLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - out_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + out_dtype=self.getOutputKerasDtype(), my_param=self.getMyParam(), ) @@ -220,72 +231,10 @@ class MyTransformer( ### Checklist - [ ] I have implemented a Spark Transformer that extends [BaseTransformer](../src/kamae/spark/transformers/base.py). - [ ] If my transformer needs a fit method, I have implemented a Spark Estimator that extends [BaseEstimator](../src/kamae/spark/estimators/base.py). +- [ ] I have defined `supported_backends` and `jit_compatible` class attributes on my transformer/estimator (not in the Params class). - [ ] I have followed the instructions for the `__init__` and `setParams` methods. - [ ] I have used one (or more) of the input/output mixin classes from [base.py](../src/kamae/spark/params/base.py). - [ ] If my transformer requires more parameters that would need to be serialised to the Spark ML pipeline, I have added a parameter class by extending the `Params` class [here](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.param.Params.html). - [ ] I have defined the `compatible_dtypes` property to specify the input/output data types that my transformer/estimator supports. -- [ ] I used a Keras subclassed layer for my `get_tf_layer` method. +- [ ] I used a Keras subclassed layer for my `get_keras_layer` method. - [ ] I have unit tests of my implementation. In particular, I have parity tests between the Spark and Keras implementations. - -## Scikit-learn Transformer/Estimator - -If your transformer is a wrapper around an existing Scikit-Learn transformer, you should -also extend the [BaseTransformerMixin](../src/kamae/sklearn/transformers/base.py) class. This will provide the required functionality -to be incorporated into the Kamae framework. - -If you are writing a custom transformer, you should extend the [BaseTransformer](../src/kamae/sklearn/transformers/base.py) class. -The only difference between these classes is that the `BaseTransformer` class also extends -the `BaseEstimator` and `TransformerMixin` classes from scikit-learn. If you are wrapping -an existing transformer, these are already extended by the transformer you are wrapping. -See the [StandardScaleEstimator](../src/kamae/sklearn/estimators/standard_scale.py) for an example of a wrapper around an existing transformer. -See the [IdentityTransformer](../src/kamae/sklearn/transformers/identity.py) for an example of a custom transformer. - -Additionally, your transformer should use one (or more) of the input/output mixin classes from [base.py](../src/kamae/sklearn/params/base.py) -- SingleInputSingleOutputMixin -- SingleInputMultiOutputMixin -- MultiInputSingleOutputMixin -- MultiInputMultiOutputMixin - -Only use more than one if you want to support two usages of your transformer. -We have no scikit-learn examples of this yet, only Spark. The behaviour is the same. -See above to the Spark section to understand why you may want to do this. - -In scikit-learn, everything is an estimator. If your transformer does not require a fit method, -just return `self` from the `fit` method. If your transformer does require a fit method, you -should implement it within the `fit` method of your transformer. - -### Example -```python -import pandas as pd -import tensorflow as tf -from kamae.sklearn.params import SingleInputSingleOutputMixin -from kamae.sklearn.transformers import BaseTransformer - -class MyTransformer( - BaseTransformer, SingleInputSingleOutputMixin -): - def __init__(self, input_col: str, output_col: str, layer_name: str) -> None: - super().__init__() - self.input_col = input_col - self.output_col = output_col - self.layer_name = layer_name - - def fit(self, X: pd.DataFrame, y=None) -> "MyTransformer": - return self - - def transform(self, X: pd.DataFrame, y=None) -> pd.DataFrame: - X[self.output_col] = output_of_transform - return X - - def get_tf_layer(self) -> tf.keras.layers.Layer: - return MyLayer( - name=self.layer_name, - ) -``` - -### Checklist -- [ ] I have implemented a Scikit-learn Transformer that extends [BaseTransformer](../src/kamae/sklearn/transformers/base.py) (if custom) or [BaseTransformerMixin](../src/kamae/sklearn/transformers/base.py) (if wrapping an existing transformer). -- [ ] If my transformer needs a fit method, I have implemented it within the `fit` method of my transformer. -- [ ] I have used one (or more) of the input/output mixin classes from [base.py](../src/kamae/sklearn/params/base.py). -- [ ] I used a Keras subclassed layer for my `get_tf_layer` method. -- [ ] I have unit tests of my implementation. In particular, I have parity tests between the scikit-learn and Keras implementations. diff --git a/docs/chaining_models.md b/docs/chaining_models.md index dbd3cb6c..ac0b521f 100644 --- a/docs/chaining_models.md +++ b/docs/chaining_models.md @@ -9,32 +9,20 @@ This method will return a Keras model that you can use to process your data. ### Accessing model inputs -The way in which you specify the `tf_input_schema` to this method can influence how you access your model inputs. +The way in which you specify the `input_schema` to this method can influence how you access your model inputs. #### 1. **List of dictionary config.** -This is the standard way of specifying the `tf_input_schema`. -In this case, you would pass the `tf_input_schema` as a list of dictionaries, where each dictionary specifies (at least) the name, shape and type of the input. -These dictionaries will be passed directly into the [`tf.keras.layers.Input`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/InputLayer) via ** kwargs, and so the names of the arguments will be the keys specified in the dictionary. +This is the standard way of specifying the `input_schema`. +In this case, you would pass the `input_schema` as a list of dictionaries, where each dictionary specifies (at least) the name, shape and dtype of the input. +These dictionaries will be passed directly into [`keras.layers.Input`](https://keras.io/api/layers/core_layers/input/) via ** kwargs, and so the names of the arguments will be the keys specified in the dictionary. -In this case, when accessing your model inputs, you can use the `inputs` attribute of the model, which is a list of `tf.keras.Input` objects. +In this case, when accessing your model inputs, you can use the `inputs` attribute of the model, which is a list of `keras.Input` objects. You can access the `name` attribute of each of these objects to get the name of the input. -These will match the names specified in the `tf_input_schema` dictionary. +These will match the names specified in the `input_schema` dictionary. -#### 2. **List of tf.TypeSpec.** - -If you have more complex inputs (e.g. a [`RaggedTensor`](https://www.tensorflow.org/api_docs/python/tf/RaggedTensor)) then you may find using [`tf.TypeSpec`](https://www.tensorflow.org/api_docs/python/tf/TypeSpec?hl=en) objects easier. -In this case, you would pass the `tf_input_schema` as a list of `tf.TypeSpec` objects. -Under the hood, these will be passed to the [`tf.keras.layers.Input`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/InputLayer) via the `typespec` argument. - -However, in this case, accessing the inputs of your model via the `inputs` attribute will return inputs with missing names (i.e. `None`). This is detailed in this [GitHub issue](https://github.com/keras-team/tf-keras/issues/406). - -In order to fix this you will need to zip the `input_names` attribute of your model with the `inputs` attribute, to assign the names to the inputs. - -```python -inputs_with_names = list(zip(model.input_names, model.inputs)) -``` +**Note**: For Keras 3, use dictionary config (method 1 above) as it's the most portable across backends. Complex TensorFlow-specific inputs like RaggedTensors are only supported on the TensorFlow backend. ### Accessing model outputs @@ -49,7 +37,7 @@ The output names match the pipeline output column names directly, and can be acc Assuming we have two models, `prepro_model` and `trained_model` which we want to chain together, we can do the following: ```python -import tensorflow as tf +import keras # Get the inputs of the prepro model prepro_inputs = prepro_model.inputs @@ -70,7 +58,7 @@ prepro_outputs_dict = { combined_outputs = trained_model(prepro_outputs_dict) # Create a new model with the prepro inputs and combined outputs -combined_model = tf.keras.Model(inputs=prepro_inputs, outputs=combined_outputs) +combined_model = keras.Model(inputs=prepro_inputs, outputs=combined_outputs) ``` ### Postprocessing example @@ -78,7 +66,7 @@ combined_model = tf.keras.Model(inputs=prepro_inputs, outputs=combined_outputs) Postprocessing works in a very similar way, you just change which model is applied to the other: ```python -import tensorflow as tf +import keras # Get the inputs of the trained model trained_inputs = trained_model.inputs @@ -99,5 +87,5 @@ trained_outputs_dict = { combined_outputs = postpro_model(trained_outputs_dict) # Create a new model with the trained inputs and combined outputs -combined_model = tf.keras.Model(inputs=trained_inputs, outputs=combined_outputs) +combined_model = keras.Model(inputs=trained_inputs, outputs=combined_outputs) ``` diff --git a/docs/keras3_migration.md b/docs/keras3_migration.md new file mode 100644 index 00000000..939eefad --- /dev/null +++ b/docs/keras3_migration.md @@ -0,0 +1,243 @@ +# Keras 3 Migration Guide + +This document summarizes the migration of Kamae to Keras 3. + +## Overview + +Kamae has been migrated from Keras 2 (tf.keras) to Keras 3, enabling multi-backend support while maintaining full backward compatibility for existing TensorFlow-based workflows. + +## Key Changes + +### 1. Multi-Backend Architecture + +Kamae now supports three backends: **TensorFlow**, **JAX**, and **PyTorch**. + +```python +# Set backend before importing keras +import os +os.environ['KERAS_BACKEND'] = 'tensorflow' # or 'jax' or 'torch' + +import keras +from kamae.keras.core.layers import AbsoluteValueLayer # Works on all backends +``` + +### 2. Package Structure + +``` +kamae/ +├── keras/ +│ ├── core/ # Backend-agnostic layers (numeric ops) +│ │ ├── base.py # Unified BaseLayer +│ │ ├── layers/ # 31 multi-backend layers +│ │ └── utils/ # Backend-agnostic utilities +│ └── tensorflow/ # TensorFlow-specific layers +│ ├── layers/ # 36 TF-only layers (strings, datetime) +│ └── utils/ # TF-specific utilities +├── spark/ # Spark transformers (unchanged) +├── graph/ # Pipeline graph (now backend-agnostic) +└── utils/ # General utilities +``` + +**Removed:** +- `kamae.tensorflow.layers/` - moved to `kamae.keras.core.layers/` or `kamae.keras.tensorflow.layers/` +- `kamae.sklearn/` - removed (was experimental, not maintained) + +### 3. Layer Categories + +#### Multi-Backend Layers (31 layers) +Located in `kamae.keras.core.layers/`, work on TensorFlow, JAX, and PyTorch: + +- **Numeric operations**: AbsoluteValue, Divide, Exp, Exponent, Log, Max, Mean, Min, Modulo, Multiply, Subtract, Sum +- **Array operations**: ArrayConcatenate, ArrayCrop, ArraySplit, ArraySubtractMinimum +- **Statistical operations**: StandardScale, MinMaxScale, ConditionalStandardScale, Impute +- **Mathematical operations**: BearingAngle, CosineSimilarity, HaversineDistance +- **Logical operations**: LogicalAnd, LogicalNot, LogicalOr +- **Binning/Rounding**: Bin, Round, RoundToDecimal +- **Control flow**: NumericalIfStatement +- **Utility**: Identity + +#### TensorFlow-Only Layers (36 layers) +Located in `kamae.keras.tensorflow.layers/`, require TensorFlow backend: + +- **String operations**: StringAffix, StringArrayConstant, StringCase, StringConcatenate, StringContains, StringContainsList, StringEqualsIfStatement, StringIndex, StringIsInList, StringListToString, StringMap, StringReplace, StringToStringList, SubStringDelimAtIndex +- **DateTime operations**: CurrentDate, CurrentDateTime, CurrentUnixTimestamp, DateAdd, DateDiff, DateParse, DateTimeToUnixTimestamp, UnixTimestampToDateTime +- **List operations**: ListMax, ListMean, ListMedian, ListMin, ListRank, ListStdDev +- **Encoding**: BloomEncode, HashIndex, MinHashIndex, OneHotEncode, OrdinalArrayEncode, SharedOneHotEncode, SharedStringIndex +- **Other**: Bucketize, IfStatement, LambdaFunction, SingleFeatureArrayStandardScale + +### 4. Model Serialization + +**Keras 3 uses `.keras` format** (replaces `.h5`): + +```python +# OLD (Keras 2) +model.save("path/to/model") +model = tf.keras.models.load_model("path/to/model") + +# NEW (Keras 3) +model.save("model.keras") +model = keras.models.load_model("model.keras") +``` + +### 5. Import Changes + +```python +# OLD (Keras 2) +import tensorflow as tf +from kamae.tensorflow.layers import AbsoluteValueLayer + +layer = AbsoluteValueLayer() +model = tf.keras.Model(inputs=inputs, outputs=outputs) +model.save("path/to/model") + +# NEW (Keras 3) +import keras +from kamae.keras.core.layers import AbsoluteValueLayer + +layer = AbsoluteValueLayer() +model = keras.Model(inputs=inputs, outputs=outputs) +model.save("model.keras") +``` + +### 6. DType Changes + +```python +# OLD (Keras 2) +from kamae.utils import DType +dtype = DType.INT +tf_dtype = dtype.tf_dtype # Returns tf.int32 + +# NEW (Keras 3) +from kamae.utils import DType +dtype = DType.INT +keras_dtype = dtype.keras_dtype # Returns "int32" (string) +``` + +### 7. Type Annotations + +```python +# OLD (Keras 2) +from typing import Optional, List +import tensorflow as tf + +def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + return [tf.float32, tf.float64] + +# NEW (Keras 3 - Multi-backend) +from typing import Optional, List + +def compatible_dtypes(self) -> Optional[List[str]]: + return ["float32", "float64"] +``` + +### 8. API Method Renames + +**Methods renamed for backend-agnostic naming:** + +| Old Name (Keras 2) | New Name (Keras 3) | Location | +|-------------------|-------------------|----------| +| `get_tf_layer()` | `get_keras_layer()` | All transformers | +| `getInputTFDtype()` | `getInputKerasDtype()` | Transformer parameters | +| `getOutputTFDtype()` | `getOutputKerasDtype()` | Transformer parameters | +| `get_all_tf_layers()` | `get_all_keras_layers()` | PipelineModel | +| `tf_input_schema` parameter | `input_schema` parameter | build_keras_model() | + +## Migration Checklist + +### For Users + +- [ ] Update model save/load to use `.keras` extension +- [ ] Change `tf.keras` imports to `keras` +- [ ] Update `tf.keras.models.load_model()` to `keras.models.load_model()` +- [ ] Remove Keras 2 vs 3 version checking code +- [ ] Set `KERAS_BACKEND` environment variable if not using TensorFlow +- [ ] Update `tf_input_schema` parameter to `input_schema` in `build_keras_model()` calls + +### For Contributors + +- [ ] Use `kamae.keras.core.layers` for new numeric operations (multi-backend) +- [ ] Use `kamae.keras.tensorflow.layers` for string/datetime operations (TF-only) +- [ ] Import from `kamae.keras.core.base.BaseLayer` (not `kamae.tensorflow.layers.base`) +- [ ] Use `@keras.saving.register_keras_serializable` decorator (not `tf.keras.utils`) +- [ ] Return string dtypes from `compatible_dtypes` property (not tf.DType objects) +- [ ] Use `keras.ops` for numeric operations (not `tf.math`) +- [ ] Add tests to the corresponding test directory (`tests/kamae/keras/core/layers/` for multi-backend layers, `tests/kamae/keras/tensorflow/layers/` for TF-only layers) +- [ ] Use `get_keras_layer()` instead of `get_tf_layer()` in transformer implementations +- [ ] Use `getInputKerasDtype()` and `getOutputKerasDtype()` instead of TF-prefixed versions + +## Backend-Specific String Operations + +The `BaseLayer` class supports string operations, but they **only work on TensorFlow backend**: + +```python +import os +os.environ['KERAS_BACKEND'] = 'tensorflow' + +import keras +from kamae.keras.core.layers import BinLayer + +# String output types work on TensorFlow backend +layer = BinLayer( + condition_operators=["lt", "gt"], + bin_values=[5, 10], + bin_labels=["small", "large"], + default_label="medium" +) +``` + +If you try to use string dtypes on JAX or PyTorch backends, you'll get a clear error message. + +## Testing + +All existing tests pass. Test organization now mirrors source structure: +- `tests/kamae/keras/core/layers/` - 32 test files for multi-backend layers +- `tests/kamae/keras/tensorflow/layers/` - 36 test files for TF-only layers + +## Backward Compatibility + +Spark pipelines continue to work exactly as before: +- All Spark transformers unchanged +- `build_keras_model()` works identically +- Generated Keras models are backward compatible with TensorFlow Serving + +## Performance + +No performance regressions. Multi-backend layers use `keras.ops` which compiles efficiently on all backends. + +## Documentation + +All documentation updated: +- README.md - Updated to Keras 3, removed sklearn references +- docs/adding_transformer.md - Updated for Keras 3 layer development +- docs/chaining_models.md - Updated code examples to use `keras` imports +- examples/spark/*.py - All examples updated to Keras 3 + +## Breaking Changes + +1. **Removed sklearn support** - `kamae.sklearn` package removed (was experimental) +2. **Module paths changed**: + - `kamae.tensorflow.layers` → `kamae.keras.core.layers` or `kamae.keras.tensorflow.layers` + - `kamae.tensorflow.utils` → `kamae.keras.core.utils` or `kamae.keras.tensorflow.utils` + - `kamae.tensorflow.typing` → `kamae.keras.tensorflow.utils.typing` +3. **DType enum** - `tf_dtype` attribute renamed to `keras_dtype` (returns string, not tf.DType) +4. **Model format** - Should use `.keras` extension (`.h5` still works but deprecated) +5. **API method names** - All TensorFlow-prefixed methods renamed for backend-agnostic naming: + - `get_tf_layer()` → `get_keras_layer()` + - `getInputTFDtype()` → `getInputKerasDtype()` + - `getOutputTFDtype()` → `getOutputKerasDtype()` + - `get_all_tf_layers()` → `get_all_keras_layers()` + - `tf_input_schema` parameter → `input_schema` + +## Benefits + +1. **Multi-backend support** - Run on TensorFlow, JAX, or PyTorch +2. **Cleaner architecture** - Clear separation between multi-backend and TF-only code +3. **Better maintainability** - Unified BaseLayer, no code duplication +4. **Future-proof** - Built on Keras 3, the future of Keras +5. **Smaller package** - Removed unmaintained sklearn code + +## Resources + +- [Keras 3 Documentation](https://keras.io/) +- [Keras 3 Migration Guide](https://keras.io/keras_3/) +- [Multi-backend Guide](https://keras.io/guides/distributed_training_with_jax/) diff --git a/examples/sklearn/example_pipeline.py b/examples/sklearn/example_pipeline.py deleted file mode 100644 index 48ef6d8a..00000000 --- a/examples/sklearn/example_pipeline.py +++ /dev/null @@ -1,130 +0,0 @@ -# Copyright [2024] Expedia, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import joblib -import pandas as pd - -from kamae.sklearn.estimators import StandardScaleEstimator -from kamae.sklearn.pipeline import KamaeSklearnPipeline -from kamae.sklearn.transformers import ( - ArrayConcatenateTransformer, - ArraySplitTransformer, - IdentityTransformer, - LogTransformer, -) - -if __name__ == "__main__": - pd.options.display.max_columns = None - pd.options.display.max_rows = None - - # Create some dummy pandas data - df = pd.DataFrame( - { - "col1": [10, 4.8, 7.3], - "col2": [2.5, 5.3, 8.2], - "col3": [3.7, 6.4, 9.4], - "col4": [[1.6, 4.0, 7.0], [2.4, 5.5, 8.1], [3.1, 6.4, 9.1]], - }, - ) - print("Original dataframe:") - print(df.head()) - - # Create a scikit-learn pipeline - log_transformer = LogTransformer( - input_col="col1", - output_col="log_col1", - alpha=1, - layer_name="log_one_plus_x", - ) - identity_transformer = IdentityTransformer( - input_col="col3", - output_col="identity_col3", - layer_name="identity_col3_output", - ) - vector_assembler = ArrayConcatenateTransformer( - input_cols=["log_col1", "col2", "identity_col3", "col4"], - output_col="vec_assembled", - layer_name="vector_assembler", - ) - standard_scaler = StandardScaleEstimator( - input_col="vec_assembled", - output_col="scaled_assembled_vec", - layer_name="standard_scaler", - ) - vector_slicer = ArraySplitTransformer( - input_col="scaled_assembled_vec", - output_cols=[ - "sliced_col1", - "sliced_col2", - "sliced_col3", - "sliced_col4_1", - "sliced_col4_2", - "sliced_col4_3", - ], - layer_name="vector_slicer", - ) - test_pipeline = KamaeSklearnPipeline( - steps=[ - ("identity_transformer", identity_transformer), - ("log_transformer", log_transformer), - ("vec_assembler", vector_assembler), - ("standard_scaler", standard_scaler), - ("vector_slicer", vector_slicer), - ] - ) - - # Fit the pipeline - test_pipeline.fit(df) - # Transform the pipeline - transformed_df = test_pipeline.transform(df) - - print("Transformed dataframe:") - print(transformed_df.head()) - - print("Saving pipeline using joblib...") - joblib.dump(test_pipeline, "./output/test_sklearn_pipeline.joblib") - - print("Loading pipeline using joblib...") - loaded_pipeline = joblib.load("./output/test_sklearn_pipeline.joblib") - - print("Transforming dataframe using loaded pipeline...") - loaded_transformed_df = loaded_pipeline.transform(df) - print(loaded_transformed_df.head()) - - # Get keras model - tf_input_schema = [ - { - "name": "col1", - "dtype": "float32", - "shape": (None, 1), - }, - { - "name": "col2", - "dtype": "float32", - "shape": (None, 1), - }, - { - "name": "col3", - "dtype": "float32", - "shape": (None, 1), - }, - { - "name": "col4", - "dtype": "float32", - "shape": (None, 3), - }, - ] - print("Building keras model...") - keras_model = loaded_pipeline.build_keras_model(tf_input_schema=tf_input_schema) - print(keras_model.summary()) diff --git a/examples/sklearn/example_simple_keras_tuner_pipeline.py b/examples/sklearn/example_simple_keras_tuner_pipeline.py deleted file mode 100644 index efb5d1bd..00000000 --- a/examples/sklearn/example_simple_keras_tuner_pipeline.py +++ /dev/null @@ -1,302 +0,0 @@ -# Copyright [2024] Expedia, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import joblib -import keras -import keras_tuner as kt -import pandas as pd -import tensorflow as tf -from packaging.version import Version - -from kamae.sklearn.estimators import StandardScaleEstimator -from kamae.sklearn.pipeline import KamaeSklearnPipeline -from kamae.sklearn.transformers import ArrayConcatenateTransformer, LogTransformer - -is_keras_3 = Version(keras.__version__) >= Version("3.0.0") - -if __name__ == "__main__": - print( - """Starting test of Spark pipeline, - integration with Tensorflow and Keras Tuner""" - ) - - pd.options.display.max_columns = None - pd.options.display.max_rows = None - - # Create some dummy pandas data - df = pd.DataFrame( - { - "col1": [10, 4.8, 7.3], - "col2": [2.5, 5.3, 8.2], - "col3": [3.7, 6.4, 9.4], - "col4": ["a", "b", "c"], - }, - ) - - print("Original dataframe:") - print(df.head()) - - # Setup transformers, can use set methods or just pass in the args to the constructor. - log_transformer = LogTransformer( - input_col="col1", - output_col="log_col1", - alpha=1, - layer_name="log_col1_one_plus_x", - ) - log_transformer2 = LogTransformer( - input_col="col2", - output_col="log_col2", - alpha=1, - layer_name="log_col2_one_plus_x", - ) - vector_assembler = ArrayConcatenateTransformer( - input_cols=["log_col1", "log_col2", "col3"], - output_col="features", - layer_name="vec_assemble_log_col1_col2_col3", - ) - - standard_scalar_layer = StandardScaleEstimator( - input_col="features", - output_col="scaled_features", - layer_name="standard_scaler", - ) - - print("Creating pipeline and writing to disk") - test_pipeline = KamaeSklearnPipeline( - steps=[ - ("log_transformer_1", log_transformer), - ("log_transformer_2", log_transformer2), - ("vector_assembler", vector_assembler), - ("standard_scaler", standard_scalar_layer), - ] - ) - - joblib.dump(test_pipeline, "./output/test_pipeline.joblib") - - print("Loading pipeline from disk") - loaded_pipeline = joblib.load("./output/test_pipeline.joblib") - - print("Transforming data with loaded pipeline") - fit_pipeline = loaded_pipeline.fit(df) - print(fit_pipeline.transform(df).head()) - - print("Building keras tuner model builder function from fit pipeline") - # Create input schema for keras model - # The values here will be inserted into tf.keras.Input layers - # using kwargs ** syntax. - tf_input_schema = [ - { - "name": "col1", - "dtype": "float32", - "shape": (1,), - }, - { - "name": "col2", - "dtype": "float32", - "shape": (1,), - }, - { - "name": "col3", - "dtype": "float32", - "shape": (1,), - }, - ] - - # In order to use the keras tuner we need to define a dictionary of hyperparameters - # The format is as follows: - # { - # "layer_name": [ - # { - # "arg_name": , - # "method": , e.g. "choice" - # "kwargs": { - # - # } - # } - # ] - # } - - hyper_param_dict = { - "log_col1_one_plus_x": [ - { - "arg_name": "alpha", - "method": "choice", - "kwargs": { - "name": "log_one_plus_x_alpha", - "values": [1, 10, 20], - }, - } - ], - "log_col2_one_plus_x": [ - { - "arg_name": "alpha", - "method": "float", - "kwargs": { - "name": "log2_one_plus_x_alpha", - "min_value": 1.0, - "max_value": 20.0, - }, - } - ], - } - - build_prepro_model = fit_pipeline.get_keras_tuner_model_builder( - tf_input_schema=tf_input_schema, - hp_dict=hyper_param_dict, - ) - - # Next we setup the model builder function. Here we will use the function - # we just got for the preprocessing hyperparameters and then add a dense layer - # with a hyperparameter for the number of units. - - def build_model(hp): - prepro_model = build_prepro_model(hp) - prepro_output_layer = prepro_model.outputs[0] - # Add dense layer with hyperparameter on top of prepro model output. - dense_layer = tf.keras.layers.Dense( - units=hp.Int("units", min_value=32, max_value=512, step=32), - activation="relu", - name="dense_layer", - )(prepro_output_layer) - output_layer = tf.keras.layers.Dense( - units=1, - activation="relu", - name="output_layer", - )(dense_layer) - - # We need to be careful not to end up with a disconnected graph when combining - # the preprocessing model and the rest of the training. - - model = tf.keras.Model( - inputs=prepro_model.inputs, - outputs=output_layer, - ) - - model.compile( - optimizer=tf.keras.optimizers.Adam( - hp.Choice("learning_rate", values=[1e-2, 1e-3, 1e-4]) - ), - loss="mse", - metrics=["mse"], - ) - return model - - print("Creating keras tuner object") - tuner = kt.RandomSearch( - build_model, - objective="val_loss", - max_trials=5, - project_name="output/test_keras_tuner_simple", - ) - - # Create some fake data for training and validation. This will be used in the keras - # tuner to train and evaluate the model. - x_train = [ - tf.constant( - [ - [1.0], - [2.0], - [3.0], - [4.0], - [5.0], - [6.0], - ] - ), - tf.constant( - [ - [45.0], - [48.0], - [51.0], - [54.0], - [57.0], - [60.0], - ] - ), - tf.constant( - [ - [5.0], - [8.0], - [1.0], - [4.0], - [7.0], - [0.0], - ] - ), - ] - - y_train = tf.constant( - [ - [1.0], - [2.0], - [3.0], - [4.0], - [52.0], - [53.0], - ] - ) - - x_val = [ - tf.constant( - [ - [3.0], - [5.0], - [6.0], - ] - ), - tf.constant( - [ - [45.0], - [48.0], - [54.0], - ] - ), - tf.constant( - [ - [5.0], - [8.0], - [4.0], - ] - ), - ] - - y_val = tf.constant( - [ - [1.0], - [3.0], - [53.0], - ] - ) - - print("Running keras tuner search") - tuner.search(x_train, y_train, epochs=5, validation_data=(x_val, y_val)) - - print("Best model summary") - best_model = tuner.get_best_models()[0] - print(best_model.summary()) - - print("Best hyperparameters") - best_hp = tuner.get_best_hyperparameters()[0] - print(best_hp.values) - - print("Saving best model") - model_path = "output/test_keras_tuner_simple_best_model" - if is_keras_3: - model_path += ".keras" - best_model.save(model_path) - - print("Loading best model") - loaded_best_model = tf.keras.models.load_model(model_path) - - print("Predict with best model") - print(loaded_best_model.predict(x_val)) diff --git a/examples/spark/example_array_transform.py b/examples/spark/example_array_transform.py index d0e1965f..2f1a0e63 100644 --- a/examples/spark/example_array_transform.py +++ b/examples/spark/example_array_transform.py @@ -14,7 +14,6 @@ import keras import tensorflow as tf -from packaging.version import Version from pyspark.sql import SparkSession from kamae.spark.estimators import StringIndexEstimator @@ -25,8 +24,6 @@ OrdinalArrayEncodeTransformer, ) -is_keras_3 = Version(keras.__version__) >= Version("3.0.0") - if __name__ == "__main__": print("Starting test of Spark pipeline and integration with Tensorflow") @@ -124,7 +121,7 @@ print("Transformed array fake data") loaded_fitted_pipeline.transform(array_fake_data_to_transform).show(20, False) - tf_input_schema = [ + input_schema = [ { "name": "col4", "dtype": "string", @@ -136,9 +133,7 @@ "shape": (None, None), }, ] - keras_model = loaded_fitted_pipeline.build_keras_model( - tf_input_schema=tf_input_schema - ) + keras_model = loaded_fitted_pipeline.build_keras_model(input_schema=input_schema) print(keras_model.summary()) print("Start: Predicting with the model with reg_inputs") @@ -212,16 +207,14 @@ # Saving model in pb format print("Saving model in pb format") - model_path = "./output/test_saved_model" - if is_keras_3: - model_path += ".keras" + model_path = "./output/test_saved_model.keras" keras_model.save(model_path) print("Model saved in pb format") # Load model from SavedModel format print("Loading model from pb format") - loaded_model = tf.keras.models.load_model(model_path) + loaded_model = keras.models.load_model(model_path) print("Model loaded from pb format") # Predict with the loaded model diff --git a/examples/spark/example_cosine_sim_pipeline.py b/examples/spark/example_cosine_sim_pipeline.py index 3d895204..23c4e4c2 100644 --- a/examples/spark/example_cosine_sim_pipeline.py +++ b/examples/spark/example_cosine_sim_pipeline.py @@ -14,14 +14,11 @@ import keras import tensorflow as tf -from packaging.version import Version from pyspark.sql import SparkSession from kamae.spark.pipeline import KamaeSparkPipeline, KamaeSparkPipelineModel from kamae.spark.transformers import CosineSimilarityTransformer -is_keras_3 = Version(keras.__version__) >= Version("3.0.0") - if __name__ == "__main__": print("Starting test of Spark pipeline and integration with Tensorflow") @@ -71,7 +68,7 @@ print("Building keras model from fit pipeline") # Create input schema for keras model. - tf_input_schema = [ + input_schema = [ { "name": "col1", "dtype": "float32", @@ -83,17 +80,13 @@ "shape": (None, 4), }, ] - keras_model = loaded_fitted_pipeline.build_keras_model( - tf_input_schema=tf_input_schema - ) + keras_model = loaded_fitted_pipeline.build_keras_model(input_schema=input_schema) print(keras_model.summary()) - model_path = "./output/test_saved_model" - if is_keras_3: - model_path += ".keras" + model_path = "./output/test_saved_model.keras" keras_model.save(model_path) print("Loading keras model from disk") - loaded_keras_model = tf.keras.models.load_model(model_path) + loaded_keras_model = keras.models.load_model(model_path) inputs = [ tf.constant( [ diff --git a/examples/spark/example_date_diff_transform.py b/examples/spark/example_date_diff_transform.py index 2ac17bc1..791878e9 100644 --- a/examples/spark/example_date_diff_transform.py +++ b/examples/spark/example_date_diff_transform.py @@ -14,14 +14,11 @@ import keras import tensorflow as tf -from packaging.version import Version from pyspark.sql import SparkSession from kamae.spark.pipeline import KamaeSparkPipeline from kamae.spark.transformers import DateDiffTransformer -is_keras_3 = Version(keras.__version__) >= Version("3.0.0") - if __name__ == "__main__": print("Starting test of Spark pipeline and integration with Tensorflow") @@ -59,7 +56,7 @@ # Create input schema for keras model. # Or a list of dicts. - tf_input_schema = [ + input_schema = [ { "name": "col1", "dtype": tf.string, @@ -81,15 +78,13 @@ "shape": (None, 1), }, ] - keras_model = fit_pipeline.build_keras_model(tf_input_schema=tf_input_schema) + keras_model = fit_pipeline.build_keras_model(input_schema=input_schema) print(keras_model.summary()) - model_path = "./output/test_saved_model" - if is_keras_3: - model_path += ".keras" + model_path = "./output/test_saved_model.keras" keras_model.save(model_path) print("Loading keras model from disk") - loaded_keras_model = tf.keras.models.load_model(model_path) + loaded_keras_model = keras.models.load_model(model_path) inputs = [ tf.constant( [ diff --git a/examples/spark/example_date_parse_pipeline.py b/examples/spark/example_date_parse_pipeline.py index 3cf7d4cb..fbe360e9 100644 --- a/examples/spark/example_date_parse_pipeline.py +++ b/examples/spark/example_date_parse_pipeline.py @@ -14,14 +14,11 @@ import keras import tensorflow as tf -from packaging.version import Version from pyspark.sql import SparkSession from kamae.spark.pipeline import KamaeSparkPipeline from kamae.spark.transformers import DateParseTransformer -is_keras_3 = Version(keras.__version__) >= Version("3.0.0") - if __name__ == "__main__": print("Starting test of Spark pipeline and integration with Tensorflow") @@ -88,7 +85,7 @@ fit_pipeline.transform(fake_data).show(20, False) # Create input schema for keras model. - tf_input_schema = [ + input_schema = [ { "name": "col5", "dtype": tf.string, @@ -100,15 +97,13 @@ "shape": (None, 3, 1), }, ] - keras_model = fit_pipeline.build_keras_model(tf_input_schema=tf_input_schema) + keras_model = fit_pipeline.build_keras_model(input_schema=input_schema) print(keras_model.summary()) - model_path = "./output/test_saved_model" - if is_keras_3: - model_path += ".keras" + model_path = "./output/test_saved_model.keras" keras_model.save(model_path) print("Loading keras model from disk") - loaded_keras_model = tf.keras.models.load_model(model_path) + loaded_keras_model = keras.models.load_model(model_path) inputs = [ tf.constant( [ diff --git a/examples/spark/example_hash_indexer_keras_tuner_pipeline.py b/examples/spark/example_hash_indexer_keras_tuner_pipeline.py index 62e0bd17..c57e112a 100644 --- a/examples/spark/example_hash_indexer_keras_tuner_pipeline.py +++ b/examples/spark/example_hash_indexer_keras_tuner_pipeline.py @@ -15,14 +15,11 @@ import keras import keras_tuner as kt import tensorflow as tf -from packaging.version import Version from pyspark.sql import SparkSession from kamae.spark.pipeline import KamaeSparkPipeline, KamaeSparkPipelineModel from kamae.spark.transformers import HashIndexTransformer -is_keras_3 = Version(keras.__version__) >= Version("3.0.0") - if __name__ == "__main__": print( """ @@ -95,7 +92,7 @@ print("Building keras tuner model builder function from fit pipeline") # Create input schema for keras model. - tf_input_schema = [ + input_schema = [ { "name": "col1", "dtype": "string", @@ -163,7 +160,7 @@ } build_prepro_model = loaded_fitted_pipeline.get_keras_tuner_model_builder( - tf_input_schema=tf_input_schema, + input_schema=input_schema, hp_dict=hyper_param_dict, ) @@ -342,13 +339,11 @@ def build_model(hp): print(best_hp.values) print("Saving best model") - model_path = "./output/test_keras_tuner_hash_best_model" - if is_keras_3: - model_path += ".keras" + model_path = "./output/test_keras_tuner_hash_best_model.keras" best_model.save(model_path) print("Loading best model") - loaded_best_model = tf.keras.models.load_model(model_path) + loaded_best_model = keras.models.load_model(model_path) print("Predict with best model") print(loaded_best_model.predict(x_val)) diff --git a/examples/spark/example_haversine_distance_pipeline.py b/examples/spark/example_haversine_distance_pipeline.py index a8343a48..ac560125 100644 --- a/examples/spark/example_haversine_distance_pipeline.py +++ b/examples/spark/example_haversine_distance_pipeline.py @@ -14,14 +14,11 @@ import keras import tensorflow as tf -from packaging.version import Version from pyspark.sql import SparkSession from kamae.spark.pipeline import KamaeSparkPipeline from kamae.spark.transformers import HaversineDistanceTransformer -is_keras_3 = Version(keras.__version__) >= Version("3.0.0") - if __name__ == "__main__": print("Starting test of Spark pipeline and integration with Tensorflow") @@ -63,7 +60,7 @@ # Create input schema for keras model. # Or a list of dicts. - tf_input_schema = [ + input_schema = [ { "name": "lat1", "dtype": tf.float32, @@ -85,11 +82,9 @@ "shape": (None, 1), }, ] - keras_model = fit_pipeline.build_keras_model(tf_input_schema=tf_input_schema) + keras_model = fit_pipeline.build_keras_model(input_schema=input_schema) print(keras_model.summary()) - model_path = "./output/test_saved_model" - if is_keras_3: - model_path += ".keras" + model_path = "./output/test_saved_model.keras" keras_model.save(model_path) # print("Loading keras model from disk") diff --git a/examples/spark/example_if_statements_pipeline.py b/examples/spark/example_if_statements_pipeline.py index d7b2e534..e7bf64ad 100644 --- a/examples/spark/example_if_statements_pipeline.py +++ b/examples/spark/example_if_statements_pipeline.py @@ -14,7 +14,6 @@ import keras import tensorflow as tf -from packaging.version import Version from pyspark.sql import SparkSession from kamae.spark.pipeline import KamaeSparkPipeline, KamaeSparkPipelineModel @@ -23,8 +22,6 @@ StringEqualsIfStatementTransformer, ) -is_keras_3 = Version(keras.__version__) >= Version("3.0.0") - if __name__ == "__main__": print("Starting test of Spark pipeline and integration with Tensorflow") @@ -83,7 +80,7 @@ print("Building keras model from fit pipeline") # Create input schema for keras model. - tf_input_schema = [ + input_schema = [ { "name": "col1", "dtype": tf.int32, @@ -105,17 +102,13 @@ "shape": (None, 1), }, ] - keras_model = loaded_fitted_pipeline.build_keras_model( - tf_input_schema=tf_input_schema - ) + keras_model = loaded_fitted_pipeline.build_keras_model(input_schema=input_schema) print(keras_model.summary()) - model_path = "./output/test_saved_model" - if is_keras_3: - model_path += ".keras" + model_path = "./output/test_saved_model.keras" keras_model.save(model_path) print("Loading keras model from disk") - loaded_keras_model = tf.keras.models.load_model(model_path) + loaded_keras_model = keras.models.load_model(model_path) inputs = [ tf.constant([[[1], [4], [7]]]), tf.constant([[[2], [5], [8]]]), diff --git a/examples/spark/example_imputation.py b/examples/spark/example_imputation.py index f7dc0f2a..23575919 100644 --- a/examples/spark/example_imputation.py +++ b/examples/spark/example_imputation.py @@ -14,14 +14,11 @@ import keras import tensorflow as tf -from packaging.version import Version from pyspark.sql import SparkSession from kamae.spark.estimators import ImputeEstimator from kamae.spark.pipeline import KamaeSparkPipeline, KamaeSparkPipelineModel -is_keras_3 = Version(keras.__version__) >= Version("3.0.0") - if __name__ == "__main__": print("Starting test of Spark pipeline and integration with Tensorflow") @@ -88,7 +85,7 @@ print("Building keras model from fit pipeline") # Create input schema for keras model. - tf_input_schema = [ + input_schema = [ { "name": "col1", "dtype": "int32", @@ -110,17 +107,13 @@ "shape": (None, 1), }, ] - keras_model = loaded_fitted_pipeline.build_keras_model( - tf_input_schema=tf_input_schema - ) + keras_model = loaded_fitted_pipeline.build_keras_model(input_schema=input_schema) print(keras_model.summary()) - model_path = "./output/test_saved_model" - if is_keras_3: - model_path += ".keras" + model_path = "./output/test_saved_model.keras" keras_model.save(model_path) print("Loading keras model from disk") - loaded_keras_model = tf.keras.models.load_model(model_path) + loaded_keras_model = keras.models.load_model(model_path) inputs = [ tf.constant([[[1], [4], [7], [100]]]), tf.constant([[[2], [5], [8], [100]]]), diff --git a/examples/spark/example_listwise_stats.py b/examples/spark/example_listwise_stats.py index ad3e4645..4611c097 100644 --- a/examples/spark/example_listwise_stats.py +++ b/examples/spark/example_listwise_stats.py @@ -14,7 +14,6 @@ import keras import tensorflow as tf -from packaging.version import Version from pyspark.sql import SparkSession from kamae.spark.pipeline import KamaeSparkPipeline @@ -24,8 +23,6 @@ ListMinTransformer, ) -is_keras_3 = Version(keras.__version__) >= Version("3.0.0") - if __name__ == "__main__": print("Starting test of Spark pipeline and integration with Tensorflow") @@ -132,7 +129,7 @@ fit_pipeline.transform(fake_data).show(20, False) # Create input schema for keras model. - tf_input_schema = [ + input_schema = [ { "name": "col2", "dtype": "float32", @@ -149,15 +146,13 @@ "shape": (None, 1), }, ] - keras_model = fit_pipeline.build_keras_model(tf_input_schema=tf_input_schema) + keras_model = fit_pipeline.build_keras_model(input_schema=input_schema) print(keras_model.summary()) - model_path = "./output/test_keras_model" - if is_keras_3: - model_path += ".keras" + model_path = "./output/test_keras_model.keras" keras_model.save(model_path) print("Loading keras model from disk") - loaded_keras_model = tf.keras.models.load_model(model_path) + loaded_keras_model = keras.models.load_model(model_path) inputs = { "col2": tf.constant( [ diff --git a/examples/spark/example_logical_operations_pipeline.py b/examples/spark/example_logical_operations_pipeline.py index 7c14a233..d09ce08d 100755 --- a/examples/spark/example_logical_operations_pipeline.py +++ b/examples/spark/example_logical_operations_pipeline.py @@ -14,7 +14,6 @@ import keras import tensorflow as tf -from packaging.version import Version from pyspark.sql import SparkSession from kamae.spark.pipeline import KamaeSparkPipeline, KamaeSparkPipelineModel @@ -24,8 +23,6 @@ LogicalOrTransformer, ) -is_keras_3 = Version(keras.__version__) >= Version("3.0.0") - if __name__ == "__main__": print("Starting test of Spark pipeline and integration with Tensorflow") @@ -100,7 +97,7 @@ print("Building keras model from fit pipeline") # Create input schema for keras model. - tf_input_schema = [ + input_schema = [ { "name": "col1", "dtype": tf.bool, @@ -112,17 +109,13 @@ "shape": (1,), }, ] - keras_model = loaded_fitted_pipeline.build_keras_model( - tf_input_schema=tf_input_schema - ) + keras_model = loaded_fitted_pipeline.build_keras_model(input_schema=input_schema) print(keras_model.summary()) - model_path = "./output/test_saved_model" - if is_keras_3: - model_path += ".keras" + model_path = "./output/test_saved_model.keras" keras_model.save(model_path) print("Loading keras model from disk") - loaded_keras_model = tf.keras.models.load_model(model_path) + loaded_keras_model = keras.models.load_model(model_path) inputs = [ tf.constant([[True], [True], [False], [False]]), tf.constant([[True], [False], [True], [False]]), diff --git a/examples/spark/example_oh_encoder_pipeline.py b/examples/spark/example_oh_encoder_pipeline.py index 69dd3b1c..1ea6b9db 100644 --- a/examples/spark/example_oh_encoder_pipeline.py +++ b/examples/spark/example_oh_encoder_pipeline.py @@ -14,7 +14,6 @@ import keras import tensorflow as tf -from packaging.version import Version from pyspark.sql import SparkSession from kamae.spark.estimators import OneHotEncodeEstimator, StandardScaleEstimator @@ -25,8 +24,6 @@ LogTransformer, ) -is_keras_3 = Version(keras.__version__) >= Version("3.0.0") - if __name__ == "__main__": print("Starting test of Spark pipeline and integration with Tensorflow") @@ -114,7 +111,7 @@ print("Building keras model from fit pipeline") # Create input schema for keras model. - tf_input_schema = [ + input_schema = [ { "name": "col1", "dtype": tf.int32, @@ -136,17 +133,13 @@ "shape": (None, 1), }, ] - keras_model = loaded_fitted_pipeline.build_keras_model( - tf_input_schema=tf_input_schema - ) + keras_model = loaded_fitted_pipeline.build_keras_model(input_schema=input_schema) print(keras_model.summary()) - model_path = "./output/test_keras_model" - if is_keras_3: - model_path += ".keras" + model_path = "./output/test_keras_model.keras" keras_model.save(model_path) print("Loading keras model from disk") - loaded_keras_model = tf.keras.models.load_model(model_path) + loaded_keras_model = keras.models.load_model(model_path) inputs = [ tf.constant([[[1], [4], [7], [7]]]), tf.constant([[[2], [5], [8], [8]]]), diff --git a/examples/spark/example_pipeline.py b/examples/spark/example_pipeline.py index 9f9b43d7..d594c6e8 100755 --- a/examples/spark/example_pipeline.py +++ b/examples/spark/example_pipeline.py @@ -13,8 +13,6 @@ # limitations under the License. import keras -import tensorflow as tf -from packaging.version import Version from pyspark.sql import SparkSession from kamae.spark.estimators import StandardScaleEstimator, StringIndexEstimator @@ -27,8 +25,6 @@ LogTransformer, ) -is_keras_3 = Version(keras.__version__) >= Version("3.0.0") - if __name__ == "__main__": print("Starting test of Spark pipeline and integration with Tensorflow") @@ -148,7 +144,7 @@ print("Building keras model from fit pipeline") # Create input schema for keras model. - tf_input_schema = [ + input_schema = [ { "name": "col1", "dtype": "int32", @@ -170,17 +166,16 @@ "shape": (None, 1), }, ] - keras_model = loaded_fitted_pipeline.build_keras_model( - tf_input_schema=tf_input_schema - ) + keras_model = loaded_fitted_pipeline.build_keras_model(input_schema=input_schema) print(keras_model.summary()) - model_path = "./output/test_keras_model" - if is_keras_3: - model_path += ".keras" + model_path = "./output/test_keras_model.keras" keras_model.save(model_path) print("Loading keras model from disk") - loaded_keras_model = tf.keras.models.load_model(model_path) + loaded_keras_model = keras.models.load_model(model_path) + + import tensorflow as tf + inputs = [ tf.constant([[[1], [4], [7]]]), tf.constant([[[2], [5], [8]]]), diff --git a/examples/spark/example_pipeline_lambda_fn.py b/examples/spark/example_pipeline_lambda_fn.py index f2514cb0..06fff4c2 100755 --- a/examples/spark/example_pipeline_lambda_fn.py +++ b/examples/spark/example_pipeline_lambda_fn.py @@ -15,15 +15,12 @@ import keras import tensorflow as tf -from packaging.version import Version from pyspark.sql import SparkSession from pyspark.sql.types import ArrayType, FloatType from kamae.spark.pipeline import KamaeSparkPipeline, KamaeSparkPipelineModel from kamae.spark.transformers import LambdaFunctionTransformer -is_keras_3 = Version(keras.__version__) >= Version("3.0.0") - if __name__ == "__main__": print("Starting test of Spark pipeline and integration with Tensorflow") @@ -115,7 +112,7 @@ def my_multi_input_multi_output_fn(x): ) # Create input schema for keras model. - tf_input_schema = [ + input_schema = [ { "name": "col2", "dtype": "int32", @@ -127,17 +124,13 @@ def my_multi_input_multi_output_fn(x): "shape": (None, 1), }, ] - keras_model = loaded_fitted_pipeline.build_keras_model( - tf_input_schema=tf_input_schema - ) + keras_model = loaded_fitted_pipeline.build_keras_model(input_schema=input_schema) # print(keras_model.summary()) - model_path = "./output/test_keras_model" - if is_keras_3: - model_path += ".keras" + model_path = "./output/test_keras_model.keras" keras_model.save(model_path) print("Loading keras model from disk") - loaded_keras_model = tf.keras.models.load_model(model_path) + loaded_keras_model = keras.models.load_model(model_path) inputs = [ tf.constant([[[2], [5], [8]]]), tf.constant([[[3], [6], [9]]]), diff --git a/examples/spark/example_pipeline_strings.py b/examples/spark/example_pipeline_strings.py index bcd57b2b..086ee02e 100755 --- a/examples/spark/example_pipeline_strings.py +++ b/examples/spark/example_pipeline_strings.py @@ -14,7 +14,6 @@ import keras import tensorflow as tf -from packaging.version import Version from pyspark.sql import SparkSession from kamae.spark.pipeline import KamaeSparkPipeline, KamaeSparkPipelineModel @@ -25,8 +24,6 @@ StringListToStringTransformer, ) -is_keras_3 = Version(keras.__version__) >= Version("3.0.0") - if __name__ == "__main__": print("Starting test of Spark pipeline and integration with Tensorflow") @@ -94,24 +91,20 @@ print("Building keras model from fit pipeline") # Create input schema for keras model. - tf_input_schema = [ + input_schema = [ { "name": "col4", "dtype": tf.string, "shape": (None, 1), }, ] - keras_model = loaded_fitted_pipeline.build_keras_model( - tf_input_schema=tf_input_schema - ) + keras_model = loaded_fitted_pipeline.build_keras_model(input_schema=input_schema) print(keras_model.summary()) - model_path = "./output/test_keras_model" - if is_keras_3: - model_path += ".keras" + model_path = "./output/test_keras_model.keras" keras_model.save(model_path) print("Loading keras model from disk") - loaded_keras_model = tf.keras.models.load_model(model_path) + loaded_keras_model = keras.models.load_model(model_path) inputs = [ tf.constant([[["a"], ["b"], ["c"]]]), ] diff --git a/examples/spark/example_pipeline_with_nulls.py b/examples/spark/example_pipeline_with_nulls.py index a4590981..c389d188 100644 --- a/examples/spark/example_pipeline_with_nulls.py +++ b/examples/spark/example_pipeline_with_nulls.py @@ -14,7 +14,6 @@ import keras import tensorflow as tf -from packaging.version import Version from pyspark.sql import SparkSession from kamae.spark.estimators import StandardScaleEstimator, StringIndexEstimator @@ -26,8 +25,6 @@ LogTransformer, ) -is_keras_3 = Version(keras.__version__) >= Version("3.0.0") - if __name__ == "__main__": print("Starting test of Spark pipeline and integration with Tensorflow") @@ -136,7 +133,7 @@ print("Building keras model from fit pipeline") # Create input schema for keras model. # Or a list of dicts - tf_input_schema = [ + input_schema = [ { "name": "col1", "dtype": tf.int32, @@ -158,17 +155,13 @@ "shape": (None, 1), }, ] - keras_model = loaded_fitted_pipeline.build_keras_model( - tf_input_schema=tf_input_schema - ) + keras_model = loaded_fitted_pipeline.build_keras_model(input_schema=input_schema) print(keras_model.summary()) - model_path = "./output/test_keras_model" - if is_keras_3: - model_path += ".keras" + model_path = "./output/test_keras_model.keras" keras_model.save(model_path) print("Loading keras model from disk") - loaded_keras_model = tf.keras.models.load_model(model_path) + loaded_keras_model = keras.models.load_model(model_path) inputs = [ tf.constant([[[1], [4], [7]]]), tf.constant([[[2], [5], [8]]]), diff --git a/examples/spark/example_round_mod_pipeline.py b/examples/spark/example_round_mod_pipeline.py index f029ce6e..2c9a2c28 100644 --- a/examples/spark/example_round_mod_pipeline.py +++ b/examples/spark/example_round_mod_pipeline.py @@ -14,7 +14,6 @@ import keras import tensorflow as tf -from packaging.version import Version from pyspark.sql import SparkSession from kamae.spark.pipeline import KamaeSparkPipeline, KamaeSparkPipelineModel @@ -24,8 +23,6 @@ RoundTransformer, ) -is_keras_3 = Version(keras.__version__) >= Version("3.0.0") - if __name__ == "__main__": print("Starting test of Spark pipeline and integration with Tensorflow") @@ -114,7 +111,7 @@ print("Building keras model from fit pipeline") # Create input schema for keras model. - tf_input_schema = [ + input_schema = [ { "name": "col1", "dtype": tf.float32, @@ -131,17 +128,13 @@ "shape": (None, 1), }, ] - keras_model = loaded_fitted_pipeline.build_keras_model( - tf_input_schema=tf_input_schema - ) + keras_model = loaded_fitted_pipeline.build_keras_model(input_schema=input_schema) print(keras_model.summary()) - model_path = "./output/test_keras_model" - if is_keras_3: - model_path += ".keras" + model_path = "./output/test_keras_model.keras" keras_model.save(model_path) print("Loading keras model from disk") - loaded_keras_model = tf.keras.models.load_model(model_path) + loaded_keras_model = keras.models.load_model(model_path) inputs = [ tf.constant([[[1.4567], [4.2343], [7.1234435]]]), tf.constant([[[2.23424], [5.46456], [8.45657567]]]), diff --git a/examples/spark/example_simple_jax_pipeline.py b/examples/spark/example_simple_jax_pipeline.py index 742512ff..253f0519 100644 --- a/examples/spark/example_simple_jax_pipeline.py +++ b/examples/spark/example_simple_jax_pipeline.py @@ -173,7 +173,7 @@ fit_pipeline = test_pipeline.fit(fit_data) fit_pipeline.transform(fit_data).show() - tf_input_schema = [ + input_schema = [ { "name": col, "dtype": tf.float32, @@ -182,7 +182,7 @@ for col in x_schema ] - tf_preproc_model = fit_pipeline.build_keras_model(tf_input_schema=tf_input_schema) + tf_preproc_model = fit_pipeline.build_keras_model(input_schema=input_schema) tf_preproc_model.summary() print("\n* Build and train a JAX neural network\n") diff --git a/examples/spark/example_simple_keras_tuner_pipeline.py b/examples/spark/example_simple_keras_tuner_pipeline.py index e74d1845..ee82a75e 100644 --- a/examples/spark/example_simple_keras_tuner_pipeline.py +++ b/examples/spark/example_simple_keras_tuner_pipeline.py @@ -15,15 +15,12 @@ import keras import keras_tuner as kt import tensorflow as tf -from packaging.version import Version from pyspark.sql import SparkSession from kamae.spark.estimators import StandardScaleEstimator from kamae.spark.pipeline import KamaeSparkPipeline, KamaeSparkPipelineModel from kamae.spark.transformers import ArrayConcatenateTransformer, LogTransformer -is_keras_3 = Version(keras.__version__) >= Version("3.0.0") - if __name__ == "__main__": print( """Starting test of Spark pipeline, @@ -100,7 +97,7 @@ print("Building keras tuner model builder function from fit pipeline") # Create input schema for keras model. - tf_input_schema = [ + input_schema = [ { "name": "col1", "dtype": tf.int32, @@ -157,7 +154,7 @@ } build_prepro_model = loaded_fitted_pipeline.get_keras_tuner_model_builder( - tf_input_schema=tf_input_schema, + input_schema=input_schema, hp_dict=hyper_param_dict, ) @@ -297,13 +294,11 @@ def build_model(hp): print(best_hp.values) print("Saving best model") - model_path = "output/test_keras_tuner_simple_best_model" - if is_keras_3: - model_path += ".keras" + model_path = "output/test_keras_tuner_simple_best_model.keras" best_model.save(model_path) print("Loading best model") - loaded_best_model = tf.keras.models.load_model(model_path) + loaded_best_model = keras.models.load_model(model_path) print("Predict with best model") print(loaded_best_model.predict(x_val)) diff --git a/examples/spark/example_string_list_to_list_pipeline.py b/examples/spark/example_string_list_to_list_pipeline.py index 0daa2182..9fad2544 100644 --- a/examples/spark/example_string_list_to_list_pipeline.py +++ b/examples/spark/example_string_list_to_list_pipeline.py @@ -14,7 +14,6 @@ import keras import tensorflow as tf -from packaging.version import Version from pyspark.sql import SparkSession from kamae.spark.pipeline import KamaeSparkPipeline, KamaeSparkPipelineModel @@ -23,8 +22,6 @@ StringToStringListTransformer, ) -is_keras_3 = Version(keras.__version__) >= Version("3.0.0") - if __name__ == "__main__": print("Starting test of Spark pipeline and integration with Tensorflow") @@ -93,7 +90,7 @@ print("Building keras model from fit pipeline") # Create input schema for keras model. - tf_input_schema = [ + input_schema = [ { "name": "col1", "dtype": tf.string, @@ -110,17 +107,13 @@ "shape": (1,), }, ] - keras_model = loaded_fitted_pipeline.build_keras_model( - tf_input_schema=tf_input_schema - ) + keras_model = loaded_fitted_pipeline.build_keras_model(input_schema=input_schema) print(keras_model.summary()) - model_path = "./output/test_keras_model" - if is_keras_3: - model_path += ".keras" + model_path = "./output/test_keras_model.keras" keras_model.save(model_path) print("Loading keras model from disk") - loaded_keras_model = tf.keras.models.load_model(model_path) + loaded_keras_model = keras.models.load_model(model_path) inputs = [ tf.constant([[["a", "b", "c"], ["d", "e", "f"]]]), tf.constant([[["g", "h", "i"], ["j", "k", "l"]]]), diff --git a/examples/spark/example_string_pipeline.py b/examples/spark/example_string_pipeline.py index 28fa0f61..f26b1b62 100644 --- a/examples/spark/example_string_pipeline.py +++ b/examples/spark/example_string_pipeline.py @@ -14,7 +14,6 @@ import keras import tensorflow as tf -from packaging.version import Version from pyspark.sql import SparkSession from kamae.spark.pipeline import KamaeSparkPipeline, KamaeSparkPipelineModel @@ -23,8 +22,6 @@ SubStringDelimAtIndexTransformer, ) -is_keras_3 = Version(keras.__version__) >= Version("3.0.0") - if __name__ == "__main__": print("Starting test of Spark pipeline and integration with Tensorflow") @@ -82,7 +79,7 @@ loaded_fit_pipeline = KamaeSparkPipelineModel.load("./output/test_pipeline/") # Create input schema for keras model. - tf_input_schema = [ + input_schema = [ { "name": "col4", "dtype": tf.string, @@ -94,15 +91,13 @@ "shape": (None, None, 1), }, ] - keras_model = loaded_fit_pipeline.build_keras_model(tf_input_schema=tf_input_schema) + keras_model = loaded_fit_pipeline.build_keras_model(input_schema=input_schema) print(keras_model.summary()) - model_path = "./output/test_keras_model" - if is_keras_3: - model_path += ".keras" + model_path = "./output/test_keras_model.keras" keras_model.save(model_path) print("Loading keras model from disk") - loaded_keras_model = tf.keras.models.load_model(model_path) + loaded_keras_model = keras.models.load_model(model_path) inputs = [ tf.constant( [ diff --git a/examples/spark/example_string_replace_pipeline.py b/examples/spark/example_string_replace_pipeline.py index 1d7bc87f..c9e1b553 100644 --- a/examples/spark/example_string_replace_pipeline.py +++ b/examples/spark/example_string_replace_pipeline.py @@ -14,14 +14,11 @@ import keras import tensorflow as tf -from packaging.version import Version from pyspark.sql import SparkSession from kamae.spark.pipeline import KamaeSparkPipeline from kamae.spark.transformers import StringReplaceTransformer -is_keras_3 = Version(keras.__version__) >= Version("3.0.0") - if __name__ == "__main__": print("Starting test of Spark pipeline and integration with Tensorflow") @@ -64,7 +61,7 @@ fit_pipeline.transform(fake_data).show(20, False) # Create input schema for keras model. - tf_input_schema = [ + input_schema = [ { "name": "col4", "dtype": tf.string, @@ -81,15 +78,13 @@ "shape": (None, 1), }, ] - keras_model = fit_pipeline.build_keras_model(tf_input_schema=tf_input_schema) + keras_model = fit_pipeline.build_keras_model(input_schema=input_schema) print(keras_model.summary()) - model_path = "./output/test_keras_model" - if is_keras_3: - model_path += ".keras" + model_path = "./output/test_keras_model.keras" keras_model.save(model_path) # print("Loading keras model from disk") - # loaded_keras_model = tf.keras.models.load_model("./output/test_keras_model/") + # loaded_keras_model = keras.models.load_model("./output/test_keras_model.keras") inputs = [ tf.constant( [[["EXPEDIA"], ["EXPEDIA.._UK"], ["EXPEDIA_.UK_4EVA.UK_4EV_WHEHEIW"]]] diff --git a/pyproject.toml b/pyproject.toml index a4adc07e..ed5d9050 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,20 +8,25 @@ authors = [ readme = "README.md" license = "Apache-2.0" license-files = ["LICENSE.txt"] -requires-python = ">=3.8.1,<3.13" +requires-python = ">=3.10,<3.13" dependencies = [ "pyspark>=3.4.0,<4.0.0", "pandas>=1.3.4,<3.0.0", "networkx>=2.6.3,<3.0.0", "pyfarmhash>=0.3.2,<0.4.0", - "keras-tuner>=1.0.4,<2.0.0", - "scikit-learn>=1.0.0,<2.0.0", - "joblib>=1.0.0,<2.0.0", + "keras>=3.0.0,<4.0.0", + "keras-tuner>=1.4.0,<2.0.0", "numpy>=1.22.0,<2.0.0", - "tensorflow>=2.9.1,<2.19.0", + "tensorflow>=2.16.0,<3.0.0", "dill>=0.3.0,<1.0.0", ] +[project.optional-dependencies] +# JAX backend (for future multi-backend support) +jax = ["jax>=0.4.0", "jaxlib>=0.4.0"] +# PyTorch backend (for future multi-backend support) +torch = ["torch>=2.0.0"] + [dependency-groups] dev = [ "pre-commit>=3.3.3,<4", diff --git a/src/kamae/__init__.py b/src/kamae/__init__.py index b0206498..8141aff4 100644 --- a/src/kamae/__init__.py +++ b/src/kamae/__init__.py @@ -21,3 +21,10 @@ __version__ = "2.40.0" __name__ = "kamae" + +from .discovery import ( # noqa: F401 + get_compatible_layers, + get_compatible_transformers, + get_jit_compatible_layers, + get_jit_compatible_transformers, +) diff --git a/src/kamae/discovery.py b/src/kamae/discovery.py new file mode 100644 index 00000000..c2b2b7fb --- /dev/null +++ b/src/kamae/discovery.py @@ -0,0 +1,169 @@ +# Copyright [2024] Expedia, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Discovery utilities for finding backend-compatible layers and transformers. +""" + +import inspect +from typing import Any, Callable, Dict, Union + +import kamae.keras.core.layers as core_layers +import kamae.keras.tensorflow.layers as tf_layers +import kamae.spark.estimators as estimators +import kamae.spark.transformers as transformers +from kamae.keras.core.backend import ALL_BACKENDS, current_backend +from kamae.keras.core.base import BaseLayer +from kamae.spark.estimators.base import BaseEstimator +from kamae.spark.transformers.base import BaseTransformer + + +def _inspect_modules( + modules: list[Any], attribute: str, condition: Callable[[Any], bool] +) -> Dict[str, type]: + """ + Helper to inspect multiple modules for classes matching a condition. + + :param modules: List of modules to inspect + :param attribute: Attribute name to check on each class + :param condition: Function that returns True if the attribute value matches + :returns: Dict mapping class names to class objects + """ + compatible = {} + for module in modules: + for name, obj in inspect.getmembers(module, inspect.isclass): + if hasattr(obj, attribute) and condition(getattr(obj, attribute)): + compatible[name] = obj + return compatible + + +def get_compatible_layers(backend: str = None) -> Dict[str, type[BaseLayer]]: + """ + Returns a dict of Keras layer classes compatible with the specified backend. + + :param backend: Backend name ('tensorflow', 'jax', or 'torch'). If None, uses + the current backend. + :returns: Dict mapping layer names to layer class objects that work on the + specified backend. + :raises ValueError: If backend name is invalid. + + Example: + >>> from kamae.discovery import get_compatible_layers + >>> # Get layers that work on JAX + >>> jax_layers = get_compatible_layers('jax') + >>> # Instantiate a layer by name + >>> layer = jax_layers['MultiplyLayer'](multiplier=2.0) + >>> # List available layer names + >>> print(list(jax_layers.keys())) + """ + if backend is None: + backend = current_backend() + + if backend not in ALL_BACKENDS: + raise ValueError( + f"Invalid backend '{backend}'. Must be one of {sorted(ALL_BACKENDS)}" + ) + + return _inspect_modules( + modules=[core_layers, tf_layers], + attribute="supported_backends", + condition=lambda backends: backend in backends, + ) + + +def get_compatible_transformers( + backend: str = None, +) -> Dict[str, Union[type[BaseTransformer], type[BaseEstimator]]]: + """ + Returns a dict of Spark transformer/estimator classes compatible with the + specified backend. + + :param backend: Backend name ('tensorflow', 'jax', or 'torch'). If None, uses + the current backend. + :returns: Dict mapping transformer/estimator names to class objects that work + on the specified backend. + :raises ValueError: If backend name is invalid. + + Example: + >>> from kamae.discovery import get_compatible_transformers + >>> # Get transformers that work on PyTorch + >>> torch_transformers = get_compatible_transformers('torch') + >>> # Instantiate a transformer by name + >>> transformer = torch_transformers['LogTransformer'](inputCol="x", outputCol="y") + >>> # List available transformer names + >>> print(list(torch_transformers.keys())) + """ + if backend is None: + backend = current_backend() + + if backend not in ALL_BACKENDS: + raise ValueError( + f"Invalid backend '{backend}'. Must be one of {sorted(ALL_BACKENDS)}" + ) + + return _inspect_modules( + modules=[transformers, estimators], + attribute="supported_backends", + condition=lambda backends: backend in backends, + ) + + +def get_jit_compatible_layers() -> Dict[str, type[BaseLayer]]: + """ + Returns a dict of Keras layer classes that are JIT-compatible. + + JIT-compatible layers can be compiled with @tf.function or jax.jit for improved + performance. + + :returns: Dict mapping layer names to JIT-compatible layer class objects. + + Example: + >>> from kamae.discovery import get_jit_compatible_layers + >>> jit_layers = get_jit_compatible_layers() + >>> # Instantiate a JIT-compatible layer by name + >>> layer = jit_layers['MultiplyLayer'](multiplier=2.0) + >>> # See how many JIT-compatible layers exist + >>> print(f"Found {len(jit_layers)} JIT-compatible layers") + """ + return _inspect_modules( + modules=[core_layers, tf_layers], + attribute="jit_compatible", + condition=lambda jit: jit is True, + ) + + +def get_jit_compatible_transformers() -> ( + Dict[str, Union[type[BaseTransformer], type[BaseEstimator]]] +): + """ + Returns a dict of Spark transformer/estimator classes that are JIT-compatible. + + JIT-compatible transformers generate Keras layers that can be compiled with + @tf.function or jax.jit for improved performance. + + :returns: Dict mapping transformer/estimator names to JIT-compatible class objects. + + Example: + >>> from kamae.discovery import get_jit_compatible_transformers + >>> jit_transformers = get_jit_compatible_transformers() + >>> # Instantiate a JIT-compatible transformer by name + >>> transformer = jit_transformers['LogTransformer'](inputCol="x", outputCol="y") + >>> # See all JIT-compatible transformer names + >>> print(list(jit_transformers.keys())) + """ + return _inspect_modules( + modules=[transformers, estimators], + attribute="jit_compatible", + condition=lambda jit: jit is True, + ) diff --git a/src/kamae/graph/pipeline_graph.py b/src/kamae/graph/pipeline_graph.py index bee0e877..ad784795 100644 --- a/src/kamae/graph/pipeline_graph.py +++ b/src/kamae/graph/pipeline_graph.py @@ -17,10 +17,7 @@ import keras import keras_tuner import networkx as nx -import tensorflow as tf -from packaging.version import Version - -keras_version = Version(keras.__version__) +from keras import KerasTensor class PipelineGraph: @@ -33,7 +30,7 @@ class PipelineGraph: The graph is then topologically sorted to determine the order in which the layers should be constructed. Iterating through this order, the layers are constructed by - calling the get_tf_layer method of each stage. The inputs to the layer are + calling the get_keras_layer method of each stage. The inputs to the layer are determined by the outputs of the previous layers. """ @@ -56,7 +53,7 @@ def __init__(self, stage_dict: Dict[str, Any]) -> None: self.inputs = {} def update_layer_store_with_key( - self, layer_key: str, layer_output: tf.Tensor + self, layer_key: str, layer_output: KerasTensor ) -> None: """ Updates the layer store at a specific key with the layer output and whether @@ -72,7 +69,7 @@ def update_layer_store_with_key( else: self.layer_store[layer_key] = {"output": layer_output, "reused": False} - def update_layer_store(self, layer_dict: Dict[str, tf.Tensor]) -> None: + def update_layer_store(self, layer_dict: Dict[str, KerasTensor]) -> None: """ Given a dictionary of layer output names and tensor outputs, update the layer store. @@ -83,7 +80,7 @@ def update_layer_store(self, layer_dict: Dict[str, tf.Tensor]) -> None: for name, output in layer_dict.items(): self.update_layer_store_with_key(layer_key=name, layer_output=output) - def get_layer_output_from_layer_store(self, layer_output_name: str) -> tf.Tensor: + def get_layer_output_from_layer_store(self, layer_output_name: str) -> KerasTensor: """ Given a layer name and index, get the output from the layer store. @@ -121,7 +118,7 @@ def add_stage_edges(self, graph: nx.DiGraph) -> nx.DiGraph: def get_model_outputs( self, output_names: Optional[List[str]] = None - ) -> Dict[str, tf.Tensor]: + ) -> Dict[str, KerasTensor]: """ Gets the outputs of the model. If output_names is provided, we use this to find the outputs for the model. Otherwise, the outputs are those that are not reused @@ -143,9 +140,7 @@ def get_model_outputs( k: v["output"] for k, v in self.layer_store.items() if k in output_names } - def build_keras_inputs( - self, tf_input_schema: Union[List[tf.TypeSpec], List[Dict[str, Any]]] - ) -> None: + def build_keras_inputs(self, input_schema: List[Dict[str, Any]]) -> None: """ Builds a Keras input layer for the given node. @@ -157,32 +152,17 @@ def build_keras_inputs( keras input layer. We then store this layer as an input and update the layer store. - :param tf_input_schema: List of tf.TypeSpec objects containing the input schema - for the model or a list of dict config to be passed to the Input constructor. + :param input_schema: List of dict config to be passed to the Input constructor. :returns: None - layer store is updated and input layer is stored in the inputs dict. """ - if isinstance(tf_input_schema, list) and all( - isinstance(x, tf.TypeSpec) for x in tf_input_schema - ): - if keras_version >= Version("3.0.0"): - raise ValueError( - "tf.TypeSpec is not supported in Keras 3, please use a dict config" - ) - input_config = [ - { - "name": spec.name, - "type_spec": spec, - } - for spec in tf_input_schema - ] - elif isinstance(tf_input_schema, list) and all( - isinstance(x, dict) for x in tf_input_schema + if not isinstance(input_schema, list) or not all( + isinstance(x, dict) for x in input_schema ): - input_config = tf_input_schema - else: - raise ValueError("tf_input_schema must be a list of tf.TypeSpec or dict!") + raise ValueError("input_schema must be a list of dict!") + + input_config = input_schema for conf in input_config: name = conf.get("name", None) @@ -190,13 +170,13 @@ def build_keras_inputs( raise ValueError( "Input schema must have names for all inputs, but found None" ) - input_layer = tf.keras.layers.Input(**conf) + input_layer = keras.layers.Input(**conf) self.inputs[name] = input_layer self.update_layer_store_with_key(layer_key=name, layer_output=input_layer) def sort_inputs( - self, layer_name: str, input_dict: Dict[str, tf.Tensor] - ) -> List[tf.Tensor]: + self, layer_name: str, input_dict: Dict[str, KerasTensor] + ) -> List[KerasTensor]: """ Sorts the inputs for a given layer based on the order of the inputs in the stage dict. This is needed because layers with multiple inputs are not @@ -212,7 +192,7 @@ def sort_inputs( def build_transform_layer_inputs( self, node: str, in_edges: List[Tuple[str, str]] - ) -> List[tf.Tensor]: + ) -> List[KerasTensor]: """ Constructs all the layers that are connected to the current node. These are either input layers or the outputs of previous layers. @@ -271,9 +251,9 @@ def build_transform_layer_inputs( @staticmethod def override_hyperparameters( - layer: Union[tf.keras.layers.Layer, List[tf.keras.layers.Layer]], + layer: Union[keras.layers.Layer, List[keras.layers.Layer]], hp_override: Dict[str, Any] = None, - ) -> Union[tf.keras.layers.Layer, List[tf.keras.layers.Layer]]: + ) -> Union[keras.layers.Layer, List[keras.layers.Layer]]: """ Overrides layer arguments with hyperparameters provided in the hyperparameter override dictionary. @@ -284,8 +264,8 @@ def override_hyperparameters( """ def update_layer( - layer: tf.keras.layers.Layer, hp_override: Dict[str, Any] - ) -> tf.keras.layers.Layer: + layer: keras.layers.Layer, hp_override: Dict[str, Any] + ) -> keras.layers.Layer: config = layer.get_config() config.update(hp_override) updated_layer = type(layer).from_config(config) @@ -393,17 +373,17 @@ def get_keras_hyperparam_from_config( def get_keras_tuner_model_builder( self, - tf_input_schema: Union[List[tf.TypeSpec], List[Dict[str, Any]]], + input_schema: List[Dict[str, Any]], hp_dict: Dict[str, List[Dict[str, Any]]], output_names: Optional[List[str]] = None, - ) -> Callable[[keras_tuner.HyperParameters], tf.keras.Model]: + ) -> Callable[[keras_tuner.HyperParameters], keras.Model]: """ Returns a Keras tuner model builder function for the current graph. This allows the user to tune the hyperparameters of the preprocessing model. Useful for scenarios where the best preprocessing variables are not known a priori. For example, the num_bins to use for a HashIndexLayer. - :param tf_input_schema: List of tf.TypeSpec objects containing the input schema + :param input_schema: List of dict config containing the input schema for the model. Specifically the name, shape and dtype of each input. These will be passed as is to the Keras Input layer. :param hp_dict: Dictionary of possible hyperparameters for each layer. @@ -427,12 +407,12 @@ def get_keras_tuner_model_builder( transform_order = self.transform_order - def keras_model_builder(hp: keras_tuner.HyperParameters) -> tf.keras.Model: + def keras_model_builder(hp: keras_tuner.HyperParameters) -> keras.Model: # We need to clear the layer store and inputs each time we build a model. self.layer_store = {} self.inputs = {} # Build the input layers - self.build_keras_inputs(tf_input_schema=tf_input_schema) + self.build_keras_inputs(input_schema=input_schema) for node in transform_order: in_edges = list(self.graph.in_edges(node)) @@ -449,7 +429,7 @@ def keras_model_builder(hp: keras_tuner.HyperParameters) -> tf.keras.Model: ) sorted_inputs = [self.inputs[k] for k in sorted(self.inputs)] - return tf.keras.Model( + return keras.Model( inputs=sorted_inputs, outputs=self.get_model_outputs(output_names=output_names), ) @@ -458,21 +438,21 @@ def keras_model_builder(hp: keras_tuner.HyperParameters) -> tf.keras.Model: def build_keras_model( self, - tf_input_schema: Union[List[tf.TypeSpec], List[Dict[str, Any]]], + input_schema: List[Dict[str, Any]], output_names: Optional[List[str]] = None, - ) -> tf.keras.Model: + ) -> keras.Model: """ Builds a Keras model from the graph. - :param tf_input_schema: List of tf.TypeSpec objects containing the input schema - for the model. Each TypeSpec object must define a unique `name` attribute. + :param input_schema: List of dict config containing the input schema + for the model. Each dict must have a 'name' key. These will be passed as is to the Keras Input layer. :param output_names: Optional list of output names for the Keras model. If provided, only the outputs specified are used as model outputs. :returns: Keras model to be applied to a tensors dictionary: {name: Tensor}. """ # Build the input layers - self.build_keras_inputs(tf_input_schema=tf_input_schema) + self.build_keras_inputs(input_schema=input_schema) for node in self.transform_order: in_edges = list(self.graph.in_edges(node)) @@ -482,7 +462,7 @@ def build_keras_model( # with all inputs/outputs specified. # We can now build the model by specifying the inputs and outputs. sorted_inputs = {k: self.inputs[k] for k in sorted(self.inputs)} - return tf.keras.Model( + return keras.Model( inputs=sorted_inputs, outputs=self.get_model_outputs(output_names=output_names), ) diff --git a/src/kamae/sklearn/pipeline/__init__.py b/src/kamae/keras/__init__.py similarity index 72% rename from src/kamae/sklearn/pipeline/__init__.py rename to src/kamae/keras/__init__.py index ead1d06b..f5864d2d 100644 --- a/src/kamae/sklearn/pipeline/__init__.py +++ b/src/kamae/keras/__init__.py @@ -12,4 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .pipeline import KamaeSklearnPipeline # noqa: F401 +""" +Kamae Keras 3 module with multi-backend support. + +This package provides: +- keras.core: Backend-agnostic layers (numeric operations only) +- keras.tensorflow: TensorFlow-specific layers (strings, datetime, TF-only ops) +""" diff --git a/src/kamae/sklearn/estimators/__init__.py b/src/kamae/keras/core/__init__.py similarity index 72% rename from src/kamae/sklearn/estimators/__init__.py rename to src/kamae/keras/core/__init__.py index 5c2460ed..789f1a17 100644 --- a/src/kamae/sklearn/estimators/__init__.py +++ b/src/kamae/keras/core/__init__.py @@ -12,4 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .standard_scale import StandardScaleEstimator # noqa: F401 +""" +Backend-agnostic Keras layers for numeric operations. + +These layers work with TensorFlow, JAX, and PyTorch backends via keras.ops. +They do NOT handle string or datetime operations (use keras.tensorflow for those). +""" diff --git a/src/kamae/keras/core/backend.py b/src/kamae/keras/core/backend.py new file mode 100644 index 00000000..76f08c09 --- /dev/null +++ b/src/kamae/keras/core/backend.py @@ -0,0 +1,68 @@ +# Copyright [2024] Expedia, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Backend detection and enforcement utilities for Keras 3 multi-backend support. +""" + +from typing import FrozenSet + +import keras + +ALL_BACKENDS: FrozenSet[str] = frozenset({"tensorflow", "jax", "torch"}) +TENSORFLOW_ONLY: FrozenSet[str] = frozenset({"tensorflow"}) + + +def current_backend() -> str: + """ + Returns the current Keras backend. + + :returns: Backend name: 'tensorflow', 'jax', or 'torch' + """ + return keras.backend.backend() + + +def require_tensorflow() -> None: + """ + Raises RuntimeError if not running on TensorFlow backend. + + This should be called in the __init__ of TensorFlow-only layers + to fail fast with a clear error message. + + :raises RuntimeError: If current backend is not TensorFlow + """ + backend = current_backend() + if backend != "tensorflow": + raise RuntimeError( + f"This layer requires TensorFlow backend. " + f"Current backend: {backend}. " + f"Set KERAS_BACKEND=tensorflow before importing keras." + ) + + +def validate_backend(class_name: str, supported_backends: FrozenSet[str]) -> None: + """ + Validates that the current backend is supported by the layer/operation. + + :param class_name: Name of the class being validated + :param supported_backends: Frozenset of supported backend names + :raises RuntimeError: If current backend is not in supported_backends + """ + backend = current_backend() + if backend not in supported_backends: + raise RuntimeError( + f"{class_name} requires one of {sorted(supported_backends)} backends. " + f"Current backend: '{backend}'. " + f"Set KERAS_BACKEND=tensorflow before importing keras." + ) diff --git a/src/kamae/tensorflow/layers/base.py b/src/kamae/keras/core/base.py similarity index 52% rename from src/kamae/tensorflow/layers/base.py rename to src/kamae/keras/core/base.py index 507ca332..e7c6be61 100644 --- a/src/kamae/tensorflow/layers/base.py +++ b/src/kamae/keras/core/base.py @@ -12,24 +12,53 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +Multi-backend base layer with string support on TensorFlow backend. + +This base layer provides casting and dtype validation for layers that work across +TensorFlow, JAX, and PyTorch backends. + +String operations (input_dtype="string" or output_dtype="string") are supported +only when running on TensorFlow backend. Multi-backend numeric operations work +on all backends. +""" + from abc import ABC, abstractmethod from functools import reduce from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +import keras import tensorflow as tf +from keras import KerasTensor, ops import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import allow_single_or_multiple_tensor_input +from kamae.keras.core.backend import ( + current_backend, + require_tensorflow, + validate_backend, +) +from kamae.keras.core.utils.input_utils import allow_single_or_multiple_tensor_input -@tf.keras.utils.register_keras_serializable(package=kamae.__name__) -class BaseLayer(tf.keras.layers.Layer, ABC): +@keras.saving.register_keras_serializable(package=kamae.__name__) +class BaseLayer(keras.layers.Layer, ABC): """ - Abstract base layer that performs casting of inputs and outputs to specified - data types. All layers should inherit from this class. + Abstract base layer for multi-backend layers with TensorFlow string support. + + Provides: + - Multi-backend numeric dtype casting (works on TensorFlow, JAX, PyTorch) + - String dtype casting (TensorFlow backend only) + - Dtype compatibility validation + - Numeric constant type coercion + - Boolean string parsing (TensorFlow backend only) + + String operations automatically work when running on TensorFlow backend. + Attempting to use string dtypes on JAX or PyTorch backends raises an error. """ + supported_backends: frozenset + jit_compatible: bool + def __init__( self, name: Optional[str] = None, @@ -38,7 +67,7 @@ def __init__( **kwargs: Any, ) -> None: """ - Initialises the BaseLayer. + Initializes the BaseLayer. :param name: Name of the layer, defaults to `None`. :param input_dtype: Input data type of the layer. If specified, inputs will be @@ -46,12 +75,11 @@ def __init__( :param output_dtype: Output data type of the layer. Defaults to `None`. If specified, the output will be cast to this data type before being returned. """ + validate_backend(self.__class__.__name__, self.supported_backends) super().__init__(name=name, **kwargs) - # We handle casting of inputs and outputs in the call method - # Allowing keras to also autocast causes issues in some layers that require - # 64 bit precision. Such as timestamp layers after the year 2038. + # Disable Keras automatic casting to prevent float32 coercion + # This is critical for layers that require 64-bit precision (e.g., timestamps) self._autocast = False - # Needed to ensure keras 3 does not autocast inputs to float32 self._convert_input_args = False self._input_dtype = input_dtype self._output_dtype = output_dtype @@ -60,25 +88,27 @@ def __init__( @property @abstractmethod - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ - List of compatible data types for the layer. + List of compatible data type names for the layer. If the computation can be performed on any data type, return None. - :returns: List of compatible data types for the layer. + :returns: List of compatible dtype names (e.g., ['float32', 'float64']) + or None if any dtype is compatible. """ raise NotImplementedError - def _string_to_bool_cast(self, inputs: Tensor) -> Tensor: + def _string_to_bool_cast(self, inputs: KerasTensor) -> KerasTensor: """ Casts a string tensor to a bool tensor. :param inputs: Input string tensor :returns: Bool tensor. """ - if inputs.dtype.name != "string": + if keras.backend.standardize_dtype(inputs.dtype) != "string": raise TypeError( - f"Expected a string tensor, but got a {inputs.dtype.name} tensor." + f"Expected a string tensor, but got a " + f"{keras.backend.standardize_dtype(inputs.dtype)} tensor." ) # Replace true strings with "1" and false strings with "0" @@ -111,12 +141,12 @@ def _string_to_bool_cast(self, inputs: Tensor) -> Tensor: ) bool_float_tensor = tf.strings.to_number( - string_bool_tensor_with_invalid, out_type=tf.float32 + string_bool_tensor_with_invalid, out_type="float32" ) return tf.cast(bool_float_tensor, tf.bool) @staticmethod - def _float_to_string_cast(inputs: Tensor) -> Tensor: + def _float_to_string_cast(inputs: KerasTensor) -> KerasTensor: """ Casts a float tensor to a string tensor. Ensures that the precision of the float does not impact the string representation. Specifically, we want the string @@ -149,18 +179,18 @@ def _float_to_string_cast(inputs: Tensor) -> Tensor: shortest_float_string, ) - def _to_string_cast(self, inputs: Tensor) -> Tensor: + def _to_string_cast(self, inputs: KerasTensor) -> KerasTensor: """ Casts inputs to string tensor. :param inputs: Input tensor. :returns: String tensor. """ - if inputs.dtype.is_floating: + if "float" in keras.backend.standardize_dtype(inputs.dtype): return self._float_to_string_cast(inputs) return tf.strings.as_string(inputs) - def _from_string_cast(self, inputs: Tensor, cast_dtype: str) -> Tensor: + def _from_string_cast(self, inputs: KerasTensor, cast_dtype: str) -> KerasTensor: """ Casts inputs to the desired dtype when inputs are a string tensor. @@ -168,107 +198,168 @@ def _from_string_cast(self, inputs: Tensor, cast_dtype: str) -> Tensor: :param cast_dtype: Dtype to cast to. :returns: Tensor cast to the desired dtype. """ - if inputs.dtype.name != "string": + if keras.backend.standardize_dtype(inputs.dtype) != "string": raise TypeError("inputs is not a string Tensor.") if cast_dtype in ["float32", "float64", "int32", "int64"]: - # If the casting dtype is supported by tf.strings.to_number, we use that. return tf.strings.to_number(inputs, out_type=cast_dtype) - elif tf.as_dtype(cast_dtype).is_integer: - # If the casting dtype is an integer, we need to cast to int64 first + elif "int" in cast_dtype: intermediate_cast = tf.strings.to_number(inputs, out_type="int64") - return tf.cast(intermediate_cast, cast_dtype) - elif tf.as_dtype(cast_dtype).is_floating: - # If the casting dtype is a float, we need to cast to float64 first + return ops.cast(intermediate_cast, cast_dtype) + elif "float" in cast_dtype: intermediate_cast = tf.strings.to_number(inputs, out_type="float64") - return tf.cast(intermediate_cast, cast_dtype) - elif tf.as_dtype(cast_dtype).is_bool: - # If the casting dtype is a boolean, we need to use a custom function - # to cast the string to boolean. + return ops.cast(intermediate_cast, cast_dtype) + elif cast_dtype == "bool": return self._string_to_bool_cast(inputs) else: raise TypeError(f"Casting string to dtype {cast_dtype} is not supported.") - def _string_cast(self, inputs: Tensor, cast_dtype: str) -> Tensor: + def _string_cast(self, inputs: KerasTensor, cast_dtype: str) -> KerasTensor: """ Casts from and to string tensors. Either inputs is a string tensor, and we want to cast it to the desired dtype, or inputs is not a string tensor, and we want to cast it to a string tensor. + Requires TensorFlow backend. + :param inputs: Input tensor. :param cast_dtype: Dtype to cast to. :returns: Tensor cast to the desired dtype. """ - if inputs.dtype.name == "string" and cast_dtype == "string": + require_tensorflow() + + if ( + keras.backend.standardize_dtype(inputs.dtype) == "string" + and cast_dtype == "string" + ): return inputs if cast_dtype == "string": return self._to_string_cast(inputs) return self._from_string_cast(inputs, cast_dtype) @staticmethod - def _numeric_cast(inputs: Tensor, cast_dtype: str) -> Tensor: + def _check_string_dtype_backend_compatibility(dtype_str: str) -> None: """ - Casts a numeric tensor to the desired (non-string) dtype. + Check if string dtype is used on a non-TensorFlow backend. + + String operations are only supported on TensorFlow backend. JAX and PyTorch + do not support string tensors. + + :param dtype_str: Dtype string to check (e.g., 'float32', 'string') + :raises RuntimeError: If string dtype is used on JAX or PyTorch backend. + """ + if dtype_str == "string": + backend = keras.backend.backend() + if backend != "tensorflow": + raise RuntimeError( + f"String dtype is not supported on '{backend}' backend. " + f"String operations require TensorFlow backend. " + f"Set KERAS_BACKEND=tensorflow before importing keras." + ) + + @staticmethod + def _numeric_cast(inputs: KerasTensor, cast_dtype: str) -> KerasTensor: + """ + Casts a numeric tensor to the desired dtype using keras.ops. :param inputs: Input numeric tensor - :param cast_dtype: Dtype to cast to. + :param cast_dtype: Dtype to cast to (e.g., 'float32', 'int64') :returns: Tensor cast to the desired dtype. """ - return tf.cast(inputs, cast_dtype) + # keras.ops.cast doesn't support string dtype, even on TF backend + # Check if we're on TF backend and dealing with strings + if cast_dtype == "string" or ( + hasattr(inputs, "dtype") + and keras.backend.standardize_dtype(inputs.dtype) == "string" + ): + if keras.backend.backend() == "tensorflow": + return ( + tf.strings.as_string(inputs) + if cast_dtype == "string" + else tf.cast(inputs, cast_dtype) + ) + else: + # String operations not supported on JAX/PyTorch backends + raise ValueError( + f"String dtype casting not supported on {keras.backend.backend()} backend. " + "String operations require TensorFlow backend." + ) + return ops.cast(inputs, cast_dtype) - def _cast(self, inputs: Tensor, cast_dtype: str) -> Tensor: + def _cast(self, inputs: KerasTensor, cast_dtype: str) -> KerasTensor: """ Casts inputs to the desired dtype. + Routes to string casting when string dtype is involved (TensorFlow backend only), + otherwise uses numeric casting for multi-backend compatibility. + :param inputs: Input tensor. :param cast_dtype: Dtype to cast to. :returns: Tensor cast to the desired dtype. """ - if inputs.dtype.name == "string" or cast_dtype == "string": - # If input tensor is a string tensor, or we are casting to a string, - # we need to use the string_cast function. + # Check if string dtype is involved + if ( + keras.backend.standardize_dtype(inputs.dtype) == "string" + or cast_dtype == "string" + ): return self._string_cast(inputs, cast_dtype) - else: - return self._numeric_cast(inputs, cast_dtype) + return self._numeric_cast(inputs, cast_dtype) def _force_cast_to_compatible_numeric_type( - self, inputs: Tensor, constant: Union[float, int] - ) -> Tuple[Tensor, Tensor]: + self, inputs: KerasTensor, constant: Union[float, int] + ) -> Tuple[KerasTensor, KerasTensor]: """ - Casts an input tensor and a single constant to compatible tensors. + Casts an input tensor and a single constant to compatible numeric tensors. - If the provided input is a float, create the constant tensor as a float of the - same precision. If the provided input is an integer, check if the constant is - non-floating, and if so, create the constant tensor as an integer of the same - precision. If the constant is floating, cast the input to a float with the same - precision as its integer dtype and create the constant tensor likewise. + This ensures operations between tensors and constants work correctly: + - If input is float, constant becomes float of same precision + - If input is int and constant is int, keep as int of same precision + - If input is int but constant is float, cast input to float :param inputs: Input numeric tensor :param constant: The constant to cast to the compatible dtype. - :returns: Tuple of tensors cast to compatible types + :returns: Tuple of (cast_input, cast_constant) with compatible types """ - if inputs.dtype.is_floating: + input_dtype = keras.backend.standardize_dtype(inputs.dtype) + + # Check if dtype is floating point + if "float" in input_dtype: + # Input is float - cast constant to same precision if isinstance(constant, float): - return inputs, tf.constant(constant, dtype=inputs.dtype) - return inputs, tf.constant(float(constant), dtype=inputs.dtype) - if inputs.dtype.is_integer: + return inputs, ops.convert_to_tensor(constant, dtype=input_dtype) + return inputs, ops.convert_to_tensor(float(constant), dtype=input_dtype) + + # Check if dtype is integer + if "int" in input_dtype: + # Input is integer if isinstance(constant, int): - return inputs, tf.constant(constant, dtype=inputs.dtype) + # Constant is also int - keep as int + return inputs, ops.convert_to_tensor(constant, dtype=input_dtype) + if isinstance(constant, float) and constant.is_integer(): - return inputs, tf.constant(int(constant), dtype=inputs.dtype) + # Constant is float but represents an integer + return inputs, ops.convert_to_tensor(int(constant), dtype=input_dtype) + if isinstance(constant, float): - precision = inputs.dtype.size * 8 + # Constant is truly float - need to cast input to float + # Extract precision (e.g., int32 -> 32 bits) + if "64" in input_dtype: + float_dtype = "float64" + else: + float_dtype = "float32" return ( - self._cast(inputs, f"float{precision}"), - tf.constant(constant, dtype=f"float{precision}"), + self._cast(inputs, float_dtype), + ops.convert_to_tensor(constant, dtype=float_dtype), ) + raise TypeError( - "inputs must be a numeric tensor and constant must be a numeric value." + f"inputs must be a numeric tensor (got {input_dtype}) " + f"and constant must be a numeric value (got {type(constant)})." ) def _cast_input_output_tensors( - self, tensors: Union[Tensor, List[Tensor]], ingress: bool - ) -> Union[Tensor, List[Tensor]]: + self, tensors: Union[KerasTensor, List[KerasTensor]], ingress: bool + ) -> Union[KerasTensor, List[KerasTensor]]: """ Casts either the input or output tensors to the given input/output dtype, if specified. Ingress is a boolean that indicates whether we are casting the @@ -281,41 +372,45 @@ def _cast_input_output_tensors( """ if ingress: cast_dtype = self._input_dtype + # Validate input_dtype is compatible if ( cast_dtype is not None and self.compatible_dtypes is not None - and cast_dtype not in [dtype.name for dtype in self.compatible_dtypes] + and cast_dtype not in self.compatible_dtypes ): raise ValueError( - f"""input_dtype {cast_dtype} is not a compatible dtype for - this layer. Compatible dtypes are {[ - dtype.name for dtype in self.compatible_dtypes - ]}.""" + f"input_dtype {cast_dtype} is not a compatible dtype for " + f"this layer. Compatible dtypes are {self.compatible_dtypes}." ) else: cast_dtype = self._output_dtype if cast_dtype is not None: - if tf.is_tensor(tensors): + # Check if string dtype is being used on non-TF backend + self._check_string_dtype_backend_compatibility(cast_dtype) + # Check if tensors is a single tensor + if not isinstance(tensors, list): + current_dtype = keras.backend.standardize_dtype(tensors.dtype) return ( self._cast(tensors, cast_dtype) - if tensors.dtype.name != cast_dtype + if current_dtype != cast_dtype else tensors ) + # Handle list of tensors return [ - self._cast(inp, cast_dtype) if inp.dtype.name != cast_dtype else inp + self._cast(inp, cast_dtype) + if keras.backend.standardize_dtype(inp.dtype) != cast_dtype + else inp for inp in tensors ] return tensors def cast_input_tensors( - self, inputs: Union[Tensor, List[Tensor]] - ) -> Union[Tensor, List[Tensor]]: + self, inputs: Union[KerasTensor, List[KerasTensor]] + ) -> Union[KerasTensor, List[KerasTensor]]: """ Casts the input tensors to the given input dtype, if specified. All tensors are - cast to this. This might not be ideal, there may be layers where some inputs are - expected to be different types. In these cases, the subclass should - implement the cast_input_tensors method. + cast to this. Subclasses can override for more complex casting behavior. :param inputs: The input tensor(s) to the layer. :returns: The input tensor(s) cast to the desired input_dtype. @@ -323,20 +418,18 @@ def cast_input_tensors( return self._cast_input_output_tensors(tensors=inputs, ingress=True) def cast_output_tensors( - self, outputs: Union[Tensor, List[Tensor]] - ) -> Union[Tensor, List[Tensor]]: + self, outputs: Union[KerasTensor, List[KerasTensor]] + ) -> Union[KerasTensor, List[KerasTensor]]: """ Casts the output tensors to the given output dtype, if specified. All tensors - are cast to this. This might not be ideal, there may be layers where some - outputs are expected to be different types. In these cases, the subclass should - implement the cast_output_tensors method. + are cast to this. Subclasses can override for more complex casting behavior. :param outputs: The output tensor(s) of the layer. :returns: The output tensor(s) cast to the desired output_dtype. """ return self._cast_input_output_tensors(tensors=outputs, ingress=False) - def _check_input_dtypes_compatible(self, inputs: List[Tensor]) -> None: + def _check_input_dtypes_compatible(self, inputs: List[KerasTensor]) -> None: """ Checks if the input tensors are compatible with the compatible_dtypes of the layer. @@ -346,23 +439,26 @@ def _check_input_dtypes_compatible(self, inputs: List[Tensor]) -> None: compatible_dtypes of the layer. :returns: None """ + if self.compatible_dtypes is None: + # Any dtype is compatible, but check for string dtype on non-TF backends + for inp in inputs: + inp_dtype = keras.backend.standardize_dtype(inp.dtype) + self._check_string_dtype_backend_compatibility(inp_dtype) + return + for inp in inputs: - if ( - self.compatible_dtypes is not None - and inp.dtype not in self.compatible_dtypes - ): + inp_dtype = keras.backend.standardize_dtype(inp.dtype) + if inp_dtype not in self.compatible_dtypes: raise TypeError( - f"""Input tensor with dtype {inp.dtype.name} - is not a compatible dtype for this layer. - Compatible dtypes are {[ - dtype.name for dtype in self.compatible_dtypes - ]}.""" + f"Input tensor with dtype {inp_dtype} " + f"is not a compatible dtype for this layer. " + f"Compatible dtypes are {self.compatible_dtypes}." ) @allow_single_or_multiple_tensor_input def call( - self, inputs: Iterable[Tensor], **kwargs: Any - ) -> Union[Tensor, List[Tensor]]: + self, inputs: Iterable[KerasTensor], **kwargs: Any + ) -> Union[KerasTensor, List[KerasTensor]]: """ Casts inputs to the given `input_dtype`, calls the internal `_call` method, and casts the outputs to the given `output_dtype`. @@ -382,13 +478,16 @@ def call( @abstractmethod def _call( - self, inputs: Union[Tensor, List[Tensor]], **kwargs: Any - ) -> Union[Tensor, List[Tensor]]: + self, inputs: Union[KerasTensor, List[KerasTensor]], **kwargs: Any + ) -> Union[KerasTensor, List[KerasTensor]]: """ The internal call method that should be implemented by the layer. - :param inputs: The input tensor(s) to the layer. - :returns: The output tensor(s) of the layer. + Subclasses implement this method to define the layer's computation. + Input and output casting is handled by the base class `call()` method. + + :param inputs: The input tensor(s) to the layer (after input casting). + :returns: The output tensor(s) of the layer (before output casting). """ raise NotImplementedError diff --git a/src/kamae/keras/core/layers/__init__.py b/src/kamae/keras/core/layers/__init__.py new file mode 100644 index 00000000..474df48c --- /dev/null +++ b/src/kamae/keras/core/layers/__init__.py @@ -0,0 +1,89 @@ +# Copyright [2024] Expedia, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Backend-agnostic Keras layers. + +Multi-backend layers that work across TensorFlow, JAX, and PyTorch backends. +""" + +from .absolute_value import AbsoluteValueLayer +from .array_concatenate import ArrayConcatenateLayer +from .array_crop import ArrayCropLayer +from .array_reduce_max import ArrayReduceMaxLayer +from .array_split import ArraySplitLayer +from .array_subtract_minimum import ArraySubtractMinimumLayer +from .bearing_angle import BearingAngleLayer +from .bin import BinLayer +from .conditional_standard_scale import ConditionalStandardScaleLayer +from .cosine_similarity import CosineSimilarityLayer +from .divide import DivideLayer +from .exp import ExpLayer +from .exponent import ExponentLayer +from .haversine_distance import HaversineDistanceLayer +from .identity import IdentityLayer +from .impute import ImputeLayer +from .log import LogLayer +from .logical_and import LogicalAndLayer +from .logical_not import LogicalNotLayer +from .logical_or import LogicalOrLayer +from .max import MaxLayer +from .mean import MeanLayer +from .min import MinLayer +from .min_max_scale import MinMaxScaleLayer +from .modulo import ModuloLayer +from .multiply import MultiplyLayer +from .numerical_if_statement import NumericalIfStatementLayer +from .pairwise_cosine_similarity import PairwiseCosineSimilarityLayer +from .round import RoundLayer +from .round_to_decimal import RoundToDecimalLayer +from .standard_scale import StandardScaleLayer +from .subtract import SubtractLayer +from .sum import SumLayer + +__all__ = [ + "IdentityLayer", + "AbsoluteValueLayer", + "MultiplyLayer", + "ExpLayer", + "LogLayer", + "DivideLayer", + "SubtractLayer", + "RoundLayer", + "RoundToDecimalLayer", + "ModuloLayer", + "SumLayer", + "MaxLayer", + "MinLayer", + "MeanLayer", + "ExponentLayer", + "LogicalAndLayer", + "LogicalOrLayer", + "LogicalNotLayer", + "NumericalIfStatementLayer", + "ArrayConcatenateLayer", + "ArrayReduceMaxLayer", + "ArraySplitLayer", + "ArrayCropLayer", + "ArraySubtractMinimumLayer", + "StandardScaleLayer", + "ConditionalStandardScaleLayer", + "MinMaxScaleLayer", + "ImputeLayer", + "BinLayer", + "BearingAngleLayer", + "CosineSimilarityLayer", + "PairwiseCosineSimilarityLayer", + "HaversineDistanceLayer", +] diff --git a/src/kamae/tensorflow/layers/absolute_value.py b/src/kamae/keras/core/layers/absolute_value.py similarity index 73% rename from src/kamae/tensorflow/layers/absolute_value.py rename to src/kamae/keras/core/layers/absolute_value.py index 29fad865..d748b728 100644 --- a/src/kamae/tensorflow/layers/absolute_value.py +++ b/src/kamae/keras/core/layers/absolute_value.py @@ -14,21 +14,24 @@ from typing import Any, Dict, List, Optional -import tensorflow as tf +import keras +from keras import KerasTensor, ops import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import enforce_single_tensor_input +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_single_tensor_input -from .base import BaseLayer - -@tf.keras.utils.register_keras_serializable(package=kamae.__name__) +@keras.saving.register_keras_serializable(package=kamae.__name__) class AbsoluteValueLayer(BaseLayer): """ - Performs the abs(x) operation on a given input tensor + Performs the abs(x) operation on a given input tensor. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + def __init__( self, name: Optional[str] = None, @@ -48,24 +51,24 @@ def __init__( ) @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. - :returns: The compatible dtypes of the layer. + :returns: List of compatible dtype names """ return [ - tf.float16, - tf.float32, - tf.float64, - tf.int32, - tf.int64, - tf.complex64, - tf.complex128, + "float16", + "float32", + "float64", + "int32", + "int64", + "complex64", + "complex128", ] @enforce_single_tensor_input - def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: + def _call(self, inputs: KerasTensor, **kwargs: Any) -> KerasTensor: """ Performs the abs(x) operation on a given input tensor. @@ -76,7 +79,7 @@ def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: :param inputs: Tensor to perform the abs(x) operation on. :returns: The absolute value of the input tensor. """ - outputs = tf.math.abs(inputs) + outputs = ops.absolute(inputs) return outputs def get_config(self) -> Dict[str, Any]: diff --git a/src/kamae/tensorflow/layers/array_concatenate.py b/src/kamae/keras/core/layers/array_concatenate.py similarity index 78% rename from src/kamae/tensorflow/layers/array_concatenate.py rename to src/kamae/keras/core/layers/array_concatenate.py index cb544da8..1b6d4780 100644 --- a/src/kamae/tensorflow/layers/array_concatenate.py +++ b/src/kamae/keras/core/layers/array_concatenate.py @@ -14,21 +14,25 @@ from typing import Any, Dict, Iterable, List, Optional -import tensorflow as tf +import keras +from keras import KerasTensor, ops import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import enforce_multiple_tensor_input, reshape_to_equal_rank +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_multiple_tensor_input +from kamae.keras.core.utils.shape_utils import reshape_to_equal_rank -from .base import BaseLayer - -@tf.keras.utils.register_keras_serializable(kamae.__name__) +@keras.saving.register_keras_serializable(package=kamae.__name__) class ArrayConcatenateLayer(BaseLayer): """ Performs a concatenation of the input tensors. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + def __init__( self, name: Optional[str] = None, @@ -57,7 +61,7 @@ def __init__( self.auto_broadcast = auto_broadcast @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. Returns `None` as the compatible dtypes are not restricted. @@ -67,7 +71,7 @@ def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: return None @enforce_multiple_tensor_input - def _call(self, inputs: Iterable[Tensor], **kwargs: Any) -> Tensor: + def _call(self, inputs: Iterable[KerasTensor], **kwargs: Any) -> KerasTensor: """ Concatenates the input tensors along the specified axis. If auto_broadcast is set to True, the tensors are broadcasted to the @@ -99,11 +103,11 @@ def _call(self, inputs: Iterable[Tensor], **kwargs: Any) -> Tensor: max_static_shape.append(max(shapes)) # Determine the maximum dynamic shape for each dimension, except last one - # Since shapes can be dynamic (None), we need to use tf.shape + # Since shapes can be dynamic (None), we need to use ops.shape max_dynamic_shape = [] for i in range(max_rank - 1): - shapes = [tf.shape(x)[i] for x in reshaped_inputs] - max_dynamic_shape.append(tf.reduce_max(shapes)) + shapes = [ops.shape(x)[i] for x in reshaped_inputs] + max_dynamic_shape.append(ops.max(ops.stack(shapes))) # Broadcast tensors to the maximum dynamic shape if the static is different # WARNING: It assumes that when the static shapes of two tensors are None @@ -112,19 +116,25 @@ def _call(self, inputs: Iterable[Tensor], **kwargs: Any) -> Tensor: x_static_shape = x.shape[:-1] if x_static_shape != max_static_shape: last_dim = x.shape[-1] - broadcast_shape = tf.concat([max_dynamic_shape, [last_dim]], axis=0) - broadcasted_x = tf.broadcast_to(x, broadcast_shape) + broadcast_shape = ops.concatenate( + [ + ops.stack(max_dynamic_shape), + ops.convert_to_tensor([last_dim]), + ], + axis=0, + ) + broadcasted_x = ops.broadcast_to(x, broadcast_shape) reshaped_inputs[idx] = broadcasted_x inputs = reshaped_inputs - return tf.concat(inputs, axis=self.axis) + return ops.concatenate(inputs, axis=self.axis) def get_config(self) -> Dict[str, Any]: """ - Gets the configuration of the VectorConcat layer. + Gets the configuration of the ArrayConcatenate layer. Used for saving and loading from a model. - Specifically, adds the `axis` to the config. + Specifically, adds the `axis` and `auto_broadcast` to the config. :returns: Dictionary of the configuration of the layer. """ diff --git a/src/kamae/tensorflow/layers/array_crop.py b/src/kamae/keras/core/layers/array_crop.py similarity index 61% rename from src/kamae/tensorflow/layers/array_crop.py rename to src/kamae/keras/core/layers/array_crop.py index 66021642..a9494ba7 100644 --- a/src/kamae/tensorflow/layers/array_crop.py +++ b/src/kamae/keras/core/layers/array_crop.py @@ -14,16 +14,16 @@ from typing import Any, Dict, List, Optional, Union -import tensorflow as tf +import keras +from keras import KerasTensor, ops import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import enforce_single_tensor_input +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_single_tensor_input -from .base import BaseLayer - -@tf.keras.utils.register_keras_serializable(kamae.__name__) +@keras.saving.register_keras_serializable(package=kamae.__name__) class ArrayCropLayer(BaseLayer): """ Performs a cropping of the input tensor to a certain length. @@ -33,6 +33,9 @@ class ArrayCropLayer(BaseLayer): TODO: Currently only supports cropping the final dimension of the tensor. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + def __init__( self, name: Optional[str] = None, @@ -63,7 +66,7 @@ def __init__( self.pad_value = pad_value @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. @@ -72,38 +75,49 @@ def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: return None @enforce_single_tensor_input - def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: + def _call(self, inputs: KerasTensor, **kwargs: Any) -> KerasTensor: """ Crops the tensor to specified length and pads with specified value. :param inputs: Tensor to split. :returns: Cropped and padded tensor """ - inputs_shape = tf.shape(inputs) - # Crop final dimension of tensor - crop_length = tf.minimum(self.array_length, inputs_shape[-1]) - cropped = inputs[..., :crop_length] + # Use static shape for slicing if available, otherwise dynamic + if inputs.shape[-1] is not None: + crop_length = min(self.array_length, inputs.shape[-1]) + cropped = inputs[..., :crop_length] + padding_needed = max(self.array_length - inputs.shape[-1], 0) + else: + # Dynamic shape - need runtime computation + dynamic_last_dim = ops.shape(inputs)[-1] + crop_length = ops.minimum(self.array_length, dynamic_last_dim) + cropped = inputs[..., :crop_length] + padding_needed = ops.maximum(self.array_length - dynamic_last_dim, 0) # Pad final dim of tensor if necessary - padding_length = tf.maximum(self.array_length - inputs_shape[-1], 0) - paddings = [[0, 0]] * (inputs_shape.shape[0] - 1) + [[0, padding_length]] - padded = tf.pad(cropped, paddings, constant_values=self.pad_value) - new_shape = tf.concat( - [ - tf.shape(padded)[:-1], - tf.expand_dims(tf.constant(self.array_length), axis=-1), - ], - axis=0, - ) - return tf.reshape(padded, new_shape) + rank = len(inputs.shape) + paddings = [[0, 0]] * (rank - 1) + [[0, padding_needed]] + padded = ops.pad(cropped, paddings, constant_values=self.pad_value) + + # Build target shape tuple for reshape + # Use static shape dimensions where available, dynamic where needed + new_shape_list = [] + for i in range(rank - 1): + if padded.shape[i] is not None: + new_shape_list.append(padded.shape[i]) + else: + new_shape_list.append(ops.shape(padded)[i]) + new_shape_list.append(self.array_length) + + return ops.reshape(padded, new_shape_list) def get_config(self) -> Dict[str, Any]: """ Gets the configuration of the ArrayCrop layer. Used for saving and loading from a model. - Specifically, adds the `array_length` amd `pad_value to the config. + Specifically, adds the `array_length` and `pad_value` to the config. :returns: Dictionary of the configuration of the layer. """ diff --git a/src/kamae/tensorflow/layers/array_reduce_max.py b/src/kamae/keras/core/layers/array_reduce_max.py similarity index 61% rename from src/kamae/tensorflow/layers/array_reduce_max.py rename to src/kamae/keras/core/layers/array_reduce_max.py index 7482bc23..ee7cdccc 100644 --- a/src/kamae/tensorflow/layers/array_reduce_max.py +++ b/src/kamae/keras/core/layers/array_reduce_max.py @@ -12,26 +12,31 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Dict, List, Optional -import tensorflow as tf +import keras +from keras import KerasTensor, ops import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import enforce_single_tensor_input +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_single_tensor_input -from .base import BaseLayer - -@tf.keras.utils.register_keras_serializable(package=kamae.__name__) +@keras.saving.register_keras_serializable(package=kamae.__name__) class ArrayReduceMaxLayer(BaseLayer): """ Reduces the last dimension of a tensor by taking the maximum. Input: (..., N) Output: (...) + + NaN values in the result are replaced with the configured default_value. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + def __init__( self, name: Optional[str] = None, @@ -46,15 +51,20 @@ def __init__( self.default_value = default_value @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: - return [tf.bfloat16, tf.float16, tf.float32, tf.float64] + def compatible_dtypes(self) -> Optional[List[str]]: + return [ + "bfloat16", + "float16", + "float32", + "float64", + ] @enforce_single_tensor_input - def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: - result = tf.reduce_max(inputs, axis=-1) - return tf.where( - tf.math.is_nan(result), - tf.constant(self.default_value, dtype=result.dtype), + def _call(self, inputs: KerasTensor, **kwargs: Any) -> KerasTensor: + result = ops.max(inputs, axis=-1) + return ops.where( + ops.isnan(result), + ops.cast(self.default_value, dtype=result.dtype), result, ) diff --git a/src/kamae/tensorflow/layers/array_split.py b/src/kamae/keras/core/layers/array_split.py similarity index 79% rename from src/kamae/tensorflow/layers/array_split.py rename to src/kamae/keras/core/layers/array_split.py index 13d4065e..bfc710f9 100644 --- a/src/kamae/tensorflow/layers/array_split.py +++ b/src/kamae/keras/core/layers/array_split.py @@ -14,22 +14,25 @@ from typing import Any, Dict, List, Optional -import tensorflow as tf +import keras +from keras import KerasTensor, ops import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import enforce_single_tensor_input +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_single_tensor_input -from .base import BaseLayer - -@tf.keras.utils.register_keras_serializable(kamae.__name__) +@keras.saving.register_keras_serializable(package=kamae.__name__) class ArraySplitLayer(BaseLayer): """ Performs a splitting of the input tensor into a list of tensors. Expands dimensions to ensure the output tensors are the same shape as the input. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + def __init__( self, name: Optional[str] = None, @@ -52,7 +55,7 @@ def __init__( self.axis = axis @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. @@ -61,7 +64,7 @@ def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: return None @enforce_single_tensor_input - def _call(self, inputs: Tensor, **kwargs: Any) -> List[Tensor]: + def _call(self, inputs: KerasTensor, **kwargs: Any) -> List[KerasTensor]: """ Splits the input tensor along the specified axis. @@ -73,13 +76,13 @@ def _call(self, inputs: Tensor, **kwargs: Any) -> List[Tensor]: :returns: List of split tensors. """ return [ - tf.expand_dims(y, axis=self.axis) - for y in tf.unstack(inputs, axis=self.axis) + ops.expand_dims(y, axis=self.axis) + for y in ops.unstack(inputs, axis=self.axis) ] def get_config(self) -> Dict[str, Any]: """ - Gets the configuration of the VectorSplit layer. + Gets the configuration of the ArraySplit layer. Used for saving and loading from a model. Specifically, adds the `axis` to the config. diff --git a/src/kamae/tensorflow/layers/array_subtract_minimum.py b/src/kamae/keras/core/layers/array_subtract_minimum.py similarity index 64% rename from src/kamae/tensorflow/layers/array_subtract_minimum.py rename to src/kamae/keras/core/layers/array_subtract_minimum.py index f6b34701..ff964815 100644 --- a/src/kamae/tensorflow/layers/array_subtract_minimum.py +++ b/src/kamae/keras/core/layers/array_subtract_minimum.py @@ -14,30 +14,34 @@ from typing import Any, Dict, List, Optional, Union -import tensorflow as tf +import keras +from keras import KerasTensor, ops import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import enforce_single_tensor_input +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_single_tensor_input +from kamae.keras.core.utils.tensor_utils import get_dtype_max -from .base import BaseLayer - -@tf.keras.utils.register_keras_serializable(package=kamae.__name__) +@keras.saving.register_keras_serializable(package=kamae.__name__) class ArraySubtractMinimumLayer(BaseLayer): """ - TensorFlow layer that computes the difference across an axis from the minimum - non-paded element in the input tensor. + Computes the difference across an axis from the minimum non-padded element + in the input tensor. It takes a tensor of numerical value and calculates the differences between each value and the minimum value in the tensor. The calculation preserves the pad value elements. The principal use case for this layer is to calculate the time difference - from the first event to all events in a sequence, where the tensor is a array of + from the first event to all events in a sequence, where the tensor is an array of timestamps. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + def __init__( self, name: Optional[str] = None, @@ -65,55 +69,50 @@ def __init__( self.pad_value = pad_value @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ return [ - tf.bfloat16, - tf.float16, - tf.float32, - tf.float64, - tf.uint8, - tf.int8, - tf.uint16, - tf.int16, - tf.int32, - tf.int64, - tf.uint32, - tf.uint64, + "bfloat16", + "float16", + "float32", + "float64", + "uint8", + "int8", + "uint16", + "int16", + "int32", + "int64", + "uint32", + "uint64", ] @enforce_single_tensor_input - def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: + def _call(self, inputs: KerasTensor, **kwargs: Any) -> KerasTensor: """ Performs the calculation of the differences on the input tensor. Example: - input_tensor = tf.Tensor([ - [19, 18, 13, 11, 10, -1, -1, -1], - [12, 2, 1, -1, -1, -1, -1, -1], - ] - ) - layer = ArraySubtractMinimumLayer() + input_tensor = [[19, 18, 13, 11, 10, -1, -1, -1], + [12, 2, 1, -1, -1, -1, -1, -1]] + layer = ArraySubtractMinimumLayer(pad_value=-1) differences = layer(input_tensor) - print(differences) - Output: tf.Tensor([[ - [9, 8, 3, 1, 0, -1, -1, -1], - [11, 1, 0, -1, -1, -1, -1, -1], - ] - ) + Output: [[9, 8, 3, 1, 0, -1, -1, -1], + [11, 1, 0, -1, -1, -1, -1, -1]] :param inputs: The input tensor. - :returns: Tensor of differences from the minimum (non-padded) timestamp. + :returns: Tensor of differences from the minimum (non-padded) value. """ if self.pad_value is None: # If pad value is not defined, then the smallest value in the tensor is # considered as the first value and subtracted from all the values. - first_value = tf.reduce_min(inputs, axis=self.axis) - subtracted_val = tf.subtract(inputs, tf.expand_dims(first_value, self.axis)) + first_value = ops.min(inputs, axis=self.axis) + subtracted_val = ops.subtract( + inputs, ops.expand_dims(first_value, self.axis) + ) return subtracted_val # Otherwise, we find the smallest non padded value and subtract it from all @@ -121,14 +120,20 @@ def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: inputs, pad_tensor = self._force_cast_to_compatible_numeric_type( inputs, self.pad_value ) - first_non_pad_value = tf.reduce_min( - tf.where(tf.equal(inputs, pad_tensor), inputs.dtype.max, inputs), + + # Get the dtype max value for masking + dtype_str = keras.backend.standardize_dtype(inputs.dtype) + dtype_max = get_dtype_max(dtype_str) + dtype_max_tensor = ops.convert_to_tensor(dtype_max, dtype=inputs.dtype) + + first_non_pad_value = ops.min( + ops.where(ops.equal(inputs, pad_tensor), dtype_max_tensor, inputs), axis=self.axis, ) - subtracted_val = tf.subtract( - inputs, tf.expand_dims(first_non_pad_value, self.axis) + subtracted_val = ops.subtract( + inputs, ops.expand_dims(first_non_pad_value, self.axis) ) - return tf.where(tf.equal(inputs, pad_tensor), inputs, subtracted_val) + return ops.where(ops.equal(inputs, pad_tensor), inputs, subtracted_val) def get_config(self) -> Dict[str, Any]: """ diff --git a/src/kamae/tensorflow/layers/bearing_angle.py b/src/kamae/keras/core/layers/bearing_angle.py similarity index 69% rename from src/kamae/tensorflow/layers/bearing_angle.py rename to src/kamae/keras/core/layers/bearing_angle.py index b50c27a3..f43d6d9b 100644 --- a/src/kamae/tensorflow/layers/bearing_angle.py +++ b/src/kamae/keras/core/layers/bearing_angle.py @@ -12,23 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math from typing import Any, Dict, Iterable, List, Optional -import tensorflow as tf -from tensorflow.math import atan2, cos, mod, sin +import keras +from keras import KerasTensor, ops import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import enforce_multiple_tensor_input +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_multiple_tensor_input +from kamae.keras.core.utils.ops_utils import get_degrees, get_radians -from .base import BaseLayer - -@tf.keras.utils.register_keras_serializable(package=kamae.__name__) +@keras.saving.register_keras_serializable(package=kamae.__name__) class BearingAngleLayer(BaseLayer): """ Computes the Bearing angle operation on a given input tensor. + If lat_lon_constant is not set, inputs must be a list of 4 tensors, in the order of lat1, lon1, lat2, lon2. If lat_lon_constant is set, inputs must be a tensor of 2 tensors, @@ -38,6 +38,9 @@ class BearingAngleLayer(BaseLayer): For lat, this is [-90, 90] and for lon, this is [-180, 180]. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + def __init__( self, name: Optional[str] = None, @@ -63,42 +66,17 @@ def __init__( self.lat_lon_constant = lat_lon_constant @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ - return [tf.bfloat16, tf.float16, tf.float32, tf.float64] - - @staticmethod - def get_radians(degrees: Tensor) -> Tensor: - """ - Converts degrees tensor to radians. We need to cast to float64 otherwise - pi / 180 will lose precision. - - :param degrees: Tensor of degrees. - :returns: Tensor of radians. - """ - return tf.cast(degrees, dtype=tf.float64) * tf.constant( - math.pi / 180, dtype=tf.float64 - ) - - @staticmethod - def get_degrees(radians: Tensor) -> Tensor: - """ - Converts radians tensor to degrees. - - :param radians: Tensor of degrees. - :returns: Tensor of degrees. - """ - return tf.cast(radians, dtype=tf.float64) * tf.constant( - 180 / math.pi, dtype=tf.float64 - ) + return ["bfloat16", "float16", "float32", "float64"] def compute_bearing_angle( - self, lat1: Tensor, lon1: Tensor, lat2: Tensor, lon2: Tensor - ) -> Tensor: + self, lat1: KerasTensor, lon1: KerasTensor, lat2: KerasTensor, lon2: KerasTensor + ) -> KerasTensor: """ Computes the bearing angle between two lat/lon pairs. @@ -108,25 +86,25 @@ def compute_bearing_angle( :param lon2: Tensor of longitudes of the second point. :returns: Tensor of bearing angles. """ - lat1_radians = self.get_radians(lat1) - lon1_radians = self.get_radians(lon1) - lat2_radians = self.get_radians(lat2) - lon2_radians = self.get_radians(lon2) + lat1_radians = get_radians(lat1) + lon1_radians = get_radians(lon1) + lat2_radians = get_radians(lat2) + lon2_radians = get_radians(lon2) lon_difference = lon2_radians - lon1_radians # Bearing formula calculation - y = sin(lon_difference) * cos(lat2_radians) + y = ops.sin(lon_difference) * ops.cos(lat2_radians) - x = cos(lat1_radians) * sin(lat2_radians) - x -= sin(lat1_radians) * cos(lat2_radians) * cos(lon_difference) + x = ops.cos(lat1_radians) * ops.sin(lat2_radians) + x -= ops.sin(lat1_radians) * ops.cos(lat2_radians) * ops.cos(lon_difference) # Calculate bearing in degrees - bearing = atan2(y, x) - bearing_deg = mod(self.get_degrees(bearing) + 360, 360) + bearing = ops.arctan2(y, x) + bearing_deg = ops.mod(get_degrees(bearing) + 360, 360) return bearing_deg @enforce_multiple_tensor_input - def _call(self, inputs: Iterable[Tensor], **kwargs: Any) -> Tensor: + def _call(self, inputs: Iterable[KerasTensor], **kwargs: Any) -> KerasTensor: """ Computes the bearing angle between two lat/lon pairs. @@ -148,8 +126,8 @@ def _call(self, inputs: Iterable[Tensor], **kwargs: Any) -> Tensor: return self.compute_bearing_angle( inputs[0], inputs[1], - tf.constant(self.lat_lon_constant[0]), - tf.constant(self.lat_lon_constant[1]), + ops.convert_to_tensor(self.lat_lon_constant[0]), + ops.convert_to_tensor(self.lat_lon_constant[1]), ) else: if not isinstance(inputs, list) or len(inputs) != 4: @@ -169,7 +147,7 @@ def get_config(self) -> Dict[str, Any]: Gets the configuration of the Bearing Angle layer. Used for saving and loading from a model. - Specifically, we add the `lat_lon_constant` and `unit` to the config. + Specifically, we add the `lat_lon_constant` to the config. :returns: Dictionary of the configuration of the layer. """ diff --git a/src/kamae/tensorflow/layers/bin.py b/src/kamae/keras/core/layers/bin.py similarity index 84% rename from src/kamae/tensorflow/layers/bin.py rename to src/kamae/keras/core/layers/bin.py index d4e6fc1d..b4a5348b 100644 --- a/src/kamae/tensorflow/layers/bin.py +++ b/src/kamae/keras/core/layers/bin.py @@ -14,17 +14,17 @@ from typing import Any, Dict, List, Optional, Union -import tensorflow as tf +import keras +from keras import KerasTensor, ops import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import enforce_single_tensor_input +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_single_tensor_input from kamae.utils import get_condition_operator -from .base import BaseLayer - -@tf.keras.utils.register_keras_serializable(package=kamae.__name__) +@keras.saving.register_keras_serializable(package=kamae.__name__) class BinLayer(BaseLayer): """ Performs a binning operation on a given input tensor. @@ -36,6 +36,9 @@ class BinLayer(BaseLayer): If no conditions evaluate to True, the default label is returned. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + def __init__( self, condition_operators: List[str], @@ -50,9 +53,6 @@ def __init__( """ Initializes the BinLayer layer - :param name: Name of the layer, defaults to `None`. - :param input_dtype: The dtype to cast the input to. Defaults to `None`. - :param output_dtype: The dtype to cast the output to. Defaults to `None`. :param condition_operators: List of operators to use in the if statement. Can be one of: - "eq": Equal to @@ -66,6 +66,9 @@ def __init__( :param bin_labels: List of labels to use for each bin. Must be the same length as condition_operators. :param default_label: Label to use if none of the conditions are met. + :param name: Name of the layer, defaults to `None`. + :param input_dtype: The dtype to cast the input to. Defaults to `None`. + :param output_dtype: The dtype to cast the output to. Defaults to `None`. """ super().__init__( name=name, input_dtype=input_dtype, output_dtype=output_dtype, **kwargs @@ -82,33 +85,33 @@ def __init__( self.default_label = default_label @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ return [ - tf.bfloat16, - tf.float16, - tf.float32, - tf.float64, - tf.int8, - tf.uint8, - tf.int16, - tf.uint16, - tf.int32, - tf.uint32, - tf.int64, - tf.uint64, + "bfloat16", + "float16", + "float32", + "float64", + "int8", + "uint8", + "int16", + "uint16", + "int32", + "uint32", + "int64", + "uint64", ] @enforce_single_tensor_input - def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: + def _call(self, inputs: KerasTensor, **kwargs: Any) -> KerasTensor: """ Performs a binning operation on a given input tensor. - Creates a string tensor of the same shape as the input tensor, where each + Creates a tensor of the same shape as the input tensor, where each element is the label of the bin that the corresponding element in the input tensor belongs to. The bin labels are determined by successively applying the condition operators to the input tensor, and returning the label of the @@ -124,7 +127,7 @@ def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: cond_op_fns = [get_condition_operator(op) for op in self.condition_operators] # Build default output tensor - outputs = tf.constant(self.default_label) + outputs = ops.convert_to_tensor(self.default_label) # Loop through the conditions. # Reverse the list of conditions so that we start from the last condition @@ -137,12 +140,12 @@ def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: cast_input, cast_value = self._force_cast_to_compatible_numeric_type( inputs, value ) - outputs = tf.where( + outputs = ops.where( cond_op( cast_input, cast_value, ), - tf.constant(label), + ops.convert_to_tensor(label), outputs, ) diff --git a/src/kamae/tensorflow/layers/conditional_standard_scale.py b/src/kamae/keras/core/layers/conditional_standard_scale.py similarity index 76% rename from src/kamae/tensorflow/layers/conditional_standard_scale.py rename to src/kamae/keras/core/layers/conditional_standard_scale.py index 07aff3b2..a8dd62e8 100644 --- a/src/kamae/tensorflow/layers/conditional_standard_scale.py +++ b/src/kamae/keras/core/layers/conditional_standard_scale.py @@ -14,27 +14,35 @@ from typing import Any, Dict, List, Optional, Union +import keras import numpy as np -import tensorflow as tf +from keras import KerasTensor, ops import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import NormalizeLayer, enforce_single_tensor_input +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.utils.input_utils import enforce_single_tensor_input +from kamae.keras.core.utils.normalize_layer import NormalizeLayer +from kamae.keras.core.utils.ops_utils import divide_no_nan -@tf.keras.utils.register_keras_serializable(package=kamae.__name__) +@keras.saving.register_keras_serializable(package=kamae.__name__) class ConditionalStandardScaleLayer(NormalizeLayer): """ Performs the standard scaling of the input with a masking condition. + This layer will shift and scale inputs into a distribution centered around 0 with standard deviation 1. It accomplishes this by precomputing the mean and variance of the data, and calling `(input - mean) / sqrt(var)` at runtime. + The skip_zeros parameter allows to apply the standard scaling process only when input is not equal to zero. If equal to zero, it will remain zero in the output value as it was in the input value. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + def __init__( self, mean: Union[List[float], np.array], @@ -48,7 +56,8 @@ def __init__( **kwargs: Any, ) -> None: """ - Intialise the ConditionalStandardScaleLayer layer. + Initialise the ConditionalStandardScaleLayer layer. + :param mean: The mean value(s) to use during normalization. The passed value(s) will be broadcast to the shape of the kept axes above; if the value(s) cannot be broadcast, an error will be raised when this layer's @@ -73,7 +82,7 @@ def __init__( :param input_dtype: The dtype to cast the input to. Defaults to `None`. :param output_dtype: The dtype to cast the output to. Defaults to `None`. :param epsilon: Small value to add to conditional check of zeros. Valid only - when skipZeros is True. Defaults to 1e-4. + when skipZeros is True. Defaults to 0. """ super().__init__( name=name, @@ -88,38 +97,44 @@ def __init__( self.epsilon = epsilon @enforce_single_tensor_input - def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: + def _call(self, inputs: KerasTensor, **kwargs: Any) -> KerasTensor: """ - Performs normalization on the input tensor(s) by calling the keras - ConditionalStandardScaleLayer layer. + Performs normalization on the input tensor(s). + It applies the scaling only to values matching the mask condition, if set. It applies the scaling only to values not equal to zero, if skip_zeros is set. + Decorated with `@enforce_single_tensor_input` to ensure that the input is a single tensor. Raises an error if multiple tensors are passed in as an iterable. + :param inputs: Input tensor to perform the normalization on. :returns: The input tensor with the normalization applied. """ # Ensure mean and variance match input dtype. - mean = self._cast(self.mean, inputs.dtype.name) - variance = self._cast(self.variance, inputs.dtype.name) - normalized_outputs = tf.math.divide_no_nan( - tf.math.subtract(inputs, mean), - tf.math.maximum( - tf.sqrt(variance), tf.constant(self.epsilon, dtype=inputs.dtype) - ), + input_dtype_str = keras.backend.standardize_dtype(inputs.dtype) + mean = self._cast(self.mean, input_dtype_str) + variance = self._cast(self.variance, input_dtype_str) + + # Compute (input - mean) / sqrt(variance) using safe division + numerator = ops.subtract(inputs, mean) + denominator = ops.maximum( + ops.sqrt(variance), ops.convert_to_tensor(self.epsilon, dtype=inputs.dtype) ) - # output is 0 if variance is 0 - normalized_outputs = tf.where( - tf.equal(variance, 0), - tf.zeros_like(normalized_outputs), + normalized_outputs = divide_no_nan(numerator, denominator) + + # Output is 0 if variance is 0 + normalized_outputs = ops.where( + ops.equal(variance, 0), + ops.zeros_like(normalized_outputs), normalized_outputs, ) + if self.skip_zeros: - eps = tf.constant(self.epsilon, dtype=inputs.dtype) - normalized_outputs = tf.where( - tf.abs(inputs) <= eps, # x = (0 +- eps) - tf.zeros_like(normalized_outputs), + eps = ops.convert_to_tensor(self.epsilon, dtype=inputs.dtype) + normalized_outputs = ops.where( + ops.less_equal(ops.abs(inputs), eps), # x = (0 +- eps) + ops.zeros_like(normalized_outputs), normalized_outputs, ) return normalized_outputs diff --git a/src/kamae/tensorflow/layers/cosine_similarity.py b/src/kamae/keras/core/layers/cosine_similarity.py similarity index 77% rename from src/kamae/tensorflow/layers/cosine_similarity.py rename to src/kamae/keras/core/layers/cosine_similarity.py index c1a8fb9e..045419ca 100644 --- a/src/kamae/tensorflow/layers/cosine_similarity.py +++ b/src/kamae/keras/core/layers/cosine_similarity.py @@ -14,21 +14,25 @@ from typing import Any, Dict, Iterable, List, Optional -import tensorflow as tf +import keras +from keras import KerasTensor, ops import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import enforce_multiple_tensor_input +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_multiple_tensor_input +from kamae.keras.core.utils.ops_utils import l2_normalize -from .base import BaseLayer - -@tf.keras.utils.register_keras_serializable(package=kamae.__name__) +@keras.saving.register_keras_serializable(package=kamae.__name__) class CosineSimilarityLayer(BaseLayer): """ Computes the cosine similarity between two input tensors. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + def __init__( self, name: Optional[str] = None, @@ -56,23 +60,23 @@ def __init__( self.keepdims = keepdims @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ return [ - tf.bfloat16, - tf.float16, - tf.float32, - tf.float64, - tf.complex64, - tf.complex128, + "bfloat16", + "float16", + "float32", + "float64", + "complex64", + "complex128", ] @enforce_multiple_tensor_input - def _call(self, inputs: Iterable[Tensor], **kwargs: Any) -> Tensor: + def _call(self, inputs: Iterable[KerasTensor], **kwargs: Any) -> KerasTensor: """ Computes the cosine similarity between two input tensors. If `keepdims` is `True`, the shape is retained. Otherwise, the shape is reduced along the @@ -91,10 +95,10 @@ def _call(self, inputs: Iterable[Tensor], **kwargs: Any) -> Tensor: raise ValueError( f"Expected 2 inputs, received {len(inputs)} inputs instead." ) - x = tf.nn.l2_normalize(inputs[0], axis=self.axis) - y = tf.nn.l2_normalize(inputs[1], axis=self.axis) + x = l2_normalize(inputs[0], axis=self.axis) + y = l2_normalize(inputs[1], axis=self.axis) - return tf.reduce_sum(tf.multiply(x, y), axis=self.axis, keepdims=self.keepdims) + return ops.sum(ops.multiply(x, y), axis=self.axis, keepdims=self.keepdims) def get_config(self) -> Dict[str, Any]: """ diff --git a/src/kamae/tensorflow/layers/divide.py b/src/kamae/keras/core/layers/divide.py similarity index 79% rename from src/kamae/tensorflow/layers/divide.py rename to src/kamae/keras/core/layers/divide.py index 2223b028..a72eafcb 100644 --- a/src/kamae/tensorflow/layers/divide.py +++ b/src/kamae/keras/core/layers/divide.py @@ -15,22 +15,26 @@ from functools import reduce from typing import Any, Dict, Iterable, List, Optional, Union -import tensorflow as tf +import keras +from keras import KerasTensor, ops import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import allow_single_or_multiple_tensor_input +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import allow_single_or_multiple_tensor_input +from kamae.keras.core.utils.ops_utils import divide_no_nan -from .base import BaseLayer - -@tf.keras.utils.register_keras_serializable(package=kamae.__name__) +@keras.saving.register_keras_serializable(package=kamae.__name__) class DivideLayer(BaseLayer): """ Performs the divide(x, y) operation on a given input tensor. If divisor is not set, inputs must be a list. If divisor is set, inputs must be a tensor. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + def __init__( self, name: Optional[str] = None, @@ -53,7 +57,7 @@ def __init__( self.divisor = divisor @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. @@ -64,14 +68,16 @@ def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: # error for the any inputs of size > 2 since we then try to divide a float64 # by an int. return [ - tf.bfloat16, - tf.float16, - tf.float32, - tf.float64, + "bfloat16", + "float16", + "float32", + "float64", ] @allow_single_or_multiple_tensor_input - def _call(self, inputs: Union[Tensor, Iterable[Tensor]], **kwargs: Any) -> Tensor: + def _call( + self, inputs: Union[KerasTensor, Iterable[KerasTensor]], **kwargs: Any + ) -> KerasTensor: """ Performs the divide(x, y) operation on either an iterable of input tensors or a single input tensor and a constant. @@ -87,13 +93,12 @@ def _call(self, inputs: Union[Tensor, Iterable[Tensor]], **kwargs: Any) -> Tenso if self.divisor is not None: if len(inputs) > 1: raise ValueError("If divisor is set, cannot have multiple inputs") - return tf.math.divide_no_nan( - inputs[0], tf.constant(self.divisor, dtype=inputs[0].dtype) - ) + divisor_tensor = ops.cast(self.divisor, dtype=inputs[0].dtype) + return divide_no_nan(inputs[0], divisor_tensor) else: if not len(inputs) > 1: raise ValueError("If divisor is not set, must have multiple inputs") - return reduce(tf.math.divide_no_nan, inputs) + return reduce(divide_no_nan, inputs) def get_config(self) -> Dict[str, Any]: """ diff --git a/src/kamae/tensorflow/layers/exp.py b/src/kamae/keras/core/layers/exp.py similarity index 73% rename from src/kamae/tensorflow/layers/exp.py rename to src/kamae/keras/core/layers/exp.py index f7083b00..e3d8632d 100644 --- a/src/kamae/tensorflow/layers/exp.py +++ b/src/kamae/keras/core/layers/exp.py @@ -14,21 +14,24 @@ from typing import Any, Dict, List, Optional -import tensorflow as tf +import keras +from keras import KerasTensor, ops import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import enforce_single_tensor_input +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_single_tensor_input -from .base import BaseLayer - -@tf.keras.utils.register_keras_serializable(package=kamae.__name__) +@keras.saving.register_keras_serializable(package=kamae.__name__) class ExpLayer(BaseLayer): """ - Performs the exp(x) operation on a given input tensor + Performs the exp(x) operation on a given input tensor. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + def __init__( self, name: Optional[str] = None, @@ -48,23 +51,23 @@ def __init__( ) @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. - :returns: The compatible dtypes of the layer. + :returns: List of compatible dtype names """ return [ - tf.bfloat16, - tf.float16, - tf.float32, - tf.float64, - tf.complex64, - tf.complex128, + "bfloat16", + "float16", + "float32", + "float64", + "complex64", + "complex128", ] @enforce_single_tensor_input - def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: + def _call(self, inputs: KerasTensor, **kwargs: Any) -> KerasTensor: """ Performs the exp(x) operation on a given input tensor. @@ -75,7 +78,7 @@ def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: :param inputs: Tensor to perform the exp(x) operation on. :returns: The exp of the input tensor. """ - return tf.math.exp(inputs) + return ops.exp(inputs) def get_config(self) -> Dict[str, Any]: """ diff --git a/src/kamae/tensorflow/layers/exponent.py b/src/kamae/keras/core/layers/exponent.py similarity index 71% rename from src/kamae/tensorflow/layers/exponent.py rename to src/kamae/keras/core/layers/exponent.py index 5c020eba..62c84090 100644 --- a/src/kamae/tensorflow/layers/exponent.py +++ b/src/kamae/keras/core/layers/exponent.py @@ -13,21 +13,24 @@ # limitations under the License. from typing import Any, Dict, List, Optional -import tensorflow as tf +import keras +from keras import KerasTensor, ops import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import allow_single_or_multiple_tensor_input +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import allow_single_or_multiple_tensor_input -from .base import BaseLayer - -@tf.keras.utils.register_keras_serializable(package=kamae.__name__) +@keras.saving.register_keras_serializable(package=kamae.__name__) class ExponentLayer(BaseLayer): """ - Performs the x^exponent operation on a given input tensor + Performs the x^exponent operation on a given input tensor. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + def __init__( self, name: Optional[str] = None, @@ -50,44 +53,38 @@ def __init__( self.exponent = exponent @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ return [ - tf.float16, - tf.float32, - tf.float64, - tf.complex64, - tf.complex128, + "float16", + "float32", + "float64", + "complex64", + "complex128", ] @allow_single_or_multiple_tensor_input - def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: + def _call(self, inputs: KerasTensor, **kwargs: Any) -> KerasTensor: """ - Performs the x^exponent operation on a given input tensor. - - Decorated with `@allow_single_or_multiple_tensor_input` to ensure that the input - is either a single tensor or an iterable of tensors. Returns this result as a - list of tensors for easier use here. - :param inputs: Single tensor or iterable of tensors to perform the x^pow - operation on. + operation on. :returns: The tensor raised to the power of the exponent. """ if self.exponent is not None: if len(inputs) > 1: raise ValueError("If exponent is set, cannot have multiple inputs") - return tf.math.pow( + return ops.power( inputs[0], - self._cast(tf.constant(self.exponent), cast_dtype=inputs[0].dtype.name), + ops.cast(self.exponent, dtype=inputs[0].dtype), ) else: if not len(inputs) == 2: raise ValueError("If exponent is not set, must have exactly 2 inputs") - return tf.math.pow(inputs[0], inputs[1]) + return ops.power(inputs[0], inputs[1]) def get_config(self) -> Dict[str, Any]: """ diff --git a/src/kamae/tensorflow/layers/haversine_distance.py b/src/kamae/keras/core/layers/haversine_distance.py similarity index 75% rename from src/kamae/tensorflow/layers/haversine_distance.py rename to src/kamae/keras/core/layers/haversine_distance.py index 7a17ba82..bb36e178 100644 --- a/src/kamae/tensorflow/layers/haversine_distance.py +++ b/src/kamae/keras/core/layers/haversine_distance.py @@ -12,22 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math from typing import Any, Dict, Iterable, List, Optional -import tensorflow as tf +import keras +from keras import KerasTensor, ops import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import enforce_multiple_tensor_input +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_multiple_tensor_input +from kamae.keras.core.utils.ops_utils import get_radians -from .base import BaseLayer - -@tf.keras.utils.register_keras_serializable(package=kamae.__name__) +@keras.saving.register_keras_serializable(package=kamae.__name__) class HaversineDistanceLayer(BaseLayer): """ Computes the haversine distance operation on a given input tensor. + If lat_lon_constant is not set, inputs must be a list of 4 tensors, in the order of lat1, lon1, lat2, lon2. If lat_lon_constant is set, inputs must be a tensor of 2 tensors, @@ -37,6 +38,9 @@ class HaversineDistanceLayer(BaseLayer): For lat, this is [-90, 90] and for lon, this is [-180, 180]. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + def __init__( self, name: Optional[str] = None, @@ -68,30 +72,17 @@ def __init__( self.earth_radius = 6371.0 if unit == "km" else 3958.8 @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ - return [tf.bfloat16, tf.float16, tf.float32, tf.float64] - - @staticmethod - def get_radians(degrees: Tensor) -> Tensor: - """ - Converts degrees tensor to radians. We need to cast to float64 otherwise - pi / 180 will lose precision. - - :param degrees: Tensor of degrees. - :returns: Tensor of radians. - """ - return tf.cast(degrees, dtype=tf.float64) * tf.constant( - math.pi / 180, dtype=tf.float64 - ) + return ["bfloat16", "float16", "float32", "float64"] def compute_haversine_distance( - self, lat1: Tensor, lon1: Tensor, lat2: Tensor, lon2: Tensor - ) -> Tensor: + self, lat1: KerasTensor, lon1: KerasTensor, lat2: KerasTensor, lon2: KerasTensor + ) -> KerasTensor: """ Computes the haversine distance between two lat/lon pairs. @@ -101,24 +92,24 @@ def compute_haversine_distance( :param lon2: Tensor of longitudes of the second point. :returns: Tensor of haversine distances. """ - lat1_radians = self.get_radians(lat1) - lon1_radians = self.get_radians(lon1) - lat2_radians = self.get_radians(lat2) - lon2_radians = self.get_radians(lon2) + lat1_radians = get_radians(lat1) + lon1_radians = get_radians(lon1) + lat2_radians = get_radians(lat2) + lon2_radians = get_radians(lon2) lat_diff = lat2_radians - lat1_radians lon_diff = lon2_radians - lon1_radians - a = tf.math.pow(tf.math.sin(lat_diff / 2.0), 2.0) + tf.math.cos( - lat1_radians - ) * tf.math.cos(lat2_radians) * tf.math.pow(tf.math.sin(lon_diff / 2.0), 2.0) - c = 2.0 * tf.math.asin(pow(a, 0.5)) - # Radius of earth in kilometers. - r = tf.constant(self.earth_radius, dtype=c.dtype) + a = ops.power(ops.sin(lat_diff / 2.0), 2.0) + ops.cos(lat1_radians) * ops.cos( + lat2_radians + ) * ops.power(ops.sin(lon_diff / 2.0), 2.0) + c = 2.0 * ops.arcsin(ops.power(a, 0.5)) + # Radius of earth in kilometers or miles + r = ops.convert_to_tensor(self.earth_radius, dtype=c.dtype) return c * r @enforce_multiple_tensor_input - def _call(self, inputs: Iterable[Tensor], **kwargs: Any) -> Tensor: + def _call(self, inputs: Iterable[KerasTensor], **kwargs: Any) -> KerasTensor: """ Computes the haversine distance between two lat/lon pairs. @@ -140,8 +131,8 @@ def _call(self, inputs: Iterable[Tensor], **kwargs: Any) -> Tensor: return self.compute_haversine_distance( inputs[0], inputs[1], - tf.constant(self.lat_lon_constant[0]), - tf.constant(self.lat_lon_constant[1]), + ops.convert_to_tensor(self.lat_lon_constant[0]), + ops.convert_to_tensor(self.lat_lon_constant[1]), ) else: if not isinstance(inputs, list) or len(inputs) != 4: diff --git a/src/kamae/tensorflow/layers/identity.py b/src/kamae/keras/core/layers/identity.py similarity index 73% rename from src/kamae/tensorflow/layers/identity.py rename to src/kamae/keras/core/layers/identity.py index 5588eb7e..1b7e39e1 100644 --- a/src/kamae/tensorflow/layers/identity.py +++ b/src/kamae/keras/core/layers/identity.py @@ -14,21 +14,24 @@ from typing import Any, Dict, List, Optional -import tensorflow as tf +import keras +from keras import KerasTensor, ops import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import enforce_single_tensor_input +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_single_tensor_input -from .base import BaseLayer - -@tf.keras.utils.register_keras_serializable(package=kamae.__name__) +@keras.saving.register_keras_serializable(package=kamae.__name__) class IdentityLayer(BaseLayer): """ Performs an identity transform on the input tensor. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + def __init__( self, name: Optional[str] = None, @@ -48,16 +51,16 @@ def __init__( ) @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. - :returns: The compatible dtypes of the layer. + :returns: None (all dtypes are compatible) """ return None @enforce_single_tensor_input - def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: + def _call(self, inputs: KerasTensor, **kwargs: Any) -> KerasTensor: """ Performs an identity transform on the input tensor. @@ -65,10 +68,12 @@ def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: the input is a single tensor. Raises an error if multiple tensors are passed in as an iterable. - :param inputs: Tensor to be apply the identity transform to. + :param inputs: Tensor to apply the identity transform to. :returns: The input tensor. """ - return tf.identity(inputs) + # For identity, simply return the input unchanged + # Note: keras.ops.identity() exists but has bugs in TensorFlow backend + return inputs def get_config(self) -> Dict[str, Any]: """ diff --git a/src/kamae/tensorflow/layers/impute.py b/src/kamae/keras/core/layers/impute.py similarity index 76% rename from src/kamae/tensorflow/layers/impute.py rename to src/kamae/keras/core/layers/impute.py index d16b799f..4cf4696c 100644 --- a/src/kamae/tensorflow/layers/impute.py +++ b/src/kamae/keras/core/layers/impute.py @@ -14,25 +14,30 @@ from typing import Any, Dict, List, Optional, Union -import tensorflow as tf +import keras +from keras import KerasTensor, ops import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import enforce_single_tensor_input +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_single_tensor_input -from .base import BaseLayer - -@tf.keras.utils.register_keras_serializable(package=kamae.__name__) +@keras.saving.register_keras_serializable(package=kamae.__name__) class ImputeLayer(BaseLayer): """ Performs imputation on the input. + Where the input data is equal to the specified mask value, this layer will replace the data with the impute value calculated at preprocessing time. + The impute value is either the mean or median and is computed while ignoring rows in the data which are equal to the mask value or are null. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + def __init__( self, impute_value: Union[float, str, int], @@ -44,12 +49,13 @@ def __init__( ) -> None: """ Initialise the ImputeLayer layer. + :param impute_value: The value to use for imputation. + :param mask_value: Value which should be replaced by the + impute value at inference. :param name: The name of the layer. Defaults to `None`. :param input_dtype: The dtype to cast the input to. Defaults to `None`. :param output_dtype: The dtype to cast the output to. Defaults to `None`. - :param mask_value: Value which should be replaced by the - impute value at inference. """ super().__init__( name=name, input_dtype=input_dtype, output_dtype=output_dtype, **kwargs @@ -62,7 +68,7 @@ def __init__( ) @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. @@ -71,18 +77,22 @@ def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: return None @enforce_single_tensor_input - def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: + def _call(self, inputs: KerasTensor, **kwargs: Any) -> KerasTensor: """ - Performs imputation on the input tensor(s) by calling the keras - ImputeLayer layer. It imputes over values which are equal to the - mask_value. + Performs imputation on the input tensor(s). It imputes over values which + are equal to the mask_value. + Decorated with `@enforce_single_tensor_input` to ensure that the input is a single tensor. Raises an error if multiple tensors are passed in as an iterable. + :param inputs: Input tensor to perform the imputation on. :returns: The input tensor with the imputation applied. """ - if inputs.dtype.is_floating or inputs.dtype.is_integer: + input_dtype_str = keras.backend.standardize_dtype(inputs.dtype) + + # Check if dtype is numeric (floating or integer) + if "float" in input_dtype_str or "int" in input_dtype_str: inputs, mask = self._force_cast_to_compatible_numeric_type( inputs, self.mask_value ) @@ -90,11 +100,14 @@ def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: inputs, self.impute_value ) else: - mask = self._cast(tf.constant(self.mask_value), inputs.dtype.name) - impute_value = self._cast(tf.constant(self.impute_value), inputs.dtype.name) + # For non-numeric types (like strings) + mask = self._cast(ops.convert_to_tensor(self.mask_value), input_dtype_str) + impute_value = self._cast( + ops.convert_to_tensor(self.impute_value), input_dtype_str + ) - mask = tf.equal(inputs, mask) - imputed_outputs = tf.where( + mask = ops.equal(inputs, mask) + imputed_outputs = ops.where( mask, impute_value, inputs, diff --git a/src/kamae/tensorflow/layers/log.py b/src/kamae/keras/core/layers/log.py similarity index 76% rename from src/kamae/tensorflow/layers/log.py rename to src/kamae/keras/core/layers/log.py index 0e8f7d09..13ff77e1 100644 --- a/src/kamae/tensorflow/layers/log.py +++ b/src/kamae/keras/core/layers/log.py @@ -14,21 +14,24 @@ from typing import Any, Dict, List, Optional -import tensorflow as tf +import keras +from keras import KerasTensor, ops import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import enforce_single_tensor_input +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_single_tensor_input -from .base import BaseLayer - -@tf.keras.utils.register_keras_serializable(package=kamae.__name__) +@keras.saving.register_keras_serializable(package=kamae.__name__) class LogLayer(BaseLayer): """ - Performs the log(alpha + x) operation on a given input tensor + Performs the log(alpha + x) operation on a given input tensor. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + def __init__( self, name: Optional[str] = None, @@ -52,23 +55,23 @@ def __init__( self.alpha = alpha @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. - :returns: The compatible dtypes of the layer. + :returns: List of compatible dtype names """ return [ - tf.bfloat16, - tf.float16, - tf.float32, - tf.float64, - tf.complex64, - tf.complex128, + "bfloat16", + "float16", + "float32", + "float64", + "complex64", + "complex128", ] @enforce_single_tensor_input - def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: + def _call(self, inputs: KerasTensor, **kwargs: Any) -> KerasTensor: """ Performs the log(alpha + x) operation on a given input tensor @@ -79,11 +82,11 @@ def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: :param inputs: Input tensor to perform the log(alpha + x) operation on. :returns: The input tensor with the log(alpha + x) operation applied. """ - return tf.math.log(tf.math.add(inputs, self.alpha)) + return ops.log(ops.add(inputs, self.alpha)) def get_config(self) -> Dict[str, Any]: """ - Gets the configuration of the LogAlphaP layer. + Gets the configuration of the Log layer. Used for saving and loading from a model. Specifically adds the `alpha` value to the configuration. diff --git a/src/kamae/tensorflow/layers/logical_and.py b/src/kamae/keras/core/layers/logical_and.py similarity index 81% rename from src/kamae/tensorflow/layers/logical_and.py rename to src/kamae/keras/core/layers/logical_and.py index 53e8e836..03acd29d 100644 --- a/src/kamae/tensorflow/layers/logical_and.py +++ b/src/kamae/keras/core/layers/logical_and.py @@ -15,21 +15,24 @@ from functools import reduce from typing import Any, Dict, Iterable, List, Optional -import tensorflow as tf +import keras +from keras import KerasTensor, ops import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import enforce_multiple_tensor_input +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_multiple_tensor_input -from .base import BaseLayer - -@tf.keras.utils.register_keras_serializable(package=kamae.__name__) +@keras.saving.register_keras_serializable(package=kamae.__name__) class LogicalAndLayer(BaseLayer): """ Performs the and(x, y) operation on a given input tensor. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + def __init__( self, name: Optional[str] = None, @@ -49,16 +52,16 @@ def __init__( ) @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ - return [tf.bool] + return ["bool"] @enforce_multiple_tensor_input - def _call(self, inputs: Iterable[Tensor], **kwargs: Any) -> Tensor: + def _call(self, inputs: Iterable[KerasTensor], **kwargs: Any) -> KerasTensor: """ Performs the and(x, y) operation on an iterable of input tensors @@ -71,7 +74,7 @@ def _call(self, inputs: Iterable[Tensor], **kwargs: Any) -> Tensor: """ if len(inputs) == 1: raise ValueError("Expected multiple inputs, but got a single input") - return reduce(tf.math.logical_and, inputs) + return reduce(ops.logical_and, inputs) def get_config(self) -> Dict[str, Any]: """ diff --git a/src/kamae/tensorflow/layers/logical_not.py b/src/kamae/keras/core/layers/logical_not.py similarity index 80% rename from src/kamae/tensorflow/layers/logical_not.py rename to src/kamae/keras/core/layers/logical_not.py index 8f907b60..bd9a9f75 100644 --- a/src/kamae/tensorflow/layers/logical_not.py +++ b/src/kamae/keras/core/layers/logical_not.py @@ -14,21 +14,24 @@ from typing import Any, Dict, List, Optional -import tensorflow as tf +import keras +from keras import KerasTensor, ops import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import enforce_single_tensor_input +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_single_tensor_input -from .base import BaseLayer - -@tf.keras.utils.register_keras_serializable(package=kamae.__name__) +@keras.saving.register_keras_serializable(package=kamae.__name__) class LogicalNotLayer(BaseLayer): """ Performs the not operation on a given input tensor. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + def __init__( self, name: Optional[str] = None, @@ -48,16 +51,16 @@ def __init__( ) @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ - return [tf.bool] + return ["bool"] @enforce_single_tensor_input - def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: + def _call(self, inputs: KerasTensor, **kwargs: Any) -> KerasTensor: """ Performs the not operation on a single input tensor @@ -68,7 +71,7 @@ def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: :param inputs: Input tensor to perform the not operation on. :returns: The tensor resulting from the or(x, y) operation. """ - return tf.math.logical_not(inputs) + return ops.logical_not(inputs) def get_config(self) -> Dict[str, Any]: """ diff --git a/src/kamae/tensorflow/layers/logical_or.py b/src/kamae/keras/core/layers/logical_or.py similarity index 81% rename from src/kamae/tensorflow/layers/logical_or.py rename to src/kamae/keras/core/layers/logical_or.py index 5c043262..92ee53bf 100644 --- a/src/kamae/tensorflow/layers/logical_or.py +++ b/src/kamae/keras/core/layers/logical_or.py @@ -15,21 +15,24 @@ from functools import reduce from typing import Any, Dict, Iterable, List, Optional -import tensorflow as tf +import keras +from keras import KerasTensor, ops import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import enforce_multiple_tensor_input +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_multiple_tensor_input -from .base import BaseLayer - -@tf.keras.utils.register_keras_serializable(package=kamae.__name__) +@keras.saving.register_keras_serializable(package=kamae.__name__) class LogicalOrLayer(BaseLayer): """ Performs the or(x, y) operation on a given input tensor. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + def __init__( self, name: Optional[str] = None, @@ -49,16 +52,16 @@ def __init__( ) @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ - return [tf.bool] + return ["bool"] @enforce_multiple_tensor_input - def _call(self, inputs: Iterable[Tensor], **kwargs: Any) -> Tensor: + def _call(self, inputs: Iterable[KerasTensor], **kwargs: Any) -> KerasTensor: """ Performs the or(x, y) operation on an iterable of input tensors @@ -71,7 +74,7 @@ def _call(self, inputs: Iterable[Tensor], **kwargs: Any) -> Tensor: """ if len(inputs) == 1: raise ValueError("Expected multiple inputs, but got a single input") - return reduce(tf.math.logical_or, inputs) + return reduce(ops.logical_or, inputs) def get_config(self) -> Dict[str, Any]: """ diff --git a/src/kamae/tensorflow/layers/max.py b/src/kamae/keras/core/layers/max.py similarity index 73% rename from src/kamae/tensorflow/layers/max.py rename to src/kamae/keras/core/layers/max.py index e29bb7bf..81784f3c 100644 --- a/src/kamae/tensorflow/layers/max.py +++ b/src/kamae/keras/core/layers/max.py @@ -15,24 +15,28 @@ from functools import reduce from typing import Any, Dict, Iterable, List, Optional, Union -import tensorflow as tf +import keras +from keras import KerasTensor, ops import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import allow_single_or_multiple_tensor_input +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import allow_single_or_multiple_tensor_input -from .base import BaseLayer - -@tf.keras.utils.register_keras_serializable(package=kamae.__name__) +@keras.saving.register_keras_serializable(package=kamae.__name__) class MaxLayer(BaseLayer): """ Performs the max(x, y) operation on a given input tensor. + If max_constant is not set, inputs are assumed to be a list of tensors and the max of all the tensors is computed. If max_constant is set, inputs must be a tensor. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + def __init__( self, name: Optional[str] = None, @@ -55,39 +59,34 @@ def __init__( self.max_constant = max_constant @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ return [ - tf.bfloat16, - tf.float16, - tf.float32, - tf.float64, - tf.int8, - tf.uint8, - tf.int16, - tf.uint16, - tf.int32, - tf.uint32, - tf.int64, - tf.uint64, + "bfloat16", + "float16", + "float32", + "float64", + "int8", + "uint8", + "int16", + "uint16", + "int32", + "uint32", + "int64", + "uint64", ] @allow_single_or_multiple_tensor_input - def _call(self, inputs: Union[Tensor, Iterable[Tensor]], **kwargs: Any) -> Tensor: + def _call( + self, inputs: Union[KerasTensor, Iterable[KerasTensor]], **kwargs: Any + ) -> KerasTensor: """ - Performs the max(x, y) operation on either an iterable of input tensors or - a single input tensor and a constant. - - Decorated with `@allow_single_or_multiple_tensor_input` to ensure that the input - is either a single tensor or an iterable of tensors. Returns this result as a - list of tensors for easier use here. - :param inputs: Single tensor or iterable of tensors to perform the - max(x, y) operation on. + max(x, y) operation on. :returns: The tensor resulting from the max(x, y) operation. """ if self.max_constant is not None: @@ -96,7 +95,7 @@ def _call(self, inputs: Union[Tensor, Iterable[Tensor]], **kwargs: Any) -> Tenso cast_input, cast_max_constant = self._force_cast_to_compatible_numeric_type( inputs[0], self.max_constant ) - return tf.math.maximum( + return ops.maximum( cast_input, cast_max_constant, ) @@ -105,7 +104,7 @@ def _call(self, inputs: Union[Tensor, Iterable[Tensor]], **kwargs: Any) -> Tenso raise ValueError( "If max_constant is not set, must have multiple inputs" ) - return reduce(tf.math.maximum, inputs) + return reduce(ops.maximum, inputs) def get_config(self) -> Dict[str, Any]: """ diff --git a/src/kamae/tensorflow/layers/mean.py b/src/kamae/keras/core/layers/mean.py similarity index 72% rename from src/kamae/tensorflow/layers/mean.py rename to src/kamae/keras/core/layers/mean.py index 07114da0..72888d15 100644 --- a/src/kamae/tensorflow/layers/mean.py +++ b/src/kamae/keras/core/layers/mean.py @@ -15,24 +15,28 @@ from functools import reduce from typing import Any, Dict, Iterable, List, Optional, Union -import tensorflow as tf +import keras +from keras import KerasTensor, ops import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import allow_single_or_multiple_tensor_input +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import allow_single_or_multiple_tensor_input -from .base import BaseLayer - -@tf.keras.utils.register_keras_serializable(package=kamae.__name__) +@keras.saving.register_keras_serializable(package=kamae.__name__) class MeanLayer(BaseLayer): """ Performs the mean(x, y) operation on a given input tensor. + If mean_constant is not set, inputs are assumed to be a list of tensors and the mean of all the tensors is computed. If mean_constant is set, inputs must be a tensor. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + def __init__( self, name: Optional[str] = None, @@ -56,39 +60,34 @@ def __init__( self.mean_constant = mean_constant @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ return [ - tf.bfloat16, - tf.float16, - tf.float32, - tf.float64, - tf.int8, - tf.uint8, - tf.int16, - tf.uint16, - tf.int32, - tf.uint32, - tf.int64, - tf.uint64, + "bfloat16", + "float16", + "float32", + "float64", + "int8", + "uint8", + "int16", + "uint16", + "int32", + "uint32", + "int64", + "uint64", ] @allow_single_or_multiple_tensor_input - def _call(self, inputs: Union[Tensor, Iterable[Tensor]], **kwargs: Any) -> Tensor: + def _call( + self, inputs: Union[KerasTensor, Iterable[KerasTensor]], **kwargs: Any + ) -> KerasTensor: """ - Performs the mean(x, y) operation on either an iterable of input tensors or - a single input tensor and a constant. - - Decorated with `@allow_single_or_multiple_tensor_input` to ensure that the input - is either a single tensor or an iterable of tensors. Returns this result as a - list of tensors for easier use here. - :param inputs: Single tensor or iterable of tensors to perform the - mean(x, y) operation on. + mean(x, y) operation on. :returns: The tensor resulting from the mean(x, y) operation. """ if self.mean_constant is not None: @@ -101,14 +100,14 @@ def _call(self, inputs: Union[Tensor, Iterable[Tensor]], **kwargs: Any) -> Tenso inputs[0], self.mean_constant, ) - return tf.truediv(tf.math.add(cast_input, cast_mean_constant), 2) + return ops.true_divide(ops.add(cast_input, cast_mean_constant), 2) else: if not len(inputs) > 1: raise ValueError( "If mean_constant is not set, must have multiple inputs" ) - return tf.truediv(reduce(tf.math.add, inputs), len(inputs)) + return ops.true_divide(reduce(ops.add, inputs), len(inputs)) def get_config(self) -> Dict[str, Any]: """ diff --git a/src/kamae/tensorflow/layers/min.py b/src/kamae/keras/core/layers/min.py similarity index 73% rename from src/kamae/tensorflow/layers/min.py rename to src/kamae/keras/core/layers/min.py index 7d95cd9b..5c08f7d2 100644 --- a/src/kamae/tensorflow/layers/min.py +++ b/src/kamae/keras/core/layers/min.py @@ -15,24 +15,28 @@ from functools import reduce from typing import Any, Dict, Iterable, List, Optional, Union -import tensorflow as tf +import keras +from keras import KerasTensor, ops import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import allow_single_or_multiple_tensor_input +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import allow_single_or_multiple_tensor_input -from .base import BaseLayer - -@tf.keras.utils.register_keras_serializable(package=kamae.__name__) +@keras.saving.register_keras_serializable(package=kamae.__name__) class MinLayer(BaseLayer): """ Performs the min(x, y) operation on a given input tensor. + If min_constant is not set, inputs are assumed to be a list of tensors and the min of all the tensors is computed. If min_constant is set, inputs must be a tensor. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + def __init__( self, name: Optional[str] = None, @@ -55,39 +59,34 @@ def __init__( self.min_constant = min_constant @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ return [ - tf.bfloat16, - tf.float16, - tf.float32, - tf.float64, - tf.int8, - tf.uint8, - tf.int16, - tf.uint16, - tf.int32, - tf.uint32, - tf.int64, - tf.uint64, + "bfloat16", + "float16", + "float32", + "float64", + "int8", + "uint8", + "int16", + "uint16", + "int32", + "uint32", + "int64", + "uint64", ] @allow_single_or_multiple_tensor_input - def _call(self, inputs: Union[Tensor, Iterable[Tensor]], **kwargs: Any) -> Tensor: + def _call( + self, inputs: Union[KerasTensor, Iterable[KerasTensor]], **kwargs: Any + ) -> KerasTensor: """ - Performs the min(x, y) operation on either an iterable of input tensors or - a single input tensor and a constant. - - Decorated with `@allow_single_or_multiple_tensor_input` to ensure that the input - is either a single tensor or an iterable of tensors. Returns this result as a - list of tensors for easier use here. - :param inputs: Single tensor or iterable of tensors to perform the - min(x, y) operation on. + min(x, y) operation on. :returns: The tensor resulting from the min(x, y) operation. """ if self.min_constant is not None: @@ -96,7 +95,7 @@ def _call(self, inputs: Union[Tensor, Iterable[Tensor]], **kwargs: Any) -> Tenso cast_input, cast_min_constant = self._force_cast_to_compatible_numeric_type( inputs[0], self.min_constant ) - return tf.math.minimum( + return ops.minimum( cast_input, cast_min_constant, ) @@ -106,7 +105,7 @@ def _call(self, inputs: Union[Tensor, Iterable[Tensor]], **kwargs: Any) -> Tenso "If min_constant is not set, must have multiple inputs" ) - return reduce(tf.math.minimum, inputs) + return reduce(ops.minimum, inputs) def get_config(self) -> Dict[str, Any]: """ diff --git a/src/kamae/tensorflow/layers/min_max_scale.py b/src/kamae/keras/core/layers/min_max_scale.py similarity index 71% rename from src/kamae/tensorflow/layers/min_max_scale.py rename to src/kamae/keras/core/layers/min_max_scale.py index b52832f8..370d73a5 100644 --- a/src/kamae/tensorflow/layers/min_max_scale.py +++ b/src/kamae/keras/core/layers/min_max_scale.py @@ -14,26 +14,32 @@ from typing import Any, Dict, List, Optional, Tuple, Union +import keras import numpy as np -import tensorflow as tf +from keras import KerasTensor, ops import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import enforce_single_tensor_input, listify_tensors +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_single_tensor_input +from kamae.keras.core.utils.ops_utils import divide_no_nan +from kamae.keras.core.utils.tensor_utils import listify_tensors -from .base import BaseLayer - -@tf.keras.utils.register_keras_serializable(package=kamae.__name__) +@keras.saving.register_keras_serializable(package=kamae.__name__) class MinMaxScaleLayer(BaseLayer): """ Performs a min-max scaling operation on the input tensor(s). + This is used to standardize/transform the input tensor to the range [0, 1] using the minimum and maximum values. Formula: (x - min)/(max - min) """ + supported_backends = ALL_BACKENDS + jit_compatible = True + def __init__( self, min: Union[List[float], np.array], @@ -46,7 +52,8 @@ def __init__( **kwargs: Any, ) -> None: """ - Intialise the MinMaxScaleLayer layer. + Initialise the MinMaxScaleLayer layer. + :param min: The min value(s) to use during scaling. :param max: The max value(s) to use during scaling. :param name: The name of the layer. Defaults to `None`. @@ -62,7 +69,8 @@ def __init__( input_dtype=input_dtype, output_dtype=output_dtype, **kwargs, - ) # Standardize `axis` to a tuple. + ) + # Standardize `axis` to a tuple. if axis is None: axis = () elif isinstance(axis, int): @@ -76,13 +84,13 @@ def __init__( self.mask_value = mask_value @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ - return [tf.bfloat16, tf.float16, tf.float32, tf.float64] + return ["bfloat16", "float16", "float32", "float64"] def build(self, input_shape: Tuple[int]) -> None: """ @@ -96,18 +104,23 @@ def build(self, input_shape: Tuple[int]) -> None: """ super().build(input_shape) - if isinstance(input_shape, (list, tuple)) and all( - isinstance(shape, (tf.TensorShape, list, tuple)) for shape in input_shape - ): - # This seems to be needed to handle sending in multiple inputs as a list. - # Although this layer should only have one input, so this is a bit of a - # hack. We catch this nicely in call method with a decorator. Maybe we - # should do the same here? - input_shape = input_shape[0] + # Save the original input_shape for serialization + # Store as tuple to ensure consistent format + if isinstance(input_shape, (list, tuple)): + self._build_input_shape = tuple(input_shape) + else: + self._build_input_shape = input_shape + + # Ensure input_shape is a list for easier manipulation + if not isinstance(input_shape, list): + input_shape = list(input_shape) + + # Handle Keras serialization quirk: when a tuple like (100, 10, 5) is saved + # and deserialized, Keras may wrap it as [(100, 10, 5)] + if len(input_shape) == 1 and isinstance(input_shape[0], (list, tuple)): + input_shape = list(input_shape[0]) - input_shape = tf.TensorShape(input_shape).as_list() ndim = len(input_shape) - self._build_input_shape = input_shape if any(a < -ndim or a >= ndim for a in self.axis): raise ValueError( @@ -128,11 +141,17 @@ def build(self, input_shape: Tuple[int]) -> None: ) # Broadcast any reduced axes. broadcast_shape = [input_shape[d] if d in keep_axis else 1 for d in range(ndim)] - min_and_max_shape = tuple(input_shape[d] for d in keep_axis) + # Extract shape dimensions - handle both int and tuple (e.g., 5 or (5,)) + min_and_max_shape = tuple( + int(input_shape[d][0]) + if isinstance(input_shape[d], tuple) + else int(input_shape[d]) + for d in keep_axis + ) min_tensor = self.input_min * np.ones(min_and_max_shape) max_tensor = self.input_max * np.ones(min_and_max_shape) - self.min = tf.reshape(min_tensor, broadcast_shape) - self.max = tf.reshape(max_tensor, broadcast_shape) + self.min = ops.reshape(min_tensor, broadcast_shape) + self.max = ops.reshape(max_tensor, broadcast_shape) def get_config(self) -> Dict[str, Any]: """ @@ -142,12 +161,13 @@ def get_config(self) -> Dict[str, Any]: :returns: Dictionary of the configuration of the layer. """ config = super().get_config() - # Ensure mean and variance are lists for serialization. + # Ensure min and max are lists for serialization. config.update( { "min": listify_tensors(self.input_min), "max": listify_tensors(self.input_max), "axis": self.axis, + "mask_value": self.mask_value, } ) return config @@ -177,25 +197,30 @@ def build_from_config(self, config: Dict[str, Any]) -> None: self.build(config["input_shape"]) @enforce_single_tensor_input - def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: + def _call(self, inputs: KerasTensor, **kwargs: Any) -> KerasTensor: """ Performs normalization on the input tensor(s) to scale it to the range [0, 1] + Decorated with `@enforce_single_tensor_input` to ensure that the input is a single tensor. Raises an error if multiple tensors are passed in as an iterable. + :param inputs: Input tensor to perform the normalization on. :returns: The input tensor with the normalization applied. """ # Ensure min and max match input dtype. - min_tensor = self._cast(self.min, inputs.dtype.name) - max_tensor = self._cast(self.max, inputs.dtype.name) - normalized_outputs = tf.math.divide_no_nan( - tf.math.subtract(inputs, min_tensor), - tf.math.subtract(max_tensor, min_tensor), - ) + input_dtype_str = keras.backend.standardize_dtype(inputs.dtype) + min_tensor = self._cast(self.min, input_dtype_str) + max_tensor = self._cast(self.max, input_dtype_str) + + # Compute (input - min) / (max - min) using safe division + numerator = ops.subtract(inputs, min_tensor) + denominator = ops.subtract(max_tensor, min_tensor) + normalized_outputs = divide_no_nan(numerator, denominator) + if self.mask_value is not None: - mask = tf.equal(inputs, self.mask_value) - normalized_outputs = tf.where( - mask, inputs, self._cast(normalized_outputs, inputs.dtype.name) + mask = ops.equal(inputs, self.mask_value) + normalized_outputs = ops.where( + mask, inputs, self._cast(normalized_outputs, input_dtype_str) ) return normalized_outputs diff --git a/src/kamae/tensorflow/layers/modulo.py b/src/kamae/keras/core/layers/modulo.py similarity index 78% rename from src/kamae/tensorflow/layers/modulo.py rename to src/kamae/keras/core/layers/modulo.py index 5f408454..2b919ca4 100644 --- a/src/kamae/tensorflow/layers/modulo.py +++ b/src/kamae/keras/core/layers/modulo.py @@ -14,16 +14,16 @@ from typing import Any, Dict, Iterable, List, Optional, Union -import tensorflow as tf +import keras +from keras import KerasTensor, ops import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import allow_single_or_multiple_tensor_input +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import allow_single_or_multiple_tensor_input -from .base import BaseLayer - -@tf.keras.utils.register_keras_serializable(package=kamae.__name__) +@keras.saving.register_keras_serializable(package=kamae.__name__) class ModuloLayer(BaseLayer): """ Performs the modulo(x, y) operation on a given input tensor. @@ -32,6 +32,9 @@ class ModuloLayer(BaseLayer): If divisor is set, inputs must be a tensor. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + def __init__( self, name: Optional[str] = None, @@ -54,29 +57,31 @@ def __init__( self.divisor = divisor @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ return [ - tf.int8, - tf.int16, - tf.int32, - tf.int64, - tf.uint8, - tf.uint16, - tf.uint32, - tf.uint64, - tf.bfloat16, - tf.float16, - tf.float32, - tf.float64, + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + "bfloat16", + "float16", + "float32", + "float64", ] @allow_single_or_multiple_tensor_input - def _call(self, inputs: Union[Tensor, Iterable[Tensor]], **kwargs: Any) -> Tensor: + def _call( + self, inputs: Union[KerasTensor, Iterable[KerasTensor]], **kwargs: Any + ) -> KerasTensor: """ Performs the modulo(x, y) operation on either an iterable of input tensors or a single input tensor and a constant. @@ -95,14 +100,11 @@ def _call(self, inputs: Union[Tensor, Iterable[Tensor]], **kwargs: Any) -> Tenso cast_input, cast_divisor = self._force_cast_to_compatible_numeric_type( inputs[0], self.divisor ) - return tf.math.floormod( - cast_input, - cast_divisor, - ) + return ops.mod(cast_input, cast_divisor) else: if len(inputs) != 2: raise ValueError("If divisor is not set, must have exactly 2 inputs") - return tf.math.floormod(inputs[0], inputs[1]) + return ops.mod(inputs[0], inputs[1]) def get_config(self) -> Dict[str, Any]: """ diff --git a/src/kamae/tensorflow/layers/multiply.py b/src/kamae/keras/core/layers/multiply.py similarity index 78% rename from src/kamae/tensorflow/layers/multiply.py rename to src/kamae/keras/core/layers/multiply.py index b93432f2..85991c3a 100644 --- a/src/kamae/tensorflow/layers/multiply.py +++ b/src/kamae/keras/core/layers/multiply.py @@ -15,16 +15,16 @@ from functools import reduce from typing import Any, Dict, Iterable, List, Optional, Union -import tensorflow as tf +import keras +from keras import KerasTensor, ops import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import allow_single_or_multiple_tensor_input +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import allow_single_or_multiple_tensor_input -from .base import BaseLayer - -@tf.keras.utils.register_keras_serializable(package=kamae.__name__) +@keras.saving.register_keras_serializable(package=kamae.__name__) class MultiplyLayer(BaseLayer): """ Performs the multiply(x, y) operation on a given input tensor. @@ -32,6 +32,9 @@ class MultiplyLayer(BaseLayer): If multiplier is set, inputs must be a tensor. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + def __init__( self, name: Optional[str] = None, @@ -54,29 +57,31 @@ def __init__( self.multiplier = multiplier @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. - :returns: The compatible dtypes of the layer. + :returns: List of compatible dtype names """ return [ - tf.bfloat16, - tf.float16, - tf.float32, - tf.float64, - tf.uint8, - tf.int8, - tf.uint16, - tf.int16, - tf.int32, - tf.int64, - tf.complex64, - tf.complex128, + "bfloat16", + "float16", + "float32", + "float64", + "uint8", + "int8", + "uint16", + "int16", + "int32", + "int64", + "complex64", + "complex128", ] @allow_single_or_multiple_tensor_input - def _call(self, inputs: Union[Tensor, Iterable[Tensor]], **kwargs: Any) -> Tensor: + def _call( + self, inputs: Union[KerasTensor, Iterable[KerasTensor]], **kwargs: Any + ) -> KerasTensor: """ Performs the multiply(x, y) operation on either an iterable of input tensors or a single input tensor and a constant. @@ -95,7 +100,7 @@ def _call(self, inputs: Union[Tensor, Iterable[Tensor]], **kwargs: Any) -> Tenso cast_input, cast_multiplier = self._force_cast_to_compatible_numeric_type( inputs[0], self.multiplier ) - return tf.math.multiply( + return ops.multiply( cast_input, cast_multiplier, ) @@ -103,7 +108,7 @@ def _call(self, inputs: Union[Tensor, Iterable[Tensor]], **kwargs: Any) -> Tenso if not len(inputs) > 1: raise ValueError("If multiplier is not set, must have multiple inputs") - return reduce(tf.math.multiply, inputs) + return reduce(ops.multiply, inputs) def get_config(self) -> Dict[str, Any]: """ diff --git a/src/kamae/tensorflow/layers/numerical_if_statement.py b/src/kamae/keras/core/layers/numerical_if_statement.py similarity index 87% rename from src/kamae/tensorflow/layers/numerical_if_statement.py rename to src/kamae/keras/core/layers/numerical_if_statement.py index 940a151b..6b2b5dbe 100644 --- a/src/kamae/tensorflow/layers/numerical_if_statement.py +++ b/src/kamae/keras/core/layers/numerical_if_statement.py @@ -14,18 +14,17 @@ from typing import Any, Dict, Iterable, List, Optional, Union -import tensorflow as tf +import keras +from keras import KerasTensor, ops import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import allow_single_or_multiple_tensor_input +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import allow_single_or_multiple_tensor_input from kamae.utils import get_condition_operator -from .base import BaseLayer - -# TODO: Deprecate this in favor of IfStatementLayer in next major release. -@tf.keras.utils.register_keras_serializable(package=kamae.__name__) +@keras.saving.register_keras_serializable(package=kamae.__name__) class NumericalIfStatementLayer(BaseLayer): """ Performs a numerical if statement on the input tensor, @@ -50,6 +49,9 @@ class NumericalIfStatementLayer(BaseLayer): not None, then inputs is expected to be a tensor. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + def __init__( self, condition_operator: str, @@ -88,17 +90,17 @@ def __init__( self.result_if_false = result_if_false @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ - return [tf.bfloat16, tf.float16, tf.float32, tf.float64] + return ["bfloat16", "float16", "float32", "float64"] def _construct_input_tensors( - self, inputs: Iterable[tf.Tensor] - ) -> Iterable[tf.Tensor]: + self, inputs: Iterable[KerasTensor] + ) -> Iterable[KerasTensor]: """ Constructs the input tensors for the layer in the case where all the optional parameters are not specified. We need to run through the provided inputs and @@ -132,11 +134,15 @@ def _construct_input_tensors( else: # Otherwise, we create a constant tensor for the parameter # and do not increment the counter. - multiple_inputs.append(tf.constant(param, dtype=inputs[0].dtype)) + multiple_inputs.append( + ops.convert_to_tensor(param, dtype=inputs[0].dtype) + ) return multiple_inputs @allow_single_or_multiple_tensor_input - def _call(self, inputs: Union[Tensor, Iterable[Tensor]], **kwargs: Any) -> Tensor: + def _call( + self, inputs: Union[KerasTensor, Iterable[KerasTensor]], **kwargs: Any + ) -> KerasTensor: """ Performs the numerical if statement on the inputs. If the inputs are a tensor, we assume that the value_to_compare, result_if_true, and result_if_false are @@ -168,17 +174,17 @@ def _call(self, inputs: Union[Tensor, Iterable[Tensor]], **kwargs: Any) -> Tenso "If inputs is a tensor, value_to_compare, result_if_true, and " "result_if_false must be specified." ) - cond = tf.where( + cond = ops.where( condition_op(inputs[0], self.value_to_compare), - tf.constant(self.result_if_true, dtype=inputs[0].dtype), - tf.constant(self.result_if_false, dtype=inputs[0].dtype), + ops.convert_to_tensor(self.result_if_true, dtype=inputs[0].dtype), + ops.convert_to_tensor(self.result_if_false, dtype=inputs[0].dtype), ) return cond else: # If the input is a list, we assume that the value_to_compare, # result_if_true, and result_if_false are potentially provided in the inputs input_tensors = self._construct_input_tensors(inputs) - cond = tf.where( + cond = ops.where( condition_op(input_tensors[0], input_tensors[1]), input_tensors[2], input_tensors[3], diff --git a/src/kamae/tensorflow/layers/pairwise_cosine_similarity.py b/src/kamae/keras/core/layers/pairwise_cosine_similarity.py similarity index 64% rename from src/kamae/tensorflow/layers/pairwise_cosine_similarity.py rename to src/kamae/keras/core/layers/pairwise_cosine_similarity.py index c3664a64..8d2ce280 100644 --- a/src/kamae/tensorflow/layers/pairwise_cosine_similarity.py +++ b/src/kamae/keras/core/layers/pairwise_cosine_similarity.py @@ -14,16 +14,17 @@ from typing import Any, Dict, Iterable, List, Optional -import tensorflow as tf +import keras +from keras import KerasTensor, ops import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import enforce_multiple_tensor_input +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_multiple_tensor_input +from kamae.keras.core.utils.ops_utils import l2_normalize -from .base import BaseLayer - -@tf.keras.utils.register_keras_serializable(package=kamae.__name__) +@keras.saving.register_keras_serializable(package=kamae.__name__) class PairwiseCosineSimilarityLayer(BaseLayer): """ Computes pairwise cosine similarity between a query embedding and @@ -34,6 +35,9 @@ class PairwiseCosineSimilarityLayer(BaseLayer): Output: (..., N) -- cosine similarity per candidate """ + supported_backends = ALL_BACKENDS + jit_compatible = True + def __init__( self, name: Optional[str] = None, @@ -48,11 +52,16 @@ def __init__( self.embedding_dim = embedding_dim @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: - return [tf.bfloat16, tf.float16, tf.float32, tf.float64] + def compatible_dtypes(self) -> Optional[List[str]]: + return [ + "bfloat16", + "float16", + "float32", + "float64", + ] @enforce_multiple_tensor_input - def _call(self, inputs: Iterable[Tensor], **kwargs: Any) -> Tensor: + def _call(self, inputs: Iterable[KerasTensor], **kwargs: Any) -> KerasTensor: if len(inputs) != 2: raise ValueError(f"Expected 2 inputs, received {len(inputs)} instead.") @@ -60,27 +69,25 @@ def _call(self, inputs: Iterable[Tensor], **kwargs: Any) -> Tensor: flat_candidates = inputs[1] # (..., N*D) # Reshape: (..., N*D) -> (..., N, D) - orig_shape = tf.shape(flat_candidates) + orig_shape = ops.shape(flat_candidates) num_candidates = orig_shape[-1] // self.embedding_dim - new_shape = tf.concat( - [orig_shape[:-1], [num_candidates, self.embedding_dim]], axis=0 - ) - candidates = tf.reshape(flat_candidates, new_shape) + new_shape = list(orig_shape[:-1]) + [num_candidates, self.embedding_dim] + candidates = ops.reshape(flat_candidates, new_shape) # (..., D) -> (..., 1, D) for broadcasting - query_expanded = tf.expand_dims(query, axis=-2) + query_expanded = ops.expand_dims(query, axis=-2) # L2 normalize along embedding dimension - q_norm = tf.nn.l2_normalize(query_expanded, axis=-1) - c_norm = tf.nn.l2_normalize(candidates, axis=-1) + q_norm = l2_normalize(query_expanded, axis=-1) + c_norm = l2_normalize(candidates, axis=-1) # Dot product along last axis: (..., N) - similarities = tf.reduce_sum(tf.multiply(q_norm, c_norm), axis=-1) + similarities = ops.sum(ops.multiply(q_norm, c_norm), axis=-1) # Zero-vector → NaN from normalization → replace with 0.0 - return tf.where( - tf.math.is_nan(similarities), - tf.zeros_like(similarities), + return ops.where( + ops.isnan(similarities), + ops.zeros_like(similarities), similarities, ) diff --git a/src/kamae/tensorflow/layers/round.py b/src/kamae/keras/core/layers/round.py similarity index 83% rename from src/kamae/tensorflow/layers/round.py rename to src/kamae/keras/core/layers/round.py index a11d0616..74eaeeea 100644 --- a/src/kamae/tensorflow/layers/round.py +++ b/src/kamae/keras/core/layers/round.py @@ -14,16 +14,16 @@ from typing import Any, Dict, List, Optional -import tensorflow as tf +import keras +from keras import KerasTensor, ops import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import enforce_single_tensor_input +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_single_tensor_input -from .base import BaseLayer - -@tf.keras.utils.register_keras_serializable(package=kamae.__name__) +@keras.saving.register_keras_serializable(package=kamae.__name__) class RoundLayer(BaseLayer): """ Performs a standard rounding operation on the input tensor. @@ -34,6 +34,9 @@ class RoundLayer(BaseLayer): - 'round' rounds to the nearest integer. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + def __init__( self, round_type: str = "round", @@ -59,16 +62,16 @@ def __init__( self.round_type = round_type @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ - return [tf.float16, tf.float32, tf.float64] + return ["float16", "float32", "float64"] @enforce_single_tensor_input - def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: + def _call(self, inputs: KerasTensor, **kwargs: Any) -> KerasTensor: """ Performs the rounding operation on the input tensor. @@ -79,11 +82,11 @@ def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: :returns: The input tensor with the rounding applied. """ if self.round_type == "ceil": - return tf.math.ceil(inputs) + return ops.ceil(inputs) elif self.round_type == "floor": - return tf.math.floor(inputs) + return ops.floor(inputs) elif self.round_type == "round": - return tf.math.round(inputs) + return ops.round(inputs) else: raise ValueError("""roundType must be one of 'ceil', 'floor' or 'round'.""") diff --git a/src/kamae/tensorflow/layers/round_to_decimal.py b/src/kamae/keras/core/layers/round_to_decimal.py similarity index 79% rename from src/kamae/tensorflow/layers/round_to_decimal.py rename to src/kamae/keras/core/layers/round_to_decimal.py index 503676d3..c61e275c 100644 --- a/src/kamae/tensorflow/layers/round_to_decimal.py +++ b/src/kamae/keras/core/layers/round_to_decimal.py @@ -14,16 +14,17 @@ from typing import Any, Dict, List, Optional -import tensorflow as tf +import keras +from keras import KerasTensor, ops import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import enforce_single_tensor_input +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_single_tensor_input +from kamae.keras.core.utils.tensor_utils import get_dtype_max -from .base import BaseLayer - -@tf.keras.utils.register_keras_serializable(package=kamae.__name__) +@keras.saving.register_keras_serializable(package=kamae.__name__) class RoundToDecimalLayer(BaseLayer): """ Performs a rounding to the nearest decimal operation on the input tensor. @@ -35,6 +36,9 @@ class RoundToDecimalLayer(BaseLayer): number of decimals. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + def __init__( self, decimals: int = 1, @@ -59,16 +63,16 @@ def __init__( self.decimals = decimals @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ - return [tf.float16, tf.float32, tf.float64, tf.int32, tf.int64] + return ["float16", "float32", "float64", "int32", "int64"] @enforce_single_tensor_input - def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: + def _call(self, inputs: KerasTensor, **kwargs: Any) -> KerasTensor: """ Performs the rounding operation on the input tensor. @@ -80,14 +84,16 @@ def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: """ # WARNING: Depending on the type of the input and the number of decimals, # this multiplier could overflow. - max_val = inputs.dtype.max + dtype_str = keras.backend.standardize_dtype(inputs.dtype) + max_val = get_dtype_max(dtype_str) + if 10**self.decimals > max_val: raise ValueError( """The number of decimals is too large for the input dtype. Overflow expected.""" ) - multiplier = tf.constant(10**self.decimals, dtype=inputs.dtype) - return tf.round(inputs * multiplier) / multiplier + multiplier = ops.cast(10**self.decimals, dtype=inputs.dtype) + return ops.divide(ops.round(ops.multiply(inputs, multiplier)), multiplier) def get_config(self) -> Dict[str, Any]: """ diff --git a/src/kamae/tensorflow/layers/standard_scale.py b/src/kamae/keras/core/layers/standard_scale.py similarity index 78% rename from src/kamae/tensorflow/layers/standard_scale.py rename to src/kamae/keras/core/layers/standard_scale.py index b582e601..00e719a7 100644 --- a/src/kamae/tensorflow/layers/standard_scale.py +++ b/src/kamae/keras/core/layers/standard_scale.py @@ -14,18 +14,22 @@ from typing import Any, Dict, List, Optional, Union +import keras import numpy as np -import tensorflow as tf +from keras import KerasTensor, ops import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import NormalizeLayer, enforce_single_tensor_input +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.utils.input_utils import enforce_single_tensor_input +from kamae.keras.core.utils.normalize_layer import NormalizeLayer +from kamae.keras.core.utils.ops_utils import divide_no_nan -@tf.keras.utils.register_keras_serializable(package=kamae.__name__) +@keras.saving.register_keras_serializable(package=kamae.__name__) class StandardScaleLayer(NormalizeLayer): """ Performs the standard scaling of the input. + This layer will shift and scale inputs into a distribution centered around 0 with standard deviation 1. It accomplishes this by precomputing the mean and variance of the data, and calling `(input - mean) / sqrt(var)` at @@ -34,6 +38,9 @@ class StandardScaleLayer(NormalizeLayer): the input value. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + def __init__( self, mean: Union[List[float], np.array], @@ -46,7 +53,8 @@ def __init__( **kwargs: Any, ) -> None: """ - Intialise the StandardScaleLayer layer. + Initialise the StandardScaleLayer layer. + :param mean: The mean value(s) to use during normalization. The passed value(s) will be broadcast to the shape of the kept axes above; if the value(s) cannot be broadcast, an error will be raised when this layer's @@ -83,28 +91,34 @@ def __init__( self.mask_value = mask_value @enforce_single_tensor_input - def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: + def _call(self, inputs: KerasTensor, **kwargs: Any) -> KerasTensor: """ - Performs normalization on the input tensor(s) by calling the keras - StandardScaleLayer layer. It ignores values which are equal to the - mask_value. + Performs normalization on the input tensor(s). It ignores values which + are equal to the mask_value. + Decorated with `@enforce_single_tensor_input` to ensure that the input is a single tensor. Raises an error if multiple tensors are passed in as an iterable. + :param inputs: Input tensor to perform the normalization on. :returns: The input tensor with the normalization applied. """ # Ensure mean and variance match input dtype. - mean = self._cast(self.mean, inputs.dtype.name) - variance = self._cast(self.variance, inputs.dtype.name) - normalized_outputs = tf.math.divide_no_nan( - tf.math.subtract(inputs, mean), - tf.math.maximum(tf.sqrt(variance), tf.constant(1e-8, dtype=inputs.dtype)), + input_dtype_str = keras.backend.standardize_dtype(inputs.dtype) + mean = self._cast(self.mean, input_dtype_str) + variance = self._cast(self.variance, input_dtype_str) + + # Compute (input - mean) / sqrt(variance) using safe division + numerator = ops.subtract(inputs, mean) + denominator = ops.maximum( + ops.sqrt(variance), ops.convert_to_tensor(1e-8, dtype=inputs.dtype) ) + normalized_outputs = divide_no_nan(numerator, denominator) + if self.mask_value is not None: - mask = tf.equal(inputs, self.mask_value) - normalized_outputs = tf.where( - mask, inputs, self._cast(normalized_outputs, inputs.dtype.name) + mask = ops.equal(inputs, self.mask_value) + normalized_outputs = ops.where( + mask, inputs, self._cast(normalized_outputs, input_dtype_str) ) return normalized_outputs @@ -116,7 +130,6 @@ def get_config(self) -> Dict[str, Any]: :returns: Dictionary of the configuration of the layer. """ config = super().get_config() - # Ensure mean and variance are lists for serialization. config.update( { "mask_value": self.mask_value, diff --git a/src/kamae/tensorflow/layers/subtract.py b/src/kamae/keras/core/layers/subtract.py similarity index 75% rename from src/kamae/tensorflow/layers/subtract.py rename to src/kamae/keras/core/layers/subtract.py index 393ee212..5f770e1f 100644 --- a/src/kamae/tensorflow/layers/subtract.py +++ b/src/kamae/keras/core/layers/subtract.py @@ -15,17 +15,24 @@ from functools import reduce from typing import Any, Dict, Iterable, List, Optional, Union -import tensorflow as tf +import keras +from keras import KerasTensor, ops import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import allow_single_or_multiple_tensor_input +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import allow_single_or_multiple_tensor_input -from .base import BaseLayer - -@tf.keras.utils.register_keras_serializable(package=kamae.__name__) +@keras.saving.register_keras_serializable(package=kamae.__name__) class SubtractLayer(BaseLayer): + """ + Performs the subtract(x, y) operation on a given input tensor. + """ + + supported_backends = ALL_BACKENDS + jit_compatible = True + def __init__( self, name: Optional[str] = None, @@ -49,31 +56,33 @@ def __init__( self.subtrahend = subtrahend @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ return [ - tf.bfloat16, - tf.float16, - tf.float32, - tf.float64, - tf.uint8, - tf.int8, - tf.uint16, - tf.int16, - tf.int32, - tf.int64, - tf.complex64, - tf.complex128, - tf.uint32, - tf.uint64, + "bfloat16", + "float16", + "float32", + "float64", + "uint8", + "int8", + "uint16", + "int16", + "int32", + "int64", + "complex64", + "complex128", + "uint32", + "uint64", ] @allow_single_or_multiple_tensor_input - def _call(self, inputs: Union[Tensor, Iterable[Tensor]], **kwargs: Any) -> Tensor: + def _call( + self, inputs: Union[KerasTensor, Iterable[KerasTensor]], **kwargs: Any + ) -> KerasTensor: """ Performs the subtract(x, y) operation on either an iterable of input tensors or a single input tensor and a constant. @@ -92,14 +101,11 @@ def _call(self, inputs: Union[Tensor, Iterable[Tensor]], **kwargs: Any) -> Tenso cast_input, cast_subtrahend = self._force_cast_to_compatible_numeric_type( inputs[0], self.subtrahend ) - return tf.math.subtract( - cast_input, - cast_subtrahend, - ) + return ops.subtract(cast_input, cast_subtrahend) else: if not len(inputs) > 1: raise ValueError("If subtrahend is not set, must have multiple inputs") - return reduce(tf.math.subtract, inputs) + return reduce(ops.subtract, inputs) def get_config(self) -> Dict[str, Any]: """ diff --git a/src/kamae/tensorflow/layers/sum.py b/src/kamae/keras/core/layers/sum.py similarity index 73% rename from src/kamae/tensorflow/layers/sum.py rename to src/kamae/keras/core/layers/sum.py index b09bd8ba..94dd523f 100644 --- a/src/kamae/tensorflow/layers/sum.py +++ b/src/kamae/keras/core/layers/sum.py @@ -15,23 +15,26 @@ from functools import reduce from typing import Any, Dict, Iterable, List, Optional, Union -import tensorflow as tf +import keras +from keras import KerasTensor, ops import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import allow_single_or_multiple_tensor_input +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import allow_single_or_multiple_tensor_input -from .base import BaseLayer - -@tf.keras.utils.register_keras_serializable(package=kamae.__name__) +@keras.saving.register_keras_serializable(package=kamae.__name__) class SumLayer(BaseLayer): """ Performs the sum(x, y) operation on a given input tensor. - If added is not set, inputs are assumed to be a list of tensors and summed. - If added is set, inputs must be a tensor. + If addend is not set, inputs are assumed to be a list of tensors and summed. + If addend is set, inputs must be a tensor. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + def __init__( self, name: Optional[str] = None, @@ -52,31 +55,33 @@ def __init__( self.addend = addend @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ return [ - tf.bfloat16, - tf.float16, - tf.float32, - tf.float64, - tf.uint8, - tf.uint16, - tf.uint32, - tf.uint64, - tf.int8, - tf.int16, - tf.int32, - tf.int64, - tf.complex64, - tf.complex128, + "bfloat16", + "float16", + "float32", + "float64", + "uint8", + "uint16", + "uint32", + "uint64", + "int8", + "int16", + "int32", + "int64", + "complex64", + "complex128", ] @allow_single_or_multiple_tensor_input - def _call(self, inputs: Union[Tensor, Iterable[Tensor]], **kwargs: Any) -> Tensor: + def _call( + self, inputs: Union[KerasTensor, Iterable[KerasTensor]], **kwargs: Any + ) -> KerasTensor: """ Performs the sum(x, y) operation on either an iterable of input tensors or a single input tensor and a constant. @@ -95,14 +100,11 @@ def _call(self, inputs: Union[Tensor, Iterable[Tensor]], **kwargs: Any) -> Tenso cast_input, cast_addend = self._force_cast_to_compatible_numeric_type( inputs[0], self.addend ) - return tf.math.add( - cast_input, - cast_addend, - ) + return ops.add(cast_input, cast_addend) else: if not len(inputs) > 1: raise ValueError("If addend is not set, must have multiple inputs") - return reduce(tf.math.add, inputs) + return reduce(ops.add, inputs) def get_config(self) -> Dict[str, Any]: """ diff --git a/src/kamae/tensorflow/typing/__init__.py b/src/kamae/keras/core/utils/__init__.py similarity index 90% rename from src/kamae/tensorflow/typing/__init__.py rename to src/kamae/keras/core/utils/__init__.py index 2d013142..b5be51ee 100644 --- a/src/kamae/tensorflow/typing/__init__.py +++ b/src/kamae/keras/core/utils/__init__.py @@ -12,4 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .types import Tensor # noqa: F401 +""" +Utility functions for backend-agnostic Keras layers. +""" diff --git a/src/kamae/tensorflow/utils/input_utils.py b/src/kamae/keras/core/utils/input_utils.py similarity index 82% rename from src/kamae/tensorflow/utils/input_utils.py rename to src/kamae/keras/core/utils/input_utils.py index c10c27a2..f7a363ed 100644 --- a/src/kamae/tensorflow/utils/input_utils.py +++ b/src/kamae/keras/core/utils/input_utils.py @@ -12,12 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Provides utilities for tensorflow layer inputs""" +"""Multi-backend input validation utilities for backend-agnostic layers.""" + from typing import Any, Callable, Iterable, List, Union -import tensorflow as tf +import keras +from keras import KerasTensor, ops + + +def is_tensor(x: Any) -> bool: + """ + Checks if the input is a Keras tensor (backend-agnostic). + + Uses keras.ops.is_tensor() which works across all backends. -from kamae.tensorflow.typing import Tensor + :param x: Input to check + :returns: True if x is a Keras tensor + """ + return ops.is_tensor(x) def iter_values(x: Iterable) -> Iterable: @@ -48,15 +60,15 @@ def enforce_single_tensor_input(layer_call_method: Callable) -> Callable: def _enforce_single_tensor_input( self: Any, - inputs: Union[Tensor, Iterable[Tensor]], + inputs: Union[KerasTensor, Iterable[KerasTensor]], **kwargs: Any, - ) -> Tensor: - if tf.is_tensor(inputs): + ) -> KerasTensor: + if is_tensor(inputs): # If the inputs are a tensor, then we return the tensor. processed_inputs = inputs else: input_list = list(iter_values(inputs)) - if len(input_list) == 1 and tf.is_tensor(input_list[0]): + if len(input_list) == 1 and is_tensor(input_list[0]): # If the inputs are an iterable with a single tensor, # then we return the tensor. processed_inputs = input_list[0] @@ -85,19 +97,17 @@ def enforce_multiple_tensor_input(layer_call_method: Callable) -> Callable: def _enforce_multiple_tensor_input( self: Any, - inputs: Union[Tensor, Iterable[Tensor]], + inputs: Union[KerasTensor, Iterable[KerasTensor]], **kwargs: Any, - ) -> List[Tensor]: - if tf.is_tensor(inputs): + ) -> List[KerasTensor]: + if is_tensor(inputs): raise ValueError( """Expected inputs to be a iterable of tensors, but got a single tensor.""" ) else: input_list = list(iter_values(inputs)) - if len(input_list) > 1 and all( - [tf.is_tensor(input_tensor) for input_tensor in input_list] - ): + if len(input_list) > 1 and all([is_tensor(inp) for inp in input_list]): processed_inputs = input_list else: raise ValueError( @@ -121,14 +131,14 @@ def allow_single_or_multiple_tensor_input(layer_call_method: Callable) -> Callab def _allow_single_or_multiple_tensor_input( self: Any, - inputs: Union[Tensor, Iterable[Tensor]], + inputs: Union[KerasTensor, Iterable[KerasTensor]], **kwargs: Any, - ) -> List[Tensor]: - if tf.is_tensor(inputs): + ) -> List[KerasTensor]: + if is_tensor(inputs): processed_inputs = [inputs] else: input_list = list(iter_values(inputs)) - if all([tf.is_tensor(input_tensor) for input_tensor in input_list]): + if all([is_tensor(inp) for inp in input_list]): processed_inputs = input_list else: raise ValueError( diff --git a/src/kamae/tensorflow/utils/layer_utils.py b/src/kamae/keras/core/utils/normalize_layer.py similarity index 77% rename from src/kamae/tensorflow/utils/layer_utils.py rename to src/kamae/keras/core/utils/normalize_layer.py index 3962b911..4c534256 100644 --- a/src/kamae/tensorflow/utils/layer_utils.py +++ b/src/kamae/keras/core/utils/normalize_layer.py @@ -12,13 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +Multi-backend normalization base layer for backend-agnostic scaling operations. +""" + from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np -import tensorflow as tf +from keras import ops -from kamae.tensorflow.layers.base import BaseLayer -from kamae.tensorflow.utils import listify_tensors +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.tensor_utils import listify_tensors class NormalizeLayer(BaseLayer): @@ -26,6 +30,8 @@ class NormalizeLayer(BaseLayer): Intermediate layer for normalization layers. Reduces code duplication by providing a common interface for normalization layers. + + This is a backend-agnostic layer that works with TensorFlow, JAX, and PyTorch. """ def __init__( @@ -80,13 +86,13 @@ def __init__( self.epsilon = 1e-8 @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ - return [tf.bfloat16, tf.float16, tf.float32, tf.float64] + return ["bfloat16", "float16", "float32", "float64"] def build(self, input_shape: Tuple[int]) -> None: """ @@ -100,18 +106,23 @@ def build(self, input_shape: Tuple[int]) -> None: """ super().build(input_shape) - if isinstance(input_shape, (list, tuple)) and all( - isinstance(shape, (tf.TensorShape, list, tuple)) for shape in input_shape - ): - # This seems to be needed to handle sending in multiple inputs as a list. - # Although this layer should only have one input, so this is a bit of a - # hack. We catch this nicely in call method with a decorator. Maybe we - # should do the same here? - input_shape = input_shape[0] + # Save the original input_shape for serialization + # Store as tuple to ensure consistent format + if isinstance(input_shape, (list, tuple)): + self._build_input_shape = tuple(input_shape) + else: + self._build_input_shape = input_shape + + # Ensure input_shape is a list for easier manipulation + if not isinstance(input_shape, list): + input_shape = list(input_shape) + + # Handle Keras serialization quirk: when a tuple like (100, 10, 5) is saved + # and deserialized, Keras may wrap it as [(100, 10, 5)] + if len(input_shape) == 1 and isinstance(input_shape[0], (list, tuple)): + input_shape = list(input_shape[0]) - input_shape = tf.TensorShape(input_shape).as_list() ndim = len(input_shape) - self._build_input_shape = input_shape if any(a < -ndim or a >= ndim for a in self.axis): raise ValueError( @@ -132,15 +143,21 @@ def build(self, input_shape: Tuple[int]) -> None: ) # Broadcast any reduced axes. broadcast_shape = [input_shape[d] if d in keep_axis else 1 for d in range(ndim)] - mean_and_var_shape = tuple(input_shape[d] for d in keep_axis) + # Extract shape dimensions - handle both int and tuple (e.g., 5 or (5,)) + mean_and_var_shape = tuple( + int(input_shape[d][0]) + if isinstance(input_shape[d], tuple) + else int(input_shape[d]) + for d in keep_axis + ) mean = self.input_mean * np.ones(mean_and_var_shape) variance = self.input_variance * np.ones(mean_and_var_shape) - self.mean = tf.reshape(mean, broadcast_shape) - self.variance = tf.reshape(variance, broadcast_shape) + self.mean = ops.reshape(mean, broadcast_shape) + self.variance = ops.reshape(variance, broadcast_shape) def get_config(self) -> Dict[str, Any]: """ - Gets the configuration of the StandardScaleLayer layer. + Gets the configuration of the NormalizeLayer layer. Used for saving and loading from a model. Specifically adds additional parameters to the base configuration. :returns: Dictionary of the configuration of the layer. @@ -151,7 +168,7 @@ def get_config(self) -> Dict[str, Any]: { "mean": listify_tensors(self.input_mean), "variance": listify_tensors(self.input_variance), - "axis": self.axis, + "axis": list(self.axis) if self.axis else None, } ) return config diff --git a/src/kamae/keras/core/utils/ops_utils.py b/src/kamae/keras/core/utils/ops_utils.py new file mode 100644 index 00000000..178c0257 --- /dev/null +++ b/src/kamae/keras/core/utils/ops_utils.py @@ -0,0 +1,83 @@ +# Copyright [2024] Expedia, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Multi-backend operation utilities for backend-agnostic layers. + +Provides common operations that aren't directly available in keras.ops. +""" + +import math + +import keras +from keras import KerasTensor, ops + + +def divide_no_nan(x: KerasTensor, y: KerasTensor) -> KerasTensor: + """ + Multi-backend safe division that returns 0 where y == 0. + + This is a backend-agnostic equivalent of tf.math.divide_no_nan. + Instead of returning NaN or Inf when dividing by zero, returns 0. + + :param x: Numerator tensor + :param y: Denominator tensor + :returns: Result of x / y, with 0 where y == 0 + """ + is_zero = ops.equal(y, ops.convert_to_tensor(0.0, dtype=y.dtype)) + return ops.where(is_zero, ops.zeros_like(x), ops.divide(x, y)) + + +def get_radians(degrees: KerasTensor) -> KerasTensor: + """ + Converts degrees tensor to radians. We need to cast to float64 otherwise + pi / 180 will lose precision. + + :param degrees: Tensor of degrees. + :returns: Tensor of radians. + """ + return ops.cast(degrees, dtype="float64") * ops.convert_to_tensor( + math.pi / 180, dtype="float64" + ) + + +def get_degrees(radians: KerasTensor) -> KerasTensor: + """ + Converts radians tensor to degrees. + + :param radians: Tensor of radians. + :returns: Tensor of degrees. + """ + return ops.cast(radians, dtype="float64") * ops.convert_to_tensor( + 180 / math.pi, dtype="float64" + ) + + +def l2_normalize(x: KerasTensor, axis: int, epsilon: float = 1e-12) -> KerasTensor: + """ + L2 normalize a tensor along a specified axis. + + This is a backend-agnostic implementation of L2 normalization: + normalized = x / sqrt(sum(x^2)) + + :param x: Input tensor to normalize. + :param axis: Axis along which to normalize. + :param epsilon: Small constant to avoid division by zero. + :returns: L2-normalized tensor. + """ + square_sum = ops.sum(ops.square(x), axis=axis, keepdims=True) + norm = ops.sqrt( + ops.maximum(square_sum, ops.convert_to_tensor(epsilon, dtype=x.dtype)) + ) + return x / norm diff --git a/src/kamae/tensorflow/utils/shape_utils.py b/src/kamae/keras/core/utils/shape_utils.py similarity index 64% rename from src/kamae/tensorflow/utils/shape_utils.py rename to src/kamae/keras/core/utils/shape_utils.py index 2b6b7e91..f52388c1 100644 --- a/src/kamae/tensorflow/utils/shape_utils.py +++ b/src/kamae/keras/core/utils/shape_utils.py @@ -12,17 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable, List +""" +Multi-backend shape utility functions for backend-agnostic operations. +""" -import tensorflow as tf +from typing import Iterable, List -from kamae.tensorflow.typing import Tensor +import keras +from keras import KerasTensor, ops -def reshape_to_equal_rank(inputs: Iterable[Tensor]) -> List[Tensor]: +def reshape_to_equal_rank(inputs: Iterable[KerasTensor]) -> List[KerasTensor]: """ Reshapes the input tensors to match the rank of the largest tensor. + This is a backend-agnostic version using keras.ops. + :param inputs: The input tensors to reshape. :return: The reshaped input tensors. """ @@ -31,14 +36,16 @@ def reshape_to_equal_rank(inputs: Iterable[Tensor]) -> List[Tensor]: for x in inputs: rank_diff = max_rank - len(x.shape) if rank_diff > 0: - reshape_dim = tf.concat( + # Get shape as tensor (handles both static and dynamic shapes) + shape_tensor = ops.convert_to_tensor(ops.shape(x)) + reshape_dim = ops.concatenate( [ - tf.shape(x)[:-1], - tf.ones(rank_diff, dtype=tf.int32), - tf.shape(x)[-1:], + shape_tensor[:-1], + ops.ones(rank_diff, dtype="int32"), + shape_tensor[-1:], ], axis=0, ) - x = tf.reshape(x, reshape_dim) + x = ops.reshape(x, reshape_dim) reshaped_inputs.append(x) return reshaped_inputs diff --git a/src/kamae/keras/core/utils/tensor_utils.py b/src/kamae/keras/core/utils/tensor_utils.py new file mode 100644 index 00000000..ba30cafe --- /dev/null +++ b/src/kamae/keras/core/utils/tensor_utils.py @@ -0,0 +1,57 @@ +# Copyright [2024] Expedia, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Multi-backend tensor utility functions for backend-agnostic operations. +""" + +from typing import Any, List, Union + +import keras +import numpy as np +from keras import ops + + +def listify_tensors(x: Union[Any, np.ndarray, List[Any]]) -> List[Any]: + """ + Converts any tensors or numpy arrays to lists for config serialization. + + Works with any backend (TensorFlow, JAX, PyTorch). + + :param x: The input tensor or numpy array. + :returns: The input as a list. + """ + if hasattr(x, "numpy"): + # Most backend tensors have a .numpy() method + x = x.numpy() + if isinstance(x, np.ndarray): + x = x.tolist() + return x + + +def get_dtype_max(dtype_str: str) -> float: + """ + Get the maximum value for a given dtype using numpy's dtype info. + + :param dtype_str: Dtype string (e.g. 'float32', 'int64') + :returns: Maximum value for the dtype + """ + np_dtype = np.dtype(dtype_str) + if np.issubdtype(np_dtype, np.floating): + return np.finfo(np_dtype).max + elif np.issubdtype(np_dtype, np.integer): + return np.iinfo(np_dtype).max + else: + # Fallback for unsupported dtypes + return float("inf") diff --git a/src/kamae/keras/tensorflow/__init__.py b/src/kamae/keras/tensorflow/__init__.py new file mode 100644 index 00000000..5c2932a6 --- /dev/null +++ b/src/kamae/keras/tensorflow/__init__.py @@ -0,0 +1,22 @@ +# Copyright [2024] Expedia, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +TensorFlow-specific Keras layers. + +These layers require the TensorFlow backend and use TensorFlow-specific operations: +- String operations (tf.strings.*) +- Datetime parsing and manipulation +- TensorFlow-specific ops (tf.unique, tf.RaggedTensor, etc.) +""" diff --git a/src/kamae/tensorflow/layers/__init__.py b/src/kamae/keras/tensorflow/layers/__init__.py similarity index 60% rename from src/kamae/tensorflow/layers/__init__.py rename to src/kamae/keras/tensorflow/layers/__init__.py index da971958..f3df8d00 100644 --- a/src/kamae/tensorflow/layers/__init__.py +++ b/src/kamae/keras/tensorflow/layers/__init__.py @@ -12,18 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .absolute_value import AbsoluteValueLayer # noqa: F401 -from .array_concatenate import ArrayConcatenateLayer # noqa: F401 -from .array_crop import ArrayCropLayer # noqa: F401 -from .array_reduce_max import ArrayReduceMaxLayer # noqa: F401 -from .array_split import ArraySplitLayer # noqa: F401 -from .array_subtract_minimum import ArraySubtractMinimumLayer # noqa: F401 -from .bearing_angle import BearingAngleLayer # noqa: F401 -from .bin import BinLayer # noqa: F401 +""" +TensorFlow-specific Keras layers. + +These layers use TensorFlow-specific operations and are the canonical location +for TF-only layers. All layers use the unified BaseLayer from kamae.keras.core.base. +""" + from .bloom_encode import BloomEncodeLayer # noqa: F401 from .bucketize import BucketizeLayer # noqa: F401 -from .conditional_standard_scale import ConditionalStandardScaleLayer # noqa: F401 -from .cosine_similarity import CosineSimilarityLayer # noqa: F401 from .current_date import CurrentDateLayer # noqa: F401 from .current_date_time import CurrentDateTimeLayer # noqa: F401 from .current_unix_timestamp import CurrentUnixTimestampLayer # noqa: F401 @@ -31,14 +28,8 @@ from .date_diff import DateDiffLayer # noqa: F401 from .date_parse import DateParseLayer # noqa: F401 from .date_time_to_unix_timestamp import DateTimeToUnixTimestampLayer # noqa: F401 -from .divide import DivideLayer # noqa: F401 -from .exp import ExpLayer # noqa: F401 -from .exponent import ExponentLayer # noqa: F401 from .hash_index import HashIndexLayer # noqa: F401 -from .haversine_distance import HaversineDistanceLayer # noqa: F401 -from .identity import IdentityLayer # noqa: F401 from .if_statement import IfStatementLayer # noqa: F401 -from .impute import ImputeLayer # noqa: F401 from .lambda_function import LambdaFunctionLayer # noqa: F401 from .list_max import ListMaxLayer # noqa: F401 from .list_mean import ListMeanLayer # noqa: F401 @@ -46,24 +37,9 @@ from .list_min import ListMinLayer # noqa: F401 from .list_rank import ListRankLayer # noqa: F401 from .list_std_dev import ListStdDevLayer # noqa: F401 -from .log import LogLayer # noqa: F401 -from .logical_and import LogicalAndLayer # noqa: F401 -from .logical_not import LogicalNotLayer # noqa: F401 -from .logical_or import LogicalOrLayer # noqa: F401 -from .max import MaxLayer # noqa: F401 -from .mean import MeanLayer # noqa: F401 -from .min import MinLayer # noqa: F401 from .min_hash_index import MinHashIndexLayer # noqa: F401 -from .min_max_scale import MinMaxScaleLayer # noqa: F401 -from .modulo import ModuloLayer # noqa: F401 -from .multiply import MultiplyLayer # noqa: F401 -from .numerical_if_statement import NumericalIfStatementLayer # noqa: F401 from .one_hot_encode import OneHotEncodeLayer, OneHotLayer # noqa: F401 from .ordinal_array_encode import OrdinalArrayEncodeLayer # noqa: F401 -from .pairwise_cosine_similarity import PairwiseCosineSimilarityLayer # noqa: F401 -from .round import RoundLayer # noqa: F401 -from .round_to_decimal import RoundToDecimalLayer # noqa: F401 -from .standard_scale import StandardScaleLayer # noqa: F401 from .string_affix import StringAffixLayer # noqa: F401 from .string_array_constant import StringArrayConstantLayer # noqa: F401 from .string_case import StringCaseLayer # noqa: F401 @@ -78,6 +54,44 @@ from .string_replace import StringReplaceLayer # noqa: F401 from .string_to_string_list import StringToStringListLayer # noqa: F401 from .sub_string_delim_at_index import SubStringDelimAtIndexLayer # noqa: F401 -from .subtract import SubtractLayer # noqa: F401 -from .sum import SumLayer # noqa: F401 from .unix_timestamp_to_date_time import UnixTimestampToDateTimeLayer # noqa: F401 + +__all__ = [ + "BloomEncodeLayer", + "BucketizeLayer", + "CurrentDateLayer", + "CurrentDateTimeLayer", + "CurrentUnixTimestampLayer", + "DateAddLayer", + "DateDiffLayer", + "DateParseLayer", + "DateTimeToUnixTimestampLayer", + "HashIndexLayer", + "IfStatementLayer", + "LambdaFunctionLayer", + "ListMaxLayer", + "ListMeanLayer", + "ListMedianLayer", + "ListMinLayer", + "ListRankLayer", + "ListStdDevLayer", + "MinHashIndexLayer", + "OneHotEncodeLayer", + "OneHotLayer", + "OrdinalArrayEncodeLayer", + "StringAffixLayer", + "StringArrayConstantLayer", + "StringCaseLayer", + "StringConcatenateLayer", + "StringContainsLayer", + "StringContainsListLayer", + "StringEqualsIfStatementLayer", + "StringIndexLayer", + "StringIsInListLayer", + "StringListToStringLayer", + "StringMapLayer", + "StringReplaceLayer", + "StringToStringListLayer", + "SubStringDelimAtIndexLayer", + "UnixTimestampToDateTimeLayer", +] diff --git a/src/kamae/tensorflow/layers/bloom_encode.py b/src/kamae/keras/tensorflow/layers/bloom_encode.py similarity index 94% rename from src/kamae/tensorflow/layers/bloom_encode.py rename to src/kamae/keras/tensorflow/layers/bloom_encode.py index b98a4709..3554e25b 100644 --- a/src/kamae/tensorflow/layers/bloom_encode.py +++ b/src/kamae/keras/tensorflow/layers/bloom_encode.py @@ -14,14 +14,15 @@ from typing import Any, Dict, List, Optional, Union +import keras import tensorflow as tf +from keras import KerasTensor from tensorflow.keras.layers import Hashing import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import enforce_single_tensor_input - -from .base import BaseLayer +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_single_tensor_input @tf.keras.utils.register_keras_serializable(package=kamae.__name__) @@ -37,6 +38,9 @@ class BloomEncodeLayer(BaseLayer): this can be seen as a psuedo-bloom encoding. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + def __init__( self, name: Optional[str] = None, @@ -117,16 +121,16 @@ def __init__( } @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ - return [tf.string] + return ["string"] @enforce_single_tensor_input - def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: + def _call(self, inputs: KerasTensor, **kwargs: Any) -> KerasTensor: """ Performs the bloom encoding on the input tensor. diff --git a/src/kamae/tensorflow/layers/bucketize.py b/src/kamae/keras/tensorflow/layers/bucketize.py similarity index 87% rename from src/kamae/tensorflow/layers/bucketize.py rename to src/kamae/keras/tensorflow/layers/bucketize.py index 982c6470..f3449aca 100644 --- a/src/kamae/tensorflow/layers/bucketize.py +++ b/src/kamae/keras/tensorflow/layers/bucketize.py @@ -14,13 +14,14 @@ from typing import Any, Dict, List, Optional +import keras import tensorflow as tf +from keras import KerasTensor import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import enforce_single_tensor_input - -from .base import BaseLayer +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_single_tensor_input @tf.keras.utils.register_keras_serializable(package=kamae.__name__) @@ -35,6 +36,9 @@ class BucketizeLayer(BaseLayer): is reserved for padding values. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = True + def __init__( self, splits: List[float], @@ -59,16 +63,16 @@ def __init__( self.splits = splits @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ - return [tf.int32, tf.int64, tf.float32, tf.float64] + return ["int32", "int64", "float32", "float64"] @enforce_single_tensor_input - def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: + def _call(self, inputs: KerasTensor, **kwargs: Any) -> KerasTensor: """ Performs the bucketing operation on the input tensor. diff --git a/src/kamae/tensorflow/layers/current_date.py b/src/kamae/keras/tensorflow/layers/current_date.py similarity index 84% rename from src/kamae/tensorflow/layers/current_date.py rename to src/kamae/keras/tensorflow/layers/current_date.py index 05b1a217..4ba4417e 100644 --- a/src/kamae/tensorflow/layers/current_date.py +++ b/src/kamae/keras/tensorflow/layers/current_date.py @@ -14,16 +14,15 @@ from typing import Any, Dict, List, Optional +import keras import tensorflow as tf +from keras import KerasTensor import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import ( - enforce_single_tensor_input, - unix_timestamp_to_datetime, -) - -from .base import BaseLayer +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_single_tensor_input +from kamae.keras.tensorflow.utils.date_utils import unix_timestamp_to_datetime @tf.keras.utils.register_keras_serializable(package=kamae.__name__) @@ -32,6 +31,9 @@ class CurrentDateLayer(BaseLayer): Returns the current UTC date in yyyy-MM-dd format. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + def __init__( self, name: Optional[str] = None, @@ -51,7 +53,7 @@ def __init__( ) @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. Returns `None` as the layer only returns the current date as a string. It does not transform any input. @@ -61,7 +63,7 @@ def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: return None @enforce_single_tensor_input - def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: + def _call(self, inputs: KerasTensor, **kwargs: Any) -> KerasTensor: """ Returns the current timestamp in yyyy-MM-dd format. Uses the input tensor to determine the shape of the output tensor. diff --git a/src/kamae/tensorflow/layers/current_date_time.py b/src/kamae/keras/tensorflow/layers/current_date_time.py similarity index 86% rename from src/kamae/tensorflow/layers/current_date_time.py rename to src/kamae/keras/tensorflow/layers/current_date_time.py index 4b034dca..a439fdd2 100644 --- a/src/kamae/tensorflow/layers/current_date_time.py +++ b/src/kamae/keras/tensorflow/layers/current_date_time.py @@ -14,16 +14,15 @@ from typing import Any, Dict, List, Optional +import keras import tensorflow as tf +from keras import KerasTensor import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import ( - enforce_single_tensor_input, - unix_timestamp_to_datetime, -) - -from .base import BaseLayer +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_single_tensor_input +from kamae.keras.tensorflow.utils.date_utils import unix_timestamp_to_datetime @tf.keras.utils.register_keras_serializable(package=kamae.__name__) @@ -39,6 +38,9 @@ class CurrentDateTimeLayer(BaseLayer): It is recommended not to rely on parity at the millisecond level. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + def __init__( self, name: Optional[str] = None, @@ -58,7 +60,7 @@ def __init__( ) @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. Returns `None` as the layer only returns the current date as a string. It does not transform any input. @@ -68,7 +70,7 @@ def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: return None @enforce_single_tensor_input - def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: + def _call(self, inputs: KerasTensor, **kwargs: Any) -> KerasTensor: """ Returns the current timestamp in yyyy-MM-dd HH:mm:ss format. Uses the input tensor to determine the shape of the output tensor. diff --git a/src/kamae/tensorflow/layers/current_unix_timestamp.py b/src/kamae/keras/tensorflow/layers/current_unix_timestamp.py similarity index 90% rename from src/kamae/tensorflow/layers/current_unix_timestamp.py rename to src/kamae/keras/tensorflow/layers/current_unix_timestamp.py index e37cfdce..7fca81f9 100644 --- a/src/kamae/tensorflow/layers/current_unix_timestamp.py +++ b/src/kamae/keras/tensorflow/layers/current_unix_timestamp.py @@ -14,13 +14,14 @@ from typing import Any, Dict, List, Optional +import keras import tensorflow as tf +from keras import KerasTensor import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import enforce_single_tensor_input - -from .base import BaseLayer +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_single_tensor_input @tf.keras.utils.register_keras_serializable(package=kamae.__name__) @@ -37,6 +38,9 @@ class CurrentUnixTimestampLayer(BaseLayer): It is recommended not to rely on parity at the millisecond level. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + def __init__( self, name: Optional[str] = None, @@ -66,7 +70,7 @@ def __init__( self.unit = unit @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. Returns `None` as the layer only returns the current date as a string. It does not transform any input. @@ -76,7 +80,7 @@ def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: return None @enforce_single_tensor_input - def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: + def _call(self, inputs: KerasTensor, **kwargs: Any) -> KerasTensor: """ Returns the current unix timestamp in either seconds or milliseconds. Uses the input tensor to determine the shape of the output tensor. diff --git a/src/kamae/tensorflow/layers/date_add.py b/src/kamae/keras/tensorflow/layers/date_add.py similarity index 87% rename from src/kamae/tensorflow/layers/date_add.py rename to src/kamae/keras/tensorflow/layers/date_add.py index e306c3ec..f50bd58e 100644 --- a/src/kamae/tensorflow/layers/date_add.py +++ b/src/kamae/keras/tensorflow/layers/date_add.py @@ -14,16 +14,15 @@ from typing import Any, Dict, List, Optional +import keras import tensorflow as tf +from keras import KerasTensor import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import ( - allow_single_or_multiple_tensor_input, - datetime_add_days, -) - -from .base import BaseLayer +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import allow_single_or_multiple_tensor_input +from kamae.keras.tensorflow.utils.date_utils import datetime_add_days @tf.keras.utils.register_keras_serializable(package=kamae.__name__) @@ -34,6 +33,9 @@ class DateAddLayer(BaseLayer): WARNING: This layer destroys the time component of the date column. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + def __init__( self, name: Optional[str] = None, @@ -66,16 +68,16 @@ def __init__( self.num_days = num_days @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ - return [tf.string, tf.int8, tf.int16, tf.int32, tf.int64] + return ["string", "int8", "int16", "int32", "int64"] @allow_single_or_multiple_tensor_input - def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: + def _call(self, inputs: KerasTensor, **kwargs: Any) -> KerasTensor: """ Adds or subtracts a number of days from a date(time) string. """ @@ -98,7 +100,7 @@ def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: raise ValueError( "When `num_days` is not set, the input should be two tensors." ) - if not inputs[1].dtype.is_integer: + if "int" not in keras.backend.standardize_dtype(inputs[1].dtype): raise ValueError( f"""Expected second input dtype to be integer, but got {inputs[1].dtype}.""" diff --git a/src/kamae/tensorflow/layers/date_diff.py b/src/kamae/keras/tensorflow/layers/date_diff.py similarity index 86% rename from src/kamae/tensorflow/layers/date_diff.py rename to src/kamae/keras/tensorflow/layers/date_diff.py index eb20052b..c5ffcd0e 100644 --- a/src/kamae/tensorflow/layers/date_diff.py +++ b/src/kamae/keras/tensorflow/layers/date_diff.py @@ -14,13 +14,15 @@ from typing import Any, Dict, List, Optional +import keras import tensorflow as tf +from keras import KerasTensor import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import datetime_total_days, enforce_multiple_tensor_input - -from .base import BaseLayer +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_multiple_tensor_input +from kamae.keras.tensorflow.utils.date_utils import datetime_total_days @tf.keras.utils.register_keras_serializable(package=kamae.__name__) @@ -32,6 +34,9 @@ class DateDiffLayer(BaseLayer): The transformer will return a negative value if the order is reversed. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + def __init__( self, name: Optional[str] = None, @@ -53,16 +58,16 @@ def __init__( self.default_value = default_value @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ - return [tf.string] + return ["string"] @enforce_multiple_tensor_input - def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: + def _call(self, inputs: KerasTensor, **kwargs: Any) -> KerasTensor: """ Performs the date difference operation on two input tensors. @@ -98,7 +103,9 @@ def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: outputs = self.date_difference(end_date, start_date) return outputs - def date_difference(self, end_date: Tensor, start_date: Tensor) -> Tensor: + def date_difference( + self, end_date: KerasTensor, start_date: KerasTensor + ) -> KerasTensor: """ Calculates the difference between two dates. diff --git a/src/kamae/tensorflow/layers/date_parse.py b/src/kamae/keras/tensorflow/layers/date_parse.py similarity index 91% rename from src/kamae/tensorflow/layers/date_parse.py rename to src/kamae/keras/tensorflow/layers/date_parse.py index 13a89a72..57193f28 100644 --- a/src/kamae/tensorflow/layers/date_parse.py +++ b/src/kamae/keras/tensorflow/layers/date_parse.py @@ -14,11 +14,15 @@ from typing import Any, Dict, List, Optional +import keras import tensorflow as tf +from keras import KerasTensor import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import ( +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_single_tensor_input +from kamae.keras.tensorflow.utils.date_utils import ( datetime_day, datetime_day_of_year, datetime_hour, @@ -28,11 +32,8 @@ datetime_second, datetime_weekday, datetime_year, - enforce_single_tensor_input, ) -from .base import BaseLayer - @tf.keras.utils.register_keras_serializable(package=kamae.__name__) class DateParseLayer(BaseLayer): @@ -62,6 +63,9 @@ class DateParseLayer(BaseLayer): as "2020-02-30" no errors will be thrown and you will get a nonsense output. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + def __init__( self, date_part: str, @@ -102,16 +106,16 @@ def __init__( self.default_value = default_value @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ - return [tf.string] + return ["string"] @enforce_single_tensor_input - def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: + def _call(self, inputs: KerasTensor, **kwargs: Any) -> KerasTensor: """ Extracts date part from date(time) string. @@ -142,7 +146,7 @@ def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: return outputs @staticmethod - def _parse_date(date_tensor: Tensor, date_part: str) -> Tensor: + def _parse_date(date_tensor: KerasTensor, date_part: str) -> KerasTensor: """ Parse date(time) string into a dictionary of date part tensors. diff --git a/src/kamae/tensorflow/layers/date_time_to_unix_timestamp.py b/src/kamae/keras/tensorflow/layers/date_time_to_unix_timestamp.py similarity index 86% rename from src/kamae/tensorflow/layers/date_time_to_unix_timestamp.py rename to src/kamae/keras/tensorflow/layers/date_time_to_unix_timestamp.py index 217f289d..a5c280e8 100644 --- a/src/kamae/tensorflow/layers/date_time_to_unix_timestamp.py +++ b/src/kamae/keras/tensorflow/layers/date_time_to_unix_timestamp.py @@ -14,16 +14,15 @@ from typing import Any, Dict, List, Optional +import keras import tensorflow as tf +from keras import KerasTensor import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import ( - datetime_to_unix_timestamp, - enforce_single_tensor_input, -) - -from .base import BaseLayer +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_single_tensor_input +from kamae.keras.tensorflow.utils.date_utils import datetime_to_unix_timestamp @tf.keras.utils.register_keras_serializable(package=kamae.__name__) @@ -33,6 +32,9 @@ class DateTimeToUnixTimestampLayer(BaseLayer): or yyyy-MM-dd format. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + def __init__( self, name: Optional[str] = None, @@ -64,16 +66,16 @@ def __init__( self.unit = unit @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ - return [tf.string] + return ["string"] @enforce_single_tensor_input - def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: + def _call(self, inputs: KerasTensor, **kwargs: Any) -> KerasTensor: """ Returns the unix timestamp from a datetime in either yyyy-MM-dd HH:mm:ss.SSS or yyyy-MM-dd format. diff --git a/src/kamae/tensorflow/layers/hash_index.py b/src/kamae/keras/tensorflow/layers/hash_index.py similarity index 90% rename from src/kamae/tensorflow/layers/hash_index.py rename to src/kamae/keras/tensorflow/layers/hash_index.py index 3b982074..4946a95c 100644 --- a/src/kamae/tensorflow/layers/hash_index.py +++ b/src/kamae/keras/tensorflow/layers/hash_index.py @@ -14,14 +14,15 @@ from typing import Any, Dict, List, Optional, Union +import keras import tensorflow as tf +from keras import KerasTensor from tensorflow.keras.layers import Hashing import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import enforce_single_tensor_input - -from .base import BaseLayer +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_single_tensor_input @tf.keras.utils.register_keras_serializable(package=kamae.__name__) @@ -40,6 +41,9 @@ class HashIndexLayer(BaseLayer): input bits thoroughly. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + def __init__( self, num_bins: int, @@ -77,16 +81,16 @@ def __init__( self.hash_indexer = Hashing(name=name, num_bins=num_bins - 1) @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ - return [tf.string] + return ["string"] @enforce_single_tensor_input - def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: + def _call(self, inputs: KerasTensor, **kwargs: Any) -> KerasTensor: """ Performs the hash indexing on the input tensor by calling the underlying Hashing layer. diff --git a/src/kamae/tensorflow/layers/if_statement.py b/src/kamae/keras/tensorflow/layers/if_statement.py similarity index 90% rename from src/kamae/tensorflow/layers/if_statement.py rename to src/kamae/keras/tensorflow/layers/if_statement.py index c08fec4a..cacda1aa 100644 --- a/src/kamae/tensorflow/layers/if_statement.py +++ b/src/kamae/keras/tensorflow/layers/if_statement.py @@ -14,19 +14,24 @@ from numbers import Number from typing import Any, Dict, Iterable, List, Optional, Union +import keras import tensorflow as tf +from keras import KerasTensor import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import allow_single_or_multiple_tensor_input +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import allow_single_or_multiple_tensor_input from kamae.utils import get_condition_operator -from .base import BaseLayer - @tf.keras.utils.register_keras_serializable(package=kamae.__name__) class IfStatementLayer(BaseLayer): """ + Performs an if statement on the input tensor. + + This layer requires TensorFlow backend as it supports string operations. + Performs an if statement on the input tensor, returning a tensor of the same shape as the input tensor. @@ -46,6 +51,9 @@ class IfStatementLayer(BaseLayer): not None, then inputs is expected to be a tensor. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + def __init__( self, condition_operator: str, @@ -101,7 +109,7 @@ def __init__( ) @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. @@ -166,7 +174,9 @@ def _create_casted_tensor_from_tensor_or_constant( ) @allow_single_or_multiple_tensor_input - def _call(self, inputs: Union[Tensor, Iterable[Tensor]], **kwargs: Any) -> Tensor: + def _call( + self, inputs: Union[KerasTensor, Iterable[KerasTensor]], **kwargs: Any + ) -> KerasTensor: """ Performs the numerical if statement on the inputs. If the inputs are a tensor, we assume that the value_to_compare, result_if_true, and result_if_false are @@ -198,7 +208,8 @@ def _call(self, inputs: Union[Tensor, Iterable[Tensor]], **kwargs: Any) -> Tenso "If inputs is a tensor, value_to_compare, result_if_true, and " "result_if_false must be specified." ) - if inputs[0].dtype.is_floating or inputs[0].dtype.is_integer: + dtype_str = keras.backend.standardize_dtype(inputs[0].dtype) + if "float" in dtype_str or "int" in dtype_str: inputs, value_to_compare = self._force_cast_to_compatible_numeric_type( inputs[0], self.value_to_compare ) @@ -229,11 +240,12 @@ def _call(self, inputs: Union[Tensor, Iterable[Tensor]], **kwargs: Any) -> Tenso # If the value to compare is a tensor, we cast it to the input dtype inputs = input_tensors[0] value_to_compare = self._cast( - input_tensors[1], cast_dtype=input_tensors[0].dtype.name + input_tensors[1], + cast_dtype=keras.backend.standardize_dtype(input_tensors[0].dtype), ) - elif ( - input_tensors[0].dtype.is_floating or input_tensors[0].dtype.is_integer - ): + elif "float" in keras.backend.standardize_dtype( + input_tensors[0].dtype + ) or "int" in keras.backend.standardize_dtype(input_tensors[0].dtype): # If the inputs are numeric we force cast it to a compatible dtype inputs, value_to_compare = self._force_cast_to_compatible_numeric_type( input_tensors[0], input_tensors[1] @@ -242,7 +254,8 @@ def _call(self, inputs: Union[Tensor, Iterable[Tensor]], **kwargs: Any) -> Tenso # The inputs are not numeric, so we just do the regular casting inputs = input_tensors[0] value_to_compare = self._cast( - tf.constant(input_tensors[1]), inputs.dtype.name + tf.constant(input_tensors[1]), + keras.backend.standardize_dtype(inputs.dtype), ) cond = tf.where( diff --git a/src/kamae/tensorflow/layers/lambda_function.py b/src/kamae/keras/tensorflow/layers/lambda_function.py similarity index 83% rename from src/kamae/tensorflow/layers/lambda_function.py rename to src/kamae/keras/tensorflow/layers/lambda_function.py index b02e715f..4a298c7a 100644 --- a/src/kamae/tensorflow/layers/lambda_function.py +++ b/src/kamae/keras/tensorflow/layers/lambda_function.py @@ -14,13 +14,14 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Union +import keras import tensorflow as tf +from keras import KerasTensor import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import allow_single_or_multiple_tensor_input - -from .base import BaseLayer +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import allow_single_or_multiple_tensor_input @tf.keras.utils.register_keras_serializable(package=kamae.__name__) @@ -36,9 +37,15 @@ class LambdaFunctionLayer(BaseLayer, tf.keras.layers.Lambda): they were saved. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + def __init__( self, - function: Callable[[Union[Tensor, List[Tensor]]], Union[Tensor, List[Tensor]]], + function: Callable[ + [Union[KerasTensor, List[KerasTensor]]], + Union[KerasTensor, List[KerasTensor]], + ], name: Optional[str] = None, input_dtype: Optional[str] = None, output_dtype: Optional[str] = None, @@ -61,7 +68,7 @@ def __init__( ) @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. @@ -71,8 +78,8 @@ def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: @allow_single_or_multiple_tensor_input def _call( - self, inputs: Union[Tensor, Iterable[Tensor]], **kwargs: Any - ) -> Union[Tensor, Iterable[Tensor]]: + self, inputs: Union[KerasTensor, Iterable[KerasTensor]], **kwargs: Any + ) -> Union[KerasTensor, Iterable[KerasTensor]]: """ Transforms the input tensor(s) by applying the lambda function. diff --git a/src/kamae/tensorflow/layers/list_max.py b/src/kamae/keras/tensorflow/layers/list_max.py similarity index 90% rename from src/kamae/tensorflow/layers/list_max.py rename to src/kamae/keras/tensorflow/layers/list_max.py index 3331b45d..12e91a52 100644 --- a/src/kamae/tensorflow/layers/list_max.py +++ b/src/kamae/keras/tensorflow/layers/list_max.py @@ -14,18 +14,16 @@ from typing import Any, Dict, Iterable, List, Optional +import keras import tensorflow as tf +from keras import KerasTensor import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import ( - allow_single_or_multiple_tensor_input, - get_top_n, - map_fn_w_axis, - segmented_operation, -) - -from .base import BaseLayer +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import allow_single_or_multiple_tensor_input +from kamae.keras.tensorflow.utils.list_utils import get_top_n, segmented_operation +from kamae.keras.tensorflow.utils.transform_utils import map_fn_w_axis @tf.keras.utils.register_keras_serializable(package=kamae.__name__) @@ -52,6 +50,9 @@ class ListMaxLayer(BaseLayer): items sorted by descending production. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = True + def __init__( self, name: Optional[str] = None, @@ -96,22 +97,22 @@ def __init__( self.with_segment = with_segment @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ return [ - tf.bfloat16, - tf.float16, - tf.float32, - tf.float64, - tf.string, + "bfloat16", + "float16", + "float32", + "float64", + "string", ] @allow_single_or_multiple_tensor_input - def _call(self, inputs: Iterable[Tensor], **kwargs: Any) -> Tensor: + def _call(self, inputs: Iterable[KerasTensor], **kwargs: Any) -> KerasTensor: """ Calculate the listwise max, optionally sorting and filtering based on the second input tensor, or segmenting diff --git a/src/kamae/tensorflow/layers/list_mean.py b/src/kamae/keras/tensorflow/layers/list_mean.py similarity index 92% rename from src/kamae/tensorflow/layers/list_mean.py rename to src/kamae/keras/tensorflow/layers/list_mean.py index ec34e4fa..1a27360f 100644 --- a/src/kamae/tensorflow/layers/list_mean.py +++ b/src/kamae/keras/tensorflow/layers/list_mean.py @@ -14,18 +14,16 @@ from typing import Any, Dict, Iterable, List, Optional +import keras import tensorflow as tf +from keras import KerasTensor import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import ( - allow_single_or_multiple_tensor_input, - get_top_n, - map_fn_w_axis, - segmented_operation, -) - -from .base import BaseLayer +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import allow_single_or_multiple_tensor_input +from kamae.keras.tensorflow.utils.list_utils import get_top_n, segmented_operation +from kamae.keras.tensorflow.utils.transform_utils import map_fn_w_axis @tf.keras.utils.register_keras_serializable(package=kamae.__name__) @@ -50,6 +48,9 @@ class ListMeanLayer(BaseLayer): items sorted by descending production. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = True + def __init__( self, name: Optional[str] = None, @@ -94,22 +95,21 @@ def __init__( self.with_segment = with_segment @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ return [ - tf.bfloat16, - tf.float16, - tf.float32, - tf.float64, - tf.string, + "bfloat16", + "float16", + "float32", + "float64", ] @allow_single_or_multiple_tensor_input - def _call(self, inputs: Iterable[Tensor], **kwargs: Any) -> Tensor: + def _call(self, inputs: Iterable[KerasTensor], **kwargs: Any) -> KerasTensor: """ Calculate the listwise mean, optionally sorting and filtering based on the second input tensor, or segmenting @@ -148,7 +148,7 @@ def _call(self, inputs: Iterable[Tensor], **kwargs: Any) -> Tensor: if self.with_segment: - def segment_mean(values: List[Tensor]) -> Tensor: + def segment_mean(values: List[KerasTensor]) -> KerasTensor: mask = tf.math.is_finite(values[0]) val_tensor = values[0] segment_tensor = values[1] diff --git a/src/kamae/tensorflow/layers/list_median.py b/src/kamae/keras/tensorflow/layers/list_median.py similarity index 91% rename from src/kamae/tensorflow/layers/list_median.py rename to src/kamae/keras/tensorflow/layers/list_median.py index bf1c417c..3ff67a79 100644 --- a/src/kamae/tensorflow/layers/list_median.py +++ b/src/kamae/keras/tensorflow/layers/list_median.py @@ -14,13 +14,15 @@ from typing import Any, Dict, Iterable, List, Optional +import keras import tensorflow as tf +from keras import KerasTensor import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import allow_single_or_multiple_tensor_input, get_top_n - -from .base import BaseLayer +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import allow_single_or_multiple_tensor_input +from kamae.keras.tensorflow.utils.list_utils import get_top_n @tf.keras.utils.register_keras_serializable(package=kamae.__name__) @@ -43,6 +45,9 @@ class ListMedianLayer(BaseLayer): WARNING: ListMedianLayer requires at least rank 3 tensor input. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = True + def __init__( self, name: Optional[str] = None, @@ -83,20 +88,20 @@ def __init__( self.axis = axis @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ return [ - tf.bfloat16, - tf.float16, - tf.float32, - tf.float64, + "bfloat16", + "float16", + "float32", + "float64", ] - def sort_with_nans_last(self, tensor: Tensor) -> Tensor: + def sort_with_nans_last(self, tensor: KerasTensor) -> KerasTensor: """ Sorts a tensor while placing NaN values at the end along the specified axis. @@ -119,7 +124,7 @@ def sort_with_nans_last(self, tensor: Tensor) -> Tensor: return sorted_masked_tensor @allow_single_or_multiple_tensor_input - def _call(self, inputs: Iterable[Tensor], **kwargs: Any) -> Tensor: + def _call(self, inputs: Iterable[KerasTensor], **kwargs: Any) -> KerasTensor: """ Calculate the listwise median, optionally sorting and filtering based on the second input tensor. @@ -186,7 +191,7 @@ def _call(self, inputs: Iterable[Tensor], **kwargs: Any) -> Tensor: ) # Fill nan - is_integer = listwise_median.dtype.is_integer + is_integer = "int" in keras.backend.standardize_dtype(listwise_median.dtype) nan_val = int(self.nan_fill_value) if is_integer else self.nan_fill_value listwise_median = tf.where( tf.math.is_nan(listwise_median), diff --git a/src/kamae/tensorflow/layers/list_min.py b/src/kamae/keras/tensorflow/layers/list_min.py similarity index 91% rename from src/kamae/tensorflow/layers/list_min.py rename to src/kamae/keras/tensorflow/layers/list_min.py index c1998aac..13b66d1d 100644 --- a/src/kamae/tensorflow/layers/list_min.py +++ b/src/kamae/keras/tensorflow/layers/list_min.py @@ -14,18 +14,16 @@ from typing import Any, Dict, Iterable, List, Optional +import keras import tensorflow as tf +from keras import KerasTensor import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import ( - allow_single_or_multiple_tensor_input, - get_top_n, - map_fn_w_axis, - segmented_operation, -) - -from .base import BaseLayer +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import allow_single_or_multiple_tensor_input +from kamae.keras.tensorflow.utils.list_utils import get_top_n, segmented_operation +from kamae.keras.tensorflow.utils.transform_utils import map_fn_w_axis @tf.keras.utils.register_keras_serializable(package=kamae.__name__) @@ -51,6 +49,9 @@ class ListMinLayer(BaseLayer): items sorted by descending production. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = True + def __init__( self, name: Optional[str] = None, @@ -95,22 +96,22 @@ def __init__( self.with_segment = with_segment @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ return [ - tf.bfloat16, - tf.float16, - tf.float32, - tf.float64, - tf.string, + "bfloat16", + "float16", + "float32", + "float64", + "string", ] @allow_single_or_multiple_tensor_input - def _call(self, inputs: Iterable[Tensor], **kwargs: Any) -> Tensor: + def _call(self, inputs: Iterable[KerasTensor], **kwargs: Any) -> KerasTensor: """ Calculate the listwise min, optionally sorting and filtering based on the second input tensor, or segmenting diff --git a/src/kamae/tensorflow/layers/list_rank.py b/src/kamae/keras/tensorflow/layers/list_rank.py similarity index 82% rename from src/kamae/tensorflow/layers/list_rank.py rename to src/kamae/keras/tensorflow/layers/list_rank.py index 055eef9f..58698146 100644 --- a/src/kamae/tensorflow/layers/list_rank.py +++ b/src/kamae/keras/tensorflow/layers/list_rank.py @@ -14,13 +14,14 @@ from typing import Any, Dict, Iterable, List, Optional +import keras import tensorflow as tf +from keras import KerasTensor import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import enforce_single_tensor_input - -from .base import BaseLayer +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_single_tensor_input @tf.keras.utils.register_keras_serializable(package=kamae.__name__) @@ -31,6 +32,9 @@ class ListRankLayer(BaseLayer): Example: calculate the rank of items within a query, given the score. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = True + def __init__( self, name: Optional[str] = None, @@ -56,27 +60,27 @@ def __init__( self.axis = axis @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ return [ - tf.bfloat16, - tf.float16, - tf.float32, - tf.float64, - tf.uint8, - tf.int8, - tf.uint16, - tf.int16, - tf.int32, - tf.int64, + "bfloat16", + "float16", + "float32", + "float64", + "uint8", + "int8", + "uint16", + "int16", + "int32", + "int64", ] @enforce_single_tensor_input - def _call(self, inputs: Iterable[Tensor], **kwargs: Any) -> Tensor: + def _call(self, inputs: Iterable[KerasTensor], **kwargs: Any) -> KerasTensor: """ Calculate the rank. diff --git a/src/kamae/tensorflow/layers/list_std_dev.py b/src/kamae/keras/tensorflow/layers/list_std_dev.py similarity index 91% rename from src/kamae/tensorflow/layers/list_std_dev.py rename to src/kamae/keras/tensorflow/layers/list_std_dev.py index 0e37485c..5e61a96e 100644 --- a/src/kamae/tensorflow/layers/list_std_dev.py +++ b/src/kamae/keras/tensorflow/layers/list_std_dev.py @@ -14,13 +14,15 @@ from typing import Any, Dict, Iterable, List, Optional +import keras import tensorflow as tf +from keras import KerasTensor import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import allow_single_or_multiple_tensor_input, get_top_n - -from .base import BaseLayer +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import allow_single_or_multiple_tensor_input +from kamae.keras.tensorflow.utils.list_utils import get_top_n @tf.keras.utils.register_keras_serializable(package=kamae.__name__) @@ -41,6 +43,9 @@ class ListStdDevLayer(BaseLayer): items sorted by descending production. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = True + def __init__( self, name: Optional[str] = None, @@ -81,21 +86,21 @@ def __init__( self.axis = axis @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ return [ - tf.bfloat16, - tf.float16, - tf.float32, - tf.float64, + "bfloat16", + "float16", + "float32", + "float64", ] @allow_single_or_multiple_tensor_input - def _call(self, inputs: Iterable[Tensor], **kwargs: Any) -> Tensor: + def _call(self, inputs: Iterable[KerasTensor], **kwargs: Any) -> KerasTensor: """ Calculate the listwise average, optionally sorting and filtering based on the second input tensor. @@ -169,7 +174,7 @@ def _call(self, inputs: Iterable[Tensor], **kwargs: Any) -> Tensor: listwise_stddev = tf.sqrt(listwise_variance) # Fill nan - is_integer = listwise_stddev.dtype.is_integer + is_integer = "int" in keras.backend.standardize_dtype(listwise_stddev.dtype) nan_val = int(self.nan_fill_value) if is_integer else self.nan_fill_value listwise_stddev = tf.where( tf.math.is_nan(listwise_stddev), diff --git a/src/kamae/tensorflow/layers/min_hash_index.py b/src/kamae/keras/tensorflow/layers/min_hash_index.py similarity index 92% rename from src/kamae/tensorflow/layers/min_hash_index.py rename to src/kamae/keras/tensorflow/layers/min_hash_index.py index 0b6ad686..c98de296 100644 --- a/src/kamae/tensorflow/layers/min_hash_index.py +++ b/src/kamae/keras/tensorflow/layers/min_hash_index.py @@ -14,14 +14,15 @@ from typing import Any, Dict, List, Optional +import keras import tensorflow as tf +from keras import KerasTensor from tensorflow.keras.layers import Hashing import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import enforce_single_tensor_input - -from .base import BaseLayer +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_single_tensor_input @tf.keras.utils.register_keras_serializable(package=kamae.__name__) @@ -43,6 +44,9 @@ class MinHashIndexLayer(BaseLayer): The minimum is computed across the last dimension of the input tensor. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + def __init__( self, name: Optional[str] = None, @@ -81,16 +85,16 @@ def __init__( ) @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ - return [tf.string] + return ["string"] @enforce_single_tensor_input - def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: + def _call(self, inputs: KerasTensor, **kwargs: Any) -> KerasTensor: """ Performs the min hash indexing on the input tensor. diff --git a/src/kamae/tensorflow/layers/one_hot_encode.py b/src/kamae/keras/tensorflow/layers/one_hot_encode.py similarity index 92% rename from src/kamae/tensorflow/layers/one_hot_encode.py rename to src/kamae/keras/tensorflow/layers/one_hot_encode.py index 5c020e9c..408739f7 100644 --- a/src/kamae/tensorflow/layers/one_hot_encode.py +++ b/src/kamae/keras/tensorflow/layers/one_hot_encode.py @@ -15,13 +15,14 @@ import warnings from typing import Any, Dict, List, Optional, Union +import keras import tensorflow as tf +from keras import KerasTensor import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import enforce_single_tensor_input - -from .base import BaseLayer +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_single_tensor_input @tf.keras.utils.register_keras_serializable(package=kamae.__name__) @@ -36,6 +37,9 @@ class OneHotEncodeLayer(BaseLayer): dimension for the encoded output. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + def __init__( self, vocabulary: Union[str, List[str]], @@ -87,16 +91,16 @@ def __init__( ) @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ - return [tf.int16, tf.int32, tf.int64, tf.string] + return ["int16", "int32", "int64", "string"] @enforce_single_tensor_input - def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: + def _call(self, inputs: KerasTensor, **kwargs: Any) -> KerasTensor: """ Performs the one-hot encoding on the input tensor. @@ -159,6 +163,9 @@ def get_config(self) -> Dict[str, Any]: # it is maintained for backwards compatibility @tf.keras.utils.register_keras_serializable(package=kamae.__name__) class OneHotLayer(OneHotEncodeLayer): + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + def __init__(self, *args: Any, **kwargs: Any) -> None: warnings.warn( "OneHotLayer is deprecated and will be removed in a future release. " diff --git a/src/kamae/tensorflow/layers/ordinal_array_encode.py b/src/kamae/keras/tensorflow/layers/ordinal_array_encode.py similarity index 88% rename from src/kamae/tensorflow/layers/ordinal_array_encode.py rename to src/kamae/keras/tensorflow/layers/ordinal_array_encode.py index 2bfaede5..67af333c 100644 --- a/src/kamae/tensorflow/layers/ordinal_array_encode.py +++ b/src/kamae/keras/tensorflow/layers/ordinal_array_encode.py @@ -14,13 +14,15 @@ from typing import Any, Dict, List, Optional +import keras import tensorflow as tf +from keras import KerasTensor import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import enforce_single_tensor_input, map_fn_w_axis - -from .base import BaseLayer +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_single_tensor_input +from kamae.keras.tensorflow.utils.transform_utils import map_fn_w_axis @tf.keras.utils.register_keras_serializable(package=kamae.__name__) @@ -33,6 +35,9 @@ class OrdinalArrayEncodeLayer(BaseLayer): ignore the pad value if specified. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + def __init__( self, pad_value: Optional[str] = None, @@ -58,16 +63,16 @@ def __init__( self.axis = axis @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ - return [tf.string] + return ["string"] @enforce_single_tensor_input - def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: + def _call(self, inputs: KerasTensor, **kwargs: Any) -> KerasTensor: """ Performs the ordinal encoding on the input dataset. Example: @@ -88,7 +93,7 @@ def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: """ @tf.function - def _transform_row(input_row: Tensor) -> Tensor: + def _transform_row(input_row: KerasTensor) -> KerasTensor: if self.pad_value is None: converted_tensor = tf.unique(input_row).idx else: diff --git a/src/kamae/tensorflow/layers/string_affix.py b/src/kamae/keras/tensorflow/layers/string_affix.py similarity index 86% rename from src/kamae/tensorflow/layers/string_affix.py rename to src/kamae/keras/tensorflow/layers/string_affix.py index 806845b6..79b439e1 100644 --- a/src/kamae/tensorflow/layers/string_affix.py +++ b/src/kamae/keras/tensorflow/layers/string_affix.py @@ -14,21 +14,25 @@ from typing import Any, Dict, List, Optional +import keras import tensorflow as tf +from keras import KerasTensor import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import enforce_single_tensor_input +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_single_tensor_input -from .base import BaseLayer - -@tf.keras.utils.register_keras_serializable(kamae.__name__) +@tf.keras.utils.register_keras_serializable(package=kamae.__name__) class StringAffixLayer(BaseLayer): """ Performs a prefixing and suffing on the input tensor. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + def __init__( self, name: Optional[str] = None, @@ -66,16 +70,16 @@ def validate_params(self) -> None: ) @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ - return [tf.string] + return ["string"] @enforce_single_tensor_input - def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: + def _call(self, inputs: KerasTensor, **kwargs: Any) -> KerasTensor: """ Prefixes and suffixes a given input tensor. diff --git a/src/kamae/tensorflow/layers/string_array_constant.py b/src/kamae/keras/tensorflow/layers/string_array_constant.py similarity index 88% rename from src/kamae/tensorflow/layers/string_array_constant.py rename to src/kamae/keras/tensorflow/layers/string_array_constant.py index d86aae94..892fe95e 100644 --- a/src/kamae/tensorflow/layers/string_array_constant.py +++ b/src/kamae/keras/tensorflow/layers/string_array_constant.py @@ -14,13 +14,14 @@ from typing import Any, Dict, List, Optional +import keras import tensorflow as tf +from keras import KerasTensor import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import enforce_single_tensor_input - -from .base import BaseLayer +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_single_tensor_input @tf.keras.utils.register_keras_serializable(package=kamae.__name__) @@ -29,6 +30,9 @@ class StringArrayConstantLayer(BaseLayer): Tensorflow keras layer that outputs a constant string array. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + def __init__( self, name: Optional[str] = None, @@ -50,7 +54,7 @@ def __init__( self.constant_string_array = constant_string_array @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. @@ -59,7 +63,7 @@ def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: return None @enforce_single_tensor_input - def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: + def _call(self, inputs: KerasTensor, **kwargs: Any) -> KerasTensor: """ Returns the constant string array with the same shape as the input tensor. diff --git a/src/kamae/tensorflow/layers/string_case.py b/src/kamae/keras/tensorflow/layers/string_case.py similarity index 87% rename from src/kamae/tensorflow/layers/string_case.py rename to src/kamae/keras/tensorflow/layers/string_case.py index 24be7011..02314ae6 100644 --- a/src/kamae/tensorflow/layers/string_case.py +++ b/src/kamae/keras/tensorflow/layers/string_case.py @@ -14,13 +14,14 @@ from typing import Any, Dict, List, Optional +import keras import tensorflow as tf +from keras import KerasTensor import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import enforce_single_tensor_input - -from .base import BaseLayer +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_single_tensor_input @tf.keras.utils.register_keras_serializable(package=kamae.__name__) @@ -30,6 +31,9 @@ class StringCaseLayer(BaseLayer): Supported string case types are 'upper' and 'lower'. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + def __init__( self, string_case_type: str = "lower", @@ -53,16 +57,16 @@ def __init__( self.string_case_type = string_case_type @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ - return [tf.string] + return ["string"] @enforce_single_tensor_input - def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: + def _call(self, inputs: KerasTensor, **kwargs: Any) -> KerasTensor: """ Performs the string case transform on the input tensor. diff --git a/src/kamae/tensorflow/layers/string_concatenate.py b/src/kamae/keras/tensorflow/layers/string_concatenate.py similarity index 83% rename from src/kamae/tensorflow/layers/string_concatenate.py rename to src/kamae/keras/tensorflow/layers/string_concatenate.py index 6820b001..6f7c5298 100644 --- a/src/kamae/tensorflow/layers/string_concatenate.py +++ b/src/kamae/keras/tensorflow/layers/string_concatenate.py @@ -14,21 +14,25 @@ from typing import Any, Dict, Iterable, List, Optional +import keras import tensorflow as tf +from keras import KerasTensor import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import enforce_multiple_tensor_input +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_multiple_tensor_input -from .base import BaseLayer - -@tf.keras.utils.register_keras_serializable(kamae.__name__) +@tf.keras.utils.register_keras_serializable(package=kamae.__name__) class StringConcatenateLayer(BaseLayer): """ Performs a concatenation of the input tensors. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + def __init__( self, name: Optional[str] = None, @@ -50,16 +54,16 @@ def __init__( self.separator = separator @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ - return [tf.string] + return ["string"] @enforce_multiple_tensor_input - def _call(self, inputs: Iterable[Tensor], **kwargs: Any) -> Tensor: + def _call(self, inputs: Iterable[KerasTensor], **kwargs: Any) -> KerasTensor: """ Concatenates the input tensors. diff --git a/src/kamae/tensorflow/layers/string_contains.py b/src/kamae/keras/tensorflow/layers/string_contains.py similarity index 92% rename from src/kamae/tensorflow/layers/string_contains.py rename to src/kamae/keras/tensorflow/layers/string_contains.py index 5fc7c25c..623b5bb7 100644 --- a/src/kamae/tensorflow/layers/string_contains.py +++ b/src/kamae/keras/tensorflow/layers/string_contains.py @@ -14,13 +14,14 @@ from typing import Any, Dict, Iterable, List, Optional, Union +import keras import tensorflow as tf +from keras import KerasTensor import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import allow_single_or_multiple_tensor_input - -from .base import BaseLayer +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import allow_single_or_multiple_tensor_input @tf.keras.utils.register_keras_serializable(package=kamae.__name__) @@ -38,6 +39,9 @@ class StringContainsLayer(BaseLayer): does not support matching of newline characters. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + def __init__( self, string_constant: Optional[str] = None, @@ -62,16 +66,18 @@ def __init__( self.string_constant = string_constant @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ - return [tf.string] + return ["string"] @allow_single_or_multiple_tensor_input - def _call(self, inputs: Union[Tensor, Iterable[Tensor]], **kwargs: Any) -> Tensor: + def _call( + self, inputs: Union[KerasTensor, Iterable[KerasTensor]], **kwargs: Any + ) -> KerasTensor: """ Checks for the existence of a substring/pattern within a tensor. WARNING: While it works, the use of tensors in matching @@ -122,7 +128,7 @@ def _call(self, inputs: Union[Tensor, Iterable[Tensor]], **kwargs: Any) -> Tenso # Two tensors provided @tf.function - def tensor_match(x: List[Tensor]) -> Tensor: + def tensor_match(x: List[KerasTensor]) -> KerasTensor: match_substring = x[1] match_substring = self._escape_special_characters(match_substring) return tf.strings.regex_full_match( @@ -155,8 +161,8 @@ def tensor_match(x: List[Tensor]) -> Tensor: return output_tensor def _escape_special_characters( - self, string: Union[str, Tensor] - ) -> Union[str, Tensor]: + self, string: Union[str, KerasTensor] + ) -> Union[str, KerasTensor]: """ Escapes special characters in a string so they are not parsed as regex. :param string: The string or string tensor to escape special characters in. diff --git a/src/kamae/tensorflow/layers/string_contains_list.py b/src/kamae/keras/tensorflow/layers/string_contains_list.py similarity index 91% rename from src/kamae/tensorflow/layers/string_contains_list.py rename to src/kamae/keras/tensorflow/layers/string_contains_list.py index 2a2616d2..c4558e6f 100644 --- a/src/kamae/tensorflow/layers/string_contains_list.py +++ b/src/kamae/keras/tensorflow/layers/string_contains_list.py @@ -14,13 +14,14 @@ from typing import Any, Dict, List, Optional +import keras import tensorflow as tf +from keras import KerasTensor import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import enforce_single_tensor_input - -from .base import BaseLayer +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_single_tensor_input @tf.keras.utils.register_keras_serializable(package=kamae.__name__) @@ -33,6 +34,9 @@ class StringContainsListLayer(BaseLayer): strings. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + def __init__( self, string_constant_list: List[str], @@ -57,16 +61,16 @@ def __init__( self.string_constant_list = string_constant_list @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ - return [tf.string] + return ["string"] @enforce_single_tensor_input - def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: + def _call(self, inputs: KerasTensor, **kwargs: Any) -> KerasTensor: """ Checks for the existence of any substring in the string_contains_list within a tensor. diff --git a/src/kamae/tensorflow/layers/string_equals_if_statement.py b/src/kamae/keras/tensorflow/layers/string_equals_if_statement.py similarity index 92% rename from src/kamae/tensorflow/layers/string_equals_if_statement.py rename to src/kamae/keras/tensorflow/layers/string_equals_if_statement.py index 67f52e50..cc798215 100644 --- a/src/kamae/tensorflow/layers/string_equals_if_statement.py +++ b/src/kamae/keras/tensorflow/layers/string_equals_if_statement.py @@ -14,16 +14,16 @@ from typing import Any, Dict, Iterable, List, Optional, Union +import keras import tensorflow as tf +from keras import KerasTensor import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import allow_single_or_multiple_tensor_input +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import allow_single_or_multiple_tensor_input -from .base import BaseLayer - -# TODO: Deprecate this in favor of IfStatementLayer in next major release. @tf.keras.utils.register_keras_serializable(package=kamae.__name__) class StringEqualsIfStatementLayer(BaseLayer): """ @@ -42,6 +42,9 @@ class StringEqualsIfStatementLayer(BaseLayer): not None, then inputs is expected to be a tensor. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + def __init__( self, value_to_compare: Optional[str] = None, @@ -73,15 +76,15 @@ def __init__( self.result_if_false = result_if_false @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ - return [tf.string] + return ["string"] - def _construct_input_tensors(self, inputs: List[Tensor]) -> List[Tensor]: + def _construct_input_tensors(self, inputs: List[KerasTensor]) -> List[KerasTensor]: """ Constructs the input tensors for the layer in the case where all the optional parameters are not specified. We need to run through the provided inputs and @@ -119,7 +122,9 @@ def _construct_input_tensors(self, inputs: List[Tensor]) -> List[Tensor]: return multiple_inputs @allow_single_or_multiple_tensor_input - def _call(self, inputs: Union[Tensor, Iterable[Tensor]], **kwargs: Any) -> Tensor: + def _call( + self, inputs: Union[KerasTensor, Iterable[KerasTensor]], **kwargs: Any + ) -> KerasTensor: """ Performs the string if equals statement on the inputs. If the inputs are a tensor, we assume that the value_to_compare, result_if_true, and diff --git a/src/kamae/tensorflow/layers/string_index.py b/src/kamae/keras/tensorflow/layers/string_index.py similarity index 90% rename from src/kamae/tensorflow/layers/string_index.py rename to src/kamae/keras/tensorflow/layers/string_index.py index a8c4cbfd..4308e054 100644 --- a/src/kamae/tensorflow/layers/string_index.py +++ b/src/kamae/keras/tensorflow/layers/string_index.py @@ -14,14 +14,15 @@ from typing import Any, Dict, List, Optional, Union +import keras import tensorflow as tf +from keras import KerasTensor from tensorflow.keras.layers import StringLookup import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import enforce_single_tensor_input - -from .base import BaseLayer +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_single_tensor_input @tf.keras.utils.register_keras_serializable(package=kamae.__name__) @@ -34,6 +35,9 @@ class StringIndexLayer(BaseLayer): transformation of input strings. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + def __init__( self, vocabulary: Union[str, List[str]], @@ -80,16 +84,16 @@ def __init__( ) @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ - return [tf.string] + return ["string"] @enforce_single_tensor_input - def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: + def _call(self, inputs: KerasTensor, **kwargs: Any) -> KerasTensor: """ Performs string indexing by calling the StringLookup layer. diff --git a/src/kamae/tensorflow/layers/string_isin_list.py b/src/kamae/keras/tensorflow/layers/string_isin_list.py similarity index 88% rename from src/kamae/tensorflow/layers/string_isin_list.py rename to src/kamae/keras/tensorflow/layers/string_isin_list.py index 01dc2293..08292ea4 100644 --- a/src/kamae/tensorflow/layers/string_isin_list.py +++ b/src/kamae/keras/tensorflow/layers/string_isin_list.py @@ -14,13 +14,14 @@ from typing import Any, Dict, List, Optional +import keras import tensorflow as tf +from keras import KerasTensor import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import enforce_single_tensor_input - -from .base import BaseLayer +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_single_tensor_input @tf.keras.utils.register_keras_serializable(package=kamae.__name__) @@ -30,6 +31,9 @@ class StringIsInListLayer(BaseLayer): the string constant list. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + def __init__( self, string_constant_list: List[str], @@ -54,16 +58,16 @@ def __init__( self.string_constant_list = string_constant_list @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ - return [tf.string] + return ["string"] @enforce_single_tensor_input - def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: + def _call(self, inputs: KerasTensor, **kwargs: Any) -> KerasTensor: """ Checks if the input tensor is matching any string in the string_constant_list. diff --git a/src/kamae/tensorflow/layers/string_list_to_string.py b/src/kamae/keras/tensorflow/layers/string_list_to_string.py similarity index 88% rename from src/kamae/tensorflow/layers/string_list_to_string.py rename to src/kamae/keras/tensorflow/layers/string_list_to_string.py index ce424c97..a807805a 100644 --- a/src/kamae/tensorflow/layers/string_list_to_string.py +++ b/src/kamae/keras/tensorflow/layers/string_list_to_string.py @@ -14,13 +14,14 @@ from typing import Any, Dict, List, Optional +import keras import tensorflow as tf +from keras import KerasTensor import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import enforce_single_tensor_input - -from .base import BaseLayer +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_single_tensor_input @tf.keras.utils.register_keras_serializable(package=kamae.__name__) @@ -31,6 +32,9 @@ class StringListToStringLayer(BaseLayer): If `keepdims` is `True`, the shape is retained. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + def __init__( self, name: Optional[str] = None, @@ -61,16 +65,16 @@ def __init__( self.keepdims = keepdims @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ - return [tf.string] + return ["string"] @enforce_single_tensor_input - def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: + def _call(self, inputs: KerasTensor, **kwargs: Any) -> KerasTensor: """ Joins the strings along the specified axis with the specified separator. If `keepdims` is `True`, the shape is retained. Otherwise the shape is diff --git a/src/kamae/tensorflow/layers/string_map.py b/src/kamae/keras/tensorflow/layers/string_map.py similarity index 91% rename from src/kamae/tensorflow/layers/string_map.py rename to src/kamae/keras/tensorflow/layers/string_map.py index a05e1fab..b535b8b8 100644 --- a/src/kamae/tensorflow/layers/string_map.py +++ b/src/kamae/keras/tensorflow/layers/string_map.py @@ -14,13 +14,14 @@ from typing import Any, Dict, List, Optional +import keras import tensorflow as tf +from keras import KerasTensor import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import enforce_single_tensor_input - -from .base import BaseLayer +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_single_tensor_input @tf.keras.utils.register_keras_serializable(package=kamae.__name__) @@ -29,6 +30,9 @@ class StringMapLayer(BaseLayer): StringMapLayer layer for TensorFlow. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + def __init__( self, string_match_values: List[str], @@ -59,16 +63,16 @@ def __init__( self.default_replace_value = default_replace_value @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ - return [tf.string] + return ["string"] @enforce_single_tensor_input - def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: + def _call(self, inputs: KerasTensor, **kwargs: Any) -> KerasTensor: """ Checks if the input tensor is matching any of the string_match_values and replaces it with the corresponding string_replace_values. diff --git a/src/kamae/tensorflow/layers/string_replace.py b/src/kamae/keras/tensorflow/layers/string_replace.py similarity index 93% rename from src/kamae/tensorflow/layers/string_replace.py rename to src/kamae/keras/tensorflow/layers/string_replace.py index 0e5fa906..5863edcc 100644 --- a/src/kamae/tensorflow/layers/string_replace.py +++ b/src/kamae/keras/tensorflow/layers/string_replace.py @@ -14,13 +14,14 @@ from typing import Any, Dict, Iterable, List, Optional, Union +import keras import tensorflow as tf +from keras import KerasTensor import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import allow_single_or_multiple_tensor_input - -from .base import BaseLayer +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import allow_single_or_multiple_tensor_input @tf.keras.utils.register_keras_serializable(package=kamae.__name__) @@ -29,6 +30,9 @@ class StringReplaceLayer(BaseLayer): StringReplaceLayer layer for TensorFlow. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + def __init__( self, string_match_constant: Optional[str] = None, @@ -71,16 +75,18 @@ def __init__( self.regex = regex @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ - return [tf.string] + return ["string"] @allow_single_or_multiple_tensor_input - def _call(self, inputs: Union[Tensor, Iterable[Tensor]], **kwargs: Any) -> Tensor: + def _call( + self, inputs: Union[KerasTensor, Iterable[KerasTensor]], **kwargs: Any + ) -> KerasTensor: """ Checks for the existence of a substring/pattern within a tensor and replaces if there is a match. @@ -165,7 +171,7 @@ def _call(self, inputs: Union[Tensor, Iterable[Tensor]], **kwargs: Any) -> Tenso ) mappable_tensor = tf.reshape(mappable_tensor, [-1, 3]) - def _tensor_replace(x: List[Tensor]) -> Tensor: + def _tensor_replace(x: List[KerasTensor]) -> KerasTensor: match_substring = x[1] if not self.regex: match_substring = self._escape_special_characters(x[1]) @@ -191,8 +197,8 @@ def _tensor_replace(x: List[Tensor]) -> Tensor: return replaced_tensor def _escape_special_characters( - self, string_to_escape: Union[str, Tensor] - ) -> Union[str, Tensor]: + self, string_to_escape: Union[str, KerasTensor] + ) -> Union[str, KerasTensor]: """ Escapes special characters in a string so they are not parsed as regex. :param string_to_escape: The string or string tensor to escape special characters in. diff --git a/src/kamae/tensorflow/layers/string_to_string_list.py b/src/kamae/keras/tensorflow/layers/string_to_string_list.py similarity index 91% rename from src/kamae/tensorflow/layers/string_to_string_list.py rename to src/kamae/keras/tensorflow/layers/string_to_string_list.py index cd1db06f..e0bb1ac6 100644 --- a/src/kamae/tensorflow/layers/string_to_string_list.py +++ b/src/kamae/keras/tensorflow/layers/string_to_string_list.py @@ -14,13 +14,14 @@ from typing import Any, Dict, List, Optional +import keras import tensorflow as tf +from keras import KerasTensor import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import enforce_single_tensor_input - -from .base import BaseLayer +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_single_tensor_input @tf.keras.utils.register_keras_serializable(package=kamae.__name__) @@ -33,6 +34,9 @@ class StringToStringListLayer(BaseLayer): If the separator is empty, the string is split on bytes/characters. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + def __init__( self, name: Optional[str] = None, @@ -64,16 +68,16 @@ def __init__( self.default_value = default_value @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ - return [tf.string] + return ["string"] @enforce_single_tensor_input - def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: + def _call(self, inputs: KerasTensor, **kwargs: Any) -> KerasTensor: """ Splits the input string tensor by the separator and returns the list of strings. A list_length parameter is used to ensure that the output tensor has a diff --git a/src/kamae/tensorflow/layers/sub_string_delim_at_index.py b/src/kamae/keras/tensorflow/layers/sub_string_delim_at_index.py similarity index 94% rename from src/kamae/tensorflow/layers/sub_string_delim_at_index.py rename to src/kamae/keras/tensorflow/layers/sub_string_delim_at_index.py index eeb37f40..58462961 100644 --- a/src/kamae/tensorflow/layers/sub_string_delim_at_index.py +++ b/src/kamae/keras/tensorflow/layers/sub_string_delim_at_index.py @@ -14,13 +14,14 @@ from typing import Any, Dict, List, Optional +import keras import tensorflow as tf +from keras import KerasTensor import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import enforce_single_tensor_input - -from .base import BaseLayer +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_single_tensor_input @tf.keras.utils.register_keras_serializable(package=kamae.__name__) @@ -33,6 +34,9 @@ class SubStringDelimAtIndexLayer(BaseLayer): If the index is out of bounds, the default value is returned. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + def __init__( self, name: Optional[str] = None, @@ -64,13 +68,13 @@ def __init__( self.default_value = default_value @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. :returns: The compatible dtypes of the layer. """ - return [tf.string] + return ["string"] @staticmethod def resolve_negative_indices( @@ -91,7 +95,7 @@ def resolve_negative_indices( return tf.math.add(ragged_row_lengths, index) @enforce_single_tensor_input - def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: + def _call(self, inputs: KerasTensor, **kwargs: Any) -> KerasTensor: """ Splits the input string tensor by the delimiter and returns the substring at the specified index. If the index is out of bounds, the default value diff --git a/src/kamae/tensorflow/layers/unix_timestamp_to_date_time.py b/src/kamae/keras/tensorflow/layers/unix_timestamp_to_date_time.py similarity index 88% rename from src/kamae/tensorflow/layers/unix_timestamp_to_date_time.py rename to src/kamae/keras/tensorflow/layers/unix_timestamp_to_date_time.py index f2710f18..21583fa7 100644 --- a/src/kamae/tensorflow/layers/unix_timestamp_to_date_time.py +++ b/src/kamae/keras/tensorflow/layers/unix_timestamp_to_date_time.py @@ -14,16 +14,15 @@ from typing import Any, Dict, List, Optional +import keras import tensorflow as tf +from keras import KerasTensor import kamae -from kamae.tensorflow.typing import Tensor -from kamae.tensorflow.utils import ( - enforce_single_tensor_input, - unix_timestamp_to_datetime, -) - -from .base import BaseLayer +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_single_tensor_input +from kamae.keras.tensorflow.utils.date_utils import unix_timestamp_to_datetime @tf.keras.utils.register_keras_serializable(package=kamae.__name__) @@ -33,6 +32,9 @@ class UnixTimestampToDateTimeLayer(BaseLayer): If `include_time` is set to `False`, the output will be in yyyy-MM-dd format. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + def __init__( self, name: Optional[str] = None, @@ -68,7 +70,7 @@ def __init__( self.include_time = include_time @property - def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: + def compatible_dtypes(self) -> Optional[List[str]]: """ Returns the compatible dtypes of the layer. Returns `None` as the layer only returns the current date as a string. It does not transform any input. @@ -76,12 +78,12 @@ def compatible_dtypes(self) -> Optional[List[tf.dtypes.DType]]: :returns: The compatible dtypes of the layer. """ return [ - tf.float64, - tf.int64, + "float64", + "int64", ] @enforce_single_tensor_input - def _call(self, inputs: Tensor, **kwargs: Any) -> Tensor: + def _call(self, inputs: KerasTensor, **kwargs: Any) -> KerasTensor: """ Returns the datetime in yyyy-MM-dd HH:mm:ss.SSS format if `include_time` is set to `True`. Otherwise, returns the date in yyyy-MM-dd format. diff --git a/src/kamae/tensorflow/utils/__init__.py b/src/kamae/keras/tensorflow/utils/__init__.py similarity index 80% rename from src/kamae/tensorflow/utils/__init__.py rename to src/kamae/keras/tensorflow/utils/__init__.py index 29d2c3db..ca2949f6 100644 --- a/src/kamae/tensorflow/utils/__init__.py +++ b/src/kamae/keras/tensorflow/utils/__init__.py @@ -12,6 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +TensorFlow-specific utilities for TF-only layers. + +These utilities use TensorFlow-specific operations and are only available +when using the TensorFlow backend. +""" + from .date_utils import ( # noqa: F401 datetime_add_days, datetime_day, @@ -30,13 +37,5 @@ datetime_year, unix_timestamp_to_datetime, ) -from .input_utils import ( # noqa: F401 - allow_single_or_multiple_tensor_input, - enforce_multiple_tensor_input, - enforce_single_tensor_input, -) from .list_utils import get_top_n, listify_tensors, segmented_operation # noqa: F401 -from .shape_utils import reshape_to_equal_rank # noqa: F401 from .transform_utils import map_fn_w_axis # noqa: F401 - -from .layer_utils import NormalizeLayer # noqa: F401 # isort:skip diff --git a/src/kamae/tensorflow/utils/date_utils.py b/src/kamae/keras/tensorflow/utils/date_utils.py similarity index 100% rename from src/kamae/tensorflow/utils/date_utils.py rename to src/kamae/keras/tensorflow/utils/date_utils.py diff --git a/src/kamae/tensorflow/utils/list_utils.py b/src/kamae/keras/tensorflow/utils/list_utils.py similarity index 99% rename from src/kamae/tensorflow/utils/list_utils.py rename to src/kamae/keras/tensorflow/utils/list_utils.py index 24c5c23a..264fd028 100644 --- a/src/kamae/tensorflow/utils/list_utils.py +++ b/src/kamae/keras/tensorflow/utils/list_utils.py @@ -16,7 +16,7 @@ import numpy as np import tensorflow as tf -from kamae.tensorflow.typing import Tensor +from .typing import Tensor def get_top_n( diff --git a/src/kamae/tensorflow/utils/transform_utils.py b/src/kamae/keras/tensorflow/utils/transform_utils.py similarity index 99% rename from src/kamae/tensorflow/utils/transform_utils.py rename to src/kamae/keras/tensorflow/utils/transform_utils.py index a34e7861..4e1b45a9 100644 --- a/src/kamae/tensorflow/utils/transform_utils.py +++ b/src/kamae/keras/tensorflow/utils/transform_utils.py @@ -16,7 +16,7 @@ import tensorflow as tf -from kamae.tensorflow.typing import Tensor +from .typing import Tensor def map_fn_w_axis( diff --git a/src/kamae/tensorflow/typing/types.py b/src/kamae/keras/tensorflow/utils/typing.py similarity index 83% rename from src/kamae/tensorflow/typing/types.py rename to src/kamae/keras/tensorflow/utils/typing.py index 6db85a61..78e548e0 100644 --- a/src/kamae/tensorflow/typing/types.py +++ b/src/kamae/keras/tensorflow/utils/typing.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Creates typing objects for common tensorflow types.""" +"""TensorFlow-specific type hints for TF-only utilities.""" from typing import Union import tensorflow as tf +# TensorFlow-specific tensor type that includes sparse and ragged tensors Tensor = Union[tf.Tensor, tf.SparseTensor, tf.RaggedTensor] diff --git a/src/kamae/sklearn/estimators/standard_scale.py b/src/kamae/sklearn/estimators/standard_scale.py deleted file mode 100644 index ae600975..00000000 --- a/src/kamae/sklearn/estimators/standard_scale.py +++ /dev/null @@ -1,105 +0,0 @@ -# Copyright [2024] Expedia, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any - -import pandas as pd -import tensorflow as tf -from sklearn.preprocessing import StandardScaler - -from kamae.sklearn.params import SingleInputSingleOutputMixin -from kamae.sklearn.transformers import BaseTransformerMixin -from kamae.tensorflow.layers import StandardScaleLayer - - -class StandardScaleEstimator( - StandardScaler, - BaseTransformerMixin, - SingleInputSingleOutputMixin, -): - """ - Standard Scikit-Learn Estimator for use in Scikit-Learn pipelines. - Wrapper over the existing implementation of the StandardScaler in Scikit-Learn, - however operates on array columns and returns array columns. This is to align - with the Spark implementation of the StandardScaler. - - Standardize features by removing the mean and scaling to unit variance. - - The standard score of a sample `x` is calculated as: - - z = (x - u) / s - - where `u` is the mean of the training samples - and `s` is the standard deviation of the training samples - """ - - def __init__(self, input_col: str, output_col: str, layer_name: str) -> None: - """ - Intializes a StandardScale estimator. - - :param input_col: Input column name. - :param output_col: Output column name. - :param layer_name: Name of the layer. Used as the name of the tensorflow layer - """ - super().__init__(with_mean=True, with_std=True) - self.input_col = input_col - self.output_col = output_col - self.layer_name = layer_name - - def fit( - self, X: pd.DataFrame, y: None = None, **kwargs: Any - ) -> "StandardScaleEstimator": - """ - Fits the transformer to the data. Since the scikit-learn StandardScaler - takes scalar values, we need to convert the numpy array to a list of scalars. - This is to mimic the behavior of the Spark StandardScaler. - - In this, the input to our transformer is an array, and the output is a scaled - array. - - :param X: Pandas dataframe to fit the transformer to. - :param y: Not used, present here for API consistency by convention. - :returns: Fit pipeline. - """ - # Get array column as a list of scalars - feature_array = X[self.input_col].tolist() - super().fit(X=feature_array, y=y, sample_weight=None) - return self - - def transform(self, X: pd.DataFrame, y: None = None) -> pd.DataFrame: - """ - Transforms the data using the transformer. Standardises the array `input_col`, - creating a new standardised `output_col`. - - :param X: Pandas dataframe to transform. - :param y: Not used, present here for API consistency by convention. - :returns: Transformed data. - """ - # Get array column as a list of scalars - feature_array = X[self.input_col].tolist() - # Transform the list of scalars - transformed_list_of_scalars = super().transform(feature_array) - # Set the output column to an array of the transformed list of scalars - X[self.output_col] = list(transformed_list_of_scalars) - return X - - def get_tf_layer(self) -> tf.keras.layers.Layer: - """ - Gets the tensorflow layer for the standard scaler transformer. - - :returns: Tensorflow keras layer with name equal to the layerName parameter - that performs the standardization. - """ - return StandardScaleLayer( - name=self.layer_name, mean=self.mean_, variance=self.var_ - ) diff --git a/src/kamae/sklearn/params/__init__.py b/src/kamae/sklearn/params/__init__.py deleted file mode 100644 index 1ba0ff80..00000000 --- a/src/kamae/sklearn/params/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright [2024] Expedia, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .name import LayerNameMixin # noqa: F401 # isort:skip -from .base import ( # noqa: F401 - MultiInputMultiOutputMixin, - MultiInputSingleOutputMixin, - SingleInputMultiOutputMixin, - SingleInputSingleOutputMixin, -) -from .utils import InputOutputExtractor # noqa: F401 diff --git a/src/kamae/sklearn/params/base.py b/src/kamae/sklearn/params/base.py deleted file mode 100644 index f377e952..00000000 --- a/src/kamae/sklearn/params/base.py +++ /dev/null @@ -1,172 +0,0 @@ -# Copyright [2024] Expedia, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List - -from .name import LayerNameMixin - - -class SingleInputMixin: - """ - Mixin class containing set methods for the single input column scenario. - """ - - _input_col: str - - @property - def input_col(self) -> str: - """ - Gets the input column name. - - :returns: Input column name. - """ - return self._input_col - - @input_col.setter - def input_col(self, value: str) -> None: - """ - Sets the input column name. - - :param value: String to set the input_col parameter to. - :returns: None, input_col is set to the given value. - """ - self._input_col = value - - -class MultiInputMixin: - """ - Mixin class containing set methods for the multiple input columns scenario. - """ - - _input_cols: List[str] - - @property - def input_cols(self) -> List[str]: - """ - Gets the input column names. - - :returns: List of strings of input column names. - """ - return self._input_cols - - @input_cols.setter - def input_cols(self, value: List[str]) -> None: - """ - Sets the input column names. to the given list of strings. - - :param value: List of strings to set the input_col parameter to. - :returns: None, input_col is set to the given value. - """ - self._input_cols = value - - -class SingleOutputMixin(LayerNameMixin): - """ - Mixin class containing set methods for the single output column scenario. - """ - - _output_col: str - - @property - def output_col(self) -> str: - """ - Gets the output column name. - - :returns: List of strings of output column names. - """ - return self._output_col - - @output_col.setter - def output_col(self, value: str) -> None: - """ - Sets the output column name to the given string value. - - :param value: String to set the output_col parameter to. - :returns: None, output_col is set to the given value. - """ - if value is None: - # Set default output column name - self._output_col = "output" - self._output_col = value - - @LayerNameMixin.layer_name.setter - def layer_name(self, value: str) -> None: - """ - Sets the layer name to the given string value. - - :param value: String to set the layer_name parameter to. - :returns: None, layer_name is set to the given value. - """ - self._layer_name = value if value is not None else self.__repr__() - - -class MultiOutputMixin(LayerNameMixin): - """ - Mixin class containing set methods for the multiple output columns scenario. - """ - - _output_cols: List[str] - - @property - def output_cols(self) -> List[str]: - """ - Gets the output column names. - - :returns: List of strings of output column names. - """ - return self._output_cols - - @LayerNameMixin.layer_name.setter - def layer_name(self, value: str) -> None: - """ - Sets the layer name to the given string value. - - :param value: String to set the layer_name parameter to. - :returns: None, layer_name is set to the given value. - """ - self._layer_name = value if value is not None else self.__repr__() - - @output_cols.setter - def output_cols(self, value: List[str]) -> None: - """ - Sets the output column names to the given list of strings. - - :param value: List of strings to set the output_cols parameter to. - :returns: None, output_cols is set to the given value. - """ - self._output_cols = value - - -class SingleInputSingleOutputMixin(SingleInputMixin, SingleOutputMixin): - """ - Mixin for a layer that takes a single input and returns a single output - """ - - -class SingleInputMultiOutputMixin(SingleInputMixin, MultiOutputMixin): - """ - Mixin for a layer that takes a single input and returns multiple outputs - """ - - -class MultiInputSingleOutputMixin(MultiInputMixin, SingleOutputMixin): - """ - Mixin for a layer that takes multiple inputs and returns a single output - """ - - -class MultiInputMultiOutputMixin(MultiInputMixin, MultiOutputMixin): - """ - Mixin for a layer that takes multiple inputs and returns multiple outputs - """ diff --git a/src/kamae/sklearn/params/name.py b/src/kamae/sklearn/params/name.py deleted file mode 100644 index fc1ab443..00000000 --- a/src/kamae/sklearn/params/name.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright [2024] Expedia, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional - - -class LayerNameMixin: - """ - Mixin class for a layer name. - """ - - _layer_name: Optional[str] - - @property - def layer_name(self) -> str: - """ - Gets the layer name. - - :returns: String of layer name. - """ - return self._layer_name diff --git a/src/kamae/sklearn/params/utils.py b/src/kamae/sklearn/params/utils.py deleted file mode 100644 index ca053b33..00000000 --- a/src/kamae/sklearn/params/utils.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright [2024] Expedia, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List, Tuple - - -class InputOutputExtractor: - """ - Mixin class containing methods for extracting input and output column names. - """ - - def get_layer_inputs_outputs(self) -> Tuple[List[str], List[str]]: - """ - Gets the input & output information of the layer. Returns a tuple of lists, - the first containing the input column names and the second containing the - output column names. - - :returns: Tuple of lists containing the input and output column names. - """ - - if hasattr(self, "input_cols") and getattr(self, "input_cols") is not None: - inputs = self.input_cols - elif hasattr(self, "input_col") and getattr(self, "input_col") is not None: - inputs = [self.input_col] - else: - inputs = [] - - if hasattr(self, "output_cols") and getattr(self, "output_cols") is not None: - outputs = self.output_cols - elif hasattr(self, "output_col") and getattr(self, "output_col") is not None: - outputs = [self.output_col] - else: - outputs = [] - - return inputs, outputs diff --git a/src/kamae/sklearn/pipeline/pipeline.py b/src/kamae/sklearn/pipeline/pipeline.py deleted file mode 100644 index 03080937..00000000 --- a/src/kamae/sklearn/pipeline/pipeline.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright [2024] Expedia, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -import joblib -import keras_tuner as kt -import tensorflow as tf -from sklearn.pipeline import Pipeline - -from kamae.graph import PipelineGraph -from kamae.sklearn.transformers import BaseTransformer - - -class KamaeSklearnPipeline(Pipeline): - """ - KamaeSklearnPipeline is a subclass of sklearn.pipeline.Pipeline that is used to - chain together BaseTransformers. It maintains the same functionality - as sklearn.pipeline.Pipeline e.g. serialisation. - """ - - def __init__( - self, - steps: List[Tuple[str, BaseTransformer]], - *, - memory: Optional[Union[str, joblib.Memory]] = None, - verbose: bool = False, - ) -> None: - """ - Initializes a KamaeSklearnPipeline object. - - :param steps: List of tuples containing the name and LayerTransformer - :param memory: str or object with the joblib.Memory interface, default=None - Used to cache the fitted transformers of the pipeline. The last step - will never be cached, even if it is a transformer. By default, no - caching is performed. If a string is given, it is the path to the - caching directory. Enabling caching triggers a clone of the transformers - before fitting. Therefore, the transformer instance given to the - pipeline cannot be inspected directly. Use the attribute ``named_steps`` - or ``steps`` to inspect estimators within the pipeline. Caching the - transformers is advantageous when fitting is time consuming. - :param verbose: If True, the time elapsed while fitting each step - will be printed as it is completed. - """ - super().__init__(steps, memory=memory, verbose=verbose) - - def get_all_tf_layers(self) -> List[tf.keras.layers.Layer]: - """ - Gets a list of all tensorflow layers in the pipeline model. - - :returns: List of tensorflow layers within the pipeline model. - """ - return [step[1].get_tf_layer() for step in self.steps] - - def build_keras_model( - self, - tf_input_schema: Union[List[tf.TypeSpec], List[Dict[str, Any]]], - output_names: Optional[List[str]] = None, - ) -> tf.keras.Model: - """ - Builds a keras model from the pipeline using the PipelineGraph - helper class. - - :param tf_input_schema: List of dictionaries containing the input schema for - the model. Specifically the name, shape and dtype of each input. - These will be passed as is to the Keras Input layer. - :param output_names: Optional list of output names for the Keras model. If - provided, only the outputs specified are used as model outputs. - :returns: Keras model. - """ - stage_dict = { - step[1].layer_name: step[1].construct_layer_info() for step in self.steps - } - pipeline_graph = PipelineGraph(stage_dict=stage_dict) - return pipeline_graph.build_keras_model( - tf_input_schema=tf_input_schema, output_names=output_names - ) - - def get_keras_tuner_model_builder( - self, - tf_input_schema: Union[List[tf.TypeSpec], List[Dict[str, Any]]], - hp_dict: Dict[str, List[Dict[str, Any]]], - output_names: Optional[List[str]] = None, - ) -> Callable[[kt.HyperParameters], tf.keras.Model]: - """ - Builds a keras tuner model builder (function) from the pipeline model - using the PipelineGraph helper class. - - :param tf_input_schema: List of dictionaries containing the input schema for - the model. Specifically the name, shape and dtype of each input. - These will be passed as is to the Keras Input layer. - :param hp_dict: Dictionary containing the hyperparameters for the model. - :param output_names: Optional list of output names for the Keras model. If - provided, only the outputs specified are used as model outputs. - :returns: Keras tuner model builder (function). - """ - stage_dict = { - step[1].layer_name: step[1].construct_layer_info() for step in self.steps - } - pipeline_graph = PipelineGraph(stage_dict=stage_dict) - return pipeline_graph.get_keras_tuner_model_builder( - tf_input_schema=tf_input_schema, hp_dict=hp_dict, output_names=output_names - ) diff --git a/src/kamae/sklearn/transformers/__init__.py b/src/kamae/sklearn/transformers/__init__.py deleted file mode 100644 index 401391f8..00000000 --- a/src/kamae/sklearn/transformers/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright [2024] Expedia, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .array_concatenate import ArrayConcatenateTransformer # noqa: F401 -from .array_split import ArraySplitTransformer # noqa: F401 -from .base import BaseTransformer, BaseTransformerMixin # noqa: F401 -from .identity import IdentityTransformer # noqa: F401 -from .log import LogTransformer # noqa: F401 diff --git a/src/kamae/sklearn/transformers/array_concatenate.py b/src/kamae/sklearn/transformers/array_concatenate.py deleted file mode 100644 index 2276dd89..00000000 --- a/src/kamae/sklearn/transformers/array_concatenate.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright [2024] Expedia, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List - -import numpy as np -import pandas as pd -import tensorflow as tf - -from kamae.sklearn.params import MultiInputSingleOutputMixin -from kamae.tensorflow.layers import ArrayConcatenateLayer - -from .base import BaseTransformer - - -class ArrayConcatenateTransformer( - BaseTransformer, - MultiInputSingleOutputMixin, -): - """ - Vector Assembler Scikit-Learn Transformer for use in Scikit-Learn pipelines. - This transformer assembles multiple columns into a single array column. - """ - - def __init__(self, input_cols: List[str], output_col: str, layer_name: str) -> None: - super().__init__() - self.input_cols = input_cols - self.output_col = output_col - self.layer_name = layer_name - - def fit(self, X: pd.DataFrame, y: None = None) -> "ArrayConcatenateTransformer": - """ - Fits the transformer to the data. Does nothing since - this is transformer not an estimator. - - :param X: Pandas dataframe to fit the transformer to. - :param y: Not used, present here for API consistency by convention. - :returns: Fit pipeline, in this case the transformer itself. - """ - return self - - def transform(self, X: pd.DataFrame, y: None = None) -> pd.DataFrame: - """ - Transform the input dataset. Creates a new column named outputCol which is a - concatenated array of all input columns. - - :param X: Pandas dataframe to transform. - :param y: Not used, present here for API consistency by convention. - :returns: Transformed data. - """ - - # Check which columns are arrays, this gives a dict like: - # {'col1': True, 'col2': False, 'col3': True} - is_col_an_array_dict = ( - X.head(1)[self.input_cols] - .applymap(lambda x: pd.api.types.is_list_like(x)) - .to_dict(orient="records")[0] - ) - - new_input_cols = [] - for col_name, col_an_array in is_col_an_array_dict.items(): - if col_an_array: - # If the column is an array then we need to create a - # numpy array of arrays - # TODO: Can we make this more this efficient? - values = X[col_name].to_numpy() - new_input_cols.append(np.array([np.array(x) for x in values])) - else: - # If the column is not an array then we just need to extend - # the numpy array to have an extra dimension. This is so we can concat - # the arrays later. - values = X[col_name].to_numpy() - new_input_cols.append(values[:, None]) - - # Concatenate the arrays, this creates an N x M array - # where N is the number of rows, M is the number of features - concatenated_array = np.concatenate(new_input_cols, axis=-1) - # Add this to the dataframe, convert the numpy array to a list - # of 1-D numpy arrays - X[self.output_col] = list(concatenated_array) - - return X - - def get_tf_layer(self) -> tf.keras.layers.Layer: - """ - Gets the tensorflow layer that concatenates the input tensors. - - :returns: Tensorflow keras layer with name equal to the layerName parameter - that concatenates the input tensors. - """ - return ArrayConcatenateLayer(name=self.layer_name, axis=-1) diff --git a/src/kamae/sklearn/transformers/array_split.py b/src/kamae/sklearn/transformers/array_split.py deleted file mode 100644 index d9af68ed..00000000 --- a/src/kamae/sklearn/transformers/array_split.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright [2024] Expedia, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List - -import pandas as pd -import tensorflow as tf - -from kamae.sklearn.params import SingleInputMultiOutputMixin -from kamae.tensorflow.layers import ArraySplitLayer - -from .base import BaseTransformer - - -class ArraySplitTransformer( - BaseTransformer, - SingleInputMultiOutputMixin, -): - """ - VectorSlicer Scikit-Learn Transformer for use in Scikit-Learn pipelines. - This transformer slices an array column into multiple columns. - """ - - def __init__(self, input_col: str, output_cols: List[str], layer_name: str) -> None: - super().__init__() - self.input_col = input_col - self.output_cols = output_cols - self.layer_name = layer_name - - def fit(self, X: pd.DataFrame, y: None = None) -> "ArraySplitTransformer": - """ - Fits the transformer to the data. Does nothing since - this is transformer not an estimator. - - :param X: Pandas dataframe to fit the transformer to. - :param y: Not used, present here for API consistency by convention. - :returns: Fit pipeline, in this case the transformer itself. - """ - return self - - def transform(self, X: pd.DataFrame, y: None = None) -> pd.DataFrame: - """ - Transforms the input dataset. Creates a new column for each output column equal - to the value of the input column at the given index. - - :param X: Pandas dataframe to transform. - :param y: Not used, present here for API consistency by convention. - :returns: Transformed data. - """ - X[self.output_cols] = pd.DataFrame(X[self.input_col].tolist(), index=X.index) - return X - - def get_tf_layer(self) -> tf.keras.layers.Layer: - """ - Gets the tensorflow layer for that unstacks the input tensor and reshapes - to the original shape. - - :returns: Tensorflow keras layer with name equal to the layerName parameter - that slices the input tensors. - """ - return ArraySplitLayer(name=self.layer_name, axis=-1) diff --git a/src/kamae/sklearn/transformers/base.py b/src/kamae/sklearn/transformers/base.py deleted file mode 100644 index c2b1aaaa..00000000 --- a/src/kamae/sklearn/transformers/base.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright [2024] Expedia, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from abc import ABC, abstractmethod -from typing import Any, Dict, List, Union - -import tensorflow as tf -from sklearn.base import BaseEstimator, TransformerMixin - -from kamae.sklearn.params import InputOutputExtractor, LayerNameMixin - - -class BaseTransformerMixin(ABC, LayerNameMixin, InputOutputExtractor): - """ - Mixin abstract class defining methods needed for all kamae scikit-learn - transformers. - """ - - def __init__(self, **kwargs: Any) -> None: - """ - Initializes the transformer. - """ - super().__init__() - - @abstractmethod - def get_tf_layer(self) -> Union[tf.keras.layers.Layer, List[tf.keras.layers.Layer]]: - """ - Gets the tensorflow layer to be used in the model. - This is the only abstract method that must be implemented. - :returns: Tensorflow Layer - """ - raise NotImplementedError - - def construct_layer_info(self) -> Dict[str, Any]: - """ - Constructs the layer info dictionary. - Contains the layer name, the tensorflow layer, and the inputs and outputs. - This is used when constructing the pipeline graph. - - :returns: Dictionary containing layer information such as - name, tensorflow layer, inputs, and outputs. - """ - inputs, outputs = self.get_layer_inputs_outputs() - return { - "name": self.layer_name, - "layer": self.get_tf_layer(), - "inputs": inputs, - "outputs": outputs, - } - - -class BaseTransformer(BaseTransformerMixin, BaseEstimator, TransformerMixin, ABC): - """ - Abstract class for all scikit-learn transformers. Specifically, this class extends - the required scikit-learn classes BaseEstimator and TransformerMixin adding in the - kamae BaseTransformerMixin which defines the methods needed to work with the kamae - pipeline graph. - - The reason we keep this separate from the BaseTransformerMixin (which is not done - for Spark) is because on the scikit-learn side we want to allow the ability to - inherit from existing scikit-learn classes (such as the StandardScaler). In these - cases the existing class already inherits from BaseEstimator and TransformerMixin - and so only needs the BaseTransformerMixin (to add kamae specific functionality). - If you try and inherit these classes twice (once from the existing scikit-learn - class and once from BaseTransformer) you will get an error. Therefore, we keep - these separate. - - If you are building an entirely new transformer, then you can inherit from this - class directly, to save you from having to inherit from BaseEstimator and - TransformerMixin. - - In Spark, all existing (core) implementations are built in Scala and ported to - Python. In this case, the ability to re-use existing Spark transformers is very - difficult and not worth the effort. You can see that for the StandardScaleEstimator - the logic does not depend on the existing Spark StandardScaler. - - Therefore, we have a single BaseTransformer class for use by all Spark - transformers. - """ diff --git a/src/kamae/sklearn/transformers/identity.py b/src/kamae/sklearn/transformers/identity.py deleted file mode 100644 index b37ca267..00000000 --- a/src/kamae/sklearn/transformers/identity.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright [2024] Expedia, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pandas as pd -import tensorflow as tf - -from kamae.sklearn.params import SingleInputSingleOutputMixin -from kamae.tensorflow.layers import IdentityLayer - -from .base import BaseTransformer - - -class IdentityTransformer(BaseTransformer, SingleInputSingleOutputMixin): - """ - Identity Scikit-Learn Transformer for use in Scikit-Learn pipelines. - This transformer simply passes the input to the output unchanged. - Used for cases where you want to keep the input the same. - """ - - def __init__(self, input_col: str, output_col: str, layer_name: str) -> None: - """ - Intializes an IdentityTransformer transformer. - - :param input_col: Input column name. - :param output_col: Output column name. - :param layer_name: Name of the layer. Used as the name of the tensorflow layer - in the keras model. - :returns: None - class instantialized. - """ - super().__init__() - self.input_col = input_col - self.output_col = output_col - self.layer_name = layer_name - - def fit(self, X: pd.DataFrame, y: None = None) -> "IdentityTransformer": - """ - Fits the transformer to the data. Does nothing since - this is an identity transformer. - - :param X: Pandas dataframe to fit the transformer to. - :param y: Not used, present here for API consistency by convention. - :returns: Fit pipeline, in this case the transformer itself. - """ - return self - - def transform(self, X: pd.DataFrame, y: None = None) -> pd.DataFrame: - """ - Transforms the data using the transformer. Creates a new column with name - `output_col`, which is the same as the `input_col`. - - :param X: Pandas dataframe to transform. - :param y: Not used, present here for API consistency by convention. - :returns: Transformed data. - """ - X[self.output_col] = X[self.input_col] - return X - - def get_tf_layer(self) -> tf.keras.layers.Layer: - """ - Gets the tensorflow layer for the identity transformer. - - :returns: Tensorflow keras layer with name equal to the layerName parameter that - performs an Identity operation. - """ - return IdentityLayer( - name=self.layer_name, - ) diff --git a/src/kamae/sklearn/transformers/log.py b/src/kamae/sklearn/transformers/log.py deleted file mode 100644 index addb8691..00000000 --- a/src/kamae/sklearn/transformers/log.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright [2024] Expedia, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional - -import numpy as np -import pandas as pd -import tensorflow as tf - -from kamae.sklearn.params import SingleInputSingleOutputMixin -from kamae.tensorflow.layers import LogLayer - -from .base import BaseTransformer - - -class LogTransformer(BaseTransformer, SingleInputSingleOutputMixin): - """ - Log Scikit-Learn Transformer for use in Scikit-Learn pipelines. - This transformer applies a log(alpha + x) transform to the input column. - """ - - def __init__( - self, - input_col: str, - output_col: str, - layer_name: str, - alpha: Optional[float] = None, - ) -> None: - """ - Intializes a LogTransformLayer transformer. Sets the default values of: - - - alpha: 1 - - :param input_col: Input column name. - :param output_col: Output column name. - :param layer_name: Name of the layer. Used as the name of the tensorflow layer - :param alpha: Value to use in log transform: log(alpha + x). Default is 1. - :returns: None - class instantialized. - """ - super().__init__() - self.input_col = input_col - self.output_col = output_col - self.layer_name = layer_name - self.alpha = float(alpha) if alpha is not None else 1.0 - - def fit(self, X: pd.DataFrame, y: None = None) -> "LogTransformer": - """ - Fits the transformer. Does nothing since this is just a transformer. - - :param X: Pandas dataframe to fit the transformer to. - :param y: Not used, present here for API consistency by convention. - :returns: Fit pipeline, in this case the transformer itself. - """ - return self - - def transform(self, X: pd.DataFrame, y: None = None) -> pd.DataFrame: - """ - Transforms the data using the transformer. Creates a new column with name - `output_col`, which applies log(alpha + x) transform to the `input_col`. - - :param X: Pandas dataframe to transform. - :param y: Not used, present here for API consistency by convention. - :returns: Transformed data. - """ - X[self.output_col] = np.log(X[self.input_col] + self.alpha) - return X - - def get_tf_layer(self) -> tf.keras.layers.Layer: - """ - Gets the tensorflow layer that performs the log transform. - - :returns: Tensorflow keras layer with name equal to the layerName parameter - that performs the log(alpha + x) operation. - """ - alpha = self.alpha - return LogLayer(name=self.layer_name, alpha=alpha) diff --git a/src/kamae/spark/common/spark_operation.py b/src/kamae/spark/common/spark_operation.py index b07dd47c..625cb6da 100644 --- a/src/kamae/spark/common/spark_operation.py +++ b/src/kamae/spark/common/spark_operation.py @@ -22,6 +22,7 @@ from pyspark.sql import Column, DataFrame from pyspark.sql.types import DataType, NumericType +from kamae.keras.core.backend import validate_backend from kamae.spark.params import ( HasInputDtype, HasLayerName, @@ -42,10 +43,14 @@ class SparkOperation( param setting, input/output dtype casting, and layer name setting. """ + supported_backends: frozenset + jit_compatible: bool + def __init__(self) -> None: """ Initializes the spark operation class. """ + validate_backend(self.__class__.__name__, self.supported_backends) super().__init__() self._setDefault(layerName=self.uid, inputDtype=None, outputDtype=None) self.tmp_column_suffix = self.generate_tmp_column_suffix() diff --git a/src/kamae/spark/estimators/conditional_standard_scale.py b/src/kamae/spark/estimators/conditional_standard_scale.py index a0b50f45..b6edfb2f 100644 --- a/src/kamae/spark/estimators/conditional_standard_scale.py +++ b/src/kamae/spark/estimators/conditional_standard_scale.py @@ -26,6 +26,7 @@ from pyspark.sql import Column, DataFrame from pyspark.sql.types import ArrayType, DataType, DoubleType, FloatType +from kamae.keras.core.backend import ALL_BACKENDS from kamae.spark.params import ( NanFillValueParams, SampleFractionParams, @@ -237,6 +238,9 @@ class ConditionalStandardScaleEstimator( shape across all rows. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + @keyword_only def __init__( self, diff --git a/src/kamae/spark/estimators/impute.py b/src/kamae/spark/estimators/impute.py index abfc1814..56a4e83e 100644 --- a/src/kamae/spark/estimators/impute.py +++ b/src/kamae/spark/estimators/impute.py @@ -23,6 +23,7 @@ from pyspark.sql import DataFrame from pyspark.sql.types import ArrayType, DataType, DoubleType, FloatType +from kamae.keras.core.backend import ALL_BACKENDS from kamae.spark.params import ( ImputeMethodParams, MaskValueParams, @@ -51,6 +52,9 @@ class ImputeEstimator( either null or equal to the supplied mask value. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + @keyword_only def __init__( self, diff --git a/src/kamae/spark/estimators/min_max_scale.py b/src/kamae/spark/estimators/min_max_scale.py index 872d6c34..36a6430f 100644 --- a/src/kamae/spark/estimators/min_max_scale.py +++ b/src/kamae/spark/estimators/min_max_scale.py @@ -24,6 +24,7 @@ from pyspark.sql import DataFrame from pyspark.sql.types import ArrayType, DataType, DoubleType, FloatType +from kamae.keras.core.backend import ALL_BACKENDS from kamae.spark.params import ( MaskValueParams, SampleFractionParams, @@ -51,6 +52,9 @@ class MinMaxScaleEstimator( shape across all rows. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + @keyword_only def __init__( self, diff --git a/src/kamae/spark/estimators/one_hot_encode.py b/src/kamae/spark/estimators/one_hot_encode.py index 502642ca..80d18dbd 100644 --- a/src/kamae/spark/estimators/one_hot_encode.py +++ b/src/kamae/spark/estimators/one_hot_encode.py @@ -23,6 +23,7 @@ from pyspark.sql import DataFrame from pyspark.sql.types import DataType, IntegerType, LongType, ShortType, StringType +from kamae.keras.core.backend import TENSORFLOW_ONLY from kamae.spark.params import ( DropUnseenParams, SingleInputSingleOutputParams, @@ -48,6 +49,9 @@ class OneHotEncodeEstimator( same string labels. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + @keyword_only def __init__( self, diff --git a/src/kamae/spark/estimators/shared_one_hot_encode.py b/src/kamae/spark/estimators/shared_one_hot_encode.py index 45e9a4d6..43827c8f 100644 --- a/src/kamae/spark/estimators/shared_one_hot_encode.py +++ b/src/kamae/spark/estimators/shared_one_hot_encode.py @@ -23,6 +23,7 @@ from pyspark.sql import DataFrame from pyspark.sql.types import DataType, IntegerType, LongType, ShortType, StringType +from kamae.keras.core.backend import TENSORFLOW_ONLY from kamae.spark.params import ( DropUnseenParams, MultiInputMultiOutputParams, @@ -48,6 +49,9 @@ class SharedOneHotEncodeEstimator( same string labels. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + @keyword_only def __init__( self, diff --git a/src/kamae/spark/estimators/shared_string_index.py b/src/kamae/spark/estimators/shared_string_index.py index 4bbd3489..78110a4c 100644 --- a/src/kamae/spark/estimators/shared_string_index.py +++ b/src/kamae/spark/estimators/shared_string_index.py @@ -23,6 +23,7 @@ from pyspark.sql import DataFrame from pyspark.sql.types import DataType, StringType +from kamae.keras.core.backend import TENSORFLOW_ONLY from kamae.spark.params import MultiInputMultiOutputParams, StringIndexParams from kamae.spark.transformers import SharedStringIndexTransformer from kamae.spark.utils import collect_labels_array_from_multiple_columns @@ -43,6 +44,9 @@ class SharedStringIndexEstimator( to index additional feature columns using the same string labels. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + @keyword_only def __init__( self, diff --git a/src/kamae/spark/estimators/single_feature_array_standard_scale.py b/src/kamae/spark/estimators/single_feature_array_standard_scale.py index 5e55c9c5..4c209893 100644 --- a/src/kamae/spark/estimators/single_feature_array_standard_scale.py +++ b/src/kamae/spark/estimators/single_feature_array_standard_scale.py @@ -20,6 +20,7 @@ from pyspark.sql import DataFrame from pyspark.sql.types import ArrayType, DataType, DoubleType, FloatType +from kamae.keras.core.backend import ALL_BACKENDS from kamae.spark.params import ( MaskValueParams, SampleFractionParams, @@ -47,6 +48,9 @@ class SingleFeatureArrayStandardScaleEstimator( and standard deviation are calculated across all elements in all the arrays. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + @keyword_only def __init__( self, diff --git a/src/kamae/spark/estimators/standard_scale.py b/src/kamae/spark/estimators/standard_scale.py index 178ac662..a1c654ea 100644 --- a/src/kamae/spark/estimators/standard_scale.py +++ b/src/kamae/spark/estimators/standard_scale.py @@ -24,6 +24,7 @@ from pyspark.sql import DataFrame from pyspark.sql.types import ArrayType, DataType, DoubleType, FloatType +from kamae.keras.core.backend import ALL_BACKENDS from kamae.spark.params import ( MaskValueParams, SampleFractionParams, @@ -51,6 +52,9 @@ class StandardScaleEstimator( shape across all rows. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + @keyword_only def __init__( self, diff --git a/src/kamae/spark/estimators/string_index.py b/src/kamae/spark/estimators/string_index.py index 32a1688e..a568a979 100644 --- a/src/kamae/spark/estimators/string_index.py +++ b/src/kamae/spark/estimators/string_index.py @@ -23,6 +23,7 @@ from pyspark.sql import DataFrame from pyspark.sql.types import DataType, StringType +from kamae.keras.core.backend import TENSORFLOW_ONLY from kamae.spark.params import SingleInputSingleOutputParams, StringIndexParams from kamae.spark.transformers import StringIndexTransformer from kamae.spark.utils import collect_labels_array @@ -42,6 +43,9 @@ class StringIndexEstimator( to index additional feature columns using the same string labels. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + @keyword_only def __init__( self, diff --git a/src/kamae/spark/params/base.py b/src/kamae/spark/params/base.py index 154d0c54..071305ad 100644 --- a/src/kamae/spark/params/base.py +++ b/src/kamae/spark/params/base.py @@ -68,17 +68,17 @@ def getInputDtype(self) -> str: """ return self.getOrDefault(self.inputDtype) - def getInputTFDtype(self) -> Optional[str]: + def getInputKerasDtype(self) -> Optional[str]: """ - Gets the tensorflow datatype string from the inputDtype parameter. - Uses the DType enum within Kamae to map the inputDtype to the tensorflow + Gets the Keras datatype string from the inputDtype parameter. + Uses the DType enum within Kamae to map the inputDtype to the Keras datatype string. - :returns: String of the tensorflow datatype. + :returns: String of the Keras datatype. """ input_dtype = self.getInputDtype() if input_dtype is None: return None - dtypes_map = {dtype.dtype_name: dtype.tf_dtype.name for dtype in DType} + dtypes_map = {dtype.dtype_name: dtype.keras_dtype for dtype in DType} return dtypes_map[input_dtype] @@ -117,18 +117,18 @@ def getOutputDtype(self) -> str: """ return self.getOrDefault(self.outputDtype) - def getOutputTFDtype(self) -> Optional[str]: + def getOutputKerasDtype(self) -> Optional[str]: """ - Gets the tensorflow datatype string from the outputDtype parameter. - Uses the DType enum within Kamae to map the outputDtype to the tensorflow + Gets the Keras datatype string from the outputDtype parameter. + Uses the DType enum within Kamae to map the outputDtype to the Keras datatype string. - :returns: String of the tensorflow datatype. + :returns: String of the Keras datatype. """ output_dtype = self.getOutputDtype() if output_dtype is None: return None - dtypes_map = {dtype.dtype_name: dtype.tf_dtype.name for dtype in DType} + dtypes_map = {dtype.dtype_name: dtype.keras_dtype for dtype in DType} return dtypes_map[output_dtype] diff --git a/src/kamae/spark/pipeline/pipeline_model.py b/src/kamae/spark/pipeline/pipeline_model.py index 63132f33..ef6125f7 100644 --- a/src/kamae/spark/pipeline/pipeline_model.py +++ b/src/kamae/spark/pipeline/pipeline_model.py @@ -14,8 +14,8 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union, cast +import keras import keras_tuner as kt -import tensorflow as tf from pyspark.ml import PipelineModel from pyspark.ml.pipeline import ( PipelineModelReader, @@ -78,13 +78,13 @@ def read(cls) -> "KamaeSparkPipelineModelReader": """ return KamaeSparkPipelineModelReader(cls) - def get_all_tf_layers(self) -> List[tf.keras.layers.Layer]: + def get_all_keras_layers(self) -> List[keras.layers.Layer]: """ - Gets a list of all tensorflow layers in the pipeline model. + Gets a list of all Keras layers in the pipeline model. - :returns: List of tensorflow layers within the pipeline model. + :returns: List of Keras layers within the pipeline model. """ - return [stage.get_tf_layer() for stage in self.stages] + return [stage.get_keras_layer() for stage in self.stages] def expand_pipeline_stages(self) -> List[BaseTransformer]: """ @@ -105,14 +105,14 @@ def expand_pipeline_stages(self) -> List[BaseTransformer]: def build_keras_model( self, - tf_input_schema: Union[List[tf.TypeSpec], List[Dict[str, Any]]], + input_schema: List[Dict[str, Any]], output_names: Optional[List[str]] = None, - ) -> tf.keras.Model: + ) -> keras.Model: """ Builds a keras model from the pipeline model using the PipelineGraph helper class. - :param tf_input_schema: List of dictionaries containing the input schema for + :param input_schema: List of dictionaries containing the input schema for the model. Specifically the name, shape and dtype of each input. These will be passed as is to the Keras Input layer. :param output_names: Optional list of output names for the Keras model. If @@ -125,20 +125,20 @@ def build_keras_model( } pipeline_graph = PipelineGraph(stage_dict=stage_dict) return pipeline_graph.build_keras_model( - tf_input_schema=tf_input_schema, output_names=output_names + input_schema=input_schema, output_names=output_names ) def get_keras_tuner_model_builder( self, - tf_input_schema: Union[List[tf.TypeSpec], List[Dict[str, Any]]], + input_schema: List[Dict[str, Any]], hp_dict: Dict[str, List[Dict[str, Any]]], output_names: Optional[List[str]] = None, - ) -> Callable[[kt.HyperParameters], tf.keras.Model]: + ) -> Callable[[kt.HyperParameters], keras.Model]: """ Builds a keras tuner model builder (function) from the pipeline model using the PipelineGraph helper class. - :param tf_input_schema: List of dictionaries containing the input schema for + :param input_schema: List of dictionaries containing the input schema for the model. Specifically the name, shape and dtype of each input. These will be passed as is to the Keras Input layer. :param hp_dict: Dictionary containing the hyperparameters for the model. @@ -152,7 +152,7 @@ def get_keras_tuner_model_builder( } pipeline_graph = PipelineGraph(stage_dict=stage_dict) return pipeline_graph.get_keras_tuner_model_builder( - tf_input_schema=tf_input_schema, hp_dict=hp_dict, output_names=output_names + input_schema=input_schema, hp_dict=hp_dict, output_names=output_names ) diff --git a/src/kamae/spark/transformers/absolute_value.py b/src/kamae/spark/transformers/absolute_value.py index e1d16a23..9d1ce929 100644 --- a/src/kamae/spark/transformers/absolute_value.py +++ b/src/kamae/spark/transformers/absolute_value.py @@ -18,8 +18,8 @@ # pylint: disable=no-member from typing import List, Optional +import keras import pyspark.sql.functions as F -import tensorflow as tf from pyspark import keyword_only from pyspark.sql import DataFrame from pyspark.sql.types import ( @@ -32,9 +32,10 @@ ShortType, ) +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.layers import AbsoluteValueLayer from kamae.spark.params import SingleInputSingleOutputParams from kamae.spark.utils import single_input_single_output_scalar_transform -from kamae.tensorflow.layers import AbsoluteValueLayer from .base import BaseTransformer @@ -48,6 +49,9 @@ class AbsoluteValueTransformer( This transformer applies abs(x) operation to the input. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + @keyword_only def __init__( self, @@ -66,7 +70,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :returns: None - class instantiated. """ @@ -109,15 +113,15 @@ def _transform(self, dataset: DataFrame) -> DataFrame: ) return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> keras.layers.Layer: """ - Gets the tensorflow layer for the absolute value transformer. + Gets the Keras layer for the absolute value transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter that + :returns: Keras layer with name equal to the layerName parameter that performs an absolute value operation. """ return AbsoluteValueLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), ) diff --git a/src/kamae/spark/transformers/array_concatenate.py b/src/kamae/spark/transformers/array_concatenate.py index 1d9a1f41..ab12359e 100644 --- a/src/kamae/spark/transformers/array_concatenate.py +++ b/src/kamae/spark/transformers/array_concatenate.py @@ -18,12 +18,14 @@ # pylint: disable=no-member from typing import List, Optional +import keras import pyspark.sql.functions as F -import tensorflow as tf from pyspark import keyword_only from pyspark.sql import Column, DataFrame from pyspark.sql.types import ArrayType, DataType +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.layers import ArrayConcatenateLayer from kamae.spark.params import AutoBroadcastParams, MultiInputSingleOutputParams from kamae.spark.utils import ( broadcast_scalar_column_to_array_with_inner_singleton_array, @@ -31,7 +33,6 @@ nested_arrays_zip, nested_transform, ) -from kamae.tensorflow.layers import ArrayConcatenateLayer from .base import BaseTransformer @@ -46,6 +47,9 @@ class ArrayConcatenateTransformer( This transformer assembles multiple columns into a single array column. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + @keyword_only def __init__( self, @@ -65,7 +69,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :param autoBroadcast: If True, the Keras transformer will broadcast scalar inputs to the biggest rank. Default is False. @@ -275,17 +279,17 @@ def _transform(self, dataset: DataFrame) -> DataFrame: ) return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> keras.layers.Layer: """ - Gets the tensorflow layer that concatneates the input tensors. + Gets the Keras layer that concatneates the input tensors. - :returns: Tensorflow keras layer with name equal to the layerName parameter + :returns: Keras layer with name equal to the layerName parameter that concatenates the input tensors. """ return ArrayConcatenateLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), axis=-1, auto_broadcast=self.getAutoBroadcast(), ) diff --git a/src/kamae/spark/transformers/array_crop.py b/src/kamae/spark/transformers/array_crop.py index 1dc6d319..140caed1 100644 --- a/src/kamae/spark/transformers/array_crop.py +++ b/src/kamae/spark/transformers/array_crop.py @@ -14,19 +14,20 @@ from typing import List, Optional, Union +import keras import pyspark.sql.functions as F -import tensorflow as tf from pyspark import keyword_only from pyspark.ml.param import Param, TypeConverters from pyspark.sql import DataFrame from pyspark.sql.types import BooleanType, DataType, FloatType, IntegerType, StringType +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.layers import ArrayCropLayer from kamae.spark.params import PadValueParams, SingleInputSingleOutputParams from kamae.spark.utils import ( get_array_nesting_level_and_element_dtype, single_input_single_output_array_transform, ) -from kamae.tensorflow.layers import ArrayCropLayer from .base import BaseTransformer @@ -73,6 +74,9 @@ class ArrayCropTransformer( padded with specified pad value. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + @keyword_only def __init__( self, @@ -92,7 +96,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer :param arrayLength: The length to crop or pad the arrays to. Defaults to 128. :param padValue: The value pad the arrays with. Defaults to `None`. :returns: None @@ -201,17 +205,17 @@ def _transform(self, dataset: DataFrame) -> DataFrame: ) return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> keras.layers.Layer: """ - Gets the tensorflow layer that performs the array cropping and padding. + Gets the Keras layer that performs the array cropping and padding. - :returns: Tensorflow keras layer with name equal to the layerName parameter + :returns: Keras layer with name equal to the layerName parameter that performs the array cropping and padding operation. """ return ArrayCropLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), array_length=self.getArrayLength(), pad_value=self.getPadValue(), ) diff --git a/src/kamae/spark/transformers/array_reduce_max.py b/src/kamae/spark/transformers/array_reduce_max.py index 0b498da7..38d5441e 100644 --- a/src/kamae/spark/transformers/array_reduce_max.py +++ b/src/kamae/spark/transformers/array_reduce_max.py @@ -14,16 +14,17 @@ from typing import List, Optional +import keras import pyspark.sql.functions as F -import tensorflow as tf from pyspark import keyword_only from pyspark.ml.param import Param, Params, TypeConverters from pyspark.sql import DataFrame from pyspark.sql.types import DataType, DoubleType, FloatType +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.layers import ArrayReduceMaxLayer from kamae.spark.params import SingleInputSingleOutputParams from kamae.spark.utils import single_input_single_output_array_transform -from kamae.tensorflow.layers import ArrayReduceMaxLayer from .base import BaseTransformer @@ -41,6 +42,9 @@ class ArrayReduceMaxTransformer( Returns defaultValue when the array is empty or null. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + defaultValue = Param( Params._dummy(), "defaultValue", @@ -86,10 +90,16 @@ def _transform(self, dataset: DataFrame) -> DataFrame: ) return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> keras.layers.Layer: + """ + Gets the Keras layer that reduces an array to its maximum element. + + :returns: Keras layer with name equal to the layerName parameter + that performs the array reduce max operation. + """ return ArrayReduceMaxLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), default_value=self.getDefaultValue(), ) diff --git a/src/kamae/spark/transformers/array_split.py b/src/kamae/spark/transformers/array_split.py index 6ef35ffa..9f1da58a 100644 --- a/src/kamae/spark/transformers/array_split.py +++ b/src/kamae/spark/transformers/array_split.py @@ -18,15 +18,16 @@ # pylint: disable=no-member from typing import List, Optional +import keras import pyspark.sql.functions as F -import tensorflow as tf from pyspark import keyword_only from pyspark.sql import DataFrame from pyspark.sql.types import DataType +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.layers import ArraySplitLayer from kamae.spark.params import SingleInputMultiOutputParams from kamae.spark.utils import single_input_single_output_array_transform -from kamae.tensorflow.layers import ArraySplitLayer from .base import BaseTransformer @@ -40,6 +41,9 @@ class ArraySplitTransformer( This transformer splits an array column into multiple columns. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + @keyword_only def __init__( self, @@ -58,7 +62,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column(s) to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :returns: None - class instantiated. """ @@ -99,17 +103,17 @@ def _transform(self, dataset: DataFrame) -> DataFrame: select_cols = original_columns + output_cols return dataset.select(select_cols) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> keras.layers.Layer: """ - Gets the tensorflow layer for that unstacks the input tensor and reshapes + Gets the Keras layer for that unstacks the input tensor and reshapes to the original shape. - :returns: Tensorflow keras layer with name equal to the layerName parameter + :returns: Keras layer with name equal to the layerName parameter that slices the input tensors. """ return ArraySplitLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), axis=-1, ) diff --git a/src/kamae/spark/transformers/array_subtract_minimum.py b/src/kamae/spark/transformers/array_subtract_minimum.py index 3e6f6f65..f1472389 100644 --- a/src/kamae/spark/transformers/array_subtract_minimum.py +++ b/src/kamae/spark/transformers/array_subtract_minimum.py @@ -14,8 +14,8 @@ from typing import List, Optional +import keras import pyspark.sql.functions as F -import tensorflow as tf from pyspark import keyword_only from pyspark.ml.param import Param, Params, TypeConverters from pyspark.sql import Column, DataFrame @@ -30,9 +30,10 @@ ShortType, ) +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.layers import ArraySubtractMinimumLayer from kamae.spark.params import SingleInputSingleOutputParams from kamae.spark.utils import single_input_single_output_array_transform -from kamae.tensorflow.layers import ArraySubtractMinimumLayer from .base import BaseTransformer @@ -81,6 +82,9 @@ class ArraySubtractMinimumTransformer( The main use case in mind for this is working with an array of timestamps. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + @keyword_only def __init__( self, @@ -100,7 +104,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer :param padValue: The value to be considered as padding. Defaults to `None`. :returns: None """ @@ -180,16 +184,16 @@ def array_subtract_min(x: Column, pad_value: Optional[float]) -> Column: ) return dataset.withColumn(self.getOutputCol(), array_subtract) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> keras.layers.Layer: """ - Gets the tensorflow layer for the sequential difference transformer. + Gets the Keras layer for the sequential difference transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter that + :returns: Keras layer with name equal to the layerName parameter that performs the sequential difference operation. """ return ArraySubtractMinimumLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), pad_value=self.getPadValue(), ) diff --git a/src/kamae/spark/transformers/base.py b/src/kamae/spark/transformers/base.py index f819f1b1..602cd483 100644 --- a/src/kamae/spark/transformers/base.py +++ b/src/kamae/spark/transformers/base.py @@ -15,7 +15,7 @@ from abc import abstractmethod from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union -import tensorflow as tf +import keras from pyspark.ml import Transformer from pyspark.sql import DataFrame @@ -89,27 +89,29 @@ def transform( ).with_traceback(e.__traceback__) @abstractmethod - def get_tf_layer(self) -> Union[tf.keras.layers.Layer, List[tf.keras.layers.Layer]]: + def get_keras_layer( + self, + ) -> Union[keras.layers.Layer, List[keras.layers.Layer]]: """ - Gets the tensorflow layer to be used in the model. + Gets the Keras layer to be used in the model. This is the only abstract method that must be implemented. - :returns: Tensorflow Layer + :returns: Keras Layer """ raise NotImplementedError def construct_layer_info(self) -> Dict[str, Any]: """ Constructs the layer info dictionary. - Contains the layer name, the tensorflow layer, and the inputs and outputs. + Contains the layer name, the Keras layer, and the inputs and outputs. This is used when constructing the pipeline graph. :returns: Dictionary containing layer information such as - name, tensorflow layer, inputs, and outputs. + name, Keras layer, inputs, and outputs. """ inputs, outputs = self.get_layer_inputs_outputs() return { "name": self.getOrDefault("layerName"), - "layer": self.get_tf_layer(), + "layer": self.get_keras_layer(), "inputs": inputs, "outputs": outputs, } diff --git a/src/kamae/spark/transformers/bearing_angle.py b/src/kamae/spark/transformers/bearing_angle.py index 330195ff..9479f30f 100644 --- a/src/kamae/spark/transformers/bearing_angle.py +++ b/src/kamae/spark/transformers/bearing_angle.py @@ -19,15 +19,16 @@ import math from typing import List, Optional +import keras import pyspark.sql.functions as F -import tensorflow as tf from pyspark import keyword_only from pyspark.sql import Column, DataFrame from pyspark.sql.types import DataType, DoubleType, FloatType +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.layers import BearingAngleLayer from kamae.spark.params import LatLonConstantParams, MultiInputSingleOutputParams from kamae.spark.utils import multi_input_single_output_scalar_transform -from kamae.tensorflow.layers import BearingAngleLayer from .base import BaseTransformer @@ -72,6 +73,9 @@ class BearingAngleTransformer( are out of bounds. For lat, this is [-90, 90] and for lon, this is [-180, 180]. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + @keyword_only def __init__( self, @@ -93,7 +97,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :param latLonConstant: Optional list of lat/lon constant to use. Must be in the order [lat, lon]. @@ -218,15 +222,15 @@ def bearing_calculate_transform( return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> keras.layers.Layer: """ - Gets the tensorflow layer for the bearing angle transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter that + Gets the Keras layer for the bearing angle transformer. + :returns: Keras layer with name equal to the layerName parameter that computes the bearing angle between two lat/lon pairs. """ return BearingAngleLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), lat_lon_constant=self.getLatLonConstant(), ) diff --git a/src/kamae/spark/transformers/bin.py b/src/kamae/spark/transformers/bin.py index bd360e82..511ce8cb 100644 --- a/src/kamae/spark/transformers/bin.py +++ b/src/kamae/spark/transformers/bin.py @@ -18,8 +18,8 @@ # pylint: disable=no-member from typing import Any, List, Optional, Union +import keras import pyspark.sql.functions as F -import tensorflow as tf from pyspark import keyword_only from pyspark.ml.param import Param, Params, TypeConverters from pyspark.sql import Column, DataFrame @@ -33,9 +33,10 @@ ShortType, ) +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.layers import BinLayer from kamae.spark.params import SingleInputSingleOutputParams from kamae.spark.utils import single_input_single_output_scalar_transform -from kamae.tensorflow.layers import BinLayer from kamae.utils import get_condition_operator from .base import BaseTransformer @@ -203,6 +204,9 @@ class BinTransformer( If no conditions evaluate to True, the default label is returned. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + @keyword_only def __init__( self, @@ -236,7 +240,7 @@ def __init__( :param binValues: Float values to compare to input column. :param binLabels: Bin labels to use when binning. :param defaultLabel: Default label to use when binning. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :returns: None - class instantiated. """ @@ -305,17 +309,17 @@ def bin_func(x: Column) -> Column: ) return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> keras.layers.Layer: """ - Gets the tensorflow layer for the bin transformer. + Gets the Keras layer for the bin transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter that + :returns: Keras layer with name equal to the layerName parameter that performs the binning operation. """ return BinLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), condition_operators=self.getConditionOperators(), bin_values=self.getBinValues(), bin_labels=self.getBinLabels(), diff --git a/src/kamae/spark/transformers/bloom_encode.py b/src/kamae/spark/transformers/bloom_encode.py index 3865e981..2b1caa49 100644 --- a/src/kamae/spark/transformers/bloom_encode.py +++ b/src/kamae/spark/transformers/bloom_encode.py @@ -25,13 +25,14 @@ from pyspark.sql import Column, DataFrame from pyspark.sql.types import ArrayType, DataType, IntegerType, StringType +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.tensorflow.layers import BloomEncodeLayer from kamae.spark.params import HashIndexParams, SingleInputSingleOutputParams from kamae.spark.utils import ( hash_udf, single_input_single_output_array_udf_transform, single_input_single_output_scalar_transform, ) -from kamae.tensorflow.layers import BloomEncodeLayer from .base import BaseTransformer @@ -128,6 +129,9 @@ class BloomEncodeTransformer( See paper for more details: https://arxiv.org/pdf/1706.03993.pdf """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + @keyword_only def __init__( self, @@ -151,7 +155,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :param numHashFns: Number of hash functions to use. Defaults to 3. The paper suggests a range of 2-4 hash functions for optimal performance. @@ -254,17 +258,17 @@ def bloom_encode(x: List[str]) -> List[int]: ) return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> tf.keras.layers.Layer: """ - Gets the tensorflow layer that performs the bloom encoding. + Gets the Keras layer that performs the bloom encoding. - :returns: Tensorflow keras layer with name equal to the layerName parameter + :returns: Keras layer with name equal to the layerName parameter that performs the bloom encoding operation. """ return BloomEncodeLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), num_hash_fns=self.getNumHashFns(), num_bins=self.getNumBins(), mask_value=self.getMaskValue(), diff --git a/src/kamae/spark/transformers/bucketize.py b/src/kamae/spark/transformers/bucketize.py index 20065a29..b639f0cc 100644 --- a/src/kamae/spark/transformers/bucketize.py +++ b/src/kamae/spark/transformers/bucketize.py @@ -26,11 +26,12 @@ from pyspark.sql import DataFrame from pyspark.sql.types import DataType, DoubleType, FloatType, IntegerType, LongType +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.tensorflow.layers import BucketizeLayer from kamae.spark.params import SingleInputSingleOutputParams from kamae.spark.utils.transform_utils import ( single_input_single_output_scalar_udf_transform, ) -from kamae.tensorflow.layers import BucketizeLayer from .base import BaseTransformer @@ -89,6 +90,10 @@ class BucketizeTransformer( The 0 index is reserved for masking/padding. """ + jit_compatible = True + + supported_backends = TENSORFLOW_ONLY + @keyword_only def __init__( self, @@ -108,7 +113,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :param splits: List of float values to use for bucketing. :returns: None - class instantiated. @@ -160,16 +165,16 @@ def bucketize(value: Optional[Union[float, int]]) -> Optional[int]: output_col, ) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> tf.keras.layers.Layer: """ - Gets the tensorflow layer for the BucketizeLayer transformer. + Gets the Keras layer for the BucketizeLayer transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter that + :returns: Keras layer with name equal to the layerName parameter that performs a bucketing operation. """ return BucketizeLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), splits=self.getSplits(), ) diff --git a/src/kamae/spark/transformers/conditional_standard_scale.py b/src/kamae/spark/transformers/conditional_standard_scale.py index eaea4570..867a13b8 100644 --- a/src/kamae/spark/transformers/conditional_standard_scale.py +++ b/src/kamae/spark/transformers/conditional_standard_scale.py @@ -18,20 +18,21 @@ # pylint: disable=no-member from typing import List, Optional +import keras import numpy as np import pyspark.sql.functions as F -import tensorflow as tf from pyspark import keyword_only from pyspark.sql import DataFrame from pyspark.sql.types import ArrayType, DataType, DoubleType, FloatType +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.layers import ConditionalStandardScaleLayer from kamae.spark.params import ( SingleInputSingleOutputParams, StandardScaleSkipZerosParams, ) from kamae.spark.transformers.standard_scale import StandardScaleParams from kamae.spark.utils.transform_utils import single_input_single_output_array_transform -from kamae.tensorflow.layers import ConditionalStandardScaleLayer from .base import BaseTransformer @@ -54,6 +55,9 @@ class ConditionalStandardScaleTransformer( shape across all rows. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + @keyword_only def __init__( self, @@ -76,7 +80,7 @@ def __init__( :param inputCol: Input column name to standardize. :param outputCol: Output column name. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. :param inputDtype: Input data type to cast input column to before transforming. @@ -152,19 +156,19 @@ def _transform(self, dataset: DataFrame) -> DataFrame: output_col = output_col.getItem(0) return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> keras.layers.Layer: """ - Gets the tensorflow layer for the standard scaler transformer. + Gets the Keras layer for the standard scaler transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter + :returns: Keras layer with name equal to the layerName parameter that performs the standardization. """ np_mean = np.array(self.getMean()) np_variance = np.array(self.getStddev()) ** 2 return ConditionalStandardScaleLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), mean=np_mean, variance=np_variance, skip_zeros=self.getSkipZeros(), diff --git a/src/kamae/spark/transformers/cosine_similarity.py b/src/kamae/spark/transformers/cosine_similarity.py index 178f0c06..9db45583 100644 --- a/src/kamae/spark/transformers/cosine_similarity.py +++ b/src/kamae/spark/transformers/cosine_similarity.py @@ -18,15 +18,16 @@ # pylint: disable=no-member from typing import List, Optional +import keras import pyspark.sql.functions as F -import tensorflow as tf from pyspark import keyword_only from pyspark.sql import Column, DataFrame from pyspark.sql.types import ArrayType, DataType, DoubleType, FloatType +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.layers import CosineSimilarityLayer from kamae.spark.params import MultiInputSingleOutputParams from kamae.spark.utils import multi_input_single_output_array_transform -from kamae.tensorflow.layers import CosineSimilarityLayer from .base import BaseTransformer @@ -40,6 +41,9 @@ class CosineSimilarityTransformer( This transformer computes the cosine similarity between two array columns. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + @keyword_only def __init__( self, @@ -58,7 +62,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :returns: None - class instantiated. """ @@ -141,17 +145,17 @@ def norm(x: Column, col_name: str) -> Column: ) return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> keras.layers.Layer: """ - Gets the tensorflow layer for the cosine similarity transformer. + Gets the Keras layer for the cosine similarity transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter that + :returns: Keras layer with name equal to the layerName parameter that computes the cosine similarity between two arrays. """ return CosineSimilarityLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), axis=-1, keepdims=True, ) diff --git a/src/kamae/spark/transformers/current_date.py b/src/kamae/spark/transformers/current_date.py index 4b6c3eeb..0b68d74f 100644 --- a/src/kamae/spark/transformers/current_date.py +++ b/src/kamae/spark/transformers/current_date.py @@ -24,10 +24,11 @@ from pyspark.sql import Column, DataFrame, SparkSession from pyspark.sql.types import DataType +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.tensorflow.layers import CurrentDateLayer from kamae.spark.params import SingleInputSingleOutputParams from kamae.spark.transformers.base import BaseTransformer from kamae.spark.utils import single_input_single_output_scalar_transform -from kamae.tensorflow.layers import CurrentDateLayer class CurrentDateTransformer(BaseTransformer, SingleInputSingleOutputParams): @@ -35,6 +36,9 @@ class CurrentDateTransformer(BaseTransformer, SingleInputSingleOutputParams): Returns the current UTC date in yyyy-MM-dd format. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + @keyword_only def __init__( self, @@ -53,7 +57,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :returns: None - class instantiated. """ @@ -113,14 +117,14 @@ def current_utc_date() -> Column: return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> tf.keras.layers.Layer: """ - Gets the tensorflow layer. + Gets the Keras layer. - :returns: CurrentDateLayer Tensorflow layer. + :returns: CurrentDateLayer Keras layer. """ return CurrentDateLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), ) diff --git a/src/kamae/spark/transformers/current_date_time.py b/src/kamae/spark/transformers/current_date_time.py index 59827ad8..b6bf08b6 100644 --- a/src/kamae/spark/transformers/current_date_time.py +++ b/src/kamae/spark/transformers/current_date_time.py @@ -24,10 +24,11 @@ from pyspark.sql import Column, DataFrame, SparkSession from pyspark.sql.types import DataType +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.tensorflow.layers import CurrentDateTimeLayer from kamae.spark.params import SingleInputSingleOutputParams from kamae.spark.transformers.base import BaseTransformer from kamae.spark.utils import single_input_single_output_scalar_transform -from kamae.tensorflow.layers import CurrentDateTimeLayer class CurrentDateTimeTransformer(BaseTransformer, SingleInputSingleOutputParams): @@ -42,6 +43,9 @@ class CurrentDateTimeTransformer(BaseTransformer, SingleInputSingleOutputParams) It is recommended not to rely on parity at the millisecond level. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + @keyword_only def __init__( self, @@ -60,7 +64,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :returns: None - class instantiated. """ @@ -123,14 +127,14 @@ def current_utc_timestamp() -> Column: return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> tf.keras.layers.Layer: """ - Gets the tensorflow layer. + Gets the Keras layer. - :returns: CurrentDateTimeLayer Tensorflow layer. + :returns: CurrentDateTimeLayer Keras layer. """ return CurrentDateTimeLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), ) diff --git a/src/kamae/spark/transformers/current_unix_timestamp.py b/src/kamae/spark/transformers/current_unix_timestamp.py index 099c621b..28afe920 100644 --- a/src/kamae/spark/transformers/current_unix_timestamp.py +++ b/src/kamae/spark/transformers/current_unix_timestamp.py @@ -24,10 +24,11 @@ from pyspark.sql import Column, DataFrame from pyspark.sql.types import DataType +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.tensorflow.layers import CurrentUnixTimestampLayer from kamae.spark.params import SingleInputSingleOutputParams, UnixTimestampParams from kamae.spark.transformers.base import BaseTransformer from kamae.spark.utils import single_input_single_output_scalar_transform -from kamae.tensorflow.layers import CurrentUnixTimestampLayer class CurrentUnixTimestampTransformer( @@ -45,6 +46,9 @@ class CurrentUnixTimestampTransformer( It is recommended not to rely on parity at the millisecond level. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + @keyword_only def __init__( self, @@ -64,7 +68,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :param unit: Unit of the output timestamp. Can be either "s" (or "seconds") for seconds or "ms" (or "milliseconds") for milliseconds. Defaults to "s". @@ -129,15 +133,15 @@ def current_unix_timestamp() -> Column: return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> tf.keras.layers.Layer: """ - Gets the tensorflow layer. + Gets the Keras layer. - :returns: CurrentUnixTimestampLayer Tensorflow layer. + :returns: CurrentUnixTimestampLayer Keras layer. """ return CurrentUnixTimestampLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), unit=self.getUnit(), ) diff --git a/src/kamae/spark/transformers/date_add.py b/src/kamae/spark/transformers/date_add.py index e1b66b26..405350f2 100644 --- a/src/kamae/spark/transformers/date_add.py +++ b/src/kamae/spark/transformers/date_add.py @@ -31,6 +31,8 @@ StringType, ) +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.tensorflow.layers import DateAddLayer from kamae.spark.params import ( MultiInputSingleOutputParams, SingleInputSingleOutputParams, @@ -40,7 +42,6 @@ get_element_type, multi_input_single_output_scalar_transform, ) -from kamae.tensorflow.layers import DateAddLayer class DateAdditionParams(Params): @@ -88,6 +89,9 @@ class DateAddTransformer( WARNING: This transform destroys the time component of the date column. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + @keyword_only def __init__( self, @@ -108,7 +112,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Layer name. Used as the name of the tensorflow layer + :param layerName: Layer name. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :param numDays: Number of days to add/subtract. Negative values subtract. :returns: None - class instantiated. @@ -212,15 +216,15 @@ def _transform(self, dataset: DataFrame) -> DataFrame: ) return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> tf.keras.layers.Layer: """ - Gets the tensorflow layer. + Gets the Keras layer. - :returns: DateAddLayer Tensorflow layer. + :returns: DateAddLayer Keras layer. """ return DateAddLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), num_days=self.getNumDays(), ) diff --git a/src/kamae/spark/transformers/date_diff.py b/src/kamae/spark/transformers/date_diff.py index 6bc7a0c9..9ee42fab 100644 --- a/src/kamae/spark/transformers/date_diff.py +++ b/src/kamae/spark/transformers/date_diff.py @@ -24,9 +24,10 @@ from pyspark.sql import Column, DataFrame from pyspark.sql.types import DataType, StringType +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.tensorflow.layers import DateDiffLayer from kamae.spark.params import DefaultIntValueParams, MultiInputSingleOutputParams from kamae.spark.utils import multi_input_single_output_scalar_transform -from kamae.tensorflow.layers import DateDiffLayer from .base import BaseTransformer @@ -41,6 +42,9 @@ class DateDiffTransformer( This transformer calculates the difference between two dates. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + @keyword_only def __init__( self, @@ -63,7 +67,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :param defaultValue: Default value to use when one of the dates is the empty string. Empty strings can be used when the date is not available. @@ -132,16 +136,16 @@ def date_diff(x: Column) -> Column: ) return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> tf.keras.layers.Layer: """ - Gets the tensorflow layer for the absolute value transformer. + Gets the Keras layer for the absolute value transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter that + :returns: Keras layer with name equal to the layerName parameter that performs an absolute value operation. """ return DateDiffLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), default_value=self.getDefaultValue(), ) diff --git a/src/kamae/spark/transformers/date_parse.py b/src/kamae/spark/transformers/date_parse.py index 4ac301f5..5bf9eec2 100644 --- a/src/kamae/spark/transformers/date_parse.py +++ b/src/kamae/spark/transformers/date_parse.py @@ -26,10 +26,11 @@ from pyspark.sql import Column, DataFrame from pyspark.sql.types import DataType, StringType +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.tensorflow.layers import DateParseLayer from kamae.spark.params import DefaultIntValueParams, SingleInputSingleOutputParams from kamae.spark.transformers.base import BaseTransformer from kamae.spark.utils import single_input_single_output_scalar_transform -from kamae.tensorflow.layers import DateParseLayer class DateParseParams(DefaultIntValueParams): @@ -103,6 +104,9 @@ class DateParseTransformer( fields will be returned as 0. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + @keyword_only def __init__( self, @@ -126,7 +130,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Layer name. Used as the name of the tensorflow layer + :param layerName: Layer name. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :returns: None - class instantiated. """ @@ -216,11 +220,11 @@ def _parse_date(self, column: Column) -> Column: return formatted_date.cast("int") - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> tf.keras.layers.Layer: """ - Gets the tensorflow layer. + Gets the Keras layer. - :returns: DateParseLayer Tensorflow layer. + :returns: DateParseLayer Keras layer. """ if not self.isDefined("datePart"): @@ -229,8 +233,8 @@ def get_tf_layer(self) -> tf.keras.layers.Layer: return DateParseLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), date_part=date_part, default_value=self.getDefaultValue(), ) diff --git a/src/kamae/spark/transformers/date_time_to_unix_timestamp.py b/src/kamae/spark/transformers/date_time_to_unix_timestamp.py index 3fc90c57..dbd8b74c 100644 --- a/src/kamae/spark/transformers/date_time_to_unix_timestamp.py +++ b/src/kamae/spark/transformers/date_time_to_unix_timestamp.py @@ -24,10 +24,11 @@ from pyspark.sql import Column, DataFrame from pyspark.sql.types import DataType, StringType +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.tensorflow.layers import DateTimeToUnixTimestampLayer from kamae.spark.params import SingleInputSingleOutputParams, UnixTimestampParams from kamae.spark.transformers.base import BaseTransformer from kamae.spark.utils import single_input_single_output_scalar_transform -from kamae.tensorflow.layers import DateTimeToUnixTimestampLayer class DateTimeToUnixTimestampTransformer( @@ -39,6 +40,9 @@ class DateTimeToUnixTimestampTransformer( The unix timestamp can be in milliseconds or seconds, set by the `unit` parameter. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + @keyword_only def __init__( self, @@ -60,7 +64,7 @@ def __init__( transforming. :param unit: Unit of the output timestamp. Can be `milliseconds` (shorthand `ms`) or `seconds` (shorthand `s`). Default is `s` (seconds). - :param layerName: Layer name. Used as the name of the tensorflow layer + :param layerName: Layer name. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :returns: None - class instantiated. """ @@ -131,15 +135,15 @@ def datetime_to_unix_timestamp(datetime: Column) -> Column: ) return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> tf.keras.layers.Layer: """ - Gets the tensorflow layer that performs the datetime to unix timestamp. + Gets the Keras layer that performs the datetime to unix timestamp. - :returns: Tensorflow layer that performs the unix timestamp to date transform. + :returns: Keras layer that performs the unix timestamp to date transform. """ return DateTimeToUnixTimestampLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), unit=self.getUnit(), ) diff --git a/src/kamae/spark/transformers/divide.py b/src/kamae/spark/transformers/divide.py index cac93e35..74791fdc 100644 --- a/src/kamae/spark/transformers/divide.py +++ b/src/kamae/spark/transformers/divide.py @@ -19,19 +19,20 @@ from functools import reduce from typing import List, Optional +import keras import pyspark.sql.functions as F -import tensorflow as tf from pyspark import keyword_only from pyspark.sql import Column, DataFrame from pyspark.sql.types import DataType, DoubleType, FloatType +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.layers import DivideLayer from kamae.spark.params import ( MathFloatConstantParams, MultiInputSingleOutputParams, SingleInputSingleOutputParams, ) from kamae.spark.utils import multi_input_single_output_scalar_transform -from kamae.tensorflow.layers import DivideLayer from .base import BaseTransformer @@ -47,6 +48,9 @@ class DivideTransformer( This transformer divides a column by a constant or another column. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + @keyword_only def __init__( self, @@ -69,7 +73,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :param mathFloatConstant: Optional constant to divide by. If not provided, then two input columns are required. @@ -127,16 +131,16 @@ def divide_no_nan(column1: Column, column2: Column) -> Column: ) return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> keras.layers.Layer: """ - Gets the tensorflow layer for the divide transformer. + Gets the Keras layer for the divide transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter that + :returns: Keras layer with name equal to the layerName parameter that performs a divide operation. """ return DivideLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), divisor=self.getMathFloatConstant(), ) diff --git a/src/kamae/spark/transformers/exp.py b/src/kamae/spark/transformers/exp.py index 7d4a38bd..f8c14886 100644 --- a/src/kamae/spark/transformers/exp.py +++ b/src/kamae/spark/transformers/exp.py @@ -18,15 +18,16 @@ # pylint: disable=no-member from typing import List, Optional +import keras import pyspark.sql.functions as F -import tensorflow as tf from pyspark import keyword_only from pyspark.sql import DataFrame from pyspark.sql.types import DataType, DoubleType, FloatType +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.layers import ExpLayer from kamae.spark.params import SingleInputSingleOutputParams from kamae.spark.utils import single_input_single_output_scalar_transform -from kamae.tensorflow.layers import ExpLayer from .base import BaseTransformer @@ -40,6 +41,9 @@ class ExpTransformer( This transformer applies exp(x) operation to the input. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + @keyword_only def __init__( self, @@ -58,7 +62,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :returns: None - class instantiated. """ @@ -94,15 +98,15 @@ def _transform(self, dataset: DataFrame) -> DataFrame: ) return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> keras.layers.Layer: """ - Gets the tensorflow layer for the exp value transformer. + Gets the Keras layer for the exp value transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter that + :returns: Keras layer with name equal to the layerName parameter that performs an exp value operation. """ return ExpLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), ) diff --git a/src/kamae/spark/transformers/exponent.py b/src/kamae/spark/transformers/exponent.py index 9438c8dd..beeb194e 100644 --- a/src/kamae/spark/transformers/exponent.py +++ b/src/kamae/spark/transformers/exponent.py @@ -18,19 +18,20 @@ # pylint: disable=no-member from typing import List, Optional +import keras import pyspark.sql.functions as F -import tensorflow as tf from pyspark import keyword_only from pyspark.ml.param import Param, Params, TypeConverters from pyspark.sql import DataFrame from pyspark.sql.types import DataType, DoubleType, FloatType +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.layers import ExponentLayer from kamae.spark.params import ( MultiInputSingleOutputParams, SingleInputSingleOutputParams, ) from kamae.spark.utils import multi_input_single_output_scalar_transform -from kamae.tensorflow.layers import ExponentLayer from .base import BaseTransformer @@ -77,6 +78,9 @@ class ExponentTransformer( case of two inputs. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + @keyword_only def __init__( self, @@ -100,7 +104,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :param exponent: Optional exponent/power to raise the input to. If not provided, then two input columns are required. @@ -171,16 +175,16 @@ def _transform(self, dataset: DataFrame) -> DataFrame: ) return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> keras.layers.Layer: """ - Gets the tensorflow layer for the exp value transformer. + Gets the Keras layer for the exp value transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter that + :returns: Keras layer with name equal to the layerName parameter that performs an exp value operation. """ return ExponentLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), exponent=self.getExponent(), ) diff --git a/src/kamae/spark/transformers/hash_index.py b/src/kamae/spark/transformers/hash_index.py index cb9551e2..d7c3b51e 100644 --- a/src/kamae/spark/transformers/hash_index.py +++ b/src/kamae/spark/transformers/hash_index.py @@ -24,9 +24,10 @@ from pyspark.sql import DataFrame from pyspark.sql.types import DataType, IntegerType, StringType +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.tensorflow.layers import HashIndexLayer from kamae.spark.params import HashIndexParams, SingleInputSingleOutputParams from kamae.spark.utils import hash_udf, single_input_single_output_scalar_udf_transform -from kamae.tensorflow.layers import HashIndexLayer from .base import BaseTransformer @@ -47,6 +48,9 @@ class HashIndexTransformer( characters. If you have null characters in your data, you should remove them. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + @keyword_only def __init__( self, @@ -67,7 +71,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :param numBins: Number of bins to use for hash indexing. :param maskValue: Mask value to use for hash indexing. @@ -114,17 +118,17 @@ def _transform(self, dataset: DataFrame) -> DataFrame: output_col, ) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> tf.keras.layers.Layer: """ - Gets the tensorflow layer that performs the hash indexing. + Gets the Keras layer that performs the hash indexing. - :returns: Tensorflow keras layer with name equal to the layerName parameter + :returns: Keras layer with name equal to the layerName parameter that performs the hash indexing operation. """ return HashIndexLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), num_bins=self.getNumBins(), mask_value=self.getMaskValue(), ) diff --git a/src/kamae/spark/transformers/haversine_distance.py b/src/kamae/spark/transformers/haversine_distance.py index bfc7ed84..e5833b97 100644 --- a/src/kamae/spark/transformers/haversine_distance.py +++ b/src/kamae/spark/transformers/haversine_distance.py @@ -19,16 +19,17 @@ import math from typing import List, Optional +import keras import pyspark.sql.functions as F -import tensorflow as tf from pyspark import keyword_only from pyspark.ml.param import Param, Params, TypeConverters from pyspark.sql import Column, DataFrame from pyspark.sql.types import DataType, DoubleType, FloatType +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.layers import HaversineDistanceLayer from kamae.spark.params import LatLonConstantParams, MultiInputSingleOutputParams from kamae.spark.utils import multi_input_single_output_scalar_transform -from kamae.tensorflow.layers import HaversineDistanceLayer from .base import BaseTransformer @@ -99,6 +100,9 @@ class HaversineDistanceTransformer( are out of bounds. For lat, this is [-90, 90] and for lon, this is [-180, 180]. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + @keyword_only def __init__( self, @@ -122,7 +126,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :param latLonConstant: Optional list of lat/lon constant to use. Must be in the order [lat, lon]. @@ -256,17 +260,17 @@ def haversine_distance_transform( return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> keras.layers.Layer: """ - Gets the tensorflow layer for the haversine distance transformer. + Gets the Keras layer for the haversine distance transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter that + :returns: Keras layer with name equal to the layerName parameter that computes the haversine distance between two lat/lon pairs. """ return HaversineDistanceLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), lat_lon_constant=self.getLatLonConstant(), unit=self.getUnit(), ) diff --git a/src/kamae/spark/transformers/identity.py b/src/kamae/spark/transformers/identity.py index 14cb8dc6..8e4452f7 100644 --- a/src/kamae/spark/transformers/identity.py +++ b/src/kamae/spark/transformers/identity.py @@ -18,14 +18,15 @@ # pylint: disable=no-member from typing import List, Optional +import keras import pyspark.sql.functions as F -import tensorflow as tf from pyspark import keyword_only from pyspark.sql import DataFrame from pyspark.sql.types import DataType +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.layers import IdentityLayer from kamae.spark.params import SingleInputSingleOutputParams -from kamae.tensorflow.layers import IdentityLayer from .base import BaseTransformer @@ -40,6 +41,9 @@ class IdentityTransformer( Used for cases where you want to keep the input the same. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + @keyword_only def __init__( self, @@ -58,7 +62,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :returns: None - class instantiated. """ @@ -86,18 +90,15 @@ def _transform(self, dataset: DataFrame) -> DataFrame: """ return dataset.withColumn(self.getOutputCol(), F.col(self.getInputCol())) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> keras.layers.Layer: """ - Gets the tensorflow layer for the identity transformer. + Gets the Keras layer for the identity transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter that + :returns: Keras layer with name equal to the layerName parameter that performs an IdentityLayer operation. """ - # Tensorflow <= 2.11 does not contain tf.keras.layers.IdentityLayer - # so we use a lambda layer instead. - # When we have a subclassed identity layer, we can use that. return IdentityLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), ) diff --git a/src/kamae/spark/transformers/if_statement.py b/src/kamae/spark/transformers/if_statement.py index 61c6a2be..8d7b6a84 100644 --- a/src/kamae/spark/transformers/if_statement.py +++ b/src/kamae/spark/transformers/if_statement.py @@ -27,12 +27,13 @@ from pyspark.sql import Column, DataFrame from pyspark.sql.types import DataType +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.tensorflow.layers import IfStatementLayer from kamae.spark.params import ( MultiInputSingleOutputParams, SingleInputSingleOutputParams, ) from kamae.spark.utils import multi_input_single_output_scalar_transform -from kamae.tensorflow.layers import IfStatementLayer from kamae.utils import get_condition_operator from .base import BaseTransformer @@ -194,6 +195,9 @@ class IfStatementTransformer( and columns. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + @keyword_only def __init__( self, @@ -221,7 +225,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :param conditionOperator: Operator to use in condition: eq, neq, lt, gt, leq, geq. @@ -383,20 +387,20 @@ def _transform(self, dataset: DataFrame) -> DataFrame: return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> tf.keras.layers.Layer: """ - Gets the tensorflow layer for the numerical if statement transformer. + Gets the Keras layer for the numerical if statement transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter that + :returns: Keras layer with name equal to the layerName parameter that performs the numerical if statement. """ if not self.isDefined("conditionOperator"): - raise ValueError("Must specify conditionOperator to use tensorflow layer.") + raise ValueError("Must specify conditionOperator to use Keras layer.") return IfStatementLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), condition_operator=self.getConditionOperator(), value_to_compare=self.getValueToCompare(), result_if_true=self.getResultIfTrue(), diff --git a/src/kamae/spark/transformers/impute.py b/src/kamae/spark/transformers/impute.py index 75a5b66f..7d09693d 100644 --- a/src/kamae/spark/transformers/impute.py +++ b/src/kamae/spark/transformers/impute.py @@ -18,16 +18,17 @@ # pylint: disable=no-member from typing import List, Optional, Union +import keras import pyspark.sql.functions as F -import tensorflow as tf from pyspark import keyword_only from pyspark.ml.param import Param, Params, TypeConverters from pyspark.sql import DataFrame from pyspark.sql.types import DataType +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.layers import ImputeLayer from kamae.spark.params import SingleInputSingleOutputParams from kamae.spark.utils import single_input_single_output_scalar_transform -from kamae.tensorflow.layers import ImputeLayer from .base import BaseTransformer @@ -97,6 +98,9 @@ class ImputeTransformer(BaseTransformer, ImputeParams, SingleInputSingleOutputPa value is null or equalling a mask """ + supported_backends = ALL_BACKENDS + jit_compatible = True + @keyword_only def __init__( self, @@ -117,7 +121,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. :param imputeValue: String, float or int value to impute in place of mask or nulls. @@ -163,19 +167,19 @@ def _transform(self, dataset: DataFrame) -> DataFrame: ) return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> keras.layers.Layer: """ - Gets the tensorflow layer for the imputation transformer. + Gets the Keras layer for the imputation transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter + :returns: Keras layer with name equal to the layerName parameter that performs the imputation. """ mask_value = self.getMaskValue() return ImputeLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), impute_value=self.getImputeValue(), mask_value=mask_value, ) diff --git a/src/kamae/spark/transformers/lambda_function.py b/src/kamae/spark/transformers/lambda_function.py index d00fc23b..e97cac23 100644 --- a/src/kamae/spark/transformers/lambda_function.py +++ b/src/kamae/spark/transformers/lambda_function.py @@ -27,14 +27,15 @@ from pyspark.sql import DataFrame from pyspark.sql.types import ArrayType, DataType, StructField, StructType +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.tensorflow.layers import LambdaFunctionLayer +from kamae.keras.tensorflow.utils.typing import Tensor from kamae.spark.params import ( MultiInputMultiOutputParams, MultiInputSingleOutputParams, SingleInputMultiOutputParams, SingleInputSingleOutputParams, ) -from kamae.tensorflow.layers import LambdaFunctionLayer -from kamae.tensorflow.typing import Tensor from .base import BaseTransformer @@ -138,6 +139,9 @@ def my_tf_fn(x): native Spark functions. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + @keyword_only def __init__( self, @@ -175,7 +179,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :returns: None - class instantiated. """ @@ -306,7 +310,7 @@ def _apply_udf_func_to_dataset( a struct column is created and then the columns are extracted. :param dataset: Pyspark dataframe to transform. - :param func: Tensorflow function. + :param func: Keras function. :param input_col_names: List of input column names. :param output_col_names: List of output column names. :param function_return_types: List of return types of the lambda function. @@ -366,7 +370,7 @@ def tf_function_wrapper( If value is a list of size 1, return the single value. - If the output tensor is a string, decodes the bytes to a string. - :param fn: Tensorflow function. + :param fn: Keras function. :returns: Function that can be used within a Spark UDF. """ @@ -425,16 +429,16 @@ def wrapper(*args: Any) -> Union[Any, tuple[Any, ...]]: function_return_types=function_return_types, ) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> tf.keras.layers.Layer: """ - Gets the tensorflow layer for the lambda function transformer. + Gets the Keras layer for the lambda function transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter that + :returns: Keras layer with name equal to the layerName parameter that performs the lambda function on the input. """ return LambdaFunctionLayer( function=self.getFunction(), name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), ) diff --git a/src/kamae/spark/transformers/list_max.py b/src/kamae/spark/transformers/list_max.py index 72a5a157..2fdc2834 100644 --- a/src/kamae/spark/transformers/list_max.py +++ b/src/kamae/spark/transformers/list_max.py @@ -20,6 +20,8 @@ from pyspark.sql import DataFrame from pyspark.sql.types import DataType, DoubleType, FloatType, StringType +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.tensorflow.layers import ListMaxLayer from kamae.spark.params import ( ListwiseStatisticsParams, MultiInputSingleOutputParams, @@ -27,7 +29,6 @@ SingleInputSingleOutputParams, ) from kamae.spark.utils import check_and_apply_listwise_op -from kamae.tensorflow.layers import ListMaxLayer from .base import BaseTransformer @@ -81,6 +82,10 @@ class ListMaxTransformer( :nanFillValue: Value to fill NaNs results with. Defaults to 0. """ + jit_compatible = True + + supported_backends = TENSORFLOW_ONLY + @keyword_only def __init__( self, @@ -168,17 +173,17 @@ def _transform(self, dataset: DataFrame) -> DataFrame: return dataset - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> tf.keras.layers.Layer: """ - Gets the tensorflow layer for the listwise-maximum transformer. + Gets the Keras layer for the listwise-maximum transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter that + :returns: Keras layer with name equal to the layerName parameter that performs an averaging operation. """ return ListMaxLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), top_n=self.getTopN(), sort_order=self.getSortOrder(), with_segment=self.getWithSegment(), diff --git a/src/kamae/spark/transformers/list_mean.py b/src/kamae/spark/transformers/list_mean.py index 38d37385..fb697f85 100644 --- a/src/kamae/spark/transformers/list_mean.py +++ b/src/kamae/spark/transformers/list_mean.py @@ -29,6 +29,8 @@ StringType, ) +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.tensorflow.layers import ListMeanLayer from kamae.spark.params import ( ListwiseStatisticsParams, MultiInputSingleOutputParams, @@ -36,7 +38,6 @@ SingleInputSingleOutputParams, ) from kamae.spark.utils import check_and_apply_listwise_op -from kamae.tensorflow.layers import ListMeanLayer from .base import BaseTransformer @@ -90,6 +91,10 @@ class ListMeanTransformer( :nanFillValue: Value to fill NaNs results with. Defaults to 0. """ + jit_compatible = True + + supported_backends = TENSORFLOW_ONLY + @keyword_only def __init__( self, @@ -177,17 +182,17 @@ def _transform(self, dataset: DataFrame) -> DataFrame: return dataset - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> tf.keras.layers.Layer: """ - Gets the tensorflow layer for the listwise-mean transformer. + Gets the Keras layer for the listwise-mean transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter that + :returns: Keras layer with name equal to the layerName parameter that performs an averaging operation. """ return ListMeanLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), top_n=self.getTopN(), sort_order=self.getSortOrder(), with_segment=self.getWithSegment(), diff --git a/src/kamae/spark/transformers/list_median.py b/src/kamae/spark/transformers/list_median.py index 851973fc..5d10a86f 100644 --- a/src/kamae/spark/transformers/list_median.py +++ b/src/kamae/spark/transformers/list_median.py @@ -20,6 +20,8 @@ from pyspark.sql import DataFrame from pyspark.sql.types import DataType, DoubleType, FloatType +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.tensorflow.layers import ListMedianLayer from kamae.spark.params import ( ListwiseStatisticsParams, MultiInputSingleOutputParams, @@ -27,7 +29,6 @@ SingleInputSingleOutputParams, ) from kamae.spark.utils import check_listwise_columns, get_listwise_condition_and_window -from kamae.tensorflow.layers import ListMedianLayer from .base import BaseTransformer @@ -72,6 +73,10 @@ class ListMedianTransformer( :nanFillValue: Value to fill NaNs results with. Defaults to 0. """ + jit_compatible = True + + supported_backends = TENSORFLOW_ONLY + @keyword_only def __init__( self, @@ -176,17 +181,17 @@ def _transform(self, dataset: DataFrame) -> DataFrame: return dataset - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> tf.keras.layers.Layer: """ - Gets the tensorflow layer for the listwise-median transformer. + Gets the Keras layer for the listwise-median transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter that + :returns: Keras layer with name equal to the layerName parameter that performs a median operation. """ return ListMedianLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), top_n=self.getTopN(), sort_order=self.getSortOrder(), min_filter_value=self.getMinFilterValue(), diff --git a/src/kamae/spark/transformers/list_min.py b/src/kamae/spark/transformers/list_min.py index 10057abd..229212d0 100644 --- a/src/kamae/spark/transformers/list_min.py +++ b/src/kamae/spark/transformers/list_min.py @@ -20,6 +20,8 @@ from pyspark.sql import DataFrame from pyspark.sql.types import DataType, DoubleType, FloatType, StringType +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.tensorflow.layers import ListMinLayer from kamae.spark.params import ( ListwiseStatisticsParams, MultiInputSingleOutputParams, @@ -27,7 +29,6 @@ SingleInputSingleOutputParams, ) from kamae.spark.utils import check_and_apply_listwise_op -from kamae.tensorflow.layers import ListMinLayer from .base import BaseTransformer @@ -81,6 +82,10 @@ class ListMinTransformer( :nanFillValue: Value to fill NaNs results with. Defaults to 0. """ + jit_compatible = True + + supported_backends = TENSORFLOW_ONLY + @keyword_only def __init__( self, @@ -168,17 +173,17 @@ def _transform(self, dataset: DataFrame) -> DataFrame: return dataset - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> tf.keras.layers.Layer: """ - Gets the tensorflow layer for the listwise-minimum transformer. + Gets the Keras layer for the listwise-minimum transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter that + :returns: Keras layer with name equal to the layerName parameter that performs an averaging operation. """ return ListMinLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), top_n=self.getTopN(), sort_order=self.getSortOrder(), with_segment=self.getWithSegment(), diff --git a/src/kamae/spark/transformers/list_rank.py b/src/kamae/spark/transformers/list_rank.py index 81c36b01..4a540331 100644 --- a/src/kamae/spark/transformers/list_rank.py +++ b/src/kamae/spark/transformers/list_rank.py @@ -28,9 +28,10 @@ ShortType, ) +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.tensorflow.layers import ListRankLayer from kamae.spark.params import ListwiseParams, SingleInputSingleOutputParams from kamae.spark.utils import check_listwise_columns -from kamae.tensorflow.layers import ListRankLayer from .base import BaseTransformer @@ -56,6 +57,10 @@ class ListRankTransformer( for listwise operation. Default is 'desc'. """ + jit_compatible = True + + supported_backends = TENSORFLOW_ONLY + @keyword_only def __init__( self, @@ -127,16 +132,16 @@ def _transform(self, dataset: DataFrame) -> DataFrame: return dataset - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> tf.keras.layers.Layer: """ - Gets the tensorflow layer for the listwise-rank transformer. + Gets the Keras layer for the listwise-rank transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter that + :returns: Keras layer with name equal to the layerName parameter that performs a ranking operation. """ return ListRankLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), sort_order=self.getSortOrder(), ) diff --git a/src/kamae/spark/transformers/list_std_dev.py b/src/kamae/spark/transformers/list_std_dev.py index e770b6b6..2d64679f 100644 --- a/src/kamae/spark/transformers/list_std_dev.py +++ b/src/kamae/spark/transformers/list_std_dev.py @@ -20,6 +20,8 @@ from pyspark.sql import DataFrame from pyspark.sql.types import DataType, DoubleType, FloatType +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.tensorflow.layers import ListStdDevLayer from kamae.spark.params import ( ListwiseStatisticsParams, MultiInputSingleOutputParams, @@ -27,7 +29,6 @@ SingleInputSingleOutputParams, ) from kamae.spark.utils import check_listwise_columns, get_listwise_condition_and_window -from kamae.tensorflow.layers import ListStdDevLayer from .base import BaseTransformer @@ -72,6 +73,10 @@ class ListStdDevTransformer( :nanFillValue: Value to fill NaNs results with. Defaults to 0. """ + jit_compatible = True + + supported_backends = TENSORFLOW_ONLY + @keyword_only def __init__( self, @@ -156,17 +161,17 @@ def _transform(self, dataset: DataFrame) -> DataFrame: return dataset - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> tf.keras.layers.Layer: """ - Gets the tensorflow layer for the listwise-stddev transformer. + Gets the Keras layer for the listwise-stddev transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter that + :returns: Keras layer with name equal to the layerName parameter that performs an averaging operation. """ return ListStdDevLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), top_n=self.getTopN(), sort_order=self.getSortOrder(), min_filter_value=self.getMinFilterValue(), diff --git a/src/kamae/spark/transformers/log.py b/src/kamae/spark/transformers/log.py index 5e285e7f..37016cef 100644 --- a/src/kamae/spark/transformers/log.py +++ b/src/kamae/spark/transformers/log.py @@ -18,16 +18,17 @@ # pylint: disable=no-member from typing import List, Optional +import keras import pyspark.sql.functions as F -import tensorflow as tf from pyspark import keyword_only from pyspark.ml.param import Param, Params, TypeConverters from pyspark.sql import DataFrame from pyspark.sql.types import DataType, DoubleType, FloatType +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.layers import LogLayer from kamae.spark.params import SingleInputSingleOutputParams from kamae.spark.utils import single_input_single_output_scalar_transform -from kamae.tensorflow.layers import LogLayer from .base import BaseTransformer @@ -72,6 +73,9 @@ class LogTransformer( This transformer applies a log(alpha + x) transform to the input column. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + @keyword_only def __init__( self, @@ -93,7 +97,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :param alpha: Value to use in log transform: log(alpha + x). Default is 0. :returns: None - class instantiated. @@ -132,16 +136,16 @@ def _transform(self, dataset: DataFrame) -> DataFrame: ) return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> keras.layers.Layer: """ - Gets the tensorflow layer that performs the log transform. + Gets the Keras layer that performs the log transform. - :returns: Tensorflow keras layer with name equal to the layerName parameter + :returns: Keras layer with name equal to the layerName parameter that performs the log(alpha + x) operation. """ return LogLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), alpha=self.getAlpha(), ) diff --git a/src/kamae/spark/transformers/logical_and.py b/src/kamae/spark/transformers/logical_and.py index 5941a283..fab4e70b 100644 --- a/src/kamae/spark/transformers/logical_and.py +++ b/src/kamae/spark/transformers/logical_and.py @@ -20,15 +20,16 @@ from operator import and_ from typing import List, Optional +import keras import pyspark.sql.functions as F -import tensorflow as tf from pyspark import keyword_only from pyspark.sql import DataFrame from pyspark.sql.types import BooleanType, DataType +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.layers import LogicalAndLayer from kamae.spark.params import MultiInputSingleOutputParams from kamae.spark.utils import multi_input_single_output_scalar_transform -from kamae.tensorflow.layers import LogicalAndLayer from .base import BaseTransformer @@ -42,6 +43,9 @@ class LogicalAndTransformer( This transformer performs an element-wise logical and operation on multiple columns. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + @keyword_only def __init__( self, @@ -60,7 +64,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :returns: None - class instantiated. """ @@ -112,15 +116,15 @@ def _transform(self, dataset: DataFrame) -> DataFrame: ) return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> keras.layers.Layer: """ - Gets the tensorflow layer for the logical and transformer. + Gets the Keras layer for the logical and transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter that + :returns: Keras layer with name equal to the layerName parameter that performs a logical and operation. """ return LogicalAndLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), ) diff --git a/src/kamae/spark/transformers/logical_not.py b/src/kamae/spark/transformers/logical_not.py index 6617573f..4855eb5e 100644 --- a/src/kamae/spark/transformers/logical_not.py +++ b/src/kamae/spark/transformers/logical_not.py @@ -18,15 +18,16 @@ # pylint: disable=no-member from typing import List, Optional +import keras import pyspark.sql.functions as F -import tensorflow as tf from pyspark import keyword_only from pyspark.sql import DataFrame from pyspark.sql.types import BooleanType, DataType +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.layers import LogicalNotLayer from kamae.spark.params import SingleInputSingleOutputParams from kamae.spark.utils import single_input_single_output_scalar_transform -from kamae.tensorflow.layers import LogicalNotLayer from .base import BaseTransformer @@ -40,6 +41,9 @@ class LogicalNotTransformer( This transformer performs a logical not operation on a single column. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + @keyword_only def __init__( self, @@ -58,7 +62,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :returns: None - class instantiated. """ @@ -94,15 +98,15 @@ def _transform(self, dataset: DataFrame) -> DataFrame: ) return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> keras.layers.Layer: """ - Gets the tensorflow layer for the logical not transformer. + Gets the Keras layer for the logical not transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter that + :returns: Keras layer with name equal to the layerName parameter that performs a logical not operation. """ return LogicalNotLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), ) diff --git a/src/kamae/spark/transformers/logical_or.py b/src/kamae/spark/transformers/logical_or.py index be851066..e569049c 100644 --- a/src/kamae/spark/transformers/logical_or.py +++ b/src/kamae/spark/transformers/logical_or.py @@ -20,15 +20,16 @@ from operator import or_ from typing import List, Optional +import keras import pyspark.sql.functions as F -import tensorflow as tf from pyspark import keyword_only from pyspark.sql import DataFrame from pyspark.sql.types import BooleanType, DataType +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.layers import LogicalOrLayer from kamae.spark.params import MultiInputSingleOutputParams from kamae.spark.utils import multi_input_single_output_scalar_transform -from kamae.tensorflow.layers import LogicalOrLayer from .base import BaseTransformer @@ -42,6 +43,9 @@ class LogicalOrTransformer( This transformer performs an element-wise logical or operation on multiple columns. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + @keyword_only def __init__( self, @@ -60,7 +64,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :returns: None - class instantiated. """ @@ -112,15 +116,15 @@ def _transform(self, dataset: DataFrame) -> DataFrame: ) return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> keras.layers.Layer: """ - Gets the tensorflow layer for the logical or transformer. + Gets the Keras layer for the logical or transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter that + :returns: Keras layer with name equal to the layerName parameter that performs a logical or operation. """ return LogicalOrLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), ) diff --git a/src/kamae/spark/transformers/max.py b/src/kamae/spark/transformers/max.py index 476bb479..355f71c2 100644 --- a/src/kamae/spark/transformers/max.py +++ b/src/kamae/spark/transformers/max.py @@ -18,8 +18,8 @@ # pylint: disable=no-member from typing import List, Optional +import keras import pyspark.sql.functions as F -import tensorflow as tf from pyspark import keyword_only from pyspark.sql import DataFrame from pyspark.sql.types import ( @@ -32,13 +32,14 @@ ShortType, ) +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.layers import MaxLayer from kamae.spark.params import ( MathFloatConstantParams, MultiInputSingleOutputParams, SingleInputSingleOutputParams, ) from kamae.spark.utils import multi_input_single_output_scalar_transform -from kamae.tensorflow.layers import MaxLayer from .base import BaseTransformer @@ -54,6 +55,9 @@ class MaxTransformer( This transformer gets the max of a column and a constant or another column. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + @keyword_only def __init__( self, @@ -76,7 +80,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :param mathFloatConstant: Optional constant to use for max op. If not provided, then two input columns are required. @@ -133,16 +137,16 @@ def _transform(self, dataset: DataFrame) -> DataFrame: ) return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> keras.layers.Layer: """ - Gets the tensorflow layer for the max transformer. + Gets the Keras layer for the max transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter that + :returns: Keras layer with name equal to the layerName parameter that performs a max operation. """ return MaxLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), max_constant=self.getMathFloatConstant(), ) diff --git a/src/kamae/spark/transformers/mean.py b/src/kamae/spark/transformers/mean.py index 71ad6c50..d4bb3778 100644 --- a/src/kamae/spark/transformers/mean.py +++ b/src/kamae/spark/transformers/mean.py @@ -20,7 +20,7 @@ from operator import add from typing import List, Optional -import tensorflow as tf +import keras from pyspark import keyword_only from pyspark.sql import DataFrame from pyspark.sql.types import ( @@ -33,13 +33,14 @@ ShortType, ) +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.layers import MeanLayer from kamae.spark.params import ( MathFloatConstantParams, MultiInputSingleOutputParams, SingleInputSingleOutputParams, ) from kamae.spark.utils import multi_input_single_output_scalar_transform -from kamae.tensorflow.layers import MeanLayer from .base import BaseTransformer @@ -55,6 +56,9 @@ class MeanTransformer( This transformer gets the mean of a column and a constant or another column. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + @keyword_only def __init__( self, @@ -77,7 +81,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :param mathFloatConstant: Optional constant to use for min op. If not provided, then two input columns are required. @@ -136,16 +140,16 @@ def _transform(self, dataset: DataFrame) -> DataFrame: ) return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> keras.layers.Layer: """ - Gets the tensorflow layer for the mean transformer. + Gets the Keras layer for the mean transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter that + :returns: Keras layer with name equal to the layerName parameter that performs a min operation. """ return MeanLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), mean_constant=self.getMathFloatConstant(), ) diff --git a/src/kamae/spark/transformers/min.py b/src/kamae/spark/transformers/min.py index 781af131..e6b4e48b 100644 --- a/src/kamae/spark/transformers/min.py +++ b/src/kamae/spark/transformers/min.py @@ -18,8 +18,8 @@ # pylint: disable=no-member from typing import List, Optional +import keras import pyspark.sql.functions as F -import tensorflow as tf from pyspark import keyword_only from pyspark.sql import DataFrame from pyspark.sql.types import ( @@ -32,13 +32,14 @@ ShortType, ) +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.layers import MinLayer from kamae.spark.params import ( MathFloatConstantParams, MultiInputSingleOutputParams, SingleInputSingleOutputParams, ) from kamae.spark.utils import multi_input_single_output_scalar_transform -from kamae.tensorflow.layers import MinLayer from .base import BaseTransformer @@ -54,6 +55,9 @@ class MinTransformer( This transformer gets the min of a column and a constant or another column. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + @keyword_only def __init__( self, @@ -76,7 +80,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :param mathFloatConstant: Optional constant to use for min op. If not provided, then two input columns are required. @@ -133,16 +137,16 @@ def _transform(self, dataset: DataFrame) -> DataFrame: ) return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> keras.layers.Layer: """ - Gets the tensorflow layer for the min transformer. + Gets the Keras layer for the min transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter that + :returns: Keras layer with name equal to the layerName parameter that performs a min operation. """ return MinLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), min_constant=self.getMathFloatConstant(), ) diff --git a/src/kamae/spark/transformers/min_hash_index.py b/src/kamae/spark/transformers/min_hash_index.py index 6a706533..80783826 100644 --- a/src/kamae/spark/transformers/min_hash_index.py +++ b/src/kamae/spark/transformers/min_hash_index.py @@ -25,12 +25,13 @@ from pyspark.sql import DataFrame from pyspark.sql.types import ArrayType, DataType, IntegerType, StringType +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.tensorflow.layers import MinHashIndexLayer from kamae.spark.params import MaskStringValueParams, SingleInputSingleOutputParams from kamae.spark.utils import ( min_hash_udf, single_input_single_output_array_udf_transform, ) -from kamae.tensorflow.layers import MinHashIndexLayer from .base import BaseTransformer @@ -94,6 +95,9 @@ class MinHashIndexTransformer( characters. If you have null characters in your data, you should remove them. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + @keyword_only def __init__( self, @@ -114,7 +118,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :param numPermutations: Number of permutations of your output min hash. Defaults to 128. This is the length of the output array. @@ -171,17 +175,17 @@ def _transform(self, dataset: DataFrame) -> DataFrame: output_col, ) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> tf.keras.layers.Layer: """ - Gets the tensorflow layer that performs the min hash indexing. + Gets the Keras layer that performs the min hash indexing. - :returns: Tensorflow keras layer with name equal to the layerName parameter + :returns: Keras layer with name equal to the layerName parameter that performs the hash indexing operation. """ return MinHashIndexLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), num_permutations=self.getNumPermutations(), mask_value=self.getMaskValue(), ) diff --git a/src/kamae/spark/transformers/min_max_scale.py b/src/kamae/spark/transformers/min_max_scale.py index b5af36bb..c9653201 100644 --- a/src/kamae/spark/transformers/min_max_scale.py +++ b/src/kamae/spark/transformers/min_max_scale.py @@ -18,17 +18,18 @@ # pylint: disable=no-member from typing import List, Optional +import keras import numpy as np import pyspark.sql.functions as F -import tensorflow as tf from pyspark import keyword_only from pyspark.ml.param import Param, Params, TypeConverters from pyspark.sql import DataFrame from pyspark.sql.types import ArrayType, DataType, DoubleType, FloatType +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.layers import MinMaxScaleLayer from kamae.spark.params import MaskValueParams, SingleInputSingleOutputParams from kamae.spark.utils import single_input_single_output_array_transform -from kamae.tensorflow.layers import MinMaxScaleLayer from .base import BaseTransformer @@ -112,6 +113,9 @@ class MinMaxScaleTransformer( shape across all rows. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + @keyword_only def __init__( self, @@ -133,7 +137,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. :param min: List of minimum values corresponding to the input column. :param max: List of maximum values corresponding to the @@ -197,11 +201,11 @@ def _transform(self, dataset: DataFrame) -> DataFrame: return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> keras.layers.Layer: """ - Gets the tensorflow layer for the min max transformation. + Gets the Keras layer for the min max transformation. - :returns: Tensorflow keras layer with name equal to the layerName parameter + :returns: Keras layer with name equal to the layerName parameter that performs the standardization. """ np_min = np.array(self.getMin()) @@ -209,8 +213,8 @@ def get_tf_layer(self) -> tf.keras.layers.Layer: mask_value = self.getMaskValue() return MinMaxScaleLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), min=np_min, max=np_max, mask_value=mask_value, diff --git a/src/kamae/spark/transformers/modulo.py b/src/kamae/spark/transformers/modulo.py index d247037a..3b105883 100644 --- a/src/kamae/spark/transformers/modulo.py +++ b/src/kamae/spark/transformers/modulo.py @@ -18,8 +18,8 @@ # pylint: disable=no-member from typing import List, Optional +import keras import pyspark.sql.functions as F -import tensorflow as tf from pyspark import keyword_only from pyspark.ml.param import Param, Params, TypeConverters from pyspark.sql import DataFrame @@ -33,12 +33,13 @@ ShortType, ) +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.layers import ModuloLayer from kamae.spark.params import ( MultiInputSingleOutputParams, SingleInputSingleOutputParams, ) from kamae.spark.utils import multi_input_single_output_scalar_transform -from kamae.tensorflow.layers import ModuloLayer from .base import BaseTransformer @@ -89,6 +90,9 @@ class ModuloTransformer( by the divisor parameter or another column. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + @keyword_only def __init__( self, @@ -111,7 +115,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :param divisor: Optional constant to use in modulo operation. If not provided, then two input columns are required. @@ -187,16 +191,16 @@ def _transform(self, dataset: DataFrame) -> DataFrame: ) return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> keras.layers.Layer: """ - Gets the tensorflow layer for the modulo transformer. + Gets the Keras layer for the modulo transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter that + :returns: Keras layer with name equal to the layerName parameter that performs a modulo operation. """ return ModuloLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), divisor=self.getDivisor(), ) diff --git a/src/kamae/spark/transformers/multiply.py b/src/kamae/spark/transformers/multiply.py index a9eb09be..9338647f 100644 --- a/src/kamae/spark/transformers/multiply.py +++ b/src/kamae/spark/transformers/multiply.py @@ -20,7 +20,7 @@ from operator import mul from typing import List, Optional -import tensorflow as tf +import keras from pyspark import keyword_only from pyspark.sql import DataFrame from pyspark.sql.types import ( @@ -33,13 +33,14 @@ ShortType, ) +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.layers import MultiplyLayer from kamae.spark.params import ( MathFloatConstantParams, MultiInputSingleOutputParams, SingleInputSingleOutputParams, ) from kamae.spark.utils import multi_input_single_output_scalar_transform -from kamae.tensorflow.layers import MultiplyLayer from .base import BaseTransformer @@ -55,6 +56,9 @@ class MultiplyTransformer( This transformer multiplies a column by a constant or another column. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + @keyword_only def __init__( self, @@ -77,7 +81,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :param mathFloatConstant: Optional constant to multiply by. If not provided, then input columns are required. @@ -133,16 +137,16 @@ def _transform(self, dataset: DataFrame) -> DataFrame: ) return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> keras.layers.Layer: """ - Gets the tensorflow layer for the multiply transformer. + Gets the Keras layer for the multiply transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter that + :returns: Keras layer with name equal to the layerName parameter that performs a multiply operation. """ return MultiplyLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), multiplier=self.getMathFloatConstant(), ) diff --git a/src/kamae/spark/transformers/numerical_if_statement.py b/src/kamae/spark/transformers/numerical_if_statement.py index 6f9e0195..d243d4ff 100644 --- a/src/kamae/spark/transformers/numerical_if_statement.py +++ b/src/kamae/spark/transformers/numerical_if_statement.py @@ -18,19 +18,20 @@ # pylint: disable=no-member from typing import List, Optional +import keras import pyspark.sql.functions as F -import tensorflow as tf from pyspark import keyword_only from pyspark.ml.param import Param, Params, TypeConverters from pyspark.sql import Column, DataFrame from pyspark.sql.types import DataType, DoubleType, FloatType +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.layers import NumericalIfStatementLayer from kamae.spark.params import ( MultiInputSingleOutputParams, SingleInputSingleOutputParams, ) from kamae.spark.utils import multi_input_single_output_scalar_transform -from kamae.tensorflow.layers import NumericalIfStatementLayer from kamae.utils import get_condition_operator from .base import BaseTransformer @@ -169,6 +170,9 @@ class NumericalIfStatementTransformer( and columns. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + @keyword_only def __init__( self, @@ -196,7 +200,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :param conditionOperator: Operator to use in condition: eq, neq, lt, gt, leq, geq. @@ -358,20 +362,20 @@ def _transform(self, dataset: DataFrame) -> DataFrame: return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> keras.layers.Layer: """ - Gets the tensorflow layer for the numerical if statement transformer. + Gets the Keras layer for the numerical if statement transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter that + :returns: Keras layer with name equal to the layerName parameter that performs the numerical if statement. """ if not self.isDefined("conditionOperator"): - raise ValueError("Must specify conditionOperator to use tensorflow layer.") + raise ValueError("Must specify conditionOperator to use Keras layer.") return NumericalIfStatementLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), condition_operator=self.getConditionOperator(), value_to_compare=self.getValueToCompare(), result_if_true=self.getResultIfTrue(), diff --git a/src/kamae/spark/transformers/one_hot_encode.py b/src/kamae/spark/transformers/one_hot_encode.py index e54c3bd6..14f355bb 100644 --- a/src/kamae/spark/transformers/one_hot_encode.py +++ b/src/kamae/spark/transformers/one_hot_encode.py @@ -32,6 +32,8 @@ StringType, ) +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.tensorflow.layers import OneHotEncodeLayer from kamae.spark.params import ( DropUnseenParams, SingleInputSingleOutputParams, @@ -41,7 +43,6 @@ one_hot_encoding_udf, single_input_single_output_scalar_udf_transform, ) -from kamae.tensorflow.layers import OneHotEncodeLayer from .base import BaseTransformer @@ -63,6 +64,9 @@ class OneHotEncodeTransformer( characters. If you have null characters in your data, you should remove them. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + @keyword_only def __init__( self, @@ -86,7 +90,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :param labelsArray: List of string labels to use for one-hot encoding. :param stringOrderType: How to order the string indices. @@ -158,17 +162,17 @@ def _transform(self, dataset: DataFrame) -> DataFrame: output_col, ) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> tf.keras.layers.Layer: """ - Gets the tensorflow layer for the one-hot encoder transformer. + Gets the Keras layer for the one-hot encoder transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter + :returns: Keras layer with name equal to the layerName parameter that performs the one-hot encoding. """ return OneHotEncodeLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), vocabulary=self.getLabelsArray(), num_oov_indices=self.getNumOOVIndices(), mask_token=self.getMaskToken(), diff --git a/src/kamae/spark/transformers/ordinal_array_encode.py b/src/kamae/spark/transformers/ordinal_array_encode.py index 092b7f42..103e62a9 100644 --- a/src/kamae/spark/transformers/ordinal_array_encode.py +++ b/src/kamae/spark/transformers/ordinal_array_encode.py @@ -20,12 +20,13 @@ from pyspark.sql import DataFrame from pyspark.sql.types import ArrayType, DataType, IntegerType, StringType +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.tensorflow.layers import OrdinalArrayEncodeLayer from kamae.spark.params import PadValueParams, SingleInputSingleOutputParams from kamae.spark.utils import ( ordinal_array_encode_udf, single_input_single_output_array_udf_transform, ) -from kamae.tensorflow.layers import OrdinalArrayEncodeLayer from .base import BaseTransformer @@ -43,6 +44,9 @@ class OrdinalArrayEncodeTransformer( ignore the pad value if specified. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + @keyword_only def __init__( self, @@ -61,7 +65,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer :param padValue: The value to be considered as padding. Defaults to `None`. :returns: None """ @@ -128,17 +132,17 @@ def _transform(self, dataset: DataFrame) -> DataFrame: output_col, ) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> tf.keras.layers.Layer: """ - Gets the tensorflow layer that performs the ordinal array encoding. + Gets the Keras layer that performs the ordinal array encoding. - :returns: Tensorflow keras layer with name equal to the layerName parameter + :returns: Keras layer with name equal to the layerName parameter that performs the ordinal array encoding operation. """ return OrdinalArrayEncodeLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), pad_value=self.getPadValue(), axis=-1, ) diff --git a/src/kamae/spark/transformers/pairwise_cosine_similarity.py b/src/kamae/spark/transformers/pairwise_cosine_similarity.py index 9165392a..f56eecb8 100644 --- a/src/kamae/spark/transformers/pairwise_cosine_similarity.py +++ b/src/kamae/spark/transformers/pairwise_cosine_similarity.py @@ -14,15 +14,16 @@ from typing import List, Optional +import keras import pyspark.sql.functions as F -import tensorflow as tf from pyspark import keyword_only from pyspark.ml.param import Param, Params, TypeConverters from pyspark.sql import Column, DataFrame from pyspark.sql.types import ArrayType, DataType, DoubleType, FloatType +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.layers import PairwiseCosineSimilarityLayer from kamae.spark.params import MultiInputSingleOutputParams -from kamae.tensorflow.layers import PairwiseCosineSimilarityLayer from .base import BaseTransformer @@ -40,6 +41,9 @@ class PairwiseCosineSimilarityTransformer( Output: Array[Float] of size N containing cosine similarities. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + embeddingDim = Param( Params._dummy(), "embeddingDim", @@ -127,10 +131,10 @@ def cosine_sim_at_index(idx: Column) -> Column: similarities = F.transform(indices, cosine_sim_at_index) return dataset.withColumn(self.getOutputCol(), similarities) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> keras.layers.Layer: return PairwiseCosineSimilarityLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), embedding_dim=self.getEmbeddingDim(), ) diff --git a/src/kamae/spark/transformers/round.py b/src/kamae/spark/transformers/round.py index 65a655c2..734eeab3 100644 --- a/src/kamae/spark/transformers/round.py +++ b/src/kamae/spark/transformers/round.py @@ -18,16 +18,17 @@ # pylint: disable=no-member from typing import List, Optional +import keras import pyspark.sql.functions as F -import tensorflow as tf from pyspark import keyword_only from pyspark.ml.param import Param, Params, TypeConverters from pyspark.sql import DataFrame from pyspark.sql.types import DataType, DoubleType, FloatType +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.layers import RoundLayer from kamae.spark.params import SingleInputSingleOutputParams from kamae.spark.utils import single_input_single_output_scalar_transform -from kamae.tensorflow.layers import RoundLayer from .base import BaseTransformer @@ -77,6 +78,9 @@ class RoundTransformer( specified rounding type. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + @keyword_only def __init__( self, @@ -96,7 +100,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :param roundType: Rounding type to use in round transform, one of 'floor', 'ceil' or 'round'. Defaults to 'round'. @@ -141,16 +145,16 @@ def _transform(self, dataset: DataFrame) -> DataFrame: ) return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> keras.layers.Layer: """ - Gets the tensorflow layer for the round transformer. + Gets the Keras layer for the round transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter that + :returns: Keras layer with name equal to the layerName parameter that performs a rounding operation. """ return RoundLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), round_type=self.getRoundType(), ) diff --git a/src/kamae/spark/transformers/round_to_decimal.py b/src/kamae/spark/transformers/round_to_decimal.py index a8d0234a..6dfb42dc 100644 --- a/src/kamae/spark/transformers/round_to_decimal.py +++ b/src/kamae/spark/transformers/round_to_decimal.py @@ -18,16 +18,17 @@ # pylint: disable=no-member from typing import List, Optional +import keras import pyspark.sql.functions as F -import tensorflow as tf from pyspark import keyword_only from pyspark.ml.param import Param, Params, TypeConverters from pyspark.sql import DataFrame from pyspark.sql.types import DataType, DoubleType, FloatType, IntegerType, LongType +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.layers import RoundToDecimalLayer from kamae.spark.params import SingleInputSingleOutputParams from kamae.spark.utils import single_input_single_output_scalar_transform -from kamae.tensorflow.layers import RoundToDecimalLayer from .base import BaseTransformer @@ -75,6 +76,9 @@ class RoundToDecimalTransformer( specified number of decimals. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + @keyword_only def __init__( self, @@ -94,7 +98,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :param decimals: Number of decimals to round to. :returns: None - class instantiated. @@ -132,16 +136,16 @@ def _transform(self, dataset: DataFrame) -> DataFrame: ) return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> keras.layers.Layer: """ - Gets the tensorflow layer for the round transformer. + Gets the Keras layer for the round transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter that + :returns: Keras layer with name equal to the layerName parameter that performs a rounding operation. """ return RoundToDecimalLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), decimals=self.getDecimals(), ) diff --git a/src/kamae/spark/transformers/shared_one_hot_encode.py b/src/kamae/spark/transformers/shared_one_hot_encode.py index a2b3c752..71e176b0 100644 --- a/src/kamae/spark/transformers/shared_one_hot_encode.py +++ b/src/kamae/spark/transformers/shared_one_hot_encode.py @@ -32,6 +32,8 @@ StringType, ) +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.tensorflow.layers import OneHotEncodeLayer from kamae.spark.params import ( DropUnseenParams, MultiInputMultiOutputParams, @@ -41,7 +43,6 @@ one_hot_encoding_udf, single_input_single_output_scalar_udf_transform, ) -from kamae.tensorflow.layers import OneHotEncodeLayer from .base import BaseTransformer @@ -63,6 +64,9 @@ class SharedOneHotEncodeTransformer( characters. If you have null characters in your data, you should remove them. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + @keyword_only def __init__( self, @@ -82,7 +86,7 @@ def __init__( :param inputCols: List of input column names. :param outputCols: List of output column name. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :param labelsArray: List of string labels to use for one-hot encoding. :param stringOrderType: How to order the string indices. @@ -159,19 +163,19 @@ def _transform(self, dataset: DataFrame) -> DataFrame: return dataset.select(*select_cols) - def get_tf_layer(self) -> List[tf.keras.layers.Layer]: + def get_keras_layer(self) -> List[tf.keras.layers.Layer]: """ - Gets the list of tensorflow layers for the shared onehot encoder transformer. + Gets the list of Keras layers for the shared onehot encoder transformer. We need to use a list as each layer could operate on differing input shapes. - :returns: List of Tensorflow keras layer with name equal to the layerName + :returns: List of Keras layer with name equal to the layerName parameter and the input column name, that performs the indexing. """ return [ OneHotEncodeLayer( name=f"{self.getLayerName()}_{input_name}", - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), vocabulary=self.getLabelsArray(), num_oov_indices=self.getNumOOVIndices(), mask_token=self.getMaskToken(), diff --git a/src/kamae/spark/transformers/shared_string_index.py b/src/kamae/spark/transformers/shared_string_index.py index 28cbb333..e4b1aec9 100644 --- a/src/kamae/spark/transformers/shared_string_index.py +++ b/src/kamae/spark/transformers/shared_string_index.py @@ -24,12 +24,13 @@ from pyspark.sql import DataFrame from pyspark.sql.types import DataType, IntegerType, StringType +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.tensorflow.layers import StringIndexLayer from kamae.spark.params import MultiInputMultiOutputParams, StringIndexParams from kamae.spark.utils import ( indexer_udf, single_input_single_output_scalar_udf_transform, ) -from kamae.tensorflow.layers import StringIndexLayer from .base import BaseTransformer @@ -50,6 +51,9 @@ class SharedStringIndexTransformer( characters. If you have null characters in your data, you should remove them. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + @keyword_only def __init__( self, @@ -72,7 +76,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column(s) to after transforming. Must be the same length as inputCols. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :param stringOrderType: How to order the string indices. Options are 'frequencyAsc', 'frequencyDesc', 'alphabeticalAsc', @@ -139,19 +143,19 @@ def _transform(self, dataset: DataFrame) -> DataFrame: return dataset.select(*select_cols) - def get_tf_layer(self) -> List[tf.keras.layers.Layer]: + def get_keras_layer(self) -> List[tf.keras.layers.Layer]: """ - Gets the list of tensorflow layers for the shared string indexer transformer. + Gets the list of Keras layers for the shared string indexer transformer. We need to use a list as each layer could operate on differing input shapes. - :returns: List of Tensorflow keras layer with name equal to the layerName + :returns: List of Keras layer with name equal to the layerName parameter and the input column name, that performs the indexing. """ return [ StringIndexLayer( name=f"{self.getLayerName()}_{input_name}", - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), vocabulary=self.getLabelsArray(), mask_token=self.getMaskToken(), num_oov_indices=self.getNumOOVIndices(), diff --git a/src/kamae/spark/transformers/standard_scale.py b/src/kamae/spark/transformers/standard_scale.py index 9e3a76c9..79afe8e8 100644 --- a/src/kamae/spark/transformers/standard_scale.py +++ b/src/kamae/spark/transformers/standard_scale.py @@ -18,16 +18,17 @@ # pylint: disable=no-member from typing import List, Optional +import keras import numpy as np import pyspark.sql.functions as F -import tensorflow as tf from pyspark import keyword_only from pyspark.sql import DataFrame from pyspark.sql.types import ArrayType, DataType, DoubleType, FloatType +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.layers import StandardScaleLayer from kamae.spark.params import SingleInputSingleOutputParams, StandardScaleParams from kamae.spark.utils import single_input_single_output_array_transform -from kamae.tensorflow.layers import StandardScaleLayer from .base import BaseTransformer @@ -46,6 +47,9 @@ class StandardScaleTransformer( shape across all rows. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + @keyword_only def __init__( self, @@ -67,7 +71,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. :param mean: List of mean values corresponding to the input column. :param stddev: List of standard deviation values corresponding to the @@ -130,11 +134,11 @@ def _transform(self, dataset: DataFrame) -> DataFrame: return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> keras.layers.Layer: """ - Gets the tensorflow layer for the standard scaler transformer. + Gets the Keras layer for the standard scaler transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter + :returns: Keras layer with name equal to the layerName parameter that performs the standardization. """ np_mean = np.array(self.getMean()) @@ -142,8 +146,8 @@ def get_tf_layer(self) -> tf.keras.layers.Layer: mask_value = self.getMaskValue() return StandardScaleLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), mean=np_mean, variance=np_variance, mask_value=mask_value, diff --git a/src/kamae/spark/transformers/string_affix.py b/src/kamae/spark/transformers/string_affix.py index bdf9c35e..b7ffb6b6 100644 --- a/src/kamae/spark/transformers/string_affix.py +++ b/src/kamae/spark/transformers/string_affix.py @@ -25,9 +25,10 @@ from pyspark.sql import Column, DataFrame from pyspark.sql.types import DataType, StringType +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.tensorflow.layers import StringAffixLayer from kamae.spark.params import SingleInputSingleOutputParams from kamae.spark.utils import single_input_single_output_scalar_transform -from kamae.tensorflow.layers import StringAffixLayer from .base import BaseTransformer @@ -97,6 +98,9 @@ class StringAffixTransformer( Input columns must be of type string. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + @keyword_only def __init__( self, @@ -112,7 +116,7 @@ def __init__( Initializes the string affix transformer. :param inputCol: column to combine with prefix or suffix. Must be type string. :param outputCol: column to output the affixed string to. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :param inputDtype: Input data type to cast input column to before transforming. @@ -178,17 +182,17 @@ def add_prefix_suffix( return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> tf.keras.layers.Layer: """ - Gets the tensorflow layer for the string affix transformer. + Gets the Keras layer for the string affix transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter that + :returns: Keras layer with name equal to the layerName parameter that performs prefixing and suffixing. """ return StringAffixLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), prefix=self.getPrefix(), suffix=self.getSuffix(), ) diff --git a/src/kamae/spark/transformers/string_array_constant.py b/src/kamae/spark/transformers/string_array_constant.py index 04d8d8ff..865014ad 100644 --- a/src/kamae/spark/transformers/string_array_constant.py +++ b/src/kamae/spark/transformers/string_array_constant.py @@ -24,9 +24,10 @@ from pyspark.sql import DataFrame from pyspark.sql.types import DataType +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.tensorflow.layers import StringArrayConstantLayer from kamae.spark.params import ConstantStringArrayParams, SingleInputSingleOutputParams from kamae.spark.utils import single_input_single_output_scalar_transform -from kamae.tensorflow.layers import StringArrayConstantLayer from .base import BaseTransformer @@ -41,6 +42,9 @@ class StringArrayConstantTransformer( This transformer populates a column with a constant string array. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + @keyword_only def __init__( self, @@ -55,9 +59,9 @@ def __init__( Initializes the String Array Constant Transformer. :param inputCol: Input column used to copy shape from. Ignored for Spark, used - for Tensorflow. + for Keras. :param outputCol: column to fill with the constant. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :param inputDtype: Input data type to cast input column to before transforming. @@ -97,16 +101,16 @@ def _transform(self, dataset: DataFrame) -> DataFrame: ) return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> tf.keras.layers.Layer: """ - Gets the tensorflow layer for generating the keras model that outputs + Gets the Keras layer for generating the keras model that outputs the constant string array. - :returns: Tensorflow keras layer with name equal to the layerName parameter + :returns: Keras layer with name equal to the layerName parameter """ return StringArrayConstantLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), constant_string_array=self.getConstantStringArray(), ) diff --git a/src/kamae/spark/transformers/string_case.py b/src/kamae/spark/transformers/string_case.py index 370ff556..44945a79 100644 --- a/src/kamae/spark/transformers/string_case.py +++ b/src/kamae/spark/transformers/string_case.py @@ -25,9 +25,10 @@ from pyspark.sql import Column, DataFrame from pyspark.sql.types import DataType, StringType +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.tensorflow.layers import StringCaseLayer from kamae.spark.params import SingleInputSingleOutputParams from kamae.spark.utils import single_input_single_output_scalar_transform -from kamae.tensorflow.layers import StringCaseLayer from .base import BaseTransformer @@ -84,6 +85,9 @@ class StringCaseTransformer( on the input column. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + @keyword_only def __init__( self, @@ -103,7 +107,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :param stringCaseType: How to change the case of the string. Must be one of: - 'upper' @@ -158,16 +162,16 @@ def string_case(x: Column, case_type: str) -> Column: return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> tf.keras.layers.Layer: """ - Gets the tensorflow layer for the StringCaseLayer transformer. + Gets the Keras layer for the StringCaseLayer transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter that + :returns: Keras layer with name equal to the layerName parameter that performs the string casing operation. """ return StringCaseLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), string_case_type=self.getStringCaseType(), ) diff --git a/src/kamae/spark/transformers/string_concatenate.py b/src/kamae/spark/transformers/string_concatenate.py index 3117ea81..80c1e01c 100644 --- a/src/kamae/spark/transformers/string_concatenate.py +++ b/src/kamae/spark/transformers/string_concatenate.py @@ -25,9 +25,10 @@ from pyspark.sql import DataFrame from pyspark.sql.types import DataType, StringType +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.tensorflow.layers import StringConcatenateLayer from kamae.spark.params import MultiInputSingleOutputParams from kamae.spark.utils import multi_input_single_output_scalar_transform -from kamae.tensorflow.layers import StringConcatenateLayer from .base import BaseTransformer @@ -74,6 +75,9 @@ class StringConcatenateTransformer( single column using a separator. Input columns must be of type string. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + @keyword_only def __init__( self, @@ -88,7 +92,7 @@ def __init__( Initializes the string concatenate transformer. :param inputCols: columns to concatenate together. Must be of type string. :param outputCol: column to output the concatenated string to. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :param inputDtype: Input data type to cast input column(s) to before transforming. @@ -140,16 +144,16 @@ def _transform(self, dataset: DataFrame) -> DataFrame: ) return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> tf.keras.layers.Layer: """ - Gets the tensorflow layer for the concatenate transformer. + Gets the Keras layer for the concatenate transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter that + :returns: Keras layer with name equal to the layerName parameter that performs a concatenation. """ return StringConcatenateLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), separator=self.getSeparator(), ) diff --git a/src/kamae/spark/transformers/string_contains.py b/src/kamae/spark/transformers/string_contains.py index 4abc8db2..839928f2 100644 --- a/src/kamae/spark/transformers/string_contains.py +++ b/src/kamae/spark/transformers/string_contains.py @@ -24,6 +24,8 @@ from pyspark.sql import Column, DataFrame from pyspark.sql.types import DataType, StringType +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.tensorflow.layers import StringContainsLayer from kamae.spark.params import ( MultiInputSingleOutputParams, NegationParams, @@ -31,7 +33,6 @@ StringConstantParams, ) from kamae.spark.utils import multi_input_single_output_scalar_transform -from kamae.tensorflow.layers import StringContainsLayer from .base import BaseTransformer @@ -52,6 +53,9 @@ class StringContainsTransformer( Used for cases where you want to keep the input the same. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + @keyword_only def __init__( self, @@ -78,7 +82,7 @@ def __init__( operation. Only used in single input scenario. :param negation: Whether to negate the string contains operation. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :returns: None - class instantiated. """ @@ -149,17 +153,17 @@ def string_contains( ) return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> tf.keras.layers.Layer: """ - Gets the tensorflow layer for the StringContainsLayer transformer. + Gets the Keras layer for the StringContainsLayer transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter that + :returns: Keras layer with name equal to the layerName parameter that performs a string contains operation. """ return StringContainsLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), negation=self.getNegation(), string_constant=self.getStringConstant(), ) diff --git a/src/kamae/spark/transformers/string_contains_list.py b/src/kamae/spark/transformers/string_contains_list.py index 423816a3..37669950 100644 --- a/src/kamae/spark/transformers/string_contains_list.py +++ b/src/kamae/spark/transformers/string_contains_list.py @@ -25,6 +25,8 @@ from pyspark.sql import Column, DataFrame from pyspark.sql.types import DataType, StringType +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.tensorflow.layers import StringContainsListLayer from kamae.spark.params import ( ConstantStringArrayParams, NegationParams, @@ -32,7 +34,6 @@ ) from kamae.spark.transformers.base import BaseTransformer from kamae.spark.utils import single_input_single_output_scalar_transform -from kamae.tensorflow.layers import StringContainsListLayer class StringContainsListTransformer( @@ -47,6 +48,9 @@ class StringContainsListTransformer( constants in the passed constantStringArray. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + @keyword_only def __init__( self, @@ -70,7 +74,7 @@ def __init__( :param constantStringArray: String constant array to use in string contains list operation. :param negation: Whether to negate the string contains list operation. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :returns: None - class instantiated. """ @@ -124,11 +128,11 @@ def string_contains_list( ) return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> tf.keras.layers.Layer: """ - Gets the tensorflow layer for the StringContainsLayer transformer. + Gets the Keras layer for the StringContainsLayer transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter that + :returns: Keras layer with name equal to the layerName parameter that performs a string contains operation. """ @@ -137,8 +141,8 @@ def get_tf_layer(self) -> tf.keras.layers.Layer: return StringContainsListLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), negation=self.getNegation(), string_constant_list=self.getConstantStringArray(), ) diff --git a/src/kamae/spark/transformers/string_equals_if_statement.py b/src/kamae/spark/transformers/string_equals_if_statement.py index 80b49051..e9dd7a05 100644 --- a/src/kamae/spark/transformers/string_equals_if_statement.py +++ b/src/kamae/spark/transformers/string_equals_if_statement.py @@ -25,12 +25,13 @@ from pyspark.sql import Column, DataFrame from pyspark.sql.types import DataType, StringType +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.tensorflow.layers import StringEqualsIfStatementLayer from kamae.spark.params import ( MultiInputSingleOutputParams, SingleInputSingleOutputParams, ) from kamae.spark.utils import multi_input_single_output_scalar_transform -from kamae.tensorflow.layers import StringEqualsIfStatementLayer from .base import BaseTransformer @@ -127,6 +128,9 @@ class StringEqualsIfStatementTransformer( and columns. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + @keyword_only def __init__( self, @@ -153,7 +157,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :param valueToCompare: Optional str value to compare to input column. If not specified, then assumed to be the first input column. @@ -311,17 +315,17 @@ def _transform(self, dataset: DataFrame) -> DataFrame: return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> tf.keras.layers.Layer: """ - Gets the tensorflow layer for the string if equal statement transformer. + Gets the Keras layer for the string if equal statement transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter that + :returns: Keras layer with name equal to the layerName parameter that performs the string if equals statement. """ return StringEqualsIfStatementLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), value_to_compare=self.getValueToCompare(), result_if_true=self.getResultIfTrue(), result_if_false=self.getResultIfFalse(), diff --git a/src/kamae/spark/transformers/string_index.py b/src/kamae/spark/transformers/string_index.py index 390dfefd..b5ffb25a 100644 --- a/src/kamae/spark/transformers/string_index.py +++ b/src/kamae/spark/transformers/string_index.py @@ -24,12 +24,13 @@ from pyspark.sql import DataFrame from pyspark.sql.types import DataType, IntegerType, StringType +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.tensorflow.layers import StringIndexLayer from kamae.spark.params import SingleInputSingleOutputParams, StringIndexParams from kamae.spark.utils import ( indexer_udf, single_input_single_output_scalar_udf_transform, ) -from kamae.tensorflow.layers import StringIndexLayer from .base import BaseTransformer @@ -50,6 +51,9 @@ class StringIndexTransformer( characters. If you have null characters in your data, you should remove them. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + @keyword_only def __init__( self, @@ -72,7 +76,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :param stringOrderType: How to order the string indices. Options are 'frequencyAsc', 'frequencyDesc', 'alphabeticalAsc', @@ -134,17 +138,17 @@ def _transform(self, dataset: DataFrame) -> DataFrame: output_col, ) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> tf.keras.layers.Layer: """ - Gets the tensorflow layer for the string indexer transformer. + Gets the Keras layer for the string indexer transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter + :returns: Keras layer with name equal to the layerName parameter that performs the indexing. """ return StringIndexLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), vocabulary=self.getLabelsArray(), mask_token=self.getMaskToken(), num_oov_indices=self.getNumOOVIndices(), diff --git a/src/kamae/spark/transformers/string_isin_list.py b/src/kamae/spark/transformers/string_isin_list.py index cd96b33b..25dcff1f 100644 --- a/src/kamae/spark/transformers/string_isin_list.py +++ b/src/kamae/spark/transformers/string_isin_list.py @@ -24,13 +24,14 @@ from pyspark.sql import Column, DataFrame from pyspark.sql.types import DataType, StringType +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.tensorflow.layers import StringIsInListLayer from kamae.spark.params import ( ConstantStringArrayParams, NegationParams, SingleInputSingleOutputParams, ) from kamae.spark.utils import single_input_single_output_scalar_transform -from kamae.tensorflow.layers import StringIsInListLayer from .base import BaseTransformer @@ -47,6 +48,9 @@ class StringIsInListTransformer( constants in the passed constantStringArray. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + @keyword_only def __init__( self, @@ -70,7 +74,7 @@ def __init__( :param constantStringArray: String constant array to use in string isin list operation. :param negation: Whether to negate the string isin list operation. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :returns: None - class instantiated. """ @@ -121,11 +125,11 @@ def string_isin_list( ) return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> tf.keras.layers.Layer: """ - Gets the tensorflow layer for the StringIsInListLayer transformer. + Gets the Keras layer for the StringIsInListLayer transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter that + :returns: Keras layer with name equal to the layerName parameter that performs a string isin operation. """ @@ -135,7 +139,7 @@ def get_tf_layer(self) -> tf.keras.layers.Layer: return StringIsInListLayer( name=self.getLayerName(), negation=self.getNegation(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), string_constant_list=self.getConstantStringArray(), ) diff --git a/src/kamae/spark/transformers/string_list_to_string.py b/src/kamae/spark/transformers/string_list_to_string.py index 1a3d1b97..5d1f13d8 100644 --- a/src/kamae/spark/transformers/string_list_to_string.py +++ b/src/kamae/spark/transformers/string_list_to_string.py @@ -25,9 +25,10 @@ from pyspark.sql import DataFrame from pyspark.sql.types import ArrayType, DataType, StringType +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.tensorflow.layers import StringListToStringLayer from kamae.spark.params import SingleInputSingleOutputParams from kamae.spark.utils import single_input_single_output_array_transform -from kamae.tensorflow.layers import StringListToStringLayer from .base import BaseTransformer @@ -73,6 +74,9 @@ class StringListToStringTransformer( This transformer takes a column of string lists and joins them into a single string. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + @keyword_only def __init__( self, @@ -92,7 +96,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :param separator: Separator to use when joining the string list. Default is the empty string. @@ -138,17 +142,17 @@ def _transform(self, dataset: DataFrame) -> DataFrame: ) return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> tf.keras.layers.Layer: """ - Gets the tensorflow layer for the StringListToStringLayer transformer. + Gets the Keras layer for the StringListToStringLayer transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter that + :returns: Keras layer with name equal to the layerName parameter that joins the string list. """ return StringListToStringLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), separator=self.getSeparator(), axis=-1, keepdims=True, diff --git a/src/kamae/spark/transformers/string_map.py b/src/kamae/spark/transformers/string_map.py index df368e13..cf13a57f 100644 --- a/src/kamae/spark/transformers/string_map.py +++ b/src/kamae/spark/transformers/string_map.py @@ -25,9 +25,10 @@ from pyspark.sql import Column, DataFrame from pyspark.sql.types import DataType, StringType +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.tensorflow.layers import StringMapLayer from kamae.spark.params import SingleInputSingleOutputParams from kamae.spark.utils import single_input_single_output_scalar_transform -from kamae.tensorflow.layers import StringMapLayer from .base import BaseTransformer @@ -129,6 +130,9 @@ class StringMapTransformer( This transformer replaces a list of strings with the respective mapping value. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + @keyword_only def __init__( self, @@ -154,7 +158,7 @@ def __init__( :param stringReplaceValues: List of string replace constants. :param defaultReplaceValue: Default value to replace the unmatched strings with. If None, the original string is kept unchanged. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :returns: None - class instantiated. """ @@ -224,17 +228,17 @@ def _transform(self, dataset: DataFrame) -> DataFrame: ) return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> tf.keras.layers.Layer: """ - Gets the tensorflow layer for the StringMapLayer transformer. + Gets the Keras layer for the StringMapLayer transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter that + :returns: Keras layer with name equal to the layerName parameter that performs a string replace operation. """ return StringMapLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), string_match_values=self.getStringMatchValues(), string_replace_values=self.getStringReplaceValues(), default_replace_value=self.getDefaultReplaceValue(), diff --git a/src/kamae/spark/transformers/string_replace.py b/src/kamae/spark/transformers/string_replace.py index d1065731..f7fa80f5 100644 --- a/src/kamae/spark/transformers/string_replace.py +++ b/src/kamae/spark/transformers/string_replace.py @@ -25,13 +25,14 @@ from pyspark.sql import Column, DataFrame from pyspark.sql.types import DataType, StringType +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.tensorflow.layers import StringReplaceLayer from kamae.spark.params import ( MultiInputSingleOutputParams, SingleInputSingleOutputParams, StringRegexParams, ) from kamae.spark.utils import multi_input_single_output_scalar_transform -from kamae.tensorflow.layers import StringReplaceLayer from .base import BaseTransformer @@ -108,6 +109,9 @@ class StringReplaceTransformer( This is consistent in both spark and tensorflow components. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + @keyword_only def __init__( self, @@ -135,7 +139,7 @@ def __init__( operation. :param stringReplaceConstant: String constant to replace with. :param regex: Whether to allow regex-matching in the string matching. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :returns: None - class instantiated. """ @@ -263,17 +267,17 @@ def string_replace( ) return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> tf.keras.layers.Layer: """ - Gets the tensorflow layer for the StringReplaceLayer transformer. + Gets the Keras layer for the StringReplaceLayer transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter that + :returns: Keras layer with name equal to the layerName parameter that performs a string replace operation. """ return StringReplaceLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), regex=self.getRegex(), string_match_constant=self.getStringMatchConstant(), string_replace_constant=self.getStringReplaceConstant(), diff --git a/src/kamae/spark/transformers/string_to_string_list.py b/src/kamae/spark/transformers/string_to_string_list.py index 05c40825..63dd701d 100644 --- a/src/kamae/spark/transformers/string_to_string_list.py +++ b/src/kamae/spark/transformers/string_to_string_list.py @@ -26,9 +26,10 @@ from pyspark.sql import Column, DataFrame from pyspark.sql.types import DataType, StringType +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.tensorflow.layers import StringToStringListLayer from kamae.spark.params import SingleInputSingleOutputParams from kamae.spark.utils import single_input_single_output_scalar_transform -from kamae.tensorflow.layers import StringToStringListLayer from .base import BaseTransformer @@ -124,6 +125,9 @@ class StringToStringListTransformer( This transformer takes a column of string lists and joins them into a single string. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + @keyword_only def __init__( self, @@ -145,7 +149,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :param separator: Separator to use when joining the string list. Defaults to ",". @@ -209,17 +213,17 @@ def string_to_string_list(x: Column, separator: str) -> Column: ) return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> tf.keras.layers.Layer: """ - Gets the tensorflow layer for the StringToStringListLayer transformer. + Gets the Keras layer for the StringToStringListLayer transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter that + :returns: Keras layer with name equal to the layerName parameter that splits the string into a list of strings. """ return StringToStringListLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), separator=self.getSeparator(), default_value=self.getDefaultValue(), list_length=self.getListLength(), diff --git a/src/kamae/spark/transformers/sub_string_delim_at_index.py b/src/kamae/spark/transformers/sub_string_delim_at_index.py index 0b2f5edd..f8e87204 100644 --- a/src/kamae/spark/transformers/sub_string_delim_at_index.py +++ b/src/kamae/spark/transformers/sub_string_delim_at_index.py @@ -26,9 +26,10 @@ from pyspark.sql import DataFrame from pyspark.sql.types import DataType, StringType +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.tensorflow.layers import SubStringDelimAtIndexLayer from kamae.spark.params import SingleInputSingleOutputParams from kamae.spark.utils import single_input_single_output_scalar_transform -from kamae.tensorflow.layers import SubStringDelimAtIndexLayer from .base import BaseTransformer @@ -125,6 +126,9 @@ class SubStringDelimAtIndexTransformer( If the index is out of bounds, the default value is returned. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + @keyword_only def __init__( self, @@ -146,7 +150,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :param delimiter: Value to use to split the string into substrings. Default is "_". @@ -204,17 +208,17 @@ def _transform(self, dataset: DataFrame) -> DataFrame: ) return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> tf.keras.layers.Layer: """ - Gets the tensorflow layer for SubStringDelimAtIndexTransformer. + Gets the Keras layer for SubStringDelimAtIndexTransformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter that + :returns: Keras layer with name equal to the layerName parameter that performs sub string at delimiter. """ return SubStringDelimAtIndexLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), delimiter=self.getDelimiter(), index=self.getIndex(), default_value=self.getDefaultValue(), diff --git a/src/kamae/spark/transformers/subtract.py b/src/kamae/spark/transformers/subtract.py index bf4d4ca4..4a74c04b 100644 --- a/src/kamae/spark/transformers/subtract.py +++ b/src/kamae/spark/transformers/subtract.py @@ -20,7 +20,7 @@ from operator import sub from typing import List, Optional -import tensorflow as tf +import keras from pyspark import keyword_only from pyspark.sql import DataFrame from pyspark.sql.types import ( @@ -33,13 +33,14 @@ ShortType, ) +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.layers import SubtractLayer from kamae.spark.params import ( MathFloatConstantParams, MultiInputSingleOutputParams, SingleInputSingleOutputParams, ) from kamae.spark.utils import multi_input_single_output_scalar_transform -from kamae.tensorflow.layers import SubtractLayer from .base import BaseTransformer @@ -55,6 +56,9 @@ class SubtractTransformer( This transformer subtracts a column by a constant or another column. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + @keyword_only def __init__( self, @@ -77,7 +81,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :param mathFloatConstant: Optional constant to divide by. If not provided, then two input columns are required. @@ -133,16 +137,16 @@ def _transform(self, dataset: DataFrame) -> DataFrame: return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> keras.layers.Layer: """ - Gets the tensorflow layer for the divide transformer. + Gets the Keras layer for the divide transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter that + :returns: Keras layer with name equal to the layerName parameter that performs a divide operation. """ return SubtractLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), subtrahend=self.getMathFloatConstant(), ) diff --git a/src/kamae/spark/transformers/sum.py b/src/kamae/spark/transformers/sum.py index bca6ffc6..407b0948 100644 --- a/src/kamae/spark/transformers/sum.py +++ b/src/kamae/spark/transformers/sum.py @@ -20,7 +20,7 @@ from operator import add from typing import List, Optional -import tensorflow as tf +import keras from pyspark import keyword_only from pyspark.sql import DataFrame from pyspark.sql.types import ( @@ -33,13 +33,14 @@ ShortType, ) +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.layers import SumLayer from kamae.spark.params import ( MathFloatConstantParams, MultiInputSingleOutputParams, SingleInputSingleOutputParams, ) from kamae.spark.utils import multi_input_single_output_scalar_transform -from kamae.tensorflow.layers import SumLayer from .base import BaseTransformer @@ -55,6 +56,9 @@ class SumTransformer( This transformer sums a column with a constant or another column. """ + supported_backends = ALL_BACKENDS + jit_compatible = True + @keyword_only def __init__( self, @@ -77,7 +81,7 @@ def __init__( transforming. :param outputDtype: Output data type to cast the output column to after transforming. - :param layerName: Name of the layer. Used as the name of the tensorflow layer + :param layerName: Name of the layer. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :param mathFloatConstant: Optional constant to sum. If not provided, then two input columns are required. @@ -133,16 +137,16 @@ def _transform(self, dataset: DataFrame) -> DataFrame: return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> keras.layers.Layer: """ - Gets the tensorflow layer for the sum transformer. + Gets the Keras layer for the sum transformer. - :returns: Tensorflow keras layer with name equal to the layerName parameter that + :returns: Keras layer with name equal to the layerName parameter that performs a sum operation. """ return SumLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), addend=self.getMathFloatConstant(), ) diff --git a/src/kamae/spark/transformers/unix_timestamp_to_date_time.py b/src/kamae/spark/transformers/unix_timestamp_to_date_time.py index 6a1b65cc..a404b0f2 100644 --- a/src/kamae/spark/transformers/unix_timestamp_to_date_time.py +++ b/src/kamae/spark/transformers/unix_timestamp_to_date_time.py @@ -24,6 +24,8 @@ from pyspark.sql import Column, DataFrame, SparkSession from pyspark.sql.types import DataType, DoubleType, LongType +from kamae.keras.core.backend import TENSORFLOW_ONLY +from kamae.keras.tensorflow.layers import UnixTimestampToDateTimeLayer from kamae.spark.params import ( DateTimeParams, SingleInputSingleOutputParams, @@ -31,7 +33,6 @@ ) from kamae.spark.transformers.base import BaseTransformer from kamae.spark.utils import single_input_single_output_scalar_transform -from kamae.tensorflow.layers import UnixTimestampToDateTimeLayer class UnixTimestampToDateTimeTransformer( @@ -46,6 +47,9 @@ class UnixTimestampToDateTimeTransformer( yyyy-MM-dd format. """ + supported_backends = TENSORFLOW_ONLY + jit_compatible = False + @keyword_only def __init__( self, @@ -69,7 +73,7 @@ def __init__( :param unit: Unit of the timestamp. Can be `milliseconds` (shorthand `ms`) or `seconds` (shorthand `s`). Default is `s` (seconds). :param includeTime: Whether to include the time in the output. Default is True. - :param layerName: Layer name. Used as the name of the tensorflow layer + :param layerName: Layer name. Used as the name of the Keras layer in the keras model. If not set, we use the uid of the Spark transformer. :returns: None - class instantiated. """ @@ -153,16 +157,16 @@ def unix_timestamp_to_datetime( ) return dataset.withColumn(self.getOutputCol(), output_col) - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> tf.keras.layers.Layer: """ - Gets the tensorflow layer that performs the unix timestamp to date transform. + Gets the Keras layer that performs the unix timestamp to date transform. - :returns: Tensorflow layer that performs the unix timestamp to date transform. + :returns: Keras layer that performs the unix timestamp to date transform. """ return UnixTimestampToDateTimeLayer( name=self.getLayerName(), - input_dtype=self.getInputTFDtype(), - output_dtype=self.getOutputTFDtype(), + input_dtype=self.getInputKerasDtype(), + output_dtype=self.getOutputKerasDtype(), unit=self.getUnit(), include_time=self.getIncludeTime(), ) diff --git a/src/kamae/spark/utils/user_defined_functions.py b/src/kamae/spark/utils/user_defined_functions.py index db0cdce1..d7042039 100644 --- a/src/kamae/spark/utils/user_defined_functions.py +++ b/src/kamae/spark/utils/user_defined_functions.py @@ -14,7 +14,7 @@ from typing import List, Optional -import tensorflow as tf +import numpy as np from kamae.spark.utils.indexer_utils import safe_hash64 @@ -185,14 +185,15 @@ def min_hash_udf( # This matches the behavior of the TensorFlow layer. if mask_value is not None: hashed_vals = [ - tf.int32.max + np.iinfo(np.int32).max if label == mask_value - else hash_udf(label=f"{label}{i}", num_bins=tf.int32.max) + else hash_udf(label=f"{label}{i}", num_bins=np.iinfo(np.int32).max) for label in labels ] else: hashed_vals = [ - hash_udf(label=f"{label}{i}", num_bins=tf.int32.max) for label in labels + hash_udf(label=f"{label}{i}", num_bins=np.iinfo(np.int32).max) + for label in labels ] min_hash_val = min(hashed_vals) min_hash_bit = min_hash_val & 1 diff --git a/src/kamae/utils/dtype_enum.py b/src/kamae/utils/dtype_enum.py index d058e443..08edb97d 100644 --- a/src/kamae/utils/dtype_enum.py +++ b/src/kamae/utils/dtype_enum.py @@ -15,7 +15,6 @@ from enum import Enum from typing import Any, Dict -import tensorflow as tf from pyspark.sql.types import ( BooleanType, ByteType, @@ -33,7 +32,7 @@ class DType(Enum): """ Enum class for supported data types in Kamae. Contains a string name, the corresponding Spark data type, the corresponding - TensorFlow data type, and the number of bytes the data type takes up. + Keras data type, and the number of bytes the data type takes up. String is a special case, as it can be of any length, so the number of bytes is set to 0. """ @@ -41,31 +40,31 @@ class DType(Enum): STRING = ( "string", StringType(), - tf.string, + "string", 0, False, False, ) # String can be of any length - BIGINT = ("bigint", LongType(), tf.int64, 8, False, True) - INT = ("int", IntegerType(), tf.int32, 4, False, True) - SMALLINT = ("smallint", ShortType(), tf.int16, 2, False, True) - TINYINT = ("tinyint", ByteType(), tf.int8, 1, False, True) - FLOAT = ("float", FloatType(), tf.float32, 4, True, False) - DOUBLE = ("double", DoubleType(), tf.float64, 8, True, False) - BOOLEAN = ("boolean", BooleanType(), tf.bool, 1, False, False) + BIGINT = ("bigint", LongType(), "int64", 8, False, True) + INT = ("int", IntegerType(), "int32", 4, False, True) + SMALLINT = ("smallint", ShortType(), "int16", 2, False, True) + TINYINT = ("tinyint", ByteType(), "int8", 1, False, True) + FLOAT = ("float", FloatType(), "float32", 4, True, False) + DOUBLE = ("double", DoubleType(), "float64", 8, True, False) + BOOLEAN = ("boolean", BooleanType(), "bool", 1, False, False) def __init__( self, dtype_name: str, spark_dtype: DataType, - tf_dtype: tf.dtypes.DType, + keras_dtype: str, bytes: int, is_floating: bool = False, is_integer: bool = False, ) -> None: self.dtype_name = dtype_name self.spark_dtype = spark_dtype - self.tf_dtype = tf_dtype + self.keras_dtype = keras_dtype self.bytes = bytes self.is_floating = is_floating self.is_integer = is_integer @@ -74,7 +73,7 @@ def as_dict(self) -> Dict[str, Any]: return { "dtype_name": self.dtype_name, "spark_dtype": self.spark_dtype, - "tf_dtype": self.tf_dtype, + "keras_dtype": self.keras_dtype, "bytes": self.bytes, "is_floating": self.is_floating, "is_integer": self.is_integer, diff --git a/tests/kamae/graph/test_pipeline_graph.py b/tests/kamae/graph/test_pipeline_graph.py index fdef8e9d..fcd69517 100644 --- a/tests/kamae/graph/test_pipeline_graph.py +++ b/tests/kamae/graph/test_pipeline_graph.py @@ -115,6 +115,19 @@ def test_get_layer_output_from_layer_store(self, layer_name, expected): ("layer2", "layer2_output0"), ], ), + ( + { + "layer1": { + "name": "layer1", + "layer": None, + "inputs": ["input1"], + "outputs": ["layer1"], + }, + }, + [ + ("input1", "layer1"), + ], + ), ( { "layer1": { @@ -299,7 +312,7 @@ def test_sort_inputs(self, layer_name, stage_dict, input_dict, expected_outputs) assert outputs == expected_outputs @pytest.mark.parametrize( - "tf_input_schema, expected_inputs, expected_layer_store", + "input_schema, expected_inputs, expected_layer_store", [ ( [ @@ -323,7 +336,7 @@ def test_sort_inputs(self, layer_name, stage_dict, input_dict, expected_outputs) ) def test_build_keras_inputs( self, - tf_input_schema, + input_schema, expected_inputs, expected_layer_store, ): @@ -331,7 +344,7 @@ def test_build_keras_inputs( pipeline_graph = PipelineGraph(stage_dict={}) # when pipeline_graph.build_keras_inputs( - tf_input_schema=tf_input_schema, + input_schema=input_schema, ) # then for key, value in pipeline_graph.inputs.items(): diff --git a/src/kamae/sklearn/__init__.py b/tests/kamae/keras/core/__init__.py similarity index 100% rename from src/kamae/sklearn/__init__.py rename to tests/kamae/keras/core/__init__.py diff --git a/src/kamae/tensorflow/__init__.py b/tests/kamae/keras/core/layers/__init__.py similarity index 100% rename from src/kamae/tensorflow/__init__.py rename to tests/kamae/keras/core/layers/__init__.py diff --git a/tests/kamae/tensorflow/layers/test_absolute_value.py b/tests/kamae/keras/core/layers/test_absolute_value.py similarity index 98% rename from tests/kamae/tensorflow/layers/test_absolute_value.py rename to tests/kamae/keras/core/layers/test_absolute_value.py index d6560973..241fcb66 100644 --- a/tests/kamae/tensorflow/layers/test_absolute_value.py +++ b/tests/kamae/keras/core/layers/test_absolute_value.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import AbsoluteValueLayer +from kamae.keras.core.layers import AbsoluteValueLayer class TestAbsoluteValue: diff --git a/tests/kamae/tensorflow/layers/test_array_concatenate.py b/tests/kamae/keras/core/layers/test_array_concatenate.py similarity index 98% rename from tests/kamae/tensorflow/layers/test_array_concatenate.py rename to tests/kamae/keras/core/layers/test_array_concatenate.py index 4f738453..4b2ee981 100644 --- a/tests/kamae/tensorflow/layers/test_array_concatenate.py +++ b/tests/kamae/keras/core/layers/test_array_concatenate.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import ArrayConcatenateLayer +from kamae.keras.core.layers import ArrayConcatenateLayer class TestArrayConcatenate: diff --git a/tests/kamae/tensorflow/layers/test_array_crop.py b/tests/kamae/keras/core/layers/test_array_crop.py similarity index 99% rename from tests/kamae/tensorflow/layers/test_array_crop.py rename to tests/kamae/keras/core/layers/test_array_crop.py index 609513cc..7394f7be 100644 --- a/tests/kamae/tensorflow/layers/test_array_crop.py +++ b/tests/kamae/keras/core/layers/test_array_crop.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import ArrayCropLayer +from kamae.keras.core.layers import ArrayCropLayer class TestArrayCrop: diff --git a/tests/kamae/keras/core/layers/test_array_reduce_max.py b/tests/kamae/keras/core/layers/test_array_reduce_max.py new file mode 100644 index 00000000..cb517d89 --- /dev/null +++ b/tests/kamae/keras/core/layers/test_array_reduce_max.py @@ -0,0 +1,83 @@ +# Copyright [2024] Expedia, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest +import tensorflow as tf + +from kamae.keras.core.layers import ArrayReduceMaxLayer + + +class TestArrayReduceMax: + @pytest.mark.parametrize( + "input_tensor, name, default_value, expected_output", + [ + ( + tf.constant([[1.0, 3.0, 2.0], [5.0, 4.0, 6.0]]), + "basic_max", + 0.0, + tf.constant([3.0, 6.0]), + ), + ( + tf.constant([[-5.0, -1.0, -3.0]]), + "negative_max", + 0.0, + tf.constant([-1.0]), + ), + ( + tf.constant([[7.0]]), + "single_element", + 0.0, + tf.constant([7.0]), + ), + ( + tf.constant([[float("nan"), 2.0, 3.0]]), + "nan_handling", + -1.0, + tf.constant([-1.0]), + ), + ( + tf.constant([[float("nan"), float("nan")]]), + "all_nan", + -99.0, + tf.constant([-99.0]), + ), + ( + tf.constant([[1.0, 2.0, 3.0]]), + "custom_default", + 42.0, + tf.constant([3.0]), + ), + ], + ) + def test_array_reduce_max(self, input_tensor, name, default_value, expected_output): + layer = ArrayReduceMaxLayer(name=name, default_value=default_value) + output_tensor = layer(input_tensor) + + assert layer.name == name + assert output_tensor.shape == expected_output.shape + tf.debugging.assert_near(output_tensor, expected_output, atol=1e-6) + + def test_array_reduce_max_batch(self): + input_tensor = tf.constant([[1.0, 5.0, 3.0], [9.0, 2.0, 7.0], [4.0, 4.0, 4.0]]) + layer = ArrayReduceMaxLayer(name="batch_test") + output_tensor = layer(input_tensor) + expected = tf.constant([5.0, 9.0, 4.0]) + tf.debugging.assert_near(output_tensor, expected, atol=1e-6) + + def test_get_config(self): + layer = ArrayReduceMaxLayer(name="config_test", default_value=5.0) + config = layer.get_config() + assert config["default_value"] == 5.0 + assert config["name"] == "config_test" diff --git a/tests/kamae/tensorflow/layers/test_array_split.py b/tests/kamae/keras/core/layers/test_array_split.py similarity index 98% rename from tests/kamae/tensorflow/layers/test_array_split.py rename to tests/kamae/keras/core/layers/test_array_split.py index 0f724022..0a328c84 100644 --- a/tests/kamae/tensorflow/layers/test_array_split.py +++ b/tests/kamae/keras/core/layers/test_array_split.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import ArraySplitLayer +from kamae.keras.core.layers import ArraySplitLayer class TestArraySplit: diff --git a/tests/kamae/tensorflow/layers/test_array_subtract_minimum.py b/tests/kamae/keras/core/layers/test_array_subtract_minimum.py similarity index 99% rename from tests/kamae/tensorflow/layers/test_array_subtract_minimum.py rename to tests/kamae/keras/core/layers/test_array_subtract_minimum.py index 9da386d2..9b4f73b8 100644 --- a/tests/kamae/tensorflow/layers/test_array_subtract_minimum.py +++ b/tests/kamae/keras/core/layers/test_array_subtract_minimum.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import ArraySubtractMinimumLayer +from kamae.keras.core.layers import ArraySubtractMinimumLayer class TestArraySubtractMinimum: diff --git a/tests/kamae/keras/core/layers/test_base.py b/tests/kamae/keras/core/layers/test_base.py new file mode 100644 index 00000000..a841c332 --- /dev/null +++ b/tests/kamae/keras/core/layers/test_base.py @@ -0,0 +1,207 @@ +# Copyright [2024] Expedia, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for BaseLayer""" + +from typing import Any, List, Optional + +import keras +import pytest +import tensorflow as tf +from keras import ops + +from kamae.keras.core.backend import ALL_BACKENDS +from kamae.keras.core.base import BaseLayer +from kamae.keras.core.utils.input_utils import enforce_single_tensor_input + + +@keras.saving.register_keras_serializable(package="kamae_test") +class MockLayer(BaseLayer): + """Mock layer for testing BaseLayer""" + + supported_backends = ALL_BACKENDS + jit_compatible = False + + @property + def compatible_dtypes(self) -> Optional[List[str]]: + return None + + @enforce_single_tensor_input + def _call(self, inputs, **kwargs: Any): + return ops.multiply(inputs, 2.0) + + +@keras.saving.register_keras_serializable(package="kamae_test") +class MockLayerWithCompatibleDtypes(BaseLayer): + """Mock layer with specific compatible dtypes""" + + supported_backends = ALL_BACKENDS + jit_compatible = False + + @property + def compatible_dtypes(self) -> Optional[List[str]]: + return ["float32", "float64"] + + @enforce_single_tensor_input + def _call(self, inputs, **kwargs: Any): + return ops.multiply(inputs, 2.0) + + +class TestBaseLayer: + """Test suite for BaseLayer""" + + def test_instantiation(self): + """Test layer instantiation""" + layer = MockLayer(name="test_layer") + assert layer.name == "test_layer" + assert layer._input_dtype is None + assert layer._output_dtype is None + + def test_instantiation_with_dtypes(self): + """Test layer instantiation with dtype specification""" + layer = MockLayer( + name="test_layer", input_dtype="float32", output_dtype="float64" + ) + assert layer._input_dtype == "float32" + assert layer._output_dtype == "float64" + + def test_basic_call(self): + """Test basic layer call""" + layer = MockLayer() + x = tf.constant([[1.0, 2.0], [3.0, 4.0]]) + output = layer(x) + expected = tf.constant([[2.0, 4.0], [6.0, 8.0]]) + tf.debugging.assert_near(output, expected) + + def test_output_dtype_casting(self): + """Test output dtype casting""" + layer = MockLayer(output_dtype="float64") + x = tf.constant([[1.0, 2.0]], dtype=tf.float32) + output = layer(x) + assert keras.backend.standardize_dtype(output.dtype) == "float64" + + def test_input_dtype_casting(self): + """Test input dtype casting""" + layer = MockLayer(input_dtype="float32") + x = tf.constant([[1, 2]], dtype=tf.int32) + output = layer(x) + # Layer should cast int32 to float32, compute, and return float32 + assert keras.backend.standardize_dtype(output.dtype) == "float32" + + def test_input_output_dtype_casting(self): + """Test combined input and output dtype casting""" + layer = MockLayer(input_dtype="float32", output_dtype="float64") + x = tf.constant([[1, 2]], dtype=tf.int32) + output = layer(x) + # Should cast int32 -> float32 (input), compute, cast -> float64 (output) + assert keras.backend.standardize_dtype(output.dtype) == "float64" + + def test_compatible_dtypes_validation_pass(self): + """Test compatible dtypes validation - should pass""" + layer = MockLayerWithCompatibleDtypes() + x = tf.constant([[1.0, 2.0]], dtype=tf.float32) + output = layer(x) # Should not raise + assert output is not None + + def test_compatible_dtypes_validation_fail(self): + """Test compatible dtypes validation - should fail""" + layer = MockLayerWithCompatibleDtypes() + x = tf.constant([[1, 2]], dtype=tf.int32) + with pytest.raises(TypeError, match="not a compatible dtype"): + layer(x) + + def test_compatible_dtypes_with_input_casting(self): + """Test compatible dtypes validation with input casting""" + layer = MockLayerWithCompatibleDtypes(input_dtype="float32") + x = tf.constant([[1, 2]], dtype=tf.int32) + # Should cast int32 to float32 first, then pass validation + output = layer(x) + assert output is not None + + def test_invalid_input_dtype_for_layer(self): + """Test that specifying incompatible input_dtype raises error""" + with pytest.raises(ValueError, match="not a compatible dtype"): + layer = MockLayerWithCompatibleDtypes(input_dtype="int32") + x = tf.constant([[1, 2]], dtype=tf.int32) + layer(x) + + def test_force_cast_float_input_float_constant(self): + """Test force cast with float input and float constant""" + layer = MockLayer() + x = tf.constant([1.5, 2.5], dtype=tf.float32) + cast_input, cast_const = layer._force_cast_to_compatible_numeric_type(x, 3.14) + assert keras.backend.standardize_dtype(cast_input.dtype) == "float32" + assert keras.backend.standardize_dtype(cast_const.dtype) == "float32" + tf.debugging.assert_near(cast_const, tf.constant(3.14, dtype=tf.float32)) + + def test_force_cast_int_input_int_constant(self): + """Test force cast with int input and int constant""" + layer = MockLayer() + x = tf.constant([1, 2, 3], dtype=tf.int32) + cast_input, cast_const = layer._force_cast_to_compatible_numeric_type(x, 5) + assert keras.backend.standardize_dtype(cast_input.dtype) == "int32" + assert keras.backend.standardize_dtype(cast_const.dtype) == "int32" + tf.debugging.assert_equal(cast_const, tf.constant(5, dtype=tf.int32)) + + def test_force_cast_int_input_float_constant(self): + """Test force cast with int input and float constant - should promote to float""" + layer = MockLayer() + x = tf.constant([1, 2, 3], dtype=tf.int64) + cast_input, cast_const = layer._force_cast_to_compatible_numeric_type(x, 3.14) + # Should promote to float64 + assert keras.backend.standardize_dtype(cast_input.dtype) == "float64" + assert keras.backend.standardize_dtype(cast_const.dtype) == "float64" + + def test_force_cast_int_input_integer_valued_float(self): + """Test force cast with int input and integer-valued float - should keep as int""" + layer = MockLayer() + x = tf.constant([1, 2, 3], dtype=tf.int32) + cast_input, cast_const = layer._force_cast_to_compatible_numeric_type(x, 5.0) + # 5.0 is integer-valued, so should keep as int32 + assert keras.backend.standardize_dtype(cast_input.dtype) == "int32" + assert keras.backend.standardize_dtype(cast_const.dtype) == "int32" + tf.debugging.assert_equal(cast_const, tf.constant(5, dtype=tf.int32)) + + def test_get_config(self): + """Test get_config returns correct configuration""" + layer = MockLayer( + name="test_layer", input_dtype="float32", output_dtype="float64" + ) + config = layer.get_config() + assert config["name"] == "test_layer" + assert config["input_dtype"] == "float32" + assert config["output_dtype"] == "float64" + + def test_serialization_round_trip(self): + """Test layer can be serialized and deserialized""" + original = MockLayer( + name="test_layer", input_dtype="float32", output_dtype="float64" + ) + config = original.get_config() + recreated = MockLayer.from_config(config) + + assert recreated.name == original.name + assert recreated._input_dtype == original._input_dtype + assert recreated._output_dtype == original._output_dtype + + # Test that recreated layer works + x = tf.constant([[1.0, 2.0]]) + output = recreated(x) + assert keras.backend.standardize_dtype(output.dtype) == "float64" + + def test_autocast_disabled(self): + """Test that autocast is disabled""" + layer = MockLayer() + assert layer._autocast is False + assert layer._convert_input_args is False diff --git a/tests/kamae/tensorflow/layers/test_bearing_angle.py b/tests/kamae/keras/core/layers/test_bearing_angle.py similarity index 99% rename from tests/kamae/tensorflow/layers/test_bearing_angle.py rename to tests/kamae/keras/core/layers/test_bearing_angle.py index ffd2ac88..4443f889 100644 --- a/tests/kamae/tensorflow/layers/test_bearing_angle.py +++ b/tests/kamae/keras/core/layers/test_bearing_angle.py @@ -19,7 +19,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import BearingAngleLayer +from kamae.keras.core.layers import BearingAngleLayer class TestBearingAngle: diff --git a/tests/kamae/tensorflow/layers/test_bin.py b/tests/kamae/keras/core/layers/test_bin.py similarity index 99% rename from tests/kamae/tensorflow/layers/test_bin.py rename to tests/kamae/keras/core/layers/test_bin.py index 43c378cf..676103e9 100644 --- a/tests/kamae/tensorflow/layers/test_bin.py +++ b/tests/kamae/keras/core/layers/test_bin.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import BinLayer +from kamae.keras.core.layers import BinLayer class TestBin: diff --git a/tests/kamae/tensorflow/layers/test_conditional_standard_scale.py b/tests/kamae/keras/core/layers/test_conditional_standard_scale.py similarity index 99% rename from tests/kamae/tensorflow/layers/test_conditional_standard_scale.py rename to tests/kamae/keras/core/layers/test_conditional_standard_scale.py index 9a3f0806..08f232cc 100644 --- a/tests/kamae/tensorflow/layers/test_conditional_standard_scale.py +++ b/tests/kamae/keras/core/layers/test_conditional_standard_scale.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import ConditionalStandardScaleLayer +from kamae.keras.core.layers import ConditionalStandardScaleLayer class TestConditionalStandardScale: diff --git a/tests/kamae/tensorflow/layers/test_cosine_similarity.py b/tests/kamae/keras/core/layers/test_cosine_similarity.py similarity index 99% rename from tests/kamae/tensorflow/layers/test_cosine_similarity.py rename to tests/kamae/keras/core/layers/test_cosine_similarity.py index 28761e67..b196e9ee 100644 --- a/tests/kamae/tensorflow/layers/test_cosine_similarity.py +++ b/tests/kamae/keras/core/layers/test_cosine_similarity.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import CosineSimilarityLayer +from kamae.keras.core.layers import CosineSimilarityLayer class TestCosineSimilarity: diff --git a/tests/kamae/tensorflow/layers/test_divide.py b/tests/kamae/keras/core/layers/test_divide.py similarity index 98% rename from tests/kamae/tensorflow/layers/test_divide.py rename to tests/kamae/keras/core/layers/test_divide.py index bb85cb8a..f2c9985d 100644 --- a/tests/kamae/tensorflow/layers/test_divide.py +++ b/tests/kamae/keras/core/layers/test_divide.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import DivideLayer +from kamae.keras.core.layers import DivideLayer class TestDivide: diff --git a/tests/kamae/tensorflow/layers/test_exp.py b/tests/kamae/keras/core/layers/test_exp.py similarity index 98% rename from tests/kamae/tensorflow/layers/test_exp.py rename to tests/kamae/keras/core/layers/test_exp.py index 2385fc0d..94fbf1fc 100644 --- a/tests/kamae/tensorflow/layers/test_exp.py +++ b/tests/kamae/keras/core/layers/test_exp.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import ExpLayer +from kamae.keras.core.layers import ExpLayer class TestExp: diff --git a/tests/kamae/tensorflow/layers/test_exponent.py b/tests/kamae/keras/core/layers/test_exponent.py similarity index 98% rename from tests/kamae/tensorflow/layers/test_exponent.py rename to tests/kamae/keras/core/layers/test_exponent.py index 02e88a8f..452fcbc1 100644 --- a/tests/kamae/tensorflow/layers/test_exponent.py +++ b/tests/kamae/keras/core/layers/test_exponent.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import ExponentLayer +from kamae.keras.core.layers import ExponentLayer class TestExponent: diff --git a/tests/kamae/tensorflow/layers/test_haversine_distance.py b/tests/kamae/keras/core/layers/test_haversine_distance.py similarity index 99% rename from tests/kamae/tensorflow/layers/test_haversine_distance.py rename to tests/kamae/keras/core/layers/test_haversine_distance.py index f1344765..1b8f18ed 100644 --- a/tests/kamae/tensorflow/layers/test_haversine_distance.py +++ b/tests/kamae/keras/core/layers/test_haversine_distance.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import HaversineDistanceLayer +from kamae.keras.core.layers import HaversineDistanceLayer class TestHaversineDistance: diff --git a/tests/kamae/tensorflow/layers/test_identity.py b/tests/kamae/keras/core/layers/test_identity.py similarity index 98% rename from tests/kamae/tensorflow/layers/test_identity.py rename to tests/kamae/keras/core/layers/test_identity.py index bdafe347..fa96fd38 100644 --- a/tests/kamae/tensorflow/layers/test_identity.py +++ b/tests/kamae/keras/core/layers/test_identity.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import IdentityLayer +from kamae.keras.core.layers import IdentityLayer class TestIdentity: diff --git a/tests/kamae/tensorflow/layers/test_impute.py b/tests/kamae/keras/core/layers/test_impute.py similarity index 98% rename from tests/kamae/tensorflow/layers/test_impute.py rename to tests/kamae/keras/core/layers/test_impute.py index 89288d63..9c158452 100644 --- a/tests/kamae/tensorflow/layers/test_impute.py +++ b/tests/kamae/keras/core/layers/test_impute.py @@ -16,7 +16,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import ImputeLayer +from kamae.keras.core.layers import ImputeLayer class TestImpute: diff --git a/tests/kamae/tensorflow/layers/test_log.py b/tests/kamae/keras/core/layers/test_log.py similarity index 98% rename from tests/kamae/tensorflow/layers/test_log.py rename to tests/kamae/keras/core/layers/test_log.py index 9b669808..04405891 100644 --- a/tests/kamae/tensorflow/layers/test_log.py +++ b/tests/kamae/keras/core/layers/test_log.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import LogLayer +from kamae.keras.core.layers import LogLayer class TestLog: diff --git a/tests/kamae/tensorflow/layers/test_logical_and.py b/tests/kamae/keras/core/layers/test_logical_and.py similarity index 98% rename from tests/kamae/tensorflow/layers/test_logical_and.py rename to tests/kamae/keras/core/layers/test_logical_and.py index 0f4d2a01..28ce9b93 100644 --- a/tests/kamae/tensorflow/layers/test_logical_and.py +++ b/tests/kamae/keras/core/layers/test_logical_and.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import LogicalAndLayer +from kamae.keras.core.layers import LogicalAndLayer class TestLogicalAnd: diff --git a/tests/kamae/tensorflow/layers/test_logical_not.py b/tests/kamae/keras/core/layers/test_logical_not.py similarity index 98% rename from tests/kamae/tensorflow/layers/test_logical_not.py rename to tests/kamae/keras/core/layers/test_logical_not.py index 662d0da2..720e6abc 100644 --- a/tests/kamae/tensorflow/layers/test_logical_not.py +++ b/tests/kamae/keras/core/layers/test_logical_not.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import LogicalNotLayer +from kamae.keras.core.layers import LogicalNotLayer class TestLogicalNot: diff --git a/tests/kamae/tensorflow/layers/test_logical_or.py b/tests/kamae/keras/core/layers/test_logical_or.py similarity index 98% rename from tests/kamae/tensorflow/layers/test_logical_or.py rename to tests/kamae/keras/core/layers/test_logical_or.py index ba66fb36..7f24c6e6 100644 --- a/tests/kamae/tensorflow/layers/test_logical_or.py +++ b/tests/kamae/keras/core/layers/test_logical_or.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import LogicalOrLayer +from kamae.keras.core.layers import LogicalOrLayer class TestLogicalOr: diff --git a/tests/kamae/tensorflow/layers/test_max.py b/tests/kamae/keras/core/layers/test_max.py similarity index 98% rename from tests/kamae/tensorflow/layers/test_max.py rename to tests/kamae/keras/core/layers/test_max.py index a38bf520..8309b292 100644 --- a/tests/kamae/tensorflow/layers/test_max.py +++ b/tests/kamae/keras/core/layers/test_max.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import MaxLayer +from kamae.keras.core.layers import MaxLayer class TestMax: diff --git a/tests/kamae/tensorflow/layers/test_mean.py b/tests/kamae/keras/core/layers/test_mean.py similarity index 98% rename from tests/kamae/tensorflow/layers/test_mean.py rename to tests/kamae/keras/core/layers/test_mean.py index 5aad1df2..eab98575 100644 --- a/tests/kamae/tensorflow/layers/test_mean.py +++ b/tests/kamae/keras/core/layers/test_mean.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import MeanLayer +from kamae.keras.core.layers import MeanLayer class TestMean: diff --git a/tests/kamae/tensorflow/layers/test_min.py b/tests/kamae/keras/core/layers/test_min.py similarity index 98% rename from tests/kamae/tensorflow/layers/test_min.py rename to tests/kamae/keras/core/layers/test_min.py index 28b3bc4f..9fda2d61 100644 --- a/tests/kamae/tensorflow/layers/test_min.py +++ b/tests/kamae/keras/core/layers/test_min.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import MinLayer +from kamae.keras.core.layers import MinLayer class TestMin: diff --git a/tests/kamae/tensorflow/layers/test_min_max_scale.py b/tests/kamae/keras/core/layers/test_min_max_scale.py similarity index 99% rename from tests/kamae/tensorflow/layers/test_min_max_scale.py rename to tests/kamae/keras/core/layers/test_min_max_scale.py index ccd810a9..39c64acf 100644 --- a/tests/kamae/tensorflow/layers/test_min_max_scale.py +++ b/tests/kamae/keras/core/layers/test_min_max_scale.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import MinMaxScaleLayer +from kamae.keras.core.layers import MinMaxScaleLayer class TestMinMaxScale: diff --git a/tests/kamae/tensorflow/layers/test_modulo.py b/tests/kamae/keras/core/layers/test_modulo.py similarity index 98% rename from tests/kamae/tensorflow/layers/test_modulo.py rename to tests/kamae/keras/core/layers/test_modulo.py index 1a298356..96c07b31 100644 --- a/tests/kamae/tensorflow/layers/test_modulo.py +++ b/tests/kamae/keras/core/layers/test_modulo.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import ModuloLayer +from kamae.keras.core.layers import ModuloLayer class TestModulo: diff --git a/tests/kamae/tensorflow/layers/test_multiply.py b/tests/kamae/keras/core/layers/test_multiply.py similarity index 98% rename from tests/kamae/tensorflow/layers/test_multiply.py rename to tests/kamae/keras/core/layers/test_multiply.py index 89ba1ff9..43c4cc2a 100644 --- a/tests/kamae/tensorflow/layers/test_multiply.py +++ b/tests/kamae/keras/core/layers/test_multiply.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import MultiplyLayer +from kamae.keras.core.layers import MultiplyLayer class TestMultiply: diff --git a/tests/kamae/tensorflow/layers/test_numerical_if_statement.py b/tests/kamae/keras/core/layers/test_numerical_if_statement.py similarity index 99% rename from tests/kamae/tensorflow/layers/test_numerical_if_statement.py rename to tests/kamae/keras/core/layers/test_numerical_if_statement.py index b26d93c7..af504f1f 100644 --- a/tests/kamae/tensorflow/layers/test_numerical_if_statement.py +++ b/tests/kamae/keras/core/layers/test_numerical_if_statement.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import NumericalIfStatementLayer +from kamae.keras.core.layers import NumericalIfStatementLayer class TestNumericalIfStatement: diff --git a/tests/kamae/keras/core/layers/test_pairwise_cosine_similarity.py b/tests/kamae/keras/core/layers/test_pairwise_cosine_similarity.py new file mode 100644 index 00000000..f888dda8 --- /dev/null +++ b/tests/kamae/keras/core/layers/test_pairwise_cosine_similarity.py @@ -0,0 +1,85 @@ +# Copyright [2024] Expedia, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import tensorflow as tf + +from kamae.keras.core.layers import PairwiseCosineSimilarityLayer + + +class TestPairwiseCosineSimilarity: + @pytest.mark.parametrize( + "query, flat_candidates, embedding_dim, expected_output", + [ + ( + tf.constant([[1.0, 0.0, 0.0]]), + tf.constant([[1.0, 0.0, 0.0]]), + 3, + tf.constant([[1.0]]), + ), + ( + tf.constant([[1.0, 0.0, 0.0]]), + tf.constant([[-1.0, 0.0, 0.0]]), + 3, + tf.constant([[-1.0]]), + ), + ( + tf.constant([[1.0, 0.0]]), + tf.constant([[0.0, 1.0]]), + 2, + tf.constant([[0.0]]), + ), + ( + tf.constant([[1.0, 0.0, 0.0]]), + tf.constant([[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]]), + 3, + tf.constant([[1.0, 0.0, 0.0]]), + ), + ( + tf.constant([[0.0, 0.0, 0.0]]), + tf.constant([[1.0, 0.0, 0.0]]), + 3, + tf.constant([[0.0]]), + ), + ], + ) + def test_pairwise_cosine_similarity( + self, query, flat_candidates, embedding_dim, expected_output + ): + layer = PairwiseCosineSimilarityLayer( + name="pairwise_cos", embedding_dim=embedding_dim + ) + output_tensor = layer([query, flat_candidates]) + + assert output_tensor.shape == expected_output.shape + tf.debugging.assert_near(output_tensor, expected_output, atol=1e-6) + + def test_batch_processing(self): + query = tf.constant([[1.0, 0.0], [0.0, 1.0]]) + flat_candidates = tf.constant([[1.0, 0.0, 0.0, 1.0], [1.0, 0.0, 0.0, 1.0]]) + layer = PairwiseCosineSimilarityLayer(name="batch_test", embedding_dim=2) + output_tensor = layer([query, flat_candidates]) + expected = tf.constant([[1.0, 0.0], [0.0, 1.0]]) + tf.debugging.assert_near(output_tensor, expected, atol=1e-6) + + def test_wrong_number_of_inputs(self): + layer = PairwiseCosineSimilarityLayer(name="error_test", embedding_dim=3) + with pytest.raises(ValueError): + layer([tf.constant([[1.0, 0.0, 0.0]])]) + + def test_get_config(self): + layer = PairwiseCosineSimilarityLayer(name="config_test", embedding_dim=64) + config = layer.get_config() + assert config["embedding_dim"] == 64 + assert config["name"] == "config_test" diff --git a/tests/kamae/tensorflow/layers/test_round.py b/tests/kamae/keras/core/layers/test_round.py similarity index 98% rename from tests/kamae/tensorflow/layers/test_round.py rename to tests/kamae/keras/core/layers/test_round.py index ce83afe3..000921e6 100644 --- a/tests/kamae/tensorflow/layers/test_round.py +++ b/tests/kamae/keras/core/layers/test_round.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import RoundLayer +from kamae.keras.core.layers import RoundLayer class TestRound: diff --git a/tests/kamae/tensorflow/layers/test_round_to_decimal.py b/tests/kamae/keras/core/layers/test_round_to_decimal.py similarity index 98% rename from tests/kamae/tensorflow/layers/test_round_to_decimal.py rename to tests/kamae/keras/core/layers/test_round_to_decimal.py index 053b7a45..b00d6c10 100644 --- a/tests/kamae/tensorflow/layers/test_round_to_decimal.py +++ b/tests/kamae/keras/core/layers/test_round_to_decimal.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import RoundToDecimalLayer +from kamae.keras.core.layers import RoundToDecimalLayer class TestRoundToDecimal: diff --git a/tests/kamae/tensorflow/layers/test_standard_scale.py b/tests/kamae/keras/core/layers/test_standard_scale.py similarity index 99% rename from tests/kamae/tensorflow/layers/test_standard_scale.py rename to tests/kamae/keras/core/layers/test_standard_scale.py index e4f0ce64..2c76e722 100644 --- a/tests/kamae/tensorflow/layers/test_standard_scale.py +++ b/tests/kamae/keras/core/layers/test_standard_scale.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import StandardScaleLayer +from kamae.keras.core.layers import StandardScaleLayer class TestStandardScale: diff --git a/tests/kamae/tensorflow/layers/test_subtract.py b/tests/kamae/keras/core/layers/test_subtract.py similarity index 98% rename from tests/kamae/tensorflow/layers/test_subtract.py rename to tests/kamae/keras/core/layers/test_subtract.py index 70da41c2..83499471 100644 --- a/tests/kamae/tensorflow/layers/test_subtract.py +++ b/tests/kamae/keras/core/layers/test_subtract.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import SubtractLayer +from kamae.keras.core.layers import SubtractLayer class TestSubtract: diff --git a/tests/kamae/tensorflow/layers/test_sum.py b/tests/kamae/keras/core/layers/test_sum.py similarity index 98% rename from tests/kamae/tensorflow/layers/test_sum.py rename to tests/kamae/keras/core/layers/test_sum.py index ea80cd8b..cefaf771 100644 --- a/tests/kamae/tensorflow/layers/test_sum.py +++ b/tests/kamae/keras/core/layers/test_sum.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import SumLayer +from kamae.keras.core.layers import SumLayer class TestSum: diff --git a/tests/kamae/tensorflow/layers/test_bloom_encode.py b/tests/kamae/keras/tensorflow/layers/test_bloom_encode.py similarity index 98% rename from tests/kamae/tensorflow/layers/test_bloom_encode.py rename to tests/kamae/keras/tensorflow/layers/test_bloom_encode.py index 266e413a..deaea0b7 100644 --- a/tests/kamae/tensorflow/layers/test_bloom_encode.py +++ b/tests/kamae/keras/tensorflow/layers/test_bloom_encode.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import BloomEncodeLayer +from kamae.keras.tensorflow.layers import BloomEncodeLayer class TestBloomEncode: diff --git a/tests/kamae/tensorflow/layers/test_bucketize.py b/tests/kamae/keras/tensorflow/layers/test_bucketize.py similarity index 98% rename from tests/kamae/tensorflow/layers/test_bucketize.py rename to tests/kamae/keras/tensorflow/layers/test_bucketize.py index 5d9f1d05..2a092911 100644 --- a/tests/kamae/tensorflow/layers/test_bucketize.py +++ b/tests/kamae/keras/tensorflow/layers/test_bucketize.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import BucketizeLayer +from kamae.keras.tensorflow.layers import BucketizeLayer class TestBucketize: diff --git a/tests/kamae/tensorflow/layers/test_current_date.py b/tests/kamae/keras/tensorflow/layers/test_current_date.py similarity index 97% rename from tests/kamae/tensorflow/layers/test_current_date.py rename to tests/kamae/keras/tensorflow/layers/test_current_date.py index 7a946110..ff98e928 100644 --- a/tests/kamae/tensorflow/layers/test_current_date.py +++ b/tests/kamae/keras/tensorflow/layers/test_current_date.py @@ -19,7 +19,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import CurrentDateLayer +from kamae.keras.tensorflow.layers import CurrentDateLayer class TestCurrentDate: @@ -148,7 +148,7 @@ def test_current_date( ): # patch for tf.timestamp() in CurrentDateLayer layer with of 1622745600.0 is 2021-06-03 00:00:00 with patch( - "kamae.tensorflow.layers.current_date.tf.timestamp", + "kamae.keras.tensorflow.layers.current_date.tf.timestamp", lambda: tf.constant(test_timestamp, dtype=tf.float64), ): layer = CurrentDateLayer( @@ -185,7 +185,7 @@ def test_full_dates(self, min_date, max_date): def patch_date(x): with patch( - "kamae.tensorflow.layers.current_date.tf.timestamp", + "kamae.keras.tensorflow.layers.current_date.tf.timestamp", return_value=tf.constant([x], dtype=tf.float64), ): return current_date(tf.constant(1)) diff --git a/tests/kamae/tensorflow/layers/test_current_date_time.py b/tests/kamae/keras/tensorflow/layers/test_current_date_time.py similarity index 96% rename from tests/kamae/tensorflow/layers/test_current_date_time.py rename to tests/kamae/keras/tensorflow/layers/test_current_date_time.py index 6b17f576..5ae48cb5 100644 --- a/tests/kamae/tensorflow/layers/test_current_date_time.py +++ b/tests/kamae/keras/tensorflow/layers/test_current_date_time.py @@ -19,7 +19,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import CurrentDateTimeLayer +from kamae.keras.tensorflow.layers import CurrentDateTimeLayer class TestCurrentDateTime: @@ -147,7 +147,7 @@ def test_current_date_time( ): # patch for tf.timestamp() in CurrentDateTimeLayer layer with of 1622745600.0 is 2021-06-03 00:00:00 with patch( - "kamae.tensorflow.layers.current_date_time.tf.timestamp", + "kamae.keras.tensorflow.layers.current_date_time.tf.timestamp", lambda: tf.constant(test_timestamp, dtype=tf.float64), ): layer = CurrentDateTimeLayer( @@ -181,7 +181,7 @@ def test_full_hour(self, min_date, max_date): def patch_date(x): with patch( - "kamae.tensorflow.layers.current_date_time.tf.timestamp", + "kamae.keras.tensorflow.layers.current_date_time.tf.timestamp", return_value=tf.constant([x], dtype=tf.float64), ): return current_date_time(tf.constant(1)) diff --git a/tests/kamae/tensorflow/layers/test_current_unix_timestamp.py b/tests/kamae/keras/tensorflow/layers/test_current_unix_timestamp.py similarity index 96% rename from tests/kamae/tensorflow/layers/test_current_unix_timestamp.py rename to tests/kamae/keras/tensorflow/layers/test_current_unix_timestamp.py index c105c395..917f63fd 100644 --- a/tests/kamae/tensorflow/layers/test_current_unix_timestamp.py +++ b/tests/kamae/keras/tensorflow/layers/test_current_unix_timestamp.py @@ -17,7 +17,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import CurrentUnixTimestampLayer +from kamae.keras.tensorflow.layers import CurrentUnixTimestampLayer class TestCurrentUnixTimestamp: @@ -111,7 +111,7 @@ def test_current_unix_timestamp( ): # patch for tf.timestamp() in CurrentUnixTimestampLayer layer with of 1622745600.0 is 2021-06-03 00:00:00 with patch( - "kamae.tensorflow.layers.current_unix_timestamp.tf.timestamp", + "kamae.keras.tensorflow.layers.current_unix_timestamp.tf.timestamp", lambda: tf.constant(test_timestamp, dtype=tf.float64), ): layer = CurrentUnixTimestampLayer( diff --git a/tests/kamae/tensorflow/layers/test_date_add.py b/tests/kamae/keras/tensorflow/layers/test_date_add.py similarity index 99% rename from tests/kamae/tensorflow/layers/test_date_add.py rename to tests/kamae/keras/tensorflow/layers/test_date_add.py index 7ed9ea06..3b7eafc2 100644 --- a/tests/kamae/tensorflow/layers/test_date_add.py +++ b/tests/kamae/keras/tensorflow/layers/test_date_add.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import DateAddLayer +from kamae.keras.tensorflow.layers import DateAddLayer class TestDateAdd: diff --git a/tests/kamae/tensorflow/layers/test_date_diff.py b/tests/kamae/keras/tensorflow/layers/test_date_diff.py similarity index 98% rename from tests/kamae/tensorflow/layers/test_date_diff.py rename to tests/kamae/keras/tensorflow/layers/test_date_diff.py index 8ea495ca..afde95a0 100644 --- a/tests/kamae/tensorflow/layers/test_date_diff.py +++ b/tests/kamae/keras/tensorflow/layers/test_date_diff.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import DateDiffLayer +from kamae.keras.tensorflow.layers import DateDiffLayer class TestDateDiff: diff --git a/tests/kamae/tensorflow/layers/test_date_parse.py b/tests/kamae/keras/tensorflow/layers/test_date_parse.py similarity index 99% rename from tests/kamae/tensorflow/layers/test_date_parse.py rename to tests/kamae/keras/tensorflow/layers/test_date_parse.py index 29d46bee..f2f2c9bc 100644 --- a/tests/kamae/tensorflow/layers/test_date_parse.py +++ b/tests/kamae/keras/tensorflow/layers/test_date_parse.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import DateParseLayer +from kamae.keras.tensorflow.layers import DateParseLayer class TestDateParse: diff --git a/tests/kamae/tensorflow/layers/test_date_time_to_unix_timestamp.py b/tests/kamae/keras/tensorflow/layers/test_date_time_to_unix_timestamp.py similarity index 98% rename from tests/kamae/tensorflow/layers/test_date_time_to_unix_timestamp.py rename to tests/kamae/keras/tensorflow/layers/test_date_time_to_unix_timestamp.py index 723cae8e..d7bfd6b4 100644 --- a/tests/kamae/tensorflow/layers/test_date_time_to_unix_timestamp.py +++ b/tests/kamae/keras/tensorflow/layers/test_date_time_to_unix_timestamp.py @@ -17,7 +17,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import DateTimeToUnixTimestampLayer +from kamae.keras.tensorflow.layers import DateTimeToUnixTimestampLayer class TestDateTimeToUnixTimestamp: diff --git a/tests/kamae/tensorflow/layers/test_hash_index.py b/tests/kamae/keras/tensorflow/layers/test_hash_index.py similarity index 98% rename from tests/kamae/tensorflow/layers/test_hash_index.py rename to tests/kamae/keras/tensorflow/layers/test_hash_index.py index 9fd34e1f..220d9c7e 100644 --- a/tests/kamae/tensorflow/layers/test_hash_index.py +++ b/tests/kamae/keras/tensorflow/layers/test_hash_index.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import HashIndexLayer +from kamae.keras.tensorflow.layers import HashIndexLayer class TestHashIndex: diff --git a/tests/kamae/tensorflow/layers/test_if_statement.py b/tests/kamae/keras/tensorflow/layers/test_if_statement.py similarity index 99% rename from tests/kamae/tensorflow/layers/test_if_statement.py rename to tests/kamae/keras/tensorflow/layers/test_if_statement.py index 77440222..cbf8d5e0 100644 --- a/tests/kamae/tensorflow/layers/test_if_statement.py +++ b/tests/kamae/keras/tensorflow/layers/test_if_statement.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import IfStatementLayer +from kamae.keras.tensorflow.layers import IfStatementLayer class TestIfStatement: diff --git a/tests/kamae/tensorflow/layers/test_lambda_function.py b/tests/kamae/keras/tensorflow/layers/test_lambda_function.py similarity index 99% rename from tests/kamae/tensorflow/layers/test_lambda_function.py rename to tests/kamae/keras/tensorflow/layers/test_lambda_function.py index 30af917e..6c68e24d 100644 --- a/tests/kamae/tensorflow/layers/test_lambda_function.py +++ b/tests/kamae/keras/tensorflow/layers/test_lambda_function.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import LambdaFunctionLayer +from kamae.keras.tensorflow.layers import LambdaFunctionLayer class TestLambdaFunction: diff --git a/tests/kamae/tensorflow/layers/test_list_max.py b/tests/kamae/keras/tensorflow/layers/test_list_max.py similarity index 99% rename from tests/kamae/tensorflow/layers/test_list_max.py rename to tests/kamae/keras/tensorflow/layers/test_list_max.py index 1c8b7bee..7ffa8db1 100644 --- a/tests/kamae/tensorflow/layers/test_list_max.py +++ b/tests/kamae/keras/tensorflow/layers/test_list_max.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import ListMaxLayer +from kamae.keras.tensorflow.layers import ListMaxLayer class TestListMax: diff --git a/tests/kamae/tensorflow/layers/test_list_mean.py b/tests/kamae/keras/tensorflow/layers/test_list_mean.py similarity index 99% rename from tests/kamae/tensorflow/layers/test_list_mean.py rename to tests/kamae/keras/tensorflow/layers/test_list_mean.py index 10dca8ec..769364b5 100644 --- a/tests/kamae/tensorflow/layers/test_list_mean.py +++ b/tests/kamae/keras/tensorflow/layers/test_list_mean.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import ListMeanLayer +from kamae.keras.tensorflow.layers import ListMeanLayer class TestListMean: diff --git a/tests/kamae/tensorflow/layers/test_list_median.py b/tests/kamae/keras/tensorflow/layers/test_list_median.py similarity index 99% rename from tests/kamae/tensorflow/layers/test_list_median.py rename to tests/kamae/keras/tensorflow/layers/test_list_median.py index 367eeb21..513c2c47 100644 --- a/tests/kamae/tensorflow/layers/test_list_median.py +++ b/tests/kamae/keras/tensorflow/layers/test_list_median.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import ListMedianLayer +from kamae.keras.tensorflow.layers import ListMedianLayer class TestListMedian: diff --git a/tests/kamae/tensorflow/layers/test_list_min.py b/tests/kamae/keras/tensorflow/layers/test_list_min.py similarity index 99% rename from tests/kamae/tensorflow/layers/test_list_min.py rename to tests/kamae/keras/tensorflow/layers/test_list_min.py index 8989c569..29d72f04 100644 --- a/tests/kamae/tensorflow/layers/test_list_min.py +++ b/tests/kamae/keras/tensorflow/layers/test_list_min.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import ListMinLayer +from kamae.keras.tensorflow.layers import ListMinLayer class TestListMin: diff --git a/tests/kamae/tensorflow/layers/test_list_rank.py b/tests/kamae/keras/tensorflow/layers/test_list_rank.py similarity index 98% rename from tests/kamae/tensorflow/layers/test_list_rank.py rename to tests/kamae/keras/tensorflow/layers/test_list_rank.py index 39e2736a..76d26bd5 100644 --- a/tests/kamae/tensorflow/layers/test_list_rank.py +++ b/tests/kamae/keras/tensorflow/layers/test_list_rank.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import ListRankLayer +from kamae.keras.tensorflow.layers import ListRankLayer class TestListRank: diff --git a/tests/kamae/tensorflow/layers/test_list_std_dev.py b/tests/kamae/keras/tensorflow/layers/test_list_std_dev.py similarity index 99% rename from tests/kamae/tensorflow/layers/test_list_std_dev.py rename to tests/kamae/keras/tensorflow/layers/test_list_std_dev.py index 4d86ed62..1c6602df 100644 --- a/tests/kamae/tensorflow/layers/test_list_std_dev.py +++ b/tests/kamae/keras/tensorflow/layers/test_list_std_dev.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import ListStdDevLayer +from kamae.keras.tensorflow.layers import ListStdDevLayer class TestListStdDev: diff --git a/tests/kamae/tensorflow/layers/test_min_hash_index.py b/tests/kamae/keras/tensorflow/layers/test_min_hash_index.py similarity index 98% rename from tests/kamae/tensorflow/layers/test_min_hash_index.py rename to tests/kamae/keras/tensorflow/layers/test_min_hash_index.py index bfa583b1..c509bbe5 100644 --- a/tests/kamae/tensorflow/layers/test_min_hash_index.py +++ b/tests/kamae/keras/tensorflow/layers/test_min_hash_index.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import MinHashIndexLayer +from kamae.keras.tensorflow.layers import MinHashIndexLayer class TestMinHashIndex: diff --git a/tests/kamae/tensorflow/layers/test_one_hot_encode.py b/tests/kamae/keras/tensorflow/layers/test_one_hot_encode.py similarity index 98% rename from tests/kamae/tensorflow/layers/test_one_hot_encode.py rename to tests/kamae/keras/tensorflow/layers/test_one_hot_encode.py index 07ff486a..b1b63b6b 100644 --- a/tests/kamae/tensorflow/layers/test_one_hot_encode.py +++ b/tests/kamae/keras/tensorflow/layers/test_one_hot_encode.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import OneHotEncodeLayer +from kamae.keras.tensorflow.layers import OneHotEncodeLayer class TestOneHotEncode: diff --git a/tests/kamae/tensorflow/layers/test_ordinal_array_encode.py b/tests/kamae/keras/tensorflow/layers/test_ordinal_array_encode.py similarity index 97% rename from tests/kamae/tensorflow/layers/test_ordinal_array_encode.py rename to tests/kamae/keras/tensorflow/layers/test_ordinal_array_encode.py index a5e171be..dd1e3a72 100644 --- a/tests/kamae/tensorflow/layers/test_ordinal_array_encode.py +++ b/tests/kamae/keras/tensorflow/layers/test_ordinal_array_encode.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers.ordinal_array_encode import OrdinalArrayEncodeLayer +from kamae.keras.tensorflow.layers import OrdinalArrayEncodeLayer class TestOrdinalArrayEncode: diff --git a/tests/kamae/tensorflow/layers/test_string_affix.py b/tests/kamae/keras/tensorflow/layers/test_string_affix.py similarity index 97% rename from tests/kamae/tensorflow/layers/test_string_affix.py rename to tests/kamae/keras/tensorflow/layers/test_string_affix.py index d3e43acc..25b1ad34 100644 --- a/tests/kamae/tensorflow/layers/test_string_affix.py +++ b/tests/kamae/keras/tensorflow/layers/test_string_affix.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import StringAffixLayer +from kamae.keras.tensorflow.layers import StringAffixLayer class TestStringAffix: diff --git a/tests/kamae/tensorflow/layers/test_string_array_constant.py b/tests/kamae/keras/tensorflow/layers/test_string_array_constant.py similarity index 97% rename from tests/kamae/tensorflow/layers/test_string_array_constant.py rename to tests/kamae/keras/tensorflow/layers/test_string_array_constant.py index 10b99caa..ed93659f 100644 --- a/tests/kamae/tensorflow/layers/test_string_array_constant.py +++ b/tests/kamae/keras/tensorflow/layers/test_string_array_constant.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import StringArrayConstantLayer +from kamae.keras.tensorflow.layers import StringArrayConstantLayer class TestStringArrayConstant: diff --git a/tests/kamae/tensorflow/layers/test_string_case.py b/tests/kamae/keras/tensorflow/layers/test_string_case.py similarity index 98% rename from tests/kamae/tensorflow/layers/test_string_case.py rename to tests/kamae/keras/tensorflow/layers/test_string_case.py index f83c0f4a..b309ae11 100644 --- a/tests/kamae/tensorflow/layers/test_string_case.py +++ b/tests/kamae/keras/tensorflow/layers/test_string_case.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import StringCaseLayer +from kamae.keras.tensorflow.layers import StringCaseLayer class TestStringCase: diff --git a/tests/kamae/tensorflow/layers/test_string_concatenate.py b/tests/kamae/keras/tensorflow/layers/test_string_concatenate.py similarity index 97% rename from tests/kamae/tensorflow/layers/test_string_concatenate.py rename to tests/kamae/keras/tensorflow/layers/test_string_concatenate.py index 03401a72..31abe72a 100644 --- a/tests/kamae/tensorflow/layers/test_string_concatenate.py +++ b/tests/kamae/keras/tensorflow/layers/test_string_concatenate.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import StringConcatenateLayer +from kamae.keras.tensorflow.layers import StringConcatenateLayer class TestStringConcatenate: diff --git a/tests/kamae/tensorflow/layers/test_string_contains.py b/tests/kamae/keras/tensorflow/layers/test_string_contains.py similarity index 99% rename from tests/kamae/tensorflow/layers/test_string_contains.py rename to tests/kamae/keras/tensorflow/layers/test_string_contains.py index 8a620e06..4fea6a9c 100644 --- a/tests/kamae/tensorflow/layers/test_string_contains.py +++ b/tests/kamae/keras/tensorflow/layers/test_string_contains.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import StringContainsLayer +from kamae.keras.tensorflow.layers import StringContainsLayer class TestStringContains: diff --git a/tests/kamae/tensorflow/layers/test_string_contains_list.py b/tests/kamae/keras/tensorflow/layers/test_string_contains_list.py similarity index 99% rename from tests/kamae/tensorflow/layers/test_string_contains_list.py rename to tests/kamae/keras/tensorflow/layers/test_string_contains_list.py index 24da1611..4eb799ae 100644 --- a/tests/kamae/tensorflow/layers/test_string_contains_list.py +++ b/tests/kamae/keras/tensorflow/layers/test_string_contains_list.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import StringContainsListLayer +from kamae.keras.tensorflow.layers import StringContainsListLayer # TODO: Rename and repurpose diff --git a/tests/kamae/tensorflow/layers/test_string_equals_if_statement.py b/tests/kamae/keras/tensorflow/layers/test_string_equals_if_statement.py similarity index 98% rename from tests/kamae/tensorflow/layers/test_string_equals_if_statement.py rename to tests/kamae/keras/tensorflow/layers/test_string_equals_if_statement.py index a6218814..4d9acbd8 100644 --- a/tests/kamae/tensorflow/layers/test_string_equals_if_statement.py +++ b/tests/kamae/keras/tensorflow/layers/test_string_equals_if_statement.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import StringEqualsIfStatementLayer +from kamae.keras.tensorflow.layers import StringEqualsIfStatementLayer class TestStringEqualsIfStatement: diff --git a/tests/kamae/tensorflow/layers/test_string_index.py b/tests/kamae/keras/tensorflow/layers/test_string_index.py similarity index 98% rename from tests/kamae/tensorflow/layers/test_string_index.py rename to tests/kamae/keras/tensorflow/layers/test_string_index.py index a457c4cf..2b98aa52 100644 --- a/tests/kamae/tensorflow/layers/test_string_index.py +++ b/tests/kamae/keras/tensorflow/layers/test_string_index.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import StringIndexLayer +from kamae.keras.tensorflow.layers import StringIndexLayer class TestStringIndex: diff --git a/tests/kamae/tensorflow/layers/test_string_isin_list.py b/tests/kamae/keras/tensorflow/layers/test_string_isin_list.py similarity index 97% rename from tests/kamae/tensorflow/layers/test_string_isin_list.py rename to tests/kamae/keras/tensorflow/layers/test_string_isin_list.py index a8e72303..dea6cb3a 100644 --- a/tests/kamae/tensorflow/layers/test_string_isin_list.py +++ b/tests/kamae/keras/tensorflow/layers/test_string_isin_list.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import StringIsInListLayer +from kamae.keras.tensorflow.layers import StringIsInListLayer class TestStringIsInList: @@ -44,7 +44,7 @@ class TestStringIsInList: tf.constant([["Mon"], ["mon"], [""], ["MON"]]), "input_3", "string", - "float", + "float32", ["mon"], False, tf.constant([[0.0], [1.0], [0.0], [0.0]], dtype=tf.float32), @@ -72,7 +72,7 @@ class TestStringIsInList: tf.constant([[1], [2], [3], [4]]), "input_3", "string", - "float", + "float32", ["1"], False, tf.constant([[1.0], [0.0], [0.0], [0.0]], dtype=tf.float32), diff --git a/tests/kamae/tensorflow/layers/test_string_list_to_string.py b/tests/kamae/keras/tensorflow/layers/test_string_list_to_string.py similarity index 97% rename from tests/kamae/tensorflow/layers/test_string_list_to_string.py rename to tests/kamae/keras/tensorflow/layers/test_string_list_to_string.py index ccb6c023..3fa9e224 100644 --- a/tests/kamae/tensorflow/layers/test_string_list_to_string.py +++ b/tests/kamae/keras/tensorflow/layers/test_string_list_to_string.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import StringListToStringLayer +from kamae.keras.tensorflow.layers import StringListToStringLayer class TestStringListToString: diff --git a/tests/kamae/tensorflow/layers/test_string_map.py b/tests/kamae/keras/tensorflow/layers/test_string_map.py similarity index 98% rename from tests/kamae/tensorflow/layers/test_string_map.py rename to tests/kamae/keras/tensorflow/layers/test_string_map.py index e7e838ea..e45e2e5b 100644 --- a/tests/kamae/tensorflow/layers/test_string_map.py +++ b/tests/kamae/keras/tensorflow/layers/test_string_map.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import StringMapLayer +from kamae.keras.tensorflow.layers import StringMapLayer class TestStringMap: diff --git a/tests/kamae/tensorflow/layers/test_string_replace.py b/tests/kamae/keras/tensorflow/layers/test_string_replace.py similarity index 99% rename from tests/kamae/tensorflow/layers/test_string_replace.py rename to tests/kamae/keras/tensorflow/layers/test_string_replace.py index 786374f4..9bc39e6a 100644 --- a/tests/kamae/tensorflow/layers/test_string_replace.py +++ b/tests/kamae/keras/tensorflow/layers/test_string_replace.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import StringReplaceLayer +from kamae.keras.tensorflow.layers import StringReplaceLayer class TestStringReplace: diff --git a/tests/kamae/tensorflow/layers/test_string_to_string_list.py b/tests/kamae/keras/tensorflow/layers/test_string_to_string_list.py similarity index 98% rename from tests/kamae/tensorflow/layers/test_string_to_string_list.py rename to tests/kamae/keras/tensorflow/layers/test_string_to_string_list.py index 437c8e69..e312ef78 100644 --- a/tests/kamae/tensorflow/layers/test_string_to_string_list.py +++ b/tests/kamae/keras/tensorflow/layers/test_string_to_string_list.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import StringToStringListLayer +from kamae.keras.tensorflow.layers import StringToStringListLayer class TestStringToStringList: diff --git a/tests/kamae/tensorflow/layers/test_sub_string_delim_at_index.py b/tests/kamae/keras/tensorflow/layers/test_sub_string_delim_at_index.py similarity index 98% rename from tests/kamae/tensorflow/layers/test_sub_string_delim_at_index.py rename to tests/kamae/keras/tensorflow/layers/test_sub_string_delim_at_index.py index 20d5c56c..ca05fdfe 100644 --- a/tests/kamae/tensorflow/layers/test_sub_string_delim_at_index.py +++ b/tests/kamae/keras/tensorflow/layers/test_sub_string_delim_at_index.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import SubStringDelimAtIndexLayer +from kamae.keras.tensorflow.layers import SubStringDelimAtIndexLayer class TestSubStringDelimAtIndex: diff --git a/tests/kamae/tensorflow/layers/test_unix_timestamp_to_date_time.py b/tests/kamae/keras/tensorflow/layers/test_unix_timestamp_to_date_time.py similarity index 98% rename from tests/kamae/tensorflow/layers/test_unix_timestamp_to_date_time.py rename to tests/kamae/keras/tensorflow/layers/test_unix_timestamp_to_date_time.py index 02a1daa3..dbb337d1 100644 --- a/tests/kamae/tensorflow/layers/test_unix_timestamp_to_date_time.py +++ b/tests/kamae/keras/tensorflow/layers/test_unix_timestamp_to_date_time.py @@ -17,7 +17,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.layers import UnixTimestampToDateTimeLayer +from kamae.keras.tensorflow.layers import UnixTimestampToDateTimeLayer class TestUnixTimestampToDate: diff --git a/tests/kamae/tensorflow/utils/test_list_utils.py b/tests/kamae/keras/tensorflow/test_list_utils.py similarity index 98% rename from tests/kamae/tensorflow/utils/test_list_utils.py rename to tests/kamae/keras/tensorflow/test_list_utils.py index 8d210e77..71a4a06e 100644 --- a/tests/kamae/tensorflow/utils/test_list_utils.py +++ b/tests/kamae/keras/tensorflow/test_list_utils.py @@ -15,7 +15,7 @@ import pytest import tensorflow as tf -from kamae.tensorflow.utils import get_top_n +from kamae.keras.tensorflow.utils import get_top_n class TestGetTopN: diff --git a/tests/kamae/keras/test_jit_compatibility.py b/tests/kamae/keras/test_jit_compatibility.py new file mode 100644 index 00000000..7c1d1338 --- /dev/null +++ b/tests/kamae/keras/test_jit_compatibility.py @@ -0,0 +1,586 @@ +# Copyright [2024] Expedia, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for JIT compatibility of Keras layers.""" + +import keras +import pytest +import tensorflow as tf + +import kamae.keras.core.layers as core_layers_mod +import kamae.keras.tensorflow.layers as tf_layers_mod + +# Multi-backend layers +from kamae.keras.core.layers import ( + AbsoluteValueLayer, + ArrayConcatenateLayer, + ArrayCropLayer, + ArrayReduceMaxLayer, + ArraySplitLayer, + ArraySubtractMinimumLayer, + BearingAngleLayer, + BinLayer, + ConditionalStandardScaleLayer, + CosineSimilarityLayer, + DivideLayer, + ExpLayer, + ExponentLayer, + HaversineDistanceLayer, + IdentityLayer, + ImputeLayer, + LogicalAndLayer, + LogicalNotLayer, + LogicalOrLayer, + LogLayer, + MaxLayer, + MeanLayer, + MinLayer, + MinMaxScaleLayer, + ModuloLayer, + MultiplyLayer, + NumericalIfStatementLayer, + PairwiseCosineSimilarityLayer, + RoundLayer, + RoundToDecimalLayer, + StandardScaleLayer, + SubtractLayer, + SumLayer, +) + +# TF-only layers +from kamae.keras.tensorflow.layers import ( + BloomEncodeLayer, + BucketizeLayer, + CurrentDateLayer, + CurrentDateTimeLayer, + CurrentUnixTimestampLayer, + DateAddLayer, + DateDiffLayer, + DateParseLayer, + DateTimeToUnixTimestampLayer, + HashIndexLayer, + IfStatementLayer, + LambdaFunctionLayer, + ListMaxLayer, + ListMeanLayer, + ListMedianLayer, + ListMinLayer, + ListRankLayer, + ListStdDevLayer, + MinHashIndexLayer, + OneHotEncodeLayer, + OneHotLayer, + OrdinalArrayEncodeLayer, + StringAffixLayer, + StringArrayConstantLayer, + StringCaseLayer, + StringConcatenateLayer, + StringContainsLayer, + StringContainsListLayer, + StringEqualsIfStatementLayer, + StringIndexLayer, + StringIsInListLayer, + StringListToStringLayer, + StringMapLayer, + StringReplaceLayer, + StringToStringListLayer, + SubStringDelimAtIndexLayer, + UnixTimestampToDateTimeLayer, +) + +# JIT-compatible layers (jit_compatible = True) +JIT_COMPATIBLE_LAYERS = [ + # All 31 core layers + (AbsoluteValueLayer, [tf.random.normal((32, 10))], None), + ( + ArrayConcatenateLayer, + [tf.random.normal((32, 10, 100, 3)), tf.random.normal((32, 10, 100, 3))], + {"axis": -2}, + ), + (ArrayReduceMaxLayer, [tf.random.normal((32, 10))], {"default_value": 0.0}), + (ArraySplitLayer, [tf.random.normal((32, 10, 100, 3))], {"axis": -2}), + ( + ArraySubtractMinimumLayer, + [tf.random.normal((32, 10, 10, 3))], + {"axis": 1, "pad_value": 0}, + ), + ( + ArrayCropLayer, + [tf.constant(1.0, shape=(1, 4))], + {"array_length": 3, "pad_value": -1.0}, + ), + ( + BearingAngleLayer, + [ + tf.constant(0.0, shape=(100, 10, 1)), + tf.constant(90.0, shape=(100, 10, 1)), + ], + {"lat_lon_constant": [-45.9, 180.67]}, + ), + ( + BinLayer, + [tf.random.normal((100, 56, 3))], + { + "condition_operators": ["eq", "neq", "lt", "leq", "gt", "geq"], + "bin_values": [0, 1, 2, 3, 4, 5], + "bin_labels": [0.0, 1.0, 2.0, 3.0, 4.0, 5.0], + "default_label": 6.0, + }, + ), + ( + ConditionalStandardScaleLayer, + [tf.random.normal((100, 10, 5))], + { + "mean": [0.0, 1.0, 5.6, 7.8, 9.0], + "variance": [1.0, 1.0, 1.0, 1.0, 1.0], + "axis": -1, + "skip_zeros": True, + }, + ), + ( + CosineSimilarityLayer, + [tf.random.normal((100, 10, 10, 5)), tf.random.normal((100, 10, 10, 5))], + None, + ), + ( + PairwiseCosineSimilarityLayer, + [tf.random.normal((32, 4)), tf.random.normal((32, 12))], + {"embedding_dim": 4}, + ), + (DivideLayer, [tf.random.normal((100, 10, 5))], {"divisor": 2}), + (ExpLayer, [tf.random.normal((100, 10, 5))], None), + (ExponentLayer, [tf.random.normal((100, 10, 5))], {"exponent": 2}), + ( + HaversineDistanceLayer, + [ + tf.constant(-90.0, shape=(100, 10, 1)), + tf.constant(178.9, shape=(100, 10, 1)), + ], + {"lat_lon_constant": [-45.9, 180.67], "unit": "miles"}, + ), + (IdentityLayer, [tf.random.normal((100, 10, 5))], None), + ( + ImputeLayer, + [tf.constant([[[-999.0], [6.0], [9.0], [100.0]]])], + { + "impute_value": 2.0, + "mask_value": -999.0, + }, + ), + (LogLayer, [tf.random.normal((100, 10, 5))], None), + ( + LogicalAndLayer, + [tf.constant(True, shape=(10, 1, 5)), tf.constant(False, shape=(10, 1, 5))], + None, + ), + (LogicalNotLayer, [tf.constant(True, shape=(10, 1, 5))], None), + ( + LogicalOrLayer, + [tf.constant(True, shape=(10, 1, 5)), tf.constant(False, shape=(10, 1, 5))], + None, + ), + (MaxLayer, [tf.random.normal((100, 10, 5))], {"max_constant": 10}), + (MeanLayer, [tf.random.normal((100, 10, 5))], {"mean_constant": 10}), + (MinLayer, [tf.random.normal((100, 10, 5))], {"min_constant": 10}), + ( + MinMaxScaleLayer, + [ + tf.concat( + [ + tf.random.uniform((100, 10, 1), minval=-i, maxval=i) + for i in range(1, 6) + ], + axis=-1, + ) + ], + { + "min": [-i for i in range(1, 6)], + "max": [i for i in range(1, 6)], + "axis": -1, + }, + ), + (ModuloLayer, [tf.random.normal((1000, 32, 1))], {"divisor": 10}), + (MultiplyLayer, [tf.random.normal((1, 5))], {"multiplier": 50}), + ( + NumericalIfStatementLayer, + [tf.random.normal((100, 10, 5)), tf.random.normal((100, 10, 5))], + {"condition_operator": "gt", "value_to_compare": 5, "result_if_true": 1}, + ), + ( + RoundLayer, + [tf.random.normal((10, 10, 10, 1))], + {"round_type": "ceil"}, + ), + (RoundToDecimalLayer, [tf.random.normal((100, 5))], {"decimals": 2}), + ( + StandardScaleLayer, + [tf.random.normal((100, 10, 5))], + { + "mean": [0.0, 1.0, 5.6, 7.8, 9.0], + "variance": [1.0, 1.0, 1.0, 1.0, 1.0], + "axis": -1, + }, + ), + (SubtractLayer, [tf.random.normal((100, 10, 5))], {"subtrahend": 10}), + (SumLayer, [tf.random.normal((100, 10, 5))], {"addend": -1}), + # TF-only JIT-compatible layers + ( + ListRankLayer, + [tf.random.normal((1, 2, 3))], + {"axis": 1, "sort_order": "desc"}, + ), + ( + BucketizeLayer, + [tf.random.normal((100, 1))], + {"splits": [-0.5, 0, 0.1, 0.2, 3]}, + ), + (ListMaxLayer, [tf.random.normal((100, 10, 5))], None), + (ListMeanLayer, [tf.random.normal((100, 10, 5))], None), + ( + ListMedianLayer, + [tf.constant([[[1.0], [2.0], [3.0]], [[4.0], [5.0], [6.0]]])], + { + "axis": 1, + "top_n": 5, + "sort_order": "descending", + "nan_fill_value": 0, + "min_filter_value": 0, + }, + ), + (ListMinLayer, [tf.random.normal((100, 10, 5))], None), + ( + ListStdDevLayer, + [tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])], + { + "axis": -1, + "top_n": 5, + "sort_order": "descending", + "nan_fill_value": 0, + "min_filter_value": 0, + }, + ), +] + + +# JIT-incompatible layers (jit_compatible = False) +JIT_INCOMPATIBLE_LAYERS = [ + ( + BloomEncodeLayer, + [tf.strings.as_string(tf.random.normal((100, 23, 32, 1)))], + {"num_hash_fns": 3, "num_bins": 100}, + ), + (CurrentDateLayer, [tf.constant(100, shape=(100, 10, 1))], None), + (CurrentDateTimeLayer, [tf.constant(100, shape=(100, 10, 1))], None), + ( + CurrentUnixTimestampLayer, + [tf.constant(100, shape=(100, 10, 1))], + {"unit": "ms"}, + ), + ( + DateAddLayer, + [ + tf.constant("2023-03-02", shape=(100, 10, 1)), + ], + {"num_days": 10}, + ), + ( + DateDiffLayer, + [ + tf.constant("2023-03-02", shape=(100, 10, 1)), + tf.constant("2023-02-02", shape=(100, 10, 1)), + ], + {"default_value": 1}, + ), + ( + DateParseLayer, + [tf.constant("2023-02-02", shape=(100, 10, 1))], + {"date_part": "DayOfWeek", "default_value": 1}, + ), + ( + DateTimeToUnixTimestampLayer, + [tf.constant("2021-07-14", shape=(100, 10, 1))], + {"unit": "s"}, + ), + ( + HashIndexLayer, + [tf.strings.as_string(tf.random.normal((100, 10, 5)))], + {"num_bins": 100}, + ), + ( + IfStatementLayer, + [tf.constant("hello", shape=(100, 10, 5))], + { + "condition_operator": "eq", + "value_to_compare": "world", + "result_if_true": "yes", + "result_if_false": "no", + }, + ), + ( + LambdaFunctionLayer, + [tf.constant([[1, 2, 3], [4, 5, 6]])], + { + "function": lambda x: tf.square(x), + "input_dtype": "float", + "output_dtype": "float", + "output_shape": (3,), + }, + ), + ( + MinHashIndexLayer, + [tf.strings.as_string(tf.random.normal((100, 10, 5)))], + {"num_permutations": 10, "mask_value": None, "axis": -1}, + ), + ( + OneHotEncodeLayer, + [tf.constant("a", shape=(100, 10, 1))], + {"num_oov_indices": 1, "vocabulary": ["a", "b"], "drop_unseen": True}, + ), + ( + OneHotLayer, + [tf.constant("a", shape=(100, 10, 1))], + {"num_oov_indices": 1, "vocabulary": ["a", "b"], "drop_unseen": True}, + ), + ( + OrdinalArrayEncodeLayer, + [tf.constant([["a", "a", "b", "-1"]])], + {"pad_value": "-1"}, + ), + ( + StringAffixLayer, + [tf.constant("a", shape=(100, 10, 1))], + {"prefix": "b", "suffix": "c"}, + ), + ( + StringArrayConstantLayer, + [tf.constant("a", shape=(100, 10, 1))], + {"constant_string_array": "b"}, + ), + ( + StringCaseLayer, + [tf.constant("hEllO wOrLd", shape=(100, 10, 1))], + {"string_case_type": "lower"}, + ), + ( + StringConcatenateLayer, + [ + tf.constant("a", shape=(10, 1, 1, 5, 2)), + tf.constant("b", shape=(10, 1, 1, 5, 2)), + ], + {"separator": "y"}, + ), + ( + StringContainsLayer, + [ + tf.constant("a", shape=(100, 10, 1)), + tf.constant("b", shape=(100, 10, 1)), + ], + {"negation": True}, + ), + ( + StringContainsListLayer, + [tf.constant("a", shape=(230, 67, 1))], + {"negation": True, "string_constant_list": ["a", "b", "c"]}, + ), + ( + StringEqualsIfStatementLayer, + [ + tf.constant("a", shape=(23, 1, 1, 67)), + tf.constant("b", shape=(23, 1, 1, 67)), + ], + {"result_if_true": "a", "result_if_false": "b"}, + ), + ( + StringIndexLayer, + [tf.constant("a", shape=(23, 5))], + { + "num_oov_indices": 2, + "encoding": "utf-8", + "vocabulary": ["a", "b"], + "mask_token": "c", + }, + ), + ( + StringIsInListLayer, + [tf.constant("a", shape=(23, 5))], + {"string_constant_list": ["a", "b", "c"], "negation": False}, + ), + ( + StringListToStringLayer, + [tf.constant("a", shape=(23, 5))], + {"separator": "b", "axis": -1}, + ), + ( + StringMapLayer, + [tf.constant("a", shape=(100, 5))], + { + "string_match_values": ["a", "c"], + "string_replace_values": ["b", "c"], + "default_replace_value": "z", + }, + ), + ( + StringReplaceLayer, + [tf.constant("a_b_c_d_e", shape=(1, 5, 45))], + { + "string_match_constant": "_", + "string_replace_constant": "-", + "regex": False, + }, + ), + ( + StringToStringListLayer, + [tf.constant("a", shape=(100, 5))], + {"separator": "b", "default_value": "hello", "list_length": 5}, + ), + ( + SubStringDelimAtIndexLayer, + [tf.constant("a_b_c_d_e", shape=(1, 5, 45))], + {"delimiter": "_", "index": 3, "default_value": "hello"}, + ), + ( + UnixTimestampToDateTimeLayer, + [tf.constant(100000, shape=(100, 10, 1), dtype=tf.int64)], + {"include_time": True, "unit": "s"}, + ), +] + + +@pytest.mark.parametrize("layer_cls, input_tensors, kwargs", JIT_COMPATIBLE_LAYERS) +def test_jit_compatible_layers_pass(layer_cls, input_tensors, kwargs): + """Test that layers marked jit_compatible=True can be JIT-compiled.""" + if kwargs is None: + kwargs = {} + + layer = layer_cls(**kwargs) + assert ( + layer.jit_compatible is True + ), f"{layer_cls.__name__} should have jit_compatible=True" + + @tf.function(jit_compile=True) + def jit_call(*inputs): + if len(inputs) == 1: + return layer(inputs[0]) + return layer(list(inputs)) + + # Must not raise + result = jit_call(*input_tensors) + assert result is not None + + +@pytest.mark.parametrize("layer_cls, input_tensors, kwargs", JIT_INCOMPATIBLE_LAYERS) +def test_jit_incompatible_layers_fail(layer_cls, input_tensors, kwargs): + """Test that layers marked jit_compatible=False fail JIT compilation. + + This ensures that if a layer becomes JIT-safe (e.g., TF upgrade), the test + breaks and prompts the developer to update the jit_compatible flag. + """ + if kwargs is None: + kwargs = {} + + layer = layer_cls(**kwargs) + assert ( + layer.jit_compatible is False + ), f"{layer_cls.__name__} should have jit_compatible=False" + + @tf.function(jit_compile=True) + def jit_call(*inputs): + if len(inputs) == 1: + return layer(inputs[0]) + return layer(list(inputs)) + + # Must raise Exception when trying to JIT compile + with pytest.raises(Exception): + result = jit_call(*input_tensors) + # Force evaluation if result is symbolic + if hasattr(result, "numpy"): + result.numpy() + + +def test_all_layers_define_jit_compatible_and_supported_backends(): + """Test that all layers define jit_compatible and supported_backends directly (not inherited).""" + # Get all classes from kamae.keras.core.layers (multi-backend) + multi_backend_layers = [ + obj + for name, obj in vars(core_layers_mod).items() + if isinstance(obj, type) + and issubclass(obj, keras.Layer) + and obj is not keras.Layer + and name != "BaseLayer" # Exclude base class + ] + + # Get all classes from kamae.keras.tensorflow.layers (TF-only) + tf_only_layers = [ + obj + for name, obj in vars(tf_layers_mod).items() + if isinstance(obj, type) + and issubclass(obj, tf.keras.layers.Layer) + and obj is not tf.keras.layers.Layer + ] + + all_layers = multi_backend_layers + tf_only_layers + + for layer_cls in all_layers: + assert ( + "jit_compatible" in layer_cls.__dict__ + ), f"{layer_cls.__name__} must define 'jit_compatible' directly (not inherit it)" + assert isinstance( + layer_cls.jit_compatible, bool + ), f"{layer_cls.__name__}.jit_compatible must be bool, got {type(layer_cls.jit_compatible)}" + assert ( + "supported_backends" in layer_cls.__dict__ + ), f"{layer_cls.__name__} must define 'supported_backends' directly (not inherit it)" + assert isinstance( + layer_cls.supported_backends, frozenset + ), f"{layer_cls.__name__}.supported_backends must be frozenset" + + +def test_all_layers_in_jit_tests(): + """Test that all layers appear in exactly one of the JIT test lists.""" + # Get all layer classes + multi_backend_layers = [ + obj + for name, obj in vars(core_layers_mod).items() + if isinstance(obj, type) + and issubclass(obj, keras.Layer) + and obj is not keras.Layer + and name != "BaseLayer" + ] + + tf_only_layers = [ + obj + for name, obj in vars(tf_layers_mod).items() + if isinstance(obj, type) + and issubclass(obj, tf.keras.layers.Layer) + and obj is not tf.keras.layers.Layer + ] + + all_layers = set(multi_backend_layers + tf_only_layers) + + # Get tested layers + jit_compatible_tested = {param[0] for param in JIT_COMPATIBLE_LAYERS} + jit_incompatible_tested = {param[0] for param in JIT_INCOMPATIBLE_LAYERS} + + # Check coverage + tested_layers = jit_compatible_tested | jit_incompatible_tested + missing = all_layers - tested_layers + assert ( + not missing + ), f"Layers missing from JIT tests: {[l.__name__ for l in missing]}" + + # Check no duplicates + duplicates = jit_compatible_tested & jit_incompatible_tested + assert ( + not duplicates + ), f"Layers in both JIT test lists: {[l.__name__ for l in duplicates]}" diff --git a/tests/kamae/tensorflow/test_layer_serialisation.py b/tests/kamae/keras/test_layer_serialisation.py similarity index 92% rename from tests/kamae/tensorflow/test_layer_serialisation.py rename to tests/kamae/keras/test_layer_serialisation.py index 40ec92ba..1ce4fa52 100644 --- a/tests/kamae/tensorflow/test_layer_serialisation.py +++ b/tests/kamae/keras/test_layer_serialisation.py @@ -18,22 +18,18 @@ import numpy as np import pytest import tensorflow as tf -from packaging.version import Version - -import kamae.tensorflow.layers as layers_mod -keras_version = Version(keras.__version__) -# If keras >= 2.13.0, we need to enable unsafe deserialization in order to load the -# LambdaFunctionLayer. -# Before 2.13.0, keras the default behavior is to allow unsafe deserialization. -if keras_version >= Version("2.13.0"): - from keras.src.saving import serialization_lib +# Enable unsafe deserialization for LambdaFunctionLayer (Keras 3) +from keras.src.saving import serialization_lib +from packaging.version import Version - serialization_lib.enable_unsafe_deserialization() +import kamae.keras.core.layers as core_layers_mod +import kamae.keras.tensorflow.layers as layers_mod -is_keras_3 = keras_version >= Version("3.0.0") +serialization_lib.enable_unsafe_deserialization() -from kamae.tensorflow.layers import ( +# Multi-backend layers +from kamae.keras.core.layers import ( AbsoluteValueLayer, ArrayConcatenateLayer, ArrayCropLayer, @@ -42,51 +38,57 @@ ArraySubtractMinimumLayer, BearingAngleLayer, BinLayer, - BloomEncodeLayer, - BucketizeLayer, ConditionalStandardScaleLayer, CosineSimilarityLayer, - CurrentDateLayer, - CurrentDateTimeLayer, - CurrentUnixTimestampLayer, - DateAddLayer, - DateDiffLayer, - DateParseLayer, - DateTimeToUnixTimestampLayer, DivideLayer, ExpLayer, ExponentLayer, - HashIndexLayer, HaversineDistanceLayer, IdentityLayer, - IfStatementLayer, ImputeLayer, - LambdaFunctionLayer, - ListMaxLayer, - ListMeanLayer, - ListMedianLayer, - ListMinLayer, - ListRankLayer, - ListStdDevLayer, LogicalAndLayer, LogicalNotLayer, LogicalOrLayer, LogLayer, MaxLayer, MeanLayer, - MinHashIndexLayer, MinLayer, MinMaxScaleLayer, ModuloLayer, MultiplyLayer, NumericalIfStatementLayer, - OneHotEncodeLayer, - OneHotLayer, - OrdinalArrayEncodeLayer, PairwiseCosineSimilarityLayer, RoundLayer, RoundToDecimalLayer, StandardScaleLayer, + SubtractLayer, + SumLayer, +) + +# TF-only layers +from kamae.keras.tensorflow.layers import ( + BloomEncodeLayer, + BucketizeLayer, + CurrentDateLayer, + CurrentDateTimeLayer, + CurrentUnixTimestampLayer, + DateAddLayer, + DateDiffLayer, + DateParseLayer, + DateTimeToUnixTimestampLayer, + HashIndexLayer, + IfStatementLayer, + LambdaFunctionLayer, + ListMaxLayer, + ListMeanLayer, + ListMedianLayer, + ListMinLayer, + ListRankLayer, + ListStdDevLayer, + MinHashIndexLayer, + OneHotEncodeLayer, + OneHotLayer, + OrdinalArrayEncodeLayer, StringAffixLayer, StringArrayConstantLayer, StringCaseLayer, @@ -101,8 +103,6 @@ StringReplaceLayer, StringToStringListLayer, SubStringDelimAtIndexLayer, - SubtractLayer, - SumLayer, UnixTimestampToDateTimeLayer, ) @@ -119,6 +119,12 @@ {"axis": -2}, False, ), + ( + ArrayReduceMaxLayer, + [tf.random.normal((32, 10))], + {"default_value": 0.0}, + False, + ), (ArraySplitLayer, [tf.random.normal((32, 10, 100, 3))], {"axis": -2}, False), ( ArraySubtractMinimumLayer, @@ -132,12 +138,6 @@ {"array_length": 3, "pad_value": "-1"}, False, ), - ( - ArrayReduceMaxLayer, - [tf.random.normal((32, 10, 5))], - {"default_value": 0.0}, - False, - ), ( BearingAngleLayer, [ @@ -189,8 +189,8 @@ ), ( PairwiseCosineSimilarityLayer, - [tf.random.normal((32, 8)), tf.random.normal((32, 40))], - {"embedding_dim": 8}, + [tf.random.normal((32, 4)), tf.random.normal((32, 12))], + {"embedding_dim": 4}, False, ), (CurrentDateLayer, [tf.constant(100, shape=(100, 10, 1))], None, False), @@ -368,6 +368,7 @@ "function": lambda x: tf.square(x) - tf.math.log(x), "input_dtype": "float", "output_dtype": "float", + "output_shape": (3,), # Required for Keras 3 serialization }, False, ), @@ -558,9 +559,12 @@ def test_layer_serialisation( Tests whether a layer is serialisable in a Model and that the output from the model matches calling the layer directly. """ - if is_keras_3 and layer_cls == LambdaFunctionLayer: - # TODO: Understand why - pytest.skip(reason="LambdaFunctionLayer does not serialise properly in keras 3") + if layer_cls == LambdaFunctionLayer: + # LambdaFunctionLayer cannot serialize/deserialize lambda functions that reference + # external modules (like tf) - this is a fundamental limitation of Python lambda serialization + pytest.skip( + reason="LambdaFunctionLayer with module references cannot serialize in Keras 3" + ) if kwargs is None: kwargs = {} @@ -585,10 +589,8 @@ def test_layer_serialisation( # check with the functional API model = tf.keras.Model(inputs=model_inputs, outputs=model_outputs) - # Test saving and reloading - model_path = os.path.join(tmp_path, layer.name) - if is_keras_3: - model_path += ".keras" + # Test saving and reloading (Keras 3 .keras format) + model_path = os.path.join(tmp_path, layer.name + ".keras") model.save(model_path) reloaded_model = tf.keras.models.load_model(model_path) @@ -633,10 +635,20 @@ def test_layer_serialisation( def test_all_layers_tested_for_serialisation(): """ - Checks that all layers in kamae.tensorflow.layers have a serialisation test. + Checks that all layers (both multi-backend and TF-only) have a serialisation test. """ - # Get all classes defined in kamae.tensorflow.layers - all_layers = [ + # Get all classes from kamae.keras.core.layers (multi-backend) + multi_backend_layers = [ + obj + for name, obj in vars(core_layers_mod).items() + if isinstance(obj, type) + and issubclass(obj, keras.Layer) + and obj is not keras.Layer + and name != "BaseLayer" # Exclude base class + ] + + # Get all classes from kamae.tensorflow.layers (TF-only) + tf_only_layers = [ obj for name, obj in vars(layers_mod).items() if isinstance(obj, type) @@ -644,6 +656,8 @@ def test_all_layers_tested_for_serialisation(): and obj is not tf.keras.layers.Layer ] + all_layers = multi_backend_layers + tf_only_layers + # Extract all layer_cls from the test parameterization parametrize_mark = next( mark diff --git a/tests/kamae/sklearn/__init__.py b/tests/kamae/sklearn/__init__.py deleted file mode 100644 index d47f0081..00000000 --- a/tests/kamae/sklearn/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright [2024] Expedia, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/tests/kamae/sklearn/conftest.py b/tests/kamae/sklearn/conftest.py deleted file mode 100644 index c2916c61..00000000 --- a/tests/kamae/sklearn/conftest.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright [2024] Expedia, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pandas as pd -import pytest -import tensorflow as tf - -from kamae.sklearn.params import SingleInputSingleOutputMixin -from kamae.sklearn.transformers import BaseTransformer - - -@pytest.fixture -def example_dataframe(): - example_df = pd.DataFrame( - { - "col1": [1, 4, 7], - "col2": [2, 2, 8], - "col3": [3, 6, 3], - "col4": ["a", "b", "a"], - "col5": ["c", "c", "a"], - "col1_col2_col3": [[1, 2, 3], [4, 2, 6], [7, 8, 3]], - }, - ) - return example_df - - -@pytest.fixture -def example_dataframe_with_nulls(): - example_df = pd.DataFrame( - { - "col1": [None, 4, 7, 7], - "col2": [2, None, 2, 8], - "col3": [3, 6, None, None], - "col4": ["a", "b", None, "a"], - "col5": ["c", None, "a", "a"], - "col1_col2_col3": [[None, 2, 3], [4, None, 6], [7, 8, None], [7, 8, None]], - }, - ) - return example_df - - -@pytest.fixture -def layer_name(): - return "test_layer" - - -@pytest.fixture -def input_col(): - return "test_input" - - -@pytest.fixture -def output_col(): - return "test_output" - - -@pytest.fixture -def tf_layer(): - return tf.keras.layers.Dense(1) - - -@pytest.fixture -def base_transformer(layer_name, output_col, input_col, tf_layer): - class TestTransformer( - BaseTransformer, - SingleInputSingleOutputMixin, - ): - """Test transformer for testing abstract base class LayerTransformer""" - - def __init__(self, input_col, output_col, layer_name): - super().__init__( - input_col=input_col, output_col=output_col, layer_name=layer_name - ) - self.input_col = input_col - self.output_col = output_col - self.layer_name = layer_name - - def fit(self, X: pd.DataFrame, y=None, **kwargs): - return self - - def transform(self, X: pd.DataFrame, y=None, **kwargs): - return X - - def get_tf_layer(self) -> tf.keras.layers.Layer: - return tf_layer - - return TestTransformer( - input_col=input_col, output_col=output_col, layer_name=layer_name - ) diff --git a/tests/kamae/sklearn/estimators/test_standard_scale.py b/tests/kamae/sklearn/estimators/test_standard_scale.py deleted file mode 100644 index e68141e2..00000000 --- a/tests/kamae/sklearn/estimators/test_standard_scale.py +++ /dev/null @@ -1,198 +0,0 @@ -# Copyright [2024] Expedia, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import pandas as pd -import pytest -import tensorflow as tf - -from kamae.sklearn.estimators import StandardScaleEstimator - - -class TestStandardScale: - @pytest.fixture(scope="class") - def standard_scaler_expected(self): - return pd.DataFrame( - { - "col1": [1, 4, 7], - "col2": [2, 2, 8], - "col3": [3, 6, 3], - "col4": ["a", "b", "a"], - "col5": ["c", "c", "a"], - "col1_col2_col3": [[1, 2, 3], [4, 2, 6], [7, 8, 3]], - "scaled_features": [ - [-0.3278688524590164, 0.2886751345948129, -2.886751345948129], - [0.6557377049180328, 0.2886751345948129, -1.1547005383792517], - [1.639344262295082, 2.0207259421636903, -2.886751345948129], - ], - } - ) - - @pytest.mark.parametrize( - "input_col, output_col, expected_mean, expected_var", - [ - ( - "col1_col2_col3", - "scaled_features", - [4.0, 4.0, 4.0], - [6.0, 8.0, 2.0], - ), - ], - ) - def test_sklearn_standard_scaler_fit( - self, - example_dataframe, - input_col, - output_col, - expected_mean, - expected_var, - ): - # when - standard_scaler = StandardScaleEstimator( - input_col=input_col, - output_col=output_col, - layer_name="standard_scaler", - ) - actual = standard_scaler.fit(example_dataframe) - # then - actual_mean, actual_var = actual.mean_, actual.var_ - np.testing.assert_almost_equal(np.array(actual_mean), np.array(expected_mean)) - np.testing.assert_almost_equal(np.array(actual_var), np.array(expected_var)) - - @pytest.mark.parametrize( - "input_col, output_col, expected_mean, expected_var", - [ - ( - "col1_col2_col3", - "scaled_features", - [6.0, 6.0, 4.5], - [2.0, 8.0, 2.25], - ), - ], - ) - def test_sklearn_standard_scaler_fit_with_nulls( - self, - example_dataframe_with_nulls, - input_col, - output_col, - expected_mean, - expected_var, - ): - # when - standard_scaler = StandardScaleEstimator( - input_col=input_col, - output_col=output_col, - layer_name="standard_scaler", - ) - actual = standard_scaler.fit(example_dataframe_with_nulls) - # then - actual_mean, actual_var = actual.mean_, actual.var_ - np.testing.assert_almost_equal(np.array(actual_mean), np.array(expected_mean)) - np.testing.assert_almost_equal(np.array(actual_var), np.array(expected_var)) - - @pytest.mark.parametrize( - "input_col, output_col, mean, var, expected_dataframe", - [ - ( - "col1_col2_col3", - "scaled_features", - [2.0, 1.0, 8.0], - [9.3025, 12.0, 3.0], - "standard_scaler_expected", - ), - ], - ) - def test_sklearn_standard_scaler_transform( - self, - example_dataframe, - input_col, - output_col, - mean, - var, - expected_dataframe, - request, - ): - # given - expected = request.getfixturevalue(expected_dataframe) - # when - standard_scaler_model = StandardScaleEstimator( - input_col=input_col, - output_col=output_col, - layer_name="standard_scaler", - ) - standard_scaler_model.mean_ = mean - standard_scaler_model.var_ = var - standard_scaler_model.scale_ = np.sqrt(var) - actual = standard_scaler_model.transform(example_dataframe) - # then - pd.testing.assert_frame_equal(actual, expected) - - @pytest.mark.parametrize( - "input_tensor, mean, stddev", - [ - ( - tf.constant([[1.0, 2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0, 10.0]]), - [3.0, 10.0, -1.0, 4.0, 2.0], - [2.0, 2.0, 1.0, 3.0, 4.0], - ), - ( - tf.constant( - [ - [1.0, 2.0, 3.0, 4.0, 5.0], - [6.0, 7.0, 8.0, 9.0, 10.0], - [-1.0, 51.0, 12.89, 0.0, 1.0], - ] - ), - [3.0, 10.0, -1.0, 4.0, 2.0], - [2.0, 2.0, 1.0, 3.0, 4.0], - ), - ( - tf.constant([[-1.0, -2.0, 3.0, 5.0], [6.0, -7.0, -9.0, 10.0]]), - [3.0, -1.0, 4.0, 2.0], - [2.0, 2.0, 1.0, 4.0], - ), - ( - tf.constant([[1.0, 2.0], [6.0, 10.0]]), - [-1.0, 4.0], - [2.0, 4.0], - ), - ], - ) - def test_standard_scaler_spark_tf_parity(self, input_tensor, mean, stddev): - # given - transformer = StandardScaleEstimator( - input_col="input", - output_col="output", - layer_name="standard_scaler", - ) - transformer.mean_ = mean - transformer.var_ = np.power(stddev, 2) - transformer.scale_ = stddev - - # when - pd_df = pd.DataFrame( - { - "input": input_tensor.numpy().tolist(), - } - ) - pd_values = transformer.transform(pd_df)["output"].values.tolist() - tensorflow_values = transformer.get_tf_layer()(input_tensor).numpy().tolist() - - # then - np.testing.assert_almost_equal( - pd_values, - tensorflow_values, - decimal=6, - err_msg="Spark and Tensorflow transform outputs are not equal", - ) diff --git a/tests/kamae/sklearn/pipeline/test_pipeline.py b/tests/kamae/sklearn/pipeline/test_pipeline.py deleted file mode 100644 index 69ab0cc6..00000000 --- a/tests/kamae/sklearn/pipeline/test_pipeline.py +++ /dev/null @@ -1,498 +0,0 @@ -# Copyright [2024] Expedia, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from shutil import rmtree - -import joblib -import pandas as pd -import pytest -import tensorflow as tf - -from kamae.sklearn.estimators import StandardScaleEstimator -from kamae.sklearn.pipeline import KamaeSklearnPipeline -from kamae.sklearn.transformers import ( - ArrayConcatenateTransformer, - ArraySplitTransformer, - IdentityTransformer, - LogTransformer, -) - - -class TestKamaeSklearnPipeline: - """ - Tests both the pipeline and the pipeline model (fit and transform) - """ - - @pytest.fixture(scope="class") - def test_dir(self): - path = "./tmp_sklearn_test" - os.makedirs(path, exist_ok=True) - yield path - rmtree(path) - - @pytest.fixture(scope="class") - def valid_stages_transforms_only_0(self): - return [ - LogTransformer( - input_col="col1", - output_col="log_col1", - alpha=0.1, - layer_name="log_transform_0", - ), - ArrayConcatenateTransformer( - input_cols=["log_col1", "col2", "col3"], - output_col="features", - layer_name="vector_assembler_0", - ), - ArraySplitTransformer( - input_col="features", - output_cols=["log_col1_sliced", "col2_sliced", "col3_sliced"], - layer_name="vector_slicer_0", - ), - ] - - @pytest.fixture(scope="class") - def valid_stages_transforms_only_1(self): - return [ - LogTransformer( - input_col="col2", - output_col="log_col2", - alpha=5, - layer_name="log_transform_1", - ), - IdentityTransformer( - input_col="col1", - output_col="col1_identity", - layer_name="identity_transform_1", - ), - ArrayConcatenateTransformer( - input_cols=["col1_identity", "log_col2", "col3"], - output_col="features", - layer_name="vector_assembler_1", - ), - ArraySplitTransformer( - input_col="features", - output_cols=["col1_sliced", "log_col2_sliced", "col3_sliced"], - layer_name="vector_slicer_1", - ), - ] - - @pytest.fixture(scope="class") - def valid_stages_0(self): - return [ - ArrayConcatenateTransformer( - input_cols=["col1", "col2", "col3"], - output_col="features", - layer_name="vector_assembler_0", - ), - StandardScaleEstimator( - input_col="features", - output_col="features_scaled", - layer_name="standard_scaler_0", - ), - ] - - @pytest.fixture(scope="class") - def expected_dataframe_stage_0(self): - return pd.DataFrame( - { - "col1": [1, 4, 7], - "col2": [2, 2, 8], - "col3": [3, 6, 3], - "col4": ["a", "b", "a"], - "col5": ["c", "c", "a"], - "col1_col2_col3": [[1, 2, 3], [4, 2, 6], [7, 8, 3]], - "features": [[1, 2, 3], [4, 2, 6], [7, 8, 3]], - "features_scaled": [ - [-1.2247448713915892, -0.7071067811865475, -0.7071067811865475], - [0.0, -0.7071067811865475, 1.414213562373095], - [1.2247448713915892, 1.414213562373095, -0.7071067811865475], - ], - } - ) - - @pytest.fixture(scope="class") - def valid_stages_1(self): - return [ - LogTransformer( - input_col="col3", - output_col="log_col3", - alpha=0.1, - layer_name="log_transform_2", - ), - ArrayConcatenateTransformer( - input_cols=["col1_col2_col3", "log_col3"], - output_col="features", - layer_name="vector_assembler_2", - ), - StandardScaleEstimator( - input_col="features", - output_col="features_scaled", - layer_name="standard_scaler_2", - ), - IdentityTransformer( - input_col="col4", - output_col="col4_identity", - layer_name="identity_transform_2", - ), - ] - - @pytest.fixture(scope="class") - def expected_dataframe_stage_1(self): - return pd.DataFrame( - { - "col1": [1, 4, 7], - "col2": [2, 2, 8], - "col3": [3, 6, 3], - "col4": ["a", "b", "a"], - "col5": ["c", "c", "a"], - "col1_col2_col3": [[1, 2, 3], [4, 2, 6], [7, 8, 3]], - "log_col3": [ - 1.1314021114911006, - 1.8082887711792655, - 1.1314021114911006, - ], - "features": [ - [1, 2, 3, 1.1314021114911006], - [4, 2, 6, 1.8082887711792655], - [7, 8, 3, 1.1314021114911006], - ], - "features_scaled": [ - [ - -1.2247448713915892, - -0.7071067811865475, - -0.7071067811865475, - -0.7071067811865468, - ], - [0.0, -0.7071067811865475, 1.414213562373095, 1.4142135623730956], - [ - 1.2247448713915892, - 1.414213562373095, - -0.7071067811865475, - -0.7071067811865468, - ], - ], - "col4_identity": ["a", "b", "a"], - } - ) - - @pytest.mark.parametrize( - "stages", - [ - "valid_stages_0", - "valid_stages_1", - ], - ) - def test_sklearn_read_write_pipeline( - self, example_dataframe, test_dir, stages, request - ): - stages = request.getfixturevalue(stages) - pipeline = KamaeSklearnPipeline(steps=[(s.layer_name, s) for s in stages]) - joblib.dump(pipeline, f"{test_dir}/pipeline") - pipeline_loaded = joblib.load(f"{test_dir}/pipeline") - pipeline.fit(example_dataframe) - pipeline_loaded.fit(example_dataframe) - orig_actual = pipeline.transform(example_dataframe) - loaded_actual = pipeline_loaded.transform(example_dataframe) - pd.testing.assert_frame_equal(orig_actual, loaded_actual) - - @pytest.mark.parametrize( - "stages, expected_dataframe", - [ - ("valid_stages_0", "expected_dataframe_stage_0"), - ("valid_stages_1", "expected_dataframe_stage_1"), - ], - ) - def test_sklearn_pipeline( - self, stages, example_dataframe, expected_dataframe, request - ): - stages = request.getfixturevalue(stages) - pipeline = KamaeSklearnPipeline(steps=[(s.layer_name, s) for s in stages]) - - pipeline.fit(example_dataframe) - - transformed_df = pipeline.transform(example_dataframe) - expected = request.getfixturevalue(expected_dataframe) - pd.testing.assert_frame_equal(transformed_df, expected) - - @pytest.mark.parametrize( - "stages, input_tensors, tf_input_schema, expected_output", - [ - ( - "valid_stages_0", - { - "col1": tf.constant( - [ - [[1], [4], [7]], - ], - dtype=tf.float32, - ), - "col2": tf.constant( - [ - [[2], [2], [8]], - ], - dtype=tf.float32, - ), - "col3": tf.constant( - [ - [[3], [6], [3]], - ], - dtype=tf.float32, - ), - }, - [ - { - "name": "col1", - "dtype": "float32", - "shape": (None, 1), - }, - { - "name": "col2", - "dtype": "float32", - "shape": (None, 1), - }, - { - "name": "col3", - "dtype": "float32", - "shape": (None, 1), - }, - ], - tf.constant( - [ - [ - [-1.2247448, -0.70710677, -0.70710677], - [0.0, -0.70710677, 1.4142135], - [1.2247448, 1.4142135, -0.70710677], - ] - ] - ), - ), - ( - "valid_stages_1", - { - "col1_col2_col3": tf.constant( - [ - [[1, 2, 3], [4, 2, 6], [7, 8, 3]], - ], - dtype=tf.float32, - ), - "col3": tf.constant( - [ - [[3], [6], [3]], - ], - dtype=tf.float32, - ), - "col4": tf.constant( - [ - [["a"], ["b"], ["a"]], - ], - dtype=tf.string, - ), - }, - [ - { - "name": "col1_col2_col3", - "dtype": "float32", - "shape": (None, 3), - }, - { - "name": "col3", - "dtype": "float32", - "shape": (None, 1), - }, - { - "name": "col4", - "dtype": "string", - "shape": (None, 1), - }, - ], - [ - tf.constant( - [ - [["a"], ["b"], ["a"]], - ], - dtype=tf.string, - ), - tf.constant( - [ - [ - [-1.2247448, -0.70710677, -0.70710677, -0.7071067], - [0.0, -0.70710677, 1.4142135, 1.4142138], - [1.2247448, 1.4142135, -0.70710677, -0.7071067], - ] - ], - dtype=tf.float32, - ), - ], - ), - ( - "valid_stages_transforms_only_0", - { - "col1": tf.constant( - [ - [[1.0], [4.0], [7.0]], - ], - dtype=tf.float32, - ), - "col2": tf.constant( - [ - [[2.0], [2.0], [8.0]], - ], - dtype=tf.float32, - ), - "col3": tf.constant( - [ - [[3.0], [6.0], [3.0]], - ], - dtype=tf.float32, - ), - }, - [ - { - "name": "col1", - "dtype": "float32", - "shape": (None, 1), - }, - { - "name": "col2", - "dtype": "float32", - "shape": (None, 1), - }, - { - "name": "col3", - "dtype": "float32", - "shape": (None, 1), - }, - ], - [ - tf.constant( - [ - [[0.0953102], [1.4109869], [1.9600948]], - ], - dtype=tf.float32, - ), - tf.constant( - [ - [[2.0], [2.0], [8.0]], - ], - dtype=tf.float32, - ), - tf.constant( - [ - [[3.0], [6.0], [3.0]], - ], - dtype=tf.float32, - ), - ], - ), - ( - "valid_stages_transforms_only_1", - { - "col1": tf.constant( - [ - [[1.0], [4.0], [7.0]], - ], - dtype=tf.float32, - ), - "col2": tf.constant( - [ - [[2.0], [2.0], [8.0]], - ], - dtype=tf.float32, - ), - "col3": tf.constant( - [ - [[3.0], [6.0], [3.0]], - ], - dtype=tf.float32, - ), - }, - [ - { - "name": "col1", - "dtype": "float32", - "shape": (None, 1), - }, - { - "name": "col2", - "dtype": "float32", - "shape": (None, 1), - }, - { - "name": "col3", - "dtype": "float32", - "shape": (None, 1), - }, - ], - [ - tf.constant( - [ - [[1.0], [4.0], [7.0]], - ], - dtype=tf.float32, - ), - tf.constant( - [ - [[1.9459101], [1.9459101], [2.5649493]], - ], - dtype=tf.float32, - ), - tf.constant( - [ - [[3.0], [6.0], [3.0]], - ], - dtype=tf.float32, - ), - ], - ), - ], - ) - def test_keras_model( - self, - stages, - input_tensors, - tf_input_schema, - expected_output, - example_dataframe, - request, - ): - stages = request.getfixturevalue(stages) - pipeline = KamaeSklearnPipeline( - steps=[(stage.layer_name, stage) for stage in stages] - ) - - pipeline.fit(example_dataframe) - - keras_model = pipeline.build_keras_model(tf_input_schema=tf_input_schema) - - actual = keras_model(input_tensors) - - if isinstance(actual, list): - for a, e in zip(actual, expected_output): - if a.dtype == "string": - tf.debugging.assert_equal(a, e) - else: - tf.debugging.assert_near(a, e, atol=1e-6) - elif isinstance(actual, dict): - for a, e in zip(actual.values(), expected_output): - if a.dtype == "string": - tf.debugging.assert_equal(a, e) - else: - tf.debugging.assert_near(a, e, atol=1e-6) - else: - if actual.dtype == "string": - tf.debugging.assert_equal(actual, expected_output) - else: - tf.debugging.assert_near(actual, expected_output, atol=1e-6) diff --git a/tests/kamae/sklearn/transformers/test_array_concatenate.py b/tests/kamae/sklearn/transformers/test_array_concatenate.py deleted file mode 100644 index 64c46801..00000000 --- a/tests/kamae/sklearn/transformers/test_array_concatenate.py +++ /dev/null @@ -1,142 +0,0 @@ -# Copyright [2024] Expedia, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import pandas as pd -import pytest -import tensorflow as tf - -from kamae.sklearn.transformers import ArrayConcatenateTransformer - - -class TestArrayConcatenate: - @pytest.fixture(scope="class") - def array_concatenate_col1_col2_expected(self): - return pd.DataFrame( - { - "col1": [1, 4, 7], - "col2": [2, 2, 8], - "col3": [3, 6, 3], - "col4": ["a", "b", "a"], - "col5": ["c", "c", "a"], - "col1_col2_col3": [[1, 2, 3], [4, 2, 6], [7, 8, 3]], - "vec_col1_col2": [[1, 2], [4, 2], [7, 8]], - }, - ) - - @pytest.fixture(scope="class") - def array_concatenate_col1_col2_col3_expected(self): - return pd.DataFrame( - { - "col1": [1, 4, 7], - "col2": [2, 2, 8], - "col3": [3, 6, 3], - "col4": ["a", "b", "a"], - "col5": ["c", "c", "a"], - "col1_col2_col3": [[1, 2, 3], [4, 2, 6], [7, 8, 3]], - "vec_col1_col2_col3": [[1, 2, 3], [4, 2, 6], [7, 8, 3]], - }, - ) - - @pytest.fixture(scope="class") - def array_concatenate_col4_col5_expected(self): - return pd.DataFrame( - { - "col1": [1, 4, 7], - "col2": [2, 2, 8], - "col3": [3, 6, 3], - "col4": ["a", "b", "a"], - "col5": ["c", "c", "a"], - "col1_col2_col3": [[1, 2, 3], [4, 2, 6], [7, 8, 3]], - "vec_col4_col5": [["a", "c"], ["b", "c"], ["a", "a"]], - }, - ) - - @pytest.mark.parametrize( - "input_cols, output_col, expected_dataframe", - [ - (["col1", "col2"], "vec_col1_col2", "array_concatenate_col1_col2_expected"), - ( - ["col1", "col2", "col3"], - "vec_col1_col2_col3", - "array_concatenate_col1_col2_col3_expected", - ), - (["col4", "col5"], "vec_col4_col5", "array_concatenate_col4_col5_expected"), - ], - ) - def test_sklearn_array_concatenate( - self, - example_dataframe, - input_cols, - output_col, - expected_dataframe, - request, - ): - # given - expected = request.getfixturevalue(expected_dataframe) - # when - transformer = ArrayConcatenateTransformer( - input_cols=input_cols, - output_col=output_col, - layer_name="array_concatenate", - ) - actual = transformer.transform(example_dataframe) - # then - pd.testing.assert_frame_equal(actual, expected) - - @pytest.mark.parametrize( - "input_tensors", - [ - [ - tf.constant([[1.1], [2.0], [3.0], [4.0], [5.0]]), - tf.constant([[6.05], [7.0], [8.0], [9.0], [10.0]]), - tf.constant([[11.01], [12.0], [13.0], [14.0], [15.0]]), - ], - [ - tf.constant([[6.7], [2.3], [3.7], [4.1], [5.0111]]), - tf.constant([[4.7], [5.3], [3.7], [6.1], [8.0111]]), - tf.constant([[2.7], [67.3], [3.7], [8.1], [9.0111]]), - tf.constant([[45.7], [3.3], [3.7], [8.1], [10.0111]]), - tf.constant([[6.9], [23.3], [3.7], [10.111], [15.0111]]), - ], - [ - tf.constant([[1.1], [2.0], [3.0], [4.0], [5.0], [7.90], [345.890]]), - tf.constant([[6.05], [7.0], [8.0], [9.0], [10.0], [4567.0], [1000.0]]), - ], - ], - ) - def test_array_concatenate_spark_tf_parity(self, input_tensors): - col_names = [f"input{i}" for i in range(len(input_tensors))] - - # given - transformer = ArrayConcatenateTransformer( - input_cols=col_names, - output_col="output", - layer_name="array_concatenate", - ) - - # when - pd_df = pd.DataFrame( - {f"input{i}": inp.numpy().tolist() for i, inp in enumerate(input_tensors)} - ) - pd_values = transformer.transform(pd_df)["output"].values.tolist() - tensorflow_values = transformer.get_tf_layer()(input_tensors).numpy().tolist() - - # then - np.testing.assert_almost_equal( - pd_values, - tensorflow_values, - decimal=6, - err_msg="Spark and Tensorflow transform outputs are not equal", - ) diff --git a/tests/kamae/sklearn/transformers/test_array_split.py b/tests/kamae/sklearn/transformers/test_array_split.py deleted file mode 100644 index 9a84bdee..00000000 --- a/tests/kamae/sklearn/transformers/test_array_split.py +++ /dev/null @@ -1,129 +0,0 @@ -# Copyright [2024] Expedia, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import pandas as pd -import pytest -import tensorflow as tf - -from kamae.sklearn.transformers import ArraySplitTransformer - - -class TestArraySplit: - @pytest.fixture(scope="class") - def array_split_col1_col2_col3_expected(self): - return pd.DataFrame( - { - "col1": [1, 4, 7], - "col2": [2, 2, 8], - "col3": [3, 6, 3], - "col4": ["a", "b", "a"], - "col5": ["c", "c", "a"], - "col1_col2_col3": [[1, 2, 3], [4, 2, 6], [7, 8, 3]], - "slice_col1": [1, 4, 7], - "slice_col2": [2, 2, 8], - "slice_col3": [3, 6, 3], - }, - ) - - @pytest.mark.parametrize( - "input_col, output_cols, expected_dataframe", - [ - ( - "col1_col2_col3", - ["slice_col1", "slice_col2", "slice_col3"], - "array_split_col1_col2_col3_expected", - ), - ], - ) - def test_sklearn_array_split( - self, - example_dataframe, - input_col, - output_cols, - expected_dataframe, - request, - ): - # given - expected = request.getfixturevalue(expected_dataframe) - # when - transformer = ArraySplitTransformer( - input_col=input_col, - output_cols=output_cols, - layer_name="array_split", - ) - actual = transformer.transform(example_dataframe) - # then - pd.testing.assert_frame_equal(actual, expected) - - @pytest.mark.parametrize( - "input_tensor", - [ - tf.constant( - [ - [1.0, 6.0, 11.0], - [2.0, 7.0, 12.0], - [3.0, 8.0, 13.0], - [4.0, 9.0, 14.0], - [5.0, 10.0, 15.0], - ] - ), - tf.constant( - [ - [6.7, 4.7, 2.7, 45.7, 6.9], - [2.3, 5.3, 67.3, 3.3, 23.3], - [3.7, 3.7, 3.7, 3.7, 3.7], - [4.1, 6.1, 8.1, 8.1, 10.111], - [5.0111, 8.0111, 9.0111, 10.0111, 15.0111], - ] - ), - tf.constant( - [ - [1.1, 6.05], - [2.0, 7.0], - [3.0, 8.0], - [4.0, 9.0], - [5.0, 10.0], - [7.90, 4567.0], - [345.890, 1000.0], - ] - ), - ], - ) - def test_array_split_sklearn_tf_parity(self, input_tensor): - col_names = [f"output{i}" for i in range(input_tensor.shape[1])] - # given - transformer = ArraySplitTransformer( - input_col="input", - output_cols=col_names, - layer_name="array_split", - ) - # when - pd_df = pd.DataFrame( - { - "input": input_tensor.numpy().tolist(), - } - ) - pd_values = [transformer.transform(pd_df)[c].values.tolist() for c in col_names] - tensorflow_values = [ - x.numpy().tolist() for x in transformer.get_tf_layer()(input_tensor) - ] - - # then - np.testing.assert_almost_equal( - np.array(pd_values).flatten(), - np.array(tensorflow_values).flatten(), - decimal=6, - err_msg="Scikit-Learn and Tensorflow transform outputs are not equal", - ) diff --git a/tests/kamae/sklearn/transformers/test_base.py b/tests/kamae/sklearn/transformers/test_base.py deleted file mode 100644 index 241f2b4e..00000000 --- a/tests/kamae/sklearn/transformers/test_base.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright [2024] Expedia, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -class TestBaseTransformer: - def test_construct_layer_info( - self, - base_transformer, - layer_name, - output_col, - input_col, - tf_layer, - ): - # when - layer_info = base_transformer.construct_layer_info() - # then - assert layer_info["name"] == layer_name - assert layer_info["layer"] == tf_layer - assert layer_info["inputs"] == [input_col] - assert layer_info["outputs"] == [output_col] diff --git a/tests/kamae/sklearn/transformers/test_identity.py b/tests/kamae/sklearn/transformers/test_identity.py deleted file mode 100644 index 91612f0d..00000000 --- a/tests/kamae/sklearn/transformers/test_identity.py +++ /dev/null @@ -1,154 +0,0 @@ -# Copyright [2024] Expedia, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import pandas as pd -import pytest -import tensorflow as tf - -from kamae.sklearn.transformers import IdentityTransformer - - -class TestIdentity: - @pytest.fixture(scope="class") - def identity_transform_col1_expected(self): - return pd.DataFrame( - { - "col1": [1, 4, 7], - "col2": [2, 2, 8], - "col3": [3, 6, 3], - "col4": ["a", "b", "a"], - "col5": ["c", "c", "a"], - "col1_col2_col3": [[1, 2, 3], [4, 2, 6], [7, 8, 3]], - "iden_col1": [1, 4, 7], - }, - ) - - @pytest.fixture(scope="class") - def identity_transform_col2_expected(self): - return pd.DataFrame( - { - "col1": [1, 4, 7], - "col2": [2, 2, 8], - "col3": [3, 6, 3], - "col4": ["a", "b", "a"], - "col5": ["c", "c", "a"], - "col1_col2_col3": [[1, 2, 3], [4, 2, 6], [7, 8, 3]], - "iden_col2": [2, 2, 8], - }, - ) - - @pytest.fixture(scope="class") - def identity_transform_col3_expected(self): - return pd.DataFrame( - { - "col1": [1, 4, 7], - "col2": [2, 2, 8], - "col3": [3, 6, 3], - "col4": ["a", "b", "a"], - "col5": ["c", "c", "a"], - "col1_col2_col3": [[1, 2, 3], [4, 2, 6], [7, 8, 3]], - "iden_col3": [3, 6, 3], - }, - ) - - @pytest.fixture(scope="class") - def identity_transform_col4_expected(self): - return pd.DataFrame( - { - "col1": [1, 4, 7], - "col2": [2, 2, 8], - "col3": [3, 6, 3], - "col4": ["a", "b", "a"], - "col5": ["c", "c", "a"], - "col1_col2_col3": [[1, 2, 3], [4, 2, 6], [7, 8, 3]], - "iden_col4": ["a", "b", "a"], - }, - ) - - @pytest.fixture(scope="class") - def identity_transform_col5_expected(self): - return pd.DataFrame( - { - "col1": [1, 4, 7], - "col2": [2, 2, 8], - "col3": [3, 6, 3], - "col4": ["a", "b", "a"], - "col5": ["c", "c", "a"], - "col1_col2_col3": [[1, 2, 3], [4, 2, 6], [7, 8, 3]], - "iden_col5": ["c", "c", "a"], - }, - ) - - @pytest.mark.parametrize( - "input_col, output_col, expected_dataframe", - [ - ("col1", "iden_col1", "identity_transform_col1_expected"), - ("col2", "iden_col2", "identity_transform_col2_expected"), - ("col3", "iden_col3", "identity_transform_col3_expected"), - ("col4", "iden_col4", "identity_transform_col4_expected"), - ("col5", "iden_col5", "identity_transform_col5_expected"), - ], - ) - def test_sklearn_identity_transform( - self, - example_dataframe, - input_col, - output_col, - expected_dataframe, - request, - ): - # given - expected = request.getfixturevalue(expected_dataframe) - # when - transformer = IdentityTransformer( - input_col=input_col, - output_col=output_col, - layer_name="identity_transform", - ) - actual = transformer.transform(example_dataframe) - # then - pd.testing.assert_frame_equal(actual, expected) - - @pytest.mark.parametrize( - "input_tensor", - [ - (tf.constant([1.0, 4.0, 7.0, 8.0])), - (tf.constant([2.0, 5.0, 1.0])), - (tf.constant([-1.0, 7.0])), - (tf.constant([0.0, 6.0, 3.0])), - (tf.constant([2.0, 5.0, 1.0, 5.0, 2.5])), - ], - ) - def test_identity_transform_sklearn_tf_parity(self, input_tensor): - # given - transformer = IdentityTransformer( - input_col="input", output_col="output", layer_name="identity_transform" - ) - # when - pd_df = pd.DataFrame( - { - "input": input_tensor.numpy().tolist(), - } - ) - pd_values = transformer.transform(pd_df)["output"].values.tolist() - tensorflow_values = transformer.get_tf_layer()(input_tensor).numpy().tolist() - - # then - np.testing.assert_almost_equal( - pd_values, - tensorflow_values, - decimal=6, - err_msg="Sckit-Learn and Tensorflow transform outputs are not equal", - ) diff --git a/tests/kamae/sklearn/transformers/test_log.py b/tests/kamae/sklearn/transformers/test_log.py deleted file mode 100644 index e1f48cf2..00000000 --- a/tests/kamae/sklearn/transformers/test_log.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright [2024] Expedia, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import pandas as pd -import pytest -import tensorflow as tf - -from kamae.sklearn.transformers import LogTransformer - - -class TestLogTransformLayer: - @pytest.fixture(scope="class") - def log_transform_col1_alpha_1_expected(self): - return pd.DataFrame( - { - "col1": [1, 4, 7], - "col2": [2, 2, 8], - "col3": [3, 6, 3], - "col4": ["a", "b", "a"], - "col5": ["c", "c", "a"], - "col1_col2_col3": [[1, 2, 3], [4, 2, 6], [7, 8, 3]], - "log_col1": [ - 0.6931471805599453, - 1.6094379124341003, - 2.0794415416798357, - ], - }, - ) - - @pytest.fixture(scope="class") - def log_transform_col2_alpha_5_expected(self): - return pd.DataFrame( - { - "col1": [1, 4, 7], - "col2": [2, 2, 8], - "col3": [3, 6, 3], - "col4": ["a", "b", "a"], - "col5": ["c", "c", "a"], - "col1_col2_col3": [[1, 2, 3], [4, 2, 6], [7, 8, 3]], - "log_col2": [ - 1.9459101490553132, - 1.9459101490553132, - 2.5649493574615367, - ], - }, - ) - - @pytest.mark.parametrize( - "input_col, output_col, alpha, expected_dataframe", - [ - ("col1", "log_col1", 1, "log_transform_col1_alpha_1_expected"), - ("col2", "log_col2", 5, "log_transform_col2_alpha_5_expected"), - ], - ) - def test_sklearn_log_transform( - self, - example_dataframe, - input_col, - output_col, - alpha, - expected_dataframe, - request, - ): - # given - expected = request.getfixturevalue(expected_dataframe) - # when - transformer = LogTransformer( - input_col=input_col, - output_col=output_col, - layer_name="log_transform", - alpha=alpha, - ) - actual = transformer.transform(example_dataframe) - # then - pd.testing.assert_frame_equal(actual, expected) - - @pytest.mark.parametrize( - "input_tensor, alpha", - [ - (tf.constant([1.0, 4.0, 7.0, 8.0]), 1), - (tf.constant([2.0, 5.0, 1.0]), 2), - (tf.constant([-1.0, 7.0]), 3), - (tf.constant([0.0, 6.0, 3.0]), 4), - (tf.constant([2.0, 5.0, 1.0, 5.0, 2.5]), 10), - ], - ) - def test_log_transform_sklearn_tf_parity(self, input_tensor, alpha): - # given - transformer = LogTransformer( - input_col="input", - output_col="output", - alpha=alpha, - layer_name="log_transform", - ) - # when - pd_df = pd.DataFrame( - { - "input": input_tensor.numpy().tolist(), - } - ) - pd_values = transformer.transform(pd_df)["output"].values.tolist() - tensorflow_values = transformer.get_tf_layer()(input_tensor).numpy().tolist() - - # then - np.testing.assert_almost_equal( - pd_values, - tensorflow_values, - decimal=6, - err_msg="Scikit-Learn and Tensorflow transform outputs are not equal", - ) diff --git a/tests/kamae/spark/conftest.py b/tests/kamae/spark/conftest.py index b17ad9f3..0d356a68 100644 --- a/tests/kamae/spark/conftest.py +++ b/tests/kamae/spark/conftest.py @@ -19,6 +19,7 @@ from pyspark.sql import DataFrame from pyspark.sql.types import DataType +from kamae.keras.core.backend import ALL_BACKENDS from kamae.spark.params import SingleInputSingleOutputParams from kamae.spark.transformers import BaseTransformer @@ -422,6 +423,9 @@ def test_base_transformer(layer_name, output_col, input_col, tf_layer): class TestTransformer(BaseTransformer, SingleInputSingleOutputParams): """Test transformer for testing abstract base class BaseTransformer""" + supported_backends = ALL_BACKENDS + jit_compatible = False + @property def compatible_dtypes(self) -> Optional[List[DataType]]: return None @@ -429,7 +433,7 @@ def compatible_dtypes(self) -> Optional[List[DataType]]: def _transform(self, dataset: DataFrame) -> DataFrame: return dataset - def get_tf_layer(self) -> tf.keras.layers.Layer: + def get_keras_layer(self) -> tf.keras.layers.Layer: return tf_layer return ( diff --git a/tests/kamae/spark/pipeline/test_pipeline.py b/tests/kamae/spark/pipeline/test_pipeline.py index 256d4855..03c7ddbd 100644 --- a/tests/kamae/spark/pipeline/test_pipeline.py +++ b/tests/kamae/spark/pipeline/test_pipeline.py @@ -592,7 +592,7 @@ def test_spark_pipeline_with_uid_same_as_input( transformed_df.count() @pytest.mark.parametrize( - "stages, input_tensors, tf_input_schema, output_names, expected_output", + "stages, input_tensors, input_schema, output_names, expected_output", [ ( "valid_stages_0", @@ -1004,7 +1004,7 @@ def test_keras_model( self, stages, input_tensors, - tf_input_schema, + input_schema, output_names, expected_output, example_dataframe, @@ -1016,7 +1016,7 @@ def test_keras_model( pipeline_model = pipeline.fit(example_dataframe) keras_model = pipeline_model.build_keras_model( - tf_input_schema=tf_input_schema, output_names=output_names + input_schema=input_schema, output_names=output_names ) actual = keras_model(input_tensors) diff --git a/tests/kamae/spark/test_jit_compatibility.py b/tests/kamae/spark/test_jit_compatibility.py new file mode 100644 index 00000000..038c5145 --- /dev/null +++ b/tests/kamae/spark/test_jit_compatibility.py @@ -0,0 +1,59 @@ +# Copyright [2024] Expedia, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for JIT compatibility attributes on Spark estimators and transformers.""" + +from pyspark.ml import Estimator, Transformer + +import kamae.spark.estimators as estimators_mod +import kamae.spark.transformers as transformers_mod + + +def test_all_spark_operations_define_jit_compatible_and_supported_backends(): + """Test that all Spark transformers and estimators define jit_compatible and supported_backends directly.""" + # Get all transformer classes + transformers = [ + obj + for name, obj in vars(transformers_mod).items() + if isinstance(obj, type) + and issubclass(obj, Transformer) + and obj is not Transformer + and name != "BaseTransformer" # Exclude base class + ] + + # Get all estimator classes + estimators = [ + obj + for name, obj in vars(estimators_mod).items() + if isinstance(obj, type) + and issubclass(obj, Estimator) + and obj is not Estimator + and name != "BaseEstimator" # Exclude base class + ] + + all_operations = transformers + estimators + + for op_cls in all_operations: + assert ( + "jit_compatible" in op_cls.__dict__ + ), f"{op_cls.__name__} must define 'jit_compatible' directly (not inherit it)" + assert isinstance( + op_cls.jit_compatible, bool + ), f"{op_cls.__name__}.jit_compatible must be bool, got {type(op_cls.jit_compatible)}" + assert ( + "supported_backends" in op_cls.__dict__ + ), f"{op_cls.__name__} must define 'supported_backends' directly (not inherit it)" + assert isinstance( + op_cls.supported_backends, frozenset + ), f"{op_cls.__name__}.supported_backends must be frozenset" diff --git a/tests/kamae/spark/transformers/test_absolute_value.py b/tests/kamae/spark/transformers/test_absolute_value.py index 6bac6298..ea05bc4c 100644 --- a/tests/kamae/spark/transformers/test_absolute_value.py +++ b/tests/kamae/spark/transformers/test_absolute_value.py @@ -226,7 +226,7 @@ def test_absolute_value_spark_tf_parity( ) tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()(input_tensor).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensor).numpy().tolist() ] # then diff --git a/tests/kamae/spark/transformers/test_array_concatenate.py b/tests/kamae/spark/transformers/test_array_concatenate.py index 1407b6c8..226d3802 100644 --- a/tests/kamae/spark/transformers/test_array_concatenate.py +++ b/tests/kamae/spark/transformers/test_array_concatenate.py @@ -264,7 +264,7 @@ def test_vector_assembler_spark_tf_parity( vec_decoder = np.vectorize(decoder) tensorflow_values = [ vec_decoder(v) if isinstance(v[0], bytes) else v - for v in transformer.get_tf_layer()(input_tensors).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensors).numpy().tolist() ] # then diff --git a/tests/kamae/spark/transformers/test_array_crop.py b/tests/kamae/spark/transformers/test_array_crop.py index 0a9b6363..24ab7246 100644 --- a/tests/kamae/spark/transformers/test_array_crop.py +++ b/tests/kamae/spark/transformers/test_array_crop.py @@ -483,7 +483,7 @@ def test_array_crop_spark_tf_parity( ) tensorflow_values = [ array_decoder(v) if isinstance(v[0], bytes) else v - for v in transformer.get_tf_layer()(input_tensor).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensor).numpy().tolist() ] # then diff --git a/tests/kamae/spark/transformers/test_array_reduce_max.py b/tests/kamae/spark/transformers/test_array_reduce_max.py index ea8430d4..6fe918aa 100644 --- a/tests/kamae/spark/transformers/test_array_reduce_max.py +++ b/tests/kamae/spark/transformers/test_array_reduce_max.py @@ -116,7 +116,7 @@ def test_spark_tf_parity( ) inputs = tf.constant(rows, dtype=tf.float32) - keras_values = transformer.get_tf_layer()(inputs).numpy().tolist() + keras_values = transformer.get_keras_layer()(inputs).numpy().tolist() np.testing.assert_almost_equal( spark_values, diff --git a/tests/kamae/spark/transformers/test_array_split.py b/tests/kamae/spark/transformers/test_array_split.py index 8f6a3da7..99dea5c1 100644 --- a/tests/kamae/spark/transformers/test_array_split.py +++ b/tests/kamae/spark/transformers/test_array_split.py @@ -220,7 +220,7 @@ def test_vector_slicer_spark_tf_parity( vec_decoder = np.vectorize(decoder) tensorflow_values = [ vec_decoder(v.numpy().tolist()) if isinstance(v[0], bytes) else v - for v in transformer.get_tf_layer()(input_tensor) + for v in transformer.get_keras_layer()(input_tensor) ] # then diff --git a/tests/kamae/spark/transformers/test_array_subtract_minimum.py b/tests/kamae/spark/transformers/test_array_subtract_minimum.py index 42dc6702..390f5bad 100644 --- a/tests/kamae/spark/transformers/test_array_subtract_minimum.py +++ b/tests/kamae/spark/transformers/test_array_subtract_minimum.py @@ -271,7 +271,7 @@ def test_array_subtract_minimum_spark_tf_parity( ) tensorflow_values = [ array_decoder(v) if isinstance(v[0], bytes) else v - for v in transformer.get_tf_layer()(input_tensor).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensor).numpy().tolist() ] # then diff --git a/tests/kamae/spark/transformers/test_bearing_angle.py b/tests/kamae/spark/transformers/test_bearing_angle.py index a769ac79..3feff4df 100644 --- a/tests/kamae/spark/transformers/test_bearing_angle.py +++ b/tests/kamae/spark/transformers/test_bearing_angle.py @@ -146,7 +146,9 @@ def test_bearing_angle_transform_spark_tf_parity( .rdd.map(lambda r: r[0]) .collect() ) - tensorflow_values = transformer.get_tf_layer()(input_tensors).numpy().tolist() + tensorflow_values = ( + transformer.get_keras_layer()(input_tensors).numpy().tolist() + ) # then np.testing.assert_almost_equal( diff --git a/tests/kamae/spark/transformers/test_bin.py b/tests/kamae/spark/transformers/test_bin.py index c7156c7e..c94ca9bc 100644 --- a/tests/kamae/spark/transformers/test_bin.py +++ b/tests/kamae/spark/transformers/test_bin.py @@ -317,7 +317,7 @@ def test_bin_transform_spark_tf_parity( ) tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()(input_tensor).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensor).numpy().tolist() ] # then diff --git a/tests/kamae/spark/transformers/test_bloom_encode.py b/tests/kamae/spark/transformers/test_bloom_encode.py index 36bb57cd..f73edf27 100644 --- a/tests/kamae/spark/transformers/test_bloom_encode.py +++ b/tests/kamae/spark/transformers/test_bloom_encode.py @@ -281,7 +281,7 @@ def test_bloom_encoder_spark_tf_parity( ) tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()(input_tensor).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensor).numpy().tolist() ] # then diff --git a/tests/kamae/spark/transformers/test_bucketize.py b/tests/kamae/spark/transformers/test_bucketize.py index 70e63713..bb5bf286 100644 --- a/tests/kamae/spark/transformers/test_bucketize.py +++ b/tests/kamae/spark/transformers/test_bucketize.py @@ -223,7 +223,7 @@ def test_bucketizer_spark_tf_parity( ) tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()(input_tensor).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensor).numpy().tolist() ] # then diff --git a/tests/kamae/spark/transformers/test_conditional_standard_scale.py b/tests/kamae/spark/transformers/test_conditional_standard_scale.py index 90ca90af..a99bd87d 100644 --- a/tests/kamae/spark/transformers/test_conditional_standard_scale.py +++ b/tests/kamae/spark/transformers/test_conditional_standard_scale.py @@ -448,7 +448,7 @@ def test_cond_standard_scaler_spark_tf_parity( ) tensorflow_values = [ array_decoder(v) if isinstance(v[0], bytes) else v - for v in transformer.get_tf_layer()(input_tensor).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensor).numpy().tolist() ] # then if isinstance(spark_values[0][0], str): diff --git a/tests/kamae/spark/transformers/test_cosine_similarity.py b/tests/kamae/spark/transformers/test_cosine_similarity.py index f76d3660..f19f35bd 100644 --- a/tests/kamae/spark/transformers/test_cosine_similarity.py +++ b/tests/kamae/spark/transformers/test_cosine_similarity.py @@ -311,7 +311,7 @@ def test_cosine_similarity_transform_spark_tf_parity( .collect() ) tensorflow_values = ( - transformer.get_tf_layer()(input_tensors).numpy().flatten().tolist() + transformer.get_keras_layer()(input_tensors).numpy().flatten().tolist() ) # then diff --git a/tests/kamae/spark/transformers/test_current_date.py b/tests/kamae/spark/transformers/test_current_date.py index fa65c7f4..15e7f733 100644 --- a/tests/kamae/spark/transformers/test_current_date.py +++ b/tests/kamae/spark/transformers/test_current_date.py @@ -331,12 +331,12 @@ def test_current_date_transform_spark_tf_parity( ) with patch( - "kamae.tensorflow.layers.current_date.tf.timestamp", + "kamae.keras.tensorflow.layers.current_date.tf.timestamp", lambda: tf.constant(timestamp_seconds, dtype=tf.float64), ): tensorflow_values = [ v.decode("utf-8") - for v in transformer.get_tf_layer()(input_tensor).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensor).numpy().tolist() ] np.testing.assert_equal( spark_values, @@ -395,7 +395,7 @@ def test_current_date_transform_spark_tf_parity_no_patch( tensorflow_values = [ v.decode("utf-8") - for v in transformer.get_tf_layer()(input_tensor).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensor).numpy().tolist() ] np.testing.assert_equal( spark_values, diff --git a/tests/kamae/spark/transformers/test_current_date_time.py b/tests/kamae/spark/transformers/test_current_date_time.py index e69fd6c3..9c45ae42 100644 --- a/tests/kamae/spark/transformers/test_current_date_time.py +++ b/tests/kamae/spark/transformers/test_current_date_time.py @@ -370,12 +370,12 @@ def test_current_date_time_transform_spark_tf_parity( ) with patch( - "kamae.tensorflow.layers.current_date_time.tf.timestamp", + "kamae.keras.tensorflow.layers.current_date_time.tf.timestamp", lambda: tf.constant(timestamp_seconds, dtype=tf.float64), ): tensorflow_values = [ v.decode("utf-8") - for v in transformer.get_tf_layer()(input_tensor).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensor).numpy().tolist() ] np.testing.assert_equal( spark_values, @@ -434,7 +434,7 @@ def test_current_date_transform_spark_tf_parity_no_patch( tensorflow_values = [ v.decode("utf-8") - for v in transformer.get_tf_layer()(input_tensor).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensor).numpy().tolist() ] # Only check correct to the minute, since some time may have passed between # the two calls diff --git a/tests/kamae/spark/transformers/test_current_unix_timestamp.py b/tests/kamae/spark/transformers/test_current_unix_timestamp.py index bb0ae582..f15dc7b9 100644 --- a/tests/kamae/spark/transformers/test_current_unix_timestamp.py +++ b/tests/kamae/spark/transformers/test_current_unix_timestamp.py @@ -361,11 +361,11 @@ def test_current_unix_timestamp_transform_spark_tf_parity( ) with patch( - "kamae.tensorflow.layers.current_unix_timestamp.tf.timestamp", + "kamae.keras.tensorflow.layers.current_unix_timestamp.tf.timestamp", lambda: tf.constant(timestamp_seconds, dtype=tf.float64), ): tensorflow_values = [ - v for v in transformer.get_tf_layer()(input_tensor).numpy().tolist() + v for v in transformer.get_keras_layer()(input_tensor).numpy().tolist() ] np.testing.assert_equal( spark_values, @@ -426,7 +426,7 @@ def test_current_unix_timestamp_transform_spark_tf_parity_no_patch( ) tensorflow_values = [ - v for v in transformer.get_tf_layer()(input_tensor).numpy().tolist() + v for v in transformer.get_keras_layer()(input_tensor).numpy().tolist() ] # Set Spark and Tensorflow to numpy floats spark_values = np.array(spark_values).astype(np.float64) diff --git a/tests/kamae/spark/transformers/test_date_add.py b/tests/kamae/spark/transformers/test_date_add.py index 26e49282..b4e1f8fd 100644 --- a/tests/kamae/spark/transformers/test_date_add.py +++ b/tests/kamae/spark/transformers/test_date_add.py @@ -374,7 +374,7 @@ def test_date_add_transform_single_input_spark_tf_parity( ) tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()(input_tensor).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensor).numpy().tolist() ] # then @@ -455,7 +455,7 @@ def test_date_add_transform_multi_input_spark_tf_parity( ) tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()(input_tensors).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensors).numpy().tolist() ] # then diff --git a/tests/kamae/spark/transformers/test_date_diff.py b/tests/kamae/spark/transformers/test_date_diff.py index e9bf4b7d..c27c416f 100644 --- a/tests/kamae/spark/transformers/test_date_diff.py +++ b/tests/kamae/spark/transformers/test_date_diff.py @@ -410,7 +410,7 @@ def test_date_diff_transform_spark_tf_parity( ) tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()(input_tensors).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensors).numpy().tolist() ] # then diff --git a/tests/kamae/spark/transformers/test_date_parse.py b/tests/kamae/spark/transformers/test_date_parse.py index 571bcb89..b2c7f4f2 100644 --- a/tests/kamae/spark/transformers/test_date_parse.py +++ b/tests/kamae/spark/transformers/test_date_parse.py @@ -1660,7 +1660,7 @@ def test_date_parse_transform_spark_tf_parity( tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()(input_tensor).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensor).numpy().tolist() ] np.testing.assert_equal( diff --git a/tests/kamae/spark/transformers/test_date_time_to_unix_timestamp.py b/tests/kamae/spark/transformers/test_date_time_to_unix_timestamp.py index a2d93aea..173261ee 100644 --- a/tests/kamae/spark/transformers/test_date_time_to_unix_timestamp.py +++ b/tests/kamae/spark/transformers/test_date_time_to_unix_timestamp.py @@ -322,7 +322,7 @@ def test_date_time_to_unix_timestamp_transform_spark_tf_parity( .collect() ) tensorflow_values = [ - v for v in transformer.get_tf_layer()(input_tensor).numpy().tolist() + v for v in transformer.get_keras_layer()(input_tensor).numpy().tolist() ] np.testing.assert_equal( spark_values, diff --git a/tests/kamae/spark/transformers/test_divide.py b/tests/kamae/spark/transformers/test_divide.py index 6ce6531e..ec2fbeb0 100644 --- a/tests/kamae/spark/transformers/test_divide.py +++ b/tests/kamae/spark/transformers/test_divide.py @@ -219,7 +219,7 @@ def test_divide_transform_single_input_spark_tf_parity( ) tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()(input_tensor).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensor).numpy().tolist() ] # then @@ -305,7 +305,9 @@ def test_divide_transform_multiple_input_spark_tf_parity( .rdd.map(lambda r: r[0]) .collect() ) - tensorflow_values = transformer.get_tf_layer()(input_tensors).numpy().tolist() + tensorflow_values = ( + transformer.get_keras_layer()(input_tensors).numpy().tolist() + ) # then if isinstance(spark_values[0], str): diff --git a/tests/kamae/spark/transformers/test_exp.py b/tests/kamae/spark/transformers/test_exp.py index 2e99ee51..0382837f 100644 --- a/tests/kamae/spark/transformers/test_exp.py +++ b/tests/kamae/spark/transformers/test_exp.py @@ -194,7 +194,7 @@ def test_exp_spark_tf_parity( .rdd.map(lambda r: r[0]) .collect() ) - tensorflow_values = transformer.get_tf_layer()(input_tensor).numpy().tolist() + tensorflow_values = transformer.get_keras_layer()(input_tensor).numpy().tolist() # then np.testing.assert_almost_equal( diff --git a/tests/kamae/spark/transformers/test_exponent.py b/tests/kamae/spark/transformers/test_exponent.py index e68d1d48..2497cedd 100644 --- a/tests/kamae/spark/transformers/test_exponent.py +++ b/tests/kamae/spark/transformers/test_exponent.py @@ -194,7 +194,7 @@ def test_exponent_transform_single_input_spark_tf_parity( ) tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()(input_tensor).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensor).numpy().tolist() ] # then @@ -275,7 +275,9 @@ def test_exponent_transform_multiple_input_spark_tf_parity( .rdd.map(lambda r: r[0]) .collect() ) - tensorflow_values = transformer.get_tf_layer()(input_tensors).numpy().tolist() + tensorflow_values = ( + transformer.get_keras_layer()(input_tensors).numpy().tolist() + ) # then if isinstance(spark_values[0], str): diff --git a/tests/kamae/spark/transformers/test_hash_index.py b/tests/kamae/spark/transformers/test_hash_index.py index 5aaf54f9..1d32dcf7 100644 --- a/tests/kamae/spark/transformers/test_hash_index.py +++ b/tests/kamae/spark/transformers/test_hash_index.py @@ -274,7 +274,7 @@ def test_hash_indexer_spark_tf_parity( ) tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()(input_tensor).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensor).numpy().tolist() ] # then diff --git a/tests/kamae/spark/transformers/test_haversine_distance.py b/tests/kamae/spark/transformers/test_haversine_distance.py index 137616e6..fe05e590 100644 --- a/tests/kamae/spark/transformers/test_haversine_distance.py +++ b/tests/kamae/spark/transformers/test_haversine_distance.py @@ -679,7 +679,9 @@ def test_haversine_distance_transform_spark_tf_parity( .rdd.map(lambda r: r[0]) .collect() ) - tensorflow_values = transformer.get_tf_layer()(input_tensors).numpy().tolist() + tensorflow_values = ( + transformer.get_keras_layer()(input_tensors).numpy().tolist() + ) # then np.testing.assert_almost_equal( diff --git a/tests/kamae/spark/transformers/test_identity.py b/tests/kamae/spark/transformers/test_identity.py index e2411458..29df46a7 100644 --- a/tests/kamae/spark/transformers/test_identity.py +++ b/tests/kamae/spark/transformers/test_identity.py @@ -159,7 +159,7 @@ def test_identity_transform_spark_tf_parity( ) tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()(input_tensor).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensor).numpy().tolist() ] # then diff --git a/tests/kamae/spark/transformers/test_if_statement.py b/tests/kamae/spark/transformers/test_if_statement.py index 8ade848b..659878a7 100644 --- a/tests/kamae/spark/transformers/test_if_statement.py +++ b/tests/kamae/spark/transformers/test_if_statement.py @@ -455,7 +455,7 @@ def test_if_statement_transform_single_input_spark_tf_parity( ) tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()(input_tensor).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensor).numpy().tolist() ] # then @@ -587,7 +587,7 @@ def test_if_statement_transform_multiple_input_spark_tf_parity( ) tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()(input_tensors).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensors).numpy().tolist() ] # then diff --git a/tests/kamae/spark/transformers/test_impute.py b/tests/kamae/spark/transformers/test_impute.py index 7be66d1c..550482f3 100644 --- a/tests/kamae/spark/transformers/test_impute.py +++ b/tests/kamae/spark/transformers/test_impute.py @@ -207,7 +207,7 @@ def test_impute_spark_tf_parity( ) tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()(input_tensor).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensor).numpy().tolist() ] # then diff --git a/tests/kamae/spark/transformers/test_lambda_function.py b/tests/kamae/spark/transformers/test_lambda_function.py index 754c529c..53142279 100644 --- a/tests/kamae/spark/transformers/test_lambda_function.py +++ b/tests/kamae/spark/transformers/test_lambda_function.py @@ -550,7 +550,7 @@ def test_lambda_function_transform_single_input_single_output_spark_tf_parity( ) tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()(input_tensor).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensor).numpy().tolist() ] # then @@ -649,7 +649,9 @@ def test_lambda_function_transform_multiple_input_single_output_spark_tf_parity( .rdd.map(lambda r: r[0]) .collect() ) - tensorflow_values = transformer.get_tf_layer()(input_tensors).numpy().tolist() + tensorflow_values = ( + transformer.get_keras_layer()(input_tensors).numpy().tolist() + ) # then if isinstance(spark_values[0], str): @@ -709,7 +711,7 @@ def test_lambda_function_transform_single_input_multiple_output_spark_tf_parity( for c in output_col_names ] tensorflow_values = [ - v.numpy().tolist() for v in transformer.get_tf_layer()(input_tensor) + v.numpy().tolist() for v in transformer.get_keras_layer()(input_tensor) ] # then @@ -776,7 +778,7 @@ def test_lambda_function_transform_multiple_input_multiple_output_spark_tf_parit for c in output_col_names ] tensorflow_values = [ - v.numpy().tolist() for v in transformer.get_tf_layer()(input_tensors) + v.numpy().tolist() for v in transformer.get_keras_layer()(input_tensors) ] # then diff --git a/tests/kamae/spark/transformers/test_list_max.py b/tests/kamae/spark/transformers/test_list_max.py index 080acc0c..9f927a72 100644 --- a/tests/kamae/spark/transformers/test_list_max.py +++ b/tests/kamae/spark/transformers/test_list_max.py @@ -715,7 +715,7 @@ def test_list_max_transform_spark_tf_parity( tensorflow_values = np.reshape( [ np.squeeze(v) - for v in transformer.get_tf_layer()(input_tensors).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensors).numpy().tolist() ], -1, ) diff --git a/tests/kamae/spark/transformers/test_list_mean.py b/tests/kamae/spark/transformers/test_list_mean.py index 95cf3bc4..13df8f16 100644 --- a/tests/kamae/spark/transformers/test_list_mean.py +++ b/tests/kamae/spark/transformers/test_list_mean.py @@ -714,7 +714,7 @@ def test_list_mean_transform_spark_tf_parity( tensorflow_values = np.reshape( [ np.squeeze(v) - for v in transformer.get_tf_layer()(input_tensors).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensors).numpy().tolist() ], -1, ) diff --git a/tests/kamae/spark/transformers/test_list_median.py b/tests/kamae/spark/transformers/test_list_median.py index f5fe0c77..abe15a2d 100644 --- a/tests/kamae/spark/transformers/test_list_median.py +++ b/tests/kamae/spark/transformers/test_list_median.py @@ -557,7 +557,7 @@ def test_list_median_transform_spark_tf_parity( tensorflow_values = np.reshape( [ np.squeeze(v) - for v in transformer.get_tf_layer()(input_tensors).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensors).numpy().tolist() ], -1, ) diff --git a/tests/kamae/spark/transformers/test_list_min.py b/tests/kamae/spark/transformers/test_list_min.py index 7b65fa22..9213a37f 100644 --- a/tests/kamae/spark/transformers/test_list_min.py +++ b/tests/kamae/spark/transformers/test_list_min.py @@ -715,7 +715,7 @@ def test_list_min_transform_spark_tf_parity( tensorflow_values = np.reshape( [ np.squeeze(v) - for v in transformer.get_tf_layer()(input_tensors).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensors).numpy().tolist() ], -1, ) diff --git a/tests/kamae/spark/transformers/test_list_rank.py b/tests/kamae/spark/transformers/test_list_rank.py index 958411de..00fbe637 100644 --- a/tests/kamae/spark/transformers/test_list_rank.py +++ b/tests/kamae/spark/transformers/test_list_rank.py @@ -227,7 +227,7 @@ def test_list_rank_transform_spark_tf_parity( tensorflow_values = np.reshape( [ np.squeeze(v) - for v in transformer.get_tf_layer()(input_tensor).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensor).numpy().tolist() ], -1, ) diff --git a/tests/kamae/spark/transformers/test_list_std_dev.py b/tests/kamae/spark/transformers/test_list_std_dev.py index 62c13394..4b9475f8 100644 --- a/tests/kamae/spark/transformers/test_list_std_dev.py +++ b/tests/kamae/spark/transformers/test_list_std_dev.py @@ -557,7 +557,7 @@ def test_list_average_transform_spark_tf_parity( tensorflow_values = np.reshape( [ np.squeeze(v) - for v in transformer.get_tf_layer()(input_tensors).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensors).numpy().tolist() ], -1, ) diff --git a/tests/kamae/spark/transformers/test_log.py b/tests/kamae/spark/transformers/test_log.py index b8b78dbc..2fa007a6 100644 --- a/tests/kamae/spark/transformers/test_log.py +++ b/tests/kamae/spark/transformers/test_log.py @@ -204,7 +204,7 @@ def test_log_transform_spark_tf_parity( ) tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()(input_tensor).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensor).numpy().tolist() ] # then diff --git a/tests/kamae/spark/transformers/test_logical_and.py b/tests/kamae/spark/transformers/test_logical_and.py index a7ae8865..24a69007 100644 --- a/tests/kamae/spark/transformers/test_logical_and.py +++ b/tests/kamae/spark/transformers/test_logical_and.py @@ -226,7 +226,7 @@ def test_logical_and_transform_spark_tf_parity( ) tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()(input_tensors).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensors).numpy().tolist() ] # then diff --git a/tests/kamae/spark/transformers/test_logical_not.py b/tests/kamae/spark/transformers/test_logical_not.py index 68bab15b..887862f0 100644 --- a/tests/kamae/spark/transformers/test_logical_not.py +++ b/tests/kamae/spark/transformers/test_logical_not.py @@ -192,7 +192,7 @@ def test_logical_not_transform_spark_tf_parity( ) tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()(input_tensor).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensor).numpy().tolist() ] # then diff --git a/tests/kamae/spark/transformers/test_logical_or.py b/tests/kamae/spark/transformers/test_logical_or.py index 63669b5c..7626cede 100644 --- a/tests/kamae/spark/transformers/test_logical_or.py +++ b/tests/kamae/spark/transformers/test_logical_or.py @@ -225,7 +225,7 @@ def test_logical_or_transform_spark_tf_parity( ) tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()(input_tensors).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensors).numpy().tolist() ] # then diff --git a/tests/kamae/spark/transformers/test_max.py b/tests/kamae/spark/transformers/test_max.py index a9c04d08..0edf1457 100644 --- a/tests/kamae/spark/transformers/test_max.py +++ b/tests/kamae/spark/transformers/test_max.py @@ -285,7 +285,7 @@ def test_max_transform_single_input_spark_tf_parity( ) tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()(input_tensor).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensor).numpy().tolist() ] # then @@ -373,7 +373,7 @@ def test_max_transform_multiple_input_spark_tf_parity( ) tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()(input_tensors).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensors).numpy().tolist() ] # then diff --git a/tests/kamae/spark/transformers/test_mean.py b/tests/kamae/spark/transformers/test_mean.py index f03557b8..6ae94268 100644 --- a/tests/kamae/spark/transformers/test_mean.py +++ b/tests/kamae/spark/transformers/test_mean.py @@ -293,7 +293,7 @@ def test_mean_transform_single_input_spark_tf_parity( ) tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()(input_tensor).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensor).numpy().tolist() ] # then @@ -381,7 +381,7 @@ def test_mean_transform_multiple_input_spark_tf_parity( ) tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()(input_tensors).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensors).numpy().tolist() ] # then diff --git a/tests/kamae/spark/transformers/test_min.py b/tests/kamae/spark/transformers/test_min.py index 7b1fb60a..94b42ace 100644 --- a/tests/kamae/spark/transformers/test_min.py +++ b/tests/kamae/spark/transformers/test_min.py @@ -285,7 +285,7 @@ def test_min_transform_single_input_spark_tf_parity( ) tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()(input_tensor).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensor).numpy().tolist() ] # then @@ -373,7 +373,7 @@ def test_min_transform_multiple_input_spark_tf_parity( ) tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()(input_tensors).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensors).numpy().tolist() ] # then diff --git a/tests/kamae/spark/transformers/test_min_hash_index.py b/tests/kamae/spark/transformers/test_min_hash_index.py index 2098845e..0cf990b1 100644 --- a/tests/kamae/spark/transformers/test_min_hash_index.py +++ b/tests/kamae/spark/transformers/test_min_hash_index.py @@ -517,7 +517,7 @@ def test_min_hash_spark_tf_parity( tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()(input_tensor).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensor).numpy().tolist() ] # then diff --git a/tests/kamae/spark/transformers/test_min_max_scale.py b/tests/kamae/spark/transformers/test_min_max_scale.py index 12dc8965..a12dfd81 100644 --- a/tests/kamae/spark/transformers/test_min_max_scale.py +++ b/tests/kamae/spark/transformers/test_min_max_scale.py @@ -441,7 +441,7 @@ def test_min_max_scaler_spark_tf_parity( .rdd.map(lambda r: r[0]) .collect() ) - tensorflow_values = transformer.get_tf_layer()(input_tensor).numpy().tolist() + tensorflow_values = transformer.get_keras_layer()(input_tensor).numpy().tolist() # then np.testing.assert_almost_equal( diff --git a/tests/kamae/spark/transformers/test_modulo.py b/tests/kamae/spark/transformers/test_modulo.py index 22736648..a6ec3ba6 100644 --- a/tests/kamae/spark/transformers/test_modulo.py +++ b/tests/kamae/spark/transformers/test_modulo.py @@ -244,7 +244,7 @@ def test_modulo_transform_single_input_spark_tf_parity( ) tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()(input_tensor).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensor).numpy().tolist() ] # then if isinstance(spark_values[0], str): @@ -318,7 +318,7 @@ def test_modulo_transform_multiple_input_spark_tf_parity( ) tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()(input_tensors).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensors).numpy().tolist() ] # then diff --git a/tests/kamae/spark/transformers/test_multiply.py b/tests/kamae/spark/transformers/test_multiply.py index 613bf082..868832d1 100644 --- a/tests/kamae/spark/transformers/test_multiply.py +++ b/tests/kamae/spark/transformers/test_multiply.py @@ -296,7 +296,7 @@ def test_multiply_transform_single_input_spark_tf_parity( ) tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()(input_tensor).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensor).numpy().tolist() ] # then @@ -384,7 +384,7 @@ def test_multiply_transform_multiple_input_spark_tf_parity( ) tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()(input_tensors).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensors).numpy().tolist() ] # then diff --git a/tests/kamae/spark/transformers/test_numerical_if_statement.py b/tests/kamae/spark/transformers/test_numerical_if_statement.py index d9a79ee3..a4f5226b 100644 --- a/tests/kamae/spark/transformers/test_numerical_if_statement.py +++ b/tests/kamae/spark/transformers/test_numerical_if_statement.py @@ -278,7 +278,7 @@ def test_numerical_if_statement_transform_single_input_spark_tf_parity( .rdd.map(lambda r: r[0]) .collect() ) - tensorflow_values = transformer.get_tf_layer()(input_tensor).numpy().tolist() + tensorflow_values = transformer.get_keras_layer()(input_tensor).numpy().tolist() # then np.testing.assert_almost_equal( @@ -367,7 +367,9 @@ def test_numerical_if_statement_transform_multiple_input_spark_tf_parity( .rdd.map(lambda r: r[0]) .collect() ) - tensorflow_values = transformer.get_tf_layer()(input_tensors).numpy().tolist() + tensorflow_values = ( + transformer.get_keras_layer()(input_tensors).numpy().tolist() + ) # then np.testing.assert_almost_equal( diff --git a/tests/kamae/spark/transformers/test_one_hot_encode.py b/tests/kamae/spark/transformers/test_one_hot_encode.py index e32e03fe..8b4b0af4 100644 --- a/tests/kamae/spark/transformers/test_one_hot_encode.py +++ b/tests/kamae/spark/transformers/test_one_hot_encode.py @@ -354,7 +354,7 @@ def test_one_hot_encoder_spark_tf_parity( vec_decoder = np.vectorize(decoder) tensorflow_values = [ vec_decoder(v) if isinstance(v[0], bytes) else v - for v in transformer.get_tf_layer()(input_tensor).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensor).numpy().tolist() ] # then diff --git a/tests/kamae/spark/transformers/test_ordinal_array_encode.py b/tests/kamae/spark/transformers/test_ordinal_array_encode.py index d5c02f3b..be2d4652 100644 --- a/tests/kamae/spark/transformers/test_ordinal_array_encode.py +++ b/tests/kamae/spark/transformers/test_ordinal_array_encode.py @@ -206,7 +206,7 @@ def test_ordinal_encoding_spark_tf_parity( .rdd.map(lambda r: r[0]) .collect() ) - tensorflow_values = transformer.get_tf_layer()(input_tensor).numpy().tolist() + tensorflow_values = transformer.get_keras_layer()(input_tensor).numpy().tolist() # then np.testing.assert_equal( diff --git a/tests/kamae/spark/transformers/test_pairwise_cosine_similarity.py b/tests/kamae/spark/transformers/test_pairwise_cosine_similarity.py index a7374518..7d424834 100644 --- a/tests/kamae/spark/transformers/test_pairwise_cosine_similarity.py +++ b/tests/kamae/spark/transformers/test_pairwise_cosine_similarity.py @@ -157,7 +157,7 @@ def test_spark_tf_parity( tf_queries = tf.constant(queries, dtype=tf.float32) tf_candidates = tf.constant(flat_candidates, dtype=tf.float32) keras_values = ( - transformer.get_tf_layer()([tf_queries, tf_candidates]).numpy().tolist() + transformer.get_keras_layer()([tf_queries, tf_candidates]).numpy().tolist() ) np.testing.assert_almost_equal( diff --git a/tests/kamae/spark/transformers/test_round.py b/tests/kamae/spark/transformers/test_round.py index 7584d89e..dfe29384 100644 --- a/tests/kamae/spark/transformers/test_round.py +++ b/tests/kamae/spark/transformers/test_round.py @@ -220,7 +220,7 @@ def test_round_transform_spark_tf_parity( .rdd.map(lambda r: r[0]) .collect() ) - tensorflow_values = transformer.get_tf_layer()(input_tensor).numpy().tolist() + tensorflow_values = transformer.get_keras_layer()(input_tensor).numpy().tolist() # then np.testing.assert_almost_equal( diff --git a/tests/kamae/spark/transformers/test_round_to_decimal.py b/tests/kamae/spark/transformers/test_round_to_decimal.py index 409794a7..28df0d12 100644 --- a/tests/kamae/spark/transformers/test_round_to_decimal.py +++ b/tests/kamae/spark/transformers/test_round_to_decimal.py @@ -224,7 +224,7 @@ def test_round_to_decimal_transform_spark_tf_parity( .rdd.map(lambda r: r[0]) .collect() ) - tensorflow_values = transformer.get_tf_layer()(input_tensor).numpy().tolist() + tensorflow_values = transformer.get_keras_layer()(input_tensor).numpy().tolist() # then np.testing.assert_almost_equal( diff --git a/tests/kamae/spark/transformers/test_shared_one_hot_encode.py b/tests/kamae/spark/transformers/test_shared_one_hot_encode.py index e2e57a32..84495a0d 100644 --- a/tests/kamae/spark/transformers/test_shared_one_hot_encode.py +++ b/tests/kamae/spark/transformers/test_shared_one_hot_encode.py @@ -283,7 +283,7 @@ def test_one_hot_encoder_spark_tf_parity( vec_decoder = np.vectorize(decoder) tensorflow_values = [ vec_decoder(v) if isinstance(v[0], bytes) else v - for v in transformer.get_tf_layer()[0](input_tensors[0]).numpy().tolist() + for v in transformer.get_keras_layer()[0](input_tensors[0]).numpy().tolist() ] # then np.testing.assert_equal( diff --git a/tests/kamae/spark/transformers/test_shared_string_index.py b/tests/kamae/spark/transformers/test_shared_string_index.py index 3771b74b..c1161d01 100644 --- a/tests/kamae/spark/transformers/test_shared_string_index.py +++ b/tests/kamae/spark/transformers/test_shared_string_index.py @@ -283,7 +283,7 @@ def test_shared_string_indexer_spark_tf_parity( ) tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()[0](input_tensor).numpy().tolist() + for v in transformer.get_keras_layer()[0](input_tensor).numpy().tolist() ] # then diff --git a/tests/kamae/spark/transformers/test_standard_scale.py b/tests/kamae/spark/transformers/test_standard_scale.py index 73ed603d..076b8101 100644 --- a/tests/kamae/spark/transformers/test_standard_scale.py +++ b/tests/kamae/spark/transformers/test_standard_scale.py @@ -439,7 +439,7 @@ def test_standard_scaler_spark_tf_parity( .rdd.map(lambda r: r[0]) .collect() ) - tensorflow_values = transformer.get_tf_layer()(input_tensor).numpy().tolist() + tensorflow_values = transformer.get_keras_layer()(input_tensor).numpy().tolist() # then np.testing.assert_almost_equal( diff --git a/tests/kamae/spark/transformers/test_string_affix.py b/tests/kamae/spark/transformers/test_string_affix.py index 65747d5b..22ac940f 100644 --- a/tests/kamae/spark/transformers/test_string_affix.py +++ b/tests/kamae/spark/transformers/test_string_affix.py @@ -243,7 +243,7 @@ def test_string_affix_transform_spark_tf_parity( ) tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()(input_tensor).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensor).numpy().tolist() ] # then np.testing.assert_equal( @@ -285,4 +285,4 @@ def test_fail_string_affix_transform( with pytest.raises(expected_error): transformer.transform(spark_df) with pytest.raises(expected_error): - transformer.get_tf_layer()(input_tensor) + transformer.get_keras_layer()(input_tensor) diff --git a/tests/kamae/spark/transformers/test_string_array_constant.py b/tests/kamae/spark/transformers/test_string_array_constant.py index b726230e..b07f8204 100644 --- a/tests/kamae/spark/transformers/test_string_array_constant.py +++ b/tests/kamae/spark/transformers/test_string_array_constant.py @@ -237,7 +237,7 @@ def test_string_array_constant_transform_spark_tf_parity( # (this drops first dimension) # and put it in a list to bring back the dimension spark_values_reshape = [spark_values[0]] - tensorflow_values_np = transformer.get_tf_layer()(input_tensor).numpy() + tensorflow_values_np = transformer.get_keras_layer()(input_tensor).numpy() tensorflow_values = np.vectorize( lambda x: x.decode("utf-8") if isinstance(x, bytes) else x )(tensorflow_values_np).tolist() diff --git a/tests/kamae/spark/transformers/test_string_case.py b/tests/kamae/spark/transformers/test_string_case.py index dfa57baf..36231c1b 100644 --- a/tests/kamae/spark/transformers/test_string_case.py +++ b/tests/kamae/spark/transformers/test_string_case.py @@ -274,7 +274,7 @@ def test_string_case_transform_spark_tf_parity( ) tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()(input_tensor).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensor).numpy().tolist() ] # then diff --git a/tests/kamae/spark/transformers/test_string_concatenate.py b/tests/kamae/spark/transformers/test_string_concatenate.py index b16049a2..51a4fab6 100644 --- a/tests/kamae/spark/transformers/test_string_concatenate.py +++ b/tests/kamae/spark/transformers/test_string_concatenate.py @@ -242,7 +242,7 @@ def test_string_concatenate_transform_spark_tf_parity( ) tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()(input_tensors).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensors).numpy().tolist() ] # then np.testing.assert_equal( diff --git a/tests/kamae/spark/transformers/test_string_contains.py b/tests/kamae/spark/transformers/test_string_contains.py index 51dac2a7..6c8b9467 100644 --- a/tests/kamae/spark/transformers/test_string_contains.py +++ b/tests/kamae/spark/transformers/test_string_contains.py @@ -294,7 +294,7 @@ def test_string_contains_spark_tf_parity( ) tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()(input_tensors).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensors).numpy().tolist() ] # then diff --git a/tests/kamae/spark/transformers/test_string_contains_list.py b/tests/kamae/spark/transformers/test_string_contains_list.py index b66c8491..1ec6706e 100644 --- a/tests/kamae/spark/transformers/test_string_contains_list.py +++ b/tests/kamae/spark/transformers/test_string_contains_list.py @@ -216,7 +216,7 @@ def test_string_contains_list_spark_tf_parity( ) tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()(input_tensor).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensor).numpy().tolist() ] # then diff --git a/tests/kamae/spark/transformers/test_string_equals_if_statement.py b/tests/kamae/spark/transformers/test_string_equals_if_statement.py index 714bc884..6b27c268 100644 --- a/tests/kamae/spark/transformers/test_string_equals_if_statement.py +++ b/tests/kamae/spark/transformers/test_string_equals_if_statement.py @@ -279,7 +279,7 @@ def test_string_if_statement_transform_single_input_spark_tf_parity( ) tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()(input_tensor).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensor).numpy().tolist() ] # then @@ -358,7 +358,7 @@ def test_string_if_statement_transform_multiple_input_spark_tf_parity( ) tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()(input_tensors).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensors).numpy().tolist() ] # then diff --git a/tests/kamae/spark/transformers/test_string_index.py b/tests/kamae/spark/transformers/test_string_index.py index 4dd8262a..a3b474b1 100644 --- a/tests/kamae/spark/transformers/test_string_index.py +++ b/tests/kamae/spark/transformers/test_string_index.py @@ -287,7 +287,7 @@ def test_string_indexer_spark_tf_parity( ) tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()(input_tensor).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensor).numpy().tolist() ] # then diff --git a/tests/kamae/spark/transformers/test_string_isin_list.py b/tests/kamae/spark/transformers/test_string_isin_list.py index 20e8d266..6d833825 100644 --- a/tests/kamae/spark/transformers/test_string_isin_list.py +++ b/tests/kamae/spark/transformers/test_string_isin_list.py @@ -217,7 +217,7 @@ def test_string_isin_list_spark_tf_parity( .rdd.map(lambda x: x[0]) .collect() ) - tensorflow_values = transformer.get_tf_layer()(input_tensor).numpy() + tensorflow_values = transformer.get_keras_layer()(input_tensor).numpy() # then np.testing.assert_equal( diff --git a/tests/kamae/spark/transformers/test_string_list_to_string.py b/tests/kamae/spark/transformers/test_string_list_to_string.py index 8bd7dea7..1ae77abf 100644 --- a/tests/kamae/spark/transformers/test_string_list_to_string.py +++ b/tests/kamae/spark/transformers/test_string_list_to_string.py @@ -243,7 +243,7 @@ def test_string_list_to_string_spark_tf_parity( ) tensorflow_values = vec_decoder( - transformer.get_tf_layer()(input_tensor).numpy().flatten() + transformer.get_keras_layer()(input_tensor).numpy().flatten() ).tolist() # then diff --git a/tests/kamae/spark/transformers/test_string_map.py b/tests/kamae/spark/transformers/test_string_map.py index 6f4f8f6b..13e5e847 100644 --- a/tests/kamae/spark/transformers/test_string_map.py +++ b/tests/kamae/spark/transformers/test_string_map.py @@ -151,7 +151,7 @@ def test_string_map_spark_tf_parity_no_constants( ) tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()(input_tensors).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensors).numpy().tolist() ] # then diff --git a/tests/kamae/spark/transformers/test_string_replace.py b/tests/kamae/spark/transformers/test_string_replace.py index 975a78db..896de9ee 100644 --- a/tests/kamae/spark/transformers/test_string_replace.py +++ b/tests/kamae/spark/transformers/test_string_replace.py @@ -319,7 +319,7 @@ def test_string_replace_spark_tf_parity_no_constants( ) tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()(input_tensors).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensors).numpy().tolist() ] # then diff --git a/tests/kamae/spark/transformers/test_string_to_string_list.py b/tests/kamae/spark/transformers/test_string_to_string_list.py index 2299032d..8918f716 100644 --- a/tests/kamae/spark/transformers/test_string_to_string_list.py +++ b/tests/kamae/spark/transformers/test_string_to_string_list.py @@ -329,7 +329,7 @@ def test_string_to_string_list_spark_tf_parity( vec_decoder = np.vectorize(decoder) tensorflow_values = [ vec_decoder(v) if isinstance(v[0], bytes) else v - for v in transformer.get_tf_layer()(input_tensor).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensor).numpy().tolist() ] # then diff --git a/tests/kamae/spark/transformers/test_sub_string_delim_at_index.py b/tests/kamae/spark/transformers/test_sub_string_delim_at_index.py index 4da92288..d71f7bd0 100644 --- a/tests/kamae/spark/transformers/test_sub_string_delim_at_index.py +++ b/tests/kamae/spark/transformers/test_sub_string_delim_at_index.py @@ -427,7 +427,7 @@ def test_sub_string_delim_transform_spark_tf_parity( ) tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()(input_tensor).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensor).numpy().tolist() ] # then diff --git a/tests/kamae/spark/transformers/test_subtract.py b/tests/kamae/spark/transformers/test_subtract.py index 247d3630..5e2c5e6a 100644 --- a/tests/kamae/spark/transformers/test_subtract.py +++ b/tests/kamae/spark/transformers/test_subtract.py @@ -291,7 +291,7 @@ def test_subtract_transform_single_input_spark_tf_parity( ) tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()(input_tensor).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensor).numpy().tolist() ] # then @@ -381,7 +381,7 @@ def test_subtract_transform_multiple_input_spark_tf_parity( ) tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()(input_tensors).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensors).numpy().tolist() ] # then diff --git a/tests/kamae/spark/transformers/test_sum.py b/tests/kamae/spark/transformers/test_sum.py index ed9d88a5..0f9e1c91 100644 --- a/tests/kamae/spark/transformers/test_sum.py +++ b/tests/kamae/spark/transformers/test_sum.py @@ -285,7 +285,7 @@ def test_sum_transform_single_input_spark_tf_parity( ) tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()(input_tensor).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensor).numpy().tolist() ] # then @@ -373,7 +373,7 @@ def test_sum_transform_multiple_input_spark_tf_parity( ) tensorflow_values = [ v.decode("utf-8") if isinstance(v, bytes) else v - for v in transformer.get_tf_layer()(input_tensors).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensors).numpy().tolist() ] # then diff --git a/tests/kamae/spark/transformers/test_unix_timestamp_to_date_time.py b/tests/kamae/spark/transformers/test_unix_timestamp_to_date_time.py index e1637f2d..2de668b7 100644 --- a/tests/kamae/spark/transformers/test_unix_timestamp_to_date_time.py +++ b/tests/kamae/spark/transformers/test_unix_timestamp_to_date_time.py @@ -330,7 +330,7 @@ def test_unix_timestamp_to_date_time_transform_spark_tf_parity( ) tensorflow_values = [ v.decode("utf-8") - for v in transformer.get_tf_layer()(input_tensor).numpy().tolist() + for v in transformer.get_keras_layer()(input_tensor).numpy().tolist() ] np.testing.assert_equal( spark_values, diff --git a/tests/kamae/tensorflow/layers/test_array_reduce_max.py b/tests/kamae/tensorflow/layers/test_array_reduce_max.py deleted file mode 100644 index 92243923..00000000 --- a/tests/kamae/tensorflow/layers/test_array_reduce_max.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright [2024] Expedia, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import numpy as np -import tensorflow as tf - -from kamae.tensorflow.layers import ArrayReduceMaxLayer - - -class TestArrayReduceMaxLayer: - def test_returns_maximum_of_each_row(self): - layer = ArrayReduceMaxLayer() - inputs = tf.constant([[3.0, 1.0, 2.0], [0.0, 5.0, 4.0]]) - - result = layer(inputs).numpy() - - np.testing.assert_array_almost_equal(result, [3.0, 5.0]) - - def test_negative_values(self): - layer = ArrayReduceMaxLayer() - inputs = tf.constant([[-3.0, -1.0, -2.0], [-10.0, -5.0, -7.0]]) - - result = layer(inputs).numpy() - - np.testing.assert_array_almost_equal(result, [-1.0, -5.0]) - - def test_single_element_array(self): - layer = ArrayReduceMaxLayer() - inputs = tf.constant([[42.0]]) - - result = layer(inputs).numpy() - - np.testing.assert_array_almost_equal(result, [42.0]) - - def test_default_value_returned_for_nan_input(self): - layer = ArrayReduceMaxLayer(default_value=-99.0) - inputs = tf.constant([[float("nan"), float("nan")]]) - - result = layer(inputs).numpy() - - np.testing.assert_array_almost_equal(result, [-99.0]) diff --git a/tests/kamae/tensorflow/layers/test_pairwise_cosine_similarity.py b/tests/kamae/tensorflow/layers/test_pairwise_cosine_similarity.py deleted file mode 100644 index d57043fe..00000000 --- a/tests/kamae/tensorflow/layers/test_pairwise_cosine_similarity.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright [2024] Expedia, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import tensorflow as tf - -from kamae.tensorflow.layers import PairwiseCosineSimilarityLayer - - -class TestPairwiseCosineSimilarityLayer: - def test_identical_vectors_give_similarity_one(self): - # query [1, 0] vs candidate [1, 0] → cosine = 1.0 - layer = PairwiseCosineSimilarityLayer(embedding_dim=2) - query = tf.constant([[1.0, 0.0]]) - candidates = tf.constant([[1.0, 0.0]]) # 1 candidate, dim=2 - - result = layer([query, candidates]).numpy() - - np.testing.assert_array_almost_equal(result, [[1.0]]) - - def test_opposite_vectors_give_similarity_minus_one(self): - # query [1, 0] vs candidate [-1, 0] → cosine = -1.0 - layer = PairwiseCosineSimilarityLayer(embedding_dim=2) - query = tf.constant([[1.0, 0.0]]) - candidates = tf.constant([[-1.0, 0.0]]) - - result = layer([query, candidates]).numpy() - - np.testing.assert_array_almost_equal(result, [[-1.0]]) - - def test_orthogonal_vectors_give_similarity_zero(self): - # query [1, 0] vs candidate [0, 1] → cosine = 0.0 - layer = PairwiseCosineSimilarityLayer(embedding_dim=2) - query = tf.constant([[1.0, 0.0]]) - candidates = tf.constant([[0.0, 1.0]]) - - result = layer([query, candidates]).numpy() - - np.testing.assert_array_almost_equal(result, [[0.0]]) - - def test_multiple_candidates_flat_packed(self): - # query [1, 0], candidates: [1, 0] and [0, 1] packed as [1, 0, 0, 1] - # expected: [1.0, 0.0] - layer = PairwiseCosineSimilarityLayer(embedding_dim=2) - query = tf.constant([[1.0, 0.0]]) - candidates = tf.constant([[1.0, 0.0, 0.0, 1.0]]) - - result = layer([query, candidates]).numpy() - - np.testing.assert_array_almost_equal(result, [[1.0, 0.0]]) - - def test_zero_query_vector_gives_zero_similarity(self): - layer = PairwiseCosineSimilarityLayer(embedding_dim=2) - query = tf.constant([[0.0, 0.0]]) - candidates = tf.constant([[1.0, 0.0]]) - - result = layer([query, candidates]).numpy() - - np.testing.assert_array_almost_equal(result, [[0.0]]) - - def test_batch_of_queries(self): - # Two rows, each with query vs one candidate of dim=2 - layer = PairwiseCosineSimilarityLayer(embedding_dim=2) - query = tf.constant([[1.0, 0.0], [0.0, 1.0]]) - candidates = tf.constant([[1.0, 0.0], [0.0, 1.0]]) # same vectors - - result = layer([query, candidates]).numpy() - - np.testing.assert_array_almost_equal(result, [[1.0], [1.0]]) diff --git a/uv.lock b/uv.lock index 4a175d56..68ac512a 100644 --- a/uv.lock +++ b/uv.lock @@ -1,11 +1,9 @@ version = 1 -requires-python = ">=3.8.1, <3.13" +requires-python = ">=3.10, <3.13" resolution-markers = [ "python_full_version >= '3.12'", "python_full_version == '3.11.*'", - "python_full_version == '3.10.*'", - "python_full_version == '3.9.*'", - "python_full_version < '3.9'", + "python_full_version < '3.11'", ] [[package]] @@ -21,9 +19,6 @@ wheels = [ name = "annotated-types" version = "0.7.0" source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "typing-extensions", marker = "python_full_version < '3.9'" }, -] sdist = { url = "https://files.pythonhosted.org/packages/ee/67/531ea369ba64dcff5ec9c3402f9f51bf748cec26dde048a2f973a4eea7f5/annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89", size = 16081 } wheels = [ { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643 }, @@ -67,9 +62,6 @@ wheels = [ name = "babel" version = "2.17.0" source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "pytz", marker = "python_full_version < '3.9'" }, -] sdist = { url = "https://files.pythonhosted.org/packages/7d/6b/d52e42361e1aa00709585ecc30b3f9684b3ab62530771402248b1b1d6240/babel-2.17.0.tar.gz", hash = "sha256:0c54cffb19f690cdcc52a3b50bcbf71e07a808d1c80d549f2459b9d2cf0afb9d", size = 9951852 } wheels = [ { url = "https://files.pythonhosted.org/packages/b7/b8/3fe70c75fe32afc4bb507f75563d39bc5642255d1d94f1f23604725780bf/babel-2.17.0-py3-none-any.whl", hash = "sha256:4d0b53093fdfb4b21c92b5213dba5a1b23885afa8383709427046b21c366e5f2", size = 10182537 }, @@ -98,26 +90,9 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8b/3c/c9a03a4d5dd8c18c4af211e694bcc73dd305a2b85788eb311d3dbb14cfe9/black-23.10.1-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:e293e4c2f4a992b980032bbd62df07c1bcff82d6964d6c9496f2cd726e246ace", size = 1484835 }, { url = "https://files.pythonhosted.org/packages/80/4a/dd74ca838e8a536f3ac061cec9ef1d0c73e3ad2f3584be2127d53cd82f0f/black-23.10.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7d56124b7a61d092cb52cce34182a5280e160e6aff3137172a68c2c2c4b76bcb", size = 1629860 }, { url = "https://files.pythonhosted.org/packages/bf/f6/1b039c5ea8fc18a3e710cc1e217fa65369e3fe9173eac9ec5080f89f9f38/black-23.10.1-cp311-cp311-win_amd64.whl", hash = "sha256:3f157a8945a7b2d424da3335f7ace89c14a3b0625e6593d21139c2d8214d55ce", size = 1290854 }, - { url = "https://files.pythonhosted.org/packages/a2/5e/acf7eff1ce3cc035f7a140d7a1a2fab1f04175573ec1586331f8a64f7d30/black-23.10.1-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:cfcce6f0a384d0da692119f2d72d79ed07c7159879d0bb1bb32d2e443382bf3a", size = 1342161 }, - { url = "https://files.pythonhosted.org/packages/c6/43/e775dd9c571f6eac939fa25c885745cf7262cdd2c92d9a506302dad88f81/black-23.10.1-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:33d40f5b06be80c1bbce17b173cda17994fbad096ce60eb22054da021bf933d1", size = 1491509 }, - { url = "https://files.pythonhosted.org/packages/b0/66/1a67f40228061d9046fa7bf806b2748d17427f14e7bdd3ee98a11fb6e0c4/black-23.10.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:840015166dbdfbc47992871325799fd2dc0dcf9395e401ada6d88fe11498abad", size = 1632456 }, - { url = "https://files.pythonhosted.org/packages/5e/5d/a30a63bb5397648ec82dc74e25fd377185044040f88089c340a69dac4a85/black-23.10.1-cp38-cp38-win_amd64.whl", hash = "sha256:037e9b4664cafda5f025a1728c50a9e9aedb99a759c89f760bd83730e76ba884", size = 1289456 }, - { url = "https://files.pythonhosted.org/packages/87/0f/0c665af27f6ce286145d747e1e37d9d4ed807af266401f4aa4d7d428fd9c/black-23.10.1-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:7cb5936e686e782fddb1c73f8aa6f459e1ad38a6a7b0e54b403f1f05a1507ee9", size = 1354727 }, - { url = "https://files.pythonhosted.org/packages/57/61/a91a66459dc4885a3b92c1bcf36e0556021f849e8c21732199a72ce9603c/black-23.10.1-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:7670242e90dc129c539e9ca17665e39a146a761e681805c54fbd86015c7c84f7", size = 1504025 }, - { url = "https://files.pythonhosted.org/packages/3c/32/56126f1991a4dfe31ce82adbf57b100b8bb11d4a8bf3b7ac716cfd52bf4d/black-23.10.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ed45ac9a613fb52dad3b61c8dea2ec9510bf3108d4db88422bacc7d1ba1243d", size = 1644413 }, - { url = "https://files.pythonhosted.org/packages/1b/e5/33e5ed299302607adbd9c23d651acd788ffb9095fe6cc0f169e9d71f41d4/black-23.10.1-cp39-cp39-win_amd64.whl", hash = "sha256:6d23d7822140e3fef190734216cefb262521789367fbdc0b3f22af6744058982", size = 1280223 }, { url = "https://files.pythonhosted.org/packages/72/6e/3c49b5779a087979cb1916b1409e2bcee2d58bab1f880a4d2720251a3bfa/black-23.10.1-py3-none-any.whl", hash = "sha256:d431e6739f727bb2e0495df64a6c7a5310758e87505f5f8cde9ff6c0f2d7e4fe", size = 184603 }, ] -[[package]] -name = "cachetools" -version = "5.5.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/6c/81/3747dad6b14fa2cf53fcf10548cf5aea6913e96fab41a3c198676f8948a5/cachetools-5.5.2.tar.gz", hash = "sha256:1a661caa9175d26759571b2e19580f9d6393969e5dfca11fdb1f947a23e640d4", size = 28380 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/72/76/20fa66124dbe6be5cafeb312ece67de6b61dd91a0247d1ea13db4ebb33c2/cachetools-5.5.2-py3-none-any.whl", hash = "sha256:d26a22bcc62eb95c3beabd9f1ee5e820d3d2704fe2967cbe350e20c8ffcd3f0a", size = 10080 }, -] - [[package]] name = "certifi" version = "2025.1.31" @@ -181,32 +156,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/13/0e/9c8d4cb99c98c1007cc11eda969ebfe837bbbd0acdb4736d228ccaabcd22/charset_normalizer-3.4.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e358e64305fe12299a08e08978f51fc21fac060dcfcddd95453eabe5b93ed0e1", size = 146192 }, { url = "https://files.pythonhosted.org/packages/b2/21/2b6b5b860781a0b49427309cb8670785aa543fb2178de875b87b9cc97746/charset_normalizer-3.4.1-cp312-cp312-win32.whl", hash = "sha256:9b23ca7ef998bc739bf6ffc077c2116917eabcc901f88da1b9856b210ef63f35", size = 95550 }, { url = "https://files.pythonhosted.org/packages/21/5b/1b390b03b1d16c7e382b561c5329f83cc06623916aab983e8ab9239c7d5c/charset_normalizer-3.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:6ff8a4a60c227ad87030d76e99cd1698345d4491638dfa6673027c48b3cd395f", size = 102785 }, - { url = "https://files.pythonhosted.org/packages/10/bd/6517ea94f2672e801011d50b5d06be2a0deaf566aea27bcdcd47e5195357/charset_normalizer-3.4.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:ecddf25bee22fe4fe3737a399d0d177d72bc22be6913acfab364b40bce1ba83c", size = 195653 }, - { url = "https://files.pythonhosted.org/packages/e5/0d/815a2ba3f283b4eeaa5ece57acade365c5b4135f65a807a083c818716582/charset_normalizer-3.4.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c60ca7339acd497a55b0ea5d506b2a2612afb2826560416f6894e8b5770d4a9", size = 140701 }, - { url = "https://files.pythonhosted.org/packages/aa/17/c94be7ee0d142687e047fe1de72060f6d6837f40eedc26e87e6e124a3fc6/charset_normalizer-3.4.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b7b2d86dd06bfc2ade3312a83a5c364c7ec2e3498f8734282c6c3d4b07b346b8", size = 150495 }, - { url = "https://files.pythonhosted.org/packages/f7/33/557ac796c47165fc141e4fb71d7b0310f67e05cb420756f3a82e0a0068e0/charset_normalizer-3.4.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dd78cfcda14a1ef52584dbb008f7ac81c1328c0f58184bf9a84c49c605002da6", size = 142946 }, - { url = "https://files.pythonhosted.org/packages/1e/0d/38ef4ae41e9248d63fc4998d933cae22473b1b2ac4122cf908d0f5eb32aa/charset_normalizer-3.4.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e27f48bcd0957c6d4cb9d6fa6b61d192d0b13d5ef563e5f2ae35feafc0d179c", size = 144737 }, - { url = "https://files.pythonhosted.org/packages/43/01/754cdb29dd0560f58290aaaa284d43eea343ad0512e6ad3b8b5c11f08592/charset_normalizer-3.4.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:01ad647cdd609225c5350561d084b42ddf732f4eeefe6e678765636791e78b9a", size = 147471 }, - { url = "https://files.pythonhosted.org/packages/ba/cd/861883ba5160c7a9bd242c30b2c71074cda2aefcc0addc91118e0d4e0765/charset_normalizer-3.4.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:619a609aa74ae43d90ed2e89bdd784765de0a25ca761b93e196d938b8fd1dbbd", size = 140801 }, - { url = "https://files.pythonhosted.org/packages/6f/7f/0c0dad447819e90b93f8ed238cc8f11b91353c23c19e70fa80483a155bed/charset_normalizer-3.4.1-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:89149166622f4db9b4b6a449256291dc87a99ee53151c74cbd82a53c8c2f6ccd", size = 149312 }, - { url = "https://files.pythonhosted.org/packages/8e/09/9f8abcc6fff60fb727268b63c376c8c79cc37b833c2dfe1f535dfb59523b/charset_normalizer-3.4.1-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:7709f51f5f7c853f0fb938bcd3bc59cdfdc5203635ffd18bf354f6967ea0f824", size = 152347 }, - { url = "https://files.pythonhosted.org/packages/be/e5/3f363dad2e24378f88ccf63ecc39e817c29f32e308ef21a7a6d9c1201165/charset_normalizer-3.4.1-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:345b0426edd4e18138d6528aed636de7a9ed169b4aaf9d61a8c19e39d26838ca", size = 149888 }, - { url = "https://files.pythonhosted.org/packages/e4/10/a78c0e91f487b4ad0ef7480ac765e15b774f83de2597f1b6ef0eaf7a2f99/charset_normalizer-3.4.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:0907f11d019260cdc3f94fbdb23ff9125f6b5d1039b76003b5b0ac9d6a6c9d5b", size = 145169 }, - { url = "https://files.pythonhosted.org/packages/d3/81/396e7d7f5d7420da8273c91175d2e9a3f569288e3611d521685e4b9ac9cc/charset_normalizer-3.4.1-cp38-cp38-win32.whl", hash = "sha256:ea0d8d539afa5eb2728aa1932a988a9a7af94f18582ffae4bc10b3fbdad0626e", size = 95094 }, - { url = "https://files.pythonhosted.org/packages/40/bb/20affbbd9ea29c71ea123769dc568a6d42052ff5089c5fe23e21e21084a6/charset_normalizer-3.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:329ce159e82018d646c7ac45b01a430369d526569ec08516081727a20e9e4af4", size = 102139 }, - { url = "https://files.pythonhosted.org/packages/7f/c0/b913f8f02836ed9ab32ea643c6fe4d3325c3d8627cf6e78098671cafff86/charset_normalizer-3.4.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:b97e690a2118911e39b4042088092771b4ae3fc3aa86518f84b8cf6888dbdb41", size = 197867 }, - { url = "https://files.pythonhosted.org/packages/0f/6c/2bee440303d705b6fb1e2ec789543edec83d32d258299b16eed28aad48e0/charset_normalizer-3.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:78baa6d91634dfb69ec52a463534bc0df05dbd546209b79a3880a34487f4b84f", size = 141385 }, - { url = "https://files.pythonhosted.org/packages/3d/04/cb42585f07f6f9fd3219ffb6f37d5a39b4fd2db2355b23683060029c35f7/charset_normalizer-3.4.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1a2bc9f351a75ef49d664206d51f8e5ede9da246602dc2d2726837620ea034b2", size = 151367 }, - { url = "https://files.pythonhosted.org/packages/54/54/2412a5b093acb17f0222de007cc129ec0e0df198b5ad2ce5699355269dfe/charset_normalizer-3.4.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:75832c08354f595c760a804588b9357d34ec00ba1c940c15e31e96d902093770", size = 143928 }, - { url = "https://files.pythonhosted.org/packages/5a/6d/e2773862b043dcf8a221342954f375392bb2ce6487bcd9f2c1b34e1d6781/charset_normalizer-3.4.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0af291f4fe114be0280cdd29d533696a77b5b49cfde5467176ecab32353395c4", size = 146203 }, - { url = "https://files.pythonhosted.org/packages/b9/f8/ca440ef60d8f8916022859885f231abb07ada3c347c03d63f283bec32ef5/charset_normalizer-3.4.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0167ddc8ab6508fe81860a57dd472b2ef4060e8d378f0cc555707126830f2537", size = 148082 }, - { url = "https://files.pythonhosted.org/packages/04/d2/42fd330901aaa4b805a1097856c2edf5095e260a597f65def493f4b8c833/charset_normalizer-3.4.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:2a75d49014d118e4198bcee5ee0a6f25856b29b12dbf7cd012791f8a6cc5c496", size = 142053 }, - { url = "https://files.pythonhosted.org/packages/9e/af/3a97a4fa3c53586f1910dadfc916e9c4f35eeada36de4108f5096cb7215f/charset_normalizer-3.4.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:363e2f92b0f0174b2f8238240a1a30142e3db7b957a5dd5689b0e75fb717cc78", size = 150625 }, - { url = "https://files.pythonhosted.org/packages/26/ae/23d6041322a3556e4da139663d02fb1b3c59a23ab2e2b56432bd2ad63ded/charset_normalizer-3.4.1-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:ab36c8eb7e454e34e60eb55ca5d241a5d18b2c6244f6827a30e451c42410b5f7", size = 153549 }, - { url = "https://files.pythonhosted.org/packages/94/22/b8f2081c6a77cb20d97e57e0b385b481887aa08019d2459dc2858ed64871/charset_normalizer-3.4.1-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:4c0907b1928a36d5a998d72d64d8eaa7244989f7aaaf947500d3a800c83a3fd6", size = 150945 }, - { url = "https://files.pythonhosted.org/packages/c7/0b/c5ec5092747f801b8b093cdf5610e732b809d6cb11f4c51e35fc28d1d389/charset_normalizer-3.4.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:04432ad9479fa40ec0f387795ddad4437a2b50417c69fa275e212933519ff294", size = 146595 }, - { url = "https://files.pythonhosted.org/packages/0c/5a/0b59704c38470df6768aa154cc87b1ac7c9bb687990a1559dc8765e8627e/charset_normalizer-3.4.1-cp39-cp39-win32.whl", hash = "sha256:3bed14e9c89dcb10e8f3a29f9ccac4955aebe93c71ae803af79265c9ca5644c5", size = 95453 }, - { url = "https://files.pythonhosted.org/packages/85/2d/a9790237cb4d01a6d57afadc8573c8b73c609ade20b80f4cda30802009ee/charset_normalizer-3.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:49402233c892a461407c512a19435d1ce275543138294f7ef013f0b63d5d3765", size = 102811 }, { url = "https://files.pythonhosted.org/packages/0e/f6/65ecc6878a89bb1c23a086ea335ad4bf21a588990c3f535a227b9eea9108/charset_normalizer-3.4.1-py3-none-any.whl", hash = "sha256:d98b1668f06378c6dbefec3b92299716b931cd4e6061f3c875a71ced1780ab85", size = 49767 }, ] @@ -236,7 +185,7 @@ name = "coverage" version = "7.6.1" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version < '3.9'", + "python_full_version < '3.11'", ] sdist = { url = "https://files.pythonhosted.org/packages/f7/08/7e37f82e4d1aead42a7443ff06a1e406aabf7302c4f00a546e4b320b994c/coverage-7.6.1.tar.gz", hash = "sha256:953510dfb7b12ab69d20135a0662397f077c59b1e6379a768e97c59d852ee51d", size = 798791 } wheels = [ @@ -270,26 +219,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/86/74/1dc7a20969725e917b1e07fe71a955eb34bc606b938316bcc799f228374b/coverage-7.6.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c3c02d12f837d9683e5ab2f3d9844dc57655b92c74e286c262e0fc54213c216d", size = 238897 }, { url = "https://files.pythonhosted.org/packages/b6/e9/d9cc3deceb361c491b81005c668578b0dfa51eed02cd081620e9a62f24ec/coverage-7.6.1-cp312-cp312-win32.whl", hash = "sha256:e05882b70b87a18d937ca6768ff33cc3f72847cbc4de4491c8e73880766718e5", size = 209606 }, { url = "https://files.pythonhosted.org/packages/47/c8/5a2e41922ea6740f77d555c4d47544acd7dc3f251fe14199c09c0f5958d3/coverage-7.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:b5d7b556859dd85f3a541db6a4e0167b86e7273e1cdc973e5b175166bb634fdb", size = 210373 }, - { url = "https://files.pythonhosted.org/packages/81/d0/d9e3d554e38beea5a2e22178ddb16587dbcbe9a1ef3211f55733924bf7fa/coverage-7.6.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6db04803b6c7291985a761004e9060b2bca08da6d04f26a7f2294b8623a0c1a0", size = 206674 }, - { url = "https://files.pythonhosted.org/packages/38/ea/cab2dc248d9f45b2b7f9f1f596a4d75a435cb364437c61b51d2eb33ceb0e/coverage-7.6.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:f1adfc8ac319e1a348af294106bc6a8458a0f1633cc62a1446aebc30c5fa186a", size = 207101 }, - { url = "https://files.pythonhosted.org/packages/ca/6f/f82f9a500c7c5722368978a5390c418d2a4d083ef955309a8748ecaa8920/coverage-7.6.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a95324a9de9650a729239daea117df21f4b9868ce32e63f8b650ebe6cef5595b", size = 236554 }, - { url = "https://files.pythonhosted.org/packages/a6/94/d3055aa33d4e7e733d8fa309d9adf147b4b06a82c1346366fc15a2b1d5fa/coverage-7.6.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b43c03669dc4618ec25270b06ecd3ee4fa94c7f9b3c14bae6571ca00ef98b0d3", size = 234440 }, - { url = "https://files.pythonhosted.org/packages/e4/6e/885bcd787d9dd674de4a7d8ec83faf729534c63d05d51d45d4fa168f7102/coverage-7.6.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8929543a7192c13d177b770008bc4e8119f2e1f881d563fc6b6305d2d0ebe9de", size = 235889 }, - { url = "https://files.pythonhosted.org/packages/f4/63/df50120a7744492710854860783d6819ff23e482dee15462c9a833cc428a/coverage-7.6.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:a09ece4a69cf399510c8ab25e0950d9cf2b42f7b3cb0374f95d2e2ff594478a6", size = 235142 }, - { url = "https://files.pythonhosted.org/packages/3a/5d/9d0acfcded2b3e9ce1c7923ca52ccc00c78a74e112fc2aee661125b7843b/coverage-7.6.1-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:9054a0754de38d9dbd01a46621636689124d666bad1936d76c0341f7d71bf569", size = 233805 }, - { url = "https://files.pythonhosted.org/packages/c4/56/50abf070cb3cd9b1dd32f2c88f083aab561ecbffbcd783275cb51c17f11d/coverage-7.6.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:0dbde0f4aa9a16fa4d754356a8f2e36296ff4d83994b2c9d8398aa32f222f989", size = 234655 }, - { url = "https://files.pythonhosted.org/packages/25/ee/b4c246048b8485f85a2426ef4abab88e48c6e80c74e964bea5cd4cd4b115/coverage-7.6.1-cp38-cp38-win32.whl", hash = "sha256:da511e6ad4f7323ee5702e6633085fb76c2f893aaf8ce4c51a0ba4fc07580ea7", size = 209296 }, - { url = "https://files.pythonhosted.org/packages/5c/1c/96cf86b70b69ea2b12924cdf7cabb8ad10e6130eab8d767a1099fbd2a44f/coverage-7.6.1-cp38-cp38-win_amd64.whl", hash = "sha256:3f1156e3e8f2872197af3840d8ad307a9dd18e615dc64d9ee41696f287c57ad8", size = 210137 }, - { url = "https://files.pythonhosted.org/packages/19/d3/d54c5aa83268779d54c86deb39c1c4566e5d45c155369ca152765f8db413/coverage-7.6.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:abd5fd0db5f4dc9289408aaf34908072f805ff7792632250dcb36dc591d24255", size = 206688 }, - { url = "https://files.pythonhosted.org/packages/a5/fe/137d5dca72e4a258b1bc17bb04f2e0196898fe495843402ce826a7419fe3/coverage-7.6.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:547f45fa1a93154bd82050a7f3cddbc1a7a4dd2a9bf5cb7d06f4ae29fe94eaf8", size = 207120 }, - { url = "https://files.pythonhosted.org/packages/78/5b/a0a796983f3201ff5485323b225d7c8b74ce30c11f456017e23d8e8d1945/coverage-7.6.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:645786266c8f18a931b65bfcefdbf6952dd0dea98feee39bd188607a9d307ed2", size = 235249 }, - { url = "https://files.pythonhosted.org/packages/4e/e1/76089d6a5ef9d68f018f65411fcdaaeb0141b504587b901d74e8587606ad/coverage-7.6.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9e0b2df163b8ed01d515807af24f63de04bebcecbd6c3bfeff88385789fdf75a", size = 233237 }, - { url = "https://files.pythonhosted.org/packages/9a/6f/eef79b779a540326fee9520e5542a8b428cc3bfa8b7c8f1022c1ee4fc66c/coverage-7.6.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:609b06f178fe8e9f89ef676532760ec0b4deea15e9969bf754b37f7c40326dbc", size = 234311 }, - { url = "https://files.pythonhosted.org/packages/75/e1/656d65fb126c29a494ef964005702b012f3498db1a30dd562958e85a4049/coverage-7.6.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:702855feff378050ae4f741045e19a32d57d19f3e0676d589df0575008ea5004", size = 233453 }, - { url = "https://files.pythonhosted.org/packages/68/6a/45f108f137941a4a1238c85f28fd9d048cc46b5466d6b8dda3aba1bb9d4f/coverage-7.6.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:2bdb062ea438f22d99cba0d7829c2ef0af1d768d1e4a4f528087224c90b132cb", size = 231958 }, - { url = "https://files.pythonhosted.org/packages/9b/e7/47b809099168b8b8c72ae311efc3e88c8d8a1162b3ba4b8da3cfcdb85743/coverage-7.6.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:9c56863d44bd1c4fe2abb8a4d6f5371d197f1ac0ebdee542f07f35895fc07f36", size = 232938 }, - { url = "https://files.pythonhosted.org/packages/52/80/052222ba7058071f905435bad0ba392cc12006380731c37afaf3fe749b88/coverage-7.6.1-cp39-cp39-win32.whl", hash = "sha256:6e2cd258d7d927d09493c8df1ce9174ad01b381d4729a9d8d4e38670ca24774c", size = 209352 }, - { url = "https://files.pythonhosted.org/packages/b8/d8/1b92e0b3adcf384e98770a00ca095da1b5f7b483e6563ae4eb5e935d24a1/coverage-7.6.1-cp39-cp39-win_amd64.whl", hash = "sha256:06a737c882bd26d0d6ee7269b20b12f14a8704807a01056c80bb881a4b2ce6ca", size = 210153 }, { url = "https://files.pythonhosted.org/packages/a5/2b/0354ed096bca64dc8e32a7cbcae28b34cb5ad0b1fe2125d6d99583313ac0/coverage-7.6.1-pp38.pp39.pp310-none-any.whl", hash = "sha256:e9a6e0eb86070e8ccaedfbd9d38fec54864f3125ab95419970575b42af7541df", size = 198926 }, ] @@ -300,8 +229,6 @@ source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version >= '3.12'", "python_full_version == '3.11.*'", - "python_full_version == '3.10.*'", - "python_full_version == '3.9.*'", ] sdist = { url = "https://files.pythonhosted.org/packages/0c/d6/2b53ab3ee99f2262e6f0b8369a43f6d66658eab45510331c0b3d5c8c4272/coverage-7.6.12.tar.gz", hash = "sha256:48cfc4641d95d34766ad41d9573cc0f22a48aa88d22657a1fe01dca0dbae4de2", size = 805941 } wheels = [ @@ -335,20 +262,77 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1c/8e/5bb04f0318805e190984c6ce106b4c3968a9562a400180e549855d8211bd/coverage-7.6.12-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b076e625396e787448d27a411aefff867db2bffac8ed04e8f7056b07024eed5a", size = 241329 }, { url = "https://files.pythonhosted.org/packages/9e/9d/fa04d9e6c3f6459f4e0b231925277cfc33d72dfab7fa19c312c03e59da99/coverage-7.6.12-cp312-cp312-win32.whl", hash = "sha256:00b2086892cf06c7c2d74983c9595dc511acca00665480b3ddff749ec4fb2a95", size = 211289 }, { url = "https://files.pythonhosted.org/packages/53/40/53c7ffe3c0c3fff4d708bc99e65f3d78c129110d6629736faf2dbd60ad57/coverage-7.6.12-cp312-cp312-win_amd64.whl", hash = "sha256:7ae6eabf519bc7871ce117fb18bf14e0e343eeb96c377667e3e5dd12095e0288", size = 212079 }, - { url = "https://files.pythonhosted.org/packages/6c/eb/cf062b1c3dbdcafd64a2a154beea2e4aa8e9886c34e41f53fa04925c8b35/coverage-7.6.12-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:e7575ab65ca8399c8c4f9a7d61bbd2d204c8b8e447aab9d355682205c9dd948d", size = 208343 }, - { url = "https://files.pythonhosted.org/packages/95/42/4ebad0ab065228e29869a060644712ab1b0821d8c29bfefa20c2118c9e19/coverage-7.6.12-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8161d9fbc7e9fe2326de89cd0abb9f3599bccc1287db0aba285cb68d204ce929", size = 208769 }, - { url = "https://files.pythonhosted.org/packages/44/9f/421e84f7f9455eca85ff85546f26cbc144034bb2587e08bfc214dd6e9c8f/coverage-7.6.12-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3a1e465f398c713f1b212400b4e79a09829cd42aebd360362cd89c5bdc44eb87", size = 237553 }, - { url = "https://files.pythonhosted.org/packages/c9/c4/a2c4f274bcb711ed5db2ccc1b851ca1c45f35ed6077aec9d6c61845d80e3/coverage-7.6.12-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f25d8b92a4e31ff1bd873654ec367ae811b3a943583e05432ea29264782dc32c", size = 235473 }, - { url = "https://files.pythonhosted.org/packages/e0/10/a3d317e38e5627b06debe861d6c511b1611dd9dc0e2a47afbe6257ffd341/coverage-7.6.12-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1a936309a65cc5ca80fa9f20a442ff9e2d06927ec9a4f54bcba9c14c066323f2", size = 236575 }, - { url = "https://files.pythonhosted.org/packages/4d/49/51cd991b56257d2e07e3d5cb053411e9de5b0f4e98047167ec05e4e19b55/coverage-7.6.12-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:aa6f302a3a0b5f240ee201297fff0bbfe2fa0d415a94aeb257d8b461032389bd", size = 235690 }, - { url = "https://files.pythonhosted.org/packages/f7/87/631e5883fe0a80683a1f20dadbd0f99b79e17a9d8ea9aff3a9b4cfe50b93/coverage-7.6.12-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:f973643ef532d4f9be71dd88cf7588936685fdb576d93a79fe9f65bc337d9d73", size = 234040 }, - { url = "https://files.pythonhosted.org/packages/7c/34/edd03f6933f766ec97dddd178a7295855f8207bb708dbac03777107ace5b/coverage-7.6.12-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:78f5243bb6b1060aed6213d5107744c19f9571ec76d54c99cc15938eb69e0e86", size = 235048 }, - { url = "https://files.pythonhosted.org/packages/ee/1e/d45045b7d3012fe518c617a57b9f9396cdaebe6455f1b404858b32c38cdd/coverage-7.6.12-cp39-cp39-win32.whl", hash = "sha256:69e62c5034291c845fc4df7f8155e8544178b6c774f97a99e2734b05eb5bed31", size = 211085 }, - { url = "https://files.pythonhosted.org/packages/df/ea/086cb06af14a84fe773b86aa140892006a906c5ec947e609ceb6a93f6257/coverage-7.6.12-cp39-cp39-win_amd64.whl", hash = "sha256:b01a840ecc25dce235ae4c1b6a0daefb2a203dba0e6e980637ee9c2f6ee0df57", size = 211965 }, { url = "https://files.pythonhosted.org/packages/7a/7f/05818c62c7afe75df11e0233bd670948d68b36cdbf2a339a095bc02624a8/coverage-7.6.12-pp39.pp310-none-any.whl", hash = "sha256:7e39e845c4d764208e7b8f6a21c541ade741e2c41afabdfa1caa28687a3c98cf", size = 200558 }, { url = "https://files.pythonhosted.org/packages/fb/b2/f655700e1024dec98b10ebaafd0cedbc25e40e4abe62a3c8e2ceef4f8f0a/coverage-7.6.12-py3-none-any.whl", hash = "sha256:eb8668cfbc279a536c633137deeb9435d2962caec279c3f8cf8b91fff6ff8953", size = 200552 }, ] +[[package]] +name = "cuda-bindings" +version = "13.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cuda-pathfinder", marker = "python_full_version >= '3.11'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/1a/fe/7351d7e586a8b4c9f89731bfe4cf0148223e8f9903ff09571f78b3fb0682/cuda_bindings-13.2.0-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:08b395f79cb89ce0cd8effff07c4a1e20101b873c256a1aeb286e8fd7bd0f556", size = 5744254 }, + { url = "https://files.pythonhosted.org/packages/aa/ef/184aa775e970fc089942cd9ec6302e6e44679d4c14549c6a7ea45bf7f798/cuda_bindings-13.2.0-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d6f3682ec3c4769326aafc67c2ba669d97d688d0b7e63e659d36d2f8b72f32d6", size = 6329075 }, + { url = "https://files.pythonhosted.org/packages/e0/a9/3a8241c6e19483ac1f1dcf5c10238205dcb8a6e9d0d4d4709240dff28ff4/cuda_bindings-13.2.0-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:721104c603f059780d287969be3d194a18d0cc3b713ed9049065a1107706759d", size = 5730273 }, + { url = "https://files.pythonhosted.org/packages/e9/94/2748597f47bb1600cd466b20cab4159f1530a3a33fe7f70fee199b3abb9e/cuda_bindings-13.2.0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1eba9504ac70667dd48313395fe05157518fd6371b532790e96fbb31bbb5a5e1", size = 6313924 }, + { url = "https://files.pythonhosted.org/packages/52/c8/b2589d68acf7e3d63e2be330b84bc25712e97ed799affbca7edd7eae25d6/cuda_bindings-13.2.0-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e865447abfb83d6a98ad5130ed3c70b1fc295ae3eeee39fd07b4ddb0671b6788", size = 5722404 }, + { url = "https://files.pythonhosted.org/packages/1f/92/f899f7bbb5617bb65ec52a6eac1e9a1447a86b916c4194f8a5001b8cde0c/cuda_bindings-13.2.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:46d8776a55d6d5da9dd6e9858fba2efcda2abe6743871dee47dd06eb8cb6d955", size = 6320619 }, +] + +[[package]] +name = "cuda-pathfinder" +version = "1.5.3" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d3/d6/ac63065d33dd700fee7ebd7d287332401b54e31b9346e142f871e1f0b116/cuda_pathfinder-1.5.3-py3-none-any.whl", hash = "sha256:dff021123aedbb4117cc7ec81717bbfe198fb4e8b5f1ee57e0e084fec5c8577d", size = 49991 }, +] + +[[package]] +name = "cuda-toolkit" +version = "13.0.2" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/57/b2/453099f5f3b698d7d0eab38916aac44c7f76229f451709e2eb9db6615dcd/cuda_toolkit-13.0.2-py2.py3-none-any.whl", hash = "sha256:b198824cf2f54003f50d64ada3a0f184b42ca0846c1c94192fa269ecd97a66eb", size = 2364 }, +] + +[package.optional-dependencies] +cublas = [ + { name = "nvidia-cublas", marker = "(python_full_version >= '3.11' and sys_platform == 'linux') or (python_full_version >= '3.11' and sys_platform == 'win32')" }, +] +cudart = [ + { name = "nvidia-cuda-runtime", marker = "(python_full_version >= '3.11' and sys_platform == 'linux') or (python_full_version >= '3.11' and sys_platform == 'win32')" }, +] +cufft = [ + { name = "nvidia-cufft", marker = "(python_full_version >= '3.11' and sys_platform == 'linux') or (python_full_version >= '3.11' and sys_platform == 'win32')" }, +] +cufile = [ + { name = "nvidia-cufile", marker = "python_full_version >= '3.11' and sys_platform == 'linux'" }, +] +cupti = [ + { name = "nvidia-cuda-cupti", marker = "(python_full_version >= '3.11' and sys_platform == 'linux') or (python_full_version >= '3.11' and sys_platform == 'win32')" }, +] +curand = [ + { name = "nvidia-curand", marker = "(python_full_version >= '3.11' and sys_platform == 'linux') or (python_full_version >= '3.11' and sys_platform == 'win32')" }, +] +cusolver = [ + { name = "nvidia-cusolver", marker = "(python_full_version >= '3.11' and sys_platform == 'linux') or (python_full_version >= '3.11' and sys_platform == 'win32')" }, +] +cusparse = [ + { name = "nvidia-cusparse", marker = "(python_full_version >= '3.11' and sys_platform == 'linux') or (python_full_version >= '3.11' and sys_platform == 'win32')" }, +] +nvjitlink = [ + { name = "nvidia-nvjitlink", marker = "(python_full_version >= '3.11' and sys_platform == 'linux') or (python_full_version >= '3.11' and sys_platform == 'win32')" }, +] +nvrtc = [ + { name = "nvidia-cuda-nvrtc", marker = "(python_full_version >= '3.11' and sys_platform == 'linux') or (python_full_version >= '3.11' and sys_platform == 'win32')" }, +] +nvtx = [ + { name = "nvidia-nvtx", marker = "(python_full_version >= '3.11' and sys_platform == 'linux') or (python_full_version >= '3.11' and sys_platform == 'win32')" }, +] + [[package]] name = "dill" version = "0.3.9" @@ -399,7 +383,7 @@ name = "filelock" version = "3.16.1" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version < '3.9'", + "python_full_version < '3.11'", ] sdist = { url = "https://files.pythonhosted.org/packages/9d/db/3ef5bb276dae18d6ec2124224403d1d67bccdbefc17af4cc8f553e341ab1/filelock-3.16.1.tar.gz", hash = "sha256:c249fbfcd5db47e5e2d6d62198e565475ee65e4831e2561c8e313fa7eb961435", size = 18037 } wheels = [ @@ -413,8 +397,6 @@ source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version >= '3.12'", "python_full_version == '3.11.*'", - "python_full_version == '3.10.*'", - "python_full_version == '3.9.*'", ] sdist = { url = "https://files.pythonhosted.org/packages/dc/9c/0b15fb47b464e1b663b1acd1253a062aa5feecb07d4e597daea542ebd2b5/filelock-3.17.0.tar.gz", hash = "sha256:ee4e77401ef576ebb38cd7f13b9b28893194acc20a8e68e18730ba9c0e54660e", size = 18027 } wheels = [ @@ -469,12 +451,37 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/51/0b/0d7fee5919bccc1fdc1c2a7528b98f65c6f69b223a3fd8f809918c142c36/freezegun-1.5.1-py3-none-any.whl", hash = "sha256:bf111d7138a8abe55ab48a71755673dbaa4ab87f4cff5634a4442dfec34c15f1", size = 17569 }, ] +[[package]] +name = "fsspec" +version = "2025.10.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.11'", +] +sdist = { url = "https://files.pythonhosted.org/packages/24/7f/2747c0d332b9acfa75dc84447a066fdf812b5a6b8d30472b74d309bfe8cb/fsspec-2025.10.0.tar.gz", hash = "sha256:b6789427626f068f9a83ca4e8a3cc050850b6c0f71f99ddb4f542b8266a26a59", size = 309285 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/eb/02/a6b21098b1d5d6249b7c5ab69dde30108a71e4e819d4a9778f1de1d5b70d/fsspec-2025.10.0-py3-none-any.whl", hash = "sha256:7c7712353ae7d875407f97715f0e1ffcc21e33d5b24556cb1e090ae9409ec61d", size = 200966 }, +] + +[[package]] +name = "fsspec" +version = "2026.3.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.12'", + "python_full_version == '3.11.*'", +] +sdist = { url = "https://files.pythonhosted.org/packages/e1/cf/b50ddf667c15276a9ab15a70ef5f257564de271957933ffea49d2cdbcdfb/fsspec-2026.3.0.tar.gz", hash = "sha256:1ee6a0e28677557f8c2f994e3eea77db6392b4de9cd1f5d7a9e87a0ae9d01b41", size = 313547 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d5/1f/5f4a3cd9e4440e9d9bc78ad0a91a1c8d46b4d429d5239ebe6793c9fe5c41/fsspec-2026.3.0-py3-none-any.whl", hash = "sha256:d2ceafaad1b3457968ed14efa28798162f1638dbb5d2a6868a2db002a5ee39a4", size = 202595 }, +] + [[package]] name = "gast" version = "0.4.0" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version < '3.9'", + "python_full_version < '3.11'", ] sdist = { url = "https://files.pythonhosted.org/packages/83/4a/07c7e59cef23fb147454663c3271c21da68ba2ab141427c20548ae5a8a4d/gast-0.4.0.tar.gz", hash = "sha256:40feb7b8b8434785585ab224d1568b857edb18297e5a3047f1ba012bc83b42c1", size = 13804 } wheels = [ @@ -488,8 +495,6 @@ source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version >= '3.12'", "python_full_version == '3.11.*'", - "python_full_version == '3.10.*'", - "python_full_version == '3.9.*'", ] sdist = { url = "https://files.pythonhosted.org/packages/3c/14/c566f5ca00c115db7725263408ff952b8ae6d6a4e792ef9c84e77d9af7a1/gast-0.6.0.tar.gz", hash = "sha256:88fc5300d32c7ac6ca7b515310862f71e6fdf2c029bbec7c66c0f5dd47b6b1fb", size = 27708 } wheels = [ @@ -532,33 +537,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1d/9a/4114a9057db2f1462d5c8f8390ab7383925fe1ac012eaa42402ad65c2963/GitPython-3.1.44-py3-none-any.whl", hash = "sha256:9e0e10cda9bed1ee64bc9a6de50e7e38a9c9943241cd7f585f6df3ed28011110", size = 207599 }, ] -[[package]] -name = "google-auth" -version = "2.38.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "cachetools", marker = "python_full_version < '3.9'" }, - { name = "pyasn1-modules", marker = "python_full_version < '3.9'" }, - { name = "rsa", marker = "python_full_version < '3.9'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/c6/eb/d504ba1daf190af6b204a9d4714d457462b486043744901a6eeea711f913/google_auth-2.38.0.tar.gz", hash = "sha256:8285113607d3b80a3f1543b75962447ba8a09fe85783432a784fdeef6ac094c4", size = 270866 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/9d/47/603554949a37bca5b7f894d51896a9c534b9eab808e2520a748e081669d0/google_auth-2.38.0-py2.py3-none-any.whl", hash = "sha256:e7dae6694313f434a2727bf2906f27ad259bae090d7aa896590d86feec3d9d4a", size = 210770 }, -] - -[[package]] -name = "google-auth-oauthlib" -version = "0.4.6" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "google-auth", marker = "python_full_version < '3.9'" }, - { name = "requests-oauthlib", marker = "python_full_version < '3.9'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/30/21/b84fa7ef834d4b126faad13da6e582c8f888e196326b9d6aab1ae303df4f/google-auth-oauthlib-0.4.6.tar.gz", hash = "sha256:a90a072f6993f2c327067bf65270046384cda5a8ecb20b94ea9a687f1f233a7a", size = 19516 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/b1/0e/0636cc1448a7abc444fb1b3a63655e294e0d2d49092dc3de05241be6d43c/google_auth_oauthlib-0.4.6-py2.py3-none-any.whl", hash = "sha256:3f2a6e802eebbb6fb736a370fbf3b055edcb6b52878bf2f26330b5e041316c73", size = 18306 }, -] - [[package]] name = "google-pasta" version = "0.2.0" @@ -576,11 +554,10 @@ name = "griffe" version = "1.4.0" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version < '3.9'", + "python_full_version < '3.11'", ] dependencies = [ - { name = "astunparse", marker = "python_full_version < '3.9'" }, - { name = "colorama", marker = "python_full_version < '3.9'" }, + { name = "colorama", marker = "python_full_version < '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/05/e9/b2c86ad9d69053e497a24ceb25d661094fb321ab4ed39a8b71793dcbae82/griffe-1.4.0.tar.gz", hash = "sha256:8fccc585896d13f1221035d32c50dec65830c87d23f9adb9b1e6f3d63574f7f5", size = 381028 } wheels = [ @@ -594,11 +571,9 @@ source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version >= '3.12'", "python_full_version == '3.11.*'", - "python_full_version == '3.10.*'", - "python_full_version == '3.9.*'", ] dependencies = [ - { name = "colorama", marker = "python_full_version >= '3.9'" }, + { name = "colorama", marker = "python_full_version >= '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/59/80/13b6456bfbf8bc854875e58d3a3bad297ee19ebdd693ce62a10fab007e7a/griffe-1.5.7.tar.gz", hash = "sha256:465238c86deaf1137761f700fb343edd8ffc846d72f6de43c3c345ccdfbebe92", size = 391503 } wheels = [ @@ -638,24 +613,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9c/b2/6a97ac91042a2c59d18244c479ee3894e7fb6f8c3a90619bb5a7757fa30c/grpcio-1.70.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ac073fe1c4cd856ebcf49e9ed6240f4f84d7a4e6ee95baa5d66ea05d3dd0df7f", size = 6190055 }, { url = "https://files.pythonhosted.org/packages/86/2b/28db55c8c4d156053a8c6f4683e559cd0a6636f55a860f87afba1ac49a51/grpcio-1.70.0-cp312-cp312-win32.whl", hash = "sha256:cd24d2d9d380fbbee7a5ac86afe9787813f285e684b0271599f95a51bce33528", size = 3600214 }, { url = "https://files.pythonhosted.org/packages/17/c3/a7a225645a965029ed432e5b5e9ed959a574e62100afab553eef58be0e37/grpcio-1.70.0-cp312-cp312-win_amd64.whl", hash = "sha256:0495c86a55a04a874c7627fd33e5beaee771917d92c0e6d9d797628ac40e7655", size = 4292538 }, - { url = "https://files.pythonhosted.org/packages/38/5f/d7fe323c18a2ec98a2a9b38fb985f5e843f76990298d7c4ce095f44b46a7/grpcio-1.70.0-cp38-cp38-linux_armv7l.whl", hash = "sha256:8058667a755f97407fca257c844018b80004ae8035565ebc2812cc550110718d", size = 5232027 }, - { url = "https://files.pythonhosted.org/packages/d4/4b/3d3b5548575b635f51883212a482cd237e8525535d4591b9dc7e5b2c2ddc/grpcio-1.70.0-cp38-cp38-macosx_10_14_universal2.whl", hash = "sha256:879a61bf52ff8ccacbedf534665bb5478ec8e86ad483e76fe4f729aaef867cab", size = 11448811 }, - { url = "https://files.pythonhosted.org/packages/8a/d7/9a0922fc12d339271c7e4e6691470172b7c13715fed7bd934274803f1527/grpcio-1.70.0-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:0ba0a173f4feacf90ee618fbc1a27956bfd21260cd31ced9bc707ef551ff7dc7", size = 5711890 }, - { url = "https://files.pythonhosted.org/packages/1e/ae/d4dbf8bff0f1d270f118d08558bc8dc0489e026d6620a4e3ee2d79d79041/grpcio-1.70.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:558c386ecb0148f4f99b1a65160f9d4b790ed3163e8610d11db47838d452512d", size = 6331933 }, - { url = "https://files.pythonhosted.org/packages/2c/64/66a74c02b00e00b919c245ca9da8e5c44e8692bf3fe7f27efbc97572566c/grpcio-1.70.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:412faabcc787bbc826f51be261ae5fa996b21263de5368a55dc2cf824dc5090e", size = 5950685 }, - { url = "https://files.pythonhosted.org/packages/b0/64/e992ac693118c37164e085676216d258804d7a5bbf3581d3f989c843a9a5/grpcio-1.70.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:3b0f01f6ed9994d7a0b27eeddea43ceac1b7e6f3f9d86aeec0f0064b8cf50fdb", size = 6640974 }, - { url = "https://files.pythonhosted.org/packages/57/17/34d0a6af4477fd48b8b41d13782fb1e35b8841b17d6ac7a3eb24d2f3b17e/grpcio-1.70.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:7385b1cb064734005204bc8994eed7dcb801ed6c2eda283f613ad8c6c75cf873", size = 6204792 }, - { url = "https://files.pythonhosted.org/packages/d3/e5/e45d8eb81929c0becd5bda413b60262f79d862e19cff632d496909aa3bd0/grpcio-1.70.0-cp38-cp38-win32.whl", hash = "sha256:07269ff4940f6fb6710951116a04cd70284da86d0a4368fd5a3b552744511f5a", size = 3620015 }, - { url = "https://files.pythonhosted.org/packages/87/7d/36009c38093e62969c708f20b86ab6761c2ba974b12ff10def6f397f24fa/grpcio-1.70.0-cp38-cp38-win_amd64.whl", hash = "sha256:aba19419aef9b254e15011b230a180e26e0f6864c90406fdbc255f01d83bc83c", size = 4307043 }, - { url = "https://files.pythonhosted.org/packages/9d/0e/64061c9746a2dd6e07cb0a0f3829f0a431344add77ec36397cc452541ff6/grpcio-1.70.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:4f1937f47c77392ccd555728f564a49128b6a197a05a5cd527b796d36f3387d0", size = 5231123 }, - { url = "https://files.pythonhosted.org/packages/72/9f/c93501d5f361aecee0146ab19300d5acb1c2747b00217c641f06fffbcd62/grpcio-1.70.0-cp39-cp39-macosx_10_14_universal2.whl", hash = "sha256:0cd430b9215a15c10b0e7d78f51e8a39d6cf2ea819fd635a7214fae600b1da27", size = 11467217 }, - { url = "https://files.pythonhosted.org/packages/0a/1a/980d115b701023450a304881bf3f6309f6fb15787f9b78d2728074f3bf86/grpcio-1.70.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:e27585831aa6b57b9250abaf147003e126cd3a6c6ca0c531a01996f31709bed1", size = 5710913 }, - { url = "https://files.pythonhosted.org/packages/a0/84/af420067029808f9790e98143b3dd0f943bebba434a4706755051a520c91/grpcio-1.70.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c1af8e15b0f0fe0eac75195992a63df17579553b0c4af9f8362cc7cc99ccddf4", size = 6330947 }, - { url = "https://files.pythonhosted.org/packages/24/1c/e1f06a7d29a1fa5053dcaf5352a50f8e1f04855fd194a65422a9d685d375/grpcio-1.70.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cbce24409beaee911c574a3d75d12ffb8c3e3dd1b813321b1d7a96bbcac46bf4", size = 5943913 }, - { url = "https://files.pythonhosted.org/packages/41/8f/de13838e4467519a50cd0693e98b0b2bcc81d656013c38a1dd7dcb801526/grpcio-1.70.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:ff4a8112a79464919bb21c18e956c54add43ec9a4850e3949da54f61c241a4a6", size = 6643236 }, - { url = "https://files.pythonhosted.org/packages/ac/73/d68c745d34e43a80440da4f3d79fa02c56cb118c2a26ba949f3cfd8316d7/grpcio-1.70.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5413549fdf0b14046c545e19cfc4eb1e37e9e1ebba0ca390a8d4e9963cab44d2", size = 6199038 }, - { url = "https://files.pythonhosted.org/packages/7e/dd/991f100b8c31636b4bb2a941dbbf54dbcc55d69c722cfa038c3d017eaa0c/grpcio-1.70.0-cp39-cp39-win32.whl", hash = "sha256:b745d2c41b27650095e81dea7091668c040457483c9bdb5d0d9de8f8eb25e59f", size = 3617512 }, - { url = "https://files.pythonhosted.org/packages/4d/80/1aa2ba791207a13e314067209b48e1a0893ed8d1f43ef012e194aaa6c2de/grpcio-1.70.0-cp39-cp39-win_amd64.whl", hash = "sha256:a31d7e3b529c94e930a117b2175b2efd179d96eb3c7a21ccb0289a8ab05b645c", size = 4303506 }, ] [[package]] @@ -663,10 +620,10 @@ name = "h5py" version = "3.11.0" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version < '3.9'", + "python_full_version < '3.11'", ] dependencies = [ - { name = "numpy", version = "1.24.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, + { name = "numpy", version = "1.24.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/52/8f/e557819155a282da36fb21f8de4730cfd10a964b52b3ae8d20157ac1c668/h5py-3.11.0.tar.gz", hash = "sha256:7b7e8f78072a2edec87c9836f25f34203fd492a4475709a18b417a33cfb21fa9", size = 406519 } wheels = [ @@ -682,14 +639,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9d/3f/cf80ef55e0a9b18aae96c763fbd275c54d0723e0f2cc54f954f87cc5c69a/h5py-3.11.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f3736fe21da2b7d8a13fe8fe415f1272d2a1ccdeff4849c1421d2fb30fd533bc", size = 2943214 }, { url = "https://files.pythonhosted.org/packages/db/7e/fedac8bb8c4729409e2dec5e4136a289116d701d54f69ce73c5617afc5f0/h5py-3.11.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa6ae84a14103e8dc19266ef4c3e5d7c00b68f21d07f2966f0ca7bdb6c2761fb", size = 5378375 }, { url = "https://files.pythonhosted.org/packages/2b/b2/0ee327933ffa37af1fc7915df7fc067e6009adcd8445d55ad07a9bec11b5/h5py-3.11.0-cp312-cp312-win_amd64.whl", hash = "sha256:21dbdc5343f53b2e25404673c4f00a3335aef25521bd5fa8c707ec3833934892", size = 2970991 }, - { url = "https://files.pythonhosted.org/packages/33/97/c1a8f28329ad794d18fc61bf251268ac03959bf93b82fdd7701ac6931fed/h5py-3.11.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:754c0c2e373d13d6309f408325343b642eb0f40f1a6ad21779cfa9502209e150", size = 3470228 }, - { url = "https://files.pythonhosted.org/packages/a4/1d/fd0b88c51c37bc8aeedecc4f4b48397f7ce13c87073aaf6912faec06e9f6/h5py-3.11.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:731839240c59ba219d4cb3bc5880d438248533366f102402cfa0621b71796b62", size = 2935809 }, - { url = "https://files.pythonhosted.org/packages/86/43/fd0bd74462b3c3fb35d98568935d3e5a435c8ec24d45ef408ac8869166af/h5py-3.11.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ec9df3dd2018904c4cc06331951e274f3f3fd091e6d6cc350aaa90fa9b42a76", size = 5309045 }, - { url = "https://files.pythonhosted.org/packages/15/9a/b5456e1acc4abb382938d4a730600823bfe77a4bbfd29140ccbf01ba5596/h5py-3.11.0-cp38-cp38-win_amd64.whl", hash = "sha256:55106b04e2c83dfb73dc8732e9abad69d83a436b5b82b773481d95d17b9685e1", size = 2989172 }, - { url = "https://files.pythonhosted.org/packages/c2/1f/36a84945616881bd47e6c40dcdca7e929bc811725d78d001eddba6864185/h5py-3.11.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f4e025e852754ca833401777c25888acb96889ee2c27e7e629a19aee288833f0", size = 3490090 }, - { url = "https://files.pythonhosted.org/packages/3c/fb/e213586de5ea56f1747a843e725c62eef350512be57452186996ba660d52/h5py-3.11.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:6c4b760082626120031d7902cd983d8c1f424cdba2809f1067511ef283629d4b", size = 2951710 }, - { url = "https://files.pythonhosted.org/packages/71/28/69a881e01f198ccdb65c36f7adcfef22bfe85e38ffbfdf833af24f58eb5e/h5py-3.11.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:67462d0669f8f5459529de179f7771bd697389fcb3faab54d63bf788599a48ea", size = 5326481 }, - { url = "https://files.pythonhosted.org/packages/c3/61/0b35ad9aac0ab0a33365879556fdb824fc83013df69b247386690db59015/h5py-3.11.0-cp39-cp39-win_amd64.whl", hash = "sha256:d9c944d364688f827dc889cf83f1fca311caf4fa50b19f009d1f2b525edd33a3", size = 2978689 }, ] [[package]] @@ -699,11 +648,9 @@ source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version >= '3.12'", "python_full_version == '3.11.*'", - "python_full_version == '3.10.*'", - "python_full_version == '3.9.*'", ] dependencies = [ - { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, + { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/03/2e/a22d6a8bfa6f8be33e7febd985680fba531562795f0a9077ed1eb047bfb0/h5py-3.13.0.tar.gz", hash = "sha256:1870e46518720023da85d0895a1960ff2ce398c5671eac3b1a41ec696b7105c3", size = 414876 } wheels = [ @@ -722,11 +669,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9e/d9/aed99e1c858dc698489f916eeb7c07513bc864885d28ab3689d572ba0ea0/h5py-3.13.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:357e6dc20b101a805ccfd0024731fbaf6e8718c18c09baf3b5e4e9d198d13fca", size = 4669544 }, { url = "https://files.pythonhosted.org/packages/a7/da/3c137006ff5f0433f0fb076b1ebe4a7bf7b5ee1e8811b5486af98b500dd5/h5py-3.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d6f13f9b5ce549448c01e4dfe08ea8d1772e6078799af2c1c8d09e941230a90d", size = 4932139 }, { url = "https://files.pythonhosted.org/packages/25/61/d897952629cae131c19d4c41b2521e7dd6382f2d7177c87615c2e6dced1a/h5py-3.13.0-cp312-cp312-win_amd64.whl", hash = "sha256:21daf38171753899b5905f3d82c99b0b1ec2cbbe282a037cad431feb620e62ec", size = 2954179 }, - { url = "https://files.pythonhosted.org/packages/cd/91/3e5b4e4c399bb57141a2451c67808597ab6993f799587566c9f11dbaefe9/h5py-3.13.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:82690e89c72b85addf4fc4d5058fb1e387b6c14eb063b0b879bf3f42c3b93c35", size = 3424729 }, - { url = "https://files.pythonhosted.org/packages/12/82/4e455e12e7ff26533c762eaf324edd6b076f84c3a003a40a1e52d805e0fb/h5py-3.13.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d571644958c5e19a61c793d8d23cd02479572da828e333498c9acc463f4a3997", size = 2926632 }, - { url = "https://files.pythonhosted.org/packages/ab/c9/fb430d3277e81eade92e54e87bd73e9f60c98240a86a5f43e3b85620d7d8/h5py-3.13.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:560e71220dc92dfa254b10a4dcb12d56b574d2d87e095db20466b32a93fec3f9", size = 4285580 }, - { url = "https://files.pythonhosted.org/packages/3f/9b/3e8cded7877ec84b707df82b9c6289cd1d7ad80fef9a10bb1389c5fee8f2/h5py-3.13.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c10f061764d8dce0a9592ce08bfd5f243a00703325c388f1086037e5d619c5f1", size = 4550898 }, - { url = "https://files.pythonhosted.org/packages/cb/47/8353102cff9290861135e13eefff5a916855d2ab23bd052ec7ac144f4c48/h5py-3.13.0-cp39-cp39-win_amd64.whl", hash = "sha256:9c82ece71ed1c2b807b6628e3933bc6eae57ea21dac207dca3470e3ceaaf437c", size = 2960208 }, ] [[package]] @@ -734,7 +676,7 @@ name = "identify" version = "2.6.1" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version < '3.9'", + "python_full_version < '3.11'", ] sdist = { url = "https://files.pythonhosted.org/packages/29/bb/25024dbcc93516c492b75919e76f389bac754a3e4248682fba32b250c880/identify-2.6.1.tar.gz", hash = "sha256:91478c5fb7c3aac5ff7bf9b4344f803843dc586832d5f110d672b19aa1984c98", size = 99097 } wheels = [ @@ -748,8 +690,6 @@ source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version >= '3.12'", "python_full_version == '3.11.*'", - "python_full_version == '3.10.*'", - "python_full_version == '3.9.*'", ] sdist = { url = "https://files.pythonhosted.org/packages/f9/fa/5eb460539e6f5252a7c5a931b53426e49258cde17e3d50685031c300a8fd/identify-2.6.8.tar.gz", hash = "sha256:61491417ea2c0c5c670484fd8abbb34de34cdae1e5f39a73ee65e48e4bb663fc", size = 99249 } wheels = [ @@ -766,84 +706,146 @@ wheels = [ ] [[package]] -name = "importlib-metadata" -version = "8.5.0" +name = "importlib-resources" +version = "6.4.5" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version < '3.9'", -] -dependencies = [ - { name = "zipp", version = "3.20.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, + "python_full_version < '3.11'", ] -sdist = { url = "https://files.pythonhosted.org/packages/cd/12/33e59336dca5be0c398a7482335911a33aa0e20776128f038019f1a95f1b/importlib_metadata-8.5.0.tar.gz", hash = "sha256:71522656f0abace1d072b9e5481a48f07c138e00f079c38c8f883823f9c26bd7", size = 55304 } +sdist = { url = "https://files.pythonhosted.org/packages/98/be/f3e8c6081b684f176b761e6a2fef02a0be939740ed6f54109a2951d806f3/importlib_resources-6.4.5.tar.gz", hash = "sha256:980862a1d16c9e147a59603677fa2aa5fd82b87f223b6cb870695bcfce830065", size = 43372 } wheels = [ - { url = "https://files.pythonhosted.org/packages/a0/d9/a1e041c5e7caa9a05c925f4bdbdfb7f006d1f74996af53467bc394c97be7/importlib_metadata-8.5.0-py3-none-any.whl", hash = "sha256:45e54197d28b7a7f1559e60b95e7c567032b602131fbd588f1497f47880aa68b", size = 26514 }, + { url = "https://files.pythonhosted.org/packages/e1/6a/4604f9ae2fa62ef47b9de2fa5ad599589d28c9fd1d335f32759813dfa91e/importlib_resources-6.4.5-py3-none-any.whl", hash = "sha256:ac29d5f956f01d5e4bb63102a5a19957f1b9175e45649977264a1416783bb717", size = 36115 }, ] [[package]] -name = "importlib-metadata" -version = "8.6.1" +name = "importlib-resources" +version = "6.5.2" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version == '3.9.*'", + "python_full_version >= '3.12'", + "python_full_version == '3.11.*'", ] -dependencies = [ - { name = "zipp", version = "3.21.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.9.*'" }, +sdist = { url = "https://files.pythonhosted.org/packages/cf/8c/f834fbf984f691b4f7ff60f50b514cc3de5cc08abfc3295564dd89c5e2e7/importlib_resources-6.5.2.tar.gz", hash = "sha256:185f87adef5bcc288449d98fb4fba07cea78bc036455dd44c5fc4a2fe78fed2c", size = 44693 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a4/ed/1f1afb2e9e7f38a545d628f864d562a5ae64fe6f7a10e28ffb9b185b4e89/importlib_resources-6.5.2-py3-none-any.whl", hash = "sha256:789cfdc3ed28c78b67a06acb8126751ced69a3d5f79c095a98298cd8a760ccec", size = 37461 }, ] -sdist = { url = "https://files.pythonhosted.org/packages/33/08/c1395a292bb23fd03bdf572a1357c5a733d3eecbab877641ceacab23db6e/importlib_metadata-8.6.1.tar.gz", hash = "sha256:310b41d755445d74569f993ccfc22838295d9fe005425094fad953d7f15c8580", size = 55767 } + +[[package]] +name = "iniconfig" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d7/4b/cbd8e699e64a6f16ca3a8220661b5f83792b3017d0f79807cb8708d33913/iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3", size = 4646 } wheels = [ - { url = "https://files.pythonhosted.org/packages/79/9d/0fb148dc4d6fa4a7dd1d8378168d9b4cd8d4560a6fbf6f0121c5fc34eb68/importlib_metadata-8.6.1-py3-none-any.whl", hash = "sha256:02a89390c1e15fdfdc0d7c6b25cb3e62650d0494005c97d6f148bf5b9787525e", size = 26971 }, + { url = "https://files.pythonhosted.org/packages/ef/a6/62565a6e1cf69e10f5727360368e451d4b7f58beeac6173dc9db836a5b46/iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374", size = 5892 }, ] [[package]] -name = "importlib-resources" -version = "6.4.5" +name = "isort" +version = "5.12.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a9/c4/dc00e42c158fc4dda2afebe57d2e948805c06d5169007f1724f0683010a9/isort-5.12.0.tar.gz", hash = "sha256:8bef7dde241278824a6d83f44a544709b065191b95b6e50894bdc722fcba0504", size = 174643 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0a/63/4036ae70eea279c63e2304b91ee0ac182f467f24f86394ecfe726092340b/isort-5.12.0-py3-none-any.whl", hash = "sha256:f84c2818376e66cf843d497486ea8fed8700b340f308f076c6fb1229dff318b6", size = 91198 }, +] + +[[package]] +name = "jax" +version = "0.4.30" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version < '3.9'", + "python_full_version < '3.11'", ] dependencies = [ - { name = "zipp", version = "3.20.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, + { name = "jaxlib", version = "0.4.30", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "ml-dtypes", marker = "python_full_version < '3.11'" }, + { name = "numpy", version = "1.24.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "opt-einsum", marker = "python_full_version < '3.11'" }, + { name = "scipy", version = "1.10.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/98/be/f3e8c6081b684f176b761e6a2fef02a0be939740ed6f54109a2951d806f3/importlib_resources-6.4.5.tar.gz", hash = "sha256:980862a1d16c9e147a59603677fa2aa5fd82b87f223b6cb870695bcfce830065", size = 43372 } +sdist = { url = "https://files.pythonhosted.org/packages/15/41/d6dbafc31d6bd93eeec2e1c709adfa454266e83714ebeeed9de52a6ad881/jax-0.4.30.tar.gz", hash = "sha256:94d74b5b2db0d80672b61d83f1f63ebf99d2ab7398ec12b2ca0c9d1e97afe577", size = 1715462 } wheels = [ - { url = "https://files.pythonhosted.org/packages/e1/6a/4604f9ae2fa62ef47b9de2fa5ad599589d28c9fd1d335f32759813dfa91e/importlib_resources-6.4.5-py3-none-any.whl", hash = "sha256:ac29d5f956f01d5e4bb63102a5a19957f1b9175e45649977264a1416783bb717", size = 36115 }, + { url = "https://files.pythonhosted.org/packages/fd/f2/9dbb75de3058acfd1600cf0839bcce7ea391148c9d2b4fa5f5666e66f09e/jax-0.4.30-py3-none-any.whl", hash = "sha256:289b30ae03b52f7f4baf6ef082a9f4e3e29c1080e22d13512c5ecf02d5f1a55b", size = 2009197 }, ] [[package]] -name = "importlib-resources" -version = "6.5.2" +name = "jax" +version = "0.4.34" source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version >= '3.12'", "python_full_version == '3.11.*'", - "python_full_version == '3.10.*'", - "python_full_version == '3.9.*'", ] dependencies = [ - { name = "zipp", version = "3.21.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.9.*'" }, + { name = "jaxlib", version = "0.4.34", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "ml-dtypes", marker = "python_full_version >= '3.11'" }, + { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "opt-einsum", marker = "python_full_version >= '3.11'" }, + { name = "scipy", version = "1.15.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/cf/8c/f834fbf984f691b4f7ff60f50b514cc3de5cc08abfc3295564dd89c5e2e7/importlib_resources-6.5.2.tar.gz", hash = "sha256:185f87adef5bcc288449d98fb4fba07cea78bc036455dd44c5fc4a2fe78fed2c", size = 44693 } +sdist = { url = "https://files.pythonhosted.org/packages/19/6a/cacfcdf77841a4562e555ef35e0dbc5f8ca79c9f1010aaa4cf3973e79c69/jax-0.4.34.tar.gz", hash = "sha256:44196854f40c5f9cea3142824b9f1051f85afc3fcf7593ec5479fc8db01c58db", size = 1848472 } wheels = [ - { url = "https://files.pythonhosted.org/packages/a4/ed/1f1afb2e9e7f38a545d628f864d562a5ae64fe6f7a10e28ffb9b185b4e89/importlib_resources-6.5.2-py3-none-any.whl", hash = "sha256:789cfdc3ed28c78b67a06acb8126751ced69a3d5f79c095a98298cd8a760ccec", size = 37461 }, + { url = "https://files.pythonhosted.org/packages/06/f3/c499d358dd7f267a63d7d38ef54aadad82e28d2c28bafff15360c3091946/jax-0.4.34-py3-none-any.whl", hash = "sha256:b957ca1fc91f7343f91a186af9f19c7f342c946f95a8c11c7f1e5cdfe2e58d9e", size = 2144294 }, ] [[package]] -name = "iniconfig" -version = "2.0.0" +name = "jaxlib" +version = "0.4.30" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/d7/4b/cbd8e699e64a6f16ca3a8220661b5f83792b3017d0f79807cb8708d33913/iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3", size = 4646 } +resolution-markers = [ + "python_full_version < '3.11'", +] +dependencies = [ + { name = "ml-dtypes", marker = "python_full_version < '3.11'" }, + { name = "numpy", version = "1.24.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "scipy", version = "1.10.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, +] wheels = [ - { url = "https://files.pythonhosted.org/packages/ef/a6/62565a6e1cf69e10f5727360368e451d4b7f58beeac6173dc9db836a5b46/iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374", size = 5892 }, + { url = "https://files.pythonhosted.org/packages/f3/18/ff7f2f6d6195853ed55c5b5d835f5c8c3c8b190c7221cb04a0cb81f5db10/jaxlib-0.4.30-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:c40856e28f300938c6824ab1a615166193d6997dec946578823f6d402ad454e5", size = 83542097 }, + { url = "https://files.pythonhosted.org/packages/d4/c0/ff65503ecfed3aee11e4abe4c4e9e8a3513f072e0b595f8247b9989d1510/jaxlib-0.4.30-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4bdfda6a3c7a2b0cc0a7131009eb279e98ca4a6f25679fabb5302dd135a5e349", size = 66694495 }, + { url = "https://files.pythonhosted.org/packages/b9/d7/82df748a31a1cfbd531a12979ea846d6b676d4adfa1e91114b848665b2aa/jaxlib-0.4.30-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:28e032c9b394ab7624d89b0d9d3bbcf4d1d71694fe8b3e09d3fe64122eda7b0c", size = 67781242 }, + { url = "https://files.pythonhosted.org/packages/4a/ca/561aabed63007bb2621a62f0d816aa2f68cfe947859c8b4e61519940344b/jaxlib-0.4.30-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:d83f36ef42a403bbf7c7f2da526b34ba286988e170f4df5e58b3bb735417868c", size = 79640266 }, + { url = "https://files.pythonhosted.org/packages/b0/90/8e5347eda95d3cb695cd5ebb82f850fa7866078a6a7a0568549e34125a82/jaxlib-0.4.30-cp310-cp310-win_amd64.whl", hash = "sha256:a56678b28f96b524ded6da8ef4b38e72a532356d139cfd434da804abf4234e14", size = 51945307 }, + { url = "https://files.pythonhosted.org/packages/33/2d/b6078f5d173d3087d32b1b49e5f65d406985fb3894ff1d21905972b9c89d/jaxlib-0.4.30-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:bfb5d85b69c29c3c6e8051a0ea715ac1e532d6e54494c8d9c3813dcc00deac30", size = 83539315 }, + { url = "https://files.pythonhosted.org/packages/12/95/399da9204c3b13696baefb93468402f3389416b0caecfd9126aa94742bf2/jaxlib-0.4.30-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:974998cd8a78550402e6c09935c1f8d850cad9cc19ccd7488bde45b6f7f99c12", size = 66690971 }, + { url = "https://files.pythonhosted.org/packages/a4/f8/b85a46cb0cc4bc228cea4366b0d15caf42656c6d43cf8c91d90f7399aa4d/jaxlib-0.4.30-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:e93eb0646b41ba213252b51b0b69096b9cd1d81a35ea85c9d06663b5d11efe45", size = 67780747 }, + { url = "https://files.pythonhosted.org/packages/a6/a3/951da3d1487b2f8995a2a14cc7e9496c9a7c93aa1f1d0b33e833e24dee92/jaxlib-0.4.30-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:16b2ab18ea90d2e15941bcf45de37afc2f289a029129c88c8d7aba0404dd0043", size = 79640352 }, + { url = "https://files.pythonhosted.org/packages/bb/1a/8f45ea28a5ca67e4d23ebd70fc78ea94be6fa20323f983c7607c32c6f9a5/jaxlib-0.4.30-cp311-cp311-win_amd64.whl", hash = "sha256:3a2e2c11c179f8851a72249ba1ae40ae817dfaee9877d23b3b8f7c6b7a012f76", size = 51943960 }, + { url = "https://files.pythonhosted.org/packages/19/40/ae943d3c1fc8b50947aebbaa3bad2842759e43bc9fc91e1758c1c20a81ab/jaxlib-0.4.30-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:7704db5962b32a2be3cc07185433cbbcc94ed90ee50c84021a3f8a1ecfd66ee3", size = 83587124 }, + { url = "https://files.pythonhosted.org/packages/c6/e3/97f8edff6f64245a500415be021869522b235e8b38cd930d358b91243583/jaxlib-0.4.30-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:57090d33477fd0f0c99dc686274882ea75c44c7d712ae42dd2460b10f896131d", size = 66724768 }, + { url = "https://files.pythonhosted.org/packages/4c/c7/ee1f48f8daa409d0ed039e0d8b5ae1a447e53db3acb2ff06239828ad96d5/jaxlib-0.4.30-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:0a3850e76278038e21685975a62b622bcf3708485f13125757a0561ee4512940", size = 67800348 }, + { url = "https://files.pythonhosted.org/packages/f2/fa/a2dddea0d6965b8e433bb99aeedbe5c8a9b47110c1c4f197a7b6239daf44/jaxlib-0.4.30-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:c58a8071c4e00898282118169f6a5a97eb15a79c2897858f3a732b17891c99ab", size = 79674030 }, + { url = "https://files.pythonhosted.org/packages/db/31/3500633d61b20b882a0fbcf8100013195c31b51f71249b0b38737851fc9a/jaxlib-0.4.30-cp312-cp312-win_amd64.whl", hash = "sha256:b7079a5b1ab6864a7d4f2afaa963841451186d22c90f39719a3ff85735ce3915", size = 51965689 }, ] [[package]] -name = "isort" -version = "5.12.0" +name = "jaxlib" +version = "0.4.34" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a9/c4/dc00e42c158fc4dda2afebe57d2e948805c06d5169007f1724f0683010a9/isort-5.12.0.tar.gz", hash = "sha256:8bef7dde241278824a6d83f44a544709b065191b95b6e50894bdc722fcba0504", size = 174643 } +resolution-markers = [ + "python_full_version >= '3.12'", + "python_full_version == '3.11.*'", +] +dependencies = [ + { name = "ml-dtypes", marker = "python_full_version >= '3.11'" }, + { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "scipy", version = "1.15.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, +] wheels = [ - { url = "https://files.pythonhosted.org/packages/0a/63/4036ae70eea279c63e2304b91ee0ac182f467f24f86394ecfe726092340b/isort-5.12.0-py3-none-any.whl", hash = "sha256:f84c2818376e66cf843d497486ea8fed8700b340f308f076c6fb1229dff318b6", size = 91198 }, + { url = "https://files.pythonhosted.org/packages/24/31/2e254fe2fc23201775a7d0ccd1bcde892cfa349eb805744b81b15e0dcf74/jaxlib-0.4.34-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:b7a212a3cb5c6acc201c32ae4f4b5f5a9ac09457fbb77ba8db5ce7e7d4adc214", size = 87399257 }, + { url = "https://files.pythonhosted.org/packages/1e/67/6a344c357caad33e84b871925cd043b4218fc13a427266d1a1dedcb1c095/jaxlib-0.4.34-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:45d719a2ce0ebf21255a277b71d756f3609b7b5be70cddc5d88fd58c35219de0", size = 67617952 }, + { url = "https://files.pythonhosted.org/packages/dd/ea/12c836126419ca80248228f2236831617eedb1e3640c34c942606f33bb08/jaxlib-0.4.34-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:3e60bc826933082e99b19b87c21818a8d26fcdb01f418d47cedff554746fd6cc", size = 69391770 }, + { url = "https://files.pythonhosted.org/packages/e4/b0/a5bd34643c070e50829beec217189eab1acdfea334df1f9ddb4e5f8bec0f/jaxlib-0.4.34-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:d840e64b85f8865404d6d225b9bb340e158df1457152a361b05680e24792b232", size = 86094116 }, + { url = "https://files.pythonhosted.org/packages/d8/c9/35a4233fe74ddd5aabe89aac1b3992b0e463982564252d21fd263d4d9992/jaxlib-0.4.34-cp310-cp310-win_amd64.whl", hash = "sha256:b0001c8f0e2b1c7bc99e4f314b524a340d25653505c1a1484d4041a9d3617f6f", size = 55206389 }, + { url = "https://files.pythonhosted.org/packages/bf/14/00a3385532d72ab51bd8e9f8c3e19a2e257667955565e9fc10236771dd06/jaxlib-0.4.34-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:8ee3f93836e53c86556ccd9449a4ea43516ee05184d031a71dd692e81259f7d9", size = 87420889 }, + { url = "https://files.pythonhosted.org/packages/66/78/d1535ee73fe505dc6c8831c19c4846afdce7df5acefb9f8ee885aa73d700/jaxlib-0.4.34-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c9d3adcae43a33aad4332be9c2aedc5ef751d1e755f917a5afb30c7872eacaa8", size = 67635880 }, + { url = "https://files.pythonhosted.org/packages/aa/06/3e09e794acf308e170905d732eca0d041449503c47505cc22e8ef78a989d/jaxlib-0.4.34-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:571ef03259835458111596a71a2f4a6fabf4ec34595df4cea555035362ac5bf0", size = 69421901 }, + { url = "https://files.pythonhosted.org/packages/c7/d0/6bc81c0b1d507f403e6085ce76a429e6d7f94749d742199252e299dd1424/jaxlib-0.4.34-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:3bcfa639ca3cfaf86c8ceebd5fc0d47300fd98a078014a1d0cc03133e1523d5f", size = 86114491 }, + { url = "https://files.pythonhosted.org/packages/9d/5d/7e71019af5f6fdebe6c10eab97d01f44b931d94609330da9e142cb155f8c/jaxlib-0.4.34-cp311-cp311-win_amd64.whl", hash = "sha256:133070d4fec5525ffea4dc72956398c1cf647a04dcb37f8a935ee82af78d9965", size = 55241262 }, + { url = "https://files.pythonhosted.org/packages/bc/42/5038983664494dfb50f8669a662d965d7ea62f9250e40d8cd36dcf9ac3dd/jaxlib-0.4.34-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:c7b3e724a30426a856070aba0192b5d199e95b4411070e7ad96ad8b196877b10", size = 87473956 }, + { url = "https://files.pythonhosted.org/packages/87/2e/8a75d3107c019c370c50c01acc205da33f9d6fba830950401a772a8e9f6d/jaxlib-0.4.34-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:096f0ca309d41fa692a9d1f2f9baab1c5c8ca0749876ebb3f748e738a27c7ff4", size = 67650276 }, + { url = "https://files.pythonhosted.org/packages/af/09/cceae2d251a506b4297679d10ee9f5e905a6b992b0687d553c9470ffd1db/jaxlib-0.4.34-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:1a30771d85fa77f9ab8f18e63240f455ab3a3f87660ed7b8d5eea6ceecbe5c1e", size = 69431284 }, + { url = "https://files.pythonhosted.org/packages/e7/0d/4faf839e3c8ce2a5b615df64427be3e870899c72c0ebfb5859348150aba1/jaxlib-0.4.34-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:48272e9034ff868d4328cf0055a07882fd2be93f59dfb6283af7de491f9d1290", size = 86151183 }, + { url = "https://files.pythonhosted.org/packages/a4/bc/a38f99071fca6cc31ae949e508a23b0de5de559da594443bb625a1adb8f3/jaxlib-0.4.34-cp312-cp312-win_amd64.whl", hash = "sha256:901cb4040ed24eae40071d8114ea8d10dff436277fa74a1a5b9e7206f641151c", size = 55278745 }, ] [[package]] @@ -851,41 +853,41 @@ name = "jinja2" version = "3.1.5" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "markupsafe", version = "2.1.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "markupsafe", version = "3.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, + { name = "markupsafe", version = "2.1.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "markupsafe", version = "3.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/af/92/b3130cbbf5591acf9ade8708c365f3238046ac7cb8ccba6e81abccb0ccff/jinja2-3.1.5.tar.gz", hash = "sha256:8fefff8dc3034e27bb80d67c671eb8a9bc424c0ef4c0826edbff304cceff43bb", size = 244674 } wheels = [ { url = "https://files.pythonhosted.org/packages/bd/0f/2ba5fbcd631e3e88689309dbe978c5769e883e4b84ebfe7da30b43275c5a/jinja2-3.1.5-py3-none-any.whl", hash = "sha256:aba0f4dc9ed8013c424088f68a5c226f7d6097ed89b246d7749c2ec4175c6adb", size = 134596 }, ] -[[package]] -name = "joblib" -version = "1.4.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/64/33/60135848598c076ce4b231e1b1895170f45fbcaeaa2c9d5e38b04db70c35/joblib-1.4.2.tar.gz", hash = "sha256:2382c5816b2636fbd20a09e0f4e9dad4736765fdfb7dca582943b9c1366b3f0e", size = 2116621 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/91/29/df4b9b42f2be0b623cbd5e2140cafcaa2bef0759a00b7b70104dcfe2fb51/joblib-1.4.2-py3-none-any.whl", hash = "sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6", size = 301817 }, -] - [[package]] name = "kamae" source = { editable = "." } dependencies = [ { name = "dill" }, - { name = "joblib" }, + { name = "keras" }, { name = "keras-tuner" }, { name = "networkx" }, - { name = "numpy", version = "1.24.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, + { name = "numpy", version = "1.24.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "pandas", version = "1.5.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.12'" }, { name = "pandas", version = "2.3.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, { name = "pyfarmhash" }, { name = "pyspark" }, - { name = "scikit-learn", version = "1.3.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "scikit-learn", version = "1.6.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, - { name = "tensorflow", version = "2.11.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "tensorflow", version = "2.16.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, + { name = "tensorflow" }, +] + +[package.optional-dependencies] +jax = [ + { name = "jax", version = "0.4.30", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "jax", version = "0.4.34", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "jaxlib", version = "0.4.30", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "jaxlib", version = "0.4.34", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, +] +torch = [ + { name = "torch", version = "2.8.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, ] [package.dev-dependencies] @@ -900,10 +902,9 @@ dev = [ { name = "mkdocs-literate-nav" }, { name = "mkdocs-material" }, { name = "mkdocs-section-index" }, - { name = "mkdocstrings", version = "0.26.1", source = { registry = "https://pypi.org/simple" }, extra = ["python"], marker = "python_full_version < '3.9'" }, - { name = "mkdocstrings", version = "0.28.2", source = { registry = "https://pypi.org/simple" }, extra = ["python"], marker = "python_full_version >= '3.9'" }, - { name = "pre-commit", version = "3.5.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "pre-commit", version = "3.8.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, + { name = "mkdocstrings", extra = ["python"] }, + { name = "pre-commit", version = "3.5.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "pre-commit", version = "3.8.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "pylint" }, { name = "pytest" }, { name = "pytest-cov" }, @@ -914,15 +915,17 @@ dev = [ [package.metadata] requires-dist = [ { name = "dill", specifier = ">=0.3.0,<1.0.0" }, - { name = "joblib", specifier = ">=1.0.0,<2.0.0" }, - { name = "keras-tuner", specifier = ">=1.0.4,<2.0.0" }, + { name = "jax", marker = "extra == 'jax'", specifier = ">=0.4.0" }, + { name = "jaxlib", marker = "extra == 'jax'", specifier = ">=0.4.0" }, + { name = "keras", specifier = ">=3.0.0,<4.0.0" }, + { name = "keras-tuner", specifier = ">=1.4.0,<2.0.0" }, { name = "networkx", specifier = ">=2.6.3,<3.0.0" }, { name = "numpy", specifier = ">=1.22.0,<2.0.0" }, { name = "pandas", specifier = ">=1.3.4,<3.0.0" }, { name = "pyfarmhash", specifier = ">=0.3.2,<0.4.0" }, { name = "pyspark", specifier = ">=3.4.0,<4.0.0" }, - { name = "scikit-learn", specifier = ">=1.0.0,<2.0.0" }, - { name = "tensorflow", specifier = ">=2.9.1,<2.19.0" }, + { name = "tensorflow", specifier = ">=2.16.0,<3.0.0" }, + { name = "torch", marker = "extra == 'torch'", specifier = ">=2.0.0" }, ] [package.metadata.requires-dev] @@ -946,36 +949,21 @@ dev = [ { name = "python-semantic-release", specifier = ">=8.0.0,<9" }, ] -[[package]] -name = "keras" -version = "2.11.0" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version < '3.9'", -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/de/44/bf1b0eef5b13e6201aef076ff34b91bc40aace8591cd273c1c2a94a9cc00/keras-2.11.0-py2.py3-none-any.whl", hash = "sha256:38c6fff0ea9a8b06a2717736565c92a73c8cd9b1c239e7125ccb188b7848f65e", size = 1685489 }, -] - [[package]] name = "keras" version = "3.8.0" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version >= '3.12'", - "python_full_version == '3.11.*'", - "python_full_version == '3.10.*'", - "python_full_version == '3.9.*'", -] dependencies = [ - { name = "absl-py", marker = "python_full_version >= '3.9'" }, - { name = "h5py", version = "3.13.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, - { name = "ml-dtypes", marker = "python_full_version >= '3.9'" }, - { name = "namex", marker = "python_full_version >= '3.9'" }, - { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, - { name = "optree", marker = "python_full_version >= '3.9'" }, - { name = "packaging", marker = "python_full_version >= '3.9'" }, - { name = "rich", marker = "python_full_version >= '3.9'" }, + { name = "absl-py" }, + { name = "h5py", version = "3.11.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "h5py", version = "3.13.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "ml-dtypes" }, + { name = "namex" }, + { name = "numpy", version = "1.24.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "optree" }, + { name = "packaging" }, + { name = "rich" }, ] sdist = { url = "https://files.pythonhosted.org/packages/cd/97/8b0b420e14008100a330d30e78df9bce04fd1845edc5d29b0a6f4d8ad061/keras-3.8.0.tar.gz", hash = "sha256:6289006e6f6cb2b68a563b58cf8ae5a45569449c5a791df6b2f54c1877f3f344", size = 975959 } wheels = [ @@ -987,8 +975,7 @@ name = "keras-tuner" version = "1.4.7" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "keras", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "keras", version = "3.8.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, + { name = "keras" }, { name = "kt-legacy" }, { name = "packaging" }, { name = "requests" }, @@ -1028,10 +1015,6 @@ wheels = [ name = "markdown" version = "3.7" source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "importlib-metadata", version = "8.5.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "importlib-metadata", version = "8.6.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.9.*'" }, -] sdist = { url = "https://files.pythonhosted.org/packages/54/28/3af612670f82f4c056911fbbbb42760255801b3068c48de792d354ff4472/markdown-3.7.tar.gz", hash = "sha256:2ae2471477cfd02dbbf038d5d9bc226d40def84b4fe2986e49b59b6b472bbed2", size = 357086 } wheels = [ { url = "https://files.pythonhosted.org/packages/3f/08/83871f3c50fc983b88547c196d11cf8c3340e37c32d2e9d6152abe2c61f7/Markdown-3.7-py3-none-any.whl", hash = "sha256:7eb6df5690b81a1d7942992c97fad2938e956e79df20cbc6186e9c3a77b1c803", size = 106349 }, @@ -1054,7 +1037,7 @@ name = "markupsafe" version = "2.1.5" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version < '3.9'", + "python_full_version < '3.11'", ] sdist = { url = "https://files.pythonhosted.org/packages/87/5b/aae44c6655f3801e81aa3eef09dbbf012431987ba564d7231722f68df02d/MarkupSafe-2.1.5.tar.gz", hash = "sha256:d283d37a890ba4c1ae73ffadf8046435c76e7bc2247bbb63c00bd1a709c6544b", size = 19384 } wheels = [ @@ -1088,26 +1071,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/88/07/2dc76aa51b481eb96a4c3198894f38b480490e834479611a4053fbf08623/MarkupSafe-2.1.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:58c98fee265677f63a4385256a6d7683ab1832f3ddd1e66fe948d5880c21a169", size = 33038 }, { url = "https://files.pythonhosted.org/packages/96/0c/620c1fb3661858c0e37eb3cbffd8c6f732a67cd97296f725789679801b31/MarkupSafe-2.1.5-cp312-cp312-win32.whl", hash = "sha256:8590b4ae07a35970728874632fed7bd57b26b0102df2d2b233b6d9d82f6c62ad", size = 16572 }, { url = "https://files.pythonhosted.org/packages/3f/14/c3554d512d5f9100a95e737502f4a2323a1959f6d0d01e0d0997b35f7b10/MarkupSafe-2.1.5-cp312-cp312-win_amd64.whl", hash = "sha256:823b65d8706e32ad2df51ed89496147a42a2a6e01c13cfb6ffb8b1e92bc910bb", size = 17127 }, - { url = "https://files.pythonhosted.org/packages/f8/ff/2c942a82c35a49df5de3a630ce0a8456ac2969691b230e530ac12314364c/MarkupSafe-2.1.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:656f7526c69fac7f600bd1f400991cc282b417d17539a1b228617081106feb4a", size = 18192 }, - { url = "https://files.pythonhosted.org/packages/4f/14/6f294b9c4f969d0c801a4615e221c1e084722ea6114ab2114189c5b8cbe0/MarkupSafe-2.1.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:97cafb1f3cbcd3fd2b6fbfb99ae11cdb14deea0736fc2b0952ee177f2b813a46", size = 14072 }, - { url = "https://files.pythonhosted.org/packages/81/d4/fd74714ed30a1dedd0b82427c02fa4deec64f173831ec716da11c51a50aa/MarkupSafe-2.1.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f3fbcb7ef1f16e48246f704ab79d79da8a46891e2da03f8783a5b6fa41a9532", size = 26928 }, - { url = "https://files.pythonhosted.org/packages/c7/bd/50319665ce81bb10e90d1cf76f9e1aa269ea6f7fa30ab4521f14d122a3df/MarkupSafe-2.1.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fa9db3f79de01457b03d4f01b34cf91bc0048eb2c3846ff26f66687c2f6d16ab", size = 26106 }, - { url = "https://files.pythonhosted.org/packages/4c/6f/f2b0f675635b05f6afd5ea03c094557bdb8622fa8e673387444fe8d8e787/MarkupSafe-2.1.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ffee1f21e5ef0d712f9033568f8344d5da8cc2869dbd08d87c84656e6a2d2f68", size = 25781 }, - { url = "https://files.pythonhosted.org/packages/51/e0/393467cf899b34a9d3678e78961c2c8cdf49fb902a959ba54ece01273fb1/MarkupSafe-2.1.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:5dedb4db619ba5a2787a94d877bc8ffc0566f92a01c0ef214865e54ecc9ee5e0", size = 30518 }, - { url = "https://files.pythonhosted.org/packages/f6/02/5437e2ad33047290dafced9df741d9efc3e716b75583bbd73a9984f1b6f7/MarkupSafe-2.1.5-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:30b600cf0a7ac9234b2638fbc0fb6158ba5bdcdf46aeb631ead21248b9affbc4", size = 29669 }, - { url = "https://files.pythonhosted.org/packages/0e/7d/968284145ffd9d726183ed6237c77938c021abacde4e073020f920e060b2/MarkupSafe-2.1.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8dd717634f5a044f860435c1d8c16a270ddf0ef8588d4887037c5028b859b0c3", size = 29933 }, - { url = "https://files.pythonhosted.org/packages/bf/f3/ecb00fc8ab02b7beae8699f34db9357ae49d9f21d4d3de6f305f34fa949e/MarkupSafe-2.1.5-cp38-cp38-win32.whl", hash = "sha256:daa4ee5a243f0f20d528d939d06670a298dd39b1ad5f8a72a4275124a7819eff", size = 16656 }, - { url = "https://files.pythonhosted.org/packages/92/21/357205f03514a49b293e214ac39de01fadd0970a6e05e4bf1ddd0ffd0881/MarkupSafe-2.1.5-cp38-cp38-win_amd64.whl", hash = "sha256:619bc166c4f2de5caa5a633b8b7326fbe98e0ccbfacabd87268a2b15ff73a029", size = 17206 }, - { url = "https://files.pythonhosted.org/packages/0f/31/780bb297db036ba7b7bbede5e1d7f1e14d704ad4beb3ce53fb495d22bc62/MarkupSafe-2.1.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:7a68b554d356a91cce1236aa7682dc01df0edba8d043fd1ce607c49dd3c1edcf", size = 18193 }, - { url = "https://files.pythonhosted.org/packages/6c/77/d77701bbef72892affe060cdacb7a2ed7fd68dae3b477a8642f15ad3b132/MarkupSafe-2.1.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:db0b55e0f3cc0be60c1f19efdde9a637c32740486004f20d1cff53c3c0ece4d2", size = 14073 }, - { url = "https://files.pythonhosted.org/packages/d9/a7/1e558b4f78454c8a3a0199292d96159eb4d091f983bc35ef258314fe7269/MarkupSafe-2.1.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3e53af139f8579a6d5f7b76549125f0d94d7e630761a2111bc431fd820e163b8", size = 26486 }, - { url = "https://files.pythonhosted.org/packages/5f/5a/360da85076688755ea0cceb92472923086993e86b5613bbae9fbc14136b0/MarkupSafe-2.1.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:17b950fccb810b3293638215058e432159d2b71005c74371d784862b7e4683f3", size = 25685 }, - { url = "https://files.pythonhosted.org/packages/6a/18/ae5a258e3401f9b8312f92b028c54d7026a97ec3ab20bfaddbdfa7d8cce8/MarkupSafe-2.1.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4c31f53cdae6ecfa91a77820e8b151dba54ab528ba65dfd235c80b086d68a465", size = 25338 }, - { url = "https://files.pythonhosted.org/packages/0b/cc/48206bd61c5b9d0129f4d75243b156929b04c94c09041321456fd06a876d/MarkupSafe-2.1.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:bff1b4290a66b490a2f4719358c0cdcd9bafb6b8f061e45c7a2460866bf50c2e", size = 30439 }, - { url = "https://files.pythonhosted.org/packages/d1/06/a41c112ab9ffdeeb5f77bc3e331fdadf97fa65e52e44ba31880f4e7f983c/MarkupSafe-2.1.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:bc1667f8b83f48511b94671e0e441401371dfd0f0a795c7daa4a3cd1dde55bea", size = 29531 }, - { url = "https://files.pythonhosted.org/packages/02/8c/ab9a463301a50dab04d5472e998acbd4080597abc048166ded5c7aa768c8/MarkupSafe-2.1.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5049256f536511ee3f7e1b3f87d1d1209d327e818e6ae1365e8653d7e3abb6a6", size = 29823 }, - { url = "https://files.pythonhosted.org/packages/bc/29/9bc18da763496b055d8e98ce476c8e718dcfd78157e17f555ce6dd7d0895/MarkupSafe-2.1.5-cp39-cp39-win32.whl", hash = "sha256:00e046b6dd71aa03a41079792f8473dc494d564611a8f89bbbd7cb93295ebdcf", size = 16658 }, - { url = "https://files.pythonhosted.org/packages/f6/f8/4da07de16f10551ca1f640c92b5f316f9394088b183c6a57183df6de5ae4/MarkupSafe-2.1.5-cp39-cp39-win_amd64.whl", hash = "sha256:fa173ec60341d6bb97a89f5ea19c85c5643c1e7dedebc22f5181eb73573142c5", size = 17211 }, ] [[package]] @@ -1117,8 +1080,6 @@ source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version >= '3.12'", "python_full_version == '3.11.*'", - "python_full_version == '3.10.*'", - "python_full_version == '3.9.*'", ] sdist = { url = "https://files.pythonhosted.org/packages/b2/97/5d42485e71dfc078108a86d6de8fa46db44a1a9295e89c5d6d4a06e23a62/markupsafe-3.0.2.tar.gz", hash = "sha256:ee55d3edf80167e48ea11a923c7386f4669df67d7994554387f84e7d8b0a2bf0", size = 20537 } wheels = [ @@ -1152,16 +1113,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a2/82/8be4c96ffee03c5b4a034e60a31294daf481e12c7c43ab8e34a1453ee48b/MarkupSafe-3.0.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ad10d3ded218f1039f11a75f8091880239651b52e9bb592ca27de44eed242a48", size = 23352 }, { url = "https://files.pythonhosted.org/packages/51/ae/97827349d3fcffee7e184bdf7f41cd6b88d9919c80f0263ba7acd1bbcb18/MarkupSafe-3.0.2-cp312-cp312-win32.whl", hash = "sha256:0f4ca02bea9a23221c0182836703cbf8930c5e9454bacce27e767509fa286a30", size = 15097 }, { url = "https://files.pythonhosted.org/packages/c1/80/a61f99dc3a936413c3ee4e1eecac96c0da5ed07ad56fd975f1a9da5bc630/MarkupSafe-3.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:8e06879fc22a25ca47312fbe7c8264eb0b662f6db27cb2d3bbbc74b1df4b9b87", size = 15601 }, - { url = "https://files.pythonhosted.org/packages/a7/ea/9b1530c3fdeeca613faeb0fb5cbcf2389d816072fab72a71b45749ef6062/MarkupSafe-3.0.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:eaa0a10b7f72326f1372a713e73c3f739b524b3af41feb43e4921cb529f5929a", size = 14344 }, - { url = "https://files.pythonhosted.org/packages/4b/c2/fbdbfe48848e7112ab05e627e718e854d20192b674952d9042ebd8c9e5de/MarkupSafe-3.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:48032821bbdf20f5799ff537c7ac3d1fba0ba032cfc06194faffa8cda8b560ff", size = 12389 }, - { url = "https://files.pythonhosted.org/packages/f0/25/7a7c6e4dbd4f867d95d94ca15449e91e52856f6ed1905d58ef1de5e211d0/MarkupSafe-3.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1a9d3f5f0901fdec14d8d2f66ef7d035f2157240a433441719ac9a3fba440b13", size = 21607 }, - { url = "https://files.pythonhosted.org/packages/53/8f/f339c98a178f3c1e545622206b40986a4c3307fe39f70ccd3d9df9a9e425/MarkupSafe-3.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:88b49a3b9ff31e19998750c38e030fc7bb937398b1f78cfa599aaef92d693144", size = 20728 }, - { url = "https://files.pythonhosted.org/packages/1a/03/8496a1a78308456dbd50b23a385c69b41f2e9661c67ea1329849a598a8f9/MarkupSafe-3.0.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cfad01eed2c2e0c01fd0ecd2ef42c492f7f93902e39a42fc9ee1692961443a29", size = 20826 }, - { url = "https://files.pythonhosted.org/packages/e6/cf/0a490a4bd363048c3022f2f475c8c05582179bb179defcee4766fb3dcc18/MarkupSafe-3.0.2-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:1225beacc926f536dc82e45f8a4d68502949dc67eea90eab715dea3a21c1b5f0", size = 21843 }, - { url = "https://files.pythonhosted.org/packages/19/a3/34187a78613920dfd3cdf68ef6ce5e99c4f3417f035694074beb8848cd77/MarkupSafe-3.0.2-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:3169b1eefae027567d1ce6ee7cae382c57fe26e82775f460f0b2778beaad66c0", size = 21219 }, - { url = "https://files.pythonhosted.org/packages/17/d8/5811082f85bb88410ad7e452263af048d685669bbbfb7b595e8689152498/MarkupSafe-3.0.2-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:eb7972a85c54febfb25b5c4b4f3af4dcc731994c7da0d8a0b4a6eb0640e1d178", size = 20946 }, - { url = "https://files.pythonhosted.org/packages/7c/31/bd635fb5989440d9365c5e3c47556cfea121c7803f5034ac843e8f37c2f2/MarkupSafe-3.0.2-cp39-cp39-win32.whl", hash = "sha256:8c4e8c3ce11e1f92f6536ff07154f9d49677ebaaafc32db9db4620bc11ed480f", size = 15063 }, - { url = "https://files.pythonhosted.org/packages/b3/73/085399401383ce949f727afec55ec3abd76648d04b9f22e1c0e99cb4bec3/MarkupSafe-3.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:6e296a513ca3d94054c2c881cc913116e90fd030ad1c656b3869762b754f5f8a", size = 15506 }, ] [[package]] @@ -1199,20 +1150,18 @@ dependencies = [ { name = "click" }, { name = "colorama", marker = "sys_platform == 'win32'" }, { name = "ghp-import" }, - { name = "importlib-metadata", version = "8.5.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "importlib-metadata", version = "8.6.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.9.*'" }, { name = "jinja2" }, { name = "markdown" }, - { name = "markupsafe", version = "2.1.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "markupsafe", version = "3.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, + { name = "markupsafe", version = "2.1.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "markupsafe", version = "3.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "mergedeep" }, { name = "mkdocs-get-deps" }, { name = "packaging" }, { name = "pathspec" }, { name = "pyyaml" }, { name = "pyyaml-env-tag" }, - { name = "watchdog", version = "4.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "watchdog", version = "6.0.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, + { name = "watchdog", version = "4.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "watchdog", version = "6.0.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/bc/c6/bbd4f061bd16b378247f12953ffcb04786a618ce5e904b8c5a01a0309061/mkdocs-1.6.1.tar.gz", hash = "sha256:7b432f01d928c084353ab39c57282f29f92136665bdd6abf7c1ec8d822ef86f2", size = 3889159 } wheels = [ @@ -1224,12 +1173,12 @@ name = "mkdocs-autorefs" version = "1.2.0" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version < '3.9'", + "python_full_version < '3.11'", ] dependencies = [ - { name = "markdown", marker = "python_full_version < '3.9'" }, - { name = "markupsafe", version = "2.1.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "mkdocs", marker = "python_full_version < '3.9'" }, + { name = "markdown", marker = "python_full_version < '3.11'" }, + { name = "markupsafe", version = "2.1.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "mkdocs", marker = "python_full_version < '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/fb/ae/0f1154c614d6a8b8a36fff084e5b82af3a15f7d2060cf0dcdb1c53297a71/mkdocs_autorefs-1.2.0.tar.gz", hash = "sha256:a86b93abff653521bda71cf3fc5596342b7a23982093915cb74273f67522190f", size = 40262 } wheels = [ @@ -1243,13 +1192,11 @@ source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version >= '3.12'", "python_full_version == '3.11.*'", - "python_full_version == '3.10.*'", - "python_full_version == '3.9.*'", ] dependencies = [ - { name = "markdown", marker = "python_full_version >= '3.9'" }, - { name = "markupsafe", version = "3.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, - { name = "mkdocs", marker = "python_full_version >= '3.9'" }, + { name = "markdown", marker = "python_full_version >= '3.11'" }, + { name = "markupsafe", version = "3.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "mkdocs", marker = "python_full_version >= '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/83/79/e846eb3323d1546b25d2ae4c957f5edf1bdfb7e0b695d43feae034c61553/mkdocs_autorefs-1.4.0.tar.gz", hash = "sha256:a9c0aa9c90edbce302c09d050a3c4cb7c76f8b7b2c98f84a7a05f53d00392156", size = 3128903 } wheels = [ @@ -1273,8 +1220,6 @@ name = "mkdocs-get-deps" version = "0.2.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "importlib-metadata", version = "8.5.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "importlib-metadata", version = "8.6.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.9.*'" }, { name = "mergedeep" }, { name = "platformdirs" }, { name = "pyyaml" }, @@ -1343,20 +1288,17 @@ wheels = [ name = "mkdocstrings" version = "0.26.1" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version < '3.9'", -] dependencies = [ - { name = "click", marker = "python_full_version < '3.9'" }, - { name = "importlib-metadata", version = "8.5.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "jinja2", marker = "python_full_version < '3.9'" }, - { name = "markdown", marker = "python_full_version < '3.9'" }, - { name = "markupsafe", version = "2.1.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "mkdocs", marker = "python_full_version < '3.9'" }, - { name = "mkdocs-autorefs", version = "1.2.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "platformdirs", marker = "python_full_version < '3.9'" }, - { name = "pymdown-extensions", marker = "python_full_version < '3.9'" }, - { name = "typing-extensions", marker = "python_full_version < '3.9'" }, + { name = "click" }, + { name = "jinja2" }, + { name = "markdown" }, + { name = "markupsafe", version = "2.1.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "markupsafe", version = "3.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "mkdocs" }, + { name = "mkdocs-autorefs", version = "1.2.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "mkdocs-autorefs", version = "1.4.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "platformdirs" }, + { name = "pymdown-extensions" }, ] sdist = { url = "https://files.pythonhosted.org/packages/e6/bf/170ff04de72227f715d67da32950c7b8434449f3805b2ec3dd1085db4d7c/mkdocstrings-0.26.1.tar.gz", hash = "sha256:bb8b8854d6713d5348ad05b069a09f3b79edbc6a0f33a34c6821141adb03fe33", size = 92677 } wheels = [ @@ -1365,84 +1307,32 @@ wheels = [ [package.optional-dependencies] python = [ - { name = "mkdocstrings-python", version = "1.11.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, -] - -[[package]] -name = "mkdocstrings" -version = "0.28.2" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version >= '3.12'", - "python_full_version == '3.11.*'", - "python_full_version == '3.10.*'", - "python_full_version == '3.9.*'", -] -dependencies = [ - { name = "importlib-metadata", version = "8.6.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.9.*'" }, - { name = "jinja2", marker = "python_full_version >= '3.9'" }, - { name = "markdown", marker = "python_full_version >= '3.9'" }, - { name = "markupsafe", version = "3.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, - { name = "mkdocs", marker = "python_full_version >= '3.9'" }, - { name = "mkdocs-autorefs", version = "1.4.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, - { name = "mkdocs-get-deps", marker = "python_full_version >= '3.9'" }, - { name = "pymdown-extensions", marker = "python_full_version >= '3.9'" }, - { name = "typing-extensions", marker = "python_full_version == '3.9.*'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/e8/83/5eab81d31953c725942eb663b6a4cf36ad06d803633c8e1c6ddc708af62d/mkdocstrings-0.28.2.tar.gz", hash = "sha256:9b847266d7a588ea76a8385eaebe1538278b4361c0d1ce48ed005be59f053569", size = 5691916 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/32/60/15ef9759431cf8e60ffda7d5bba3914cc764f2bd8e7f62e1bd301ea292e0/mkdocstrings-0.28.2-py3-none-any.whl", hash = "sha256:57f79c557e2718d217d6f6a81bf75a0de097f10e922e7e5e00f085c3f0ff6895", size = 8056702 }, -] - -[package.optional-dependencies] -python = [ - { name = "mkdocstrings-python", version = "1.16.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, + { name = "mkdocstrings-python" }, ] [[package]] name = "mkdocstrings-python" version = "1.11.1" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version < '3.9'", -] dependencies = [ - { name = "griffe", version = "1.4.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "mkdocs-autorefs", version = "1.2.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "mkdocstrings", version = "0.26.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, + { name = "griffe", version = "1.4.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "griffe", version = "1.5.7", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "mkdocs-autorefs", version = "1.2.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "mkdocs-autorefs", version = "1.4.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "mkdocstrings" }, ] sdist = { url = "https://files.pythonhosted.org/packages/fc/ba/534c934cd0a809f51c91332d6ed278782ee4126b8ba8db02c2003f162b47/mkdocstrings_python-1.11.1.tar.gz", hash = "sha256:8824b115c5359304ab0b5378a91f6202324a849e1da907a3485b59208b797322", size = 166890 } wheels = [ { url = "https://files.pythonhosted.org/packages/2f/f2/2a2c48fda645ac6bbe73bcc974587a579092b6868e6ff8bc6d177f4db38a/mkdocstrings_python-1.11.1-py3-none-any.whl", hash = "sha256:a21a1c05acef129a618517bb5aae3e33114f569b11588b1e7af3e9d4061a71af", size = 109297 }, ] -[[package]] -name = "mkdocstrings-python" -version = "1.16.2" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version >= '3.12'", - "python_full_version == '3.11.*'", - "python_full_version == '3.10.*'", - "python_full_version == '3.9.*'", -] -dependencies = [ - { name = "griffe", version = "1.5.7", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, - { name = "mkdocs-autorefs", version = "1.4.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, - { name = "mkdocstrings", version = "0.28.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, - { name = "typing-extensions", marker = "python_full_version >= '3.9' and python_full_version < '3.11'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/ed/a9/5990642e1bb2d90b049f655b92f46d0a77acb76ed59ef3233d5a6934312e/mkdocstrings_python-1.16.2.tar.gz", hash = "sha256:942ec1a2e0481d28f96f93be3d6e343cab92a21e5baf01c37dd2d7236c4d0bd7", size = 423492 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/82/a2/60be7e17a2f2a9d4bfb7273cdb2071eeeb65bdca5c0d07ff16df63221ca2/mkdocstrings_python-1.16.2-py3-none-any.whl", hash = "sha256:ff7e719404e59ad1a72f1afbe854769984c889b8fa043c160f6c988e1ad9e966", size = 449141 }, -] - [[package]] name = "ml-dtypes" version = "0.3.2" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, + { name = "numpy", version = "1.24.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/39/7d/8d85fcba868758b3a546e6914e727abd8f29ea6918079f816975c9eecd63/ml_dtypes-0.3.2.tar.gz", hash = "sha256:533059bc5f1764fac071ef54598db358c167c51a718f68f5bb55e3dee79d2967", size = 692014 } wheels = [ @@ -1458,10 +1348,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6a/05/ec30199c791cf0d788a26f56d8efb8ee4133ede79a9680fd8cc05e706404/ml_dtypes-0.3.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3a17ef2322e60858d93584e9c52a5be7dd6236b056b7fa1ec57f1bb6ba043e33", size = 2180925 }, { url = "https://files.pythonhosted.org/packages/e5/f1/93219c44bae4017e6e43391fa4433592de08e05def9d885227d3596f21a5/ml_dtypes-0.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8505946df1665db01332d885c2020b4cb9e84a8b1241eb4ba69d59591f65855", size = 2160573 }, { url = "https://files.pythonhosted.org/packages/47/f3/847da54c3d243ff2aa778078ecf09da199194d282744718ef325dd8afd41/ml_dtypes-0.3.2-cp312-cp312-win_amd64.whl", hash = "sha256:f47619d978ab1ae7dfdc4052ea97c636c6263e1f19bd1be0e42c346b98d15ff4", size = 128649 }, - { url = "https://files.pythonhosted.org/packages/7b/bb/4513133bccda7e66eb56ee38f68d1a8bbc81f072d00a40ee369c43f25ba9/ml_dtypes-0.3.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:c7b3fb3d4f6b39bcd4f6c4b98f406291f0d681a895490ee29a0f95bab850d53c", size = 389810 }, - { url = "https://files.pythonhosted.org/packages/ea/58/c56da71b1d9f9c6c1e61f63d27f901c3526e13da0589ec2ff993e9a72c04/ml_dtypes-0.3.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7a4c3fcbf86fa52d0204f07cfd23947ef05b4ad743a1a988e163caa34a201e5e", size = 2180720 }, - { url = "https://files.pythonhosted.org/packages/86/29/b389f235add26220bc7b7f100362f4e3a84e14f7c837abd34a11347df1b0/ml_dtypes-0.3.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:91f8783fd1f2c23fd3b9ee5ad66b785dafa58ba3cdb050c4458021fa4d1eb226", size = 2158181 }, - { url = "https://files.pythonhosted.org/packages/38/3c/5d058a50340759423b25cb99f930cb3691fc30ebe86d53fdf1bff55c2d71/ml_dtypes-0.3.2-cp39-cp39-win_amd64.whl", hash = "sha256:7ba8e1fafc7fff3e643f453bffa7d082df1678a73286ce8187d3e825e776eb94", size = 127704 }, +] + +[[package]] +name = "mpmath" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e0/47/dd32fa426cc72114383ac549964eecb20ecfd886d1e5ccf5340b55b02f57/mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f", size = 508106 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/43/e3/7d92a15f894aa0c9c4b49b8ee9ac9850d6e63b03c9c32c0367a13ae62209/mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c", size = 536198 }, ] [[package]] @@ -1505,7 +1400,7 @@ name = "numpy" version = "1.24.4" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version < '3.9'", + "python_full_version < '3.11'", ] sdist = { url = "https://files.pythonhosted.org/packages/a4/9b/027bec52c633f6556dba6b722d9a0befb40498b9ceddd29cbe67a45a127c/numpy-1.24.4.tar.gz", hash = "sha256:80f5e3a4e498641401868df4208b74581206afbee7cf7b8329daae82676d9463", size = 10911229 } wheels = [ @@ -1521,21 +1416,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/22/97/dfb1a31bb46686f09e68ea6ac5c63fdee0d22d7b23b8f3f7ea07712869ef/numpy-1.24.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7215847ce88a85ce39baf9e89070cb860c98fdddacbaa6c0da3ffb31b3350bd5", size = 17278923 }, { url = "https://files.pythonhosted.org/packages/35/e2/76a11e54139654a324d107da1d98f99e7aa2a7ef97cfd7c631fba7dbde71/numpy-1.24.4-cp311-cp311-win32.whl", hash = "sha256:4979217d7de511a8d57f4b4b5b2b965f707768440c17cb70fbf254c4b225238d", size = 12422446 }, { url = "https://files.pythonhosted.org/packages/d8/ec/ebef2f7d7c28503f958f0f8b992e7ce606fb74f9e891199329d5f5f87404/numpy-1.24.4-cp311-cp311-win_amd64.whl", hash = "sha256:b7b1fc9864d7d39e28f41d089bfd6353cb5f27ecd9905348c24187a768c79694", size = 14834466 }, - { url = "https://files.pythonhosted.org/packages/11/10/943cfb579f1a02909ff96464c69893b1d25be3731b5d3652c2e0cf1281ea/numpy-1.24.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1452241c290f3e2a312c137a9999cdbf63f78864d63c79039bda65ee86943f61", size = 19780722 }, - { url = "https://files.pythonhosted.org/packages/a7/ae/f53b7b265fdc701e663fbb322a8e9d4b14d9cb7b2385f45ddfabfc4327e4/numpy-1.24.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:04640dab83f7c6c85abf9cd729c5b65f1ebd0ccf9de90b270cd61935eef0197f", size = 13843102 }, - { url = "https://files.pythonhosted.org/packages/25/6f/2586a50ad72e8dbb1d8381f837008a0321a3516dfd7cb57fc8cf7e4bb06b/numpy-1.24.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5425b114831d1e77e4b5d812b69d11d962e104095a5b9c3b641a218abcc050e", size = 14039616 }, - { url = "https://files.pythonhosted.org/packages/98/5d/5738903efe0ecb73e51eb44feafba32bdba2081263d40c5043568ff60faf/numpy-1.24.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd80e219fd4c71fc3699fc1dadac5dcf4fd882bfc6f7ec53d30fa197b8ee22dc", size = 17316263 }, - { url = "https://files.pythonhosted.org/packages/d1/57/8d328f0b91c733aa9aa7ee540dbc49b58796c862b4fbcb1146c701e888da/numpy-1.24.4-cp38-cp38-win32.whl", hash = "sha256:4602244f345453db537be5314d3983dbf5834a9701b7723ec28923e2889e0bb2", size = 12455660 }, - { url = "https://files.pythonhosted.org/packages/69/65/0d47953afa0ad569d12de5f65d964321c208492064c38fe3b0b9744f8d44/numpy-1.24.4-cp38-cp38-win_amd64.whl", hash = "sha256:692f2e0f55794943c5bfff12b3f56f99af76f902fc47487bdfe97856de51a706", size = 14868112 }, - { url = "https://files.pythonhosted.org/packages/9a/cd/d5b0402b801c8a8b56b04c1e85c6165efab298d2f0ab741c2406516ede3a/numpy-1.24.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:2541312fbf09977f3b3ad449c4e5f4bb55d0dbf79226d7724211acc905049400", size = 19816549 }, - { url = "https://files.pythonhosted.org/packages/14/27/638aaa446f39113a3ed38b37a66243e21b38110d021bfcb940c383e120f2/numpy-1.24.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9667575fb6d13c95f1b36aca12c5ee3356bf001b714fc354eb5465ce1609e62f", size = 13879950 }, - { url = "https://files.pythonhosted.org/packages/8f/27/91894916e50627476cff1a4e4363ab6179d01077d71b9afed41d9e1f18bf/numpy-1.24.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f3a86ed21e4f87050382c7bc96571755193c4c1392490744ac73d660e8f564a9", size = 14030228 }, - { url = "https://files.pythonhosted.org/packages/7a/7c/d7b2a0417af6428440c0ad7cb9799073e507b1a465f827d058b826236964/numpy-1.24.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d11efb4dbecbdf22508d55e48d9c8384db795e1b7b51ea735289ff96613ff74d", size = 17311170 }, - { url = "https://files.pythonhosted.org/packages/18/9d/e02ace5d7dfccee796c37b995c63322674daf88ae2f4a4724c5dd0afcc91/numpy-1.24.4-cp39-cp39-win32.whl", hash = "sha256:6620c0acd41dbcb368610bb2f4d83145674040025e5536954782467100aa8835", size = 12454918 }, - { url = "https://files.pythonhosted.org/packages/63/38/6cc19d6b8bfa1d1a459daf2b3fe325453153ca7019976274b6f33d8b5663/numpy-1.24.4-cp39-cp39-win_amd64.whl", hash = "sha256:befe2bf740fd8373cf56149a5c23a0f601e82869598d41f8e188a0e9869926f8", size = 14867441 }, - { url = "https://files.pythonhosted.org/packages/a4/fd/8dff40e25e937c94257455c237b9b6bf5a30d42dd1cc11555533be099492/numpy-1.24.4-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:31f13e25b4e304632a4619d0e0777662c2ffea99fcae2029556b17d8ff958aef", size = 19156590 }, - { url = "https://files.pythonhosted.org/packages/42/e7/4bf953c6e05df90c6d351af69966384fed8e988d0e8c54dad7103b59f3ba/numpy-1.24.4-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95f7ac6540e95bc440ad77f56e520da5bf877f87dca58bd095288dce8940532a", size = 16705744 }, - { url = "https://files.pythonhosted.org/packages/fc/dd/9106005eb477d022b60b3817ed5937a43dad8fd1f20b0610ea8a32fcb407/numpy-1.24.4-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:e98f220aa76ca2a977fe435f5b04d7b3470c0a2e6312907b37ba6068f26787f2", size = 14734290 }, ] [[package]] @@ -1545,8 +1425,6 @@ source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version >= '3.12'", "python_full_version == '3.11.*'", - "python_full_version == '3.10.*'", - "python_full_version == '3.9.*'", ] sdist = { url = "https://files.pythonhosted.org/packages/65/6e/09db70a523a96d25e115e71cc56a6f9031e7b8cd166c1ac8438307c14058/numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010", size = 15786129 } wheels = [ @@ -1574,26 +1452,281 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/76/8c/2ba3902e1a0fc1c74962ea9bb33a534bb05984ad7ff9515bf8d07527cadd/numpy-1.26.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0", size = 17786643 }, { url = "https://files.pythonhosted.org/packages/28/4a/46d9e65106879492374999e76eb85f87b15328e06bd1550668f79f7b18c6/numpy-1.26.4-cp312-cp312-win32.whl", hash = "sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110", size = 5677803 }, { url = "https://files.pythonhosted.org/packages/16/2e/86f24451c2d530c88daf997cb8d6ac622c1d40d19f5a031ed68a4b73a374/numpy-1.26.4-cp312-cp312-win_amd64.whl", hash = "sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818", size = 15517754 }, - { url = "https://files.pythonhosted.org/packages/7d/24/ce71dc08f06534269f66e73c04f5709ee024a1afe92a7b6e1d73f158e1f8/numpy-1.26.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7349ab0fa0c429c82442a27a9673fc802ffdb7c7775fad780226cb234965e53c", size = 20636301 }, - { url = "https://files.pythonhosted.org/packages/ae/8c/ab03a7c25741f9ebc92684a20125fbc9fc1b8e1e700beb9197d750fdff88/numpy-1.26.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:52b8b60467cd7dd1e9ed082188b4e6bb35aa5cdd01777621a1658910745b90be", size = 13971216 }, - { url = "https://files.pythonhosted.org/packages/6d/64/c3bcdf822269421d85fe0d64ba972003f9bb4aa9a419da64b86856c9961f/numpy-1.26.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5241e0a80d808d70546c697135da2c613f30e28251ff8307eb72ba696945764", size = 14226281 }, - { url = "https://files.pythonhosted.org/packages/54/30/c2a907b9443cf42b90c17ad10c1e8fa801975f01cb9764f3f8eb8aea638b/numpy-1.26.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3", size = 18249516 }, - { url = "https://files.pythonhosted.org/packages/43/12/01a563fc44c07095996d0129b8899daf89e4742146f7044cdbdb3101c57f/numpy-1.26.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:679b0076f67ecc0138fd2ede3a8fd196dddc2ad3254069bcb9faf9a79b1cebcd", size = 13882132 }, - { url = "https://files.pythonhosted.org/packages/16/ee/9df80b06680aaa23fc6c31211387e0db349e0e36d6a63ba3bd78c5acdf11/numpy-1.26.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:47711010ad8555514b434df65f7d7b076bb8261df1ca9bb78f53d3b2db02e95c", size = 18084181 }, - { url = "https://files.pythonhosted.org/packages/28/7d/4b92e2fe20b214ffca36107f1a3e75ef4c488430e64de2d9af5db3a4637d/numpy-1.26.4-cp39-cp39-win32.whl", hash = "sha256:a354325ee03388678242a4d7ebcd08b5c727033fcff3b2f536aea978e15ee9e6", size = 5976360 }, - { url = "https://files.pythonhosted.org/packages/b5/42/054082bd8220bbf6f297f982f0a8f5479fcbc55c8b511d928df07b965869/numpy-1.26.4-cp39-cp39-win_amd64.whl", hash = "sha256:3373d5d70a5fe74a2c1bb6d2cfd9609ecf686d47a2d7b1d37a8f3b6bf6003aea", size = 15814633 }, - { url = "https://files.pythonhosted.org/packages/3f/72/3df6c1c06fc83d9cfe381cccb4be2532bbd38bf93fbc9fad087b6687f1c0/numpy-1.26.4-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:afedb719a9dcfc7eaf2287b839d8198e06dcd4cb5d276a3df279231138e83d30", size = 20455961 }, - { url = "https://files.pythonhosted.org/packages/8e/02/570545bac308b58ffb21adda0f4e220ba716fb658a63c151daecc3293350/numpy-1.26.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95a7476c59002f2f6c590b9b7b998306fba6a5aa646b1e22ddfeaf8f78c3a29c", size = 18061071 }, - { url = "https://files.pythonhosted.org/packages/f4/5f/fafd8c51235f60d49f7a88e2275e13971e90555b67da52dd6416caec32fe/numpy-1.26.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7e50d0a0cc3189f9cb0aeb3a6a6af18c16f59f004b866cd2be1c14b36134a4a0", size = 15709730 }, ] [[package]] -name = "oauthlib" -version = "3.2.2" +name = "nvidia-cublas" +version = "13.1.0.3" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e1/a5/fce49e2ae977e0ccc084e5adafceb4f0ac0c8333cb6863501618a7277f67/nvidia_cublas-13.1.0.3-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:c86fc7f7ae36d7528288c5d88098edcb7b02c633d262e7ddbb86b0ad91be5df2", size = 542851226 }, + { url = "https://files.pythonhosted.org/packages/e7/44/423ac00af4dd95a5aeb27207e2c0d9b7118702149bf4704c3ddb55bb7429/nvidia_cublas-13.1.0.3-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:ee8722c1f0145ab246bccb9e452153b5e0515fd094c3678df50b2a0888b8b171", size = 423133236 }, +] + +[[package]] +name = "nvidia-cublas-cu12" +version = "12.8.4.1" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/dc/61/e24b560ab2e2eaeb3c839129175fb330dfcfc29e5203196e5541a4c44682/nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:8ac4e771d5a348c551b2a426eda6193c19aa630236b418086020df5ba9667142", size = 594346921 }, +] + +[[package]] +name = "nvidia-cuda-cupti" +version = "13.0.85" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2a/2a/80353b103fc20ce05ef51e928daed4b6015db4aaa9162ed0997090fe2250/nvidia_cuda_cupti-13.0.85-py3-none-manylinux_2_25_aarch64.whl", hash = "sha256:796bd679890ee55fb14a94629b698b6db54bcfd833d391d5e94017dd9d7d3151", size = 10310827 }, + { url = "https://files.pythonhosted.org/packages/33/6d/737d164b4837a9bbd202f5ae3078975f0525a55730fe871d8ed4e3b952b0/nvidia_cuda_cupti-13.0.85-py3-none-manylinux_2_25_x86_64.whl", hash = "sha256:4eb01c08e859bf924d222250d2e8f8b8ff6d3db4721288cf35d14252a4d933c8", size = 10715597 }, +] + +[[package]] +name = "nvidia-cuda-cupti-cu12" +version = "12.8.90" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/02/2adcaa145158bf1a8295d83591d22e4103dbfd821bcaf6f3f53151ca4ffa/nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ea0cb07ebda26bb9b29ba82cda34849e73c166c18162d3913575b0c9db9a6182", size = 10248621 }, +] + +[[package]] +name = "nvidia-cuda-nvrtc" +version = "13.0.88" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c3/68/483a78f5e8f31b08fb1bb671559968c0ca3a065ac7acabfc7cee55214fd6/nvidia_cuda_nvrtc-13.0.88-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:ad9b6d2ead2435f11cbb6868809d2adeeee302e9bb94bcf0539c7a40d80e8575", size = 90215200 }, + { url = "https://files.pythonhosted.org/packages/b7/dc/6bb80850e0b7edd6588d560758f17e0550893a1feaf436807d64d2da040f/nvidia_cuda_nvrtc-13.0.88-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d27f20a0ca67a4bb34268a5e951033496c5b74870b868bacd046b1b8e0c3267b", size = 43015449 }, +] + +[[package]] +name = "nvidia-cuda-nvrtc-cu12" +version = "12.8.93" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/6d/fa/fbf4001037904031639e6bfbfc02badfc7e12f137a8afa254df6c4c8a670/oauthlib-3.2.2.tar.gz", hash = "sha256:9859c40929662bec5d64f34d01c99e093149682a3f38915dc0655d5a633dd918", size = 177352 } wheels = [ - { url = "https://files.pythonhosted.org/packages/7e/80/cab10959dc1faead58dc8384a781dfbf93cb4d33d50988f7a69f1b7c9bbe/oauthlib-3.2.2-py3-none-any.whl", hash = "sha256:8139f29aac13e25d502680e9e19963e83f16838d48a0d71c287fe40e7067fbca", size = 151688 }, + { url = "https://files.pythonhosted.org/packages/05/6b/32f747947df2da6994e999492ab306a903659555dddc0fbdeb9d71f75e52/nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:a7756528852ef889772a84c6cd89d41dfa74667e24cca16bb31f8f061e3e9994", size = 88040029 }, +] + +[[package]] +name = "nvidia-cuda-runtime" +version = "13.0.96" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/87/4f/17d7b9b8e285199c58ce28e31b5c5bbaa4d8271af06a89b6405258245de2/nvidia_cuda_runtime-13.0.96-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ef9bcbe90493a2b9d810e43d249adb3d02e98dd30200d86607d8d02687c43f55", size = 2261060 }, + { url = "https://files.pythonhosted.org/packages/2e/24/d1558f3b68b1d26e706813b1d10aa1d785e4698c425af8db8edc3dced472/nvidia_cuda_runtime-13.0.96-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:7f82250d7782aa23b6cfe765ecc7db554bd3c2870c43f3d1821f1d18aebf0548", size = 2243632 }, +] + +[[package]] +name = "nvidia-cuda-runtime-cu12" +version = "12.8.90" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0d/9b/a997b638fcd068ad6e4d53b8551a7d30fe8b404d6f1804abf1df69838932/nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:adade8dcbd0edf427b7204d480d6066d33902cab2a4707dcfc48a2d0fd44ab90", size = 954765 }, +] + +[[package]] +name = "nvidia-cudnn-cu12" +version = "9.10.2.21" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-cublas-cu12", marker = "python_full_version < '3.11'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467 }, +] + +[[package]] +name = "nvidia-cudnn-cu13" +version = "9.19.0.56" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-cublas", marker = "python_full_version >= '3.11'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/f1/84/26025437c1e6b61a707442184fa0c03d083b661adf3a3eecfd6d21677740/nvidia_cudnn_cu13-9.19.0.56-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:6ed29ffaee1176c612daf442e4dd6cfeb6a0caa43ddcbeb59da94953030b1be4", size = 433781201 }, + { url = "https://files.pythonhosted.org/packages/a3/22/0b4b932655d17a6da1b92fa92ab12844b053bb2ac2475e179ba6f043da1e/nvidia_cudnn_cu13-9.19.0.56-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:d20e1734305e9d68889a96e3f35094d733ff1f83932ebe462753973e53a572bf", size = 366066321 }, +] + +[[package]] +name = "nvidia-cufft" +version = "12.0.0.61" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-nvjitlink", marker = "python_full_version >= '3.11'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/8b/ae/f417a75c0259e85c1d2f83ca4e960289a5f814ed0cea74d18c353d3e989d/nvidia_cufft-12.0.0.61-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2708c852ef8cd89d1d2068bdbece0aa188813a0c934db3779b9b1faa8442e5f5", size = 214053554 }, + { url = "https://files.pythonhosted.org/packages/a8/2f/7b57e29836ea8714f81e9898409196f47d772d5ddedddf1592eadb8ab743/nvidia_cufft-12.0.0.61-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6c44f692dce8fd5ffd3e3df134b6cdb9c2f72d99cf40b62c32dde45eea9ddad3", size = 214085489 }, +] + +[[package]] +name = "nvidia-cufft-cu12" +version = "11.3.3.83" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-nvjitlink-cu12", marker = "python_full_version < '3.11'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695 }, +] + +[[package]] +name = "nvidia-cufile" +version = "1.15.1.6" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3f/70/4f193de89a48b71714e74602ee14d04e4019ad36a5a9f20c425776e72cd6/nvidia_cufile-1.15.1.6-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:08a3ecefae5a01c7f5117351c64f17c7c62efa5fffdbe24fc7d298da19cd0b44", size = 1223672 }, + { url = "https://files.pythonhosted.org/packages/ab/73/cc4a14c9813a8a0d509417cf5f4bdaba76e924d58beb9864f5a7baceefbf/nvidia_cufile-1.15.1.6-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:bdc0deedc61f548bddf7733bdc216456c2fdb101d020e1ab4b88d232d5e2f6d1", size = 1136992 }, +] + +[[package]] +name = "nvidia-cufile-cu12" +version = "1.13.1.3" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bb/fe/1bcba1dfbfb8d01be8d93f07bfc502c93fa23afa6fd5ab3fc7c1df71038a/nvidia_cufile_cu12-1.13.1.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1d069003be650e131b21c932ec3d8969c1715379251f8d23a1860554b1cb24fc", size = 1197834 }, +] + +[[package]] +name = "nvidia-curand" +version = "10.4.0.35" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/72/7c2ae24fb6b63a32e6ae5d241cc65263ea18d08802aaae087d9f013335a2/nvidia_curand-10.4.0.35-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:133df5a7509c3e292aaa2b477afd0194f06ce4ea24d714d616ff36439cee349a", size = 61962106 }, + { url = "https://files.pythonhosted.org/packages/a5/9f/be0a41ca4a4917abf5cb9ae0daff1a6060cc5de950aec0396de9f3b52bc5/nvidia_curand-10.4.0.35-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:1aee33a5da6e1db083fe2b90082def8915f30f3248d5896bcec36a579d941bfc", size = 59544258 }, +] + +[[package]] +name = "nvidia-curand-cu12" +version = "10.3.9.90" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fb/aa/6584b56dc84ebe9cf93226a5cde4d99080c8e90ab40f0c27bda7a0f29aa1/nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:b32331d4f4df5d6eefa0554c565b626c7216f87a06a4f56fab27c3b68a830ec9", size = 63619976 }, +] + +[[package]] +name = "nvidia-cusolver" +version = "12.0.4.66" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-cublas", marker = "python_full_version >= '3.11'" }, + { name = "nvidia-cusparse", marker = "python_full_version >= '3.11'" }, + { name = "nvidia-nvjitlink", marker = "python_full_version >= '3.11'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/c8/c3/b30c9e935fc01e3da443ec0116ed1b2a009bb867f5324d3f2d7e533e776b/nvidia_cusolver-12.0.4.66-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:02c2457eaa9e39de20f880f4bd8820e6a1cfb9f9a34f820eb12a155aa5bc92d2", size = 223467760 }, + { url = "https://files.pythonhosted.org/packages/5f/67/cba3777620cdacb99102da4042883709c41c709f4b6323c10781a9c3aa34/nvidia_cusolver-12.0.4.66-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:0a759da5dea5c0ea10fd307de75cdeb59e7ea4fcb8add0924859b944babf1112", size = 200941980 }, +] + +[[package]] +name = "nvidia-cusolver-cu12" +version = "11.7.3.90" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-cublas-cu12", marker = "python_full_version < '3.11'" }, + { name = "nvidia-cusparse-cu12", marker = "python_full_version < '3.11'" }, + { name = "nvidia-nvjitlink-cu12", marker = "python_full_version < '3.11'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905 }, +] + +[[package]] +name = "nvidia-cusparse" +version = "12.6.3.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-nvjitlink", marker = "python_full_version >= '3.11'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/94/5c26f33738ae35276672f12615a64bd008ed5be6d1ebcb23579285d960a9/nvidia_cusparse-12.6.3.3-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:80bcc4662f23f1054ee334a15c72b8940402975e0eab63178fc7e670aa59472c", size = 162155568 }, + { url = "https://files.pythonhosted.org/packages/fa/18/623c77619c31d62efd55302939756966f3ecc8d724a14dab2b75f1508850/nvidia_cusparse-12.6.3.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:2b3c89c88d01ee0e477cb7f82ef60a11a4bcd57b6b87c33f789350b59759360b", size = 145942937 }, +] + +[[package]] +name = "nvidia-cusparse-cu12" +version = "12.5.8.93" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-nvjitlink-cu12", marker = "python_full_version < '3.11'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466 }, +] + +[[package]] +name = "nvidia-cusparselt-cu12" +version = "0.7.1" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/56/79/12978b96bd44274fe38b5dde5cfb660b1d114f70a65ef962bcbbed99b549/nvidia_cusparselt_cu12-0.7.1-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f1bb701d6b930d5a7cea44c19ceb973311500847f81b634d802b7b539dc55623", size = 287193691 }, +] + +[[package]] +name = "nvidia-cusparselt-cu13" +version = "0.8.0" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/46/10/8dcd1175260706a2fc92a16a52e306b71d4c1ea0b0cc4a9484183399818a/nvidia_cusparselt_cu13-0.8.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:400c6ed1cf6780fc6efedd64ec9f1345871767e6a1a0a552a1ea0578117ea77c", size = 220791277 }, + { url = "https://files.pythonhosted.org/packages/fd/53/43b0d71f4e702fa9733f8b4571fdca50a8813f1e450b656c239beff12315/nvidia_cusparselt_cu13-0.8.0-py3-none-manylinux2014_x86_64.whl", hash = "sha256:25e30a8a7323935d4ad0340b95a0b69926eee755767e8e0b1cf8dd85b197d3fd", size = 169884119 }, +] + +[[package]] +name = "nvidia-nccl-cu12" +version = "2.27.3" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5c/5b/4e4fff7bad39adf89f735f2bc87248c81db71205b62bcc0d5ca5b606b3c3/nvidia_nccl_cu12-2.27.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:adf27ccf4238253e0b826bce3ff5fa532d65fc42322c8bfdfaf28024c0fbe039", size = 322364134 }, +] + +[[package]] +name = "nvidia-nccl-cu13" +version = "2.28.9" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/39/55/1920646a2e43ffd4fc958536b276197ed740e9e0c54105b4bb3521591fc7/nvidia_nccl_cu13-2.28.9-py3-none-manylinux_2_18_aarch64.whl", hash = "sha256:01c873ba1626b54caa12272ed228dc5b2781545e0ae8ba3f432a8ef1c6d78643", size = 196561677 }, + { url = "https://files.pythonhosted.org/packages/b0/b4/878fefaad5b2bcc6fcf8d474a25e3e3774bc5133e4b58adff4d0bca238bc/nvidia_nccl_cu13-2.28.9-py3-none-manylinux_2_18_x86_64.whl", hash = "sha256:e4553a30f34195f3fa1da02a6da3d6337d28f2003943aa0a3d247bbc25fefc42", size = 196493177 }, +] + +[[package]] +name = "nvidia-nvjitlink" +version = "13.0.88" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/56/7a/123e033aaff487c77107195fa5a2b8686795ca537935a24efae476c41f05/nvidia_nvjitlink-13.0.88-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:13a74f429e23b921c1109976abefacc69835f2f433ebd323d3946e11d804e47b", size = 40713933 }, + { url = "https://files.pythonhosted.org/packages/ab/2c/93c5250e64df4f894f1cbb397c6fd71f79813f9fd79d7cd61de3f97b3c2d/nvidia_nvjitlink-13.0.88-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:e931536ccc7d467a98ba1d8b89ff7fa7f1fa3b13f2b0069118cd7f47bff07d0c", size = 38768748 }, +] + +[[package]] +name = "nvidia-nvjitlink-cu12" +version = "12.8.93" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f6/74/86a07f1d0f42998ca31312f998bd3b9a7eff7f52378f4f270c8679c77fb9/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:81ff63371a7ebd6e6451970684f916be2eab07321b73c9d244dc2b4da7f73b88", size = 39254836 }, +] + +[[package]] +name = "nvidia-nvshmem-cu13" +version = "3.4.5" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/dc/0f/05cc9c720236dcd2db9c1ab97fff629e96821be2e63103569da0c9b72f19/nvidia_nvshmem_cu13-3.4.5-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:6dc2a197f38e5d0376ad52cd1a2a3617d3cdc150fd5966f4aee9bcebb1d68fe9", size = 60215947 }, + { url = "https://files.pythonhosted.org/packages/3c/35/a9bf80a609e74e3b000fef598933235c908fcefcef9026042b8e6dfde2a9/nvidia_nvshmem_cu13-3.4.5-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:290f0a2ee94c9f3687a02502f3b9299a9f9fe826e6d0287ee18482e78d495b80", size = 60412546 }, +] + +[[package]] +name = "nvidia-nvtx" +version = "13.0.85" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c2/f3/d86c845465a2723ad7e1e5c36dcd75ddb82898b3f53be47ebd429fb2fa5d/nvidia_nvtx-13.0.85-py3-none-manylinux1_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:4936d1d6780fbe68db454f5e72a42ff64d1fd6397df9f363ae786930fd5c1cd4", size = 148047 }, + { url = "https://files.pythonhosted.org/packages/a8/64/3708a90d1ebe202ffdeb7185f878a3c84d15c2b2c31858da2ce0583e2def/nvidia_nvtx-13.0.85-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cb7780edb6b14107373c835bf8b72e7a178bac7367e23da7acb108f973f157a6", size = 148878 }, +] + +[[package]] +name = "nvidia-nvtx-cu12" +version = "12.8.90" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a2/eb/86626c1bbc2edb86323022371c39aa48df6fd8b0a1647bc274577f72e90b/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5b17e2001cc0d751a5bc2c6ec6d26ad95913324a4adb86788c944f8ce9ba441f", size = 89954 }, ] [[package]] @@ -1610,7 +1743,7 @@ name = "optree" version = "0.14.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "typing-extensions", marker = "python_full_version >= '3.9'" }, + { name = "typing-extensions" }, ] sdist = { url = "https://files.pythonhosted.org/packages/86/3a/313dae3303d526c333259544e9196207d33a43f0768cdca45f8e69cdd8ba/optree-0.14.0.tar.gz", hash = "sha256:d2b4b8784f5c7651a899997c9d6d4cd814c4222cd450c76d1fa386b8f5728d61", size = 158834 } wheels = [ @@ -1644,35 +1777,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d1/42/cd327132f2a481939d07315cf98393fd62912c31bc3288b83dd142a7d0d2/optree-0.14.0-cp312-cp312-win32.whl", hash = "sha256:c153bb5b5d2286109d1d8bee704b59f9303aed9c92822075e7002ea5362fa534", size = 268878 }, { url = "https://files.pythonhosted.org/packages/ce/e6/b1c08aa53a2db9d8102d439f680ae2065ca7a3ea7da62902b7f57f576236/optree-0.14.0-cp312-cp312-win_amd64.whl", hash = "sha256:c79cad5da479ee6931f2c96cacccf588ff75029072661021963117df895305d9", size = 299568 }, { url = "https://files.pythonhosted.org/packages/9d/42/db1e14970e3dd6ff0b2aea7767e92989769a0dc8b07f89850197515ecf97/optree-0.14.0-cp312-cp312-win_arm64.whl", hash = "sha256:c844427e28cc661782fdfba6a2a13d89acabc3b183f49f5e366f8b4fab9616f4", size = 295279 }, - { url = "https://files.pythonhosted.org/packages/78/b8/04fd39f998e68a057b4768dd5962f0311f4f105e44b038d7e8f67c861d37/optree-0.14.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:db73d8750deb66cd6402fee86c1b3a2df32a0bca1049448829eaa1023408f282", size = 599586 }, - { url = "https://files.pythonhosted.org/packages/1d/ee/54bb3740662a91af74f187b4afda5fd008f3966a2651f4452bf4a41ee6b0/optree-0.14.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:614c97c6e42a7e9a7765c051cff0ad3f482750205f2b6a113eecb5c381da38d5", size = 324113 }, - { url = "https://files.pythonhosted.org/packages/ff/1f/cdb2243c7b664adde6a3656a4270f6ce2b21bd924dd242a582e068479a26/optree-0.14.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3127e77bd5eabd28bd3388db3291f1ea15eaeedd86bb4e71770f8aba4bb68acb", size = 355926 }, - { url = "https://files.pythonhosted.org/packages/90/03/1aee947a7edaee888f2502b82e6403210dccd67779ca9264da2cd4656d5d/optree-0.14.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:faab435742987c8ea244e81b7526234c6f86cfc8fec5ec11d48184348e92aada", size = 400890 }, - { url = "https://files.pythonhosted.org/packages/a4/11/b2fb4045a01f39bb2de996bfed2a7ee66e66669ca06c3577b5928625bb09/optree-0.14.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4eee7d0248129465d1ad1c391ab38fe76f5af789571551823f131c81a008ceb1", size = 398002 }, - { url = "https://files.pythonhosted.org/packages/e9/39/986f91a11a846492a96a93344fec7a91bd5de6a20229ed7d5c9c9647b920/optree-0.14.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4c0c65c764cda12841759a03ff86dec79404f96b2750f90859b042d60e9a2d82", size = 368519 }, - { url = "https://files.pythonhosted.org/packages/75/7b/a646501b649ae606cea5b63933251294a8ca3d63dd45c5870adec594ffa9/optree-0.14.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:53f14de1c07d64e381acdb29254dbdd86bba84138e7c789a6d2be026d03a36a9", size = 391641 }, - { url = "https://files.pythonhosted.org/packages/f2/d7/6c14095995386c43dfbf7eb9c9aef57cc790fb563cf7450ae527e51516e3/optree-0.14.0-cp38-cp38-win32.whl", hash = "sha256:202e97dab0b7eae95738d8775cba4417a26e8539568f5b7e0a50e500263a3703", size = 262430 }, - { url = "https://files.pythonhosted.org/packages/9c/56/8c163760347b781fb6c2bfdd348192ae26abc3e0b364923ccdcb840730aa/optree-0.14.0-cp38-cp38-win_amd64.whl", hash = "sha256:9e1dfb12bcdf2d759602b7ad1bc6228ec5a19451c3504a80bd5445b9c8e53bab", size = 290767 }, - { url = "https://files.pythonhosted.org/packages/90/61/f754605df3dd1b15ad88a87ff7d97dafeaa8d458320a05de3842ed76b363/optree-0.14.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:80a70cc5f944d2db3eae1a225b41a935d957c928d324f7677f8387e4ab3e8626", size = 599843 }, - { url = "https://files.pythonhosted.org/packages/39/35/2207d20b4f7aed6ddf0b46ee33f1a178caef54ed8fa246363612f7c9c46f/optree-0.14.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8b1ca7d17007b46223c5f3c02ffa9effc812adff5bc30f561dbfe88f241a16ba", size = 324174 }, - { url = "https://files.pythonhosted.org/packages/7c/42/12cd07070bb815bb8ac6df0d0ea149dc06e6cb1cd4262565c65805957f6e/optree-0.14.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c3a7704f7f3cd45caa684e0b762bac29207435ea811ca3da7b2d93cc2fa54310", size = 358070 }, - { url = "https://files.pythonhosted.org/packages/1a/14/e3aa38bd9e4cc0be7ab00884f750595315ba74dcad4657d4d1f3c61e324b/optree-0.14.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6e0fd04f11bbb9862bedee4f4e7b3b1ed7476c34a3e7bf25a2169d43a1b23e90", size = 401567 }, - { url = "https://files.pythonhosted.org/packages/07/3d/7fbef260a539bd90846e5f2d9ea673cbbddb38e45dc764137ce99d34108e/optree-0.14.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:27b66f1d542cf4cc9867268485cad3c719bee3e80731a3dc45649c9c57c66f25", size = 400194 }, - { url = "https://files.pythonhosted.org/packages/bf/d7/75ca91a87a2d4d434a1a2eac40c59738b9274db14246289fb928a2985fa2/optree-0.14.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d47cf9c991505aae3e93879404bf9bb47efaeb2c84951610d9b63453b8edfadb", size = 370467 }, - { url = "https://files.pythonhosted.org/packages/39/d2/97e53c017bf91441acd476563202c00238c62d679db8c0f1b4c8a9771bea/optree-0.14.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0a08dcc8b5a7529ebef64533cba13444de46ba9e923a9c54a9c1dcceb4de2f55", size = 392136 }, - { url = "https://files.pythonhosted.org/packages/cd/95/90bf10b8da83258d64245bf257202b2a7cb8e4883ab7531490984ab35fa0/optree-0.14.0-cp39-cp39-win32.whl", hash = "sha256:e3aa3421fc50619cf15caaa457952c06b532a192df02d9e94a8a6aabe5acbebf", size = 262475 }, - { url = "https://files.pythonhosted.org/packages/7d/db/71537de2852bc5c86365315cfd52a70611cf18291d2106d4a76c6ecdb16c/optree-0.14.0-cp39-cp39-win_amd64.whl", hash = "sha256:b1f03ed925afee44fea9e26bf99a297111f313d88cfb69142463a3cb359f7953", size = 286052 }, - { url = "https://files.pythonhosted.org/packages/6c/af/bf110bd801b4598476892fdfb064f5e5fbab230acd6a11252f6be9e5bea5/optree-0.14.0-cp39-cp39-win_arm64.whl", hash = "sha256:81122a324237fccb4f8abe5dca1b00be12cf4c0a53d3a4872cfc1f060c713854", size = 285162 }, { url = "https://files.pythonhosted.org/packages/dc/f3/eb0379246428ef28484a40607f74248766c40986567b6d4e7d416dcaddfd/optree-0.14.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:a4934f4da6f79314760e9559f8c8484e00aa99ea79f8d3326f66cf8e11db71b0", size = 330719 }, { url = "https://files.pythonhosted.org/packages/12/48/71ca54dc7d4729af8b7d4706549d5c4236e2a24d9a9a41c20bd4b36d3442/optree-0.14.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:78d33c499c102e2aba05abf99876025ba7f1d5ca98f2e3c75d5cddc9dc42cfa5", size = 360622 }, { url = "https://files.pythonhosted.org/packages/22/21/6438ee6c4894ff996e85e187e83975eef4d95bcd58978f1f2e473e0882c2/optree-0.14.0-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b3eea1ab8fb32cf5745eead68671100db8547e6d22e8b5c3780376369560659c", size = 405706 }, { url = "https://files.pythonhosted.org/packages/e8/37/a12cfe33b5db4949905bc02dfeca494b153057d70eb680fd520e0b4b529a/optree-0.14.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d3fe8f48cb16454e3b9c44f081b940062180e0d6c10fda0a098ed7855be8d0a9", size = 395076 }, { url = "https://files.pythonhosted.org/packages/da/5a/e9b94bbf183ab83565fd31146b509f39288c2b293208337deaeb9ff300f9/optree-0.14.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:3e53c3aa6303efb9a64ccef160ec6638bb4a97b41b77c3871a1204397e27a98a", size = 293687 }, - { url = "https://files.pythonhosted.org/packages/ab/5f/d17d44731df91457740799e99c4625a3ffc9959b38abfec8afb2c85e52cb/optree-0.14.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:ede3b9ccf4cfd5e1ec12db79b93bf45e14e5c1596b339761d3296ce85739ef7a", size = 330639 }, - { url = "https://files.pythonhosted.org/packages/a3/5b/606622cca7322bc16cc3e902aff7b5ef50b98394a6b2c042eb585204af73/optree-0.14.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:68803a66b836f595c291347a2bff237852ca80fcfbb2606fee88d046764240de", size = 360331 }, - { url = "https://files.pythonhosted.org/packages/1f/70/f239ec4ef319a63b2bd48c12bf185a451f47f47d1b73eea34e63e050d411/optree-0.14.0-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:aec7dfa57fc9a42e18a2e23bc8c011dbacdf16d8da0a62cc3b4b5ef0fba13d05", size = 405750 }, - { url = "https://files.pythonhosted.org/packages/eb/9d/960dbfc47c99a2cc1e5698db848b4888107e490ff0d7669765f5c7aaf870/optree-0.14.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f505038e5be2a84155e642c396811bbf1e88a4c6aea6a8766b2c57b562bc65de", size = 394797 }, - { url = "https://files.pythonhosted.org/packages/e6/ee/189359bd4e81faa0b352a2c00291c069afa79d302afb5cf1e57522c8b46b/optree-0.14.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:9527a9b3a2f4f73334e9fdbebaec1d7001f717a0c2d195e8419cc5d0ba3183b6", size = 293705 }, ] [[package]] @@ -1699,13 +1808,11 @@ version = "1.5.3" source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version == '3.11.*'", - "python_full_version == '3.10.*'", - "python_full_version == '3.9.*'", - "python_full_version < '3.9'", + "python_full_version < '3.11'", ] dependencies = [ - { name = "numpy", version = "1.24.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9' and python_full_version < '3.12'" }, + { name = "numpy", version = "1.24.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.11.*'" }, { name = "python-dateutil", marker = "python_full_version < '3.12'" }, { name = "pytz", marker = "python_full_version < '3.12'" }, ] @@ -1723,20 +1830,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/63/8d/c2bd356b9d4baf1c5cf8d7e251fb4540e87083072c905430da48c2bb31eb/pandas-1.5.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e474390e60ed609cec869b0da796ad94f420bb057d86784191eefc62b65819ae", size = 11374218 }, { url = "https://files.pythonhosted.org/packages/56/73/3351beeb807dca69fcc3c4966bcccc51552bd01549a9b13c04ab00a43f21/pandas-1.5.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5f2b952406a1588ad4cad5b3f55f520e82e902388a6d5a4a91baa8d38d23c7f6", size = 12017319 }, { url = "https://files.pythonhosted.org/packages/da/6d/1235da14daddaa6e47f74ba0c255358f0ce7a6ee05da8bf8eb49161aa6b5/pandas-1.5.3-cp311-cp311-win_amd64.whl", hash = "sha256:bc4c368f42b551bf72fac35c5128963a171b40dce866fb066540eeaf46faa003", size = 10303385 }, - { url = "https://files.pythonhosted.org/packages/26/c1/469f5d7863a9901d92b795d9fc5c7c4acccd7df62b13367c7fac0d499c3b/pandas-1.5.3-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:14e45300521902689a81f3f41386dc86f19b8ba8dd5ac5a3c7010ef8d2932813", size = 18428032 }, - { url = "https://files.pythonhosted.org/packages/2b/63/fa344006a41dd696720328af0f1f914f530e9eca2f794607f6af9158897d/pandas-1.5.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9842b6f4b8479e41968eced654487258ed81df7d1c9b7b870ceea24ed9459b31", size = 11896315 }, - { url = "https://files.pythonhosted.org/packages/0e/1d/f964977eea9ed72d5f1c53af56038aca2ce781a0cc8bce8aeb33da039ca1/pandas-1.5.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:26d9c71772c7afb9d5046e6e9cf42d83dd147b5cf5bcb9d97252077118543792", size = 10825052 }, - { url = "https://files.pythonhosted.org/packages/b2/87/e0a0e9a0ab9ede47192aa40887b7e31d048c98326a41d6b57c658d1a809d/pandas-1.5.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5fbcb19d6fceb9e946b3e23258757c7b225ba450990d9ed63ccceeb8cae609f7", size = 11465500 }, - { url = "https://files.pythonhosted.org/packages/54/a0/c62d63c5c69be9aae07836e4d7e25e7a6f5590be3d8f2d53f43eeec5c475/pandas-1.5.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:565fa34a5434d38e9d250af3c12ff931abaf88050551d9fbcdfafca50d62babf", size = 12189084 }, - { url = "https://files.pythonhosted.org/packages/bc/bb/359b304fb2d9a97c7344b6ceb585dc22fff864e4f3f1d1511166cd84865e/pandas-1.5.3-cp38-cp38-win32.whl", hash = "sha256:87bd9c03da1ac870a6d2c8902a0e1fd4267ca00f13bc494c9e5a9020920e1d51", size = 9753053 }, - { url = "https://files.pythonhosted.org/packages/ca/4e/d18db7d5ff9d28264cd2a7e2499b8701108f0e6c698e382cfd5d20685c21/pandas-1.5.3-cp38-cp38-win_amd64.whl", hash = "sha256:41179ce559943d83a9b4bbacb736b04c928b095b5f25dd2b7389eda08f46f373", size = 10959031 }, - { url = "https://files.pythonhosted.org/packages/90/19/1a92d73cda1233326e787a4c14362a1fcce4c7d9f28316fd769308aefb99/pandas-1.5.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:c74a62747864ed568f5a82a49a23a8d7fe171d0c69038b38cedf0976831296fa", size = 18722090 }, - { url = "https://files.pythonhosted.org/packages/02/4a/8e2513db9d15929b833147f975d8424dc6a3e18100ead10aab78756a1aad/pandas-1.5.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c4c00e0b0597c8e4f59e8d461f797e5d70b4d025880516a8261b2817c47759ee", size = 12049642 }, - { url = "https://files.pythonhosted.org/packages/a7/2b/c71df8794e8e75ba1ec9da1c1a2efc946590aa79a05148a4138405ef5f72/pandas-1.5.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a50d9a4336a9621cab7b8eb3fb11adb82de58f9b91d84c2cd526576b881a0c5a", size = 10962439 }, - { url = "https://files.pythonhosted.org/packages/7d/d6/92be61dca3880c7cec99a9b4acf6260b3dc00519673fdb3e6666ac6096ce/pandas-1.5.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dd05f7783b3274aa206a1af06f0ceed3f9b412cf665b7247eacd83be41cf7bf0", size = 11471277 }, - { url = "https://files.pythonhosted.org/packages/e1/4d/3eb96e53a9208350ee21615f850c4be9a246d32bf1d34cd36682cb58c3b7/pandas-1.5.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9f69c4029613de47816b1bb30ff5ac778686688751a5e9c99ad8c7031f6508e5", size = 12169732 }, - { url = "https://files.pythonhosted.org/packages/94/85/89f6547642b28fbd874504a6f548d6be4d88981837a23ab18d76cb773bea/pandas-1.5.3-cp39-cp39-win32.whl", hash = "sha256:7cec0bee9f294e5de5bbfc14d0573f65526071029d036b753ee6507d2a21480a", size = 9730624 }, - { url = "https://files.pythonhosted.org/packages/c2/45/801ecd8434eef0b39cc02795ffae273fe3df3cfcb3f6fff215efbe92d93c/pandas-1.5.3-cp39-cp39-win_amd64.whl", hash = "sha256:dfd681c5dc216037e0b0a2c821f5ed99ba9f03ebcf119c7dac0e9a7b960b9ec9", size = 10932203 }, ] [[package]] @@ -1775,13 +1868,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a6/de/8b1895b107277d52f2b42d3a6806e69cfef0d5cf1d0ba343470b9d8e0a04/pandas-2.3.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a68e15f780eddf2b07d242e17a04aa187a7ee12b40b930bfdd78070556550e98", size = 12771002 }, { url = "https://files.pythonhosted.org/packages/87/21/84072af3187a677c5893b170ba2c8fbe450a6ff911234916da889b698220/pandas-2.3.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:371a4ab48e950033bcf52b6527eccb564f52dc826c02afd9a1bc0ab731bba084", size = 13450971 }, { url = "https://files.pythonhosted.org/packages/86/41/585a168330ff063014880a80d744219dbf1dd7a1c706e75ab3425a987384/pandas-2.3.3-cp312-cp312-win_amd64.whl", hash = "sha256:a16dcec078a01eeef8ee61bf64074b4e524a2a3f4b3be9326420cabe59c4778b", size = 10992722 }, - { url = "https://files.pythonhosted.org/packages/56/b4/52eeb530a99e2a4c55ffcd352772b599ed4473a0f892d127f4147cf0f88e/pandas-2.3.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c503ba5216814e295f40711470446bc3fd00f0faea8a086cbc688808e26f92a2", size = 11567720 }, - { url = "https://files.pythonhosted.org/packages/48/4a/2d8b67632a021bced649ba940455ed441ca854e57d6e7658a6024587b083/pandas-2.3.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a637c5cdfa04b6d6e2ecedcb81fc52ffb0fd78ce2ebccc9ea964df9f658de8c8", size = 10810302 }, - { url = "https://files.pythonhosted.org/packages/13/e6/d2465010ee0569a245c975dc6967b801887068bc893e908239b1f4b6c1ac/pandas-2.3.3-cp39-cp39-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:854d00d556406bffe66a4c0802f334c9ad5a96b4f1f868adf036a21b11ef13ff", size = 12154874 }, - { url = "https://files.pythonhosted.org/packages/1f/18/aae8c0aa69a386a3255940e9317f793808ea79d0a525a97a903366bb2569/pandas-2.3.3-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bf1f8a81d04ca90e32a0aceb819d34dbd378a98bf923b6398b9a3ec0bf44de29", size = 12790141 }, - { url = "https://files.pythonhosted.org/packages/f7/26/617f98de789de00c2a444fbe6301bb19e66556ac78cff933d2c98f62f2b4/pandas-2.3.3-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:23ebd657a4d38268c7dfbdf089fbc31ea709d82e4923c5ffd4fbd5747133ce73", size = 13208697 }, - { url = "https://files.pythonhosted.org/packages/b9/fb/25709afa4552042bd0e15717c75e9b4a2294c3dc4f7e6ea50f03c5136600/pandas-2.3.3-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:5554c929ccc317d41a5e3d1234f3be588248e61f08a74dd17c9eabb535777dc9", size = 13879233 }, - { url = "https://files.pythonhosted.org/packages/98/af/7be05277859a7bc399da8ba68b88c96b27b48740b6cf49688899c6eb4176/pandas-2.3.3-cp39-cp39-win_amd64.whl", hash = "sha256:d3e28b3e83862ccf4d85ff19cf8c20b2ae7e503881711ff2d534dc8f761131aa", size = 11359119 }, ] [[package]] @@ -1816,14 +1902,14 @@ name = "pre-commit" version = "3.5.0" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version < '3.9'", + "python_full_version < '3.11'", ] dependencies = [ - { name = "cfgv", marker = "python_full_version < '3.9'" }, - { name = "identify", version = "2.6.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "nodeenv", marker = "python_full_version < '3.9'" }, - { name = "pyyaml", marker = "python_full_version < '3.9'" }, - { name = "virtualenv", marker = "python_full_version < '3.9'" }, + { name = "cfgv", marker = "python_full_version < '3.11'" }, + { name = "identify", version = "2.6.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "nodeenv", marker = "python_full_version < '3.11'" }, + { name = "pyyaml", marker = "python_full_version < '3.11'" }, + { name = "virtualenv", marker = "python_full_version < '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/04/b3/4ae08d21eb097162f5aad37f4585f8069a86402ed7f5362cc9ae097f9572/pre_commit-3.5.0.tar.gz", hash = "sha256:5804465c675b659b0862f07907f96295d490822a450c4c40e747d0b1c6ebcb32", size = 177079 } wheels = [ @@ -1837,57 +1923,23 @@ source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version >= '3.12'", "python_full_version == '3.11.*'", - "python_full_version == '3.10.*'", - "python_full_version == '3.9.*'", ] dependencies = [ - { name = "cfgv", marker = "python_full_version >= '3.9'" }, - { name = "identify", version = "2.6.8", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, - { name = "nodeenv", marker = "python_full_version >= '3.9'" }, - { name = "pyyaml", marker = "python_full_version >= '3.9'" }, - { name = "virtualenv", marker = "python_full_version >= '3.9'" }, + { name = "cfgv", marker = "python_full_version >= '3.11'" }, + { name = "identify", version = "2.6.8", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "nodeenv", marker = "python_full_version >= '3.11'" }, + { name = "pyyaml", marker = "python_full_version >= '3.11'" }, + { name = "virtualenv", marker = "python_full_version >= '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/64/10/97ee2fa54dff1e9da9badbc5e35d0bbaef0776271ea5907eccf64140f72f/pre_commit-3.8.0.tar.gz", hash = "sha256:8bb6494d4a20423842e198980c9ecf9f96607a07ea29549e180eef9ae80fe7af", size = 177815 } wheels = [ { url = "https://files.pythonhosted.org/packages/07/92/caae8c86e94681b42c246f0bca35c059a2f0529e5b92619f6aba4cf7e7b6/pre_commit-3.8.0-py2.py3-none-any.whl", hash = "sha256:9a90a53bf82fdd8778d58085faf8d83df56e40dfe18f45b19446e26bf1b3a63f", size = 204643 }, ] -[[package]] -name = "protobuf" -version = "3.19.6" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version < '3.9'", -] -sdist = { url = "https://files.pythonhosted.org/packages/51/d1/79bfd1f481469b661a2eddab551255536401892722189433282bfb13cfb1/protobuf-3.19.6.tar.gz", hash = "sha256:5f5540d57a43042389e87661c6eaa50f47c19c6176e8cf1c4f287aeefeccb5c4", size = 218071 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/4b/3b/90f805b9e5ecacf8a216f2e5acabc2d3ad965b62803510be41804e6bfbfe/protobuf-3.19.6-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:010be24d5a44be7b0613750ab40bc8b8cedc796db468eae6c779b395f50d1fa1", size = 913631 }, - { url = "https://files.pythonhosted.org/packages/26/ef/bd6ba3b4ff9a35944bdd325e2c9ee56f71e855757f7d43938232499f0278/protobuf-3.19.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:11478547958c2dfea921920617eb457bc26867b0d1aa065ab05f35080c5d9eb6", size = 1055327 }, - { url = "https://files.pythonhosted.org/packages/4a/25/85bcc155980b5d7754ebdf4cb32039105f6020b6b2b8f7536a866113fc1c/protobuf-3.19.6-cp310-cp310-win32.whl", hash = "sha256:559670e006e3173308c9254d63facb2c03865818f22204037ab76f7a0ff70b5f", size = 775745 }, - { url = "https://files.pythonhosted.org/packages/97/f9/a14bac5331f3e55bcbbed906a0c8b112f554152ddf09efeb6f5f95653ffd/protobuf-3.19.6-cp310-cp310-win_amd64.whl", hash = "sha256:347b393d4dd06fb93a77620781e11c058b3b0a5289262f094379ada2920a3730", size = 895657 }, - { url = "https://files.pythonhosted.org/packages/f4/c3/3e7c48cd8e5b0ce9c2e57f38a166cc1b894b9b6a92f28f14a3fa48766ee7/protobuf-3.19.6-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2b2d2913bcda0e0ec9a784d194bc490f5dc3d9d71d322d070b11a0ade32ff6ba", size = 980365 }, - { url = "https://files.pythonhosted.org/packages/af/53/7e26bb62753910e98243725c2348c5c37914596dd52d53b1d3287662dbe2/protobuf-3.19.6-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:d0b635cefebd7a8a0f92020562dead912f81f401af7e71f16bf9506ff3bdbb38", size = 913911 }, - { url = "https://files.pythonhosted.org/packages/3c/f8/b6d7fd81464553e24a07f9d444126db3beb902b6bff6fcd6524d8284097f/protobuf-3.19.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a552af4dc34793803f4e735aabe97ffc45962dfd3a237bdde242bff5a3de684", size = 1055475 }, - { url = "https://files.pythonhosted.org/packages/ac/dd/b5e3b826322295afd5153fadd2c0ee5ab1ed2ddefa6a7f49f935ca9b51d3/protobuf-3.19.6-cp38-cp38-win32.whl", hash = "sha256:0469bc66160180165e4e29de7f445e57a34ab68f49357392c5b2f54c656ab25e", size = 775927 }, - { url = "https://files.pythonhosted.org/packages/fd/38/cb53f28950a386c8d7e17fc305c97a158cf85d51d7e6caffe4f37336c138/protobuf-3.19.6-cp38-cp38-win_amd64.whl", hash = "sha256:91d5f1e139ff92c37e0ff07f391101df77e55ebb97f46bbc1535298d72019462", size = 896095 }, - { url = "https://files.pythonhosted.org/packages/17/e6/9554fb822d60c513898234722635d0c29a51f252b359449cfb351b16172a/protobuf-3.19.6-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c0ccd3f940fe7f3b35a261b1dd1b4fc850c8fde9f74207015431f174be5976b3", size = 980513 }, - { url = "https://files.pythonhosted.org/packages/bc/db/8b33c9558f1f27dd74e7f9ad730c6b32efab431419af556b1659e125b041/protobuf-3.19.6-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:30a15015d86b9c3b8d6bf78d5b8c7749f2512c29f168ca259c9d7727604d0e39", size = 913657 }, - { url = "https://files.pythonhosted.org/packages/51/61/e80b7a04f4e1b4eecc86582335205fd876abca0abafee4a6c001f70a375e/protobuf-3.19.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:878b4cd080a21ddda6ac6d1e163403ec6eea2e206cf225982ae04567d39be7b0", size = 1055457 }, - { url = "https://files.pythonhosted.org/packages/26/6b/e2aca5a4e83f95796bc65ee81d3a2c06b13dcdba0db294517cad5e71b3f9/protobuf-3.19.6-cp39-cp39-win32.whl", hash = "sha256:5a0d7539a1b1fb7e76bf5faa0b44b30f812758e989e59c40f77a7dab320e79b9", size = 775891 }, - { url = "https://files.pythonhosted.org/packages/9b/6e/ffecb6488629407ac44ec956990c616eb56fd0069a81a9e28feeed8a2ca2/protobuf-3.19.6-cp39-cp39-win_amd64.whl", hash = "sha256:bbf5cea5048272e1c60d235c7bd12ce1b14b8a16e76917f371c718bd3005f045", size = 895879 }, - { url = "https://files.pythonhosted.org/packages/32/27/1141a8232723dcb10a595cc0ce4321dcbbd5215300bf4acfc142343205bf/protobuf-3.19.6-py2.py3-none-any.whl", hash = "sha256:14082457dc02be946f60b15aad35e9f5c69e738f80ebbc0900a19bc83734a5a4", size = 162648 }, -] - [[package]] name = "protobuf" version = "4.25.6" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version >= '3.12'", - "python_full_version == '3.11.*'", - "python_full_version == '3.10.*'", - "python_full_version == '3.9.*'", -] sdist = { url = "https://files.pythonhosted.org/packages/48/d5/cccc7e82bbda9909ced3e7a441a24205ea07fea4ce23a772743c0c7611fa/protobuf-4.25.6.tar.gz", hash = "sha256:f8cfbae7c5afd0d0eaccbe73267339bff605a2315860bb1ba08eb66670a9a91f", size = 380631 } wheels = [ { url = "https://files.pythonhosted.org/packages/42/41/0ff3559d9a0fbdb37c9452f2b84e61f7784d8d7b9850182c7ef493f523ee/protobuf-4.25.6-cp310-abi3-win32.whl", hash = "sha256:61df6b5786e2b49fc0055f636c1e8f0aff263808bb724b95b164685ac1bcc13a", size = 392454 }, @@ -1895,10 +1947,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/03/361e87cc824452376c2abcef0eabd18da78a7439479ec6541cf29076a4dc/protobuf-4.25.6-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:6d4381f2417606d7e01750e2729fe6fbcda3f9883aa0c32b51d23012bded6c91", size = 394246 }, { url = "https://files.pythonhosted.org/packages/64/d5/7dbeb69b74fa88f297c6d8f11b7c9cef0c2e2fb1fdf155c2ca5775cfa998/protobuf-4.25.6-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:5dd800da412ba7f6f26d2c08868a5023ce624e1fdb28bccca2dc957191e81fb5", size = 293714 }, { url = "https://files.pythonhosted.org/packages/d4/f0/6d5c100f6b18d973e86646aa5fc09bc12ee88a28684a56fd95511bceee68/protobuf-4.25.6-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:4434ff8bb5576f9e0c78f47c41cdf3a152c0b44de475784cd3fd170aef16205a", size = 294634 }, - { url = "https://files.pythonhosted.org/packages/ab/7e/fa5728ef2382291b5cb06b0ec4a05013ce9ab67c2e6c19c02d2d3acd99d2/protobuf-4.25.6-cp38-cp38-win32.whl", hash = "sha256:8bad0f9e8f83c1fbfcc34e573352b17dfce7d0519512df8519994168dc015d7d", size = 392493 }, - { url = "https://files.pythonhosted.org/packages/2f/b1/b625b3e86742420a0920f9ef43c9145c2256e8ffb5b6fc8d932d1ec28fbd/protobuf-4.25.6-cp38-cp38-win_amd64.whl", hash = "sha256:b6905b68cde3b8243a198268bb46fbec42b3455c88b6b02fb2529d2c306d18fc", size = 413389 }, - { url = "https://files.pythonhosted.org/packages/f2/2d/3d28a1c513ae75808bd8663f517a9f38693aaf448a120a88788af9931832/protobuf-4.25.6-cp39-cp39-win32.whl", hash = "sha256:3f3b0b39db04b509859361ac9bca65a265fe9342e6b9406eda58029f5b1d10b2", size = 392500 }, - { url = "https://files.pythonhosted.org/packages/9d/35/0705d3ff52364af2bdd2989b09fce93c268ea7c3fc03bdc7174ec630048c/protobuf-4.25.6-cp39-cp39-win_amd64.whl", hash = "sha256:6ef2045f89d4ad8d95fd43cd84621487832a61d15b49500e4c1350e8a0ef96be", size = 413389 }, { url = "https://files.pythonhosted.org/packages/71/eb/be11a1244d0e58ee04c17a1f939b100199063e26ecca8262c04827fe0bf5/protobuf-4.25.6-py3-none-any.whl", hash = "sha256:07972021c8e30b870cfc0863409d033af940213e0e7f64e27fe017b929d2c9f7", size = 156466 }, ] @@ -1911,27 +1959,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/10/30/a58b32568f1623aaad7db22aa9eafc4c6c194b429ff35bdc55ca2726da47/py4j-0.10.9.7-py2.py3-none-any.whl", hash = "sha256:85defdfd2b2376eb3abf5ca6474b51ab7e0de341c75a02f46dc9b5976f5a5c1b", size = 200481 }, ] -[[package]] -name = "pyasn1" -version = "0.6.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/ba/e9/01f1a64245b89f039897cb0130016d79f77d52669aae6ee7b159a6c4c018/pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034", size = 145322 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/c8/f1/d6a797abb14f6283c0ddff96bbdd46937f64122b8c925cab503dd37f8214/pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629", size = 83135 }, -] - -[[package]] -name = "pyasn1-modules" -version = "0.4.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "pyasn1", marker = "python_full_version < '3.9'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/1d/67/6afbf0d507f73c32d21084a79946bfcfca5fbc62a72057e9c23797a737c9/pyasn1_modules-0.4.1.tar.gz", hash = "sha256:c28e2dbf9c06ad61c71a075c7e0f9fd0f1b0bb2d2ad4377f240d33ac2ab60a7c", size = 310028 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/77/89/bc88a6711935ba795a679ea6ebee07e128050d6382eaa35a0a47c8032bdc/pyasn1_modules-0.4.1-py3-none-any.whl", hash = "sha256:49bfa96b45a292b711e986f222502c1c9a5e1f4e568fc30e2574a6c7d07838fd", size = 181537 }, -] - [[package]] name = "pycodestyle" version = "2.12.1" @@ -2005,32 +2032,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9b/67/4e197c300976af185b7cef4c02203e175fb127e414125916bf1128b639a9/pydantic_core-2.27.2-cp312-cp312-win32.whl", hash = "sha256:1e2cb691ed9834cd6a8be61228471d0a503731abfb42f82458ff27be7b2186fc", size = 1834064 }, { url = "https://files.pythonhosted.org/packages/1f/ea/cd7209a889163b8dcca139fe32b9687dd05249161a3edda62860430457a5/pydantic_core-2.27.2-cp312-cp312-win_amd64.whl", hash = "sha256:cc3f1a99a4f4f9dd1de4fe0312c114e740b5ddead65bb4102884b384c15d8bc9", size = 1989046 }, { url = "https://files.pythonhosted.org/packages/bc/49/c54baab2f4658c26ac633d798dab66b4c3a9bbf47cff5284e9c182f4137a/pydantic_core-2.27.2-cp312-cp312-win_arm64.whl", hash = "sha256:3911ac9284cd8a1792d3cb26a2da18f3ca26c6908cc434a18f730dc0db7bfa3b", size = 1885092 }, - { url = "https://files.pythonhosted.org/packages/43/53/13e9917fc69c0a4aea06fd63ed6a8d6cda9cf140ca9584d49c1650b0ef5e/pydantic_core-2.27.2-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:d3e8d504bdd3f10835468f29008d72fc8359d95c9c415ce6e767203db6127506", size = 1899595 }, - { url = "https://files.pythonhosted.org/packages/f4/20/26c549249769ed84877f862f7bb93f89a6ee08b4bee1ed8781616b7fbb5e/pydantic_core-2.27.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:521eb9b7f036c9b6187f0b47318ab0d7ca14bd87f776240b90b21c1f4f149320", size = 1775010 }, - { url = "https://files.pythonhosted.org/packages/35/eb/8234e05452d92d2b102ffa1b56d801c3567e628fdc63f02080fdfc68fd5e/pydantic_core-2.27.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:85210c4d99a0114f5a9481b44560d7d1e35e32cc5634c656bc48e590b669b145", size = 1830727 }, - { url = "https://files.pythonhosted.org/packages/8f/df/59f915c8b929d5f61e5a46accf748a87110ba145156f9326d1a7d28912b2/pydantic_core-2.27.2-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d716e2e30c6f140d7560ef1538953a5cd1a87264c737643d481f2779fc247fe1", size = 1868393 }, - { url = "https://files.pythonhosted.org/packages/d5/52/81cf4071dca654d485c277c581db368b0c95b2b883f4d7b736ab54f72ddf/pydantic_core-2.27.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f66d89ba397d92f840f8654756196d93804278457b5fbede59598a1f9f90b228", size = 2040300 }, - { url = "https://files.pythonhosted.org/packages/9c/00/05197ce1614f5c08d7a06e1d39d5d8e704dc81971b2719af134b844e2eaf/pydantic_core-2.27.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:669e193c1c576a58f132e3158f9dfa9662969edb1a250c54d8fa52590045f046", size = 2738785 }, - { url = "https://files.pythonhosted.org/packages/f7/a3/5f19bc495793546825ab160e530330c2afcee2281c02b5ffafd0b32ac05e/pydantic_core-2.27.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9fdbe7629b996647b99c01b37f11170a57ae675375b14b8c13b8518b8320ced5", size = 1996493 }, - { url = "https://files.pythonhosted.org/packages/ed/e8/e0102c2ec153dc3eed88aea03990e1b06cfbca532916b8a48173245afe60/pydantic_core-2.27.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d262606bf386a5ba0b0af3b97f37c83d7011439e3dc1a9298f21efb292e42f1a", size = 1998544 }, - { url = "https://files.pythonhosted.org/packages/fb/a3/4be70845b555bd80aaee9f9812a7cf3df81550bce6dadb3cfee9c5d8421d/pydantic_core-2.27.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:cabb9bcb7e0d97f74df8646f34fc76fbf793b7f6dc2438517d7a9e50eee4f14d", size = 2007449 }, - { url = "https://files.pythonhosted.org/packages/e3/9f/b779ed2480ba355c054e6d7ea77792467631d674b13d8257085a4bc7dcda/pydantic_core-2.27.2-cp38-cp38-musllinux_1_1_armv7l.whl", hash = "sha256:d2d63f1215638d28221f664596b1ccb3944f6e25dd18cd3b86b0a4c408d5ebb9", size = 2129460 }, - { url = "https://files.pythonhosted.org/packages/a0/f0/a6ab0681f6e95260c7fbf552874af7302f2ea37b459f9b7f00698f875492/pydantic_core-2.27.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:bca101c00bff0adb45a833f8451b9105d9df18accb8743b08107d7ada14bd7da", size = 2159609 }, - { url = "https://files.pythonhosted.org/packages/8a/2b/e1059506795104349712fbca647b18b3f4a7fd541c099e6259717441e1e0/pydantic_core-2.27.2-cp38-cp38-win32.whl", hash = "sha256:f6f8e111843bbb0dee4cb6594cdc73e79b3329b526037ec242a3e49012495b3b", size = 1819886 }, - { url = "https://files.pythonhosted.org/packages/aa/6d/df49c17f024dfc58db0bacc7b03610058018dd2ea2eaf748ccbada4c3d06/pydantic_core-2.27.2-cp38-cp38-win_amd64.whl", hash = "sha256:fd1aea04935a508f62e0d0ef1f5ae968774a32afc306fb8545e06f5ff5cdf3ad", size = 1980773 }, - { url = "https://files.pythonhosted.org/packages/27/97/3aef1ddb65c5ccd6eda9050036c956ff6ecbfe66cb7eb40f280f121a5bb0/pydantic_core-2.27.2-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:c10eb4f1659290b523af58fa7cffb452a61ad6ae5613404519aee4bfbf1df993", size = 1896475 }, - { url = "https://files.pythonhosted.org/packages/ad/d3/5668da70e373c9904ed2f372cb52c0b996426f302e0dee2e65634c92007d/pydantic_core-2.27.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ef592d4bad47296fb11f96cd7dc898b92e795032b4894dfb4076cfccd43a9308", size = 1772279 }, - { url = "https://files.pythonhosted.org/packages/8a/9e/e44b8cb0edf04a2f0a1f6425a65ee089c1d6f9c4c2dcab0209127b6fdfc2/pydantic_core-2.27.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c61709a844acc6bf0b7dce7daae75195a10aac96a596ea1b776996414791ede4", size = 1829112 }, - { url = "https://files.pythonhosted.org/packages/1c/90/1160d7ac700102effe11616e8119e268770f2a2aa5afb935f3ee6832987d/pydantic_core-2.27.2-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:42c5f762659e47fdb7b16956c71598292f60a03aa92f8b6351504359dbdba6cf", size = 1866780 }, - { url = "https://files.pythonhosted.org/packages/ee/33/13983426df09a36d22c15980008f8d9c77674fc319351813b5a2739b70f3/pydantic_core-2.27.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4c9775e339e42e79ec99c441d9730fccf07414af63eac2f0e48e08fd38a64d76", size = 2037943 }, - { url = "https://files.pythonhosted.org/packages/01/d7/ced164e376f6747e9158c89988c293cd524ab8d215ae4e185e9929655d5c/pydantic_core-2.27.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:57762139821c31847cfb2df63c12f725788bd9f04bc2fb392790959b8f70f118", size = 2740492 }, - { url = "https://files.pythonhosted.org/packages/8b/1f/3dc6e769d5b7461040778816aab2b00422427bcaa4b56cc89e9c653b2605/pydantic_core-2.27.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d1e85068e818c73e048fe28cfc769040bb1f475524f4745a5dc621f75ac7630", size = 1995714 }, - { url = "https://files.pythonhosted.org/packages/07/d7/a0bd09bc39283530b3f7c27033a814ef254ba3bd0b5cfd040b7abf1fe5da/pydantic_core-2.27.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:097830ed52fd9e427942ff3b9bc17fab52913b2f50f2880dc4a5611446606a54", size = 1997163 }, - { url = "https://files.pythonhosted.org/packages/2d/bb/2db4ad1762e1c5699d9b857eeb41959191980de6feb054e70f93085e1bcd/pydantic_core-2.27.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:044a50963a614ecfae59bb1eaf7ea7efc4bc62f49ed594e18fa1e5d953c40e9f", size = 2005217 }, - { url = "https://files.pythonhosted.org/packages/53/5f/23a5a3e7b8403f8dd8fc8a6f8b49f6b55c7d715b77dcf1f8ae919eeb5628/pydantic_core-2.27.2-cp39-cp39-musllinux_1_1_armv7l.whl", hash = "sha256:4e0b4220ba5b40d727c7f879eac379b822eee5d8fff418e9d3381ee45b3b0362", size = 2127899 }, - { url = "https://files.pythonhosted.org/packages/c2/ae/aa38bb8dd3d89c2f1d8362dd890ee8f3b967330821d03bbe08fa01ce3766/pydantic_core-2.27.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5e4f4bb20d75e9325cc9696c6802657b58bc1dbbe3022f32cc2b2b632c3fbb96", size = 2155726 }, - { url = "https://files.pythonhosted.org/packages/98/61/4f784608cc9e98f70839187117ce840480f768fed5d386f924074bf6213c/pydantic_core-2.27.2-cp39-cp39-win32.whl", hash = "sha256:cca63613e90d001b9f2f9a9ceb276c308bfa2a43fafb75c8031c4f66039e8c6e", size = 1817219 }, - { url = "https://files.pythonhosted.org/packages/57/82/bb16a68e4a1a858bb3768c2c8f1ff8d8978014e16598f001ea29a25bf1d1/pydantic_core-2.27.2-cp39-cp39-win_amd64.whl", hash = "sha256:77d1bca19b0f7021b3a982e6f903dcd5b2b06076def36a652e3907f596e29f67", size = 1985382 }, { url = "https://files.pythonhosted.org/packages/46/72/af70981a341500419e67d5cb45abe552a7c74b66326ac8877588488da1ac/pydantic_core-2.27.2-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:2bf14caea37e91198329b828eae1618c068dfb8ef17bb33287a7ad4b61ac314e", size = 1891159 }, { url = "https://files.pythonhosted.org/packages/ad/3d/c5913cccdef93e0a6a95c2d057d2c2cba347815c845cda79ddd3c0f5e17d/pydantic_core-2.27.2-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:b0cb791f5b45307caae8810c2023a184c74605ec3bcbb67d13846c28ff731ff8", size = 1768331 }, { url = "https://files.pythonhosted.org/packages/f6/f0/a3ae8fbee269e4934f14e2e0e00928f9346c5943174f2811193113e58252/pydantic_core-2.27.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:688d3fd9fcb71f41c4c015c023d12a79d1c4c0732ec9eb35d96e3388a120dcf3", size = 1822467 }, @@ -2040,15 +2041,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/86/aa/837821ecf0c022bbb74ca132e117c358321e72e7f9702d1b6a03758545e2/pydantic_core-2.27.2-pp310-pypy310_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:0296abcb83a797db256b773f45773da397da75a08f5fcaef41f2044adec05f50", size = 2116582 }, { url = "https://files.pythonhosted.org/packages/81/b0/5e74656e95623cbaa0a6278d16cf15e10a51f6002e3ec126541e95c29ea3/pydantic_core-2.27.2-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:0d75070718e369e452075a6017fbf187f788e17ed67a3abd47fa934d001863d9", size = 2151985 }, { url = "https://files.pythonhosted.org/packages/63/37/3e32eeb2a451fddaa3898e2163746b0cffbbdbb4740d38372db0490d67f3/pydantic_core-2.27.2-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:7e17b560be3c98a8e3aa66ce828bdebb9e9ac6ad5466fba92eb74c4c95cb1151", size = 2004715 }, - { url = "https://files.pythonhosted.org/packages/29/0e/dcaea00c9dbd0348b723cae82b0e0c122e0fa2b43fa933e1622fd237a3ee/pydantic_core-2.27.2-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:c33939a82924da9ed65dab5a65d427205a73181d8098e79b6b426bdf8ad4e656", size = 1891733 }, - { url = "https://files.pythonhosted.org/packages/86/d3/e797bba8860ce650272bda6383a9d8cad1d1c9a75a640c9d0e848076f85e/pydantic_core-2.27.2-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:00bad2484fa6bda1e216e7345a798bd37c68fb2d97558edd584942aa41b7d278", size = 1768375 }, - { url = "https://files.pythonhosted.org/packages/41/f7/f847b15fb14978ca2b30262548f5fc4872b2724e90f116393eb69008299d/pydantic_core-2.27.2-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c817e2b40aba42bac6f457498dacabc568c3b7a986fc9ba7c8d9d260b71485fb", size = 1822307 }, - { url = "https://files.pythonhosted.org/packages/9c/63/ed80ec8255b587b2f108e514dc03eed1546cd00f0af281e699797f373f38/pydantic_core-2.27.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:251136cdad0cb722e93732cb45ca5299fb56e1344a833640bf93b2803f8d1bfd", size = 1979971 }, - { url = "https://files.pythonhosted.org/packages/a9/6d/6d18308a45454a0de0e975d70171cadaf454bc7a0bf86b9c7688e313f0bb/pydantic_core-2.27.2-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d2088237af596f0a524d3afc39ab3b036e8adb054ee57cbb1dcf8e09da5b29cc", size = 1987616 }, - { url = "https://files.pythonhosted.org/packages/82/8a/05f8780f2c1081b800a7ca54c1971e291c2d07d1a50fb23c7e4aef4ed403/pydantic_core-2.27.2-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:d4041c0b966a84b4ae7a09832eb691a35aec90910cd2dbe7a208de59be77965b", size = 1998943 }, - { url = "https://files.pythonhosted.org/packages/5e/3e/fe5b6613d9e4c0038434396b46c5303f5ade871166900b357ada4766c5b7/pydantic_core-2.27.2-pp39-pypy39_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:8083d4e875ebe0b864ffef72a4304827015cff328a1be6e22cc850753bfb122b", size = 2116654 }, - { url = "https://files.pythonhosted.org/packages/db/ad/28869f58938fad8cc84739c4e592989730bfb69b7c90a8fff138dff18e1e/pydantic_core-2.27.2-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:f141ee28a0ad2123b6611b6ceff018039df17f32ada8b534e6aa039545a3efb2", size = 2152292 }, - { url = "https://files.pythonhosted.org/packages/a1/0c/c5c5cd3689c32ed1fe8c5d234b079c12c281c051759770c05b8bed6412b5/pydantic_core-2.27.2-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:7d0c8399fcc1848491f00e0314bd59fb34a9c008761bcb422a057670c3f65e35", size = 2004961 }, ] [[package]] @@ -2058,8 +2050,6 @@ source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/c3/7f/256f1954343fc44641d04292e1410470337db3720bd57b510782e449d6db/pyfarmhash-0.3.2.tar.gz", hash = "sha256:4146308a0ed0b37d69003199c90fa59b155666c9deb0249b40e594cee10551ea", size = 99890 } wheels = [ { url = "https://files.pythonhosted.org/packages/99/e7/e3c97a5ba709e28db06f89684ad54e740efcdf8235cecc9ae2626b3188d2/pyfarmhash-0.3.2-cp310-cp310-win_amd64.whl", hash = "sha256:dc3ef74dc64a19bb325d85749e0a7955ebaa6777d7cc357bfa4ba6e5864a4362", size = 14375 }, - { url = "https://files.pythonhosted.org/packages/0e/4f/0c7dddbb43e6da3be80c52182555951636c541a2bad5d7a4418e59a6d6e3/pyfarmhash-0.3.2-cp38-cp38-win_amd64.whl", hash = "sha256:00eadc04a0a0595fbf05bf430bac3baf9788e00b3abcdd26cd478b4b3c244837", size = 14408 }, - { url = "https://files.pythonhosted.org/packages/7e/d3/659f24a6636df197d804db194f764bd3489d037b66a06f4f750eb6b14e60/pyfarmhash-0.3.2-cp39-cp39-win_amd64.whl", hash = "sha256:9c125ffdf317672996e63e98bf1e84d0829fc2a85db3304ca62f873767bc0abf", size = 14372 }, ] [[package]] @@ -2093,7 +2083,6 @@ dependencies = [ { name = "platformdirs" }, { name = "tomli", marker = "python_full_version < '3.11'" }, { name = "tomlkit" }, - { name = "typing-extensions", marker = "python_full_version < '3.10'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/aa/f7/325b71d78faf9fcf1c246669a2448356fe3d7d69c5f93d48f41cc241a6bb/pylint-3.0.0.tar.gz", hash = "sha256:d22816c963816d7810b87afe0bdf5c80009e1078ecbb9c8f2e2a24d4430039b1", size = 441234 } wheels = [ @@ -2144,8 +2133,8 @@ name = "pytest-cov" version = "2.12.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "coverage", version = "7.6.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "coverage", version = "7.6.12", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, + { name = "coverage", version = "7.6.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "coverage", version = "7.6.12", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "pytest" }, { name = "toml" }, ] @@ -2200,8 +2189,8 @@ dependencies = [ { name = "click" }, { name = "dotty-dict" }, { name = "gitpython" }, - { name = "importlib-resources", version = "6.4.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "importlib-resources", version = "6.5.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, + { name = "importlib-resources", version = "6.4.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "importlib-resources", version = "6.5.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "jinja2" }, { name = "pydantic" }, { name = "python-gitlab" }, @@ -2257,22 +2246,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c9/1f/4f998c900485e5c0ef43838363ba4a9723ac0ad73a9dc42068b12aaba4e4/PyYAML-6.0.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8b9c7197f7cb2738065c481a0461e50ad02f18c78cd75775628afb4d7137fb3b", size = 756611 }, { url = "https://files.pythonhosted.org/packages/df/d1/f5a275fdb252768b7a11ec63585bc38d0e87c9e05668a139fea92b80634c/PyYAML-6.0.2-cp312-cp312-win32.whl", hash = "sha256:ef6107725bd54b262d6dedcc2af448a266975032bc85ef0172c5f059da6325b4", size = 140591 }, { url = "https://files.pythonhosted.org/packages/0c/e8/4f648c598b17c3d06e8753d7d13d57542b30d56e6c2dedf9c331ae56312e/PyYAML-6.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:7e7401d0de89a9a855c839bc697c079a4af81cf878373abd7dc625847d25cbd8", size = 156338 }, - { url = "https://files.pythonhosted.org/packages/74/d9/323a59d506f12f498c2097488d80d16f4cf965cee1791eab58b56b19f47a/PyYAML-6.0.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:24471b829b3bf607e04e88d79542a9d48bb037c2267d7927a874e6c205ca7e9a", size = 183218 }, - { url = "https://files.pythonhosted.org/packages/74/cc/20c34d00f04d785f2028737e2e2a8254e1425102e730fee1d6396f832577/PyYAML-6.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7fded462629cfa4b685c5416b949ebad6cec74af5e2d42905d41e257e0869f5", size = 728067 }, - { url = "https://files.pythonhosted.org/packages/20/52/551c69ca1501d21c0de51ddafa8c23a0191ef296ff098e98358f69080577/PyYAML-6.0.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d84a1718ee396f54f3a086ea0a66d8e552b2ab2017ef8b420e92edbc841c352d", size = 757812 }, - { url = "https://files.pythonhosted.org/packages/fd/7f/2c3697bba5d4aa5cc2afe81826d73dfae5f049458e44732c7a0938baa673/PyYAML-6.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9056c1ecd25795207ad294bcf39f2db3d845767be0ea6e6a34d856f006006083", size = 746531 }, - { url = "https://files.pythonhosted.org/packages/8c/ab/6226d3df99900e580091bb44258fde77a8433511a86883bd4681ea19a858/PyYAML-6.0.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:82d09873e40955485746739bcb8b4586983670466c23382c19cffecbf1fd8706", size = 800820 }, - { url = "https://files.pythonhosted.org/packages/a0/99/a9eb0f3e710c06c5d922026f6736e920d431812ace24aae38228d0d64b04/PyYAML-6.0.2-cp38-cp38-win32.whl", hash = "sha256:43fa96a3ca0d6b1812e01ced1044a003533c47f6ee8aca31724f78e93ccc089a", size = 145514 }, - { url = "https://files.pythonhosted.org/packages/75/8a/ee831ad5fafa4431099aa4e078d4c8efd43cd5e48fbc774641d233b683a9/PyYAML-6.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:01179a4a8559ab5de078078f37e5c1a30d76bb88519906844fd7bdea1b7729ff", size = 162702 }, - { url = "https://files.pythonhosted.org/packages/65/d8/b7a1db13636d7fb7d4ff431593c510c8b8fca920ade06ca8ef20015493c5/PyYAML-6.0.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:688ba32a1cffef67fd2e9398a2efebaea461578b0923624778664cc1c914db5d", size = 184777 }, - { url = "https://files.pythonhosted.org/packages/0a/02/6ec546cd45143fdf9840b2c6be8d875116a64076218b61d68e12548e5839/PyYAML-6.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a8786accb172bd8afb8be14490a16625cbc387036876ab6ba70912730faf8e1f", size = 172318 }, - { url = "https://files.pythonhosted.org/packages/0e/9a/8cc68be846c972bda34f6c2a93abb644fb2476f4dcc924d52175786932c9/PyYAML-6.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d8e03406cac8513435335dbab54c0d385e4a49e4945d2909a581c83647ca0290", size = 720891 }, - { url = "https://files.pythonhosted.org/packages/e9/6c/6e1b7f40181bc4805e2e07f4abc10a88ce4648e7e95ff1abe4ae4014a9b2/PyYAML-6.0.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f753120cb8181e736c57ef7636e83f31b9c0d1722c516f7e86cf15b7aa57ff12", size = 722614 }, - { url = "https://files.pythonhosted.org/packages/3d/32/e7bd8535d22ea2874cef6a81021ba019474ace0d13a4819c2a4bce79bd6a/PyYAML-6.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3b1fdb9dc17f5a7677423d508ab4f243a726dea51fa5e70992e59a7411c89d19", size = 737360 }, - { url = "https://files.pythonhosted.org/packages/d7/12/7322c1e30b9be969670b672573d45479edef72c9a0deac3bb2868f5d7469/PyYAML-6.0.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0b69e4ce7a131fe56b7e4d770c67429700908fc0752af059838b1cfb41960e4e", size = 699006 }, - { url = "https://files.pythonhosted.org/packages/82/72/04fcad41ca56491995076630c3ec1e834be241664c0c09a64c9a2589b507/PyYAML-6.0.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a9f8c2e67970f13b16084e04f134610fd1d374bf477b17ec1599185cf611d725", size = 723577 }, - { url = "https://files.pythonhosted.org/packages/ed/5e/46168b1f2757f1fcd442bc3029cd8767d88a98c9c05770d8b420948743bb/PyYAML-6.0.2-cp39-cp39-win32.whl", hash = "sha256:6395c297d42274772abc367baaa79683958044e5d3835486c16da75d2a694631", size = 144593 }, - { url = "https://files.pythonhosted.org/packages/19/87/5124b1c1f2412bb95c59ec481eaf936cd32f0fe2a7b16b97b81c4c017a6a/PyYAML-6.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:39693e1f8320ae4f43943590b49779ffb98acb81f788220ea932a6b6c51004d8", size = 162312 }, ] [[package]] @@ -2339,38 +2312,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/93/2d/dd56bb76bd8e95bbce684326302f287455b56242a4f9c61f1bc76e28360e/regex-2024.11.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:0a86e7eeca091c09e021db8eb72d54751e527fa47b8d5787caf96d9831bd02ad", size = 787692 }, { url = "https://files.pythonhosted.org/packages/0b/55/31877a249ab7a5156758246b9c59539abbeba22461b7d8adc9e8475ff73e/regex-2024.11.6-cp312-cp312-win32.whl", hash = "sha256:32f9a4c643baad4efa81d549c2aadefaeba12249b2adc5af541759237eee1c54", size = 262135 }, { url = "https://files.pythonhosted.org/packages/38/ec/ad2d7de49a600cdb8dd78434a1aeffe28b9d6fc42eb36afab4a27ad23384/regex-2024.11.6-cp312-cp312-win_amd64.whl", hash = "sha256:a93c194e2df18f7d264092dc8539b8ffb86b45b899ab976aa15d48214138e81b", size = 273567 }, - { url = "https://files.pythonhosted.org/packages/44/0f/207b37e6e08d548fac0aa00bf0b7464126315d58ab5161216b8cb3abb2aa/regex-2024.11.6-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:3a51ccc315653ba012774efca4f23d1d2a8a8f278a6072e29c7147eee7da446b", size = 482777 }, - { url = "https://files.pythonhosted.org/packages/5a/5a/586bafa294c5d2451265d3685815606c61e620f469cac3b946fff0a4aa48/regex-2024.11.6-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ad182d02e40de7459b73155deb8996bbd8e96852267879396fb274e8700190e3", size = 287751 }, - { url = "https://files.pythonhosted.org/packages/08/92/9df786fad8a4e0766bfc9a2e334c5f0757356070c9639b2ec776b8cdef3d/regex-2024.11.6-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:ba9b72e5643641b7d41fa1f6d5abda2c9a263ae835b917348fc3c928182ad467", size = 284552 }, - { url = "https://files.pythonhosted.org/packages/0a/27/0b3cf7d9fbe43301aa3473d54406019a7380abe4e3c9ae250bac13c4fdb3/regex-2024.11.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:40291b1b89ca6ad8d3f2b82782cc33807f1406cf68c8d440861da6304d8ffbbd", size = 783587 }, - { url = "https://files.pythonhosted.org/packages/89/38/499b32cbb61163af60a5c5ff26aacea7836fe7e3d821e76af216e996088c/regex-2024.11.6-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cdf58d0e516ee426a48f7b2c03a332a4114420716d55769ff7108c37a09951bf", size = 822904 }, - { url = "https://files.pythonhosted.org/packages/3f/a4/e3b11c643e5ae1059a08aeef971973f0c803d2a9ae2e7a86f97c68146a6c/regex-2024.11.6-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a36fdf2af13c2b14738f6e973aba563623cb77d753bbbd8d414d18bfaa3105dd", size = 809900 }, - { url = "https://files.pythonhosted.org/packages/5a/c8/dc7153ceb5bcc344f5c4f0291ea45925a5f00009afa3849e91561ac2e847/regex-2024.11.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d1cee317bfc014c2419a76bcc87f071405e3966da434e03e13beb45f8aced1a6", size = 785105 }, - { url = "https://files.pythonhosted.org/packages/2a/29/841489ea52013062b22625fbaf49b0916aeb62bae2e56425ac30f9dead46/regex-2024.11.6-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:50153825ee016b91549962f970d6a4442fa106832e14c918acd1c8e479916c4f", size = 773033 }, - { url = "https://files.pythonhosted.org/packages/3e/4e/4a0da5e87f7c2dc73a8505785d5af2b1a19c66f4645b93caa50b7eb08242/regex-2024.11.6-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:ea1bfda2f7162605f6e8178223576856b3d791109f15ea99a9f95c16a7636fb5", size = 702374 }, - { url = "https://files.pythonhosted.org/packages/94/6e/444e66346600d11e8a0f4bb31611973cffa772d5033ba1cf1f15de8a0d52/regex-2024.11.6-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:df951c5f4a1b1910f1a99ff42c473ff60f8225baa1cdd3539fe2819d9543e9df", size = 769990 }, - { url = "https://files.pythonhosted.org/packages/da/28/95c3ed6cd51b27f54e59940400e2a3ddd3f8bbbc3aaf947e57a67104ecbd/regex-2024.11.6-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:072623554418a9911446278f16ecb398fb3b540147a7828c06e2011fa531e773", size = 775345 }, - { url = "https://files.pythonhosted.org/packages/07/5d/0cd19cf44d96a7aa31526611c24235d21d27c23b65201cb2c5cac508dd42/regex-2024.11.6-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:f654882311409afb1d780b940234208a252322c24a93b442ca714d119e68086c", size = 840379 }, - { url = "https://files.pythonhosted.org/packages/2a/13/ec3f8d85b789ee1c6ffbdfd4092fd901416716317ee17bf51aa2890bac96/regex-2024.11.6-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:89d75e7293d2b3e674db7d4d9b1bee7f8f3d1609428e293771d1a962617150cc", size = 845842 }, - { url = "https://files.pythonhosted.org/packages/50/cb/7170247e65afea2bf9204bcb2682f292b0a3a57d112478da199b84d59792/regex-2024.11.6-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:f65557897fc977a44ab205ea871b690adaef6b9da6afda4790a2484b04293a5f", size = 775026 }, - { url = "https://files.pythonhosted.org/packages/cc/06/c817c9201f09b7d9dd033039ba90d8197c91e9fe2984141f2d1de270c159/regex-2024.11.6-cp38-cp38-win32.whl", hash = "sha256:6f44ec28b1f858c98d3036ad5d7d0bfc568bdd7a74f9c24e25f41ef1ebfd81a4", size = 261738 }, - { url = "https://files.pythonhosted.org/packages/cf/69/c39e16320400842eb4358c982ef5fc680800866f35ebfd4dd38a22967ce0/regex-2024.11.6-cp38-cp38-win_amd64.whl", hash = "sha256:bb8f74f2f10dbf13a0be8de623ba4f9491faf58c24064f32b65679b021ed0001", size = 274094 }, - { url = "https://files.pythonhosted.org/packages/89/23/c4a86df398e57e26f93b13ae63acce58771e04bdde86092502496fa57f9c/regex-2024.11.6-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:5704e174f8ccab2026bd2f1ab6c510345ae8eac818b613d7d73e785f1310f839", size = 482682 }, - { url = "https://files.pythonhosted.org/packages/3c/8b/45c24ab7a51a1658441b961b86209c43e6bb9d39caf1e63f46ce6ea03bc7/regex-2024.11.6-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:220902c3c5cc6af55d4fe19ead504de80eb91f786dc102fbd74894b1551f095e", size = 287679 }, - { url = "https://files.pythonhosted.org/packages/7a/d1/598de10b17fdafc452d11f7dada11c3be4e379a8671393e4e3da3c4070df/regex-2024.11.6-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5e7e351589da0850c125f1600a4c4ba3c722efefe16b297de54300f08d734fbf", size = 284578 }, - { url = "https://files.pythonhosted.org/packages/49/70/c7eaa219efa67a215846766fde18d92d54cb590b6a04ffe43cef30057622/regex-2024.11.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5056b185ca113c88e18223183aa1a50e66507769c9640a6ff75859619d73957b", size = 782012 }, - { url = "https://files.pythonhosted.org/packages/89/e5/ef52c7eb117dd20ff1697968219971d052138965a4d3d9b95e92e549f505/regex-2024.11.6-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e34b51b650b23ed3354b5a07aab37034d9f923db2a40519139af34f485f77d0", size = 820580 }, - { url = "https://files.pythonhosted.org/packages/5f/3f/9f5da81aff1d4167ac52711acf789df13e789fe6ac9545552e49138e3282/regex-2024.11.6-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5670bce7b200273eee1840ef307bfa07cda90b38ae56e9a6ebcc9f50da9c469b", size = 809110 }, - { url = "https://files.pythonhosted.org/packages/86/44/2101cc0890c3621b90365c9ee8d7291a597c0722ad66eccd6ffa7f1bcc09/regex-2024.11.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:08986dce1339bc932923e7d1232ce9881499a0e02925f7402fb7c982515419ef", size = 780919 }, - { url = "https://files.pythonhosted.org/packages/ce/2e/3e0668d8d1c7c3c0d397bf54d92fc182575b3a26939aed5000d3cc78760f/regex-2024.11.6-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:93c0b12d3d3bc25af4ebbf38f9ee780a487e8bf6954c115b9f015822d3bb8e48", size = 771515 }, - { url = "https://files.pythonhosted.org/packages/a6/49/1bc4584254355e3dba930a3a2fd7ad26ccba3ebbab7d9100db0aff2eedb0/regex-2024.11.6-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:764e71f22ab3b305e7f4c21f1a97e1526a25ebdd22513e251cf376760213da13", size = 696957 }, - { url = "https://files.pythonhosted.org/packages/c8/dd/42879c1fc8a37a887cd08e358af3d3ba9e23038cd77c7fe044a86d9450ba/regex-2024.11.6-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:f056bf21105c2515c32372bbc057f43eb02aae2fda61052e2f7622c801f0b4e2", size = 768088 }, - { url = "https://files.pythonhosted.org/packages/89/96/c05a0fe173cd2acd29d5e13c1adad8b706bcaa71b169e1ee57dcf2e74584/regex-2024.11.6-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:69ab78f848845569401469da20df3e081e6b5a11cb086de3eed1d48f5ed57c95", size = 774752 }, - { url = "https://files.pythonhosted.org/packages/b5/f3/a757748066255f97f14506483436c5f6aded7af9e37bca04ec30c90ca683/regex-2024.11.6-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:86fddba590aad9208e2fa8b43b4c098bb0ec74f15718bb6a704e3c63e2cef3e9", size = 838862 }, - { url = "https://files.pythonhosted.org/packages/5c/93/c6d2092fd479dcaeea40fc8fa673822829181ded77d294a7f950f1dda6e2/regex-2024.11.6-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:684d7a212682996d21ca12ef3c17353c021fe9de6049e19ac8481ec35574a70f", size = 842622 }, - { url = "https://files.pythonhosted.org/packages/ff/9c/daa99532c72f25051a90ef90e1413a8d54413a9e64614d9095b0c1c154d0/regex-2024.11.6-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:a03e02f48cd1abbd9f3b7e3586d97c8f7a9721c436f51a5245b3b9483044480b", size = 772713 }, - { url = "https://files.pythonhosted.org/packages/13/5d/61a533ccb8c231b474ac8e3a7d70155b00dfc61af6cafdccd1947df6d735/regex-2024.11.6-cp39-cp39-win32.whl", hash = "sha256:41758407fc32d5c3c5de163888068cfee69cb4c2be844e7ac517a52770f9af57", size = 261756 }, - { url = "https://files.pythonhosted.org/packages/dc/7b/e59b7f7c91ae110d154370c24133f947262525b5d6406df65f23422acc17/regex-2024.11.6-cp39-cp39-win_amd64.whl", hash = "sha256:b2837718570f95dd41675328e111345f9b7095d821bac435aac173ac80b19983", size = 274110 }, ] [[package]] @@ -2381,27 +2322,14 @@ dependencies = [ { name = "certifi" }, { name = "charset-normalizer" }, { name = "idna" }, - { name = "urllib3", version = "2.2.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "urllib3", version = "2.3.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, + { name = "urllib3", version = "2.2.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "urllib3", version = "2.3.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/63/70/2bf7780ad2d390a8d301ad0b550f1581eadbd9a20f896afe06353c2a2913/requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760", size = 131218 } wheels = [ { url = "https://files.pythonhosted.org/packages/f9/9b/335f9764261e915ed497fcdeb11df5dfd6f7bf257d4a6a2a686d80da4d54/requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6", size = 64928 }, ] -[[package]] -name = "requests-oauthlib" -version = "2.0.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "oauthlib", marker = "python_full_version < '3.9'" }, - { name = "requests", marker = "python_full_version < '3.9'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/42/f2/05f29bc3913aea15eb670be136045bf5c5bbf4b99ecb839da9b422bb2c85/requests-oauthlib-2.0.0.tar.gz", hash = "sha256:b3dffaebd884d8cd778494369603a9e7b58d29111bf6b41bdc2dcd87203af4e9", size = 55650 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/3b/5d/63d4ae3b9daea098d5d6f5da83984853c1bbacd5dc826764b249fe119d24/requests_oauthlib-2.0.0-py2.py3-none-any.whl", hash = "sha256:7dd8a5c40426b779b0868c404bdef9768deccf22749cde15852df527e6269b36", size = 24179 }, -] - [[package]] name = "requests-toolbelt" version = "1.0.0" @@ -2428,110 +2356,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/19/71/39c7c0d87f8d4e6c020a393182060eaefeeae6c01dab6a84ec346f2567df/rich-13.9.4-py3-none-any.whl", hash = "sha256:6049d5e6ec054bf2779ab3358186963bac2ea89175919d699e378b99738c2a90", size = 242424 }, ] -[[package]] -name = "rsa" -version = "4.9" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "pyasn1", marker = "python_full_version < '3.9'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/aa/65/7d973b89c4d2351d7fb232c2e452547ddfa243e93131e7cfa766da627b52/rsa-4.9.tar.gz", hash = "sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21", size = 29711 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/49/97/fa78e3d2f65c02c8e1268b9aba606569fe97f6c8f7c2d74394553347c145/rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7", size = 34315 }, -] - -[[package]] -name = "scikit-learn" -version = "1.3.2" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version < '3.9'", -] -dependencies = [ - { name = "joblib", marker = "python_full_version < '3.9'" }, - { name = "numpy", version = "1.24.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "scipy", version = "1.10.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "threadpoolctl", marker = "python_full_version < '3.9'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/88/00/835e3d280fdd7784e76bdef91dd9487582d7951a7254f59fc8004fc8b213/scikit-learn-1.3.2.tar.gz", hash = "sha256:a2f54c76accc15a34bfb9066e6c7a56c1e7235dda5762b990792330b52ccfb05", size = 7510251 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/0d/53/570b55a6e10b8694ac1e3024d2df5cd443f1b4ff6d28430845da8b9019b3/scikit_learn-1.3.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e326c0eb5cf4d6ba40f93776a20e9a7a69524c4db0757e7ce24ba222471ee8a1", size = 10209999 }, - { url = "https://files.pythonhosted.org/packages/70/d0/50ace22129f79830e3cf682d0a2bd4843ef91573299d43112d52790163a8/scikit_learn-1.3.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:535805c2a01ccb40ca4ab7d081d771aea67e535153e35a1fd99418fcedd1648a", size = 9479353 }, - { url = "https://files.pythonhosted.org/packages/8f/46/fcc35ed7606c50d3072eae5a107a45cfa5b7f5fa8cc48610edd8cc8e8550/scikit_learn-1.3.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1215e5e58e9880b554b01187b8c9390bf4dc4692eedeaf542d3273f4785e342c", size = 10304705 }, - { url = "https://files.pythonhosted.org/packages/d0/0b/26ad95cf0b747be967b15fb71a06f5ac67aba0fd2f9cd174de6edefc4674/scikit_learn-1.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ee107923a623b9f517754ea2f69ea3b62fc898a3641766cb7deb2f2ce450161", size = 10827807 }, - { url = "https://files.pythonhosted.org/packages/69/8a/cf17d6443f5f537e099be81535a56ab68a473f9393fbffda38cd19899fc8/scikit_learn-1.3.2-cp310-cp310-win_amd64.whl", hash = "sha256:35a22e8015048c628ad099da9df5ab3004cdbf81edc75b396fd0cff8699ac58c", size = 9255427 }, - { url = "https://files.pythonhosted.org/packages/08/5d/e5acecd6e99a6b656e42e7a7b18284e2f9c9f512e8ed6979e1e75d25f05f/scikit_learn-1.3.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6fb6bc98f234fda43163ddbe36df8bcde1d13ee176c6dc9b92bb7d3fc842eb66", size = 10116376 }, - { url = "https://files.pythonhosted.org/packages/40/c6/2e91eefb757822e70d351e02cc38d07c137212ae7c41ac12746415b4860a/scikit_learn-1.3.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:18424efee518a1cde7b0b53a422cde2f6625197de6af36da0b57ec502f126157", size = 9383415 }, - { url = "https://files.pythonhosted.org/packages/fa/fd/b3637639e73bb72b12803c5245f2a7299e09b2acd85a0f23937c53369a1c/scikit_learn-1.3.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3271552a5eb16f208a6f7f617b8cc6d1f137b52c8a1ef8edf547db0259b2c9fb", size = 10279163 }, - { url = "https://files.pythonhosted.org/packages/0c/2a/d3ff6091406bc2207e0adb832ebd15e40ac685811c7e2e3b432bfd969b71/scikit_learn-1.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc4144a5004a676d5022b798d9e573b05139e77f271253a4703eed295bde0433", size = 10884422 }, - { url = "https://files.pythonhosted.org/packages/4e/ba/ce9bd1cd4953336a0e213b29cb80bb11816f2a93de8c99f88ef0b446ad0c/scikit_learn-1.3.2-cp311-cp311-win_amd64.whl", hash = "sha256:67f37d708f042a9b8d59551cf94d30431e01374e00dc2645fa186059c6c5d78b", size = 9207060 }, - { url = "https://files.pythonhosted.org/packages/26/7e/2c3b82c8c29aa384c8bf859740419278627d2cdd0050db503c8840e72477/scikit_learn-1.3.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:8db94cd8a2e038b37a80a04df8783e09caac77cbe052146432e67800e430c028", size = 9979322 }, - { url = "https://files.pythonhosted.org/packages/cf/fc/6c52ffeb587259b6b893b7cac268f1eb1b5426bcce1aa20e53523bfe6944/scikit_learn-1.3.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:61a6efd384258789aa89415a410dcdb39a50e19d3d8410bd29be365bcdd512d5", size = 9270688 }, - { url = "https://files.pythonhosted.org/packages/e5/a7/6f4ae76f72ae9de162b97acbf1f53acbe404c555f968d13da21e4112a002/scikit_learn-1.3.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb06f8dce3f5ddc5dee1715a9b9f19f20d295bed8e3cd4fa51e1d050347de525", size = 10280398 }, - { url = "https://files.pythonhosted.org/packages/5d/b7/ee35904c07a0666784349529412fbb9814a56382b650d30fd9d6be5e5054/scikit_learn-1.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5b2de18d86f630d68fe1f87af690d451388bb186480afc719e5f770590c2ef6c", size = 10796478 }, - { url = "https://files.pythonhosted.org/packages/fe/6b/db949ed5ac367987b1f250f070f340b7715d22f0c9c965bdf07de6ca75a3/scikit_learn-1.3.2-cp312-cp312-win_amd64.whl", hash = "sha256:0402638c9a7c219ee52c94cbebc8fcb5eb9fe9c773717965c1f4185588ad3107", size = 9133979 }, - { url = "https://files.pythonhosted.org/packages/e3/52/fd60b0b022af41fbf3463587ddc719288f0f2d4e46603ab3184996cd5f04/scikit_learn-1.3.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:a19f90f95ba93c1a7f7924906d0576a84da7f3b2282ac3bfb7a08a32801add93", size = 10064879 }, - { url = "https://files.pythonhosted.org/packages/a4/62/92e9cec3deca8b45abf62dd8f6469d688b3f28b9c170809fcc46f110b523/scikit_learn-1.3.2-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:b8692e395a03a60cd927125eef3a8e3424d86dde9b2370d544f0ea35f78a8073", size = 9373934 }, - { url = "https://files.pythonhosted.org/packages/49/81/91585dc83ec81dcd52e934f6708bf350b06949d8bfa13bf3b711b851c3f4/scikit_learn-1.3.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:15e1e94cc23d04d39da797ee34236ce2375ddea158b10bee3c343647d615581d", size = 10499159 }, - { url = "https://files.pythonhosted.org/packages/3f/48/6fdd99f5717045f9984616b5c2ec683d6286d30c0ac234563062132b83ab/scikit_learn-1.3.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:785a2213086b7b1abf037aeadbbd6d67159feb3e30263434139c98425e3dcfcf", size = 11067392 }, - { url = "https://files.pythonhosted.org/packages/52/2d/ad6928a578c78bb0e44e34a5a922818b14c56716b81d145924f1f291416f/scikit_learn-1.3.2-cp38-cp38-win_amd64.whl", hash = "sha256:64381066f8aa63c2710e6b56edc9f0894cc7bf59bd71b8ce5613a4559b6145e0", size = 9257871 }, - { url = "https://files.pythonhosted.org/packages/f8/67/584acfc492ae1bd293d80c7a8c57ba7456e4e415c64869b7c240679eaf78/scikit_learn-1.3.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6c43290337f7a4b969d207e620658372ba3c1ffb611f8bc2b6f031dc5c6d1d03", size = 10232286 }, - { url = "https://files.pythonhosted.org/packages/20/0f/51e3ccdc87c25e2e33bf7962249ff8c5ab1d6aed0144fb003348ce8bd352/scikit_learn-1.3.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:dc9002fc200bed597d5d34e90c752b74df516d592db162f756cc52836b38fe0e", size = 9504918 }, - { url = "https://files.pythonhosted.org/packages/61/2e/5bbf3c9689d2911b65297fb5861c4257e54c797b3158c9fca8a5c576644b/scikit_learn-1.3.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d08ada33e955c54355d909b9c06a4789a729977f165b8bae6f225ff0a60ec4a", size = 10358127 }, - { url = "https://files.pythonhosted.org/packages/25/89/dce01a35d354159dcc901e3c7e7eb3fe98de5cb3639c6cd39518d8830caa/scikit_learn-1.3.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:763f0ae4b79b0ff9cca0bf3716bcc9915bdacff3cebea15ec79652d1cc4fa5c9", size = 10890482 }, - { url = "https://files.pythonhosted.org/packages/1c/49/30ffcac5af06d08dfdd27da322ce31a373b733711bb272941877c1e4794a/scikit_learn-1.3.2-cp39-cp39-win_amd64.whl", hash = "sha256:ed932ea780517b00dae7431e031faae6b49b20eb6950918eb83bd043237950e0", size = 9331050 }, -] - -[[package]] -name = "scikit-learn" -version = "1.6.1" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version >= '3.12'", - "python_full_version == '3.11.*'", - "python_full_version == '3.10.*'", - "python_full_version == '3.9.*'", -] -dependencies = [ - { name = "joblib", marker = "python_full_version >= '3.9'" }, - { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, - { name = "scipy", version = "1.13.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.9.*'" }, - { name = "scipy", version = "1.15.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, - { name = "threadpoolctl", marker = "python_full_version >= '3.9'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/9e/a5/4ae3b3a0755f7b35a280ac90b28817d1f380318973cff14075ab41ef50d9/scikit_learn-1.6.1.tar.gz", hash = "sha256:b4fc2525eca2c69a59260f583c56a7557c6ccdf8deafdba6e060f94c1c59738e", size = 7068312 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/2e/3a/f4597eb41049110b21ebcbb0bcb43e4035017545daa5eedcfeb45c08b9c5/scikit_learn-1.6.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d056391530ccd1e501056160e3c9673b4da4805eb67eb2bdf4e983e1f9c9204e", size = 12067702 }, - { url = "https://files.pythonhosted.org/packages/37/19/0423e5e1fd1c6ec5be2352ba05a537a473c1677f8188b9306097d684b327/scikit_learn-1.6.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:0c8d036eb937dbb568c6242fa598d551d88fb4399c0344d95c001980ec1c7d36", size = 11112765 }, - { url = "https://files.pythonhosted.org/packages/70/95/d5cb2297a835b0f5fc9a77042b0a2d029866379091ab8b3f52cc62277808/scikit_learn-1.6.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8634c4bd21a2a813e0a7e3900464e6d593162a29dd35d25bdf0103b3fce60ed5", size = 12643991 }, - { url = "https://files.pythonhosted.org/packages/b7/91/ab3c697188f224d658969f678be86b0968ccc52774c8ab4a86a07be13c25/scikit_learn-1.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:775da975a471c4f6f467725dff0ced5c7ac7bda5e9316b260225b48475279a1b", size = 13497182 }, - { url = "https://files.pythonhosted.org/packages/17/04/d5d556b6c88886c092cc989433b2bab62488e0f0dafe616a1d5c9cb0efb1/scikit_learn-1.6.1-cp310-cp310-win_amd64.whl", hash = "sha256:8a600c31592bd7dab31e1c61b9bbd6dea1b3433e67d264d17ce1017dbdce8002", size = 11125517 }, - { url = "https://files.pythonhosted.org/packages/6c/2a/e291c29670795406a824567d1dfc91db7b699799a002fdaa452bceea8f6e/scikit_learn-1.6.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:72abc587c75234935e97d09aa4913a82f7b03ee0b74111dcc2881cba3c5a7b33", size = 12102620 }, - { url = "https://files.pythonhosted.org/packages/25/92/ee1d7a00bb6b8c55755d4984fd82608603a3cc59959245068ce32e7fb808/scikit_learn-1.6.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:b3b00cdc8f1317b5f33191df1386c0befd16625f49d979fe77a8d44cae82410d", size = 11116234 }, - { url = "https://files.pythonhosted.org/packages/30/cd/ed4399485ef364bb25f388ab438e3724e60dc218c547a407b6e90ccccaef/scikit_learn-1.6.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dc4765af3386811c3ca21638f63b9cf5ecf66261cc4815c1db3f1e7dc7b79db2", size = 12592155 }, - { url = "https://files.pythonhosted.org/packages/a8/f3/62fc9a5a659bb58a03cdd7e258956a5824bdc9b4bb3c5d932f55880be569/scikit_learn-1.6.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:25fc636bdaf1cc2f4a124a116312d837148b5e10872147bdaf4887926b8c03d8", size = 13497069 }, - { url = "https://files.pythonhosted.org/packages/a1/a6/c5b78606743a1f28eae8f11973de6613a5ee87366796583fb74c67d54939/scikit_learn-1.6.1-cp311-cp311-win_amd64.whl", hash = "sha256:fa909b1a36e000a03c382aade0bd2063fd5680ff8b8e501660c0f59f021a6415", size = 11139809 }, - { url = "https://files.pythonhosted.org/packages/0a/18/c797c9b8c10380d05616db3bfb48e2a3358c767affd0857d56c2eb501caa/scikit_learn-1.6.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:926f207c804104677af4857b2c609940b743d04c4c35ce0ddc8ff4f053cddc1b", size = 12104516 }, - { url = "https://files.pythonhosted.org/packages/c4/b7/2e35f8e289ab70108f8cbb2e7a2208f0575dc704749721286519dcf35f6f/scikit_learn-1.6.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:2c2cae262064e6a9b77eee1c8e768fc46aa0b8338c6a8297b9b6759720ec0ff2", size = 11167837 }, - { url = "https://files.pythonhosted.org/packages/a4/f6/ff7beaeb644bcad72bcfd5a03ff36d32ee4e53a8b29a639f11bcb65d06cd/scikit_learn-1.6.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1061b7c028a8663fb9a1a1baf9317b64a257fcb036dae5c8752b2abef31d136f", size = 12253728 }, - { url = "https://files.pythonhosted.org/packages/29/7a/8bce8968883e9465de20be15542f4c7e221952441727c4dad24d534c6d99/scikit_learn-1.6.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e69fab4ebfc9c9b580a7a80111b43d214ab06250f8a7ef590a4edf72464dd86", size = 13147700 }, - { url = "https://files.pythonhosted.org/packages/62/27/585859e72e117fe861c2079bcba35591a84f801e21bc1ab85bce6ce60305/scikit_learn-1.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:70b1d7e85b1c96383f872a519b3375f92f14731e279a7b4c6cfd650cf5dffc52", size = 11110613 }, - { url = "https://files.pythonhosted.org/packages/d2/37/b305b759cc65829fe1b8853ff3e308b12cdd9d8884aa27840835560f2b42/scikit_learn-1.6.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6849dd3234e87f55dce1db34c89a810b489ead832aaf4d4550b7ea85628be6c1", size = 12101868 }, - { url = "https://files.pythonhosted.org/packages/83/74/f64379a4ed5879d9db744fe37cfe1978c07c66684d2439c3060d19a536d8/scikit_learn-1.6.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:e7be3fa5d2eb9be7d77c3734ff1d599151bb523674be9b834e8da6abe132f44e", size = 11144062 }, - { url = "https://files.pythonhosted.org/packages/fd/dc/d5457e03dc9c971ce2b0d750e33148dd060fefb8b7dc71acd6054e4bb51b/scikit_learn-1.6.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:44a17798172df1d3c1065e8fcf9019183f06c87609b49a124ebdf57ae6cb0107", size = 12693173 }, - { url = "https://files.pythonhosted.org/packages/79/35/b1d2188967c3204c78fa79c9263668cf1b98060e8e58d1a730fe5b2317bb/scikit_learn-1.6.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8b7a3b86e411e4bce21186e1c180d792f3d99223dcfa3b4f597ecc92fa1a422", size = 13518605 }, - { url = "https://files.pythonhosted.org/packages/fb/d8/8d603bdd26601f4b07e2363032b8565ab82eb857f93d86d0f7956fcf4523/scikit_learn-1.6.1-cp39-cp39-win_amd64.whl", hash = "sha256:7a73d457070e3318e32bdb3aa79a8d990474f19035464dfd8bede2883ab5dc3b", size = 11155078 }, -] - [[package]] name = "scipy" version = "1.10.1" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version < '3.9'", + "python_full_version < '3.11'", ] dependencies = [ - { name = "numpy", version = "1.24.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, + { name = "numpy", version = "1.24.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/84/a9/2bf119f3f9cff1f376f924e39cfae18dec92a1514784046d185731301281/scipy-1.10.1.tar.gz", hash = "sha256:2cf9dfb80a7b4589ba4c40ce7588986d6d5cebc5457cad2c2880f6bc2d42f3a5", size = 42407997 } wheels = [ @@ -2545,54 +2378,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a5/3d/b69746c50e44893da57a68457da3d7e5bb75f6a37fbace3769b70d017488/scipy-1.10.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aaea0a6be54462ec027de54fca511540980d1e9eea68b2d5c1dbfe084797be35", size = 30687257 }, { url = "https://files.pythonhosted.org/packages/21/cd/fe2d4af234b80dc08c911ce63fdaee5badcdde3e9bcd9a68884580652ef0/scipy-1.10.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:15a35c4242ec5f292c3dd364a7c71a61be87a3d4ddcc693372813c0b73c9af1d", size = 34124096 }, { url = "https://files.pythonhosted.org/packages/65/76/903324159e4a3566e518c558aeb21571d642f781d842d8dd0fd9c6b0645a/scipy-1.10.1-cp311-cp311-win_amd64.whl", hash = "sha256:43b8e0bcb877faf0abfb613d51026cd5cc78918e9530e375727bf0625c82788f", size = 42238704 }, - { url = "https://files.pythonhosted.org/packages/a0/e3/37508a11dae501349d7c16e4dd18c706a023629eedc650ee094593887a89/scipy-1.10.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:5678f88c68ea866ed9ebe3a989091088553ba12c6090244fdae3e467b1139c35", size = 35041063 }, - { url = "https://files.pythonhosted.org/packages/93/4a/50c436de1353cce8b66b26e49a687f10b91fe7465bf34e4565d810153003/scipy-1.10.1-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:39becb03541f9e58243f4197584286e339029e8908c46f7221abeea4b749fa88", size = 28797694 }, - { url = "https://files.pythonhosted.org/packages/d2/b5/ff61b79ad0ebd15d87ade10e0f4e80114dd89fac34a5efade39e99048c91/scipy-1.10.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bce5869c8d68cf383ce240e44c1d9ae7c06078a9396df68ce88a1230f93a30c1", size = 31024657 }, - { url = "https://files.pythonhosted.org/packages/69/f0/fb07a9548e48b687b8bf2fa81d71aba9cfc548d365046ca1c791e24db99d/scipy-1.10.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:07c3457ce0b3ad5124f98a86533106b643dd811dd61b548e78cf4c8786652f6f", size = 34540352 }, - { url = "https://files.pythonhosted.org/packages/32/8e/7f403535ddf826348c9b8417791e28712019962f7e90ff845896d6325d09/scipy-1.10.1-cp38-cp38-win_amd64.whl", hash = "sha256:049a8bbf0ad95277ffba9b3b7d23e5369cc39e66406d60422c8cfef40ccc8415", size = 42215036 }, - { url = "https://files.pythonhosted.org/packages/d9/7d/78b8035bc93c869b9f17261c87aae97a9cdb937f65f0d453c2831aa172fc/scipy-1.10.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:cd9f1027ff30d90618914a64ca9b1a77a431159df0e2a195d8a9e8a04c78abf9", size = 35158611 }, - { url = "https://files.pythonhosted.org/packages/e7/f0/55d81813b1a4cb79ce7dc8290eac083bf38bfb36e1ada94ea13b7b1a5f79/scipy-1.10.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:79c8e5a6c6ffaf3a2262ef1be1e108a035cf4f05c14df56057b64acc5bebffb6", size = 28902591 }, - { url = "https://files.pythonhosted.org/packages/77/d1/722c457b319eed1d642e0a14c9be37eb475f0e6ed1f3401fa480d5d6d36e/scipy-1.10.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:51af417a000d2dbe1ec6c372dfe688e041a7084da4fdd350aeb139bd3fb55353", size = 30960654 }, - { url = "https://files.pythonhosted.org/packages/5d/30/b2a2a5bf1a3beefb7609fb871dcc6aef7217c69cef19a4631b7ab5622a8a/scipy-1.10.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1b4735d6c28aad3cdcf52117e0e91d6b39acd4272f3f5cd9907c24ee931ad601", size = 34458863 }, - { url = "https://files.pythonhosted.org/packages/35/20/0ec6246bbb43d18650c9a7cad6602e1a84fd8f9564a9b84cc5faf1e037d0/scipy-1.10.1-cp39-cp39-win_amd64.whl", hash = "sha256:7ff7f37b1bf4417baca958d254e8e2875d0cc23aaadbe65b3d5b3077b0eb23ea", size = 42509516 }, -] - -[[package]] -name = "scipy" -version = "1.13.1" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version == '3.9.*'", -] -dependencies = [ - { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.9.*'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/ae/00/48c2f661e2816ccf2ecd77982f6605b2950afe60f60a52b4cbbc2504aa8f/scipy-1.13.1.tar.gz", hash = "sha256:095a87a0312b08dfd6a6155cbbd310a8c51800fc931b8c0b84003014b874ed3c", size = 57210720 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/33/59/41b2529908c002ade869623b87eecff3e11e3ce62e996d0bdcb536984187/scipy-1.13.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:20335853b85e9a49ff7572ab453794298bcf0354d8068c5f6775a0eabf350aca", size = 39328076 }, - { url = "https://files.pythonhosted.org/packages/d5/33/f1307601f492f764062ce7dd471a14750f3360e33cd0f8c614dae208492c/scipy-1.13.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:d605e9c23906d1994f55ace80e0125c587f96c020037ea6aa98d01b4bd2e222f", size = 30306232 }, - { url = "https://files.pythonhosted.org/packages/c0/66/9cd4f501dd5ea03e4a4572ecd874936d0da296bd04d1c45ae1a4a75d9c3a/scipy-1.13.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cfa31f1def5c819b19ecc3a8b52d28ffdcc7ed52bb20c9a7589669dd3c250989", size = 33743202 }, - { url = "https://files.pythonhosted.org/packages/a3/ba/7255e5dc82a65adbe83771c72f384d99c43063648456796436c9a5585ec3/scipy-1.13.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f26264b282b9da0952a024ae34710c2aff7d27480ee91a2e82b7b7073c24722f", size = 38577335 }, - { url = "https://files.pythonhosted.org/packages/49/a5/bb9ded8326e9f0cdfdc412eeda1054b914dfea952bda2097d174f8832cc0/scipy-1.13.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:eccfa1906eacc02de42d70ef4aecea45415f5be17e72b61bafcfd329bdc52e94", size = 38820728 }, - { url = "https://files.pythonhosted.org/packages/12/30/df7a8fcc08f9b4a83f5f27cfaaa7d43f9a2d2ad0b6562cced433e5b04e31/scipy-1.13.1-cp310-cp310-win_amd64.whl", hash = "sha256:2831f0dc9c5ea9edd6e51e6e769b655f08ec6db6e2e10f86ef39bd32eb11da54", size = 46210588 }, - { url = "https://files.pythonhosted.org/packages/b4/15/4a4bb1b15bbd2cd2786c4f46e76b871b28799b67891f23f455323a0cdcfb/scipy-1.13.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:27e52b09c0d3a1d5b63e1105f24177e544a222b43611aaf5bc44d4a0979e32f9", size = 39333805 }, - { url = "https://files.pythonhosted.org/packages/ba/92/42476de1af309c27710004f5cdebc27bec62c204db42e05b23a302cb0c9a/scipy-1.13.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:54f430b00f0133e2224c3ba42b805bfd0086fe488835effa33fa291561932326", size = 30317687 }, - { url = "https://files.pythonhosted.org/packages/80/ba/8be64fe225360a4beb6840f3cbee494c107c0887f33350d0a47d55400b01/scipy-1.13.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e89369d27f9e7b0884ae559a3a956e77c02114cc60a6058b4e5011572eea9299", size = 33694638 }, - { url = "https://files.pythonhosted.org/packages/36/07/035d22ff9795129c5a847c64cb43c1fa9188826b59344fee28a3ab02e283/scipy-1.13.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a78b4b3345f1b6f68a763c6e25c0c9a23a9fd0f39f5f3d200efe8feda560a5fa", size = 38569931 }, - { url = "https://files.pythonhosted.org/packages/d9/10/f9b43de37e5ed91facc0cfff31d45ed0104f359e4f9a68416cbf4e790241/scipy-1.13.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:45484bee6d65633752c490404513b9ef02475b4284c4cfab0ef946def50b3f59", size = 38838145 }, - { url = "https://files.pythonhosted.org/packages/4a/48/4513a1a5623a23e95f94abd675ed91cfb19989c58e9f6f7d03990f6caf3d/scipy-1.13.1-cp311-cp311-win_amd64.whl", hash = "sha256:5713f62f781eebd8d597eb3f88b8bf9274e79eeabf63afb4a737abc6c84ad37b", size = 46196227 }, - { url = "https://files.pythonhosted.org/packages/f2/7b/fb6b46fbee30fc7051913068758414f2721003a89dd9a707ad49174e3843/scipy-1.13.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:5d72782f39716b2b3509cd7c33cdc08c96f2f4d2b06d51e52fb45a19ca0c86a1", size = 39357301 }, - { url = "https://files.pythonhosted.org/packages/dc/5a/2043a3bde1443d94014aaa41e0b50c39d046dda8360abd3b2a1d3f79907d/scipy-1.13.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:017367484ce5498445aade74b1d5ab377acdc65e27095155e448c88497755a5d", size = 30363348 }, - { url = "https://files.pythonhosted.org/packages/e7/cb/26e4a47364bbfdb3b7fb3363be6d8a1c543bcd70a7753ab397350f5f189a/scipy-1.13.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:949ae67db5fa78a86e8fa644b9a6b07252f449dcf74247108c50e1d20d2b4627", size = 33406062 }, - { url = "https://files.pythonhosted.org/packages/88/ab/6ecdc526d509d33814835447bbbeedbebdec7cca46ef495a61b00a35b4bf/scipy-1.13.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de3ade0e53bc1f21358aa74ff4830235d716211d7d077e340c7349bc3542e884", size = 38218311 }, - { url = "https://files.pythonhosted.org/packages/0b/00/9f54554f0f8318100a71515122d8f4f503b1a2c4b4cfab3b4b68c0eb08fa/scipy-1.13.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2ac65fb503dad64218c228e2dc2d0a0193f7904747db43014645ae139c8fad16", size = 38442493 }, - { url = "https://files.pythonhosted.org/packages/3e/df/963384e90733e08eac978cd103c34df181d1fec424de383cdc443f418dd4/scipy-1.13.1-cp312-cp312-win_amd64.whl", hash = "sha256:cdd7dacfb95fea358916410ec61bbc20440f7860333aee6d882bb8046264e949", size = 45910955 }, - { url = "https://files.pythonhosted.org/packages/7f/29/c2ea58c9731b9ecb30b6738113a95d147e83922986b34c685b8f6eefde21/scipy-1.13.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:436bbb42a94a8aeef855d755ce5a465479c721e9d684de76bf61a62e7c2b81d5", size = 39352927 }, - { url = "https://files.pythonhosted.org/packages/5c/c0/e71b94b20ccf9effb38d7147c0064c08c622309fd487b1b677771a97d18c/scipy-1.13.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:8335549ebbca860c52bf3d02f80784e91a004b71b059e3eea9678ba994796a24", size = 30324538 }, - { url = "https://files.pythonhosted.org/packages/6d/0f/aaa55b06d474817cea311e7b10aab2ea1fd5d43bc6a2861ccc9caec9f418/scipy-1.13.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d533654b7d221a6a97304ab63c41c96473ff04459e404b83275b60aa8f4b7004", size = 33732190 }, - { url = "https://files.pythonhosted.org/packages/35/f5/d0ad1a96f80962ba65e2ce1de6a1e59edecd1f0a7b55990ed208848012e0/scipy-1.13.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:637e98dcf185ba7f8e663e122ebf908c4702420477ae52a04f9908707456ba4d", size = 38612244 }, - { url = "https://files.pythonhosted.org/packages/8d/02/1165905f14962174e6569076bcc3315809ae1291ed14de6448cc151eedfd/scipy-1.13.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a014c2b3697bde71724244f63de2476925596c24285c7a637364761f8710891c", size = 38845637 }, - { url = "https://files.pythonhosted.org/packages/3e/77/dab54fe647a08ee4253963bcd8f9cf17509c8ca64d6335141422fe2e2114/scipy-1.13.1-cp39-cp39-win_amd64.whl", hash = "sha256:392e4ec766654852c25ebad4f64e4e584cf19820b980bc04960bca0b0cd6eaa2", size = 46227440 }, ] [[package]] @@ -2602,10 +2387,9 @@ source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version >= '3.12'", "python_full_version == '3.11.*'", - "python_full_version == '3.10.*'", ] dependencies = [ - { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" }, + { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/b7/b9/31ba9cd990e626574baf93fbc1ac61cf9ed54faafd04c479117517661637/scipy-1.15.2.tar.gz", hash = "sha256:cd58a314d92838f7e6f755c8a2167ead4f27e1fd5c1251fd54289569ef3495ec", size = 59417316 } wheels = [ @@ -2643,7 +2427,7 @@ name = "setuptools" version = "75.3.0" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version < '3.9'", + "python_full_version < '3.11'", ] sdist = { url = "https://files.pythonhosted.org/packages/ed/22/a438e0caa4576f8c383fa4d35f1cc01655a46c75be358960d815bfbb12bd/setuptools-75.3.0.tar.gz", hash = "sha256:fba5dd4d766e97be1b1681d98712680ae8f2f26d7881245f2ce9e40714f1a686", size = 1351577 } wheels = [ @@ -2657,8 +2441,6 @@ source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version >= '3.12'", "python_full_version == '3.11.*'", - "python_full_version == '3.10.*'", - "python_full_version == '3.9.*'", ] sdist = { url = "https://files.pythonhosted.org/packages/76/95/faf61eb8363f26aa7e1d762267a8d602a1b26d4f3a1e758e92cb3cb8b054/setuptools-80.10.2.tar.gz", hash = "sha256:8b0e9d10c784bf7d262c4e5ec5d4ec94127ce206e8738f29a437945fbc219b70", size = 1200343 } wheels = [ @@ -2693,171 +2475,81 @@ wheels = [ ] [[package]] -name = "tensorboard" -version = "2.11.2" +name = "sympy" +version = "1.14.0" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version < '3.9'", -] dependencies = [ - { name = "absl-py", marker = "python_full_version < '3.9'" }, - { name = "google-auth", marker = "python_full_version < '3.9'" }, - { name = "google-auth-oauthlib", marker = "python_full_version < '3.9'" }, - { name = "grpcio", marker = "python_full_version < '3.9'" }, - { name = "markdown", marker = "python_full_version < '3.9'" }, - { name = "numpy", version = "1.24.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "protobuf", version = "3.19.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "requests", marker = "python_full_version < '3.9'" }, - { name = "setuptools", version = "75.3.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "tensorboard-data-server", version = "0.6.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "tensorboard-plugin-wit", marker = "python_full_version < '3.9'" }, - { name = "werkzeug", version = "3.0.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "wheel", marker = "python_full_version < '3.9'" }, + { name = "mpmath" }, ] +sdist = { url = "https://files.pythonhosted.org/packages/83/d3/803453b36afefb7c2bb238361cd4ae6125a569b4db67cd9e79846ba2d68c/sympy-1.14.0.tar.gz", hash = "sha256:d3d3fe8df1e5a0b42f0e7bdf50541697dbe7d23746e894990c030e2b05e72517", size = 7793921 } wheels = [ - { url = "https://files.pythonhosted.org/packages/6f/77/e624b4916531721e674aa105151ffa5223fb224d3ca4bd5c10574664f944/tensorboard-2.11.2-py3-none-any.whl", hash = "sha256:cbaa2210c375f3af1509f8571360a19ccc3ded1d9641533414874b5deca47e89", size = 5992449 }, + { url = "https://files.pythonhosted.org/packages/a2/09/77d55d46fd61b4a135c444fc97158ef34a095e5681d0a6c10b75bf356191/sympy-1.14.0-py3-none-any.whl", hash = "sha256:e091cc3e99d2141a0ba2847328f5479b05d94a6635cb96148ccb3f34671bd8f5", size = 6299353 }, ] [[package]] name = "tensorboard" version = "2.16.2" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version >= '3.12'", - "python_full_version == '3.11.*'", - "python_full_version == '3.10.*'", - "python_full_version == '3.9.*'", -] dependencies = [ - { name = "absl-py", marker = "python_full_version >= '3.9'" }, - { name = "grpcio", marker = "python_full_version >= '3.9'" }, - { name = "markdown", marker = "python_full_version >= '3.9'" }, - { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, - { name = "protobuf", version = "4.25.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, - { name = "setuptools", version = "80.10.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, - { name = "six", marker = "python_full_version >= '3.9'" }, - { name = "tensorboard-data-server", version = "0.7.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, - { name = "werkzeug", version = "3.1.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, + { name = "absl-py" }, + { name = "grpcio" }, + { name = "markdown" }, + { name = "numpy", version = "1.24.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "protobuf" }, + { name = "setuptools", version = "75.3.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "setuptools", version = "80.10.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "six" }, + { name = "tensorboard-data-server" }, + { name = "werkzeug", version = "3.0.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "werkzeug", version = "3.1.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/3a/d0/b97889ffa769e2d1fdebb632084d5e8b53fc299d43a537acee7ec0c021a3/tensorboard-2.16.2-py3-none-any.whl", hash = "sha256:9f2b4e7dad86667615c0e5cd072f1ea8403fc032a299f0072d6f74855775cc45", size = 5490335 }, ] -[[package]] -name = "tensorboard-data-server" -version = "0.6.1" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version < '3.9'", -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/74/69/5747a957f95e2e1d252ca41476ae40ce79d70d38151d2e494feb7722860c/tensorboard_data_server-0.6.1-py3-none-any.whl", hash = "sha256:809fe9887682d35c1f7d1f54f0f40f98bb1f771b14265b453ca051e2ce58fca7", size = 2350 }, - { url = "https://files.pythonhosted.org/packages/3e/48/dd135dbb3cf16bfb923720163493cab70e7336db4b5f3103d49efa730404/tensorboard_data_server-0.6.1-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:fa8cef9be4fcae2f2363c88176638baf2da19c5ec90addb49b1cde05c95c88ee", size = 3546350 }, - { url = "https://files.pythonhosted.org/packages/60/f9/802efd84988bffd9f644c03b6e66fde8e76c3aa33db4279ddd11c5d61f4b/tensorboard_data_server-0.6.1-py3-none-manylinux2010_x86_64.whl", hash = "sha256:d8237580755e58eff68d1f3abefb5b1e39ae5c8b127cc40920f9c4fb33f4b98a", size = 4910134 }, -] - [[package]] name = "tensorboard-data-server" version = "0.7.2" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version >= '3.12'", - "python_full_version == '3.11.*'", - "python_full_version == '3.10.*'", - "python_full_version == '3.9.*'", -] wheels = [ { url = "https://files.pythonhosted.org/packages/7a/13/e503968fefabd4c6b2650af21e110aa8466fe21432cd7c43a84577a89438/tensorboard_data_server-0.7.2-py3-none-any.whl", hash = "sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb", size = 2356 }, { url = "https://files.pythonhosted.org/packages/b7/85/dabeaf902892922777492e1d253bb7e1264cadce3cea932f7ff599e53fea/tensorboard_data_server-0.7.2-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60", size = 4823598 }, { url = "https://files.pythonhosted.org/packages/73/c6/825dab04195756cf8ff2e12698f22513b3db2f64925bdd41671bfb33aaa5/tensorboard_data_server-0.7.2-py3-none-manylinux_2_31_x86_64.whl", hash = "sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530", size = 6590363 }, ] -[[package]] -name = "tensorboard-plugin-wit" -version = "1.8.1" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e0/68/e8ecfac5dd594b676c23a7f07ea34c197d7d69b3313afdf8ac1b0a9905a2/tensorboard_plugin_wit-1.8.1-py3-none-any.whl", hash = "sha256:ff26bdd583d155aa951ee3b152b3d0cffae8005dc697f72b44a8e8c2a77a8cbe", size = 781327 }, -] - -[[package]] -name = "tensorflow" -version = "2.11.1" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version < '3.9'", -] -dependencies = [ - { name = "absl-py", marker = "python_full_version < '3.9'" }, - { name = "astunparse", marker = "python_full_version < '3.9'" }, - { name = "flatbuffers", marker = "python_full_version < '3.9'" }, - { name = "gast", version = "0.4.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "google-pasta", marker = "python_full_version < '3.9'" }, - { name = "grpcio", marker = "python_full_version < '3.9'" }, - { name = "h5py", version = "3.11.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "keras", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "libclang", marker = "python_full_version < '3.9'" }, - { name = "numpy", version = "1.24.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "opt-einsum", marker = "python_full_version < '3.9'" }, - { name = "packaging", marker = "python_full_version < '3.9'" }, - { name = "protobuf", version = "3.19.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "setuptools", version = "75.3.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "six", marker = "python_full_version < '3.9'" }, - { name = "tensorboard", version = "2.11.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "tensorflow-estimator", marker = "python_full_version < '3.9'" }, - { name = "tensorflow-io-gcs-filesystem", marker = "(python_full_version < '3.9' and platform_machine != 'arm64') or (python_full_version < '3.9' and sys_platform != 'darwin')" }, - { name = "termcolor", version = "2.4.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "typing-extensions", marker = "python_full_version < '3.9'" }, - { name = "wrapt", marker = "python_full_version < '3.9'" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/23/f7/95a96ca7ccd190cc53973768cbfddf82eb6a3a073dd87ba34b6e72442af7/tensorflow-2.11.1-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:ac0e46c5de7985def49e4f688a0ca4180949a4d5dc62b89e9c6640db3c3982ba", size = 244334320 }, - { url = "https://files.pythonhosted.org/packages/fb/91/044e8cf52b062c87b57efe7421d0d36e5ee01114d324f7e71bf84f739a0f/tensorflow-2.11.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45b1669c523fa6dc240688bffe79f08dfbb76bf5e23a7fe10e722ba658637a44", size = 1936 }, - { url = "https://files.pythonhosted.org/packages/0d/f6/3ab09c7c161d2e08353e65f0df0512a8e4578d33497563edd61aa887f29d/tensorflow-2.11.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9a96595e0c068d54717405fa12f36b4a5bb0a9fc53fb9065155a92cff944b35b", size = 588264045 }, - { url = "https://files.pythonhosted.org/packages/bc/e6/2276b171697d4f1649bc870be7db0af128925f60d4d81129942fc88acd98/tensorflow-2.11.1-cp310-cp310-win_amd64.whl", hash = "sha256:13197f18f31a52d3f2eac28743d1b06abb8efd86017f184110a1b16841b745b1", size = 1914 }, - { url = "https://files.pythonhosted.org/packages/db/59/3fdf9a29b40191629b99262fffa672e774b4fcccedea599895fc689f115e/tensorflow-2.11.1-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:9f030f1bc9e7763fa03ec5738323c42021ababcd562fe861b3a3f41e9ff10e43", size = 244307939 }, - { url = "https://files.pythonhosted.org/packages/f1/2c/5556df785e3accb1c30613ad335275fb4b336be8d92e1df0ff5016c1ab36/tensorflow-2.11.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9f12855c1e8373c1327650061fd6a9a3d3772e1bac8241202ea8ccb56213d005", size = 1935 }, - { url = "https://files.pythonhosted.org/packages/d9/ab/038c68864bc84f2463936aa3dedf64136c61623f7ed300b9e9ea5783be2e/tensorflow-2.11.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:76cd4279cb500074a8ab28af116af7f060f0b015651bef552769d51e55d6fd5c", size = 588236305 }, - { url = "https://files.pythonhosted.org/packages/80/19/d370201c6a0a4967b6e6217cdd2442f87c6b52408a164485b105d2b4579c/tensorflow-2.11.1-cp38-cp38-win_amd64.whl", hash = "sha256:f5a2f75f28cd5fb615a5306f2091eac7da3a8fff949ab8804ec06b8e3682f837", size = 1913 }, - { url = "https://files.pythonhosted.org/packages/5e/2b/70f34ed683896c9e86d96152d76c23fa9b7125e4527904f2b2a3bf21e9c5/tensorflow-2.11.1-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:ea93246ad6c90ff0422f06a82164836fe8098989a8a65c3b02c720eadbe15dde", size = 244335000 }, - { url = "https://files.pythonhosted.org/packages/15/ce/03b677055f1857727a7eab916b285ef4edd0406850569c9d9e842c75181c/tensorflow-2.11.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:30ba6b3c2f68037e965a19427a1f2a5f0351b7ceae6c686938a8485b08e1e1f3", size = 1935 }, - { url = "https://files.pythonhosted.org/packages/70/1b/a467b78e0ca747c20226a03ddf4779a1122f8b04236ec7f2980a39738ddd/tensorflow-2.11.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ddd5c61f68d8125c985370de96a24a80aee5e3f1604efacec7e1c34ca72de24", size = 588266406 }, - { url = "https://files.pythonhosted.org/packages/53/9b/92d939a18ed618a3b89ea490e1d71e20ee9236dd98d7a67d55040c4e8c63/tensorflow-2.11.1-cp39-cp39-win_amd64.whl", hash = "sha256:b7d8834df3f72d7eab56bc2f34f2e52b82d705776b80b36bf5470b7538c9865c", size = 1912 }, -] - [[package]] name = "tensorflow" version = "2.16.2" source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version >= '3.12'", - "python_full_version == '3.11.*'", - "python_full_version == '3.10.*'", - "python_full_version == '3.9.*'", -] -dependencies = [ - { name = "absl-py", marker = "python_full_version >= '3.9'" }, - { name = "astunparse", marker = "python_full_version >= '3.9'" }, - { name = "flatbuffers", marker = "python_full_version >= '3.9'" }, - { name = "gast", version = "0.6.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, - { name = "google-pasta", marker = "python_full_version >= '3.9'" }, - { name = "grpcio", marker = "python_full_version >= '3.9'" }, - { name = "h5py", version = "3.13.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, - { name = "keras", version = "3.8.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, - { name = "libclang", marker = "python_full_version >= '3.9'" }, - { name = "ml-dtypes", marker = "python_full_version >= '3.9'" }, - { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, - { name = "opt-einsum", marker = "python_full_version >= '3.9'" }, - { name = "packaging", marker = "python_full_version >= '3.9'" }, - { name = "protobuf", version = "4.25.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, - { name = "requests", marker = "python_full_version >= '3.9'" }, - { name = "setuptools", version = "80.10.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, - { name = "six", marker = "python_full_version >= '3.9'" }, - { name = "tensorboard", version = "2.16.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, - { name = "tensorflow-io-gcs-filesystem", marker = "python_full_version >= '3.9' and python_full_version < '3.12'" }, - { name = "termcolor", version = "2.5.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, - { name = "typing-extensions", marker = "python_full_version >= '3.9'" }, - { name = "wrapt", marker = "python_full_version >= '3.9'" }, +dependencies = [ + { name = "absl-py" }, + { name = "astunparse" }, + { name = "flatbuffers" }, + { name = "gast", version = "0.4.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "gast", version = "0.6.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "google-pasta" }, + { name = "grpcio" }, + { name = "h5py", version = "3.11.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "h5py", version = "3.13.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "keras" }, + { name = "libclang" }, + { name = "ml-dtypes" }, + { name = "numpy", version = "1.24.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "opt-einsum" }, + { name = "packaging" }, + { name = "protobuf" }, + { name = "requests" }, + { name = "setuptools", version = "75.3.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "setuptools", version = "80.10.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "six" }, + { name = "tensorboard" }, + { name = "tensorflow-io-gcs-filesystem", marker = "python_full_version < '3.12'" }, + { name = "termcolor", version = "2.4.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "termcolor", version = "2.5.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "typing-extensions" }, + { name = "wrapt" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/f0/da/f242771de50d12dc1816cc9a66dfa5b377e8cd6ea316a6ffc9a7d2c6dfb8/tensorflow-2.16.2-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:546dc68d0740fb4b75593a6bfa308da9526fe31f65c2181d48c8551c4a0ad02f", size = 259544836 }, @@ -2875,19 +2567,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/53/b8/6ef11d379b8079310b20b89c6e1ebd5fb44f0acf51c0caf26366c5c928cf/tensorflow-2.16.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e7df529f8db271d3def80538aa7fcd6f5abe306f7b01cb5b580138df68afb499", size = 218991442 }, { url = "https://files.pythonhosted.org/packages/d6/5c/691ab570c3637ba26d76f24d743a71f6afd952fc74e42243c108690d9f66/tensorflow-2.16.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5badc6744672a3181c012b6ab2815975be34d0573db3b561383634acc0d46a55", size = 590776704 }, { url = "https://files.pythonhosted.org/packages/9b/cb/d3d450d41bd66813933b85f49bb872c66409852370e55d04bf426b8980f4/tensorflow-2.16.2-cp312-cp312-win_amd64.whl", hash = "sha256:505df82fde3b9c6a2a78bf679efb4d0a2e84f4f925202130477ca519ae1514e4", size = 2070 }, - { url = "https://files.pythonhosted.org/packages/05/c7/6a1be731753934a1965fa7d751dab30d5cdea1800ca34e0fe57c1d40ac35/tensorflow-2.16.2-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:2528a162e879b40d81db3568c08256718cec4a0356580badbd362cd8af02a41b", size = 259545482 }, - { url = "https://files.pythonhosted.org/packages/a2/18/6382ea38225ea302d21368d735b7a10eae0996ae26fdf07945bd4927b893/tensorflow-2.16.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:4c94106b73ecd044b7772e4338f8aa65a43ef2e290fe3fc27cc094138f50a341", size = 226982639 }, - { url = "https://files.pythonhosted.org/packages/0d/24/1f9c0f17c8f962fe7fa7b8cd81c349823fcd4a43ebb88bf360f574091f80/tensorflow-2.16.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ec5c57e6828b074ddb460aa69fbaa2cd502c6080a4e200e0163f2a2c9e20acfc", size = 218861480 }, - { url = "https://files.pythonhosted.org/packages/48/1f/0c5eb76e1ca25d36489c3b6125ee87867dc3bfdd409386304eefc65a0e17/tensorflow-2.16.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b085fc4b296e0daf2e8a8b71bf433acba0ba30d6c30f3d07ad05f10477c7762c", size = 590617949 }, - { url = "https://files.pythonhosted.org/packages/6b/02/affe1945a988ad4cc49c154b91a42aa6db8334b27c17a0a019dda22a3a25/tensorflow-2.16.2-cp39-cp39-win_amd64.whl", hash = "sha256:5d5951e91435909d6023f8c5afcfde9cee946a65ed03020fc8b87e627c04c6d1", size = 2069 }, -] - -[[package]] -name = "tensorflow-estimator" -version = "2.11.0" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/bb/e2/8bf618c7c30a525054230ee6d40b036d3e5abc2c4ff67cf7c7420a519204/tensorflow_estimator-2.11.0-py2.py3-none-any.whl", hash = "sha256:ea3b64acfff3d9a244f06178c9bdedcbdd3f125b67d0888dba8229498d06468b", size = 439214 }, ] [[package]] @@ -2907,10 +2586,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/43/9b/be27588352d7bd971696874db92d370f578715c17c0ccb27e4b13e16751e/tensorflow_io_gcs_filesystem-0.37.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:fe8dcc6d222258a080ac3dfcaaaa347325ce36a7a046277f6b3e19abc1efb3c5", size = 3479614 }, { url = "https://files.pythonhosted.org/packages/d3/46/962f47af08bd39fc9feb280d3192825431a91a078c856d17a78ae4884eb1/tensorflow_io_gcs_filesystem-0.37.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fbb33f1745f218464a59cecd9a18e32ca927b0f4d77abd8f8671b645cc1a182f", size = 4842077 }, { url = "https://files.pythonhosted.org/packages/f0/9b/790d290c232bce9b691391cf16e95a96e469669c56abfb1d9d0f35fa437c/tensorflow_io_gcs_filesystem-0.37.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:286389a203a5aee1a4fa2e53718c661091aa5fea797ff4fa6715ab8436b02e6c", size = 5085733 }, - { url = "https://files.pythonhosted.org/packages/12/4f/798df777498fab9dc683a658688e962f0af56454eb040c90f836fd9fa67c/tensorflow_io_gcs_filesystem-0.37.1-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:ee5da49019670ed364f3e5fb86b46420841a6c3cb52a300553c63841671b3e6d", size = 2470221 }, - { url = "https://files.pythonhosted.org/packages/7a/f9/ce6a0efde262a79361f0d67392fdf0d0406781a1ee4fc48d0d8b0553b311/tensorflow_io_gcs_filesystem-0.37.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:8943036bbf84e7a2be3705cb56f9c9df7c48c9e614bb941f0936c58e3ca89d6f", size = 3479613 }, - { url = "https://files.pythonhosted.org/packages/66/5f/334a011caa1eb97689274d1141df8e6b7a25e389f0390bdcd90235de9783/tensorflow_io_gcs_filesystem-0.37.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:426de1173cb81fbd62becec2012fc00322a295326d90eb6c737fab636f182aed", size = 4842075 }, - { url = "https://files.pythonhosted.org/packages/3d/cb/7dcee55fc5a7d7d8a862e12519322851cd5fe5b086f946fd71e4ae1ef281/tensorflow_io_gcs_filesystem-0.37.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0df00891669390078a003cedbdd3b8e645c718b111917535fa1d7725e95cdb95", size = 5087496 }, ] [[package]] @@ -2918,7 +2593,7 @@ name = "termcolor" version = "2.4.0" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version < '3.9'", + "python_full_version < '3.11'", ] sdist = { url = "https://files.pythonhosted.org/packages/10/56/d7d66a84f96d804155f6ff2873d065368b25a07222a6fd51c4f24ef6d764/termcolor-2.4.0.tar.gz", hash = "sha256:aab9e56047c8ac41ed798fa36d892a37aca6b3e9159f3e0c24bc64a9b3ac7b7a", size = 12664 } wheels = [ @@ -2932,23 +2607,12 @@ source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version >= '3.12'", "python_full_version == '3.11.*'", - "python_full_version == '3.10.*'", - "python_full_version == '3.9.*'", ] sdist = { url = "https://files.pythonhosted.org/packages/37/72/88311445fd44c455c7d553e61f95412cf89054308a1aa2434ab835075fc5/termcolor-2.5.0.tar.gz", hash = "sha256:998d8d27da6d48442e8e1f016119076b690d962507531df4890fcd2db2ef8a6f", size = 13057 } wheels = [ { url = "https://files.pythonhosted.org/packages/7f/be/df630c387a0a054815d60be6a97eb4e8f17385d5d6fe660e1c02750062b4/termcolor-2.5.0-py3-none-any.whl", hash = "sha256:37b17b5fc1e604945c2642c872a3764b5d547a48009871aea3edd3afa180afb8", size = 7755 }, ] -[[package]] -name = "threadpoolctl" -version = "3.5.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/bd/55/b5148dcbf72f5cde221f8bfe3b6a540da7aa1842f6b491ad979a6c8b84af/threadpoolctl-3.5.0.tar.gz", hash = "sha256:082433502dd922bf738de0d8bcc4fdcbf0979ff44c42bd40f5af8a282f6fa107", size = 41936 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/4b/2c/ffbf7a134b9ab11a67b0cf0726453cedd9c5043a4fe7a35d1cefa9a1bcfb/threadpoolctl-3.5.0-py3-none-any.whl", hash = "sha256:56c1e26c150397e58c4926da8eeee87533b1e32bef131bd4bf6a2f45f3185467", size = 18414 }, -] - [[package]] name = "toml" version = "0.10.2" @@ -2996,6 +2660,123 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f9/b6/a447b5e4ec71e13871be01ba81f5dfc9d0af7e473da256ff46bc0e24026f/tomlkit-0.13.2-py3-none-any.whl", hash = "sha256:7a974427f6e119197f670fbbbeae7bef749a6c14e793db934baefc1b5f03efde", size = 37955 }, ] +[[package]] +name = "torch" +version = "2.8.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.11'", +] +dependencies = [ + { name = "filelock", version = "3.16.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "fsspec", version = "2025.10.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "jinja2", marker = "python_full_version < '3.11'" }, + { name = "networkx", marker = "python_full_version < '3.11'" }, + { name = "nvidia-cublas-cu12", marker = "python_full_version < '3.11' and platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-cupti-cu12", marker = "python_full_version < '3.11' and platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-nvrtc-cu12", marker = "python_full_version < '3.11' and platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-runtime-cu12", marker = "python_full_version < '3.11' and platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cudnn-cu12", marker = "python_full_version < '3.11' and platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cufft-cu12", marker = "python_full_version < '3.11' and platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cufile-cu12", marker = "python_full_version < '3.11' and platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-curand-cu12", marker = "python_full_version < '3.11' and platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusolver-cu12", marker = "python_full_version < '3.11' and platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "python_full_version < '3.11' and platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparselt-cu12", marker = "python_full_version < '3.11' and platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nccl-cu12", marker = "python_full_version < '3.11' and platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "python_full_version < '3.11' and platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvtx-cu12", marker = "python_full_version < '3.11' and platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "sympy", marker = "python_full_version < '3.11'" }, + { name = "triton", version = "3.4.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11' and platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/63/28/110f7274254f1b8476c561dada127173f994afa2b1ffc044efb773c15650/torch-2.8.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:0be92c08b44009d4131d1ff7a8060d10bafdb7ddcb7359ef8d8c5169007ea905", size = 102052793 }, + { url = "https://files.pythonhosted.org/packages/70/1c/58da560016f81c339ae14ab16c98153d51c941544ae568da3cb5b1ceb572/torch-2.8.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:89aa9ee820bb39d4d72b794345cccef106b574508dd17dbec457949678c76011", size = 888025420 }, + { url = "https://files.pythonhosted.org/packages/70/87/f69752d0dd4ba8218c390f0438130c166fa264a33b7025adb5014b92192c/torch-2.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:e8e5bf982e87e2b59d932769938b698858c64cc53753894be25629bdf5cf2f46", size = 241363614 }, + { url = "https://files.pythonhosted.org/packages/ef/d6/e6d4c57e61c2b2175d3aafbfb779926a2cfd7c32eeda7c543925dceec923/torch-2.8.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:a3f16a58a9a800f589b26d47ee15aca3acf065546137fc2af039876135f4c760", size = 73611154 }, + { url = "https://files.pythonhosted.org/packages/8f/c4/3e7a3887eba14e815e614db70b3b529112d1513d9dae6f4d43e373360b7f/torch-2.8.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:220a06fd7af8b653c35d359dfe1aaf32f65aa85befa342629f716acb134b9710", size = 102073391 }, + { url = "https://files.pythonhosted.org/packages/5a/63/4fdc45a0304536e75a5e1b1bbfb1b56dd0e2743c48ee83ca729f7ce44162/torch-2.8.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:c12fa219f51a933d5f80eeb3a7a5d0cbe9168c0a14bbb4055f1979431660879b", size = 888063640 }, + { url = "https://files.pythonhosted.org/packages/84/57/2f64161769610cf6b1c5ed782bd8a780e18a3c9d48931319f2887fa9d0b1/torch-2.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:8c7ef765e27551b2fbfc0f41bcf270e1292d9bf79f8e0724848b1682be6e80aa", size = 241366752 }, + { url = "https://files.pythonhosted.org/packages/a4/5e/05a5c46085d9b97e928f3f037081d3d2b87fb4b4195030fc099aaec5effc/torch-2.8.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:5ae0524688fb6707c57a530c2325e13bb0090b745ba7b4a2cd6a3ce262572916", size = 73621174 }, + { url = "https://files.pythonhosted.org/packages/49/0c/2fd4df0d83a495bb5e54dca4474c4ec5f9c62db185421563deeb5dabf609/torch-2.8.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:e2fab4153768d433f8ed9279c8133a114a034a61e77a3a104dcdf54388838705", size = 101906089 }, + { url = "https://files.pythonhosted.org/packages/99/a8/6acf48d48838fb8fe480597d98a0668c2beb02ee4755cc136de92a0a956f/torch-2.8.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:b2aca0939fb7e4d842561febbd4ffda67a8e958ff725c1c27e244e85e982173c", size = 887913624 }, + { url = "https://files.pythonhosted.org/packages/af/8a/5c87f08e3abd825c7dfecef5a0f1d9aa5df5dd0e3fd1fa2f490a8e512402/torch-2.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:2f4ac52f0130275d7517b03a33d2493bab3693c83dcfadf4f81688ea82147d2e", size = 241326087 }, + { url = "https://files.pythonhosted.org/packages/be/66/5c9a321b325aaecb92d4d1855421e3a055abd77903b7dab6575ca07796db/torch-2.8.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:619c2869db3ada2c0105487ba21b5008defcc472d23f8b80ed91ac4a380283b0", size = 73630478 }, +] + +[[package]] +name = "torch" +version = "2.11.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.12'", + "python_full_version == '3.11.*'", +] +dependencies = [ + { name = "cuda-bindings", marker = "python_full_version >= '3.11' and sys_platform == 'linux'" }, + { name = "cuda-toolkit", extra = ["cublas", "cudart", "cufft", "cufile", "cupti", "curand", "cusolver", "cusparse", "nvjitlink", "nvrtc", "nvtx"], marker = "python_full_version >= '3.11' and sys_platform == 'linux'" }, + { name = "filelock", version = "3.17.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "fsspec", version = "2026.3.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "jinja2", marker = "python_full_version >= '3.11'" }, + { name = "networkx", marker = "python_full_version >= '3.11'" }, + { name = "nvidia-cudnn-cu13", marker = "python_full_version >= '3.11' and sys_platform == 'linux'" }, + { name = "nvidia-cusparselt-cu13", marker = "python_full_version >= '3.11' and sys_platform == 'linux'" }, + { name = "nvidia-nccl-cu13", marker = "python_full_version >= '3.11' and sys_platform == 'linux'" }, + { name = "nvidia-nvshmem-cu13", marker = "python_full_version >= '3.11' and sys_platform == 'linux'" }, + { name = "setuptools", version = "80.10.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "sympy", marker = "python_full_version >= '3.11'" }, + { name = "triton", version = "3.6.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' and sys_platform == 'linux'" }, + { name = "typing-extensions", marker = "python_full_version >= '3.11'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/ac/f2/c1690994afe461aae2d0cac62251e6802a703dec0a6c549c02ecd0de92a9/torch-2.11.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2c0d7fcfbc0c4e8bb5ebc3907cbc0c6a0da1b8f82b1fc6e14e914fa0b9baf74e", size = 80526521 }, + { url = "https://files.pythonhosted.org/packages/a4/f0/98ae802fa8c09d3149b0c8690741f3f5753c90e779bd28c9613257295945/torch-2.11.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:4cf8687f4aec3900f748d553483ef40e0ac38411c3c48d0a86a438f6d7a99b18", size = 419723025 }, + { url = "https://files.pythonhosted.org/packages/f9/1e/18a9b10b4bd34f12d4e561c52b0ae7158707b8193c6cfc0aad2b48167090/torch-2.11.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:1b32ceda909818a03b112006709b02be1877240c31750a8d9c6b7bf5f2d8a6e5", size = 530589207 }, + { url = "https://files.pythonhosted.org/packages/35/40/2d532e8c0e23705be9d1debce5bc37b68d59a39bda7584c26fe9668076fe/torch-2.11.0-cp310-cp310-win_amd64.whl", hash = "sha256:b3c712ae6fb8e7a949051a953fc412fe0a6940337336c3b6f905e905dac5157f", size = 114518313 }, + { url = "https://files.pythonhosted.org/packages/ae/0d/98b410492609e34a155fa8b121b55c7dca229f39636851c3a9ec20edea21/torch-2.11.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7b6a60d48062809f58595509c524b88e6ddec3ebe25833d6462eeab81e5f2ce4", size = 80529712 }, + { url = "https://files.pythonhosted.org/packages/84/03/acea680005f098f79fd70c1d9d5ccc0cb4296ec2af539a0450108232fc0c/torch-2.11.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:d91aac77f24082809d2c5a93f52a5f085032740a1ebc9252a7b052ef5a4fddc6", size = 419718178 }, + { url = "https://files.pythonhosted.org/packages/8c/8b/d7be22fbec9ffee6cff31a39f8750d4b3a65d349a286cf4aec74c2375662/torch-2.11.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:7aa2f9bbc6d4595ba72138026b2074be1233186150e9292865e04b7a63b8c67a", size = 530604548 }, + { url = "https://files.pythonhosted.org/packages/d1/bd/9912d30b68845256aabbb4a40aeefeef3c3b20db5211ccda653544ada4b6/torch-2.11.0-cp311-cp311-win_amd64.whl", hash = "sha256:73e24aaf8f36ab90d95cd1761208b2eb70841c2a9ca1a3f9061b39fc5331b708", size = 114519675 }, + { url = "https://files.pythonhosted.org/packages/6f/8b/69e3008d78e5cee2b30183340cc425081b78afc5eff3d080daab0adda9aa/torch-2.11.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4b5866312ee6e52ea625cd211dcb97d6a2cdc1131a5f15cc0d87eec948f6dd34", size = 80606338 }, + { url = "https://files.pythonhosted.org/packages/13/16/42e5915ebe4868caa6bac83a8ed59db57f12e9a61b7d749d584776ed53d5/torch-2.11.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:f99924682ef0aa6a4ab3b1b76f40dc6e273fca09f367d15a524266db100a723f", size = 419731115 }, + { url = "https://files.pythonhosted.org/packages/1a/c9/82638ef24d7877510f83baf821f5619a61b45568ce21c0a87a91576510aa/torch-2.11.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:0f68f4ac6d95d12e896c3b7a912b5871619542ec54d3649cf48cc1edd4dd2756", size = 530712279 }, + { url = "https://files.pythonhosted.org/packages/1c/ff/6756f1c7ee302f6d202120e0f4f05b432b839908f9071157302cedfc5232/torch-2.11.0-cp312-cp312-win_amd64.whl", hash = "sha256:fbf39280699d1b869f55eac536deceaa1b60bd6788ba74f399cc67e60a5fab10", size = 114556047 }, +] + +[[package]] +name = "triton" +version = "3.4.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.11'", +] +dependencies = [ + { name = "setuptools", version = "75.3.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/62/ee/0ee5f64a87eeda19bbad9bc54ae5ca5b98186ed00055281fd40fb4beb10e/triton-3.4.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7ff2785de9bc02f500e085420273bb5cc9c9bb767584a4aa28d6e360cec70128", size = 155430069 }, + { url = "https://files.pythonhosted.org/packages/7d/39/43325b3b651d50187e591eefa22e236b2981afcebaefd4f2fc0ea99df191/triton-3.4.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7b70f5e6a41e52e48cfc087436c8a28c17ff98db369447bcaff3b887a3ab4467", size = 155531138 }, + { url = "https://files.pythonhosted.org/packages/d0/66/b1eb52839f563623d185f0927eb3530ee4d5ffe9d377cdaf5346b306689e/triton-3.4.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:31c1d84a5c0ec2c0f8e8a072d7fd150cab84a9c239eaddc6706c081bfae4eb04", size = 155560068 }, +] + +[[package]] +name = "triton" +version = "3.6.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.12'", + "python_full_version == '3.11.*'", +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/44/ba/b1b04f4b291a3205d95ebd24465de0e5bf010a2df27a4e58a9b5f039d8f2/triton-3.6.0-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6c723cfb12f6842a0ae94ac307dba7e7a44741d720a40cf0e270ed4a4e3be781", size = 175972180 }, + { url = "https://files.pythonhosted.org/packages/8c/f7/f1c9d3424ab199ac53c2da567b859bcddbb9c9e7154805119f8bd95ec36f/triton-3.6.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a6550fae429e0667e397e5de64b332d1e5695b73650ee75a6146e2e902770bea", size = 188105201 }, + { url = "https://files.pythonhosted.org/packages/0f/2c/96f92f3c60387e14cc45aed49487f3486f89ea27106c1b1376913c62abe4/triton-3.6.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:49df5ef37379c0c2b5c0012286f80174fcf0e073e5ade1ca9a86c36814553651", size = 176081190 }, + { url = "https://files.pythonhosted.org/packages/e0/12/b05ba554d2c623bffa59922b94b0775673de251f468a9609bc9e45de95e9/triton-3.6.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e8e323d608e3a9bfcc2d9efcc90ceefb764a82b99dea12a86d643c72539ad5d3", size = 188214640 }, + { url = "https://files.pythonhosted.org/packages/17/5d/08201db32823bdf77a0e2b9039540080b2e5c23a20706ddba942924ebcd6/triton-3.6.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:374f52c11a711fd062b4bfbb201fd9ac0a5febd28a96fb41b4a0f51dde3157f4", size = 176128243 }, + { url = "https://files.pythonhosted.org/packages/ab/a8/cdf8b3e4c98132f965f88c2313a4b493266832ad47fb52f23d14d4f86bb5/triton-3.6.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:74caf5e34b66d9f3a429af689c1c7128daba1d8208df60e81106b115c00d6fca", size = 188266850 }, +] + [[package]] name = "typing-extensions" version = "4.12.2" @@ -3019,7 +2800,7 @@ name = "urllib3" version = "2.2.3" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version < '3.9'", + "python_full_version < '3.11'", ] sdist = { url = "https://files.pythonhosted.org/packages/ed/63/22ba4ebfe7430b76388e7cd448d5478814d3032121827c12a2cc287e2260/urllib3-2.2.3.tar.gz", hash = "sha256:e7d814a81dad81e6caf2ec9fdedb284ecc9c73076b62654547cc64ccdcae26e9", size = 300677 } wheels = [ @@ -3033,8 +2814,6 @@ source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version >= '3.12'", "python_full_version == '3.11.*'", - "python_full_version == '3.10.*'", - "python_full_version == '3.9.*'", ] sdist = { url = "https://files.pythonhosted.org/packages/aa/63/e53da845320b757bf29ef6a9062f5c669fe997973f966045cb019c3f4b66/urllib3-2.3.0.tar.gz", hash = "sha256:f8c5449b3cf0861679ce7e0503c7b44b5ec981bec0d1d3795a07f1ba96f0204d", size = 307268 } wheels = [ @@ -3047,8 +2826,8 @@ version = "20.29.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "distlib" }, - { name = "filelock", version = "3.16.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, - { name = "filelock", version = "3.17.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, + { name = "filelock", version = "3.16.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "filelock", version = "3.17.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "platformdirs" }, ] sdist = { url = "https://files.pythonhosted.org/packages/f1/88/dacc875dd54a8acadb4bcbfd4e3e86df8be75527116c91d8f9784f5e9cab/virtualenv-20.29.2.tar.gz", hash = "sha256:fdaabebf6d03b5ba83ae0a02cfe96f48a716f4fae556461d180825866f75b728", size = 4320272 } @@ -3061,7 +2840,7 @@ name = "watchdog" version = "4.0.2" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version < '3.9'", + "python_full_version < '3.11'", ] sdist = { url = "https://files.pythonhosted.org/packages/4f/38/764baaa25eb5e35c9a043d4c4588f9836edfe52a708950f4b6d5f714fd42/watchdog-4.0.2.tar.gz", hash = "sha256:b4dfbb6c49221be4535623ea4474a4d6ee0a9cef4a80b20c28db4d858b64e270", size = 126587 } wheels = [ @@ -3074,18 +2853,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/92/f5/ea22b095340545faea37ad9a42353b265ca751f543da3fb43f5d00cdcd21/watchdog-4.0.2-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:1cdcfd8142f604630deef34722d695fb455d04ab7cfe9963055df1fc69e6727a", size = 100342 }, { url = "https://files.pythonhosted.org/packages/cb/d2/8ce97dff5e465db1222951434e3115189ae54a9863aef99c6987890cc9ef/watchdog-4.0.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d7ab624ff2f663f98cd03c8b7eedc09375a911794dfea6bf2a359fcc266bff29", size = 92306 }, { url = "https://files.pythonhosted.org/packages/49/c4/1aeba2c31b25f79b03b15918155bc8c0b08101054fc727900f1a577d0d54/watchdog-4.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:132937547a716027bd5714383dfc40dc66c26769f1ce8a72a859d6a48f371f3a", size = 92915 }, - { url = "https://files.pythonhosted.org/packages/55/08/1a9086a3380e8828f65b0c835b86baf29ebb85e5e94a2811a2eb4f889cfd/watchdog-4.0.2-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:aa160781cafff2719b663c8a506156e9289d111d80f3387cf3af49cedee1f040", size = 100255 }, - { url = "https://files.pythonhosted.org/packages/6c/3e/064974628cf305831f3f78264800bd03b3358ec181e3e9380a36ff156b93/watchdog-4.0.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f6ee8dedd255087bc7fe82adf046f0b75479b989185fb0bdf9a98b612170eac7", size = 92257 }, - { url = "https://files.pythonhosted.org/packages/23/69/1d2ad9c12d93bc1e445baa40db46bc74757f3ffc3a3be592ba8dbc51b6e5/watchdog-4.0.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0b4359067d30d5b864e09c8597b112fe0a0a59321a0f331498b013fb097406b4", size = 92886 }, - { url = "https://files.pythonhosted.org/packages/68/eb/34d3173eceab490d4d1815ba9a821e10abe1da7a7264a224e30689b1450c/watchdog-4.0.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:770eef5372f146997638d737c9a3c597a3b41037cfbc5c41538fc27c09c3a3f9", size = 100254 }, - { url = "https://files.pythonhosted.org/packages/18/a1/4bbafe7ace414904c2cc9bd93e472133e8ec11eab0b4625017f0e34caad8/watchdog-4.0.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:eeea812f38536a0aa859972d50c76e37f4456474b02bd93674d1947cf1e39578", size = 92249 }, - { url = "https://files.pythonhosted.org/packages/f3/11/ec5684e0ca692950826af0de862e5db167523c30c9cbf9b3f4ce7ec9cc05/watchdog-4.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b2c45f6e1e57ebb4687690c05bc3a2c1fb6ab260550c4290b8abb1335e0fd08b", size = 92891 }, { url = "https://files.pythonhosted.org/packages/3b/9a/6f30f023324de7bad8a3eb02b0afb06bd0726003a3550e9964321315df5a/watchdog-4.0.2-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:10b6683df70d340ac3279eff0b2766813f00f35a1d37515d2c99959ada8f05fa", size = 91775 }, { url = "https://files.pythonhosted.org/packages/87/62/8be55e605d378a154037b9ba484e00a5478e627b69c53d0f63e3ef413ba6/watchdog-4.0.2-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:f7c739888c20f99824f7aa9d31ac8a97353e22d0c0e54703a547a218f6637eb3", size = 92255 }, - { url = "https://files.pythonhosted.org/packages/6b/59/12e03e675d28f450bade6da6bc79ad6616080b317c472b9ae688d2495a03/watchdog-4.0.2-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:c100d09ac72a8a08ddbf0629ddfa0b8ee41740f9051429baa8e31bb903ad7508", size = 91682 }, - { url = "https://files.pythonhosted.org/packages/ef/69/241998de9b8e024f5c2fbdf4324ea628b4231925305011ca8b7e1c3329f6/watchdog-4.0.2-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:f5315a8c8dd6dd9425b974515081fc0aadca1d1d61e078d2246509fd756141ee", size = 92249 }, - { url = "https://files.pythonhosted.org/packages/70/3f/2173b4d9581bc9b5df4d7f2041b6c58b5e5448407856f68d4be9981000d0/watchdog-4.0.2-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:2d468028a77b42cc685ed694a7a550a8d1771bb05193ba7b24006b8241a571a1", size = 91773 }, - { url = "https://files.pythonhosted.org/packages/f0/de/6fff29161d5789048f06ef24d94d3ddcc25795f347202b7ea503c3356acb/watchdog-4.0.2-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:f15edcae3830ff20e55d1f4e743e92970c847bcddc8b7509bcd172aa04de506e", size = 92250 }, { url = "https://files.pythonhosted.org/packages/8a/b1/25acf6767af6f7e44e0086309825bd8c098e301eed5868dc5350642124b9/watchdog-4.0.2-py3-none-manylinux2014_aarch64.whl", hash = "sha256:936acba76d636f70db8f3c66e76aa6cb5136a936fc2a5088b9ce1c7a3508fc83", size = 82947 }, { url = "https://files.pythonhosted.org/packages/e8/90/aebac95d6f954bd4901f5d46dcd83d68e682bfd21798fd125a95ae1c9dbf/watchdog-4.0.2-py3-none-manylinux2014_armv7l.whl", hash = "sha256:e252f8ca942a870f38cf785aef420285431311652d871409a64e2a0a52a2174c", size = 82942 }, { url = "https://files.pythonhosted.org/packages/15/3a/a4bd8f3b9381824995787488b9282aff1ed4667e1110f31a87b871ea851c/watchdog-4.0.2-py3-none-manylinux2014_i686.whl", hash = "sha256:0e83619a2d5d436a7e58a1aea957a3c1ccbf9782c43c0b4fed80580e5e4acd1a", size = 82947 }, @@ -3105,8 +2874,6 @@ source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version >= '3.12'", "python_full_version == '3.11.*'", - "python_full_version == '3.10.*'", - "python_full_version == '3.9.*'", ] sdist = { url = "https://files.pythonhosted.org/packages/db/7d/7f3d619e951c88ed75c6037b246ddcf2d322812ee8ea189be89511721d54/watchdog-6.0.0.tar.gz", hash = "sha256:9ddf7c82fda3ae8e24decda1338ede66e1c99883db93711d8fb941eaa2d8c282", size = 131220 } wheels = [ @@ -3119,13 +2886,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/39/ea/3930d07dafc9e286ed356a679aa02d777c06e9bfd1164fa7c19c288a5483/watchdog-6.0.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:bdd4e6f14b8b18c334febb9c4425a878a2ac20efd1e0b231978e7b150f92a948", size = 96471 }, { url = "https://files.pythonhosted.org/packages/12/87/48361531f70b1f87928b045df868a9fd4e253d9ae087fa4cf3f7113be363/watchdog-6.0.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c7c15dda13c4eb00d6fb6fc508b3c0ed88b9d5d374056b239c4ad1611125c860", size = 88449 }, { url = "https://files.pythonhosted.org/packages/5b/7e/8f322f5e600812e6f9a31b75d242631068ca8f4ef0582dd3ae6e72daecc8/watchdog-6.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6f10cb2d5902447c7d0da897e2c6768bca89174d0c6e1e30abec5421af97a5b0", size = 89054 }, - { url = "https://files.pythonhosted.org/packages/05/52/7223011bb760fce8ddc53416beb65b83a3ea6d7d13738dde75eeb2c89679/watchdog-6.0.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:e6f0e77c9417e7cd62af82529b10563db3423625c5fce018430b249bf977f9e8", size = 96390 }, - { url = "https://files.pythonhosted.org/packages/9c/62/d2b21bc4e706d3a9d467561f487c2938cbd881c69f3808c43ac1ec242391/watchdog-6.0.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:90c8e78f3b94014f7aaae121e6b909674df5b46ec24d6bebc45c44c56729af2a", size = 88386 }, - { url = "https://files.pythonhosted.org/packages/ea/22/1c90b20eda9f4132e4603a26296108728a8bfe9584b006bd05dd94548853/watchdog-6.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e7631a77ffb1f7d2eefa4445ebbee491c720a5661ddf6df3498ebecae5ed375c", size = 89017 }, { url = "https://files.pythonhosted.org/packages/30/ad/d17b5d42e28a8b91f8ed01cb949da092827afb9995d4559fd448d0472763/watchdog-6.0.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:c7ac31a19f4545dd92fc25d200694098f42c9a8e391bc00bdd362c5736dbf881", size = 87902 }, { url = "https://files.pythonhosted.org/packages/5c/ca/c3649991d140ff6ab67bfc85ab42b165ead119c9e12211e08089d763ece5/watchdog-6.0.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:9513f27a1a582d9808cf21a07dae516f0fab1cf2d7683a742c498b93eedabb11", size = 88380 }, - { url = "https://files.pythonhosted.org/packages/5b/79/69f2b0e8d3f2afd462029031baafb1b75d11bb62703f0e1022b2e54d49ee/watchdog-6.0.0-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:7a0e56874cfbc4b9b05c60c8a1926fedf56324bb08cfbc188969777940aef3aa", size = 87903 }, - { url = "https://files.pythonhosted.org/packages/e2/2b/dc048dd71c2e5f0f7ebc04dd7912981ec45793a03c0dc462438e0591ba5d/watchdog-6.0.0-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:e6439e374fc012255b4ec786ae3c4bc838cd7309a540e5fe0952d03687d8804e", size = 88381 }, { url = "https://files.pythonhosted.org/packages/a9/c7/ca4bf3e518cb57a686b2feb4f55a1892fd9a3dd13f470fca14e00f80ea36/watchdog-6.0.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7607498efa04a3542ae3e05e64da8202e58159aa1fa4acddf7678d34a35d4f13", size = 79079 }, { url = "https://files.pythonhosted.org/packages/5c/51/d46dc9332f9a647593c947b4b88e2381c8dfc0942d15b8edc0310fa4abb1/watchdog-6.0.0-py3-none-manylinux2014_armv7l.whl", hash = "sha256:9041567ee8953024c83343288ccc458fd0a2d811d6a0fd68c4c22609e3490379", size = 79078 }, { url = "https://files.pythonhosted.org/packages/d4/57/04edbf5e169cd318d5f07b4766fee38e825d64b6913ca157ca32d1a42267/watchdog-6.0.0-py3-none-manylinux2014_i686.whl", hash = "sha256:82dc3e3143c7e38ec49d61af98d6558288c415eac98486a5c581726e0737c00e", size = 79076 }, @@ -3143,10 +2905,10 @@ name = "werkzeug" version = "3.0.6" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version < '3.9'", + "python_full_version < '3.11'", ] dependencies = [ - { name = "markupsafe", version = "2.1.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9'" }, + { name = "markupsafe", version = "2.1.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/d4/f9/0ba83eaa0df9b9e9d1efeb2ea351d0677c37d41ee5d0f91e98423c7281c9/werkzeug-3.0.6.tar.gz", hash = "sha256:a8dd59d4de28ca70471a34cba79bed5f7ef2e036a76b3ab0835474246eb41f8d", size = 805170 } wheels = [ @@ -3160,11 +2922,9 @@ source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version >= '3.12'", "python_full_version == '3.11.*'", - "python_full_version == '3.10.*'", - "python_full_version == '3.9.*'", ] dependencies = [ - { name = "markupsafe", version = "3.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9'" }, + { name = "markupsafe", version = "3.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/9f/69/83029f1f6300c5fb2471d621ab06f6ec6b3324685a2ce0f9777fd4a8b71e/werkzeug-3.1.3.tar.gz", hash = "sha256:60723ce945c19328679790e3282cc758aa4a6040e4bb330f53d30fa546d44746", size = 806925 } wheels = [ @@ -3206,48 +2966,4 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f2/31/cbce966b6760e62d005c237961e839a755bf0c907199248394e2ee03ab05/wrapt-1.14.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:49ef582b7a1152ae2766557f0550a9fcbf7bbd76f43fbdc94dd3bf07cc7168be", size = 83361 }, { url = "https://files.pythonhosted.org/packages/9a/aa/ab46fb18072b86e87e0965a402f8723217e8c0312d1b3e2a91308df924ab/wrapt-1.14.1-cp311-cp311-win32.whl", hash = "sha256:358fe87cc899c6bb0ddc185bf3dbfa4ba646f05b1b0b9b5a27c2cb92c2cea204", size = 33454 }, { url = "https://files.pythonhosted.org/packages/ba/7e/14113996bc6ee68eb987773b4139c87afd3ceff60e27e37648aa5eb2798a/wrapt-1.14.1-cp311-cp311-win_amd64.whl", hash = "sha256:26046cd03936ae745a502abf44dac702a5e6880b2b01c29aea8ddf3353b68224", size = 35616 }, - { url = "https://files.pythonhosted.org/packages/33/cd/7335d8b82ff0a442581ab37a8d275ad76b4c1f33ace63c1a4d7c23791eee/wrapt-1.14.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:8c0ce1e99116d5ab21355d8ebe53d9460366704ea38ae4d9f6933188f327b456", size = 35231 }, - { url = "https://files.pythonhosted.org/packages/5e/d3/bd44864e0274b7e162e2a68c71fffbd8b3a7b620efd23320fd0f70333cff/wrapt-1.14.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e3fb1677c720409d5f671e39bac6c9e0e422584e5f518bfd50aa4cbbea02433f", size = 35933 }, - { url = "https://files.pythonhosted.org/packages/23/8b/e4de40ac2fa6d53e694310c576e160bec3db8a282fbdcd5596544f6bc69e/wrapt-1.14.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:642c2e7a804fcf18c222e1060df25fc210b9c58db7c91416fb055897fc27e8cc", size = 81192 }, - { url = "https://files.pythonhosted.org/packages/12/cd/da6611401655ac2b8496b316ad9e21a3fd4f8e62e2c3b3e3c50207770517/wrapt-1.14.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7b7c050ae976e286906dd3f26009e117eb000fb2cf3533398c5ad9ccc86867b1", size = 73727 }, - { url = "https://files.pythonhosted.org/packages/36/ee/944dc7e5462662270e8a379755bcc543fc8f09029866288060dc163ed5b4/wrapt-1.14.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ef3f72c9666bba2bab70d2a8b79f2c6d2c1a42a7f7e2b0ec83bb2f9e383950af", size = 81021 }, - { url = "https://files.pythonhosted.org/packages/94/4b/ff8d58aee32ed91744f1ff4970e590f0c8fdda3fa6d702dc82281e0309bd/wrapt-1.14.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:01c205616a89d09827986bc4e859bcabd64f5a0662a7fe95e0d359424e0e071b", size = 85435 }, - { url = "https://files.pythonhosted.org/packages/e8/f6/7e30a8c53d27ef8c1ff872dc4fb75247c99eb73d834c91a49a55d046c127/wrapt-1.14.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:5a0f54ce2c092aaf439813735584b9537cad479575a09892b8352fea5e988dc0", size = 78500 }, - { url = "https://files.pythonhosted.org/packages/da/f4/7af9e01b6c1126b2daef72d5ba2cbf59a7229fd57c5b23166f694d758a8f/wrapt-1.14.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:2cf71233a0ed05ccdabe209c606fe0bac7379fdcf687f39b944420d2a09fdb57", size = 85457 }, - { url = "https://files.pythonhosted.org/packages/88/ef/05655df7648752ae0a57fe2b9820e340ff025cecec9341aad7936c589a2f/wrapt-1.14.1-cp38-cp38-win32.whl", hash = "sha256:aa31fdcc33fef9eb2552cbcbfee7773d5a6792c137b359e82879c101e98584c5", size = 33397 }, - { url = "https://files.pythonhosted.org/packages/c7/1b/0cdff572d22600fcf47353e8eb1077d83cab3f161ebfb4843565c6e07e66/wrapt-1.14.1-cp38-cp38-win_amd64.whl", hash = "sha256:d1967f46ea8f2db647c786e78d8cc7e4313dbd1b0aca360592d8027b8508e24d", size = 35564 }, - { url = "https://files.pythonhosted.org/packages/d9/ab/3ba5816dd466ffd7242913708771d258569825ab76fd29d7fd85b9361311/wrapt-1.14.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3232822c7d98d23895ccc443bbdf57c7412c5a65996c30442ebe6ed3df335383", size = 35234 }, - { url = "https://files.pythonhosted.org/packages/bb/70/73c54e24ea69a8b06ae9649e61d5e64f2b4bdfc6f202fc7794abeac1ed20/wrapt-1.14.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:988635d122aaf2bdcef9e795435662bcd65b02f4f4c1ae37fbee7401c440b3a7", size = 35933 }, - { url = "https://files.pythonhosted.org/packages/38/38/5b338163b3b4f1ab718306984678c3d180b85a25d72654ea4c61aa6b0968/wrapt-1.14.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9cca3c2cdadb362116235fdbd411735de4328c61425b0aa9f872fd76d02c4e86", size = 77892 }, - { url = "https://files.pythonhosted.org/packages/0a/61/330f24065b8f2fc02f94321092a24e0c30aefcbac89ab5c860e180366c9f/wrapt-1.14.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d52a25136894c63de15a35bc0bdc5adb4b0e173b9c0d07a2be9d3ca64a332735", size = 70318 }, - { url = "https://files.pythonhosted.org/packages/e0/6a/3c660fa34c8106aa9719f2a6636c1c3ea7afd5931ae665eb197fdf4def84/wrapt-1.14.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40e7bc81c9e2b2734ea4bc1aceb8a8f0ceaac7c5299bc5d69e37c44d9081d43b", size = 77752 }, - { url = "https://files.pythonhosted.org/packages/e0/20/9716fb522d17a726364c4d032c8806ffe312268773dd46a394436b2787cc/wrapt-1.14.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b9b7a708dd92306328117d8c4b62e2194d00c365f18eff11a9b53c6f923b01e3", size = 82284 }, - { url = "https://files.pythonhosted.org/packages/6a/12/76bbe26dc39d05f1a7be8d570d91c87bb79297e08e885148ed670ed17b7b/wrapt-1.14.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:6a9a25751acb379b466ff6be78a315e2b439d4c94c1e99cb7266d40a537995d3", size = 75170 }, - { url = "https://files.pythonhosted.org/packages/f9/3c/110e52b9da396a4ef3a0521552a1af9c7875a762361f48678c1ac272fd7e/wrapt-1.14.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:34aa51c45f28ba7f12accd624225e2b1e5a3a45206aa191f6f9aac931d9d56fe", size = 82281 }, - { url = "https://files.pythonhosted.org/packages/4b/07/782463e367a7c6b418af231ded753e4b2dd3293a157d9b0bb010806fc0c0/wrapt-1.14.1-cp39-cp39-win32.whl", hash = "sha256:dee0ce50c6a2dd9056c20db781e9c1cfd33e77d2d569f5d1d9321c641bb903d5", size = 33404 }, - { url = "https://files.pythonhosted.org/packages/5b/02/5ac7ea3b6722c84a2882d349ac581a9711b4047fe7a58475903832caa295/wrapt-1.14.1-cp39-cp39-win_amd64.whl", hash = "sha256:dee60e1de1898bde3b238f18340eec6148986da0455d8ba7848d50470a7a32fb", size = 35557 }, -] - -[[package]] -name = "zipp" -version = "3.20.2" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version < '3.9'", -] -sdist = { url = "https://files.pythonhosted.org/packages/54/bf/5c0000c44ebc80123ecbdddba1f5dcd94a5ada602a9c225d84b5aaa55e86/zipp-3.20.2.tar.gz", hash = "sha256:bc9eb26f4506fda01b81bcde0ca78103b6e62f991b381fec825435c836edbc29", size = 24199 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/62/8b/5ba542fa83c90e09eac972fc9baca7a88e7e7ca4b221a89251954019308b/zipp-3.20.2-py3-none-any.whl", hash = "sha256:a817ac80d6cf4b23bf7f2828b7cabf326f15a001bea8b1f9b49631780ba28350", size = 9200 }, -] - -[[package]] -name = "zipp" -version = "3.21.0" -source = { registry = "https://pypi.org/simple" } -resolution-markers = [ - "python_full_version == '3.9.*'", -] -sdist = { url = "https://files.pythonhosted.org/packages/3f/50/bad581df71744867e9468ebd0bcd6505de3b275e06f202c2cb016e3ff56f/zipp-3.21.0.tar.gz", hash = "sha256:2c9958f6430a2040341a52eb608ed6dd93ef4392e02ffe219417c1b28b5dd1f4", size = 24545 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/b7/1a/7e4798e9339adc931158c9d69ecc34f5e6791489d469f5e50ec15e35f458/zipp-3.21.0-py3-none-any.whl", hash = "sha256:ac1bbe05fd2991f160ebce24ffbac5f6d11d83dc90891255885223d42b3cd931", size = 9630 }, ]