Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions weightslab/examples/PyTorch/ws-classification/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def _build_filepath_mapping(self):
# For each index, construct a meaningful filepath
# MNIST doesn't have original individual files, so we create virtual paths
for idx in range(len(self.mnist)):
if self.max_samples is not None and idx >= self.max_samples:
if self.max_samples != None and idx >= self.max_samples:
break
label = self.mnist.targets[idx].item() if hasattr(self.mnist.targets[idx], 'item') else self.mnist.targets[idx]
split = 'train' if self.train else 'test'
Expand All @@ -105,7 +105,7 @@ def _build_filepath_mapping(self):
self.filepaths[idx] = virtual_path

def __len__(self):
if self.max_samples is not None:
if self.max_samples != None:
return min(len(self.mnist), self.max_samples)
return len(self.mnist)

Expand Down Expand Up @@ -371,15 +371,15 @@ def test(loader, model, criterion_mlt, metric_mlt, device, test_loader_len):
# Setup clean progress bar with custom format
if tqdm_display:
train_range = tqdm.tqdm(
range(training_steps_to_do) if training_steps_to_do is not None else itertools.count(),
range(training_steps_to_do) if training_steps_to_do != None else itertools.count(),
desc="Training",
bar_format="{desc}: {n}/{total} [{elapsed}<{remaining}, {rate_fmt}] {bar} | {postfix}",
ncols=140,
position=0,
leave=True
)
else:
train_range = range(training_steps_to_do) if training_steps_to_do is not None else itertools.count()
train_range = range(training_steps_to_do) if training_steps_to_do != None else itertools.count()

# ================
# Training Loop
Expand Down
2 changes: 1 addition & 1 deletion weightslab/examples/PyTorch/ws-detection/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def __init__(
val_set = set(all_imgs[::k])
selected = [f for f in all_imgs if f not in val_set]

selected = selected[:max_samples] if max_samples is not None else selected
selected = selected[:max_samples] if max_samples != None else selected

self.images = []
self.masks = []
Expand Down
2 changes: 1 addition & 1 deletion weightslab/examples/PyTorch/ws-segmentation/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(
for f in os.listdir(img_dir)
if f.lower().endswith((".jpg", ".jpeg", ".png"))
]
image_files = sorted(set(image_files))[:max_samples] if max_samples is not None else sorted(set(image_files)) # Optionally limit number of samples for faster testing
image_files = sorted(set(image_files))[:max_samples] if max_samples != None else sorted(set(image_files)) # Optionally limit number of samples for faster testing

self.images = []
self.masks = []
Expand Down
2 changes: 1 addition & 1 deletion weightslab/examples/Ultralytics/ws-detection/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def main():
trainer=WLAwareTrainer,
data=data_root,
imgsz=image_size,
epochs=1000 if max_steps is None else max(1, int(max_steps)),
epochs=1000 if max_steps == None else max(1, int(max_steps)),
device=device,
project=project, name=name, # → UL save_dir → WL logger log_dir/name
resume=False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def __init__(
else:
val_set = set(frames[::k])
selected = [f for f in frames if f not in val_set]
self.frames = selected[:max_samples] if max_samples is not None else selected
self.frames = selected[:max_samples] if max_samples != None else selected

def __len__(self):
return len(self.frames)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def __init__(
else:
val_set = set(frames[::k])
selected = [f for f in frames if f not in val_set]
self.frames = selected[:max_samples] if max_samples is not None else selected
self.frames = selected[:max_samples] if max_samples != None else selected

if len(self.frames) == 0:
raise RuntimeError(f"No LiDAR frames found (source={source}, root={root})")
Expand Down
2 changes: 2 additions & 0 deletions weightslab/proto/experiment_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,8 @@ message HistogramResponse {
string message = 2;
int64 total_rows = 3; // rows in the view that were binned
repeated HistogramBin bins = 4;
}

// --- Metadata retrieval (separated from GetDataSamples) ---
message GetMetaDataRequest {
int32 start_index = 1; // grid slice start (current view order)
Expand Down
247 changes: 70 additions & 177 deletions weightslab/proto/experiment_service_pb2.py

Large diffs are not rendered by default.

40 changes: 35 additions & 5 deletions weightslab/proto/experiment_service_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from weightslab.proto import experiment_service_pb2 as weightslab_dot_proto_dot_experiment__service__pb2

GRPC_GENERATED_VERSION = '1.81.1'
GRPC_GENERATED_VERSION = '1.76.0'
GRPC_VERSION = grpc.__version__
_version_not_supported = False

Expand All @@ -25,7 +25,7 @@
)


class ExperimentServiceStub:
class ExperimentServiceStub(object):
"""Missing associated documentation comment in .proto file."""

def __init__(self, channel):
Expand Down Expand Up @@ -78,6 +78,7 @@ def __init__(self, channel):
'/ExperimentService/GetHistogram',
request_serializer=weightslab_dot_proto_dot_experiment__service__pb2.HistogramRequest.SerializeToString,
response_deserializer=weightslab_dot_proto_dot_experiment__service__pb2.HistogramResponse.FromString,
_registered_method=True)
self.GetMetaData = channel.unary_unary(
'/ExperimentService/GetMetaData',
request_serializer=weightslab_dot_proto_dot_experiment__service__pb2.GetMetaDataRequest.SerializeToString,
Expand Down Expand Up @@ -145,7 +146,7 @@ def __init__(self, channel):
_registered_method=True)


class ExperimentServiceServicer:
class ExperimentServiceServicer(object):
"""Missing associated documentation comment in .proto file."""

def GetLatestLoggerData(self, request, context):
Expand Down Expand Up @@ -199,6 +200,11 @@ def GetDataSamples(self, request, context):

def GetHistogram(self, request, context):
"""Server-side histogram binning of one metadata/signal column.
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def GetMetaData(self, request, context):
"""Metadata-only retrieval (dataframe columns). Returns every metadata column
name for the WHOLE dataset, the current grid slice's per-sample metadata, and
Expand Down Expand Up @@ -332,6 +338,7 @@ def add_ExperimentServiceServicer_to_server(servicer, server):
servicer.GetHistogram,
request_deserializer=weightslab_dot_proto_dot_experiment__service__pb2.HistogramRequest.FromString,
response_serializer=weightslab_dot_proto_dot_experiment__service__pb2.HistogramResponse.SerializeToString,
),
'GetMetaData': grpc.unary_unary_rpc_method_handler(
servicer.GetMetaData,
request_deserializer=weightslab_dot_proto_dot_experiment__service__pb2.GetMetaDataRequest.FromString,
Expand Down Expand Up @@ -405,7 +412,7 @@ def add_ExperimentServiceServicer_to_server(servicer, server):


# This class is part of an EXPERIMENTAL API.
class ExperimentService:
class ExperimentService(object):
"""Missing associated documentation comment in .proto file."""

@staticmethod
Expand Down Expand Up @@ -626,7 +633,6 @@ def GetDataSamples(request,

@staticmethod
def GetHistogram(request,
def GetMetaData(request,
target,
options=(),
channel_credentials=None,
Expand All @@ -642,6 +648,30 @@ def GetMetaData(request,
'/ExperimentService/GetHistogram',
weightslab_dot_proto_dot_experiment__service__pb2.HistogramRequest.SerializeToString,
weightslab_dot_proto_dot_experiment__service__pb2.HistogramResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)

@staticmethod
def GetMetaData(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/ExperimentService/GetMetaData',
weightslab_dot_proto_dot_experiment__service__pb2.GetMetaDataRequest.SerializeToString,
weightslab_dot_proto_dot_experiment__service__pb2.GetMetaDataResponse.FromString,
Expand Down