diff --git a/.gitignore b/.gitignore index 7d4676ce3..93ef00ce7 100644 --- a/.gitignore +++ b/.gitignore @@ -172,4 +172,7 @@ mase-trainer/ test-trainer/ # DiffLogic: tutorial files -docs/tutorials/difflogic/data-mnist/ \ No newline at end of file +docs/tutorials/difflogic/data-mnist/ + +# For testing the emit +generated-SV \ No newline at end of file diff --git a/docs/labs/bram/hardware/rtl/fc1_bias_source.sv b/docs/labs/bram/hardware/rtl/fc1_bias_source.sv new file mode 100644 index 000000000..066b2c58e --- /dev/null +++ b/docs/labs/bram/hardware/rtl/fc1_bias_source.sv @@ -0,0 +1,115 @@ + +// ===================================== +// Mase Hardware +// Parameter: fc1_bias +// 04/02/2026 21:26:55 +// ===================================== + +`timescale 1 ns / 1 ps +module fc1_bias_rom #( + parameter DWIDTH = 32, + parameter MEM_SIZE = 2, + parameter AWIDTH = $clog2(MEM_SIZE) + 1 +) ( + input clk, + input logic [AWIDTH-1:0] addr0, + input ce0, + output logic [DWIDTH-1:0] q0 +); + + logic [DWIDTH-1:0] ram[0:MEM_SIZE-1]; + logic [DWIDTH-1:0] q0_t0; + logic [DWIDTH-1:0] q0_t1; + + initial begin + $readmemh("./bram/hardware/rtl/fc1_bias_rom.dat", ram); + end + + assign q0 = q0_t1; + + always_ff @(posedge clk) if (ce0) q0_t1 <= q0_t0; + always_ff @(posedge clk) if (ce0) q0_t0 <= ram[addr0]; + +endmodule + +`timescale 1 ns / 1 ps +module fc1_bias #( + parameter DATA_WIDTH = 32'd32, + parameter ADDR_RANGE = 32'd2, + parameter ADDR_WIDTH = $clog2(ADDR_RANGE) + 1 +) ( + input reset, + input clk, + input logic [ADDR_WIDTH - 1:0] address0, + input ce0, + output logic [DATA_WIDTH - 1:0] q0 +); + + fc1_bias_rom fc1_bias_rom_U ( + .clk(clk), + .addr0(address0), + .ce0(ce0), + .q0(q0) + ); + +endmodule + + +`timescale 1ns / 1ps +module fc1_bias_source #( + parameter BIAS_TENSOR_SIZE_DIM_0 = 32, + parameter BIAS_TENSOR_SIZE_DIM_1 = 1, + parameter BIAS_PRECISION_0 = 16, + parameter BIAS_PRECISION_1 = 3, + + parameter BIAS_PARALLELISM_DIM_0 = 1, + parameter BIAS_PARALLELISM_DIM_1 = 1, + parameter OUT_DEPTH = ((BIAS_TENSOR_SIZE_DIM_0 + BIAS_PARALLELISM_DIM_0 - 1) / BIAS_PARALLELISM_DIM_0) * ((BIAS_TENSOR_SIZE_DIM_1 + BIAS_PARALLELISM_DIM_1 - 1) / BIAS_PARALLELISM_DIM_1) +) ( + input clk, + input rst, + + output logic [BIAS_PRECISION_0-1:0] data_out [BIAS_PARALLELISM_DIM_0 * BIAS_PARALLELISM_DIM_1-1:0], + output data_out_valid, + input data_out_ready +); + // 1-bit wider so IN_DEPTH also fits. + localparam COUNTER_WIDTH = $clog2(OUT_DEPTH); + logic [COUNTER_WIDTH:0] counter; + + always_ff @(posedge clk) + if (rst) counter <= 0; + else begin + if (data_out_ready) begin + if (counter == OUT_DEPTH - 1) counter <= 0; + else counter <= counter + 1; + end + end + + logic [1:0] clear; + always_ff @(posedge clk) + if (rst) clear <= 0; + else if ((data_out_ready == 1) && (clear != 2)) clear <= clear + 1; + logic ce0; + assign ce0 = data_out_ready; + + logic [BIAS_PRECISION_0*BIAS_PARALLELISM_DIM_0*BIAS_PARALLELISM_DIM_1-1:0] data_vector; + fc1_bias #( + .DATA_WIDTH(BIAS_PRECISION_0 * BIAS_PARALLELISM_DIM_0 * BIAS_PARALLELISM_DIM_1), + .ADDR_RANGE(OUT_DEPTH) + ) fc1_bias_mem ( + .clk(clk), + .reset(rst), + .address0(counter), + .ce0(ce0), + .q0(data_vector) + ); + + // Cocotb/verilator does not support array flattening, so + // we need to manually add some reshaping process. + for (genvar j = 0; j < BIAS_PARALLELISM_DIM_0 * BIAS_PARALLELISM_DIM_1; j++) + assign data_out[j] = data_vector[BIAS_PRECISION_0*j+BIAS_PRECISION_0-1:BIAS_PRECISION_0*j]; + + assign data_out_valid = clear == 2; + +endmodule diff --git a/docs/labs/bram/hardware/rtl/fc1_weight_source.sv b/docs/labs/bram/hardware/rtl/fc1_weight_source.sv new file mode 100644 index 000000000..210aa1d09 --- /dev/null +++ b/docs/labs/bram/hardware/rtl/fc1_weight_source.sv @@ -0,0 +1,115 @@ + +// ===================================== +// Mase Hardware +// Parameter: fc1_weight +// 04/02/2026 21:26:55 +// ===================================== + +`timescale 1 ns / 1 ps +module fc1_weight_rom #( + parameter DWIDTH = 128, + parameter MEM_SIZE = 2, + parameter AWIDTH = $clog2(MEM_SIZE) + 1 +) ( + input clk, + input logic [AWIDTH-1:0] addr0, + input ce0, + output logic [DWIDTH-1:0] q0 +); + + logic [DWIDTH-1:0] ram[0:MEM_SIZE-1]; + logic [DWIDTH-1:0] q0_t0; + logic [DWIDTH-1:0] q0_t1; + + initial begin + $readmemh("./bram/hardware/rtl/fc1_weight_rom.dat", ram); + end + + assign q0 = q0_t1; + + always_ff @(posedge clk) if (ce0) q0_t1 <= q0_t0; + always_ff @(posedge clk) if (ce0) q0_t0 <= ram[addr0]; + +endmodule + +`timescale 1 ns / 1 ps +module fc1_weight #( + parameter DATA_WIDTH = 32'd128, + parameter ADDR_RANGE = 32'd2, + parameter ADDR_WIDTH = $clog2(ADDR_RANGE) + 1 +) ( + input reset, + input clk, + input logic [ADDR_WIDTH - 1:0] address0, + input ce0, + output logic [DATA_WIDTH - 1:0] q0 +); + + fc1_weight_rom fc1_weight_rom_U ( + .clk(clk), + .addr0(address0), + .ce0(ce0), + .q0(q0) + ); + +endmodule + + +`timescale 1ns / 1ps +module fc1_weight_source #( + parameter WEIGHT_TENSOR_SIZE_DIM_0 = 32, + parameter WEIGHT_TENSOR_SIZE_DIM_1 = 1, + parameter WEIGHT_PRECISION_0 = 16, + parameter WEIGHT_PRECISION_1 = 3, + + parameter WEIGHT_PARALLELISM_DIM_0 = 1, + parameter WEIGHT_PARALLELISM_DIM_1 = 1, + parameter OUT_DEPTH = ((WEIGHT_TENSOR_SIZE_DIM_0 + WEIGHT_PARALLELISM_DIM_0 - 1) / WEIGHT_PARALLELISM_DIM_0) * ((WEIGHT_TENSOR_SIZE_DIM_1 + WEIGHT_PARALLELISM_DIM_1 - 1) / WEIGHT_PARALLELISM_DIM_1) +) ( + input clk, + input rst, + + output logic [WEIGHT_PRECISION_0-1:0] data_out [WEIGHT_PARALLELISM_DIM_0 * WEIGHT_PARALLELISM_DIM_1-1:0], + output data_out_valid, + input data_out_ready +); + // 1-bit wider so IN_DEPTH also fits. + localparam COUNTER_WIDTH = $clog2(OUT_DEPTH); + logic [COUNTER_WIDTH:0] counter; + + always_ff @(posedge clk) + if (rst) counter <= 0; + else begin + if (data_out_ready) begin + if (counter == OUT_DEPTH - 1) counter <= 0; + else counter <= counter + 1; + end + end + + logic [1:0] clear; + always_ff @(posedge clk) + if (rst) clear <= 0; + else if ((data_out_ready == 1) && (clear != 2)) clear <= clear + 1; + logic ce0; + assign ce0 = data_out_ready; + + logic [WEIGHT_PRECISION_0*WEIGHT_PARALLELISM_DIM_0*WEIGHT_PARALLELISM_DIM_1-1:0] data_vector; + fc1_weight #( + .DATA_WIDTH(WEIGHT_PRECISION_0 * WEIGHT_PARALLELISM_DIM_0 * WEIGHT_PARALLELISM_DIM_1), + .ADDR_RANGE(OUT_DEPTH) + ) fc1_weight_mem ( + .clk(clk), + .reset(rst), + .address0(counter), + .ce0(ce0), + .q0(data_vector) + ); + + // Cocotb/verilator does not support array flattening, so + // we need to manually add some reshaping process. + for (genvar j = 0; j < WEIGHT_PARALLELISM_DIM_0 * WEIGHT_PARALLELISM_DIM_1; j++) + assign data_out[j] = data_vector[WEIGHT_PRECISION_0*j+WEIGHT_PRECISION_0-1:WEIGHT_PRECISION_0*j]; + + assign data_out_valid = clear == 2; + +endmodule diff --git a/docs/labs/dram-minimal-sv-dependencies.md b/docs/labs/dram-minimal-sv-dependencies.md new file mode 100644 index 000000000..e556eeb16 --- /dev/null +++ b/docs/labs/dram-minimal-sv-dependencies.md @@ -0,0 +1,120 @@ +# Minimal New `.sv` Dependencies for DRAM-Streamed MLP (with Testbench) + +## Goal +Make the DRAM-based MLP path compile and run with the existing cocotb testbench, with the minimum number of new RTL files. + +## Short Answer +- For the MLP compute path itself, **0 new compute `.sv` files are required**. +- `fixed_linear.sv` already supports streamed `weight`/`bias` with valid/ready. +- DRAM mode changes who drives those ports (external stream) rather than changing linear math hardware. + +## What to change in codegen (so no new compute module is needed) +1. Remove the `_dram` module renaming in `src/chop/passes/graph/transforms/verilog/emit_top.py`. +2. Keep module name as `fixed_linear`. +3. Keep DRAM behavior controlled by interface metadata (`storage="DRAM"`) and top-level parameter ports. + +## Where to add DRAM-only deployment dependencies +The primary file to control per-node dependency lists is: +- `src/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py` + +Use `add_component_source(...)` to switch dependency files only when DRAM mode is requested. + +Simplest DRAM-only pattern: +```python +elif mase_op in INTERNAL_COMP.keys(): + node.meta["mase"]["hardware"]["toolchain"] = "INTERNAL_RTL" + node.meta["mase"]["hardware"]["module"] = INTERNAL_COMP[mase_op][0]["name"] + node.meta["mase"]["hardware"]["dependence_files"] = INTERNAL_COMP[mase_op][0]["dependence_files"] + + if pass_args.get("interface", {}).get("storage", "BRAM") == "DRAM": + dram_extra = pass_args.get("interface", {}).get("dram_deployment_deps", []) + node.meta["mase"]["hardware"]["dependence_files"] = ( + node.meta["mase"]["hardware"]["dependence_files"] + dram_extra + ) +``` + +Notes: +- Keep compute dependencies unchanged for simulation. +- Append only deployment adapters (AXIS formatter/scheduler) under DRAM. +- Keep `storage` explicit as `"DRAM"` for off-chip flow. + +If you want this keyed by op type (for example only `linear`), the cleanest place for static lists is: +- `src/chop/passes/graph/analysis/add_metadata/hardware_metadata_layers.py` + +Example structure (conceptual): +```python +DRAM_EXTRA_DEPENDENCIES = { + "linear": [ + "memory_adapters/rtl/axis_param_reformatter.sv", + "memory_adapters/rtl/param_stream_scheduler.sv", + ] +} +``` + +Then in `add_component_source(...)`, append `DRAM_EXTRA_DEPENDENCIES.get(mase_op, [])` when `storage == "DRAM"`. + +## `emit_dram` function (deployment only) +This function is only for real FPGA deployment. It is not needed for cocotb simulation. + +In simulation, cocotb already quantizes and streams parameter blocks directly into DRAM-backed top-level ports. +In deployment, we need deployable parameter images that PS software and DMA can read from DDR and stream to the accelerator. + +```python +def emit_dram_transform_pass(graph, pass_args={}): + """ + Emit deployable DRAM parameter images (not BRAM ROM files) for real FPGA deployment. + + Expected outputs per DRAM-backed parameter: + - packed tensor stream image (e.g. .bin/.hex) + - metadata sidecar (shape, precision, parallelism, beat count, ordering) + + Notes: + - Not required for cocotb simulation, because cocotb drivers already stream quantized parameters directly. + - Required for deployment where PS software / DMA reads from DDR and streams to top-level parameter ports. + """ +``` + +Recommended `pass_args` fields: +- `project_dir`: output root +- `format`: `bin` or `hex` +- `endianness`: `little` or `big` +- `align_bytes`: beat alignment for DMA +- `emit_metadata`: `True/False` +- `target`: board/deployment tag + +## Minimum deployment stack (real FPGA) +Vivado IP helps with transport, but Vivado IP alone is not enough to satisfy model-specific packing, ordering, replay, and control requirements. + +1. Parameter image emitter (software/pass side) +Converts quantized tensors into stream-ready images matching hardware beat format and ordering. +2. DDR buffer allocation + loader software (PS side) +Allocates contiguous buffers, loads emitted images, and passes addresses/lengths to DMA/control plane. +3. AXI DMA (Vivado IP) +Moves parameter beats from DDR to AXI-Stream. +4. AXIS packet/beat reformatter (custom RTL/HLS) +Adapts DMA stream width/packet structure to `fc1_weight` / `fc1_bias` port beat shape. +5. Parameter stream scheduler/control (custom RTL/HLS or PS-driven FSM) +Controls when each parameter stream starts, stops, and repeats per inference, while respecting back-pressure. +6. Optional AXIS FIFO/width converter (Vivado IP) +Used when clock-domain crossing, buffering, or width adaptation is needed. + +### Vivado IP vs custom logic +Vivado can provide: +- AXI DMA +- AXIS FIFO +- AXIS data width converter +- AXI interconnect / SmartConnect +- Clocking/reset and standard infrastructure + +You still need to implement: +- Parameter block ordering/packing contract used by `fixed_linear` +- Tensor replay policy (e.g. re-stream full weights per sample/batch) +- Stream scheduling across weight/bias channels +- Control/status integration between PS and accelerator +- Model-specific handshaking and sequencing correctness + +## Final handoff list for teammate (today) +- Do not create `fixed_linear_dram.sv`. +- Remove `_dram` renaming in `src/chop/passes/graph/transforms/verilog/emit_top.py` so top instantiates `fixed_linear`. +- Add DRAM-only deployment dependencies in `src/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py` by appending an extra list only when `storage == "DRAM"`. +- Keep deployment adapters separate from compute RTL (compute stays `fixed_linear`). diff --git a/docs/labs/dram-mlp-testbenching.md b/docs/labs/dram-mlp-testbenching.md new file mode 100644 index 000000000..e48787c35 --- /dev/null +++ b/docs/labs/dram-mlp-testbenching.md @@ -0,0 +1,282 @@ +# DRAM-Backed MLP Testbenching in MASE + +## 1. Scope and objective + +This note documents how the MLP hardware testbench was adapted to validate off-chip parameter streaming, where weights and biases are supplied as runtime input streams instead of BRAM-backed parameter source modules. + +The focus is on additions in MASE required to support this verification flow. + +## 2. Verification target + +For the toy MLP: + +```python +class MLP(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(4, 8) + + def forward(self, x): + x = torch.flatten(x, start_dim=1, end_dim=-1) + return torch.nn.functional.relu(self.fc1(x)) +``` + +the hardware test should: + +- drive a valid activation stream into `data_in_0` +- drive valid parameter streams into `fc1_weight` and `fc1_bias` +- compare DUT output stream against software model expectation +- terminate quickly without waiting for timeout + +## 3. Off-chip metadata selection + +The DRAM mode is selected during hardware metadata generation: + +```python +mg, _ = add_hardware_metadata_analysis_pass(mg, {"interface": {"storage": "DRAM"}}) +``` + +In `add_component_source`, MASE selects DRAM-aware module/dependency mapping and marks tensor parameters as DRAM-backed interfaces: + +```python +storage_type = pass_args.get("interface", {}).get("storage", "BRAM") +if storage_type == "DRAM": + node.meta["mase"]["hardware"]["module"] = DRAM_INTERNAL_COMP[mase_op][0]["name"] + node.meta["mase"]["hardware"]["dependence_files"] = DRAM_INTERNAL_COMP[mase_op][0]["dependence_files"] +... +node.meta["mase"]["hardware"]["interface"][arg] = { + "storage": storage_type, + "transpose": False, +} +``` + +Source: `src/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py` + +## 4. Top-level RTL interface changes for DRAM parameters + +The top emitter now exposes parameter streaming ports when storage is DRAM: + +```python +if hardware.get("interface", {}).get(arg, {}).get("storage", "BRAM") != "DRAM": + continue +... +interface += f""" +input [{node_name}_{arg_name}_PRECISION_0-1:0] {node_name}_{arg} [{'*'.join(parallelism_params)}-1:0], +input {node_name}_{arg}_valid, +output {node_name}_{arg}_ready,""" +``` + +This creates explicit DUT ports such as: + +- `fc1_weight`, `fc1_weight_valid`, `fc1_weight_ready` +- `fc1_bias`, `fc1_bias_valid`, `fc1_bias_ready` + +Source: `src/chop/passes/graph/transforms/verilog/emit_top.py` + +## 5. Core cocotb test flow + +The generated test computes expected output directly from the software model and then drives/monitors streams: + +```python +in_tensors = tb.generate_inputs(batches=1) +exp_out = tb.model(*list(in_tensors.values())) + +tb.load_drivers(in_tensors) +tb.load_monitors(exp_out) + +await tb.wait_end(timeout=2, timeout_unit="ms") +``` + +Source: generated `~/.mase/top/hardware/test/mase_top_tb/test.py` from `emit_cocotb_transform_pass`. + +### Expected-value definition + +For a single layer MLP with ReLU: + +$$ +y = \mathrm{ReLU}(xW^T + b) +$$ + +`exp_out` is computed by the PyTorch model and then quantized/packed into output beats by `fixed_preprocess_tensor` in `load_monitors`. + +## 6. MASE additions in testbench emitter for off-chip parameter driving + +### 6.1 Why an additional change was needed + +During debugging, parameter ports were present in generated RTL, but some serialized graph paths had incomplete FX-node metadata at cocotb runtime, causing DRAM discovery to miss all parameter streams. + +### 6.2 Robust parameter-port discovery + +The testbench now binds parameter drivers from `model.named_parameters()` and DUT port existence, instead of relying only on FX-node metadata traversal: + +```python +self.dram_drivers = {} +self.dram_param_specs = {} +for full_name, param_tensor in self.model.named_parameters(): + if "." not in full_name: + continue + module_name, arg = full_name.rsplit(".", 1) + node_name = vf(module_name) + port_name = f"{node_name}_{arg}" + missing = [ + sig + for sig in [port_name, f"{port_name}_valid", f"{port_name}_ready"] + if not hasattr(dut, sig) + ] + if missing: + continue + + self.dram_drivers[port_name] = StreamDriver( + dut.clk, + getattr(dut, port_name), + getattr(dut, f"{port_name}_valid"), + getattr(dut, f"{port_name}_ready"), + ) + self.dram_param_specs[port_name] = (node_name, arg, param_tensor) +``` + +Source: `src/chop/passes/graph/transforms/verilog/emit_tb.py` + +### 6.3 Parameter tensor packing and streaming + +Each parameter tensor is quantized and packed according to emitted per-port precision and parallelism: + +```python +for port_name, (node_name, arg, param_tensor) in self.dram_param_specs.items(): + arg_cap = _cap(arg) + parallelism_0 = self.get_parameter(f"{node_name}_{arg_cap}_PARALLELISM_DIM_0") + parallelism_1 = self.get_parameter(f"{node_name}_{arg_cap}_PARALLELISM_DIM_1") + width = self.get_parameter(f"{node_name}_{arg_cap}_PRECISION_0") + frac_width = self.get_parameter(f"{node_name}_{arg_cap}_PRECISION_1") + + param_blocks = fixed_preprocess_tensor( + tensor=param_tensor, + q_config={"width": width, "frac_width": frac_width}, + parallelism=[parallelism_1, parallelism_0], + ) + + block_size = parallelism_0 * parallelism_1 + for block in param_blocks: + if len(block) < block_size: + block = block + [0] * (block_size - len(block)) + self.dram_drivers[port_name].append(block) +``` + +This makes parameter traffic follow the same valid/ready handshake pattern as regular input streams. + +Source: `src/chop/passes/graph/transforms/verilog/emit_tb.py` + +## 7. Input driving and expected output monitor path + +Input stream driving remains unchanged in concept: random input tensors are quantized and block-packed using interface parameters, then appended to input drivers. + +Output monitoring uses quantized/packed expected blocks and checks observed beats in-order. + +Representative monitor setup: + +```python +self.output_monitors[result] = StreamMonitor( + dut.clk, + getattr(dut, result), + getattr(dut, f"{result}_valid"), + getattr(dut, f"{result}_ready"), + check=False, +) +``` + +Source: `src/chop/passes/graph/transforms/verilog/emit_tb.py` + +## 8. End-to-end sequence with external-memory pass + +To make the report pipeline explicit, we define a dedicated transform pass whose only job is to tag supported tensor parameters for off-chip storage. + +### 8.1 Proposed pass + +```python +def emit_external_memory_transform_pass(graph, pass_args=None): + """Mark supported parameter tensors as externally streamed (DRAM).""" + if pass_args is None: + pass_args = {} + + storage = pass_args.get("storage", "DRAM") + supported_ops = set(pass_args.get("supported_ops", ["linear"])) + + for node in graph.fx_graph.nodes: + if "mase" not in node.meta: + continue + mase = node.meta["mase"] + common = mase.get("common", {}) + hardware = mase.get("hardware", {}) + + if hardware.get("is_implicit", False): + continue + if common.get("mase_op") not in supported_ops: + continue + + iface = hardware.setdefault("interface", {}) + for arg, arg_info in common.get("args", {}).items(): + if "data_in" in arg or not isinstance(arg_info, dict): + continue + iface.setdefault(arg, {}) + iface[arg]["storage"] = storage + iface[arg].setdefault("transpose", False) + + return graph, {} +``` + +Design intent: + +- keep this pass minimal and declarative +- separate policy (which params are off-chip) from code generation +- let downstream emitters consume the same `hardware.interface` metadata + +### 8.2 Revised hardware emit pipeline + +```python +# 1) Build and quantize graph +mg, _ = init_metadata_analysis_pass(mg, None) +mg, _ = add_common_metadata_analysis_pass(mg, {"dummy_in": dummy_in, "add_value": False}) +mg, _ = quantize_transform_pass(mg, quan_args) + +# 2) Populate hardware defaults +mg, _ = add_hardware_metadata_analysis_pass(mg) + +# 3) Rewrite selected params for external memory +mg, _ = emit_external_memory_transform_pass( + mg, + {"storage": "DRAM", "supported_ops": ["linear"]}, +) + +# 4) Emit RTL and testbench +mg, _ = emit_verilog_top_transform_pass(mg) +mg, _ = emit_internal_rtl_transform_pass(mg) +mg, _ = emit_bram_transform_pass(mg) # still used for on-chip tensors +mg, _ = emit_cocotb_transform_pass(mg) + +# 5) Simulate +simulate(skip_build=False, skip_test=False, waves=True) +``` + +This sequence is equivalent to the current direct DRAM configuration but gives a cleaner report narrative: a dedicated external-memory transform stage followed by unchanged emit stages. + +Notebook source for current flow: `docs/labs/lab4-hardware.ipynb` + +## 9. Observed runtime behavior after additions + +After the off-chip testbench additions: + +- parameter stream drivers are bound for `fc1_weight` and `fc1_bias` +- DRAM parameter blocks are queued before main input driving +- simulation terminates in a short runtime instead of timing out + +This validates that the testbench can inject a valid test case into both activation and off-chip parameter channels. + +## 10. Suggested report framing (for 4-page writeup) + +A clear structure for the report section is: + +1. Problem statement: BRAM-only parameter sourcing limits off-chip verification. +2. Method: expose DRAM parameter ports, generate cocotb drivers for those ports, compute expected output from software model, compare stream outputs. +3. Implementation details: metadata switch, top-level port generation, robust parameter-port binding, quantized block packing. +4. Results: successful parameter+input driving and completion without timeout. +5. Limitations and future work: stricter output checks (`check=True`), multi-layer scaling, AXI adapter-level traffic modeling. diff --git a/docs/labs/dram-tb-dut-ports-implementation.md b/docs/labs/dram-tb-dut-ports-implementation.md new file mode 100644 index 000000000..49a196249 --- /dev/null +++ b/docs/labs/dram-tb-dut-ports-implementation.md @@ -0,0 +1,61 @@ +# DRAM TB DUT Ports: Short Plan + +## Scope +- DRAM mode only: expose parameter stream ports. +- BRAM mode: no extra parameter ports. + +## Required DRAM Ports +For each DRAM parameter: +- `_` +- `__valid` +- `__ready` + +MLP example: +- `fc1_weight`, `fc1_weight_valid`, `fc1_weight_ready` +- `fc1_bias`, `fc1_bias_valid`, `fc1_bias_ready` + +## How original input is handled (reference) +`data_in*` flow is: +1. Create `StreamDriver(dut., _valid, _ready)`. +2. Generate tensors from metadata shape. +3. Quantize+pack with `fixed_preprocess_tensor(...)`. +4. Append blocks to driver. + +DRAM parameters should use the same pattern. + +## Verification plan +1. Runtime TB check only +- Discover expected DRAM ports from metadata (`storage == DRAM`). +- At TB init, log expected port count. +- For each expected port, verify DUT has all three signals: + - `_` + - `__valid` + - `__ready` +- Bind a `StreamDriver` and log the bound port name. +- Assert: if DRAM metadata exists, bound driver count must equal expected count. + +2. Runtime preload check +- During `load_drivers`, log queued block count per DRAM port. +- Log total queued DRAM blocks. +- Assert total queued DRAM blocks > 0 in DRAM mode. + +## Driving plan +For each DRAM port in TB: +1. Bind `StreamDriver` to `_`, valid, ready. +2. Read parameter tensor from model (`module.get_parameter(arg)`). +3. Read DUT metadata parameters for that port: + - precision (`*_PRECISION_0`, `*_PRECISION_1`) + - parallelism (`*_PARALLELISM_DIM_0`, `*_PARALLELISM_DIM_1`) +4. Quantize and pack with `fixed_preprocess_tensor(...)`. +5. Use block size = `PARALLELISM_DIM_0 * PARALLELISM_DIM_1`. +6. Pad final block to full block size when needed. +7. Append each block to the stream driver queue. + +Expected runtime logs location: +- Cocotb simulator output stream in terminal while running tests. +- Same messages in simulator log output under the generated project run directory. + +## Done criteria +- DRAM `top.sv` has parameter stream ports. +- TB runtime logs show expected ports, bound drivers, and queued DRAM blocks. +- BRAM tests still pass unchanged. diff --git a/docs/labs/dram-weight-streaming.md b/docs/labs/dram-weight-streaming.md new file mode 100644 index 000000000..acb4ab7bb --- /dev/null +++ b/docs/labs/dram-weight-streaming.md @@ -0,0 +1,146 @@ +# Weight Streaming in MASE Hardware Emit Flow + +## Overview + +By default, MASE generates FPGA hardware that stores all layer weights and biases +**on-chip in BRAM ROM modules**. This guide explains how to instead route weights +and biases from **off-chip **, exposing them as streaming handshake ports on +the generated `top.sv` module. + +The flow is useful when: +- Weight tensors are too large to fit in on-chip BRAM. +- You want to connect to an external AXI memory controller (e.g., DDR via Zynq PS). +- You are benchmarking off-chip memory bandwidth vs on-chip compute throughput. + +--- + +## What Changes Between BRAM and Modes + +| Aspect | BRAM (default) | | +|---|---|---| +| Weight storage | On-chip ROM (`fc1_weight_source.sv`) | External — driven into top-level ports | +| `top.sv` interface | Only `data_in`/`data_out` ports | Also exposes `{node}_{param}`, `{node}_{param}_valid`, `{node}_{param}_ready` | +| `emit_bram` pass | Generates `.sv` ROM + `.dat` init file | Skips generation, deletes stale files | +| Testbench | BRAM source drives weights internally | Cocotb `StreamDriver` drives weight ports | + +--- + +## How to Use + +### Step 1 — Pass `storage: ` to the hardware metadata pass + +```python +mg, _ = add_hardware_metadata_analysis_pass(mg, {"interface": {"storage": ""}}) +``` + +This tags every non-input tensor parameter (weight, bias, …) in the hardware +metadata with `storage: ""` instead of the default `"BRAM"`. + +To use BRAM (the original behaviour), omit the argument or pass `"BRAM"` explicitly: + +```python +mg, _ = add_hardware_metadata_analysis_pass(mg, {"interface": {"storage": "BRAM"}}) +# or simply: +mg, _ = add_hardware_metadata_analysis_pass(mg) +``` + +### Step 2 — Run the emit passes as normal + +```python +mg, _ = emit_verilog_top_transform_pass(mg) # generates top.sv with ports +mg, _ = emit_internal_rtl_transform_pass(mg) # copies component RTL files +mg, _ = emit_bram_transform_pass(mg) # skips ROM generation for params; + # deletes any stale BRAM .sv/.dat files +mg, _ = emit_cocotb_transform_pass(mg) # testbench with StreamDrivers +``` + +### Step 3 — Simulate + +```python +from chop.actions import simulate +simulate(skip_build=False, skip_test=False, waves=True) +``` + +The cocotb testbench will automatically drive `fc1_weight`, `fc1_weight_valid`, +`fc1_weight_ready` (and equivalent bias ports) from the preloaded quantised +parameter tensors. + +--- + +## Generated `top.sv` Interface + +For an MLP with a single `fc1` linear layer using storage, `top.sv` will +expose the following additional ports: + +```systemverilog +// Activation data (unchanged from BRAM flow) +input [DATA_IN_0_PRECISION_0-1:0] data_in_0 [DATA_IN_0_PARALLELISM_DIM_0*...-1:0], +input data_in_0_valid, +output data_in_0_ready, + +output [DATA_OUT_0_PRECISION_0-1:0] data_out_0 [DATA_OUT_0_PARALLELISM_DIM_0*...-1:0], +output data_out_0_valid, +input data_out_0_ready, + +// weight streaming ports <-- NEW +input [fc1_WEIGHT_PRECISION_0-1:0] fc1_weight [fc1_WEIGHT_PARALLELISM_DIM_0*fc1_WEIGHT_PARALLELISM_DIM_1-1:0], +input fc1_weight_valid, +output fc1_weight_ready, + +// bias streaming ports <-- NEW +input [fc1_BIAS_PRECISION_0-1:0] fc1_bias [fc1_BIAS_PARALLELISM_DIM_0*fc1_BIAS_PARALLELISM_DIM_1-1:0], +input fc1_bias_valid, +output fc1_bias_ready, +``` + +Each parameter uses a **valid/ready handshake**, matching the existing streaming +interface of the internal compute modules (e.g. `fixed_linear`). + +--- + +## Connecting to a Real Memory Controller (FPGA Deployment) + +In simulation the cocotb testbench acts as the memory controller. For real FPGA +deployment you need to replace that role with an AXI-stream or custom DMA engine. +A typical integration looks like: + +``` +DDR ──► AXI DMA ──► AXI-Stream ──► (serialiser/reformatter) ──► fc1_weight port on top.sv +``` + +Key considerations: + +1. **Data width**: the `fc1_weight` port is `PRECISION_0 * PARALLELISM_DIM_0 * PARALLELISM_DIM_1` + bits wide per beat. Size your DMA burst accordingly. +2. **Ordering**: weights must arrive in the same row-major block order that + `fixed_linear` expects — identical to how the BRAM ROM stored them. +3. **Cycling**: for batch inference `fixed_linear` reads the full weight matrix + once per sample. Your controller must re-stream the weights for each input. +4. **Back-pressure**: the `_ready` signal indicates the compute module can accept + a beat. Your controller must respect this or data will be dropped. + +--- + +## Files Changed + +| File | Change | +|---|---| +| `src/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py` | `add_component_source` now reads `storage` from `pass_args["interface"]` instead of assuming BRAM | +| `src/chop/passes/graph/transforms/verilog/emit_top.py` | `VerilogInterfaceEmitter` emits top-level ports; `VerilogInternalComponentEmitter` skips BRAM source instantiation for params | +| `src/chop/passes/graph/transforms/verilog/emit_bram.py` | `emit_bram_handshake` skips `.sv`/`.dat` generation for params and deletes stale files from previous BRAM runs | +| `src/chop/passes/graph/transforms/verilog/emit_tb.py` | `_emit_cocotb_tb` creates `StreamDriver` instances for ports and preloads quantised parameter blocks in `load_drivers` | + +--- + +## Limitations / Future Work + +- **Per-parameter granularity**: currently `storage` applies to all parameters + globally. A future extension could allow e.g. weights on but bias on BRAM + by passing `{"interface": {"weight": {"storage": ""}, "bias": {"storage": "BRAM"}}}`. +- **Multi-layer support**: the ports are per-node (prefixed with node name), + so multi-layer models automatically get separate ports per layer. +- **Transposition**: the `transpose` flag in the interface metadata is preserved + but not yet acted on in the path — the data is assumed to arrive in the + same layout as BRAM. +- **HLS toolchain**: `emit_bram_hls` has a symmetric stub that also needs + support if the HLS flow is used. diff --git a/docs/labs/dump.fst b/docs/labs/dump.fst new file mode 100644 index 000000000..47515d20e Binary files /dev/null and b/docs/labs/dump.fst differ diff --git a/docs/labs/initial-cocotb-plan.md b/docs/labs/initial-cocotb-plan.md new file mode 100644 index 000000000..6e00e9f8b --- /dev/null +++ b/docs/labs/initial-cocotb-plan.md @@ -0,0 +1,59 @@ +# Initial Cocotb Plan: DRAM-Style MLP Testbench Only + +1. in `lab4-hardware.ipynb` run until `emit_internal_rtl_transform_pass` + + + +## Objective +Deliver working testbenches for off-chip weight streaming on a small MLP. +This plan is testbench-focused only and does not include real FPGA integration work. + +## Scope +In scope: +- Generate DRAM-style `top.sv` for an MLP +- Remove BRAM parameter source usage +- Drive `weight` and `bias` from cocotb (DRAM emulation) +- Verify functional correctness and basic throughput behavior in simulation + +Out of scope: +- AXI DMA integration +- Vivado block design and deployment adapters +- Board bring-up + +## 3-Step Execution Plan + +1. Take a normal generated MLP and regenerate in DRAM mode +- Start from the existing MLP emit flow. +- Use hardware metadata storage `"DRAM"` so parameters are off-chip streamed. +- Confirm `top.sv` is generated with top-level parameter ports: + - `{node}_{param}` + - `{node}_{param}_valid` + - `{node}_{param}_ready` + +2. Ensure `top.sv` has no BRAM-initialized parameter sources +- Verify there are no `*_source.sv` parameter module instantiations for weight/bias. +- Verify `emit_bram` skipped DRAM parameter ROM generation. +- Keep compute logic unchanged (`fixed_linear` stays the math core). + +3. Write cocotb to emulate DRAM and verify behavior +- In testbench, drive both activations and parameters. +- For parameters (`weight`, `bias`): + - quantize with the same precision as hardware params + - pack into blocks by parallelism + - stream with valid/ready handshake +- Verify: + - output correctness vs software model + - no handshake drops/deadlocks + - basic simulation-level performance counters (cycles/latency per sample) + +## Minimal Touch Points +- `src/chop/passes/graph/transforms/verilog/emit_top.py` +- `src/chop/passes/graph/transforms/verilog/emit_bram.py` +- `src/chop/passes/graph/transforms/verilog/emit_tb.py` +- One DRAM-mode test under `test/passes/graph/transforms/verilog/` + +## Acceptance Criteria +1. DRAM-mode `top.sv` compiles and has exposed parameter streaming ports. +2. No BRAM parameter source modules are instantiated for DRAM params. +3. Cocotb test passes for dummy MLP with DRAM-emulated `weight` and `bias` streams. +4. Collected basic performance metrics in simulation (for example cycle count to first output and total inference cycles). \ No newline at end of file diff --git a/docs/labs/internal_rtl/hardware/rtl/fixed_accumulator.sv b/docs/labs/internal_rtl/hardware/rtl/fixed_accumulator.sv new file mode 100644 index 000000000..1f6f48c64 --- /dev/null +++ b/docs/labs/internal_rtl/hardware/rtl/fixed_accumulator.sv @@ -0,0 +1,67 @@ +`timescale 1ns / 1ps +module fixed_accumulator #( + parameter IN_DEPTH = 4, + parameter IN_WIDTH = 32, + parameter OUT_WIDTH = $clog2(IN_DEPTH) + IN_WIDTH +) ( + input logic clk, + input logic rst, + + input logic [IN_WIDTH-1:0] data_in, + input logic data_in_valid, + output logic data_in_ready, + + output logic [OUT_WIDTH-1:0] data_out, + output logic data_out_valid, + input logic data_out_ready +); + logic [OUT_WIDTH-1:0] reg_in; + logic reg_in_valid, reg_in_ready; + + skid_buffer #( + .DATA_WIDTH(OUT_WIDTH) + ) register_slice ( + .data_in(reg_in), + .data_in_valid(reg_in_valid), + .data_in_ready(reg_in_ready), + .* + ); + // 1-bit wider so IN_DEPTH also fits. + localparam COUNTER_WIDTH = $clog2(IN_DEPTH); + logic [COUNTER_WIDTH:0] counter; + + // Sign extension before feeding into the accumulator + logic [ OUT_WIDTH-1:0] data_in_sext; + assign data_in_sext = {{(OUT_WIDTH - IN_WIDTH) {data_in[IN_WIDTH-1]}}, data_in}; + + /* verilator lint_off WIDTH */ + assign data_in_ready = (counter != IN_DEPTH) || reg_in_ready; + assign reg_in_valid = (counter == IN_DEPTH); + /* verilator lint_on WIDTH */ + + // counter + always_ff @(posedge clk) + if (rst) counter <= 0; + else begin + if (reg_in_valid) begin + if (reg_in_ready) begin + if (data_in_valid) counter <= 1; + else counter <= 0; + end + end else if (data_in_valid && data_in_ready) counter <= counter + 1; + end + + // data_out + always_ff @(posedge clk) + if (rst) reg_in <= '0; + else begin + if (reg_in_valid) begin + if (reg_in_ready) begin + if (data_in_valid) reg_in <= data_in_sext; + else reg_in <= '0; + end + end else if (data_in_valid && data_in_ready) reg_in <= reg_in + data_in_sext; + end + + +endmodule diff --git a/docs/labs/internal_rtl/hardware/rtl/fixed_adder_tree.sv b/docs/labs/internal_rtl/hardware/rtl/fixed_adder_tree.sv new file mode 100644 index 000000000..99d6f409f --- /dev/null +++ b/docs/labs/internal_rtl/hardware/rtl/fixed_adder_tree.sv @@ -0,0 +1,89 @@ +`timescale 1ns / 1ps + +// TODO: Add signed param. fixed_adder_tree_layer already supports signedness + +module fixed_adder_tree #( + parameter IN_SIZE = 2, + parameter IN_WIDTH = 32, + parameter OUT_WIDTH = $clog2(IN_SIZE) + IN_WIDTH +) ( + /* verilator lint_off UNUSEDSIGNAL */ + input logic clk, + input logic rst, + /* verilator lint_on UNUSEDSIGNAL */ + input logic [ IN_WIDTH-1:0] data_in [IN_SIZE-1:0], + input logic data_in_valid, + output logic data_in_ready, + output logic [OUT_WIDTH-1:0] data_out, + output logic data_out_valid, + input logic data_out_ready +); + + localparam LEVELS = $clog2(IN_SIZE); + + initial begin + assert (IN_SIZE > 0); + end + + generate + if (LEVELS == 0) begin : gen_skip_adder_tree + + assign data_out = data_in[0]; + assign data_out_valid = data_in_valid; + assign data_in_ready = data_out_ready; + + end else begin : gen_adder_tree + + // data & sum wires are oversized on purpose for vivado. + logic [OUT_WIDTH*IN_SIZE-1:0] data[LEVELS:0]; + logic [OUT_WIDTH*IN_SIZE-1:0] sum[LEVELS-1:0]; + logic valid[IN_SIZE-1:0]; + logic ready[IN_SIZE-1:0]; + + // Generate adder for each layer + for (genvar i = 0; i < LEVELS; i++) begin : level + + localparam LEVEL_IN_SIZE = (IN_SIZE + ((1 << i) - 1)) >> i; + localparam LEVEL_OUT_SIZE = (LEVEL_IN_SIZE + 1) / 2; + localparam LEVEL_IN_WIDTH = IN_WIDTH + i; + localparam LEVEL_OUT_WIDTH = LEVEL_IN_WIDTH + 1; + + fixed_adder_tree_layer #( + .IN_SIZE (LEVEL_IN_SIZE), + .IN_WIDTH(LEVEL_IN_WIDTH) + ) layer ( + .data_in (data[i]), // flattened LEVEL_IN_SIZE * LEVEL_IN_WIDTH + .data_out(sum[i]) // flattened LEVEL_OUT_SIZE * LEVEL_OUT_WIDTH + ); + + skid_buffer #( + .DATA_WIDTH(LEVEL_OUT_SIZE * LEVEL_OUT_WIDTH) + ) register_slice ( + .clk (clk), + .rst (rst), + .data_in (sum[i]), + .data_in_valid (valid[i]), + .data_in_ready (ready[i]), + .data_out (data[i+1]), + .data_out_valid(valid[i+1]), + .data_out_ready(ready[i+1]) + ); + + end + + for (genvar i = 0; i < IN_SIZE; i++) begin : gen_input_assign + assign data[0][(i+1)*IN_WIDTH-1 : i*IN_WIDTH] = data_in[i]; + end + + assign valid[0] = data_in_valid; + assign data_in_ready = ready[0]; + + assign data_out = data[LEVELS][OUT_WIDTH-1:0]; + assign data_out_valid = valid[LEVELS]; + assign ready[LEVELS] = data_out_ready; + + end + endgenerate + + +endmodule diff --git a/docs/labs/internal_rtl/hardware/rtl/fixed_adder_tree_layer.sv b/docs/labs/internal_rtl/hardware/rtl/fixed_adder_tree_layer.sv new file mode 100644 index 000000000..7d2e38861 --- /dev/null +++ b/docs/labs/internal_rtl/hardware/rtl/fixed_adder_tree_layer.sv @@ -0,0 +1,42 @@ +`timescale 1ns / 1ps +module fixed_adder_tree_layer #( + parameter IN_SIZE = 2, + parameter IN_WIDTH = 16, + parameter SIGNED = 1, + + localparam OUT_WIDTH = IN_WIDTH + 1, + localparam OUT_SIZE = (IN_SIZE + 1) / 2 +) ( + input logic [ IN_SIZE*IN_WIDTH-1:0] data_in, + output logic [OUT_SIZE*OUT_WIDTH-1:0] data_out +); + + logic [ IN_WIDTH-1:0] data_in_unflat [ IN_SIZE-1:0]; + logic [OUT_WIDTH-1:0] data_out_unflat[OUT_SIZE-1:0]; + + for (genvar i = 0; i < IN_SIZE; i++) begin : in_unflat + assign data_in_unflat[i] = data_in[(i+1)*IN_WIDTH-1 : i*IN_WIDTH]; + end + + for (genvar i = 0; i < IN_SIZE / 2; i++) begin : pair + if (SIGNED) begin + assign data_out_unflat[i] = $signed(data_in_unflat[2*i]) + $signed(data_in_unflat[2*i+1]); + end else begin + assign data_out_unflat[i] = data_in_unflat[2*i] + data_in_unflat[2*i+1]; + end + end + + if (IN_SIZE % 2 != 0) begin : left + if (SIGNED) begin + assign data_out_unflat[OUT_SIZE-1] = $signed(data_in_unflat[IN_SIZE-1]); + end else begin + assign data_out_unflat[OUT_SIZE-1] = {1'b0, data_in_unflat[IN_SIZE-1]}; + end + end + + for (genvar i = 0; i < OUT_SIZE; i++) begin : out_flat + assign data_out[(i+1)*OUT_WIDTH-1 : i*OUT_WIDTH] = data_out_unflat[i]; + end + + +endmodule diff --git a/docs/labs/internal_rtl/hardware/rtl/fixed_cast.sv b/docs/labs/internal_rtl/hardware/rtl/fixed_cast.sv new file mode 100644 index 000000000..2a2a86f79 --- /dev/null +++ b/docs/labs/internal_rtl/hardware/rtl/fixed_cast.sv @@ -0,0 +1,66 @@ +`timescale 1ns / 1ps +module fixed_cast #( + parameter IN_SIZE = 8, + parameter IN_WIDTH = 8, + parameter IN_FRAC_WIDTH = 4, + parameter OUT_WIDTH = 8, + parameter OUT_FRAC_WIDTH = 4 +) ( + input logic [ IN_WIDTH-1:0] data_in [IN_SIZE-1:0], + output logic [OUT_WIDTH-1:0] data_out[IN_SIZE-1:0] +); + + // TODO: Negative frac_width is not supported + + localparam IN_INT_WIDTH = IN_WIDTH - IN_FRAC_WIDTH; + localparam OUT_INT_WIDTH = OUT_WIDTH - OUT_FRAC_WIDTH; + + // Sign + for (genvar i = 0; i < IN_SIZE; i++) begin : out_sign + logic data; + assign data = data_in[i][IN_WIDTH-1]; + end + + // Fraction part + for (genvar i = 0; i < IN_SIZE; i++) begin : out_frac + logic [OUT_FRAC_WIDTH-1:0] data; + if (IN_FRAC_WIDTH > OUT_FRAC_WIDTH) + assign data = data_in[i][IN_FRAC_WIDTH-1:IN_FRAC_WIDTH-OUT_FRAC_WIDTH]; + /* verilator lint_off WIDTH */ + else + assign data = data_in[i][IN_FRAC_WIDTH-1:0] << (OUT_FRAC_WIDTH - IN_FRAC_WIDTH); + /* verilator lint_on WIDTH */ + end + + // Integer part + for (genvar i = 0; i < IN_SIZE; i++) begin : out_int + logic [OUT_INT_WIDTH-2:0] data; + if (IN_INT_WIDTH > OUT_INT_WIDTH) + assign data = data_in[i][OUT_INT_WIDTH-2+IN_FRAC_WIDTH:IN_FRAC_WIDTH]; + else + assign data = { + {(OUT_INT_WIDTH - IN_INT_WIDTH) {data_in[i][IN_WIDTH-1]}}, + data_in[i][IN_WIDTH-2:IN_FRAC_WIDTH] + }; + end + + for (genvar i = 0; i < IN_SIZE; i++) begin + + if (IN_INT_WIDTH > OUT_INT_WIDTH) begin + always_comb begin + // Saturation check + if (|({(IN_WIDTH-OUT_INT_WIDTH-IN_FRAC_WIDTH){data_in[i][IN_WIDTH-1]}} ^ data_in[i][IN_WIDTH-2:OUT_INT_WIDTH-1+IN_FRAC_WIDTH])) begin + /* saturate to b'100...001 or b' 011..111*/ + data_out[i] = {out_sign[i].data, {(OUT_WIDTH - 2) {~data_in[i][IN_WIDTH-1]}}, 1'b1}; + end else begin + data_out[i] = {out_sign[i].data, out_int[i].data, out_frac[i].data}; + end + end + end else begin + assign data_out[i] = {out_sign[i].data, out_int[i].data, out_frac[i].data}; + end + end + + +endmodule + diff --git a/docs/labs/internal_rtl/hardware/rtl/fixed_dot_product.sv b/docs/labs/internal_rtl/hardware/rtl/fixed_dot_product.sv new file mode 100644 index 000000000..8bd12f739 --- /dev/null +++ b/docs/labs/internal_rtl/hardware/rtl/fixed_dot_product.sv @@ -0,0 +1,85 @@ +`timescale 1ns / 1ps +module fixed_dot_product #( + parameter IN_WIDTH = 32, + // this defines the number of elements in the vector, this is tunable + // when block arithmetics are applied, this is the same as the block size + parameter IN_SIZE = 4, + parameter WEIGHT_WIDTH = 16, + // this is the width for the product + // parameter PRODUCT_WIDTH = 8, + // this is the width for the summed product + parameter OUT_WIDTH = IN_WIDTH + WEIGHT_WIDTH + $clog2(IN_SIZE) +) ( + input clk, + input rst, + + // input port for activations + input logic [IN_WIDTH-1:0] data_in [IN_SIZE-1:0], + input data_in_valid, + output data_in_ready, + + // input port for weight + input logic [WEIGHT_WIDTH-1:0] weight [IN_SIZE-1:0], + input weight_valid, + output weight_ready, + + // output port + output logic [OUT_WIDTH-1:0] data_out, + output data_out_valid, + input data_out_ready + +); + + localparam PRODUCT_WIDTH = IN_WIDTH + WEIGHT_WIDTH; + + + logic [PRODUCT_WIDTH-1:0] pv [IN_SIZE-1:0]; + logic pv_valid; + logic pv_ready; + + logic [ OUT_WIDTH-1:0] sum; + logic sum_valid; + logic sum_ready; + + fixed_vector_mult #( + .IN_WIDTH(IN_WIDTH), + .WEIGHT_WIDTH(WEIGHT_WIDTH), + .IN_SIZE(IN_SIZE) + ) fixed_vector_mult_inst ( + .clk(clk), + .rst(rst), + .data_in(data_in), + .data_in_valid(data_in_valid), + .data_in_ready(data_in_ready), + .weight(weight), + .weight_valid(weight_valid), + .weight_ready(weight_ready), + .data_out(pv), + .data_out_valid(pv_valid), + .data_out_ready(pv_ready) + ); + + + // sum the products + // sum = sum(pv) + fixed_adder_tree #( + .IN_SIZE (IN_SIZE), + .IN_WIDTH(PRODUCT_WIDTH) + ) fixed_adder_tree_inst ( + .clk(clk), + .rst(rst), + .data_in(pv), + .data_in_valid(pv_valid), + .data_in_ready(pv_ready), + + .data_out(sum), + .data_out_valid(sum_valid), + .data_out_ready(sum_ready) + ); + + // Picking the end of the buffer, wire them to the output port + assign data_out = sum; + assign data_out_valid = sum_valid; + assign sum_ready = data_out_ready; + +endmodule diff --git a/docs/labs/internal_rtl/hardware/rtl/fixed_linear.sv b/docs/labs/internal_rtl/hardware/rtl/fixed_linear.sv new file mode 100644 index 000000000..1e0edce53 --- /dev/null +++ b/docs/labs/internal_rtl/hardware/rtl/fixed_linear.sv @@ -0,0 +1,284 @@ +`timescale 1ns / 1ps + +/* + * + * The fixed_linear module implements torch.nn.functional.linear, which + * computes Y = X @ W^T + b + * + * Weight tensor is assumed to have shape (out_features, in_features) + * Data tensor is assumed to have shape (batch_size, in_features) + * Bias tensor is assumed to have shape (out_features) + * + * If WEIGHTS_PRE_TRANSPOSED is set to 0, the module will transpose the incoming + * weight matrix. Otherwise, it will assume that the incoming weight matrix is + * already transposed. + * + */ + +module fixed_linear #( + /* verilator lint_off UNUSEDPARAM */ + parameter HAS_BIAS = 1, + parameter WEIGHTS_PRE_TRANSPOSED = 0, + + parameter DATA_IN_0_PRECISION_0 = 16, + parameter DATA_IN_0_PRECISION_1 = 3, + parameter DATA_IN_0_TENSOR_SIZE_DIM_0 = 20, + parameter DATA_IN_0_TENSOR_SIZE_DIM_1 = 20, + parameter DATA_IN_0_PARALLELISM_DIM_0 = 4, // must equal WEIGHT_PARALLELISM_DIM_1 + parameter DATA_IN_0_PARALLELISM_DIM_1 = 4, + localparam IN_0_DEPTH_DIM_0 = DATA_IN_0_TENSOR_SIZE_DIM_0 / DATA_IN_0_PARALLELISM_DIM_0, + localparam IN_0_DEPTH_DIM_1 = DATA_IN_0_TENSOR_SIZE_DIM_1 / DATA_IN_0_PARALLELISM_DIM_1, + + parameter WEIGHT_PRECISION_0 = 16, + parameter WEIGHT_PRECISION_1 = 3, + parameter WEIGHT_TENSOR_SIZE_DIM_0 = 20, + parameter WEIGHT_TENSOR_SIZE_DIM_1 = 20, + parameter WEIGHT_PARALLELISM_DIM_0 = 4, + parameter WEIGHT_PARALLELISM_DIM_1 = 4, + + // Inferred precision of the output data + // if the data out precision will be replaced by the setting + parameter DATA_OUT_0_PRECISION_0 = DATA_IN_0_PRECISION_0 + WEIGHT_PRECISION_0 + $clog2( + DATA_IN_0_TENSOR_SIZE_DIM_0 + ) + 1, + parameter DATA_OUT_0_PRECISION_1 = DATA_IN_0_PRECISION_1 + WEIGHT_PRECISION_1, + parameter DATA_OUT_0_TENSOR_SIZE_DIM_0 = (WEIGHTS_PRE_TRANSPOSED == 0)? WEIGHT_TENSOR_SIZE_DIM_1:WEIGHT_TENSOR_SIZE_DIM_0, + parameter DATA_OUT_0_TENSOR_SIZE_DIM_1 = DATA_IN_0_TENSOR_SIZE_DIM_1, + parameter DATA_OUT_0_PARALLELISM_DIM_0 = (WEIGHTS_PRE_TRANSPOSED == 0)? WEIGHT_PARALLELISM_DIM_1:WEIGHT_PARALLELISM_DIM_0, + parameter DATA_OUT_0_PARALLELISM_DIM_1 = DATA_IN_0_PARALLELISM_DIM_1, + + parameter BIAS_PRECISION_0 = 16, + parameter BIAS_PRECISION_1 = 3, + parameter BIAS_TENSOR_SIZE_DIM_0 = DATA_OUT_0_TENSOR_SIZE_DIM_0, + parameter BIAS_TENSOR_SIZE_DIM_1 = 1, + parameter BIAS_PARALLELISM_DIM_0 = DATA_OUT_0_PARALLELISM_DIM_0, + parameter BIAS_PARALLELISM_DIM_1 = 1, + + localparam BIAS_DEPTH_DIM_0 = BIAS_TENSOR_SIZE_DIM_0 / BIAS_PARALLELISM_DIM_0, + localparam BIAS_DEPTH_DIM_1 = BIAS_TENSOR_SIZE_DIM_1 / BIAS_PARALLELISM_DIM_1 + +) ( + input clk, + input rst, + + // input port for data_inivations + input logic [DATA_IN_0_PRECISION_0-1:0] data_in_0 [DATA_IN_0_PARALLELISM_DIM_0*DATA_IN_0_PARALLELISM_DIM_1-1:0], + input logic data_in_0_valid, + output logic data_in_0_ready, + + // input port for weight + input logic [WEIGHT_PRECISION_0-1:0] weight [WEIGHT_PARALLELISM_DIM_0 * WEIGHT_PARALLELISM_DIM_1-1:0], + input logic weight_valid, + output logic weight_ready, + + /* verilator lint_off UNUSEDSIGNAL */ + input logic [BIAS_PRECISION_0-1:0] bias[BIAS_PARALLELISM_DIM_0 * BIAS_PARALLELISM_DIM_1 -1:0], + input logic bias_valid, + /* verilator lint_on UNUSEDSIGNAL */ + output logic bias_ready, + + output logic [DATA_OUT_0_PRECISION_0-1:0] data_out_0 [DATA_OUT_0_PARALLELISM_DIM_0*DATA_OUT_0_PARALLELISM_DIM_1-1:0], + output logic data_out_0_valid, + input logic data_out_0_ready +); + localparam MMOUT_PRECISION_0 = DATA_IN_0_PRECISION_0 + WEIGHT_PRECISION_0 + $clog2( + DATA_IN_0_TENSOR_SIZE_DIM_0 + ); + localparam MMOUT_PRECISION_1 = DATA_IN_0_PRECISION_1 + WEIGHT_PRECISION_1; + + // The TENSOR_SIZE and PARALLELISM parameters for the weights are set by emit verilog according to the real + // tensor values. Here we account for the change when the weights are pre-transposed + localparam REAL_WEIGHT_TENSOR_SIZE_DIM_0 = (WEIGHTS_PRE_TRANSPOSED == 0) ? WEIGHT_TENSOR_SIZE_DIM_1 : WEIGHT_TENSOR_SIZE_DIM_0; + localparam REAL_WEIGHT_TENSOR_SIZE_DIM_1 = (WEIGHTS_PRE_TRANSPOSED == 0) ? WEIGHT_TENSOR_SIZE_DIM_0 : WEIGHT_TENSOR_SIZE_DIM_1; + localparam REAL_WEIGHT_PARALLELISM_DIM_0 = (WEIGHTS_PRE_TRANSPOSED == 0) ? WEIGHT_PARALLELISM_DIM_1 : WEIGHT_PARALLELISM_DIM_0; + localparam REAL_WEIGHT_PARALLELISM_DIM_1 = (WEIGHTS_PRE_TRANSPOSED == 0) ? WEIGHT_PARALLELISM_DIM_0 : WEIGHT_PARALLELISM_DIM_1; + + // * Declarations + // * --------------------------------------------------------------------------------------------------- + + logic [WEIGHT_PRECISION_0-1:0] weight_transposed [WEIGHT_PARALLELISM_DIM_0 * WEIGHT_PARALLELISM_DIM_1-1:0]; + logic weight_transposed_valid; + logic weight_transposed_ready; + + logic [MMOUT_PRECISION_0-1:0] matmul_out [DATA_OUT_0_PARALLELISM_DIM_0*DATA_OUT_0_PARALLELISM_DIM_1-1:0]; + logic matmul_out_valid; + logic matmul_out_ready; + + logic [BIAS_PRECISION_0-1:0] bias_buffered [BIAS_PARALLELISM_DIM_0 * BIAS_PARALLELISM_DIM_1 -1:0]; + logic bias_buffered_valid, bias_buffered_ready; + + logic [MMOUT_PRECISION_0-1:0] bias_casted [BIAS_PARALLELISM_DIM_0 * BIAS_PARALLELISM_DIM_1 -1:0]; + logic [MMOUT_PRECISION_0:0] add_bias_in [DATA_OUT_0_PARALLELISM_DIM_0*DATA_OUT_0_PARALLELISM_DIM_1-1:0]; + logic [DATA_OUT_0_PRECISION_0 - 1:0] add_bias_in_casted [DATA_OUT_0_PARALLELISM_DIM_0*DATA_OUT_0_PARALLELISM_DIM_1-1:0]; + logic add_bias_in_valid; + logic add_bias_in_ready; + + // * Instances + // * --------------------------------------------------------------------------------------------------- + + if (WEIGHTS_PRE_TRANSPOSED == 0) begin + matrix_stream_transpose #( + .TOTAL_DIM0 (WEIGHT_TENSOR_SIZE_DIM_0), + .TOTAL_DIM1 (WEIGHT_TENSOR_SIZE_DIM_1), + .COMPUTE_DIM0(WEIGHT_PARALLELISM_DIM_0), + .COMPUTE_DIM1(WEIGHT_PARALLELISM_DIM_1), + .DATA_WIDTH (WEIGHT_PRECISION_0) + ) weight_matrix_transpose_i ( + .clk, + .rst, + + .in_data (weight), + .in_valid(weight_valid), + .in_ready(weight_ready), + + .out_data (weight_transposed), + .out_valid(weight_transposed_valid), + .out_ready(weight_transposed_ready) + ); + end + matmul #( + // Total dimensions + .A_TOTAL_DIM0(DATA_IN_0_TENSOR_SIZE_DIM_0), + .A_TOTAL_DIM1(DATA_IN_0_TENSOR_SIZE_DIM_1), + .B_TOTAL_DIM0(REAL_WEIGHT_TENSOR_SIZE_DIM_0), + .B_TOTAL_DIM1(REAL_WEIGHT_TENSOR_SIZE_DIM_1), + + .A_COMPUTE_DIM0(DATA_IN_0_PARALLELISM_DIM_0), + .A_COMPUTE_DIM1(DATA_IN_0_PARALLELISM_DIM_1), + .B_COMPUTE_DIM0(REAL_WEIGHT_PARALLELISM_DIM_0), + .B_COMPUTE_DIM1(REAL_WEIGHT_PARALLELISM_DIM_1), + + .A_WIDTH (DATA_IN_0_PRECISION_0), + .A_FRAC_WIDTH(DATA_IN_0_PRECISION_1), + .B_WIDTH (WEIGHT_PRECISION_0), + .B_FRAC_WIDTH(WEIGHT_PRECISION_1), + + .OUT_WIDTH (MMOUT_PRECISION_0), + .OUT_FRAC_WIDTH(MMOUT_PRECISION_1), + .OUT_SYMMETRIC (0) + ) matmul_i ( + .clk, + .rst, + + .a_data (data_in_0), + .a_valid(data_in_0_valid), + .a_ready(data_in_0_ready), + + .b_data (weight_transposed), + .b_valid(weight_transposed_valid), + .b_ready(weight_transposed_ready), + + .out_data (matmul_out), + .out_valid(matmul_out_valid), + .out_ready(matmul_out_ready) + ); + + // Bias output + if (HAS_BIAS == 1) begin + + join2 join2_matmul_bias_i ( + .data_in_valid ({matmul_out_valid, bias_buffered_valid}), + .data_in_ready ({matmul_out_ready, bias_buffered_ready}), + .data_out_valid(add_bias_in_valid), + .data_out_ready(add_bias_in_ready) + ); + + input_buffer #( + .DATA_WIDTH (BIAS_PRECISION_0), + .IN_NUM (BIAS_PARALLELISM_DIM_0 * BIAS_PARALLELISM_DIM_1), + .REPEAT (IN_0_DEPTH_DIM_1), + .BUFFER_SIZE(BIAS_DEPTH_DIM_0) + ) bias_buffer_inst ( + .clk, + .rst, + + // Input streaming port + .data_in(bias), + .data_in_valid(bias_valid), + .data_in_ready(bias_ready), + + // Output streaming port + .data_out(bias_buffered), + .data_out_valid(bias_buffered_valid), + .data_out_ready(bias_buffered_ready) + ); + fixed_rounding #( + .IN_SIZE (DATA_OUT_0_PARALLELISM_DIM_0 * DATA_OUT_0_PARALLELISM_DIM_1), + .IN_WIDTH (MMOUT_PRECISION_0 + 1), + .IN_FRAC_WIDTH (MMOUT_PRECISION_1), + .OUT_WIDTH (DATA_OUT_0_PRECISION_0), + .OUT_FRAC_WIDTH(DATA_OUT_0_PRECISION_1) + ) output_cast ( + .data_in (add_bias_in), + .data_out(add_bias_in_casted) + ); + unpacked_register_slice #( + .DATA_WIDTH(DATA_OUT_0_PRECISION_0), + .IN_SIZE (DATA_OUT_0_PARALLELISM_DIM_0 * DATA_OUT_0_PARALLELISM_DIM_1) + ) register_slice_i ( + .clk(clk), + .rst(rst), + + .data_in(add_bias_in_casted), + .data_in_valid(add_bias_in_valid), + .data_in_ready(add_bias_in_ready), + + .data_out(data_out_0), + .data_out_valid(data_out_0_valid), + .data_out_ready(data_out_0_ready) + ); + end + + // * Logic + // * --------------------------------------------------------------------------------------------------- + + if (WEIGHTS_PRE_TRANSPOSED == 1) begin + always_comb begin + weight_transposed_valid = weight_valid; + weight_ready = weight_transposed_ready; + end + + for (genvar i = 0; i < WEIGHT_PARALLELISM_DIM_0 * WEIGHT_PARALLELISM_DIM_1; i++) begin + assign weight_transposed[i] = weight[i]; + end + end + + // * Add bias + if (HAS_BIAS == 1) begin + fixed_cast #( + .IN_SIZE (BIAS_PARALLELISM_DIM_0 * BIAS_PARALLELISM_DIM_1), + .IN_WIDTH (BIAS_PRECISION_0), + .IN_FRAC_WIDTH (BIAS_PRECISION_1), + .OUT_WIDTH (MMOUT_PRECISION_0), + .OUT_FRAC_WIDTH(MMOUT_PRECISION_1) + ) bias_cast_i ( + .data_in (bias_buffered), + .data_out(bias_casted) + ); + + for (genvar i_0 = 0; i_0 < DATA_OUT_0_PARALLELISM_DIM_0; i_0++) begin + for (genvar i_1 = 0; i_1 < DATA_OUT_0_PARALLELISM_DIM_1; i_1++) begin + assign add_bias_in[i_1*DATA_OUT_0_PARALLELISM_DIM_0+i_0] = $signed( + matmul_out[i_1*DATA_OUT_0_PARALLELISM_DIM_0+i_0] + ) + $signed( + bias_casted[i_0] + ); + end + end + + end else begin + fixed_rounding #( + .IN_SIZE (DATA_OUT_0_PARALLELISM_DIM_0 * DATA_OUT_0_PARALLELISM_DIM_1), + .IN_WIDTH (MMOUT_PRECISION_0), + .IN_FRAC_WIDTH (MMOUT_PRECISION_1), + .OUT_WIDTH (DATA_OUT_0_PRECISION_0), + .OUT_FRAC_WIDTH(DATA_OUT_0_PRECISION_1) + ) output_cast ( + .data_in (matmul_out), + .data_out(data_out_0) + ); + assign data_out_0_valid = matmul_out_valid; + assign matmul_out_ready = data_out_0_ready; + end + +endmodule diff --git a/docs/labs/internal_rtl/hardware/rtl/fixed_mult.sv b/docs/labs/internal_rtl/hardware/rtl/fixed_mult.sv new file mode 100644 index 000000000..a70744964 --- /dev/null +++ b/docs/labs/internal_rtl/hardware/rtl/fixed_mult.sv @@ -0,0 +1,18 @@ +`timescale 1ns / 1ps +// fixed-point multiplier + +module fixed_mult #( + parameter IN_A_WIDTH = 32, + parameter IN_B_WIDTH = 32, + parameter type TYPE_A = logic [ IN_A_WIDTH-1:0], + parameter type TYPE_B = logic [ IN_B_WIDTH-1:0], + parameter type TYPE_PRODUCT = logic [IN_A_WIDTH+IN_B_WIDTH-1:0] +) ( + input TYPE_A data_a, + input TYPE_B data_b, + output TYPE_PRODUCT product +); + + assign product = $signed(data_a) * $signed(data_b); + +endmodule diff --git a/docs/labs/internal_rtl/hardware/rtl/fixed_relu.sv b/docs/labs/internal_rtl/hardware/rtl/fixed_relu.sv new file mode 100644 index 000000000..f117fba06 --- /dev/null +++ b/docs/labs/internal_rtl/hardware/rtl/fixed_relu.sv @@ -0,0 +1,54 @@ +`timescale 1ns / 1ps + +module fixed_relu #( + /* verilator lint_off UNUSEDPARAM */ + parameter DATA_IN_0_PRECISION_0 = 8, + parameter DATA_IN_0_PRECISION_1 = 3, + parameter DATA_IN_0_TENSOR_SIZE_DIM_0 = 8, + parameter DATA_IN_0_TENSOR_SIZE_DIM_1 = 1, + parameter DATA_IN_0_PARALLELISM_DIM_0 = 1, + parameter DATA_IN_0_PARALLELISM_DIM_1 = 1, + + parameter DATA_OUT_0_PRECISION_0 = 8, + parameter DATA_OUT_0_PRECISION_1 = 3, + parameter DATA_OUT_0_TENSOR_SIZE_DIM_0 = 8, + parameter DATA_OUT_0_TENSOR_SIZE_DIM_1 = 1, + parameter DATA_OUT_0_PARALLELISM_DIM_0 = 1, + parameter DATA_OUT_0_PARALLELISM_DIM_1 = 1, + + parameter INPLACE = 0 +) ( + /* verilator lint_off UNUSEDSIGNAL */ + input rst, + input clk, + input logic [DATA_IN_0_PRECISION_0-1:0] data_in_0[DATA_IN_0_PARALLELISM_DIM_0*DATA_IN_0_PARALLELISM_DIM_1-1:0], + output logic [DATA_OUT_0_PRECISION_0-1:0] data_out_0[DATA_OUT_0_PARALLELISM_DIM_0*DATA_OUT_0_PARALLELISM_DIM_1-1:0], + + input logic data_in_0_valid, + output logic data_in_0_ready, + output logic data_out_0_valid, + input logic data_out_0_ready +); + + initial begin + assert (DATA_IN_0_PRECISION_0 == DATA_OUT_0_PRECISION_0) + else $error("ReLU: DATA_IN_0_PRECISION_0 must be equal to DATA_OUT_0_PRECISION_0"); + assert (DATA_IN_0_PRECISION_1 == DATA_OUT_0_PRECISION_1) + else $error("ReLU: DATA_IN_0_PRECISION_1 must be equal to DATA_OUT_0_PRECISION_1"); + end + + + /* verilator lint_off SELRANGE */ + for ( + genvar i = 0; i < DATA_IN_0_PARALLELISM_DIM_0 * DATA_IN_0_PARALLELISM_DIM_1; i++ + ) begin : ReLU + always_comb begin + if ($signed(data_in_0[i]) <= 0) data_out_0[i] = '0; + else data_out_0[i] = data_in_0[i]; + end + end + + assign data_out_0_valid = data_in_0_valid; + assign data_in_0_ready = data_out_0_ready; + +endmodule diff --git a/docs/labs/internal_rtl/hardware/rtl/fixed_vector_mult.sv b/docs/labs/internal_rtl/hardware/rtl/fixed_vector_mult.sv new file mode 100644 index 000000000..2a6b0757a --- /dev/null +++ b/docs/labs/internal_rtl/hardware/rtl/fixed_vector_mult.sv @@ -0,0 +1,92 @@ +`timescale 1ns / 1ps +module fixed_vector_mult #( + parameter IN_WIDTH = 32, + parameter WEIGHT_WIDTH = 16, + // this is the width for the product + // parameter PRODUCT_WIDTH = 8, + // this is the width for the summed product + parameter OUT_WIDTH = IN_WIDTH + WEIGHT_WIDTH, + + // this defines the number of elements in the vector, this is tunable + parameter IN_SIZE = 4 +) ( + input clk, + input rst, + + // input port for activations + input logic [IN_WIDTH-1:0] data_in [IN_SIZE-1:0], + input data_in_valid, + output data_in_ready, + + // input port for weight + input logic [WEIGHT_WIDTH-1:0] weight [IN_SIZE-1:0], + input weight_valid, + output weight_ready, + + // output port + output logic [OUT_WIDTH-1:0] data_out [IN_SIZE-1:0], + output data_out_valid, + input data_out_ready +); + + localparam PRODUCT_WIDTH = IN_WIDTH + WEIGHT_WIDTH; + + // pv[i] = data_in[i] * w[i] + logic [PRODUCT_WIDTH-1:0] product_vector[IN_SIZE-1:0]; + logic product_data_in_valid; + logic product_data_in_ready; + logic product_data_out_valid; + logic product_data_out_ready; + logic [$bits(product_vector)-1:0] product_data_in; + logic [$bits(product_vector)-1:0] product_data_out; + + for (genvar i = 0; i < IN_SIZE; i = i + 1) begin : parallel_mult + fixed_mult #( + .IN_A_WIDTH(IN_WIDTH), + .IN_B_WIDTH(WEIGHT_WIDTH) + ) fixed_mult_inst ( + .data_a (data_in[i]), + .data_b (weight[i]), + .product(product_vector[i]) + ); + end + + + + join2 #() join_inst ( + .data_in_ready ({weight_ready, data_in_ready}), + .data_in_valid ({weight_valid, data_in_valid}), + .data_out_valid(product_data_in_valid), + .data_out_ready(product_data_in_ready) + ); + + // Cocotb/verilator does not support array flattening, so + // we need to manually add some reshaping process. + + // Casting array for product vector + for (genvar i = 0; i < IN_SIZE; i++) begin : reshape_in + assign product_data_in[PRODUCT_WIDTH*i+PRODUCT_WIDTH-1:PRODUCT_WIDTH*i] = product_vector[i]; + end + + skid_buffer #( + .DATA_WIDTH($bits(product_vector)) + ) register_slice ( + .clk (clk), + .rst (rst), + .data_in (product_data_in), + .data_in_valid (product_data_in_valid), + .data_in_ready (product_data_in_ready), + .data_out (product_data_out), + .data_out_valid(product_data_out_valid), + .data_out_ready(product_data_out_ready) + ); + + // Casting array for product vector + for (genvar i = 0; i < IN_SIZE; i++) begin : reshape_out + assign data_out[i] = product_data_out[PRODUCT_WIDTH*i+PRODUCT_WIDTH-1:PRODUCT_WIDTH*i]; + end + + assign data_out_valid = product_data_out_valid; + assign product_data_out_ready = data_out_ready; + +endmodule diff --git a/docs/labs/internal_rtl/hardware/rtl/join2.sv b/docs/labs/internal_rtl/hardware/rtl/join2.sv new file mode 100644 index 000000000..391daf5a0 --- /dev/null +++ b/docs/labs/internal_rtl/hardware/rtl/join2.sv @@ -0,0 +1,36 @@ +`timescale 1ns / 1ps +// Join2 synchronises two sets of input handshake signals with a set of output handshaked signals +module join2 #( +) ( + input logic [1:0] data_in_valid, + output logic [1:0] data_in_ready, + output logic data_out_valid, + input logic data_out_ready +); + + // If only one of the inputs is valid - we need to stall that input and wait + // for the other input by setting one of the ready bit to 0. + // +-----------+-----------+------------+------------+------------+ + // | data_out_ready | invalid_0 | data_in_valid_1 | data_in_ready_0 | data_in_ready_1 | + // +-----------+-----------+------------+------------+------------+ + // | 0 | 0 | 0 | 0 | 0 | + // +-----------+-----------+------------+------------+------------+ + // | 0 | 0 | 1 | 0 | 0 | + // +-----------+-----------+------------+------------+------------+ + // | 0 | 1 | 0 | 0 | 0 | + // +-----------+-----------+------------+------------+------------+ + // | 0 | 1 | 1 | 0 | 0 | + // +-----------+-----------+------------+------------+------------+ + // | 1 | 0 | 0 | 1 | 1 | + // +-----------+-----------+------------+------------+------------+ + // | 1 | 0 | 1 | 1 | 0 | + // +-----------+-----------+------------+------------+------------+ + // | 1 | 1 | 0 | 0 | 1 | + // +-----------+-----------+------------+------------+------------+ + // | 1 | 1 | 1 | 1 | 1 | + // +-----------+-----------+------------+------------+------------+ + assign data_in_ready[0] = data_out_ready & (!data_in_valid[0] | data_in_valid[1]); + assign data_in_ready[1] = data_out_ready & (!data_in_valid[1] | data_in_valid[0]); + assign data_out_valid = &data_in_valid; + +endmodule diff --git a/docs/labs/internal_rtl/hardware/rtl/matmul.sv b/docs/labs/internal_rtl/hardware/rtl/matmul.sv new file mode 100644 index 000000000..e8443426a --- /dev/null +++ b/docs/labs/internal_rtl/hardware/rtl/matmul.sv @@ -0,0 +1,363 @@ +/* +Module : matmul +Description : This module does a matrix multiplcation between matrices X & Y. + + Try to use port B as the weights if you can optimise the bitwidth + more aggressively than the input data. This is because as the B + matrix is buffered, while the data in port A is streamed. + + Please refer to matrix multiplication documentation in + "docs/hardware" for the parameter naming conventions and + algorithm/order of matrix multiplications. + + Inputs are streamed row-wise (along dim0 first) in 2D sub-blocks + which have width and height equal to the compute dimensions. + + Python equivalent: + z = np.matmul(A, B) +*/ + +`timescale 1ns / 1ps + +/* verilator lint_off UNUSEDPARAM */ +module matmul #( + // Total dimensions + parameter A_TOTAL_DIM0 = 4, + parameter A_TOTAL_DIM1 = 4, + parameter B_TOTAL_DIM0 = 4, + parameter B_TOTAL_DIM1 = 4, // must equal A_TOTAL_DIM0 + + // Compute dimensions + parameter A_COMPUTE_DIM0 = 2, + parameter A_COMPUTE_DIM1 = 2, + parameter B_COMPUTE_DIM0 = 2, + parameter B_COMPUTE_DIM1 = 2, // must equal A_COMPUTE_DIM0 + + // Input fixed point widths + parameter A_WIDTH = 8, + parameter A_FRAC_WIDTH = 1, + parameter B_WIDTH = 8, + parameter B_FRAC_WIDTH = 1, + + // Output fixed point widths + parameter OUT_WIDTH = 16, + parameter OUT_FRAC_WIDTH = 2, + + // Output casting/rounding + parameter OUT_SYMMETRIC = 0, + + // Derived Dimensions (Constants) + localparam C_TOTAL_DIM0 = B_TOTAL_DIM0, + localparam C_TOTAL_DIM1 = A_TOTAL_DIM1, + localparam C_COMPUTE_DIM0 = B_COMPUTE_DIM0, + localparam C_COMPUTE_DIM1 = A_COMPUTE_DIM1, + + // Derived Depth (Constants) + localparam A_DEPTH_DIM0 = A_TOTAL_DIM0 / A_COMPUTE_DIM0, + localparam A_DEPTH_DIM1 = A_TOTAL_DIM1 / A_COMPUTE_DIM1, + localparam B_DEPTH_DIM0 = B_TOTAL_DIM0 / B_COMPUTE_DIM0, + localparam B_DEPTH_DIM1 = B_TOTAL_DIM1 / B_COMPUTE_DIM1, + localparam C_DEPTH_DIM0 = C_TOTAL_DIM0 / C_COMPUTE_DIM0, + localparam C_DEPTH_DIM1 = C_TOTAL_DIM1 / C_COMPUTE_DIM1 +) ( + input logic clk, + input logic rst, + + // Matix A - row-major order + input logic [A_WIDTH-1:0] a_data [A_COMPUTE_DIM0*A_COMPUTE_DIM1-1:0], + input logic a_valid, + output logic a_ready, + + // Matix B - row-major order + input logic [B_WIDTH-1:0] b_data [B_COMPUTE_DIM0*B_COMPUTE_DIM1-1:0], + input logic b_valid, + output logic b_ready, + + // Matrix C - row-major order + output logic [OUT_WIDTH-1:0] out_data [C_COMPUTE_DIM0*C_COMPUTE_DIM1-1:0], + output logic out_valid, + input logic out_ready +); + initial begin + // Check dimension constraint not violated + assert (A_TOTAL_DIM0 == B_TOTAL_DIM1) + else $fatal("A_TOTAL_DIM0 must equal B_TOTAL_DIM1!"); + assert (A_COMPUTE_DIM0 == B_COMPUTE_DIM1) + else $fatal("A_COMPUTE_DIM0 must equal B_COMPUTE_DIM1!"); + + // Check compute vs. total divisibility + assert (A_TOTAL_DIM0 % A_COMPUTE_DIM0 == 0) + else $fatal("A_DIM0 compute is not divisible!"); + assert (A_TOTAL_DIM1 % A_COMPUTE_DIM1 == 0) + else $fatal("A_DIM1 compute is not divisible!"); + assert (B_TOTAL_DIM0 % B_COMPUTE_DIM0 == 0) + else $fatal("B_DIM0 compute is not divisible!"); + assert (B_TOTAL_DIM1 % B_COMPUTE_DIM1 == 0) + else $fatal("B_DIM1 compute is not divisible!"); + end + + // ----- + // Params + // ----- + + localparam A_FLAT_WIDTH = A_WIDTH * A_COMPUTE_DIM0 * A_COMPUTE_DIM1; + localparam B_FLAT_WIDTH = B_WIDTH * B_COMPUTE_DIM0 * B_COMPUTE_DIM1; + + localparam SM_OUT_WIDTH = A_WIDTH + B_WIDTH + $clog2(A_COMPUTE_DIM0); + localparam SM_OUT_FRAC_WIDTH = A_FRAC_WIDTH + B_FRAC_WIDTH; + + localparam MAT_ACC_PTR_WIDTH = C_DEPTH_DIM0 == 1 ? 1 : $clog2(C_DEPTH_DIM0); + localparam MAT_ACC_OUT_WIDTH = $clog2(B_DEPTH_DIM1) + SM_OUT_WIDTH; + + // ----- + // Wires + // ----- + + // Buffer unflatten out + logic a_buffer_out_valid, a_buffer_out_ready; + logic [A_WIDTH-1:0] a_buffer_out_data[A_COMPUTE_DIM0*A_COMPUTE_DIM1-1:0]; + + // Repeat each submatrix in Matrix A stream B_DEPTH_DIM0 times + // Only if (B_DEPTH_DIM0 > 1) + logic [A_FLAT_WIDTH-1:0] a_data_flat; + logic [A_FLAT_WIDTH-1:0] a_buffer_out_data_flat; + + // We need to buffer the B matrix + // TODO: unless A_DEPTH_DIM1 == 1 + + logic [B_FLAT_WIDTH-1:0] b_data_flat; + + // Buffer outputs + logic [B_FLAT_WIDTH-1:0] b_buffer_out_data_flat; + logic b_buffer_out_valid, b_buffer_out_ready; + + // Matrix unflatten output + logic [B_WIDTH-1:0] b_buffer_out_data[B_COMPUTE_DIM0*B_COMPUTE_DIM1-1:0]; + + logic [SM_OUT_WIDTH-1:0] sm_out_data[C_COMPUTE_DIM0*C_COMPUTE_DIM1]; + logic sm_out_valid, sm_out_ready; + + logic [C_DEPTH_DIM0-1:0] acc_in_valid; + logic [C_DEPTH_DIM0-1:0] acc_in_ready; + logic [C_DEPTH_DIM0-1:0] acc_out_valid; + logic [C_DEPTH_DIM0-1:0] acc_out_ready; + logic [MAT_ACC_OUT_WIDTH-1:0] acc_out_data[C_DEPTH_DIM0-1:0][C_COMPUTE_DIM0*C_COMPUTE_DIM1-1:0]; + + logic [MAT_ACC_OUT_WIDTH-1:0] cast_in_data[C_COMPUTE_DIM0*C_COMPUTE_DIM1-1:0]; + + + // ----- + // State + // ----- + + struct { + // Points to which matrix accumulator should store the simple_matmul output + logic [MAT_ACC_PTR_WIDTH-1:0] matrix_acc_ptr; + // Points at which output accumulator should be connected to the out stream + logic [MAT_ACC_PTR_WIDTH-1:0] output_acc_ptr; + } + self, next_self; + + + // ----- + // Logic + // ----- + + generate + + // A matrix Buffers + + if (B_DEPTH_DIM0 > 1) begin : gen_a_buffer + + matrix_flatten #( + .DATA_WIDTH(A_WIDTH), + .DIM0 (A_COMPUTE_DIM0), + .DIM1 (A_COMPUTE_DIM1) + ) weight_buffer_flatten_a ( + .data_in (a_data), + .data_out(a_data_flat) + ); + + single_element_repeat #( + .DATA_WIDTH(A_FLAT_WIDTH), + // Repeat for number of rows in matrix A + .REPEAT (B_DEPTH_DIM0) + ) input_stream_buffer ( + .clk (clk), + .rst (rst), + .in_data (a_data_flat), + .in_valid (a_valid), + .in_ready (a_ready), + .out_data (a_buffer_out_data_flat), + .out_valid(a_buffer_out_valid), + .out_ready(a_buffer_out_ready) + ); + + matrix_unflatten #( + .DATA_WIDTH(A_WIDTH), + .DIM0 (A_COMPUTE_DIM0), + .DIM1 (A_COMPUTE_DIM1) + ) weight_buffer_unflatten_a ( + .data_in (a_buffer_out_data_flat), + .data_out(a_buffer_out_data) + ); + + end else begin : gen_a_reg_slice + + // Add a register stage to cut any combinatoral paths to simple matmul + unpacked_skid_buffer #( + .DATA_WIDTH(A_WIDTH), + .IN_NUM (A_COMPUTE_DIM0 * A_COMPUTE_DIM1) + ) input_stream_reg_slice ( + .clk (clk), + .rst (rst), + .data_in (a_data), + .data_in_valid (a_valid), + .data_in_ready (a_ready), + .data_out (a_buffer_out_data), + .data_out_valid(a_buffer_out_valid), + .data_out_ready(a_buffer_out_ready) + ); + end + + // A matrix Buffers + + if (A_DEPTH_DIM1 > 1) begin : g_circular_buffer + + input_buffer #( + .DATA_WIDTH (B_WIDTH), + .IN_NUM (B_COMPUTE_DIM0 * B_COMPUTE_DIM1), + .REPEAT (A_DEPTH_DIM1), + .BUFFER_SIZE(B_DEPTH_DIM0 * B_DEPTH_DIM1) + ) weight_buffer ( + .clk, + .rst, + + // Input streaming port + .data_in(b_data), + .data_in_valid(b_valid), + .data_in_ready(b_ready), + + // Output streaming port + .data_out(b_buffer_out_data), + .data_out_valid(b_buffer_out_valid), + .data_out_ready(b_buffer_out_ready) + ); + end else begin + assign b_buffer_out_data = b_data; + assign b_buffer_out_valid = b_valid; + assign b_ready = b_buffer_out_ready; + end + endgenerate + + // Feed input A & buffered input B into simple matrix mult + + // Simple matrix multiply block's accumulator width + // We do not round at simple_matmul level as we want to keep high precision + // and round ourselves after the output accumulation in this matmul module. + + simple_matmul #( + .N (A_COMPUTE_DIM1), + .M (A_COMPUTE_DIM0), // == B_COMPUTE_DIM1 + .K (B_COMPUTE_DIM0), + .X_WIDTH (A_WIDTH), + .X_FRAC_WIDTH (A_FRAC_WIDTH), + .Y_WIDTH (B_WIDTH), + .Y_FRAC_WIDTH (B_FRAC_WIDTH), + .OUT_WIDTH (SM_OUT_WIDTH), + .OUT_FRAC_WIDTH(SM_OUT_FRAC_WIDTH) + ) simple_matmul_inst ( + .clk (clk), + .rst (rst), + .x_data (a_buffer_out_data), + .x_valid (a_buffer_out_valid), + .x_ready (a_buffer_out_ready), + .y_data (b_buffer_out_data), + .y_valid (b_buffer_out_valid), + .y_ready (b_buffer_out_ready), + .out_data (sm_out_data), + .out_valid(sm_out_valid), + .out_ready(sm_out_ready) + ); + + // Direct the result of the simple matmul to the correct matrix_accumulator + + for (genvar i = 0; i < C_DEPTH_DIM0; i++) begin : gen_acc + matrix_accumulator #( + .IN_DEPTH(B_DEPTH_DIM1), + .IN_WIDTH(SM_OUT_WIDTH), + .DIM0 (C_COMPUTE_DIM0), + .DIM1 (C_COMPUTE_DIM1) + ) matrix_acc_inst ( + .clk (clk), + .rst (rst), + .in_data (sm_out_data), + .in_valid (acc_in_valid[i]), + .in_ready (acc_in_ready[i]), + .out_data (acc_out_data[i]), + .out_valid(acc_out_valid[i]), + .out_ready(acc_out_ready[i]) + ); + end + + for (genvar i = 0; i < C_DEPTH_DIM0; i++) begin : gen_handshake + // Change which accumulator the output of simple_matmul goes to + assign acc_in_valid[i] = self.matrix_acc_ptr == i ? sm_out_valid : 0; + + // Select which accumulator can output on out stream + assign acc_out_ready[i] = self.output_acc_ptr == i ? out_ready : 0; + end + + assign sm_out_ready = acc_in_ready[self.matrix_acc_ptr]; + + for (genvar i = 0; i < C_COMPUTE_DIM0 * C_COMPUTE_DIM1; i++) begin : gen_cast + fixed_signed_cast #( + .IN_WIDTH (MAT_ACC_OUT_WIDTH), + .IN_FRAC_WIDTH (SM_OUT_FRAC_WIDTH), + .OUT_WIDTH (OUT_WIDTH), + .OUT_FRAC_WIDTH(OUT_FRAC_WIDTH), + .SYMMETRIC (OUT_SYMMETRIC), + .ROUND_FLOOR (1) + ) output_cast ( + .in_data (cast_in_data[i]), + .out_data(out_data[i]) + ); + end + + // Logic to handle accumulator selection & output selection. + always_comb begin + next_self = self; + + for (int i = 0; i < C_COMPUTE_DIM0 * C_COMPUTE_DIM1; i++) begin + cast_in_data[i] = acc_out_data[self.output_acc_ptr][i]; + end + out_valid = acc_out_valid[self.output_acc_ptr]; + + // Change accumulator pointer + if (sm_out_valid && sm_out_ready) begin + if (self.matrix_acc_ptr == C_DEPTH_DIM0 - 1) begin + next_self.matrix_acc_ptr = 0; + end else begin + next_self.matrix_acc_ptr += 1; + end + end + + // Change output pointer + if (|acc_out_ready && |acc_out_valid) begin + if (self.output_acc_ptr == C_DEPTH_DIM0 - 1) begin + next_self.output_acc_ptr = 0; + end else begin + next_self.output_acc_ptr += 1; + end + end + end + + always_ff @(posedge clk) begin + if (rst) begin + self <= '{default: 0}; + end else begin + self <= next_self; + end + end + +endmodule diff --git a/docs/labs/internal_rtl/hardware/rtl/matrix_accumulator.sv b/docs/labs/internal_rtl/hardware/rtl/matrix_accumulator.sv new file mode 100644 index 000000000..fe9b01cde --- /dev/null +++ b/docs/labs/internal_rtl/hardware/rtl/matrix_accumulator.sv @@ -0,0 +1,53 @@ +/* +Module : matrix_accumulator +Description : This module instantiated a 2D array of accumulators. +*/ + +`timescale 1ns / 1ps + +module matrix_accumulator #( + parameter IN_DEPTH = 4, + parameter IN_WIDTH = 32, + parameter DIM0 = 2, + parameter DIM1 = 2, + + // Derived parameter + localparam OUT_WIDTH = $clog2(IN_DEPTH) + IN_WIDTH +) ( + input logic clk, + input logic rst, + input logic [ IN_WIDTH-1:0] in_data [DIM0*DIM1-1:0], + input logic in_valid, + output logic in_ready, + output logic [OUT_WIDTH-1:0] out_data [DIM0*DIM1-1:0], + output logic out_valid, + input logic out_ready +); + + for (genvar i = 0; i < DIM1; i++) begin : rows + for (genvar j = 0; j < DIM0; j++) begin : columns + /* verilator lint_off UNUSEDSIGNAL */ + logic in_ready_signal, out_valid_signal; + /* verilator lint_on UNUSEDSIGNAL */ + fixed_accumulator #( + .IN_DEPTH(IN_DEPTH), + .IN_WIDTH(IN_WIDTH) + ) acc_inst ( + .clk (clk), + .rst (rst), + .data_in (in_data[i*DIM0+j]), + .data_in_valid (in_valid), + .data_in_ready (in_ready_signal), + .data_out (out_data[i*DIM0+j]), + .data_out_valid(out_valid_signal), + .data_out_ready(out_ready) + ); + end + end + + // All accumulators should be synchronised, so we can take a single signal for + // out_valid and in_ready + assign in_ready = rows[0].columns[0].in_ready_signal; + assign out_valid = rows[0].columns[0].out_valid_signal; + +endmodule diff --git a/docs/labs/internal_rtl/hardware/rtl/matrix_fifo.sv b/docs/labs/internal_rtl/hardware/rtl/matrix_fifo.sv new file mode 100644 index 000000000..b66af38e0 --- /dev/null +++ b/docs/labs/internal_rtl/hardware/rtl/matrix_fifo.sv @@ -0,0 +1,67 @@ +/* +Module : matrix_fifo +Description : FIFO to buffer matrices or 2D data. +*/ + +`timescale 1ns / 1ps + +module matrix_fifo #( + // Dimensions + parameter DATA_WIDTH = 8, + parameter DIM0 = 4, + parameter DIM1 = 4, + parameter FIFO_SIZE = 32 +) ( + input logic clk, + input logic rst, + + input logic [DATA_WIDTH-1:0] in_data [DIM0*DIM1-1:0], + input logic in_valid, + output logic in_ready, + + output logic [DATA_WIDTH-1:0] out_data [DIM0*DIM1-1:0], + output logic out_valid, + input logic out_ready +); + + // Wires + localparam FLAT_DATA_WIDTH = DATA_WIDTH * DIM0 * DIM1; + + logic [FLAT_DATA_WIDTH-1:0] in_data_flat, out_data_flat; + + // Modules + matrix_flatten #( + .DATA_WIDTH(DATA_WIDTH), + .DIM0 (DIM0), + .DIM1 (DIM1) + ) input_flatten ( + .data_in (in_data), + .data_out(in_data_flat) + ); + + fifo #( + .DEPTH (FIFO_SIZE), + .DATA_WIDTH(FLAT_DATA_WIDTH) + ) input_fifo_inst ( + .clk (clk), + .rst (rst), + .in_data (in_data_flat), + .in_valid (in_valid), + .in_ready (in_ready), + .out_data (out_data_flat), + .out_valid(out_valid), + .out_ready(out_ready), + .empty (), + .full () + ); + + matrix_unflatten #( + .DATA_WIDTH(DATA_WIDTH), + .DIM0 (DIM0), + .DIM1 (DIM1) + ) fifo_unflatten ( + .data_in (out_data_flat), + .data_out(out_data) + ); + +endmodule diff --git a/docs/labs/internal_rtl/hardware/rtl/matrix_flatten.sv b/docs/labs/internal_rtl/hardware/rtl/matrix_flatten.sv new file mode 100644 index 000000000..f77036fb9 --- /dev/null +++ b/docs/labs/internal_rtl/hardware/rtl/matrix_flatten.sv @@ -0,0 +1,27 @@ +/* +Module : matrix_flatten +Description : This module is a purely combinatorial block which flattens a + row-major matrix data stream into 1D bit vector. + + Assumptions: + data_in is already in row-major form. + + To reverse the result, you can use the "matrix_unflatten" module. +*/ + +`timescale 1ns / 1ps + +module matrix_flatten #( + parameter DATA_WIDTH = 32, + parameter DIM0 = 4, + parameter DIM1 = 4 +) ( + input logic [ DATA_WIDTH-1:0] data_in [DIM0*DIM1-1:0], + output logic [DATA_WIDTH*DIM0*DIM1-1:0] data_out +); + + for (genvar i = 0; i < DIM0 * DIM1; i++) begin + assign data_out[(i+1)*DATA_WIDTH-1:i*DATA_WIDTH] = data_in[i]; + end + +endmodule diff --git a/docs/labs/internal_rtl/hardware/rtl/matrix_stream_transpose.sv b/docs/labs/internal_rtl/hardware/rtl/matrix_stream_transpose.sv new file mode 100644 index 000000000..9548d37d8 --- /dev/null +++ b/docs/labs/internal_rtl/hardware/rtl/matrix_stream_transpose.sv @@ -0,0 +1,255 @@ +/* +Module : matrix_stream_transpose +Description : This module handles transposing a matrix which is being streamed + one compute chunk at a time. + + Note: If you need a simpler combinatorial transpose for a + non-streaming architecture, see transpose.sv +*/ + +`timescale 1ns / 1ps + +// TODO: fix throughput problems + +module matrix_stream_transpose #( + // Total dimensions + parameter TOTAL_DIM0 = 4, + parameter TOTAL_DIM1 = 4, + + // Compute dimensions + parameter COMPUTE_DIM0 = 2, + parameter COMPUTE_DIM1 = 2, + + // Other params + parameter DATA_WIDTH = 8 +) ( + input logic clk, + input logic rst, + + // In Matrix + input logic [DATA_WIDTH-1:0] in_data [COMPUTE_DIM0*COMPUTE_DIM1-1:0], + input logic in_valid, + output logic in_ready, + + // Out Matrix + output logic [DATA_WIDTH-1:0] out_data [COMPUTE_DIM0*COMPUTE_DIM1-1:0], + output logic out_valid, + input logic out_ready +); + + initial begin + // Check compute vs. total divisibility + assert (TOTAL_DIM0 % COMPUTE_DIM0 == 0) + else $fatal("DIM0 compute is not divisible!"); + assert (TOTAL_DIM1 % COMPUTE_DIM1 == 0) + else $fatal("DIM1 compute is not divisible!"); + end + + // ----- + // Parameters + // ----- + // let max(a, b) = (a > b) ? a : b; + + localparam IN_DEPTH_DIM0 = TOTAL_DIM0 / COMPUTE_DIM0; + localparam IN_DEPTH_DIM1 = TOTAL_DIM1 / COMPUTE_DIM1; + localparam OUT_DEPTH_DIM0 = IN_DEPTH_DIM1; + localparam OUT_DEPTH_DIM1 = IN_DEPTH_DIM0; + localparam IN_ROW_COUNTER_WIDTH = $clog2(IN_DEPTH_DIM1) > 1 ? $clog2(IN_DEPTH_DIM1) : 1; + localparam IN_COL_COUNTER_WIDTH = $clog2(IN_DEPTH_DIM0) > 1 ? $clog2(IN_DEPTH_DIM0) : 1; + localparam OUT_ROW_COUNTER_WIDTH = $clog2(OUT_DEPTH_DIM1) > 1 ? $clog2(OUT_DEPTH_DIM1) : 1; + localparam OUT_COL_COUNTER_WIDTH = $clog2(OUT_DEPTH_DIM0) > 1 ? $clog2(OUT_DEPTH_DIM0) : 1; + + localparam FIFO_DEPTH = IN_DEPTH_DIM1; + localparam FIFO_DATA_WIDTH = DATA_WIDTH * COMPUTE_DIM0 * COMPUTE_DIM1; + + // ----- + // State + // ----- + + struct { + // Current row & col that the window is at for the input + logic [IN_ROW_COUNTER_WIDTH-1:0] in_row_count; + logic [IN_COL_COUNTER_WIDTH-1:0] in_col_count; + // Current row & col that the window is at for the output + logic [OUT_ROW_COUNTER_WIDTH-1:0] out_row_count; + logic [OUT_COL_COUNTER_WIDTH-1:0] out_col_count; + } + self, next_self; + + // ----- + // Wires + // ----- + + logic [FIFO_DATA_WIDTH-1:0] in_data_flat; + logic [FIFO_DATA_WIDTH-1:0] fifo_in_data[IN_DEPTH_DIM0-1:0]; + logic fifo_in_valid[IN_DEPTH_DIM0-1:0]; + logic fifo_in_ready[IN_DEPTH_DIM0-1:0]; + logic [FIFO_DATA_WIDTH-1:0] fifo_out_data_flat[IN_DEPTH_DIM0-1:0]; + logic fifo_out_valid[IN_DEPTH_DIM0-1:0]; + logic fifo_out_ready[IN_DEPTH_DIM0-1:0]; + + logic fifo_data_readys[IN_DEPTH_DIM0-1:0]; + + logic [FIFO_DATA_WIDTH-1:0] fifo_out_data_flat_mux_in[IN_DEPTH_DIM0-1:0]; + logic [FIFO_DATA_WIDTH-1:0] fifo_out_data_flat_mux_out; + logic fifo_out_valids[IN_DEPTH_DIM0-1:0]; + + logic [DATA_WIDTH-1:0] transpose_data_in[COMPUTE_DIM0*COMPUTE_DIM1-1:0]; + + + // FIFOs + // We want to generate IN_DEPTH_DIM0 FIFOs to buffer the input chunks. + // Each FIFO will need to be IN_DEPTH_DIM1 elements deep and each element will + // be flattened to be size (DATA_WIDTH * COMPUTE_DIM0 * COMPUTE_DIM1) + + matrix_flatten #( + .DATA_WIDTH(DATA_WIDTH), + .DIM0 (COMPUTE_DIM0), + .DIM1 (COMPUTE_DIM1) + ) flatten_inst ( + .data_in (in_data), + .data_out(in_data_flat) + ); + + for (genvar i = 0; i < IN_DEPTH_DIM0; i++) begin : fifos + fifo #( + .DEPTH (FIFO_DEPTH), + .DATA_WIDTH(FIFO_DATA_WIDTH) + ) fifo_inst ( + .clk (clk), + .rst (rst), + .in_data (fifo_in_data[i]), + .in_valid (fifo_in_valid[i]), + .in_ready (fifo_in_ready[i]), + .out_data (fifo_out_data_flat[i]), + .out_valid(fifo_out_valid[i]), + .out_ready(fifo_out_ready[i]), + .empty (), + .full () + ); + end + + // Connect up wires to write to all of the fifos using in_col_count as index + // The valid and ready signals will be used to select which one is written to + for (genvar i = 0; i < IN_DEPTH_DIM0; i++) begin + assign fifo_in_data[i] = in_data_flat; + assign fifo_in_valid[i] = (self.in_col_count == i) ? in_valid : 0; + assign fifo_data_readys[i] = fifo_in_ready[i]; + end + + generate + if (IN_DEPTH_DIM0 > 1) begin : gen_in_ready_mux + mux #( + .NUM_INPUTS(IN_DEPTH_DIM0), + .DATA_WIDTH(1) + ) in_ready_mux ( + .data_in (fifo_data_readys), + .select (self.in_col_count), + .data_out(in_ready) + ); + end else begin : gen_fifo_in_ready + assign in_ready = fifo_data_readys[0]; + end + endgenerate + + // Connect up wires to read from all of the fifos using out_row_count to index + // into the column fifos which buffer the matrix + + for (genvar i = 0; i < IN_DEPTH_DIM0; i++) begin + assign fifo_out_data_flat_mux_in[i] = fifo_out_data_flat[i]; + assign fifo_out_valids[i] = fifo_out_valid[i]; + assign fifo_out_ready[i] = (self.out_row_count == i) ? out_ready : 0; + end + + generate + if (IN_DEPTH_DIM0 > 1) begin : gen_fifo_out_muxes + mux #( + .NUM_INPUTS(IN_DEPTH_DIM0), + .DATA_WIDTH(FIFO_DATA_WIDTH) + ) fifo_data_out_mux ( + .data_in (fifo_out_data_flat_mux_in), + .select (self.out_row_count), + .data_out(fifo_out_data_flat_mux_out) + ); + mux #( + .NUM_INPUTS(IN_DEPTH_DIM0), + .DATA_WIDTH(1) + ) fifo_data_valid_mux ( + .data_in (fifo_out_valids), + .select (self.out_row_count), + .data_out(out_valid) + ); + end else begin : gen_fifo_out + assign fifo_out_data_flat_mux_out = fifo_out_data_flat_mux_in[0]; + assign out_valid = fifo_out_valids[0]; + end + endgenerate + + // Unflatten FIFO data + matrix_unflatten #( + .DATA_WIDTH(DATA_WIDTH), + .DIM0 (COMPUTE_DIM0), + .DIM1 (COMPUTE_DIM1) + ) pre_transpose_flatten_inst ( + .data_in (fifo_out_data_flat_mux_out), + .data_out(transpose_data_in) + ); + + // Combinatorial transpose module + transpose #( + .WIDTH(DATA_WIDTH), + .DIM0 (COMPUTE_DIM0), + .DIM1 (COMPUTE_DIM1) + ) transpose_inst ( + .in_data (transpose_data_in), + .out_data(out_data) + ); + + always_comb begin + next_self = self; + + // Increment input side counters + if (in_valid && in_ready) begin + if (self.in_row_count == IN_DEPTH_DIM1 - 1 && self.in_col_count == IN_DEPTH_DIM0 - 1) begin + // End of matrix + next_self.in_row_count = 0; + next_self.in_col_count = 0; + end else if (self.in_col_count == IN_DEPTH_DIM0 - 1) begin + // End of row + next_self.in_row_count = self.in_row_count + 1; + next_self.in_col_count = 0; + end else begin + // Increment col counter + next_self.in_col_count = self.in_col_count + 1; + end + end + + // Increment output side counters + if (out_valid && out_ready) begin + if (self.out_row_count == OUT_DEPTH_DIM1 - 1 && + self.out_col_count == OUT_DEPTH_DIM0 - 1) begin + // End of matrix + next_self.out_row_count = 0; + next_self.out_col_count = 0; + end else if (self.out_col_count == OUT_DEPTH_DIM0 - 1) begin + // End of row + next_self.out_row_count = self.out_row_count + 1; + next_self.out_col_count = 0; + end else begin + // Increment col counter + next_self.out_col_count = self.out_col_count + 1; + end + end + + end + + always_ff @(posedge clk) begin + if (rst) begin + self <= '{default: 0}; + end else begin + self <= next_self; + end + end + + +endmodule diff --git a/docs/labs/internal_rtl/hardware/rtl/matrix_unflatten.sv b/docs/labs/internal_rtl/hardware/rtl/matrix_unflatten.sv new file mode 100644 index 000000000..27c48789b --- /dev/null +++ b/docs/labs/internal_rtl/hardware/rtl/matrix_unflatten.sv @@ -0,0 +1,22 @@ +/* +Module : matrix_unflatten +Description : This module is a purely combinatorial block which unflattens a 1D + bit vector generated by the "matrix_flatten" module. +*/ + +`timescale 1ns / 1ps + +module matrix_unflatten #( + parameter DATA_WIDTH = 32, + parameter DIM0 = 4, + parameter DIM1 = 4 +) ( + input logic [DATA_WIDTH*DIM0*DIM1-1:0] data_in, + output logic [ DATA_WIDTH-1:0] data_out[DIM0*DIM1-1:0] +); + + for (genvar i = 0; i < DIM0 * DIM1; i++) begin + assign data_out[i] = data_in[(i+1)*DATA_WIDTH-1:i*DATA_WIDTH]; + end + +endmodule diff --git a/docs/labs/internal_rtl/hardware/rtl/register_slice.sv b/docs/labs/internal_rtl/hardware/rtl/register_slice.sv new file mode 100644 index 000000000..fc12055f7 --- /dev/null +++ b/docs/labs/internal_rtl/hardware/rtl/register_slice.sv @@ -0,0 +1,91 @@ +`timescale 1ns / 1ps +module register_slice #( + parameter DATA_WIDTH = 32, + parameter type MYDATA = logic [DATA_WIDTH-1:0] +) ( + input logic clk, + input logic rst, + + input MYDATA data_in, + input logic data_in_valid, + output logic data_in_ready, + + output MYDATA data_out, + output logic data_out_valid, + input logic data_out_ready +); + + // The buffer stores the intermeidate data being computed in the register slice + logic [DATA_WIDTH-1:0] buffer; + // The shift register stores the validity of the data in the buffer + logic shift_reg; + logic to_load; + + + // There are eight cases: + // +----------------+----------+-----------+----------------+--------------------+---------------------+ + // | hold_any_input | in_valid | out_ready | accept_next_in | release_current_in | hold_any_input_next | + // | (shift_reg) | | | | | (shift_reg_next) | + // +----------------+----------+-----------+----------------+--------------------+---------------------+ + // | 0 | 0 | 0 | 0 | 0 | 0 | + // +----------------+----------+-----------+----------------+--------------------+---------------------+ + // | 0 | 0 | 1 | 0 | 0 | 0 | + // +----------------+----------+-----------+----------------+--------------------+---------------------+ + // | 0 | 1 | 0 | 1 | 0 | 1 | + // +----------------+----------+-----------+----------------+--------------------+---------------------+ + // | 0 | 1 | 1 | 1 | 0 | 1 | + // +----------------+----------+-----------+----------------+--------------------+---------------------+ + // | 1 | 0 | 0 | 0 | 0 | 1 | + // +----------------+----------+-----------+----------------+--------------------+---------------------+ + // | 1 | 0 | 1 | 0 | 1 | 0 | + // +----------------+----------+-----------+----------------+--------------------+---------------------+ + // | 1 | 1 | 0 | 0 | 0 | 1 | + // +----------------+----------+-----------+----------------+--------------------+---------------------+ + // | 1 | 1 | 1 | 1 | 1 | 1 | + // +----------------+----------+-----------+----------------+--------------------+---------------------+ + + // shift_register + always_ff @(posedge clk) begin + if (rst) shift_reg <= 1'b0; + else begin + shift_reg <= (shift_reg && (!data_out_ready)) || data_in_valid; + end + end + + // buffer + assign to_load = ((!shift_reg) && data_in_valid) || (data_in_valid && data_out_ready); + always_ff @(posedge clk) begin + if (rst) buffer <= 0; + else if (to_load) buffer <= data_in; + end + + // output + assign data_out = buffer; + + + // control logic + // +----------------+----------+-----------+----------+-----------+ + // | hold_any_input | in_valid | out_ready | in_ready | out_valid | + // | (shift_reg) | | | | | + // +----------------+----------+-----------+----------+-----------+ + // | 0 | 0 | 0 | 1 | 0 | + // +----------------+----------+-----------+----------+-----------+ + // | 0 | 0 | 1 | 1 | 0 | + // +----------------+----------+-----------+----------+-----------+ + // | 0 | 1 | 0 | 1 | 0 | + // +----------------+----------+-----------+----------+-----------+ + // | 0 | 1 | 1 | 1 | 0 | + // +----------------+----------+-----------+----------+-----------+ + // | 1 | 0 | 0 | 1 | 1 | + // +----------------+----------+-----------+----------+-----------+ + // | 1 | 0 | 1 | 1 | 1 | + // +----------------+----------+-----------+----------+-----------+ + // | 1 | 1 | 0 | 0 | 1 | + // +----------------+----------+-----------+----------+-----------+ + // | 1 | 1 | 1 | 1 | 1 | + // +----------------+----------+-----------+----------+-----------+ + + assign data_in_ready = (!shift_reg) || (!data_in_valid) || data_out_ready; + assign data_out_valid = shift_reg; + +endmodule diff --git a/docs/labs/internal_rtl/hardware/rtl/simple_matmul.sv b/docs/labs/internal_rtl/hardware/rtl/simple_matmul.sv new file mode 100644 index 000000000..703fa3648 --- /dev/null +++ b/docs/labs/internal_rtl/hardware/rtl/simple_matmul.sv @@ -0,0 +1,156 @@ +/* +Module : simple_matmul +Description : This module does a matrix multiplcation between matrices X & Y. + + The dimensions for the matrix multiplcation are: + n x m * m x k + + or in MASE naming convention + a_dim1 x a_dim0 * b_dim1 x bdim_0 + + Python equivalent: + out = np.matmul(X, Y) +*/ + +`timescale 1ns / 1ps + +module simple_matmul #( + // Dimensions + parameter N = 2, + parameter M = 2, + parameter K = 2, + // Input fixed point widths + parameter X_WIDTH = 8, + parameter X_FRAC_WIDTH = 1, + parameter Y_WIDTH = 8, + parameter Y_FRAC_WIDTH = 1, + // Output fixed point widths + // if OUTPUT_ROUNDING == 0: + // then out_width & out_frac_width must match accumulator widths + parameter OUTPUT_ROUNDING = 1, + parameter OUT_WIDTH = 16, + parameter OUT_FRAC_WIDTH = 2 +) ( + input logic clk, + input logic rst, + + // Input matrix X, row-wise ordering + input logic [X_WIDTH-1:0] x_data [N*M-1:0], + input logic x_valid, + output logic x_ready, + + // Input matrix Y, column-wise ordering + input logic [Y_WIDTH-1:0] y_data [M*K-1:0], + input logic y_valid, + output logic y_ready, + + // Output matrix + output logic [OUT_WIDTH-1:0] out_data [N*K-1:0], + output logic out_valid, + input logic out_ready +); + + // ----- + // Params + // ----- + + // Accumulator widths in linear layer + localparam ACC_WIDTH = X_WIDTH + Y_WIDTH + $clog2(M); + localparam ACC_FRAC_WIDTH = X_FRAC_WIDTH + Y_FRAC_WIDTH; + + initial begin + if (OUTPUT_ROUNDING == 0) begin + assert (ACC_WIDTH == OUT_WIDTH) + else $fatal("OUT_WIDTH must be %d if OUTPUT_ROUNDING == 0", ACC_WIDTH); + assert (ACC_FRAC_WIDTH == OUT_FRAC_WIDTH) + else $fatal("OUT_FRAC_WIDTH must be %d if OUTPUT_ROUNDING == 0", ACC_FRAC_WIDTH); + end + end + + + // ----- + // Wires + // ----- + + logic [Y_WIDTH-1:0] y_data_transpose[K*M-1:0]; + logic dot_product_ready; + logic inputs_valid, inputs_ready; + + logic [N*K-1:0] dot_product_valid; + logic [N*K-1:0] sync_ready; + logic [ACC_WIDTH-1:0] dot_product_data_out[N*K-1:0]; + logic [OUT_WIDTH-1:0] rounded_dot_product[N*K-1:0]; + + + // ----- + // Logic + // ----- + + // Need to synchronise x & y inputs + assign inputs_ready = sync_ready[0]; + join2 sync_handshake ( + .data_in_valid ({x_valid, y_valid}), + .data_in_ready ({x_ready, y_ready}), + .data_out_valid(inputs_valid), + .data_out_ready(inputs_ready) + ); + + // Transpose y to make column assignment easier, this module is just a rewire + // so it shouldn't contribute anything to comb path. + transpose #( + .WIDTH(Y_WIDTH), + .DIM0 (K), + .DIM1 (M) + ) y_transpose ( + .in_data (y_data), + .out_data(y_data_transpose) + ); + + // Instantiate N-by-K number of dot products + for (genvar i = 0; i < N; i++) begin : multi_row + for (genvar j = 0; j < K; j++) begin : multi_col + + fixed_dot_product #( + .IN_WIDTH (X_WIDTH), + .IN_SIZE (M), + .WEIGHT_WIDTH(Y_WIDTH) + ) dot_product_inst ( + .clk (clk), + .rst (rst), + .data_in (x_data[((i+1)*M)-1 : i*M]), + .data_in_valid (inputs_valid), + .data_in_ready (sync_ready[i*K+j]), + .weight (y_data_transpose[((j+1)*M)-1 : j*M]), + .weight_valid (inputs_valid), + /* verilator lint_off PINCONNECTEMPTY */ + // This pin is the same as data_in_ready pin + .weight_ready (), + /* verilator lint_on PINCONNECTEMPTY */ + .data_out (dot_product_data_out[i*K+j]), + .data_out_valid(dot_product_valid[i*K+j]), + .data_out_ready(dot_product_ready) + ); + + if (OUTPUT_ROUNDING) begin : rounding + // Rounded output + fixed_round #( + .IN_WIDTH (ACC_WIDTH), + .IN_FRAC_WIDTH (ACC_FRAC_WIDTH), + .OUT_WIDTH (OUT_WIDTH), + .OUT_FRAC_WIDTH(OUT_FRAC_WIDTH) + ) round_inst ( + .data_in (dot_product_data_out[i*K+j]), + .data_out(rounded_dot_product[i*K+j]) + ); + assign out_data[i*K+j] = rounded_dot_product[i*K+j]; + end else begin : no_rounding + assign out_data[i*K+j] = dot_product_data_out[i*K+j]; + end + + end + end + + assign out_valid = dot_product_valid[0]; + assign dot_product_ready = out_ready; + +endmodule diff --git a/docs/labs/internal_rtl/hardware/rtl/skid_buffer.sv b/docs/labs/internal_rtl/hardware/rtl/skid_buffer.sv new file mode 100644 index 000000000..ca1e76f27 --- /dev/null +++ b/docs/labs/internal_rtl/hardware/rtl/skid_buffer.sv @@ -0,0 +1,92 @@ +`timescale 1ns / 1ps + +module skid_buffer #( + parameter DATA_WIDTH = 32 +) ( + input logic clk, + input logic rst, + + input logic [DATA_WIDTH - 1:0] data_in, + input logic data_in_valid, + output logic data_in_ready, + + output logic [DATA_WIDTH - 1:0] data_out, + output logic data_out_valid, + input logic data_out_ready +); + // feed the data_out either from + // data_in or a buffered copy of data_in + + logic [DATA_WIDTH - 1:0] data_buffer_out; + logic data_buffer_wren; + logic data_out_wren; + logic use_buffered_data; + logic [DATA_WIDTH - 1:0] selected_data; + logic insert, remove; + logic load, flow, fill, flush, unload; + + always_ff @(posedge clk) begin + if (rst) data_buffer_out <= 0; + else if (data_buffer_wren) data_buffer_out <= data_in; + end + + assign selected_data = (use_buffered_data) ? data_buffer_out : data_in; + always_ff @(posedge clk) begin + if (rst) data_out <= 0; + else if (data_out_wren) data_out <= selected_data; + end + // control path + // skid buffer has 4 states + // 1. Empty + // 2. Busy, holding data in the main register, but not transferring to output + // 3. Full, both two registers were hold + + enum { + EMPTY, + BUSY, + FULL + } + state, state_next; + + always_ff @(posedge clk) begin : handshake + if (rst) begin + data_in_ready <= 0; + data_out_valid <= 0; + end else begin + /* verilator lint_off WIDTH */ + data_in_ready <= (state_next != FULL); + data_out_valid <= (state_next != EMPTY); + /* verilator lint_on WIDTH */ + end + end + always_comb begin + insert = (data_in_valid && data_in_ready); + remove = (data_out_valid && data_out_ready); + end + + always_comb begin + load = (state == EMPTY) && ({insert, remove} == 2'b10); + flow = (state == BUSY) && ({insert, remove} == 2'b11); + fill = (state == BUSY) && ({insert, remove} == 2'b10); + unload = (state == BUSY) && ({insert, remove} == 2'b01); + flush = (state == FULL) && ({insert, remove} == 2'b01); + end + + always_comb + /* verilator lint_off WIDTH */ + case (state) + EMPTY: state_next = (load) ? BUSY : state; + BUSY: state_next = (fill) ? FULL : (flow) ? BUSY : (unload) ? EMPTY : state; + FULL: state_next = (flush) ? BUSY : state; + default: state_next = state; + endcase + always_ff @(posedge clk) + if (rst) state <= EMPTY; + else state <= state_next; + /* verilator lint_on WIDTH */ + always_comb begin + data_out_wren = (load == 1'b1) || (flow == 1'b1) || (flush == 1'b1); + data_buffer_wren = (fill == 1'b1); + use_buffered_data = (flush == 1'b1); + end +endmodule diff --git a/docs/labs/internal_rtl/hardware/rtl/transpose.sv b/docs/labs/internal_rtl/hardware/rtl/transpose.sv new file mode 100644 index 000000000..0804384f5 --- /dev/null +++ b/docs/labs/internal_rtl/hardware/rtl/transpose.sv @@ -0,0 +1,20 @@ +/* +Module : transpose +Description : This module does a combinatorial transpose of a matrix. +*/ + +`timescale 1ns / 1ps + +module transpose #( + parameter WIDTH = 8, + parameter DIM0 = 4, + parameter DIM1 = 4 +) ( + input logic [WIDTH-1:0] in_data [DIM1*DIM0-1:0], + output logic [WIDTH-1:0] out_data[DIM1*DIM0-1:0] +); + + for (genvar i = 0; i < DIM1; i++) + for (genvar j = 0; j < DIM0; j++) assign out_data[j*DIM1+i] = in_data[i*DIM0+j]; + +endmodule diff --git a/docs/labs/internal_rtl/hardware/rtl/unpacked_repeat_circular_buffer.sv b/docs/labs/internal_rtl/hardware/rtl/unpacked_repeat_circular_buffer.sv new file mode 100644 index 000000000..8a5a1da6a --- /dev/null +++ b/docs/labs/internal_rtl/hardware/rtl/unpacked_repeat_circular_buffer.sv @@ -0,0 +1,58 @@ +/* +Module : repeat_circular_buffer +Description : This module is a repeating circular buffer. +*/ + +`timescale 1ns / 1ps + +module unpacked_repeat_circular_buffer #( + parameter DATA_WIDTH = 32, + parameter IN_NUM = 1, + parameter REPEAT = 2, + parameter SIZE = 4 +) ( + input logic clk, + input logic rst, + + // Input streaming port + input logic [DATA_WIDTH-1:0] in_data [IN_NUM-1:0], + input logic in_valid, + output logic in_ready, + + // Output streaming port + output logic [DATA_WIDTH-1:0] out_data [IN_NUM-1:0], + output logic out_valid, + input logic out_ready +); + + logic [DATA_WIDTH * IN_NUM - 1:0] data_in_flatten; + logic [DATA_WIDTH * IN_NUM - 1:0] data_out_flatten; + + for (genvar i = 0; i < IN_NUM; i++) begin : reshape + assign data_in_flatten[i*DATA_WIDTH+DATA_WIDTH-1:i*DATA_WIDTH] = in_data[i]; + end + + repeat_circular_buffer #( + .DATA_WIDTH(DATA_WIDTH * IN_NUM), + .REPEAT(REPEAT), + .SIZE(SIZE) + ) buffer_inst ( + .clk(clk), + .rst(rst), + + // Input streaming port + .in_data (data_in_flatten), + .in_valid(in_valid), + .in_ready(in_ready), + + // Output streaming port + .out_data (data_out_flatten), + .out_valid(out_valid), + .out_ready(out_ready) + ); + + for (genvar i = 0; i < IN_NUM; i++) begin : unreshape + assign out_data[i] = data_out_flatten[i*DATA_WIDTH+DATA_WIDTH-1:i*DATA_WIDTH]; + end + +endmodule diff --git a/docs/labs/lab4-hardware.ipynb b/docs/labs/lab4-hardware.ipynb index 56168f74a..dfc70f8fb 100644 --- a/docs/labs/lab4-hardware.ipynb +++ b/docs/labs/lab4-hardware.ipynb @@ -486,7 +486,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "mase (3.11.9)", "language": "python", "name": "python3" }, @@ -500,7 +500,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.11" + "version": "3.11.9" } }, "nbformat": 4, diff --git a/docs/source/modules/documentation/tutorials/classification-model/classification_pruned_dram_metadata_graph.svg b/docs/source/modules/documentation/tutorials/classification-model/classification_pruned_dram_metadata_graph.svg new file mode 100644 index 000000000..c5e7f9d08 --- /dev/null +++ b/docs/source/modules/documentation/tutorials/classification-model/classification_pruned_dram_metadata_graph.svg @@ -0,0 +1,217 @@ + + + + + + +masegraph + + + +x + +name=%x + +op_code=placeholder + +target=x + +num_users=1 + + + +fc1 + +name=%fc1 + +op_code=call_module +chop.nn.quantized.modules.linear.LinearInteger + +in_features: 2 +out_features: 64 + + + +x->fc1 + + + + + +relu + +name=%relu + +op_code=call_function + +target=torch.relu + +num_users=1 + + + +fc1->relu + + + + + +fc1_weight + +fc1_weight + +op_code=get_parametertorch.float32[64, 2] + + + +fc1_weight->fc1 + + + + + +fc1_bias + +fc1_bias + +op_code=get_parametertorch.float32[64] + + + +fc1_bias->fc1 + + + + + +fc2 + +name=%fc2 + +op_code=call_module +chop.nn.quantized.modules.linear.LinearInteger + +in_features: 64 +out_features: 64 + + + +relu->fc2 + + + + + +relu_1 + +name=%relu_1 + +op_code=call_function + +target=torch.relu + +num_users=1 + + + +fc2->relu_1 + + + + + +fc2_weight + +fc2_weight + +op_code=get_parametertorch.float32[64, 64] + + + +fc2_weight->fc2 + + + + + +fc2_bias + +fc2_bias + +op_code=get_parametertorch.float32[64] + + + +fc2_bias->fc2 + + + + + +fc3 + +name=%fc3 + +op_code=call_module +chop.nn.quantized.modules.linear.LinearInteger + +in_features: 64 +out_features: 16 + + + +relu_1->fc3 + + + + + +output + +name=%output + +op_code=output + +target=output + +num_users=0 + + + +fc3->output + + + + + +fc3_weight + +fc3_weight + +op_code=get_parametertorch.float32[16, 64] + + + +fc3_weight->fc3 + + + + + +fc3_bias + +fc3_bias + +op_code=get_parametertorch.float32[16] + + + +fc3_bias->fc3 + + + + + diff --git a/docs/source/modules/documentation/tutorials/classification-model/classification_pruned_graph.svg b/docs/source/modules/documentation/tutorials/classification-model/classification_pruned_graph.svg new file mode 100644 index 000000000..70c31ce6b --- /dev/null +++ b/docs/source/modules/documentation/tutorials/classification-model/classification_pruned_graph.svg @@ -0,0 +1,272 @@ + + + + + + +masegraph + + + +x + +name=%x + +op_code=placeholder + +target=x + +num_users=1 + + + +net_0 + +name=%net_0 + +op_code=call_module +torch.nn.utils.parametrize.ParametrizedLinearInteger + +in_features: 2 +out_features: 64 + + + +x->net_0 + + + + + +net_1 + +name=%net_1 + +op_code=call_module +torch.nn.modules.activation.ReLU + +inplace: False + + + +net_0->net_1 + + + + + +net_0_bias + +net_0_bias + +op_code=get_parametertorch.float32[64] + + + +net_0_bias->net_0 + + + + + +net_0.parametrizations_weight.original + +net_0.parametrizations_weight.original + +op_code=get_parametertorch.float32[64, 2] + + + +net_0.parametrizations_weight.original->net_0 + + + + + +net_0.parametrizations_weight.0.mask + +buffer +torch.bool[64, 2] + + + +net_0.parametrizations_weight.0.mask->net_0 + + + + + +net_2 + +name=%net_2 + +op_code=call_module +torch.nn.modules.dropout.Dropout + +p: 0.1 +inplace: False + + + +net_1->net_2 + + + + + +net_3 + +name=%net_3 + +op_code=call_module +torch.nn.utils.parametrize.ParametrizedLinearInteger + +in_features: 64 +out_features: 64 + + + +net_2->net_3 + + + + + +net_4 + +name=%net_4 + +op_code=call_module +torch.nn.modules.activation.ReLU + +inplace: False + + + +net_3->net_4 + + + + + +net_3_bias + +net_3_bias + +op_code=get_parametertorch.float32[64] + + + +net_3_bias->net_3 + + + + + +net_3.parametrizations_weight.original + +net_3.parametrizations_weight.original + +op_code=get_parametertorch.float32[64, 64] + + + +net_3.parametrizations_weight.original->net_3 + + + + + +net_3.parametrizations_weight.0.mask + +buffer +torch.bool[64, 64] + + + +net_3.parametrizations_weight.0.mask->net_3 + + + + + +net_5 + +name=%net_5 + +op_code=call_module +torch.nn.utils.parametrize.ParametrizedLinearInteger + +in_features: 64 +out_features: 16 + + + +net_4->net_5 + + + + + +output + +name=%output + +op_code=output + +target=output + +num_users=0 + + + +net_5->output + + + + + +net_5_bias + +net_5_bias + +op_code=get_parametertorch.float32[16] + + + +net_5_bias->net_5 + + + + + +net_5.parametrizations_weight.original + +net_5.parametrizations_weight.original + +op_code=get_parametertorch.float32[16, 64] + + + +net_5.parametrizations_weight.original->net_5 + + + + + +net_5.parametrizations_weight.0.mask + +buffer +torch.bool[16, 64] + + + +net_5.parametrizations_weight.0.mask->net_5 + + + + + diff --git a/docs/source/modules/documentation/tutorials/tutorial_1_introduction_to_mase.ipynb b/docs/source/modules/documentation/tutorials/tutorial_1_introduction_to_mase.ipynb index 5899ee801..8e3e5fcce 100644 --- a/docs/source/modules/documentation/tutorials/tutorial_1_introduction_to_mase.ipynb +++ b/docs/source/modules/documentation/tutorials/tutorial_1_introduction_to_mase.ipynb @@ -18,9 +18,73 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "'(ProtocolError('Connection aborted.', RemoteDisconnected('Remote end closed connection without response')), '(Request ID: 6010b400-a9a4-43de-834f-0653cb964024)')' thrown while requesting HEAD https://huggingface.co/prajjwal1/bert-tiny/resolve/main/config.json\n", + "WARNING:huggingface_hub.utils._http:'(ProtocolError('Connection aborted.', RemoteDisconnected('Remote end closed connection without response')), '(Request ID: 6010b400-a9a4-43de-834f-0653cb964024)')' thrown while requesting HEAD https://huggingface.co/prajjwal1/bert-tiny/resolve/main/config.json\n", + "Retrying in 1s [Retry 1/5].\n", + "WARNING:huggingface_hub.utils._http:Retrying in 1s [Retry 1/5].\n", + "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['classifier.bias', 'classifier.weight']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "BertForSequenceClassification(\n", + " (bert): BertModel(\n", + " (embeddings): BertEmbeddings(\n", + " (word_embeddings): Embedding(30522, 128, padding_idx=0)\n", + " (position_embeddings): Embedding(512, 128)\n", + " (token_type_embeddings): Embedding(2, 128)\n", + " (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (encoder): BertEncoder(\n", + " (layer): ModuleList(\n", + " (0-1): 2 x BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSdpaSelfAttention(\n", + " (query): Linear(in_features=128, out_features=128, bias=True)\n", + " (key): Linear(in_features=128, out_features=128, bias=True)\n", + " (value): Linear(in_features=128, out_features=128, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=128, out_features=128, bias=True)\n", + " (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=128, out_features=512, bias=True)\n", + " (intermediate_act_fn): GELUActivation()\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=512, out_features=128, bias=True)\n", + " (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (pooler): BertPooler(\n", + " (dense): Linear(in_features=128, out_features=128, bias=True)\n", + " (activation): Tanh()\n", + " )\n", + " )\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (classifier): Linear(in_features=128, out_features=2, bias=True)\n", + ")\n" + ] + } + ], "source": [ "from transformers import AutoModelForSequenceClassification\n", "\n", @@ -51,9 +115,17 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`past_key_values` were not specified as input names, but model.config.use_cache = True. Setting model.config.use_cache = False.\n" + ] + } + ], "source": [ "import os\n", "import platform\n", @@ -96,7 +168,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -218,9 +290,852 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[ 101, 9932, 2089, 2202, 2058, 1996, 2088, 2028, 2154, 102],\n", + " [ 101, 2023, 2003, 2339, 2017, 2323, 4553, 4748, 4877, 102]])\n", + "tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])\n", + "tensor([[ 101, 9932, 2089, 2202, 2058, 1996, 2088, 2028, 2154, 102],\n", + " [ 101, 2023, 2003, 2339, 2017, 2323, 4553, 4748, 4877, 102]])\n", + "tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n", + "tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n", + "tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n", + "tensor([[[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]],\n", + "\n", + "\n", + " [[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]]])\n", + "tensor([[[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]],\n", + "\n", + "\n", + " [[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]]])\n", + "tensor([[[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],\n", + "\n", + "\n", + " [[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]])\n", + "tensor([[[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],\n", + "\n", + "\n", + " [[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]])\n", + "tensor([[[ 0.7973, 0.0109, -8.8405, ..., 1.4170, 0.1046, -0.1551],\n", + " [-1.1766, 1.2879, -1.0986, ..., 0.4749, -0.5899, 0.8746],\n", + " [-2.0560, 0.7748, -0.8909, ..., -0.4034, 0.5352, -1.3657],\n", + " ...,\n", + " [ 0.2317, -0.7896, 0.9634, ..., -0.8037, 0.4834, -0.5868],\n", + " [ 0.0243, -1.0235, -1.2771, ..., -2.2378, 1.8530, 0.1558],\n", + " [-1.3637, 0.7055, -0.2177, ..., 0.3557, -0.3971, -0.3107]],\n", + "\n", + " [[ 0.7973, 0.0109, -8.8405, ..., 1.4170, 0.1046, -0.1551],\n", + " [-2.6940, 0.6198, -0.4564, ..., -1.4367, -1.5705, -3.1260],\n", + " [-1.7524, 0.8535, -0.2155, ..., -0.5222, -1.2430, -1.7199],\n", + " ...,\n", + " [-0.0347, 0.7446, 1.4462, ..., -1.1578, -2.6197, 0.2612],\n", + " [ 2.4334, -0.3068, 0.8250, ..., 0.1475, 0.1790, 2.2907],\n", + " [-1.3637, 0.7055, -0.2177, ..., 0.3557, -0.3971, -0.3107]]],\n", + " grad_fn=)\n", + "tensor([[[ 9.7740e-01, 2.5482e-03, -5.2921e-01, ..., 1.4757e-01,\n", + " 1.8900e-01, 2.8282e-01],\n", + " [-3.5020e-02, -6.1047e-02, -1.0465e-01, ..., -8.1892e-01,\n", + " 1.1978e+00, 2.1808e+00],\n", + " [ 4.3982e-01, -1.9602e+00, -6.8830e-01, ..., -6.3025e-01,\n", + " -1.5967e-01, 1.3284e+00],\n", + " ...,\n", + " [ 1.4783e+00, 1.0907e-01, -1.5222e+00, ..., -3.0983e-01,\n", + " -1.2971e-01, 1.1265e+00],\n", + " [ 1.5890e+00, -1.6859e+00, 7.8703e-01, ..., -1.3174e+00,\n", + " 2.2258e-01, 8.8157e-01],\n", + " [-3.7517e-01, 1.5191e+00, -2.6796e-01, ..., -1.6159e+00,\n", + " 7.2677e-02, 1.1724e-01]],\n", + "\n", + " [[ 9.7740e-01, 2.5482e-03, -5.2921e-01, ..., 1.4757e-01,\n", + " 1.8900e-01, 2.8282e-01],\n", + " [-4.4781e-01, -7.9224e-01, -2.1741e+00, ..., -5.9181e-01,\n", + " 1.4373e+00, 2.4267e+00],\n", + " [-2.5942e-01, 9.7163e-01, -3.2928e+00, ..., -5.9773e-01,\n", + " -3.0482e-01, 1.4038e+00],\n", + " ...,\n", + " [ 5.1575e-02, 3.5218e-01, -3.8926e-01, ..., -1.1508e+00,\n", + " 7.5490e-01, 8.2911e-01],\n", + " [ 1.6107e+00, 6.8170e-02, 9.2537e-01, ..., -1.5233e+00,\n", + " -6.0733e-01, 3.3097e-01],\n", + " [-3.7517e-01, 1.5191e+00, -2.6796e-01, ..., -1.6159e+00,\n", + " 7.2677e-02, 1.1724e-01]]], grad_fn=)\n", + "tensor([[[ 9.7740e-01, 2.5482e-03, -5.2921e-01, ..., 1.4757e-01,\n", + " 1.8900e-01, 2.8282e-01],\n", + " [-3.5020e-02, -6.1047e-02, -1.0465e-01, ..., -8.1892e-01,\n", + " 1.1978e+00, 2.1808e+00],\n", + " [ 4.3982e-01, -1.9602e+00, -6.8830e-01, ..., -6.3025e-01,\n", + " -1.5967e-01, 1.3284e+00],\n", + " ...,\n", + " [ 1.4783e+00, 1.0907e-01, -1.5222e+00, ..., -3.0983e-01,\n", + " -1.2971e-01, 1.1265e+00],\n", + " [ 1.5890e+00, -1.6859e+00, 7.8703e-01, ..., -1.3174e+00,\n", + " 2.2258e-01, 8.8157e-01],\n", + " [-3.7517e-01, 1.5191e+00, -2.6796e-01, ..., -1.6159e+00,\n", + " 7.2677e-02, 1.1724e-01]],\n", + "\n", + " [[ 9.7740e-01, 2.5482e-03, -5.2921e-01, ..., 1.4757e-01,\n", + " 1.8900e-01, 2.8282e-01],\n", + " [-4.4781e-01, -7.9224e-01, -2.1741e+00, ..., -5.9181e-01,\n", + " 1.4373e+00, 2.4267e+00],\n", + " [-2.5942e-01, 9.7163e-01, -3.2928e+00, ..., -5.9773e-01,\n", + " -3.0482e-01, 1.4038e+00],\n", + " ...,\n", + " [ 5.1575e-02, 3.5218e-01, -3.8926e-01, ..., -1.1508e+00,\n", + " 7.5490e-01, 8.2911e-01],\n", + " [ 1.6107e+00, 6.8170e-02, 9.2537e-01, ..., -1.5233e+00,\n", + " -6.0733e-01, 3.3097e-01],\n", + " [-3.7517e-01, 1.5191e+00, -2.6796e-01, ..., -1.6159e+00,\n", + " 7.2677e-02, 1.1724e-01]]], grad_fn=)\n", + "tensor([[[[ 9.7740e-01, 2.5482e-03, -5.2921e-01, ..., -5.2614e-01,\n", + " -3.5687e-01, -2.6793e-01],\n", + " [ 2.0335e-01, -5.4534e-01, 3.0686e-01, ..., 1.4757e-01,\n", + " 1.8900e-01, 2.8282e-01]],\n", + "\n", + " [[-3.5020e-02, -6.1047e-02, -1.0465e-01, ..., -2.0138e+00,\n", + " 4.5529e-01, -7.8171e-01],\n", + " [ 1.1969e+00, 1.6337e+00, 2.5047e-01, ..., -8.1892e-01,\n", + " 1.1978e+00, 2.1808e+00]],\n", + "\n", + " [[ 4.3982e-01, -1.9602e+00, -6.8830e-01, ..., -2.2501e-01,\n", + " 7.2290e-02, -1.8290e+00],\n", + " [ 8.9952e-01, 1.0029e+00, 7.4536e-04, ..., -6.3025e-01,\n", + " -1.5967e-01, 1.3284e+00]],\n", + "\n", + " ...,\n", + "\n", + " [[ 1.4783e+00, 1.0907e-01, -1.5222e+00, ..., 1.1867e+00,\n", + " -1.3561e+00, 6.5158e-01],\n", + " [ 9.5466e-01, 4.5887e-01, 7.8078e-01, ..., -3.0983e-01,\n", + " -1.2971e-01, 1.1265e+00]],\n", + "\n", + " [[ 1.5890e+00, -1.6859e+00, 7.8703e-01, ..., 6.5467e-01,\n", + " -6.8451e-01, 6.5081e-01],\n", + " [ 7.0729e-01, 1.4499e+00, -1.5089e-01, ..., -1.3174e+00,\n", + " 2.2258e-01, 8.8157e-01]],\n", + "\n", + " [[-3.7517e-01, 1.5191e+00, -2.6796e-01, ..., 3.3130e-01,\n", + " -3.2756e-01, -6.3130e-01],\n", + " [ 8.6773e-01, 2.0996e-01, -3.4332e-01, ..., -1.6159e+00,\n", + " 7.2677e-02, 1.1724e-01]]],\n", + "\n", + "\n", + " [[[ 9.7740e-01, 2.5482e-03, -5.2921e-01, ..., -5.2614e-01,\n", + " -3.5687e-01, -2.6793e-01],\n", + " [ 2.0335e-01, -5.4534e-01, 3.0686e-01, ..., 1.4757e-01,\n", + " 1.8900e-01, 2.8282e-01]],\n", + "\n", + " [[-4.4781e-01, -7.9224e-01, -2.1741e+00, ..., 1.7508e+00,\n", + " -3.6708e-01, -1.3251e+00],\n", + " [ 7.9208e-01, -1.3537e-01, 2.3756e-01, ..., -5.9181e-01,\n", + " 1.4373e+00, 2.4267e+00]],\n", + "\n", + " [[-2.5942e-01, 9.7163e-01, -3.2928e+00, ..., -9.6646e-01,\n", + " -4.8876e-01, -1.4426e+00],\n", + " [ 1.0250e+00, -6.9093e-01, -1.2734e+00, ..., -5.9773e-01,\n", + " -3.0482e-01, 1.4038e+00]],\n", + "\n", + " ...,\n", + "\n", + " [[ 5.1575e-02, 3.5218e-01, -3.8926e-01, ..., -1.2252e-02,\n", + " 1.0394e+00, 4.2402e-01],\n", + " [-4.7386e-01, 2.6401e+00, 1.7024e+00, ..., -1.1508e+00,\n", + " 7.5490e-01, 8.2911e-01]],\n", + "\n", + " [[ 1.6107e+00, 6.8170e-02, 9.2537e-01, ..., -6.1665e-01,\n", + " 2.7627e-01, -1.2083e+00],\n", + " [ 9.3395e-01, -9.7541e-01, -2.5442e-02, ..., -1.5233e+00,\n", + " -6.0733e-01, 3.3097e-01]],\n", + "\n", + " [[-3.7517e-01, 1.5191e+00, -2.6796e-01, ..., 3.3130e-01,\n", + " -3.2756e-01, -6.3130e-01],\n", + " [ 8.6773e-01, 2.0996e-01, -3.4332e-01, ..., -1.6159e+00,\n", + " 7.2677e-02, 1.1724e-01]]]], grad_fn=)\n", + "tensor([[[-0.1709, 0.5230, -0.8713, ..., -1.3382, 0.5892, 0.4026],\n", + " [-0.5842, 0.9588, 1.5642, ..., -1.0731, -0.7330, 0.3132],\n", + " [-0.8601, -1.3756, 0.5042, ..., -0.0476, 0.2650, 1.2150],\n", + " ...,\n", + " [ 0.0520, 1.1719, -1.5471, ..., -0.7894, 0.1419, 1.6964],\n", + " [ 0.7654, -1.5053, -0.4142, ..., -1.4622, -0.8975, 1.4576],\n", + " [-1.2008, -0.6008, -1.4608, ..., -1.2105, -0.4289, 0.3827]],\n", + "\n", + " [[-0.1709, 0.5230, -0.8713, ..., -1.3382, 0.5892, 0.4026],\n", + " [-1.3806, 0.2626, -0.5207, ..., -1.6714, -0.0554, 1.0225],\n", + " [-1.7116, 1.8788, -2.5695, ..., -0.6958, 0.5728, 0.5461],\n", + " ...,\n", + " [-1.3246, 1.2196, -0.3034, ..., -1.1955, -0.6708, 0.5128],\n", + " [ 0.9854, 0.8260, 0.2892, ..., -0.6428, 0.3637, 0.4339],\n", + " [-1.2008, -0.6008, -1.4608, ..., -1.2105, -0.4289, 0.3827]]],\n", + " grad_fn=)\n", + "tensor([[[-0.1709, 0.5230, -0.8713, ..., -1.3382, 0.5892, 0.4026],\n", + " [-0.5842, 0.9588, 1.5642, ..., -1.0731, -0.7330, 0.3132],\n", + " [-0.8601, -1.3756, 0.5042, ..., -0.0476, 0.2650, 1.2150],\n", + " ...,\n", + " [ 0.0520, 1.1719, -1.5471, ..., -0.7894, 0.1419, 1.6964],\n", + " [ 0.7654, -1.5053, -0.4142, ..., -1.4622, -0.8975, 1.4576],\n", + " [-1.2008, -0.6008, -1.4608, ..., -1.2105, -0.4289, 0.3827]],\n", + "\n", + " [[-0.1709, 0.5230, -0.8713, ..., -1.3382, 0.5892, 0.4026],\n", + " [-1.3806, 0.2626, -0.5207, ..., -1.6714, -0.0554, 1.0225],\n", + " [-1.7116, 1.8788, -2.5695, ..., -0.6958, 0.5728, 0.5461],\n", + " ...,\n", + " [-1.3246, 1.2196, -0.3034, ..., -1.1955, -0.6708, 0.5128],\n", + " [ 0.9854, 0.8260, 0.2892, ..., -0.6428, 0.3637, 0.4339],\n", + " [-1.2008, -0.6008, -1.4608, ..., -1.2105, -0.4289, 0.3827]]],\n", + " grad_fn=)\n", + "tensor([[[[-0.1709, 0.5230, -0.8713, ..., 0.4365, 0.6238, -0.9414],\n", + " [-1.3731, 1.1521, 0.1321, ..., -1.3382, 0.5892, 0.4026]],\n", + "\n", + " [[-0.5842, 0.9588, 1.5642, ..., -1.5431, 0.4999, -1.1350],\n", + " [ 0.9615, 0.8694, 0.0998, ..., -1.0731, -0.7330, 0.3132]],\n", + "\n", + " [[-0.8601, -1.3756, 0.5042, ..., 0.9764, -0.8321, -1.0204],\n", + " [ 1.5175, 1.1454, 0.7791, ..., -0.0476, 0.2650, 1.2150]],\n", + "\n", + " ...,\n", + "\n", + " [[ 0.0520, 1.1719, -1.5471, ..., 1.9402, -1.1294, 0.4793],\n", + " [ 1.0053, 0.8099, 1.6415, ..., -0.7894, 0.1419, 1.6964]],\n", + "\n", + " [[ 0.7654, -1.5053, -0.4142, ..., 1.7455, -0.7326, 1.5248],\n", + " [ 1.0806, 1.1457, 2.2163, ..., -1.4622, -0.8975, 1.4576]],\n", + "\n", + " [[-1.2008, -0.6008, -1.4608, ..., 2.0905, 1.8849, -1.5708],\n", + " [ 1.9999, 0.3493, -0.8524, ..., -1.2105, -0.4289, 0.3827]]],\n", + "\n", + "\n", + " [[[-0.1709, 0.5230, -0.8713, ..., 0.4365, 0.6238, -0.9414],\n", + " [-1.3731, 1.1521, 0.1321, ..., -1.3382, 0.5892, 0.4026]],\n", + "\n", + " [[-1.3806, 0.2626, -0.5207, ..., 1.6517, -0.2316, -1.3171],\n", + " [ 0.6812, -0.0090, 0.3803, ..., -1.6714, -0.0554, 1.0225]],\n", + "\n", + " [[-1.7116, 1.8788, -2.5695, ..., 0.4927, -0.4850, -1.0645],\n", + " [ 1.2646, 1.6481, 0.9055, ..., -0.6958, 0.5728, 0.5461]],\n", + "\n", + " ...,\n", + "\n", + " [[-1.3246, 1.2196, -0.3034, ..., 1.2747, 1.2353, 0.2825],\n", + " [ 1.5373, 0.8648, 0.6062, ..., -1.1955, -0.6708, 0.5128]],\n", + "\n", + " [[ 0.9854, 0.8260, 0.2892, ..., 1.3848, -0.0103, -1.0700],\n", + " [ 1.3827, 2.9809, 0.0276, ..., -0.6428, 0.3637, 0.4339]],\n", + "\n", + " [[-1.2008, -0.6008, -1.4608, ..., 2.0905, 1.8849, -1.5708],\n", + " [ 1.9999, 0.3493, -0.8524, ..., -1.2105, -0.4289, 0.3827]]]],\n", + " grad_fn=)\n", + "tensor([[[-0.0123, 0.5761, 0.2209, ..., -0.1027, 1.1061, -2.5200],\n", + " [-1.1465, -1.5578, -0.6984, ..., 1.0310, 0.4824, -0.2291],\n", + " [-1.0361, -1.8192, -2.3055, ..., 1.5286, -1.5941, 1.1762],\n", + " ...,\n", + " [-0.7992, 0.0886, 0.4887, ..., -1.7941, 0.4835, 1.3780],\n", + " [-1.4692, -0.9135, -0.2802, ..., -0.9691, 0.3500, 1.8863],\n", + " [-0.5760, -0.0452, 0.4230, ..., -0.7179, -0.7858, 1.6879]],\n", + "\n", + " [[-0.0123, 0.5761, 0.2209, ..., -0.1027, 1.1061, -2.5200],\n", + " [-0.3700, -1.9754, -0.7315, ..., 0.2293, 0.6996, 3.1299],\n", + " [-0.6252, 0.2879, -1.4036, ..., -2.0560, -2.4623, -0.9584],\n", + " ...,\n", + " [-1.1306, -1.4343, -1.4422, ..., -1.6115, -0.0475, 1.3975],\n", + " [-0.9816, -1.4909, -1.0086, ..., -0.9284, 0.5260, 1.5330],\n", + " [-0.5760, -0.0452, 0.4230, ..., -0.7179, -0.7858, 1.6879]]],\n", + " grad_fn=)\n", + "tensor([[[-0.0123, 0.5761, 0.2209, ..., -0.1027, 1.1061, -2.5200],\n", + " [-1.1465, -1.5578, -0.6984, ..., 1.0310, 0.4824, -0.2291],\n", + " [-1.0361, -1.8192, -2.3055, ..., 1.5286, -1.5941, 1.1762],\n", + " ...,\n", + " [-0.7992, 0.0886, 0.4887, ..., -1.7941, 0.4835, 1.3780],\n", + " [-1.4692, -0.9135, -0.2802, ..., -0.9691, 0.3500, 1.8863],\n", + " [-0.5760, -0.0452, 0.4230, ..., -0.7179, -0.7858, 1.6879]],\n", + "\n", + " [[-0.0123, 0.5761, 0.2209, ..., -0.1027, 1.1061, -2.5200],\n", + " [-0.3700, -1.9754, -0.7315, ..., 0.2293, 0.6996, 3.1299],\n", + " [-0.6252, 0.2879, -1.4036, ..., -2.0560, -2.4623, -0.9584],\n", + " ...,\n", + " [-1.1306, -1.4343, -1.4422, ..., -1.6115, -0.0475, 1.3975],\n", + " [-0.9816, -1.4909, -1.0086, ..., -0.9284, 0.5260, 1.5330],\n", + " [-0.5760, -0.0452, 0.4230, ..., -0.7179, -0.7858, 1.6879]]],\n", + " grad_fn=)\n", + "tensor([[[[-0.0123, 0.5761, 0.2209, ..., -0.1457, -0.7538, 0.1761],\n", + " [-0.0705, 0.9215, 0.7990, ..., -0.1027, 1.1061, -2.5200]],\n", + "\n", + " [[-1.1465, -1.5578, -0.6984, ..., 0.0289, -2.1112, -0.8728],\n", + " [ 0.6506, -1.6966, 1.4463, ..., 1.0310, 0.4824, -0.2291]],\n", + "\n", + " [[-1.0361, -1.8192, -2.3055, ..., -0.2195, -1.1732, 0.3182],\n", + " [-0.5841, -0.0227, 3.0901, ..., 1.5286, -1.5941, 1.1762]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.7992, 0.0886, 0.4887, ..., 0.7859, -1.0127, -0.2676],\n", + " [-0.3055, 0.6270, -3.0705, ..., -1.7941, 0.4835, 1.3780]],\n", + "\n", + " [[-1.4692, -0.9135, -0.2802, ..., 0.1197, -0.7532, 0.0731],\n", + " [ 0.6096, -1.0893, -0.6959, ..., -0.9691, 0.3500, 1.8863]],\n", + "\n", + " [[-0.5760, -0.0452, 0.4230, ..., 0.8851, 0.3078, 0.8106],\n", + " [-1.1804, 0.9512, 0.3169, ..., -0.7179, -0.7858, 1.6879]]],\n", + "\n", + "\n", + " [[[-0.0123, 0.5761, 0.2209, ..., -0.1457, -0.7538, 0.1761],\n", + " [-0.0705, 0.9215, 0.7990, ..., -0.1027, 1.1061, -2.5200]],\n", + "\n", + " [[-0.3700, -1.9754, -0.7315, ..., 0.5756, -1.5559, 0.0326],\n", + " [ 1.4229, 2.3970, -0.4516, ..., 0.2293, 0.6996, 3.1299]],\n", + "\n", + " [[-0.6252, 0.2879, -1.4036, ..., 0.5306, -0.5608, 1.1861],\n", + " [-2.5980, 0.2673, 3.3016, ..., -2.0560, -2.4623, -0.9584]],\n", + "\n", + " ...,\n", + "\n", + " [[-1.1306, -1.4343, -1.4422, ..., 0.3918, -1.5336, -0.5026],\n", + " [ 1.8587, 0.8501, -1.2402, ..., -1.6115, -0.0475, 1.3975]],\n", + "\n", + " [[-0.9816, -1.4909, -1.0086, ..., 0.2956, 0.0351, -1.0685],\n", + " [-0.6594, -0.0133, -1.1863, ..., -0.9284, 0.5260, 1.5330]],\n", + "\n", + " [[-0.5760, -0.0452, 0.4230, ..., 0.8851, 0.3078, 0.8106],\n", + " [-1.1804, 0.9512, 0.3169, ..., -0.7179, -0.7858, 1.6879]]]],\n", + " grad_fn=)\n", + "tensor([[[[-0.5911, -0.4682, -0.4314, ..., 0.0366, -1.0405, -0.0579],\n", + " [-1.1141, -1.4785, -0.6477, ..., 0.0631, -1.9950, -0.7912],\n", + " [-0.7059, -0.9954, -1.3923, ..., -0.1557, -0.9998, 0.2774],\n", + " ...,\n", + " [-0.6300, 0.1252, 0.3486, ..., 0.5692, -0.9417, -0.1399],\n", + " [-1.0593, -0.6395, -0.3343, ..., 0.0571, -0.7842, 0.1042],\n", + " [-0.2081, 0.2785, 0.0613, ..., -0.0353, -0.8389, -0.0026]],\n", + "\n", + " [[-0.3266, 0.6360, 0.0214, ..., 0.1677, 0.3883, 0.4382],\n", + " [-0.6299, 0.6012, 0.7379, ..., 0.2989, -0.1569, 0.9508],\n", + " [-0.4097, 0.6374, -0.1589, ..., 0.0258, -0.0364, 0.9990],\n", + " ...,\n", + " [-0.0088, 0.1378, -0.4819, ..., -0.1261, 0.2908, 1.0980],\n", + " [-0.5584, 0.9932, -0.1105, ..., 0.2486, 0.4005, 0.6046],\n", + " [-0.4516, 0.8092, -0.0513, ..., 0.1182, 0.4294, 0.4781]]],\n", + "\n", + "\n", + " [[[-0.6850, -0.2168, -0.6087, ..., 0.1402, -0.7267, -0.1502],\n", + " [-0.2084, -0.1325, -0.1151, ..., 0.0938, -0.8963, 0.1296],\n", + " [-0.3665, 0.3408, -0.6133, ..., 0.1938, -0.6681, 0.5929],\n", + " ...,\n", + " [-1.0985, -1.2861, -1.2531, ..., 0.3609, -1.4022, -0.4444],\n", + " [-0.9787, -1.4207, -0.9908, ..., 0.2989, -0.0159, -1.0060],\n", + " [-0.1970, 0.2371, -0.0115, ..., 0.0049, -0.8135, 0.1201]],\n", + "\n", + " [[ 0.0907, 0.7927, -0.0524, ..., -0.4610, -0.4295, 0.4391],\n", + " [-0.2352, 1.0586, -0.1117, ..., -0.4943, -0.6546, 0.6617],\n", + " [ 0.8269, 1.7861, -1.1207, ..., -0.0885, 0.2791, 1.4721],\n", + " ...,\n", + " [-0.1038, 0.9623, -0.7062, ..., -0.3340, -0.3050, 0.8792],\n", + " [ 0.0507, 0.6634, -0.2642, ..., -0.4959, -0.7944, 0.7687],\n", + " [ 0.3364, 0.9047, -0.2037, ..., -0.3705, -0.2893, 0.6787]]]],\n", + " grad_fn=)\n", + "tensor([[[[-0.5911, -0.4682, -0.4314, ..., 0.0366, -1.0405, -0.0579],\n", + " [-0.3266, 0.6360, 0.0214, ..., 0.1677, 0.3883, 0.4382]],\n", + "\n", + " [[-1.1141, -1.4785, -0.6477, ..., 0.0631, -1.9950, -0.7912],\n", + " [-0.6299, 0.6012, 0.7379, ..., 0.2989, -0.1569, 0.9508]],\n", + "\n", + " [[-0.7059, -0.9954, -1.3923, ..., -0.1557, -0.9998, 0.2774],\n", + " [-0.4097, 0.6374, -0.1589, ..., 0.0258, -0.0364, 0.9990]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.6300, 0.1252, 0.3486, ..., 0.5692, -0.9417, -0.1399],\n", + " [-0.0088, 0.1378, -0.4819, ..., -0.1261, 0.2908, 1.0980]],\n", + "\n", + " [[-1.0593, -0.6395, -0.3343, ..., 0.0571, -0.7842, 0.1042],\n", + " [-0.5584, 0.9932, -0.1105, ..., 0.2486, 0.4005, 0.6046]],\n", + "\n", + " [[-0.2081, 0.2785, 0.0613, ..., -0.0353, -0.8389, -0.0026],\n", + " [-0.4516, 0.8092, -0.0513, ..., 0.1182, 0.4294, 0.4781]]],\n", + "\n", + "\n", + " [[[-0.6850, -0.2168, -0.6087, ..., 0.1402, -0.7267, -0.1502],\n", + " [ 0.0907, 0.7927, -0.0524, ..., -0.4610, -0.4295, 0.4391]],\n", + "\n", + " [[-0.2084, -0.1325, -0.1151, ..., 0.0938, -0.8963, 0.1296],\n", + " [-0.2352, 1.0586, -0.1117, ..., -0.4943, -0.6546, 0.6617]],\n", + "\n", + " [[-0.3665, 0.3408, -0.6133, ..., 0.1938, -0.6681, 0.5929],\n", + " [ 0.8269, 1.7861, -1.1207, ..., -0.0885, 0.2791, 1.4721]],\n", + "\n", + " ...,\n", + "\n", + " [[-1.0985, -1.2861, -1.2531, ..., 0.3609, -1.4022, -0.4444],\n", + " [-0.1038, 0.9623, -0.7062, ..., -0.3340, -0.3050, 0.8792]],\n", + "\n", + " [[-0.9787, -1.4207, -0.9908, ..., 0.2989, -0.0159, -1.0060],\n", + " [ 0.0507, 0.6634, -0.2642, ..., -0.4959, -0.7944, 0.7687]],\n", + "\n", + " [[-0.1970, 0.2371, -0.0115, ..., 0.0049, -0.8135, 0.1201],\n", + " [ 0.3364, 0.9047, -0.2037, ..., -0.3705, -0.2893, 0.6787]]]],\n", + " grad_fn=)\n", + "tensor([[[-0.9552, 0.6594, -6.5403, ..., -0.7144, 0.0906, 0.3369],\n", + " [-2.5251, 1.3955, -0.8914, ..., -2.1363, 0.0271, 1.1132],\n", + " [-3.7148, 0.6796, -0.8710, ..., -2.6492, 0.5694, -0.1085],\n", + " ...,\n", + " [-2.2403, -0.7594, 0.5414, ..., -3.0426, 0.8895, -0.0546],\n", + " [-1.6945, -0.6326, -0.8632, ..., -4.0678, 1.7219, 0.6481],\n", + " [-2.9625, 0.7451, -0.8037, ..., -2.5048, 0.3125, 0.5537]],\n", + "\n", + " [[-0.5150, 0.8150, -6.5015, ..., -0.5377, -0.4171, 0.1350],\n", + " [-2.9979, 1.0930, -0.2619, ..., -3.1811, -1.0048, -1.8349],\n", + " [-2.8788, 0.5405, -0.0789, ..., -2.3969, -0.7016, -0.7332],\n", + " ...,\n", + " [-1.7194, 1.5158, 1.0070, ..., -2.8931, -2.3309, 1.1685],\n", + " [ 0.0717, -0.1039, 0.5084, ..., -2.1932, 0.0751, 2.8236],\n", + " [-2.4774, 0.7563, -0.7502, ..., -2.1312, -0.0685, 0.8700]]],\n", + " grad_fn=)\n", + "tensor([[[-0.0455, 0.6529, 0.6297, ..., 0.4139, -0.9381, 0.6769],\n", + " [-3.1104, -3.7282, -2.3953, ..., -0.9155, -0.8280, -1.7070],\n", + " [-0.9324, -2.9333, -2.3249, ..., -0.8455, -0.0326, -0.6998],\n", + " ...,\n", + " [-1.9000, -1.1028, -1.1281, ..., -0.2809, 2.0206, -1.0802],\n", + " [-1.1088, -1.0420, -2.4026, ..., -0.4478, 0.7391, -0.0354],\n", + " [ 0.7349, 0.6742, -2.6697, ..., -0.5114, 1.5155, 2.0246]],\n", + "\n", + " [[-0.0799, 0.7813, 0.4918, ..., 0.6888, -0.7680, 0.9805],\n", + " [-2.8720, -1.0602, -2.3610, ..., -2.1143, 0.9664, -1.1212],\n", + " [-1.4705, -2.1384, -1.9955, ..., -0.9722, 1.5909, -0.1668],\n", + " ...,\n", + " [-2.9884, -1.1566, -2.5215, ..., 1.1460, 0.7120, -0.6320],\n", + " [-3.3666, -0.7966, -3.3154, ..., 0.5316, 1.7058, 2.1950],\n", + " [ 0.6626, 0.8537, -2.7251, ..., -0.0901, 1.5883, 2.3840]]],\n", + " grad_fn=)\n", + "tensor([[[-0.0455, 0.6529, 0.6297, ..., 0.4139, -0.9381, 0.6769],\n", + " [-3.1104, -3.7282, -2.3953, ..., -0.9155, -0.8280, -1.7070],\n", + " [-0.9324, -2.9333, -2.3249, ..., -0.8455, -0.0326, -0.6998],\n", + " ...,\n", + " [-1.9000, -1.1028, -1.1281, ..., -0.2809, 2.0206, -1.0802],\n", + " [-1.1088, -1.0420, -2.4026, ..., -0.4478, 0.7391, -0.0354],\n", + " [ 0.7349, 0.6742, -2.6697, ..., -0.5114, 1.5155, 2.0246]],\n", + "\n", + " [[-0.0799, 0.7813, 0.4918, ..., 0.6888, -0.7680, 0.9805],\n", + " [-2.8720, -1.0602, -2.3610, ..., -2.1143, 0.9664, -1.1212],\n", + " [-1.4705, -2.1384, -1.9955, ..., -0.9722, 1.5909, -0.1668],\n", + " ...,\n", + " [-2.9884, -1.1566, -2.5215, ..., 1.1460, 0.7120, -0.6320],\n", + " [-3.3666, -0.7966, -3.3154, ..., 0.5316, 1.7058, 2.1950],\n", + " [ 0.6626, 0.8537, -2.7251, ..., -0.0901, 1.5883, 2.3840]]],\n", + " grad_fn=)\n", + "tensor([[[[-0.0455, 0.6529, 0.6297, ..., 1.0165, -1.6055, 0.0557],\n", + " [-0.9141, -1.5101, -0.8415, ..., 0.4139, -0.9381, 0.6769]],\n", + "\n", + " [[-3.1104, -3.7282, -2.3953, ..., -1.6195, 0.7426, -3.2794],\n", + " [-1.2694, 0.3821, -0.5687, ..., -0.9155, -0.8280, -1.7070]],\n", + "\n", + " [[-0.9324, -2.9333, -2.3249, ..., -1.0254, 1.8158, -1.8835],\n", + " [-1.5265, -0.3901, 0.2734, ..., -0.8455, -0.0326, -0.6998]],\n", + "\n", + " ...,\n", + "\n", + " [[-1.9000, -1.1028, -1.1281, ..., -1.2688, -0.0851, -2.3190],\n", + " [-2.4374, 0.0718, -2.7276, ..., -0.2809, 2.0206, -1.0802]],\n", + "\n", + " [[-1.1088, -1.0420, -2.4026, ..., -1.0658, 0.1932, -1.7012],\n", + " [-2.3622, -0.5291, -1.9931, ..., -0.4478, 0.7391, -0.0354]],\n", + "\n", + " [[ 0.7349, 0.6742, -2.6697, ..., -1.4630, -0.1686, -2.5682],\n", + " [-0.1401, -0.9712, -2.3801, ..., -0.5114, 1.5155, 2.0246]]],\n", + "\n", + "\n", + " [[[-0.0799, 0.7813, 0.4918, ..., 1.2364, -1.9500, -0.1275],\n", + " [-0.4080, -1.5069, -0.8504, ..., 0.6888, -0.7680, 0.9805]],\n", + "\n", + " [[-2.8720, -1.0602, -2.3610, ..., -2.3225, -0.0351, -2.7432],\n", + " [-0.2305, -0.5940, -1.1570, ..., -2.1143, 0.9664, -1.1212]],\n", + "\n", + " [[-1.4705, -2.1384, -1.9955, ..., -0.6524, -1.8025, -1.8321],\n", + " [-1.7742, -0.6800, -0.2172, ..., -0.9722, 1.5909, -0.1668]],\n", + "\n", + " ...,\n", + "\n", + " [[-2.9884, -1.1566, -2.5215, ..., -0.5054, -1.0314, -3.4883],\n", + " [-1.9535, 0.5573, -2.1564, ..., 1.1460, 0.7120, -0.6320]],\n", + "\n", + " [[-3.3666, -0.7966, -3.3154, ..., 0.7587, -0.6289, -3.4848],\n", + " [-1.4099, -2.0919, -1.5870, ..., 0.5316, 1.7058, 2.1950]],\n", + "\n", + " [[ 0.6626, 0.8537, -2.7251, ..., -1.1831, -0.7083, -2.7717],\n", + " [ 0.4486, -1.1639, -2.1203, ..., -0.0901, 1.5883, 2.3840]]]],\n", + " grad_fn=)\n", + "tensor([[[-0.8947, 0.0412, -1.2359, ..., 0.4410, -0.3965, 0.0106],\n", + " [-1.9196, 0.3326, 0.8482, ..., -1.5790, -1.1817, -1.0156],\n", + " [-2.1664, 0.3959, 0.7476, ..., -2.1767, -0.6488, 0.1889],\n", + " ...,\n", + " [-1.6009, 0.4887, -0.4818, ..., -1.1268, 0.4111, 0.7892],\n", + " [-0.1528, 1.1728, -0.5164, ..., -0.4340, 0.1499, 1.6704],\n", + " [ 1.0253, 1.4222, -0.1805, ..., -0.6130, -0.5380, 1.6164]],\n", + "\n", + " [[-0.6736, -0.0718, -1.1724, ..., 0.2001, -0.5481, 0.0232],\n", + " [-1.8698, -1.2184, 0.2913, ..., -1.1398, -1.3523, -0.7851],\n", + " [-1.3725, -0.8212, 0.1984, ..., -1.8218, -1.4800, -0.2956],\n", + " ...,\n", + " [-0.5946, 0.5680, 0.8938, ..., -1.6653, 0.8218, 1.1902],\n", + " [ 1.2800, 1.9566, 0.2540, ..., -1.2290, 0.5257, 1.2667],\n", + " [ 1.3511, 1.3329, -0.0782, ..., -0.8454, -0.7400, 1.5966]]],\n", + " grad_fn=)\n", + "tensor([[[-0.8947, 0.0412, -1.2359, ..., 0.4410, -0.3965, 0.0106],\n", + " [-1.9196, 0.3326, 0.8482, ..., -1.5790, -1.1817, -1.0156],\n", + " [-2.1664, 0.3959, 0.7476, ..., -2.1767, -0.6488, 0.1889],\n", + " ...,\n", + " [-1.6009, 0.4887, -0.4818, ..., -1.1268, 0.4111, 0.7892],\n", + " [-0.1528, 1.1728, -0.5164, ..., -0.4340, 0.1499, 1.6704],\n", + " [ 1.0253, 1.4222, -0.1805, ..., -0.6130, -0.5380, 1.6164]],\n", + "\n", + " [[-0.6736, -0.0718, -1.1724, ..., 0.2001, -0.5481, 0.0232],\n", + " [-1.8698, -1.2184, 0.2913, ..., -1.1398, -1.3523, -0.7851],\n", + " [-1.3725, -0.8212, 0.1984, ..., -1.8218, -1.4800, -0.2956],\n", + " ...,\n", + " [-0.5946, 0.5680, 0.8938, ..., -1.6653, 0.8218, 1.1902],\n", + " [ 1.2800, 1.9566, 0.2540, ..., -1.2290, 0.5257, 1.2667],\n", + " [ 1.3511, 1.3329, -0.0782, ..., -0.8454, -0.7400, 1.5966]]],\n", + " grad_fn=)\n", + "tensor([[[[-0.8947, 0.0412, -1.2359, ..., 1.5140, -1.9812, -2.5532],\n", + " [-0.2951, -1.6086, -0.6381, ..., 0.4410, -0.3965, 0.0106]],\n", + "\n", + " [[-1.9196, 0.3326, 0.8482, ..., -2.3348, 1.3935, 1.1452],\n", + " [-0.5277, 0.1234, 0.7865, ..., -1.5790, -1.1817, -1.0156]],\n", + "\n", + " [[-2.1664, 0.3959, 0.7476, ..., -2.0817, 0.2852, 0.8173],\n", + " [-0.8414, 0.5154, -0.4553, ..., -2.1767, -0.6488, 0.1889]],\n", + "\n", + " ...,\n", + "\n", + " [[-1.6009, 0.4887, -0.4818, ..., -1.8165, 1.4764, 0.5091],\n", + " [-0.6869, 0.4007, -1.5818, ..., -1.1268, 0.4111, 0.7892]],\n", + "\n", + " [[-0.1528, 1.1728, -0.5164, ..., -1.3611, 1.0621, 1.1810],\n", + " [-0.7595, -0.1699, -1.5305, ..., -0.4340, 0.1499, 1.6704]],\n", + "\n", + " [[ 1.0253, 1.4222, -0.1805, ..., -0.6989, 0.4721, 2.6129],\n", + " [-1.2381, -0.4573, -1.7561, ..., -0.6130, -0.5380, 1.6164]]],\n", + "\n", + "\n", + " [[[-0.6736, -0.0718, -1.1724, ..., 1.4816, -1.7920, -2.5177],\n", + " [-0.3929, -1.5120, -0.5353, ..., 0.2001, -0.5481, 0.0232]],\n", + "\n", + " [[-1.8698, -1.2184, 0.2913, ..., -1.5227, 1.9764, 0.6389],\n", + " [-0.4202, 0.4572, -1.0780, ..., -1.1398, -1.3523, -0.7851]],\n", + "\n", + " [[-1.3725, -0.8212, 0.1984, ..., -2.1553, 1.7041, 0.7166],\n", + " [-1.0124, 0.9351, -0.0954, ..., -1.8218, -1.4800, -0.2956]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.5946, 0.5680, 0.8938, ..., -2.1904, 1.7986, 1.0902],\n", + " [-1.3820, 1.0268, -1.0041, ..., -1.6653, 0.8218, 1.1902]],\n", + "\n", + " [[ 1.2800, 1.9566, 0.2540, ..., -1.6180, 1.6176, 2.5636],\n", + " [-2.0592, 0.7059, -1.3359, ..., -1.2290, 0.5257, 1.2667]],\n", + "\n", + " [[ 1.3511, 1.3329, -0.0782, ..., -0.5836, 0.6491, 2.6554],\n", + " [-1.3686, -0.2348, -1.7438, ..., -0.8454, -0.7400, 1.5966]]]],\n", + " grad_fn=)\n", + "tensor([[[ 3.8307e-01, 1.8837e-02, -1.0314e+00, ..., 6.4335e-01,\n", + " -3.1830e-01, -1.7296e+00],\n", + " [ 1.5897e+00, 1.3689e-01, -6.8915e-01, ..., 1.5973e+00,\n", + " 1.1907e+00, -9.0454e-01],\n", + " [-5.8138e-01, 6.7943e-01, -1.3203e+00, ..., -5.3627e-01,\n", + " -1.0456e+00, -1.6301e+00],\n", + " ...,\n", + " [-1.7466e-01, 3.0707e-02, 7.5225e-01, ..., -1.3217e+00,\n", + " -1.3415e+00, -3.8328e-01],\n", + " [ 1.5170e-01, 5.1089e-01, 1.3993e-01, ..., -1.6600e-01,\n", + " -6.5011e-01, 2.1798e-02],\n", + " [-2.4311e-01, 1.6726e+00, 1.6682e-01, ..., 1.3441e-03,\n", + " -1.6754e+00, 3.1771e-01]],\n", + "\n", + " [[ 4.7844e-01, -3.1772e-01, -1.0617e+00, ..., 6.4928e-01,\n", + " -3.2944e-01, -2.4185e+00],\n", + " [ 5.9122e-01, -2.2648e-01, 1.7474e-01, ..., -1.8623e+00,\n", + " -1.1230e+00, 3.5013e-01],\n", + " [ 1.1642e-01, 1.2460e+00, 7.7941e-02, ..., 6.4975e-01,\n", + " -7.3862e-01, -2.1510e+00],\n", + " ...,\n", + " [ 1.1291e+00, -1.3637e+00, -1.5779e+00, ..., 1.7637e+00,\n", + " 9.1331e-01, -1.7033e+00],\n", + " [ 1.5909e+00, -1.4922e+00, 1.0060e+00, ..., 9.8096e-01,\n", + " 8.6736e-01, -2.3894e+00],\n", + " [ 1.8451e-01, 1.2740e+00, 4.2857e-01, ..., 6.2708e-01,\n", + " -1.3601e+00, -3.9984e-01]]], grad_fn=)\n", + "tensor([[[ 3.8307e-01, 1.8837e-02, -1.0314e+00, ..., 6.4335e-01,\n", + " -3.1830e-01, -1.7296e+00],\n", + " [ 1.5897e+00, 1.3689e-01, -6.8915e-01, ..., 1.5973e+00,\n", + " 1.1907e+00, -9.0454e-01],\n", + " [-5.8138e-01, 6.7943e-01, -1.3203e+00, ..., -5.3627e-01,\n", + " -1.0456e+00, -1.6301e+00],\n", + " ...,\n", + " [-1.7466e-01, 3.0707e-02, 7.5225e-01, ..., -1.3217e+00,\n", + " -1.3415e+00, -3.8328e-01],\n", + " [ 1.5170e-01, 5.1089e-01, 1.3993e-01, ..., -1.6600e-01,\n", + " -6.5011e-01, 2.1798e-02],\n", + " [-2.4311e-01, 1.6726e+00, 1.6682e-01, ..., 1.3441e-03,\n", + " -1.6754e+00, 3.1771e-01]],\n", + "\n", + " [[ 4.7844e-01, -3.1772e-01, -1.0617e+00, ..., 6.4928e-01,\n", + " -3.2944e-01, -2.4185e+00],\n", + " [ 5.9122e-01, -2.2648e-01, 1.7474e-01, ..., -1.8623e+00,\n", + " -1.1230e+00, 3.5013e-01],\n", + " [ 1.1642e-01, 1.2460e+00, 7.7941e-02, ..., 6.4975e-01,\n", + " -7.3862e-01, -2.1510e+00],\n", + " ...,\n", + " [ 1.1291e+00, -1.3637e+00, -1.5779e+00, ..., 1.7637e+00,\n", + " 9.1331e-01, -1.7033e+00],\n", + " [ 1.5909e+00, -1.4922e+00, 1.0060e+00, ..., 9.8096e-01,\n", + " 8.6736e-01, -2.3894e+00],\n", + " [ 1.8451e-01, 1.2740e+00, 4.2857e-01, ..., 6.2708e-01,\n", + " -1.3601e+00, -3.9984e-01]]], grad_fn=)\n", + "tensor([[[[ 3.8307e-01, 1.8837e-02, -1.0314e+00, ..., -8.0955e-01,\n", + " -2.8557e-01, -2.3318e-01],\n", + " [-9.9661e-02, 3.1722e-01, -3.0517e-01, ..., 6.4335e-01,\n", + " -3.1830e-01, -1.7296e+00]],\n", + "\n", + " [[ 1.5897e+00, 1.3689e-01, -6.8915e-01, ..., -7.9549e-01,\n", + " -6.9279e-01, -1.8082e-01],\n", + " [-9.9201e-01, 9.4938e-01, 4.4198e-02, ..., 1.5973e+00,\n", + " 1.1907e+00, -9.0454e-01]],\n", + "\n", + " [[-5.8138e-01, 6.7943e-01, -1.3203e+00, ..., -1.8905e+00,\n", + " 1.6226e-01, -1.2953e+00],\n", + " [-7.0312e-01, -6.4926e-01, -5.0913e-01, ..., -5.3627e-01,\n", + " -1.0456e+00, -1.6301e+00]],\n", + "\n", + " ...,\n", + "\n", + " [[-1.7466e-01, 3.0707e-02, 7.5225e-01, ..., -1.9281e+00,\n", + " 1.1489e+00, -2.4530e-01],\n", + " [-7.6225e-02, 8.5814e-01, -1.5467e+00, ..., -1.3217e+00,\n", + " -1.3415e+00, -3.8328e-01]],\n", + "\n", + " [[ 1.5170e-01, 5.1089e-01, 1.3993e-01, ..., -2.4168e+00,\n", + " 3.3385e-01, -6.2115e-02],\n", + " [-1.6390e+00, -1.6085e-01, -1.9118e+00, ..., -1.6600e-01,\n", + " -6.5011e-01, 2.1798e-02]],\n", + "\n", + " [[-2.4311e-01, 1.6726e+00, 1.6682e-01, ..., -1.0481e+00,\n", + " -2.7634e+00, 2.2741e-01],\n", + " [-1.4603e+00, 3.1239e-02, 3.8892e-01, ..., 1.3441e-03,\n", + " -1.6754e+00, 3.1771e-01]]],\n", + "\n", + "\n", + " [[[ 4.7844e-01, -3.1772e-01, -1.0617e+00, ..., -7.8511e-01,\n", + " -3.2510e-01, -2.1300e-01],\n", + " [-3.9893e-01, 6.1469e-01, -3.9206e-01, ..., 6.4928e-01,\n", + " -3.2944e-01, -2.4185e+00]],\n", + "\n", + " [[ 5.9122e-01, -2.2648e-01, 1.7474e-01, ..., -1.6488e+00,\n", + " -8.6854e-01, -8.0783e-01],\n", + " [-2.1516e+00, -2.4247e-01, -8.1713e-01, ..., -1.8623e+00,\n", + " -1.1230e+00, 3.5013e-01]],\n", + "\n", + " [[ 1.1642e-01, 1.2460e+00, 7.7941e-02, ..., -1.3121e+00,\n", + " -6.7044e-01, -1.1324e+00],\n", + " [ 4.3930e-01, 3.4082e-01, -1.2243e+00, ..., 6.4975e-01,\n", + " -7.3862e-01, -2.1510e+00]],\n", + "\n", + " ...,\n", + "\n", + " [[ 1.1291e+00, -1.3637e+00, -1.5779e+00, ..., 8.5980e-01,\n", + " 3.2796e-01, -1.9442e+00],\n", + " [-8.9502e-01, 9.7357e-01, 8.5447e-01, ..., 1.7637e+00,\n", + " 9.1331e-01, -1.7033e+00]],\n", + "\n", + " [[ 1.5909e+00, -1.4922e+00, 1.0060e+00, ..., -1.5965e+00,\n", + " -3.9380e-01, -4.3585e-01],\n", + " [-2.2103e+00, 4.4127e-01, 1.1554e+00, ..., 9.8096e-01,\n", + " 8.6736e-01, -2.3894e+00]],\n", + "\n", + " [[ 1.8451e-01, 1.2740e+00, 4.2857e-01, ..., -1.1366e+00,\n", + " -2.8409e+00, 4.6711e-01],\n", + " [-1.9576e+00, 2.0176e-01, -4.1035e-02, ..., 6.2708e-01,\n", + " -1.3601e+00, -3.9984e-01]]]], grad_fn=)\n", + "tensor([[[[ 3.7393e-01, 3.7078e-01, -7.5553e-01, ..., -9.4280e-01,\n", + " -6.5765e-01, -1.4711e-01],\n", + " [ 3.9396e-01, 4.7658e-02, -1.0277e+00, ..., -8.4926e-01,\n", + " -2.7815e-01, -2.6627e-01],\n", + " [ 9.9586e-01, 2.4414e-01, -8.3459e-01, ..., -8.5795e-01,\n", + " -4.3860e-01, -2.1582e-01],\n", + " ...,\n", + " [ 1.4504e+00, -1.2077e-01, -4.8160e-01, ..., -1.6701e+00,\n", + " 6.3807e-01, -8.3788e-02],\n", + " [ 3.8821e-01, 1.4150e-01, -1.5051e-01, ..., -1.5785e+00,\n", + " 4.1589e-01, -1.8024e-01],\n", + " [ 1.0467e-01, 5.5196e-01, 1.0818e-01, ..., -2.0373e+00,\n", + " 8.4395e-02, -6.9034e-02]],\n", + "\n", + " [[-7.7230e-01, -1.1852e-01, -8.0274e-02, ..., -1.3794e-03,\n", + " -5.2249e-01, -4.2095e-01],\n", + " [-4.6209e-01, -1.2680e-01, -2.5711e-01, ..., 1.3235e-01,\n", + " -4.1385e-01, -1.3744e+00],\n", + " [-1.7985e-01, -2.6037e-01, 3.5678e-01, ..., 2.4736e-01,\n", + " -1.6626e-01, -7.0940e-01],\n", + " ...,\n", + " [-6.2069e-01, 2.5298e-01, -7.8416e-01, ..., 8.6187e-02,\n", + " -6.9561e-01, -8.5675e-01],\n", + " [-9.1553e-01, 1.4648e-01, -5.5621e-02, ..., 1.8643e-01,\n", + " -1.0965e+00, -4.8097e-01],\n", + " [-1.0678e+00, 9.0885e-02, -4.5400e-02, ..., 1.5424e-01,\n", + " -1.1762e+00, -2.9385e-01]]],\n", + "\n", + "\n", + " [[[ 4.9445e-01, 9.6665e-03, -5.7117e-01, ..., -9.7640e-01,\n", + " -9.0948e-01, -1.3761e-01],\n", + " [ 4.7672e-01, -2.3672e-01, -8.9246e-01, ..., -8.9927e-01,\n", + " -3.9088e-01, -3.1635e-01],\n", + " [ 4.5120e-01, -5.8035e-02, -6.8736e-01, ..., -1.0352e+00,\n", + " -4.7996e-01, -4.5624e-01],\n", + " ...,\n", + " [ 3.3361e-01, 1.5918e-03, -1.0108e+00, ..., -1.5704e+00,\n", + " -3.9079e-01, -2.1742e-01],\n", + " [ 1.0123e+00, -1.1147e+00, -1.3032e+00, ..., 3.4082e-01,\n", + " 1.3298e-01, -1.5481e+00],\n", + " [ 1.4273e+00, -1.2951e+00, 5.7719e-01, ..., -1.2490e+00,\n", + " -4.1355e-01, -5.8558e-01]],\n", + "\n", + " [[-1.0884e+00, 2.0838e-01, -5.0948e-01, ..., 2.2845e-01,\n", + " -8.2075e-01, -1.1496e+00],\n", + " [-3.4296e-01, 5.0833e-01, -8.6644e-01, ..., 3.3008e-01,\n", + " -5.5070e-01, -1.8440e+00],\n", + " [-3.2602e-01, 6.0671e-01, -7.5711e-01, ..., 4.7493e-01,\n", + " -4.6589e-01, -2.0454e+00],\n", + " ...,\n", + " [-1.7399e+00, 4.6359e-01, 6.5213e-01, ..., 9.0025e-01,\n", + " 3.1643e-01, -2.0784e+00],\n", + " [-1.8375e+00, 2.6226e-01, 3.7726e-02, ..., 6.6286e-01,\n", + " -1.0588e+00, -7.6781e-01],\n", + " [-1.6402e+00, 3.1250e-01, -6.6194e-04, ..., 6.6493e-01,\n", + " -9.1110e-01, -1.0382e+00]]]],\n", + " grad_fn=)\n", + "tensor([[[[ 3.7393e-01, 3.7078e-01, -7.5553e-01, ..., -9.4280e-01,\n", + " -6.5765e-01, -1.4711e-01],\n", + " [-7.7230e-01, -1.1852e-01, -8.0274e-02, ..., -1.3794e-03,\n", + " -5.2249e-01, -4.2095e-01]],\n", + "\n", + " [[ 3.9396e-01, 4.7658e-02, -1.0277e+00, ..., -8.4926e-01,\n", + " -2.7815e-01, -2.6627e-01],\n", + " [-4.6209e-01, -1.2680e-01, -2.5711e-01, ..., 1.3235e-01,\n", + " -4.1385e-01, -1.3744e+00]],\n", + "\n", + " [[ 9.9586e-01, 2.4414e-01, -8.3459e-01, ..., -8.5795e-01,\n", + " -4.3860e-01, -2.1582e-01],\n", + " [-1.7985e-01, -2.6037e-01, 3.5678e-01, ..., 2.4736e-01,\n", + " -1.6626e-01, -7.0940e-01]],\n", + "\n", + " ...,\n", + "\n", + " [[ 1.4504e+00, -1.2077e-01, -4.8160e-01, ..., -1.6701e+00,\n", + " 6.3807e-01, -8.3788e-02],\n", + " [-6.2069e-01, 2.5298e-01, -7.8416e-01, ..., 8.6187e-02,\n", + " -6.9561e-01, -8.5675e-01]],\n", + "\n", + " [[ 3.8821e-01, 1.4150e-01, -1.5051e-01, ..., -1.5785e+00,\n", + " 4.1589e-01, -1.8024e-01],\n", + " [-9.1553e-01, 1.4648e-01, -5.5621e-02, ..., 1.8643e-01,\n", + " -1.0965e+00, -4.8097e-01]],\n", + "\n", + " [[ 1.0467e-01, 5.5196e-01, 1.0818e-01, ..., -2.0373e+00,\n", + " 8.4395e-02, -6.9034e-02],\n", + " [-1.0678e+00, 9.0885e-02, -4.5400e-02, ..., 1.5424e-01,\n", + " -1.1762e+00, -2.9385e-01]]],\n", + "\n", + "\n", + " [[[ 4.9445e-01, 9.6665e-03, -5.7117e-01, ..., -9.7640e-01,\n", + " -9.0948e-01, -1.3761e-01],\n", + " [-1.0884e+00, 2.0838e-01, -5.0948e-01, ..., 2.2845e-01,\n", + " -8.2075e-01, -1.1496e+00]],\n", + "\n", + " [[ 4.7672e-01, -2.3672e-01, -8.9246e-01, ..., -8.9927e-01,\n", + " -3.9088e-01, -3.1635e-01],\n", + " [-3.4296e-01, 5.0833e-01, -8.6644e-01, ..., 3.3008e-01,\n", + " -5.5070e-01, -1.8440e+00]],\n", + "\n", + " [[ 4.5120e-01, -5.8035e-02, -6.8736e-01, ..., -1.0352e+00,\n", + " -4.7996e-01, -4.5624e-01],\n", + " [-3.2602e-01, 6.0671e-01, -7.5711e-01, ..., 4.7493e-01,\n", + " -4.6589e-01, -2.0454e+00]],\n", + "\n", + " ...,\n", + "\n", + " [[ 3.3361e-01, 1.5918e-03, -1.0108e+00, ..., -1.5704e+00,\n", + " -3.9079e-01, -2.1742e-01],\n", + " [-1.7399e+00, 4.6359e-01, 6.5213e-01, ..., 9.0025e-01,\n", + " 3.1643e-01, -2.0784e+00]],\n", + "\n", + " [[ 1.0123e+00, -1.1147e+00, -1.3032e+00, ..., 3.4082e-01,\n", + " 1.3298e-01, -1.5481e+00],\n", + " [-1.8375e+00, 2.6226e-01, 3.7726e-02, ..., 6.6286e-01,\n", + " -1.0588e+00, -7.6781e-01]],\n", + "\n", + " [[ 1.4273e+00, -1.2951e+00, 5.7719e-01, ..., -1.2490e+00,\n", + " -4.1355e-01, -5.8558e-01],\n", + " [-1.6402e+00, 3.1250e-01, -6.6194e-04, ..., 6.6493e-01,\n", + " -9.1110e-01, -1.0382e+00]]]], grad_fn=)\n" + ] + } + ], "source": [ "import torch\n", "import chop.passes as passes\n", @@ -262,9 +1177,23 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32mINFO \u001b[0m \u001b[34mFound dropout module: bert.embeddings.dropout\u001b[0m\n", + "\u001b[32mINFO \u001b[0m \u001b[34mFound dropout module: bert.encoder.layer.0.attention.output.dropout\u001b[0m\n", + "\u001b[32mINFO \u001b[0m \u001b[34mFound dropout module: bert.encoder.layer.0.output.dropout\u001b[0m\n", + "\u001b[32mINFO \u001b[0m \u001b[34mFound dropout module: bert.encoder.layer.1.attention.output.dropout\u001b[0m\n", + "\u001b[32mINFO \u001b[0m \u001b[34mFound dropout module: bert.encoder.layer.1.output.dropout\u001b[0m\n", + "\u001b[32mINFO \u001b[0m \u001b[34mFound dropout module: dropout\u001b[0m\n", + "\u001b[32mINFO \u001b[0m \u001b[34mDropout count is: 6\u001b[0m\n" + ] + } + ], "source": [ "from chop.tools import get_logger\n", "\n", @@ -312,9 +1241,22 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32mINFO \u001b[0m \u001b[34mRemoving dropout module: bert.embeddings.dropout\u001b[0m\n", + "\u001b[32mINFO \u001b[0m \u001b[34mRemoving dropout module: bert.encoder.layer.0.attention.output.dropout\u001b[0m\n", + "\u001b[32mINFO \u001b[0m \u001b[34mRemoving dropout module: bert.encoder.layer.0.output.dropout\u001b[0m\n", + "\u001b[32mINFO \u001b[0m \u001b[34mRemoving dropout module: bert.encoder.layer.1.attention.output.dropout\u001b[0m\n", + "\u001b[32mINFO \u001b[0m \u001b[34mRemoving dropout module: bert.encoder.layer.1.output.dropout\u001b[0m\n", + "\u001b[32mINFO \u001b[0m \u001b[34mRemoving dropout module: dropout\u001b[0m\n" + ] + } + ], "source": [ "import torch.fx as fx\n", "\n", @@ -360,9 +1302,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32mINFO \u001b[0m \u001b[34mExporting MaseGraph to /home/ism/tutorial_1.pt, /home/ism/tutorial_1.mz\u001b[0m\n", + "\u001b[32mINFO \u001b[0m \u001b[34mExporting GraphModule to /home/ism/tutorial_1.pt\u001b[0m\n", + "\u001b[32mINFO \u001b[0m \u001b[34mSaving full model format\u001b[0m\n", + "\u001b[32mINFO \u001b[0m \u001b[34mExporting MaseMetadata to /home/ism/tutorial_1.mz\u001b[0m\n" + ] + } + ], "source": [ "from pathlib import Path\n", "\n", @@ -378,7 +1331,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -388,7 +1341,7 @@ ], "metadata": { "kernelspec": { - "display_name": ".venv", + "display_name": "mase", "language": "python", "name": "python3" }, diff --git a/docs/source/modules/documentation/tutorials/tutorial_2_lora_finetune.ipynb b/docs/source/modules/documentation/tutorials/tutorial_2_lora_finetune.ipynb index eb6894e00..2d60a6ee3 100644 --- a/docs/source/modules/documentation/tutorials/tutorial_2_lora_finetune.ipynb +++ b/docs/source/modules/documentation/tutorials/tutorial_2_lora_finetune.ipynb @@ -56,10 +56,14 @@ "name": "stderr", "output_type": "stream", "text": [ - "/mnt/data/zz7522/miniconda/envs/mase/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + "/home/ism/ADL/mase/.venv/lib/python3.11/site-packages/torch/cuda/__init__.py:61: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.\n", + " import pynvml # type: ignore[import]\n", + "/home/ism/ADL/mase/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", "\u001b[32mINFO \u001b[0m \u001b[34mTokenizing dataset imdb with AutoTokenizer for DeepWokLab/bert-tiny.\u001b[0m\n", - "Map: 100%|██████████| 25000/25000 [00:03<00:00, 6535.11 examples/s]\n" + "Map: 100%|██████████| 25000/25000 [00:03<00:00, 6362.44 examples/s]\n", + "Map: 100%|██████████| 25000/25000 [00:03<00:00, 6787.09 examples/s]\n", + "Map: 100%|██████████| 50000/50000 [00:07<00:00, 6692.81 examples/s]\n" ] } ], @@ -969,8 +973,8 @@ " -4.1355e-01, -5.8558e-01],\n", " [-1.6402e+00, 3.1250e-01, -6.6194e-04, ..., 6.6493e-01,\n", " -9.1110e-01, -1.0382e+00]]]], grad_fn=)\n", - "tensor([[-0.3776, -0.2373],\n", - " [-0.3180, -0.2760]], grad_fn=)\n", + "tensor([[0.1385, 0.0325],\n", + " [0.1314, 0.0993]], grad_fn=)\n", "tensor([1, 0])\n" ] } @@ -1161,7 +1165,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "/home/zz7522/Projects/mase/src/chop/tools/huggingface.py:157: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n", + "/home/ism/ADL/mase/src/chop/tools/huggingface.py:157: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n", " trainer = Trainer(\n" ] } @@ -1189,22 +1193,14 @@ "execution_count": 7, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/mnt/data/zz7522/miniconda/envs/mase/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:71: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", - " warnings.warn(\n" - ] - }, { "data": { "text/html": [ "\n", "
\n", " \n", - " \n", - " [782/782 01:29]\n", + " \n", + " [3125/3125 01:05]\n", "
\n", " " ], @@ -1219,7 +1215,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Evaluation accuracy: 0.49944\n" + "Evaluation accuracy: 0.50896\n" ] } ], @@ -1241,22 +1237,14 @@ "execution_count": 8, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/mnt/data/zz7522/miniconda/envs/mase/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:71: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", - " warnings.warn(\n" - ] - }, { "data": { "text/html": [ "\n", "
\n", " \n", - " \n", - " [782/782 00:34, Epoch 1/1]\n", + " \n", + " [3125/3125 00:32, Epoch 1/1]\n", "
\n", " \n", " \n", @@ -1268,7 +1256,27 @@ " \n", " \n", " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", "
5000.5778000.632200
10000.518200
15000.471200
20000.440200
25000.426800
30000.424900

" @@ -1280,18 +1288,10 @@ "metadata": {}, "output_type": "display_data" }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/mnt/data/zz7522/miniconda/envs/mase/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:71: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", - " warnings.warn(\n" - ] - }, { "data": { "text/plain": [ - "TrainOutput(global_step=782, training_loss=0.5414889596612252, metrics={'train_runtime': 34.4035, 'train_samples_per_second': 726.671, 'train_steps_per_second': 22.73, 'total_flos': 0.0, 'train_loss': 0.5414889596612252, 'epoch': 1.0})" + "TrainOutput(global_step=3125, training_loss=0.48300584594726564, metrics={'train_runtime': 32.6982, 'train_samples_per_second': 764.567, 'train_steps_per_second': 95.571, 'total_flos': 0.0, 'train_loss': 0.48300584594726564, 'epoch': 1.0})" ] }, "execution_count": 8, @@ -1315,19 +1315,11 @@ "execution_count": 9, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/mnt/data/zz7522/miniconda/envs/mase/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:71: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", - " warnings.warn(\n" - ] - }, { "name": "stdout", "output_type": "stream", "text": [ - "Evaluation accuracy: 0.788\n" + "Evaluation accuracy: 0.8148\n" ] } ], @@ -1352,9 +1344,10 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32mINFO \u001b[0m \u001b[34mExporting MaseGraph to /home/zz7522/tutorial_2_sft.pt, /home/zz7522/tutorial_2_sft.mz\u001b[0m\n", - "\u001b[32mINFO \u001b[0m \u001b[34mExporting GraphModule to /home/zz7522/tutorial_2_sft.pt\u001b[0m\n", - "\u001b[32mINFO \u001b[0m \u001b[34mExporting MaseMetadata to /home/zz7522/tutorial_2_sft.mz\u001b[0m\n", + "\u001b[32mINFO \u001b[0m \u001b[34mExporting MaseGraph to /home/ism/tutorial_2_sft.pt, /home/ism/tutorial_2_sft.mz\u001b[0m\n", + "\u001b[32mINFO \u001b[0m \u001b[34mExporting GraphModule to /home/ism/tutorial_2_sft.pt\u001b[0m\n", + "\u001b[32mINFO \u001b[0m \u001b[34mSaving full model format\u001b[0m\n", + "\u001b[32mINFO \u001b[0m \u001b[34mExporting MaseMetadata to /home/ism/tutorial_2_sft.mz\u001b[0m\n", "\u001b[33mWARNING \u001b[0m \u001b[34mFailed to pickle call_function node: finfo\u001b[0m\n", "\u001b[33mWARNING \u001b[0m \u001b[34mcannot pickle 'torch.finfo' object\u001b[0m\n", "\u001b[33mWARNING \u001b[0m \u001b[34mFailed to pickle call_function node: getattr_2\u001b[0m\n", @@ -1687,10 +1680,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "/home/zz7522/Projects/mase/src/chop/tools/huggingface.py:157: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n", - " trainer = Trainer(\n", - "/mnt/data/zz7522/miniconda/envs/mase/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:71: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", - " warnings.warn(\n" + "/home/ism/ADL/mase/src/chop/tools/huggingface.py:157: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n", + " trainer = Trainer(\n" ] }, { @@ -1699,8 +1690,8 @@ "\n", "

\n", " \n", - " \n", - " [782/782 00:44, Epoch 1/1]\n", + " \n", + " [3125/3125 00:52, Epoch 1/1]\n", "
\n", " \n", " \n", @@ -1712,7 +1703,27 @@ " \n", " \n", " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", "
5000.4413000.439900
10000.417000
15000.411300
20000.397800
25000.395300
30000.397300

" @@ -1724,24 +1735,14 @@ "metadata": {}, "output_type": "display_data" }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/mnt/data/zz7522/miniconda/envs/mase/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:71: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", - " warnings.warn(\n", - "/mnt/data/zz7522/miniconda/envs/mase/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:71: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", - " warnings.warn(\n" - ] - }, { "data": { "text/html": [ "\n", "

\n", " \n", - " \n", - " [782/782 01:02]\n", + " \n", + " [3125/3125 00:37]\n", "
\n", " " ], @@ -1756,7 +1757,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Evaluation accuracy: 0.82264\n" + "Evaluation accuracy: 0.83436\n" ] } ], @@ -1804,9 +1805,7 @@ "\u001b[32mINFO \u001b[0m \u001b[34mFusing LoRALinear weights for bert.encoder.layer.1.intermediate.dense.\u001b[0m\n", "\u001b[32mINFO \u001b[0m \u001b[34mFusing LoRALinear weights for bert.encoder.layer.1.output.dense.\u001b[0m\n", "\u001b[32mINFO \u001b[0m \u001b[34mFusing LoRALinear weights for bert.pooler.dense.\u001b[0m\n", - "\u001b[32mINFO \u001b[0m \u001b[34mFusing LoRALinear weights for classifier.\u001b[0m\n", - "/mnt/data/zz7522/miniconda/envs/mase/lib/python3.11/site-packages/torch/nn/parallel/_functions.py:71: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n", - " warnings.warn(\n" + "\u001b[32mINFO \u001b[0m \u001b[34mFusing LoRALinear weights for classifier.\u001b[0m\n" ] } ], @@ -1824,7 +1823,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Evaluation accuracy: 0.82264\n" + "Evaluation accuracy: 0.83436\n" ] } ], @@ -1848,9 +1847,10 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32mINFO \u001b[0m \u001b[34mExporting MaseGraph to /home/zz7522/tutorial_2_lora.pt, /home/zz7522/tutorial_2_lora.mz\u001b[0m\n", - "\u001b[32mINFO \u001b[0m \u001b[34mExporting GraphModule to /home/zz7522/tutorial_2_lora.pt\u001b[0m\n", - "\u001b[32mINFO \u001b[0m \u001b[34mExporting MaseMetadata to /home/zz7522/tutorial_2_lora.mz\u001b[0m\n", + "\u001b[32mINFO \u001b[0m \u001b[34mExporting MaseGraph to /home/ism/tutorial_2_lora.pt, /home/ism/tutorial_2_lora.mz\u001b[0m\n", + "\u001b[32mINFO \u001b[0m \u001b[34mExporting GraphModule to /home/ism/tutorial_2_lora.pt\u001b[0m\n", + "\u001b[32mINFO \u001b[0m \u001b[34mSaving full model format\u001b[0m\n", + "\u001b[32mINFO \u001b[0m \u001b[34mExporting MaseMetadata to /home/ism/tutorial_2_lora.mz\u001b[0m\n", "\u001b[33mWARNING \u001b[0m \u001b[34mFailed to pickle call_function node: finfo\u001b[0m\n", "\u001b[33mWARNING \u001b[0m \u001b[34mcannot pickle 'torch.finfo' object\u001b[0m\n", "\u001b[33mWARNING \u001b[0m \u001b[34mFailed to pickle call_function node: getattr_2\u001b[0m\n", @@ -1895,7 +1895,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.11" + "version": "3.11.9" } }, "nbformat": 4, diff --git a/docs/source/modules/documentation/tutorials/tutorial_3_qat.ipynb b/docs/source/modules/documentation/tutorials/tutorial_3_qat.ipynb index e04383982..3004d3f63 100644 --- a/docs/source/modules/documentation/tutorials/tutorial_3_qat.ipynb +++ b/docs/source/modules/documentation/tutorials/tutorial_3_qat.ipynb @@ -48,14 +48,859 @@ "name": "stderr", "output_type": "stream", "text": [ - "/Users/yz10513/anaconda3/envs/mase/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + "/home/ism/ADL/mase/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", + "/home/ism/ADL/mase/.venv/lib/python3.11/site-packages/torch/cuda/__init__.py:61: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.\n", + " import pynvml # type: ignore[import]\n", "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['classifier.bias', 'classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", "`past_key_values` were not specified as input names, but model.config.use_cache = True. Setting model.config.use_cache = False.\n", - "\u001b[32mINFO \u001b[0m \u001b[34mGetting dummy input for prajjwal1/bert-tiny.\u001b[0m\n", - "/Users/yz10513/anaconda3/envs/mase/lib/python3.11/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", - " warnings.warn(\n" + "\u001b[32mINFO \u001b[0m \u001b[34mGetting dummy input for prajjwal1/bert-tiny.\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[ 101, 9932, 2089, 2202, 2058, 1996, 2088, 2028, 2154, 102],\n", + " [ 101, 2023, 2003, 2339, 2017, 2323, 4553, 4748, 4877, 102]])\n", + "tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])\n", + "tensor([[ 101, 9932, 2089, 2202, 2058, 1996, 2088, 2028, 2154, 102],\n", + " [ 101, 2023, 2003, 2339, 2017, 2323, 4553, 4748, 4877, 102]])\n", + "tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])\n", + "tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])\n", + "tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])\n", + "tensor([[[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]],\n", + "\n", + "\n", + " [[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]]])\n", + "tensor([[[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]],\n", + "\n", + "\n", + " [[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]]])\n", + "tensor([[[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],\n", + "\n", + "\n", + " [[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]])\n", + "tensor([[[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],\n", + "\n", + "\n", + " [[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]])\n", + "tensor([[[ 0.7973, 0.0109, -8.8405, ..., 1.4170, 0.1046, -0.1551],\n", + " [-1.1766, 1.2879, -1.0986, ..., 0.4749, -0.5899, 0.8746],\n", + " [-2.0560, 0.7748, -0.8909, ..., -0.4034, 0.5352, -1.3657],\n", + " ...,\n", + " [ 0.2317, -0.7896, 0.9634, ..., -0.8037, 0.4834, -0.5868],\n", + " [ 0.0243, -1.0235, -1.2771, ..., -2.2378, 1.8530, 0.1558],\n", + " [-1.3637, 0.7055, -0.2177, ..., 0.3557, -0.3971, -0.3107]],\n", + "\n", + " [[ 0.7973, 0.0109, -8.8405, ..., 1.4170, 0.1046, -0.1551],\n", + " [-2.6940, 0.6198, -0.4564, ..., -1.4367, -1.5705, -3.1260],\n", + " [-1.7524, 0.8535, -0.2155, ..., -0.5222, -1.2430, -1.7199],\n", + " ...,\n", + " [-0.0347, 0.7446, 1.4462, ..., -1.1578, -2.6197, 0.2612],\n", + " [ 2.4334, -0.3068, 0.8250, ..., 0.1475, 0.1790, 2.2907],\n", + " [-1.3637, 0.7055, -0.2177, ..., 0.3557, -0.3971, -0.3107]]],\n", + " grad_fn=)\n", + "tensor([[[ 9.7740e-01, 2.5482e-03, -5.2921e-01, ..., 1.4757e-01,\n", + " 1.8900e-01, 2.8282e-01],\n", + " [-3.5020e-02, -6.1047e-02, -1.0465e-01, ..., -8.1892e-01,\n", + " 1.1978e+00, 2.1808e+00],\n", + " [ 4.3982e-01, -1.9602e+00, -6.8830e-01, ..., -6.3025e-01,\n", + " -1.5967e-01, 1.3284e+00],\n", + " ...,\n", + " [ 1.4783e+00, 1.0907e-01, -1.5222e+00, ..., -3.0983e-01,\n", + " -1.2971e-01, 1.1265e+00],\n", + " [ 1.5890e+00, -1.6859e+00, 7.8703e-01, ..., -1.3174e+00,\n", + " 2.2258e-01, 8.8157e-01],\n", + " [-3.7517e-01, 1.5191e+00, -2.6796e-01, ..., -1.6159e+00,\n", + " 7.2677e-02, 1.1724e-01]],\n", + "\n", + " [[ 9.7740e-01, 2.5482e-03, -5.2921e-01, ..., 1.4757e-01,\n", + " 1.8900e-01, 2.8282e-01],\n", + " [-4.4781e-01, -7.9224e-01, -2.1741e+00, ..., -5.9181e-01,\n", + " 1.4373e+00, 2.4267e+00],\n", + " [-2.5942e-01, 9.7163e-01, -3.2928e+00, ..., -5.9773e-01,\n", + " -3.0482e-01, 1.4038e+00],\n", + " ...,\n", + " [ 5.1575e-02, 3.5218e-01, -3.8926e-01, ..., -1.1508e+00,\n", + " 7.5490e-01, 8.2911e-01],\n", + " [ 1.6107e+00, 6.8170e-02, 9.2537e-01, ..., -1.5233e+00,\n", + " -6.0733e-01, 3.3097e-01],\n", + " [-3.7517e-01, 1.5191e+00, -2.6796e-01, ..., -1.6159e+00,\n", + " 7.2677e-02, 1.1724e-01]]], grad_fn=)\n", + "tensor([[[ 9.7740e-01, 2.5482e-03, -5.2921e-01, ..., 1.4757e-01,\n", + " 1.8900e-01, 2.8282e-01],\n", + " [-3.5020e-02, -6.1047e-02, -1.0465e-01, ..., -8.1892e-01,\n", + " 1.1978e+00, 2.1808e+00],\n", + " [ 4.3982e-01, -1.9602e+00, -6.8830e-01, ..., -6.3025e-01,\n", + " -1.5967e-01, 1.3284e+00],\n", + " ...,\n", + " [ 1.4783e+00, 1.0907e-01, -1.5222e+00, ..., -3.0983e-01,\n", + " -1.2971e-01, 1.1265e+00],\n", + " [ 1.5890e+00, -1.6859e+00, 7.8703e-01, ..., -1.3174e+00,\n", + " 2.2258e-01, 8.8157e-01],\n", + " [-3.7517e-01, 1.5191e+00, -2.6796e-01, ..., -1.6159e+00,\n", + " 7.2677e-02, 1.1724e-01]],\n", + "\n", + " [[ 9.7740e-01, 2.5482e-03, -5.2921e-01, ..., 1.4757e-01,\n", + " 1.8900e-01, 2.8282e-01],\n", + " [-4.4781e-01, -7.9224e-01, -2.1741e+00, ..., -5.9181e-01,\n", + " 1.4373e+00, 2.4267e+00],\n", + " [-2.5942e-01, 9.7163e-01, -3.2928e+00, ..., -5.9773e-01,\n", + " -3.0482e-01, 1.4038e+00],\n", + " ...,\n", + " [ 5.1575e-02, 3.5218e-01, -3.8926e-01, ..., -1.1508e+00,\n", + " 7.5490e-01, 8.2911e-01],\n", + " [ 1.6107e+00, 6.8170e-02, 9.2537e-01, ..., -1.5233e+00,\n", + " -6.0733e-01, 3.3097e-01],\n", + " [-3.7517e-01, 1.5191e+00, -2.6796e-01, ..., -1.6159e+00,\n", + " 7.2677e-02, 1.1724e-01]]], grad_fn=)\n", + "tensor([[[[ 9.7740e-01, 2.5482e-03, -5.2921e-01, ..., -5.2614e-01,\n", + " -3.5687e-01, -2.6793e-01],\n", + " [ 2.0335e-01, -5.4534e-01, 3.0686e-01, ..., 1.4757e-01,\n", + " 1.8900e-01, 2.8282e-01]],\n", + "\n", + " [[-3.5020e-02, -6.1047e-02, -1.0465e-01, ..., -2.0138e+00,\n", + " 4.5529e-01, -7.8171e-01],\n", + " [ 1.1969e+00, 1.6337e+00, 2.5047e-01, ..., -8.1892e-01,\n", + " 1.1978e+00, 2.1808e+00]],\n", + "\n", + " [[ 4.3982e-01, -1.9602e+00, -6.8830e-01, ..., -2.2501e-01,\n", + " 7.2290e-02, -1.8290e+00],\n", + " [ 8.9952e-01, 1.0029e+00, 7.4536e-04, ..., -6.3025e-01,\n", + " -1.5967e-01, 1.3284e+00]],\n", + "\n", + " ...,\n", + "\n", + " [[ 1.4783e+00, 1.0907e-01, -1.5222e+00, ..., 1.1867e+00,\n", + " -1.3561e+00, 6.5158e-01],\n", + " [ 9.5466e-01, 4.5887e-01, 7.8078e-01, ..., -3.0983e-01,\n", + " -1.2971e-01, 1.1265e+00]],\n", + "\n", + " [[ 1.5890e+00, -1.6859e+00, 7.8703e-01, ..., 6.5467e-01,\n", + " -6.8451e-01, 6.5081e-01],\n", + " [ 7.0729e-01, 1.4499e+00, -1.5089e-01, ..., -1.3174e+00,\n", + " 2.2258e-01, 8.8157e-01]],\n", + "\n", + " [[-3.7517e-01, 1.5191e+00, -2.6796e-01, ..., 3.3130e-01,\n", + " -3.2756e-01, -6.3130e-01],\n", + " [ 8.6773e-01, 2.0996e-01, -3.4332e-01, ..., -1.6159e+00,\n", + " 7.2677e-02, 1.1724e-01]]],\n", + "\n", + "\n", + " [[[ 9.7740e-01, 2.5482e-03, -5.2921e-01, ..., -5.2614e-01,\n", + " -3.5687e-01, -2.6793e-01],\n", + " [ 2.0335e-01, -5.4534e-01, 3.0686e-01, ..., 1.4757e-01,\n", + " 1.8900e-01, 2.8282e-01]],\n", + "\n", + " [[-4.4781e-01, -7.9224e-01, -2.1741e+00, ..., 1.7508e+00,\n", + " -3.6708e-01, -1.3251e+00],\n", + " [ 7.9208e-01, -1.3537e-01, 2.3756e-01, ..., -5.9181e-01,\n", + " 1.4373e+00, 2.4267e+00]],\n", + "\n", + " [[-2.5942e-01, 9.7163e-01, -3.2928e+00, ..., -9.6646e-01,\n", + " -4.8876e-01, -1.4426e+00],\n", + " [ 1.0250e+00, -6.9093e-01, -1.2734e+00, ..., -5.9773e-01,\n", + " -3.0482e-01, 1.4038e+00]],\n", + "\n", + " ...,\n", + "\n", + " [[ 5.1575e-02, 3.5218e-01, -3.8926e-01, ..., -1.2252e-02,\n", + " 1.0394e+00, 4.2402e-01],\n", + " [-4.7386e-01, 2.6401e+00, 1.7024e+00, ..., -1.1508e+00,\n", + " 7.5490e-01, 8.2911e-01]],\n", + "\n", + " [[ 1.6107e+00, 6.8170e-02, 9.2537e-01, ..., -6.1665e-01,\n", + " 2.7627e-01, -1.2083e+00],\n", + " [ 9.3395e-01, -9.7541e-01, -2.5442e-02, ..., -1.5233e+00,\n", + " -6.0733e-01, 3.3097e-01]],\n", + "\n", + " [[-3.7517e-01, 1.5191e+00, -2.6796e-01, ..., 3.3130e-01,\n", + " -3.2756e-01, -6.3130e-01],\n", + " [ 8.6773e-01, 2.0996e-01, -3.4332e-01, ..., -1.6159e+00,\n", + " 7.2677e-02, 1.1724e-01]]]], grad_fn=)\n", + "tensor([[[-0.1709, 0.5230, -0.8713, ..., -1.3382, 0.5892, 0.4026],\n", + " [-0.5842, 0.9588, 1.5642, ..., -1.0731, -0.7330, 0.3132],\n", + " [-0.8601, -1.3756, 0.5042, ..., -0.0476, 0.2650, 1.2150],\n", + " ...,\n", + " [ 0.0520, 1.1719, -1.5471, ..., -0.7894, 0.1419, 1.6964],\n", + " [ 0.7654, -1.5053, -0.4142, ..., -1.4622, -0.8975, 1.4576],\n", + " [-1.2008, -0.6008, -1.4608, ..., -1.2105, -0.4289, 0.3827]],\n", + "\n", + " [[-0.1709, 0.5230, -0.8713, ..., -1.3382, 0.5892, 0.4026],\n", + " [-1.3806, 0.2626, -0.5207, ..., -1.6714, -0.0554, 1.0225],\n", + " [-1.7116, 1.8788, -2.5695, ..., -0.6958, 0.5728, 0.5461],\n", + " ...,\n", + " [-1.3246, 1.2196, -0.3034, ..., -1.1955, -0.6708, 0.5128],\n", + " [ 0.9854, 0.8260, 0.2892, ..., -0.6428, 0.3637, 0.4339],\n", + " [-1.2008, -0.6008, -1.4608, ..., -1.2105, -0.4289, 0.3827]]],\n", + " grad_fn=)\n", + "tensor([[[-0.1709, 0.5230, -0.8713, ..., -1.3382, 0.5892, 0.4026],\n", + " [-0.5842, 0.9588, 1.5642, ..., -1.0731, -0.7330, 0.3132],\n", + " [-0.8601, -1.3756, 0.5042, ..., -0.0476, 0.2650, 1.2150],\n", + " ...,\n", + " [ 0.0520, 1.1719, -1.5471, ..., -0.7894, 0.1419, 1.6964],\n", + " [ 0.7654, -1.5053, -0.4142, ..., -1.4622, -0.8975, 1.4576],\n", + " [-1.2008, -0.6008, -1.4608, ..., -1.2105, -0.4289, 0.3827]],\n", + "\n", + " [[-0.1709, 0.5230, -0.8713, ..., -1.3382, 0.5892, 0.4026],\n", + " [-1.3806, 0.2626, -0.5207, ..., -1.6714, -0.0554, 1.0225],\n", + " [-1.7116, 1.8788, -2.5695, ..., -0.6958, 0.5728, 0.5461],\n", + " ...,\n", + " [-1.3246, 1.2196, -0.3034, ..., -1.1955, -0.6708, 0.5128],\n", + " [ 0.9854, 0.8260, 0.2892, ..., -0.6428, 0.3637, 0.4339],\n", + " [-1.2008, -0.6008, -1.4608, ..., -1.2105, -0.4289, 0.3827]]],\n", + " grad_fn=)\n", + "tensor([[[[-0.1709, 0.5230, -0.8713, ..., 0.4365, 0.6238, -0.9414],\n", + " [-1.3731, 1.1521, 0.1321, ..., -1.3382, 0.5892, 0.4026]],\n", + "\n", + " [[-0.5842, 0.9588, 1.5642, ..., -1.5431, 0.4999, -1.1350],\n", + " [ 0.9615, 0.8694, 0.0998, ..., -1.0731, -0.7330, 0.3132]],\n", + "\n", + " [[-0.8601, -1.3756, 0.5042, ..., 0.9764, -0.8321, -1.0204],\n", + " [ 1.5175, 1.1454, 0.7791, ..., -0.0476, 0.2650, 1.2150]],\n", + "\n", + " ...,\n", + "\n", + " [[ 0.0520, 1.1719, -1.5471, ..., 1.9402, -1.1294, 0.4793],\n", + " [ 1.0053, 0.8099, 1.6415, ..., -0.7894, 0.1419, 1.6964]],\n", + "\n", + " [[ 0.7654, -1.5053, -0.4142, ..., 1.7455, -0.7326, 1.5248],\n", + " [ 1.0806, 1.1457, 2.2163, ..., -1.4622, -0.8975, 1.4576]],\n", + "\n", + " [[-1.2008, -0.6008, -1.4608, ..., 2.0905, 1.8849, -1.5708],\n", + " [ 1.9999, 0.3493, -0.8524, ..., -1.2105, -0.4289, 0.3827]]],\n", + "\n", + "\n", + " [[[-0.1709, 0.5230, -0.8713, ..., 0.4365, 0.6238, -0.9414],\n", + " [-1.3731, 1.1521, 0.1321, ..., -1.3382, 0.5892, 0.4026]],\n", + "\n", + " [[-1.3806, 0.2626, -0.5207, ..., 1.6517, -0.2316, -1.3171],\n", + " [ 0.6812, -0.0090, 0.3803, ..., -1.6714, -0.0554, 1.0225]],\n", + "\n", + " [[-1.7116, 1.8788, -2.5695, ..., 0.4927, -0.4850, -1.0645],\n", + " [ 1.2646, 1.6481, 0.9055, ..., -0.6958, 0.5728, 0.5461]],\n", + "\n", + " ...,\n", + "\n", + " [[-1.3246, 1.2196, -0.3034, ..., 1.2747, 1.2353, 0.2825],\n", + " [ 1.5373, 0.8648, 0.6062, ..., -1.1955, -0.6708, 0.5128]],\n", + "\n", + " [[ 0.9854, 0.8260, 0.2892, ..., 1.3848, -0.0103, -1.0700],\n", + " [ 1.3827, 2.9809, 0.0276, ..., -0.6428, 0.3637, 0.4339]],\n", + "\n", + " [[-1.2008, -0.6008, -1.4608, ..., 2.0905, 1.8849, -1.5708],\n", + " [ 1.9999, 0.3493, -0.8524, ..., -1.2105, -0.4289, 0.3827]]]],\n", + " grad_fn=)\n", + "tensor([[[-0.0123, 0.5761, 0.2209, ..., -0.1027, 1.1061, -2.5200],\n", + " [-1.1465, -1.5578, -0.6984, ..., 1.0310, 0.4824, -0.2291],\n", + " [-1.0361, -1.8192, -2.3055, ..., 1.5286, -1.5941, 1.1762],\n", + " ...,\n", + " [-0.7992, 0.0886, 0.4887, ..., -1.7941, 0.4835, 1.3780],\n", + " [-1.4692, -0.9135, -0.2802, ..., -0.9691, 0.3500, 1.8863],\n", + " [-0.5760, -0.0452, 0.4230, ..., -0.7179, -0.7858, 1.6879]],\n", + "\n", + " [[-0.0123, 0.5761, 0.2209, ..., -0.1027, 1.1061, -2.5200],\n", + " [-0.3700, -1.9754, -0.7315, ..., 0.2293, 0.6996, 3.1299],\n", + " [-0.6252, 0.2879, -1.4036, ..., -2.0560, -2.4623, -0.9584],\n", + " ...,\n", + " [-1.1306, -1.4343, -1.4422, ..., -1.6115, -0.0475, 1.3975],\n", + " [-0.9816, -1.4909, -1.0086, ..., -0.9284, 0.5260, 1.5330],\n", + " [-0.5760, -0.0452, 0.4230, ..., -0.7179, -0.7858, 1.6879]]],\n", + " grad_fn=)\n", + "tensor([[[-0.0123, 0.5761, 0.2209, ..., -0.1027, 1.1061, -2.5200],\n", + " [-1.1465, -1.5578, -0.6984, ..., 1.0310, 0.4824, -0.2291],\n", + " [-1.0361, -1.8192, -2.3055, ..., 1.5286, -1.5941, 1.1762],\n", + " ...,\n", + " [-0.7992, 0.0886, 0.4887, ..., -1.7941, 0.4835, 1.3780],\n", + " [-1.4692, -0.9135, -0.2802, ..., -0.9691, 0.3500, 1.8863],\n", + " [-0.5760, -0.0452, 0.4230, ..., -0.7179, -0.7858, 1.6879]],\n", + "\n", + " [[-0.0123, 0.5761, 0.2209, ..., -0.1027, 1.1061, -2.5200],\n", + " [-0.3700, -1.9754, -0.7315, ..., 0.2293, 0.6996, 3.1299],\n", + " [-0.6252, 0.2879, -1.4036, ..., -2.0560, -2.4623, -0.9584],\n", + " ...,\n", + " [-1.1306, -1.4343, -1.4422, ..., -1.6115, -0.0475, 1.3975],\n", + " [-0.9816, -1.4909, -1.0086, ..., -0.9284, 0.5260, 1.5330],\n", + " [-0.5760, -0.0452, 0.4230, ..., -0.7179, -0.7858, 1.6879]]],\n", + " grad_fn=)\n", + "tensor([[[[-0.0123, 0.5761, 0.2209, ..., -0.1457, -0.7538, 0.1761],\n", + " [-0.0705, 0.9215, 0.7990, ..., -0.1027, 1.1061, -2.5200]],\n", + "\n", + " [[-1.1465, -1.5578, -0.6984, ..., 0.0289, -2.1112, -0.8728],\n", + " [ 0.6506, -1.6966, 1.4463, ..., 1.0310, 0.4824, -0.2291]],\n", + "\n", + " [[-1.0361, -1.8192, -2.3055, ..., -0.2195, -1.1732, 0.3182],\n", + " [-0.5841, -0.0227, 3.0901, ..., 1.5286, -1.5941, 1.1762]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.7992, 0.0886, 0.4887, ..., 0.7859, -1.0127, -0.2676],\n", + " [-0.3055, 0.6270, -3.0705, ..., -1.7941, 0.4835, 1.3780]],\n", + "\n", + " [[-1.4692, -0.9135, -0.2802, ..., 0.1197, -0.7532, 0.0731],\n", + " [ 0.6096, -1.0893, -0.6959, ..., -0.9691, 0.3500, 1.8863]],\n", + "\n", + " [[-0.5760, -0.0452, 0.4230, ..., 0.8851, 0.3078, 0.8106],\n", + " [-1.1804, 0.9512, 0.3169, ..., -0.7179, -0.7858, 1.6879]]],\n", + "\n", + "\n", + " [[[-0.0123, 0.5761, 0.2209, ..., -0.1457, -0.7538, 0.1761],\n", + " [-0.0705, 0.9215, 0.7990, ..., -0.1027, 1.1061, -2.5200]],\n", + "\n", + " [[-0.3700, -1.9754, -0.7315, ..., 0.5756, -1.5559, 0.0326],\n", + " [ 1.4229, 2.3970, -0.4516, ..., 0.2293, 0.6996, 3.1299]],\n", + "\n", + " [[-0.6252, 0.2879, -1.4036, ..., 0.5306, -0.5608, 1.1861],\n", + " [-2.5980, 0.2673, 3.3016, ..., -2.0560, -2.4623, -0.9584]],\n", + "\n", + " ...,\n", + "\n", + " [[-1.1306, -1.4343, -1.4422, ..., 0.3918, -1.5336, -0.5026],\n", + " [ 1.8587, 0.8501, -1.2402, ..., -1.6115, -0.0475, 1.3975]],\n", + "\n", + " [[-0.9816, -1.4909, -1.0086, ..., 0.2956, 0.0351, -1.0685],\n", + " [-0.6594, -0.0133, -1.1863, ..., -0.9284, 0.5260, 1.5330]],\n", + "\n", + " [[-0.5760, -0.0452, 0.4230, ..., 0.8851, 0.3078, 0.8106],\n", + " [-1.1804, 0.9512, 0.3169, ..., -0.7179, -0.7858, 1.6879]]]],\n", + " grad_fn=)\n", + "tensor([[[[-0.5911, -0.4682, -0.4314, ..., 0.0366, -1.0405, -0.0579],\n", + " [-1.1141, -1.4785, -0.6477, ..., 0.0631, -1.9950, -0.7912],\n", + " [-0.7059, -0.9954, -1.3923, ..., -0.1557, -0.9998, 0.2774],\n", + " ...,\n", + " [-0.6300, 0.1252, 0.3486, ..., 0.5692, -0.9417, -0.1399],\n", + " [-1.0593, -0.6395, -0.3343, ..., 0.0571, -0.7842, 0.1042],\n", + " [-0.2081, 0.2785, 0.0613, ..., -0.0353, -0.8389, -0.0026]],\n", + "\n", + " [[-0.3266, 0.6360, 0.0214, ..., 0.1677, 0.3883, 0.4382],\n", + " [-0.6299, 0.6012, 0.7379, ..., 0.2989, -0.1569, 0.9508],\n", + " [-0.4097, 0.6374, -0.1589, ..., 0.0258, -0.0364, 0.9990],\n", + " ...,\n", + " [-0.0088, 0.1378, -0.4819, ..., -0.1261, 0.2908, 1.0980],\n", + " [-0.5584, 0.9932, -0.1105, ..., 0.2486, 0.4005, 0.6046],\n", + " [-0.4516, 0.8092, -0.0513, ..., 0.1182, 0.4294, 0.4781]]],\n", + "\n", + "\n", + " [[[-0.6850, -0.2168, -0.6087, ..., 0.1402, -0.7267, -0.1502],\n", + " [-0.2084, -0.1325, -0.1151, ..., 0.0938, -0.8963, 0.1296],\n", + " [-0.3665, 0.3408, -0.6133, ..., 0.1938, -0.6681, 0.5929],\n", + " ...,\n", + " [-1.0985, -1.2861, -1.2531, ..., 0.3609, -1.4022, -0.4444],\n", + " [-0.9787, -1.4207, -0.9908, ..., 0.2989, -0.0159, -1.0060],\n", + " [-0.1970, 0.2371, -0.0115, ..., 0.0049, -0.8135, 0.1201]],\n", + "\n", + " [[ 0.0907, 0.7927, -0.0524, ..., -0.4610, -0.4295, 0.4391],\n", + " [-0.2352, 1.0586, -0.1117, ..., -0.4943, -0.6546, 0.6617],\n", + " [ 0.8269, 1.7861, -1.1207, ..., -0.0885, 0.2791, 1.4721],\n", + " ...,\n", + " [-0.1038, 0.9623, -0.7062, ..., -0.3340, -0.3050, 0.8792],\n", + " [ 0.0507, 0.6634, -0.2642, ..., -0.4959, -0.7944, 0.7687],\n", + " [ 0.3364, 0.9047, -0.2037, ..., -0.3705, -0.2893, 0.6787]]]],\n", + " grad_fn=)\n", + "tensor([[[[-0.5911, -0.4682, -0.4314, ..., 0.0366, -1.0405, -0.0579],\n", + " [-0.3266, 0.6360, 0.0214, ..., 0.1677, 0.3883, 0.4382]],\n", + "\n", + " [[-1.1141, -1.4785, -0.6477, ..., 0.0631, -1.9950, -0.7912],\n", + " [-0.6299, 0.6012, 0.7379, ..., 0.2989, -0.1569, 0.9508]],\n", + "\n", + " [[-0.7059, -0.9954, -1.3923, ..., -0.1557, -0.9998, 0.2774],\n", + " [-0.4097, 0.6374, -0.1589, ..., 0.0258, -0.0364, 0.9990]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.6300, 0.1252, 0.3486, ..., 0.5692, -0.9417, -0.1399],\n", + " [-0.0088, 0.1378, -0.4819, ..., -0.1261, 0.2908, 1.0980]],\n", + "\n", + " [[-1.0593, -0.6395, -0.3343, ..., 0.0571, -0.7842, 0.1042],\n", + " [-0.5584, 0.9932, -0.1105, ..., 0.2486, 0.4005, 0.6046]],\n", + "\n", + " [[-0.2081, 0.2785, 0.0613, ..., -0.0353, -0.8389, -0.0026],\n", + " [-0.4516, 0.8092, -0.0513, ..., 0.1182, 0.4294, 0.4781]]],\n", + "\n", + "\n", + " [[[-0.6850, -0.2168, -0.6087, ..., 0.1402, -0.7267, -0.1502],\n", + " [ 0.0907, 0.7927, -0.0524, ..., -0.4610, -0.4295, 0.4391]],\n", + "\n", + " [[-0.2084, -0.1325, -0.1151, ..., 0.0938, -0.8963, 0.1296],\n", + " [-0.2352, 1.0586, -0.1117, ..., -0.4943, -0.6546, 0.6617]],\n", + "\n", + " [[-0.3665, 0.3408, -0.6133, ..., 0.1938, -0.6681, 0.5929],\n", + " [ 0.8269, 1.7861, -1.1207, ..., -0.0885, 0.2791, 1.4721]],\n", + "\n", + " ...,\n", + "\n", + " [[-1.0985, -1.2861, -1.2531, ..., 0.3609, -1.4022, -0.4444],\n", + " [-0.1038, 0.9623, -0.7062, ..., -0.3340, -0.3050, 0.8792]],\n", + "\n", + " [[-0.9787, -1.4207, -0.9908, ..., 0.2989, -0.0159, -1.0060],\n", + " [ 0.0507, 0.6634, -0.2642, ..., -0.4959, -0.7944, 0.7687]],\n", + "\n", + " [[-0.1970, 0.2371, -0.0115, ..., 0.0049, -0.8135, 0.1201],\n", + " [ 0.3364, 0.9047, -0.2037, ..., -0.3705, -0.2893, 0.6787]]]],\n", + " grad_fn=)\n", + "tensor([[[-0.9552, 0.6594, -6.5403, ..., -0.7144, 0.0906, 0.3369],\n", + " [-2.5251, 1.3955, -0.8914, ..., -2.1363, 0.0271, 1.1132],\n", + " [-3.7148, 0.6796, -0.8710, ..., -2.6492, 0.5694, -0.1085],\n", + " ...,\n", + " [-2.2403, -0.7594, 0.5414, ..., -3.0426, 0.8895, -0.0546],\n", + " [-1.6945, -0.6326, -0.8632, ..., -4.0678, 1.7219, 0.6481],\n", + " [-2.9625, 0.7451, -0.8037, ..., -2.5048, 0.3125, 0.5537]],\n", + "\n", + " [[-0.5150, 0.8150, -6.5015, ..., -0.5377, -0.4171, 0.1350],\n", + " [-2.9979, 1.0930, -0.2619, ..., -3.1811, -1.0048, -1.8349],\n", + " [-2.8788, 0.5405, -0.0789, ..., -2.3969, -0.7016, -0.7332],\n", + " ...,\n", + " [-1.7194, 1.5158, 1.0070, ..., -2.8931, -2.3309, 1.1685],\n", + " [ 0.0717, -0.1039, 0.5084, ..., -2.1932, 0.0751, 2.8236],\n", + " [-2.4774, 0.7563, -0.7502, ..., -2.1312, -0.0685, 0.8700]]],\n", + " grad_fn=)\n", + "tensor([[[-0.0455, 0.6529, 0.6297, ..., 0.4139, -0.9381, 0.6769],\n", + " [-3.1104, -3.7282, -2.3953, ..., -0.9155, -0.8280, -1.7070],\n", + " [-0.9324, -2.9333, -2.3249, ..., -0.8455, -0.0326, -0.6998],\n", + " ...,\n", + " [-1.9000, -1.1028, -1.1281, ..., -0.2809, 2.0206, -1.0802],\n", + " [-1.1088, -1.0420, -2.4026, ..., -0.4478, 0.7391, -0.0354],\n", + " [ 0.7349, 0.6742, -2.6697, ..., -0.5114, 1.5155, 2.0246]],\n", + "\n", + " [[-0.0799, 0.7813, 0.4918, ..., 0.6888, -0.7680, 0.9805],\n", + " [-2.8720, -1.0602, -2.3610, ..., -2.1143, 0.9664, -1.1212],\n", + " [-1.4705, -2.1384, -1.9955, ..., -0.9722, 1.5909, -0.1668],\n", + " ...,\n", + " [-2.9884, -1.1566, -2.5215, ..., 1.1460, 0.7120, -0.6320],\n", + " [-3.3666, -0.7966, -3.3154, ..., 0.5316, 1.7058, 2.1950],\n", + " [ 0.6626, 0.8537, -2.7251, ..., -0.0901, 1.5883, 2.3840]]],\n", + " grad_fn=)\n", + "tensor([[[-0.0455, 0.6529, 0.6297, ..., 0.4139, -0.9381, 0.6769],\n", + " [-3.1104, -3.7282, -2.3953, ..., -0.9155, -0.8280, -1.7070],\n", + " [-0.9324, -2.9333, -2.3249, ..., -0.8455, -0.0326, -0.6998],\n", + " ...,\n", + " [-1.9000, -1.1028, -1.1281, ..., -0.2809, 2.0206, -1.0802],\n", + " [-1.1088, -1.0420, -2.4026, ..., -0.4478, 0.7391, -0.0354],\n", + " [ 0.7349, 0.6742, -2.6697, ..., -0.5114, 1.5155, 2.0246]],\n", + "\n", + " [[-0.0799, 0.7813, 0.4918, ..., 0.6888, -0.7680, 0.9805],\n", + " [-2.8720, -1.0602, -2.3610, ..., -2.1143, 0.9664, -1.1212],\n", + " [-1.4705, -2.1384, -1.9955, ..., -0.9722, 1.5909, -0.1668],\n", + " ...,\n", + " [-2.9884, -1.1566, -2.5215, ..., 1.1460, 0.7120, -0.6320],\n", + " [-3.3666, -0.7966, -3.3154, ..., 0.5316, 1.7058, 2.1950],\n", + " [ 0.6626, 0.8537, -2.7251, ..., -0.0901, 1.5883, 2.3840]]],\n", + " grad_fn=)\n", + "tensor([[[[-0.0455, 0.6529, 0.6297, ..., 1.0165, -1.6055, 0.0557],\n", + " [-0.9141, -1.5101, -0.8415, ..., 0.4139, -0.9381, 0.6769]],\n", + "\n", + " [[-3.1104, -3.7282, -2.3953, ..., -1.6195, 0.7426, -3.2794],\n", + " [-1.2694, 0.3821, -0.5687, ..., -0.9155, -0.8280, -1.7070]],\n", + "\n", + " [[-0.9324, -2.9333, -2.3249, ..., -1.0254, 1.8158, -1.8835],\n", + " [-1.5265, -0.3901, 0.2734, ..., -0.8455, -0.0326, -0.6998]],\n", + "\n", + " ...,\n", + "\n", + " [[-1.9000, -1.1028, -1.1281, ..., -1.2688, -0.0851, -2.3190],\n", + " [-2.4374, 0.0718, -2.7276, ..., -0.2809, 2.0206, -1.0802]],\n", + "\n", + " [[-1.1088, -1.0420, -2.4026, ..., -1.0658, 0.1932, -1.7012],\n", + " [-2.3622, -0.5291, -1.9931, ..., -0.4478, 0.7391, -0.0354]],\n", + "\n", + " [[ 0.7349, 0.6742, -2.6697, ..., -1.4630, -0.1686, -2.5682],\n", + " [-0.1401, -0.9712, -2.3801, ..., -0.5114, 1.5155, 2.0246]]],\n", + "\n", + "\n", + " [[[-0.0799, 0.7813, 0.4918, ..., 1.2364, -1.9500, -0.1275],\n", + " [-0.4080, -1.5069, -0.8504, ..., 0.6888, -0.7680, 0.9805]],\n", + "\n", + " [[-2.8720, -1.0602, -2.3610, ..., -2.3225, -0.0351, -2.7432],\n", + " [-0.2305, -0.5940, -1.1570, ..., -2.1143, 0.9664, -1.1212]],\n", + "\n", + " [[-1.4705, -2.1384, -1.9955, ..., -0.6524, -1.8025, -1.8321],\n", + " [-1.7742, -0.6800, -0.2172, ..., -0.9722, 1.5909, -0.1668]],\n", + "\n", + " ...,\n", + "\n", + " [[-2.9884, -1.1566, -2.5215, ..., -0.5054, -1.0314, -3.4883],\n", + " [-1.9535, 0.5573, -2.1564, ..., 1.1460, 0.7120, -0.6320]],\n", + "\n", + " [[-3.3666, -0.7966, -3.3154, ..., 0.7587, -0.6289, -3.4848],\n", + " [-1.4099, -2.0919, -1.5870, ..., 0.5316, 1.7058, 2.1950]],\n", + "\n", + " [[ 0.6626, 0.8537, -2.7251, ..., -1.1831, -0.7083, -2.7717],\n", + " [ 0.4486, -1.1639, -2.1203, ..., -0.0901, 1.5883, 2.3840]]]],\n", + " grad_fn=)\n", + "tensor([[[-0.8947, 0.0412, -1.2359, ..., 0.4410, -0.3965, 0.0106],\n", + " [-1.9196, 0.3326, 0.8482, ..., -1.5790, -1.1817, -1.0156],\n", + " [-2.1664, 0.3959, 0.7476, ..., -2.1767, -0.6488, 0.1889],\n", + " ...,\n", + " [-1.6009, 0.4887, -0.4818, ..., -1.1268, 0.4111, 0.7892],\n", + " [-0.1528, 1.1728, -0.5164, ..., -0.4340, 0.1499, 1.6704],\n", + " [ 1.0253, 1.4222, -0.1805, ..., -0.6130, -0.5380, 1.6164]],\n", + "\n", + " [[-0.6736, -0.0718, -1.1724, ..., 0.2001, -0.5481, 0.0232],\n", + " [-1.8698, -1.2184, 0.2913, ..., -1.1398, -1.3523, -0.7851],\n", + " [-1.3725, -0.8212, 0.1984, ..., -1.8218, -1.4800, -0.2956],\n", + " ...,\n", + " [-0.5946, 0.5680, 0.8938, ..., -1.6653, 0.8218, 1.1902],\n", + " [ 1.2800, 1.9566, 0.2540, ..., -1.2290, 0.5257, 1.2667],\n", + " [ 1.3511, 1.3329, -0.0782, ..., -0.8454, -0.7400, 1.5966]]],\n", + " grad_fn=)\n", + "tensor([[[-0.8947, 0.0412, -1.2359, ..., 0.4410, -0.3965, 0.0106],\n", + " [-1.9196, 0.3326, 0.8482, ..., -1.5790, -1.1817, -1.0156],\n", + " [-2.1664, 0.3959, 0.7476, ..., -2.1767, -0.6488, 0.1889],\n", + " ...,\n", + " [-1.6009, 0.4887, -0.4818, ..., -1.1268, 0.4111, 0.7892],\n", + " [-0.1528, 1.1728, -0.5164, ..., -0.4340, 0.1499, 1.6704],\n", + " [ 1.0253, 1.4222, -0.1805, ..., -0.6130, -0.5380, 1.6164]],\n", + "\n", + " [[-0.6736, -0.0718, -1.1724, ..., 0.2001, -0.5481, 0.0232],\n", + " [-1.8698, -1.2184, 0.2913, ..., -1.1398, -1.3523, -0.7851],\n", + " [-1.3725, -0.8212, 0.1984, ..., -1.8218, -1.4800, -0.2956],\n", + " ...,\n", + " [-0.5946, 0.5680, 0.8938, ..., -1.6653, 0.8218, 1.1902],\n", + " [ 1.2800, 1.9566, 0.2540, ..., -1.2290, 0.5257, 1.2667],\n", + " [ 1.3511, 1.3329, -0.0782, ..., -0.8454, -0.7400, 1.5966]]],\n", + " grad_fn=)\n", + "tensor([[[[-0.8947, 0.0412, -1.2359, ..., 1.5140, -1.9812, -2.5532],\n", + " [-0.2951, -1.6086, -0.6381, ..., 0.4410, -0.3965, 0.0106]],\n", + "\n", + " [[-1.9196, 0.3326, 0.8482, ..., -2.3348, 1.3935, 1.1452],\n", + " [-0.5277, 0.1234, 0.7865, ..., -1.5790, -1.1817, -1.0156]],\n", + "\n", + " [[-2.1664, 0.3959, 0.7476, ..., -2.0817, 0.2852, 0.8173],\n", + " [-0.8414, 0.5154, -0.4553, ..., -2.1767, -0.6488, 0.1889]],\n", + "\n", + " ...,\n", + "\n", + " [[-1.6009, 0.4887, -0.4818, ..., -1.8165, 1.4764, 0.5091],\n", + " [-0.6869, 0.4007, -1.5818, ..., -1.1268, 0.4111, 0.7892]],\n", + "\n", + " [[-0.1528, 1.1728, -0.5164, ..., -1.3611, 1.0621, 1.1810],\n", + " [-0.7595, -0.1699, -1.5305, ..., -0.4340, 0.1499, 1.6704]],\n", + "\n", + " [[ 1.0253, 1.4222, -0.1805, ..., -0.6989, 0.4721, 2.6129],\n", + " [-1.2381, -0.4573, -1.7561, ..., -0.6130, -0.5380, 1.6164]]],\n", + "\n", + "\n", + " [[[-0.6736, -0.0718, -1.1724, ..., 1.4816, -1.7920, -2.5177],\n", + " [-0.3929, -1.5120, -0.5353, ..., 0.2001, -0.5481, 0.0232]],\n", + "\n", + " [[-1.8698, -1.2184, 0.2913, ..., -1.5227, 1.9764, 0.6389],\n", + " [-0.4202, 0.4572, -1.0780, ..., -1.1398, -1.3523, -0.7851]],\n", + "\n", + " [[-1.3725, -0.8212, 0.1984, ..., -2.1553, 1.7041, 0.7166],\n", + " [-1.0124, 0.9351, -0.0954, ..., -1.8218, -1.4800, -0.2956]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.5946, 0.5680, 0.8938, ..., -2.1904, 1.7986, 1.0902],\n", + " [-1.3820, 1.0268, -1.0041, ..., -1.6653, 0.8218, 1.1902]],\n", + "\n", + " [[ 1.2800, 1.9566, 0.2540, ..., -1.6180, 1.6176, 2.5636],\n", + " [-2.0592, 0.7059, -1.3359, ..., -1.2290, 0.5257, 1.2667]],\n", + "\n", + " [[ 1.3511, 1.3329, -0.0782, ..., -0.5836, 0.6491, 2.6554],\n", + " [-1.3686, -0.2348, -1.7438, ..., -0.8454, -0.7400, 1.5966]]]],\n", + " grad_fn=)\n", + "tensor([[[ 3.8307e-01, 1.8837e-02, -1.0314e+00, ..., 6.4335e-01,\n", + " -3.1830e-01, -1.7296e+00],\n", + " [ 1.5897e+00, 1.3689e-01, -6.8915e-01, ..., 1.5973e+00,\n", + " 1.1907e+00, -9.0454e-01],\n", + " [-5.8138e-01, 6.7943e-01, -1.3203e+00, ..., -5.3627e-01,\n", + " -1.0456e+00, -1.6301e+00],\n", + " ...,\n", + " [-1.7466e-01, 3.0707e-02, 7.5225e-01, ..., -1.3217e+00,\n", + " -1.3415e+00, -3.8328e-01],\n", + " [ 1.5170e-01, 5.1089e-01, 1.3993e-01, ..., -1.6600e-01,\n", + " -6.5011e-01, 2.1798e-02],\n", + " [-2.4311e-01, 1.6726e+00, 1.6682e-01, ..., 1.3441e-03,\n", + " -1.6754e+00, 3.1771e-01]],\n", + "\n", + " [[ 4.7844e-01, -3.1772e-01, -1.0617e+00, ..., 6.4928e-01,\n", + " -3.2944e-01, -2.4185e+00],\n", + " [ 5.9122e-01, -2.2648e-01, 1.7474e-01, ..., -1.8623e+00,\n", + " -1.1230e+00, 3.5013e-01],\n", + " [ 1.1642e-01, 1.2460e+00, 7.7941e-02, ..., 6.4975e-01,\n", + " -7.3862e-01, -2.1510e+00],\n", + " ...,\n", + " [ 1.1291e+00, -1.3637e+00, -1.5779e+00, ..., 1.7637e+00,\n", + " 9.1331e-01, -1.7033e+00],\n", + " [ 1.5909e+00, -1.4922e+00, 1.0060e+00, ..., 9.8096e-01,\n", + " 8.6736e-01, -2.3894e+00],\n", + " [ 1.8451e-01, 1.2740e+00, 4.2857e-01, ..., 6.2708e-01,\n", + " -1.3601e+00, -3.9984e-01]]], grad_fn=)\n", + "tensor([[[ 3.8307e-01, 1.8837e-02, -1.0314e+00, ..., 6.4335e-01,\n", + " -3.1830e-01, -1.7296e+00],\n", + " [ 1.5897e+00, 1.3689e-01, -6.8915e-01, ..., 1.5973e+00,\n", + " 1.1907e+00, -9.0454e-01],\n", + " [-5.8138e-01, 6.7943e-01, -1.3203e+00, ..., -5.3627e-01,\n", + " -1.0456e+00, -1.6301e+00],\n", + " ...,\n", + " [-1.7466e-01, 3.0707e-02, 7.5225e-01, ..., -1.3217e+00,\n", + " -1.3415e+00, -3.8328e-01],\n", + " [ 1.5170e-01, 5.1089e-01, 1.3993e-01, ..., -1.6600e-01,\n", + " -6.5011e-01, 2.1798e-02],\n", + " [-2.4311e-01, 1.6726e+00, 1.6682e-01, ..., 1.3441e-03,\n", + " -1.6754e+00, 3.1771e-01]],\n", + "\n", + " [[ 4.7844e-01, -3.1772e-01, -1.0617e+00, ..., 6.4928e-01,\n", + " -3.2944e-01, -2.4185e+00],\n", + " [ 5.9122e-01, -2.2648e-01, 1.7474e-01, ..., -1.8623e+00,\n", + " -1.1230e+00, 3.5013e-01],\n", + " [ 1.1642e-01, 1.2460e+00, 7.7941e-02, ..., 6.4975e-01,\n", + " -7.3862e-01, -2.1510e+00],\n", + " ...,\n", + " [ 1.1291e+00, -1.3637e+00, -1.5779e+00, ..., 1.7637e+00,\n", + " 9.1331e-01, -1.7033e+00],\n", + " [ 1.5909e+00, -1.4922e+00, 1.0060e+00, ..., 9.8096e-01,\n", + " 8.6736e-01, -2.3894e+00],\n", + " [ 1.8451e-01, 1.2740e+00, 4.2857e-01, ..., 6.2708e-01,\n", + " -1.3601e+00, -3.9984e-01]]], grad_fn=)\n", + "tensor([[[[ 3.8307e-01, 1.8837e-02, -1.0314e+00, ..., -8.0955e-01,\n", + " -2.8557e-01, -2.3318e-01],\n", + " [-9.9661e-02, 3.1722e-01, -3.0517e-01, ..., 6.4335e-01,\n", + " -3.1830e-01, -1.7296e+00]],\n", + "\n", + " [[ 1.5897e+00, 1.3689e-01, -6.8915e-01, ..., -7.9549e-01,\n", + " -6.9279e-01, -1.8082e-01],\n", + " [-9.9201e-01, 9.4938e-01, 4.4198e-02, ..., 1.5973e+00,\n", + " 1.1907e+00, -9.0454e-01]],\n", + "\n", + " [[-5.8138e-01, 6.7943e-01, -1.3203e+00, ..., -1.8905e+00,\n", + " 1.6226e-01, -1.2953e+00],\n", + " [-7.0312e-01, -6.4926e-01, -5.0913e-01, ..., -5.3627e-01,\n", + " -1.0456e+00, -1.6301e+00]],\n", + "\n", + " ...,\n", + "\n", + " [[-1.7466e-01, 3.0707e-02, 7.5225e-01, ..., -1.9281e+00,\n", + " 1.1489e+00, -2.4530e-01],\n", + " [-7.6225e-02, 8.5814e-01, -1.5467e+00, ..., -1.3217e+00,\n", + " -1.3415e+00, -3.8328e-01]],\n", + "\n", + " [[ 1.5170e-01, 5.1089e-01, 1.3993e-01, ..., -2.4168e+00,\n", + " 3.3385e-01, -6.2115e-02],\n", + " [-1.6390e+00, -1.6085e-01, -1.9118e+00, ..., -1.6600e-01,\n", + " -6.5011e-01, 2.1798e-02]],\n", + "\n", + " [[-2.4311e-01, 1.6726e+00, 1.6682e-01, ..., -1.0481e+00,\n", + " -2.7634e+00, 2.2741e-01],\n", + " [-1.4603e+00, 3.1239e-02, 3.8892e-01, ..., 1.3441e-03,\n", + " -1.6754e+00, 3.1771e-01]]],\n", + "\n", + "\n", + " [[[ 4.7844e-01, -3.1772e-01, -1.0617e+00, ..., -7.8511e-01,\n", + " -3.2510e-01, -2.1300e-01],\n", + " [-3.9893e-01, 6.1469e-01, -3.9206e-01, ..., 6.4928e-01,\n", + " -3.2944e-01, -2.4185e+00]],\n", + "\n", + " [[ 5.9122e-01, -2.2648e-01, 1.7474e-01, ..., -1.6488e+00,\n", + " -8.6854e-01, -8.0783e-01],\n", + " [-2.1516e+00, -2.4247e-01, -8.1713e-01, ..., -1.8623e+00,\n", + " -1.1230e+00, 3.5013e-01]],\n", + "\n", + " [[ 1.1642e-01, 1.2460e+00, 7.7941e-02, ..., -1.3121e+00,\n", + " -6.7044e-01, -1.1324e+00],\n", + " [ 4.3930e-01, 3.4082e-01, -1.2243e+00, ..., 6.4975e-01,\n", + " -7.3862e-01, -2.1510e+00]],\n", + "\n", + " ...,\n", + "\n", + " [[ 1.1291e+00, -1.3637e+00, -1.5779e+00, ..., 8.5980e-01,\n", + " 3.2796e-01, -1.9442e+00],\n", + " [-8.9502e-01, 9.7357e-01, 8.5447e-01, ..., 1.7637e+00,\n", + " 9.1331e-01, -1.7033e+00]],\n", + "\n", + " [[ 1.5909e+00, -1.4922e+00, 1.0060e+00, ..., -1.5965e+00,\n", + " -3.9380e-01, -4.3585e-01],\n", + " [-2.2103e+00, 4.4127e-01, 1.1554e+00, ..., 9.8096e-01,\n", + " 8.6736e-01, -2.3894e+00]],\n", + "\n", + " [[ 1.8451e-01, 1.2740e+00, 4.2857e-01, ..., -1.1366e+00,\n", + " -2.8409e+00, 4.6711e-01],\n", + " [-1.9576e+00, 2.0176e-01, -4.1035e-02, ..., 6.2708e-01,\n", + " -1.3601e+00, -3.9984e-01]]]], grad_fn=)\n", + "tensor([[[[ 3.7393e-01, 3.7078e-01, -7.5553e-01, ..., -9.4280e-01,\n", + " -6.5765e-01, -1.4711e-01],\n", + " [ 3.9396e-01, 4.7658e-02, -1.0277e+00, ..., -8.4926e-01,\n", + " -2.7815e-01, -2.6627e-01],\n", + " [ 9.9586e-01, 2.4414e-01, -8.3459e-01, ..., -8.5795e-01,\n", + " -4.3860e-01, -2.1582e-01],\n", + " ...,\n", + " [ 1.4504e+00, -1.2077e-01, -4.8160e-01, ..., -1.6701e+00,\n", + " 6.3807e-01, -8.3788e-02],\n", + " [ 3.8821e-01, 1.4150e-01, -1.5051e-01, ..., -1.5785e+00,\n", + " 4.1589e-01, -1.8024e-01],\n", + " [ 1.0467e-01, 5.5196e-01, 1.0818e-01, ..., -2.0373e+00,\n", + " 8.4395e-02, -6.9034e-02]],\n", + "\n", + " [[-7.7230e-01, -1.1852e-01, -8.0274e-02, ..., -1.3794e-03,\n", + " -5.2249e-01, -4.2095e-01],\n", + " [-4.6209e-01, -1.2680e-01, -2.5711e-01, ..., 1.3235e-01,\n", + " -4.1385e-01, -1.3744e+00],\n", + " [-1.7985e-01, -2.6037e-01, 3.5678e-01, ..., 2.4736e-01,\n", + " -1.6626e-01, -7.0940e-01],\n", + " ...,\n", + " [-6.2069e-01, 2.5298e-01, -7.8416e-01, ..., 8.6187e-02,\n", + " -6.9561e-01, -8.5675e-01],\n", + " [-9.1553e-01, 1.4648e-01, -5.5621e-02, ..., 1.8643e-01,\n", + " -1.0965e+00, -4.8097e-01],\n", + " [-1.0678e+00, 9.0885e-02, -4.5400e-02, ..., 1.5424e-01,\n", + " -1.1762e+00, -2.9385e-01]]],\n", + "\n", + "\n", + " [[[ 4.9445e-01, 9.6665e-03, -5.7117e-01, ..., -9.7640e-01,\n", + " -9.0948e-01, -1.3761e-01],\n", + " [ 4.7672e-01, -2.3672e-01, -8.9246e-01, ..., -8.9927e-01,\n", + " -3.9088e-01, -3.1635e-01],\n", + " [ 4.5120e-01, -5.8035e-02, -6.8736e-01, ..., -1.0352e+00,\n", + " -4.7996e-01, -4.5624e-01],\n", + " ...,\n", + " [ 3.3361e-01, 1.5918e-03, -1.0108e+00, ..., -1.5704e+00,\n", + " -3.9079e-01, -2.1742e-01],\n", + " [ 1.0123e+00, -1.1147e+00, -1.3032e+00, ..., 3.4082e-01,\n", + " 1.3298e-01, -1.5481e+00],\n", + " [ 1.4273e+00, -1.2951e+00, 5.7719e-01, ..., -1.2490e+00,\n", + " -4.1355e-01, -5.8558e-01]],\n", + "\n", + " [[-1.0884e+00, 2.0838e-01, -5.0948e-01, ..., 2.2845e-01,\n", + " -8.2075e-01, -1.1496e+00],\n", + " [-3.4296e-01, 5.0833e-01, -8.6644e-01, ..., 3.3008e-01,\n", + " -5.5070e-01, -1.8440e+00],\n", + " [-3.2602e-01, 6.0671e-01, -7.5711e-01, ..., 4.7493e-01,\n", + " -4.6589e-01, -2.0454e+00],\n", + " ...,\n", + " [-1.7399e+00, 4.6359e-01, 6.5213e-01, ..., 9.0025e-01,\n", + " 3.1643e-01, -2.0784e+00],\n", + " [-1.8375e+00, 2.6226e-01, 3.7726e-02, ..., 6.6286e-01,\n", + " -1.0588e+00, -7.6781e-01],\n", + " [-1.6402e+00, 3.1250e-01, -6.6194e-04, ..., 6.6493e-01,\n", + " -9.1110e-01, -1.0382e+00]]]],\n", + " grad_fn=)\n", + "tensor([[[[ 3.7393e-01, 3.7078e-01, -7.5553e-01, ..., -9.4280e-01,\n", + " -6.5765e-01, -1.4711e-01],\n", + " [-7.7230e-01, -1.1852e-01, -8.0274e-02, ..., -1.3794e-03,\n", + " -5.2249e-01, -4.2095e-01]],\n", + "\n", + " [[ 3.9396e-01, 4.7658e-02, -1.0277e+00, ..., -8.4926e-01,\n", + " -2.7815e-01, -2.6627e-01],\n", + " [-4.6209e-01, -1.2680e-01, -2.5711e-01, ..., 1.3235e-01,\n", + " -4.1385e-01, -1.3744e+00]],\n", + "\n", + " [[ 9.9586e-01, 2.4414e-01, -8.3459e-01, ..., -8.5795e-01,\n", + " -4.3860e-01, -2.1582e-01],\n", + " [-1.7985e-01, -2.6037e-01, 3.5678e-01, ..., 2.4736e-01,\n", + " -1.6626e-01, -7.0940e-01]],\n", + "\n", + " ...,\n", + "\n", + " [[ 1.4504e+00, -1.2077e-01, -4.8160e-01, ..., -1.6701e+00,\n", + " 6.3807e-01, -8.3788e-02],\n", + " [-6.2069e-01, 2.5298e-01, -7.8416e-01, ..., 8.6187e-02,\n", + " -6.9561e-01, -8.5675e-01]],\n", + "\n", + " [[ 3.8821e-01, 1.4150e-01, -1.5051e-01, ..., -1.5785e+00,\n", + " 4.1589e-01, -1.8024e-01],\n", + " [-9.1553e-01, 1.4648e-01, -5.5621e-02, ..., 1.8643e-01,\n", + " -1.0965e+00, -4.8097e-01]],\n", + "\n", + " [[ 1.0467e-01, 5.5196e-01, 1.0818e-01, ..., -2.0373e+00,\n", + " 8.4395e-02, -6.9034e-02],\n", + " [-1.0678e+00, 9.0885e-02, -4.5400e-02, ..., 1.5424e-01,\n", + " -1.1762e+00, -2.9385e-01]]],\n", + "\n", + "\n", + " [[[ 4.9445e-01, 9.6665e-03, -5.7117e-01, ..., -9.7640e-01,\n", + " -9.0948e-01, -1.3761e-01],\n", + " [-1.0884e+00, 2.0838e-01, -5.0948e-01, ..., 2.2845e-01,\n", + " -8.2075e-01, -1.1496e+00]],\n", + "\n", + " [[ 4.7672e-01, -2.3672e-01, -8.9246e-01, ..., -8.9927e-01,\n", + " -3.9088e-01, -3.1635e-01],\n", + " [-3.4296e-01, 5.0833e-01, -8.6644e-01, ..., 3.3008e-01,\n", + " -5.5070e-01, -1.8440e+00]],\n", + "\n", + " [[ 4.5120e-01, -5.8035e-02, -6.8736e-01, ..., -1.0352e+00,\n", + " -4.7996e-01, -4.5624e-01],\n", + " [-3.2602e-01, 6.0671e-01, -7.5711e-01, ..., 4.7493e-01,\n", + " -4.6589e-01, -2.0454e+00]],\n", + "\n", + " ...,\n", + "\n", + " [[ 3.3361e-01, 1.5918e-03, -1.0108e+00, ..., -1.5704e+00,\n", + " -3.9079e-01, -2.1742e-01],\n", + " [-1.7399e+00, 4.6359e-01, 6.5213e-01, ..., 9.0025e-01,\n", + " 3.1643e-01, -2.0784e+00]],\n", + "\n", + " [[ 1.0123e+00, -1.1147e+00, -1.3032e+00, ..., 3.4082e-01,\n", + " 1.3298e-01, -1.5481e+00],\n", + " [-1.8375e+00, 2.6226e-01, 3.7726e-02, ..., 6.6286e-01,\n", + " -1.0588e+00, -7.6781e-01]],\n", + "\n", + " [[ 1.4273e+00, -1.2951e+00, 5.7719e-01, ..., -1.2490e+00,\n", + " -4.1355e-01, -5.8558e-01],\n", + " [-1.6402e+00, 3.1250e-01, -6.6194e-04, ..., 6.6493e-01,\n", + " -9.1110e-01, -1.0382e+00]]]], grad_fn=)\n", + "tensor([[0.1385, 0.0325],\n", + " [0.1314, 0.0993]], grad_fn=)\n", + "tensor([1, 0])\n" ] } ], @@ -94,11 +939,15 @@ "metadata": {}, "outputs": [ { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[33mWARNING \u001b[0m \u001b[34mNode finfo not found in loaded metadata.\u001b[0m\n", - "\u001b[33mWARNING \u001b[0m \u001b[34mNode getattr_2 not found in loaded metadata.\u001b[0m\n" + "ename": "FileNotFoundError", + "evalue": "[Errno 2] No such file or directory: '/home/ism/tutorial_2_lora.pt'", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mFileNotFoundError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[3]\u001b[39m\u001b[32m, line 4\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mpathlib\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Path\n\u001b[32m 2\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mchop\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m MaseGraph\n\u001b[32m----> \u001b[39m\u001b[32m4\u001b[39m mg = \u001b[43mMaseGraph\u001b[49m\u001b[43m.\u001b[49m\u001b[43mfrom_checkpoint\u001b[49m\u001b[43m(\u001b[49m\u001b[33;43mf\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43mPath\u001b[49m\u001b[43m.\u001b[49m\u001b[43mhome\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[33;43m/tutorial_2_lora\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/ADL/mase/src/chop/ir/graph/mase_graph.py:384\u001b[39m, in \u001b[36mMaseGraph.from_checkpoint\u001b[39m\u001b[34m(cls, checkpoint, propagate_missing_metadata)\u001b[39m\n\u001b[32m 362\u001b[39m \u001b[38;5;129m@classmethod\u001b[39m\n\u001b[32m 363\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mfrom_checkpoint\u001b[39m(\n\u001b[32m 364\u001b[39m \u001b[38;5;28mcls\u001b[39m,\n\u001b[32m 365\u001b[39m checkpoint: \u001b[38;5;28mstr\u001b[39m,\n\u001b[32m 366\u001b[39m propagate_missing_metadata: \u001b[38;5;28mbool\u001b[39m = \u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[32m 367\u001b[39m ):\n\u001b[32m 368\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"\u001b[39;00m\n\u001b[32m 369\u001b[39m \u001b[33;03m Load a MaseGraph from a checkpoint. A MaseGraph checkpoint consists of two files:\u001b[39;00m\n\u001b[32m 370\u001b[39m \u001b[33;03m {checkpoint}.pt and {checkpoint}.mz. {checkpoint}.pt contains the GraphModule,\u001b[39;00m\n\u001b[32m (...)\u001b[39m\u001b[32m 382\u001b[39m \u001b[33;03m MaseGraph: Loaded MaseGraph.\u001b[39;00m\n\u001b[32m 383\u001b[39m \u001b[33;03m \"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m384\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[33;43mf\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43mcheckpoint\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[33;43m.pt\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mrb\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mas\u001b[39;00m f:\n\u001b[32m 385\u001b[39m loaded_model = torch.load(f, weights_only=\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[32m 387\u001b[39m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\n\u001b[32m 388\u001b[39m loaded_model, fx.GraphModule\n\u001b[32m 389\u001b[39m ), \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mExpected fx.GraphModule, but received model: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(loaded_model)\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m\n", + "\u001b[31mFileNotFoundError\u001b[39m: [Errno 2] No such file or directory: '/home/ism/tutorial_2_lora.pt'" ] } ], @@ -125,7 +974,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -203,7 +1052,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -247,7 +1096,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -292,7 +1141,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -327,7 +1176,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -466,7 +1315,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": {}, "outputs": [ { diff --git a/docs/source/modules/documentation/tutorials/tutorial_4_pruning.ipynb b/docs/source/modules/documentation/tutorials/tutorial_4_pruning.ipynb index 79260257f..461b06b02 100644 --- a/docs/source/modules/documentation/tutorials/tutorial_4_pruning.ipynb +++ b/docs/source/modules/documentation/tutorials/tutorial_4_pruning.ipynb @@ -26,7 +26,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -51,19 +51,866 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ + "/home/ism/ADL/mase/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "/home/ism/ADL/mase/.venv/lib/python3.11/site-packages/torch/cuda/__init__.py:61: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.\n", + " import pynvml # type: ignore[import]\n", "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['classifier.bias', 'classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", "`past_key_values` were not specified as input names, but model.config.use_cache = True. Setting model.config.use_cache = False.\n", - "\u001b[32mINFO \u001b[0m \u001b[34mGetting dummy input for prajjwal1/bert-tiny.\u001b[0m\n", - "/Users/yz10513/anaconda3/envs/mase/lib/python3.11/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", - " warnings.warn(\n" + "\u001b[32mINFO \u001b[0m \u001b[34mGetting dummy input for prajjwal1/bert-tiny.\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[ 101, 9932, 2089, 2202, 2058, 1996, 2088, 2028, 2154, 102],\n", + " [ 101, 2023, 2003, 2339, 2017, 2323, 4553, 4748, 4877, 102]])\n", + "tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])\n", + "tensor([[ 101, 9932, 2089, 2202, 2058, 1996, 2088, 2028, 2154, 102],\n", + " [ 101, 2023, 2003, 2339, 2017, 2323, 4553, 4748, 4877, 102]])\n", + "tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])\n", + "tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])\n", + "tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])\n", + "tensor([[[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]],\n", + "\n", + "\n", + " [[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]]])\n", + "tensor([[[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]],\n", + "\n", + "\n", + " [[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]]])\n", + "tensor([[[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],\n", + "\n", + "\n", + " [[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]])\n", + "tensor([[[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],\n", + "\n", + "\n", + " [[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]])\n", + "tensor([[[ 0.7973, 0.0109, -8.8405, ..., 1.4170, 0.1046, -0.1551],\n", + " [-1.1766, 1.2879, -1.0986, ..., 0.4749, -0.5899, 0.8746],\n", + " [-2.0560, 0.7748, -0.8909, ..., -0.4034, 0.5352, -1.3657],\n", + " ...,\n", + " [ 0.2317, -0.7896, 0.9634, ..., -0.8037, 0.4834, -0.5868],\n", + " [ 0.0243, -1.0235, -1.2771, ..., -2.2378, 1.8530, 0.1558],\n", + " [-1.3637, 0.7055, -0.2177, ..., 0.3557, -0.3971, -0.3107]],\n", + "\n", + " [[ 0.7973, 0.0109, -8.8405, ..., 1.4170, 0.1046, -0.1551],\n", + " [-2.6940, 0.6198, -0.4564, ..., -1.4367, -1.5705, -3.1260],\n", + " [-1.7524, 0.8535, -0.2155, ..., -0.5222, -1.2430, -1.7199],\n", + " ...,\n", + " [-0.0347, 0.7446, 1.4462, ..., -1.1578, -2.6197, 0.2612],\n", + " [ 2.4334, -0.3068, 0.8250, ..., 0.1475, 0.1790, 2.2907],\n", + " [-1.3637, 0.7055, -0.2177, ..., 0.3557, -0.3971, -0.3107]]],\n", + " grad_fn=)\n", + "tensor([[[ 9.7740e-01, 2.5482e-03, -5.2921e-01, ..., 1.4757e-01,\n", + " 1.8900e-01, 2.8282e-01],\n", + " [-3.5020e-02, -6.1047e-02, -1.0465e-01, ..., -8.1892e-01,\n", + " 1.1978e+00, 2.1808e+00],\n", + " [ 4.3982e-01, -1.9602e+00, -6.8830e-01, ..., -6.3025e-01,\n", + " -1.5967e-01, 1.3284e+00],\n", + " ...,\n", + " [ 1.4783e+00, 1.0907e-01, -1.5222e+00, ..., -3.0983e-01,\n", + " -1.2971e-01, 1.1265e+00],\n", + " [ 1.5890e+00, -1.6859e+00, 7.8703e-01, ..., -1.3174e+00,\n", + " 2.2258e-01, 8.8157e-01],\n", + " [-3.7517e-01, 1.5191e+00, -2.6796e-01, ..., -1.6159e+00,\n", + " 7.2677e-02, 1.1724e-01]],\n", + "\n", + " [[ 9.7740e-01, 2.5482e-03, -5.2921e-01, ..., 1.4757e-01,\n", + " 1.8900e-01, 2.8282e-01],\n", + " [-4.4781e-01, -7.9224e-01, -2.1741e+00, ..., -5.9181e-01,\n", + " 1.4373e+00, 2.4267e+00],\n", + " [-2.5942e-01, 9.7163e-01, -3.2928e+00, ..., -5.9773e-01,\n", + " -3.0482e-01, 1.4038e+00],\n", + " ...,\n", + " [ 5.1575e-02, 3.5218e-01, -3.8926e-01, ..., -1.1508e+00,\n", + " 7.5490e-01, 8.2911e-01],\n", + " [ 1.6107e+00, 6.8170e-02, 9.2537e-01, ..., -1.5233e+00,\n", + " -6.0733e-01, 3.3097e-01],\n", + " [-3.7517e-01, 1.5191e+00, -2.6796e-01, ..., -1.6159e+00,\n", + " 7.2677e-02, 1.1724e-01]]], grad_fn=)\n", + "tensor([[[ 9.7740e-01, 2.5482e-03, -5.2921e-01, ..., 1.4757e-01,\n", + " 1.8900e-01, 2.8282e-01],\n", + " [-3.5020e-02, -6.1047e-02, -1.0465e-01, ..., -8.1892e-01,\n", + " 1.1978e+00, 2.1808e+00],\n", + " [ 4.3982e-01, -1.9602e+00, -6.8830e-01, ..., -6.3025e-01,\n", + " -1.5967e-01, 1.3284e+00],\n", + " ...,\n", + " [ 1.4783e+00, 1.0907e-01, -1.5222e+00, ..., -3.0983e-01,\n", + " -1.2971e-01, 1.1265e+00],\n", + " [ 1.5890e+00, -1.6859e+00, 7.8703e-01, ..., -1.3174e+00,\n", + " 2.2258e-01, 8.8157e-01],\n", + " [-3.7517e-01, 1.5191e+00, -2.6796e-01, ..., -1.6159e+00,\n", + " 7.2677e-02, 1.1724e-01]],\n", + "\n", + " [[ 9.7740e-01, 2.5482e-03, -5.2921e-01, ..., 1.4757e-01,\n", + " 1.8900e-01, 2.8282e-01],\n", + " [-4.4781e-01, -7.9224e-01, -2.1741e+00, ..., -5.9181e-01,\n", + " 1.4373e+00, 2.4267e+00],\n", + " [-2.5942e-01, 9.7163e-01, -3.2928e+00, ..., -5.9773e-01,\n", + " -3.0482e-01, 1.4038e+00],\n", + " ...,\n", + " [ 5.1575e-02, 3.5218e-01, -3.8926e-01, ..., -1.1508e+00,\n", + " 7.5490e-01, 8.2911e-01],\n", + " [ 1.6107e+00, 6.8170e-02, 9.2537e-01, ..., -1.5233e+00,\n", + " -6.0733e-01, 3.3097e-01],\n", + " [-3.7517e-01, 1.5191e+00, -2.6796e-01, ..., -1.6159e+00,\n", + " 7.2677e-02, 1.1724e-01]]], grad_fn=)\n", + "tensor([[[[ 9.7740e-01, 2.5482e-03, -5.2921e-01, ..., -5.2614e-01,\n", + " -3.5687e-01, -2.6793e-01],\n", + " [ 2.0335e-01, -5.4534e-01, 3.0686e-01, ..., 1.4757e-01,\n", + " 1.8900e-01, 2.8282e-01]],\n", + "\n", + " [[-3.5020e-02, -6.1047e-02, -1.0465e-01, ..., -2.0138e+00,\n", + " 4.5529e-01, -7.8171e-01],\n", + " [ 1.1969e+00, 1.6337e+00, 2.5047e-01, ..., -8.1892e-01,\n", + " 1.1978e+00, 2.1808e+00]],\n", + "\n", + " [[ 4.3982e-01, -1.9602e+00, -6.8830e-01, ..., -2.2501e-01,\n", + " 7.2290e-02, -1.8290e+00],\n", + " [ 8.9952e-01, 1.0029e+00, 7.4536e-04, ..., -6.3025e-01,\n", + " -1.5967e-01, 1.3284e+00]],\n", + "\n", + " ...,\n", + "\n", + " [[ 1.4783e+00, 1.0907e-01, -1.5222e+00, ..., 1.1867e+00,\n", + " -1.3561e+00, 6.5158e-01],\n", + " [ 9.5466e-01, 4.5887e-01, 7.8078e-01, ..., -3.0983e-01,\n", + " -1.2971e-01, 1.1265e+00]],\n", + "\n", + " [[ 1.5890e+00, -1.6859e+00, 7.8703e-01, ..., 6.5467e-01,\n", + " -6.8451e-01, 6.5081e-01],\n", + " [ 7.0729e-01, 1.4499e+00, -1.5089e-01, ..., -1.3174e+00,\n", + " 2.2258e-01, 8.8157e-01]],\n", + "\n", + " [[-3.7517e-01, 1.5191e+00, -2.6796e-01, ..., 3.3130e-01,\n", + " -3.2756e-01, -6.3130e-01],\n", + " [ 8.6773e-01, 2.0996e-01, -3.4332e-01, ..., -1.6159e+00,\n", + " 7.2677e-02, 1.1724e-01]]],\n", + "\n", + "\n", + " [[[ 9.7740e-01, 2.5482e-03, -5.2921e-01, ..., -5.2614e-01,\n", + " -3.5687e-01, -2.6793e-01],\n", + " [ 2.0335e-01, -5.4534e-01, 3.0686e-01, ..., 1.4757e-01,\n", + " 1.8900e-01, 2.8282e-01]],\n", + "\n", + " [[-4.4781e-01, -7.9224e-01, -2.1741e+00, ..., 1.7508e+00,\n", + " -3.6708e-01, -1.3251e+00],\n", + " [ 7.9208e-01, -1.3537e-01, 2.3756e-01, ..., -5.9181e-01,\n", + " 1.4373e+00, 2.4267e+00]],\n", + "\n", + " [[-2.5942e-01, 9.7163e-01, -3.2928e+00, ..., -9.6646e-01,\n", + " -4.8876e-01, -1.4426e+00],\n", + " [ 1.0250e+00, -6.9093e-01, -1.2734e+00, ..., -5.9773e-01,\n", + " -3.0482e-01, 1.4038e+00]],\n", + "\n", + " ...,\n", + "\n", + " [[ 5.1575e-02, 3.5218e-01, -3.8926e-01, ..., -1.2252e-02,\n", + " 1.0394e+00, 4.2402e-01],\n", + " [-4.7386e-01, 2.6401e+00, 1.7024e+00, ..., -1.1508e+00,\n", + " 7.5490e-01, 8.2911e-01]],\n", + "\n", + " [[ 1.6107e+00, 6.8170e-02, 9.2537e-01, ..., -6.1665e-01,\n", + " 2.7627e-01, -1.2083e+00],\n", + " [ 9.3395e-01, -9.7541e-01, -2.5442e-02, ..., -1.5233e+00,\n", + " -6.0733e-01, 3.3097e-01]],\n", + "\n", + " [[-3.7517e-01, 1.5191e+00, -2.6796e-01, ..., 3.3130e-01,\n", + " -3.2756e-01, -6.3130e-01],\n", + " [ 8.6773e-01, 2.0996e-01, -3.4332e-01, ..., -1.6159e+00,\n", + " 7.2677e-02, 1.1724e-01]]]], grad_fn=)\n", + "tensor([[[-0.1709, 0.5230, -0.8713, ..., -1.3382, 0.5892, 0.4026],\n", + " [-0.5842, 0.9588, 1.5642, ..., -1.0731, -0.7330, 0.3132],\n", + " [-0.8601, -1.3756, 0.5042, ..., -0.0476, 0.2650, 1.2150],\n", + " ...,\n", + " [ 0.0520, 1.1719, -1.5471, ..., -0.7894, 0.1419, 1.6964],\n", + " [ 0.7654, -1.5053, -0.4142, ..., -1.4622, -0.8975, 1.4576],\n", + " [-1.2008, -0.6008, -1.4608, ..., -1.2105, -0.4289, 0.3827]],\n", + "\n", + " [[-0.1709, 0.5230, -0.8713, ..., -1.3382, 0.5892, 0.4026],\n", + " [-1.3806, 0.2626, -0.5207, ..., -1.6714, -0.0554, 1.0225],\n", + " [-1.7116, 1.8788, -2.5695, ..., -0.6958, 0.5728, 0.5461],\n", + " ...,\n", + " [-1.3246, 1.2196, -0.3034, ..., -1.1955, -0.6708, 0.5128],\n", + " [ 0.9854, 0.8260, 0.2892, ..., -0.6428, 0.3637, 0.4339],\n", + " [-1.2008, -0.6008, -1.4608, ..., -1.2105, -0.4289, 0.3827]]],\n", + " grad_fn=)\n", + "tensor([[[-0.1709, 0.5230, -0.8713, ..., -1.3382, 0.5892, 0.4026],\n", + " [-0.5842, 0.9588, 1.5642, ..., -1.0731, -0.7330, 0.3132],\n", + " [-0.8601, -1.3756, 0.5042, ..., -0.0476, 0.2650, 1.2150],\n", + " ...,\n", + " [ 0.0520, 1.1719, -1.5471, ..., -0.7894, 0.1419, 1.6964],\n", + " [ 0.7654, -1.5053, -0.4142, ..., -1.4622, -0.8975, 1.4576],\n", + " [-1.2008, -0.6008, -1.4608, ..., -1.2105, -0.4289, 0.3827]],\n", + "\n", + " [[-0.1709, 0.5230, -0.8713, ..., -1.3382, 0.5892, 0.4026],\n", + " [-1.3806, 0.2626, -0.5207, ..., -1.6714, -0.0554, 1.0225],\n", + " [-1.7116, 1.8788, -2.5695, ..., -0.6958, 0.5728, 0.5461],\n", + " ...,\n", + " [-1.3246, 1.2196, -0.3034, ..., -1.1955, -0.6708, 0.5128],\n", + " [ 0.9854, 0.8260, 0.2892, ..., -0.6428, 0.3637, 0.4339],\n", + " [-1.2008, -0.6008, -1.4608, ..., -1.2105, -0.4289, 0.3827]]],\n", + " grad_fn=)\n", + "tensor([[[[-0.1709, 0.5230, -0.8713, ..., 0.4365, 0.6238, -0.9414],\n", + " [-1.3731, 1.1521, 0.1321, ..., -1.3382, 0.5892, 0.4026]],\n", + "\n", + " [[-0.5842, 0.9588, 1.5642, ..., -1.5431, 0.4999, -1.1350],\n", + " [ 0.9615, 0.8694, 0.0998, ..., -1.0731, -0.7330, 0.3132]],\n", + "\n", + " [[-0.8601, -1.3756, 0.5042, ..., 0.9764, -0.8321, -1.0204],\n", + " [ 1.5175, 1.1454, 0.7791, ..., -0.0476, 0.2650, 1.2150]],\n", + "\n", + " ...,\n", + "\n", + " [[ 0.0520, 1.1719, -1.5471, ..., 1.9402, -1.1294, 0.4793],\n", + " [ 1.0053, 0.8099, 1.6415, ..., -0.7894, 0.1419, 1.6964]],\n", + "\n", + " [[ 0.7654, -1.5053, -0.4142, ..., 1.7455, -0.7326, 1.5248],\n", + " [ 1.0806, 1.1457, 2.2163, ..., -1.4622, -0.8975, 1.4576]],\n", + "\n", + " [[-1.2008, -0.6008, -1.4608, ..., 2.0905, 1.8849, -1.5708],\n", + " [ 1.9999, 0.3493, -0.8524, ..., -1.2105, -0.4289, 0.3827]]],\n", + "\n", + "\n", + " [[[-0.1709, 0.5230, -0.8713, ..., 0.4365, 0.6238, -0.9414],\n", + " [-1.3731, 1.1521, 0.1321, ..., -1.3382, 0.5892, 0.4026]],\n", + "\n", + " [[-1.3806, 0.2626, -0.5207, ..., 1.6517, -0.2316, -1.3171],\n", + " [ 0.6812, -0.0090, 0.3803, ..., -1.6714, -0.0554, 1.0225]],\n", + "\n", + " [[-1.7116, 1.8788, -2.5695, ..., 0.4927, -0.4850, -1.0645],\n", + " [ 1.2646, 1.6481, 0.9055, ..., -0.6958, 0.5728, 0.5461]],\n", + "\n", + " ...,\n", + "\n", + " [[-1.3246, 1.2196, -0.3034, ..., 1.2747, 1.2353, 0.2825],\n", + " [ 1.5373, 0.8648, 0.6062, ..., -1.1955, -0.6708, 0.5128]],\n", + "\n", + " [[ 0.9854, 0.8260, 0.2892, ..., 1.3848, -0.0103, -1.0700],\n", + " [ 1.3827, 2.9809, 0.0276, ..., -0.6428, 0.3637, 0.4339]],\n", + "\n", + " [[-1.2008, -0.6008, -1.4608, ..., 2.0905, 1.8849, -1.5708],\n", + " [ 1.9999, 0.3493, -0.8524, ..., -1.2105, -0.4289, 0.3827]]]],\n", + " grad_fn=)\n", + "tensor([[[-0.0123, 0.5761, 0.2209, ..., -0.1027, 1.1061, -2.5200],\n", + " [-1.1465, -1.5578, -0.6984, ..., 1.0310, 0.4824, -0.2291],\n", + " [-1.0361, -1.8192, -2.3055, ..., 1.5286, -1.5941, 1.1762],\n", + " ...,\n", + " [-0.7992, 0.0886, 0.4887, ..., -1.7941, 0.4835, 1.3780],\n", + " [-1.4692, -0.9135, -0.2802, ..., -0.9691, 0.3500, 1.8863],\n", + " [-0.5760, -0.0452, 0.4230, ..., -0.7179, -0.7858, 1.6879]],\n", + "\n", + " [[-0.0123, 0.5761, 0.2209, ..., -0.1027, 1.1061, -2.5200],\n", + " [-0.3700, -1.9754, -0.7315, ..., 0.2293, 0.6996, 3.1299],\n", + " [-0.6252, 0.2879, -1.4036, ..., -2.0560, -2.4623, -0.9584],\n", + " ...,\n", + " [-1.1306, -1.4343, -1.4422, ..., -1.6115, -0.0475, 1.3975],\n", + " [-0.9816, -1.4909, -1.0086, ..., -0.9284, 0.5260, 1.5330],\n", + " [-0.5760, -0.0452, 0.4230, ..., -0.7179, -0.7858, 1.6879]]],\n", + " grad_fn=)\n", + "tensor([[[-0.0123, 0.5761, 0.2209, ..., -0.1027, 1.1061, -2.5200],\n", + " [-1.1465, -1.5578, -0.6984, ..., 1.0310, 0.4824, -0.2291],\n", + " [-1.0361, -1.8192, -2.3055, ..., 1.5286, -1.5941, 1.1762],\n", + " ...,\n", + " [-0.7992, 0.0886, 0.4887, ..., -1.7941, 0.4835, 1.3780],\n", + " [-1.4692, -0.9135, -0.2802, ..., -0.9691, 0.3500, 1.8863],\n", + " [-0.5760, -0.0452, 0.4230, ..., -0.7179, -0.7858, 1.6879]],\n", + "\n", + " [[-0.0123, 0.5761, 0.2209, ..., -0.1027, 1.1061, -2.5200],\n", + " [-0.3700, -1.9754, -0.7315, ..., 0.2293, 0.6996, 3.1299],\n", + " [-0.6252, 0.2879, -1.4036, ..., -2.0560, -2.4623, -0.9584],\n", + " ...,\n", + " [-1.1306, -1.4343, -1.4422, ..., -1.6115, -0.0475, 1.3975],\n", + " [-0.9816, -1.4909, -1.0086, ..., -0.9284, 0.5260, 1.5330],\n", + " [-0.5760, -0.0452, 0.4230, ..., -0.7179, -0.7858, 1.6879]]],\n", + " grad_fn=)\n", + "tensor([[[[-0.0123, 0.5761, 0.2209, ..., -0.1457, -0.7538, 0.1761],\n", + " [-0.0705, 0.9215, 0.7990, ..., -0.1027, 1.1061, -2.5200]],\n", + "\n", + " [[-1.1465, -1.5578, -0.6984, ..., 0.0289, -2.1112, -0.8728],\n", + " [ 0.6506, -1.6966, 1.4463, ..., 1.0310, 0.4824, -0.2291]],\n", + "\n", + " [[-1.0361, -1.8192, -2.3055, ..., -0.2195, -1.1732, 0.3182],\n", + " [-0.5841, -0.0227, 3.0901, ..., 1.5286, -1.5941, 1.1762]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.7992, 0.0886, 0.4887, ..., 0.7859, -1.0127, -0.2676],\n", + " [-0.3055, 0.6270, -3.0705, ..., -1.7941, 0.4835, 1.3780]],\n", + "\n", + " [[-1.4692, -0.9135, -0.2802, ..., 0.1197, -0.7532, 0.0731],\n", + " [ 0.6096, -1.0893, -0.6959, ..., -0.9691, 0.3500, 1.8863]],\n", + "\n", + " [[-0.5760, -0.0452, 0.4230, ..., 0.8851, 0.3078, 0.8106],\n", + " [-1.1804, 0.9512, 0.3169, ..., -0.7179, -0.7858, 1.6879]]],\n", + "\n", + "\n", + " [[[-0.0123, 0.5761, 0.2209, ..., -0.1457, -0.7538, 0.1761],\n", + " [-0.0705, 0.9215, 0.7990, ..., -0.1027, 1.1061, -2.5200]],\n", + "\n", + " [[-0.3700, -1.9754, -0.7315, ..., 0.5756, -1.5559, 0.0326],\n", + " [ 1.4229, 2.3970, -0.4516, ..., 0.2293, 0.6996, 3.1299]],\n", + "\n", + " [[-0.6252, 0.2879, -1.4036, ..., 0.5306, -0.5608, 1.1861],\n", + " [-2.5980, 0.2673, 3.3016, ..., -2.0560, -2.4623, -0.9584]],\n", + "\n", + " ...,\n", + "\n", + " [[-1.1306, -1.4343, -1.4422, ..., 0.3918, -1.5336, -0.5026],\n", + " [ 1.8587, 0.8501, -1.2402, ..., -1.6115, -0.0475, 1.3975]],\n", + "\n", + " [[-0.9816, -1.4909, -1.0086, ..., 0.2956, 0.0351, -1.0685],\n", + " [-0.6594, -0.0133, -1.1863, ..., -0.9284, 0.5260, 1.5330]],\n", + "\n", + " [[-0.5760, -0.0452, 0.4230, ..., 0.8851, 0.3078, 0.8106],\n", + " [-1.1804, 0.9512, 0.3169, ..., -0.7179, -0.7858, 1.6879]]]],\n", + " grad_fn=)\n", + "tensor([[[[-0.5911, -0.4682, -0.4314, ..., 0.0366, -1.0405, -0.0579],\n", + " [-1.1141, -1.4785, -0.6477, ..., 0.0631, -1.9950, -0.7912],\n", + " [-0.7059, -0.9954, -1.3923, ..., -0.1557, -0.9998, 0.2774],\n", + " ...,\n", + " [-0.6300, 0.1252, 0.3486, ..., 0.5692, -0.9417, -0.1399],\n", + " [-1.0593, -0.6395, -0.3343, ..., 0.0571, -0.7842, 0.1042],\n", + " [-0.2081, 0.2785, 0.0613, ..., -0.0353, -0.8389, -0.0026]],\n", + "\n", + " [[-0.3266, 0.6360, 0.0214, ..., 0.1677, 0.3883, 0.4382],\n", + " [-0.6299, 0.6012, 0.7379, ..., 0.2989, -0.1569, 0.9508],\n", + " [-0.4097, 0.6374, -0.1589, ..., 0.0258, -0.0364, 0.9990],\n", + " ...,\n", + " [-0.0088, 0.1378, -0.4819, ..., -0.1261, 0.2908, 1.0980],\n", + " [-0.5584, 0.9932, -0.1105, ..., 0.2486, 0.4005, 0.6046],\n", + " [-0.4516, 0.8092, -0.0513, ..., 0.1182, 0.4294, 0.4781]]],\n", + "\n", + "\n", + " [[[-0.6850, -0.2168, -0.6087, ..., 0.1402, -0.7267, -0.1502],\n", + " [-0.2084, -0.1325, -0.1151, ..., 0.0938, -0.8963, 0.1296],\n", + " [-0.3665, 0.3408, -0.6133, ..., 0.1938, -0.6681, 0.5929],\n", + " ...,\n", + " [-1.0985, -1.2861, -1.2531, ..., 0.3609, -1.4022, -0.4444],\n", + " [-0.9787, -1.4207, -0.9908, ..., 0.2989, -0.0159, -1.0060],\n", + " [-0.1970, 0.2371, -0.0115, ..., 0.0049, -0.8135, 0.1201]],\n", + "\n", + " [[ 0.0907, 0.7927, -0.0524, ..., -0.4610, -0.4295, 0.4391],\n", + " [-0.2352, 1.0586, -0.1117, ..., -0.4943, -0.6546, 0.6617],\n", + " [ 0.8269, 1.7861, -1.1207, ..., -0.0885, 0.2791, 1.4721],\n", + " ...,\n", + " [-0.1038, 0.9623, -0.7062, ..., -0.3340, -0.3050, 0.8792],\n", + " [ 0.0507, 0.6634, -0.2642, ..., -0.4959, -0.7944, 0.7687],\n", + " [ 0.3364, 0.9047, -0.2037, ..., -0.3705, -0.2893, 0.6787]]]],\n", + " grad_fn=)\n", + "tensor([[[[-0.5911, -0.4682, -0.4314, ..., 0.0366, -1.0405, -0.0579],\n", + " [-0.3266, 0.6360, 0.0214, ..., 0.1677, 0.3883, 0.4382]],\n", + "\n", + " [[-1.1141, -1.4785, -0.6477, ..., 0.0631, -1.9950, -0.7912],\n", + " [-0.6299, 0.6012, 0.7379, ..., 0.2989, -0.1569, 0.9508]],\n", + "\n", + " [[-0.7059, -0.9954, -1.3923, ..., -0.1557, -0.9998, 0.2774],\n", + " [-0.4097, 0.6374, -0.1589, ..., 0.0258, -0.0364, 0.9990]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.6300, 0.1252, 0.3486, ..., 0.5692, -0.9417, -0.1399],\n", + " [-0.0088, 0.1378, -0.4819, ..., -0.1261, 0.2908, 1.0980]],\n", + "\n", + " [[-1.0593, -0.6395, -0.3343, ..., 0.0571, -0.7842, 0.1042],\n", + " [-0.5584, 0.9932, -0.1105, ..., 0.2486, 0.4005, 0.6046]],\n", + "\n", + " [[-0.2081, 0.2785, 0.0613, ..., -0.0353, -0.8389, -0.0026],\n", + " [-0.4516, 0.8092, -0.0513, ..., 0.1182, 0.4294, 0.4781]]],\n", + "\n", + "\n", + " [[[-0.6850, -0.2168, -0.6087, ..., 0.1402, -0.7267, -0.1502],\n", + " [ 0.0907, 0.7927, -0.0524, ..., -0.4610, -0.4295, 0.4391]],\n", + "\n", + " [[-0.2084, -0.1325, -0.1151, ..., 0.0938, -0.8963, 0.1296],\n", + " [-0.2352, 1.0586, -0.1117, ..., -0.4943, -0.6546, 0.6617]],\n", + "\n", + " [[-0.3665, 0.3408, -0.6133, ..., 0.1938, -0.6681, 0.5929],\n", + " [ 0.8269, 1.7861, -1.1207, ..., -0.0885, 0.2791, 1.4721]],\n", + "\n", + " ...,\n", + "\n", + " [[-1.0985, -1.2861, -1.2531, ..., 0.3609, -1.4022, -0.4444],\n", + " [-0.1038, 0.9623, -0.7062, ..., -0.3340, -0.3050, 0.8792]],\n", + "\n", + " [[-0.9787, -1.4207, -0.9908, ..., 0.2989, -0.0159, -1.0060],\n", + " [ 0.0507, 0.6634, -0.2642, ..., -0.4959, -0.7944, 0.7687]],\n", + "\n", + " [[-0.1970, 0.2371, -0.0115, ..., 0.0049, -0.8135, 0.1201],\n", + " [ 0.3364, 0.9047, -0.2037, ..., -0.3705, -0.2893, 0.6787]]]],\n", + " grad_fn=)\n", + "tensor([[[-0.9552, 0.6594, -6.5403, ..., -0.7144, 0.0906, 0.3369],\n", + " [-2.5251, 1.3955, -0.8914, ..., -2.1363, 0.0271, 1.1132],\n", + " [-3.7148, 0.6796, -0.8710, ..., -2.6492, 0.5694, -0.1085],\n", + " ...,\n", + " [-2.2403, -0.7594, 0.5414, ..., -3.0426, 0.8895, -0.0546],\n", + " [-1.6945, -0.6326, -0.8632, ..., -4.0678, 1.7219, 0.6481],\n", + " [-2.9625, 0.7451, -0.8037, ..., -2.5048, 0.3125, 0.5537]],\n", + "\n", + " [[-0.5150, 0.8150, -6.5015, ..., -0.5377, -0.4171, 0.1350],\n", + " [-2.9979, 1.0930, -0.2619, ..., -3.1811, -1.0048, -1.8349],\n", + " [-2.8788, 0.5405, -0.0789, ..., -2.3969, -0.7016, -0.7332],\n", + " ...,\n", + " [-1.7194, 1.5158, 1.0070, ..., -2.8931, -2.3309, 1.1685],\n", + " [ 0.0717, -0.1039, 0.5084, ..., -2.1932, 0.0751, 2.8236],\n", + " [-2.4774, 0.7563, -0.7502, ..., -2.1312, -0.0685, 0.8700]]],\n", + " grad_fn=)\n", + "tensor([[[-0.0455, 0.6529, 0.6297, ..., 0.4139, -0.9381, 0.6769],\n", + " [-3.1104, -3.7282, -2.3953, ..., -0.9155, -0.8280, -1.7070],\n", + " [-0.9324, -2.9333, -2.3249, ..., -0.8455, -0.0326, -0.6998],\n", + " ...,\n", + " [-1.9000, -1.1028, -1.1281, ..., -0.2809, 2.0206, -1.0802],\n", + " [-1.1088, -1.0420, -2.4026, ..., -0.4478, 0.7391, -0.0354],\n", + " [ 0.7349, 0.6742, -2.6697, ..., -0.5114, 1.5155, 2.0246]],\n", + "\n", + " [[-0.0799, 0.7813, 0.4918, ..., 0.6888, -0.7680, 0.9805],\n", + " [-2.8720, -1.0602, -2.3610, ..., -2.1143, 0.9664, -1.1212],\n", + " [-1.4705, -2.1384, -1.9955, ..., -0.9722, 1.5909, -0.1668],\n", + " ...,\n", + " [-2.9884, -1.1566, -2.5215, ..., 1.1460, 0.7120, -0.6320],\n", + " [-3.3666, -0.7966, -3.3154, ..., 0.5316, 1.7058, 2.1950],\n", + " [ 0.6626, 0.8537, -2.7251, ..., -0.0901, 1.5883, 2.3840]]],\n", + " grad_fn=)\n", + "tensor([[[-0.0455, 0.6529, 0.6297, ..., 0.4139, -0.9381, 0.6769],\n", + " [-3.1104, -3.7282, -2.3953, ..., -0.9155, -0.8280, -1.7070],\n", + " [-0.9324, -2.9333, -2.3249, ..., -0.8455, -0.0326, -0.6998],\n", + " ...,\n", + " [-1.9000, -1.1028, -1.1281, ..., -0.2809, 2.0206, -1.0802],\n", + " [-1.1088, -1.0420, -2.4026, ..., -0.4478, 0.7391, -0.0354],\n", + " [ 0.7349, 0.6742, -2.6697, ..., -0.5114, 1.5155, 2.0246]],\n", + "\n", + " [[-0.0799, 0.7813, 0.4918, ..., 0.6888, -0.7680, 0.9805],\n", + " [-2.8720, -1.0602, -2.3610, ..., -2.1143, 0.9664, -1.1212],\n", + " [-1.4705, -2.1384, -1.9955, ..., -0.9722, 1.5909, -0.1668],\n", + " ...,\n", + " [-2.9884, -1.1566, -2.5215, ..., 1.1460, 0.7120, -0.6320],\n", + " [-3.3666, -0.7966, -3.3154, ..., 0.5316, 1.7058, 2.1950],\n", + " [ 0.6626, 0.8537, -2.7251, ..., -0.0901, 1.5883, 2.3840]]],\n", + " grad_fn=)\n", + "tensor([[[[-0.0455, 0.6529, 0.6297, ..., 1.0165, -1.6055, 0.0557],\n", + " [-0.9141, -1.5101, -0.8415, ..., 0.4139, -0.9381, 0.6769]],\n", + "\n", + " [[-3.1104, -3.7282, -2.3953, ..., -1.6195, 0.7426, -3.2794],\n", + " [-1.2694, 0.3821, -0.5687, ..., -0.9155, -0.8280, -1.7070]],\n", + "\n", + " [[-0.9324, -2.9333, -2.3249, ..., -1.0254, 1.8158, -1.8835],\n", + " [-1.5265, -0.3901, 0.2734, ..., -0.8455, -0.0326, -0.6998]],\n", + "\n", + " ...,\n", + "\n", + " [[-1.9000, -1.1028, -1.1281, ..., -1.2688, -0.0851, -2.3190],\n", + " [-2.4374, 0.0718, -2.7276, ..., -0.2809, 2.0206, -1.0802]],\n", + "\n", + " [[-1.1088, -1.0420, -2.4026, ..., -1.0658, 0.1932, -1.7012],\n", + " [-2.3622, -0.5291, -1.9931, ..., -0.4478, 0.7391, -0.0354]],\n", + "\n", + " [[ 0.7349, 0.6742, -2.6697, ..., -1.4630, -0.1686, -2.5682],\n", + " [-0.1401, -0.9712, -2.3801, ..., -0.5114, 1.5155, 2.0246]]],\n", + "\n", + "\n", + " [[[-0.0799, 0.7813, 0.4918, ..., 1.2364, -1.9500, -0.1275],\n", + " [-0.4080, -1.5069, -0.8504, ..., 0.6888, -0.7680, 0.9805]],\n", + "\n", + " [[-2.8720, -1.0602, -2.3610, ..., -2.3225, -0.0351, -2.7432],\n", + " [-0.2305, -0.5940, -1.1570, ..., -2.1143, 0.9664, -1.1212]],\n", + "\n", + " [[-1.4705, -2.1384, -1.9955, ..., -0.6524, -1.8025, -1.8321],\n", + " [-1.7742, -0.6800, -0.2172, ..., -0.9722, 1.5909, -0.1668]],\n", + "\n", + " ...,\n", + "\n", + " [[-2.9884, -1.1566, -2.5215, ..., -0.5054, -1.0314, -3.4883],\n", + " [-1.9535, 0.5573, -2.1564, ..., 1.1460, 0.7120, -0.6320]],\n", + "\n", + " [[-3.3666, -0.7966, -3.3154, ..., 0.7587, -0.6289, -3.4848],\n", + " [-1.4099, -2.0919, -1.5870, ..., 0.5316, 1.7058, 2.1950]],\n", + "\n", + " [[ 0.6626, 0.8537, -2.7251, ..., -1.1831, -0.7083, -2.7717],\n", + " [ 0.4486, -1.1639, -2.1203, ..., -0.0901, 1.5883, 2.3840]]]],\n", + " grad_fn=)\n", + "tensor([[[-0.8947, 0.0412, -1.2359, ..., 0.4410, -0.3965, 0.0106],\n", + " [-1.9196, 0.3326, 0.8482, ..., -1.5790, -1.1817, -1.0156],\n", + " [-2.1664, 0.3959, 0.7476, ..., -2.1767, -0.6488, 0.1889],\n", + " ...,\n", + " [-1.6009, 0.4887, -0.4818, ..., -1.1268, 0.4111, 0.7892],\n", + " [-0.1528, 1.1728, -0.5164, ..., -0.4340, 0.1499, 1.6704],\n", + " [ 1.0253, 1.4222, -0.1805, ..., -0.6130, -0.5380, 1.6164]],\n", + "\n", + " [[-0.6736, -0.0718, -1.1724, ..., 0.2001, -0.5481, 0.0232],\n", + " [-1.8698, -1.2184, 0.2913, ..., -1.1398, -1.3523, -0.7851],\n", + " [-1.3725, -0.8212, 0.1984, ..., -1.8218, -1.4800, -0.2956],\n", + " ...,\n", + " [-0.5946, 0.5680, 0.8938, ..., -1.6653, 0.8218, 1.1902],\n", + " [ 1.2800, 1.9566, 0.2540, ..., -1.2290, 0.5257, 1.2667],\n", + " [ 1.3511, 1.3329, -0.0782, ..., -0.8454, -0.7400, 1.5966]]],\n", + " grad_fn=)\n", + "tensor([[[-0.8947, 0.0412, -1.2359, ..., 0.4410, -0.3965, 0.0106],\n", + " [-1.9196, 0.3326, 0.8482, ..., -1.5790, -1.1817, -1.0156],\n", + " [-2.1664, 0.3959, 0.7476, ..., -2.1767, -0.6488, 0.1889],\n", + " ...,\n", + " [-1.6009, 0.4887, -0.4818, ..., -1.1268, 0.4111, 0.7892],\n", + " [-0.1528, 1.1728, -0.5164, ..., -0.4340, 0.1499, 1.6704],\n", + " [ 1.0253, 1.4222, -0.1805, ..., -0.6130, -0.5380, 1.6164]],\n", + "\n", + " [[-0.6736, -0.0718, -1.1724, ..., 0.2001, -0.5481, 0.0232],\n", + " [-1.8698, -1.2184, 0.2913, ..., -1.1398, -1.3523, -0.7851],\n", + " [-1.3725, -0.8212, 0.1984, ..., -1.8218, -1.4800, -0.2956],\n", + " ...,\n", + " [-0.5946, 0.5680, 0.8938, ..., -1.6653, 0.8218, 1.1902],\n", + " [ 1.2800, 1.9566, 0.2540, ..., -1.2290, 0.5257, 1.2667],\n", + " [ 1.3511, 1.3329, -0.0782, ..., -0.8454, -0.7400, 1.5966]]],\n", + " grad_fn=)\n", + "tensor([[[[-0.8947, 0.0412, -1.2359, ..., 1.5140, -1.9812, -2.5532],\n", + " [-0.2951, -1.6086, -0.6381, ..., 0.4410, -0.3965, 0.0106]],\n", + "\n", + " [[-1.9196, 0.3326, 0.8482, ..., -2.3348, 1.3935, 1.1452],\n", + " [-0.5277, 0.1234, 0.7865, ..., -1.5790, -1.1817, -1.0156]],\n", + "\n", + " [[-2.1664, 0.3959, 0.7476, ..., -2.0817, 0.2852, 0.8173],\n", + " [-0.8414, 0.5154, -0.4553, ..., -2.1767, -0.6488, 0.1889]],\n", + "\n", + " ...,\n", + "\n", + " [[-1.6009, 0.4887, -0.4818, ..., -1.8165, 1.4764, 0.5091],\n", + " [-0.6869, 0.4007, -1.5818, ..., -1.1268, 0.4111, 0.7892]],\n", + "\n", + " [[-0.1528, 1.1728, -0.5164, ..., -1.3611, 1.0621, 1.1810],\n", + " [-0.7595, -0.1699, -1.5305, ..., -0.4340, 0.1499, 1.6704]],\n", + "\n", + " [[ 1.0253, 1.4222, -0.1805, ..., -0.6989, 0.4721, 2.6129],\n", + " [-1.2381, -0.4573, -1.7561, ..., -0.6130, -0.5380, 1.6164]]],\n", + "\n", + "\n", + " [[[-0.6736, -0.0718, -1.1724, ..., 1.4816, -1.7920, -2.5177],\n", + " [-0.3929, -1.5120, -0.5353, ..., 0.2001, -0.5481, 0.0232]],\n", + "\n", + " [[-1.8698, -1.2184, 0.2913, ..., -1.5227, 1.9764, 0.6389],\n", + " [-0.4202, 0.4572, -1.0780, ..., -1.1398, -1.3523, -0.7851]],\n", + "\n", + " [[-1.3725, -0.8212, 0.1984, ..., -2.1553, 1.7041, 0.7166],\n", + " [-1.0124, 0.9351, -0.0954, ..., -1.8218, -1.4800, -0.2956]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.5946, 0.5680, 0.8938, ..., -2.1904, 1.7986, 1.0902],\n", + " [-1.3820, 1.0268, -1.0041, ..., -1.6653, 0.8218, 1.1902]],\n", + "\n", + " [[ 1.2800, 1.9566, 0.2540, ..., -1.6180, 1.6176, 2.5636],\n", + " [-2.0592, 0.7059, -1.3359, ..., -1.2290, 0.5257, 1.2667]],\n", + "\n", + " [[ 1.3511, 1.3329, -0.0782, ..., -0.5836, 0.6491, 2.6554],\n", + " [-1.3686, -0.2348, -1.7438, ..., -0.8454, -0.7400, 1.5966]]]],\n", + " grad_fn=)\n", + "tensor([[[ 3.8307e-01, 1.8837e-02, -1.0314e+00, ..., 6.4335e-01,\n", + " -3.1830e-01, -1.7296e+00],\n", + " [ 1.5897e+00, 1.3689e-01, -6.8915e-01, ..., 1.5973e+00,\n", + " 1.1907e+00, -9.0454e-01],\n", + " [-5.8138e-01, 6.7943e-01, -1.3203e+00, ..., -5.3627e-01,\n", + " -1.0456e+00, -1.6301e+00],\n", + " ...,\n", + " [-1.7466e-01, 3.0707e-02, 7.5225e-01, ..., -1.3217e+00,\n", + " -1.3415e+00, -3.8328e-01],\n", + " [ 1.5170e-01, 5.1089e-01, 1.3993e-01, ..., -1.6600e-01,\n", + " -6.5011e-01, 2.1798e-02],\n", + " [-2.4311e-01, 1.6726e+00, 1.6682e-01, ..., 1.3441e-03,\n", + " -1.6754e+00, 3.1771e-01]],\n", + "\n", + " [[ 4.7844e-01, -3.1772e-01, -1.0617e+00, ..., 6.4928e-01,\n", + " -3.2944e-01, -2.4185e+00],\n", + " [ 5.9122e-01, -2.2648e-01, 1.7474e-01, ..., -1.8623e+00,\n", + " -1.1230e+00, 3.5013e-01],\n", + " [ 1.1642e-01, 1.2460e+00, 7.7941e-02, ..., 6.4975e-01,\n", + " -7.3862e-01, -2.1510e+00],\n", + " ...,\n", + " [ 1.1291e+00, -1.3637e+00, -1.5779e+00, ..., 1.7637e+00,\n", + " 9.1331e-01, -1.7033e+00],\n", + " [ 1.5909e+00, -1.4922e+00, 1.0060e+00, ..., 9.8096e-01,\n", + " 8.6736e-01, -2.3894e+00],\n", + " [ 1.8451e-01, 1.2740e+00, 4.2857e-01, ..., 6.2708e-01,\n", + " -1.3601e+00, -3.9984e-01]]], grad_fn=)\n", + "tensor([[[ 3.8307e-01, 1.8837e-02, -1.0314e+00, ..., 6.4335e-01,\n", + " -3.1830e-01, -1.7296e+00],\n", + " [ 1.5897e+00, 1.3689e-01, -6.8915e-01, ..., 1.5973e+00,\n", + " 1.1907e+00, -9.0454e-01],\n", + " [-5.8138e-01, 6.7943e-01, -1.3203e+00, ..., -5.3627e-01,\n", + " -1.0456e+00, -1.6301e+00],\n", + " ...,\n", + " [-1.7466e-01, 3.0707e-02, 7.5225e-01, ..., -1.3217e+00,\n", + " -1.3415e+00, -3.8328e-01],\n", + " [ 1.5170e-01, 5.1089e-01, 1.3993e-01, ..., -1.6600e-01,\n", + " -6.5011e-01, 2.1798e-02],\n", + " [-2.4311e-01, 1.6726e+00, 1.6682e-01, ..., 1.3441e-03,\n", + " -1.6754e+00, 3.1771e-01]],\n", + "\n", + " [[ 4.7844e-01, -3.1772e-01, -1.0617e+00, ..., 6.4928e-01,\n", + " -3.2944e-01, -2.4185e+00],\n", + " [ 5.9122e-01, -2.2648e-01, 1.7474e-01, ..., -1.8623e+00,\n", + " -1.1230e+00, 3.5013e-01],\n", + " [ 1.1642e-01, 1.2460e+00, 7.7941e-02, ..., 6.4975e-01,\n", + " -7.3862e-01, -2.1510e+00],\n", + " ...,\n", + " [ 1.1291e+00, -1.3637e+00, -1.5779e+00, ..., 1.7637e+00,\n", + " 9.1331e-01, -1.7033e+00],\n", + " [ 1.5909e+00, -1.4922e+00, 1.0060e+00, ..., 9.8096e-01,\n", + " 8.6736e-01, -2.3894e+00],\n", + " [ 1.8451e-01, 1.2740e+00, 4.2857e-01, ..., 6.2708e-01,\n", + " -1.3601e+00, -3.9984e-01]]], grad_fn=)\n", + "tensor([[[[ 3.8307e-01, 1.8837e-02, -1.0314e+00, ..., -8.0955e-01,\n", + " -2.8557e-01, -2.3318e-01],\n", + " [-9.9661e-02, 3.1722e-01, -3.0517e-01, ..., 6.4335e-01,\n", + " -3.1830e-01, -1.7296e+00]],\n", + "\n", + " [[ 1.5897e+00, 1.3689e-01, -6.8915e-01, ..., -7.9549e-01,\n", + " -6.9279e-01, -1.8082e-01],\n", + " [-9.9201e-01, 9.4938e-01, 4.4198e-02, ..., 1.5973e+00,\n", + " 1.1907e+00, -9.0454e-01]],\n", + "\n", + " [[-5.8138e-01, 6.7943e-01, -1.3203e+00, ..., -1.8905e+00,\n", + " 1.6226e-01, -1.2953e+00],\n", + " [-7.0312e-01, -6.4926e-01, -5.0913e-01, ..., -5.3627e-01,\n", + " -1.0456e+00, -1.6301e+00]],\n", + "\n", + " ...,\n", + "\n", + " [[-1.7466e-01, 3.0707e-02, 7.5225e-01, ..., -1.9281e+00,\n", + " 1.1489e+00, -2.4530e-01],\n", + " [-7.6225e-02, 8.5814e-01, -1.5467e+00, ..., -1.3217e+00,\n", + " -1.3415e+00, -3.8328e-01]],\n", + "\n", + " [[ 1.5170e-01, 5.1089e-01, 1.3993e-01, ..., -2.4168e+00,\n", + " 3.3385e-01, -6.2115e-02],\n", + " [-1.6390e+00, -1.6085e-01, -1.9118e+00, ..., -1.6600e-01,\n", + " -6.5011e-01, 2.1798e-02]],\n", + "\n", + " [[-2.4311e-01, 1.6726e+00, 1.6682e-01, ..., -1.0481e+00,\n", + " -2.7634e+00, 2.2741e-01],\n", + " [-1.4603e+00, 3.1239e-02, 3.8892e-01, ..., 1.3441e-03,\n", + " -1.6754e+00, 3.1771e-01]]],\n", + "\n", + "\n", + " [[[ 4.7844e-01, -3.1772e-01, -1.0617e+00, ..., -7.8511e-01,\n", + " -3.2510e-01, -2.1300e-01],\n", + " [-3.9893e-01, 6.1469e-01, -3.9206e-01, ..., 6.4928e-01,\n", + " -3.2944e-01, -2.4185e+00]],\n", + "\n", + " [[ 5.9122e-01, -2.2648e-01, 1.7474e-01, ..., -1.6488e+00,\n", + " -8.6854e-01, -8.0783e-01],\n", + " [-2.1516e+00, -2.4247e-01, -8.1713e-01, ..., -1.8623e+00,\n", + " -1.1230e+00, 3.5013e-01]],\n", + "\n", + " [[ 1.1642e-01, 1.2460e+00, 7.7941e-02, ..., -1.3121e+00,\n", + " -6.7044e-01, -1.1324e+00],\n", + " [ 4.3930e-01, 3.4082e-01, -1.2243e+00, ..., 6.4975e-01,\n", + " -7.3862e-01, -2.1510e+00]],\n", + "\n", + " ...,\n", + "\n", + " [[ 1.1291e+00, -1.3637e+00, -1.5779e+00, ..., 8.5980e-01,\n", + " 3.2796e-01, -1.9442e+00],\n", + " [-8.9502e-01, 9.7357e-01, 8.5447e-01, ..., 1.7637e+00,\n", + " 9.1331e-01, -1.7033e+00]],\n", + "\n", + " [[ 1.5909e+00, -1.4922e+00, 1.0060e+00, ..., -1.5965e+00,\n", + " -3.9380e-01, -4.3585e-01],\n", + " [-2.2103e+00, 4.4127e-01, 1.1554e+00, ..., 9.8096e-01,\n", + " 8.6736e-01, -2.3894e+00]],\n", + "\n", + " [[ 1.8451e-01, 1.2740e+00, 4.2857e-01, ..., -1.1366e+00,\n", + " -2.8409e+00, 4.6711e-01],\n", + " [-1.9576e+00, 2.0176e-01, -4.1035e-02, ..., 6.2708e-01,\n", + " -1.3601e+00, -3.9984e-01]]]], grad_fn=)\n", + "tensor([[[[ 3.7393e-01, 3.7078e-01, -7.5553e-01, ..., -9.4280e-01,\n", + " -6.5765e-01, -1.4711e-01],\n", + " [ 3.9396e-01, 4.7658e-02, -1.0277e+00, ..., -8.4926e-01,\n", + " -2.7815e-01, -2.6627e-01],\n", + " [ 9.9586e-01, 2.4414e-01, -8.3459e-01, ..., -8.5795e-01,\n", + " -4.3860e-01, -2.1582e-01],\n", + " ...,\n", + " [ 1.4504e+00, -1.2077e-01, -4.8160e-01, ..., -1.6701e+00,\n", + " 6.3807e-01, -8.3788e-02],\n", + " [ 3.8821e-01, 1.4150e-01, -1.5051e-01, ..., -1.5785e+00,\n", + " 4.1589e-01, -1.8024e-01],\n", + " [ 1.0467e-01, 5.5196e-01, 1.0818e-01, ..., -2.0373e+00,\n", + " 8.4395e-02, -6.9034e-02]],\n", + "\n", + " [[-7.7230e-01, -1.1852e-01, -8.0274e-02, ..., -1.3794e-03,\n", + " -5.2249e-01, -4.2095e-01],\n", + " [-4.6209e-01, -1.2680e-01, -2.5711e-01, ..., 1.3235e-01,\n", + " -4.1385e-01, -1.3744e+00],\n", + " [-1.7985e-01, -2.6037e-01, 3.5678e-01, ..., 2.4736e-01,\n", + " -1.6626e-01, -7.0940e-01],\n", + " ...,\n", + " [-6.2069e-01, 2.5298e-01, -7.8416e-01, ..., 8.6187e-02,\n", + " -6.9561e-01, -8.5675e-01],\n", + " [-9.1553e-01, 1.4648e-01, -5.5621e-02, ..., 1.8643e-01,\n", + " -1.0965e+00, -4.8097e-01],\n", + " [-1.0678e+00, 9.0885e-02, -4.5400e-02, ..., 1.5424e-01,\n", + " -1.1762e+00, -2.9385e-01]]],\n", + "\n", + "\n", + " [[[ 4.9445e-01, 9.6665e-03, -5.7117e-01, ..., -9.7640e-01,\n", + " -9.0948e-01, -1.3761e-01],\n", + " [ 4.7672e-01, -2.3672e-01, -8.9246e-01, ..., -8.9927e-01,\n", + " -3.9088e-01, -3.1635e-01],\n", + " [ 4.5120e-01, -5.8035e-02, -6.8736e-01, ..., -1.0352e+00,\n", + " -4.7996e-01, -4.5624e-01],\n", + " ...,\n", + " [ 3.3361e-01, 1.5918e-03, -1.0108e+00, ..., -1.5704e+00,\n", + " -3.9079e-01, -2.1742e-01],\n", + " [ 1.0123e+00, -1.1147e+00, -1.3032e+00, ..., 3.4082e-01,\n", + " 1.3298e-01, -1.5481e+00],\n", + " [ 1.4273e+00, -1.2951e+00, 5.7719e-01, ..., -1.2490e+00,\n", + " -4.1355e-01, -5.8558e-01]],\n", + "\n", + " [[-1.0884e+00, 2.0838e-01, -5.0948e-01, ..., 2.2845e-01,\n", + " -8.2075e-01, -1.1496e+00],\n", + " [-3.4296e-01, 5.0833e-01, -8.6644e-01, ..., 3.3008e-01,\n", + " -5.5070e-01, -1.8440e+00],\n", + " [-3.2602e-01, 6.0671e-01, -7.5711e-01, ..., 4.7493e-01,\n", + " -4.6589e-01, -2.0454e+00],\n", + " ...,\n", + " [-1.7399e+00, 4.6359e-01, 6.5213e-01, ..., 9.0025e-01,\n", + " 3.1643e-01, -2.0784e+00],\n", + " [-1.8375e+00, 2.6226e-01, 3.7726e-02, ..., 6.6286e-01,\n", + " -1.0588e+00, -7.6781e-01],\n", + " [-1.6402e+00, 3.1250e-01, -6.6194e-04, ..., 6.6493e-01,\n", + " -9.1110e-01, -1.0382e+00]]]],\n", + " grad_fn=)\n", + "tensor([[[[ 3.7393e-01, 3.7078e-01, -7.5553e-01, ..., -9.4280e-01,\n", + " -6.5765e-01, -1.4711e-01],\n", + " [-7.7230e-01, -1.1852e-01, -8.0274e-02, ..., -1.3794e-03,\n", + " -5.2249e-01, -4.2095e-01]],\n", + "\n", + " [[ 3.9396e-01, 4.7658e-02, -1.0277e+00, ..., -8.4926e-01,\n", + " -2.7815e-01, -2.6627e-01],\n", + " [-4.6209e-01, -1.2680e-01, -2.5711e-01, ..., 1.3235e-01,\n", + " -4.1385e-01, -1.3744e+00]],\n", + "\n", + " [[ 9.9586e-01, 2.4414e-01, -8.3459e-01, ..., -8.5795e-01,\n", + " -4.3860e-01, -2.1582e-01],\n", + " [-1.7985e-01, -2.6037e-01, 3.5678e-01, ..., 2.4736e-01,\n", + " -1.6626e-01, -7.0940e-01]],\n", + "\n", + " ...,\n", + "\n", + " [[ 1.4504e+00, -1.2077e-01, -4.8160e-01, ..., -1.6701e+00,\n", + " 6.3807e-01, -8.3788e-02],\n", + " [-6.2069e-01, 2.5298e-01, -7.8416e-01, ..., 8.6187e-02,\n", + " -6.9561e-01, -8.5675e-01]],\n", + "\n", + " [[ 3.8821e-01, 1.4150e-01, -1.5051e-01, ..., -1.5785e+00,\n", + " 4.1589e-01, -1.8024e-01],\n", + " [-9.1553e-01, 1.4648e-01, -5.5621e-02, ..., 1.8643e-01,\n", + " -1.0965e+00, -4.8097e-01]],\n", + "\n", + " [[ 1.0467e-01, 5.5196e-01, 1.0818e-01, ..., -2.0373e+00,\n", + " 8.4395e-02, -6.9034e-02],\n", + " [-1.0678e+00, 9.0885e-02, -4.5400e-02, ..., 1.5424e-01,\n", + " -1.1762e+00, -2.9385e-01]]],\n", + "\n", + "\n", + " [[[ 4.9445e-01, 9.6665e-03, -5.7117e-01, ..., -9.7640e-01,\n", + " -9.0948e-01, -1.3761e-01],\n", + " [-1.0884e+00, 2.0838e-01, -5.0948e-01, ..., 2.2845e-01,\n", + " -8.2075e-01, -1.1496e+00]],\n", + "\n", + " [[ 4.7672e-01, -2.3672e-01, -8.9246e-01, ..., -8.9927e-01,\n", + " -3.9088e-01, -3.1635e-01],\n", + " [-3.4296e-01, 5.0833e-01, -8.6644e-01, ..., 3.3008e-01,\n", + " -5.5070e-01, -1.8440e+00]],\n", + "\n", + " [[ 4.5120e-01, -5.8035e-02, -6.8736e-01, ..., -1.0352e+00,\n", + " -4.7996e-01, -4.5624e-01],\n", + " [-3.2602e-01, 6.0671e-01, -7.5711e-01, ..., 4.7493e-01,\n", + " -4.6589e-01, -2.0454e+00]],\n", + "\n", + " ...,\n", + "\n", + " [[ 3.3361e-01, 1.5918e-03, -1.0108e+00, ..., -1.5704e+00,\n", + " -3.9079e-01, -2.1742e-01],\n", + " [-1.7399e+00, 4.6359e-01, 6.5213e-01, ..., 9.0025e-01,\n", + " 3.1643e-01, -2.0784e+00]],\n", + "\n", + " [[ 1.0123e+00, -1.1147e+00, -1.3032e+00, ..., 3.4082e-01,\n", + " 1.3298e-01, -1.5481e+00],\n", + " [-1.8375e+00, 2.6226e-01, 3.7726e-02, ..., 6.6286e-01,\n", + " -1.0588e+00, -7.6781e-01]],\n", + "\n", + " [[ 1.4273e+00, -1.2951e+00, 5.7719e-01, ..., -1.2490e+00,\n", + " -4.1355e-01, -5.8558e-01],\n", + " [-1.6402e+00, 3.1250e-01, -6.6194e-04, ..., 6.6493e-01,\n", + " -9.1110e-01, -1.0382e+00]]]], grad_fn=)\n", + "tensor([[0.1385, 0.0325],\n", + " [0.1314, 0.0993]], grad_fn=)\n", + "tensor([1, 0])\n" ] } ], @@ -98,7 +945,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -124,7 +971,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -132,39 +979,33 @@ "output_type": "stream", "text": [ "\u001b[32mINFO \u001b[0m \u001b[34mTokenizing dataset imdb with AutoTokenizer for bert-base-uncased.\u001b[0m\n", - "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", - "To disable this warning, you can either:\n", - "\t- Avoid using `tokenizers` before the fork if possible\n", - "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[2024-12-01 15:14:08,830] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to mps (auto detect)\n" + "/home/ism/ADL/mase/src/chop/tools/huggingface.py:157: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n", + " trainer = Trainer(\n" ] }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "W1201 15:14:09.744000 8580182592 torch/distributed/elastic/multiprocessing/redirects.py:27] NOTE: Redirects are currently not supported in Windows or MacOs.\n", - "100%|██████████| 3125/3125 [04:39<00:00, 11.16it/s]" - ] + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " \n", + " [3125/3125 00:25]\n", + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ - "Evaluation accuracy: 0.84232\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" + "Evaluation accuracy: 0.84224\n" ] } ], @@ -206,27 +1047,35 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32mINFO \u001b[0m \u001b[34mPruning module: bert_encoder_layer_0_attention_self_query\u001b[0m\n", - "\u001b[32mINFO \u001b[0m \u001b[34mPruning module: bert_encoder_layer_0_attention_self_key\u001b[0m\n", - "\u001b[32mINFO \u001b[0m \u001b[34mPruning module: bert_encoder_layer_0_attention_self_value\u001b[0m\n", - "\u001b[32mINFO \u001b[0m \u001b[34mPruning module: bert_encoder_layer_0_attention_output_dense\u001b[0m\n", - "\u001b[32mINFO \u001b[0m \u001b[34mPruning module: bert_encoder_layer_0_intermediate_dense\u001b[0m\n", - "\u001b[32mINFO \u001b[0m \u001b[34mPruning module: bert_encoder_layer_0_output_dense\u001b[0m\n", - "\u001b[32mINFO \u001b[0m \u001b[34mPruning module: bert_encoder_layer_1_attention_self_query\u001b[0m\n", - "\u001b[32mINFO \u001b[0m \u001b[34mPruning module: bert_encoder_layer_1_attention_self_key\u001b[0m\n", - "\u001b[32mINFO \u001b[0m \u001b[34mPruning module: bert_encoder_layer_1_attention_self_value\u001b[0m\n", - "\u001b[32mINFO \u001b[0m \u001b[34mPruning module: bert_encoder_layer_1_attention_output_dense\u001b[0m\n", - "\u001b[32mINFO \u001b[0m \u001b[34mPruning module: bert_encoder_layer_1_intermediate_dense\u001b[0m\n", - "\u001b[32mINFO \u001b[0m \u001b[34mPruning module: bert_encoder_layer_1_output_dense\u001b[0m\n", - "\u001b[32mINFO \u001b[0m \u001b[34mPruning module: bert_pooler_dense\u001b[0m\n", - "\u001b[32mINFO \u001b[0m \u001b[34mPruning module: classifier\u001b[0m\n" + "\u001b[32mINFO \u001b[0m \u001b[34mPruning module: bert_encoder_layer_0_attention_self_query\u001b[0m\n" + ] + }, + { + "ename": "RuntimeError", + "evalue": "Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mRuntimeError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[6]\u001b[39m\u001b[32m, line 16\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mchop\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mpasses\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mpasses\u001b[39;00m\n\u001b[32m 3\u001b[39m pruning_config = {\n\u001b[32m 4\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mweight\u001b[39m\u001b[33m\"\u001b[39m: {\n\u001b[32m 5\u001b[39m \u001b[33m\"\u001b[39m\u001b[33msparsity\u001b[39m\u001b[33m\"\u001b[39m: \u001b[32m0.5\u001b[39m,\n\u001b[32m (...)\u001b[39m\u001b[32m 13\u001b[39m },\n\u001b[32m 14\u001b[39m }\n\u001b[32m---> \u001b[39m\u001b[32m16\u001b[39m mg, _ = \u001b[43mpasses\u001b[49m\u001b[43m.\u001b[49m\u001b[43mprune_transform_pass\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmg\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpass_args\u001b[49m\u001b[43m=\u001b[49m\u001b[43mpruning_config\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/ADL/mase/src/chop/passes/graph/transforms/pruning/prune.py:212\u001b[39m, in \u001b[36mprune_transform_pass\u001b[39m\u001b[34m(graph, pass_args)\u001b[39m\n\u001b[32m 179\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mprune_transform_pass\u001b[39m(graph, pass_args: \u001b[38;5;28mdict\u001b[39m = {}):\n\u001b[32m 180\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"\u001b[39;00m\n\u001b[32m 181\u001b[39m \u001b[33;03m Apply pruning transformation to the given graph.\u001b[39;00m\n\u001b[32m 182\u001b[39m \u001b[33;03m This is achieved by adding a register_parametrization hook to weights\u001b[39;00m\n\u001b[32m (...)\u001b[39m\u001b[32m 210\u001b[39m \u001b[33;03m :rtype: tuple\u001b[39;00m\n\u001b[32m 211\u001b[39m \u001b[33;03m \"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m212\u001b[39m graph = \u001b[43mprune_graph_iterator\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgraph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpass_args\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 213\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m graph, {}\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/ADL/mase/src/chop/passes/graph/transforms/pruning/prune.py:170\u001b[39m, in \u001b[36mprune_graph_iterator\u001b[39m\u001b[34m(graph, config)\u001b[39m\n\u001b[32m 168\u001b[39m register_name, parameterization = node_hooks[\u001b[33m\"\u001b[39m\u001b[33mw_hook\u001b[39m\u001b[33m\"\u001b[39m]\n\u001b[32m 169\u001b[39m \u001b[38;5;66;03m# apply weigh pruning\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m170\u001b[39m \u001b[43mtorch\u001b[49m\u001b[43m.\u001b[49m\u001b[43mnn\u001b[49m\u001b[43m.\u001b[49m\u001b[43mutils\u001b[49m\u001b[43m.\u001b[49m\u001b[43mparametrize\u001b[49m\u001b[43m.\u001b[49m\u001b[43mregister_parametrization\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 171\u001b[39m \u001b[43m \u001b[49m\u001b[43mgraph\u001b[49m\u001b[43m.\u001b[49m\u001b[43mmodules\u001b[49m\u001b[43m[\u001b[49m\u001b[43mnode\u001b[49m\u001b[43m.\u001b[49m\u001b[43mtarget\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mregister_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparameterization\u001b[49m\n\u001b[32m 172\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 173\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m node_hooks[\u001b[33m\"\u001b[39m\u001b[33ma_hook\u001b[39m\u001b[33m\"\u001b[39m] \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m 174\u001b[39m register_fn, hook_fn = node_hooks[\u001b[33m\"\u001b[39m\u001b[33ma_hook\u001b[39m\u001b[33m\"\u001b[39m]\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/ADL/mase/.venv/lib/python3.11/site-packages/torch/nn/utils/parametrize.py:606\u001b[39m, in \u001b[36mregister_parametrization\u001b[39m\u001b[34m(module, tensor_name, parametrization, unsafe)\u001b[39m\n\u001b[32m 604\u001b[39m original = \u001b[38;5;28mgetattr\u001b[39m(module, tensor_name)\n\u001b[32m 605\u001b[39m \u001b[38;5;66;03m# We create this early to check for possible errors\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m606\u001b[39m parametrizations = \u001b[43mParametrizationList\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 607\u001b[39m \u001b[43m \u001b[49m\u001b[43m[\u001b[49m\u001b[43mparametrization\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moriginal\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43munsafe\u001b[49m\u001b[43m=\u001b[49m\u001b[43munsafe\u001b[49m\n\u001b[32m 608\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 609\u001b[39m \u001b[38;5;66;03m# Delete the previous parameter or buffer\u001b[39;00m\n\u001b[32m 610\u001b[39m \u001b[38;5;28mdelattr\u001b[39m(module, tensor_name)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/ADL/mase/.venv/lib/python3.11/site-packages/torch/nn/utils/parametrize.py:214\u001b[39m, in \u001b[36mParametrizationList.__init__\u001b[39m\u001b[34m(self, modules, original, unsafe)\u001b[39m\n\u001b[32m 208\u001b[39m _register_parameter_or_buffer(\u001b[38;5;28mself\u001b[39m, \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33moriginal\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mi\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m, originali)\n\u001b[32m 210\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m.unsafe:\n\u001b[32m 211\u001b[39m \u001b[38;5;66;03m# Consistency checks:\u001b[39;00m\n\u001b[32m 212\u001b[39m \u001b[38;5;66;03m# Since f : A -> B, right_inverse : B -> A, Z and original should live in B\u001b[39;00m\n\u001b[32m 213\u001b[39m \u001b[38;5;66;03m# Z = forward(right_inverse(original))\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m214\u001b[39m Z = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 215\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(Z, Tensor):\n\u001b[32m 216\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[32m 217\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mA parametrization must return a tensor. Got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(Z).\u001b[34m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m.\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 218\u001b[39m )\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/ADL/mase/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1739\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1737\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m 1738\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1739\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/ADL/mase/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1750\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1745\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m 1746\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m 1747\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m 1748\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m 1749\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1750\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1752\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1753\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/ADL/mase/.venv/lib/python3.11/site-packages/torch/nn/utils/parametrize.py:300\u001b[39m, in \u001b[36mParametrizationList.forward\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 298\u001b[39m \u001b[38;5;66;03m# Unpack the originals for the first parametrization\u001b[39;00m\n\u001b[32m 299\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.is_tensor:\n\u001b[32m--> \u001b[39m\u001b[32m300\u001b[39m x = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m[\u001b[49m\u001b[32;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43moriginal\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 301\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 302\u001b[39m originals = (\u001b[38;5;28mgetattr\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33moriginal\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mi\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m) \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mself\u001b[39m.ntensors))\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/ADL/mase/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1739\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1737\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m 1738\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1739\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/ADL/mase/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1750\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1745\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m 1746\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m 1747\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m 1748\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m 1749\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1750\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1752\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1753\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/ADL/mase/src/chop/passes/graph/transforms/pruning/sparse_parameterization.py:22\u001b[39m, in \u001b[36mFakeSparseWeight.forward\u001b[39m\u001b[34m(self, x)\u001b[39m\n\u001b[32m 20\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[32m 21\u001b[39m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m.mask.shape == x.shape\n\u001b[32m---> \u001b[39m\u001b[32m22\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mmask\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\n", + "\u001b[31mRuntimeError\u001b[39m: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!" ] } ], @@ -258,7 +1107,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -306,7 +1155,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -788,7 +1637,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "metadata": {}, "outputs": [ { diff --git a/docs/source/modules/documentation/tutorials/tutorial_5_nas_optuna.ipynb b/docs/source/modules/documentation/tutorials/tutorial_5_nas_optuna.ipynb index 91dba759c..0dcc73fae 100644 --- a/docs/source/modules/documentation/tutorials/tutorial_5_nas_optuna.ipynb +++ b/docs/source/modules/documentation/tutorials/tutorial_5_nas_optuna.ipynb @@ -26,7 +26,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -44,9 +44,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/ism/ADL/mase/.venv/lib/python3.11/site-packages/torch/cuda/__init__.py:61: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.\n", + " import pynvml # type: ignore[import]\n", + "/home/ism/ADL/mase/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "\u001b[32mINFO \u001b[0m \u001b[34mTokenizing dataset imdb with AutoTokenizer for bert-base-uncased.\u001b[0m\n" + ] + } + ], "source": [ "from chop.tools import get_tokenized_dataset\n", "\n", @@ -73,7 +85,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -108,7 +120,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -165,7 +177,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -216,7 +228,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -234,9 +246,99 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[I 2026-02-03 12:28:16,990] A new study created in memory with name: bert-tiny-nas-study\n", + "/home/ism/ADL/mase/.venv/lib/python3.11/site-packages/optuna/distributions.py:502: UserWarning: Choices for a categorical distribution should be a tuple of None, bool, int, float and str for persistent storage but contains which is of type type.\n", + " warnings.warn(message)\n", + "/home/ism/ADL/mase/.venv/lib/python3.11/site-packages/optuna/distributions.py:502: UserWarning: Choices for a categorical distribution should be a tuple of None, bool, int, float and str for persistent storage but contains which is of type type.\n", + " warnings.warn(message)\n", + "/home/ism/ADL/mase/src/chop/tools/huggingface.py:157: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n", + " trainer = Trainer(\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " \n", + " [3125/3125 00:48, Epoch 1/1]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
StepTraining Loss
5000.693500
10000.583500
15000.493800
20000.455900
25000.422900
30000.421100

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "

\n", + " \n", + " \n", + " [3125/3125 00:22]\n", + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[I 2026-02-03 12:29:30,979] Trial 0 finished with value: 0.82792 and parameters: {'num_layers': 0, 'num_heads': 3, 'hidden_size': 1, 'intermediate_size': 2, 'bert.encoder.layer.0.attention.self.query_type': , 'bert.encoder.layer.0.attention.self.key_type': , 'bert.encoder.layer.0.attention.self.value_type': , 'bert.encoder.layer.0.attention.output.dense_type': , 'bert.encoder.layer.1.attention.self.query_type': , 'bert.encoder.layer.1.attention.self.key_type': , 'bert.encoder.layer.1.attention.self.value_type': , 'bert.encoder.layer.1.attention.output.dense_type': , 'bert.pooler.dense_type': }. Best is trial 0 with value: 0.82792.\n" + ] + } + ], "source": [ "import optuna\n", "\n", @@ -262,7 +364,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -291,9 +393,778 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`past_key_values` were not specified as input names, but model.config.use_cache = True. Setting model.config.use_cache = False.\n", + "\u001b[32mINFO \u001b[0m \u001b[34mGetting dummy input for prajjwal1/bert-tiny.\u001b[0m\n", + "\u001b[32mINFO \u001b[0m \u001b[34mPruning module: bert_encoder_layer_0_attention_self_query\u001b[0m\n", + "\u001b[32mINFO \u001b[0m \u001b[34mPruning module: bert_encoder_layer_0_attention_self_value\u001b[0m\n", + "\u001b[32mINFO \u001b[0m \u001b[34mPruning module: bert_encoder_layer_0_intermediate_dense\u001b[0m\n", + "\u001b[32mINFO \u001b[0m \u001b[34mPruning module: bert_encoder_layer_0_output_dense\u001b[0m\n", + "\u001b[32mINFO \u001b[0m \u001b[34mPruning module: bert_encoder_layer_1_attention_output_dense\u001b[0m\n", + "\u001b[32mINFO \u001b[0m \u001b[34mPruning module: bert_encoder_layer_1_intermediate_dense\u001b[0m\n", + "\u001b[32mINFO \u001b[0m \u001b[34mPruning module: bert_encoder_layer_1_output_dense\u001b[0m\n", + "\u001b[32mINFO \u001b[0m \u001b[34mPruning module: classifier\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[ 101, 9932, 2089, 2202, 2058, 1996, 2088, 2028, 2154, 102],\n", + " [ 101, 2023, 2003, 2339, 2017, 2323, 4553, 4748, 4877, 102]])\n", + "tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])\n", + "tensor([[ 101, 9932, 2089, 2202, 2058, 1996, 2088, 2028, 2154, 102],\n", + " [ 101, 2023, 2003, 2339, 2017, 2323, 4553, 4748, 4877, 102]])\n", + "tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n", + "tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n", + "tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])\n", + "tensor([[[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]],\n", + "\n", + "\n", + " [[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]]])\n", + "tensor([[[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]],\n", + "\n", + "\n", + " [[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]]])\n", + "tensor([[[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],\n", + "\n", + "\n", + " [[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]])\n", + "tensor([[[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],\n", + "\n", + "\n", + " [[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]]])\n", + "tensor([[[-0.8220, 0.3099, 2.2264, ..., -1.4326, -0.8934, -0.8233],\n", + " [-0.2583, -0.7032, 1.4261, ..., 0.0806, 0.4364, 1.0597],\n", + " [-1.4471, 0.5038, 0.8523, ..., -0.1455, -1.0188, -0.7765],\n", + " ...,\n", + " [-0.8178, -0.5693, 1.1870, ..., -0.6866, -0.3260, -0.0128],\n", + " [-0.8974, -0.0396, 2.0569, ..., 0.4507, -1.3209, -0.8374],\n", + " [-0.7668, -0.2170, 2.0509, ..., -1.3579, 1.3152, -0.2025]],\n", + "\n", + " [[-0.8220, 0.3099, 2.2264, ..., -1.4326, -0.8934, -0.8233],\n", + " [ 1.3565, -0.1324, 1.5082, ..., -0.6385, 0.3736, 0.0184],\n", + " [-1.0066, 0.3500, 1.1917, ..., -0.9159, -0.1423, -1.0671],\n", + " ...,\n", + " [-1.8726, -0.7850, 1.7803, ..., -1.2450, 0.7536, -0.5221],\n", + " [ 1.0883, 1.4327, 0.4604, ..., 0.2978, -2.3361, -0.4312],\n", + " [-0.7668, -0.2170, 2.0509, ..., -1.3579, 1.3152, -0.2025]]],\n", + " grad_fn=)\n", + "tensor([[[-1.1274, -0.4625, -1.0464, ..., 0.7091, -0.9648, 0.1702],\n", + " [-0.4016, -0.0190, -0.4851, ..., 0.3042, -0.3956, -0.1381],\n", + " [-0.3132, -0.0417, -0.2925, ..., 0.5213, -0.9875, 0.3725],\n", + " ...,\n", + " [-0.1884, -0.0473, -0.6054, ..., -0.2130, -0.6436, 0.1464],\n", + " [-0.2528, -0.1195, -0.7870, ..., 0.3006, -0.5464, 0.3398],\n", + " [ 0.0312, 0.0341, -0.4793, ..., 0.4096, -0.2642, 0.4740]],\n", + "\n", + " [[-1.1274, -0.4625, -1.0464, ..., 0.7091, -0.9648, 0.1702],\n", + " [-0.3371, -0.1943, -0.4204, ..., 0.3750, -0.2669, -0.2592],\n", + " [-0.2761, -0.1219, -0.0997, ..., 0.3936, -0.8905, 0.1860],\n", + " ...,\n", + " [-0.4397, 0.1947, -0.4489, ..., -0.0815, -0.5882, 0.1392],\n", + " [-0.1914, -0.0449, -0.3146, ..., 0.0628, -0.5037, 0.3616],\n", + " [ 0.0312, 0.0341, -0.4793, ..., 0.4096, -0.2642, 0.4740]]],\n", + " grad_fn=)\n", + "tensor([[[-1.1274, -0.4625, -1.0464, ..., 0.7091, -0.9648, 0.1702],\n", + " [-0.4016, -0.0190, -0.4851, ..., 0.3042, -0.3956, -0.1381],\n", + " [-0.3132, -0.0417, -0.2925, ..., 0.5213, -0.9875, 0.3725],\n", + " ...,\n", + " [-0.1884, -0.0473, -0.6054, ..., -0.2130, -0.6436, 0.1464],\n", + " [-0.2528, -0.1195, -0.7870, ..., 0.3006, -0.5464, 0.3398],\n", + " [ 0.0312, 0.0341, -0.4793, ..., 0.4096, -0.2642, 0.4740]],\n", + "\n", + " [[-1.1274, -0.4625, -1.0464, ..., 0.7091, -0.9648, 0.1702],\n", + " [-0.3371, -0.1943, -0.4204, ..., 0.3750, -0.2669, -0.2592],\n", + " [-0.2761, -0.1219, -0.0997, ..., 0.3936, -0.8905, 0.1860],\n", + " ...,\n", + " [-0.4397, 0.1947, -0.4489, ..., -0.0815, -0.5882, 0.1392],\n", + " [-0.1914, -0.0449, -0.3146, ..., 0.0628, -0.5037, 0.3616],\n", + " [ 0.0312, 0.0341, -0.4793, ..., 0.4096, -0.2642, 0.4740]]],\n", + " grad_fn=)\n", + "tensor([[[[-1.1274, -0.4625, -1.0464, ..., 1.1679, 0.5382, 0.4827],\n", + " [ 0.7593, 1.5658, -1.4508, ..., 0.7091, -0.9648, 0.1702]],\n", + "\n", + " [[-0.4016, -0.0190, -0.4851, ..., 0.4042, -0.3666, 0.1064],\n", + " [ 0.4806, 0.3919, -0.1628, ..., 0.3042, -0.3956, -0.1381]],\n", + "\n", + " [[-0.3132, -0.0417, -0.2925, ..., 0.3926, 0.0654, 0.3780],\n", + " [ 0.3495, 0.1287, -0.6141, ..., 0.5213, -0.9875, 0.3725]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.1884, -0.0473, -0.6054, ..., 0.4443, -0.1035, 0.0345],\n", + " [ 0.3649, 0.0647, -0.5253, ..., -0.2130, -0.6436, 0.1464]],\n", + "\n", + " [[-0.2528, -0.1195, -0.7870, ..., 0.7144, 0.4511, 0.3913],\n", + " [ 0.3807, 0.5916, -0.6800, ..., 0.3006, -0.5464, 0.3398]],\n", + "\n", + " [[ 0.0312, 0.0341, -0.4793, ..., 0.3417, 0.1047, 0.1033],\n", + " [ 0.6889, 0.3934, -0.2210, ..., 0.4096, -0.2642, 0.4740]]],\n", + "\n", + "\n", + " [[[-1.1274, -0.4625, -1.0464, ..., 1.1679, 0.5382, 0.4827],\n", + " [ 0.7593, 1.5658, -1.4508, ..., 0.7091, -0.9648, 0.1702]],\n", + "\n", + " [[-0.3371, -0.1943, -0.4204, ..., 0.2491, 0.0900, 0.0341],\n", + " [ 0.6577, 0.7498, -0.2878, ..., 0.3750, -0.2669, -0.2592]],\n", + "\n", + " [[-0.2761, -0.1219, -0.0997, ..., 0.3425, 0.1108, 0.1984],\n", + " [ 0.3615, 0.1557, -0.2360, ..., 0.3936, -0.8905, 0.1860]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.4397, 0.1947, -0.4489, ..., 0.6021, -0.3952, 0.4831],\n", + " [ 0.7183, 0.1970, -0.7061, ..., -0.0815, -0.5882, 0.1392]],\n", + "\n", + " [[-0.1914, -0.0449, -0.3146, ..., 0.3866, 0.2165, 0.3482],\n", + " [ 0.4748, 0.5136, -0.7021, ..., 0.0628, -0.5037, 0.3616]],\n", + "\n", + " [[ 0.0312, 0.0341, -0.4793, ..., 0.3417, 0.1047, 0.1033],\n", + " [ 0.6889, 0.3934, -0.2210, ..., 0.4096, -0.2642, 0.4740]]]],\n", + " grad_fn=)\n", + "tensor([[[-0.8220, 0.3099, 2.2264, ..., -1.4326, -0.8934, -0.8233],\n", + " [-0.2583, -0.7032, 1.4261, ..., 0.0806, 0.4364, 1.0597],\n", + " [-1.4471, 0.5038, 0.8523, ..., -0.1455, -1.0188, -0.7765],\n", + " ...,\n", + " [-0.8178, -0.5693, 1.1870, ..., -0.6866, -0.3260, -0.0128],\n", + " [-0.8974, -0.0396, 2.0569, ..., 0.4507, -1.3209, -0.8374],\n", + " [-0.7668, -0.2170, 2.0509, ..., -1.3579, 1.3152, -0.2025]],\n", + "\n", + " [[-0.8220, 0.3099, 2.2264, ..., -1.4326, -0.8934, -0.8233],\n", + " [ 1.3565, -0.1324, 1.5082, ..., -0.6385, 0.3736, 0.0184],\n", + " [-1.0066, 0.3500, 1.1917, ..., -0.9159, -0.1423, -1.0671],\n", + " ...,\n", + " [-1.8726, -0.7850, 1.7803, ..., -1.2450, 0.7536, -0.5221],\n", + " [ 1.0883, 1.4327, 0.4604, ..., 0.2978, -2.3361, -0.4312],\n", + " [-0.7668, -0.2170, 2.0509, ..., -1.3579, 1.3152, -0.2025]]],\n", + " grad_fn=)\n", + "tensor([[[-0.8220, 0.3099, 2.2264, ..., -1.4326, -0.8934, -0.8233],\n", + " [-0.2583, -0.7032, 1.4261, ..., 0.0806, 0.4364, 1.0597],\n", + " [-1.4471, 0.5038, 0.8523, ..., -0.1455, -1.0188, -0.7765],\n", + " ...,\n", + " [-0.8178, -0.5693, 1.1870, ..., -0.6866, -0.3260, -0.0128],\n", + " [-0.8974, -0.0396, 2.0569, ..., 0.4507, -1.3209, -0.8374],\n", + " [-0.7668, -0.2170, 2.0509, ..., -1.3579, 1.3152, -0.2025]],\n", + "\n", + " [[-0.8220, 0.3099, 2.2264, ..., -1.4326, -0.8934, -0.8233],\n", + " [ 1.3565, -0.1324, 1.5082, ..., -0.6385, 0.3736, 0.0184],\n", + " [-1.0066, 0.3500, 1.1917, ..., -0.9159, -0.1423, -1.0671],\n", + " ...,\n", + " [-1.8726, -0.7850, 1.7803, ..., -1.2450, 0.7536, -0.5221],\n", + " [ 1.0883, 1.4327, 0.4604, ..., 0.2978, -2.3361, -0.4312],\n", + " [-0.7668, -0.2170, 2.0509, ..., -1.3579, 1.3152, -0.2025]]],\n", + " grad_fn=)\n", + "tensor([[[[-0.8220, 0.3099, 2.2264, ..., -1.2132, 0.4539, 1.4555],\n", + " [-0.5901, -0.0639, -0.4127, ..., -1.4326, -0.8934, -0.8233]],\n", + "\n", + " [[-0.2583, -0.7032, 1.4261, ..., -0.6968, -1.8422, 2.4488],\n", + " [-1.7574, 1.4708, -0.6847, ..., 0.0806, 0.4364, 1.0597]],\n", + "\n", + " [[-1.4471, 0.5038, 0.8523, ..., -1.0438, -1.2575, 2.5008],\n", + " [-1.0973, -0.1343, 1.4040, ..., -0.1455, -1.0188, -0.7765]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.8178, -0.5693, 1.1870, ..., -0.4630, -0.7825, 0.1127],\n", + " [ 0.8923, 0.8109, -0.5716, ..., -0.6866, -0.3260, -0.0128]],\n", + "\n", + " [[-0.8974, -0.0396, 2.0569, ..., -0.3028, 0.3290, 0.8321],\n", + " [-2.1080, -0.5053, -0.0474, ..., 0.4507, -1.3209, -0.8374]],\n", + "\n", + " [[-0.7668, -0.2170, 2.0509, ..., -0.8833, -0.4169, 0.7638],\n", + " [ 0.9468, -0.7144, 1.3739, ..., -1.3579, 1.3152, -0.2025]]],\n", + "\n", + "\n", + " [[[-0.8220, 0.3099, 2.2264, ..., -1.2132, 0.4539, 1.4555],\n", + " [-0.5901, -0.0639, -0.4127, ..., -1.4326, -0.8934, -0.8233]],\n", + "\n", + " [[ 1.3565, -0.1324, 1.5082, ..., 0.2366, -0.3073, 1.2141],\n", + " [ 0.2354, 0.2350, -0.6717, ..., -0.6385, 0.3736, 0.0184]],\n", + "\n", + " [[-1.0066, 0.3500, 1.1917, ..., -0.6808, -1.1268, 0.4581],\n", + " [-0.2434, -0.1574, 1.6191, ..., -0.9159, -0.1423, -1.0671]],\n", + "\n", + " ...,\n", + "\n", + " [[-1.8726, -0.7850, 1.7803, ..., -0.1853, -0.2930, -0.4072],\n", + " [-0.1870, 1.1098, 0.5261, ..., -1.2450, 0.7536, -0.5221]],\n", + "\n", + " [[ 1.0883, 1.4327, 0.4604, ..., -0.2177, -0.6685, 0.3263],\n", + " [-1.6381, -0.9548, 0.2220, ..., 0.2978, -2.3361, -0.4312]],\n", + "\n", + " [[-0.7668, -0.2170, 2.0509, ..., -0.8833, -0.4169, 0.7638],\n", + " [ 0.9468, -0.7144, 1.3739, ..., -1.3579, 1.3152, -0.2025]]]],\n", + " grad_fn=)\n", + "tensor([[[ 0.3644, -0.1260, -0.2689, ..., 0.4270, -0.1192, -0.1129],\n", + " [-0.0050, -0.1840, -0.2019, ..., 0.3252, -0.1175, -0.2035],\n", + " [ 0.4070, -0.0397, 0.0904, ..., 0.3217, 0.4439, 0.6481],\n", + " ...,\n", + " [ 0.4496, 0.1261, -0.1468, ..., 0.5157, 0.2950, 0.1301],\n", + " [ 0.1168, -0.3191, -0.7119, ..., -0.1015, -0.0120, 0.0217],\n", + " [-0.0178, 0.4604, -0.0222, ..., 0.5506, -0.1579, -0.2772]],\n", + "\n", + " [[ 0.3644, -0.1260, -0.2689, ..., 0.4270, -0.1192, -0.1129],\n", + " [ 0.0690, 0.2423, 0.3067, ..., 0.4794, -0.3135, 0.0875],\n", + " [ 0.2005, -0.1309, -0.1855, ..., 0.8148, 0.3684, 0.6840],\n", + " ...,\n", + " [ 0.1120, 0.2899, -0.4503, ..., 0.1702, 0.4037, 0.1114],\n", + " [ 0.6304, -0.1792, -0.8129, ..., -0.3294, -0.2011, 0.0722],\n", + " [-0.0178, 0.4604, -0.0222, ..., 0.5506, -0.1579, -0.2772]]],\n", + " grad_fn=)\n", + "tensor([[[ 0.3644, -0.1260, -0.2689, ..., 0.4270, -0.1192, -0.1129],\n", + " [-0.0050, -0.1840, -0.2019, ..., 0.3252, -0.1175, -0.2035],\n", + " [ 0.4070, -0.0397, 0.0904, ..., 0.3217, 0.4439, 0.6481],\n", + " ...,\n", + " [ 0.4496, 0.1261, -0.1468, ..., 0.5157, 0.2950, 0.1301],\n", + " [ 0.1168, -0.3191, -0.7119, ..., -0.1015, -0.0120, 0.0217],\n", + " [-0.0178, 0.4604, -0.0222, ..., 0.5506, -0.1579, -0.2772]],\n", + "\n", + " [[ 0.3644, -0.1260, -0.2689, ..., 0.4270, -0.1192, -0.1129],\n", + " [ 0.0690, 0.2423, 0.3067, ..., 0.4794, -0.3135, 0.0875],\n", + " [ 0.2005, -0.1309, -0.1855, ..., 0.8148, 0.3684, 0.6840],\n", + " ...,\n", + " [ 0.1120, 0.2899, -0.4503, ..., 0.1702, 0.4037, 0.1114],\n", + " [ 0.6304, -0.1792, -0.8129, ..., -0.3294, -0.2011, 0.0722],\n", + " [-0.0178, 0.4604, -0.0222, ..., 0.5506, -0.1579, -0.2772]]],\n", + " grad_fn=)\n", + "tensor([[[[ 0.3644, -0.1260, -0.2689, ..., 0.5730, 0.1030, 0.5545],\n", + " [ 0.1456, -0.0234, 0.0240, ..., 0.4270, -0.1192, -0.1129]],\n", + "\n", + " [[-0.0050, -0.1840, -0.2019, ..., 0.1136, 0.3381, 0.3709],\n", + " [-0.1039, 0.3027, -0.0938, ..., 0.3252, -0.1175, -0.2035]],\n", + "\n", + " [[ 0.4070, -0.0397, 0.0904, ..., 0.1872, -0.2598, 0.2972],\n", + " [ 0.1394, -0.1357, 0.2393, ..., 0.3217, 0.4439, 0.6481]],\n", + "\n", + " ...,\n", + "\n", + " [[ 0.4496, 0.1261, -0.1468, ..., -0.4594, -0.2070, 0.1646],\n", + " [ 0.2591, 0.2134, -0.1316, ..., 0.5157, 0.2950, 0.1301]],\n", + "\n", + " [[ 0.1168, -0.3191, -0.7119, ..., 0.1885, 0.4110, -0.1372],\n", + " [-0.0099, 0.2168, 0.3851, ..., -0.1015, -0.0120, 0.0217]],\n", + "\n", + " [[-0.0178, 0.4604, -0.0222, ..., -0.0267, 0.0384, 0.1212],\n", + " [-0.1959, 0.2017, 0.3176, ..., 0.5506, -0.1579, -0.2772]]],\n", + "\n", + "\n", + " [[[ 0.3644, -0.1260, -0.2689, ..., 0.5730, 0.1030, 0.5545],\n", + " [ 0.1456, -0.0234, 0.0240, ..., 0.4270, -0.1192, -0.1129]],\n", + "\n", + " [[ 0.0690, 0.2423, 0.3067, ..., 0.0628, -0.1276, 0.2575],\n", + " [ 0.2787, 0.1589, 0.0475, ..., 0.4794, -0.3135, 0.0875]],\n", + "\n", + " [[ 0.2005, -0.1309, -0.1855, ..., -0.2191, -0.1409, 0.3676],\n", + " [-0.0104, 0.1166, 0.4504, ..., 0.8148, 0.3684, 0.6840]],\n", + "\n", + " ...,\n", + "\n", + " [[ 0.1120, 0.2899, -0.4503, ..., -0.2917, -0.2754, 0.3250],\n", + " [ 0.0132, 0.1670, 0.2878, ..., 0.1702, 0.4037, 0.1114]],\n", + "\n", + " [[ 0.6304, -0.1792, -0.8129, ..., 0.2192, 0.0447, -0.2742],\n", + " [-0.1292, 0.0644, 0.2108, ..., -0.3294, -0.2011, 0.0722]],\n", + "\n", + " [[-0.0178, 0.4604, -0.0222, ..., -0.0267, 0.0384, 0.1212],\n", + " [-0.1959, 0.2017, 0.3176, ..., 0.5506, -0.1579, -0.2772]]]],\n", + " grad_fn=)\n", + "tensor([[[[ 1.4107e-01, -1.3809e-01, -1.3683e-01, ..., 2.2532e-01,\n", + " -9.6405e-02, 2.3651e-01],\n", + " [ 1.3174e-01, -6.0345e-02, -2.1670e-01, ..., 1.8448e-01,\n", + " 8.5125e-03, 2.8495e-01],\n", + " [ 1.3133e-01, -7.3089e-02, -1.9466e-01, ..., 1.9548e-01,\n", + " -2.0421e-02, 2.7337e-01],\n", + " ...,\n", + " [ 1.3366e-01, -7.0629e-02, -2.3117e-01, ..., 2.0667e-01,\n", + " -1.6625e-04, 2.7329e-01],\n", + " [ 1.2739e-01, -8.9336e-02, -2.0693e-01, ..., 1.8931e-01,\n", + " -1.8426e-02, 2.5761e-01],\n", + " [ 1.3246e-01, -4.9686e-02, -1.9184e-01, ..., 1.8211e-01,\n", + " -3.9995e-02, 2.7305e-01]],\n", + "\n", + " [[-1.1139e-01, 1.6108e-01, 1.2078e-01, ..., 3.5810e-01,\n", + " 7.8786e-02, -2.6769e-02],\n", + " [-4.9526e-02, 9.9165e-02, 1.9893e-01, ..., 4.3015e-01,\n", + " 1.6079e-01, 8.4502e-02],\n", + " [-8.5124e-02, 6.7143e-02, 1.7566e-01, ..., 3.8132e-01,\n", + " 1.2087e-01, 5.1995e-02],\n", + " ...,\n", + " [-7.5963e-02, 9.7155e-02, 1.7322e-01, ..., 3.7556e-01,\n", + " 7.5732e-02, 1.5615e-02],\n", + " [-9.1983e-02, 9.5608e-02, 1.4545e-01, ..., 3.9614e-01,\n", + " 1.2145e-01, 3.0251e-02],\n", + " [-6.2113e-02, 9.0067e-02, 1.7930e-01, ..., 3.8618e-01,\n", + " 1.0374e-01, 5.3924e-02]]],\n", + "\n", + "\n", + " [[[ 7.8779e-02, 1.7202e-02, -2.4146e-01, ..., 3.1351e-01,\n", + " 1.7767e-01, 2.9509e-01],\n", + " [ 1.8064e-01, 7.4196e-02, -2.5236e-01, ..., 1.6780e-01,\n", + " 1.5041e-02, 2.9471e-01],\n", + " [ 1.2506e-01, 7.1351e-02, -2.5676e-01, ..., 1.4592e-01,\n", + " 2.3883e-04, 2.6633e-01],\n", + " ...,\n", + " [ 1.5771e-01, 8.0993e-02, -2.6345e-01, ..., 1.5931e-01,\n", + " 1.5851e-02, 2.8181e-01],\n", + " [ 1.4276e-01, 6.9402e-02, -2.5167e-01, ..., 1.6150e-01,\n", + " 1.6533e-02, 2.7787e-01],\n", + " [ 1.5462e-01, 6.9859e-02, -2.5659e-01, ..., 1.5087e-01,\n", + " 1.9992e-02, 2.7064e-01]],\n", + "\n", + " [[-5.0975e-02, 5.2467e-03, 2.5165e-01, ..., 1.4124e-01,\n", + " 5.7959e-02, -1.2931e-02],\n", + " [-1.4636e-02, 1.8216e-02, 2.5491e-01, ..., 2.8809e-01,\n", + " 1.3363e-01, 2.7224e-03],\n", + " [-3.4466e-02, 1.7536e-02, 2.4431e-01, ..., 2.4604e-01,\n", + " 1.0138e-01, 2.7211e-02],\n", + " ...,\n", + " [-3.0669e-02, 8.0512e-04, 2.3962e-01, ..., 2.3068e-01,\n", + " 5.2713e-02, -2.8760e-02],\n", + " [-4.1895e-02, 1.5549e-02, 2.4193e-01, ..., 2.8668e-01,\n", + " 1.0456e-01, 3.9524e-02],\n", + " [-4.0898e-02, -1.0490e-02, 2.5776e-01, ..., 3.1455e-01,\n", + " 9.0539e-02, -2.9835e-02]]]],\n", + " grad_fn=)\n", + "tensor([[[[ 1.4107e-01, -1.3809e-01, -1.3683e-01, ..., 2.2532e-01,\n", + " -9.6405e-02, 2.3651e-01],\n", + " [-1.1139e-01, 1.6108e-01, 1.2078e-01, ..., 3.5810e-01,\n", + " 7.8786e-02, -2.6769e-02]],\n", + "\n", + " [[ 1.3174e-01, -6.0345e-02, -2.1670e-01, ..., 1.8448e-01,\n", + " 8.5125e-03, 2.8495e-01],\n", + " [-4.9526e-02, 9.9165e-02, 1.9893e-01, ..., 4.3015e-01,\n", + " 1.6079e-01, 8.4502e-02]],\n", + "\n", + " [[ 1.3133e-01, -7.3089e-02, -1.9466e-01, ..., 1.9548e-01,\n", + " -2.0421e-02, 2.7337e-01],\n", + " [-8.5124e-02, 6.7143e-02, 1.7566e-01, ..., 3.8132e-01,\n", + " 1.2087e-01, 5.1995e-02]],\n", + "\n", + " ...,\n", + "\n", + " [[ 1.3366e-01, -7.0629e-02, -2.3117e-01, ..., 2.0667e-01,\n", + " -1.6625e-04, 2.7329e-01],\n", + " [-7.5963e-02, 9.7155e-02, 1.7322e-01, ..., 3.7556e-01,\n", + " 7.5732e-02, 1.5615e-02]],\n", + "\n", + " [[ 1.2739e-01, -8.9336e-02, -2.0693e-01, ..., 1.8931e-01,\n", + " -1.8426e-02, 2.5761e-01],\n", + " [-9.1983e-02, 9.5608e-02, 1.4545e-01, ..., 3.9614e-01,\n", + " 1.2145e-01, 3.0251e-02]],\n", + "\n", + " [[ 1.3246e-01, -4.9686e-02, -1.9184e-01, ..., 1.8211e-01,\n", + " -3.9995e-02, 2.7305e-01],\n", + " [-6.2113e-02, 9.0067e-02, 1.7930e-01, ..., 3.8618e-01,\n", + " 1.0374e-01, 5.3924e-02]]],\n", + "\n", + "\n", + " [[[ 7.8779e-02, 1.7202e-02, -2.4146e-01, ..., 3.1351e-01,\n", + " 1.7767e-01, 2.9509e-01],\n", + " [-5.0975e-02, 5.2467e-03, 2.5165e-01, ..., 1.4124e-01,\n", + " 5.7959e-02, -1.2931e-02]],\n", + "\n", + " [[ 1.8064e-01, 7.4196e-02, -2.5236e-01, ..., 1.6780e-01,\n", + " 1.5041e-02, 2.9471e-01],\n", + " [-1.4636e-02, 1.8216e-02, 2.5491e-01, ..., 2.8809e-01,\n", + " 1.3363e-01, 2.7224e-03]],\n", + "\n", + " [[ 1.2506e-01, 7.1351e-02, -2.5676e-01, ..., 1.4592e-01,\n", + " 2.3883e-04, 2.6633e-01],\n", + " [-3.4466e-02, 1.7536e-02, 2.4431e-01, ..., 2.4604e-01,\n", + " 1.0138e-01, 2.7211e-02]],\n", + "\n", + " ...,\n", + "\n", + " [[ 1.5771e-01, 8.0993e-02, -2.6345e-01, ..., 1.5931e-01,\n", + " 1.5851e-02, 2.8181e-01],\n", + " [-3.0669e-02, 8.0512e-04, 2.3962e-01, ..., 2.3068e-01,\n", + " 5.2713e-02, -2.8760e-02]],\n", + "\n", + " [[ 1.4276e-01, 6.9402e-02, -2.5167e-01, ..., 1.6150e-01,\n", + " 1.6533e-02, 2.7787e-01],\n", + " [-4.1895e-02, 1.5549e-02, 2.4193e-01, ..., 2.8668e-01,\n", + " 1.0456e-01, 3.9524e-02]],\n", + "\n", + " [[ 1.5462e-01, 6.9859e-02, -2.5659e-01, ..., 1.5087e-01,\n", + " 1.9992e-02, 2.7064e-01],\n", + " [-4.0898e-02, -1.0490e-02, 2.5776e-01, ..., 3.1455e-01,\n", + " 9.0539e-02, -2.9835e-02]]]], grad_fn=)\n", + "tensor([[[-0.5505, 0.1107, 2.0577, ..., -0.9555, -0.6820, -0.7022],\n", + " [-0.0198, -0.7570, 1.2181, ..., 0.5870, 0.6809, 1.1828],\n", + " [-1.2578, 0.4266, 0.6185, ..., 0.2530, -0.8848, -0.4981],\n", + " ...,\n", + " [-0.4563, -0.6531, 0.8596, ..., -0.2610, -0.2033, 0.1624],\n", + " [-0.6091, -0.0584, 1.8346, ..., 0.8145, -1.1799, -0.5306],\n", + " [-0.4906, -0.1911, 1.6930, ..., -0.9406, 1.5391, -0.0060]],\n", + "\n", + " [[-0.6618, 0.2852, 1.9661, ..., -1.2402, -0.7164, -0.7102],\n", + " [ 1.6142, -0.0938, 1.2072, ..., -0.2086, 0.5469, 0.1620],\n", + " [-0.5709, 0.5075, 0.9062, ..., -0.5256, 0.0475, -0.6480],\n", + " ...,\n", + " [-1.4849, -0.7005, 1.4981, ..., -0.9821, 0.8339, -0.4232],\n", + " [ 1.2407, 1.3584, 0.4344, ..., 0.6352, -2.1573, -0.3631],\n", + " [-0.4876, -0.0971, 1.6321, ..., -1.0130, 1.5289, -0.1059]]],\n", + " grad_fn=)\n", + "tensor([[[-0.5505, 0.1107, 2.0577, ..., -0.9555, -0.6820, -0.7022],\n", + " [-0.0198, -0.7570, 1.2181, ..., 0.5870, 0.6809, 1.1828],\n", + " [-1.2578, 0.4266, 0.6185, ..., 0.2530, -0.8848, -0.4981],\n", + " ...,\n", + " [-0.4563, -0.6531, 0.8596, ..., -0.2610, -0.2033, 0.1624],\n", + " [-0.6091, -0.0584, 1.8346, ..., 0.8145, -1.1799, -0.5306],\n", + " [-0.4906, -0.1911, 1.6930, ..., -0.9406, 1.5391, -0.0060]],\n", + "\n", + " [[-0.6618, 0.2852, 1.9661, ..., -1.2402, -0.7164, -0.7102],\n", + " [ 1.6142, -0.0938, 1.2072, ..., -0.2086, 0.5469, 0.1620],\n", + " [-0.5709, 0.5075, 0.9062, ..., -0.5256, 0.0475, -0.6480],\n", + " ...,\n", + " [-1.4849, -0.7005, 1.4981, ..., -0.9821, 0.8339, -0.4232],\n", + " [ 1.2407, 1.3584, 0.4344, ..., 0.6352, -2.1573, -0.3631],\n", + " [-0.4876, -0.0971, 1.6321, ..., -1.0130, 1.5289, -0.1059]]],\n", + " grad_fn=)\n", + "tensor([[[-0.5505, 0.1107, 2.0577, ..., -0.9555, -0.6820, -0.7022],\n", + " [-0.0198, -0.7570, 1.2181, ..., 0.5870, 0.6809, 1.1828],\n", + " [-1.2578, 0.4266, 0.6185, ..., 0.2530, -0.8848, -0.4981],\n", + " ...,\n", + " [-0.4563, -0.6531, 0.8596, ..., -0.2610, -0.2033, 0.1624],\n", + " [-0.6091, -0.0584, 1.8346, ..., 0.8145, -1.1799, -0.5306],\n", + " [-0.4906, -0.1911, 1.6930, ..., -0.9406, 1.5391, -0.0060]],\n", + "\n", + " [[-0.6618, 0.2852, 1.9661, ..., -1.2402, -0.7164, -0.7102],\n", + " [ 1.6142, -0.0938, 1.2072, ..., -0.2086, 0.5469, 0.1620],\n", + " [-0.5709, 0.5075, 0.9062, ..., -0.5256, 0.0475, -0.6480],\n", + " ...,\n", + " [-1.4849, -0.7005, 1.4981, ..., -0.9821, 0.8339, -0.4232],\n", + " [ 1.2407, 1.3584, 0.4344, ..., 0.6352, -2.1573, -0.3631],\n", + " [-0.4876, -0.0971, 1.6321, ..., -1.0130, 1.5289, -0.1059]]],\n", + " grad_fn=)\n", + "tensor([[[[-0.5505, 0.1107, 2.0577, ..., -1.3225, 0.3084, 1.8929],\n", + " [-0.5827, 0.1236, -0.4949, ..., -0.9555, -0.6820, -0.7022]],\n", + "\n", + " [[-0.0198, -0.7570, 1.2181, ..., -0.7716, -1.8528, 2.7124],\n", + " [-1.6248, 1.5169, -0.4428, ..., 0.5870, 0.6809, 1.1828]],\n", + "\n", + " [[-1.2578, 0.4266, 0.6185, ..., -1.2173, -1.2894, 2.8724],\n", + " [-0.9846, -0.0505, 1.4717, ..., 0.2530, -0.8848, -0.4981]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.4563, -0.6531, 0.8596, ..., -0.4867, -0.6598, 0.4648],\n", + " [ 1.1106, 0.8209, -0.5784, ..., -0.2610, -0.2033, 0.1624]],\n", + "\n", + " [[-0.6091, -0.0584, 1.8346, ..., -0.2914, 0.2991, 1.0348],\n", + " [-2.0720, -0.6064, -0.0492, ..., 0.8145, -1.1799, -0.5306]],\n", + "\n", + " [[-0.4906, -0.1911, 1.6930, ..., -0.8351, -0.3648, 1.1076],\n", + " [ 0.9185, -0.5383, 1.3007, ..., -0.9406, 1.5391, -0.0060]]],\n", + "\n", + "\n", + " [[[-0.6618, 0.2852, 1.9661, ..., -1.1709, 0.6875, 1.9930],\n", + " [-0.5214, -0.0622, -0.3464, ..., -1.2402, -0.7164, -0.7102]],\n", + "\n", + " [[ 1.6142, -0.0938, 1.2072, ..., 0.3060, -0.3371, 1.5283],\n", + " [ 0.3286, 0.1655, -0.4952, ..., -0.2086, 0.5469, 0.1620]],\n", + "\n", + " [[-0.5709, 0.5075, 0.9062, ..., -1.0003, -1.0048, 1.0080],\n", + " [ 0.1434, 0.1086, 1.5418, ..., -0.5256, 0.0475, -0.6480]],\n", + "\n", + " ...,\n", + "\n", + " [[-1.4849, -0.7005, 1.4981, ..., -0.1283, -0.1899, -0.0448],\n", + " [-0.0697, 1.0441, 0.5048, ..., -0.9821, 0.8339, -0.4232]],\n", + "\n", + " [[ 1.2407, 1.3584, 0.4344, ..., -0.0130, -0.6824, 0.6609],\n", + " [-1.5475, -1.0539, 0.1634, ..., 0.6352, -2.1573, -0.3631]],\n", + "\n", + " [[-0.4876, -0.0971, 1.6321, ..., -0.8579, -0.3016, 1.1005],\n", + " [ 0.9250, -0.6386, 1.3778, ..., -1.0130, 1.5289, -0.1059]]]],\n", + " grad_fn=)\n", + "tensor([[[-0.5505, 0.1107, 2.0577, ..., -0.9555, -0.6820, -0.7022],\n", + " [-0.0198, -0.7570, 1.2181, ..., 0.5870, 0.6809, 1.1828],\n", + " [-1.2578, 0.4266, 0.6185, ..., 0.2530, -0.8848, -0.4981],\n", + " ...,\n", + " [-0.4563, -0.6531, 0.8596, ..., -0.2610, -0.2033, 0.1624],\n", + " [-0.6091, -0.0584, 1.8346, ..., 0.8145, -1.1799, -0.5306],\n", + " [-0.4906, -0.1911, 1.6930, ..., -0.9406, 1.5391, -0.0060]],\n", + "\n", + " [[-0.6618, 0.2852, 1.9661, ..., -1.2402, -0.7164, -0.7102],\n", + " [ 1.6142, -0.0938, 1.2072, ..., -0.2086, 0.5469, 0.1620],\n", + " [-0.5709, 0.5075, 0.9062, ..., -0.5256, 0.0475, -0.6480],\n", + " ...,\n", + " [-1.4849, -0.7005, 1.4981, ..., -0.9821, 0.8339, -0.4232],\n", + " [ 1.2407, 1.3584, 0.4344, ..., 0.6352, -2.1573, -0.3631],\n", + " [-0.4876, -0.0971, 1.6321, ..., -1.0130, 1.5289, -0.1059]]],\n", + " grad_fn=)\n", + "tensor([[[-0.5505, 0.1107, 2.0577, ..., -0.9555, -0.6820, -0.7022],\n", + " [-0.0198, -0.7570, 1.2181, ..., 0.5870, 0.6809, 1.1828],\n", + " [-1.2578, 0.4266, 0.6185, ..., 0.2530, -0.8848, -0.4981],\n", + " ...,\n", + " [-0.4563, -0.6531, 0.8596, ..., -0.2610, -0.2033, 0.1624],\n", + " [-0.6091, -0.0584, 1.8346, ..., 0.8145, -1.1799, -0.5306],\n", + " [-0.4906, -0.1911, 1.6930, ..., -0.9406, 1.5391, -0.0060]],\n", + "\n", + " [[-0.6618, 0.2852, 1.9661, ..., -1.2402, -0.7164, -0.7102],\n", + " [ 1.6142, -0.0938, 1.2072, ..., -0.2086, 0.5469, 0.1620],\n", + " [-0.5709, 0.5075, 0.9062, ..., -0.5256, 0.0475, -0.6480],\n", + " ...,\n", + " [-1.4849, -0.7005, 1.4981, ..., -0.9821, 0.8339, -0.4232],\n", + " [ 1.2407, 1.3584, 0.4344, ..., 0.6352, -2.1573, -0.3631],\n", + " [-0.4876, -0.0971, 1.6321, ..., -1.0130, 1.5289, -0.1059]]],\n", + " grad_fn=)\n", + "tensor([[[[-0.5505, 0.1107, 2.0577, ..., -1.3225, 0.3084, 1.8929],\n", + " [-0.5827, 0.1236, -0.4949, ..., -0.9555, -0.6820, -0.7022]],\n", + "\n", + " [[-0.0198, -0.7570, 1.2181, ..., -0.7716, -1.8528, 2.7124],\n", + " [-1.6248, 1.5169, -0.4428, ..., 0.5870, 0.6809, 1.1828]],\n", + "\n", + " [[-1.2578, 0.4266, 0.6185, ..., -1.2173, -1.2894, 2.8724],\n", + " [-0.9846, -0.0505, 1.4717, ..., 0.2530, -0.8848, -0.4981]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.4563, -0.6531, 0.8596, ..., -0.4867, -0.6598, 0.4648],\n", + " [ 1.1106, 0.8209, -0.5784, ..., -0.2610, -0.2033, 0.1624]],\n", + "\n", + " [[-0.6091, -0.0584, 1.8346, ..., -0.2914, 0.2991, 1.0348],\n", + " [-2.0720, -0.6064, -0.0492, ..., 0.8145, -1.1799, -0.5306]],\n", + "\n", + " [[-0.4906, -0.1911, 1.6930, ..., -0.8351, -0.3648, 1.1076],\n", + " [ 0.9185, -0.5383, 1.3007, ..., -0.9406, 1.5391, -0.0060]]],\n", + "\n", + "\n", + " [[[-0.6618, 0.2852, 1.9661, ..., -1.1709, 0.6875, 1.9930],\n", + " [-0.5214, -0.0622, -0.3464, ..., -1.2402, -0.7164, -0.7102]],\n", + "\n", + " [[ 1.6142, -0.0938, 1.2072, ..., 0.3060, -0.3371, 1.5283],\n", + " [ 0.3286, 0.1655, -0.4952, ..., -0.2086, 0.5469, 0.1620]],\n", + "\n", + " [[-0.5709, 0.5075, 0.9062, ..., -1.0003, -1.0048, 1.0080],\n", + " [ 0.1434, 0.1086, 1.5418, ..., -0.5256, 0.0475, -0.6480]],\n", + "\n", + " ...,\n", + "\n", + " [[-1.4849, -0.7005, 1.4981, ..., -0.1283, -0.1899, -0.0448],\n", + " [-0.0697, 1.0441, 0.5048, ..., -0.9821, 0.8339, -0.4232]],\n", + "\n", + " [[ 1.2407, 1.3584, 0.4344, ..., -0.0130, -0.6824, 0.6609],\n", + " [-1.5475, -1.0539, 0.1634, ..., 0.6352, -2.1573, -0.3631]],\n", + "\n", + " [[-0.4876, -0.0971, 1.6321, ..., -0.8579, -0.3016, 1.1005],\n", + " [ 0.9250, -0.6386, 1.3778, ..., -1.0130, 1.5289, -0.1059]]]],\n", + " grad_fn=)\n", + "tensor([[[-0.5505, 0.1107, 2.0577, ..., -0.9555, -0.6820, -0.7022],\n", + " [-0.0198, -0.7570, 1.2181, ..., 0.5870, 0.6809, 1.1828],\n", + " [-1.2578, 0.4266, 0.6185, ..., 0.2530, -0.8848, -0.4981],\n", + " ...,\n", + " [-0.4563, -0.6531, 0.8596, ..., -0.2610, -0.2033, 0.1624],\n", + " [-0.6091, -0.0584, 1.8346, ..., 0.8145, -1.1799, -0.5306],\n", + " [-0.4906, -0.1911, 1.6930, ..., -0.9406, 1.5391, -0.0060]],\n", + "\n", + " [[-0.6618, 0.2852, 1.9661, ..., -1.2402, -0.7164, -0.7102],\n", + " [ 1.6142, -0.0938, 1.2072, ..., -0.2086, 0.5469, 0.1620],\n", + " [-0.5709, 0.5075, 0.9062, ..., -0.5256, 0.0475, -0.6480],\n", + " ...,\n", + " [-1.4849, -0.7005, 1.4981, ..., -0.9821, 0.8339, -0.4232],\n", + " [ 1.2407, 1.3584, 0.4344, ..., 0.6352, -2.1573, -0.3631],\n", + " [-0.4876, -0.0971, 1.6321, ..., -1.0130, 1.5289, -0.1059]]],\n", + " grad_fn=)\n", + "tensor([[[-0.5505, 0.1107, 2.0577, ..., -0.9555, -0.6820, -0.7022],\n", + " [-0.0198, -0.7570, 1.2181, ..., 0.5870, 0.6809, 1.1828],\n", + " [-1.2578, 0.4266, 0.6185, ..., 0.2530, -0.8848, -0.4981],\n", + " ...,\n", + " [-0.4563, -0.6531, 0.8596, ..., -0.2610, -0.2033, 0.1624],\n", + " [-0.6091, -0.0584, 1.8346, ..., 0.8145, -1.1799, -0.5306],\n", + " [-0.4906, -0.1911, 1.6930, ..., -0.9406, 1.5391, -0.0060]],\n", + "\n", + " [[-0.6618, 0.2852, 1.9661, ..., -1.2402, -0.7164, -0.7102],\n", + " [ 1.6142, -0.0938, 1.2072, ..., -0.2086, 0.5469, 0.1620],\n", + " [-0.5709, 0.5075, 0.9062, ..., -0.5256, 0.0475, -0.6480],\n", + " ...,\n", + " [-1.4849, -0.7005, 1.4981, ..., -0.9821, 0.8339, -0.4232],\n", + " [ 1.2407, 1.3584, 0.4344, ..., 0.6352, -2.1573, -0.3631],\n", + " [-0.4876, -0.0971, 1.6321, ..., -1.0130, 1.5289, -0.1059]]],\n", + " grad_fn=)\n", + "tensor([[[[-0.5505, 0.1107, 2.0577, ..., -1.3225, 0.3084, 1.8929],\n", + " [-0.5827, 0.1236, -0.4949, ..., -0.9555, -0.6820, -0.7022]],\n", + "\n", + " [[-0.0198, -0.7570, 1.2181, ..., -0.7716, -1.8528, 2.7124],\n", + " [-1.6248, 1.5169, -0.4428, ..., 0.5870, 0.6809, 1.1828]],\n", + "\n", + " [[-1.2578, 0.4266, 0.6185, ..., -1.2173, -1.2894, 2.8724],\n", + " [-0.9846, -0.0505, 1.4717, ..., 0.2530, -0.8848, -0.4981]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.4563, -0.6531, 0.8596, ..., -0.4867, -0.6598, 0.4648],\n", + " [ 1.1106, 0.8209, -0.5784, ..., -0.2610, -0.2033, 0.1624]],\n", + "\n", + " [[-0.6091, -0.0584, 1.8346, ..., -0.2914, 0.2991, 1.0348],\n", + " [-2.0720, -0.6064, -0.0492, ..., 0.8145, -1.1799, -0.5306]],\n", + "\n", + " [[-0.4906, -0.1911, 1.6930, ..., -0.8351, -0.3648, 1.1076],\n", + " [ 0.9185, -0.5383, 1.3007, ..., -0.9406, 1.5391, -0.0060]]],\n", + "\n", + "\n", + " [[[-0.6618, 0.2852, 1.9661, ..., -1.1709, 0.6875, 1.9930],\n", + " [-0.5214, -0.0622, -0.3464, ..., -1.2402, -0.7164, -0.7102]],\n", + "\n", + " [[ 1.6142, -0.0938, 1.2072, ..., 0.3060, -0.3371, 1.5283],\n", + " [ 0.3286, 0.1655, -0.4952, ..., -0.2086, 0.5469, 0.1620]],\n", + "\n", + " [[-0.5709, 0.5075, 0.9062, ..., -1.0003, -1.0048, 1.0080],\n", + " [ 0.1434, 0.1086, 1.5418, ..., -0.5256, 0.0475, -0.6480]],\n", + "\n", + " ...,\n", + "\n", + " [[-1.4849, -0.7005, 1.4981, ..., -0.1283, -0.1899, -0.0448],\n", + " [-0.0697, 1.0441, 0.5048, ..., -0.9821, 0.8339, -0.4232]],\n", + "\n", + " [[ 1.2407, 1.3584, 0.4344, ..., -0.0130, -0.6824, 0.6609],\n", + " [-1.5475, -1.0539, 0.1634, ..., 0.6352, -2.1573, -0.3631]],\n", + "\n", + " [[-0.4876, -0.0971, 1.6321, ..., -0.8579, -0.3016, 1.1005],\n", + " [ 0.9250, -0.6386, 1.3778, ..., -1.0130, 1.5289, -0.1059]]]],\n", + " grad_fn=)\n", + "tensor([[[[-0.5129, 0.0739, 1.9879, ..., -1.2916, 0.2115, 1.9063],\n", + " [-0.0128, -0.7395, 1.2154, ..., -0.7787, -1.8157, 2.6908],\n", + " [-1.2401, 0.4203, 0.6226, ..., -1.2122, -1.2833, 2.8581],\n", + " ...,\n", + " [-0.4492, -0.6403, 0.8638, ..., -0.5032, -0.6584, 0.5015],\n", + " [-0.6029, -0.0583, 1.8284, ..., -0.2972, 0.2937, 1.0404],\n", + " [-0.4790, -0.1890, 1.6839, ..., -0.8395, -0.3631, 1.1144]],\n", + "\n", + " [[-0.5777, 0.1269, -0.4880, ..., -0.9530, -0.6797, -0.6980],\n", + " [-1.6160, 1.5094, -0.4383, ..., 0.5804, 0.6720, 1.1740],\n", + " [-0.9806, -0.0479, 1.4650, ..., 0.2488, -0.8827, -0.4981],\n", + " ...,\n", + " [ 1.1048, 0.8187, -0.5756, ..., -0.2628, -0.2044, 0.1607],\n", + " [-2.0494, -0.5935, -0.0452, ..., 0.7967, -1.1696, -0.5256],\n", + " [ 0.9034, -0.5271, 1.2857, ..., -0.9360, 1.5138, -0.0080]]],\n", + "\n", + "\n", + " [[[-0.6100, 0.2680, 1.9264, ..., -1.1444, 0.6306, 1.9625],\n", + " [ 1.6044, -0.0923, 1.2065, ..., 0.2951, -0.3332, 1.5284],\n", + " [-0.5660, 0.4978, 0.9149, ..., -0.9935, -1.0068, 1.0215],\n", + " ...,\n", + " [-1.4620, -0.6918, 1.4929, ..., -0.1373, -0.1919, -0.0248],\n", + " [ 1.2220, 1.3397, 0.4417, ..., -0.0220, -0.6782, 0.6692],\n", + " [-0.4826, -0.0980, 1.6257, ..., -0.8608, -0.3050, 1.1031]],\n", + "\n", + " [[-0.5214, -0.0432, -0.3264, ..., -1.2189, -0.7141, -0.7105],\n", + " [ 0.3189, 0.1723, -0.4790, ..., -0.2159, 0.5386, 0.1538],\n", + " [ 0.1334, 0.1188, 1.5190, ..., -0.5264, 0.0441, -0.6416],\n", + " ...,\n", + " [-0.0706, 1.0332, 0.5045, ..., -0.9781, 0.8242, -0.4226],\n", + " [-1.5423, -1.0456, 0.1644, ..., 0.6298, -2.1487, -0.3635],\n", + " [ 0.9071, -0.6167, 1.3615, ..., -1.0075, 1.5075, -0.1112]]]],\n", + " grad_fn=)\n", + "tensor([[[[-0.5129, 0.0739, 1.9879, ..., -1.2916, 0.2115, 1.9063],\n", + " [-0.5777, 0.1269, -0.4880, ..., -0.9530, -0.6797, -0.6980]],\n", + "\n", + " [[-0.0128, -0.7395, 1.2154, ..., -0.7787, -1.8157, 2.6908],\n", + " [-1.6160, 1.5094, -0.4383, ..., 0.5804, 0.6720, 1.1740]],\n", + "\n", + " [[-1.2401, 0.4203, 0.6226, ..., -1.2122, -1.2833, 2.8581],\n", + " [-0.9806, -0.0479, 1.4650, ..., 0.2488, -0.8827, -0.4981]],\n", + "\n", + " ...,\n", + "\n", + " [[-0.4492, -0.6403, 0.8638, ..., -0.5032, -0.6584, 0.5015],\n", + " [ 1.1048, 0.8187, -0.5756, ..., -0.2628, -0.2044, 0.1607]],\n", + "\n", + " [[-0.6029, -0.0583, 1.8284, ..., -0.2972, 0.2937, 1.0404],\n", + " [-2.0494, -0.5935, -0.0452, ..., 0.7967, -1.1696, -0.5256]],\n", + "\n", + " [[-0.4790, -0.1890, 1.6839, ..., -0.8395, -0.3631, 1.1144],\n", + " [ 0.9034, -0.5271, 1.2857, ..., -0.9360, 1.5138, -0.0080]]],\n", + "\n", + "\n", + " [[[-0.6100, 0.2680, 1.9264, ..., -1.1444, 0.6306, 1.9625],\n", + " [-0.5214, -0.0432, -0.3264, ..., -1.2189, -0.7141, -0.7105]],\n", + "\n", + " [[ 1.6044, -0.0923, 1.2065, ..., 0.2951, -0.3332, 1.5284],\n", + " [ 0.3189, 0.1723, -0.4790, ..., -0.2159, 0.5386, 0.1538]],\n", + "\n", + " [[-0.5660, 0.4978, 0.9149, ..., -0.9935, -1.0068, 1.0215],\n", + " [ 0.1334, 0.1188, 1.5190, ..., -0.5264, 0.0441, -0.6416]],\n", + "\n", + " ...,\n", + "\n", + " [[-1.4620, -0.6918, 1.4929, ..., -0.1373, -0.1919, -0.0248],\n", + " [-0.0706, 1.0332, 0.5045, ..., -0.9781, 0.8242, -0.4226]],\n", + "\n", + " [[ 1.2220, 1.3397, 0.4417, ..., -0.0220, -0.6782, 0.6692],\n", + " [-1.5423, -1.0456, 0.1644, ..., 0.6298, -2.1487, -0.3635]],\n", + "\n", + " [[-0.4826, -0.0980, 1.6257, ..., -0.8608, -0.3050, 1.1031],\n", + " [ 0.9071, -0.6167, 1.3615, ..., -1.0075, 1.5075, -0.1112]]]],\n", + " grad_fn=)\n" + ] + } + ], "source": [ "from chop.pipelines import CompressionPipeline\n", "from chop import MaseGraph\n", @@ -355,9 +1226,24 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32mINFO \u001b[0m \u001b[34mExporting MaseGraph to /home/ism/tutorial_5_nas_compressed.pt, /home/ism/tutorial_5_nas_compressed.mz\u001b[0m\n", + "\u001b[32mINFO \u001b[0m \u001b[34mExporting GraphModule to /home/ism/tutorial_5_nas_compressed.pt\u001b[0m\n", + "\u001b[32mINFO \u001b[0m \u001b[34mSaving state_dict format\u001b[0m\n", + "\u001b[32mINFO \u001b[0m \u001b[34mExporting MaseMetadata to /home/ism/tutorial_5_nas_compressed.mz\u001b[0m\n", + "\u001b[33mWARNING \u001b[0m \u001b[34mFailed to pickle call_function node: finfo\u001b[0m\n", + "\u001b[33mWARNING \u001b[0m \u001b[34mcannot pickle 'torch.finfo' object\u001b[0m\n", + "\u001b[33mWARNING \u001b[0m \u001b[34mFailed to pickle call_function node: getattr_3\u001b[0m\n", + "\u001b[33mWARNING \u001b[0m \u001b[34mcannot pickle 'torch.finfo' object\u001b[0m\n" + ] + } + ], "source": [ "mg.export(f\"{Path.home()}/tutorial_5_nas_compressed\", save_format=\"state_dict\")" ] @@ -365,7 +1251,7 @@ ], "metadata": { "kernelspec": { - "display_name": ".venv", + "display_name": "mase", "language": "python", "name": "python3" }, diff --git a/docs/source/modules/documentation/tutorials/tutorial_6_mixed_precision_search.ipynb b/docs/source/modules/documentation/tutorials/tutorial_6_mixed_precision_search.ipynb index cfff717e4..d01248dea 100644 --- a/docs/source/modules/documentation/tutorials/tutorial_6_mixed_precision_search.ipynb +++ b/docs/source/modules/documentation/tutorials/tutorial_6_mixed_precision_search.ipynb @@ -53,7 +53,18 @@ "cell_type": "code", "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/ism/ADL/mase/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "/home/ism/ADL/mase/.venv/lib/python3.11/site-packages/torch/cuda/__init__.py:61: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.\n", + " import pynvml # type: ignore[import]\n" + ] + } + ], "source": [ "from transformers import AutoModel\n", "\n", @@ -69,7 +80,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -89,9 +100,17 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32mINFO \u001b[0m \u001b[34mTokenizing dataset imdb with AutoTokenizer for bert-base-uncased.\u001b[0m\n" + ] + } + ], "source": [ "from chop.tools import get_tokenized_dataset\n", "\n", @@ -296,9 +315,153 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[I 2026-02-03 12:31:37,285] A new study created in memory with name: bert-tiny-nas-study\n", + "/home/ism/ADL/mase/.venv/lib/python3.11/site-packages/optuna/distributions.py:502: UserWarning: Choices for a categorical distribution should be a tuple of None, bool, int, float and str for persistent storage but contains which is of type type.\n", + " warnings.warn(message)\n", + "/home/ism/ADL/mase/.venv/lib/python3.11/site-packages/optuna/distributions.py:502: UserWarning: Choices for a categorical distribution should be a tuple of None, bool, int, float and str for persistent storage but contains which is of type type.\n", + " warnings.warn(message)\n", + "/home/ism/ADL/mase/src/chop/tools/huggingface.py:157: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n", + " trainer = Trainer(\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " \n", + " [2741/3125 00:47 < 00:06, 57.09 it/s, Epoch 0.88/1]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
StepTraining Loss
5000.371300
10000.324200
15000.321200
20000.343100
25000.329000

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[W 2026-02-03 12:32:27,486] Trial 0 failed with parameters: {'bert.encoder.layer.0.attention.self.query_type': , 'bert.encoder.layer.0.attention.self.value_type': , 'bert.encoder.layer.0.intermediate.dense_type': , 'bert.encoder.layer.0.output.dense_type': , 'bert.encoder.layer.1.attention.output.dense_type': , 'bert.encoder.layer.1.intermediate.dense_type': , 'bert.encoder.layer.1.output.dense_type': , 'classifier_type': } because of the following error: KeyboardInterrupt().\n", + "Traceback (most recent call last):\n", + " File \"/home/ism/ADL/mase/.venv/lib/python3.11/site-packages/optuna/study/_optimize.py\", line 205, in _run_trial\n", + " value_or_values = func(trial)\n", + " ^^^^^^^^^^^\n", + " File \"/tmp/ipykernel_395833/319635677.py\", line 18, in objective\n", + " trainer.train()\n", + " File \"/home/ism/ADL/mase/.venv/lib/python3.11/site-packages/transformers/trainer.py\", line 2245, in train\n", + " return inner_training_loop(\n", + " ^^^^^^^^^^^^^^^^^^^^\n", + " File \"/home/ism/ADL/mase/.venv/lib/python3.11/site-packages/transformers/trainer.py\", line 2514, in _inner_training_loop\n", + " batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches, args.device)\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"/home/ism/ADL/mase/.venv/lib/python3.11/site-packages/transformers/trainer.py\", line 5243, in get_batch_samples\n", + " batch_samples.append(next(epoch_iterator))\n", + " ^^^^^^^^^^^^^^^^^^^^\n", + " File \"/home/ism/ADL/mase/.venv/lib/python3.11/site-packages/accelerate/data_loader.py\", line 579, in __iter__\n", + " next_batch = next(dataloader_iter)\n", + " ^^^^^^^^^^^^^^^^^^^^^\n", + " File \"/home/ism/ADL/mase/.venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py\", line 708, in __next__\n", + " data = self._next_data()\n", + " ^^^^^^^^^^^^^^^^^\n", + " File \"/home/ism/ADL/mase/.venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py\", line 764, in _next_data\n", + " data = self._dataset_fetcher.fetch(index) # may raise StopIteration\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"/home/ism/ADL/mase/.venv/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py\", line 55, in fetch\n", + " return self.collate_fn(data)\n", + " ^^^^^^^^^^^^^^^^^^^^^\n", + " File \"/home/ism/ADL/mase/.venv/lib/python3.11/site-packages/transformers/data/data_collator.py\", line 272, in __call__\n", + " batch = pad_without_fast_tokenizer_warning(\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"/home/ism/ADL/mase/.venv/lib/python3.11/site-packages/transformers/data/data_collator.py\", line 67, in pad_without_fast_tokenizer_warning\n", + " padded = tokenizer.pad(*pad_args, **pad_kwargs)\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"/home/ism/ADL/mase/.venv/lib/python3.11/site-packages/transformers/tokenization_utils_base.py\", line 3407, in pad\n", + " return BatchEncoding(batch_outputs, tensor_type=return_tensors)\n", + " ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " File \"/home/ism/ADL/mase/.venv/lib/python3.11/site-packages/transformers/tokenization_utils_base.py\", line 241, in __init__\n", + " self.convert_to_tensors(tensor_type=tensor_type, prepend_batch_axis=prepend_batch_axis)\n", + " File \"/home/ism/ADL/mase/.venv/lib/python3.11/site-packages/transformers/tokenization_utils_base.py\", line 777, in convert_to_tensors\n", + " tensor = as_tensor(value)\n", + " ^^^^^^^^^^^^^^^^\n", + " File \"/home/ism/ADL/mase/.venv/lib/python3.11/site-packages/transformers/tokenization_utils_base.py\", line 739, in as_tensor\n", + " return torch.tensor(value)\n", + " ^^^^^^^^^^^^^^^^^^^\n", + "KeyboardInterrupt\n", + "[W 2026-02-03 12:32:27,495] Trial 0 failed with value None.\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mKeyboardInterrupt\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[9]\u001b[39m\u001b[32m, line 9\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01moptuna\u001b[39;00m\n\u001b[32m 3\u001b[39m study = optuna.create_study(\n\u001b[32m 4\u001b[39m direction=\u001b[33m\"\u001b[39m\u001b[33mmaximize\u001b[39m\u001b[33m\"\u001b[39m,\n\u001b[32m 5\u001b[39m study_name=\u001b[33m\"\u001b[39m\u001b[33mbert-tiny-nas-study\u001b[39m\u001b[33m\"\u001b[39m,\n\u001b[32m 6\u001b[39m sampler=sampler,\n\u001b[32m 7\u001b[39m )\n\u001b[32m----> \u001b[39m\u001b[32m9\u001b[39m \u001b[43mstudy\u001b[49m\u001b[43m.\u001b[49m\u001b[43moptimize\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 10\u001b[39m \u001b[43m \u001b[49m\u001b[43mobjective\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 11\u001b[39m \u001b[43m \u001b[49m\u001b[43mn_trials\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m1\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[32m 12\u001b[39m \u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m60\u001b[39;49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m \u001b[49m\u001b[32;43m60\u001b[39;49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m \u001b[49m\u001b[32;43m24\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[32m 13\u001b[39m \u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/ADL/mase/.venv/lib/python3.11/site-packages/optuna/study/study.py:490\u001b[39m, in \u001b[36mStudy.optimize\u001b[39m\u001b[34m(self, func, n_trials, timeout, n_jobs, catch, callbacks, gc_after_trial, show_progress_bar)\u001b[39m\n\u001b[32m 388\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34moptimize\u001b[39m(\n\u001b[32m 389\u001b[39m \u001b[38;5;28mself\u001b[39m,\n\u001b[32m 390\u001b[39m func: ObjectiveFuncType,\n\u001b[32m (...)\u001b[39m\u001b[32m 397\u001b[39m show_progress_bar: \u001b[38;5;28mbool\u001b[39m = \u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[32m 398\u001b[39m ) -> \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m 399\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Optimize an objective function.\u001b[39;00m\n\u001b[32m 400\u001b[39m \n\u001b[32m 401\u001b[39m \u001b[33;03m Optimization is done by choosing a suitable set of hyperparameter values from a given\u001b[39;00m\n\u001b[32m (...)\u001b[39m\u001b[32m 488\u001b[39m \u001b[33;03m If nested invocation of this method occurs.\u001b[39;00m\n\u001b[32m 489\u001b[39m \u001b[33;03m \"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m490\u001b[39m \u001b[43m_optimize\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 491\u001b[39m \u001b[43m \u001b[49m\u001b[43mstudy\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[32m 492\u001b[39m \u001b[43m \u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m=\u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 493\u001b[39m \u001b[43m \u001b[49m\u001b[43mn_trials\u001b[49m\u001b[43m=\u001b[49m\u001b[43mn_trials\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 494\u001b[39m \u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m=\u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 495\u001b[39m \u001b[43m \u001b[49m\u001b[43mn_jobs\u001b[49m\u001b[43m=\u001b[49m\u001b[43mn_jobs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 496\u001b[39m \u001b[43m \u001b[49m\u001b[43mcatch\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mtuple\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mcatch\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43misinstance\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mcatch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mIterable\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43mcatch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 497\u001b[39m \u001b[43m \u001b[49m\u001b[43mcallbacks\u001b[49m\u001b[43m=\u001b[49m\u001b[43mcallbacks\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 498\u001b[39m \u001b[43m \u001b[49m\u001b[43mgc_after_trial\u001b[49m\u001b[43m=\u001b[49m\u001b[43mgc_after_trial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 499\u001b[39m \u001b[43m \u001b[49m\u001b[43mshow_progress_bar\u001b[49m\u001b[43m=\u001b[49m\u001b[43mshow_progress_bar\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 500\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/ADL/mase/.venv/lib/python3.11/site-packages/optuna/study/_optimize.py:67\u001b[39m, in \u001b[36m_optimize\u001b[39m\u001b[34m(study, func, n_trials, timeout, n_jobs, catch, callbacks, gc_after_trial, show_progress_bar)\u001b[39m\n\u001b[32m 65\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m 66\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m n_jobs == \u001b[32m1\u001b[39m:\n\u001b[32m---> \u001b[39m\u001b[32m67\u001b[39m \u001b[43m_optimize_sequential\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 68\u001b[39m \u001b[43m \u001b[49m\u001b[43mstudy\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 69\u001b[39m \u001b[43m \u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 70\u001b[39m \u001b[43m \u001b[49m\u001b[43mn_trials\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 71\u001b[39m \u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 72\u001b[39m \u001b[43m \u001b[49m\u001b[43mcatch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 73\u001b[39m \u001b[43m \u001b[49m\u001b[43mcallbacks\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 74\u001b[39m \u001b[43m \u001b[49m\u001b[43mgc_after_trial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 75\u001b[39m \u001b[43m \u001b[49m\u001b[43mreseed_sampler_rng\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[32m 76\u001b[39m \u001b[43m \u001b[49m\u001b[43mtime_start\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[32m 77\u001b[39m \u001b[43m \u001b[49m\u001b[43mprogress_bar\u001b[49m\u001b[43m=\u001b[49m\u001b[43mprogress_bar\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 78\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 79\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 80\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m n_jobs == -\u001b[32m1\u001b[39m:\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/ADL/mase/.venv/lib/python3.11/site-packages/optuna/study/_optimize.py:164\u001b[39m, in \u001b[36m_optimize_sequential\u001b[39m\u001b[34m(study, func, n_trials, timeout, catch, callbacks, gc_after_trial, reseed_sampler_rng, time_start, progress_bar)\u001b[39m\n\u001b[32m 161\u001b[39m \u001b[38;5;28;01mbreak\u001b[39;00m\n\u001b[32m 163\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m164\u001b[39m frozen_trial_id = \u001b[43m_run_trial\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstudy\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcatch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 165\u001b[39m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[32m 166\u001b[39m \u001b[38;5;66;03m# The following line mitigates memory problems that can be occurred in some\u001b[39;00m\n\u001b[32m 167\u001b[39m \u001b[38;5;66;03m# environments (e.g., services that use computing containers such as GitHub Actions).\u001b[39;00m\n\u001b[32m 168\u001b[39m \u001b[38;5;66;03m# Please refer to the following PR for further details:\u001b[39;00m\n\u001b[32m 169\u001b[39m \u001b[38;5;66;03m# https://github.com/optuna/optuna/pull/325.\u001b[39;00m\n\u001b[32m 170\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m gc_after_trial:\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/ADL/mase/.venv/lib/python3.11/site-packages/optuna/study/_optimize.py:262\u001b[39m, in \u001b[36m_run_trial\u001b[39m\u001b[34m(study, func, catch)\u001b[39m\n\u001b[32m 255\u001b[39m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28;01mFalse\u001b[39;00m, \u001b[33m\"\u001b[39m\u001b[33mShould not reach.\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 257\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[32m 258\u001b[39m updated_state == TrialState.FAIL\n\u001b[32m 259\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m func_err \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 260\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(func_err, catch)\n\u001b[32m 261\u001b[39m ):\n\u001b[32m--> \u001b[39m\u001b[32m262\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m func_err\n\u001b[32m 263\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m trial._trial_id\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/ADL/mase/.venv/lib/python3.11/site-packages/optuna/study/_optimize.py:205\u001b[39m, in \u001b[36m_run_trial\u001b[39m\u001b[34m(study, func, catch)\u001b[39m\n\u001b[32m 203\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m get_heartbeat_thread(trial._trial_id, study._storage):\n\u001b[32m 204\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m205\u001b[39m value_or_values = \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrial\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 206\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m exceptions.TrialPruned \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[32m 207\u001b[39m \u001b[38;5;66;03m# TODO(mamu): Handle multi-objective cases.\u001b[39;00m\n\u001b[32m 208\u001b[39m state = TrialState.PRUNED\n", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[7]\u001b[39m\u001b[32m, line 18\u001b[39m, in \u001b[36mobjective\u001b[39m\u001b[34m(trial)\u001b[39m\n\u001b[32m 8\u001b[39m model = construct_model(trial)\n\u001b[32m 10\u001b[39m trainer = get_trainer(\n\u001b[32m 11\u001b[39m model=model,\n\u001b[32m 12\u001b[39m tokenized_dataset=dataset,\n\u001b[32m (...)\u001b[39m\u001b[32m 15\u001b[39m num_train_epochs=\u001b[32m1\u001b[39m,\n\u001b[32m 16\u001b[39m )\n\u001b[32m---> \u001b[39m\u001b[32m18\u001b[39m \u001b[43mtrainer\u001b[49m\u001b[43m.\u001b[49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 19\u001b[39m eval_results = trainer.evaluate()\n\u001b[32m 21\u001b[39m trial.set_user_attr(\u001b[33m\"\u001b[39m\u001b[33mmodel\u001b[39m\u001b[33m\"\u001b[39m, model)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/ADL/mase/.venv/lib/python3.11/site-packages/transformers/trainer.py:2245\u001b[39m, in \u001b[36mTrainer.train\u001b[39m\u001b[34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[39m\n\u001b[32m 2243\u001b[39m hf_hub_utils.enable_progress_bars()\n\u001b[32m 2244\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m2245\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 2246\u001b[39m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m=\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2247\u001b[39m \u001b[43m \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m=\u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2248\u001b[39m \u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[43m=\u001b[49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2249\u001b[39m \u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m=\u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 2250\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/ADL/mase/.venv/lib/python3.11/site-packages/transformers/trainer.py:2514\u001b[39m, in \u001b[36mTrainer._inner_training_loop\u001b[39m\u001b[34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[39m\n\u001b[32m 2512\u001b[39m update_step += \u001b[32m1\u001b[39m\n\u001b[32m 2513\u001b[39m num_batches = args.gradient_accumulation_steps \u001b[38;5;28;01mif\u001b[39;00m update_step != (total_updates - \u001b[32m1\u001b[39m) \u001b[38;5;28;01melse\u001b[39;00m remainder\n\u001b[32m-> \u001b[39m\u001b[32m2514\u001b[39m batch_samples, num_items_in_batch = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mget_batch_samples\u001b[49m\u001b[43m(\u001b[49m\u001b[43mepoch_iterator\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_batches\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m.\u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 2515\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m i, inputs \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(batch_samples):\n\u001b[32m 2516\u001b[39m step += \u001b[32m1\u001b[39m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/ADL/mase/.venv/lib/python3.11/site-packages/transformers/trainer.py:5243\u001b[39m, in \u001b[36mTrainer.get_batch_samples\u001b[39m\u001b[34m(self, epoch_iterator, num_batches, device)\u001b[39m\n\u001b[32m 5241\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(num_batches):\n\u001b[32m 5242\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m5243\u001b[39m batch_samples.append(\u001b[38;5;28mnext\u001b[39m(epoch_iterator))\n\u001b[32m 5244\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m:\n\u001b[32m 5245\u001b[39m \u001b[38;5;28;01mbreak\u001b[39;00m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/ADL/mase/.venv/lib/python3.11/site-packages/accelerate/data_loader.py:579\u001b[39m, in \u001b[36mDataLoaderShard.__iter__\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 577\u001b[39m current_batch = send_to_device(current_batch, \u001b[38;5;28mself\u001b[39m.device, non_blocking=\u001b[38;5;28mself\u001b[39m._non_blocking)\n\u001b[32m 578\u001b[39m \u001b[38;5;28mself\u001b[39m._update_state_dict()\n\u001b[32m--> \u001b[39m\u001b[32m579\u001b[39m next_batch = \u001b[38;5;28mnext\u001b[39m(dataloader_iter)\n\u001b[32m 580\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m batch_index >= \u001b[38;5;28mself\u001b[39m.skip_batches:\n\u001b[32m 581\u001b[39m \u001b[38;5;28;01myield\u001b[39;00m current_batch\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/ADL/mase/.venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py:708\u001b[39m, in \u001b[36m_BaseDataLoaderIter.__next__\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 705\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._sampler_iter \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m 706\u001b[39m \u001b[38;5;66;03m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[32m 707\u001b[39m \u001b[38;5;28mself\u001b[39m._reset() \u001b[38;5;66;03m# type: ignore[call-arg]\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m708\u001b[39m data = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_next_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 709\u001b[39m \u001b[38;5;28mself\u001b[39m._num_yielded += \u001b[32m1\u001b[39m\n\u001b[32m 710\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[32m 711\u001b[39m \u001b[38;5;28mself\u001b[39m._dataset_kind == _DatasetKind.Iterable\n\u001b[32m 712\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m._IterableDataset_len_called \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 713\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m._num_yielded > \u001b[38;5;28mself\u001b[39m._IterableDataset_len_called\n\u001b[32m 714\u001b[39m ):\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/ADL/mase/.venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py:764\u001b[39m, in \u001b[36m_SingleProcessDataLoaderIter._next_data\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 762\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m_next_data\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[32m 763\u001b[39m index = \u001b[38;5;28mself\u001b[39m._next_index() \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m764\u001b[39m data = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_dataset_fetcher\u001b[49m\u001b[43m.\u001b[49m\u001b[43mfetch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mindex\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[32m 765\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._pin_memory:\n\u001b[32m 766\u001b[39m data = _utils.pin_memory.pin_memory(data, \u001b[38;5;28mself\u001b[39m._pin_memory_device)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/ADL/mase/.venv/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py:55\u001b[39m, in \u001b[36m_MapDatasetFetcher.fetch\u001b[39m\u001b[34m(self, possibly_batched_index)\u001b[39m\n\u001b[32m 53\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 54\u001b[39m data = \u001b[38;5;28mself\u001b[39m.dataset[possibly_batched_index]\n\u001b[32m---> \u001b[39m\u001b[32m55\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mcollate_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/ADL/mase/.venv/lib/python3.11/site-packages/transformers/data/data_collator.py:272\u001b[39m, in \u001b[36mDataCollatorWithPadding.__call__\u001b[39m\u001b[34m(self, features)\u001b[39m\n\u001b[32m 271\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, features: List[Dict[\u001b[38;5;28mstr\u001b[39m, Any]]) -> Dict[\u001b[38;5;28mstr\u001b[39m, Any]:\n\u001b[32m--> \u001b[39m\u001b[32m272\u001b[39m batch = \u001b[43mpad_without_fast_tokenizer_warning\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 273\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mtokenizer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 274\u001b[39m \u001b[43m \u001b[49m\u001b[43mfeatures\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 275\u001b[39m \u001b[43m \u001b[49m\u001b[43mpadding\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mpadding\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 276\u001b[39m \u001b[43m \u001b[49m\u001b[43mmax_length\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mmax_length\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 277\u001b[39m \u001b[43m \u001b[49m\u001b[43mpad_to_multiple_of\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mpad_to_multiple_of\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 278\u001b[39m \u001b[43m \u001b[49m\u001b[43mreturn_tensors\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mreturn_tensors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 279\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 280\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[33m\"\u001b[39m\u001b[33mlabel\u001b[39m\u001b[33m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m batch:\n\u001b[32m 281\u001b[39m batch[\u001b[33m\"\u001b[39m\u001b[33mlabels\u001b[39m\u001b[33m\"\u001b[39m] = batch[\u001b[33m\"\u001b[39m\u001b[33mlabel\u001b[39m\u001b[33m\"\u001b[39m]\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/ADL/mase/.venv/lib/python3.11/site-packages/transformers/data/data_collator.py:67\u001b[39m, in \u001b[36mpad_without_fast_tokenizer_warning\u001b[39m\u001b[34m(tokenizer, *pad_args, **pad_kwargs)\u001b[39m\n\u001b[32m 64\u001b[39m tokenizer.deprecation_warnings[\u001b[33m\"\u001b[39m\u001b[33mAsking-to-pad-a-fast-tokenizer\u001b[39m\u001b[33m\"\u001b[39m] = \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[32m 66\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m---> \u001b[39m\u001b[32m67\u001b[39m padded = \u001b[43mtokenizer\u001b[49m\u001b[43m.\u001b[49m\u001b[43mpad\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43mpad_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mpad_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 68\u001b[39m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[32m 69\u001b[39m \u001b[38;5;66;03m# Restore the state of the warning.\u001b[39;00m\n\u001b[32m 70\u001b[39m tokenizer.deprecation_warnings[\u001b[33m\"\u001b[39m\u001b[33mAsking-to-pad-a-fast-tokenizer\u001b[39m\u001b[33m\"\u001b[39m] = warning_state\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/ADL/mase/.venv/lib/python3.11/site-packages/transformers/tokenization_utils_base.py:3407\u001b[39m, in \u001b[36mPreTrainedTokenizerBase.pad\u001b[39m\u001b[34m(self, encoded_inputs, padding, max_length, pad_to_multiple_of, padding_side, return_attention_mask, return_tensors, verbose)\u001b[39m\n\u001b[32m 3404\u001b[39m batch_outputs[key] = []\n\u001b[32m 3405\u001b[39m batch_outputs[key].append(value)\n\u001b[32m-> \u001b[39m\u001b[32m3407\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mBatchEncoding\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch_outputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtensor_type\u001b[49m\u001b[43m=\u001b[49m\u001b[43mreturn_tensors\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/ADL/mase/.venv/lib/python3.11/site-packages/transformers/tokenization_utils_base.py:241\u001b[39m, in \u001b[36mBatchEncoding.__init__\u001b[39m\u001b[34m(self, data, encoding, tensor_type, prepend_batch_axis, n_sequences)\u001b[39m\n\u001b[32m 237\u001b[39m n_sequences = encoding[\u001b[32m0\u001b[39m].n_sequences\n\u001b[32m 239\u001b[39m \u001b[38;5;28mself\u001b[39m._n_sequences = n_sequences\n\u001b[32m--> \u001b[39m\u001b[32m241\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mconvert_to_tensors\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtensor_type\u001b[49m\u001b[43m=\u001b[49m\u001b[43mtensor_type\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprepend_batch_axis\u001b[49m\u001b[43m=\u001b[49m\u001b[43mprepend_batch_axis\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/ADL/mase/.venv/lib/python3.11/site-packages/transformers/tokenization_utils_base.py:777\u001b[39m, in \u001b[36mBatchEncoding.convert_to_tensors\u001b[39m\u001b[34m(self, tensor_type, prepend_batch_axis)\u001b[39m\n\u001b[32m 774\u001b[39m value = [value]\n\u001b[32m 776\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_tensor(value):\n\u001b[32m--> \u001b[39m\u001b[32m777\u001b[39m tensor = \u001b[43mas_tensor\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 779\u001b[39m \u001b[38;5;66;03m# Removing this for now in favor of controlling the shape with `prepend_batch_axis`\u001b[39;00m\n\u001b[32m 780\u001b[39m \u001b[38;5;66;03m# # at-least2d\u001b[39;00m\n\u001b[32m 781\u001b[39m \u001b[38;5;66;03m# if tensor.ndim > 2:\u001b[39;00m\n\u001b[32m 782\u001b[39m \u001b[38;5;66;03m# tensor = tensor.squeeze(0)\u001b[39;00m\n\u001b[32m 783\u001b[39m \u001b[38;5;66;03m# elif tensor.ndim < 2:\u001b[39;00m\n\u001b[32m 784\u001b[39m \u001b[38;5;66;03m# tensor = tensor[None, :]\u001b[39;00m\n\u001b[32m 786\u001b[39m \u001b[38;5;28mself\u001b[39m[key] = tensor\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/ADL/mase/.venv/lib/python3.11/site-packages/transformers/tokenization_utils_base.py:739\u001b[39m, in \u001b[36mBatchEncoding.convert_to_tensors..as_tensor\u001b[39m\u001b[34m(value, dtype)\u001b[39m\n\u001b[32m 737\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(value, \u001b[38;5;28mlist\u001b[39m) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(value[\u001b[32m0\u001b[39m], np.ndarray):\n\u001b[32m 738\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m torch.from_numpy(np.array(value))\n\u001b[32m--> \u001b[39m\u001b[32m739\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[43m.\u001b[49m\u001b[43mtensor\u001b[49m\u001b[43m(\u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[31mKeyboardInterrupt\u001b[39m: " + ] + } + ], "source": [ "import optuna\n", "\n", diff --git a/src/chop/actions/simulate.py b/src/chop/actions/simulate.py index da4547c86..47333260e 100644 --- a/src/chop/actions/simulate.py +++ b/src/chop/actions/simulate.py @@ -36,7 +36,7 @@ def simulate( skip_test: bool = False, trace_depth: int = 3, gui: bool = False, - waves: bool = False, + waves: bool = True, simulator: str = "verilator", build_jobs: int = 1, unroll_count: int = 1024, @@ -89,6 +89,7 @@ def simulate( hdl_toplevel="top", build_args=build_args, parameters=[], # use default parameters, + waves=waves, ) build_end = time.time() diff --git a/src/chop/nn/quantizers/block_minifloat.py b/src/chop/nn/quantizers/block_minifloat.py index 34e00bbcb..d26d36860 100644 --- a/src/chop/nn/quantizers/block_minifloat.py +++ b/src/chop/nn/quantizers/block_minifloat.py @@ -81,15 +81,7 @@ def forward( ) @staticmethod - def backward( - ctx, - grad_output: Tensor, - width: int, - exponent_width: int, - exponent_bias_width: int, - block_size: list[int] | int = [16], - skip_first_dim: bool = False, - ): + def backward(ctx, grad_output: Tensor): return grad_output, None, None, None, None, None diff --git a/src/chop/nn/quantizers/utils.py b/src/chop/nn/quantizers/utils.py index ae881328d..85d8fd1ca 100644 --- a/src/chop/nn/quantizers/utils.py +++ b/src/chop/nn/quantizers/utils.py @@ -89,7 +89,8 @@ class BinaryBipolarScaled(InplaceFunction): @staticmethod def alpha(tensor): # determine batch means absvalue = tensor.abs() - alpha = absvalue.mean(dim=(1, 2, 3), keepdims=True) + dims = tuple(range(1, tensor.ndim)) + alpha = absvalue.mean(dim=dims, keepdims=True) return alpha.view(-1, 1) @staticmethod @@ -100,9 +101,9 @@ def forward(ctx, input, _threshold): pos_one = torch.where(input > 0, 1.0, 0.0) neg_one = pos_one - 1 out = torch.add(pos_one, neg_one) - output = out * alpha.view(-1, 1, 1, 1).expand( - -1, input.size()[1], input.size()[2], input.size()[3] - ) + shape = [1] * input.ndim + shape[0] = -1 + output = out * alpha.view(shape).expand_as(input) return output @@ -155,7 +156,8 @@ class BinaryZeroScaled(InplaceFunction): @staticmethod def alpha(tensor): # determine batch means absvalue = tensor.abs() - alpha = absvalue.mean(dim=(1, 2, 3), keepdims=True) + dims = tuple(range(1, tensor.ndim)) + alpha = absvalue.mean(dim=dims, keepdims=True) return alpha.view(-1, 1) @staticmethod @@ -163,9 +165,9 @@ def forward(ctx, input, _threshold): alpha = BinaryZeroScaled.alpha(input) pos_one = torch.where(input > 0, 1.0, 0.0) - output = pos_one * alpha.view(-1, 1, 1, 1).expand( - -1, input.size()[1], input.size()[2], input.size()[3] - ) + shape = [1] * input.ndim + shape[0] = -1 + output = pos_one * alpha.view(shape).expand_as(input) return output @staticmethod diff --git a/src/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py b/src/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py index 7e4c80e3c..0bd104613 100644 --- a/src/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py +++ b/src/chop/passes/graph/analysis/add_metadata/add_hardware_metadata.py @@ -14,7 +14,10 @@ from torch import nn -from .hardware_metadata_layers import INTERNAL_COMP +from .hardware_metadata_layers import ( + INTERNAL_COMP, + DRAM_INTERNAL_COMP, +) logger = logging.getLogger(__name__) @@ -29,7 +32,7 @@ def _cap(name): return str(name).upper() -def add_component_source(node): +def add_component_source(node, pass_args={}): if node.meta["mase"]["hardware"]["is_implicit"]: return @@ -47,13 +50,29 @@ def add_component_source(node): node.meta["mase"]["hardware"]["dependence_files"] = op_info[ "dependence_files" ] - elif mase_op in INTERNAL_COMP.keys(): + elif mase_op in INTERNAL_COMP.keys() or mase_op in DRAM_INTERNAL_COMP.keys(): node.meta["mase"]["hardware"]["toolchain"] = "INTERNAL_RTL" - # take the first ip in the component list by default - node.meta["mase"]["hardware"]["module"] = INTERNAL_COMP[mase_op][0]["name"] - node.meta["mase"]["hardware"]["dependence_files"] = INTERNAL_COMP[mase_op][0][ - "dependence_files" - ] + storage_type = pass_args.get("interface", {}).get("storage", "BRAM") + if storage_type == "DRAM": + if mase_op not in DRAM_INTERNAL_COMP.keys(): + assert ( + False + ), f"DRAM_INTERNAL_COMP does not define mase_op '{mase_op}'" + # In DRAM mode, select module/dependencies only from DRAM_INTERNAL_COMP. + node.meta["mase"]["hardware"]["module"] = DRAM_INTERNAL_COMP[mase_op][0][ + "name" + ] + node.meta["mase"]["hardware"]["dependence_files"] = DRAM_INTERNAL_COMP[ + mase_op + ][0]["dependence_files"] + else: + # BRAM/default path keeps existing INTERNAL_COMP behavior unchanged. + node.meta["mase"]["hardware"]["module"] = INTERNAL_COMP[mase_op][0][ + "name" + ] + node.meta["mase"]["hardware"]["dependence_files"] = INTERNAL_COMP[ + mase_op + ][0]["dependence_files"] else: node.meta["mase"]["hardware"]["toolchain"] = "INTERNAL_HLS" node.meta["mase"]["hardware"]["module"] = None @@ -72,8 +91,9 @@ def add_component_source(node): ]: continue elif isinstance(arg_info, dict): + storage_type = pass_args.get("interface", {}).get("storage", "BRAM") node.meta["mase"]["hardware"]["interface"][arg] = { - "storage": "BRAM", + "storage": storage_type, "transpose": False, } else: @@ -472,7 +492,7 @@ def add_hardware_metadata_analysis_pass(graph, pass_args={}): # Add component source for node in graph.nodes: - add_component_source(node) + add_component_source(node, pass_args) # * Fix max parallelism to small value to enable verilator simulation # ! TO DO: enable this to be overriden by user diff --git a/src/chop/passes/graph/analysis/add_metadata/hardware_metadata_layers.py b/src/chop/passes/graph/analysis/add_metadata/hardware_metadata_layers.py index b084821ae..8b40e07d6 100644 --- a/src/chop/passes/graph/analysis/add_metadata/hardware_metadata_layers.py +++ b/src/chop/passes/graph/analysis/add_metadata/hardware_metadata_layers.py @@ -33,6 +33,44 @@ ], } +DRAM_INTERNAL_COMP = { + "linear": [ + { + "name": "fixed_linear", + "dependence_files": [ + "cast/rtl/fixed_cast.sv", + "linear_layers/fixed_operators/rtl/fixed_dot_product.sv", + "linear_layers/fixed_operators/rtl/fixed_vector_mult.sv", + "linear_layers/fixed_operators/rtl/fixed_accumulator.sv", + "linear_layers/fixed_operators/rtl/fixed_adder_tree.sv", + "linear_layers/fixed_operators/rtl/fixed_adder_tree_layer.sv", + "linear_layers/fixed_operators/rtl/fixed_mult.sv", + "common/rtl/register_slice.sv", + "common/rtl/join2.sv", + "memory/rtl/unpacked_repeat_circular_buffer.sv", + "memory/rtl/skid_buffer.sv", + "linear_layers/fixed_linear_layer/rtl/fixed_linear.sv", + "linear_layers/matmul/rtl/matrix_flatten.sv", + "linear_layers/matmul/rtl/matrix_unflatten.sv", + "linear_layers/matmul/rtl/matrix_fifo.sv", + "linear_layers/matmul/rtl/matrix_accumulator.sv", + "linear_layers/matmul/rtl/simple_matmul.sv", + "linear_layers/matmul/rtl/matmul.sv", + "linear_layers/matmul/rtl/transpose.sv", + "linear_layers/matmul/rtl/matrix_stream_transpose.sv", + ], + }, + ], + "relu": [ + { + "name": "fixed_relu", + "dependence_files": [ + "activation_layers/rtl/fixed_relu.sv", + ], + }, + ], +} + INTERNAL_COMP = { "linear": [ { diff --git a/src/chop/passes/graph/transforms/pruning/prune.py b/src/chop/passes/graph/transforms/pruning/prune.py index 373249706..24df84d4d 100644 --- a/src/chop/passes/graph/transforms/pruning/prune.py +++ b/src/chop/passes/graph/transforms/pruning/prune.py @@ -85,8 +85,9 @@ def fetch_info(node, module): a_value = node.meta["mase"].parameters["common"]["args"]["data_in_0"]["value"] a_shape = node.meta["mase"].parameters["common"]["args"]["data_in_0"]["shape"] - w_value = node.meta["mase"].parameters["common"]["args"]["weight"]["value"] - w_shape = node.meta["mase"].parameters["common"]["args"]["weight"]["shape"] + # Use actual module weight to preserve device placement + w_value = module.weight.data + w_shape = module.weight.shape out = { "module_type": "conv2d", @@ -112,8 +113,9 @@ def fetch_info(node, module): a_value = node.meta["mase"].parameters["common"]["args"]["data_in_0"]["value"] a_shape = node.meta["mase"].parameters["common"]["args"]["data_in_0"]["shape"] - w_value = node.meta["mase"].parameters["common"]["args"]["weight"]["value"] - w_shape = node.meta["mase"].parameters["common"]["args"]["weight"]["shape"] + # Use actual module weight to preserve device placement + w_value = module.weight.data + w_shape = module.weight.shape out = { "module_type": "linear", "weight_value": w_value, diff --git a/src/chop/passes/graph/transforms/pruning/pruning_methods.py b/src/chop/passes/graph/transforms/pruning/pruning_methods.py index 665abc17e..bccc144d4 100644 --- a/src/chop/passes/graph/transforms/pruning/pruning_methods.py +++ b/src/chop/passes/graph/transforms/pruning/pruning_methods.py @@ -92,7 +92,7 @@ def neurons_random_rank( :return: a sparsity mask :rtype: torch.Tensor """ - mask = torch.ones(tensor.size(), dtype=torch.bool) + mask = torch.ones(tensor.size(), dtype=torch.bool, device=tensor.device) mask = mask.reshape(tensor.shape[0], -1) if layer_type == "Linear": for i in range(tensor.shape[0]): @@ -115,7 +115,7 @@ def neurons_random_fan_in( ) -> torch.Tensor: if fan_in == None: raise ValueError("fan_in is not been specified") - mask = torch.zeros(tensor.size(), dtype=torch.bool) + mask = torch.zeros(tensor.size(), dtype=torch.bool, device=tensor.device) mask = mask.reshape(tensor.shape[0], -1) if layer_type == "Linear": for i in range(tensor.shape[0]): diff --git a/src/chop/passes/graph/transforms/verilog/dram_emit_tb.py b/src/chop/passes/graph/transforms/verilog/dram_emit_tb.py new file mode 100644 index 000000000..a5ba8d6f3 --- /dev/null +++ b/src/chop/passes/graph/transforms/verilog/dram_emit_tb.py @@ -0,0 +1,333 @@ +import logging, torch +from pathlib import Path +from textwrap import indent + +from chop.passes.graph.utils import vf, v2p, init_project +from chop.nn.quantizers import ( + integer_quantizer_for_hw, + integer_quantizer, +) + +logger = logging.getLogger(__name__) + +from pathlib import Path + +torch.manual_seed(0) + +import cocotb +from mase_cocotb.testbench import Testbench +from mase_cocotb.interfaces.streaming import StreamDriver, StreamMonitor + + +import dill +import inspect + + +def _cap(name): + """ + capitalize a string + """ + return str(name).upper() + + +def _emit_cocotb_test(graph, pass_args={}): + + wait_time = pass_args.get("wait_time", 2) + wait_unit = pass_args.get("wait_units", "ms") + batch_size = pass_args.get("batch_size", 1) + + test_template = f""" +import cocotb + +@cocotb.test() +async def test(dut): + from pathlib import Path + import dill + from cocotb.triggers import Timer + + tb_path = Path.home() / ".mase" / "top" / "hardware" / "test" / "mase_top_tb" + with open(tb_path / "tb_obj.dill", "rb") as f: + tb = dill.load(f)(dut, fail_on_checks=True) + + await tb.initialize() + + in_tensors = tb.generate_inputs(batches={batch_size}) + exp_out = tb.model(*list(in_tensors.values())) + + tb.load_drivers(in_tensors) + tb.load_monitors(exp_out) + + await tb.wait_end(timeout={wait_time}, timeout_unit="{wait_unit}") +""" + + tb_path = Path.home() / ".mase" / "top" / "hardware" / "test" / "mase_top_tb" + tb_path.mkdir(parents=True, exist_ok=True) + with open(tb_path / "test.py", "w") as f: + f.write(test_template) + + +def _emit_cocotb_tb(graph): + class MaseGraphTB(Testbench): + def __init__(self, dut, fail_on_checks=True): + super().__init__(dut, dut.clk, dut.rst, fail_on_checks=fail_on_checks) + + # Instantiate as many drivers as required inputs to the model + self.input_drivers = {} + self.output_monitors = {} + + for node in graph.nodes_in: + for arg in node.meta["mase"]["common"]["args"].keys(): + if "data_in" not in arg: + continue + self.input_drivers[arg] = StreamDriver( + dut.clk, + getattr(dut, arg), + getattr(dut, f"{arg}_valid"), + getattr(dut, f"{arg}_ready"), + ) + self.input_drivers[arg].log.setLevel(logging.DEBUG) + + # Instantiate as many monitors as required outputs + for node in graph.nodes_out: + for result in node.meta["mase"]["common"]["results"].keys(): + if "data_out" not in result: + continue + self.output_monitors[result] = StreamMonitor( + dut.clk, + getattr(dut, result), + getattr(dut, f"{result}_valid"), + getattr(dut, f"{result}_ready"), + check=False, + ) + self.output_monitors[result].log.setLevel(logging.DEBUG) + + self.model = graph.model + + # To do: precision per input argument + self.input_precision = graph.meta["mase"]["common"]["args"]["data_in_0"][ + "precision" + ] + + # Create StreamDriver instances for any DRAM weight/bias parameters. + # These become top-level input ports on the generated top.sv when + # add_hardware_metadata_analysis_pass is called with + # {"interface": {"storage": "DRAM"}}. + self.dram_drivers = {} + for node in graph.fx_graph.nodes: + if node.meta["mase"].parameters["hardware"]["is_implicit"]: + continue + node_name = vf(node.name) + for arg, arg_info in node.meta["mase"].parameters["common"]["args"].items(): + if "data_in" in arg or not isinstance(arg_info, dict): + continue + iface = node.meta["mase"].parameters["hardware"]["interface"].get(arg, {}) + if iface.get("storage") != "DRAM": + continue + # DRAM branch placeholder: + # customize DRAM-side driver type/protocol mapping here. + pass + port_name = f"{node_name}_{arg}" + self.dram_drivers[port_name] = StreamDriver( + dut.clk, + getattr(dut, port_name), + getattr(dut, f"{port_name}_valid"), + getattr(dut, f"{port_name}_ready"), + ) + self.dram_drivers[port_name].log.setLevel(logging.DEBUG) + + def generate_inputs(self, batches): + """ + Generate inputs for the model by sampling a random tensor + for each input argument, according to its shape + + :param batches: number of batches to generate for each argument + :type batches: int + :return: a dictionary of input arguments and their corresponding tensors + :rtype: Dict + """ + # ! TO DO: iterate through graph.args instead to generalize + inputs = {} + for node in graph.nodes_in: + for arg, arg_info in node.meta["mase"]["common"]["args"].items(): + # Batch dimension always set to 1 in metadata + if "data_in" not in arg: + continue + # print(f"Generating data for node {node}, arg {arg}: {arg_info}") + inputs[f"{arg}"] = torch.rand(([batches] + arg_info["shape"][1:])) + return inputs + + def load_drivers(self, in_tensors): + for arg, arg_batches in in_tensors.items(): + # DiffLogic: do not need precision, fully unrolled so 1 batch only + if "difflogic" in graph.nodes_in[0].meta["mase"]["hardware"]["module"]: + block = arg_batches[0].round().int().tolist() + if isinstance(block[0], list): + out = [] + for row in block: + num = "" + for i in range(len(row) - 1, -1, -1): + num += str(row[i]) + num = int(num, 2) + out.append(num) + else: + num = "" + for i in range(len(block) - 1, -1, -1): + num += str(block[i]) + num = int(num, 2) + out = [num] + self.input_drivers[arg].append(out) + continue + + # Quantize input tensor according to precision + if len(self.input_precision) > 1: + from mase_cocotb.utils import fixed_preprocess_tensor + + in_data_blocks = fixed_preprocess_tensor( + tensor=arg_batches, + q_config={ + "width": self.get_parameter(f"{_cap(arg)}_PRECISION_0"), + "frac_width": self.get_parameter( + f"{_cap(arg)}_PRECISION_1" + ), + }, + parallelism=[ + self.get_parameter(f"{_cap(arg)}_PARALLELISM_DIM_1"), + self.get_parameter(f"{_cap(arg)}_PARALLELISM_DIM_0"), + ], + ) + + else: + # TO DO: convert to integer equivalent of floating point representation + pass + + # Append all input blocks to input driver + # ! TO DO: generalize + block_size = self.get_parameter( + "DATA_IN_0_PARALLELISM_DIM_0" + ) * self.get_parameter("DATA_IN_0_PARALLELISM_DIM_1") + for block in in_data_blocks: + if len(block) < block_size: + block = block + [0] * (block_size - len(block)) + self.input_drivers[arg].append(block) + + # Preload DRAM weight/bias drivers with quantised parameter blocks. + # Each parameter tensor is split into blocks matching the hardware's + # parallelism dimensions, exactly as BRAM would have stored them. + if self.dram_drivers: + from mase_cocotb.utils import fixed_preprocess_tensor + + for node in graph.fx_graph.nodes: + if node.meta["mase"].parameters["hardware"]["is_implicit"]: + continue + node_name = vf(node.name) + for arg, arg_info in node.meta["mase"].parameters["common"]["args"].items(): + if "data_in" in arg or not isinstance(arg_info, dict): + continue + iface = node.meta["mase"].parameters["hardware"]["interface"].get(arg, {}) + if iface.get("storage") != "DRAM": + continue + # DRAM branch placeholder: + # customize parameter block ordering/packing for DRAM sources. + pass + port_name = f"{node_name}_{arg}" + if port_name not in self.dram_drivers: + continue + + param_tensor = node.meta["mase"].module.get_parameter(arg) + arg_cap = _cap(arg) + parallelism_0 = self.get_parameter( + f"{node_name}_{arg_cap}_PARALLELISM_DIM_0" + ) + parallelism_1 = self.get_parameter( + f"{node_name}_{arg_cap}_PARALLELISM_DIM_1" + ) + param_blocks = fixed_preprocess_tensor( + tensor=param_tensor, + q_config={ + "width": self.get_parameter( + f"{node_name}_{arg_cap}_PRECISION_0" + ), + "frac_width": self.get_parameter( + f"{node_name}_{arg_cap}_PRECISION_1" + ), + }, + parallelism=[parallelism_1, parallelism_0], + ) + block_size = parallelism_0 * parallelism_1 + for block in param_blocks: + if len(block) < block_size: + block = block + [0] * (block_size - len(block)) + self.dram_drivers[port_name].append(block) + + def load_monitors(self, expectation): + # DiffLogic: do not need precision, fully unrolled so 1 batch only + if "difflogic" in graph.nodes_out[0].meta["mase"]["hardware"]["module"]: + self.output_monitors["data_out_0"].expect(expectation) + self.output_monitors["data_out_0"].in_flight = True + return + + from mase_cocotb.utils import fixed_preprocess_tensor + + # Process the expectation tensor + output_blocks = fixed_preprocess_tensor( + tensor=expectation, + q_config={ + "width": self.get_parameter(f"DATA_OUT_0_PRECISION_0"), + "frac_width": self.get_parameter(f"DATA_OUT_0_PRECISION_1"), + }, + parallelism=[ + self.get_parameter(f"DATA_OUT_0_PARALLELISM_DIM_1"), + self.get_parameter(f"DATA_OUT_0_PARALLELISM_DIM_0"), + ], + ) + + # Set expectation for each monitor + for block in output_blocks: + # ! TO DO: generalize to multi-output models + if len(block) < self.get_parameter("DATA_OUT_0_PARALLELISM_DIM_0"): + block = block + [0] * ( + self.get_parameter("DATA_OUT_0_PARALLELISM_DIM_0") - len(block) + ) + self.output_monitors["data_out_0"].expect(block) + + # Drive the in-flight flag for each monitor + self.output_monitors["data_out_0"].in_flight = True + + # Serialize testbench object to be instantiated within test by cocotb runner + cls_obj = MaseGraphTB + tb_path = Path.home() / ".mase" / "top" / "hardware" / "test" / "mase_top_tb" + tb_path.mkdir(parents=True, exist_ok=True) + with open(tb_path / "tb_obj.dill", "wb") as file: + dill.dump(cls_obj, file) + with open(tb_path / "__init__.py", "w") as file: + file.write("from .test import test") + + +def emit_cocotb_transform_pass(graph, pass_args={}): + """ + Emit test bench and related files for simulation + + :param graph: a MaseGraph + :type graph: MaseGraph + :param pass_args: this pass requires additional arguments which is explained below, defaults to {} + :type pass_args: _type_, optional + :return: return a tuple of a MaseGraph and an empty dict (no additional info to return) + :rtype: tuple(MaseGraph, Dict) + + - pass_args + - project_dir -> str : the directory of the project + - trace -> bool : trace waves in the simulation + """ + logger.info("Emitting testbench...") + project_dir = ( + pass_args["project_dir"] + if "project_dir" in pass_args.keys() + else Path.home() / ".mase" / "top" + ) + + init_project(project_dir) + + _emit_cocotb_test(graph, pass_args=pass_args) + _emit_cocotb_tb(graph) + + return graph, None diff --git a/src/chop/passes/graph/transforms/verilog/emit_bram.py b/src/chop/passes/graph/transforms/verilog/emit_bram.py index 4f9190fce..2b5e449d9 100644 --- a/src/chop/passes/graph/transforms/verilog/emit_bram.py +++ b/src/chop/passes/graph/transforms/verilog/emit_bram.py @@ -240,7 +240,8 @@ def emit_parameters_in_dat_internal(node, param_name, file_name): "precision" ][1] - if node.meta["mase"].module.config.get("floor", False): + module_config = getattr(node.meta["mase"].module, "config", None) or {} + if module_config.get("floor", False): base_quantizer = integer_floor_quantizer_for_hw else: base_quantizer = integer_quantizer_for_hw @@ -365,8 +366,31 @@ def emit_bram_handshake(node, rtl_dir): ) emit_parameters_in_mem_internal(node, param_name, verilog_name, data_name) emit_parameters_in_dat_internal(node, param_name, data_name) + elif ( + node.meta["mase"].parameters["hardware"]["interface"][param_verilog_name][ + "storage" + ] + == "DRAM" + ): + # DRAM branch placeholder: + # emit/copy DRAM-side helper SV files (controller adapter stubs, etc.) if needed. + # + pass + logger.debug( + f"Skipping DAT emission for node: {node_name}, parameter: {param_verilog_name} (DRAM)" + ) + # Remove stale BRAM source files from any previous BRAM runs so + # verilator does not attempt to compile them. + for stale in [ + os.path.join(rtl_dir, f"{node_name}_{param_verilog_name}_source.sv"), + os.path.join(rtl_dir, f"{node_name}_{param_verilog_name}_rom.dat"), + ]: + if os.path.exists(stale): + os.remove(stale) + logger.debug(f"Removed stale BRAM file: {stale}") + else: - assert False, "Emtting parameters in non-BRAM hardware is not supported." + assert False, "Emitting parameters in non-BRAM/DRAM hardware is not supported." def emit_parameters_in_mem_hls(node, param_name, file_name, data_name): @@ -468,15 +492,19 @@ def emit_bram_hls(node, rtl_dir): """ node_name = vf(node.name) for param_name, parameter in node.meta["mase"].module.named_parameters(): - if ( - node.meta["mase"].parameters["hardware"]["interface"][param_name]["storage"] - == "BRAM" - ): + storage = node.meta["mase"].parameters["hardware"]["interface"][param_name][ + "storage" + ] + if storage == "BRAM": # Verilog code of the ROM has been emitted using mlir passes verilog_name = os.path.join(rtl_dir, f"{node_name}_{param_name}.sv") data_name = os.path.join(rtl_dir, f"{node_name}_{param_name}_rom.dat") emit_parameters_in_mem_hls(node, param_name, verilog_name, data_name) emit_parameters_in_dat_hls(node, param_name, data_name) + elif storage == "DRAM": + # DRAM branch placeholder: + # support DRAM-backed parameter path for HLS components. + pass else: assert False, "Emtting parameters in non-BRAM hardware is not supported." diff --git a/src/chop/passes/graph/transforms/verilog/emit_internal.py b/src/chop/passes/graph/transforms/verilog/emit_internal.py index 922af30d5..57301000e 100644 --- a/src/chop/passes/graph/transforms/verilog/emit_internal.py +++ b/src/chop/passes/graph/transforms/verilog/emit_internal.py @@ -35,6 +35,11 @@ def emit_internal_rtl_transform_pass(graph, pass_args={}): for node in graph.fx_graph.nodes: if node.meta["mase"].parameters["hardware"]["is_implicit"]: continue + interface = node.meta["mase"].parameters["hardware"].get("interface", {}) + if any(v.get("storage") == "DRAM" for v in interface.values() if isinstance(v, dict)): + # DRAM branch placeholder: + # inject DRAM-specific dependency handling before file inclusion. + pass if "INTERNAL_RTL" == node.meta["mase"].parameters["hardware"]["toolchain"]: if ( hasattr(node.meta["mase"].module, "config") diff --git a/src/chop/passes/graph/transforms/verilog/emit_tb.py b/src/chop/passes/graph/transforms/verilog/emit_tb.py index e8c759d80..ae626dc78 100644 --- a/src/chop/passes/graph/transforms/verilog/emit_tb.py +++ b/src/chop/passes/graph/transforms/verilog/emit_tb.py @@ -68,8 +68,119 @@ async def test(dut): def _emit_cocotb_tb(graph): class MaseGraphTB(Testbench): + def _dram_storage_enabled(self): + """Return True if any argument is explicitly marked with DRAM storage.""" + observed_interfaces = [] + interface_count = 0 + + for node in graph.fx_graph.nodes: + if not hasattr(node, "meta") or "mase" not in node.meta: + continue + + node_name = vf(getattr(node, "name", "unknown_node")) + node_mase = node.meta["mase"] + # Metadata can be stored both in object-style + # node.meta["mase"].parameters[...] and dict-style + # node.meta["mase"]["hardware"][...]. Check both. + node_params = getattr(node_mase, "parameters", None) + if node_params is None and isinstance(node_mase, dict): + node_params = node_mase.get("parameters", {}) + + hardware_from_params = ( + node_params.get("hardware", {}) if isinstance(node_params, dict) else {} + ) + hardware_from_dict = ( + node_mase.get("hardware", {}) if isinstance(node_mase, dict) else {} + ) + + is_implicit = hardware_from_params.get( + "is_implicit", hardware_from_dict.get("is_implicit", False) + ) + if is_implicit: + self._sim_log.info( + "DRAM detect: skip implicit node '%s'", node_name + ) + continue + + interfaces = hardware_from_params.get("interface", {}) + source = "parameters.hardware.interface" + if not interfaces: + interfaces = hardware_from_dict.get("interface", {}) + source = "mase.hardware.interface" + + if not interfaces: + self._sim_log.info( + "DRAM detect: node '%s' has no interface metadata", node_name + ) + continue + + for iface_name, interface_cfg in interfaces.items(): + if not isinstance(interface_cfg, dict): + self._sim_log.info( + "DRAM detect: node '%s' iface '%s' from %s has non-dict config", + node_name, + iface_name, + source, + ) + continue + + storage = interface_cfg.get("storage", "") + interface_count += 1 + if len(observed_interfaces) < 24: + observed_interfaces.append(f"{node_name}.{iface_name}:{storage}") + + self._sim_log.info( + "DRAM detect: node='%s' iface='%s' storage='%s' source=%s", + node_name, + iface_name, + storage, + source, + ) + + if storage == "DRAM": + self._sim_log.info( + "DRAM detect: selected DRAM mode from node '%s' iface '%s'", + node_name, + iface_name, + ) + return True + + if interface_count == 0: + self._sim_log.warning( + "DRAM detect: no interface entries were found; defaulting to BRAM mode" + ) + else: + self._sim_log.warning( + "DRAM detect: scanned %d interface entries, no DRAM found. Sample: %s", + interface_count, + ", ".join(observed_interfaces), + ) + + return False + + def _discover_dram_param_specs_from_dut(self, dut): + """Discover streamable parameter ports by probing DUT signal names. + + This is a robust fallback when graph interface metadata is absent in + serialized testbench objects. + """ + specs = {} + for full_name, param_tensor in self.model.named_parameters(): + if "." not in full_name: + continue + module_name, arg = full_name.rsplit(".", 1) + node_name = vf(module_name) + port_name = f"{node_name}_{arg}" + required = [port_name, f"{port_name}_valid", f"{port_name}_ready"] + missing = [sig for sig in required if not hasattr(dut, sig)] + if missing: + continue + specs[port_name] = (node_name, arg, param_tensor) + return specs + def __init__(self, dut, fail_on_checks=True): super().__init__(dut, dut.clk, dut.rst, fail_on_checks=fail_on_checks) + self._sim_log = getattr(dut, "_log", logger) # Instantiate as many drivers as required inputs to the model self.input_drivers = {} @@ -108,6 +219,51 @@ def __init__(self, dut, fail_on_checks=True): "precision" ] + # Create StreamDriver instances for streamed parameter ports. + # Map model parameters to top-level DUT ports by name, e.g.: + # fc1.weight -> fc1_weight / fc1_weight_valid / fc1_weight_ready + self.dram_drivers = {} + self.dram_param_specs = {} + metadata_dram_mode = self._dram_storage_enabled() + discovered_specs = self._discover_dram_param_specs_from_dut(dut) + + self.dram_mode = metadata_dram_mode or bool(discovered_specs) + if not metadata_dram_mode and discovered_specs: + self._sim_log.warning( + "DRAM detect: metadata did not expose DRAM interfaces, " + "but discovered %d DRAM-style DUT parameter ports; enabling DRAM mode", + len(discovered_specs), + ) + + if self.dram_mode: + for port_name, (node_name, arg, param_tensor) in discovered_specs.items(): + + self.dram_drivers[port_name] = StreamDriver( + dut.clk, + getattr(dut, port_name), + getattr(dut, f"{port_name}_valid"), + getattr(dut, f"{port_name}_ready"), + ) + self.dram_param_specs[port_name] = (node_name, arg, param_tensor) + self.dram_drivers[port_name].log.setLevel(logging.DEBUG) + self._sim_log.info( + "Bound parameter StreamDriver for DUT port '%s'", port_name + ) + + logger.debug("Discovered %d streamed parameter ports", len(self.dram_drivers)) + + if self.dram_drivers: + self._sim_log.info( + "DRAM drivers enabled for ports: %s", + ", ".join(sorted(self.dram_drivers.keys())), + ) + else: + self._sim_log.warning( + "DRAM mode enabled, but no parameter stream ports were discovered" + ) + else: + self._sim_log.info("BRAM mode detected; skipping DRAM parameter stream setup") + def generate_inputs(self, batches): """ Generate inputs for the model by sampling a random tensor @@ -130,6 +286,60 @@ def generate_inputs(self, batches): return inputs def load_drivers(self, in_tensors): + # Preload DRAM parameter drivers with quantized parameter blocks. + if self.dram_mode and self.dram_drivers: + from mase_cocotb.utils import fixed_preprocess_tensor + + self._sim_log.info("Preloading DRAM parameter streams for %d ports", len(self.dram_drivers)) + total_blocks_queued = 0 + + for port_name, (node_name, arg, param_tensor) in self.dram_param_specs.items(): + + arg_cap = _cap(arg) + parallelism_0 = self.get_parameter( + f"{node_name}_{arg_cap}_PARALLELISM_DIM_0" + ) + parallelism_1 = self.get_parameter( + f"{node_name}_{arg_cap}_PARALLELISM_DIM_1" + ) + width = self.get_parameter(f"{node_name}_{arg_cap}_PRECISION_0") + frac_width = self.get_parameter(f"{node_name}_{arg_cap}_PRECISION_1") + + param_blocks = fixed_preprocess_tensor( + tensor=param_tensor, + q_config={ + "width": width, + "frac_width": frac_width, + }, + parallelism=[parallelism_1, parallelism_0], + ) + block_size = parallelism_0 * parallelism_1 + port_blocks_queued = 0 + for block in param_blocks: + if len(block) < block_size: + block = block + [0] * (block_size - len(block)) + self.dram_drivers[port_name].append(block) + port_blocks_queued += 1 + + total_blocks_queued += port_blocks_queued + self._sim_log.info( + "Queued %d DRAM blocks for port '%s' (block_size=%d)", + port_blocks_queued, + port_name, + block_size, + ) + + self._sim_log.info("Queued %d DRAM parameter blocks in total", total_blocks_queued) + assert total_blocks_queued > 0, ( + "DRAM mode detected, but zero DRAM parameter blocks were queued. " + "Check DRAM interface metadata/parameter extraction." + ) + elif self.dram_mode: + self._sim_log.warning("DRAM mode enabled, but no DRAM parameter streams to preload") + else: + self._sim_log.info("No DRAM parameter streams to preload") + + self._sim_log.info("Loading input drivers for %d tensor inputs", len(in_tensors)) for arg, arg_batches in in_tensors.items(): # DiffLogic: do not need precision, fully unrolled so 1 batch only if "difflogic" in graph.nodes_in[0].meta["mase"]["hardware"]["module"]: diff --git a/src/chop/passes/graph/transforms/verilog/emit_top.py b/src/chop/passes/graph/transforms/verilog/emit_top.py index 5f0bdd94a..a9d89c284 100644 --- a/src/chop/passes/graph/transforms/verilog/emit_top.py +++ b/src/chop/passes/graph/transforms/verilog/emit_top.py @@ -194,7 +194,39 @@ def emit(self, graph, parameter_map): input data_out_{i}_ready,""" i += 1 - # TODO: emit off-chip parameter interface + # Emit DRAM parameter ports for off-chip storage + for node in self.graph.fx_graph.nodes: + if not hasattr(node, "meta") or "mase" not in node.meta: + continue + node_params = node.meta["mase"].parameters + hardware = node_params.get("hardware", {}) + if hardware.get("is_implicit", False): + continue + if "INTERNAL" not in hardware.get("toolchain", ""): + continue + node_name = vf(node.name) + for arg, arg_info in node_params.get("common", {}).get("args", {}).items(): + if "data_in" in arg or not isinstance(arg_info, dict): + continue + if hardware.get("interface", {}).get(arg, {}).get("storage", "BRAM") != "DRAM": + continue + interface += "\n // this is for DRAM" + interface += "\n // DRAM parameter streaming protocol:" + interface += "\n // - carries one packed beat of parameter elements" + interface += "\n // - _valid asserted by upstream DRAM adapter when beat is valid" + interface += "\n // - _ready asserted by compute when beat is consumed" + interface += "\n // - transfer occurs when both valid and ready are high" + interface += "\n // Future extension point: add sideband ports (addr/last/id/user) if needed" + arg_name = v2p(arg) + parallelism_params = [ + param + for param in parameter_map + if f"{node_name}_{arg_name}_PARALLELISM_DIM" in param + ] + interface += f""" + input [{node_name}_{arg_name}_PRECISION_0-1:0] {node_name}_{arg} [{'*'.join(parallelism_params)}-1:0], + input {node_name}_{arg}_valid, + output {node_name}_{arg}_ready,""" return _remove_last_comma(interface) @@ -433,6 +465,21 @@ def _emit_getitem_signals(self, node): def emit(self, node, parameter_map): node_name = vf(node.name) component_name = node.meta["mase"].parameters["hardware"]["module"] + + # # For staged bring-up, allow fc1 to target a DRAM-specialized module name. + # if node_name == "fc1": + # has_dram_param = any( + # ( + # isinstance(arg_info, dict) + # and "data_in" not in arg + # and node.meta["mase"].parameters["hardware"]["interface"][arg]["storage"] + # == "DRAM" + # ) + # for arg, arg_info in node.meta["mase"].parameters["common"]["args"].items() + # ) + # if has_dram_param: + # component_name = f"{component_name}_dram" + signals = "" # Emit component instantiation parameters @@ -506,12 +553,20 @@ def emit(self, node, parameter_map): # if node.meta["mase"]["hardware"].get("module") in ["fixed_difflogic_logic", "fixed_difflogic_logic"]: # return components - # Emit module parameter instances (e.g. weights and biases) + # Emit module parameter instances (e.g. weights and biases) - BRAM only for arg, arg_info in node.meta["mase"].parameters["common"]["args"].items(): if "data_in" in arg: continue if not isinstance(arg_info, dict): continue + # DRAM parameters are driven from top-level ports, not internal BRAM sources + if node.meta["mase"].parameters["hardware"]["interface"][arg]["storage"] == "DRAM": + components += f""" +// {node_name}.{arg}: DRAM-backed parameter +// Protocol owner is top-level integration (e.g. DRAM/AXI adapter), +// so no internal *_source BRAM module is instantiated here. +""" + continue components += self._emit_module_parameters_top_internal( arg, arg_info, node, parameter_map @@ -698,6 +753,29 @@ def _emit_top_wires(self): i += 1 # TODO: emit off-chip parameter interface + has_dram_params = False + for node in self.graph.fx_graph.nodes: + if node.meta["mase"].parameters["hardware"]["is_implicit"]: + continue + for arg, arg_info in node.meta["mase"].parameters["common"]["args"].items(): + if "data_in" in arg or not isinstance(arg_info, dict): + continue + if node.meta["mase"].parameters["hardware"]["interface"][arg]["storage"] == "DRAM": + has_dram_params = True + break + if has_dram_params: + break + if has_dram_params: + wires += """ +// -------------------------- +// DRAM protocol notes +// -------------------------- +// DRAM-backed parameter ports are connected directly at top-level. +// Beat transfer rule: param_valid && param_ready. +// If a memory adapter is required (AXI, burst unpacking, address generation), +// instantiate it in the top-level integration layer and drive the emitted +// _ / __valid / __ready ports. +""" return wires diff --git a/src/chop/passes/graph/transforms/verilog/util.py b/src/chop/passes/graph/transforms/verilog/util.py index a0414f9e6..44cef273b 100644 --- a/src/chop/passes/graph/transforms/verilog/util.py +++ b/src/chop/passes/graph/transforms/verilog/util.py @@ -94,4 +94,10 @@ def include_ip_to_project(node): Copy internal files to the project """ mase_op = node.meta["mase"].parameters["common"]["mase_op"] + interface = node.meta["mase"].parameters["hardware"].get("interface", {}) + if any(v.get("storage") == "DRAM" for v in interface.values() if isinstance(v, dict)): + # DRAM branch placeholder: + # select an alternate dependency file list for DRAM-specific SV templates. + # Current behavior intentionally returns default dependence_files. + pass return node.meta["mase"].parameters["hardware"]["dependence_files"] diff --git a/test/passes/graph/transforms/verilog/test_emit_verilog_linear.py b/test/passes/graph/transforms/verilog/test_emit_verilog_linear.py index e10e61c3e..1b6a532e7 100644 --- a/test/passes/graph/transforms/verilog/test_emit_verilog_linear.py +++ b/test/passes/graph/transforms/verilog/test_emit_verilog_linear.py @@ -116,7 +116,7 @@ def test_emit_verilog_linear(): mg, _ = passes.emit_bram_transform_pass(mg) mg, _ = passes.emit_internal_rtl_transform_pass(mg) mg, _ = passes.emit_cocotb_transform_pass( - mg, pass_args={"wait_time": 100, "wait_unit": "ms", "batch_size": batch_size} + mg, pass_args={"wait_time": 2, "wait_units": "ms", "batch_size": batch_size} ) mg, _ = passes.emit_vivado_project_transform_pass(mg)