diff --git a/optionlab/engine.py b/optionlab/engine.py index 93051a4..eb835dd 100644 --- a/optionlab/engine.py +++ b/optionlab/engine.py @@ -66,7 +66,7 @@ def run_strategy(inputs_data: Inputs | dict) -> Outputs: def _init_inputs(inputs: Inputs) -> EngineData: data = EngineData( - stock_price_array=create_price_seq(inputs.min_stock, inputs.max_stock), + stock_price_array=create_price_seq(inputs.min_stock, inputs.max_stock, inputs.price_increment), terminal_stock_prices=inputs.array if inputs.model == "array" else array([]), inputs=inputs, ) diff --git a/optionlab/models.py b/optionlab/models.py index 20b0e47..74e30ce 100644 --- a/optionlab/models.py +++ b/optionlab/models.py @@ -329,6 +329,14 @@ class Inputs(BaseModel): The default is an empty array. """ + price_increment: float = Field(0.01, gt=0.0) + """ + Price increment for generating the stock price array used in profit/loss calculations. + Smaller values create more data points but increase computation time. + + The default is 0.01 (one cent increments). + """ + model_config = ConfigDict(arbitrary_types_allowed=True) @field_validator("strategy") diff --git a/optionlab/support.py b/optionlab/support.py index bb32aad..cdf4fee 100644 --- a/optionlab/support.py +++ b/optionlab/support.py @@ -174,12 +174,11 @@ def get_pl_profile_bs( return profile, n * cost - commission - @lru_cache -def create_price_seq(min_price: float, max_price: float) -> np.ndarray: +def create_price_seq(min_price: float, max_price: float, increment: float = 0.01) -> np.ndarray: """ Generates a sequence of stock prices from a minimum to a maximum price with - increment $0.01. + the specified increment. Parameters ---------- @@ -187,17 +186,28 @@ def create_price_seq(min_price: float, max_price: float) -> np.ndarray: `max_price`: maximum stock price in the range. + `increment`: price increment between consecutive values. The default is 0.01. + Returns ------- Array of sequential stock prices. """ if max_price > min_price: - return round((arange((max_price - min_price) * 100 + 1) * 0.01 + min_price), 2) + if increment <= 0.0: + raise ValueError("Increment must be greater than 0!") + num_points = int((max_price - min_price) / increment) + 1 + # Round to appropriate decimal places based on increment size + if increment >= 1.0: + decimal_places = 0 + elif increment >= 0.1: + decimal_places = 1 + else: + decimal_places = 2 + return round((arange(num_points) * increment + min_price), decimal_places) else: raise ValueError("Maximum price cannot be less than minimum price!") - def get_pop( s: np.ndarray, profit: np.ndarray,