<config lang="json">
{
"name": "mi-core",
"type": "web-python",
"tags": [],
"flags": [],
"ui": "",
"version": "0.1.0",
"cover": "",
"description": "Connect to the bioengine server, and execute operations.",
"icon": "extension",
"inputs": null,
"outputs": null,
"api_version": "0.1.8",
"env": "",
"permissions": [],
"requirements": ["imageio", "xarray"],
"dependencies": []
}
</config>
<script lang="python">
import io
import imageio
import numpy as np
from xarray import DataArray
from imjoy import api
from imjoy_rpc.hypha import connect_to_server
try:
import pyodide
is_pyodide = True
except ImportError:
is_pyodide = False
async def fetch_file_content(url) -> bytes:
"""Fetch file content from url, return bytes.
Compatible with both pyodide and native-python.
Reference:
https://github.com/imjoy-team/kaibu-utils/blob/ecc25337adb0c94e6345f09bba80aa0637ce9af0/kaibu_utils/__init__.py#L403-L425
"""
await api.log("Fetch URL: " + url)
if is_pyodide:
from js import fetch
response = await fetch(url)
bytes_ = await response.arrayBuffer()
bytes_ = bytes_.to_py()
else:
import requests
bytes_ = requests.get(url)
await api.log("Fetched bytes: " + str(len(bytes_)))
return bytes_
def is_channel_first(shape):
if len(shape) == 5: # with batch dimension
shape = shape[1:]
min_dim = np.argmin(list(shape))
if min_dim == 0: # easy case: channel first
return True
elif min_dim == len(shape) - 1: # easy case: channel last
return False
else: # hard case: can't figure it out, just guess channel first
return True
def get_default_image_axes(shape, input_tensor_axes):
ndim = len(shape)
has_z_axis = "z" in input_tensor_axes
if ndim == 2:
axes = "yx"
elif ndim == 3 and has_z_axis:
axes = "zyx"
elif ndim == 3:
channel_first = is_channel_first(shape)
axes = "cyx" if channel_first else "yxc"
elif ndim == 4 and has_z_axis:
channel_first = is_channel_first(shape)
axes = "czyx" if channel_first else "zyxc"
elif ndim == 4:
channel_first = is_channel_first(shape)
axes = "bcyx" if channel_first else "byxc"
elif ndim == 5:
channel_first = is_channel_first(shape)
axes = "bczyx" if channel_first else "bzyxc"
else:
raise ValueError(f"Invalid number of image dimensions: {ndim}")
return axes
def map_axes(
input_array,
input_axes,
output_axes,
# spatial axes: drop at middle coordnate, other axes (channel or batch): drop at 0 coordinate
drop_function=lambda ax_name, ax_len: ax_len // 2 if ax_name in "zyx" else 0
):
assert len(input_axes) == input_array.ndim, f"Number of axes {len(input_axes)} and dimension of input {input_array.ndim} don't match"
shape = {ax_name: sh for ax_name, sh in zip(input_axes, input_array.shape)}
output = DataArray(input_array, dims=tuple(input_axes))
# drop axes not part of the output
drop_axis_names = tuple(set(input_axes) - set(output_axes))
drop_axes = {ax_name: drop_function(ax_name, shape[ax_name]) for ax_name in drop_axis_names}
output = output[drop_axes]
# expand axes missing from the input
missing_axes = tuple(set(output_axes) - set(input_axes))
output = output.expand_dims(dim=missing_axes)
# transpose to the desired axis order
output = output.transpose(*tuple(output_axes))
# return numpy array
return output.values
def transform_input(image: np.ndarray, image_axes: str, output_axes: str):
"""Transfor the input image into an output tensor with output_axes
Args:
image: the input image
image_axes: the axes of the input image as simple string
output_axes: the axes of the output tensor that will be returned
"""
return map_axes(image, image_axes, output_axes)
class Plugin():
async def setup(self):
api.log("Connector plugin is ready.")
server = await connect_to_server(
{"name": "client", "server_url": "https://ai.imjoy.io", "method_timeout": 30}
)
self.triton = await server.get_service("triton-client")
self.image = None
self.vtk_viewer = await api.createWindow(
src="https://oeway.github.io/itk-vtk-viewer/",
fullscreen=False
)
def set_current_image(self, image):
assert isinstance(image, np.ndarray)
self.image = image
async def get_current_image_shape(self):
assert self.image is not None
return self.image.shape
async def show_image_vtk(self):
if self.image is None:
await api.alert("Please load an image first.")
return
image = self.image
self.vtk_viewer.setImage(image)
async def bioengine_execute(self, model_id, inputs=None, return_rdf=False, weight_format=None):
kwargs = {"model_id": model_id, "inputs": inputs, "return_rdf": return_rdf, "weight_format": weight_format}
ret = await self.triton.execute(
inputs=[kwargs],
model_name="bioengine-model-runner",
serialization="imjoy"
)
return ret["result"]
async def get_model_rdf(self, model_id):
ret = await self.bioengine_execute(model_id, return_rdf=True)
return ret["rdf"]
async def load_image_from_bytes(self, file_name, img_bytes):
_file = io.BytesIO(img_bytes)
_file.name = file_name
if file_name.endswith(".tif") or file_name.endswith(".tiff"):
image = imageio.volread(_file)
else:
image = imageio.imread(_file)
await api.log(
"Image loaded with shape: " + str(image.shape) +
" and dtype: " + str(image.dtype)
)
self.set_current_image(image)
async def load_image_from_url(self, url):
file_name = url.split("?")[0].rstrip('/').split("/")[-1]
await api.log(file_name)
bytes_ = await fetch_file_content(url)
await self.load_image_from_bytes(file_name, bytes_)
async def run_model(
self, model_id, rdf,
image_axes=None, weight_format=None):
if self.image is None:
await api.alert("Please load an image first.")
return False
img = self.image
input_spec = rdf['inputs'][0]
input_tensor_axes = input_spec['axes']
await api.log("input_tensor_axes", input_tensor_axes)
if image_axes is None:
shape = img.shape
image_axes = get_default_image_axes(shape, input_tensor_axes)
await api.log(f"Image axes were not provided. They were automatically determined to be {image_axes}")
else:
await api.log(f"Image axes were provided as {image_axes}")
assert len(image_axes) == img.ndim
await api.log("Transforming input image...")
img = transform_input(img, image_axes, input_tensor_axes)
await api.log(f"Input image was transformed into shape {img.shape} to fit the model")
await api.log("Data loaded, running model...")
try:
result = await self.bioengine_execute(
model_id, inputs=[img], weight_format=weight_format)
except Exception as exp:
await api.alert(f"Failed to run the model ({model_id}) in the BioEngine, error: {exp}")
return False
if not result['success']:
await api.alert(f"Failed to run the model ({model_id}) in the BioEngine, error: {result['error']}")
return False
output = result['outputs'][0]
await api.showMessage(f"????Model execution completed! Got output tensor of shape {output.shape}")
output_tensor_axes = rdf['outputs'][0]['axes']
transformed_output = map_axes(output, output_tensor_axes, image_axes)
self.set_current_image(transformed_output)
return True
async def run(self, ctx):
pass
api.export(Plugin())
</script>
Be the first to comment
You can use [html][/html], [css][/css], [php][/php] and more to embed the code. Urls are automatically hyperlinked. Line breaks and paragraphs are automatically generated.