Prithvi Geospatial MAE Io Processor¶
Source examples/offline_inference/prithvi_geospatial_mae_io_processor.py.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64
import os
import torch
from vllm import LLM
from vllm.pooling_params import PoolingParams
# This example shows how to perform an offline inference that generates
# multimodal data. In this specific case this example will take a geotiff
# image as input, process it using the multimodal data processor, and
# perform inference.
# Reuirement - install plugin at:
# https://github.com/christian-pinto/prithvi_io_processor_plugin
def main():
torch.set_default_dtype(torch.float16)
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/India_900498_S2Hand.tif" # noqa: E501
img_prompt = dict(
data=image_url,
data_format="url",
image_format="tiff",
out_data_format="b64_json",
)
llm = LLM(
model="christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM",
skip_tokenizer_init=True,
trust_remote_code=True,
enforce_eager=True,
# Limit the maximum number of parallel requests
# to avoid the model going OOM.
# The maximum number depends on the available GPU memory
max_num_seqs=32,
io_processor_plugin="prithvi_to_tiff_india",
)
pooling_params = PoolingParams(task="encode", softmax=False)
pooler_output = llm.encode(
img_prompt,
pooling_params=pooling_params,
)
output = pooler_output[0].outputs
print(output)
decoded_data = base64.b64decode(output.data)
file_path = os.path.join(os.getcwd(), "offline_prediction.tiff")
with open(file_path, "wb") as f:
f.write(decoded_data)
print(f"Output file path: {file_path}")
if __name__ == "__main__":
main()