Intended to be used as a server-client model. The SamEncoder.kt interface is intended to query image embeddings from a Triton server using gRPC.
The returned embedding are then provide to a SamDecoder.kt along with a prompt to run the decoding locally.
The usecase is to offload the slower/more expensive image embedding to a server, while running the decoder locally to allow for real-time prompt refinement.
The encoder models are dependent on what the triton server endpoint supports. However, the decoder models
must align with the encoder models. Below is the current list of supported decoders provided, and the matched encoder
models. All decoder models are onnx models.
The sam_vit_h_4b8983 model is supported by Sam1Decoder.kt.
The sam2.1_hiera_large_decoder model is supported by Sam2Decoder.kt.
Currently only the SAM 3 tracker model is supported. An fp16 variant is provided as sam3_tracker_fp16_decoder and
supported by Sam3Decoder.kt.
See licenses_readme.md for license information on the provided models.
Encode against a Triton server, decode locally, threshold the logits at 0, and write the binary
mask to disk:
/* Any BufferedImage will do */
val image = ImageIO.read(File("input.png"))
/* Point the encoder to the server */
val encoder = Sam2TritonEncoder(
host = "triton.local",
port = 8001,
model = "sam2.1_hiera_large_encoder"
)
val decoder = Sam2Decoder(DecoderModel.SAM2.load())
encoder.use { enc ->
decoder.use { dec ->
runBlocking {
enc.encode(image).use { embedding ->
/* Adds a FOREGROUND point at the center of the image. */
val prompt = SamPrompt().addPoint(
image.width / 2f,
image.height / 2f,
SamPointLabel.FOREGROUND,
)
val result = dec.decode(embedding, prompt)
val mask = result.bestMask
val size = result.maskSize
val out = BufferedImage(size, size, BufferedImage.TYPE_BYTE_GRAY)
val pixels = (out.raster.dataBuffer as DataBufferByte).data
/* threshold the mask */
for (i in mask.indices)
pixels[i] = if (mask[i] > 0f) 255.toByte() else 0
ImageIO.write(out, "png", File("mask.png"))
}
}
}
}